Skip to content

Commit 6d55b82

Browse files
authored
feat: load yara lazily to avoid action dispatcher error (#1162)
- Implement lazy loading of yara module with proper type hints - Add error messages when yara is not available - Add helper function _check_yara_available() for consistent error handling - make helper functions private and remove checks - add test for yara import error handling to appreciate codecov
1 parent eede896 commit 6d55b82

File tree

2 files changed

+68
-28
lines changed

2 files changed

+68
-28
lines changed

nemoguardrails/library/injection_detection/actions.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@
3434
from pathlib import Path
3535
from typing import Tuple, Union
3636

37-
import yara
37+
yara = None
38+
try:
39+
import yara
40+
except ImportError:
41+
pass
3842

3943
from nemoguardrails import RailsConfig
4044
from nemoguardrails.actions import action
@@ -45,6 +49,14 @@
4549
log = logging.getLogger(__name__)
4650

4751

52+
def _check_yara_available():
53+
if yara is None:
54+
raise ImportError(
55+
"The yara module is required for injection detection. "
56+
"Please install it using: pip install yara-python"
57+
)
58+
59+
4860
def _validate_unpack_config(config: RailsConfig) -> Tuple[str, Path, Tuple[str]]:
4961
"""
5062
Validates and unpacks the injection detection configuration.
@@ -117,7 +129,7 @@ def _validate_unpack_config(config: RailsConfig) -> Tuple[str, Path, Tuple[str]]
117129

118130

119131
@lru_cache()
120-
def load_rules(yara_path: Path, rule_names: Tuple) -> Union[yara.Rules, None]:
132+
def _load_rules(yara_path: Path, rule_names: Tuple) -> Union["yara.Rules", None]:
121133
"""
122134
Loads and compiles YARA rules from the specified path and rule names.
123135
@@ -126,12 +138,14 @@ def load_rules(yara_path: Path, rule_names: Tuple) -> Union[yara.Rules, None]:
126138
rule_names (Tuple): A tuple of YARA rule names to load.
127139
128140
Returns:
129-
Union[yara.Rules, None]: The compiled YARA rules object if successful,
141+
Union['yara.Rules', None]: The compiled YARA rules object if successful,
130142
or None if no rule names are provided.
131143
132144
Raises:
133145
yara.SyntaxError: If there is a syntax error in the YARA rules.
146+
ImportError: If the yara module is not installed.
134147
"""
148+
135149
if len(rule_names) == 0:
136150
log.warning(
137151
"Injection config was provided but no modules were specified. Returning None."
@@ -150,7 +164,7 @@ def load_rules(yara_path: Path, rule_names: Tuple) -> Union[yara.Rules, None]:
150164
return rules
151165

152166

153-
def omit_injection(text: str, matches: list[yara.Match]) -> str:
167+
def _omit_injection(text: str, matches: list["yara.Match"]) -> str:
154168
"""
155169
Attempts to strip the offending injection attempts from the provided text.
156170
@@ -160,11 +174,15 @@ def omit_injection(text: str, matches: list[yara.Match]) -> str:
160174
161175
Args:
162176
text (str): The text to check for command injection.
163-
matches (list[yara.Match]): A list of YARA rule matches.
177+
matches (list['yara.Match']): A list of YARA rule matches.
164178
165179
Returns:
166180
str: The text with the detected injections stripped out.
181+
182+
Raises:
183+
ImportError: If the yara module is not installed.
167184
"""
185+
168186
# Copy the text to a placeholder variable
169187
modified_text = text
170188
for match in matches:
@@ -180,7 +198,7 @@ def omit_injection(text: str, matches: list[yara.Match]) -> str:
180198
return modified_text
181199

182200

183-
def sanitize_injection(text: str, matches: list[yara.Match]) -> str:
201+
def _sanitize_injection(text: str, matches: list["yara.Match"]) -> str:
184202
"""
185203
Attempts to sanitize the offending injection attempts in the provided text.
186204
This is done by 'de-fanging' the offending content, transforming it into a state that will not execute
@@ -193,20 +211,22 @@ def sanitize_injection(text: str, matches: list[yara.Match]) -> str:
193211
194212
Args:
195213
text (str): The text to check for command injection.
196-
matches (list[yara.Match]): A list of YARA rule matches.
214+
matches (list['yara.Match']): A list of YARA rule matches.
197215
198216
Returns:
199217
str: The text with the detected injections sanitized.
200218
201219
Raises:
202220
NotImplementedError: If the sanitization logic is not implemented.
221+
ImportError: If the yara module is not installed.
203222
"""
223+
204224
raise NotImplementedError(
205225
"Injection sanitization is not yet implemented. Please use 'reject' or 'omit'"
206226
)
207227

208228

209-
def reject_injection(text: str, rules: yara.Rules) -> Tuple[bool, str]:
229+
def _reject_injection(text: str, rules: "yara.Rules") -> Tuple[bool, str]:
210230
"""
211231
Detects whether the provided text contains potential injection attempts.
212232
@@ -215,15 +235,17 @@ def reject_injection(text: str, rules: yara.Rules) -> Tuple[bool, str]:
215235
216236
Args:
217237
text (str): The text to check for command injection.
218-
rules (yara.Rules): The loaded YARA rules.
238+
rules ('yara.Rules'): The loaded YARA rules.
219239
220240
Returns:
221241
bool: True if attempted exploitation is detected, False otherwise.
222242
str: list of matches as a string
223243
224244
Raises:
225245
ValueError: If the `action` parameter in the configuration is invalid.
246+
ImportError: If the yara module is not installed.
226247
"""
248+
227249
if rules is None:
228250
log.warning(
229251
"reject_injection guardrail was invoked but no rules were specified in the InjectionDetection config."
@@ -258,10 +280,12 @@ async def injection_detection(text: str, config: RailsConfig) -> str:
258280
ValueError: If the `action` parameter in the configuration is invalid.
259281
NotImplementedError: If an unsupported action is encountered.
260282
"""
283+
_check_yara_available()
284+
261285
action_option, yara_path, rule_names = _validate_unpack_config(config)
262-
rules = load_rules(yara_path, rule_names)
286+
rules = _load_rules(yara_path, rule_names)
263287
if action_option == "reject":
264-
verdict, detections = reject_injection(text, rules)
288+
verdict, detections = _reject_injection(text, rules)
265289
if verdict:
266290
return f"I'm sorry, the desired output triggered rule(s) designed to mitigate exploitation of {detections}."
267291
else:
@@ -276,9 +300,9 @@ async def injection_detection(text: str, config: RailsConfig) -> str:
276300
matches_string = ", ".join([match_name.rule for match_name in matches])
277301
log.info(f"Input matched on rule {matches_string}.")
278302
if action_option == "omit":
279-
return omit_injection(text, matches)
303+
return _omit_injection(text, matches)
280304
elif action_option == "sanitize":
281-
return sanitize_injection(text, matches)
305+
return _sanitize_injection(text, matches)
282306
else:
283307
# We should never ever hit this since we inspect the action option above, but putting an error here anyway.
284308
raise NotImplementedError(

tests/test_injection_detection.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,12 @@
3939
from nemoguardrails.actions import action
4040
from nemoguardrails.actions.actions import ActionResult
4141
from nemoguardrails.library.injection_detection.actions import (
42+
_check_yara_available,
43+
_load_rules,
44+
_omit_injection,
45+
_reject_injection,
4246
_validate_unpack_config,
4347
injection_detection,
44-
load_rules,
45-
omit_injection,
46-
reject_injection,
4748
)
4849
from tests.utils import TestChat
4950

@@ -101,14 +102,14 @@ def test_load_custom_rules():
101102
""",
102103
)
103104
_action_option, yara_path, rule_names = _validate_unpack_config(config)
104-
rules = load_rules(yara_path, rule_names)
105+
rules = _load_rules(yara_path, rule_names)
105106
assert isinstance(rules, yara.Rules)
106107

107108

108109
def test_load_all_rules():
109110
config = RailsConfig.from_path(os.path.join(CONFIGS_FOLDER, "injection_detection"))
110111
_action_option, yara_path, rule_names = _validate_unpack_config(config)
111-
rules = load_rules(yara_path, rule_names)
112+
rules = _load_rules(yara_path, rule_names)
112113
assert isinstance(rules, yara.Rules)
113114

114115

@@ -204,7 +205,7 @@ def test_empty_injection_rules():
204205
"""
205206
)
206207
_action_option, yara_path, rule_names = _validate_unpack_config(config)
207-
rules = load_rules(yara_path, rule_names)
208+
rules = _load_rules(yara_path, rule_names)
208209
assert rules is None
209210

210211

@@ -220,7 +221,7 @@ async def test_omit_injection_action():
220221
create_mock_yara_match("-- comment", "sqli"),
221222
]
222223

223-
result = omit_injection(text=text, matches=mock_matches)
224+
result = _omit_injection(text=text, matches=mock_matches)
224225

225226
# all sql injection should be removed
226227
# NOTE: following rule does not get removed using sqli.yara
@@ -243,11 +244,11 @@ async def test_reject_injection_with_mismatched_action():
243244

244245
# pathcing the load_rules function to return our mock rules
245246
with patch(
246-
"nemoguardrails.library.injection_detection.actions.load_rules",
247+
"nemoguardrails.library.injection_detection.actions._load_rules",
247248
return_value=mock_rules,
248249
):
249250
sql_injection = "' OR 1 = 1"
250-
result, _ = reject_injection(sql_injection, mock_rules)
251+
result, _ = _reject_injection(sql_injection, mock_rules)
251252
assert result is True
252253

253254

@@ -263,11 +264,11 @@ async def test_multiple_injection_types():
263264
mock_rules = create_mock_rules(mock_matches)
264265

265266
with patch(
266-
"nemoguardrails.library.injection_detection.actions.load_rules",
267+
"nemoguardrails.library.injection_detection.actions._load_rules",
267268
return_value=mock_rules,
268269
):
269270
multi_injection = "' OR 1 = 1 <script>alert('xss')</script>"
270-
result, _ = reject_injection(multi_injection, mock_rules)
271+
result, _ = _reject_injection(multi_injection, mock_rules)
271272
assert result is True
272273

