from typing import Dict from typing import List from fastapi import APIRouter from fastapi import Depends from fastapi import Request from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from sqlmodel import and_ from sqlmodel import select from crowdtls.db import get_session from crowdtls.helpers import decode_der from crowdtls.helpers import parse_hostname from crowdtls.helpers import raise_HTTPException from crowdtls.logs import logger from crowdtls.models import AnomalyFlags from crowdtls.models import Certificate from crowdtls.models import CertificateAnomalyFlagsLink from crowdtls.models import Domain app = APIRouter() async def insert_certificate(hostname: str, certificate: Certificate, session: AsyncSession = Depends(get_session)): domain = parse_hostname(hostname) if domain: existing_domain = await session.get(Domain, domain.fqdn) if existing_domain: logger.info("Found existing domain in database: {existing_domain.fqdn}") existing_domain.certificates.append(certificate) session.add(existing_domain) else: logger.info("Did not find existing domain in database. Creating new domain: {domain.fqdn}") domain.certificates.append(certificate) session.add(domain) try: await session.commit() except Exception: logger.error(f"Failed to insert certificate into database for domain {domain.fqdn}: {certificate.fingerprint}") async def get_domain_by_fqdn(fqdn: str, session: AsyncSession = Depends(get_session)): return await session.get(Domain, fqdn) @app.post("/check") async def check_fingerprints( fingerprints: Dict[str, List[str]], request: Request = None, session: AsyncSession = Depends(get_session), ): response_dict = {} for hostname, fps in fingerprints.items(): parsed_hostname = parse_hostname(hostname) logger.info(f"{request.client.host} requested {hostname}: {len(fps)}") # Query for all certificates and associated domains (from links) with the given fingerprints stmt = select(Certificate).options(selectinload(Certificate.domains)).where(Certificate.fingerprint.in_(fps)) try: results = await session.execute(stmt) except Exception: logger.error( f"Failed to execute stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}" ) raise_HTTPException() certificates = results.scalars().all() logger.info( f"Found {len(certificates)} certificates (of {len(fps)} requested) in the database for client {request.client.host}" ) count = 0 for certificate in certificates: if parsed_hostname and parsed_hostname.fqdn not in [domain.fqdn for domain in certificate.domains]: count += 1 logger.info(f"Adding {parsed_hostname.fqdn} to {certificate.fingerprint} in the database.") if existing_domain := await get_domain_by_fqdn(hostname): existing_domain.certificates.append(certificate) session.add(existing_domain) else: certificate.domains.append(parsed_hostname) session.add(certificate) if count: await session.commit() logger.info(f"Added mappings between {parsed_hostname.fqdn} and {count} certificates in the database.") if any(fp for fp in fps if fp not in [cert.fingerprint for cert in certificates]): logger.info(f"Requesting new certs for {hostname}.") response_dict[hostname] = True # Query for relevant anomalies and the associated certificate fingerprints which are used as keys # in the response dict and are used by the client browser extensions to alert the user stmt = ( select(AnomalyFlags, Certificate.fingerprint) .options(selectinload(AnomalyFlags.certificates)) .where( and_( Certificate.fingerprint.in_(fps), Certificate.fingerprint == CertificateAnomalyFlagsLink.certificate_fingerprint, ) ) ) try: results = await session.execute(stmt) except Exception: logger.error( f"Failed to execute stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}" ) raise_HTTPException() anomalies = results.scalars().all() logger.info( f"Found {len(anomalies)} anomalies (of {len(fps)} requested) in the database for client {request.client.host}" ) if anomalies: response_dict["anomalies"] = {} for anomaly in anomalies: for certificate in anomaly.certificates: if certificate.fingerprint not in response_dict["anomalies"]: response_dict["anomalies"][certificate.fingerprint] = anomaly.details return response_dict @app.post("/new") async def new_fingerprints( fingerprints: Dict[str, Dict[str, List[int]]], request: Request = None, session: AsyncSession = Depends(get_session), ): # Iterate over each hostname and its fingerprints for hostname, certs in fingerprints.items(): try: parsed_hostname = parse_hostname(hostname) fps = certs.keys() stmt = select(Certificate).where(Certificate.fingerprint.in_(fps)) result = await session.execute(stmt) except Exception: logger.error( f"Failed to execute stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}" ) raise_HTTPException() existing_fingerprints = {certificate.fingerprint for certificate in result.scalars().all()} logger.info(f"Received {len(certs)} fingerprints to add from client {request.client.host} for host {hostname}") logger.info(f"Found {len(existing_fingerprints)} existing fingerprints in the database.") logger.info(f"{existing_fingerprints=}") certificates_to_add = [] for fp, rawDER in certs.items(): if fp not in existing_fingerprints: logger.info(f"Adding {fp}") decoded = decode_der(fp, rawDER) certificate = Certificate.from_orm(decoded) certificate.domains.append(parsed_hostname) certificates_to_add.append(certificate) try: session.add_all(certificates_to_add) await session.commit() except Exception: logger.error( f"Failed to add certificates to db: {certificates_to_add} after stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}" ) raise_HTTPException() return {"status": "OK"}