import asyncio import datetime from functools import cache from typing import Callable from cryptography import x509 from fastapi import Depends from sqlalchemy.ext.asyncio import AsyncSession from sqlmodel import and_ from sqlmodel import func from sqlmodel import select from crowdtls.db import get_session from crowdtls.db import session_maker from crowdtls.logs import logger from crowdtls.models import AnomalyFlags from crowdtls.models import AnomalyTypes from crowdtls.models import Certificate from crowdtls.models import CertificateAnomalyFlagsLink from crowdtls.models import DomainCertificateLink from crowdtls.scheduler import app as schedule # When editing, also add anomaly types to database in db.py. ANOMALY_HTTP_CODE = {"multiple_cas": 250, "short_lifespan": 251, "many_sans": 252} ANOMALY_SCANNERS = {} def anomaly_scanner(priority: int): def decorator(func: Callable): ANOMALY_SCANNERS[func] = priority return func return decorator @anomaly_scanner(priority=1) async def check_certs_for_fqdn(): """Check certificates for a given FQDN for anomalies.""" now = datetime.datetime.utcnow() # Query for all certificates and domains which have at least 10 DomainCertificateLink entries for unexpired certificates. query = ( select(DomainCertificateLink.fqdn, Certificate) .join(Certificate, Certificate.fingerprint == DomainCertificateLink.fingerprint) .where( and_( Certificate.not_valid_after > now, Certificate.not_valid_before < now, ) ) .group_by(DomainCertificateLink.fqdn, Certificate.fingerprint) .having(func.count(DomainCertificateLink.fqdn) > 10) ) logger.info(query) async with session_maker() as session: results = await session.execute(query) if results: for result in results.scalars().all(): yield { "anomaly": "multiple_cas", "details": f"Unusually high number of certificates for FQDN {result.fqdn}.", "certificates": [cert.fingerprint for cert in result.certificates], } @anomaly_scanner(priority=2) async def check_cert_lifespan(): """Check the lifespan of a certificate.""" # Query for fingerprints of all unexpired certificates which have a lifespan of less than 3 weeks. query = ( select(Certificate.fingerprint) .where( and_( Certificate.not_valid_after - Certificate.not_valid_before < datetime.timedelta(weeks=3), Certificate.not_valid_after > datetime.datetime.utcnow(), Certificate.not_valid_before < datetime.datetime.utcnow(), ) ) .distinct() ) async with session_maker() as session: results = await session.execute(query) if results: for result in results.scalars().all(): yield { "anomaly": "short_lifespan", "details": f"Unusually short lifespan for certificate {result.fingerprint}.", "certificates": [result.fingerprint], } @anomaly_scanner(priority=4) async def check_cert_sans(): """Check the number of Subject Alternative Names (SANs) in a certificate.""" # Query for raw certificate data. query = ( select(Certificate.fingerprint, Certificate.raw_der_certificate) .where( and_( Certificate.not_valid_after > datetime.datetime.utcnow(), Certificate.not_valid_before < datetime.datetime.utcnow(), ) ) .distinct() ) # Execute the query and iterate over the results. Load each certificate with # x509 and yield for each result which has more than 8 SAN entries. async with session_maker() as session: results = await session.execute(query) if results: for result in results.scalars().all(): cert = x509.load_der_x509_certificate(result.raw_der_certificate) if (num_sans := len(cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value)) > 80: yield { "anomaly": "many_sans", "details": f"Unusually high number of SANs ({num_sans}) for certificate {result.fingerprint}.", "certificates": [result.fingerprint], } @cache async def get_anomaly_type(response_code: int): async with session_maker() as session: anomaly_type = await session.fetch_val( query=select([AnomalyTypes]).where(AnomalyTypes.response_code == response_code) ) return anomaly_type async def create_anomaly_flag(anomaly, session: AsyncSession = Depends(get_session)): """Create a new AnomalyFlag in the database.""" anomaly_type = await get_anomaly_type(ANOMALY_HTTP_CODE[anomaly["anomaly"]]) async with session_maker() as session: for fingerprint in anomaly["certificates"]: await session.execute( query=CertificateAnomalyFlagsLink.insert(), values={ "certificate_fingerprint": fingerprint, "anomaly_flag_id": anomaly_type.id, "first_linked": datetime.utcnow(), }, ) await session.execute( query=AnomalyFlags.insert(), values={ "details": anomaly["details"], "anomaly_type_id": anomaly_type.id, "date_flagged": datetime.utcnow(), }, ) await session.commit() @schedule.task("every 30 seconds") async def find_anomalies(): """Run all registered anomaly scanners in priority order""" await asyncio.sleep(3) for scanner, priority in sorted(ANOMALY_SCANNERS.items(), key=lambda x: x[1]): logger.info(f"Running {scanner.__name__}") async for anomaly in scanner(): logger.info(f"{scanner.__name__} found {anomaly['anomaly']}") await create_anomaly_flag(anomaly)