CrowdTLS-server/crowdtls/apis/v1.py
2023-06-18 08:48:03 -07:00

174 lines
6.8 KiB
Python

from typing import Dict
from typing import List
from fastapi import APIRouter
from fastapi import Depends
from fastapi import Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlmodel import and_
from sqlmodel import select
from crowdtls.db import get_session
from crowdtls.helpers import decode_der
from crowdtls.helpers import parse_hostname
from crowdtls.helpers import raise_HTTPException
from crowdtls.logs import logger
from crowdtls.models import AnomalyFlags
from crowdtls.models import Certificate
from crowdtls.models import CertificateAnomalyFlagsLink
from crowdtls.models import Domain
app = APIRouter()
async def insert_certificate(hostname: str, certificate: Certificate, session: AsyncSession = Depends(get_session)):
domain = parse_hostname(hostname)
if domain:
existing_domain = await session.get(Domain, domain.fqdn)
if existing_domain:
logger.info("Found existing domain in database: {existing_domain.fqdn}")
existing_domain.certificates.append(certificate)
session.add(existing_domain)
else:
logger.info("Did not find existing domain in database. Creating new domain: {domain.fqdn}")
domain.certificates.append(certificate)
session.add(domain)
try:
await session.commit()
except Exception:
logger.error(f"Failed to insert certificate into database for domain {domain.fqdn}: {certificate.fingerprint}")
async def get_domain_by_fqdn(fqdn: str, session: AsyncSession = Depends(get_session)):
return await session.get(Domain, fqdn)
@app.post("/check")
async def check_fingerprints(
fingerprints: Dict[str, List[str]],
request: Request = None,
session: AsyncSession = Depends(get_session),
):
response_dict = {}
for hostname, fps in fingerprints.items():
parsed_hostname = parse_hostname(hostname)
logger.info(f"{request.client.host} requested {hostname}: {len(fps)}")
# Query for all certificates and associated domains (from links) with the given fingerprints
stmt = select(Certificate).options(selectinload(Certificate.domains)).where(Certificate.fingerprint.in_(fps))
try:
results = await session.execute(stmt)
except Exception:
logger.error(
f"Failed to execute stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}"
)
raise_HTTPException()
certificates = results.scalars().all()
logger.info(
f"Found {len(certificates)} certificates (of {len(fps)} requested) in the database for client {request.client.host}"
)
count = 0
for certificate in certificates:
if parsed_hostname and parsed_hostname.fqdn not in [domain.fqdn for domain in certificate.domains]:
count += 1
logger.info(f"Adding {parsed_hostname.fqdn} to {certificate.fingerprint} in the database.")
if existing_domain := await get_domain_by_fqdn(hostname):
existing_domain.certificates.append(certificate)
session.add(existing_domain)
else:
certificate.domains.append(parsed_hostname)
session.add(certificate)
if count:
await session.commit()
logger.info(f"Added mappings between {parsed_hostname.fqdn} and {count} certificates in the database.")
if any(fp for fp in fps if fp not in [cert.fingerprint for cert in certificates]):
logger.info(f"Requesting new certs for {hostname}.")
response_dict[hostname] = True
# Query for relevant anomalies and the associated certificate fingerprints which are used as keys
# in the response dict and are used by the client browser extensions to alert the user
stmt = (
select(AnomalyFlags, Certificate.fingerprint)
.options(selectinload(AnomalyFlags.certificates))
.where(
and_(
Certificate.fingerprint.in_(fps),
Certificate.fingerprint == CertificateAnomalyFlagsLink.certificate_fingerprint,
)
)
)
try:
results = await session.execute(stmt)
except Exception:
logger.error(
f"Failed to execute stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}"
)
raise_HTTPException()
anomalies = results.scalars().all()
logger.info(
f"Found {len(anomalies)} anomalies (of {len(fps)} requested) in the database for client {request.client.host}"
)
if anomalies:
response_dict["anomalies"] = {}
for anomaly in anomalies:
for certificate in anomaly.certificates:
if certificate.fingerprint not in response_dict["anomalies"]:
response_dict["anomalies"][certificate.fingerprint] = anomaly.details
return response_dict
@app.post("/new")
async def new_fingerprints(
fingerprints: Dict[str, Dict[str, List[int]]],
request: Request = None,
session: AsyncSession = Depends(get_session),
):
# Iterate over each hostname and its fingerprints
for hostname, certs in fingerprints.items():
try:
parsed_hostname = parse_hostname(hostname)
fps = certs.keys()
stmt = select(Certificate).where(Certificate.fingerprint.in_(fps))
result = await session.execute(stmt)
except Exception:
logger.error(
f"Failed to execute stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}"
)
raise_HTTPException()
existing_fingerprints = {certificate.fingerprint for certificate in result.scalars().all()}
logger.info(f"Received {len(certs)} fingerprints to add from client {request.client.host} for host {hostname}")
logger.info(f"Found {len(existing_fingerprints)} existing fingerprints in the database.")
logger.info(f"{existing_fingerprints=}")
certificates_to_add = []
for fp, rawDER in certs.items():
if fp not in existing_fingerprints:
logger.info(f"Adding {fp}")
decoded = decode_der(fp, rawDER)
certificate = Certificate.from_orm(decoded)
certificate.domains.append(parsed_hostname)
certificates_to_add.append(certificate)
try:
session.add_all(certificates_to_add)
await session.commit()
except Exception:
logger.error(
f"Failed to add certificates to db: {certificates_to_add} after stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}"
)
raise_HTTPException()
return {"status": "OK"}