CrowdTLS-server/crowdtls/apis/v1.py

214 lines
8.3 KiB
Python
Raw Normal View History

2023-06-07 14:35:48 -07:00
from typing import Dict
from typing import List
from fastapi import APIRouter
from fastapi import Depends
from fastapi import Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlmodel import select
from crowdtls.db import get_session
from crowdtls.helpers import decode_der
from crowdtls.helpers import parse_hostname
from crowdtls.helpers import raise_HTTPException
from crowdtls.logs import logger
from crowdtls.models import Certificate
from crowdtls.models import Domain
from crowdtls.models import DomainCertificateLink
app = APIRouter()
async def insert_certificate(hostname: str, certificate: Certificate, session: AsyncSession = Depends(get_session)):
domain = parse_hostname(hostname)
if domain:
existing_domain = await session.get(Domain, domain.fqdn)
if existing_domain:
logger.info("Found existing domain in database: {existing_domain.fqdn}")
existing_domain.certificates.append(certificate)
session.add(existing_domain)
else:
logger.info("Did not find existing domain in database. Creating new domain: {domain.fqdn}")
domain.certificates.append(certificate)
session.add(domain)
try:
await session.commit()
except Exception:
logger.error(f"Failed to insert certificate into database for domain {domain.fqdn}: {certificate.fingerprint}")
# @app.post("/check")
# async def check_fingerprints(
# fingerprints: Dict[str, Union[str, List[str]]],
# request: Request = None,
# session: AsyncSession = Depends(get_session),
# ):
# logger.info("Received request to check fingerprints from client {request.client.host}")
# hostname = parse_hostname(fingerprints.get("host"))
# fps = fingerprints.get("fps")
# logger.info(f"Received {len(fps)} fingerprints to check from client {request.client.host}")
# subquery = select(DomainCertificateLink.fqdn).join(Certificate).where(Certificate.fingerprint.in_(fps)).subquery()
# stmt = (
# select(Certificate)
# .join(DomainCertificateLink)
# .join(subquery, DomainCertificateLink.fqdn == subquery.c.fqdn)
# .options(selectinload(Certificate.domains))
# .where(DomainCertificateLink.fqdn == hostname.fqdn if hostname else True)
# )
# try:
# result = await session.execute(stmt)
# except Exception:
# logger.error(f"Failed to execute stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}")
# raise_HTTPException()
# certificates = result.scalars().all()
# logger.info(
# f"Found {len(certificates)} certificates (of {len(fps)} requested) in the database for client {request.client.host}"
# )
# if len(certificates) == len(fps):
# return {"send": False}
# for certificate in certificates:
# if hostname and hostname.fqdn not in [domain.fqdn for domain in certificate.domains]:
# certificate.domains.append(hostname)
# session.add(certificate)
# await session.commit()
# logger.info(f"Added mappings between {hostname.fqdn} up to {len(fps)} certificates in the database.")
# return {"send": True}
2023-06-07 14:35:48 -07:00
@app.post("/check")
async def check_fingerprints(
fingerprints: Dict[str, List[str]],
2023-06-07 14:35:48 -07:00
request: Request = None,
session: AsyncSession = Depends(get_session),
):
logger.info(f"Received request to check fingerprints from client {request.client.host}")
response_dict = {}
for hostname, fps in fingerprints.items():
parsed_hostname = parse_hostname(hostname)
logger.info(f"Received {len(fps)} fingerprints to check from client {request.client.host} for host {hostname}")
subquery = (
select(DomainCertificateLink.fqdn).join(Certificate).where(Certificate.fingerprint.in_(fps)).subquery()
)
stmt = (
select(Certificate)
.join(DomainCertificateLink)
.join(subquery, DomainCertificateLink.fqdn == subquery.c.fqdn)
.options(selectinload(Certificate.domains))
.where(DomainCertificateLink.fqdn == parsed_hostname.fqdn if parsed_hostname else True)
)
2023-06-07 14:35:48 -07:00
try:
result = await session.execute(stmt)
except Exception:
logger.error(
f"Failed to execute stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}"
)
raise_HTTPException()
certificates = result.scalars().all()
logger.info(
f"Found {len(certificates)} certificates (of {len(fps)} requested) in the database for client {request.client.host}"
)
2023-06-07 14:35:48 -07:00
if len(certificates) != len(fps):
for certificate in certificates:
if parsed_hostname and parsed_hostname.fqdn not in [domain.fqdn for domain in certificate.domains]:
certificate.domains.append(parsed_hostname)
session.add(certificate)
2023-06-07 14:35:48 -07:00
await session.commit()
logger.info(f"Added mappings between {parsed_hostname.fqdn} up to {len(fps)} certificates in the database.")
response_dict[hostname] = True
2023-06-07 14:35:48 -07:00
return response_dict
2023-06-07 14:35:48 -07:00
@app.post("/new")
async def new_fingerprints(
fingerprints: Dict[str, Dict[str, List[int]]],
2023-06-07 14:35:48 -07:00
request: Request = None,
session: AsyncSession = Depends(get_session),
):
# Iterate over each hostname and its fingerprints
for hostname, certs in fingerprints.items():
try:
parsed_hostname = parse_hostname(hostname)
fps = certs.keys()
stmt = select(Certificate).where(Certificate.fingerprint.in_(fps))
result = await session.execute(stmt)
except Exception:
logger.error(
f"Failed to execute stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}"
)
raise_HTTPException()
logger.info(f"Received {len(certs)} fingerprints to add from client {request.client.host} for host {hostname}")
existing_fingerprints = {certificate.fingerprint for certificate in result.scalars().all()}
certificates_to_add = []
for fp, rawDER in certs.items():
if fp not in existing_fingerprints:
decoded = decode_der(fp, rawDER)
certificate = Certificate.from_orm(decoded)
certificate.domains.append(parsed_hostname)
certificates_to_add.append(certificate)
try:
session.add_all(certificates_to_add)
await session.commit()
except Exception:
logger.error(
f"Failed to add certificates to db: {certificates_to_add} after stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}"
)
raise_HTTPException()
return {"status": "OK"}
2023-06-07 14:35:48 -07:00
# @app.post("/new")
# async def new_fingerprints(
# fingerprints: Dict[str, Union[str, Dict[str, List[int]]]],
# request: Request = None,
# session: AsyncSession = Depends(get_session),
# ):
# try:
# hostname = parse_hostname(fingerprints.get("host"))
# certs = fingerprints.get("certs")
# fps = certs.keys()
# stmt = select(Certificate).where(Certificate.fingerprint.in_(fps))
# result = await session.execute(stmt)
# except Exception:
# logger.error(f"Failed to execute stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}")
# raise_HTTPException()
# logger.info(f"Received {len(fingerprints)} fingerprints to add from client {request.client.host}")
# existing_fingerprints = {certificate.fingerprint for certificate in result.scalars().all()}
# certificates_to_add = []
# for fp, rawDER in certs.items():
# if fp not in existing_fingerprints:
# decoded = decode_der(fp, rawDER)
# certificate = Certificate.from_orm(decoded)
# certificate.domains.append(hostname)
# certificates_to_add.append(certificate)
# try:
# session.add_all(certificates_to_add)
# await session.commit()
# except Exception:
# logger.error(
# f"Failed to add certificates to db: {certificates_to_add} after stmt: {stmt} (req body {request.body}) and IP address: {request.client.host}"
# )
# raise_HTTPException()
# return {"status": "OK"}