fleetanalytics_mcp/analytics_mcp.py

309 lines
13 KiB
Python
Raw Permalink Normal View History

"""
analytics_mcp_rev.py Fireside Communications · Read-only Analytics MCP Server
Hosted MCP server for the decision & analytics team. Exposes the fleet reporting
data (reporting.* + tracksolid.*) to Claude as READ-ONLY query + introspection
tools for reporting and decisions, never edit/delete.
It is a STANDALONE Traefik-labelled bridge (not Coolify-managed), the same shape
as the dashboard_api staging bridge: it reuses the webhook_receiver image, joins
the `coolify` network, and connects to the internal DB over psycopg2 as the
dedicated read-only `analytics_ro` role (deploy_analytics_mcp.sh sets DATABASE_URL
to that DSN). Served over streamable HTTP with Bearer-token auth.
READ-ONLY is enforced at FOUR layers:
1. the analytics_ro GRANTs (no INSERT/UPDATE/DELETE; not the matview owner)
2. role + connection default_transaction_read_only = on
3. every query runs in a transaction that is ROLLED BACK (never committed)
4. the `query` tool's single-statement / keyword guard (clean errors, not DB faults)
Env:
DATABASE_URL analytics_ro DSN (set by the deploy script)
MCP_AUTH_TOKENS "alice:tok1,bob:tok2" per-analyst Bearer tokens (revocable + audited)
MCP_MAX_ROWS hard ceiling on rows returned (default 10000)
MCP_POOL_MAX max read-only pool connections (default 8)
"""
from __future__ import annotations
import logging
import os
import re
import time
from contextlib import contextmanager
import psycopg2
import psycopg2.extras
import psycopg2.pool
from mcp.server.fastmcp import FastMCP
from mcp.server.transport_security import TransportSecuritySettings
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
def _get_logger(name: str) -> logging.Logger:
"""Standalone logger mirroring ts_shared_rev's format. Intentionally NOT
importing ts_shared_rev: that module eagerly requires the Tracksolid ingestion
secrets (APP_KEY/SECRET/PWD), which this read-only analytics server has no
business holding."""
root = logging.getLogger("analytics_mcp")
if not root.handlers:
handler = logging.StreamHandler()
handler.setFormatter(
logging.Formatter(
"%(asctime)s [%(levelname)s] %(name)s%(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
)
root.addHandler(handler)
root.setLevel(logging.INFO)
return root.getChild(name)
log = _get_logger("server")
DATABASE_URL = os.environ["DATABASE_URL"] # analytics_ro DSN (set by deploy)
MAX_ROWS_CEIL = int(os.getenv("MCP_MAX_ROWS", "10000"))
# Schemas the introspection helpers (list_tables/describe_table/sample_table) expose.
# Override with MCP_READABLE_SCHEMAS="reporting,tracksolid,tickets,fuel" — these must
# stay in sync with the GRANTs in scripts/analytics_ro_role.sql. The raw query() tool
# is bounded by the analytics_ro role's GRANTs, not by this list.
READABLE_SCHEMAS = tuple(
s.strip() for s in os.getenv(
"MCP_READABLE_SCHEMAS", "reporting,tracksolid,tickets,fuel"
).split(",") if s.strip()
)
# ── Read-only connection pool ────────────────────────────────────────────────
# Force read-only + a statement timeout at the connection level (belt + braces;
# the analytics_ro role already sets these, but a self-contained server is safer
# in case it is ever pointed at a less-restricted DSN).
_pool = psycopg2.pool.ThreadedConnectionPool(
1,
int(os.getenv("MCP_POOL_MAX", "8")),
DATABASE_URL,
options="-c default_transaction_read_only=on -c statement_timeout=30000 -c client_encoding=UTF8",
)
@contextmanager
def _ro_conn():
"""Read-only connection; the transaction is ALWAYS rolled back (never commits)."""
conn = _pool.getconn()
try:
conn.set_session(readonly=True, autocommit=False)
yield conn
finally:
try:
conn.rollback()
finally:
_pool.putconn(conn)
def _rows(cur) -> list[dict]:
"""Materialise the cursor as a list of JSON-safe dicts."""
if cur.description is None:
return []
cols = [d[0] for d in cur.description]
out = []
for row in cur.fetchall():
out.append({c: _jsonable(v) for c, v in zip(cols, row)})
return out
def _jsonable(v):
"""Coerce non-JSON-native values (dates, Decimal, etc.) to str."""
if v is None or isinstance(v, (bool, int, float, str)):
return v
return str(v)
# ── SQL guard for the general query tool ─────────────────────────────────────
# The analytics_ro role + read-only txn already make writes impossible; this guard
# exists to return CLEAN errors (and block multi-statements / SET that could relax
# read-only) instead of letting the DB raise.
_FORBIDDEN = re.compile(
r"\b(insert|update|delete|drop|alter|create|grant|revoke|truncate|copy|call|do|merge|"
r"vacuum|reindex|refresh|comment|lock|set|reset)\b",
re.IGNORECASE,
)
def _strip_comments(sql: str) -> str:
sql = re.sub(r"/\*.*?\*/", " ", sql, flags=re.DOTALL) # block comments
sql = re.sub(r"--[^\n]*", " ", sql) # line comments
return sql.strip()
def _guard(sql: str) -> str:
"""Validate a single read-only statement; return the cleaned statement."""
stripped = _strip_comments(sql)
if not stripped:
raise ValueError("Empty query.")
parts = [p for p in stripped.split(";") if p.strip()] # allow one trailing ;
if len(parts) != 1:
raise ValueError("Only a single statement is allowed.")
stmt = parts[0].strip()
if not re.match(r"^(select|with)\b", stmt, re.IGNORECASE):
raise ValueError("Only SELECT / WITH queries are allowed.")
if _FORBIDDEN.search(stmt):
raise ValueError("Query contains a forbidden (write/DDL) keyword.")
return stmt
# ── MCP server + tools ───────────────────────────────────────────────────────
# The MCP SDK ships DNS-rebinding protection that, by default, only accepts a
# localhost Host header and returns 421 for anything else — which breaks this
# service behind Traefik (Host = fleetmcp.*). That protection targets browser
# attacks on localhost-bound servers; it does not apply to a public, TLS-terminated,
# Bearer-authenticated service. So it is OFF by default here, and re-enableable via
# MCP_DNS_REBINDING_PROTECTION=1 with an explicit MCP_ALLOWED_HOSTS allowlist.
_DNS_PROT = os.getenv("MCP_DNS_REBINDING_PROTECTION", "0") == "1"
_ALLOWED_HOSTS = [
h.strip()
for h in os.getenv(
"MCP_ALLOWED_HOSTS",
"fleetmcp.fivetitude.com,fleetmcp.rahamafresh.com,localhost,127.0.0.1",
).split(",")
if h.strip()
]
_transport_security = TransportSecuritySettings(
enable_dns_rebinding_protection=_DNS_PROT,
allowed_hosts=_ALLOWED_HOSTS,
allowed_origins=[f"https://{h}" for h in _ALLOWED_HOSTS],
)
mcp = FastMCP("fireside-analytics", stateless_http=True, transport_security=_transport_security)
@mcp.tool()
def query(sql: str, max_rows: int = 1000) -> dict:
"""Run a read-only SELECT/WITH query against the fleet database.
Only the reporting.* and tracksolid.* schemas are readable. Single statement
only; write/DDL is rejected. Returns up to `max_rows` rows (default 1000, hard
cap 10000). A LIMIT is auto-applied when absent. Result: {row_count, truncated, rows}.
"""
stmt = _guard(sql)
cap = max(1, min(int(max_rows), MAX_ROWS_CEIL))
if not re.search(r"\blimit\b", stmt, re.IGNORECASE):
stmt = f"{stmt}\nLIMIT {cap + 1}" # +1 row to detect truncation
t0 = time.monotonic()
with _ro_conn() as conn, conn.cursor() as cur:
cur.execute(stmt)
rows = _rows(cur)
truncated = len(rows) > cap
rows = rows[:cap]
dur_ms = int((time.monotonic() - t0) * 1000)
log.info("query rows=%d trunc=%s %dms :: %s", len(rows), truncated, dur_ms, sql[:200])
return {"row_count": len(rows), "truncated": truncated, "rows": rows}
@mcp.tool()
def list_schemas() -> list[dict]:
"""List the readable schemas (reporting, tracksolid) with their object counts."""
with _ro_conn() as conn, conn.cursor() as cur:
cur.execute(
"SELECT table_schema AS schema, count(*) AS objects "
"FROM information_schema.tables WHERE table_schema = ANY(%s) "
"GROUP BY 1 ORDER BY 1",
(list(READABLE_SCHEMAS),),
)
return _rows(cur)
@mcp.tool()
def list_tables(schema: str) -> list[dict]:
"""List tables + views in a schema (must be reporting or tracksolid)."""
if schema not in READABLE_SCHEMAS:
raise ValueError(f"schema must be one of {READABLE_SCHEMAS}")
with _ro_conn() as conn, conn.cursor() as cur:
cur.execute(
"SELECT table_name AS name, table_type AS kind "
"FROM information_schema.tables WHERE table_schema = %s "
"ORDER BY 1",
(schema,),
)
return _rows(cur)
@mcp.tool()
def describe_table(schema: str, table: str) -> list[dict]:
"""Describe a table/view: columns, types, nullability, defaults."""
if schema not in READABLE_SCHEMAS:
raise ValueError(f"schema must be one of {READABLE_SCHEMAS}")
with _ro_conn() as conn, conn.cursor() as cur:
cur.execute(
"SELECT column_name AS column, data_type AS type, "
"is_nullable AS nullable, column_default AS default "
"FROM information_schema.columns "
"WHERE table_schema = %s AND table_name = %s ORDER BY ordinal_position",
(schema, table),
)
return _rows(cur)
@mcp.tool()
def list_functions(schema: str = "reporting") -> list[dict]:
"""List callable functions (e.g. reporting.fn_*) with their argument signatures."""
if schema not in READABLE_SCHEMAS:
raise ValueError(f"schema must be one of {READABLE_SCHEMAS}")
with _ro_conn() as conn, conn.cursor() as cur:
cur.execute(
"SELECT p.proname AS name, pg_get_function_arguments(p.oid) AS args "
"FROM pg_proc p JOIN pg_namespace n ON n.oid = p.pronamespace "
"WHERE n.nspname = %s ORDER BY 1",
(schema,),
)
return _rows(cur)
_IDENT = re.compile(r"^[a-z_][a-z0-9_]*$", re.IGNORECASE)
@mcp.tool()
def sample_table(schema: str, table: str, n: int = 20) -> dict:
"""Return the first `n` rows of a table/view (convenience over query)."""
if schema not in READABLE_SCHEMAS:
raise ValueError(f"schema must be one of {READABLE_SCHEMAS}")
if not _IDENT.match(table):
raise ValueError("table must be a simple identifier")
return query(f'SELECT * FROM "{schema}"."{table}"', max_rows=n)
# ── Bearer-token auth ─────────────────────────────────────────────────────────
# MCP_AUTH_TOKENS = "alice:tok1,bob:tok2" → {token: name}. Per-analyst tokens make
# access revocable (edit the env + redeploy) and attributable in the logs.
_TOKENS = {
t.split(":", 1)[1]: t.split(":", 1)[0]
for t in os.getenv("MCP_AUTH_TOKENS", "").split(",")
if ":" in t
}
class BearerAuth(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
if request.url.path == "/healthz":
return await call_next(request)
auth = request.headers.get("authorization", "")
token = auth[7:] if auth.lower().startswith("bearer ") else ""
caller = _TOKENS.get(token)
if caller is None:
return JSONResponse({"error": "unauthorized"}, status_code=401)
request.state.caller = caller
return await call_next(request)
async def healthz(_request):
return JSONResponse({"ok": True, "tokens": len(_TOKENS)})
app = mcp.streamable_http_app()
app.add_middleware(BearerAuth)
# Starlette exposes add_route (not a Flask-style @app.route decorator).
app.add_route("/healthz", healthz, methods=["GET"])
if not _TOKENS:
log.warning("MCP_AUTH_TOKENS is empty — every request will be rejected with 401.")
log.info("Analytics MCP starting. Tokens loaded=%d. Readable schemas=%s.", len(_TOKENS), READABLE_SCHEMAS)