Skip to content

Commit a305317

Browse files
committed
fix: review comments
1 parent b102f58 commit a305317

File tree

4 files changed

+68
-15
lines changed

4 files changed

+68
-15
lines changed

haystack/components/builders/prompt_builder.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from jinja2.sandbox import SandboxedEnvironment
99

1010
from haystack import component, default_to_dict, logging
11-
from haystack.utils import Jinja2TimeExtension
11+
from haystack.utils import Jinja2TimeExtension, extract_declared_variables
1212

1313
logger = logging.getLogger(__name__)
1414

@@ -174,17 +174,9 @@ def __init__(
174174
self._env = SandboxedEnvironment()
175175

176176
self.template = self._env.from_string(template)
177-
if not variables:
178-
# infer variables from template
179-
ast = self._env.parse(template)
180-
template_variables = meta.find_undeclared_variables(ast)
181177

182-
assigned_variables = set()
183-
for node in ast.find_all((nodes.Assign, nodes.For)):
184-
if hasattr(node.target, "name"):
185-
assigned_variables.add(node.target.name)
178+
variables = extract_declared_variables(template, env=self._env)
186179

187-
variables = list(template_variables - assigned_variables)
188180
variables = variables or []
189181
self.variables = variables
190182

haystack/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"device": ["ComponentDevice", "Device", "DeviceMap", "DeviceType"],
1616
"deserialization": ["deserialize_document_store_in_init_params_inplace", "deserialize_chatgenerator_inplace"],
1717
"filters": ["document_matches_filter", "raise_on_invalid_filter_syntax"],
18+
"jinja2": ["extract_declared_variables"],
1819
"jinja2_extensions": ["Jinja2TimeExtension"],
1920
"jupyter": ["is_in_jupyter"],
2021
"misc": ["expit", "expand_page_range"],
@@ -40,6 +41,7 @@
4041
from .device import DeviceType as DeviceType
4142
from .filters import document_matches_filter as document_matches_filter
4243
from .filters import raise_on_invalid_filter_syntax as raise_on_invalid_filter_syntax
44+
from .jinja2 import extract_declared_variables as extract_declared_variables
4345
from .jinja2_extensions import Jinja2TimeExtension as Jinja2TimeExtension
4446
from .jupyter import is_in_jupyter as is_in_jupyter
4547
from .misc import expand_page_range as expand_page_range

haystack/utils/jinja2.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from jinja2 import Environment, meta, nodes
2+
3+
4+
def extract_declared_variables(template_str: str, env: Environment | None = None) -> list[str]:
5+
"""
6+
Extract declared variables from a Jinja2 template string.
7+
8+
Args:
9+
template_str (str): The Jinja2 template string to analyze.
10+
env (Environment, optional): The Jinja2 Environment. Defaults to None.
11+
12+
Returns:
13+
A list of variable names used in the template.
14+
"""
15+
env = env or Environment()
16+
17+
try:
18+
ast = env.parse(template_str)
19+
except Exception as e:
20+
raise RuntimeError(f"Failed to parse Jinja2 template: {e}")
21+
22+
# Find undeclared variables
23+
template_variables = meta.find_undeclared_variables(ast)
24+
25+
# Collect all variables assigned inside the template via {% set %}
26+
assigned_variables = set()
27+
28+
for node in ast.find_all(nodes.Assign):
29+
if isinstance(node.target, nodes.Name):
30+
assigned_variables.add(node.target.name)
31+
elif isinstance(node.target, (nodes.List, nodes.Tuple)):
32+
for name_node in node.target.items:
33+
if isinstance(name_node, nodes.Name):
34+
assigned_variables.add(name_node.name)
35+
36+
variables = list(template_variables - assigned_variables)
37+
return variables

test/components/builders/test_prompt_builder.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -352,10 +352,32 @@ def test_template_assigned_variables_from_required_inputs(self) -> None:
352352
"""
353353

354354
builder = PromptBuilder(template=template, required_variables="*")
355-
docs = [Document(content="Doc 1"), Document(content="Doc 2")]
356355

357-
res = builder.run(docs=docs, existing_documents=None)
356+
builder = PromptBuilder(template=template, required_variables="*")
357+
assert set(builder.variables) == {"docs", "existing_documents"}
358+
359+
def test_variables_correct_with_tuple_assignment(self):
360+
template = """{% if existing_documents is not none %}
361+
{% set x, y = (existing_documents|length, 1) %}
362+
{% else %}
363+
{% set x, y = (0, 1) %}
364+
{% endif %}
365+
x={{ x }}, y={{ y }}
366+
"""
367+
builder = PromptBuilder(template=template, required_variables="*")
368+
assert set(builder.variables) == {"existing_documents"}
369+
res = builder.run(existing_documents=None)
370+
assert "x=0, y=1" in res["prompt"]
358371

359-
assert "<document reference=" in res["prompt"]
360-
assert "Doc 1" in res["prompt"]
361-
assert "Doc 2" in res["prompt"]
372+
def test_variables_correct_with_list_assignment(self):
373+
template = """{% if existing_documents is not none %}
374+
{% set x, y = [existing_documents|length, 1] %}
375+
{% else %}
376+
{% set x, y = [0, 1] %}
377+
{% endif %}
378+
x={{ x }}, y={{ y }}
379+
"""
380+
builder = PromptBuilder(template=template, required_variables="*")
381+
assert set(builder.variables) == {"existing_documents"}
382+
res = builder.run(existing_documents=None)
383+
assert "x=0, y=1" in res["prompt"]

0 commit comments

Comments
 (0)