from typing import Dict from typing import List from fastapi import APIRouter from fastapi import Depends from fastapi import Request from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from sqlmodel import select from crowdtls.db import get_session from crowdtls.helpers import decode_der from crowdtls.helpers import parse_hostname from crowdtls.helpers import raise_HTTPException from crowdtls.logs import logger from crowdtls.models import Certificate from crowdtls.models import Domain from crowdtls.models import DomainCertificateLink app = APIRouter() async def insert_certificate(hostname: str, certificate: Certificate, session: AsyncSession = Depends(get_session)): domain = parse_hostname(hostname) if domain: existing_domain = await session.get(Domain, domain.fqdn) if existing_domain: logger.info("Found existing domain in database: {existing_domain.fqdn}") existing_domain.certificates.append(certificate) session.add(existing_domain) else: logger.info("Did not find existing domain in database. Creating new domain: {domain.fqdn}") domain.certificates.append(certificate) session.add(domain) try: await session.commit() except Exception: logger.error(f"Failed to insert certificate into database for domain {domain.fqdn}: {certificate.fingerprint}") # @app.post("/check") # async def check_fingerprints( # fingerprints: Dict[str, Union[str, List[str]]], # request: Request = None, # session: AsyncSession = Depends(get_session), # ): # logger.info("Received request to check fingerprints from client {request.client.host}") # hostname = parse_hostname(fingerprints.get("host")) # fps = fingerprints.get("fps") # logger.info(f"Received {len(fps)} fingerprints to check from client {request.client.host}") # subquery = select(DomainCertificateLink.fqdn).join(Certificate).where(Certificate.fingerprint.in_(fps)).subquery() # stmt = ( # select(Certificate) # .join(DomainCertificateLink) # .join(subquery, DomainCertificateLink.fqdn == subquery.c.fqdn) # .options(selectinload(Certificate.domains)) # .where(DomainCertificateLink.fqdn == hostname.fqdn if hostname else True) # ) # try: # result = await session.execute(stmt) # except Exception: # logger.error(f"Failed to execute stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}") # raise_HTTPException() # certificates = result.scalars().all() # logger.info( # f"Found {len(certificates)} certificates (of {len(fps)} requested) in the database for client {request.client.host}" # ) # if len(certificates) == len(fps): # return {"send": False} # for certificate in certificates: # if hostname and hostname.fqdn not in [domain.fqdn for domain in certificate.domains]: # certificate.domains.append(hostname) # session.add(certificate) # await session.commit() # logger.info(f"Added mappings between {hostname.fqdn} up to {len(fps)} certificates in the database.") # return {"send": True} @app.post("/check") async def check_fingerprints( fingerprints: Dict[str, List[str]], request: Request = None, session: AsyncSession = Depends(get_session), ): logger.info(f"Received request to check fingerprints from client {request.client.host}") response_dict = {} for hostname, fps in fingerprints.items(): parsed_hostname = parse_hostname(hostname) logger.info(f"Received {len(fps)} fingerprints to check from client {request.client.host} for host {hostname}") subquery = ( select(DomainCertificateLink.fqdn).join(Certificate).where(Certificate.fingerprint.in_(fps)).subquery() ) stmt = ( select(Certificate) .join(DomainCertificateLink) .join(subquery, DomainCertificateLink.fqdn == subquery.c.fqdn) .options(selectinload(Certificate.domains)) .where(DomainCertificateLink.fqdn == parsed_hostname.fqdn if parsed_hostname else True) ) try: result = await session.execute(stmt) except Exception: logger.error( f"Failed to execute stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}" ) raise_HTTPException() certificates = result.scalars().all() logger.info( f"Found {len(certificates)} certificates (of {len(fps)} requested) in the database for client {request.client.host}" ) if len(certificates) != len(fps): for certificate in certificates: if parsed_hostname and parsed_hostname.fqdn not in [domain.fqdn for domain in certificate.domains]: certificate.domains.append(parsed_hostname) session.add(certificate) await session.commit() logger.info(f"Added mappings between {parsed_hostname.fqdn} up to {len(fps)} certificates in the database.") response_dict[hostname] = True return response_dict @app.post("/new") async def new_fingerprints( fingerprints: Dict[str, Dict[str, List[int]]], request: Request = None, session: AsyncSession = Depends(get_session), ): # Iterate over each hostname and its fingerprints for hostname, certs in fingerprints.items(): try: parsed_hostname = parse_hostname(hostname) fps = certs.keys() stmt = select(Certificate).where(Certificate.fingerprint.in_(fps)) result = await session.execute(stmt) except Exception: logger.error( f"Failed to execute stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}" ) raise_HTTPException() logger.info(f"Received {len(certs)} fingerprints to add from client {request.client.host} for host {hostname}") existing_fingerprints = {certificate.fingerprint for certificate in result.scalars().all()} certificates_to_add = [] for fp, rawDER in certs.items(): if fp not in existing_fingerprints: decoded = decode_der(fp, rawDER) certificate = Certificate.from_orm(decoded) certificate.domains.append(parsed_hostname) certificates_to_add.append(certificate) try: session.add_all(certificates_to_add) await session.commit() except Exception: logger.error( f"Failed to add certificates to db: {certificates_to_add} after stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}" ) raise_HTTPException() return {"status": "OK"} # @app.post("/new") # async def new_fingerprints( # fingerprints: Dict[str, Union[str, Dict[str, List[int]]]], # request: Request = None, # session: AsyncSession = Depends(get_session), # ): # try: # hostname = parse_hostname(fingerprints.get("host")) # certs = fingerprints.get("certs") # fps = certs.keys() # stmt = select(Certificate).where(Certificate.fingerprint.in_(fps)) # result = await session.execute(stmt) # except Exception: # logger.error(f"Failed to execute stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}") # raise_HTTPException() # logger.info(f"Received {len(fingerprints)} fingerprints to add from client {request.client.host}") # existing_fingerprints = {certificate.fingerprint for certificate in result.scalars().all()} # certificates_to_add = [] # for fp, rawDER in certs.items(): # if fp not in existing_fingerprints: # decoded = decode_der(fp, rawDER) # certificate = Certificate.from_orm(decoded) # certificate.domains.append(hostname) # certificates_to_add.append(certificate) # try: # session.add_all(certificates_to_add) # await session.commit() # except Exception: # logger.error( # f"Failed to add certificates to db: {certificates_to_add} after stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}" # ) # raise_HTTPException() # return {"status": "OK"}