"""
Middleware for metrics collection in APIFromAnything.
This module provides middleware components for collecting metrics
during API request processing.
"""
from typing import Callable, Dict, Any, Optional
from apifrom.monitoring.metrics import MetricsCollector
[docs]
class MetricsMiddleware:
"""
Middleware for collecting metrics during API request processing.
This middleware tracks request counts, durations, and error rates
for each endpoint in the API.
"""
def __init__(self, collector: Optional[MetricsCollector] = None):
"""
Initialize the metrics middleware.
Args:
collector: Metrics collector to use for tracking metrics.
"""
self.collector = collector or MetricsCollector()
self._request_timers = {}
async def pre_request(self, request: Any, endpoint: str) -> None:
"""
Process a request before it is handled by the endpoint.
Args:
request: The request object.
endpoint: The endpoint being called.
"""
# Start tracking the request
timer_id = self.collector.track_request(endpoint)
self._request_timers[id(request)] = (timer_id, endpoint)
async def post_request(self, request: Any, response: Any) -> None:
"""
Process a response after it is generated by the endpoint.
Args:
request: The request object.
response: The response object.
"""
# Get the timer ID and endpoint for this request
request_id = id(request)
if request_id not in self._request_timers:
return
timer_id, endpoint = self._request_timers.pop(request_id)
# Get the status code from the response
status_code = getattr(response, "status_code", 200)
# Track the end of the request
self.collector.track_request_end(timer_id, endpoint, status_code)
async def on_error(self, request: Any, error: Exception, endpoint: str) -> None:
"""
Process an error that occurred during request handling.
Args:
request: The request object.
error: The exception that was raised.
endpoint: The endpoint being called.
"""
# Track the error
error_type = type(error).__name__
self.collector.track_error(error_type, endpoint)
# End the request tracking if it exists
request_id = id(request)
if request_id in self._request_timers:
timer_id, endpoint = self._request_timers.pop(request_id)
self.collector.track_request_end(timer_id, endpoint, 500)
def register(self, app: Any) -> None:
"""
Register the middleware with the application.
Args:
app: The application to register with.
"""
# This method should be implemented by the specific framework integration
raise NotImplementedError("This method should be implemented by framework-specific subclasses")
class FlaskMetricsMiddleware(MetricsMiddleware):
"""Metrics middleware for Flask applications."""
def register(self, app: Any) -> None:
"""
Register the middleware with a Flask application.
Args:
app: The Flask application to register with.
"""
from flask import request, g
@app.before_request
def before_request():
endpoint = request.endpoint or request.path
timer_id = self.collector.track_request(endpoint)
g.metrics_timer_id = timer_id
g.metrics_endpoint = endpoint
@app.after_request
def after_request(response):
if hasattr(g, 'metrics_timer_id') and hasattr(g, 'metrics_endpoint'):
self.collector.track_request_end(
g.metrics_timer_id,
g.metrics_endpoint,
response.status_code
)
return response
@app.errorhandler(Exception)
def handle_error(error):
if hasattr(g, 'metrics_endpoint'):
error_type = type(error).__name__
self.collector.track_error(error_type, g.metrics_endpoint)
# Re-raise the error to let Flask handle it
raise error
class FastAPIMetricsMiddleware(MetricsMiddleware):
"""Metrics middleware for FastAPI applications."""
def register(self, app: Any) -> None:
"""
Register the middleware with a FastAPI application.
Args:
app: The FastAPI application to register with.
"""
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
class FastAPIMetricsMiddlewareImpl(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# Get the endpoint
endpoint = request.url.path
# Start tracking the request
timer_id = self.collector.track_request(endpoint)
try:
# Process the request
response = await call_next(request)
# Track the end of the request
self.collector.track_request_end(timer_id, endpoint, response.status_code)
return response
except Exception as exc:
# Track the error
error_type = type(exc).__name__
self.collector.track_error(error_type, endpoint)
# End the request tracking
self.collector.track_request_end(timer_id, endpoint, 500)
# Re-raise the exception
raise
# Register the middleware with the FastAPI app
app.add_middleware(FastAPIMetricsMiddlewareImpl)
[docs]
class DjangoMetricsMiddleware:
"""Metrics middleware for Django applications."""
def __init__(self, collector: Optional[MetricsCollector] = None):
"""
Initialize the Django metrics middleware.
Args:
collector: Metrics collector to use for tracking metrics.
"""
self.collector = collector or MetricsCollector()
def __call__(self, get_response):
"""
Process a request and response in a Django application.
Args:
get_response: The next middleware or view in the chain.
"""
def middleware(request):
# Get the endpoint
endpoint = request.path
# Start tracking the request
timer_id = self.collector.track_request(endpoint)
try:
# Process the request
response = get_response(request)
# Track the end of the request
self.collector.track_request_end(timer_id, endpoint, response.status_code)
return response
except Exception as exc:
# Track the error
error_type = type(exc).__name__
self.collector.track_error(error_type, endpoint)
# End the request tracking
self.collector.track_request_end(timer_id, endpoint, 500)
# Re-raise the exception
raise
return middleware