Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add current date in UTC to PromptBuilder #8233

Merged
merged 16 commits into from
Sep 9, 2024
Merged
5 changes: 3 additions & 2 deletions haystack/components/builders/prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from jinja2.sandbox import SandboxedEnvironment

from haystack import component, default_to_dict
from haystack.utils import Jinja2TimeExtension


@component
Expand Down Expand Up @@ -160,10 +161,10 @@ def __init__(
self._required_variables = required_variables
self.required_variables = required_variables or []

self._env = SandboxedEnvironment()
self._env = SandboxedEnvironment(extensions=[Jinja2TimeExtension])
self.template = self._env.from_string(template)
if not variables:
# infere variables from template
# infer variables from template
ast = self._env.parse(template)
template_variables = meta.find_undeclared_variables(ast)
variables = list(template_variables)
Expand Down
2 changes: 2 additions & 0 deletions haystack/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .docstore_deserialization import deserialize_document_store_in_init_params_inplace
from .expit import expit
from .filters import document_matches_filter
from .jinja2_extensions import Jinja2TimeExtension
from .jupyter import is_in_jupyter
from .requests_utils import request_with_retry
from .type_serialization import deserialize_type, serialize_type
Expand All @@ -28,4 +29,5 @@
"serialize_type",
"deserialize_type",
"deserialize_document_store_in_init_params_inplace",
"Jinja2TimeExtension",
]
91 changes: 91 additions & 0 deletions haystack/utils/jinja2_extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, List, Optional, Union

import arrow
from jinja2 import Environment, nodes
from jinja2.ext import Extension


class Jinja2TimeExtension(Extension):
# Syntax for current date
tags = {"now"}

def __init__(self, environment: Environment): # pylint: disable=useless-parent-delegation
"""
Initializes the JinjaTimeExtension object.

:param environment: The Jinja2 environment to initialize the extension with.
It provides the context where the extension will operate.
"""
super().__init__(environment)

@staticmethod
def _get_datetime(
timezone: str,
operator: Optional[str] = None,
offset: Optional[str] = None,
datetime_format: Optional[str] = None,
) -> str:
"""
Get the current datetime based on timezone, apply any offset if provided, and format the result.

:param timezone: The timezone string (e.g., 'UTC' or 'America/New_York') for which the current
time should be fetched.
:param operator: The operator ('+' or '-') to apply to the offset (used for adding/subtracting intervals).
Defaults to None if no offset is applied, otherwise default is '+'.
:param offset: The offset string in the format 'interval=value' (e.g., 'hours=2,days=1') specifying how much
to adjust the datetime. The intervals can be any valid interval accepted
by Arrow (e.g., hours, days, weeks, months). Defaults to None if no adjustment is needed.
:param datetime_format: The format string to use for formatting the output datetime.
Defaults to '%Y-%m-%d %H:%M:%S' if not provided.
"""
try:
dt = arrow.now(timezone)
except Exception as e:
raise ValueError(f"Invalid timezone {timezone}: {e}")

if offset and operator:
try:
# Parse the offset and apply it to the datetime object
replace_params = {
interval.strip(): float(operator + value.strip())
for param in offset.split(",")
for interval, value in [param.split("=")]
}
# Shift the datetime fields based on the parsed offset
dt = dt.shift(**replace_params)
except (ValueError, AttributeError) as e:
raise ValueError(f"Invalid offset or operator {offset}, {operator}: {e}")

# Use the provided format or fallback to the default one
datetime_format = datetime_format or "%Y-%m-%d %H:%M:%S"

return dt.strftime(datetime_format)

def parse(self, parser: Any) -> Union[nodes.Node, List[nodes.Node]]:
"""
Parse the template expression to determine how to handle the datetime formatting.

:param parser: The parser object that processes the template expressions and manages the syntax tree.
It's used to interpret the template's structure.
"""
lineno = next(parser.stream).lineno
node = parser.parse_expression()
# Check if a custom datetime format is provided after a comma
datetime_format = parser.parse_expression() if parser.stream.skip_if("comma") else nodes.Const(None)

# Default Add when no operator is provided
operator = "+" if isinstance(node, nodes.Add) else "-"
Comment on lines +80 to +81
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this read "# Default Minus when ..." since the else statement is "-"?

