CrowdTLS-server/crowdtls/db.py

41 lines
1.4 KiB
Python
Raw Normal View History

2023-06-07 14:35:48 -07:00
import os
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import sessionmaker
2023-06-06 15:51:54 -07:00
from sqlmodel import SQLModel
2023-06-16 17:02:57 -07:00
from crowdtls.logs import logger
from crowdtls.models import AnomalyTypes
2023-06-06 15:51:54 -07:00
2023-06-07 14:35:48 -07:00
DATABASE_URL = (
f'postgresql+asyncpg://{os.environ.get("POSTGRES_USER")}:'
+ f'{os.environ.get("POSTGRES_PASSWORD")}@'
+ os.environ.get("POSTGRES_CONTAINER")
+ f':5432/{os.environ.get("POSTGRES_DB")}'
)
engine = create_async_engine(DATABASE_URL, echo=True, future=True)
2023-06-16 17:02:57 -07:00
session_maker = sessionmaker(engine, class_=AsyncSession, expire_on_commit=True)
2023-06-07 14:35:48 -07:00
async def get_session() -> AsyncSession:
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async with async_session() as session:
yield session
2023-06-16 17:02:57 -07:00
async def create_db_and_tables() -> None:
2023-06-07 14:35:48 -07:00
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.drop_all)
await conn.run_sync(SQLModel.metadata.create_all)
2023-06-16 17:02:57 -07:00
logger.info("Populating AnomalyTypes table")
anomalies = [
dict(response_code=250, description="Too many active certificates >10"),
dict(response_code=251, description="Lifespan too short <3 weeks"),
dict(response_code=252, description="Too many SANs >80"),
]
await conn.execute(AnomalyTypes.__table__.insert(), anomalies)