ghostforge/ghostforge/ghosts.py

251 lines
7.8 KiB
Python

import uuid
from datetime import datetime
from operator import truth
from typing import Annotated
from typing import Optional
from fastapi import APIRouter
from fastapi import Depends
from fastapi import Form
from fastapi import HTTPException
from fastapi import Query
from fastapi import Request
from sqlalchemy import asc
from sqlalchemy import desc
from sqlalchemy import exists
from sqlalchemy import func
from sqlalchemy import inspect
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy import Text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import and_
from sqlmodel import AutoString
from sqlmodel import Column
from sqlmodel import Field
from sqlmodel import SQLModel
from sqlmodel import TIMESTAMP
from ghostforge.db import get_session
from ghostforge.db import User
from ghostforge.helpers.stringies import age_in_human
from ghostforge.htmljson import HtmlJson
from ghostforge.users import get_current_user
gf = APIRouter()
hj = HtmlJson()
class Ghost(SQLModel, table=True):
id: Optional[int] = Field(primary_key=True)
first_name: Optional[str] = Field()
last_name: Optional[str] = Field()
middle_name: Optional[str] = Field()
birthdate: Optional[datetime] = Field(
sa_column=Column(
TIMESTAMP(timezone=True),
)
)
owner_id: Optional[uuid.UUID] = Field(foreign_key="user.id")
class GhostCreate(Ghost, table=False):
pass
class GhostUpdate(Ghost, table=False):
pass
class GhostRead(Ghost, table=False):
pass
class GhostPermissions(SQLModel, table=True):
id: int = Field(primary_key=True)
user_id: Optional[uuid.UUID] = Field(foreign_key="user.id")
ghost_id: Optional[int] = Field(foreign_key="ghost.id")
can_edit: Optional[bool] = Field()
async def can_view_ghost(
ghost_id: int,
current_user: Annotated[User, Depends(get_current_user())],
session: AsyncSession = Depends(get_session),
):
res = await session.get(Ghost, ghost_id)
if res.owner_id == current_user.id:
return True
res = await session.execute(
select(GhostPermissions).where(
and_(GhostPermissions.ghost_id == ghost_id, GhostPermissions.user_id == current_user.id)
)
)
return truth(res.scalars().first())
async def can_edit_ghost(
ghost_id: int,
current_user: Annotated[User, Depends(get_current_user())],
session: AsyncSession = Depends(get_session),
):
res = await session.get(Ghost, ghost_id)
if res.owner_id == current_user.id:
return True
res = await session.execute(
select(GhostPermissions).where(
and_(
GhostPermissions.ghost_id == ghost_id,
GhostPermissions.user_id == current_user.id,
GhostPermissions.can_edit,
)
)
)
return truth(res.scalars().first())
@gf.get("/ghosts/{ghost_id}")
@hj.html_or_json("ghosts/ghost.html")
async def read_ghost(
ghost_id: int,
current_user: Annotated[User, Depends(get_current_user())],
session: AsyncSession = Depends(get_session),
request: Request = None,
can_view: bool = Depends(can_view_ghost),
):
if not can_view:
raise HTTPException(status_code=403, detail="You're not authorized to see this ghost")
result = (
await session.execute(
select(Ghost, User.username.label("owner_username"), User.id.label("owner_guid"))
.join(User, Ghost.owner_id == User.id)
.where(Ghost.id == ghost_id)
)
).first()
if not result:
raise HTTPException(status_code=404, detail="No ghost with that ID")
data = {
"ghost": result.Ghost,
"owner": result.owner_username,
"user": current_user,
"computed": {"age": age_in_human(result.Ghost.birthdate)},
"crumbs": [("ghosts", "/ghosts"), (result.Ghost.id, False)],
}
request.state.ghostforge = data | getattr(request.state, "ghostforge", {})
return result
@gf.put("/ghosts/{ghost_id}")
async def update_ghost(
ghost_id: int,
ghost: GhostUpdate,
current_user: Annotated[User, Depends(get_current_user())],
session: AsyncSession = Depends(get_session),
request: Request = None,
can_edit: bool = Depends(can_edit_ghost),
):
if not can_edit:
raise HTTPException(status_code=403, detail="You're not authorized to edit this ghost")
db_ghost = await session.get(Ghost, ghost_id)
if db_ghost is None:
raise HTTPException(status_code=404, detail="No ghost with that ID")
for k, v in ghost.dict(exclude_unset=True).items():
setattr(db_ghost, k, v)
await session.commit()
await session.refresh(db_ghost)
return db_ghost
@gf.post("/ghosts/new")
async def create_ghost(
ghost: GhostCreate,
current_user: Annotated[User, Depends(get_current_user())],
session: AsyncSession = Depends(get_session),
request: Request = None,
):
new_ghost = Ghost.from_orm(ghost)
new_ghost.owner_id = current_user.id
session.add(new_ghost)
await session.commit()
await session.refresh(new_ghost)
return new_ghost
@gf.post("/ghosts")
async def get_ghosts(
current_user: Annotated[User, Depends(get_current_user())],
start: int = Form(None, alias="start"),
length: int = Form(None, alias="length"),
search: Optional[str] = Form(None, alias="search[value]"),
order_col: int = Form(0, alias="order[0][column]"),
order_dir: str = Form("asc", alias="order[0][dir]"),
session: AsyncSession = Depends(get_session),
request: Request = None,
):
# Get the column name based on the order_col integer value
ghost_columns = [c.key for c in inspect(Ghost).c]
order_col_name = ghost_columns[order_col]
# Retrieve filtered ghosts from database
query = select(Ghost, User.username.label("owner_username"), User.id.label("owner_guid")).join(
User, Ghost.owner_id == User.id
)
permission_filter = or_(
Ghost.owner_id == current_user.id,
exists(
select(GhostPermissions).where(
and_(GhostPermissions.ghost_id == Ghost.id, GhostPermissions.user_id == current_user.id)
)
),
)
count = select(func.count(Ghost.id)).join(User, Ghost.owner_id == User.id).filter(permission_filter)
total_ghosts = (await session.execute(count)).scalar()
query = query.filter(permission_filter)
if search:
conditions = []
for col in ghost_columns:
column_attr = getattr(Ghost, col)
if isinstance(column_attr.type, (AutoString, Text)):
conditions.append(column_attr.ilike(f"%{search}%"))
query = query.where(or_(*conditions))
if order_dir == "asc":
query = query.order_by(asc(getattr(Ghost, order_col_name)))
else:
query = query.order_by(desc(getattr(Ghost, order_col_name)))
query = query.offset(start).limit(length)
ghosts = (await session.execute(query)).all()
print(len(ghosts))
return {
"recordsTotal": total_ghosts,
"recordsFiltered": total_ghosts,
"data": ghosts,
}
@gf.get("/ghosts")
@hj.html_or_json("ghosts/ghosts.html")
async def read_users(
current_user: Annotated[User, Depends(get_current_user())],
offset: int = 0,
limit: int = Query(default=100, lte=100),
session: AsyncSession = Depends(get_session),
request: Request = None,
):
subquery = (
select(GhostPermissions.ghost_id)
.where(GhostPermissions.user_id == current_user.id)
.union(select(Ghost.id).where(Ghost.owner_id == current_user.id))
.subquery()
)
query = await session.execute(select(Ghost).where(Ghost.id.in_(subquery)).offset(offset).limit(limit))
ghosts = query.scalars().all()
data = {"ghosts": ghosts, "user": current_user, "crumbs": [("ghosts", False)]}
request.state.ghostforge = data | getattr(request.state, "ghostforge", {})
return ghosts