# Call the _get_datetime method with the appropriate operator and offset, if exist
call_method = self.call_method(
"_get_datetime",
[node.left, nodes.Const(operator), node.right, datetime_format]
if isinstance(node, (nodes.Add, nodes.Sub))
else [node, nodes.Const(None), nodes.Const(None), datetime_format],
lineno=lineno,
)

return nodes.Output([call_method], lineno=lineno)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ dependencies = [
"numpy<2",
"python-dateutil",
"haystack-experimental",
"arrow>=1.3.0" # Jinja2TimeExtension
]

[tool.hatch.envs.default]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
enhancements:
- |
Add the current date inside a template in `PromptBuilder` using the `utc_now()`.
Users can also specify the date format, such as `utc_now("%Y-%m-%d")`.
59 changes: 58 additions & 1 deletion test/components/builders/test_prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from typing import Any, Dict, List, Optional
from jinja2 import TemplateSyntaxError
import pytest

from unittest.mock import patch, MagicMock
import arrow
from haystack.components.builders.prompt_builder import PromptBuilder
from haystack import component
from haystack.core.pipeline.pipeline import Pipeline
Expand Down Expand Up @@ -254,3 +255,59 @@ def test_example_in_pipeline_simple(self):
"prompt_builder": {"prompt": "This is the dynamic prompt:\n Query: Where does the speaker live?"}
}
assert result == expected_dynamic

def test_with_custom_dateformat(self) -> None:
template = "Formatted date: {% now 'UTC', '%Y-%m-%d' %}"
builder = PromptBuilder(template=template)

result = builder.run()["prompt"]

now_formatted = f"Formatted date: {arrow.now('UTC').strftime('%Y-%m-%d')}"

assert now_formatted == result

def test_with_different_timezone(self) -> None:
template = "Current time in New York is: {% now 'America/New_York' %}"
builder = PromptBuilder(template=template)

result = builder.run()["prompt"]

now_ny = f"Current time in New York is: {arrow.now('America/New_York').strftime('%Y-%m-%d %H:%M:%S')}"

assert now_ny == result

def test_date_with_addition_offset(self) -> None:
template = "Time after 2 hours is: {% now 'UTC' + 'hours=2' %}"
builder = PromptBuilder(template=template)

result = builder.run()["prompt"]

now_plus_2 = f"Time after 2 hours is: {(arrow.now('UTC').shift(hours=+2)).strftime('%Y-%m-%d %H:%M:%S')}"

assert now_plus_2 == result

def test_date_with_substraction_offset(self) -> None:
template = "Time after 12 days is: {% now 'UTC' - 'days=12' %}"
builder = PromptBuilder(template=template)

result = builder.run()["prompt"]

now_plus_2 = f"Time after 12 days is: {(arrow.now('UTC').shift(days=-12)).strftime('%Y-%m-%d %H:%M:%S')}"

assert now_plus_2 == result

def test_invalid_timezone(self) -> None:
template = "Current time is: {% now 'Invalid/Timezone' %}"
builder = PromptBuilder(template=template)

# Expect ValueError for invalid timezone
with pytest.raises(ValueError, match="Invalid timezone"):
builder.run()

def test_invalid_offset(self) -> None:
template = "Time after invalid offset is: {% now 'UTC' + 'invalid_offset' %}"
builder = PromptBuilder(template=template)

# Expect ValueError for invalid offset
with pytest.raises(ValueError, match="Invalid offset or operator"):
builder.run()
104 changes: 104 additions & 0 deletions test/utils/test_jinja2_extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

import pytest
from jinja2 import Environment
import arrow
from haystack.utils import Jinja2TimeExtension


class TestJinja2TimeExtension:
@pytest.fixture
def jinja_env(self) -> Environment:
return Environment(extensions=[Jinja2TimeExtension])

@pytest.fixture
def jinja_extension(self, jinja_env: Environment) -> Jinja2TimeExtension:
return Jinja2TimeExtension(jinja_env)

def test_valid_datetime(self, jinja_extension: Jinja2TimeExtension) -> None:
result = jinja_extension._get_datetime(
"UTC", operator="+", offset="hours=2", datetime_format="%Y-%m-%d %H:%M:%S"
)
assert isinstance(result, str)
assert len(result) == 19

