Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix issue that could lead to RCE if using unsecure Jinja templates #8095

Merged
merged 4 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions haystack/components/builders/chat_prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from typing import Any, Dict, List, Optional, Set

from jinja2 import Template, meta
from jinja2 import meta
from jinja2.sandbox import SandboxedEnvironment

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses.chat_message import ChatMessage, ChatRole
Expand Down Expand Up @@ -123,12 +124,12 @@ def __init__(
self.required_variables = required_variables or []
self.template = template
variables = variables or []
self._env = SandboxedEnvironment()
if template and not variables:
for message in template:
if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
# infere variables from template
msg_template = Template(message.content)
ast = msg_template.environment.parse(message.content)
ast = self._env.parse(message.content)
template_variables = meta.find_undeclared_variables(ast)
variables += list(template_variables)

Expand Down Expand Up @@ -194,7 +195,8 @@ def run(
for message in template:
if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
self._validate_variables(set(template_variables_combined.keys()))
compiled_template = Template(message.content)

compiled_template = self._env.from_string(message.content)
rendered_content = compiled_template.render(template_variables_combined)
rendered_message = (
ChatMessage.from_user(rendered_content)
Expand Down
13 changes: 8 additions & 5 deletions haystack/components/builders/prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from typing import Any, Dict, List, Optional, Set

from jinja2 import Template, meta
from jinja2 import meta
from jinja2.sandbox import SandboxedEnvironment

from haystack import component, default_to_dict

Expand Down Expand Up @@ -158,10 +159,12 @@ def __init__(
self._variables = variables
self._required_variables = required_variables
self.required_variables = required_variables or []
self.template = Template(template)

self._env = SandboxedEnvironment()
self.template = self._env.from_string(template)
if not variables:
# infere variables from template
ast = self.template.environment.parse(template)
ast = self._env.parse(template)
template_variables = meta.find_undeclared_variables(ast)
variables = list(template_variables)

Expand Down Expand Up @@ -216,8 +219,8 @@ def run(self, template: Optional[str] = None, template_variables: Optional[Dict[
self._validate_variables(set(template_variables_combined.keys()))

compiled_template = self.template
if isinstance(template, str):
compiled_template = Template(template)
if template is not None:
compiled_template = self._env.from_string(template)

result = compiled_template.render(template_variables_combined)
return {"prompt": result}
Expand Down
30 changes: 18 additions & 12 deletions haystack/components/converters/output_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
#
# SPDX-License-Identifier: Apache-2.0

import ast
import contextlib
from typing import Any, Callable, Dict, Optional, Set

import jinja2.runtime
from jinja2 import TemplateSyntaxError, meta
from jinja2.nativetypes import NativeEnvironment
from jinja2.sandbox import SandboxedEnvironment
from typing_extensions import TypeAlias

from haystack import component, default_from_dict, default_to_dict
Expand Down Expand Up @@ -58,18 +60,18 @@ def __init__(self, template: str, output_type: TypeAlias, custom_filters: Option

# Create a Jinja native environment, we need it to:
# a) add custom filters to the environment for filter compilation stage
env = NativeEnvironment()
self._env = SandboxedEnvironment(undefined=jinja2.runtime.StrictUndefined)
try:
env.parse(template) # Validate template syntax
self._env.parse(template) # Validate template syntax
self.template = template
except TemplateSyntaxError as e:
raise ValueError(f"Invalid Jinja template '{template}': {e}") from e

for name, filter_func in self.custom_filters.items():
env.filters[name] = filter_func
self._env.filters[name] = filter_func

# b) extract variables in the template
route_input_names = self._extract_variables(env)
route_input_names = self._extract_variables(self._env)
input_types.update(route_input_names)

# the env is not needed, discarded automatically
Expand All @@ -92,16 +94,22 @@ def run(self, **kwargs):
# check if kwargs are empty
if not kwargs:
raise ValueError("No input data provided for output adaptation")
env = NativeEnvironment()
for name, filter_func in self.custom_filters.items():
env.filters[name] = filter_func
self._env.filters[name] = filter_func
adapted_outputs = {}
try:
adapted_output_template = env.from_string(self.template)
adapted_output_template = self._env.from_string(self.template)
output_result = adapted_output_template.render(**kwargs)
if isinstance(output_result, jinja2.runtime.Undefined):
raise OutputAdaptationException(f"Undefined variable in the template {self.template}; kwargs: {kwargs}")

# We suppress the exception in case the output is already a string, otherwise
# we try to evaluate it and would fail.
# This must be done cause the output could be different literal structures.
# This doesn't support any user types.
with contextlib.suppress(Exception):
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
output_result = ast.literal_eval(output_result)

adapted_outputs["output"] = output_result
except Exception as e:
raise OutputAdaptationException(f"Error adapting {self.template} with {kwargs}: {e}") from e
Expand Down Expand Up @@ -135,14 +143,12 @@ def from_dict(cls, data: Dict[str, Any]) -> "OutputAdapter":
init_params["custom_filters"][name] = deserialize_callable(filter_func) if filter_func else None
return default_from_dict(cls, data)

def _extract_variables(self, env: NativeEnvironment) -> Set[str]:
def _extract_variables(self, env: SandboxedEnvironment) -> Set[str]:
"""
Extracts all variables from a list of Jinja template strings.

:param env: A Jinja native environment.
:return: A set of variable names extracted from the template strings.
"""
variables = set()
ast = env.parse(self.template)
variables.update(meta.find_undeclared_variables(ast))
return variables
return meta.find_undeclared_variables(ast)
27 changes: 17 additions & 10 deletions haystack/components/routers/conditional_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
#
# SPDX-License-Identifier: Apache-2.0

import ast
import contextlib
from typing import Any, Callable, Dict, List, Optional, Set

from jinja2 import Environment, TemplateSyntaxError, meta
from jinja2.nativetypes import NativeEnvironment
from jinja2.sandbox import SandboxedEnvironment

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.utils import deserialize_callable, deserialize_type, serialize_callable, serialize_type
Expand Down Expand Up @@ -125,16 +128,16 @@ def __init__(self, routes: List[Dict], custom_filters: Optional[Dict[str, Callab
self.custom_filters = custom_filters or {}

# Create a Jinja native environment to inspect variables in the condition templates
env = NativeEnvironment()
env.filters.update(self.custom_filters)
self._env = SandboxedEnvironment()
self._env.filters.update(self.custom_filters)

# Inspect the routes to determine input and output types.
input_types: Set[str] = set() # let's just store the name, type will always be Any
output_types: Dict[str, str] = {}

for route in routes:
# extract inputs
route_input_names = self._extract_variables(env, [route["output"], route["condition"]])
route_input_names = self._extract_variables(self._env, [route["output"], route["condition"]])
input_types.update(route_input_names)

# extract outputs
Expand Down Expand Up @@ -194,16 +197,20 @@ def run(self, **kwargs):
routes.
"""
# Create a Jinja native environment to evaluate the condition templates as Python expressions
env = NativeEnvironment()
env.filters.update(self.custom_filters)

for route in self.routes:
try:
t = env.from_string(route["condition"])
if t.render(**kwargs):
t = self._env.from_string(route["condition"])
rendered = t.render(**kwargs)
if ast.literal_eval(rendered):
# We now evaluate the `output` expression to determine the route output
t_output = env.from_string(route["output"])
t_output = self._env.from_string(route["output"])
output = t_output.render(**kwargs)
# We suppress the exception in case the output is already a string, otherwise
# we try to evaluate it and would fail.
# This must be done cause the output could be different literal structures.
# This doesn't support any user types.
with contextlib.suppress(Exception):
output = ast.literal_eval(output)
# and return the output as a dictionary under the output_name key
return {route["output_name"]: output}
except Exception as e:
Expand Down Expand Up @@ -234,7 +241,7 @@ def _validate_routes(self, routes: List[Dict]):
if not self._validate_template(env, route[field]):
raise ValueError(f"Invalid template for field '{field}': {route[field]}")

def _extract_variables(self, env: NativeEnvironment, templates: List[str]) -> Set[str]:
def _extract_variables(self, env: SandboxedEnvironment, templates: List[str]) -> Set[str]:
"""
Extracts all variables from a list of Jinja template strings.

Expand Down
5 changes: 3 additions & 2 deletions haystack/core/pipeline/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from pathlib import Path
from typing import Any, Dict, Optional, Union

from jinja2 import Environment, PackageLoader, TemplateSyntaxError, meta
from jinja2 import PackageLoader, TemplateSyntaxError, meta
from jinja2.sandbox import SandboxedEnvironment

TEMPLATE_FILE_EXTENSION = ".yaml.jinja2"
TEMPLATE_HOME_DIR = Path(__file__).resolve().parent / "predefined"
Expand Down Expand Up @@ -74,7 +75,7 @@ def __init__(self, template_content: str):

:param template_content: The raw template source to use in the template.
"""
env = Environment(
env = SandboxedEnvironment(
loader=PackageLoader("haystack.core.pipeline", "predefined"), trim_blocks=True, lstrip_blocks=True
)
try:
Expand Down
14 changes: 14 additions & 0 deletions releasenotes/notes/fix-jinja-env-81c98225b22dc827.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
---
upgrade:
- |
`OutputAdapter` and `ConditionalRouter` can't return users inputs anymore.
security:
- |
Fix issue that could lead to remote code execution when using insecure Jinja template in the following Components:

- `PromptBuilder`
- `ChatPromptBuilder`
- `OutputAdapter`
- `ConditionalRouter`

The same issue has been fixed in the `PipelineTemplate` class too.
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
64 changes: 17 additions & 47 deletions test/components/routers/test_conditional_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,23 @@ def test_router_initialized(self, routes):
assert set(router.__haystack_input__._sockets_dict.keys()) == {"query", "streams"}
assert set(router.__haystack_output__._sockets_dict.keys()) == {"query", "streams"}

def test_router_evaluate_condition_expressions(self, router):
def test_router_evaluate_condition_expressions(self):
router = ConditionalRouter(
[
{
"condition": "{{streams|length < 2}}",
"output": "{{query}}",
"output_type": str,
"output_name": "query",
},
{
"condition": "{{streams|length >= 2}}",
"output": "{{streams}}",
"output_type": List[int],
"output_name": "streams",
},
]
)
# first route should be selected
kwargs = {"streams": [1, 2, 3], "query": "test"}
result = router.run(**kwargs)
Expand Down Expand Up @@ -227,52 +243,6 @@ def test_router_de_serialization(self):
# check that the result is the same and correct
assert result1 == result2 and result1 == {"streams": [1, 2, 3]}

def test_router_de_serialization_user_type(self):
routes = [
{
"condition": "{{streams|length < 2}}",
"output": "{{message}}",
"output_type": ChatMessage,
"output_name": "message",
},
{
"condition": "{{streams|length >= 2}}",
"output": "{{streams}}",
"output_type": List[int],
"output_name": "streams",
},
]
router = ConditionalRouter(routes)
router_dict = router.to_dict()

# assert that the router dict is correct, with all keys and values being strings
for route in router_dict["init_parameters"]["routes"]:
for key in route.keys():
assert isinstance(key, str)
assert isinstance(route[key], str)

# check that the output_type is a string and a proper class name
assert (
router_dict["init_parameters"]["routes"][0]["output_type"]
== "haystack.dataclasses.chat_message.ChatMessage"
)

# deserialize the router
new_router = ConditionalRouter.from_dict(router_dict)

# check that the output_type is the right class
assert new_router.routes[0]["output_type"] == ChatMessage
assert router.routes == new_router.routes

# now use both routers to run the same message
message = ChatMessage.from_user("ciao")
kwargs = {"streams": [1], "message": message}
result1 = router.run(**kwargs)
result2 = new_router.run(**kwargs)

# check that the result is the same and correct
assert result1 == result2 and result1["message"].content == message.content

def test_router_serialization_idempotence(self):
routes = [
{
Expand Down