mirror of
https://github.com/DarrylNixon/ghostforge
synced 2024-04-22 06:27:20 -07:00
Replaced rolled-own/argon2 with fastapi-users, oof
This commit is contained in:
parent
216d2ac42b
commit
6f81ef699d
14 changed files with 310 additions and 142 deletions
|
@ -1,109 +1,76 @@
|
|||
from datetime import datetime
|
||||
from typing import List
|
||||
import os
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from argon2 import PasswordHasher
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import Request
|
||||
from sqlmodel import Field
|
||||
from sqlmodel import select
|
||||
from sqlmodel import Session
|
||||
from sqlmodel import SQLModel
|
||||
from fastapi_users import BaseUserManager
|
||||
from fastapi_users import FastAPIUsers
|
||||
from fastapi_users import schemas
|
||||
from fastapi_users import UUIDIDMixin
|
||||
from fastapi_users.authentication import AuthenticationBackend
|
||||
from fastapi_users.authentication import BearerTransport
|
||||
from fastapi_users.authentication import CookieTransport
|
||||
from fastapi_users.authentication import JWTStrategy
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
|
||||
from ghostforge.db import get_session
|
||||
from ghostforge.htmljson import HtmlJson
|
||||
from ghostforge.db import get_user_db
|
||||
from ghostforge.db import User
|
||||
|
||||
SECRET = os.environ.get("GHOSTFORGE_JWT_SECRET")
|
||||
|
||||
gf = APIRouter()
|
||||
ph = PasswordHasher()
|
||||
hj = HtmlJson()
|
||||
|
||||
|
||||
class UserBase(SQLModel):
|
||||
name: str = Field()
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
pass
|
||||
|
||||
|
||||
class User(UserBase, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
creation: datetime = Field(default=datetime.now())
|
||||
password_hash: Optional[str]
|
||||
|
||||
def verify_password(self, password: str):
|
||||
return ph.verify(self.password_hash, password)
|
||||
|
||||
def set_password(self, password: str):
|
||||
self.password_hash = ph.hash(password)
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
pass
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
name: str
|
||||
password: str
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
pass
|
||||
|
||||
|
||||
class UserRead(UserBase):
|
||||
id: int
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = SECRET
|
||||
verification_token_secret = SECRET
|
||||
|
||||
async def on_after_register(self, user: User, request: Optional[Request] = None):
|
||||
print(f"User {user.id} has registered.")
|
||||
|
||||
async def on_after_forgot_password(self, user: User, token: str, request: Optional[Request] = None):
|
||||
print(f"User {user.id} has forgot their password. Reset token: {token}")
|
||||
|
||||
async def on_after_request_verify(self, user: User, token: str, request: Optional[Request] = None):
|
||||
print(f"Verification requested for user {user.id}. Verification token: {token}")
|
||||
|
||||
|
||||
class UserUpdate(SQLModel):
|
||||
name: Optional[str] = None
|
||||
async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):
|
||||
yield UserManager(user_db)
|
||||
|
||||
|
||||
@gf.get("/users", response_model=List[UserRead])
|
||||
async def read_users(
|
||||
offset: int = 0, limit: int = Query(default=100, lte=100), session: Session = Depends(get_session)
|
||||
):
|
||||
users = await session.execute(select(User).offset(offset).limit(limit))
|
||||
return [UserRead.from_orm(user[0]) for user in users.all()]
|
||||
bearer_transport = BearerTransport(tokenUrl="auth/jwt/login")
|
||||
cookie_transport = CookieTransport(cookie_httponly=True, cookie_name="ghostforge", cookie_samesite="strict")
|
||||
|
||||
|
||||
@gf.post("/users", response_model=UserRead)
|
||||
async def create_hero(user: UserCreate, session: Session = Depends(get_session)):
|
||||
new_user = User.from_orm(user)
|
||||
if len(user.password) < 12:
|
||||
raise HTTPException(status_code=442, detail="Password must be at least 12 characters")
|
||||
new_user.set_password(user.password)
|
||||
session.add(new_user)
|
||||
await session.commit()
|
||||
await session.refresh(new_user)
|
||||
return new_user
|
||||
def get_jwt_strategy() -> JWTStrategy:
|
||||
return JWTStrategy(secret=SECRET, lifetime_seconds=604800)
|
||||
|
||||
|
||||
@gf.get("/users/{user_id}", response_model=UserRead)
|
||||
@hj.html_or_json("user.html")
|
||||
async def read_user(user_id: int, session: Session = Depends(get_session), request: Request = None):
|
||||
user = await session.get(User, user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
jwt_backend = AuthenticationBackend(
|
||||
name="jwt",
|
||||
transport=bearer_transport,
|
||||
get_strategy=get_jwt_strategy,
|
||||
)
|
||||
|
||||
data = {"crumbs": [("settings", False), ("users", "/users"), (user_id, False)]}
|
||||
request.state.ghostforge = data | getattr(request.state, "ghostforge", {})
|
||||
return user
|
||||
web_backend = AuthenticationBackend(name="cookie", transport=cookie_transport, get_strategy=get_jwt_strategy)
|
||||
|
||||
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [web_backend, jwt_backend])
|
||||
|
||||
|
||||
@gf.patch("/users/{user_id}", response_model=UserRead)
|
||||
async def update_user(user_id: int, user: UserUpdate, session: Session = Depends(get_session), request: Request = None):
|
||||
edit_user = await session.get(User, user_id)
|
||||
if not edit_user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
data = user.dict(exclude_unset=True)
|
||||
for key, value in data.items():
|
||||
setattr(edit_user, key, value)
|
||||
session.add(edit_user)
|
||||
await session.commit()
|
||||
await session.refresh(edit_user)
|
||||
return edit_user
|
||||
|
||||
|
||||
@gf.delete("/users/{user_id}")
|
||||
async def delete_user(user_id: int, session: Session = Depends(get_session), request: Request = None):
|
||||
user = await session.get(User, user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
session.delete(user)
|
||||
await session.commit()
|
||||
return {"ok": True}
|
||||
def get_current_user(active: bool = True, optional: bool = False) -> User:
|
||||
return fastapi_users.current_user(active=active, optional=optional)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue