diff --git a/.changeset/economic-teal-albatross.md b/.changeset/economic-teal-albatross.md new file mode 100644 index 0000000..74b3825 --- /dev/null +++ b/.changeset/economic-teal-albatross.md @@ -0,0 +1,5 @@ +--- +"stagehand": patch +--- + +fix camelCase and snake_case return api extract schema mismatch diff --git a/stagehand/handlers/extract_handler.py b/stagehand/handlers/extract_handler.py index 89e37b8..ec9b9fa 100644 --- a/stagehand/handlers/extract_handler.py +++ b/stagehand/handlers/extract_handler.py @@ -13,7 +13,11 @@ ExtractOptions, ExtractResult, ) -from stagehand.utils import inject_urls, transform_url_strings_to_ids +from stagehand.utils import ( + convert_dict_keys_to_snake_case, + inject_urls, + transform_url_strings_to_ids, +) T = TypeVar("T", bound=BaseModel) @@ -150,13 +154,21 @@ async def extract( if schema and isinstance( raw_data_dict, dict ): # schema is the Pydantic model type + # Try direct validation first try: validated_model_instance = schema.model_validate(raw_data_dict) - processed_data_payload = validated_model_instance # Payload is now the Pydantic model instance - except Exception as e: - self.logger.error( - f"Failed to validate extracted data against schema {schema.__name__}: {e}. Keeping raw data dict in .data field." - ) + processed_data_payload = validated_model_instance + except Exception as first_error: + # Fallback: attempt camelCase→snake_case key normalization, then re-validate + try: + normalized = convert_dict_keys_to_snake_case(raw_data_dict) + validated_model_instance = schema.model_validate(normalized) + processed_data_payload = validated_model_instance + except Exception as second_error: + self.logger.error( + f"Failed to validate extracted data against schema {schema.__name__}: {first_error}. " + f"Normalization retry also failed: {second_error}. Keeping raw data dict in .data field." + ) # Create ExtractResult object result = ExtractResult( diff --git a/stagehand/page.py b/stagehand/page.py index c3e8831..3f2738e 100644 --- a/stagehand/page.py +++ b/stagehand/page.py @@ -19,6 +19,7 @@ ObserveResult, ) from .types import DefaultExtractSchema, EmptyExtractSchema +from .utils import convert_dict_keys_to_snake_case _INJECTION_SCRIPT = None @@ -412,10 +413,26 @@ async def extract( processed_data_payload ) processed_data_payload = validated_model - except Exception as e: - self._stagehand.logger.error( - f"Failed to validate extracted data against schema {schema_to_validate_with.__name__}: {e}. Keeping raw data dict in .data field." - ) + except Exception as first_error: + # Fallback: normalize keys to snake_case and try once more + try: + normalized = convert_dict_keys_to_snake_case( + processed_data_payload + ) + if not options_obj: + validated_model = EmptyExtractSchema.model_validate( + normalized + ) + else: + validated_model = schema_to_validate_with.model_validate( + normalized + ) + processed_data_payload = validated_model + except Exception as second_error: + self._stagehand.logger.error( + f"Failed to validate extracted data against schema {getattr(schema_to_validate_with, '__name__', str(schema_to_validate_with))}: {first_error}. " + f"Normalization retry also failed: {second_error}. Keeping raw data dict in .data field." + ) return ExtractResult(data=processed_data_payload).data # Handle unexpected return types self._stagehand.logger.info( diff --git a/stagehand/utils.py b/stagehand/utils.py index 37f4978..4c1bb85 100644 --- a/stagehand/utils.py +++ b/stagehand/utils.py @@ -55,6 +55,46 @@ def convert_dict_keys_to_camel_case(data: dict[str, Any]) -> dict[str, Any]: return result +def camel_to_snake(camel_str: str) -> str: + """ + Convert a camelCase or PascalCase string to snake_case. + + Args: + camel_str: The camelCase/PascalCase string to convert + + Returns: + The converted snake_case string + """ + result_chars = [] + for index, char in enumerate(camel_str): + if char.isupper() and index != 0 and (not camel_str[index - 1].isupper()): + result_chars.append("_") + result_chars.append(char.lower()) + return "".join(result_chars) + + +def convert_dict_keys_to_snake_case(data: Any) -> Any: + """ + Convert all dictionary keys from camelCase/PascalCase to snake_case. + Works recursively for nested dictionaries and lists. Non-dict/list inputs are returned as-is. + + Args: + data: Potentially nested structure with dictionaries/lists + + Returns: + A new structure with all dict keys converted to snake_case + """ + if isinstance(data, dict): + converted: dict[str, Any] = {} + for key, value in data.items(): + converted_key = camel_to_snake(key) if isinstance(key, str) else key + converted[converted_key] = convert_dict_keys_to_snake_case(value) + return converted + if isinstance(data, list): + return [convert_dict_keys_to_snake_case(item) for item in data] + return data + + def format_simplified_tree(node: AccessibilityNode, level: int = 0) -> str: """Formats a node and its children into a simplified string representation.""" indent = " " * level diff --git a/tests/e2e/test_extract_casing_normalization.py b/tests/e2e/test_extract_casing_normalization.py new file mode 100644 index 0000000..be19600 --- /dev/null +++ b/tests/e2e/test_extract_casing_normalization.py @@ -0,0 +1,124 @@ +""" +E2E tests to ensure extract returns validate into snake_case Pydantic schemas +for both LOCAL and BROWSERBASE environments, covering API responses that may +use camelCase keys. +""" + +import os +import pytest +import pytest_asyncio +from urllib.parse import urlparse +from pydantic import BaseModel, Field, HttpUrl + +from stagehand import Stagehand, StagehandConfig +from stagehand.schemas import ExtractOptions + + +class Company(BaseModel): + company_name: str = Field(..., description="The name of the company") + company_url: HttpUrl = Field(..., description="The URL of the company website or relevant page") + + +class Companies(BaseModel): + companies: list[Company] = Field(..., description="List of companies extracted from the page") + + +@pytest.fixture(scope="class") +def local_config(): + return StagehandConfig( + env="LOCAL", + model_name="gpt-4o-mini", + headless=True, + verbose=1, + dom_settle_timeout_ms=2000, + model_client_options={ + "apiKey": os.getenv("MODEL_API_KEY") or os.getenv("OPENAI_API_KEY") + }, + ) + + +@pytest.fixture(scope="class") +def browserbase_config(): + return StagehandConfig( + env="BROWSERBASE", + api_key=os.getenv("BROWSERBASE_API_KEY"), + project_id=os.getenv("BROWSERBASE_PROJECT_ID"), + model_name="gpt-4o", + headless=False, + verbose=2, + dom_settle_timeout_ms=3000, + model_client_options={ + "apiKey": os.getenv("MODEL_API_KEY") or os.getenv("OPENAI_API_KEY") + }, + ) + + +@pytest_asyncio.fixture +async def local_stagehand(local_config): + stagehand = Stagehand(config=local_config) + await stagehand.init() + yield stagehand + await stagehand.close() + + +@pytest_asyncio.fixture +async def browserbase_stagehand(browserbase_config): + if not (os.getenv("BROWSERBASE_API_KEY") and os.getenv("BROWSERBASE_PROJECT_ID")): + pytest.skip("Browserbase credentials not available") + stagehand = Stagehand(config=browserbase_config) + await stagehand.init() + yield stagehand + await stagehand.close() + + +@pytest.mark.asyncio +@pytest.mark.local +async def test_extract_companies_casing_local(local_stagehand): + stagehand = local_stagehand + # Use stable eval site for consistency + await stagehand.page.goto("https://browserbase.github.io/stagehand-eval-sites/sites/aigrant/") + + extract_options = ExtractOptions( + instruction="Extract the names and URLs of up to 5 companies in batch 3", + schema_definition=Companies, + ) + + result = await stagehand.page.extract(extract_options) + + # Should be validated into our snake_case Pydantic model + assert isinstance(result, Companies) + assert 0 < len(result.companies) <= 5 + for c in result.companies: + assert isinstance(c.company_name, str) and c.company_name + # Avoid isinstance checks with Pydantic's Annotated types; validate via parsing + parsed = urlparse(str(c.company_url)) + assert parsed.scheme in ("http", "https") and bool(parsed.netloc) + + +@pytest.mark.asyncio +@pytest.mark.api +@pytest.mark.skipif( + not (os.getenv("BROWSERBASE_API_KEY") and os.getenv("BROWSERBASE_PROJECT_ID")), + reason="Browserbase credentials not available", +) +async def test_extract_companies_casing_browserbase(browserbase_stagehand): + stagehand = browserbase_stagehand + # Use stable eval site for consistency + await stagehand.page.goto("https://browserbase.github.io/stagehand-eval-sites/sites/aigrant/") + + extract_options = ExtractOptions( + instruction="Extract the names and URLs of up to 5 companies in batch 3", + schema_definition=Companies, + ) + + result = await stagehand.page.extract(extract_options) + + # Should be validated into our snake_case Pydantic model even if API returns camelCase + assert isinstance(result, Companies) + assert 0 < len(result.companies) <= 5 + for c in result.companies: + assert isinstance(c.company_name, str) and c.company_name + parsed = urlparse(str(c.company_url)) + assert parsed.scheme in ("http", "https") and bool(parsed.netloc) + +