fleet-platform/app/auth.py

135 lines
4 KiB
Python
Raw Normal View History

import hashlib
import secrets
from collections.abc import Awaitable, Callable
from datetime import UTC, datetime, timedelta
from typing import Annotated, Any
import bcrypt
import jwt
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
from app.config import get_settings
from app.db import get_pool
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/token")
ACCESS = "access"
REFRESH = "refresh"
class TokenPair(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
expires_in: int
class AuthAccount(BaseModel):
account_id: int
username: str
scopes: list[str]
def hash_password(plain: str) -> str:
return bcrypt.hashpw(plain.encode("utf-8"), bcrypt.gensalt(rounds=12)).decode("utf-8")
def verify_password(plain: str, hashed: str) -> bool:
try:
return bcrypt.checkpw(plain.encode("utf-8"), hashed.encode("utf-8"))
except (ValueError, TypeError):
return False
def _now() -> datetime:
return datetime.now(UTC)
def issue_access_token(account_id: int, scopes: list[str]) -> tuple[str, int]:
settings = get_settings()
ttl_seconds = settings.jwt_access_ttl_min * 60
payload: dict[str, Any] = {
"sub": str(account_id),
"scopes": scopes,
"typ": ACCESS,
"iat": _now(),
"exp": _now() + timedelta(seconds=ttl_seconds),
}
token = jwt.encode(payload, settings.jwt_secret, algorithm="HS256")
return token, ttl_seconds
def issue_refresh_token(account_id: int) -> tuple[str, datetime, str]:
"""Returns (opaque_token, expires_at, token_hash). Persist only the hash."""
settings = get_settings()
raw = secrets.token_urlsafe(48)
expires_at = _now() + timedelta(days=settings.jwt_refresh_ttl_days)
token_hash = hashlib.sha256(raw.encode("utf-8")).hexdigest()
_ = account_id
return raw, expires_at, token_hash
def decode_access_token(token: str) -> dict[str, Any]:
settings = get_settings()
try:
claims: dict[str, Any] = jwt.decode(
token, settings.jwt_secret, algorithms=["HS256"]
)
except jwt.PyJWTError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"invalid token: {exc.__class__.__name__}",
) from exc
if claims.get("typ") != ACCESS:
raise HTTPException(status_code=401, detail="not an access token")
return claims
async def fetch_account(username: str) -> tuple[int, str, list[str]] | None:
pool = await get_pool()
async with pool.connection() as conn, conn.cursor() as cur:
await cur.execute(
"SELECT account_id, password_hash, scopes "
"FROM auth.accounts "
"WHERE username = %s AND is_active = true",
(username,),
)
row = await cur.fetchone()
if row is None:
return None
return int(row[0]), str(row[1]), list(row[2])
async def store_refresh_token(account_id: int, token_hash: str, expires_at: datetime) -> None:
pool = await get_pool()
async with pool.connection() as conn, conn.cursor() as cur:
await cur.execute(
"INSERT INTO auth.tokens (account_id, token_type, token_hash, expires_at) "
"VALUES (%s, 'refresh', %s, %s)",
(account_id, token_hash, expires_at),
)
async def current_account(
token: Annotated[str, Depends(oauth2_scheme)],
) -> AuthAccount:
claims = decode_access_token(token)
return AuthAccount(
account_id=int(claims["sub"]),
username="",
scopes=list(claims.get("scopes", [])),
)
def require_scope(scope: str) -> Callable[[AuthAccount], Awaitable[AuthAccount]]:
async def _checker(
account: Annotated[AuthAccount, Depends(current_account)],
) -> AuthAccount:
if scope not in account.scopes and "admin:fleet" not in account.scopes:
raise HTTPException(status_code=403, detail=f"missing scope: {scope}")
return account
return _checker