CrowdTLS-server/crowdtls/db.py
2023-06-18 08:48:03 -07:00

41 lines
1.4 KiB
Python

import os
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlmodel import SQLModel
from crowdtls.logs import logger
from crowdtls.models import AnomalyTypes
DATABASE_URL = (
f'postgresql+asyncpg://{os.environ.get("POSTGRES_USER")}:'
+ f'{os.environ.get("POSTGRES_PASSWORD")}@'
+ os.environ.get("POSTGRES_CONTAINER")
+ f':5432/{os.environ.get("POSTGRES_DB")}'
)
engine = create_async_engine(DATABASE_URL, echo=True, future=True)
session_maker = sessionmaker(engine, class_=AsyncSession, expire_on_commit=True)
async def get_session() -> AsyncSession:
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async with async_session() as session:
yield session
async def create_db_and_tables() -> None:
async with engine.begin() as conn:
# pass
await conn.run_sync(SQLModel.metadata.drop_all)
await conn.run_sync(SQLModel.metadata.create_all)
logger.info("Populating AnomalyTypes table")
anomalies = [
dict(response_code=250, description="Too many active certificates >10"),
dict(response_code=251, description="Lifespan too short <3 weeks"),
dict(response_code=252, description="Too many SANs >80"),
]
await conn.execute(AnomalyTypes.__table__.insert(), anomalies)