"""
Error handling middleware for APIFromAnything.
This module provides middleware for catching and formatting exceptions
in a consistent way for API responses.
"""
import sys
import traceback
import logging
import json
from typing import Dict, Any, Optional, Callable, Awaitable, Type, List, Union
import inspect
from apifrom.middleware.base import Middleware
from apifrom.core.request import Request
from apifrom.core.response import Response
# Set up logging
logger = logging.getLogger("apifrom.middleware.error_handling")
class APIError(Exception):
"""Base class for API errors that can be safely exposed to clients."""
def __init__(
self,
message: str,
status_code: int = 400,
error_code: str = "bad_request",
details: Optional[Dict[str, Any]] = None
):
"""
Initialize an API error.
Args:
message: The error message
status_code: The HTTP status code
error_code: An application-specific error code
details: Additional error details
"""
super().__init__(message)
self.message = message
self.status_code = status_code
self.error_code = error_code
self.details = details or {}
class BadRequestError(APIError):
"""Error for invalid client requests."""
def __init__(
self,
message: str = "Bad request",
error_code: str = "bad_request",
details: Optional[Dict[str, Any]] = None
):
super().__init__(
message=message,
status_code=400,
error_code=error_code,
details=details
)
class UnauthorizedError(APIError):
"""Error for unauthorized requests."""
def __init__(
self,
message: str = "Unauthorized",
error_code: str = "unauthorized",
details: Optional[Dict[str, Any]] = None
):
super().__init__(
message=message,
status_code=401,
error_code=error_code,
details=details
)
class ForbiddenError(APIError):
"""Error for forbidden requests."""
def __init__(
self,
message: str = "Forbidden",
error_code: str = "forbidden",
details: Optional[Dict[str, Any]] = None
):
super().__init__(
message=message,
status_code=403,
error_code=error_code,
details=details
)
class NotFoundError(APIError):
"""Error for resources that don't exist."""
def __init__(
self,
message: str = "Not found",
error_code: str = "not_found",
details: Optional[Dict[str, Any]] = None
):
super().__init__(
message=message,
status_code=404,
error_code=error_code,
details=details
)
class MethodNotAllowedError(APIError):
"""Error for disallowed HTTP methods."""
def __init__(
self,
message: str = "Method not allowed",
error_code: str = "method_not_allowed",
details: Optional[Dict[str, Any]] = None,
allowed_methods: Optional[List[str]] = None
):
details = details or {}
if allowed_methods:
details["allowed_methods"] = allowed_methods
super().__init__(
message=message,
status_code=405,
error_code=error_code,
details=details
)
class ConflictError(APIError):
"""Error for resource conflicts."""
def __init__(
self,
message: str = "Conflict",
error_code: str = "conflict",
details: Optional[Dict[str, Any]] = None
):
super().__init__(
message=message,
status_code=409,
error_code=error_code,
details=details
)
class UnprocessableEntityError(APIError):
"""Error for validation failures."""
def __init__(
self,
message: str = "Validation error",
error_code: str = "validation_error",
details: Optional[Dict[str, Any]] = None,
validation_errors: Optional[Dict[str, List[str]]] = None
):
details = details or {}
if validation_errors:
details["validation_errors"] = validation_errors
super().__init__(
message=message,
status_code=422,
error_code=error_code,
details=details
)
class TooManyRequestsError(APIError):
"""Error for rate limit exceeded."""
def __init__(
self,
message: str = "Too many requests",
error_code: str = "rate_limit_exceeded",
details: Optional[Dict[str, Any]] = None,
retry_after: Optional[int] = None
):
details = details or {}
if retry_after is not None:
details["retry_after"] = retry_after
super().__init__(
message=message,
status_code=429,
error_code=error_code,
details=details
)
class InternalServerError(APIError):
"""Error for internal server errors."""
def __init__(
self,
message: str = "Internal server error",
error_code: str = "internal_error",
details: Optional[Dict[str, Any]] = None
):
super().__init__(
message=message,
status_code=500,
error_code=error_code,
details=details
)
class ServiceUnavailableError(APIError):
"""Error for unavailable services."""
def __init__(
self,
message: str = "Service unavailable",
error_code: str = "service_unavailable",
details: Optional[Dict[str, Any]] = None,
retry_after: Optional[int] = None
):
details = details or {}
if retry_after is not None:
details["retry_after"] = retry_after
super().__init__(
message=message,
status_code=503,
error_code=error_code,
details=details
)
# Map exception types to handlers
[docs]
class ExceptionHandler:
"""Handler for converting exceptions to API responses."""
def __init__(self, exception_class: Type[Exception], status_code: int, error_code: str):
"""
Initialize an exception handler.
Args:
exception_class: The exception class to handle
status_code: The HTTP status code to return
error_code: The error code to include in the response
"""
self.exception_class = exception_class
self.status_code = status_code
self.error_code = error_code
def __call__(self, exception: Exception) -> Response:
"""
Convert an exception to a response.
Args:
exception: The exception to handle
Returns:
An API response
"""
# Get the error message
message = str(exception)
# Create a response
return Response(
content={
"error": {
"message": message,
"code": self.error_code,
"status": self.status_code
}
},
status_code=self.status_code,
headers={"Content-Type": "application/json"}
)
class ErrorHandlingMiddleware(Middleware):
"""
Middleware for handling errors in API requests.
This middleware catches exceptions and returns appropriate error responses.
"""
def __init__(
self,
debug: bool = False,
include_traceback: bool = False,
include_exception_class: bool = False,
log_exceptions: bool = True,
json_encoder: Optional[Type[json.JSONEncoder]] = None,
**kwargs
):
"""
Initialize the error handling middleware.
Args:
debug: Whether to include debug information in error responses
include_traceback: Whether to include tracebacks in debug mode
include_exception_class: Whether to include exception class names
log_exceptions: Whether to log exceptions
json_encoder: A custom JSON encoder for error responses
**kwargs: Additional options for the base middleware
"""
super().__init__(**kwargs)
self.debug = debug
self.include_traceback = include_traceback
self.include_exception_class = include_exception_class
self.log_exceptions = log_exceptions
self.json_encoder = json_encoder
# Set up exception handlers
self.exception_handlers = {}
self._setup_default_handlers()
def _setup_default_handlers(self) -> None:
"""Set up default exception handlers."""
# Register handlers for built-in API errors
self.add_exception_handler(
BadRequestError,
lambda e: self._handle_api_error(e)
)
self.add_exception_handler(
UnauthorizedError,
lambda e: self._handle_api_error(e)
)
self.add_exception_handler(
ForbiddenError,
lambda e: self._handle_api_error(e)
)
self.add_exception_handler(
NotFoundError,
lambda e: self._handle_api_error(e)
)
self.add_exception_handler(
MethodNotAllowedError,
lambda e: self._handle_api_error(e)
)
self.add_exception_handler(
ConflictError,
lambda e: self._handle_api_error(e)
)
self.add_exception_handler(
UnprocessableEntityError,
lambda e: self._handle_api_error(e)
)
self.add_exception_handler(
TooManyRequestsError,
lambda e: self._handle_api_error(e)
)
self.add_exception_handler(
InternalServerError,
lambda e: self._handle_api_error(e)
)
self.add_exception_handler(
ServiceUnavailableError,
lambda e: self._handle_api_error(e)
)
# Register handlers for common standard exceptions
self.add_exception_handler(
ValueError,
lambda e: self._create_error_response(
str(e) or "Invalid value",
400,
"bad_request",
e
)
)
self.add_exception_handler(
TypeError,
lambda e: self._create_error_response(
str(e) or "Type error",
400,
"bad_request",
e
)
)
self.add_exception_handler(
KeyError,
lambda e: self._create_error_response(
f"Missing key: {str(e)}",
400,
"bad_request",
e
)
)
self.add_exception_handler(
IndexError,
lambda e: self._create_error_response(
str(e) or "Index error",
400,
"bad_request",
e
)
)
self.add_exception_handler(
AttributeError,
lambda e: self._create_error_response(
str(e) or "Attribute error",
500,
"internal_error",
e
)
)
self.add_exception_handler(
NotImplementedError,
lambda e: self._create_error_response(
str(e) or "Not implemented",
501,
"not_implemented",
e
)
)
self.add_exception_handler(
PermissionError,
lambda e: self._create_error_response(
str(e) or "Permission denied",
403,
"forbidden",
e
)
)
# Fallback handler for all other exceptions
self.add_exception_handler(
Exception,
lambda e: self._create_error_response(
"Internal server error",
500,
"internal_error",
e
)
)
def add_exception_handler(
self,
exception_class: Type[Exception],
handler: Callable[[Exception], Response]
) -> None:
"""
Add a custom exception handler.
Args:
exception_class: The exception class to handle
handler: A function that converts the exception to a response
"""
self.exception_handlers[exception_class] = handler
def _handle_api_error(self, exception: APIError) -> Response:
"""
Handle an API error.
Args:
exception: The API error
Returns:
An API response
"""
return self._create_error_response(
exception.message,
exception.status_code,
exception.error_code,
exception,
exception.details
)
def _create_error_response(
self,
message: str,
status_code: int,
error_code: str,
exception: Exception,
details: Optional[Dict[str, Any]] = None
) -> Response:
"""
Create an error response.
Args:
message: The error message
status_code: The HTTP status code
error_code: The error code
exception: The original exception
details: Additional error details
Returns:
An API response
"""
# Prepare the error body
error = {
"message": message,
"code": error_code,
"status": status_code
}
# Include details if provided
if details:
error["details"] = details
# Include debug information if in debug mode
if self.debug:
# Include exception class if configured
if self.include_exception_class:
error["exception"] = exception.__class__.__name__
# Include traceback if configured
if self.include_traceback:
tb = traceback.format_exception(
type(exception),
exception,
exception.__traceback__
)
error["traceback"] = "".join(tb)
# Create the response
return Response(
content={"error": error},
status_code=status_code,
headers={"Content-Type": "application/json"}
)
def _find_handler(self, exception: Exception) -> Callable[[Exception], Response]:
"""
Find the appropriate handler for an exception.
Args:
exception: The exception to handle
Returns:
A handler function
"""
# Match the exception class or the closest parent class
for exception_class in self.exception_handlers:
if isinstance(exception, exception_class):
return self.exception_handlers[exception_class]
# Default to the generic exception handler
return self.exception_handlers[Exception]
async def dispatch(
self,
request: Request,
call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""
Dispatch a request, catching and handling any exceptions.
Args:
request: The request to process
call_next: The next middleware or route handler
Returns:
The response
"""
try:
# Try to process the request normally
response = await call_next(request)
return response
except Exception as e:
# Log the exception if configured
if self.log_exceptions:
logger.exception(f"Error processing request: {request.method} {request.url.path}")
# Find the appropriate handler
handler = self._find_handler(e)
# Handle the exception
return handler(e)
async def process_request(self, request: Request) -> Request:
"""
Process a request.
Args:
request: The request to process
Returns:
The processed request
"""
# This middleware doesn't modify the request, just passes it through
return request
# Export public symbols
__all__ = [
"ErrorHandlingMiddleware",
"APIError",
"BadRequestError",
"UnauthorizedError",
"ForbiddenError",
"NotFoundError",
"MethodNotAllowedError",
"ConflictError",
"UnprocessableEntityError",
"TooManyRequestsError",
"InternalServerError",
"ServiceUnavailableError",
"ExceptionHandler",
]