diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/azure/_ai_search.py b/python/packages/autogen-ext/src/autogen_ext/tools/azure/_ai_search.py index 07da00e1b9e9..cf0570d82727 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/azure/_ai_search.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/azure/_ai_search.py @@ -169,7 +169,7 @@ def __init__( credential: Union[AzureKeyCredential, TokenCredential, Dict[str, str]], description: Optional[str] = None, api_version: str = "2023-11-01", - query_type: Literal["keyword", "fulltext", "vector", "hybrid"] = "keyword", + query_type: Literal["keyword", "fulltext", "vector", "semantic"] = "keyword", search_fields: Optional[List[str]] = None, select_fields: Optional[List[str]] = None, vector_fields: Optional[List[str]] = None, @@ -188,7 +188,7 @@ def __init__( credential (Union[AzureKeyCredential, TokenCredential, Dict[str, str]]): Azure credential for authentication (API key or token) description (Optional[str]): Optional description explaining the tool's purpose api_version (str): Azure AI Search API version to use - query_type (Literal["keyword", "fulltext", "vector", "hybrid"]): Type of search to perform + query_type (Literal["keyword", "fulltext", "vector", "semantic"]): Type of search to perform search_fields (Optional[List[str]]): Fields to search within documents select_fields (Optional[List[str]]): Fields to return in search results vector_fields (Optional[List[str]]): Fields to use for vector search @@ -230,6 +230,7 @@ def __init__( vector_fields=vector_fields, top=top, filter=filter, + semantic_config_name=semantic_config_name, enable_caching=enable_caching, cache_ttl_seconds=cache_ttl_seconds, ) @@ -361,7 +362,9 @@ async def run( if self.search_config.vector_fields: vector_fields_list = self.search_config.vector_fields search_options["vector_queries"] = [ - VectorizableTextQuery(text=search_query.query, k=int(self.search_config.top or 5), fields=field) + VectorizableTextQuery( + text=search_query.query, k_nearest_neighbors=int(self.search_config.top or 5), fields=field + ) for field in vector_fields_list ] @@ -517,15 +520,15 @@ def _from_config(cls, config: Any) -> "BaseAzureAISearchTool": query_type_str = getattr(config, "query_type", "keyword") query_type_mapping = { - "simple": "keyword", "keyword": "keyword", + "simple": "fulltext", "fulltext": "fulltext", "vector": "vector", - "hybrid": "hybrid", + "semantic": "semantic", } query_type = cast( - Literal["keyword", "fulltext", "vector", "hybrid"], query_type_mapping.get(query_type_str, "vector") + Literal["keyword", "fulltext", "vector", "semantic"], query_type_mapping.get(query_type_str, "vector") ) openai_client_attr = getattr(config, "openai_client", None) @@ -536,6 +539,8 @@ def _from_config(cls, config: Any) -> "BaseAzureAISearchTool": if not embedding_model_attr: raise ValueError("embedding_model must be specified in config") + # If query_type="semantic", you must provide a valid semantic_config_name. + # If query_type is anything else, semantic_config_name is ignored. return cls( name=getattr(config, "name", ""), endpoint=getattr(config, "endpoint", ""), @@ -549,6 +554,7 @@ def _from_config(cls, config: Any) -> "BaseAzureAISearchTool": vector_fields=getattr(config, "vector_fields", None), top=getattr(config, "top", None), filter=getattr(config, "filter", None), + semantic_config_name=getattr(config, "semantic_config_name", None), enable_caching=getattr(config, "enable_caching", False), cache_ttl_seconds=getattr(config, "cache_ttl_seconds", 300), ) @@ -653,7 +659,7 @@ class AzureAISearchTool(BaseAzureAISearchTool): 1. Keyword Search: Traditional text-based search using Azure's text analysis 2. Full-Text Search: Enhanced text search with language-specific analyzers 3. Vector Search: Semantic similarity search using vector embeddings - 4. Hybrid Search: Combines text and vector search for comprehensive results + 4. Hybrid Search: Combines fulltext and vector search for comprehensive results You should use the factory methods to create instances for specific search types: - create_keyword_search() @@ -668,7 +674,7 @@ def __init__( endpoint: str, index_name: str, credential: Union[AzureKeyCredential, TokenCredential, Dict[str, str]], - query_type: Literal["keyword", "fulltext", "vector", "hybrid"], + query_type: Literal["keyword", "fulltext", "vector", "semantic"], search_fields: Optional[List[str]] = None, select_fields: Optional[List[str]] = None, vector_fields: Optional[List[str]] = None, @@ -731,15 +737,15 @@ def load_component( query_type_str = config.get("query_type", "keyword") query_type_mapping = { - "simple": "keyword", "keyword": "keyword", + "simple": "fulltext", "fulltext": "fulltext", "vector": "vector", - "hybrid": "hybrid", + "semantic": "semantic", } query_type = cast( - Literal["keyword", "fulltext", "vector", "hybrid"], query_type_mapping.get(query_type_str, "vector") + Literal["keyword", "fulltext", "vector", "semantic"], query_type_mapping.get(query_type_str, "vector") ) instance = cls( @@ -919,17 +925,12 @@ def create_full_text_search( token = _allow_private_constructor.set(True) try: - query_type = cast( - Literal["keyword", "fulltext", "vector", "hybrid"], - "fulltext", - ) - return cls( name=name, endpoint=endpoint, index_name=index_name, credential=credential, - query_type=query_type, + query_type="fulltext", search_fields=search_fields, select_fields=select_fields, filter=filter, @@ -1016,6 +1017,101 @@ def create_vector_search( finally: _allow_private_constructor.reset(token) + @classmethod + def create_hybrid_search( + cls, + name: str, + endpoint: str, + index_name: str, + credential: Union[AzureKeyCredential, TokenCredential, Dict[str, str]], + vector_fields: List[str], + search_fields: Optional[List[str]] = None, + select_fields: Optional[List[str]] = None, + filter: Optional[str] = None, + top: Optional[int] = 5, + **kwargs: Any, + ) -> "AzureAISearchTool": + """Factory method to create a hybrid search tool (text + vector). + + Hybrid search combines text search (fulltext or semantic) with vector similarity + search to provide more comprehensive results. This is the recommended entrypoint for hybrid (text + vector) search. + The query_type will be 'semantic' if semantic_config_name is provided, otherwise 'fulltext'. + + Args: + name (str): The name of the tool + endpoint (str): The URL of your Azure AI Search service + index_name (str): The name of the search index + credential (Union[AzureKeyCredential, TokenCredential, Dict[str, str]]): Authentication credentials + vector_fields (List[str]): Fields containing vector embeddings for similarity search + search_fields (Optional[List[str]]): Fields to search within for text search + select_fields (Optional[List[str]]): Fields to include in results + filter (Optional[str]): OData filter expression to filter results + top (Optional[int]): Maximum number of results to return + **kwargs (Any): Additional configuration options + + Returns: + An initialized hybrid search tool + + Example Usage: + .. code-block:: python + + # type: ignore + # Example of using hybrid search with Azure AI Search + from autogen_ext.tools.azure import AzureAISearchTool + from azure.core.credentials import AzureKeyCredential + + # Create a hybrid search tool + hybrid_search = AzureAISearchTool.create_hybrid_search( + name="hybrid_search", + endpoint="https://your-search-service.search.windows.net", + index_name="your-index", + credential=AzureKeyCredential("your-api-key"), + vector_fields=["embedding_field"], + search_fields=["title", "content"], + select_fields=["title", "content", "url", "date"], + top=10, + ) + + # The search tool can be used with an Agent + # assistant = Agent("researcher", tools=[hybrid_search]) + + + .. warning:: + + If you set ``query_type=\"semantic\"``, you must also provide a valid ``semantic_config_name``. + If you do not, the tool will default to the config name ``\"semantic\"``. + + """ + cls._validate_common_params(name, endpoint, index_name, credential) + + if not vector_fields or len(vector_fields) == 0: + raise ValueError("vector_fields must contain at least one field name") + + token = _allow_private_constructor.set(True) + try: + if kwargs.get("semantic_config_name"): + text_query_type = "semantic" + else: + text_query_type = "fulltext" + + from typing import cast + + return cls( + name=name, + endpoint=endpoint, + index_name=index_name, + credential=credential, + query_type=cast(Literal["keyword", "fulltext", "vector", "semantic"], text_query_type), + search_fields=search_fields, + select_fields=select_fields, + vector_fields=vector_fields, + filter=filter, + top=top, + **kwargs, + ) + finally: + _allow_private_constructor.reset(token) + async def _get_embedding(self, query: str) -> List[float]: """Generate embedding vector for the query text. @@ -1095,88 +1191,3 @@ def get_token() -> str: f"Unsupported embedding provider: {embedding_provider}. " "Currently supported providers are 'azure_openai' and 'openai'." ) from None - - @classmethod - def create_hybrid_search( - cls, - name: str, - endpoint: str, - index_name: str, - credential: Union[AzureKeyCredential, TokenCredential, Dict[str, str]], - vector_fields: List[str], - search_fields: Optional[List[str]] = None, - select_fields: Optional[List[str]] = None, - filter: Optional[str] = None, - top: Optional[int] = 5, - **kwargs: Any, - ) -> "AzureAISearchTool": - """Factory method to create a hybrid search tool. - - Hybrid search combines text search (keyword or semantic) with vector similarity - search to provide more comprehensive results. - - This method doesn't use a separate "hybrid" type but instead configures either - a "keyword" or "semantic" text search and combines it with vector search. - - Args: - name (str): The name of the tool - endpoint (str): The URL of your Azure AI Search service - index_name (str): The name of the search index - credential (Union[AzureKeyCredential, TokenCredential, Dict[str, str]]): Authentication credentials - vector_fields (List[str]): Fields containing vector embeddings for similarity search - search_fields (Optional[List[str]]): Fields to search within for text search - select_fields (Optional[List[str]]): Fields to include in results - filter (Optional[str]): OData filter expression to filter results - top (Optional[int]): Maximum number of results to return - **kwargs (Any): Additional configuration options - - Returns: - An initialized hybrid search tool - - Example Usage: - .. code-block:: python - - # type: ignore - # Example of using hybrid search with Azure AI Search - from autogen_ext.tools.azure import AzureAISearchTool - from azure.core.credentials import AzureKeyCredential - - # Create a hybrid search tool - hybrid_search = AzureAISearchTool.create_hybrid_search( - name="hybrid_search", - endpoint="https://your-search-service.search.windows.net", - index_name="your-index", - credential=AzureKeyCredential("your-api-key"), - vector_fields=["embedding_field"], - search_fields=["title", "content"], - select_fields=["title", "content", "url", "date"], - top=10, - ) - - # The search tool can be used with an Agent - # assistant = Agent("researcher", tools=[hybrid_search]) - """ - cls._validate_common_params(name, endpoint, index_name, credential) - - if not vector_fields or len(vector_fields) == 0: - raise ValueError("vector_fields must contain at least one field name") - - token = _allow_private_constructor.set(True) - try: - text_query_type = cast(Literal["keyword", "fulltext", "vector", "hybrid"], "hybrid") - - return cls( - name=name, - endpoint=endpoint, - index_name=index_name, - credential=credential, - query_type=text_query_type, - search_fields=search_fields, - select_fields=select_fields, - vector_fields=vector_fields, - filter=filter, - top=top, - **kwargs, - ) - finally: - _allow_private_constructor.reset(token) diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/azure/_config.py b/python/packages/autogen-ext/src/autogen_ext/tools/azure/_config.py index a80a2d07983c..a27fdd6776af 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/azure/_config.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/azure/_config.py @@ -75,11 +75,11 @@ class AzureAISearchConfig(BaseModel): credential (Union[AzureKeyCredential, TokenCredential]): Azure authentication credential: - AzureKeyCredential: For API key authentication (admin/query key) - TokenCredential: For Azure AD authentication (e.g., DefaultAzureCredential) - query_type (Literal["keyword", "fulltext", "vector", "hybrid"]): The search query mode to use: + query_type (Literal["keyword", "fulltext", "vector", "semantic"]): The search query mode to use: - 'keyword': Basic keyword search (default) - - 'full': Full Lucene query syntax + - 'fulltext': Full Lucene query syntax - 'vector': Vector similarity search - - 'hybrid': Hybrid search combining multiple techniques + - 'semantic': Semantic search using semantic configuration search_fields (Optional[List[str]]): List of index fields to search within. If not specified, searches all searchable fields. Example: ['title', 'content']. select_fields (Optional[List[str]]): Fields to return in search results. If not specified, @@ -104,8 +104,9 @@ class AzureAISearchConfig(BaseModel): credential: Union[AzureKeyCredential, TokenCredential] = Field( description="The credential to use for authentication" ) - query_type: Literal["keyword", "fulltext", "vector", "hybrid"] = Field( - default="keyword", description="Type of query to perform" + query_type: Literal["keyword", "fulltext", "vector", "semantic"] = Field( + default="keyword", + description="Type of query to perform (keyword for classic, fulltext for Lucene, vector for embedding, semantic for semantic/AI search)", ) search_fields: Optional[List[str]] = Field(default=None, description="Optional list of fields to search in") select_fields: Optional[List[str]] = Field(default=None, description="Optional list of fields to return in results") diff --git a/python/packages/autogen-ext/tests/tools/azure/test_ai_search_tool.py b/python/packages/autogen-ext/tests/tools/azure/test_ai_search_tool.py index 60084c6b2a39..13b70e823d62 100644 --- a/python/packages/autogen-ext/tests/tools/azure/test_ai_search_tool.py +++ b/python/packages/autogen-ext/tests/tools/azure/test_ai_search_tool.py @@ -204,7 +204,7 @@ async def test_create_vector_search() -> None: @pytest.mark.asyncio async def test_create_hybrid_search() -> None: - """Test the create_hybrid_search factory method.""" + """Test the create_hybrid_search factory method (hybrid = text + vector, query_type will be 'fulltext' or 'semantic').""" tool = ConcreteAzureAISearchTool.create_hybrid_search( name="hybrid_search", endpoint="https://test.search.windows.net", @@ -217,7 +217,7 @@ async def test_create_hybrid_search() -> None: ) assert tool.name == "hybrid_search" - assert tool.search_config.query_type == "hybrid" + assert tool.search_config.query_type in ("fulltext", "semantic") assert tool.search_config.vector_fields == ["embedding"] assert tool.search_config.search_fields == ["title", "content"] @@ -843,7 +843,7 @@ async def test_factory_method_validation() -> None: ) with pytest.raises(ValueError, match="vector_fields must contain at least one field name"): - ConcreteAzureAISearchTool.create_vector_search( + ConcreteAzureAISearchTool.create_hybrid_search( name="test", endpoint="https://test.search.windows.net", index_name="test-index", @@ -1001,16 +1001,16 @@ async def test_fallback_vectorizable_text_query() -> None: """Test the fallback VectorizableTextQuery class when Azure SDK is not available.""" class MockVectorizableTextQuery: - def __init__(self, text: str, k: int, fields: str) -> None: + def __init__(self, text: str, k_nearest_neighbors: int, fields: str) -> None: self.text = text - self.k = k + self.k_nearest_neighbors = k_nearest_neighbors self.fields = fields - query1 = MockVectorizableTextQuery(text="test query", k=5, fields="title") + query1 = MockVectorizableTextQuery(text="test query", k_nearest_neighbors=5, fields="title") assert query1.text == "test query" assert query1.fields == "title" - query2 = MockVectorizableTextQuery(text="test query", k=3, fields="title,content") + query2 = MockVectorizableTextQuery(text="test query", k_nearest_neighbors=3, fields="title,content") assert query2.text == "test query" assert query2.fields == "title,content"