mirror of
https://github.com/DarrylNixon/CrowdTLS-server.git
synced 2024-09-22 18:19:43 -07:00
174 lines
6.8 KiB
Python
174 lines
6.8 KiB
Python
from typing import Dict
|
|
from typing import List
|
|
|
|
from fastapi import APIRouter
|
|
from fastapi import Depends
|
|
from fastapi import Request
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import selectinload
|
|
from sqlmodel import and_
|
|
from sqlmodel import select
|
|
|
|
from crowdtls.db import get_session
|
|
from crowdtls.helpers import decode_der
|
|
from crowdtls.helpers import parse_hostname
|
|
from crowdtls.helpers import raise_HTTPException
|
|
from crowdtls.logs import logger
|
|
from crowdtls.models import AnomalyFlags
|
|
from crowdtls.models import Certificate
|
|
from crowdtls.models import CertificateAnomalyFlagsLink
|
|
from crowdtls.models import Domain
|
|
|
|
app = APIRouter()
|
|
|
|
|
|
async def insert_certificate(hostname: str, certificate: Certificate, session: AsyncSession = Depends(get_session)):
|
|
domain = parse_hostname(hostname)
|
|
if domain:
|
|
existing_domain = await session.get(Domain, domain.fqdn)
|
|
if existing_domain:
|
|
logger.info("Found existing domain in database: {existing_domain.fqdn}")
|
|
existing_domain.certificates.append(certificate)
|
|
session.add(existing_domain)
|
|
else:
|
|
logger.info("Did not find existing domain in database. Creating new domain: {domain.fqdn}")
|
|
domain.certificates.append(certificate)
|
|
session.add(domain)
|
|
try:
|
|
await session.commit()
|
|
except Exception:
|
|
logger.error(f"Failed to insert certificate into database for domain {domain.fqdn}: {certificate.fingerprint}")
|
|
|
|
|
|
async def get_domain_by_fqdn(fqdn: str, session: AsyncSession = Depends(get_session)):
|
|
return await session.get(Domain, fqdn)
|
|
|
|
|
|
@app.post("/check")
|
|
async def check_fingerprints(
|
|
fingerprints: Dict[str, List[str]],
|
|
request: Request = None,
|
|
session: AsyncSession = Depends(get_session),
|
|
):
|
|
response_dict = {}
|
|
|
|
for hostname, fps in fingerprints.items():
|
|
parsed_hostname = parse_hostname(hostname)
|
|
logger.info(f"{request.client.host} requested {hostname}: {len(fps)}")
|
|
|
|
# Query for all certificates and associated domains (from links) with the given fingerprints
|
|
stmt = select(Certificate).options(selectinload(Certificate.domains)).where(Certificate.fingerprint.in_(fps))
|
|
|
|
try:
|
|
results = await session.execute(stmt)
|
|
except Exception:
|
|
logger.error(
|
|
f"Failed to execute stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}"
|
|
)
|
|
raise_HTTPException()
|
|
|
|
certificates = results.scalars().all()
|
|
logger.info(
|
|
f"Found {len(certificates)} certificates (of {len(fps)} requested) in the database for client {request.client.host}"
|
|
)
|
|
|
|
count = 0
|
|
for certificate in certificates:
|
|
if parsed_hostname and parsed_hostname.fqdn not in [domain.fqdn for domain in certificate.domains]:
|
|
count += 1
|
|
logger.info(f"Adding {parsed_hostname.fqdn} to {certificate.fingerprint} in the database.")
|
|
if existing_domain := await get_domain_by_fqdn(hostname):
|
|
existing_domain.certificates.append(certificate)
|
|
session.add(existing_domain)
|
|
else:
|
|
certificate.domains.append(parsed_hostname)
|
|
session.add(certificate)
|
|
|
|
if count:
|
|
await session.commit()
|
|
logger.info(f"Added mappings between {parsed_hostname.fqdn} and {count} certificates in the database.")
|
|
if any(fp for fp in fps if fp not in [cert.fingerprint for cert in certificates]):
|
|
logger.info(f"Requesting new certs for {hostname}.")
|
|
response_dict[hostname] = True
|
|
|
|
# Query for relevant anomalies and the associated certificate fingerprints which are used as keys
|
|
# in the response dict and are used by the client browser extensions to alert the user
|
|
stmt = (
|
|
select(AnomalyFlags, Certificate.fingerprint)
|
|
.options(selectinload(AnomalyFlags.certificates))
|
|
.where(
|
|
and_(
|
|
Certificate.fingerprint.in_(fps),
|
|
Certificate.fingerprint == CertificateAnomalyFlagsLink.certificate_fingerprint,
|
|
)
|
|
)
|
|
)
|
|
|
|
try:
|
|
results = await session.execute(stmt)
|
|
except Exception:
|
|
logger.error(
|
|
f"Failed to execute stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}"
|
|
)
|
|
raise_HTTPException()
|
|
|
|
anomalies = results.scalars().all()
|
|
|
|
logger.info(
|
|
f"Found {len(anomalies)} anomalies (of {len(fps)} requested) in the database for client {request.client.host}"
|
|
)
|
|
|
|
if anomalies:
|
|
response_dict["anomalies"] = {}
|
|
|
|
for anomaly in anomalies:
|
|
for certificate in anomaly.certificates:
|
|
if certificate.fingerprint not in response_dict["anomalies"]:
|
|
response_dict["anomalies"][certificate.fingerprint] = anomaly.details
|
|
|
|
return response_dict
|
|
|
|
|
|
@app.post("/new")
|
|
async def new_fingerprints(
|
|
fingerprints: Dict[str, Dict[str, List[int]]],
|
|
request: Request = None,
|
|
session: AsyncSession = Depends(get_session),
|
|
):
|
|
# Iterate over each hostname and its fingerprints
|
|
for hostname, certs in fingerprints.items():
|
|
try:
|
|
parsed_hostname = parse_hostname(hostname)
|
|
fps = certs.keys()
|
|
stmt = select(Certificate).where(Certificate.fingerprint.in_(fps))
|
|
result = await session.execute(stmt)
|
|
except Exception:
|
|
logger.error(
|
|
f"Failed to execute stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}"
|
|
)
|
|
raise_HTTPException()
|
|
|
|
existing_fingerprints = {certificate.fingerprint for certificate in result.scalars().all()}
|
|
logger.info(f"Received {len(certs)} fingerprints to add from client {request.client.host} for host {hostname}")
|
|
logger.info(f"Found {len(existing_fingerprints)} existing fingerprints in the database.")
|
|
|
|
logger.info(f"{existing_fingerprints=}")
|
|
|
|
certificates_to_add = []
|
|
for fp, rawDER in certs.items():
|
|
if fp not in existing_fingerprints:
|
|
logger.info(f"Adding {fp}")
|
|
decoded = decode_der(fp, rawDER)
|
|
certificate = Certificate.from_orm(decoded)
|
|
certificate.domains.append(parsed_hostname)
|
|
certificates_to_add.append(certificate)
|
|
|
|
try:
|
|
session.add_all(certificates_to_add)
|
|
await session.commit()
|
|
except Exception:
|
|
logger.error(
|
|
f"Failed to add certificates to db: {certificates_to_add} after stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}"
|
|
)
|
|
raise_HTTPException()
|
|
return {"status": "OK"}
|