"""
Caching middleware for APIFromAnything.
This module provides middleware for caching API responses to improve performance.
"""
import hashlib
import json
import time
from typing import Any, Callable, Dict, List, Optional, Union
from apifrom.core.request import Request
from apifrom.core.response import Response
from apifrom.middleware.base import BaseMiddleware
[docs]
class MemoryCache:
"""
Simple in-memory cache implementation.
"""
def __init__(self, max_size: int = 1000, ttl: int = 60):
"""
Initialize the memory cache.
Args:
max_size: Maximum number of items to store in the cache
ttl: Time-to-live in seconds for cached items
"""
self.cache: Dict[str, Dict[str, Any]] = {}
self.max_size = max_size
self.ttl = ttl
def get(self, key: str) -> Optional[Any]:
"""
Get a value from the cache.
Args:
key: The cache key
Returns:
The cached value, or None if not found or expired
"""
if key not in self.cache:
return None
item = self.cache[key]
if time.time() > item["expires_at"]:
# Item has expired
del self.cache[key]
return None
return item["value"]
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
"""
Set a value in the cache.
Args:
key: The cache key
value: The value to cache
ttl: Time-to-live in seconds (overrides the default)
"""
# If cache is full, remove the oldest item
if len(self.cache) >= self.max_size and key not in self.cache:
oldest_key = min(self.cache.keys(), key=lambda k: self.cache[k]["expires_at"])
del self.cache[oldest_key]
# Set the new item
self.cache[key] = {
"value": value,
"expires_at": time.time() + (ttl or self.ttl)
}
def delete(self, key: str) -> None:
"""
Delete a value from the cache.
Args:
key: The cache key
"""
if key in self.cache:
del self.cache[key]
def clear(self) -> None:
"""
Clear the entire cache.
"""
self.cache.clear()
class CacheMiddleware(BaseMiddleware):
"""
Middleware for caching API responses.
"""
def __init__(
self,
cache_backend: Any = None,
ttl: int = 60,
methods: Optional[list] = None,
exclude_routes: Optional[list] = None,
vary_headers: Optional[list] = None,
key_prefix: str = "apifrom-cache:",
):
"""
Initialize the cache middleware.
Args:
cache_backend: The cache backend to use (defaults to MemoryCache)
ttl: Default time-to-live in seconds for cached items
methods: HTTP methods to cache (defaults to ["GET"])
exclude_routes: Routes to exclude from caching
vary_headers: Headers to include in the cache key
key_prefix: Prefix for cache keys
"""
super().__init__(
ttl=ttl,
methods=methods,
exclude_routes=exclude_routes,
vary_headers=vary_headers,
key_prefix=key_prefix
)
self.cache = cache_backend or MemoryCache(ttl=ttl)
self.ttl = ttl
self.methods = methods or ["GET"]
self.exclude_routes = exclude_routes or []
self.vary_headers = vary_headers or []
self.key_prefix = key_prefix
def _should_cache(self, request: Request) -> bool:
"""
Determine if a request should be cached.
Args:
request: The request object
Returns:
True if the request should be cached, False otherwise
"""
# Only cache specified methods
if request.method not in self.methods:
return False
# Don't cache excluded routes
for route in self.exclude_routes:
if request.path.startswith(route):
return False
# Don't cache requests with Cache-Control: no-cache or no-store
cache_control = request.headers.get("Cache-Control", "")
if "no-cache" in cache_control or "no-store" in cache_control:
return False
return True
def _generate_cache_key(self, request: Request) -> str:
"""
Generate a cache key for a request.
Args:
request: The request object
Returns:
The cache key
"""
# Start with the method and path
key_parts = [request.method, request.path]
# Add query parameters
if request.query_params:
# Sort query params to ensure consistent keys
sorted_params = sorted(request.query_params.items())
key_parts.append(json.dumps(sorted_params))
# Add vary headers
for header in self.vary_headers:
if header in request.headers:
key_parts.append(f"{header}:{request.headers[header]}")
# Generate a hash of the key parts
key_str = ":".join(key_parts)
key_hash = hashlib.md5(key_str.encode()).hexdigest()
return f"{self.key_prefix}{key_hash}"
async def process_request(self, request: Request) -> Request:
"""
Process a request through the cache middleware.
Args:
request: The request object
Returns:
The request object
"""
# Store whether we should cache this request in the request state
# so we can use it in process_response
request.state.should_cache = self._should_cache(request)
if request.state.should_cache:
request.state.cache_key = self._generate_cache_key(request)
# Try to get from cache
cached_response = self.cache.get(request.state.cache_key)
if cached_response is not None:
# Store the cached response in the request state
request.state.cached_response = Response.from_dict(cached_response)
request.state.cached_response.headers["X-Cache"] = "HIT"
return request
async def process_response(self, response: Response) -> Response:
"""
Process a response through the cache middleware.
Args:
response: The response object
Returns:
The response object
"""
# If we have a cached response in the request state, return it
if hasattr(response.request.state, 'cached_response'):
return response.request.state.cached_response
# If we should cache this response and it's cacheable
if (hasattr(response.request.state, 'should_cache') and
response.request.state.should_cache and
response.status_code >= 200 and response.status_code < 300):
self.cache.set(response.request.state.cache_key, response.to_dict(), self.ttl)
response.headers["X-Cache"] = "MISS"
return response
async def __call__(self, scope, receive, send):
"""
ASGI callable.
Args:
scope: The ASGI scope.
receive: The ASGI receive function.
send: The ASGI send function.
"""
# Only process HTTP requests
if scope["type"] != "http":
await self.app(scope, receive, send)
return
# Extract request information from scope
method = scope["method"]
path = scope["path"]
headers = dict(scope["headers"])
# Check if we should cache this request
should_cache = method in self.methods
for route in self.exclude_routes:
if path.startswith(route):
should_cache = False
break
if not should_cache:
await self.app(scope, receive, send)
return
# Generate cache key
key_parts = [method, path]
if "query_string" in scope and scope["query_string"]:
key_parts.append(scope["query_string"].decode("utf-8"))
for header in self.vary_headers:
header_bytes = header.encode("utf-8")
for h_key, h_value in scope["headers"]:
if h_key == header_bytes:
key_parts.append(f"{header}:{h_value.decode('utf-8')}")
break
key_str = ":".join(key_parts)
key_hash = hashlib.md5(key_str.encode()).hexdigest()
cache_key = f"{self.key_prefix}{key_hash}"
# Try to get from cache
cached_response = self.cache.get(cache_key)
if cached_response is not None:
# Return cached response
await send({
"type": "http.response.start",
"status": cached_response["status_code"],
"headers": [(k.encode("utf-8"), v.encode("utf-8")) for k, v in cached_response["headers"].items()] + [(b"X-Cache", b"HIT")],
})
await send({
"type": "http.response.body",
"body": cached_response["content"].encode("utf-8") if isinstance(cached_response["content"], str) else cached_response["content"],
})
return
# Not in cache, process request
# Capture the response to cache it
original_send = send
response_started = False
response_status = 0
response_headers = []
response_body = b""
async def send_wrapper(message):
nonlocal response_started, response_status, response_headers, response_body
if message["type"] == "http.response.start":
response_started = True
response_status = message["status"]
response_headers = message["headers"]
elif message["type"] == "http.response.body":
response_body += message.get("body", b"")
# If this is the last message and we should cache the response
if not message.get("more_body", False) and should_cache and 200 <= response_status < 300:
# Cache the response
headers_dict = {}
for key, value in response_headers:
headers_dict[key.decode("utf-8")] = value.decode("utf-8")
cached_data = {
"status_code": response_status,
"headers": headers_dict,
"content": response_body,
}
self.cache.set(cache_key, cached_data, self.ttl)
# Add X-Cache header
new_headers = []
has_cache_header = False
for key, value in response_headers:
if key.lower() == b"x-cache":
has_cache_header = True
new_headers.append((key, b"MISS"))
else:
new_headers.append((key, value))
if not has_cache_header:
new_headers.append((b"X-Cache", b"MISS"))
message["headers"] = new_headers
await original_send(message)
await self.app(scope, receive, send_wrapper)
[docs]
class CacheControl:
"""
Decorator for controlling cache behavior on specific endpoints.
"""
@staticmethod
def cache(ttl: int = 60):
"""
Cache an endpoint for the specified TTL.
Args:
ttl: Time-to-live in seconds
Returns:
A decorator function
"""
def decorator(func):
func._cache_ttl = ttl
return func
return decorator
@staticmethod
def no_cache(func):
"""
Prevent an endpoint from being cached.
Args:
func: The function to decorate
Returns:
The decorated function
"""
func._no_cache = True
return func