From 2831fa6cee6e981f59696863da6db24f717347cd Mon Sep 17 00:00:00 2001 From: jjallaire-aisi Date: Tue, 24 Sep 2024 16:28:01 +0100 Subject: [PATCH] Improved error reporting for missing `web_search()` provider environment variables. (#497) Co-authored-by: aisi-inspect <166920645+aisi-inspect@users.noreply.github.com> --- CHANGELOG.md | 1 + src/inspect_ai/tool/_tools/_web_search.py | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e8e0ad324..468510752 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## Unreleased - Fix issue w/ subtasks not getting a fresh store() (regression from introduction of `fork()` in v0.3.30) +- Improved error reporting for missing `web_search()` provider environment variables. ## v0.3.31 (24 September 2024) diff --git a/src/inspect_ai/tool/_tools/_web_search.py b/src/inspect_ai/tool/_tools/_web_search.py index 00a27c138..9ba1a3fed 100644 --- a/src/inspect_ai/tool/_tools/_web_search.py +++ b/src/inspect_ai/tool/_tools/_web_search.py @@ -12,6 +12,7 @@ wait_exponential_jitter, ) +from inspect_ai._util.error import PrerequisiteError from inspect_ai._util.retry import httpx_should_retry, log_retry_attempt from inspect_ai.util._concurrency import concurrency @@ -55,6 +56,13 @@ def web_search( # get search client client = httpx.AsyncClient() + if provider == "google": + search_provider = google_search_provider(client) + else: + raise ValueError( + f"Provider {provider} not supported. Only 'google' is supported." + ) + # resolve provider (only google for now) async def execute(query: str) -> ToolResult: """ @@ -69,13 +77,6 @@ async def execute(query: str) -> ToolResult: snippets: list[str] = [] search_calls = 0 - if provider == "google": - search_provider = google_search_provider(client) - else: - raise Exception( - f"Provider {provider} not supported. Only 'google' is supported." - ) - # Paginate through search results until we have successfully extracted num_results pages or we have reached max_provider_calls while len(page_contents) < num_results and search_calls < max_provider_calls: async with concurrency(f"{provider}_web_search", max_connections): @@ -180,8 +181,8 @@ def google_search_provider(client: httpx.AsyncClient) -> SearchProvider: google_api_key = os.environ.get("GOOGLE_CSE_API_KEY", None) google_cse_id = os.environ.get("GOOGLE_CSE_ID", None) if not google_api_key or not google_cse_id: - raise Exception( - "GOOGLE_CSE_ID and/or GOOGLE_CSE_API_KEY not set in environment" + raise PrerequisiteError( + "GOOGLE_CSE_ID and/or GOOGLE_CSE_API_KEY not set in the environment. Please ensure these variables are defined to use Google Custom Search with the web_search tool.\n\nLearn more about the Google web search provider at https://inspect.ai-safety-institute.org.uk/tools.html#google-provider" ) async def search(query: str, start_idx: int) -> list[SearchLink]: