import hashlib import secrets from collections.abc import Awaitable, Callable from datetime import UTC, datetime, timedelta from typing import Annotated, Any import bcrypt import jwt from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from pydantic import BaseModel from app.config import get_settings from app.db import get_pool oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/token") ACCESS = "access" REFRESH = "refresh" class TokenPair(BaseModel): access_token: str refresh_token: str token_type: str = "bearer" expires_in: int class AuthAccount(BaseModel): account_id: int username: str scopes: list[str] def hash_password(plain: str) -> str: return bcrypt.hashpw(plain.encode("utf-8"), bcrypt.gensalt(rounds=12)).decode("utf-8") def verify_password(plain: str, hashed: str) -> bool: try: return bcrypt.checkpw(plain.encode("utf-8"), hashed.encode("utf-8")) except (ValueError, TypeError): return False def _now() -> datetime: return datetime.now(UTC) def issue_access_token(account_id: int, scopes: list[str]) -> tuple[str, int]: settings = get_settings() ttl_seconds = settings.jwt_access_ttl_min * 60 payload: dict[str, Any] = { "sub": str(account_id), "scopes": scopes, "typ": ACCESS, "iat": _now(), "exp": _now() + timedelta(seconds=ttl_seconds), } token = jwt.encode(payload, settings.jwt_secret, algorithm="HS256") return token, ttl_seconds def issue_refresh_token(account_id: int) -> tuple[str, datetime, str]: """Returns (opaque_token, expires_at, token_hash). Persist only the hash.""" settings = get_settings() raw = secrets.token_urlsafe(48) expires_at = _now() + timedelta(days=settings.jwt_refresh_ttl_days) token_hash = hashlib.sha256(raw.encode("utf-8")).hexdigest() _ = account_id return raw, expires_at, token_hash def decode_access_token(token: str) -> dict[str, Any]: settings = get_settings() try: claims: dict[str, Any] = jwt.decode( token, settings.jwt_secret, algorithms=["HS256"] ) except jwt.PyJWTError as exc: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=f"invalid token: {exc.__class__.__name__}", ) from exc if claims.get("typ") != ACCESS: raise HTTPException(status_code=401, detail="not an access token") return claims async def fetch_account(username: str) -> tuple[int, str, list[str]] | None: pool = await get_pool() async with pool.connection() as conn, conn.cursor() as cur: await cur.execute( "SELECT account_id, password_hash, scopes " "FROM auth.accounts " "WHERE username = %s AND is_active = true", (username,), ) row = await cur.fetchone() if row is None: return None return int(row[0]), str(row[1]), list(row[2]) async def store_refresh_token(account_id: int, token_hash: str, expires_at: datetime) -> None: pool = await get_pool() async with pool.connection() as conn, conn.cursor() as cur: await cur.execute( "INSERT INTO auth.tokens (account_id, token_type, token_hash, expires_at) " "VALUES (%s, 'refresh', %s, %s)", (account_id, token_hash, expires_at), ) async def current_account( token: Annotated[str, Depends(oauth2_scheme)], ) -> AuthAccount: claims = decode_access_token(token) return AuthAccount( account_id=int(claims["sub"]), username="", scopes=list(claims.get("scopes", [])), ) def require_scope(scope: str) -> Callable[[AuthAccount], Awaitable[AuthAccount]]: async def _checker( account: Annotated[AuthAccount, Depends(current_account)], ) -> AuthAccount: if scope not in account.scopes and "admin:fleet" not in account.scopes: raise HTTPException(status_code=403, detail=f"missing scope: {scope}") return account return _checker