Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 155 additions & 20 deletions python/x402/src/x402/fastapi/middleware.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import asyncio
import base64
import json
import logging
from typing import Any, Callable, Optional, get_args, cast

from fastapi import Request
from fastapi.responses import JSONResponse, HTMLResponse
from pydantic import validate_call

from x402.common import (
process_price_to_atomic_amount,
Expand All @@ -24,37 +24,97 @@
PaywallConfig,
SupportedNetworks,
HTTPInputSchema,
PriceOrHook,
StringOrHook,
)

logger = logging.getLogger(__name__)


@validate_call
async def _resolve_value(
value: Any,
request: Request,
timeout: float = 5.0,
field_name: str = "value",
) -> Any:
"""
Resolve a value that could be a static value or an async hook.

Args:
value: Static value or async callable hook
request: The incoming request object
timeout: Maximum seconds to wait for hook execution (default: 5.0)
field_name: Name of the field being resolved (for error messages)

Returns:
The resolved value

Raises:
asyncio.TimeoutError: If hook execution exceeds timeout
ValueError: If hook returns invalid value
"""
if callable(value):
try:
return await asyncio.wait_for(value(request), timeout=timeout)
except asyncio.TimeoutError:
logger.error(f"Hook for '{field_name}' timed out after {timeout}s")
raise
except Exception as e:
logger.error(f"Hook for '{field_name}' failed: {e}")
raise
return value


def _validate_and_process_price(price: Price, network: str):
"""
Validate and process a price into atomic amounts.

Args:
price: The price to validate and process
network: The blockchain network

Returns:
Tuple of (max_amount_required, asset_address, eip712_domain)

Raises:
ValueError: If price is invalid
"""
try:
return process_price_to_atomic_amount(price, network)
except Exception as e:
raise ValueError(f"Invalid price: {price}. Error: {e}")


def require_payment(
price: Price,
price: PriceOrHook,
pay_to_address: str,
path: str | list[str] = "*",
description: str = "",
description: StringOrHook = "",
mime_type: str = "",
max_deadline_seconds: int = 60,
input_schema: Optional[HTTPInputSchema] = None,
output_schema: Optional[Any] = None,
discoverable: Optional[bool] = True,
facilitator_config: Optional[FacilitatorConfig] = None,
network: str = "base-sepolia",
resource: Optional[str] = None,
resource: Optional[StringOrHook] = None,
paywall_config: Optional[PaywallConfig] = None,
custom_paywall_html: Optional[str] = None,
):
"""Generate a FastAPI middleware that gates payments for an endpoint.

This middleware supports both static payment requirements and dynamic requirements
via async hooks that compute values at runtime based on the request context.

Args:
price (Price): Payment price. Can be:
- Money: USD amount as string/int (e.g., "$3.10", 0.10, "0.001") - defaults to USDC
- TokenAmount: Custom token amount with asset information
price: Payment price. Can be:
- Static: Money (USD string/int like "$3.10", 0.10, "0.001") or TokenAmount
- Dynamic: async def get_price(request: Request) -> Price
pay_to_address (str): Ethereum address to receive the payment
path (str | list[str], optional): Path to gate with payments. Defaults to "*" for all paths.
description (str, optional): Description of what is being purchased. Defaults to "".
description: Human-readable description. Can be:
- Static: "Access to premium content"
- Dynamic: async def get_description(request: Request) -> str
mime_type (str, optional): MIME type of the resource. Defaults to "".
max_deadline_seconds (int, optional): Maximum time allowed for payment. Defaults to 60.
input_schema (Optional[HTTPInputSchema], optional): Schema for the request structure. Defaults to None.
Expand All @@ -63,13 +123,48 @@ def require_payment(
facilitator_config (Optional[Dict[str, Any]], optional): Configuration for the payment facilitator.
If not provided, defaults to the public x402.org facilitator.
network (str, optional): Ethereum network ID. Defaults to "base-sepolia" (Base Sepolia testnet).
resource (Optional[str], optional): Resource URL. Defaults to None (uses request URL).
resource: Resource identifier. Can be:
- Static: "https://example.com/resource"
- Dynamic: async def get_resource(request: Request) -> str
- None: defaults to request.url
paywall_config (Optional[PaywallConfig], optional): Configuration for paywall UI customization.
Includes options like cdp_client_key, app_name, app_logo, session_token_endpoint.
custom_paywall_html (Optional[str], optional): Custom HTML to display for paywall instead of default.

Returns:
Callable: FastAPI middleware function that checks for valid payment before processing requests

Raises:
ValueError: If price or network configuration is invalid

Example - Static pricing:
>>> app.middleware("http")(
... require_payment(
... price="$1.00",
... pay_to_address="0x...",
... description="Access to API"
... )
... )

Example - Dynamic pricing:
>>> async def get_price(request: Request) -> str:
... if "premium" in request.url.path:
... return "$10.00"
... return "$1.00"
>>>
>>> app.middleware("http")(
... require_payment(
... price=get_price,
... pay_to_address="0x...",
... description="Dynamic content pricing"
... )
... )

Note:
- Async hooks have a 5-second timeout by default
- Hook failures return 500 Internal Server Error
- Static prices are validated at middleware creation time
- Dynamic prices are validated per-request
"""

