Skip to content

Commit

Permalink
Fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
silvanocerza committed Aug 9, 2024
1 parent ab64c34 commit ea23454
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
12 changes: 6 additions & 6 deletions haystack/components/converters/output_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

import ast
import contextlib
from typing import Any, Callable, Dict, Optional, Set, Union
from typing import Any, Callable, Dict, Optional, Set
from warnings import warn

import jinja2.runtime
from jinja2 import TemplateSyntaxError, meta
from jinja2 import Environment, TemplateSyntaxError, meta
from jinja2.nativetypes import NativeEnvironment
from jinja2.sandbox import SandboxedEnvironment
from typing_extensions import TypeAlias
Expand Down Expand Up @@ -77,9 +77,9 @@ def __init__(
"Use this only if you trust the source of the template."
)
warn(msg)
self._env = NativeEnvironment()
else:
self._env = SandboxedEnvironment(undefined=jinja2.runtime.StrictUndefined)
self._env = (
NativeEnvironment() if self._unsafe else SandboxedEnvironment(undefined=jinja2.runtime.StrictUndefined)
)

try:
self._env.parse(template) # Validate template syntax
Expand Down Expand Up @@ -173,7 +173,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "OutputAdapter":
}
return default_from_dict(cls, data)

def _extract_variables(self, env: Union[NativeEnvironment, SandboxedEnvironment]) -> Set[str]:
def _extract_variables(self, env: Environment) -> Set[str]:
"""
Extracts all variables from a list of Jinja template strings.
Expand Down
5 changes: 2 additions & 3 deletions haystack/components/routers/conditional_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,8 @@ def __init__(self, routes: List[Dict], custom_filters: Optional[Dict[str, Callab
"Use this only if you trust the source of the template."
)
warn(msg)
self._env = NativeEnvironment()
else:
self._env = SandboxedEnvironment()

self._env = NativeEnvironment() if self._unsafe else SandboxedEnvironment()
self._env.filters.update(self.custom_filters)

self._validate_routes(routes)
Expand Down

0 comments on commit ea23454

Please sign in to comment.