Working MVP

This commit is contained in:
Darryl Nixon 2023-06-18 08:48:03 -07:00
parent 60e249bd31
commit 475ed05e16
7 changed files with 147 additions and 180 deletions

View file

@ -6,6 +6,7 @@ from fastapi import Depends
from fastapi import Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlmodel import and_
from sqlmodel import select
from crowdtls.db import get_session
@ -13,9 +14,10 @@ from crowdtls.helpers import decode_der
from crowdtls.helpers import parse_hostname
from crowdtls.helpers import raise_HTTPException
from crowdtls.logs import logger
from crowdtls.models import AnomalyFlags
from crowdtls.models import Certificate
from crowdtls.models import CertificateAnomalyFlagsLink
from crowdtls.models import Domain
from crowdtls.models import DomainCertificateLink
app = APIRouter()
@ -38,48 +40,8 @@ async def insert_certificate(hostname: str, certificate: Certificate, session: A
logger.error(f"Failed to insert certificate into database for domain {domain.fqdn}: {certificate.fingerprint}")
# @app.post("/check")
# async def check_fingerprints(
# fingerprints: Dict[str, Union[str, List[str]]],
# request: Request = None,
# session: AsyncSession = Depends(get_session),
# ):
# logger.info("Received request to check fingerprints from client {request.client.host}")
# hostname = parse_hostname(fingerprints.get("host"))
# fps = fingerprints.get("fps")
# logger.info(f"Received {len(fps)} fingerprints to check from client {request.client.host}")
# subquery = select(DomainCertificateLink.fqdn).join(Certificate).where(Certificate.fingerprint.in_(fps)).subquery()
# stmt = (
# select(Certificate)
# .join(DomainCertificateLink)
# .join(subquery, DomainCertificateLink.fqdn == subquery.c.fqdn)
# .options(selectinload(Certificate.domains))
# .where(DomainCertificateLink.fqdn == hostname.fqdn if hostname else True)
# )
# try:
# result = await session.execute(stmt)
# except Exception:
# logger.error(f"Failed to execute stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}")
# raise_HTTPException()
# certificates = result.scalars().all()
# logger.info(
# f"Found {len(certificates)} certificates (of {len(fps)} requested) in the database for client {request.client.host}"
# )
# if len(certificates) == len(fps):
# return {"send": False}
# for certificate in certificates:
# if hostname and hostname.fqdn not in [domain.fqdn for domain in certificate.domains]:
# certificate.domains.append(hostname)
# session.add(certificate)
# await session.commit()
# logger.info(f"Added mappings between {hostname.fqdn} up to {len(fps)} certificates in the database.")
# return {"send": True}
async def get_domain_by_fqdn(fqdn: str, session: AsyncSession = Depends(get_session)):
return await session.get(Domain, fqdn)
@app.post("/check")
@ -88,49 +50,82 @@ async def check_fingerprints(
request: Request = None,
session: AsyncSession = Depends(get_session),
):
logger.info(f"Received request to check fingerprints from client {request.client.host}")
response_dict = {}
for hostname, fps in fingerprints.items():
parsed_hostname = parse_hostname(hostname)
logger.info(f"Received {len(fps)} fingerprints to check from client {request.client.host} for host {hostname}")
logger.info(f"{request.client.host} requested {hostname}: {len(fps)}")
subquery = (
select(DomainCertificateLink.fqdn).join(Certificate).where(Certificate.fingerprint.in_(fps)).subquery()
)
stmt = (
select(Certificate)
.join(DomainCertificateLink)
.join(subquery, DomainCertificateLink.fqdn == subquery.c.fqdn)
.options(selectinload(Certificate.domains))
.where(DomainCertificateLink.fqdn == parsed_hostname.fqdn if parsed_hostname else True)
)
# Query for all certificates and associated domains (from links) with the given fingerprints
stmt = select(Certificate).options(selectinload(Certificate.domains)).where(Certificate.fingerprint.in_(fps))
try:
result = await session.execute(stmt)
results = await session.execute(stmt)
except Exception:
logger.error(
f"Failed to execute stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}"
)
raise_HTTPException()
certificates = result.scalars().all()
certificates = results.scalars().all()
logger.info(
f"Found {len(certificates)} certificates (of {len(fps)} requested) in the database for client {request.client.host}"
)
if len(certificates) != len(fps):
for certificate in certificates:
if parsed_hostname and parsed_hostname.fqdn not in [domain.fqdn for domain in certificate.domains]:
count = 0
for certificate in certificates:
if parsed_hostname and parsed_hostname.fqdn not in [domain.fqdn for domain in certificate.domains]:
count += 1
logger.info(f"Adding {parsed_hostname.fqdn} to {certificate.fingerprint} in the database.")
if existing_domain := await get_domain_by_fqdn(hostname):
existing_domain.certificates.append(certificate)
session.add(existing_domain)
else:
certificate.domains.append(parsed_hostname)
session.add(certificate)
if count:
await session.commit()
logger.info(f"Added mappings between {parsed_hostname.fqdn} up to {len(fps)} certificates in the database.")
logger.info(f"Added mappings between {parsed_hostname.fqdn} and {count} certificates in the database.")
if any(fp for fp in fps if fp not in [cert.fingerprint for cert in certificates]):
logger.info(f"Requesting new certs for {hostname}.")
response_dict[hostname] = True
# Query for relevant anomalies and the associated certificate fingerprints which are used as keys
# in the response dict and are used by the client browser extensions to alert the user
stmt = (
select(AnomalyFlags, Certificate.fingerprint)
.options(selectinload(AnomalyFlags.certificates))
.where(
and_(
Certificate.fingerprint.in_(fps),
Certificate.fingerprint == CertificateAnomalyFlagsLink.certificate_fingerprint,
)
)
)
try:
results = await session.execute(stmt)
except Exception:
logger.error(
f"Failed to execute stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}"
)
raise_HTTPException()
anomalies = results.scalars().all()
logger.info(
f"Found {len(anomalies)} anomalies (of {len(fps)} requested) in the database for client {request.client.host}"
)
if anomalies:
response_dict["anomalies"] = {}
for anomaly in anomalies:
for certificate in anomaly.certificates:
if certificate.fingerprint not in response_dict["anomalies"]:
response_dict["anomalies"][certificate.fingerprint] = anomaly.details
return response_dict
@ -153,12 +148,16 @@ async def new_fingerprints(
)
raise_HTTPException()
logger.info(f"Received {len(certs)} fingerprints to add from client {request.client.host} for host {hostname}")
existing_fingerprints = {certificate.fingerprint for certificate in result.scalars().all()}
logger.info(f"Received {len(certs)} fingerprints to add from client {request.client.host} for host {hostname}")
logger.info(f"Found {len(existing_fingerprints)} existing fingerprints in the database.")
logger.info(f"{existing_fingerprints=}")
certificates_to_add = []
for fp, rawDER in certs.items():
if fp not in existing_fingerprints:
logger.info(f"Adding {fp}")
decoded = decode_der(fp, rawDER)
certificate = Certificate.from_orm(decoded)
certificate.domains.append(parsed_hostname)
@ -173,41 +172,3 @@ async def new_fingerprints(
)
raise_HTTPException()
return {"status": "OK"}
# @app.post("/new")
# async def new_fingerprints(
# fingerprints: Dict[str, Union[str, Dict[str, List[int]]]],
# request: Request = None,
# session: AsyncSession = Depends(get_session),
# ):
# try:
# hostname = parse_hostname(fingerprints.get("host"))
# certs = fingerprints.get("certs")
# fps = certs.keys()
# stmt = select(Certificate).where(Certificate.fingerprint.in_(fps))
# result = await session.execute(stmt)
# except Exception:
# logger.error(f"Failed to execute stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}")
# raise_HTTPException()
# logger.info(f"Received {len(fingerprints)} fingerprints to add from client {request.client.host}")
# existing_fingerprints = {certificate.fingerprint for certificate in result.scalars().all()}
# certificates_to_add = []
# for fp, rawDER in certs.items():
# if fp not in existing_fingerprints:
# decoded = decode_der(fp, rawDER)
# certificate = Certificate.from_orm(decoded)
# certificate.domains.append(hostname)
# certificates_to_add.append(certificate)
# try:
# session.add_all(certificates_to_add)
# await session.commit()
# except Exception:
# logger.error(
# f"Failed to add certificates to db: {certificates_to_add} after stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}"
# )
# raise_HTTPException()
# return {"status": "OK"}