Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions haystack/components/generators/chat/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
deserialize_tools_or_toolset_inplace,
flatten_tools_or_toolsets,
serialize_tools_or_toolset,
warm_up_tools,
)
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.http_client import init_http_client
Expand Down Expand Up @@ -201,6 +202,18 @@ def __init__( # pylint: disable=too-many-positional-arguments
self.async_client = AsyncAzureOpenAI(
http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_args
)
self._is_warmed_up = False

def warm_up(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah @HamidOna do we need this one as its parent already implements the warm_up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, Azure needs both the flag initialization and the warm_up() method because:

Azure does not call super().init() - See line 73 and the comment on lines 152-53. It explicitly skips the parent's initialization because it only needs to instantiate the Azure-specific client.

Since Azure's init() doesn't call the parent's init():

  1. The self._is_warmed_up = False flag from OpenAI's init() is never set
  2. Without this flag, calling any inherited warm_up() would fail with AttributeError making it a requirement for Azure

The warm_up() method implementation is identical to OpenAI's, so technically we could remove Azure's warm_up() method and let it inherit from OpenAI (since the flag is now initialized). However, It makes sense to keep both for explicitness. Let me know if you'd prefer I remove the warm_up() method from Azure and just keep the flag initialization!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, right this is a tension between DRY and explicitness you advocate. We could simply add:
self._is_warmed_up = False
in Azure init and be done with it. Let me ask my colleague @sjrl for an opinion

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather reimplemt the warm up method so it's more explicit.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok thanks @sjrl then we are gtg here @HamidOna

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙏 @HamidOna keep these PRs coming. Great work!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HamidOna you can, if you want, also add this method to all Chat Generators in https://github.com/deepset-ai/haystack-core-integrations/ Just lmk

"""
Warm up the Azure OpenAI chat generator.

This will warm up the tools registered in the chat generator.
This method is idempotent and will only warm up the tools once.
"""
if not self._is_warmed_up:
warm_up_tools(self.tools)
self._is_warmed_up = True

def to_dict(self) -> dict[str, Any]:
"""
Expand Down
10 changes: 10 additions & 0 deletions haystack/components/generators/chat/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ def from_dict(cls, data: dict[str, Any]) -> FallbackChatGenerator:
data["init_parameters"] = init_params
return default_from_dict(cls, data)

def warm_up(self) -> None:
"""
Warm up all underlying chat generators.

This method calls warm_up() on each underlying generator that supports it.
"""
for gen in self.chat_generators:
if hasattr(gen, "warm_up") and callable(gen.warm_up):
gen.warm_up()

def _run_single_sync( # pylint: disable=too-many-positional-arguments
self,
gen: Any,
Expand Down
13 changes: 13 additions & 0 deletions haystack/components/generators/chat/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
deserialize_tools_or_toolset_inplace,
flatten_tools_or_toolsets,
serialize_tools_or_toolset,
warm_up_tools,
)
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format
Expand Down Expand Up @@ -384,6 +385,18 @@ def __init__( # pylint: disable=too-many-positional-arguments
model_or_url, token=token.resolve_value() if token else None, **resolved_api_params
)
self.tools = tools
self._is_warmed_up = False

def warm_up(self):
"""
Warm up the Hugging Face API chat generator.

This will warm up the tools registered in the chat generator.
This method is idempotent and will only warm up the tools once.
"""
if not self._is_warmed_up:
warm_up_tools(self.tools)
self._is_warmed_up = True

def to_dict(self) -> dict[str, Any]:
"""
Expand Down
14 changes: 13 additions & 1 deletion haystack/components/generators/chat/hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
flatten_tools_or_toolsets,
serialize_tools_or_toolset,
)
from haystack.tools.utils import warm_up_tools
from haystack.utils import (
ComponentDevice,
Secret,
Expand Down Expand Up @@ -249,6 +250,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
if async_executor is None
else async_executor
)
self._is_warmed_up = False

def __del__(self) -> None:
"""
Expand All @@ -274,11 +276,21 @@ def _get_telemetry_data(self) -> dict[str, Any]:

def warm_up(self) -> None:
"""
Initializes the component.
Initializes the component and warms up tools if provided.
"""
if self._is_warmed_up:
return

# Initialize the pipeline (existing logic)
if self.pipeline is None:
self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)

# Warm up tools (new logic)
if self.tools:
warm_up_tools(self.tools)

self._is_warmed_up = True

def to_dict(self) -> dict[str, Any]:
"""
Serializes the component to a dictionary.
Expand Down
13 changes: 13 additions & 0 deletions haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
deserialize_tools_or_toolset_inplace,
flatten_tools_or_toolsets,
serialize_tools_or_toolset,
warm_up_tools,
)
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.http_client import init_http_client
Expand Down Expand Up @@ -200,6 +201,18 @@ def __init__( # pylint: disable=too-many-positional-arguments
self.async_client = AsyncOpenAI(
http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_kwargs
)
self._is_warmed_up = False

def warm_up(self):
"""
Warm up the OpenAI chat generator.

This will warm up the tools registered in the chat generator.
This method is idempotent and will only warm up the tools once.
"""
if not self._is_warmed_up:
warm_up_tools(self.tools)
self._is_warmed_up = True

def _get_telemetry_data(self) -> dict[str, Any]:
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
features:
- |
Added warm_up() method to all ChatGenerator components (OpenAIChatGenerator,
AzureOpenAIChatGenerator, HuggingFaceAPIChatGenerator, HuggingFaceLocalChatGenerator,
and FallbackChatGenerator) to properly initialize tools that require warm-up before
pipeline execution. The warm_up() method is idempotent and follows the same pattern
used in Agent and ToolInvoker components. This enables proper tool initialization
in pipelines that use ChatGenerators with tools but without an Agent component.
104 changes: 104 additions & 0 deletions test/components/generators/chat/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,110 @@ def test_to_dict_with_toolset(self, tools, monkeypatch):
}
assert data["init_parameters"]["tools"] == expected_tools_data

