Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 113 additions & 102 deletions python/packages/autogen-ext/src/autogen_ext/tools/azure/_ai_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def __init__(
credential: Union[AzureKeyCredential, TokenCredential, Dict[str, str]],
description: Optional[str] = None,
api_version: str = "2023-11-01",
query_type: Literal["keyword", "fulltext", "vector", "hybrid"] = "keyword",
query_type: Literal["keyword", "fulltext", "vector", "semantic"] = "keyword",
search_fields: Optional[List[str]] = None,
select_fields: Optional[List[str]] = None,
vector_fields: Optional[List[str]] = None,
Expand All @@ -188,7 +188,7 @@ def __init__(
credential (Union[AzureKeyCredential, TokenCredential, Dict[str, str]]): Azure credential for authentication (API key or token)
description (Optional[str]): Optional description explaining the tool's purpose
api_version (str): Azure AI Search API version to use
query_type (Literal["keyword", "fulltext", "vector", "hybrid"]): Type of search to perform
query_type (Literal["keyword", "fulltext", "vector", "semantic"]): Type of search to perform
search_fields (Optional[List[str]]): Fields to search within documents
select_fields (Optional[List[str]]): Fields to return in search results
vector_fields (Optional[List[str]]): Fields to use for vector search
Expand Down Expand Up @@ -230,6 +230,7 @@ def __init__(
vector_fields=vector_fields,
top=top,
filter=filter,
semantic_config_name=semantic_config_name,
enable_caching=enable_caching,
cache_ttl_seconds=cache_ttl_seconds,
)
Expand Down Expand Up @@ -361,7 +362,9 @@ async def run(
if self.search_config.vector_fields:
vector_fields_list = self.search_config.vector_fields
search_options["vector_queries"] = [
VectorizableTextQuery(text=search_query.query, k=int(self.search_config.top or 5), fields=field)
VectorizableTextQuery(
text=search_query.query, k_nearest_neighbors=int(self.search_config.top or 5), fields=field
)
for field in vector_fields_list
]

Expand Down Expand Up @@ -517,15 +520,15 @@ def _from_config(cls, config: Any) -> "BaseAzureAISearchTool":
query_type_str = getattr(config, "query_type", "keyword")

query_type_mapping = {
"simple": "keyword",
"keyword": "keyword",
"simple": "fulltext",
"fulltext": "fulltext",
"vector": "vector",
"hybrid": "hybrid",
"semantic": "semantic",
}

query_type = cast(
Literal["keyword", "fulltext", "vector", "hybrid"], query_type_mapping.get(query_type_str, "vector")
Literal["keyword", "fulltext", "vector", "semantic"], query_type_mapping.get(query_type_str, "vector")
)

openai_client_attr = getattr(config, "openai_client", None)
Expand All @@ -536,6 +539,8 @@ def _from_config(cls, config: Any) -> "BaseAzureAISearchTool":
if not embedding_model_attr:
raise ValueError("embedding_model must be specified in config")

# If query_type="semantic", you must provide a valid semantic_config_name.
# If query_type is anything else, semantic_config_name is ignored.
return cls(
name=getattr(config, "name", ""),
endpoint=getattr(config, "endpoint", ""),
Expand All @@ -549,6 +554,7 @@ def _from_config(cls, config: Any) -> "BaseAzureAISearchTool":
vector_fields=getattr(config, "vector_fields", None),
top=getattr(config, "top", None),
filter=getattr(config, "filter", None),
semantic_config_name=getattr(config, "semantic_config_name", None),
enable_caching=getattr(config, "enable_caching", False),
cache_ttl_seconds=getattr(config, "cache_ttl_seconds", 300),
)
Expand Down Expand Up @@ -653,7 +659,7 @@ class AzureAISearchTool(BaseAzureAISearchTool):
1. Keyword Search: Traditional text-based search using Azure's text analysis
2. Full-Text Search: Enhanced text search with language-specific analyzers
3. Vector Search: Semantic similarity search using vector embeddings
4. Hybrid Search: Combines text and vector search for comprehensive results
4. Hybrid Search: Combines fulltext and vector search for comprehensive results

You should use the factory methods to create instances for specific search types:
- create_keyword_search()
Expand All @@ -668,7 +674,7 @@ def __init__(
endpoint: str,
index_name: str,
credential: Union[AzureKeyCredential, TokenCredential, Dict[str, str]],
query_type: Literal["keyword", "fulltext", "vector", "hybrid"],
query_type: Literal["keyword", "fulltext", "vector", "semantic"],
search_fields: Optional[List[str]] = None,
select_fields: Optional[List[str]] = None,
vector_fields: Optional[List[str]] = None,
Expand Down Expand Up @@ -731,15 +737,15 @@ def load_component(
query_type_str = config.get("query_type", "keyword")

query_type_mapping = {
"simple": "keyword",
"keyword": "keyword",
"simple": "fulltext",
"fulltext": "fulltext",
"vector": "vector",
"hybrid": "hybrid",
"semantic": "semantic",
}

query_type = cast(
Literal["keyword", "fulltext", "vector", "hybrid"], query_type_mapping.get(query_type_str, "vector")
Literal["keyword", "fulltext", "vector", "semantic"], query_type_mapping.get(query_type_str, "vector")
)

instance = cls(
Expand Down Expand Up @@ -919,17 +925,12 @@ def create_full_text_search(

token = _allow_private_constructor.set(True)
try:
query_type = cast(
Literal["keyword", "fulltext", "vector", "hybrid"],
"fulltext",
)

return cls(
name=name,
endpoint=endpoint,
index_name=index_name,
credential=credential,
query_type=query_type,
query_type="fulltext",
search_fields=search_fields,
select_fields=select_fields,
filter=filter,
Expand Down Expand Up @@ -1016,6 +1017,101 @@ def create_vector_search(
finally:
_allow_private_constructor.reset(token)

@classmethod
def create_hybrid_search(
cls,
name: str,
endpoint: str,
index_name: str,
credential: Union[AzureKeyCredential, TokenCredential, Dict[str, str]],
vector_fields: List[str],
search_fields: Optional[List[str]] = None,
select_fields: Optional[List[str]] = None,
filter: Optional[str] = None,
top: Optional[int] = 5,
**kwargs: Any,
) -> "AzureAISearchTool":
"""Factory method to create a hybrid search tool (text + vector).

Hybrid search combines text search (fulltext or semantic) with vector similarity
search to provide more comprehensive results. This is the recommended entrypoint for hybrid (text + vector) search.
The query_type will be 'semantic' if semantic_config_name is provided, otherwise 'fulltext'.

Args:
name (str): The name of the tool
endpoint (str): The URL of your Azure AI Search service
index_name (str): The name of the search index
credential (Union[AzureKeyCredential, TokenCredential, Dict[str, str]]): Authentication credentials
vector_fields (List[str]): Fields containing vector embeddings for similarity search
search_fields (Optional[List[str]]): Fields to search within for text search
select_fields (Optional[List[str]]): Fields to include in results
filter (Optional[str]): OData filter expression to filter results
top (Optional[int]): Maximum number of results to return
**kwargs (Any): Additional configuration options

Returns:
An initialized hybrid search tool

Example Usage:
.. code-block:: python

# type: ignore
# Example of using hybrid search with Azure AI Search
from autogen_ext.tools.azure import AzureAISearchTool
from azure.core.credentials import AzureKeyCredential

# Create a hybrid search tool
hybrid_search = AzureAISearchTool.create_hybrid_search(
name="hybrid_search",
endpoint="https://your-search-service.search.windows.net",
index_name="your-index",
credential=AzureKeyCredential("your-api-key"),
vector_fields=["embedding_field"],
search_fields=["title", "content"],
select_fields=["title", "content", "url", "date"],
top=10,
)

# The search tool can be used with an Agent
# assistant = Agent("researcher", tools=[hybrid_search])


.. warning::

If you set ``query_type=\"semantic\"``, you must also provide a valid ``semantic_config_name``.
If you do not, the tool will default to the config name ``\"semantic\"``.

"""
cls._validate_common_params(name, endpoint, index_name, credential)

if not vector_fields or len(vector_fields) == 0:
raise ValueError("vector_fields must contain at least one field name")

token = _allow_private_constructor.set(True)
try:
if kwargs.get("semantic_config_name"):
text_query_type = "semantic"
else:
text_query_type = "fulltext"

from typing import cast

return cls(
name=name,
endpoint=endpoint,
index_name=index_name,
credential=credential,
query_type=cast(Literal["keyword", "fulltext", "vector", "semantic"], text_query_type),
search_fields=search_fields,
select_fields=select_fields,
vector_fields=vector_fields,
filter=filter,
top=top,
**kwargs,
)
finally:
_allow_private_constructor.reset(token)

async def _get_embedding(self, query: str) -> List[float]:
"""Generate embedding vector for the query text.

Expand Down Expand Up @@ -1095,88 +1191,3 @@ def get_token() -> str:
f"Unsupported embedding provider: {embedding_provider}. "
"Currently supported providers are 'azure_openai' and 'openai'."
) from None

@classmethod
def create_hybrid_search(
cls,
name: str,
endpoint: str,
index_name: str,
credential: Union[AzureKeyCredential, TokenCredential, Dict[str, str]],
vector_fields: List[str],
search_fields: Optional[List[str]] = None,
select_fields: Optional[List[str]] = None,
filter: Optional[str] = None,
top: Optional[int] = 5,
**kwargs: Any,
) -> "AzureAISearchTool":
"""Factory method to create a hybrid search tool.

Hybrid search combines text search (keyword or semantic) with vector similarity
search to provide more comprehensive results.

This method doesn't use a separate "hybrid" type but instead configures either
a "keyword" or "semantic" text search and combines it with vector search.

Args:
name (str): The name of the tool
endpoint (str): The URL of your Azure AI Search service
index_name (str): The name of the search index
credential (Union[AzureKeyCredential, TokenCredential, Dict[str, str]]): Authentication credentials
vector_fields (List[str]): Fields containing vector embeddings for similarity search
search_fields (Optional[List[str]]): Fields to search within for text search
select_fields (Optional[List[str]]): Fields to include in results
filter (Optional[str]): OData filter expression to filter results
top (Optional[int]): Maximum number of results to return
**kwargs (Any): Additional configuration options

Returns:
An initialized hybrid search tool

Example Usage:
.. code-block:: python

# type: ignore
# Example of using hybrid search with Azure AI Search
from autogen_ext.tools.azure import AzureAISearchTool
from azure.core.credentials import AzureKeyCredential

# Create a hybrid search tool
hybrid_search = AzureAISearchTool.create_hybrid_search(
name="hybrid_search",
endpoint="https://your-search-service.search.windows.net",
index_name="your-index",
credential=AzureKeyCredential("your-api-key"),
vector_fields=["embedding_field"],
search_fields=["title", "content"],
select_fields=["title", "content", "url", "date"],
top=10,
)

# The search tool can be used with an Agent
# assistant = Agent("researcher", tools=[hybrid_search])
"""
cls._validate_common_params(name, endpoint, index_name, credential)

if not vector_fields or len(vector_fields) == 0:
raise ValueError("vector_fields must contain at least one field name")

token = _allow_private_constructor.set(True)
try:
text_query_type = cast(Literal["keyword", "fulltext", "vector", "hybrid"], "hybrid")

return cls(
name=name,
endpoint=endpoint,
index_name=index_name,
credential=credential,
query_type=text_query_type,
search_fields=search_fields,
select_fields=select_fields,
vector_fields=vector_fields,
filter=filter,
top=top,
**kwargs,
)
finally:
_allow_private_constructor.reset(token)
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ class AzureAISearchConfig(BaseModel):
credential (Union[AzureKeyCredential, TokenCredential]): Azure authentication credential:
- AzureKeyCredential: For API key authentication (admin/query key)
- TokenCredential: For Azure AD authentication (e.g., DefaultAzureCredential)
query_type (Literal["keyword", "fulltext", "vector", "hybrid"]): The search query mode to use:
query_type (Literal["keyword", "fulltext", "vector", "semantic"]): The search query mode to use:
- 'keyword': Basic keyword search (default)
- 'full': Full Lucene query syntax
- 'fulltext': Full Lucene query syntax
- 'vector': Vector similarity search
- 'hybrid': Hybrid search combining multiple techniques
- 'semantic': Semantic search using semantic configuration
search_fields (Optional[List[str]]): List of index fields to search within. If not specified,
searches all searchable fields. Example: ['title', 'content'].
select_fields (Optional[List[str]]): Fields to return in search results. If not specified,
Expand All @@ -104,8 +104,9 @@ class AzureAISearchConfig(BaseModel):
credential: Union[AzureKeyCredential, TokenCredential] = Field(
description="The credential to use for authentication"
)
query_type: Literal["keyword", "fulltext", "vector", "hybrid"] = Field(
default="keyword", description="Type of query to perform"
query_type: Literal["keyword", "fulltext", "vector", "semantic"] = Field(
default="keyword",
description="Type of query to perform (keyword for classic, fulltext for Lucene, vector for embedding, semantic for semantic/AI search)",
)
search_fields: Optional[List[str]] = Field(default=None, description="Optional list of fields to search in")
select_fields: Optional[List[str]] = Field(default=None, description="Optional list of fields to return in results")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ async def test_create_vector_search() -> None:

@pytest.mark.asyncio
async def test_create_hybrid_search() -> None:
"""Test the create_hybrid_search factory method."""
"""Test the create_hybrid_search factory method (hybrid = text + vector, query_type will be 'fulltext' or 'semantic')."""
tool = ConcreteAzureAISearchTool.create_hybrid_search(
name="hybrid_search",
endpoint="https://test.search.windows.net",
Expand All @@ -217,7 +217,7 @@ async def test_create_hybrid_search() -> None:
)

assert tool.name == "hybrid_search"
assert tool.search_config.query_type == "hybrid"
assert tool.search_config.query_type in ("fulltext", "semantic")
assert tool.search_config.vector_fields == ["embedding"]
assert tool.search_config.search_fields == ["title", "content"]

Expand Down Expand Up @@ -843,7 +843,7 @@ async def test_factory_method_validation() -> None:
)

with pytest.raises(ValueError, match="vector_fields must contain at least one field name"):
ConcreteAzureAISearchTool.create_vector_search(
ConcreteAzureAISearchTool.create_hybrid_search(
name="test",
endpoint="https://test.search.windows.net",
index_name="test-index",
Expand Down Expand Up @@ -1001,16 +1001,16 @@ async def test_fallback_vectorizable_text_query() -> None:
"""Test the fallback VectorizableTextQuery class when Azure SDK is not available."""

class MockVectorizableTextQuery:
def __init__(self, text: str, k: int, fields: str) -> None:
def __init__(self, text: str, k_nearest_neighbors: int, fields: str) -> None:
self.text = text
self.k = k
self.k_nearest_neighbors = k_nearest_neighbors
self.fields = fields

query1 = MockVectorizableTextQuery(text="test query", k=5, fields="title")
query1 = MockVectorizableTextQuery(text="test query", k_nearest_neighbors=5, fields="title")
assert query1.text == "test query"
assert query1.fields == "title"

query2 = MockVectorizableTextQuery(text="test query", k=3, fields="title,content")
query2 = MockVectorizableTextQuery(text="test query", k_nearest_neighbors=3, fields="title,content")
assert query2.text == "test query"
assert query2.fields == "title,content"

Expand Down
Loading