diff --git a/README.md b/README.md index ada6498..e7fda4b 100644 --- a/README.md +++ b/README.md @@ -33,11 +33,11 @@ Below is an enumeration of analytics that are run on the resulting data set to t | Analytic Name | Description | Completeness | | --- | --- | --- | -| Multiple Active Certificates | Flag an unusually high number of active certificates for a single FQDN, especially if they're from multiple CAs. | ❌ | -| Certificate Lifespan Analysis | Flag certificates with unusually short or long lifespans. | ❌ | +| Multiple Active Certificates | Flag an unusually high number of active certificates for a single FQDN, especially if they're from multiple CAs. | ✅ | +| Certificate Lifespan Analysis | Flag certificates with unusually short or long lifespans. | ✅ | | Changes in Certificate Details | Track historical data of certificates for each FQDN and flag abrupt changes. | ❌ | | Certificates from Untrusted CAs | Flag certificates issued by untrusted or less common CAs. | ❌ | -| Uncommon SAN Usage | Flag certificates with an unusually high number of SAN entries. | ❌ | +| Uncommon SAN Usage | Flag certificates with an unusually high number of SAN entries. | ✅ | | Use of Deprecated or Weak Encryption | Flag certificates that use deprecated or weak cryptographic algorithms. | ❌ | | New Certificate Detection | Alert users when a certificate for a known domain changes unexpectedly. | ❌ | | Mismatched Issuer and Subject | Flag certificates where the issuer and subject fields match (self-signed) and are not trusted roots. | ❌ | diff --git a/crowdtls/analytics/main.py b/crowdtls/analytics/main.py index 85aa802..8588158 100644 --- a/crowdtls/analytics/main.py +++ b/crowdtls/analytics/main.py @@ -1,11 +1,170 @@ +import asyncio +import datetime +from functools import cache +from typing import Callable + +from cryptography import x509 from fastapi import Depends -from rocketry.conds import hourly from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import and_ +from sqlmodel import func +from sqlmodel import select from crowdtls.db import get_session +from crowdtls.db import session_maker +from crowdtls.logs import logger +from crowdtls.models import AnomalyFlags +from crowdtls.models import AnomalyTypes +from crowdtls.models import Certificate +from crowdtls.models import CertificateAnomalyFlagsLink +from crowdtls.models import DomainCertificateLink from crowdtls.scheduler import app as schedule +# When editing, also add anomaly types to database in db.py. +ANOMALY_HTTP_CODE = {"multiple_cas": 250, "short_lifespan": 251, "many_sans": 252} -@schedule.task(hourly) -async def find_anomalies(session: AsyncSession = Depends(get_session)): - pass +ANOMALY_SCANNERS = {} + + +def anomaly_scanner(priority: int): + def decorator(func: Callable): + ANOMALY_SCANNERS[func] = priority + return func + + return decorator + + +@anomaly_scanner(priority=1) +async def check_certs_for_fqdn(): + """Check certificates for a given FQDN for anomalies.""" + now = datetime.datetime.utcnow() + + # Query for all certificates and domains which have at least 10 DomainCertificateLink entries for unexpired certificates. + query = ( + select(DomainCertificateLink.fqdn, Certificate) + .join(Certificate, Certificate.fingerprint == DomainCertificateLink.fingerprint) + .where( + and_( + Certificate.not_valid_after > now, + Certificate.not_valid_before < now, + ) + ) + .group_by(DomainCertificateLink.fqdn, Certificate.fingerprint) + .having(func.count(DomainCertificateLink.fqdn) > 10) + ) + logger.info(query) + async with session_maker() as session: + results = await session.execute(query) + + if results: + for result in results.scalars().all(): + yield { + "anomaly": "multiple_cas", + "details": f"Unusually high number of certificates for FQDN {result.fqdn}.", + "certificates": [cert.fingerprint for cert in result.certificates], + } + + +@anomaly_scanner(priority=2) +async def check_cert_lifespan(): + """Check the lifespan of a certificate.""" + + # Query for fingerprints of all unexpired certificates which have a lifespan of less than 3 weeks. + query = ( + select(Certificate.fingerprint) + .where( + and_( + Certificate.not_valid_after - Certificate.not_valid_before < datetime.timedelta(weeks=3), + Certificate.not_valid_after > datetime.datetime.utcnow(), + Certificate.not_valid_before < datetime.datetime.utcnow(), + ) + ) + .distinct() + ) + + async with session_maker() as session: + results = await session.execute(query) + + if results: + for result in results.scalars().all(): + yield { + "anomaly": "short_lifespan", + "details": f"Unusually short lifespan for certificate {result.fingerprint}.", + "certificates": [result.fingerprint], + } + + +@anomaly_scanner(priority=4) +async def check_cert_sans(): + """Check the number of Subject Alternative Names (SANs) in a certificate.""" + + # Query for raw certificate data. + query = ( + select(Certificate.fingerprint, Certificate.raw_der_certificate) + .where( + and_( + Certificate.not_valid_after > datetime.datetime.utcnow(), + Certificate.not_valid_before < datetime.datetime.utcnow(), + ) + ) + .distinct() + ) + + # Execute the query and iterate over the results. Load each certificate with + # x509 and yield for each result which has more than 8 SAN entries. + async with session_maker() as session: + results = await session.execute(query) + if results: + for result in results.scalars().all(): + cert = x509.load_der_x509_certificate(result.raw_der_certificate) + if (num_sans := len(cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value)) > 80: + yield { + "anomaly": "many_sans", + "details": f"Unusually high number of SANs ({num_sans}) for certificate {result.fingerprint}.", + "certificates": [result.fingerprint], + } + + +@cache +async def get_anomaly_type(response_code: int): + async with session_maker() as session: + anomaly_type = await session.fetch_val( + query=select([AnomalyTypes]).where(AnomalyTypes.response_code == response_code) + ) + return anomaly_type + + +async def create_anomaly_flag(anomaly, session: AsyncSession = Depends(get_session)): + """Create a new AnomalyFlag in the database.""" + anomaly_type = await get_anomaly_type(ANOMALY_HTTP_CODE[anomaly["anomaly"]]) + + async with session_maker() as session: + for fingerprint in anomaly["certificates"]: + await session.execute( + query=CertificateAnomalyFlagsLink.insert(), + values={ + "certificate_fingerprint": fingerprint, + "anomaly_flag_id": anomaly_type.id, + "first_linked": datetime.utcnow(), + }, + ) + await session.execute( + query=AnomalyFlags.insert(), + values={ + "details": anomaly["details"], + "anomaly_type_id": anomaly_type.id, + "date_flagged": datetime.utcnow(), + }, + ) + await session.commit() + + +@schedule.task("every 30 seconds") +async def find_anomalies(): + """Run all registered anomaly scanners in priority order""" + await asyncio.sleep(3) + for scanner, priority in sorted(ANOMALY_SCANNERS.items(), key=lambda x: x[1]): + logger.info(f"Running {scanner.__name__}") + async for anomaly in scanner(): + logger.info(f"{scanner.__name__} found {anomaly['anomaly']}") + await create_anomaly_flag(anomaly) diff --git a/crowdtls/db.py b/crowdtls/db.py index 31b674f..27d68a6 100644 --- a/crowdtls/db.py +++ b/crowdtls/db.py @@ -5,6 +5,9 @@ from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.orm import sessionmaker from sqlmodel import SQLModel +from crowdtls.logs import logger +from crowdtls.models import AnomalyTypes + DATABASE_URL = ( f'postgresql+asyncpg://{os.environ.get("POSTGRES_USER")}:' @@ -14,6 +17,7 @@ DATABASE_URL = ( ) engine = create_async_engine(DATABASE_URL, echo=True, future=True) +session_maker = sessionmaker(engine, class_=AsyncSession, expire_on_commit=True) async def get_session() -> AsyncSession: @@ -22,7 +26,15 @@ async def get_session() -> AsyncSession: yield session -async def create_db_and_tables(): +async def create_db_and_tables() -> None: async with engine.begin() as conn: await conn.run_sync(SQLModel.metadata.drop_all) await conn.run_sync(SQLModel.metadata.create_all) + + logger.info("Populating AnomalyTypes table") + anomalies = [ + dict(response_code=250, description="Too many active certificates >10"), + dict(response_code=251, description="Lifespan too short <3 weeks"), + dict(response_code=252, description="Too many SANs >80"), + ] + await conn.execute(AnomalyTypes.__table__.insert(), anomalies) diff --git a/crowdtls/helpers.py b/crowdtls/helpers.py index 57a73db..ba6341f 100644 --- a/crowdtls/helpers.py +++ b/crowdtls/helpers.py @@ -36,7 +36,7 @@ def decode_der(fingerprint: str, raw_der_certificate: List[int]) -> Certificate: def parse_hostname(hostname: str) -> Domain: try: parsed_domain = tldextract.extract(hostname) - return Domain(fqdn=hostname, root=parsed_domain.domain, tld=parsed_domain.suffix) + return Domain(fqdn=hostname, domain_root=parsed_domain.domain, tld=parsed_domain.suffix) except Exception: logger.error(f"Failed to parse hostname: {hostname}") diff --git a/crowdtls/main.py b/crowdtls/main.py index e72772a..a1fde3c 100644 --- a/crowdtls/main.py +++ b/crowdtls/main.py @@ -26,6 +26,7 @@ async def start_server() -> None: fastapi = asyncio.create_task(server.serve()) rocket = asyncio.create_task(app_rocketry.serve()) + app_rocketry.task await asyncio.wait([rocket, fastapi], return_when=asyncio.FIRST_COMPLETED) @@ -43,6 +44,9 @@ def run(env: Path) -> None: from crowdtls.scheduler import app as app_rocketry from crowdtls.webserver import app as app_fastapi + from crowdtls.analytics.main import ANOMALY_SCANNERS + + logger.info(f"Anomaly scanners loaded: {', '.join([x.__name__ for x in ANOMALY_SCANNERS.keys()])}") if sys.version_info >= (3, 11): with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner: diff --git a/crowdtls/models.py b/crowdtls/models.py index 3f066cf..b3ba158 100644 --- a/crowdtls/models.py +++ b/crowdtls/models.py @@ -23,7 +23,7 @@ class CertificateAnomalyFlagsLink(SQLModel, table=True): class Domain(SQLModel, table=True): fqdn: str = Field(primary_key=True) - root: str + domain_root: str tld: str first_seen: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) last_seen: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) @@ -70,7 +70,7 @@ def certificate_loaded(target, context): class AnomalyTypes(SQLModel, table=True): id: int = Field(primary_key=True) response_code: int - anomalyString: str + description: str class AnomalyFlags(SQLModel, table=True):