"""
Advanced caching middleware for APIFromAnything.
This module provides enhanced caching functionality with various backends,
eviction policies, and optimization features.
"""
import time
import json
import logging
import hashlib
import inspect
import asyncio
import redis
from typing import Dict, List, Any, Optional, Union, Callable, Tuple, Set, TypeVar
import threading
import os
import pickle
from datetime import datetime
from functools import wraps
from apifrom.core.request import Request
from apifrom.core.response import Response
from apifrom.middleware.base import Middleware
# Set up logging
logger = logging.getLogger("apifrom.middleware.cache_advanced")
[docs]
class CacheItem:
"""
Represents an item in the cache with metadata.
Attributes:
key: The cache key
value: The cached value
expires_at: The expiration timestamp
created_at: The creation timestamp
last_accessed_at: The last access timestamp
access_count: The number of times the item has been accessed
size_bytes: The size of the item in bytes
"""
def __init__(
self,
key: str,
value: Any,
ttl: int = 60,
created_at: Optional[float] = None,
):
"""
Initialize a cache item.
Args:
key: The cache key
value: The value to cache
ttl: Time-to-live in seconds
created_at: Creation timestamp (defaults to now)
"""
self.key = key
self.value = value
self.created_at = created_at or time.time()
self.expires_at = self.created_at + ttl
self.last_accessed_at = self.created_at
self.access_count = 0
# Estimate the size in bytes
try:
self.size_bytes = len(pickle.dumps(value))
except:
self.size_bytes = 0
def access(self) -> None:
"""Update the last accessed time and access count."""
self.last_accessed_at = time.time()
self.access_count += 1
def is_expired(self) -> bool:
"""
Check if the item is expired.
Returns:
True if the item is expired, False otherwise
"""
return time.time() > self.expires_at
def get_age_seconds(self) -> float:
"""
Get the age of the item in seconds.
Returns:
The age in seconds
"""
return time.time() - self.created_at
def get_idle_time_seconds(self) -> float:
"""
Get the idle time of the item in seconds.
Returns:
The idle time in seconds
"""
return time.time() - self.last_accessed_at
def to_dict(self) -> Dict[str, Any]:
"""
Convert the cache item to a dictionary.
Returns:
A dictionary representation of the cache item
"""
return {
"key": self.key,
"expires_at": self.expires_at,
"created_at": self.created_at,
"last_accessed_at": self.last_accessed_at,
"access_count": self.access_count,
"size_bytes": self.size_bytes,
"ttl_seconds": self.expires_at - self.created_at,
"remaining_ttl_seconds": max(0, self.expires_at - time.time()),
"age_seconds": self.get_age_seconds(),
"idle_time_seconds": self.get_idle_time_seconds(),
}
[docs]
class CacheEvictionPolicy:
"""Base class for cache eviction policies."""
def select_items_to_evict(self, items: List[CacheItem], target_size: int) -> List[CacheItem]:
"""
Select items to evict to reach the target size.
Args:
items: The list of cache items
target_size: The target number of items to keep
Returns:
A list of items to evict
"""
raise NotImplementedError("Subclasses must implement select_items_to_evict")
class LRUEvictionPolicy(CacheEvictionPolicy):
"""Least Recently Used (LRU) eviction policy."""
def select_items_to_evict(self, items: List[CacheItem], target_size: int) -> List[CacheItem]:
"""
Select items to evict based on the LRU policy.
Args:
items: The list of cache items
target_size: The target number of items to keep
Returns:
A list of items to evict
"""
# Sort items by last_accessed_at (oldest first)
sorted_items = sorted(items, key=lambda item: item.last_accessed_at)
# Return items to evict to reach the target size
items_to_keep = min(target_size, len(sorted_items))
return sorted_items[:-items_to_keep] if items_to_keep > 0 else sorted_items
class LFUEvictionPolicy(CacheEvictionPolicy):
"""Least Frequently Used (LFU) eviction policy."""
def select_items_to_evict(self, items: List[CacheItem], target_size: int) -> List[CacheItem]:
"""
Select items to evict based on the LFU policy.
Args:
items: The list of cache items
target_size: The target number of items to keep
Returns:
A list of items to evict
"""
# Sort items by access count (fewest first)
sorted_items = sorted(items, key=lambda item: item.access_count)
# Return items to evict to reach the target size
items_to_keep = min(target_size, len(sorted_items))
return sorted_items[:-items_to_keep] if items_to_keep > 0 else sorted_items
class TTLEvictionPolicy(CacheEvictionPolicy):
"""Time-To-Live (TTL) eviction policy."""
def select_items_to_evict(self, items: List[CacheItem], target_size: int) -> List[CacheItem]:
"""
Select items to evict based on the TTL policy.
Args:
items: The list of cache items
target_size: The target number of items to keep
Returns:
A list of items to evict
"""
# Sort items by expiration time (soonest first)
sorted_items = sorted(items, key=lambda item: item.expires_at)
# Return items to evict to reach the target size
items_to_keep = min(target_size, len(sorted_items))
return sorted_items[:-items_to_keep] if items_to_keep > 0 else sorted_items
class SizeEvictionPolicy(CacheEvictionPolicy):
"""Size-based eviction policy."""
def select_items_to_evict(self, items: List[CacheItem], target_size: int) -> List[CacheItem]:
"""
Select items to evict based on size.
Args:
items: The list of cache items
target_size: The target number of items to keep
Returns:
A list of items to evict
"""
# Sort items by size (largest first)
sorted_items = sorted(items, key=lambda item: item.size_bytes, reverse=True)
# Return items to evict to reach the target size
items_to_keep = min(target_size, len(sorted_items))
return sorted_items[:-items_to_keep] if items_to_keep > 0 else sorted_items
class HybridEvictionPolicy(CacheEvictionPolicy):
"""Hybrid eviction policy combining multiple factors."""
def select_items_to_evict(self, items: List[CacheItem], target_size: int) -> List[CacheItem]:
"""
Select items to evict based on a hybrid policy.
Args:
items: The list of cache items
target_size: The target number of items to keep
Returns:
A list of items to evict
"""
# Calculate a score for each item (lower is better to keep)
now = time.time()
for item in items:
# Normalize factors to [0, 1] range
ttl_factor = max(0, min(1, (item.expires_at - now) / 3600)) # 1 hour max
recency_factor = max(0, min(1, (now - item.last_accessed_at) / 3600)) # 1 hour max
frequency_factor = max(0, min(1, 1 / (item.access_count + 1))) # Inverse of count
size_factor = max(0, min(1, item.size_bytes / (1024 * 1024))) # 1 MB max
# Calculate weighted score
item.score = (
0.3 * ttl_factor +
0.3 * recency_factor +
0.2 * frequency_factor +
0.2 * size_factor
)
# Sort items by score (highest first, as they are worse to keep)
sorted_items = sorted(items, key=lambda item: getattr(item, "score", 0), reverse=True)
# Return items to evict to reach the target size
items_to_keep = min(target_size, len(sorted_items))
return sorted_items[:-items_to_keep] if items_to_keep > 0 else sorted_items
[docs]
class CacheBackend:
"""Base class for cache backends."""
def get(self, key: str) -> Optional[Any]:
"""
Get a value from the cache.
Args:
key: The cache key
Returns:
The cached value, or None if not found
"""
raise NotImplementedError("Subclasses must implement get")
def set(self, key: str, value: Any, ttl: int = 60) -> None:
"""
Set a value in the cache.
Args:
key: The cache key
value: The value to cache
ttl: Time-to-live in seconds
"""
raise NotImplementedError("Subclasses must implement set")
def delete(self, key: str) -> bool:
"""
Delete a value from the cache.
Args:
key: The cache key
Returns:
True if the key was deleted, False otherwise
"""
raise NotImplementedError("Subclasses must implement delete")
def clear(self) -> None:
"""Clear the cache."""
raise NotImplementedError("Subclasses must implement clear")
def get_stats(self) -> Dict[str, Any]:
"""
Get cache statistics.
Returns:
A dictionary of cache statistics
"""
raise NotImplementedError("Subclasses must implement get_stats")
class MemoryCacheBackend(CacheBackend):
"""In-memory cache backend."""
def __init__(
self,
max_items: int = 1000,
max_size_bytes: int = 50 * 1024 * 1024, # 50 MB
eviction_policy: Optional[CacheEvictionPolicy] = None,
):
"""
Initialize the memory cache backend.
Args:
max_items: Maximum number of items to store
max_size_bytes: Maximum cache size in bytes
eviction_policy: The eviction policy to use
"""
self.cache: Dict[str, CacheItem] = {}
self.max_items = max_items
self.max_size_bytes = max_size_bytes
self.eviction_policy = eviction_policy or LRUEvictionPolicy()
self.lock = threading.RLock()
self.hit_count = 0
self.miss_count = 0
self.set_count = 0
self.eviction_count = 0
self.last_cleanup_time = time.time()
def get(self, key: str) -> Optional[Any]:
"""
Get a value from the cache.
Args:
key: The cache key
Returns:
The cached value, or None if not found
"""
with self.lock:
item = self.cache.get(key)
if item is None:
self.miss_count += 1
return None
if item.is_expired():
self.cache.pop(key, None)
self.miss_count += 1
return None
item.access()
self.hit_count += 1
return item.value
def set(self, key: str, value: Any, ttl: int = 60) -> None:
"""
Set a value in the cache.
Args:
key: The cache key
value: The value to cache
ttl: Time-to-live in seconds
"""
with self.lock:
# Create a new cache item
item = CacheItem(key, value, ttl)
self.cache[key] = item
self.set_count += 1
# Check if we need to evict items
self._evict_if_needed()
# Periodically clean up expired items
self._cleanup_if_needed()
def delete(self, key: str) -> bool:
"""
Delete a value from the cache.
Args:
key: The cache key
Returns:
True if the key was deleted, False otherwise
"""
with self.lock:
if key in self.cache:
del self.cache[key]
return True
return False
def clear(self) -> None:
"""Clear the cache."""
with self.lock:
self.cache.clear()
def get_stats(self) -> Dict[str, Any]:
"""
Get cache statistics.
Returns:
A dictionary of cache statistics
"""
with self.lock:
total_size_bytes = sum(item.size_bytes for item in self.cache.values())
hit_rate = self.hit_count / (self.hit_count + self.miss_count) if (self.hit_count + self.miss_count) > 0 else 0
# Get the top 10 largest items
largest_items = sorted(
[(item.key, item.size_bytes) for item in self.cache.values()],
key=lambda x: x[1],
reverse=True
)[:10]
# Get the top 10 most accessed items
most_accessed = sorted(
[(item.key, item.access_count) for item in self.cache.values()],
key=lambda x: x[1],
reverse=True
)[:10]
return {
"item_count": len(self.cache),
"max_items": self.max_items,
"total_size_bytes": total_size_bytes,
"max_size_bytes": self.max_size_bytes,
"hit_count": self.hit_count,
"miss_count": self.miss_count,
"set_count": self.set_count,
"eviction_count": self.eviction_count,
"hit_rate": hit_rate,
"large_keys": largest_items,
"popular_keys": most_accessed,
}
def _evict_if_needed(self) -> None:
"""Evict items if the cache exceeds its limits."""
# Check if we need to evict items
if len(self.cache) <= self.max_items:
total_size = sum(item.size_bytes for item in self.cache.values())
if total_size <= self.max_size_bytes:
return
# Select items to evict
items = list(self.cache.values())
target_size = max(1, int(0.8 * self.max_items)) # Aim for 80% capacity
items_to_evict = self.eviction_policy.select_items_to_evict(items, target_size)
# Evict the selected items
for item in items_to_evict:
self.cache.pop(item.key, None)
self.eviction_count += 1
def _cleanup_if_needed(self, force: bool = False) -> None:
"""
Clean up expired items if needed.
Args:
force: Whether to force cleanup regardless of the time since last cleanup
"""
now = time.time()
# Clean up every 5 minutes
if not force and (now - self.last_cleanup_time) < 300:
return
self.last_cleanup_time = now
# Remove expired items
expired_keys = [key for key, item in self.cache.items() if item.is_expired()]
for key in expired_keys:
self.cache.pop(key, None)
class RedisCacheBackend(CacheBackend):
"""
Redis cache backend.
"""
def __init__(
self,
redis_url: str = "redis://localhost:6379/0",
prefix: str = "apifrom:",
serializer: Optional[Callable[[Any], bytes]] = None,
deserializer: Optional[Callable[[bytes], Any]] = None,
):
"""
Initialize the Redis cache backend.
Args:
redis_url: The Redis connection URL
prefix: The key prefix to use
serializer: Function to serialize values to bytes
deserializer: Function to deserialize bytes to values
"""
try:
import redis
from redis.exceptions import RedisError
except ImportError:
raise ImportError(
"Redis cache backend requires redis-py package. "
"Install it with: pip install redis"
)
self.redis_url = redis_url
self.prefix = prefix
self._redis = None
self.serializer = serializer or (lambda v: pickle.dumps(v))
self.deserializer = deserializer or (lambda v: pickle.loads(v))
self.hit_count = 0
self.miss_count = 0
self.set_count = 0
# Connect to Redis
try:
self._redis = redis.from_url(redis_url)
self._redis.ping()
except (RedisError, ConnectionError) as e:
logger.warning(f"Failed to connect to Redis: {e}")
self._redis = None
@property
def redis(self):
"""Get the Redis client, reconnecting if necessary."""
if self._redis is None:
try:
import redis
self._redis = redis.from_url(self.redis_url)
self._redis.ping()
except Exception as e:
logger.warning(f"Failed to reconnect to Redis: {e}")
return self._redis
def _make_key(self, key: str) -> str:
"""
Create a prefixed key.
Args:
key: The original key
Returns:
The prefixed key
"""
return f"{self.prefix}{key}"
def get(self, key: str) -> Optional[Any]:
"""
Get a value from the cache.
Args:
key: The cache key
Returns:
The cached value, or None if not found
"""
if self.redis is None:
self.miss_count += 1
return None
try:
value = self.redis.get(self._make_key(key))
if value is None:
self.miss_count += 1
return None
self.hit_count += 1
return self.deserializer(value)
except Exception as e:
logger.warning(f"Error getting value from Redis: {e}")
self.miss_count += 1
return None
def set(self, key: str, value: Any, ttl: int = 60) -> None:
"""
Set a value in the cache.
Args:
key: The cache key
value: The value to cache
ttl: Time-to-live in seconds
"""
if self.redis is None:
return
try:
serialized = self.serializer(value)
self.redis.setex(self._make_key(key), ttl, serialized)
self.set_count += 1
except Exception as e:
logger.warning(f"Error setting value in Redis: {e}")
def delete(self, key: str) -> bool:
"""
Delete a value from the cache.
Args:
key: The cache key
Returns:
True if the key was deleted, False otherwise
"""
if self.redis is None:
return False
try:
return bool(self.redis.delete(self._make_key(key)))
except Exception as e:
logger.warning(f"Error deleting value from Redis: {e}")
return False
def clear(self) -> None:
"""Clear the cache."""
if self.redis is None:
return
try:
# Get all keys with the prefix
keys = self.redis.keys(f"{self.prefix}*")
# Delete all keys in batches
if keys:
for i in range(0, len(keys), 100):
batch = keys[i:i+100]
self.redis.delete(*batch)
except Exception as e:
logger.warning(f"Error clearing Redis cache: {e}")
def get_stats(self) -> Dict[str, Any]:
"""
Get cache statistics.
Returns:
A dictionary of cache statistics
"""
if self.redis is None:
return {
"connected": False,
"hit_count": self.hit_count,
"miss_count": self.miss_count,
"set_count": self.set_count,
"hit_rate": 0,
}
try:
# Get Redis info
info = self.redis.info()
# Get all keys with the prefix
keys = self.redis.keys(f"{self.prefix}*")
item_count = len(keys)
# Calculate hit rate
hit_rate = self.hit_count / (self.hit_count + self.miss_count) if (self.hit_count + self.miss_count) > 0 else 0
return {
"connected": True,
"item_count": item_count,
"hit_count": self.hit_count,
"miss_count": self.miss_count,
"set_count": self.set_count,
"hit_rate": hit_rate,
"redis_version": info.get("redis_version", "unknown"),
"used_memory": info.get("used_memory", 0),
"used_memory_human": info.get("used_memory_human", "unknown"),
"connected_clients": info.get("connected_clients", 0),
}
except Exception as e:
logger.warning(f"Error getting Redis stats: {e}")
return {
"connected": False,
"error": str(e),
"hit_count": self.hit_count,
"miss_count": self.miss_count,
"set_count": self.set_count,
"hit_rate": 0,
}
[docs]
class TagBasedInvalidation:
"""
Tag-based cache invalidation strategy.
This class provides a way to invalidate cache entries based on tags.
Tags are arbitrary strings that can be associated with cache entries.
When a tag is invalidated, all cache entries associated with that tag
are also invalidated.
Example:
```python
# Create a cache backend
cache_backend = MemoryCacheBackend()
# Create a tag-based invalidation strategy
invalidation = TagBasedInvalidation(cache_backend)
# Set a cache entry with tags
cache_backend.set("user:123", {"name": "John"})
invalidation.tag("user:123", ["user", "user:123"])
# Invalidate all cache entries with the "user" tag
invalidation.invalidate_tag("user")
```
"""
def __init__(self, cache_backend: CacheBackend):
"""
Initialize the tag-based invalidation strategy.
Args:
cache_backend: The cache backend to use
"""
self.cache_backend = cache_backend
self._tag_prefix = "_tag:"
def tag(self, key: str, tags: List[str]) -> None:
"""
Tag a cache entry with one or more tags.
Args:
key: The cache key
tags: The tags to associate with the key
"""
for tag in tags:
# Add key to tag's entry list
tag_key = f"tag:{tag}"
keys = self.cache_backend.get(tag_key) or []
if key not in keys:
keys.append(key)
self.cache_backend.set(tag_key, keys)
# Add tag to key's tag list
key_tags_key = f"tags:{key}"
key_tags = self.cache_backend.get(key_tags_key) or []
if tag not in key_tags:
key_tags.append(tag)
self.cache_backend.set(key_tags_key, key_tags)
# Alias for tag method to match the test expectations
add_tags = tag
def invalidate_tag(self, tag: str) -> None:
"""
Invalidate all cache entries associated with a tag.
Args:
tag: The tag to invalidate
"""
tag_key = f"tag:{tag}"
keys = self.cache_backend.get(tag_key) or []
# Delete all cache entries associated with this tag
for key in keys:
self.cache_backend.delete(key)
# Delete the tag itself
self.cache_backend.delete(tag_key)
def invalidate_tags(self, tags: List[str]) -> None:
"""
Invalidate all cache entries associated with any of the given tags.
Args:
tags: The tags to invalidate
"""
for tag in tags:
self.invalidate_tag(tag)
def get_keys_for_tag(self, tag: str) -> List[str]:
"""
Get all cache keys associated with a tag.
Args:
tag: The tag to get keys for
Returns:
A list of cache keys
"""
tag_key = f"{self._tag_prefix}{tag}"
return self.cache_backend.get(tag_key) or []
[docs]
class DependencyBasedInvalidation:
"""
Dependency-based cache invalidation strategy.
This class provides a way to invalidate cache entries based on dependencies.
Dependencies are relationships between cache entries, where invalidating one
entry will also invalidate all entries that depend on it.
Example:
```python
# Create a cache backend
cache_backend = MemoryCacheBackend()
# Create a dependency-based invalidation strategy
invalidation = DependencyBasedInvalidation(cache_backend)
# Set a cache entry with dependencies
cache_backend.set("user:123", {"name": "John"})
invalidation.add_dependency("user:123", "users")
# Set another cache entry with dependencies
cache_backend.set("post:456", {"title": "Hello"})
invalidation.add_dependency("post:456", "posts")
invalidation.add_dependency("post:456", "user:123")
# Invalidate all cache entries that depend on "user:123"
invalidation.invalidate("user:123")
```
"""
def __init__(self, cache_backend: CacheBackend):
"""
Initialize the dependency-based invalidation strategy.
Args:
cache_backend: The cache backend to use
"""
self.cache_backend = cache_backend
self._dep_prefix = "dep:"
self._rev_prefix = "revdep:"
def add_dependency(self, key: str, dependency: str) -> None:
"""
Add a dependency to a cache entry.
Args:
key: The cache key
dependency: The dependency key
"""
# Add forward dependency (dependency -> keys)
dep_key = f"{self._dep_prefix}{dependency}"
keys = self.cache_backend.get(dep_key) or []
if key not in keys:
keys.append(key)
self.cache_backend.set(dep_key, keys)
# Add reverse dependency (key -> dependencies)
rev_key = f"{self._rev_prefix}{key}"
deps = self.cache_backend.get(rev_key) or []
if dependency not in deps:
deps.append(dependency)
self.cache_backend.set(rev_key, deps)
def add_dependencies(self, key: str, dependencies: List[str]) -> None:
"""
Add multiple dependencies to a cache entry.
Args:
key: The cache key
dependencies: List of dependency keys
"""
for dependency in dependencies:
self.add_dependency(key, dependency)
def invalidate(self, key: str) -> None:
"""
Invalidate a cache entry and all entries that depend on it.
Args:
key: The key to invalidate
"""
# Get all keys that depend on this key
dep_key = f"{self._dep_prefix}{key}"
keys = self.cache_backend.get(dep_key) or []
# Invalidate all dependent keys
for dependent_key in keys:
self.cache_backend.delete(dependent_key)
# Also invalidate reverse dependencies
rev_key = f"{self._rev_prefix}{dependent_key}"
self.cache_backend.delete(rev_key)
# Invalidate the dependency key itself
self.cache_backend.delete(dep_key)
# Invalidate the key itself
self.cache_backend.delete(key)
# Invalidate reverse dependencies
rev_key = f"{self._rev_prefix}{key}"
self.cache_backend.delete(rev_key)
def invalidate_dependency(self, dependency: str) -> None:
"""
Invalidate all cache entries that depend on a specific dependency.
Args:
dependency: The dependency key to invalidate
"""
# Get all keys that depend on this dependency
dep_key = f"{self._dep_prefix}{dependency}"
keys = self.cache_backend.get(dep_key) or []
# Invalidate all dependent keys
for key in keys:
self.cache_backend.delete(key)
# Also invalidate reverse dependencies
rev_key = f"{self._rev_prefix}{key}"
self.cache_backend.delete(rev_key)
# Invalidate the dependency key itself
self.cache_backend.delete(dep_key)
def get_dependencies(self, key: str) -> List[str]:
"""
Get all dependencies of a cache entry.
Args:
key: The cache key
Returns:
A list of dependency keys
"""
rev_key = f"{self._rev_prefix}{key}"
return self.cache_backend.get(rev_key) or []
def get_dependents(self, key: str) -> List[str]:
"""
Get all cache entries that depend on a key.
Args:
key: The dependency key
Returns:
A list of cache keys
"""
dep_key = f"{self._dep_prefix}{key}"
return self.cache_backend.get(dep_key) or []
[docs]
class CacheControl:
"""
Cache control decorators for API endpoints.
This class provides decorators to control caching behavior for API endpoints.
It includes decorators to cache responses, prevent caching, and invalidate cache entries.
Example:
```python
from apifrom import API, api
from apifrom.middleware.cache_advanced import CacheControl
app = API()
@api(route="/users/{user_id}", method="GET")
@CacheControl.cache(ttl=60, tags=["user"])
def get_user(user_id: str):
# This response will be cached for 60 seconds
return {"id": user_id, "name": "John"}
@api(route="/users/{user_id}", method="PUT")
@CacheControl.invalidate(["user"])
def update_user(user_id: str, name: str):
# This will invalidate all cache entries with the "user" tag
return {"id": user_id, "name": name}
@api(route="/users/{user_id}/sensitive", method="GET")
@CacheControl.no_cache
def get_sensitive_user_data(user_id: str):
# This response will not be cached
return {"id": user_id, "ssn": "123-45-6789"}
```
"""
@staticmethod
def cache(ttl: int = 60, tags: Optional[List[str]] = None, dependencies: Optional[List[str]] = None):
"""
Decorator to cache the response of an API endpoint.
Args:
ttl: Time to live in seconds
tags: Tags to associate with the cache entry
dependencies: Dependencies to associate with the cache entry
Returns:
A decorator function
"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
# Call the original function
result = await func(*args, **kwargs) if asyncio.iscoroutinefunction(func) else func(*args, **kwargs)
# Store cache metadata on the result
if hasattr(result, "headers"):
result.headers["X-Cache-TTL"] = str(ttl)
if tags:
result.headers["X-Cache-Tags"] = ",".join(tags)
if dependencies:
result.headers["X-Cache-Dependencies"] = ",".join(dependencies)
return result
# Store cache metadata on the wrapper function
wrapper.__cache_control__ = {
"cache": True,
"ttl": ttl,
"tags": tags or [],
"dependencies": dependencies or []
}
return wrapper
return decorator
@staticmethod
def no_cache(func):
"""
Decorator to prevent caching of an API endpoint.
Args:
func: The function to decorate
Returns:
The decorated function
"""
@wraps(func)
async def wrapper(*args, **kwargs):
# Call the original function
result = await func(*args, **kwargs) if asyncio.iscoroutinefunction(func) else func(*args, **kwargs)
# Set cache control headers
if hasattr(result, "headers"):
result.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0"
result.headers["Pragma"] = "no-cache"
result.headers["Expires"] = "0"
return result
# Store cache metadata on the wrapper function
wrapper.__cache_control__ = {
"cache": False
}
return wrapper
@staticmethod
def invalidate(patterns: List[str]):
"""
Decorator to invalidate cache entries matching the given patterns.
Args:
patterns: Patterns to match cache keys
Returns:
A decorator function
"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
# Call the original function
result = await func(*args, **kwargs) if asyncio.iscoroutinefunction(func) else func(*args, **kwargs)
# Find the request object in args or kwargs
request = None
for arg in args:
if hasattr(arg, "app") and hasattr(arg.app, "middleware"):
request = arg
break
if request is not None:
# Find the cache middleware
for middleware in request.app.middleware:
if hasattr(middleware, "__class__") and middleware.__class__.__name__ in ["CacheMiddleware", "AdvancedCacheMiddleware"]:
# Invalidate the patterns
for pattern in patterns:
middleware.invalidate(pattern)
break
# Store invalidation patterns on the result
if hasattr(result, "headers"):
result.headers["X-Cache-Invalidate"] = ",".join(patterns)
return result
# Store cache metadata on the wrapper function
wrapper.__cache_control__ = {
"invalidate": True,
"patterns": patterns
}
return wrapper
return decorator
class CacheMiddleware(Middleware):
"""
Middleware for caching API responses.
This middleware caches API responses based on the request method, path, and query parameters.
It supports various cache backends and configuration options.
"""
def __init__(
self,
cache_backend: Optional[CacheBackend] = None,
ttl: int = 60,
cache_methods: Optional[Set[str]] = None,
ignore_paths: Optional[Set[str]] = None,
vary_headers: Optional[Set[str]] = None,
cache_control_header: bool = True,
):
"""
Initialize the cache middleware.
Args:
cache_backend: The cache backend to use (defaults to MemoryCacheBackend)
ttl: Default time-to-live in seconds
cache_methods: Set of HTTP methods to cache (defaults to {"GET"})
ignore_paths: Set of paths to exclude from caching
vary_headers: Set of headers to include in the cache key
cache_control_header: Whether to set Cache-Control headers
"""
self.cache = cache_backend or MemoryCacheBackend()
self.ttl = ttl
self.cache_methods = cache_methods or {"GET"}
self.ignore_paths = ignore_paths or set()
self.vary_headers = vary_headers or {"Accept", "Accept-Encoding"}
self.cache_control_header = cache_control_header
async def process_request(self, request: Request, call_next: Callable) -> Response:
"""
Process a request and potentially return a cached response.
Args:
request: The request object
call_next: The next middleware function
Returns:
The response object
"""
# Skip caching for non-cacheable requests
if not self._should_cache(request):
return await call_next(request)
# Generate a cache key for the request
cache_key = self._generate_cache_key(request)
# Try to get the response from the cache
cached_response = self.cache.get(cache_key)
if cached_response is not None:
# Return the cached response
if self.cache_control_header:
cached_response.headers["X-Cache"] = "HIT"
return cached_response
# Get the response from the next middleware
try:
response = await call_next(request)
except TypeError:
# Handle non-awaitable responses
response = call_next(request)
# Skip caching for non-cacheable responses
if not self._should_cache_response(response):
return response
# Cache the response if it's cacheable
if self._should_cache_response(response):
# Clone the response to avoid modifying the original
response_to_cache = response.copy()
# Set cache headers
if self.cache_control_header:
response.headers["Cache-Control"] = f"max-age={self.ttl}"
response.headers["X-Cache"] = "MISS"
# Cache the response
self.cache.set(cache_key, response_to_cache, self.ttl)
return response
def _should_cache(self, request: Request) -> bool:
"""
Check if a request should be cached.
Args:
request: The request object
Returns:
True if the request should be cached, False otherwise
"""
# Check if the method is cacheable
if request.method not in self.cache_methods:
return False
# Check if the path is ignored
for path in self.ignore_paths:
if request.path.startswith(path):
return False
# Check if the request has cache control headers
cache_control = request.headers.get("Cache-Control", "")
if "no-cache" in cache_control or "no-store" in cache_control:
return False
return True
def _should_cache_response(self, response: Response) -> bool:
"""
Check if a response should be cached.
Args:
response: The response object
Returns:
True if the response should be cached, False otherwise
"""
# Only cache successful responses
if response.status_code < 200 or response.status_code >= 300:
return False
# Check if the response has cache control headers
cache_control = response.headers.get("Cache-Control", "")
if "no-cache" in cache_control or "no-store" in cache_control:
return False
return True
def _generate_cache_key(self, request: Request) -> str:
"""
Generate a cache key for a request.
Args:
request: The request object
Returns:
A cache key string
"""
# Get the method and path
key_parts = [request.method, request.path]
# Add query parameters
if request.query_params:
sorted_params = sorted(request.query_params.items())
key_parts.append(str(sorted_params))
# Add vary headers
for header in self.vary_headers:
value = request.headers.get(header)
if value:
key_parts.append(f"{header}:{value}")
# Generate a hash of the key parts
key_str = ":".join(key_parts)
return hashlib.md5(key_str.encode()).hexdigest()
def get_stats(self) -> Dict[str, Any]:
"""
Get cache statistics.
Returns:
A dictionary of cache statistics
"""
return self.cache.get_stats()
def clear(self) -> None:
"""Clear the cache."""
self.cache.clear()
class AdvancedCacheMiddleware(CacheMiddleware):
"""
Advanced middleware for caching API responses.
This middleware extends the basic CacheMiddleware with additional features
such as per-endpoint TTL, response compression, and automatic cache key generation.
"""
def __init__(
self,
cache_backend: Optional[CacheBackend] = None,
ttl: int = 60,
cache_methods: Optional[Set[str]] = None,
ignore_paths: Optional[Set[str]] = None,
vary_headers: Optional[Set[str]] = None,
cache_control_header: bool = True,
compress_responses: bool = False,
endpoint_ttls: Optional[Dict[str, int]] = None,
auto_vary: bool = True,
invalidation_strategy: Optional[Union[TagBasedInvalidation, DependencyBasedInvalidation]] = None,
):
"""
Initialize the advanced cache middleware.
Args:
cache_backend: The cache backend to use (defaults to MemoryCacheBackend)
ttl: Default time-to-live in seconds
cache_methods: Set of HTTP methods to cache (defaults to {"GET"})
ignore_paths: Set of paths to exclude from caching
vary_headers: Set of headers to include in the cache key
cache_control_header: Whether to set Cache-Control headers
compress_responses: Whether to compress responses before caching
endpoint_ttls: Dictionary mapping endpoints to TTL values
auto_vary: Whether to automatically determine vary headers
invalidation_strategy: Strategy for cache invalidation
"""
super().__init__(
cache_backend=cache_backend,
ttl=ttl,
cache_methods=cache_methods,
ignore_paths=ignore_paths,
vary_headers=vary_headers,
cache_control_header=cache_control_header,
)
self.compress_responses = compress_responses
self.endpoint_ttls = endpoint_ttls or {}
self.auto_vary = auto_vary
self.invalidation_strategy = invalidation_strategy
# Set up compression if enabled
if self.compress_responses:
try:
import zlib
self._compress = lambda data: zlib.compress(data)
self._decompress = lambda data: zlib.decompress(data)
except ImportError:
logger.warning("zlib not available, compression disabled")
self.compress_responses = False
async def process_request(self, request: Request, call_next: Callable) -> Response:
"""
Process a request and potentially return a cached response.
Args:
request: The request object
call_next: The next middleware function
Returns:
The response object
"""
# Skip caching for non-cacheable requests
if not self._should_cache(request):
return await call_next(request)
# Generate a cache key for the request
cache_key = self._generate_cache_key(request)
# Try to get the response from the cache
cached_data = self.cache.get(cache_key)
if cached_data is not None:
# Decompress if necessary
if self.compress_responses and isinstance(cached_data, tuple) and len(cached_data) == 2:
compressed, response_dict = cached_data
if compressed:
try:
response_dict = json.loads(self._decompress(response_dict).decode())
except Exception as e:
logger.warning(f"Error decompressing cached response: {e}")
return await call_next(request)
else:
response_dict = cached_data
# Reconstruct the response
response = Response.from_dict(response_dict)
# Set cache headers
if self.cache_control_header:
response.headers["X-Cache"] = "HIT"
return response
# Get the response from the next middleware
try:
response = await call_next(request)
except TypeError:
# Handle non-awaitable responses
response = call_next(request)
# Skip caching for non-cacheable responses
if not self._should_cache_response(response):
return response
# Cache the response if it's cacheable
if self._should_cache_response(response):
# Get the TTL for this endpoint
endpoint_ttl = self._get_endpoint_ttl(request.path)
# Convert the response to a dictionary
response_dict = response.to_dict()
# Compress if enabled
if self.compress_responses and len(json.dumps(response_dict)) > 1024: # Only compress if > 1KB
try:
compressed_data = self._compress(json.dumps(response_dict).encode())
cache_data = (True, compressed_data)
except Exception as e:
logger.warning(f"Error compressing response: {e}")
cache_data = response_dict
else:
cache_data = response_dict
# Set cache headers
if self.cache_control_header:
response.headers["Cache-Control"] = f"max-age={endpoint_ttl}"
response.headers["X-Cache"] = "MISS"
# Store in cache
self.cache.set(cache_key, cache_data, ttl=endpoint_ttl)
return response
def _get_endpoint_ttl(self, path: str) -> int:
"""
Get the TTL for an endpoint.
Args:
path: The endpoint path
Returns:
The TTL in seconds
"""
# Check for exact path match
if path in self.endpoint_ttls:
return self.endpoint_ttls[path]
# Check for prefix match
for prefix, ttl in self.endpoint_ttls.items():
if prefix.endswith("*") and path.startswith(prefix[:-1]):
return ttl
# Return default TTL
return self.ttl
def _generate_cache_key(self, request: Request) -> str:
"""
Generate a cache key for a request.
Args:
request: The request object
Returns:
A cache key string
"""
# Get the method and path
key_parts = [request.method, request.path]
# Add query parameters
if request.query_params:
sorted_params = sorted(request.query_params.items())
key_parts.append(str(sorted_params))
# Add vary headers
vary_headers = set(self.vary_headers)
# Add auto vary headers if enabled
if self.auto_vary:
# Add content negotiation headers
for header in ["Accept", "Accept-Encoding", "Accept-Language"]:
if header in request.headers:
vary_headers.add(header)
# Add authorization if present
if "Authorization" in request.headers:
vary_headers.add("Authorization")
for header in vary_headers:
value = request.headers.get(header)
if value:
key_parts.append(f"{header}:{value}")
# Generate a hash of the key parts
key_str = ":".join(key_parts)
return hashlib.md5(key_str.encode()).hexdigest()
def invalidate(self, pattern: str) -> None:
"""
Invalidate cache entries matching a pattern.
Args:
pattern: The pattern to match
"""
if self.invalidation_strategy:
if isinstance(self.invalidation_strategy, TagBasedInvalidation):
self.invalidation_strategy.invalidate_tag(pattern)
elif isinstance(self.invalidation_strategy, DependencyBasedInvalidation):
self.invalidation_strategy.invalidate_dependency(pattern)
else:
# Fallback to direct cache invalidation
# This is a simplified implementation that only works with exact matches
self.cache_backend.delete(pattern)
def get_stats(self) -> Dict[str, Any]:
"""
Get cache statistics.
Returns:
A dictionary of cache statistics
"""
return self.cache.get_stats()
def clear(self) -> None:
"""Clear the cache."""
self.cache.clear()
# Export public classes
__all__ = [
"CacheItem",
"CacheEvictionPolicy",
"LRUEvictionPolicy",
"LFUEvictionPolicy",
"TTLEvictionPolicy",
"SizeEvictionPolicy",
"HybridEvictionPolicy",
"CacheBackend",
"MemoryCacheBackend",
"RedisCacheBackend",
"CacheMiddleware",
"AdvancedCacheMiddleware",
"TagBasedInvalidation",
"DependencyBasedInvalidation",
"CacheControl",
]