Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/economic-teal-albatross.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"stagehand": patch
---

fix camelCase and snake_case return api extract schema mismatch
24 changes: 18 additions & 6 deletions stagehand/handlers/extract_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
ExtractOptions,
ExtractResult,
)
from stagehand.utils import inject_urls, transform_url_strings_to_ids
from stagehand.utils import (
convert_dict_keys_to_snake_case,
inject_urls,
transform_url_strings_to_ids,
)

T = TypeVar("T", bound=BaseModel)

Expand Down Expand Up @@ -150,13 +154,21 @@ async def extract(
if schema and isinstance(
raw_data_dict, dict
): # schema is the Pydantic model type
# Try direct validation first
try:
validated_model_instance = schema.model_validate(raw_data_dict)
processed_data_payload = validated_model_instance # Payload is now the Pydantic model instance
except Exception as e:
self.logger.error(
f"Failed to validate extracted data against schema {schema.__name__}: {e}. Keeping raw data dict in .data field."
)
processed_data_payload = validated_model_instance
except Exception as first_error:
# Fallback: attempt camelCase→snake_case key normalization, then re-validate
try:
normalized = convert_dict_keys_to_snake_case(raw_data_dict)
validated_model_instance = schema.model_validate(normalized)
processed_data_payload = validated_model_instance
except Exception as second_error:
self.logger.error(
f"Failed to validate extracted data against schema {schema.__name__}: {first_error}. "
f"Normalization retry also failed: {second_error}. Keeping raw data dict in .data field."
)

# Create ExtractResult object
result = ExtractResult(
Expand Down
25 changes: 21 additions & 4 deletions stagehand/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ObserveResult,
)
from .types import DefaultExtractSchema, EmptyExtractSchema
from .utils import convert_dict_keys_to_snake_case

_INJECTION_SCRIPT = None

Expand Down Expand Up @@ -412,10 +413,26 @@ async def extract(
processed_data_payload
)
processed_data_payload = validated_model
except Exception as e:
self._stagehand.logger.error(
f"Failed to validate extracted data against schema {schema_to_validate_with.__name__}: {e}. Keeping raw data dict in .data field."
)
except Exception as first_error:
# Fallback: normalize keys to snake_case and try once more
try:
normalized = convert_dict_keys_to_snake_case(
processed_data_payload
)
if not options_obj:
validated_model = EmptyExtractSchema.model_validate(
normalized
)
else:
validated_model = schema_to_validate_with.model_validate(
normalized
)
processed_data_payload = validated_model
except Exception as second_error:
self._stagehand.logger.error(
f"Failed to validate extracted data against schema {getattr(schema_to_validate_with, '__name__', str(schema_to_validate_with))}: {first_error}. "
f"Normalization retry also failed: {second_error}. Keeping raw data dict in .data field."
)
return ExtractResult(data=processed_data_payload).data
# Handle unexpected return types
self._stagehand.logger.info(
Expand Down
40 changes: 40 additions & 0 deletions stagehand/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,46 @@ def convert_dict_keys_to_camel_case(data: dict[str, Any]) -> dict[str, Any]:
return result


def camel_to_snake(camel_str: str) -> str:
"""
Convert a camelCase or PascalCase string to snake_case.

Args:
camel_str: The camelCase/PascalCase string to convert

Returns:
The converted snake_case string
"""
result_chars = []
for index, char in enumerate(camel_str):
if char.isupper() and index != 0 and (not camel_str[index - 1].isupper()):
result_chars.append("_")
result_chars.append(char.lower())
return "".join(result_chars)


def convert_dict_keys_to_snake_case(data: Any) -> Any:
"""
Convert all dictionary keys from camelCase/PascalCase to snake_case.
Works recursively for nested dictionaries and lists. Non-dict/list inputs are returned as-is.

Args:
data: Potentially nested structure with dictionaries/lists

Returns:
A new structure with all dict keys converted to snake_case
"""
if isinstance(data, dict):
converted: dict[str, Any] = {}
for key, value in data.items():
converted_key = camel_to_snake(key) if isinstance(key, str) else key
converted[converted_key] = convert_dict_keys_to_snake_case(value)
return converted
if isinstance(data, list):
return [convert_dict_keys_to_snake_case(item) for item in data]
return data


def format_simplified_tree(node: AccessibilityNode, level: int = 0) -> str:
"""Formats a node and its children into a simplified string representation."""
indent = " " * level
Expand Down
124 changes: 124 additions & 0 deletions tests/e2e/test_extract_casing_normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""
E2E tests to ensure extract returns validate into snake_case Pydantic schemas
for both LOCAL and BROWSERBASE environments, covering API responses that may
use camelCase keys.
"""

import os
import pytest
import pytest_asyncio
from urllib.parse import urlparse
from pydantic import BaseModel, Field, HttpUrl

from stagehand import Stagehand, StagehandConfig
from stagehand.schemas import ExtractOptions


class Company(BaseModel):
company_name: str = Field(..., description="The name of the company")
company_url: HttpUrl = Field(..., description="The URL of the company website or relevant page")


class Companies(BaseModel):
companies: list[Company] = Field(..., description="List of companies extracted from the page")


@pytest.fixture(scope="class")
def local_config():
return StagehandConfig(
env="LOCAL",
model_name="gpt-4o-mini",
headless=True,
verbose=1,
dom_settle_timeout_ms=2000,
model_client_options={
"apiKey": os.getenv("MODEL_API_KEY") or os.getenv("OPENAI_API_KEY")
},
)


@pytest.fixture(scope="class")
def browserbase_config():
return StagehandConfig(
env="BROWSERBASE",
api_key=os.getenv("BROWSERBASE_API_KEY"),
project_id=os.getenv("BROWSERBASE_PROJECT_ID"),
model_name="gpt-4o",
headless=False,
verbose=2,
dom_settle_timeout_ms=3000,
model_client_options={
"apiKey": os.getenv("MODEL_API_KEY") or os.getenv("OPENAI_API_KEY")
},
)


@pytest_asyncio.fixture
async def local_stagehand(local_config):
stagehand = Stagehand(config=local_config)
await stagehand.init()
yield stagehand
await stagehand.close()


@pytest_asyncio.fixture
async def browserbase_stagehand(browserbase_config):
if not (os.getenv("BROWSERBASE_API_KEY") and os.getenv("BROWSERBASE_PROJECT_ID")):
pytest.skip("Browserbase credentials not available")
stagehand = Stagehand(config=browserbase_config)
await stagehand.init()
yield stagehand
await stagehand.close()


@pytest.mark.asyncio
@pytest.mark.local
async def test_extract_companies_casing_local(local_stagehand):
stagehand = local_stagehand
# Use stable eval site for consistency
await stagehand.page.goto("https://browserbase.github.io/stagehand-eval-sites/sites/aigrant/")

extract_options = ExtractOptions(
instruction="Extract the names and URLs of up to 5 companies in batch 3",
schema_definition=Companies,
)

result = await stagehand.page.extract(extract_options)

# Should be validated into our snake_case Pydantic model
assert isinstance(result, Companies)
assert 0 < len(result.companies) <= 5
for c in result.companies:
assert isinstance(c.company_name, str) and c.company_name
# Avoid isinstance checks with Pydantic's Annotated types; validate via parsing
parsed = urlparse(str(c.company_url))
assert parsed.scheme in ("http", "https") and bool(parsed.netloc)


@pytest.mark.asyncio
@pytest.mark.api
@pytest.mark.skipif(
not (os.getenv("BROWSERBASE_API_KEY") and os.getenv("BROWSERBASE_PROJECT_ID")),
reason="Browserbase credentials not available",
)
async def test_extract_companies_casing_browserbase(browserbase_stagehand):
stagehand = browserbase_stagehand
# Use stable eval site for consistency
await stagehand.page.goto("https://browserbase.github.io/stagehand-eval-sites/sites/aigrant/")

extract_options = ExtractOptions(
instruction="Extract the names and URLs of up to 5 companies in batch 3",
schema_definition=Companies,
)

result = await stagehand.page.extract(extract_options)

# Should be validated into our snake_case Pydantic model even if API returns camelCase
assert isinstance(result, Companies)
assert 0 < len(result.companies) <= 5
for c in result.companies:
assert isinstance(c.company_name, str) and c.company_name
parsed = urlparse(str(c.company_url))
assert parsed.scheme in ("http", "https") and bool(parsed.netloc)