Implementing Custom Middleware for Request Tracing
Key takeaways:
- Accept an incoming
x-request-idor mint a new UUID so one ID spans a multi-service request. - Store the ID in a
contextvarso any code in the request — including logging — can read it. - Attach a logging filter that injects
request_idinto every record automatically. - Return the ID in the response header so it appears in client bug reports.
- Forward the ID on outbound calls to extend the trace across services.
This guide builds the tracing piece described in Middleware Implementation. It is the foundation that the consistent error envelope and dashboards rely on.
The Problem This Solves
When a production incident arrives, you need to find every log line for one request across one or many services. Without a shared identifier, you are grepping by timestamp and guessing. A correlation ID assigned at the edge and threaded through logs turns that into a single filter.
Prerequisites
- FastAPI with a middleware registered in the application factory.
- Python 3.11+ and the standard
loggingandcontextvarsmodules.
Step-by-Step Implementation
1. Define a context variable for the ID
# app/tracing.py
from contextvars import ContextVar
# Readable from anywhere in the request's async context.
request_id_ctx: ContextVar[str] = ContextVar("request_id", default="-")
2. Set it in middleware
import uuid
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from app.tracing import request_id_ctx
class TracingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
rid = request.headers.get("x-request-id") or str(uuid.uuid4())
token = request_id_ctx.set(rid) # Bind for the duration of this request.
try:
response = await call_next(request)
finally:
request_id_ctx.reset(token) # Clean up so the value never leaks.
response.headers["x-request-id"] = rid
return response
3. Inject the ID into every log record
import logging
from app.tracing import request_id_ctx
class RequestIdFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
record.request_id = request_id_ctx.get() # Always present on every record.
return True
def configure_logging() -> None:
handler = logging.StreamHandler()
handler.addFilter(RequestIdFilter())
handler.setFormatter(logging.Formatter(
'{"level":"%(levelname)s","request_id":"%(request_id)s","msg":"%(message)s"}'
))
logging.basicConfig(level=logging.INFO, handlers=[handler])
4. Propagate the ID on outbound calls
import httpx
from app.tracing import request_id_ctx
async def call_downstream(client: httpx.AsyncClient, url: str) -> httpx.Response:
# Forward the same ID so the trace continues into the next service.
return await client.get(url, headers={"x-request-id": request_id_ctx.get()})
Edge Cases and Gotchas
- Background tasks lose context. A task scheduled after the response runs outside the request's context; capture the ID explicitly and pass it into the task.
- Thread-pool dependencies. A plain
defdependency runs in a worker thread;contextvarscopy into that thread for the call, but long-lived threads will not see later updates. - Header spoofing. Treat an inbound
x-request-idas a trace hint only, never as a trusted value for authorization.
Verification
def test_trace_id_roundtrips(client, caplog):
resp = client.get("/health", headers={"x-request-id": "abc-123"})
assert resp.headers["x-request-id"] == "abc-123" # Preserved, not regenerated.
assert any("abc-123" in r.request_id for r in caplog.records)
Related Reading
- Up to the topic: Middleware Implementation.
- Related patterns: Observability and Tracing for full distributed tracing, and Global Exception Handlers for Consistent API Responses which stamps this ID into error responses.