diff --git a/haystack/components/converters/output_adapter.py b/haystack/components/converters/output_adapter.py index 0e1682c6be..50ffd53919 100644 --- a/haystack/components/converters/output_adapter.py +++ b/haystack/components/converters/output_adapter.py @@ -4,11 +4,11 @@ import ast import contextlib -from typing import Any, Callable, Dict, Optional, Set, Union +from typing import Any, Callable, Dict, Optional, Set from warnings import warn import jinja2.runtime -from jinja2 import TemplateSyntaxError, meta +from jinja2 import Environment, TemplateSyntaxError, meta from jinja2.nativetypes import NativeEnvironment from jinja2.sandbox import SandboxedEnvironment from typing_extensions import TypeAlias @@ -77,9 +77,9 @@ def __init__( "Use this only if you trust the source of the template." ) warn(msg) - self._env = NativeEnvironment() - else: - self._env = SandboxedEnvironment(undefined=jinja2.runtime.StrictUndefined) + self._env = ( + NativeEnvironment() if self._unsafe else SandboxedEnvironment(undefined=jinja2.runtime.StrictUndefined) + ) try: self._env.parse(template) # Validate template syntax @@ -173,7 +173,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "OutputAdapter": } return default_from_dict(cls, data) - def _extract_variables(self, env: Union[NativeEnvironment, SandboxedEnvironment]) -> Set[str]: + def _extract_variables(self, env: Environment) -> Set[str]: """ Extracts all variables from a list of Jinja template strings. diff --git a/haystack/components/routers/conditional_router.py b/haystack/components/routers/conditional_router.py index 3cd0feb675..ccac555d3d 100644 --- a/haystack/components/routers/conditional_router.py +++ b/haystack/components/routers/conditional_router.py @@ -139,9 +139,8 @@ def __init__(self, routes: List[Dict], custom_filters: Optional[Dict[str, Callab "Use this only if you trust the source of the template." ) warn(msg) - self._env = NativeEnvironment() - else: - self._env = SandboxedEnvironment() + + self._env = NativeEnvironment() if self._unsafe else SandboxedEnvironment() self._env.filters.update(self.custom_filters) self._validate_routes(routes)