# Validate network is supported
Expand All @@ -79,22 +174,62 @@ def require_payment(
f"Unsupported network: {network}. Must be one of: {supported_networks}"
)

try:
max_amount_required, asset_address, eip712_domain = (
process_price_to_atomic_amount(price, network)
)
except Exception as e:
raise ValueError(f"Invalid price: {price}. Error: {e}")
# Fail-fast for static price if it's not a hook
if not callable(price):
_validate_and_process_price(price, network)

facilitator = FacilitatorClient(facilitator_config)

# Cache which values are hooks for performance optimization
is_price_hook = callable(price)
is_description_hook = callable(description)
is_resource_hook = callable(resource)

async def middleware(request: Request, call_next: Callable):
# Skip if the path is not the same as the path in the middleware
if not path_is_match(path, request.url.path):
return await call_next(request)

# Get resource URL if not explicitly provided
resource_url = resource or str(request.url)
# Resolve dynamic values with optimized conditional execution
try:
if is_price_hook:
current_price = await _resolve_value(
price, request, field_name="price"
)
else:
current_price = price

if is_description_hook:
current_description = await _resolve_value(
description, request, field_name="description"
)
else:
current_description = description

if is_resource_hook:
current_resource = await _resolve_value(
resource, request, field_name="resource"
)
else:
current_resource = resource or str(request.url)

max_amount_required, asset_address, eip712_domain = (
_validate_and_process_price(current_price, network)
)
except asyncio.TimeoutError:
logger.error("Payment requirement hook timed out")
return JSONResponse(
status_code=500,
content={"error": "Request timeout processing payment requirements"},
)
except Exception as e:
logger.error(f"Failed to resolve payment requirements: {e}")
return JSONResponse(
status_code=500,
content={
"error": "Internal server error resolving payment requirements"
},
)

# Construct payment details
payment_requirements = [
Expand All @@ -103,8 +238,8 @@ async def middleware(request: Request, call_next: Callable):
network=cast(SupportedNetworks, network),
asset=asset_address,
max_amount_required=max_amount_required,
resource=resource_url,
description=description,
resource=current_resource,
description=current_description,
mime_type=mime_type,
pay_to=pay_to_address,
max_timeout_seconds=max_deadline_seconds,
Expand Down
9 changes: 8 additions & 1 deletion python/x402/src/x402/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

from datetime import datetime
from enum import Enum
from typing import Any, Optional, Union, Dict, Literal, List
from typing import Any, Optional, Union, Dict, Literal, List, Callable, Awaitable
from typing_extensions import (
TypedDict,
) # use `typing_extensions.TypedDict` instead of `typing.TypedDict` on Python < 3.12

from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic.alias_generators import to_camel
from starlette.requests import Request

from x402.networks import SupportedNetworks

Expand Down Expand Up @@ -92,6 +93,12 @@ class EIP712Domain(BaseModel):
Money = Union[str, int] # e.g., "$0.01", 0.01, "0.001"
Price = Union[Money, TokenAmount]

# Dynamic Hook types for middleware
PriceHook = Callable[[Request], Awaitable[Price]]
StringHook = Callable[[Request], Awaitable[str]]
PriceOrHook = Union[Price, PriceHook]
StringOrHook = Union[str, StringHook]


class PaymentRequirements(BaseModel):
scheme: str
Expand Down
Loading