Skip to content

Commit f305ca9

Browse files
committed
Fix: Handle unescaped quotes in generate_value using safe_eval
1 parent b02904c commit f305ca9

File tree

3 files changed

+77
-2
lines changed

3 files changed

+77
-2
lines changed

nemoguardrails/actions/llm/generation.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,12 @@
6464
from nemoguardrails.rails.llm.config import EmbeddingSearchProvider, RailsConfig
6565
from nemoguardrails.rails.llm.options import GenerationOptions
6666
from nemoguardrails.streaming import StreamingHandler
67-
from nemoguardrails.utils import get_or_create_event_loop, new_event_dict, new_uuid
67+
from nemoguardrails.utils import (
68+
get_or_create_event_loop,
69+
new_event_dict,
70+
new_uuid,
71+
safe_eval,
72+
)
6873

6974
log = logging.getLogger(__name__)
7075

@@ -1039,7 +1044,11 @@ async def generate_value(
10391044

10401045
log.info(f"Generated value for ${var_name}: {value}")
10411046

1042-
return literal_eval(value)
1047+
try:
1048+
return safe_eval(value)
1049+
except Exception as e:
1050+
log.error(f"Error evaluating value: {value}. Error: {str(e)}")
1051+
raise ValueError(f"Invalid LLM response: `{value}`")
10431052

10441053
@action(is_system_action=True)
10451054
async def generate_intent_steps_message(

nemoguardrails/utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import random
2222
import re
2323
import uuid
24+
from ast import literal_eval
2425
from collections import namedtuple
2526
from datetime import datetime, timezone
2627
from enum import Enum
@@ -384,3 +385,26 @@ def is_ignored_by_railsignore(filename: str, ignore_patterns: str) -> bool:
384385
break
385386

386387
return ignore
388+
389+
390+
def safe_eval(input_value: str) -> str:
391+
"""
392+
Safely evaluate a string to handle unescaped quotes or invalid syntax from the async generate_value action.
393+
394+
Args:
395+
input_value (str): The input string to evaluate.
396+
397+
Returns:
398+
str: The evaluated and properly formatted string.
399+
400+
Raises:
401+
ValueError: If the input cannot be safely evaluated.
402+
"""
403+
if input_value.startswith(("'", '"')) and input_value.endswith(("'", '"')):
404+
try:
405+
return literal_eval(input_value)
406+
except (ValueError, SyntaxError):
407+
pass
408+
escaped_value = input_value.replace("'", "\\'").replace('"', '\\"')
409+
input_value = f"'{escaped_value}'"
410+
return literal_eval(input_value)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import pytest
16+
17+
from nemoguardrails.utils import safe_eval
18+
19+
20+
@pytest.mark.parametrize(
21+
"input_value, expected_result",
22+
[
23+
('"It\'s a sunny day"', "It's a sunny day"), # double quotes with single quote
24+
(
25+
"\"He said, 'Hello'\"",
26+
"He said, 'Hello'",
27+
), # double quotes with nested single quote
28+
(
29+
"It's a sunny day",
30+
"It's a sunny day",
31+
), # unquoted string containing single quote
32+
(
33+
"It is a sunny day",
34+
"It is a sunny day",
35+
), # plain string not wrapped in quotes
36+
("", ""), # empty string
37+
],
38+
)
39+
def test_safe_eval(input_value, expected_result):
40+
"""Test safe_eval with various input values."""
41+
result = safe_eval(input_value)
42+
assert result == expected_result

0 commit comments

Comments
 (0)