CrowdTLS-server/crowdtls/analytics/main.py

171 lines
6 KiB
Python
Raw Normal View History

2023-06-16 17:02:57 -07:00
import asyncio
import datetime
from functools import cache
from typing import Callable
from cryptography import x509
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession
2023-06-16 17:02:57 -07:00
from sqlmodel import and_
from sqlmodel import func
from sqlmodel import select
from crowdtls.db import get_session
2023-06-16 17:02:57 -07:00
from crowdtls.db import session_maker
from crowdtls.logs import logger
from crowdtls.models import AnomalyFlags
from crowdtls.models import AnomalyTypes
from crowdtls.models import Certificate
from crowdtls.models import CertificateAnomalyFlagsLink
from crowdtls.models import DomainCertificateLink
from crowdtls.scheduler import app as schedule
2023-06-16 17:02:57 -07:00
# When editing, also add anomaly types to database in db.py.
ANOMALY_HTTP_CODE = {"multiple_cas": 250, "short_lifespan": 251, "many_sans": 252}
ANOMALY_SCANNERS = {}
def anomaly_scanner(priority: int):
def decorator(func: Callable):
ANOMALY_SCANNERS[func] = priority
return func
return decorator
@anomaly_scanner(priority=1)
async def check_certs_for_fqdn():
"""Check certificates for a given FQDN for anomalies."""
now = datetime.datetime.utcnow()
# Query for all certificates and domains which have at least 10 DomainCertificateLink entries for unexpired certificates.
query = (
select(DomainCertificateLink.fqdn, Certificate)
.join(Certificate, Certificate.fingerprint == DomainCertificateLink.fingerprint)
.where(
and_(
Certificate.not_valid_after > now,
Certificate.not_valid_before < now,
)
)
.group_by(DomainCertificateLink.fqdn, Certificate.fingerprint)
.having(func.count(DomainCertificateLink.fqdn) > 10)
)
logger.info(query)
async with session_maker() as session:
results = await session.execute(query)
if results:
for result in results.scalars().all():
yield {
"anomaly": "multiple_cas",
"details": f"Unusually high number of certificates for FQDN {result.fqdn}.",
"certificates": [cert.fingerprint for cert in result.certificates],
}
@anomaly_scanner(priority=2)
async def check_cert_lifespan():
"""Check the lifespan of a certificate."""
# Query for fingerprints of all unexpired certificates which have a lifespan of less than 3 weeks.
query = (
select(Certificate.fingerprint)
.where(
and_(
Certificate.not_valid_after - Certificate.not_valid_before < datetime.timedelta(weeks=3),
Certificate.not_valid_after > datetime.datetime.utcnow(),
Certificate.not_valid_before < datetime.datetime.utcnow(),
)
)
.distinct()
)
async with session_maker() as session:
results = await session.execute(query)
if results:
for result in results.scalars().all():
yield {
"anomaly": "short_lifespan",
"details": f"Unusually short lifespan for certificate {result.fingerprint}.",
"certificates": [result.fingerprint],
}
@anomaly_scanner(priority=4)
async def check_cert_sans():
"""Check the number of Subject Alternative Names (SANs) in a certificate."""
# Query for raw certificate data.
query = (
select(Certificate.fingerprint, Certificate.raw_der_certificate)
.where(
and_(
Certificate.not_valid_after > datetime.datetime.utcnow(),
Certificate.not_valid_before < datetime.datetime.utcnow(),
)
)
.distinct()
)
# Execute the query and iterate over the results. Load each certificate with
# x509 and yield for each result which has more than 8 SAN entries.
async with session_maker() as session:
results = await session.execute(query)
if results:
for result in results.scalars().all():
cert = x509.load_der_x509_certificate(result.raw_der_certificate)
if (num_sans := len(cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value)) > 80:
yield {
"anomaly": "many_sans",
"details": f"Unusually high number of SANs ({num_sans}) for certificate {result.fingerprint}.",
"certificates": [result.fingerprint],
}
@cache
async def get_anomaly_type(response_code: int):
async with session_maker() as session:
anomaly_type = await session.fetch_val(
query=select([AnomalyTypes]).where(AnomalyTypes.response_code == response_code)
)
return anomaly_type
async def create_anomaly_flag(anomaly, session: AsyncSession = Depends(get_session)):
"""Create a new AnomalyFlag in the database."""
anomaly_type = await get_anomaly_type(ANOMALY_HTTP_CODE[anomaly["anomaly"]])
async with session_maker() as session:
for fingerprint in anomaly["certificates"]:
await session.execute(
query=CertificateAnomalyFlagsLink.insert(),
values={
"certificate_fingerprint": fingerprint,
"anomaly_flag_id": anomaly_type.id,
"first_linked": datetime.utcnow(),
},
)
await session.execute(
query=AnomalyFlags.insert(),
values={
"details": anomaly["details"],
"anomaly_type_id": anomaly_type.id,
"date_flagged": datetime.utcnow(),
},
)
await session.commit()
2023-06-16 17:02:57 -07:00
@schedule.task("every 30 seconds")
async def find_anomalies():
"""Run all registered anomaly scanners in priority order"""
await asyncio.sleep(3)
for scanner, priority in sorted(ANOMALY_SCANNERS.items(), key=lambda x: x[1]):
logger.info(f"Running {scanner.__name__}")
async for anomaly in scanner():
logger.info(f"{scanner.__name__} found {anomaly['anomaly']}")
await create_anomaly_flag(anomaly)