diff --git a/python/.env.example b/python/.env.example index 7d9a407dc877..c3820e5a72ff 100644 --- a/python/.env.example +++ b/python/.env.example @@ -20,6 +20,7 @@ POSTGRES_CONNECTION_STRING="" WEAVIATE_URL="" WEAVIATE_API_KEY="" GOOGLE_SEARCH_ENGINE_ID="" +BRAVE_API_KEY="" REDIS_CONNECTION_STRING="" AZCOSMOS_API="" AZCOSMOS_CONNSTR="" diff --git a/python/samples/concepts/README.md b/python/samples/concepts/README.md index 9f2512aa7561..9f9834889e27 100644 --- a/python/samples/concepts/README.md +++ b/python/samples/concepts/README.md @@ -197,6 +197,7 @@ ### Search - Using [`Search`](https://github.com/microsoft/semantic-kernel/tree/main/python/semantic_kernel/connectors/search) services information - [Bing Text Search as Plugin](./search/bing_text_search_as_plugin.py) +- [Brave Text Search as Plugin](./search/brave_text_search_as_plugin.py) - [Google Text Search as Plugin](./search/google_text_search_as_plugin.py) ### Service Selector - Shows how to create and use a custom service selector class diff --git a/python/samples/concepts/search/brave_text_search_as_plugin.py b/python/samples/concepts/search/brave_text_search_as_plugin.py new file mode 100644 index 000000000000..326ee853283e --- /dev/null +++ b/python/samples/concepts/search/brave_text_search_as_plugin.py @@ -0,0 +1,135 @@ +# Copyright (c) Microsoft. All rights reserved. + +from collections.abc import Awaitable, Callable + +from semantic_kernel import Kernel +from semantic_kernel.connectors.ai import FunctionChoiceBehavior +from semantic_kernel.connectors.ai.open_ai import OpenAIChatCompletion, OpenAIChatPromptExecutionSettings +from semantic_kernel.connectors.search.brave import BraveSearch +from semantic_kernel.contents import ChatHistory +from semantic_kernel.filters import FilterTypes, FunctionInvocationContext +from semantic_kernel.functions import KernelArguments, KernelParameterMetadata, KernelPlugin + +""" +This project demonstrates how to integrate the Brave Search API as a plugin into the Semantic Kernel +framework to enable conversational AI capabilities with real-time web information. + +To use Brave Search, you need an API key, which can be obtained by login to +https://api-dashboard.search.brave.com/ and creating a subscription key. +After that store it under the name `BRAVE_API_KEY` in a .env file or your environment variables. +""" + +kernel = Kernel() +kernel.add_service(OpenAIChatCompletion(service_id="chat")) +kernel.add_plugin( + KernelPlugin.from_text_search_with_search( + BraveSearch(), + plugin_name="brave", + description="Get details about Semantic Kernel concepts.", + parameters=[ + KernelParameterMetadata( + name="query", + description="The search query.", + type="str", + is_required=True, + type_object=str, + ), + KernelParameterMetadata( + name="top", + description="The number of results to return.", + type="int", + is_required=False, + default_value=2, + type_object=int, + ), + KernelParameterMetadata( + name="skip", + description="The number of results to skip.", + type="int", + is_required=False, + default_value=0, + type_object=int, + ), + ], + ) +) +chat_function = kernel.add_function( + prompt="{{$chat_history}}{{$user_input}}", + plugin_name="ChatBot", + function_name="Chat", +) +execution_settings = OpenAIChatPromptExecutionSettings( + service_id="chat", + max_tokens=2000, + temperature=0.7, + top_p=0.8, + function_choice_behavior=FunctionChoiceBehavior.Auto(auto_invoke=True), +) + +history = ChatHistory() +system_message = """ +You are a chat bot, specialized in Semantic Kernel, Microsoft LLM orchestration SDK. +Assume questions are related to that, and use the Brave search plugin to find answers. +""" +history.add_system_message(system_message) +history.add_user_message("Hi there, who are you?") +history.add_assistant_message("I am Mosscap, a chat bot. I'm trying to figure out what people need.") + +arguments = KernelArguments(settings=execution_settings) + + +@kernel.filter(filter_type=FilterTypes.FUNCTION_INVOCATION) +async def log_brave_filter( + context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]] +): + if context.function.plugin_name == "brave": + print("Calling Brave search with arguments:") + if "query" in context.arguments: + print(f' Query: "{context.arguments["query"]}"') + if "count" in context.arguments: + print(f' Count: "{context.arguments["count"]}"') + if "skip" in context.arguments: + print(f' Skip: "{context.arguments["skip"]}"') + await next(context) + print("Brave search completed.") + else: + await next(context) + + +async def chat() -> bool: + try: + user_input = input("User:> ") + except KeyboardInterrupt: + print("\n\nExiting chat...") + return False + except EOFError: + print("\n\nExiting chat...") + return False + + if user_input == "exit": + print("\n\nExiting chat...") + return False + arguments["user_input"] = user_input + arguments["chat_history"] = history + result = await kernel.invoke(chat_function, arguments=arguments) + print(f"Mosscap:> {result}") + history.add_user_message(user_input) + history.add_assistant_message(str(result)) + return True + + +async def main(): + chatting = True + print( + "Welcome to the chat bot!\ + \n Type 'exit' to exit.\ + \n Try to find out more about the inner workings of Semantic Kernel." + ) + while chatting: + chatting = await chat() + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main()) diff --git a/python/semantic_kernel/connectors/search/brave.py b/python/semantic_kernel/connectors/search/brave.py new file mode 100644 index 000000000000..38bbb5cddf30 --- /dev/null +++ b/python/semantic_kernel/connectors/search/brave.py @@ -0,0 +1,286 @@ +# Copyright (c) Microsoft. All rights reserved. + +import logging +from collections.abc import AsyncIterable +from typing import TYPE_CHECKING, Any, ClassVar, Final + +from httpx import AsyncClient, HTTPStatusError, RequestError +from pydantic import Field, SecretStr, ValidationError + +from semantic_kernel.data.text_search import ( + AnyTagsEqualTo, + EqualTo, + KernelSearchResults, + SearchFilter, + TextSearch, + TextSearchOptions, + TextSearchResult, +) +from semantic_kernel.exceptions import ServiceInitializationError, ServiceInvalidRequestError +from semantic_kernel.kernel_pydantic import KernelBaseModel, KernelBaseSettings +from semantic_kernel.utils.feature_stage_decorator import experimental +from semantic_kernel.utils.telemetry.user_agent import SEMANTIC_KERNEL_USER_AGENT + +if TYPE_CHECKING: + from semantic_kernel.data.text_search import SearchOptions + +logger: logging.Logger = logging.getLogger(__name__) + +# region Constants +DEFAULT_URL: Final[str] = "https://api.search.brave.com/res/v1/web/search" +QUERY_PARAMETERS: Final[list[str]] = [ + "country", + "search_lang", + "ui_lang", + "safesearch", + "text_decorations", + "spellcheck", + "result_filter", + "units", +] + + +# endregion Constants + + +# region BraveSettings +class BraveSettings(KernelBaseSettings): + """Brave Connector settings. + + The settings are first loaded from environment variables with the prefix 'BRAVE_'. If the + environment variables are not found, the settings can be loaded from a .env file with the + encoding 'utf-8'. If the settings are not found in the .env file, the settings are ignored; + however, validation will fail alerting that the settings are missing. + + Optional settings for prefix 'BRAVE_' are: + - api_key: SecretStr - The Brave API key (Env var BRAVE_API_KEY) + + """ + + env_prefix: ClassVar[str] = "BRAVE_" + + api_key: SecretStr + + +# endregion BraveSettings + + +# region BraveWeb +@experimental +class BraveWebPage(KernelBaseModel): + """A Brave web page.""" + + type: str | None = None + title: str | None = None + url: str | None = None + thumbnail: dict[str, str | bool] | None = None + description: str | None = None + age: str | None = None + language: str | None = None + family_friendly: bool | None = None + extra_snippets: list[str] | None = None + meta_ur: dict[str, str] | None = None + source: str | None = None + + +@experimental +class BraveWebPages(KernelBaseModel): + """THe web pages from a Brave search.""" + + type: str | None = Field(default="webpage") + family_friendly: bool | None = Field(default=None) + results: list[BraveWebPage] = Field(default_factory=list) + + +@experimental +class BraveSearchResponse(KernelBaseModel): + """The response from a Brave search.""" + + type: str = Field(default="search", alias="type") + query_context: dict[str, Any] = Field(default_factory=dict, validation_alias="query") + web_pages: BraveWebPages | None = Field(default=None, validation_alias="web") + + +# endregion BraveWeb + + +@experimental +class BraveSearch(KernelBaseModel, TextSearch): + """A search engine connector that uses the Brave Search API to perform a web search.""" + + settings: BraveSettings + + def __init__( + self, + api_key: str | None = None, + env_file_path: str | None = None, + env_file_encoding: str | None = None, + ) -> None: + """Initializes a new instance of the Brave Search class. + + Args: + api_key: The Brave Search API key. If provided, will override + the value in the env vars or .env file. + env_file_path: The optional path to the .env file. If provided, + the settings are read from this file path location. + env_file_encoding: The optional encoding of the .env file. If provided, + the settings are read from this file path location. + """ + try: + settings = BraveSettings( + api_key=api_key, + env_file_path=env_file_path, + env_file_encoding=env_file_encoding, + ) + except ValidationError as ex: + raise ServiceInitializationError("Failed to create Brave settings.") from ex + + super().__init__(settings=settings) # type: ignore[call-arg] + + async def search( + self, query: str, options: "SearchOptions | None" = None, **kwargs: Any + ) -> "KernelSearchResults[str]": + """Search for text, returning a KernelSearchResult with a list of strings.""" + options = self._get_options(options, **kwargs) + results = await self._inner_search(query=query, options=options) + return KernelSearchResults( + results=self._get_result_strings(results), + total_count=self._get_total_count(results, options), + metadata=self._get_metadata(results), + ) + + async def get_text_search_results( + self, query: str, options: "SearchOptions | None" = None, **kwargs + ) -> "KernelSearchResults[TextSearchResult]": + """Search for text, returning a KernelSearchResult with TextSearchResults.""" + options = self._get_options(options, **kwargs) + results = await self._inner_search(query=query, options=options) + return KernelSearchResults( + results=self._get_text_search_results(results), + total_count=self._get_total_count(results, options), + metadata=self._get_metadata(results), + ) + + async def get_search_results( + self, query: str, options: "SearchOptions | None" = None, **kwargs + ) -> "KernelSearchResults[BraveWebPage]": + """Search for text, returning a KernelSearchResult with the results directly from the service.""" + options = self._get_options(options, **kwargs) + results = await self._inner_search(query=query, options=options) + return KernelSearchResults( + results=self._get_brave_web_pages(results), + total_count=self._get_total_count(results, options), + metadata=self._get_metadata(results), + ) + + async def _get_result_strings(self, response: BraveSearchResponse) -> AsyncIterable[str]: + if response.web_pages is None: + return + for web_page in response.web_pages.results: + yield web_page.description or "" + + async def _get_text_search_results(self, response: BraveSearchResponse) -> AsyncIterable[TextSearchResult]: + if response.web_pages is None: + return + for web_page in response.web_pages.results: + yield TextSearchResult( + name=web_page.title, + value=web_page.description, + link=web_page.url, + ) + + async def _get_brave_web_pages(self, response: BraveSearchResponse) -> AsyncIterable[BraveWebPage]: + if response.web_pages is None: + return + for val in response.web_pages.results: + yield val + + def _get_metadata(self, response: BraveSearchResponse) -> dict[str, Any]: + return { + "original": response.query_context.get("original"), + "altered": response.query_context.get("altered", ""), + "spellcheck_off": response.query_context.get("spellcheck_off"), + "show_strict_warning": response.query_context.get("show_strict_warning"), + "country": response.query_context.get("country"), + } + + def _get_total_count(self, response: BraveSearchResponse, options: TextSearchOptions) -> int | None: + if options.include_total_count and response.web_pages is not None: + return len(response.web_pages.results) + return None + + def _get_options(self, options: "SearchOptions | None", **kwargs: Any) -> TextSearchOptions: + if options is not None and isinstance(options, TextSearchOptions): + return options + try: + return TextSearchOptions(**kwargs) + except ValidationError: + return TextSearchOptions() + + async def _inner_search(self, query: str, options: TextSearchOptions) -> BraveSearchResponse: + self._validate_options(options) + + logger.info( + f"Received request for brave web search with \ + params:\nnum_results: {options.top}\noffset: {options.skip}" + ) + + url = self._get_url() + params = self._build_request_parameters(query, options) + + logger.info(f"Sending GET request to {url}") + + headers = { + "X-Subscription-Token": self.settings.api_key.get_secret_value(), + "user_agent": SEMANTIC_KERNEL_USER_AGENT, + } + try: + async with AsyncClient(timeout=5) as client: + response = await client.get(url, headers=headers, params=params) + response.raise_for_status() + return BraveSearchResponse.model_validate_json(response.text) + except HTTPStatusError as ex: + logger.error(f"Failed to get search results: {ex}") + raise ServiceInvalidRequestError("Failed to get search results.") from ex + except RequestError as ex: + logger.error(f"Client error occurred: {ex}") + raise ServiceInvalidRequestError("A client error occurred while getting search results.") from ex + except Exception as ex: + logger.error(f"An unexpected error occurred: {ex}") + raise ServiceInvalidRequestError("An unexpected error occurred while getting search results.") from ex + + def _validate_options(self, options: TextSearchOptions) -> None: + if options.top <= 0: + raise ServiceInvalidRequestError("count value must be greater than 0.") + if options.top >= 21: + raise ServiceInvalidRequestError("count value must be less than 21.") + + if options.skip < 0: + raise ServiceInvalidRequestError("offset must be greater than or equal to 0.") + if options.skip > 9: + raise ServiceInvalidRequestError("offset must be less than 10.") + + def _get_url(self) -> str: + return DEFAULT_URL + + def _build_request_parameters(self, query: str, options: TextSearchOptions) -> dict[str, str | int | bool]: + params: dict[str, str | int] = {"q": query or "", "count": options.top, "offset": options.skip} + if not options.filter: + return params + for filter in options.filter.filters: + if isinstance(filter, EqualTo): + if filter.field_name in QUERY_PARAMETERS: + params[filter.field_name] = filter.value + else: + raise ServiceInvalidRequestError( + f"Observed an unwanted parameter named {filter.field_name} with value {filter.value} ." + ) + elif isinstance(filter, SearchFilter): + logger.warning("Groups are not supported by Brave search, ignored.") + continue + elif isinstance(filter, AnyTagsEqualTo): + logger.debug("Any tag equals to filter is not supported by Brave Search API.") + return params + + +__all__ = ["BraveSearch", "BraveSearchResponse", "BraveWebPage"] diff --git a/python/semantic_kernel/connectors/search_engine/google_connector.py b/python/semantic_kernel/connectors/search_engine/google_connector.py index ba2cc50d6dba..964584911a3f 100644 --- a/python/semantic_kernel/connectors/search_engine/google_connector.py +++ b/python/semantic_kernel/connectors/search_engine/google_connector.py @@ -85,8 +85,6 @@ async def search(self, query: str, num_results: int = 1, offset: int = 0) -> lis logger.info("Sending GET request to Google Search API.") - logger.info("Sending GET request to Google Search API.") - try: async with AsyncClient(timeout=5) as client: response = await client.get(request_url) diff --git a/python/tests/unit/connectors/conftest.py b/python/tests/unit/connectors/conftest.py index 80d0aa8737b3..bd9111a70c55 100644 --- a/python/tests/unit/connectors/conftest.py +++ b/python/tests/unit/connectors/conftest.py @@ -103,6 +103,28 @@ def bing_unit_test_env(monkeypatch, exclude_list, override_env_param_dict): return env_vars +@fixture() +def brave_unit_test_env(monkeypatch, exclude_list, override_env_param_dict): + """Fixture to set environment variables for BraveConnector.""" + if exclude_list is None: + exclude_list = [] + + if override_env_param_dict is None: + override_env_param_dict = {} + + env_vars = {"BRAVE_API_KEY": "test_api_key"} + + env_vars.update(override_env_param_dict) + + for key, value in env_vars.items(): + if key not in exclude_list: + monkeypatch.setenv(key, value) + else: + monkeypatch.delenv(key, raising=False) + + return env_vars + + @fixture() def google_search_unit_test_env(monkeypatch, exclude_list, override_env_param_dict): """Fixture to set environment variables for the Google Search Connector.""" diff --git a/python/tests/unit/connectors/search/brave/test_brave_search.py b/python/tests/unit/connectors/search/brave/test_brave_search.py new file mode 100644 index 000000000000..c490fbf634ff --- /dev/null +++ b/python/tests/unit/connectors/search/brave/test_brave_search.py @@ -0,0 +1,290 @@ +# Copyright (c) Microsoft. All rights reserved. + +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from semantic_kernel.connectors.search.brave import BraveSearch, BraveSearchResponse, BraveWebPage, BraveWebPages +from semantic_kernel.data.text_search import KernelSearchResults, SearchFilter, TextSearchOptions, TextSearchResult +from semantic_kernel.exceptions import ServiceInitializationError, ServiceInvalidRequestError + + +@pytest.fixture +def brave_search(brave_unit_test_env): + """Set up the fixture to configure the brave Search for these tests.""" + return BraveSearch() + + +@pytest.fixture +def async_client_mock(): + """Set up the fixture to mock AsyncClient.""" + async_client_mock = AsyncMock() + with patch("semantic_kernel.connectors.search.brave.AsyncClient.__aenter__", return_value=async_client_mock): + yield async_client_mock + + +@pytest.fixture +def mock_brave_search_response(): + """Set up the fixture to mock braveSearchResponse.""" + mock_web_page = BraveWebPage(name="Page Name", snippet="Page Snippet", url="test") + mock_response = BraveSearchResponse( + query_context={}, + webPages=MagicMock(spec=BraveWebPages, value=[mock_web_page], total_estimated_matches=3), + ) + + with ( + patch.object(BraveSearchResponse, "model_validate_json", return_value=mock_response), + ): + yield mock_response + + +async def test_brave_search_init_success(brave_search): + """Test that braveSearch initializes successfully with valid env.""" + # Should not raise any exception + assert brave_search.settings.api_key.get_secret_value() == "test_api_key" + + +@pytest.mark.parametrize("exclude_list", [["BRAVE_API_KEY"]], indirect=True) +async def test_brave_search_init_validation_error(brave_unit_test_env): + """Test that braveSearch raises ServiceInitializationError if BraveSettings creation fails.""" + with pytest.raises(ServiceInitializationError): + BraveSearch(env_file_path="invalid.env") + + +async def test_search_success(brave_unit_test_env, async_client_mock): + """Test that search method returns KernelSearchResults successfully on valid response.""" + # Arrange + mock_web_pages = BraveWebPage(description="Test snippet") + mock_response = BraveSearchResponse( + web_pages=MagicMock(spec=BraveWebPages, results=[mock_web_pages]), + query_context={ + "original": "original", + "altered": "altered something", + "show_strict_warning": False, + "spellcheck_off": False, + "country": "us", + }, + ) + + mock_result = MagicMock() + mock_result.text = """ + {"query": {'original':'original',"altered": + "altered something","show_strict_warning":False,"spellcheck_off":False,'country':"us"}, + "results": [{"description": "Test snippet"}]} + }""" + async_client_mock.get.return_value = mock_result + + # Act + with ( + patch.object(BraveSearchResponse, "model_validate_json", return_value=mock_response), + ): + search_instance = BraveSearch() + options = TextSearchOptions(include_total_count=True) + kernel_results: KernelSearchResults[str] = await search_instance.search("Test query", options) + + # Assert + results_list = [] + async for res in kernel_results.results: + results_list.append(res) + + assert len(results_list) == 1 + assert results_list[0] == "Test snippet" + assert kernel_results.total_count == 1 + assert kernel_results.metadata == { + "original": "original", + "altered": "altered something", + "show_strict_warning": False, + "spellcheck_off": False, + "country": "us", + } + + +async def test_search_http_status_error(brave_unit_test_env, async_client_mock): + """Test that search method raises ServiceInvalidRequestError on HTTPStatusError.""" + # Arrange + mock_response = MagicMock() + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Error", request=MagicMock(), response=MagicMock() + ) + async_client_mock.get.return_value = mock_response + + # Act + search_instance = BraveSearch() + + # Assert + with pytest.raises(ServiceInvalidRequestError) as exc_info: + await search_instance.search("Test query") + assert "Failed to get search results." in str(exc_info.value) + + +async def test_search_request_error(brave_unit_test_env, async_client_mock): + """Test that search method raises ServiceInvalidRequestError on RequestError.""" + # Arrange + async_client_mock.get.side_effect = httpx.RequestError("Client error") + + # Act + search_instance = BraveSearch() + + # Assert + with pytest.raises(ServiceInvalidRequestError) as exc_info: + await search_instance.search("Test query") + assert "A client error occurred while getting search results." in str(exc_info.value) + + +async def test_search_generic_exception(brave_unit_test_env, async_client_mock): + """Test that search method raises ServiceInvalidRequestError on unexpected exception.""" + # Arrange + async_client_mock.get.side_effect = Exception("Something unexpected") + + search_instance = BraveSearch() + # Assert + with pytest.raises(ServiceInvalidRequestError) as exc_info: + await search_instance.search("Test query") + assert "An unexpected error occurred while getting search results." in str(exc_info.value) + + +async def test_validate_options_raises_error_for_large_top(brave_search): + """Test that _validate_options raises ServiceInvalidRequestError when top >= 21.""" + # Arrange + options = TextSearchOptions(top=21) + + # Act / Assert + with pytest.raises(ServiceInvalidRequestError) as exc_info: + await brave_search.search("test", options) + assert "count value must be less than 21." in str(exc_info.value) + + +async def test_get_text_search_results_success(brave_unit_test_env, async_client_mock): + """Test that get_text_search_results returns KernelSearchResults[TextSearchResult].""" + + # Arrange + mock_web_pages = BraveWebPage(title="Result Name", description="Test snippet", url="test") + mock_response = BraveSearchResponse( + web_pages=MagicMock(spec=BraveWebPages, results=[mock_web_pages]), + query_context={}, + ) + + mock_result = MagicMock() + mock_result.text = """ + { "results": [{"description": "Test snippet","title":"Result Name","url":"test"}] , + "query": {} + } + """ + async_client_mock.get.return_value = mock_result + + # Act + with ( + patch.object(BraveSearchResponse, "model_validate_json", return_value=mock_response), + ): + search_instance = BraveSearch() + options = TextSearchOptions(include_total_count=True) + kernel_results: KernelSearchResults[TextSearchResult] = await search_instance.get_text_search_results( + "Test query", options + ) + + # Assert + results_list = [] + async for res in kernel_results.results: + results_list.append(res) + + assert len(results_list) == 1 + assert isinstance(results_list[0], TextSearchResult) + assert results_list[0].name == "Result Name" + assert results_list[0].value == "Test snippet" + assert results_list[0].link == "test" + assert kernel_results.total_count == 1 + + +async def test_get_search_results_success(brave_unit_test_env, async_client_mock, mock_brave_search_response): + """Test that get_search_results returns KernelSearchResults[braveWebPage].""" + # Arrange + mock_web_pages = BraveWebPage(title="Result Name", description="Page snippet", url="test") + mock_response = BraveSearchResponse( + web_pages=MagicMock(spec=BraveWebPages, results=[mock_web_pages]), + query_context={}, + ) + mock_result = MagicMock() + mock_result.text = """ + { "results": [{"description": "Page snippet","title":"Result Name","url":"test"}] , + +}""" + + async_client_mock.get.return_value = mock_result + + # Act + with ( + patch.object(BraveSearchResponse, "model_validate_json", return_value=mock_response), + ): + # Act + search_instance = BraveSearch() + options = TextSearchOptions(include_total_count=True) + kernel_results = await search_instance.get_search_results("Another query", options) + + # Assert + results_list = [] + async for res in kernel_results.results: + results_list.append(res) + + assert len(results_list) == 1 + assert isinstance(results_list[0], BraveWebPage) + assert results_list[0].title == "Result Name" + assert results_list[0].description == "Page snippet" + assert results_list[0].url == "test" + assert kernel_results.total_count == 1 + + +async def test_search_no_filter(brave_search, async_client_mock, mock_brave_search_response): + """Test that search properly sets params when no filter is provided.""" + # Arrange + options = TextSearchOptions() + + # Act + await brave_search.search("test query", options) + + # Assert + params = async_client_mock.get.call_args.kwargs["params"] + + assert params["count"] == options.top + assert params["offset"] == options.skip + + # TODO check: shouldn't this output be "test query" instead of "test query+"? + assert params["q"] == "test query" + + +async def test_search_equal_to_filter(brave_search, async_client_mock, mock_brave_search_response): + """Test that search properly sets params with an EqualTo filter.""" + + # Arrange + my_filter = SearchFilter.equal_to(field_name="spellcheck", value=True) + options = TextSearchOptions(filter=my_filter) + + # Act + await brave_search.search("test query", options) + + # Assert + params = async_client_mock.get.call_args.kwargs["params"] + + assert params["count"] == options.top + assert params["offset"] == options.skip + # 'spellcheck' is recognized in QUERY_PARAMETERS, so 'spellcheck' should be set + assert "spellcheck" in params + assert params["spellcheck"] + + assert params["q"] == "test query" + + +async def test_search_not_recognized_filter(brave_search, async_client_mock, mock_brave_search_response): + """Test that search properly appends non-recognized filters to the q parameter.""" + + # Arrange + # 'customProperty' is presumably not in QUERY_PARAMETERS + my_filter = SearchFilter.equal_to(field_name="customProperty", value="customValue") + options = TextSearchOptions(filter=my_filter) + + # Act + with pytest.raises(ServiceInvalidRequestError) as exc_info: + await brave_search.search("test query", options) + + # Assert + assert "Observed an unwanted parameter named customProperty with value customValue ." in str(exc_info.value)