Skip to content

Commit 6c78f10

Browse files
HamidOnaHamidOna13
andauthored
feat: Add warm_up() method to ChatGenerators for tool initialization (#9942)
* Add warm_up() method to OpenAIChatGenerator - Add warm_up() method that calls warm_up_tools() - Add _is_warmed_up flag for idempotency - Import warm_up_tools from haystack.tools - Add comprehensive tests: - test_warm_up_with_tools: single tool case - test_warm_up_with_no_tools: no tools case - test_warm_up_with_multiple_tools: multiple tools case - All tests passing Part of issue #9907 * Add warm_up() method to AzureOpenAIChatGenerator - Add warm_up() method that calls warm_up_tools() - Add _is_warmed_up flag for idempotency - Import warm_up_tools from haystack.tools - Add comprehensive tests: - test_warm_up_with_tools: single tool case - test_warm_up_with_no_tools: no tools case - test_warm_up_with_multiple_tools: multiple tools case - All tests passing Part of issue #9907 * Add warm_up() method to HuggingFaceAPIChatGenerator - Add warm_up() method that calls warm_up_tools() - Add _is_warmed_up flag for idempotency - Import warm_up_tools from haystack.tools - Add comprehensive tests: - test_warm_up_with_tools: single tool case - test_warm_up_with_no_tools: no tools case - test_warm_up_with_multiple_tools: multiple tools case - All tests passing Part of issue #9907 * Enhance warm_up() method in HuggingFaceLocalChatGenerator - Add warm_up_tools import from haystack.tools.utils - Add _is_warmed_up flag for idempotency - Enhance existing warm_up() to also warm up tools - Preserve existing pipeline initialization logic - Add comprehensive tests: - test_warm_up_with_tools: single tool case - test_warm_up_with_no_tools: no tools case - test_warm_up_with_multiple_tools: multiple tools case Part of issue #9907 * Add warm_up() method to FallbackChatGenerator - Add warm_up() method that delegates to underlying generators - Uses hasattr check to gracefully handle generators without warm_up - Add comprehensive tests: - test_warm_up_delegates_to_generators: verify delegation works - test_warm_up_with_no_warm_up_method: handle missing warm_up gracefully - test_warm_up_mixed_generators: mix of generators with/without warm_up - All tests passing Part of issue #9907 * docs: Add release notes for warm_up() feature --------- Co-authored-by: HamidOna13 <abdulhamid.onawole@aizatron.com>
1 parent deb5a95 commit 6c78f10

File tree

11 files changed

+593
-1
lines changed

11 files changed

+593
-1
lines changed

haystack/components/generators/chat/azure.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
deserialize_tools_or_toolset_inplace,
1919
flatten_tools_or_toolsets,
2020
serialize_tools_or_toolset,
21+
warm_up_tools,
2122
)
2223
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
2324
from haystack.utils.http_client import init_http_client
@@ -201,6 +202,18 @@ def __init__( # pylint: disable=too-many-positional-arguments
201202
self.async_client = AsyncAzureOpenAI(
202203
http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_args
203204
)
205+
self._is_warmed_up = False
206+
207+
def warm_up(self):
208+
"""
209+
Warm up the Azure OpenAI chat generator.
210+
211+
This will warm up the tools registered in the chat generator.
212+
This method is idempotent and will only warm up the tools once.
213+
"""
214+
if not self._is_warmed_up:
215+
warm_up_tools(self.tools)
216+
self._is_warmed_up = True
204217

205218
def to_dict(self) -> dict[str, Any]:
206219
"""

haystack/components/generators/chat/fallback.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,16 @@ def from_dict(cls, data: dict[str, Any]) -> FallbackChatGenerator:
8181
data["init_parameters"] = init_params
8282
return default_from_dict(cls, data)
8383

84+
def warm_up(self) -> None:
85+
"""
86+
Warm up all underlying chat generators.
87+
88+
This method calls warm_up() on each underlying generator that supports it.
89+
"""
90+
for gen in self.chat_generators:
91+
if hasattr(gen, "warm_up") and callable(gen.warm_up):
92+
gen.warm_up()
93+
8494
def _run_single_sync( # pylint: disable=too-many-positional-arguments
8595
self,
8696
gen: Any,

haystack/components/generators/chat/hugging_face_api.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
deserialize_tools_or_toolset_inplace,
2727
flatten_tools_or_toolsets,
2828
serialize_tools_or_toolset,
29+
warm_up_tools,
2930
)
3031
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
3132
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format
@@ -384,6 +385,18 @@ def __init__( # pylint: disable=too-many-positional-arguments
384385
model_or_url, token=token.resolve_value() if token else None, **resolved_api_params
385386
)
386387
self.tools = tools
388+
self._is_warmed_up = False
389+
390+
def warm_up(self):
391+
"""
392+
Warm up the Hugging Face API chat generator.
393+
394+
This will warm up the tools registered in the chat generator.
395+
This method is idempotent and will only warm up the tools once.
396+
"""
397+
if not self._is_warmed_up:
398+
warm_up_tools(self.tools)
399+
self._is_warmed_up = True
387400

388401
def to_dict(self) -> dict[str, Any]:
389402
"""

haystack/components/generators/chat/hugging_face_local.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
flatten_tools_or_toolsets,
2424
serialize_tools_or_toolset,
2525
)
26+
from haystack.tools.utils import warm_up_tools
2627
from haystack.utils import (
2728
ComponentDevice,
2829
Secret,
@@ -249,6 +250,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
249250
if async_executor is None
250251
else async_executor
251252
)
253+
self._is_warmed_up = False
252254

253255
def __del__(self) -> None:
254256
"""
@@ -274,11 +276,21 @@ def _get_telemetry_data(self) -> dict[str, Any]:
274276

275277
def warm_up(self) -> None:
276278
"""
277-
Initializes the component.
279+
Initializes the component and warms up tools if provided.
278280
"""
281+
if self._is_warmed_up:
282+
return
283+
284+
# Initialize the pipeline (existing logic)
279285
if self.pipeline is None:
280286
self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
281287

288+
# Warm up tools (new logic)
289+
if self.tools:
290+
warm_up_tools(self.tools)
291+
292+
self._is_warmed_up = True
293+
282294
def to_dict(self) -> dict[str, Any]:
283295
"""
284296
Serializes the component to a dictionary.

haystack/components/generators/chat/openai.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
deserialize_tools_or_toolset_inplace,
4242
flatten_tools_or_toolsets,
4343
serialize_tools_or_toolset,
44+
warm_up_tools,
4445
)
4546
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
4647
from haystack.utils.http_client import init_http_client
@@ -200,6 +201,18 @@ def __init__( # pylint: disable=too-many-positional-arguments
200201
self.async_client = AsyncOpenAI(
201202
http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_kwargs
202203
)
204+
self._is_warmed_up = False
205+
206+
def warm_up(self):
207+
"""
208+
Warm up the OpenAI chat generator.
209+
210+
This will warm up the tools registered in the chat generator.
211+
This method is idempotent and will only warm up the tools once.
212+
"""
213+
if not self._is_warmed_up:
214+
warm_up_tools(self.tools)
215+
self._is_warmed_up = True
203216

204217
def _get_telemetry_data(self) -> dict[str, Any]:
205218
"""
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
---
2+
features:
3+
- |
4+
Added warm_up() method to all ChatGenerator components (OpenAIChatGenerator,
5+
AzureOpenAIChatGenerator, HuggingFaceAPIChatGenerator, HuggingFaceLocalChatGenerator,
6+
and FallbackChatGenerator) to properly initialize tools that require warm-up before
7+
pipeline execution. The warm_up() method is idempotent and follows the same pattern
8+
used in Agent and ToolInvoker components. This enables proper tool initialization
9+
in pipelines that use ChatGenerators with tools but without an Agent component.

test/components/generators/chat/test_azure.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,110 @@ def test_to_dict_with_toolset(self, tools, monkeypatch):
428428
}
429429
assert data["init_parameters"]["tools"] == expected_tools_data
430430

431+
def test_warm_up_with_tools(self, monkeypatch):
432+
"""Test that warm_up() calls warm_up on tools and is idempotent."""
433+
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
434+
435+
# Create a mock tool that tracks if warm_up() was called
436+
class MockTool(Tool):
437+
warm_up_call_count = 0 # Class variable to track calls
438+
439+
def __init__(self):
440+
super().__init__(
441+
name="mock_tool",
442+
description="A mock tool for testing",
443+
parameters={"x": {"type": "string"}},
444+
function=lambda x: x,
445+
)
446+
447+
def warm_up(self):
448+
MockTool.warm_up_call_count += 1
449+
450+
# Reset the class variable before test
451+
MockTool.warm_up_call_count = 0
452+
mock_tool = MockTool()
453+
454+
# Create AzureOpenAIChatGenerator with the mock tool
455+
component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint", tools=[mock_tool])
456+
457+
# Verify initial state - warm_up not called yet
458+
assert MockTool.warm_up_call_count == 0
459+
assert not component._is_warmed_up
460+
461+
# Call warm_up() on the generator
462+
component.warm_up()
463+
464+
# Assert that the tool's warm_up() was called
465+
assert MockTool.warm_up_call_count == 1
466+
assert component._is_warmed_up
467+
468+
# Call warm_up() again and verify it's idempotent (only warms up once)
469+
component.warm_up()
470+
471+
# The tool's warm_up should still only have been called once
472+
assert MockTool.warm_up_call_count == 1
473+
assert component._is_warmed_up
474+
475+
def test_warm_up_with_no_tools(self, monkeypatch):
476+
"""Test that warm_up() works when no tools are provided."""
477+
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
478+
479+
component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint")
480+
481+
# Verify initial state
482+
assert not component._is_warmed_up
483+
assert component.tools is None
484+
485+
# Call warm_up() - should not raise an error
486+
component.warm_up()
487+
488+
# Verify the component is warmed up
489+
assert component._is_warmed_up
490+
491+
# Call warm_up() again - should be idempotent
492+
component.warm_up()
493+
assert component._is_warmed_up
494+
495+
def test_warm_up_with_multiple_tools(self, monkeypatch):
496+
"""Test that warm_up() works with multiple tools."""
497+
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
498+
499+
# Track warm_up calls
500+
warm_up_calls = []
501+
502+
class MockTool(Tool):
503+
def __init__(self, tool_name):
504+
super().__init__(
505+
name=tool_name,
506+
description=f"Mock tool {tool_name}",
507+
parameters={"type": "object", "properties": {"x": {"type": "string"}}, "required": ["x"]},
508+
function=lambda x: f"{tool_name} result: {x}",
509+
)
510+
511+
def warm_up(self):
512+
warm_up_calls.append(self.name)
513+
514+
mock_tool1 = MockTool("tool1")
515+
mock_tool2 = MockTool("tool2")
516+
517+
# Use a LIST of tools, not a Toolset
518+
component = AzureOpenAIChatGenerator(
519+
azure_endpoint="some-non-existing-endpoint", tools=[mock_tool1, mock_tool2]
520+
)
521+
522+
# Call warm_up()
523+
component.warm_up()
524+
525+
# Assert that both tools' warm_up() were called
526+
assert "tool1" in warm_up_calls
527+
assert "tool2" in warm_up_calls
528+
assert component._is_warmed_up
529+
530+
# Test idempotency - warm_up should not call tools again
531+
initial_count = len(warm_up_calls)
532+
component.warm_up()
533+
assert len(warm_up_calls) == initial_count
534+
431535

432536
class TestAzureOpenAIChatGeneratorAsync:
433537
def test_init_should_also_create_async_client_with_same_args(self, tools):

test/components/generators/chat/test_fallback.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,3 +354,71 @@ async def test_failover_trigger_401_authentication_async():
354354
assert result["replies"][0].text == "success_after_auth"
355355
assert result["meta"]["successful_chat_generator_index"] == 1
356356
assert result["meta"]["failed_chat_generators"] == ["_DummyHTTPErrorGen"]
357+
358+
359+
@component
360+
class _DummyGenWithWarmUp:
361+
"""Dummy generator that tracks warm_up calls."""
362+
363+
def __init__(self, text: str = "ok"):
364+
self.text = text
365+
self.warm_up_called = False
366+
367+
def warm_up(self) -> None:
368+
self.warm_up_called = True
369+
370+
def run(
371+
self,
372+
messages: list[ChatMessage],
373+
generation_kwargs: Optional[dict[str, Any]] = None,
374+
tools: Optional[ToolsType] = None,
375+
streaming_callback: Optional[StreamingCallbackT] = None,
376+
) -> dict[str, Any]:
377+
return {"replies": [ChatMessage.from_assistant(self.text)], "meta": {}}
378+
379+
380+
def test_warm_up_delegates_to_generators():
381+
"""Test that warm_up() is called on each underlying generator."""
382+
gen1 = _DummyGenWithWarmUp(text="A")
383+
gen2 = _DummyGenWithWarmUp(text="B")
384+
gen3 = _DummyGenWithWarmUp(text="C")
385+
386+
fallback = FallbackChatGenerator(chat_generators=[gen1, gen2, gen3])
387+
fallback.warm_up()
388+
389+
assert gen1.warm_up_called
390+
assert gen2.warm_up_called
391+
assert gen3.warm_up_called
392+
393+
394+
def test_warm_up_with_no_warm_up_method():
395+
"""Test that warm_up() handles generators without warm_up() gracefully."""
396+
gen1 = _DummySuccessGen(text="A")
397+
gen2 = _DummySuccessGen(text="B")
398+
399+
fallback = FallbackChatGenerator(chat_generators=[gen1, gen2])
400+
# Should not raise any error
401+
fallback.warm_up()
402+
403+
# Verify generators still work
404+
result = fallback.run([ChatMessage.from_user("test")])
405+
assert result["replies"][0].text == "A"
406+
407+
408+
def test_warm_up_mixed_generators():
409+
"""Test warm_up() with a mix of generators with and without warm_up()."""
410+
gen1 = _DummyGenWithWarmUp(text="A")
411+
gen2 = _DummySuccessGen(text="B")
412+
gen3 = _DummyGenWithWarmUp(text="C")
413+
gen4 = _DummyFailGen()
414+
415+
fallback = FallbackChatGenerator(chat_generators=[gen1, gen2, gen3, gen4])
416+
fallback.warm_up()
417+
418+
# Only generators with warm_up() should have been called
419+
assert gen1.warm_up_called
420+
assert gen3.warm_up_called
421+
422+
# Verify the fallback still works correctly
423+
result = fallback.run([ChatMessage.from_user("test")])
424+
assert result["replies"][0].text == "A"

0 commit comments

Comments
 (0)