Skip to content

Commit ec22cb9

Browse files
Fm/stg 690 stagehand api returns camel instead of snake (#185)
* normalize camel from api * fix test * lint * add changestet
1 parent 61ade28 commit ec22cb9

File tree

5 files changed

+208
-10
lines changed

5 files changed

+208
-10
lines changed

.changeset/economic-teal-albatross.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"stagehand": patch
3+
---
4+
5+
fix camelCase and snake_case return api extract schema mismatch

stagehand/handlers/extract_handler.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313
ExtractOptions,
1414
ExtractResult,
1515
)
16-
from stagehand.utils import inject_urls, transform_url_strings_to_ids
16+
from stagehand.utils import (
17+
convert_dict_keys_to_snake_case,
18+
inject_urls,
19+
transform_url_strings_to_ids,
20+
)
1721

1822
T = TypeVar("T", bound=BaseModel)
1923

@@ -150,13 +154,21 @@ async def extract(
150154
if schema and isinstance(
151155
raw_data_dict, dict
152156
): # schema is the Pydantic model type
157+
# Try direct validation first
153158
try:
154159
validated_model_instance = schema.model_validate(raw_data_dict)
155-
processed_data_payload = validated_model_instance # Payload is now the Pydantic model instance
156-
except Exception as e:
157-
self.logger.error(
158-
f"Failed to validate extracted data against schema {schema.__name__}: {e}. Keeping raw data dict in .data field."
159-
)
160+
processed_data_payload = validated_model_instance
161+
except Exception as first_error:
162+
# Fallback: attempt camelCase→snake_case key normalization, then re-validate
163+
try:
164+
normalized = convert_dict_keys_to_snake_case(raw_data_dict)
165+
validated_model_instance = schema.model_validate(normalized)
166+
processed_data_payload = validated_model_instance
167+
except Exception as second_error:
168+
self.logger.error(
169+
f"Failed to validate extracted data against schema {schema.__name__}: {first_error}. "
170+
f"Normalization retry also failed: {second_error}. Keeping raw data dict in .data field."
171+
)
160172

161173
# Create ExtractResult object
162174
result = ExtractResult(

stagehand/page.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
ObserveResult,
2020
)
2121
from .types import DefaultExtractSchema, EmptyExtractSchema
22+
from .utils import convert_dict_keys_to_snake_case
2223

2324
_INJECTION_SCRIPT = None
2425

@@ -412,10 +413,26 @@ async def extract(
412413
processed_data_payload
413414
)
414415
processed_data_payload = validated_model
415-
except Exception as e:
416-
self._stagehand.logger.error(
417-
f"Failed to validate extracted data against schema {schema_to_validate_with.__name__}: {e}. Keeping raw data dict in .data field."
418-
)
416+
except Exception as first_error:
417+
# Fallback: normalize keys to snake_case and try once more
418+
try:
419+
normalized = convert_dict_keys_to_snake_case(
420+
processed_data_payload
421+
)
422+
if not options_obj:
423+
validated_model = EmptyExtractSchema.model_validate(
424+
normalized
425+
)
426+
else:
427+
validated_model = schema_to_validate_with.model_validate(
428+
normalized
429+
)
430+
processed_data_payload = validated_model
431+
except Exception as second_error:
432+
self._stagehand.logger.error(
433+
f"Failed to validate extracted data against schema {getattr(schema_to_validate_with, '__name__', str(schema_to_validate_with))}: {first_error}. "
434+
f"Normalization retry also failed: {second_error}. Keeping raw data dict in .data field."
435+
)
419436
return ExtractResult(data=processed_data_payload).data
420437
# Handle unexpected return types
421438
self._stagehand.logger.info(

stagehand/utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,46 @@ def convert_dict_keys_to_camel_case(data: dict[str, Any]) -> dict[str, Any]:
5555
return result
5656

5757

58+
def camel_to_snake(camel_str: str) -> str:
59+
"""
60+
Convert a camelCase or PascalCase string to snake_case.
61+
62+
Args:
63+
camel_str: The camelCase/PascalCase string to convert
64+
65+
Returns:
66+
The converted snake_case string
67+
"""
68+
result_chars = []
69+
for index, char in enumerate(camel_str):
70+
if char.isupper() and index != 0 and (not camel_str[index - 1].isupper()):
71+
result_chars.append("_")
72+
result_chars.append(char.lower())
73+
return "".join(result_chars)
74+
75+
76+
def convert_dict_keys_to_snake_case(data: Any) -> Any:
77+
"""
78+
Convert all dictionary keys from camelCase/PascalCase to snake_case.
79+
Works recursively for nested dictionaries and lists. Non-dict/list inputs are returned as-is.
80+
81+
Args:
82+
data: Potentially nested structure with dictionaries/lists
83+
84+
Returns:
85+
A new structure with all dict keys converted to snake_case
86+
"""
87+
if isinstance(data, dict):
88+
converted: dict[str, Any] = {}
89+
for key, value in data.items():
90+
converted_key = camel_to_snake(key) if isinstance(key, str) else key
91+
converted[converted_key] = convert_dict_keys_to_snake_case(value)
92+
return converted
93+
if isinstance(data, list):
94+
return [convert_dict_keys_to_snake_case(item) for item in data]
95+
return data
96+
97+
5898
def format_simplified_tree(node: AccessibilityNode, level: int = 0) -> str:
5999
"""Formats a node and its children into a simplified string representation."""
60100
indent = " " * level
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
"""
2+
E2E tests to ensure extract returns validate into snake_case Pydantic schemas
3+
for both LOCAL and BROWSERBASE environments, covering API responses that may
4+
use camelCase keys.
5+
"""
6+
7+
import os
8+
import pytest
9+
import pytest_asyncio
10+
from urllib.parse import urlparse
11+
from pydantic import BaseModel, Field, HttpUrl
12+
13+
from stagehand import Stagehand, StagehandConfig
14+
from stagehand.schemas import ExtractOptions
15+
16+
17+
class Company(BaseModel):
18+
company_name: str = Field(..., description="The name of the company")
19+
company_url: HttpUrl = Field(..., description="The URL of the company website or relevant page")
20+
21+
22+
class Companies(BaseModel):
23+
companies: list[Company] = Field(..., description="List of companies extracted from the page")
24+
25+
26+
@pytest.fixture(scope="class")
27+
def local_config():
28+
return StagehandConfig(
29+
env="LOCAL",
30+
model_name="gpt-4o-mini",
31+
headless=True,
32+
verbose=1,
33+
dom_settle_timeout_ms=2000,
34+
model_client_options={
35+
"apiKey": os.getenv("MODEL_API_KEY") or os.getenv("OPENAI_API_KEY")
36+
},
37+
)
38+
39+
40+
@pytest.fixture(scope="class")
41+
def browserbase_config():
42+
return StagehandConfig(
43+
env="BROWSERBASE",
44+
api_key=os.getenv("BROWSERBASE_API_KEY"),
45+
project_id=os.getenv("BROWSERBASE_PROJECT_ID"),
46+
model_name="gpt-4o",
47+
headless=False,
48+
verbose=2,
49+
dom_settle_timeout_ms=3000,
50+
model_client_options={
51+
"apiKey": os.getenv("MODEL_API_KEY") or os.getenv("OPENAI_API_KEY")
52+
},
53+
)
54+
55+
56+
@pytest_asyncio.fixture
57+
async def local_stagehand(local_config):
58+
stagehand = Stagehand(config=local_config)
59+
await stagehand.init()
60+
yield stagehand
61+
await stagehand.close()
62+
63+
64+
@pytest_asyncio.fixture
65+
async def browserbase_stagehand(browserbase_config):
66+
if not (os.getenv("BROWSERBASE_API_KEY") and os.getenv("BROWSERBASE_PROJECT_ID")):
67+
pytest.skip("Browserbase credentials not available")
68+
stagehand = Stagehand(config=browserbase_config)
69+
await stagehand.init()
70+
yield stagehand
71+
await stagehand.close()
72+
73+
74+
@pytest.mark.asyncio
75+
@pytest.mark.local
76+
async def test_extract_companies_casing_local(local_stagehand):
77+
stagehand = local_stagehand
78+
# Use stable eval site for consistency
79+
await stagehand.page.goto("https://browserbase.github.io/stagehand-eval-sites/sites/aigrant/")
80+
81+
extract_options = ExtractOptions(
82+
instruction="Extract the names and URLs of up to 5 companies in batch 3",
83+
schema_definition=Companies,
84+
)
85+
86+
result = await stagehand.page.extract(extract_options)
87+
88+
# Should be validated into our snake_case Pydantic model
89+
assert isinstance(result, Companies)
90+
assert 0 < len(result.companies) <= 5
91+
for c in result.companies:
92+
assert isinstance(c.company_name, str) and c.company_name
93+
# Avoid isinstance checks with Pydantic's Annotated types; validate via parsing
94+
parsed = urlparse(str(c.company_url))
95+
assert parsed.scheme in ("http", "https") and bool(parsed.netloc)
96+
97+
98+
@pytest.mark.asyncio
99+
@pytest.mark.api
100+
@pytest.mark.skipif(
101+
not (os.getenv("BROWSERBASE_API_KEY") and os.getenv("BROWSERBASE_PROJECT_ID")),
102+
reason="Browserbase credentials not available",
103+
)
104+
async def test_extract_companies_casing_browserbase(browserbase_stagehand):
105+
stagehand = browserbase_stagehand
106+
# Use stable eval site for consistency
107+
await stagehand.page.goto("https://browserbase.github.io/stagehand-eval-sites/sites/aigrant/")
108+
109+
extract_options = ExtractOptions(
110+
instruction="Extract the names and URLs of up to 5 companies in batch 3",
111+
schema_definition=Companies,
112+
)
113+
114+
result = await stagehand.page.extract(extract_options)
115+
116+
# Should be validated into our snake_case Pydantic model even if API returns camelCase
117+
assert isinstance(result, Companies)
118+
assert 0 < len(result.companies) <= 5
119+
for c in result.companies:
120+
assert isinstance(c.company_name, str) and c.company_name
121+
parsed = urlparse(str(c.company_url))
122+
assert parsed.scheme in ("http", "https") and bool(parsed.netloc)
123+
124+

0 commit comments

Comments
 (0)