134 lines
4 KiB
Python
134 lines
4 KiB
Python
import hashlib
|
|
import secrets
|
|
from collections.abc import Awaitable, Callable
|
|
from datetime import UTC, datetime, timedelta
|
|
from typing import Annotated, Any
|
|
|
|
import bcrypt
|
|
import jwt
|
|
from fastapi import Depends, HTTPException, status
|
|
from fastapi.security import OAuth2PasswordBearer
|
|
from pydantic import BaseModel
|
|
|
|
from app.config import get_settings
|
|
from app.db import get_pool
|
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/token")
|
|
|
|
ACCESS = "access"
|
|
REFRESH = "refresh"
|
|
|
|
|
|
class TokenPair(BaseModel):
|
|
access_token: str
|
|
refresh_token: str
|
|
token_type: str = "bearer"
|
|
expires_in: int
|
|
|
|
|
|
class AuthAccount(BaseModel):
|
|
account_id: int
|
|
username: str
|
|
scopes: list[str]
|
|
|
|
|
|
def hash_password(plain: str) -> str:
|
|
return bcrypt.hashpw(plain.encode("utf-8"), bcrypt.gensalt(rounds=12)).decode("utf-8")
|
|
|
|
|
|
def verify_password(plain: str, hashed: str) -> bool:
|
|
try:
|
|
return bcrypt.checkpw(plain.encode("utf-8"), hashed.encode("utf-8"))
|
|
except (ValueError, TypeError):
|
|
return False
|
|
|
|
|
|
def _now() -> datetime:
|
|
return datetime.now(UTC)
|
|
|
|
|
|
def issue_access_token(account_id: int, scopes: list[str]) -> tuple[str, int]:
|
|
settings = get_settings()
|
|
ttl_seconds = settings.jwt_access_ttl_min * 60
|
|
payload: dict[str, Any] = {
|
|
"sub": str(account_id),
|
|
"scopes": scopes,
|
|
"typ": ACCESS,
|
|
"iat": _now(),
|
|
"exp": _now() + timedelta(seconds=ttl_seconds),
|
|
}
|
|
token = jwt.encode(payload, settings.jwt_secret, algorithm="HS256")
|
|
return token, ttl_seconds
|
|
|
|
|
|
def issue_refresh_token(account_id: int) -> tuple[str, datetime, str]:
|
|
"""Returns (opaque_token, expires_at, token_hash). Persist only the hash."""
|
|
settings = get_settings()
|
|
raw = secrets.token_urlsafe(48)
|
|
expires_at = _now() + timedelta(days=settings.jwt_refresh_ttl_days)
|
|
token_hash = hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
|
_ = account_id
|
|
return raw, expires_at, token_hash
|
|
|
|
|
|
def decode_access_token(token: str) -> dict[str, Any]:
|
|
settings = get_settings()
|
|
try:
|
|
claims: dict[str, Any] = jwt.decode(
|
|
token, settings.jwt_secret, algorithms=["HS256"]
|
|
)
|
|
except jwt.PyJWTError as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=f"invalid token: {exc.__class__.__name__}",
|
|
) from exc
|
|
if claims.get("typ") != ACCESS:
|
|
raise HTTPException(status_code=401, detail="not an access token")
|
|
return claims
|
|
|
|
|
|
async def fetch_account(username: str) -> tuple[int, str, list[str]] | None:
|
|
pool = await get_pool()
|
|
async with pool.connection() as conn, conn.cursor() as cur:
|
|
await cur.execute(
|
|
"SELECT account_id, password_hash, scopes "
|
|
"FROM auth.accounts "
|
|
"WHERE username = %s AND is_active = true",
|
|
(username,),
|
|
)
|
|
row = await cur.fetchone()
|
|
if row is None:
|
|
return None
|
|
return int(row[0]), str(row[1]), list(row[2])
|
|
|
|
|
|
async def store_refresh_token(account_id: int, token_hash: str, expires_at: datetime) -> None:
|
|
pool = await get_pool()
|
|
async with pool.connection() as conn, conn.cursor() as cur:
|
|
await cur.execute(
|
|
"INSERT INTO auth.tokens (account_id, token_type, token_hash, expires_at) "
|
|
"VALUES (%s, 'refresh', %s, %s)",
|
|
(account_id, token_hash, expires_at),
|
|
)
|
|
|
|
|
|
async def current_account(
|
|
token: Annotated[str, Depends(oauth2_scheme)],
|
|
) -> AuthAccount:
|
|
claims = decode_access_token(token)
|
|
return AuthAccount(
|
|
account_id=int(claims["sub"]),
|
|
username="",
|
|
scopes=list(claims.get("scopes", [])),
|
|
)
|
|
|
|
|
|
def require_scope(scope: str) -> Callable[[AuthAccount], Awaitable[AuthAccount]]:
|
|
async def _checker(
|
|
account: Annotated[AuthAccount, Depends(current_account)],
|
|
) -> AuthAccount:
|
|
if scope not in account.scopes and "admin:fleet" not in account.scopes:
|
|
raise HTTPException(status_code=403, detail=f"missing scope: {scope}")
|
|
return account
|
|
|
|
return _checker
|