Working MVP

This commit is contained in:
Darryl Nixon 2023-06-18 08:48:03 -07:00
parent 60e249bd31
commit 475ed05e16
7 changed files with 147 additions and 180 deletions

View file

@ -1,16 +1,13 @@
import asyncio
import datetime
from functools import cache
from typing import Callable
from asyncstdlib.functools import lru_cache
from cryptography import x509
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import and_
from sqlalchemy.orm import selectinload
from sqlmodel import func
from sqlmodel import select
from crowdtls.db import get_session
from crowdtls.db import session_maker
from crowdtls.logs import logger
from crowdtls.models import AnomalyFlags
@ -34,30 +31,39 @@ def anomaly_scanner(priority: int):
return decorator
@lru_cache(maxsize=len(ANOMALY_HTTP_CODE))
async def get_anomaly_type(response_code: int):
async with session_maker() as session:
query = select(AnomalyTypes.id).where(AnomalyTypes.response_code == response_code).limit(1)
return (await session.execute(query)).scalars().one()
async def anomaly_exists(name: str, anomalies: list):
anomaly_id = await get_anomaly_type(ANOMALY_HTTP_CODE[name])
"""Check if a given anomaly type exists in a list of anomalies."""
return any((a.anomaly_type_id == anomaly_id for a in anomalies))
@anomaly_scanner(priority=1)
async def check_certs_for_fqdn():
"""Check certificates for a given FQDN for anomalies."""
now = datetime.datetime.utcnow()
# Query for all certificates and domains which have at least 10 DomainCertificateLink entries for unexpired certificates.
query = (
select(DomainCertificateLink.fqdn, Certificate)
.options(selectinload(Certificate.anomalies))
.join(Certificate, Certificate.fingerprint == DomainCertificateLink.fingerprint)
.where(
and_(
Certificate.not_valid_after > now,
Certificate.not_valid_before < now,
)
)
.group_by(DomainCertificateLink.fqdn, Certificate.fingerprint)
.having(func.count(DomainCertificateLink.fqdn) > 10)
)
logger.info(query)
async with session_maker() as session:
results = await session.execute(query)
results = (await session.execute(query)).scalars().all()
if results:
for result in results.scalars().all():
for result in results:
if any(await anomaly_exists("multiple_cas", certificate.anomalies) for certificate in result.certificates):
continue
yield {
"anomaly": "multiple_cas",
"details": f"Unusually high number of certificates for FQDN {result.fqdn}.",
@ -71,90 +77,83 @@ async def check_cert_lifespan():
# Query for fingerprints of all unexpired certificates which have a lifespan of less than 3 weeks.
query = (
select(Certificate.fingerprint)
.where(
and_(
Certificate.not_valid_after - Certificate.not_valid_before < datetime.timedelta(weeks=3),
Certificate.not_valid_after > datetime.datetime.utcnow(),
Certificate.not_valid_before < datetime.datetime.utcnow(),
)
)
.distinct()
select(Certificate)
.options(selectinload(Certificate.anomalies))
.where(Certificate.not_valid_after - Certificate.not_valid_before < datetime.timedelta(weeks=3))
)
async with session_maker() as session:
results = await session.execute(query)
certificates = (await session.execute(query)).scalars().all()
if certificates:
for certificate in certificates:
if await anomaly_exists("short_lifespan", certificate.anomalies):
continue
if results:
for result in results.scalars().all():
yield {
"anomaly": "short_lifespan",
"details": f"Unusually short lifespan for certificate {result.fingerprint}.",
"certificates": [result.fingerprint],
"details": f"Unusually short lifespan observed. Check certificate for {certificate.subject}",
"certificates": [certificate.fingerprint],
}
@anomaly_scanner(priority=4)
async def check_cert_sans():
"""Check the number of Subject Alternative Names (SANs) in a certificate."""
"""Check for a high number of Subject Alternative Names (SANs) in a certificate."""
# Query for raw certificate data.
query = (
select(Certificate.fingerprint, Certificate.raw_der_certificate)
.where(
and_(
Certificate.not_valid_after > datetime.datetime.utcnow(),
Certificate.not_valid_before < datetime.datetime.utcnow(),
)
)
.distinct()
)
query = select(Certificate).options(selectinload(Certificate.anomalies))
# Execute the query and iterate over the results. Load each certificate with
# x509 and yield for each result which has more than 8 SAN entries.
async with session_maker() as session:
results = await session.execute(query)
if results:
for result in results.scalars().all():
cert = x509.load_der_x509_certificate(result.raw_der_certificate)
if (num_sans := len(cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value)) > 80:
certificates = (await session.execute(query)).scalars().all()
if certificates:
for certificate in certificates:
if await anomaly_exists("many_sans", certificate.anomalies):
continue
try:
certificate_x509 = x509.load_der_x509_certificate(certificate.raw_der_certificate)
except Exception as e:
logger.error(f"Error loading certificate {certificate.fingerprint}: {e}")
continue
try:
num_sans = len(certificate_x509.extensions.get_extension_for_class(x509.SubjectAlternativeName).value)
except x509.extensions.ExtensionNotFound:
num_sans = 0
if num_sans > 80:
logger.info(f"Certificate has {num_sans} SANs.")
yield {
"anomaly": "many_sans",
"details": f"Unusually high number of SANs ({num_sans}) for certificate {result.fingerprint}.",
"certificates": [result.fingerprint],
"details": f"Unusually high number of SANs ({num_sans}) observed.",
"certificates": [certificate.fingerprint],
}
@cache
async def get_anomaly_type(response_code: int):
async with session_maker() as session:
anomaly_type = await session.fetch_val(
query=select([AnomalyTypes]).where(AnomalyTypes.response_code == response_code)
)
return anomaly_type
async def create_anomaly_flag(anomaly, session: AsyncSession = Depends(get_session)):
async def create_anomaly_flag(anomaly):
"""Create a new AnomalyFlag in the database."""
anomaly_type = await get_anomaly_type(ANOMALY_HTTP_CODE[anomaly["anomaly"]])
logger.info("Creating anomaly flag.")
anomaly_id = await get_anomaly_type(ANOMALY_HTTP_CODE[anomaly["anomaly"]])
logger.info(f"Creating anomaly flag for {anomaly_id}.")
async with session_maker() as session:
for fingerprint in anomaly["certificates"]:
await session.execute(
query=CertificateAnomalyFlagsLink.insert(),
values={
"certificate_fingerprint": fingerprint,
"anomaly_flag_id": anomaly_type.id,
"first_linked": datetime.utcnow(),
},
logger.info(f"Creating anomaly flag for certificate {fingerprint}.")
new_anomaly = AnomalyFlags(
details=anomaly["details"],
anomaly_type_id=anomaly_id,
)
await session.execute(
query=AnomalyFlags.insert(),
values={
"details": anomaly["details"],
"anomaly_type_id": anomaly_type.id,
"date_flagged": datetime.utcnow(),
},
session.add(new_anomaly)
await session.flush()
session.add(
CertificateAnomalyFlagsLink(
certificate_fingerprint=fingerprint,
anomaly_flag_id=new_anomaly.id,
)
)
await session.commit()
@ -163,8 +162,11 @@ async def create_anomaly_flag(anomaly, session: AsyncSession = Depends(get_sessi
async def find_anomalies():
"""Run all registered anomaly scanners in priority order"""
await asyncio.sleep(3)
for scanner, priority in sorted(ANOMALY_SCANNERS.items(), key=lambda x: x[1]):
logger.info(f"Running {scanner.__name__}")
async for anomaly in scanner():
logger.info(f"{scanner.__name__} found {anomaly['anomaly']}")
await create_anomaly_flag(anomaly)
for analytic, priority in sorted(ANOMALY_SCANNERS.items(), key=lambda x: x[1]):
logger.info(f"Running {analytic.__name__}")
try:
async for anomaly in analytic():
logger.warning(f"{analytic.__name__} found {anomaly['anomaly']}")
await create_anomaly_flag(anomaly)
except Exception as e:
logger.error(f"Error running {analytic.__name__}: {e}")