import asyncio import datetime from typing import Callable from asyncstdlib.functools import lru_cache from cryptography import x509 from sqlalchemy.orm import selectinload from sqlmodel import func from sqlmodel import select from crowdtls.db import session_maker from crowdtls.logs import logger from crowdtls.models import AnomalyFlags from crowdtls.models import AnomalyTypes from crowdtls.models import Certificate from crowdtls.models import CertificateAnomalyFlagsLink from crowdtls.models import DomainCertificateLink from crowdtls.scheduler import app as schedule # When editing, also add anomaly types to database in db.py. ANOMALY_HTTP_CODE = {"multiple_cas": 250, "short_lifespan": 251, "many_sans": 252} ANOMALY_SCANNERS = {} def anomaly_scanner(priority: int): def decorator(func: Callable): ANOMALY_SCANNERS[func] = priority return func return decorator @lru_cache(maxsize=len(ANOMALY_HTTP_CODE)) async def get_anomaly_type(response_code: int): async with session_maker() as session: query = select(AnomalyTypes.id).where(AnomalyTypes.response_code == response_code).limit(1) return (await session.execute(query)).scalars().one() async def anomaly_exists(name: str, anomalies: list): anomaly_id = await get_anomaly_type(ANOMALY_HTTP_CODE[name]) """Check if a given anomaly type exists in a list of anomalies.""" return any((a.anomaly_type_id == anomaly_id for a in anomalies)) @anomaly_scanner(priority=1) async def check_certs_for_fqdn(): """Check certificates for a given FQDN for anomalies.""" # Query for all certificates and domains which have at least 10 DomainCertificateLink entries for unexpired certificates. query = ( select(DomainCertificateLink.fqdn, Certificate) .options(selectinload(Certificate.anomalies)) .join(Certificate, Certificate.fingerprint == DomainCertificateLink.fingerprint) .group_by(DomainCertificateLink.fqdn, Certificate.fingerprint) .having(func.count(DomainCertificateLink.fqdn) > 10) ) logger.info(query) async with session_maker() as session: results = (await session.execute(query)).scalars().all() if results: for result in results: if any(await anomaly_exists("multiple_cas", certificate.anomalies) for certificate in result.certificates): continue yield { "anomaly": "multiple_cas", "details": f"Unusually high number of certificates for FQDN {result.fqdn}.", "certificates": [cert.fingerprint for cert in result.certificates], } @anomaly_scanner(priority=2) async def check_cert_lifespan(): """Check the lifespan of a certificate.""" # Query for fingerprints of all unexpired certificates which have a lifespan of less than 3 weeks. query = ( select(Certificate) .options(selectinload(Certificate.anomalies)) .where(Certificate.not_valid_after - Certificate.not_valid_before < datetime.timedelta(weeks=3)) ) async with session_maker() as session: certificates = (await session.execute(query)).scalars().all() if certificates: for certificate in certificates: if await anomaly_exists("short_lifespan", certificate.anomalies): continue yield { "anomaly": "short_lifespan", "details": f"Unusually short lifespan observed. Check certificate for {certificate.subject}", "certificates": [certificate.fingerprint], } @anomaly_scanner(priority=4) async def check_cert_sans(): """Check for a high number of Subject Alternative Names (SANs) in a certificate.""" # Query for raw certificate data. query = select(Certificate).options(selectinload(Certificate.anomalies)) # Execute the query and iterate over the results. Load each certificate with # x509 and yield for each result which has more than 8 SAN entries. async with session_maker() as session: certificates = (await session.execute(query)).scalars().all() if certificates: for certificate in certificates: if await anomaly_exists("many_sans", certificate.anomalies): continue try: certificate_x509 = x509.load_der_x509_certificate(certificate.raw_der_certificate) except Exception as e: logger.error(f"Error loading certificate {certificate.fingerprint}: {e}") continue try: num_sans = len(certificate_x509.extensions.get_extension_for_class(x509.SubjectAlternativeName).value) except x509.extensions.ExtensionNotFound: num_sans = 0 if num_sans > 80: logger.info(f"Certificate has {num_sans} SANs.") yield { "anomaly": "many_sans", "details": f"Unusually high number of SANs ({num_sans}) observed.", "certificates": [certificate.fingerprint], } async def create_anomaly_flag(anomaly): """Create a new AnomalyFlag in the database.""" logger.info("Creating anomaly flag.") anomaly_id = await get_anomaly_type(ANOMALY_HTTP_CODE[anomaly["anomaly"]]) logger.info(f"Creating anomaly flag for {anomaly_id}.") async with session_maker() as session: for fingerprint in anomaly["certificates"]: logger.info(f"Creating anomaly flag for certificate {fingerprint}.") new_anomaly = AnomalyFlags( details=anomaly["details"], anomaly_type_id=anomaly_id, ) session.add(new_anomaly) await session.flush() session.add( CertificateAnomalyFlagsLink( certificate_fingerprint=fingerprint, anomaly_flag_id=new_anomaly.id, ) ) await session.commit() @schedule.task("every 30 seconds") async def find_anomalies(): """Run all registered anomaly scanners in priority order""" await asyncio.sleep(3) for analytic, priority in sorted(ANOMALY_SCANNERS.items(), key=lambda x: x[1]): logger.info(f"Running {analytic.__name__}") try: async for anomaly in analytic(): logger.warning(f"{analytic.__name__} found {anomaly['anomaly']}") await create_anomaly_flag(anomaly) except Exception as e: logger.error(f"Error running {analytic.__name__}: {e}")