def test_warm_up_with_tools(self, monkeypatch):
"""Test that warm_up() calls warm_up on tools and is idempotent."""
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")

# Create a mock tool that tracks if warm_up() was called
class MockTool(Tool):
warm_up_call_count = 0 # Class variable to track calls

def __init__(self):
super().__init__(
name="mock_tool",
description="A mock tool for testing",
parameters={"x": {"type": "string"}},
function=lambda x: x,
)

def warm_up(self):
MockTool.warm_up_call_count += 1

# Reset the class variable before test
MockTool.warm_up_call_count = 0
mock_tool = MockTool()

# Create AzureOpenAIChatGenerator with the mock tool
component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint", tools=[mock_tool])

# Verify initial state - warm_up not called yet
assert MockTool.warm_up_call_count == 0
assert not component._is_warmed_up

# Call warm_up() on the generator
component.warm_up()

# Assert that the tool's warm_up() was called
assert MockTool.warm_up_call_count == 1
assert component._is_warmed_up

# Call warm_up() again and verify it's idempotent (only warms up once)
component.warm_up()

# The tool's warm_up should still only have been called once
assert MockTool.warm_up_call_count == 1
assert component._is_warmed_up

def test_warm_up_with_no_tools(self, monkeypatch):
"""Test that warm_up() works when no tools are provided."""
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")

component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint")

# Verify initial state
assert not component._is_warmed_up
assert component.tools is None

# Call warm_up() - should not raise an error
component.warm_up()

# Verify the component is warmed up
assert component._is_warmed_up

# Call warm_up() again - should be idempotent
component.warm_up()
assert component._is_warmed_up

def test_warm_up_with_multiple_tools(self, monkeypatch):
"""Test that warm_up() works with multiple tools."""
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")

# Track warm_up calls
warm_up_calls = []

class MockTool(Tool):
def __init__(self, tool_name):
super().__init__(
name=tool_name,
description=f"Mock tool {tool_name}",
parameters={"type": "object", "properties": {"x": {"type": "string"}}, "required": ["x"]},
function=lambda x: f"{tool_name} result: {x}",
)

def warm_up(self):
warm_up_calls.append(self.name)

mock_tool1 = MockTool("tool1")
mock_tool2 = MockTool("tool2")

# Use a LIST of tools, not a Toolset
component = AzureOpenAIChatGenerator(
azure_endpoint="some-non-existing-endpoint", tools=[mock_tool1, mock_tool2]
)

# Call warm_up()
component.warm_up()

# Assert that both tools' warm_up() were called
assert "tool1" in warm_up_calls
assert "tool2" in warm_up_calls
assert component._is_warmed_up

# Test idempotency - warm_up should not call tools again
initial_count = len(warm_up_calls)
component.warm_up()
assert len(warm_up_calls) == initial_count


class TestAzureOpenAIChatGeneratorAsync:
def test_init_should_also_create_async_client_with_same_args(self, tools):
Expand Down
68 changes: 68 additions & 0 deletions test/components/generators/chat/test_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,71 @@ async def test_failover_trigger_401_authentication_async():
assert result["replies"][0].text == "success_after_auth"
assert result["meta"]["successful_chat_generator_index"] == 1
assert result["meta"]["failed_chat_generators"] == ["_DummyHTTPErrorGen"]


@component
class _DummyGenWithWarmUp:
"""Dummy generator that tracks warm_up calls."""

def __init__(self, text: str = "ok"):
self.text = text
self.warm_up_called = False

def warm_up(self) -> None:
self.warm_up_called = True

def run(
self,
messages: list[ChatMessage],
generation_kwargs: Optional[dict[str, Any]] = None,
tools: Optional[ToolsType] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
) -> dict[str, Any]:
return {"replies": [ChatMessage.from_assistant(self.text)], "meta": {}}


def test_warm_up_delegates_to_generators():
"""Test that warm_up() is called on each underlying generator."""
gen1 = _DummyGenWithWarmUp(text="A")
gen2 = _DummyGenWithWarmUp(text="B")
gen3 = _DummyGenWithWarmUp(text="C")

fallback = FallbackChatGenerator(chat_generators=[gen1, gen2, gen3])
fallback.warm_up()

assert gen1.warm_up_called
assert gen2.warm_up_called
assert gen3.warm_up_called


def test_warm_up_with_no_warm_up_method():
"""Test that warm_up() handles generators without warm_up() gracefully."""
gen1 = _DummySuccessGen(text="A")
gen2 = _DummySuccessGen(text="B")

fallback = FallbackChatGenerator(chat_generators=[gen1, gen2])
# Should not raise any error
fallback.warm_up()

# Verify generators still work
result = fallback.run([ChatMessage.from_user("test")])
assert result["replies"][0].text == "A"


def test_warm_up_mixed_generators():
"""Test warm_up() with a mix of generators with and without warm_up()."""
gen1 = _DummyGenWithWarmUp(text="A")
gen2 = _DummySuccessGen(text="B")
gen3 = _DummyGenWithWarmUp(text="C")
gen4 = _DummyFailGen()

fallback = FallbackChatGenerator(chat_generators=[gen1, gen2, gen3, gen4])
fallback.warm_up()

# Only generators with warm_up() should have been called
assert gen1.warm_up_called
assert gen3.warm_up_called

# Verify the fallback still works correctly
result = fallback.run([ChatMessage.from_user("test")])
assert result["replies"][0].text == "A"
Loading
Loading