fleet-platform/app/auth.py

195 lines
6.3 KiB
Python
Raw Normal View History

import hashlib
import secrets
from collections.abc import Awaitable, Callable
from datetime import UTC, datetime, timedelta
from typing import Annotated, Any
import bcrypt
import jwt
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
from app.config import get_settings
from app.db import get_pool
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/token")
ACCESS = "access"
REFRESH = "refresh"
class TokenPair(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
expires_in: int
class AuthAccount(BaseModel):
account_id: int
username: str
scopes: list[str]
def hash_password(plain: str) -> str:
return bcrypt.hashpw(plain.encode("utf-8"), bcrypt.gensalt(rounds=12)).decode("utf-8")
def verify_password(plain: str, hashed: str) -> bool:
try:
return bcrypt.checkpw(plain.encode("utf-8"), hashed.encode("utf-8"))
except (ValueError, TypeError):
return False
def _now() -> datetime:
return datetime.now(UTC)
def issue_access_token(account_id: int, scopes: list[str]) -> tuple[str, int]:
settings = get_settings()
ttl_seconds = settings.jwt_access_ttl_min * 60
payload: dict[str, Any] = {
"sub": str(account_id),
"scopes": scopes,
"typ": ACCESS,
"iat": _now(),
"exp": _now() + timedelta(seconds=ttl_seconds),
}
token = jwt.encode(payload, settings.jwt_secret, algorithm="HS256")
return token, ttl_seconds
def issue_refresh_token(account_id: int) -> tuple[str, datetime, str]:
"""Returns (opaque_token, expires_at, token_hash). Persist only the hash."""
settings = get_settings()
raw = secrets.token_urlsafe(48)
expires_at = _now() + timedelta(days=settings.jwt_refresh_ttl_days)
token_hash = hashlib.sha256(raw.encode("utf-8")).hexdigest()
_ = account_id
return raw, expires_at, token_hash
def decode_access_token(token: str) -> dict[str, Any]:
settings = get_settings()
try:
claims: dict[str, Any] = jwt.decode(
token, settings.jwt_secret, algorithms=["HS256"]
)
except jwt.PyJWTError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"invalid token: {exc.__class__.__name__}",
) from exc
if claims.get("typ") != ACCESS:
raise HTTPException(status_code=401, detail="not an access token")
return claims
async def fetch_account(username: str) -> tuple[int, str, list[str]] | None:
pool = await get_pool()
async with pool.connection() as conn, conn.cursor() as cur:
await cur.execute(
"SELECT account_id, password_hash, scopes "
"FROM auth.accounts "
"WHERE username = %s AND is_active = true",
(username,),
)
row = await cur.fetchone()
if row is None:
return None
return int(row[0]), str(row[1]), list(row[2])
async def store_refresh_token(account_id: int, token_hash: str, expires_at: datetime) -> None:
pool = await get_pool()
async with pool.connection() as conn, conn.cursor() as cur:
await cur.execute(
"INSERT INTO auth.tokens (account_id, token_type, token_hash, expires_at) "
"VALUES (%s, 'refresh', %s, %s)",
(account_id, token_hash, expires_at),
)
async def touch_last_login(account_id: int) -> None:
pool = await get_pool()
async with pool.connection() as conn, conn.cursor() as cur:
await cur.execute(
"UPDATE auth.accounts SET last_login_at = now() WHERE account_id = %s",
(account_id,),
)
async def rotate_refresh_token(raw_token: str) -> TokenPair | None:
"""Redeem a refresh token and rotate it.
Returns a fresh access+refresh pair, or None when the token is unknown,
expired, already revoked, or belongs to a deactivated account. The old
token is revoked in the same transaction (single-use rotation), so a
replayed refresh token never yields a second valid pair that also gives
us reuse detection if we want to act on it later.
"""
token_hash = hashlib.sha256(raw_token.encode("utf-8")).hexdigest()
pool = await get_pool()
async with pool.connection() as conn, conn.transaction(), conn.cursor() as cur:
await cur.execute(
"""
SELECT t.token_id, t.account_id, a.scopes
FROM auth.tokens t
JOIN auth.accounts a ON a.account_id = t.account_id
WHERE t.token_hash = %s
AND t.token_type = 'refresh'
AND t.revoked_at IS NULL
AND t.expires_at > now()
AND a.is_active = true
FOR UPDATE OF t
""",
(token_hash,),
)
row = await cur.fetchone()
if row is None:
return None
token_id, account_id, scopes = int(row[0]), int(row[1]), list(row[2])
await cur.execute(
"UPDATE auth.tokens SET revoked_at = now() WHERE token_id = %s",
(token_id,),
)
access, ttl = issue_access_token(account_id, scopes)
new_raw, expires_at, new_hash = issue_refresh_token(account_id)
await cur.execute(
"INSERT INTO auth.tokens (account_id, token_type, token_hash, expires_at) "
"VALUES (%s, 'refresh', %s, %s)",
(account_id, new_hash, expires_at),
)
await cur.execute(
"UPDATE auth.accounts SET last_login_at = now() WHERE account_id = %s",
(account_id,),
)
return TokenPair(access_token=access, refresh_token=new_raw, expires_in=ttl)
async def current_account(
token: Annotated[str, Depends(oauth2_scheme)],
) -> AuthAccount:
claims = decode_access_token(token)
return AuthAccount(
account_id=int(claims["sub"]),
username="",
scopes=list(claims.get("scopes", [])),
)
def require_scope(scope: str) -> Callable[[AuthAccount], Awaitable[AuthAccount]]:
async def _checker(
account: Annotated[AuthAccount, Depends(current_account)],
) -> AuthAccount:
if scope not in account.scopes and "admin:fleet" not in account.scopes:
raise HTTPException(status_code=403, detail=f"missing scope: {scope}")
return account
return _checker