Skip to content

Commit

Permalink
Improved error reporting for missing web_search() provider environm…
Browse files Browse the repository at this point in the history
…ent variables. (#497)

Co-authored-by: aisi-inspect <166920645+aisi-inspect@users.noreply.github.com>
  • Loading branch information
jjallaire-aisi and aisi-inspect authored Sep 24, 2024
1 parent e8e9257 commit 2831fa6
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Unreleased

- Fix issue w/ subtasks not getting a fresh store() (regression from introduction of `fork()` in v0.3.30)
- Improved error reporting for missing `web_search()` provider environment variables.

## v0.3.31 (24 September 2024)

Expand Down
19 changes: 10 additions & 9 deletions src/inspect_ai/tool/_tools/_web_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
wait_exponential_jitter,
)

from inspect_ai._util.error import PrerequisiteError
from inspect_ai._util.retry import httpx_should_retry, log_retry_attempt
from inspect_ai.util._concurrency import concurrency

Expand Down Expand Up @@ -55,6 +56,13 @@ def web_search(
# get search client
client = httpx.AsyncClient()

if provider == "google":
search_provider = google_search_provider(client)
else:
raise ValueError(
f"Provider {provider} not supported. Only 'google' is supported."
)

# resolve provider (only google for now)
async def execute(query: str) -> ToolResult:
"""
Expand All @@ -69,13 +77,6 @@ async def execute(query: str) -> ToolResult:
snippets: list[str] = []
search_calls = 0

if provider == "google":
search_provider = google_search_provider(client)
else:
raise Exception(
f"Provider {provider} not supported. Only 'google' is supported."
)

# Paginate through search results until we have successfully extracted num_results pages or we have reached max_provider_calls
while len(page_contents) < num_results and search_calls < max_provider_calls:
async with concurrency(f"{provider}_web_search", max_connections):
Expand Down Expand Up @@ -180,8 +181,8 @@ def google_search_provider(client: httpx.AsyncClient) -> SearchProvider:
google_api_key = os.environ.get("GOOGLE_CSE_API_KEY", None)
google_cse_id = os.environ.get("GOOGLE_CSE_ID", None)
if not google_api_key or not google_cse_id:
raise Exception(
"GOOGLE_CSE_ID and/or GOOGLE_CSE_API_KEY not set in environment"
raise PrerequisiteError(
"GOOGLE_CSE_ID and/or GOOGLE_CSE_API_KEY not set in the environment. Please ensure these variables are defined to use Google Custom Search with the web_search tool.\n\nLearn more about the Google web search provider at https://inspect.ai-safety-institute.org.uk/tools.html#google-provider"
)

async def search(query: str, start_idx: int) -> list[SearchLink]:
Expand Down

0 comments on commit 2831fa6

Please sign in to comment.