CrowdTLS-server/crowdtls/api/v1/api.py
2023-06-07 14:35:48 -07:00

121 lines
4.6 KiB
Python

from typing import Dict
from typing import List
from typing import Union
from fastapi import APIRouter
from fastapi import Depends
from fastapi import Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlmodel import select
from crowdtls.db import get_session
from crowdtls.helpers import decode_der
from crowdtls.helpers import parse_hostname
from crowdtls.helpers import raise_HTTPException
from crowdtls.logs import logger
from crowdtls.models import Certificate
from crowdtls.models import Domain
from crowdtls.models import DomainCertificateLink
app = APIRouter()
async def insert_certificate(hostname: str, certificate: Certificate, session: AsyncSession = Depends(get_session)):
domain = parse_hostname(hostname)
if domain:
existing_domain = await session.get(Domain, domain.fqdn)
if existing_domain:
logger.info("Found existing domain in database: {existing_domain.fqdn}")
existing_domain.certificates.append(certificate)
session.add(existing_domain)
else:
logger.info("Did not find existing domain in database. Creating new domain: {domain.fqdn}")
domain.certificates.append(certificate)
session.add(domain)
try:
await session.commit()
except Exception:
logger.error(f"Failed to insert certificate into database for domain {domain.fqdn}: {certificate.fingerprint}")
@app.post("/check")
async def check_fingerprints(
fingerprints: Dict[str, Union[str, List[str]]],
request: Request = None,
session: AsyncSession = Depends(get_session),
):
logger.info("Received request to check fingerprints from client {request.client.host}")
hostname = parse_hostname(fingerprints.get("host"))
fps = fingerprints.get("fps")
logger.info(f"Received {len(fps)} fingerprints to check from client {request.client.host}")
subquery = select(DomainCertificateLink.fqdn).join(Certificate).where(Certificate.fingerprint.in_(fps)).subquery()
stmt = (
select(Certificate)
.join(DomainCertificateLink)
.join(subquery, DomainCertificateLink.fqdn == subquery.c.fqdn)
.options(selectinload(Certificate.domains))
.where(DomainCertificateLink.fqdn == hostname.fqdn if hostname else True)
)
try:
result = await session.execute(stmt)
except Exception:
logger.error(f"Failed to execute stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}")
raise_HTTPException()
certificates = result.scalars().all()
logger.info(
f"Found {len(certificates)} certificates (of {len(fps)} requested) in the database for client {request.client.host}"
)
if len(certificates) == len(fps):
return {"send": False}
for certificate in certificates:
if hostname and hostname.fqdn not in [domain.fqdn for domain in certificate.domains]:
certificate.domains.append(hostname)
session.add(certificate)
await session.commit()
logger.info(f"Added mappings between {hostname.fqdn} up to {len(fps)} certificates in the database.")
return {"send": True}
@app.post("/new")
async def new_fingerprints(
fingerprints: Dict[str, Union[str, Dict[str, List[int]]]],
request: Request = None,
session: AsyncSession = Depends(get_session),
):
try:
hostname = parse_hostname(fingerprints.get("host"))
certs = fingerprints.get("certs")
fps = certs.keys()
stmt = select(Certificate).where(Certificate.fingerprint.in_(fps))
result = await session.execute(stmt)
except Exception:
logger.error(f"Failed to execute stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}")
raise_HTTPException()
logger.info(f"Received {len(fingerprints)} fingerprints to add from client {request.client.host}")
existing_fingerprints = {certificate.fingerprint for certificate in result.scalars().all()}
certificates_to_add = []
for fp, rawDER in certs.items():
if fp not in existing_fingerprints:
decoded = decode_der(fp, rawDER)
certificate = Certificate.from_orm(decoded)
certificate.domains.append(hostname)
certificates_to_add.append(certificate)
try:
session.add_all(certificates_to_add)
await session.commit()
except Exception:
logger.error(
f"Failed to add certificates to db: {certificates_to_add} after stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}"
)
raise_HTTPException()
return {"status": "OK"}