CrowdTLS-server/crowdtls/analytics/main.py
2023-06-18 08:48:03 -07:00

172 lines
6.4 KiB
Python

import asyncio
import datetime
from typing import Callable
from asyncstdlib.functools import lru_cache
from cryptography import x509
from sqlalchemy.orm import selectinload
from sqlmodel import func
from sqlmodel import select
from crowdtls.db import session_maker
from crowdtls.logs import logger
from crowdtls.models import AnomalyFlags
from crowdtls.models import AnomalyTypes
from crowdtls.models import Certificate
from crowdtls.models import CertificateAnomalyFlagsLink
from crowdtls.models import DomainCertificateLink
from crowdtls.scheduler import app as schedule
# When editing, also add anomaly types to database in db.py.
ANOMALY_HTTP_CODE = {"multiple_cas": 250, "short_lifespan": 251, "many_sans": 252}
ANOMALY_SCANNERS = {}
def anomaly_scanner(priority: int):
def decorator(func: Callable):
ANOMALY_SCANNERS[func] = priority
return func
return decorator
@lru_cache(maxsize=len(ANOMALY_HTTP_CODE))
async def get_anomaly_type(response_code: int):
async with session_maker() as session:
query = select(AnomalyTypes.id).where(AnomalyTypes.response_code == response_code).limit(1)
return (await session.execute(query)).scalars().one()
async def anomaly_exists(name: str, anomalies: list):
anomaly_id = await get_anomaly_type(ANOMALY_HTTP_CODE[name])
"""Check if a given anomaly type exists in a list of anomalies."""
return any((a.anomaly_type_id == anomaly_id for a in anomalies))
@anomaly_scanner(priority=1)
async def check_certs_for_fqdn():
"""Check certificates for a given FQDN for anomalies."""
# Query for all certificates and domains which have at least 10 DomainCertificateLink entries for unexpired certificates.
query = (
select(DomainCertificateLink.fqdn, Certificate)
.options(selectinload(Certificate.anomalies))
.join(Certificate, Certificate.fingerprint == DomainCertificateLink.fingerprint)
.group_by(DomainCertificateLink.fqdn, Certificate.fingerprint)
.having(func.count(DomainCertificateLink.fqdn) > 10)
)
logger.info(query)
async with session_maker() as session:
results = (await session.execute(query)).scalars().all()
if results:
for result in results:
if any(await anomaly_exists("multiple_cas", certificate.anomalies) for certificate in result.certificates):
continue
yield {
"anomaly": "multiple_cas",
"details": f"Unusually high number of certificates for FQDN {result.fqdn}.",
"certificates": [cert.fingerprint for cert in result.certificates],
}
@anomaly_scanner(priority=2)
async def check_cert_lifespan():
"""Check the lifespan of a certificate."""
# Query for fingerprints of all unexpired certificates which have a lifespan of less than 3 weeks.
query = (
select(Certificate)
.options(selectinload(Certificate.anomalies))
.where(Certificate.not_valid_after - Certificate.not_valid_before < datetime.timedelta(weeks=3))
)
async with session_maker() as session:
certificates = (await session.execute(query)).scalars().all()
if certificates:
for certificate in certificates:
if await anomaly_exists("short_lifespan", certificate.anomalies):
continue
yield {
"anomaly": "short_lifespan",
"details": f"Unusually short lifespan observed. Check certificate for {certificate.subject}",
"certificates": [certificate.fingerprint],
}
@anomaly_scanner(priority=4)
async def check_cert_sans():
"""Check for a high number of Subject Alternative Names (SANs) in a certificate."""
# Query for raw certificate data.
query = select(Certificate).options(selectinload(Certificate.anomalies))
# Execute the query and iterate over the results. Load each certificate with
# x509 and yield for each result which has more than 8 SAN entries.
async with session_maker() as session:
certificates = (await session.execute(query)).scalars().all()
if certificates:
for certificate in certificates:
if await anomaly_exists("many_sans", certificate.anomalies):
continue
try:
certificate_x509 = x509.load_der_x509_certificate(certificate.raw_der_certificate)
except Exception as e:
logger.error(f"Error loading certificate {certificate.fingerprint}: {e}")
continue
try:
num_sans = len(certificate_x509.extensions.get_extension_for_class(x509.SubjectAlternativeName).value)
except x509.extensions.ExtensionNotFound:
num_sans = 0
if num_sans > 80:
logger.info(f"Certificate has {num_sans} SANs.")
yield {
"anomaly": "many_sans",
"details": f"Unusually high number of SANs ({num_sans}) observed.",
"certificates": [certificate.fingerprint],
}
async def create_anomaly_flag(anomaly):
"""Create a new AnomalyFlag in the database."""
logger.info("Creating anomaly flag.")
anomaly_id = await get_anomaly_type(ANOMALY_HTTP_CODE[anomaly["anomaly"]])
logger.info(f"Creating anomaly flag for {anomaly_id}.")
async with session_maker() as session:
for fingerprint in anomaly["certificates"]:
logger.info(f"Creating anomaly flag for certificate {fingerprint}.")
new_anomaly = AnomalyFlags(
details=anomaly["details"],
anomaly_type_id=anomaly_id,
)
session.add(new_anomaly)
await session.flush()
session.add(
CertificateAnomalyFlagsLink(
certificate_fingerprint=fingerprint,
anomaly_flag_id=new_anomaly.id,
)
)
await session.commit()
@schedule.task("every 30 seconds")
async def find_anomalies():
"""Run all registered anomaly scanners in priority order"""
await asyncio.sleep(3)
for analytic, priority in sorted(ANOMALY_SCANNERS.items(), key=lambda x: x[1]):
logger.info(f"Running {analytic.__name__}")
try:
async for anomaly in analytic():
logger.warning(f"{analytic.__name__} found {anomaly['anomaly']}")
await create_anomaly_flag(anomaly)
except Exception as e:
logger.error(f"Error running {analytic.__name__}: {e}")