diff --git a/python/x402/src/x402/fastapi/middleware.py b/python/x402/src/x402/fastapi/middleware.py index c16f74c0d..3fa3acdeb 100644 --- a/python/x402/src/x402/fastapi/middleware.py +++ b/python/x402/src/x402/fastapi/middleware.py @@ -1,3 +1,4 @@ +import asyncio import base64 import json import logging @@ -5,7 +6,6 @@ from fastapi import Request from fastapi.responses import JSONResponse, HTMLResponse -from pydantic import validate_call from x402.common import ( process_price_to_atomic_amount, @@ -24,17 +24,72 @@ PaywallConfig, SupportedNetworks, HTTPInputSchema, + PriceOrHook, + StringOrHook, ) logger = logging.getLogger(__name__) -@validate_call +async def _resolve_value( + value: Any, + request: Request, + timeout: float = 5.0, + field_name: str = "value", +) -> Any: + """ + Resolve a value that could be a static value or an async hook. + + Args: + value: Static value or async callable hook + request: The incoming request object + timeout: Maximum seconds to wait for hook execution (default: 5.0) + field_name: Name of the field being resolved (for error messages) + + Returns: + The resolved value + + Raises: + asyncio.TimeoutError: If hook execution exceeds timeout + ValueError: If hook returns invalid value + """ + if callable(value): + try: + return await asyncio.wait_for(value(request), timeout=timeout) + except asyncio.TimeoutError: + logger.error(f"Hook for '{field_name}' timed out after {timeout}s") + raise + except Exception as e: + logger.error(f"Hook for '{field_name}' failed: {e}") + raise + return value + + +def _validate_and_process_price(price: Price, network: str): + """ + Validate and process a price into atomic amounts. + + Args: + price: The price to validate and process + network: The blockchain network + + Returns: + Tuple of (max_amount_required, asset_address, eip712_domain) + + Raises: + ValueError: If price is invalid + """ + try: + return process_price_to_atomic_amount(price, network) + except Exception as e: + raise ValueError(f"Invalid price: {price}. Error: {e}") + + def require_payment( - price: Price, + price: PriceOrHook, pay_to_address: str, path: str | list[str] = "*", - description: str = "", + description: StringOrHook = "", mime_type: str = "", max_deadline_seconds: int = 60, input_schema: Optional[HTTPInputSchema] = None, @@ -42,19 +97,24 @@ def require_payment( discoverable: Optional[bool] = True, facilitator_config: Optional[FacilitatorConfig] = None, network: str = "base-sepolia", - resource: Optional[str] = None, + resource: Optional[StringOrHook] = None, paywall_config: Optional[PaywallConfig] = None, custom_paywall_html: Optional[str] = None, ): """Generate a FastAPI middleware that gates payments for an endpoint. + This middleware supports both static payment requirements and dynamic requirements + via async hooks that compute values at runtime based on the request context. + Args: - price (Price): Payment price. Can be: - - Money: USD amount as string/int (e.g., "$3.10", 0.10, "0.001") - defaults to USDC - - TokenAmount: Custom token amount with asset information + price: Payment price. Can be: + - Static: Money (USD string/int like "$3.10", 0.10, "0.001") or TokenAmount + - Dynamic: async def get_price(request: Request) -> Price pay_to_address (str): Ethereum address to receive the payment path (str | list[str], optional): Path to gate with payments. Defaults to "*" for all paths. - description (str, optional): Description of what is being purchased. Defaults to "". + description: Human-readable description. Can be: + - Static: "Access to premium content" + - Dynamic: async def get_description(request: Request) -> str mime_type (str, optional): MIME type of the resource. Defaults to "". max_deadline_seconds (int, optional): Maximum time allowed for payment. Defaults to 60. input_schema (Optional[HTTPInputSchema], optional): Schema for the request structure. Defaults to None. @@ -63,13 +123,48 @@ def require_payment( facilitator_config (Optional[Dict[str, Any]], optional): Configuration for the payment facilitator. If not provided, defaults to the public x402.org facilitator. network (str, optional): Ethereum network ID. Defaults to "base-sepolia" (Base Sepolia testnet). - resource (Optional[str], optional): Resource URL. Defaults to None (uses request URL). + resource: Resource identifier. Can be: + - Static: "https://example.com/resource" + - Dynamic: async def get_resource(request: Request) -> str + - None: defaults to request.url paywall_config (Optional[PaywallConfig], optional): Configuration for paywall UI customization. Includes options like cdp_client_key, app_name, app_logo, session_token_endpoint. custom_paywall_html (Optional[str], optional): Custom HTML to display for paywall instead of default. Returns: Callable: FastAPI middleware function that checks for valid payment before processing requests + + Raises: + ValueError: If price or network configuration is invalid + + Example - Static pricing: + >>> app.middleware("http")( + ... require_payment( + ... price="$1.00", + ... pay_to_address="0x...", + ... description="Access to API" + ... ) + ... ) + + Example - Dynamic pricing: + >>> async def get_price(request: Request) -> str: + ... if "premium" in request.url.path: + ... return "$10.00" + ... return "$1.00" + >>> + >>> app.middleware("http")( + ... require_payment( + ... price=get_price, + ... pay_to_address="0x...", + ... description="Dynamic content pricing" + ... ) + ... ) + + Note: + - Async hooks have a 5-second timeout by default + - Hook failures return 500 Internal Server Error + - Static prices are validated at middleware creation time + - Dynamic prices are validated per-request """ # Validate network is supported @@ -79,22 +174,62 @@ def require_payment( f"Unsupported network: {network}. Must be one of: {supported_networks}" ) - try: - max_amount_required, asset_address, eip712_domain = ( - process_price_to_atomic_amount(price, network) - ) - except Exception as e: - raise ValueError(f"Invalid price: {price}. Error: {e}") + # Fail-fast for static price if it's not a hook + if not callable(price): + _validate_and_process_price(price, network) facilitator = FacilitatorClient(facilitator_config) + # Cache which values are hooks for performance optimization + is_price_hook = callable(price) + is_description_hook = callable(description) + is_resource_hook = callable(resource) + async def middleware(request: Request, call_next: Callable): # Skip if the path is not the same as the path in the middleware if not path_is_match(path, request.url.path): return await call_next(request) - # Get resource URL if not explicitly provided - resource_url = resource or str(request.url) + # Resolve dynamic values with optimized conditional execution + try: + if is_price_hook: + current_price = await _resolve_value( + price, request, field_name="price" + ) + else: + current_price = price + + if is_description_hook: + current_description = await _resolve_value( + description, request, field_name="description" + ) + else: + current_description = description + + if is_resource_hook: + current_resource = await _resolve_value( + resource, request, field_name="resource" + ) + else: + current_resource = resource or str(request.url) + + max_amount_required, asset_address, eip712_domain = ( + _validate_and_process_price(current_price, network) + ) + except asyncio.TimeoutError: + logger.error("Payment requirement hook timed out") + return JSONResponse( + status_code=500, + content={"error": "Request timeout processing payment requirements"}, + ) + except Exception as e: + logger.error(f"Failed to resolve payment requirements: {e}") + return JSONResponse( + status_code=500, + content={ + "error": "Internal server error resolving payment requirements" + }, + ) # Construct payment details payment_requirements = [ @@ -103,8 +238,8 @@ async def middleware(request: Request, call_next: Callable): network=cast(SupportedNetworks, network), asset=asset_address, max_amount_required=max_amount_required, - resource=resource_url, - description=description, + resource=current_resource, + description=current_description, mime_type=mime_type, pay_to=pay_to_address, max_timeout_seconds=max_deadline_seconds, diff --git a/python/x402/src/x402/types.py b/python/x402/src/x402/types.py index 236417aff..bc75c6ff3 100644 --- a/python/x402/src/x402/types.py +++ b/python/x402/src/x402/types.py @@ -2,13 +2,14 @@ from datetime import datetime from enum import Enum -from typing import Any, Optional, Union, Dict, Literal, List +from typing import Any, Optional, Union, Dict, Literal, List, Callable, Awaitable from typing_extensions import ( TypedDict, ) # use `typing_extensions.TypedDict` instead of `typing.TypedDict` on Python < 3.12 from pydantic import BaseModel, ConfigDict, Field, field_validator from pydantic.alias_generators import to_camel +from starlette.requests import Request from x402.networks import SupportedNetworks @@ -92,6 +93,12 @@ class EIP712Domain(BaseModel): Money = Union[str, int] # e.g., "$0.01", 0.01, "0.001" Price = Union[Money, TokenAmount] +# Dynamic Hook types for middleware +PriceHook = Callable[[Request], Awaitable[Price]] +StringHook = Callable[[Request], Awaitable[str]] +PriceOrHook = Union[Price, PriceHook] +StringOrHook = Union[str, StringHook] + class PaymentRequirements(BaseModel): scheme: str diff --git a/python/x402/tests/fastapi_tests/test_dynamic_hooks.py b/python/x402/tests/fastapi_tests/test_dynamic_hooks.py new file mode 100644 index 000000000..fec97009e --- /dev/null +++ b/python/x402/tests/fastapi_tests/test_dynamic_hooks.py @@ -0,0 +1,260 @@ +from fastapi import FastAPI, Request +from fastapi.testclient import TestClient +from x402.fastapi.middleware import require_payment + + +async def mock_endpoint(): + return {"message": "success"} + + +def test_dynamic_price_hook(): + app = FastAPI() + + async def get_price(request: Request): + # Manual path parsing since path_params are not available in middleware + if "/items/premium" in request.url.path: + return "$10.00" + return "$1.00" + + app.get("/items/{item_id}")(mock_endpoint) + app.middleware("http")( + require_payment( + price=get_price, + pay_to_address="0x1111111111111111111111111111111111111111", + network="base-sepolia", + description="Dynamic item", + ) + ) + + client = TestClient(app) + + # Test standard item + response = client.get("/items/standard") + assert response.status_code == 402 + # 1.00 USD on base-sepolia (USDC 6 decimals) -> 1,000,000 + assert response.json()["accepts"][0]["maxAmountRequired"] == "1000000" + + # Test premium item + response = client.get("/items/premium") + assert response.status_code == 402 + assert response.json()["accepts"][0]["maxAmountRequired"] == "10000000" + + +def test_dynamic_description_hook(): + app = FastAPI() + + async def get_desc(request: Request): + item_id = request.url.path.split("/")[-1] + return f"Buying item {item_id}" + + app.get("/items/{item_id}")(mock_endpoint) + app.middleware("http")( + require_payment( + price="$1.00", + pay_to_address="0x1111111111111111111111111111111111111111", + network="base-sepolia", + description=get_desc, + ) + ) + + client = TestClient(app) + + response = client.get("/items/apple") + assert response.status_code == 402 + assert response.json()["accepts"][0]["description"] == "Buying item apple" + + response = client.get("/items/orange") + assert response.status_code == 402 + assert response.json()["accepts"][0]["description"] == "Buying item orange" + + +def test_hook_failure_returns_500(): + app = FastAPI() + + async def failing_hook(request: Request): + raise ValueError("Something went wrong") + + app.get("/test")(mock_endpoint) + app.middleware("http")( + require_payment( + price=failing_hook, + pay_to_address="0x1111111111111111111111111111111111111111", + network="base-sepolia", + ) + ) + + client = TestClient(app) + response = client.get("/test") + assert response.status_code == 500 + assert "error" in response.json() + + +def test_dynamic_resource_hook(): + """Test that resource can be dynamically generated from request.""" + app = FastAPI() + + async def get_resource(request: Request): + item_id = request.url.path.split("/")[-1] + return f"https://example.com/items/{item_id}" + + app.get("/items/{item_id}")(mock_endpoint) + app.middleware("http")( + require_payment( + price="$1.00", + pay_to_address="0x1111111111111111111111111111111111111111", + network="base-sepolia", + resource=get_resource, + ) + ) + + client = TestClient(app) + + response = client.get("/items/abc123") + assert response.status_code == 402 + assert ( + response.json()["accepts"][0]["resource"] + == "https://example.com/items/abc123" + ) + + +def test_hook_with_invalid_price_format(): + """Test that hooks returning invalid price formats are handled properly.""" + app = FastAPI() + + async def invalid_price_hook(request: Request): + return "not-a-valid-price" + + app.get("/test")(mock_endpoint) + app.middleware("http")( + require_payment( + price=invalid_price_hook, + pay_to_address="0x1111111111111111111111111111111111111111", + network="base-sepolia", + ) + ) + + client = TestClient(app) + response = client.get("/test") + assert response.status_code == 500 + assert "error" in response.json() + + +def test_hook_timeout(): + """Test that hooks exceeding timeout are handled properly.""" + import asyncio + + app = FastAPI() + + async def slow_hook(request: Request): + await asyncio.sleep(10) + return "$1.00" + + app.get("/test")(mock_endpoint) + app.middleware("http")( + require_payment( + price=slow_hook, + pay_to_address="0x1111111111111111111111111111111111111111", + network="base-sepolia", + ) + ) + + client = TestClient(app) + response = client.get("/test") + assert response.status_code == 500 + assert "error" in response.json() + assert "timeout" in response.json()["error"].lower() + + +def test_hook_returns_none(): + """Test that hooks returning None are handled properly.""" + app = FastAPI() + + async def none_hook(request: Request): + return None + + app.get("/test")(mock_endpoint) + app.middleware("http")( + require_payment( + price=none_hook, + pay_to_address="0x1111111111111111111111111111111111111111", + network="base-sepolia", + ) + ) + + client = TestClient(app) + response = client.get("/test") + assert response.status_code == 500 + + +def test_concurrent_requests_with_different_prices(): + """Test that concurrent requests get correct independent pricing.""" + import asyncio + from concurrent.futures import ThreadPoolExecutor + + app = FastAPI() + + async def get_price(request: Request): + if "expensive" in request.url.path: + return "$100.00" + return "$1.00" + + app.get("/items/{item_id}")(mock_endpoint) + app.middleware("http")( + require_payment( + price=get_price, + pay_to_address="0x1111111111111111111111111111111111111111", + network="base-sepolia", + ) + ) + + client = TestClient(app) + + def make_request(path): + return client.get(path) + + with ThreadPoolExecutor(max_workers=2) as executor: + future1 = executor.submit(make_request, "/items/cheap") + future2 = executor.submit(make_request, "/items/expensive") + + response1 = future1.result() + response2 = future2.result() + + assert response1.status_code == 402 + assert response2.status_code == 402 + assert response1.json()["accepts"][0]["maxAmountRequired"] == "1000000" + assert response2.json()["accepts"][0]["maxAmountRequired"] == "100000000" + + +def test_all_hooks_combined(): + """Test using price, description, and resource hooks simultaneously.""" + app = FastAPI() + + async def get_price(request: Request): + return "$5.00" + + async def get_description(request: Request): + item = request.url.path.split("/")[-1] + return f"Payment for {item}" + + async def get_resource(request: Request): + return f"custom-resource://{request.url.path}" + + app.get("/items/{item_id}")(mock_endpoint) + app.middleware("http")( + require_payment( + price=get_price, + pay_to_address="0x1111111111111111111111111111111111111111", + network="base-sepolia", + description=get_description, + resource=get_resource, + ) + ) + + client = TestClient(app) + response = client.get("/items/widget") + + assert response.status_code == 402 + payment_info = response.json()["accepts"][0] + assert payment_info["maxAmountRequired"] == "5000000" + assert payment_info["description"] == "Payment for widget" + assert payment_info["resource"] == "custom-resource:///items/widget" diff --git a/python/x402/tests/fastapi_tests/test_middleware.py b/python/x402/tests/fastapi_tests/test_middleware.py index 15c64706c..af5e589cf 100644 --- a/python/x402/tests/fastapi_tests/test_middleware.py +++ b/python/x402/tests/fastapi_tests/test_middleware.py @@ -4,13 +4,13 @@ from x402.types import PaywallConfig -async def test_endpoint(): +async def mock_endpoint(): return {"message": "success"} def test_middleware_invalid_payment(): app_with_middleware = FastAPI() - app_with_middleware.get("/test")(test_endpoint) + app_with_middleware.get("/test")(mock_endpoint) app_with_middleware.middleware("http")( require_payment( price="$1.00", @@ -30,8 +30,8 @@ def test_middleware_invalid_payment(): def test_app_middleware_path_matching(): app_with_middleware = FastAPI() - app_with_middleware.get("/test")(test_endpoint) - app_with_middleware.get("/unprotected")(test_endpoint) + app_with_middleware.get("/test")(mock_endpoint) + app_with_middleware.get("/unprotected")(mock_endpoint) app_with_middleware.middleware("http")( require_payment( @@ -57,9 +57,9 @@ def test_app_middleware_path_matching(): def test_middleware_path_list_matching(): app_with_middleware = FastAPI() - app_with_middleware.get("/test1")(test_endpoint) - app_with_middleware.get("/test2")(test_endpoint) - app_with_middleware.get("/unprotected")(test_endpoint) + app_with_middleware.get("/test1")(mock_endpoint) + app_with_middleware.get("/test2")(mock_endpoint) + app_with_middleware.get("/unprotected")(mock_endpoint) app_with_middleware.middleware("http")( require_payment( @@ -320,7 +320,7 @@ def test_abusive_url_paths(): def test_browser_request_returns_html(): """Test that browser requests return HTML paywall instead of JSON.""" app = FastAPI() - app.get("/protected")(test_endpoint) + app.get("/protected")(mock_endpoint) app.middleware("http")( require_payment( price="$1.00", @@ -350,7 +350,7 @@ def test_browser_request_returns_html(): def test_api_client_request_returns_json(): """Test that API client requests return JSON response.""" app = FastAPI() - app.get("/protected")(test_endpoint) + app.get("/protected")(mock_endpoint) app.middleware("http")( require_payment( price="$1.00", @@ -386,7 +386,7 @@ def test_paywall_config_injection(): } app = FastAPI() - app.get("/protected")(test_endpoint) + app.get("/protected")(mock_endpoint) app.middleware("http")( require_payment( price="$2.50", @@ -431,7 +431,7 @@ def test_custom_paywall_html(): """ app = FastAPI() - app.get("/protected")(test_endpoint) + app.get("/protected")(mock_endpoint) app.middleware("http")( require_payment( price="$1.00", @@ -462,7 +462,7 @@ def test_mainnet_vs_testnet_config(): """Test that mainnet vs testnet is properly configured.""" # Test testnet (base-sepolia) app_testnet = FastAPI() - app_testnet.get("/protected")(test_endpoint) + app_testnet.get("/protected")(mock_endpoint) app_testnet.middleware("http")( require_payment( price="$1.00", @@ -474,7 +474,7 @@ def test_mainnet_vs_testnet_config(): # Test mainnet (base) app_mainnet = FastAPI() - app_mainnet.get("/protected")(test_endpoint) + app_mainnet.get("/protected")(mock_endpoint) app_mainnet.middleware("http")( require_payment( price="$1.00", @@ -508,7 +508,7 @@ def test_mainnet_vs_testnet_config(): def test_payment_amount_conversion(): """Test that payment amounts are properly converted to display values.""" app = FastAPI() - app.get("/protected")(test_endpoint) + app.get("/protected")(mock_endpoint) app.middleware("http")( require_payment( price="$0.001", # Small amount