def test_parse_valid_expression(self, jinja_env: Environment) -> None:
template = "{% now 'UTC' + 'hours=2', '%Y-%m-%d %H:%M:%S' %}"
result = jinja_env.from_string(template).render()
assert isinstance(result, str)
assert len(result) == 19

def test_get_datetime_no_offset(self, jinja_extension: Jinja2TimeExtension) -> None:
result = jinja_extension._get_datetime("UTC")
expected = arrow.now("UTC").strftime("%Y-%m-%d %H:%M:%S")
assert result == expected

def test_get_datetime_with_offset_add(self, jinja_extension: Jinja2TimeExtension) -> None:
result = jinja_extension._get_datetime("UTC", operator="+", offset="hours=1")
expected = arrow.now("UTC").shift(hours=1).strftime("%Y-%m-%d %H:%M:%S")
assert result == expected

def test_get_datetime_with_offset_subtract(self, jinja_extension: Jinja2TimeExtension) -> None:
result = jinja_extension._get_datetime("UTC", operator="-", offset="days=1")
expected = arrow.now("UTC").shift(days=-1).strftime("%Y-%m-%d %H:%M:%S")
assert result == expected

def test_get_datetime_with_offset_subtract_days_hours(self, jinja_extension: Jinja2TimeExtension) -> None:
result = jinja_extension._get_datetime("UTC", operator="-", offset="days=1, hours=2")
expected = arrow.now("UTC").shift(days=-1, hours=-2).strftime("%Y-%m-%d %H:%M:%S")
assert result == expected

def test_get_datetime_with_custom_format(self, jinja_extension: Jinja2TimeExtension) -> None:
result = jinja_extension._get_datetime("UTC", datetime_format="%d-%m-%Y")
expected = arrow.now("UTC").strftime("%d-%m-%Y")
assert result == expected

def test_get_datetime_new_york_timezone(self, jinja_env: Environment) -> None:
template = jinja_env.from_string("{% now 'America/New_York' %}")
result = template.render()
expected = arrow.now("America/New_York").strftime("%Y-%m-%d %H:%M:%S")
assert result == expected

def test_parse_no_operator(self, jinja_env: Environment) -> None:
template = jinja_env.from_string("{% now 'UTC' %}")
result = template.render()
expected = arrow.now("UTC").strftime("%Y-%m-%d %H:%M:%S")
assert result == expected

def test_parse_with_add(self, jinja_env: Environment) -> None:
template = jinja_env.from_string("{% now 'UTC' + 'hours=2' %}")
result = template.render()
expected = arrow.now("UTC").shift(hours=2).strftime("%Y-%m-%d %H:%M:%S")
assert result == expected

def test_parse_with_subtract(self, jinja_env: Environment) -> None:
template = jinja_env.from_string("{% now 'UTC' - 'days=1' %}")
result = template.render()
expected = arrow.now("UTC").shift(days=-1).strftime("%Y-%m-%d %H:%M:%S")
assert result == expected

def test_parse_with_custom_format(self, jinja_env: Environment) -> None:
template = jinja_env.from_string("{% now 'UTC', '%d-%m-%Y' %}")
result = template.render()
expected = arrow.now("UTC").strftime("%d-%m-%Y")
assert result == expected

def test_default_format(self, jinja_env: Environment) -> None:
template = jinja_env.from_string("{% now 'UTC'%}")
result = template.render()
expected = arrow.now("UTC").strftime("%Y-%m-%d %H:%M:%S") # default format
assert result == expected

def test_invalid_timezone(self, jinja_extension: Jinja2TimeExtension) -> None:
with pytest.raises(ValueError, match="Invalid timezone"):
jinja_extension._get_datetime("Invalid/Timezone")

def test_invalid_offset(self, jinja_extension: Jinja2TimeExtension) -> None:
with pytest.raises(ValueError, match="Invalid offset or operator"):
jinja_extension._get_datetime("UTC", operator="+", offset="invalid_format")

def test_invalid_operator(self, jinja_extension: Jinja2TimeExtension) -> None:
with pytest.raises(ValueError, match="Invalid offset or operator"):
jinja_extension._get_datetime("UTC", operator="*", offset="hours=2")