fix(logging): attribute each query to its analyst caller
The BearerAuth middleware matched a per-analyst token but only stashed it on request.state, which the FastMCP tools never see — so the query log line logged rows/sql with no caller, defeating the per-token attribution the auth design promises. Bridge the caller name through a ContextVar (anyio propagates it into the worker thread that runs each sync tool) and include it in the query() log. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
parent
fae40942a2
commit
af6fdbcd3f
1 changed files with 12 additions and 1 deletions
|
|
@ -26,6 +26,7 @@ Env:
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextvars
|
||||||
import hmac
|
import hmac
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
@ -65,6 +66,12 @@ def _get_logger(name: str) -> logging.Logger:
|
||||||
|
|
||||||
log = _get_logger("server")
|
log = _get_logger("server")
|
||||||
|
|
||||||
|
# Per-request caller name, set by BearerAuth from the matched token so the tools can
|
||||||
|
# attribute each query to an analyst in the logs. A ContextVar (not a tool arg) because
|
||||||
|
# FastMCP tools never receive the HTTP request; anyio propagates the context into the
|
||||||
|
# worker thread that runs each sync tool. Defaults to "?" if auth ever didn't run.
|
||||||
|
_caller_var: contextvars.ContextVar[str] = contextvars.ContextVar("caller", default="?")
|
||||||
|
|
||||||
DATABASE_URL = os.environ["DATABASE_URL"] # analytics_ro DSN (set by deploy)
|
DATABASE_URL = os.environ["DATABASE_URL"] # analytics_ro DSN (set by deploy)
|
||||||
MAX_ROWS_CEIL = int(os.getenv("MCP_MAX_ROWS", "10000"))
|
MAX_ROWS_CEIL = int(os.getenv("MCP_MAX_ROWS", "10000"))
|
||||||
# Schemas the introspection helpers (list_tables/describe_table/sample_table) expose.
|
# Schemas the introspection helpers (list_tables/describe_table/sample_table) expose.
|
||||||
|
|
@ -284,7 +291,10 @@ def query(sql: str, max_rows: int = 1000) -> dict:
|
||||||
truncated = len(rows) > cap
|
truncated = len(rows) > cap
|
||||||
rows = rows[:cap]
|
rows = rows[:cap]
|
||||||
dur_ms = int((time.monotonic() - t0) * 1000)
|
dur_ms = int((time.monotonic() - t0) * 1000)
|
||||||
log.info("query rows=%d trunc=%s %dms :: %s", len(rows), truncated, dur_ms, sql[:200])
|
log.info(
|
||||||
|
"query caller=%s rows=%d trunc=%s %dms :: %s",
|
||||||
|
_caller_var.get(), len(rows), truncated, dur_ms, sql[:200],
|
||||||
|
)
|
||||||
return {"row_count": len(rows), "truncated": truncated, "rows": rows}
|
return {"row_count": len(rows), "truncated": truncated, "rows": rows}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -393,6 +403,7 @@ class BearerAuth(BaseHTTPMiddleware):
|
||||||
if caller is None:
|
if caller is None:
|
||||||
return JSONResponse({"error": "unauthorized"}, status_code=401)
|
return JSONResponse({"error": "unauthorized"}, status_code=401)
|
||||||
request.state.caller = caller
|
request.state.caller = caller
|
||||||
|
_caller_var.set(caller) # so the tools can attribute the query in the logs
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue