Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions haystack/components/builders/chat_prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses.chat_message import ChatMessage, ChatRole, TextContent
from haystack.lazy_imports import LazyImport
from haystack.utils import Jinja2TimeExtension
from haystack.utils import Jinja2TimeExtension, extract_declared_variables
from haystack.utils.jinja2_chat_extension import ChatMessageExtension, templatize_part

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -171,21 +171,30 @@ def __init__(

extracted_variables = []
if template and not variables:

def _extract_from_text(
text: Optional[str], role: Optional[str] = None, is_filter_allowed: bool = False
) -> list:
if text is None:
raise ValueError(NO_TEXT_ERROR_MESSAGE.format(role=role or "unknown", message=text))
if is_filter_allowed and "templatize_part" in text:
raise ValueError(FILTER_NOT_ALLOWED_ERROR_MESSAGE)

ast = self._env.parse(text)
template_variables = meta.find_undeclared_variables(ast)
assigned_variables = extract_declared_variables(text, env=self._env)
return list(template_variables - assigned_variables)

if isinstance(template, list):
for message in template:
if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
# infer variables from template
if message.text is None:
raise ValueError(NO_TEXT_ERROR_MESSAGE.format(role=message.role.value, message=message))
if message.text and "templatize_part" in message.text:
raise ValueError(FILTER_NOT_ALLOWED_ERROR_MESSAGE)
ast = self._env.parse(message.text)
template_variables = meta.find_undeclared_variables(ast)
extracted_variables += list(template_variables)
extracted_variables += _extract_from_text(
message.text, role=message.role.value, is_filter_allowed=True
)
elif isinstance(template, str):
ast = self._env.parse(template)
extracted_variables = list(meta.find_undeclared_variables(ast))
extracted_variables = _extract_from_text(template, is_filter_allowed=False)

extracted_variables = extracted_variables or []
self.variables = variables or extracted_variables
self.required_variables = required_variables or []

Expand Down
9 changes: 6 additions & 3 deletions haystack/components/builders/prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from jinja2.sandbox import SandboxedEnvironment

from haystack import component, default_to_dict, logging
from haystack.utils import Jinja2TimeExtension
from haystack.utils import Jinja2TimeExtension, extract_declared_variables

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -174,11 +174,14 @@ def __init__(
self._env = SandboxedEnvironment()

self.template = self._env.from_string(template)

if not variables:
# infer variables from template
ast = self._env.parse(template)
template_variables = meta.find_undeclared_variables(ast)
variables = list(template_variables)
assigned_variables = extract_declared_variables(template, env=self._env)

variables = list(template_variables - assigned_variables)

variables = variables or []
self.variables = variables

Expand Down
2 changes: 2 additions & 0 deletions haystack/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"device": ["ComponentDevice", "Device", "DeviceMap", "DeviceType"],
"deserialization": ["deserialize_document_store_in_init_params_inplace", "deserialize_chatgenerator_inplace"],
"filters": ["document_matches_filter", "raise_on_invalid_filter_syntax"],
"jinja2": ["extract_declared_variables"],
"jinja2_extensions": ["Jinja2TimeExtension"],
"jupyter": ["is_in_jupyter"],
"misc": ["expit", "expand_page_range"],
Expand All @@ -40,6 +41,7 @@
from .device import DeviceType as DeviceType
from .filters import document_matches_filter as document_matches_filter
from .filters import raise_on_invalid_filter_syntax as raise_on_invalid_filter_syntax
from .jinja2 import extract_declared_variables as extract_declared_variables
from .jinja2_extensions import Jinja2TimeExtension as Jinja2TimeExtension
from .jupyter import is_in_jupyter as is_in_jupyter
from .misc import expand_page_range as expand_page_range
Expand Down
39 changes: 39 additions & 0 deletions haystack/utils/jinja2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

from typing import Optional

from jinja2 import Environment, nodes


def extract_declared_variables(template_str: str, env: Optional[Environment] = None) -> set:
"""
Extract declared variables from a Jinja2 template string.

Args:
template_str (str): The Jinja2 template string to analyze.
env (Environment, optional): The Jinja2 Environment. Defaults to None.

Returns:
A list of variable names used in the template.
"""
env = env or Environment()

try:
ast = env.parse(template_str)
except Exception as e:
raise RuntimeError(f"Failed to parse Jinja2 template: {e}")

# Collect all variables assigned inside the template via {% set %}
assigned_variables = set()

for node in ast.find_all(nodes.Assign):
if isinstance(node.target, nodes.Name):
assigned_variables.add(node.target.name)
elif isinstance(node.target, (nodes.List, nodes.Tuple)):
for name_node in node.target.items:
if isinstance(name_node, nodes.Name):
assigned_variables.add(name_node.name)

return assigned_variables
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
fixes:
- |
Fixed an issue where Jinja2 variable assignments using the `set` directive
were not being parsed correctly in certain contexts. This fix ensures that
variables assigned with `{% set var = value %}` are now properly recognized
and can be used as expected within templates inside `PromptBuilder` and
`ChatPromptBuilder`.
28 changes: 28 additions & 0 deletions test/components/builders/test_chat_prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,3 +957,31 @@ def test_from_dict(self):
assert builder.template == template
assert builder.variables == ["name", "assistant_name"]
assert builder.required_variables == ["name"]

def test_variables_correct_with_tuple_assignment(self):
template = """{% if existing_documents is not none %}
{% set x, y = (existing_documents|length, 1) %}
{% else %}
{% set x, y = (0, 1) %}
{% endif %}
{% message role="user" %}x={{ x }}, y={{ y }}{% endmessage %}
"""
builder = ChatPromptBuilder(template=template, required_variables="*")
assert set(builder.variables) == {"existing_documents"}
res = builder.run(existing_documents=None)
prompt = res["prompt"]
assert any("x=0, y=1" in msg.text for msg in prompt)

def test_variables_correct_with_list_assignment(self):
template = """{% if existing_documents is not none %}
{% set x, y = [existing_documents|length, 1] %}
{% else %}
{% set x, y = [0, 1] %}
{% endif %}
{% message role="user" %}x={{ x }}, y={{ y }}{% endmessage %}
"""
builder = ChatPromptBuilder(template=template, required_variables="*")
assert set(builder.variables) == {"existing_documents"}
res = builder.run(existing_documents=None)
prompt = res["prompt"]
assert any("x=0, y=1" in msg.text for msg in prompt)
44 changes: 44 additions & 0 deletions test/components/builders/test_prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,3 +337,47 @@ def test_warning_no_required_variables(self, caplog):
with caplog.at_level(logging.WARNING):
_ = PromptBuilder(template="This is a {{ variable }}")
assert "but `required_variables` is not set." in caplog.text

def test_template_assigned_variables_from_required_inputs(self) -> None:
template = """{% if existing_documents is not none %}
{% set existing_doc_len = existing_documents|length %}
{% else %}
{% set existing_doc_len = 0 %}
{% endif %}
{% for doc in docs %}
<document reference="{{loop.index + existing_doc_len}}">
{{ doc.content }}
</document>
{% endfor %}
"""

builder = PromptBuilder(template=template, required_variables="*")

builder = PromptBuilder(template=template, required_variables="*")
assert set(builder.variables) == {"docs", "existing_documents"}

def test_variables_correct_with_tuple_assignment(self):
template = """{% if existing_documents is not none %}
{% set x, y = (existing_documents|length, 1) %}
{% else %}
{% set x, y = (0, 1) %}
{% endif %}
x={{ x }}, y={{ y }}
"""
builder = PromptBuilder(template=template, required_variables="*")
assert set(builder.variables) == {"existing_documents"}
res = builder.run(existing_documents=None)
assert "x=0, y=1" in res["prompt"]

def test_variables_correct_with_list_assignment(self):
template = """{% if existing_documents is not none %}
{% set x, y = [existing_documents|length, 1] %}
{% else %}
{% set x, y = [0, 1] %}
{% endif %}
x={{ x }}, y={{ y }}
"""
builder = PromptBuilder(template=template, required_variables="*")
assert set(builder.variables) == {"existing_documents"}
res = builder.run(existing_documents=None)
assert "x=0, y=1" in res["prompt"]
Loading