"""
Security headers for APIFromAnything.
This module provides middleware and utilities for adding security headers to API responses,
including Content Security Policy (CSP), X-XSS-Protection, and other security headers
to protect against various web vulnerabilities.
"""
import re
import uuid
from typing import Callable, Dict, List, Optional, Set, Union, Any
from apifrom.core.request import Request
from apifrom.core.response import Response
from apifrom.middleware.base import BaseMiddleware
[docs]
class CSPDirective:
"""
Content Security Policy directive builder.
This class provides a fluent interface for building CSP directives.
"""
def __init__(self, name: str):
"""
Initialize the CSP directive.
Args:
name: The name of the directive (e.g., 'default-src', 'script-src')
"""
self.name = name
self.values: Set[str] = set()
self._nonce: Optional[str] = None
def allow_self(self) -> 'CSPDirective':
"""
Allow content from the same origin.
Returns:
The CSP directive instance for method chaining
"""
self.values.add("'self'")
return self
def allow_none(self) -> 'CSPDirective':
"""
Disallow content from any source.
Returns:
The CSP directive instance for method chaining
"""
self.values.add("'none'")
return self
def allow_unsafe_inline(self) -> 'CSPDirective':
"""
Allow inline content (not recommended for production).
Returns:
The CSP directive instance for method chaining
"""
self.values.add("'unsafe-inline'")
return self
def allow_unsafe_eval(self) -> 'CSPDirective':
"""
Allow eval() and similar functions (not recommended for production).
Returns:
The CSP directive instance for method chaining
"""
self.values.add("'unsafe-eval'")
return self
def allow_strict_dynamic(self) -> 'CSPDirective':
"""
Allow scripts with the correct nonce to load additional scripts.
Returns:
The CSP directive instance for method chaining
"""
self.values.add("'strict-dynamic'")
return self
def allow_sources(self, *sources: str) -> 'CSPDirective':
"""
Allow content from specific sources.
Args:
*sources: The sources to allow (e.g., 'https://example.com', '*.example.com')
Returns:
The CSP directive instance for method chaining
"""
for source in sources:
self.values.add(source)
return self
def allow_nonce(self, nonce: Optional[str] = None) -> 'CSPDirective':
"""
Allow content with a specific nonce.
Args:
nonce: The nonce to allow (if None, a random nonce will be generated)
Returns:
The CSP directive instance for method chaining
"""
if nonce is None:
nonce = self._generate_nonce()
self._nonce = nonce
self.values.add(f"'nonce-{nonce}'")
return self
def _generate_nonce(self) -> str:
"""
Generate a random nonce.
Returns:
A random nonce
"""
return uuid.uuid4().hex
def get_nonce(self) -> Optional[str]:
"""
Get the nonce for this directive.
Returns:
The nonce, or None if no nonce has been set
"""
return self._nonce
def to_string(self) -> str:
"""
Convert the directive to a string.
Returns:
The directive as a string
"""
if not self.values:
return ""
return f"{self.name} {' '.join(sorted(self.values))}"
[docs]
class ContentSecurityPolicy:
"""
Content Security Policy builder.
This class provides a fluent interface for building Content Security Policies.
"""
def __init__(self):
"""
Initialize the Content Security Policy.
"""
self.directives: Dict[str, CSPDirective] = {}
self.report_only = False
self.report_uri = None
def add_directive(self, directive: CSPDirective) -> 'ContentSecurityPolicy':
"""
Add a directive to the policy.
Args:
directive: The directive to add
Returns:
The Content Security Policy instance for method chaining
"""
self.directives[directive.name] = directive
return self
def default_src(self) -> CSPDirective:
"""
Get or create the default-src directive.
Returns:
The default-src directive
"""
if "default-src" not in self.directives:
self.directives["default-src"] = CSPDirective("default-src")
return self.directives["default-src"]
def script_src(self) -> CSPDirective:
"""
Get or create the script-src directive.
Returns:
The script-src directive
"""
if "script-src" not in self.directives:
self.directives["script-src"] = CSPDirective("script-src")
return self.directives["script-src"]
def style_src(self) -> CSPDirective:
"""
Get or create the style-src directive.
Returns:
The style-src directive
"""
if "style-src" not in self.directives:
self.directives["style-src"] = CSPDirective("style-src")
return self.directives["style-src"]
def img_src(self) -> CSPDirective:
"""
Get or create the img-src directive.
Returns:
The img-src directive
"""
if "img-src" not in self.directives:
self.directives["img-src"] = CSPDirective("img-src")
return self.directives["img-src"]
def connect_src(self) -> CSPDirective:
"""
Get or create the connect-src directive.
Returns:
The connect-src directive
"""
if "connect-src" not in self.directives:
self.directives["connect-src"] = CSPDirective("connect-src")
return self.directives["connect-src"]
def font_src(self) -> CSPDirective:
"""
Get or create the font-src directive.
Returns:
The font-src directive
"""
if "font-src" not in self.directives:
self.directives["font-src"] = CSPDirective("font-src")
return self.directives["font-src"]
def object_src(self) -> CSPDirective:
"""
Get or create the object-src directive.
Returns:
The object-src directive
"""
if "object-src" not in self.directives:
self.directives["object-src"] = CSPDirective("object-src")
return self.directives["object-src"]
def media_src(self) -> CSPDirective:
"""
Get or create the media-src directive.
Returns:
The media-src directive
"""
if "media-src" not in self.directives:
self.directives["media-src"] = CSPDirective("media-src")
return self.directives["media-src"]
def frame_src(self) -> CSPDirective:
"""
Get or create the frame-src directive.
Returns:
The frame-src directive
"""
if "frame-src" not in self.directives:
self.directives["frame-src"] = CSPDirective("frame-src")
return self.directives["frame-src"]
def worker_src(self) -> CSPDirective:
"""
Get or create the worker-src directive.
Returns:
The worker-src directive
"""
if "worker-src" not in self.directives:
self.directives["worker-src"] = CSPDirective("worker-src")
return self.directives["worker-src"]
def manifest_src(self) -> CSPDirective:
"""
Get or create the manifest-src directive.
Returns:
The manifest-src directive
"""
if "manifest-src" not in self.directives:
self.directives["manifest-src"] = CSPDirective("manifest-src")
return self.directives["manifest-src"]
def frame_ancestors(self) -> CSPDirective:
"""
Get or create the frame-ancestors directive.
Returns:
The frame-ancestors directive
"""
if "frame-ancestors" not in self.directives:
self.directives["frame-ancestors"] = CSPDirective("frame-ancestors")
return self.directives["frame-ancestors"]
def form_action(self) -> CSPDirective:
"""
Get or create the form-action directive.
Returns:
The form-action directive
"""
if "form-action" not in self.directives:
self.directives["form-action"] = CSPDirective("form-action")
return self.directives["form-action"]
def base_uri(self) -> CSPDirective:
"""
Get or create the base-uri directive.
Returns:
The base-uri directive
"""
if "base-uri" not in self.directives:
self.directives["base-uri"] = CSPDirective("base-uri")
return self.directives["base-uri"]
def set_report_only(self, report_only: bool = True) -> 'ContentSecurityPolicy':
"""
Set whether the policy is report-only.
Args:
report_only: Whether the policy is report-only
Returns:
The Content Security Policy instance for method chaining
"""
self.report_only = report_only
return self
def set_report_uri(self, uri: str) -> 'ContentSecurityPolicy':
"""
Set the report URI.
Args:
uri: The report URI
Returns:
The Content Security Policy instance for method chaining
"""
self.report_uri = uri
return self
def to_string(self) -> str:
"""
Convert the policy to a string.
Returns:
The policy as a string
"""
directives = []
for directive in self.directives.values():
directive_str = directive.to_string()
if directive_str:
directives.append(directive_str)
if self.report_uri:
directives.append(f"report-uri {self.report_uri}")
return "; ".join(directives)
def get_header_name(self) -> str:
"""
Get the header name for the policy.
Returns:
The header name
"""
if self.report_only:
return "Content-Security-Policy-Report-Only"
else:
return "Content-Security-Policy"
@classmethod
def create_strict_policy(cls) -> 'ContentSecurityPolicy':
"""
Create a strict Content Security Policy.
Returns:
A strict Content Security Policy
"""
policy = cls()
# Set default-src to 'self'
policy.default_src().allow_self()
# Disallow object-src to prevent plugin execution
policy.object_src().allow_none()
# Restrict base-uri to 'self'
policy.base_uri().allow_self()
# Prevent framing from other sites
policy.frame_ancestors().allow_self()
# Restrict form submissions to 'self'
policy.form_action().allow_self()
return policy
@classmethod
def create_api_policy(cls) -> 'ContentSecurityPolicy':
"""
Create a Content Security Policy suitable for APIs.
Returns:
A Content Security Policy suitable for APIs
"""
policy = cls()
# APIs typically don't need CSP, but we can set a minimal policy
policy.default_src().allow_none()
return policy
[docs]
class ReferrerPolicy:
"""
Referrer Policy values.
"""
NO_REFERRER = "no-referrer"
NO_REFERRER_WHEN_DOWNGRADE = "no-referrer-when-downgrade"
ORIGIN = "origin"
ORIGIN_WHEN_CROSS_ORIGIN = "origin-when-cross-origin"
SAME_ORIGIN = "same-origin"
STRICT_ORIGIN = "strict-origin"
STRICT_ORIGIN_WHEN_CROSS_ORIGIN = "strict-origin-when-cross-origin"
UNSAFE_URL = "unsafe-url"
[docs]
class XSSProtection:
"""
X-XSS-Protection values.
"""
DISABLED = "0"
ENABLED = "1"
ENABLED_BLOCK = "1; mode=block"
ENABLED_REPORT = "1; report="
class SecurityHeadersMiddleware(BaseMiddleware):
"""
Middleware for adding security headers to responses.
"""
def __init__(
self,
content_security_policy: Optional[ContentSecurityPolicy] = None,
x_frame_options: str = "DENY",
x_content_type_options: str = "nosniff",
referrer_policy: str = ReferrerPolicy.STRICT_ORIGIN_WHEN_CROSS_ORIGIN,
x_xss_protection: str = XSSProtection.ENABLED_BLOCK,
strict_transport_security: str = "max-age=31536000; includeSubDomains",
permissions_policy: Optional[Dict[str, List[str]]] = None,
cache_control: Optional[str] = None,
exempt_paths: Optional[List[str]] = None,
exempt_content_types: Optional[List[str]] = None,
):
"""
Initialize the security headers middleware.
Args:
content_security_policy: The Content Security Policy to use
x_frame_options: The X-Frame-Options header value
x_content_type_options: The X-Content-Type-Options header value
referrer_policy: The Referrer-Policy header value
x_xss_protection: The X-XSS-Protection header value
strict_transport_security: The Strict-Transport-Security header value
permissions_policy: The Permissions-Policy header value
cache_control: The Cache-Control header value
exempt_paths: Paths exempt from security headers
exempt_content_types: Content types exempt from security headers
"""
super().__init__()
self.content_security_policy = content_security_policy
self.x_frame_options = x_frame_options
self.x_content_type_options = x_content_type_options
self.referrer_policy = referrer_policy
self.x_xss_protection = x_xss_protection
self.strict_transport_security = strict_transport_security
self.permissions_policy = permissions_policy or {}
self.cache_control = cache_control
self.exempt_paths = exempt_paths or []
self.exempt_content_types = exempt_content_types or []
def _is_exempt(self, request: Request, response: Response) -> bool:
"""
Check if a request/response is exempt from security headers.
Args:
request: The request
response: The response
Returns:
True if the request/response is exempt, False otherwise
"""
# Check if the path is exempt
for path in self.exempt_paths:
if re.match(path, request.path):
return True
# Check if the content type is exempt
content_type = response.headers.get("Content-Type", "")
for exempt_type in self.exempt_content_types:
if exempt_type in content_type:
return True
return False
def _build_permissions_policy(self) -> str:
"""
Build the Permissions-Policy header value.
Returns:
The Permissions-Policy header value
"""
directives = []
for feature, origins in self.permissions_policy.items():
if not origins:
directives.append(f"{feature}=()")
else:
origins_str = " ".join(f'"{origin}"' for origin in origins)
directives.append(f"{feature}=({origins_str})")
return ", ".join(directives)
def _add_security_headers(self, response: Response) -> None:
"""
Add security headers to a response.
Args:
response: The response to add headers to
"""
# Add Content Security Policy
if self.content_security_policy:
header_name = self.content_security_policy.get_header_name()
response.headers[header_name] = self.content_security_policy.to_string()
# Add X-Frame-Options
if self.x_frame_options:
response.headers["X-Frame-Options"] = self.x_frame_options
# Add X-Content-Type-Options
if self.x_content_type_options:
response.headers["X-Content-Type-Options"] = self.x_content_type_options
# Add Referrer-Policy
if self.referrer_policy:
response.headers["Referrer-Policy"] = self.referrer_policy
# Add X-XSS-Protection
if self.x_xss_protection:
response.headers["X-XSS-Protection"] = self.x_xss_protection
# Add Strict-Transport-Security
if self.strict_transport_security:
response.headers["Strict-Transport-Security"] = self.strict_transport_security
# Add Permissions-Policy
if self.permissions_policy:
response.headers["Permissions-Policy"] = self._build_permissions_policy()
# Add Cache-Control
if self.cache_control:
response.headers["Cache-Control"] = self.cache_control
async def process_request(self, request: Request) -> Request:
"""
Process a request through the security headers middleware.
Args:
request: The request to process
Returns:
The processed request
"""
# This middleware doesn't modify the request, just passes it through
return request
async def process_response(self, response: Response) -> Response:
"""
Process a response through the security headers middleware.
Args:
response: The response to process
Returns:
The processed response
"""
# Check if the request/response is exempt
if not self._is_exempt(response.request, response):
# Add security headers
self._add_security_headers(response)
return response
[docs]
class XSSFilter:
"""
Filter for preventing Cross-Site Scripting (XSS) attacks.
"""
@staticmethod
def sanitize_html(html: str, allowed_tags: Optional[Set[str]] = None, allowed_attributes: Optional[Dict[str, Set[str]]] = None) -> str:
"""
Sanitize HTML to prevent XSS attacks.
Args:
html: The HTML to sanitize
allowed_tags: The allowed HTML tags
allowed_attributes: The allowed HTML attributes for each tag
Returns:
The sanitized HTML
"""
try:
import bleach
# Default allowed tags and attributes
if allowed_tags is None:
allowed_tags = {
"a", "abbr", "acronym", "b", "blockquote", "code", "em", "i", "li",
"ol", "p", "strong", "ul", "br", "hr", "span", "div", "h1", "h2",
"h3", "h4", "h5", "h6", "table", "thead", "tbody", "tr", "th", "td",
}
if allowed_attributes is None:
allowed_attributes = {
"a": {"href", "title", "target", "rel"},
"abbr": {"title"},
"acronym": {"title"},
"*": {"class", "id", "style"},
}
# Convert allowed_attributes to the format expected by bleach
bleach_attrs = {}
for tag, attrs in allowed_attributes.items():
bleach_attrs[tag] = list(attrs)
# Sanitize the HTML
return bleach.clean(
html,
tags=list(allowed_tags),
attributes=bleach_attrs,
strip=True,
)
except ImportError:
# If bleach is not available, use a simple regex-based approach
# This is not as secure as bleach, but it's better than nothing
import re
# First, make a copy of the HTML to work with
sanitized_html = html
# Remove all tags except allowed ones
if allowed_tags:
# Create a pattern that matches any tag not in the allowed list
allowed_tags_str = "|".join(allowed_tags)
pattern = f"<(?!/?({allowed_tags_str})\\b)[^>]*>"
sanitized_html = re.sub(pattern, "", sanitized_html, flags=re.IGNORECASE)
# Now clean up any dangerous attributes from allowed tags
dangerous_attrs = ["on\\w+", "style", "javascript:", "vbscript:"]
for attr in dangerous_attrs:
sanitized_html = re.sub(f"\\s{attr}=['\"][^'\"]*['\"]", "", sanitized_html, flags=re.IGNORECASE)
else:
# If no allowed tags, just strip all tags but keep the content
sanitized_html = re.sub(r"<script\b[^>]*>(.*?)</script>", "", sanitized_html, flags=re.IGNORECASE | re.DOTALL)
sanitized_html = re.sub(r"<[^>]*>", "", sanitized_html)
return sanitized_html
@staticmethod
def escape_html(text: str) -> str:
"""
Escape HTML special characters to prevent XSS attacks.
Args:
text: The text to escape
Returns:
The escaped text
"""
return (
text.replace("&", "&")
.replace("<", "<")
.replace(">", ">")
.replace('"', """)
.replace("'", "'")
)
@staticmethod
def sanitize_json(data: Any) -> Any:
"""
Sanitize JSON data to prevent XSS attacks.
Args:
data: The JSON data to sanitize
Returns:
The sanitized JSON data
"""
if isinstance(data, str):
# Escape HTML in strings
return XSSFilter.escape_html(data)
elif isinstance(data, dict):
# Recursively sanitize dictionaries
return {k: XSSFilter.sanitize_json(v) for k, v in data.items()}
elif isinstance(data, list):
# Recursively sanitize lists
return [XSSFilter.sanitize_json(item) for item in data]
else:
# Return other types as-is
return data
class XSSProtectionMiddleware(BaseMiddleware):
"""
Middleware for preventing Cross-Site Scripting (XSS) attacks.
"""
def __init__(
self,
sanitize_json_response: bool = True,
sanitize_html_response: bool = False,
allowed_html_tags: Optional[Set[str]] = None,
allowed_html_attributes: Optional[Dict[str, Set[str]]] = None,
exempt_paths: Optional[List[str]] = None,
exempt_content_types: Optional[List[str]] = None,
):
"""
Initialize the XSS protection middleware.
Args:
sanitize_json_response: Whether to sanitize JSON responses
sanitize_html_response: Whether to sanitize HTML responses
allowed_html_tags: The allowed HTML tags for sanitization
allowed_html_attributes: The allowed HTML attributes for sanitization
exempt_paths: Paths exempt from XSS protection
exempt_content_types: Content types exempt from XSS protection
"""
super().__init__()
self.sanitize_json_response = sanitize_json_response
self.sanitize_html_response = sanitize_html_response
self.allowed_html_tags = allowed_html_tags
self.allowed_html_attributes = allowed_html_attributes
self.exempt_paths = exempt_paths or []
self.exempt_content_types = exempt_content_types or []
def _is_exempt(self, request: Request, response: Response) -> bool:
"""
Check if a request/response is exempt from XSS protection.
Args:
request: The request
response: The response
Returns:
True if the request/response is exempt, False otherwise
"""
# Check if the path is exempt
for path in self.exempt_paths:
if re.match(path, request.path):
return True
# Check if the content type is exempt
content_type = response.headers.get("Content-Type", "")
for exempt_type in self.exempt_content_types:
if exempt_type in content_type:
return True
return False
def _sanitize_response(self, response: Response) -> None:
"""
Sanitize a response to prevent XSS attacks.
Args:
response: The response to sanitize
"""
content_type = response.headers.get("Content-Type", "")
if "application/json" in content_type and self.sanitize_json_response:
# Sanitize JSON response
if hasattr(response, "body") and response.body:
response.body = XSSFilter.sanitize_json(response.body)
elif "text/html" in content_type and self.sanitize_html_response:
# Sanitize HTML response
if hasattr(response, "body") and isinstance(response.body, str):
response.body = XSSFilter.sanitize_html(
response.body,
allowed_tags=self.allowed_html_tags,
allowed_attributes=self.allowed_html_attributes,
)
async def process_request(self, request: Request) -> Request:
"""
Process a request through the XSS protection middleware.
Args:
request: The request to process
Returns:
The processed request
"""
# This middleware doesn't modify the request, just passes it through
return request
async def process_response(self, response: Response) -> Response:
"""
Process a response through the XSS protection middleware.
Args:
response: The response to process
Returns:
The processed response
"""
# Check if the request/response is exempt
if not self._is_exempt(response.request, response):
# Sanitize the response
self._sanitize_response(response)
return response