273274

@@ -279,21 +280,21 @@ async def test_edge_cases():
279280
mock_rules = create_mock_rules([])
280281

281282
with patch(
282-
"nemoguardrails.library.injection_detection.actions.load_rules",
283+
"nemoguardrails.library.injection_detection.actions._load_rules",
283284
return_value=mock_rules,
284285
):
285286
# Test with empty string
286-
result, _ = reject_injection("", mock_rules)
287+
result, _ = _reject_injection("", mock_rules)
287288
assert result is False
288289

289290
# no issue with very long str
290291
long_string = "a" * 10000
291-
result, _ = reject_injection(long_string, mock_rules)
292+
result, _ = _reject_injection(long_string, mock_rules)
292293
assert result is False
293294

294295
# no issue with special chars
295296
special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?"
296-
result, _ = reject_injection(special_chars, mock_rules)
297+
result, _ = _reject_injection(special_chars, mock_rules)
297298
assert result is False
298299

299300

@@ -495,3 +496,18 @@ async def test_sanitize_action_not_implemented():
495496
496497
"""
497498
)
499+
500+
501+
def test_yara_import_error():
502+
"""Test that appropriate error is raised when yara module is not available."""
503+
504+
with patch("nemoguardrails.library.injection_detection.actions.yara", None):
505+
with pytest.raises(ImportError) as exc_info:
506+
_check_yara_available()
507+
assert str(exc_info.value) == (
508+
"The yara module is required for injection detection. "
509+
"Please install it using: pip install yara-python"
510+
)
511+
512+
with patch("nemoguardrails.library.injection_detection.actions.yara", yara):
513+
_check_yara_available()

0 commit comments

Comments
 (0)