Skip to content

Commit 05ef8e9

Browse files
fix: lazy load jailbreak detection dependencies (#1223)
1 parent 5f9974a commit 05ef8e9

File tree

3 files changed

+219
-17
lines changed

3 files changed

+219
-17
lines changed

nemoguardrails/library/jailbreak_detection/model_based/checks.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,36 +14,32 @@
1414
# limitations under the License.
1515

1616
import os
17-
import pickle
1817
from functools import lru_cache
1918
from pathlib import Path
2019
from typing import Tuple, Union
2120

2221
import numpy as np
23-
from sklearn.ensemble import RandomForestClassifier
24-
25-
from nemoguardrails.library.jailbreak_detection.model_based.models import (
26-
JailbreakClassifier,
27-
)
2822

2923
models_path = os.environ.get("EMBEDDING_CLASSIFIER_PATH")
3024

31-
# When we add NIM support, will need to remove this check.
32-
if models_path is None:
33-
raise EnvironmentError(
34-
"Please set the EMBEDDING_CLASSIFIER_PATH environment variable to point to the Classifier model_based folder"
35-
)
36-
3725

3826
@lru_cache()
39-
def initialize_model(classifier_path: str = models_path) -> JailbreakClassifier:
27+
def initialize_model(classifier_path: str = models_path) -> "JailbreakClassifier":
4028
"""
4129
Initialize the global classifier model according to the configuration provided.
4230
Args
4331
classifier_path: Path to the classifier model
4432
Returns
4533
jailbreak_classifier: JailbreakClassifier object combining embedding model and NemoGuard JailbreakDetect RF
4634
"""
35+
if classifier_path is None:
36+
raise EnvironmentError(
37+
"Please set the EMBEDDING_CLASSIFIER_PATH environment variable to point to the Classifier model_based folder"
38+
)
39+
40+
from nemoguardrails.library.jailbreak_detection.model_based.models import (
41+
JailbreakClassifier,
42+
)
4743

4844
jailbreak_classifier = JailbreakClassifier(
4945
str(Path(classifier_path).joinpath("snowflake.pkl"))
@@ -54,7 +50,7 @@ def initialize_model(classifier_path: str = models_path) -> JailbreakClassifier:
5450

5551
def check_jailbreak(
5652
prompt: str,
57-
classifier: JailbreakClassifier = None,
53+
classifier=None,
5854
) -> dict:
5955
"""
6056
Use embedding-based jailbreak detection model to check for the presence of a jailbreak

nemoguardrails/library/jailbreak_detection/model_based/models.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,16 @@
1414
# limitations under the License.
1515

1616
import os
17-
import pickle
1817
from typing import Tuple
1918

2019
import numpy as np
21-
import torch
22-
from transformers import AutoModel, AutoTokenizer
2320

2421

2522
class SnowflakeEmbed:
2623
def __init__(self):
24+
import torch
25+
from transformers import AutoModel, AutoTokenizer
26+
2727
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
2828
self.tokenizer = AutoTokenizer.from_pretrained(
2929
"snowflake/snowflake-arctic-embed-m-long"
@@ -71,6 +71,8 @@ def __call__(self, text: str):
7171

7272
class JailbreakClassifier:
7373
def __init__(self, random_forest_path: str):
74+
import pickle
75+
7476
self.embed = SnowflakeEmbed()
7577
with open(random_forest_path, "rb") as fd:
7678
self.classifier = pickle.load(fd)
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import sys
17+
import types
18+
from unittest import mock
19+
20+
import pytest
21+
22+
# Test 1: Lazy import behavior
23+
24+
25+
def test_lazy_import_does_not_require_heavy_deps():
26+
"""
27+
Importing the checks module should not require torch, transformers, or sklearn unless model-based classifier is used.
28+
"""
29+
with mock.patch.dict(
30+
sys.modules, {"torch": None, "transformers": None, "sklearn": None}
31+
):
32+
import nemoguardrails.library.jailbreak_detection.model_based.checks as checks
33+
34+
# Just importing and calling unrelated functions should not raise ImportError
35+
assert hasattr(checks, "initialize_model")
36+
37+
38+
# Test 2: Model-based classifier instantiation requires dependencies
39+
40+
41+
def test_model_based_classifier_imports(monkeypatch):
42+
"""
43+
Instantiating JailbreakClassifier should require sklearn and pickle, and use SnowflakeEmbed which requires torch/transformers.
44+
"""
45+
# Mock dependencies
46+
fake_rf = mock.MagicMock()
47+
fake_embed = mock.MagicMock(return_value=[0.0])
48+
fake_pickle = types.SimpleNamespace(load=mock.MagicMock(return_value=fake_rf))
49+
fake_snowflake = mock.MagicMock(return_value=fake_embed)
50+
51+
monkeypatch.setitem(
52+
sys.modules,
53+
"sklearn.ensemble",
54+
types.SimpleNamespace(RandomForestClassifier=mock.MagicMock()),
55+
)
56+
monkeypatch.setitem(sys.modules, "pickle", fake_pickle)
57+
monkeypatch.setitem(sys.modules, "torch", mock.MagicMock())
58+
monkeypatch.setitem(sys.modules, "transformers", mock.MagicMock())
59+
60+
# Patch SnowflakeEmbed to avoid real model loading
61+
import nemoguardrails.library.jailbreak_detection.model_based.models as models
62+
63+
monkeypatch.setattr(models, "SnowflakeEmbed", fake_snowflake)
64+
65+
# mocking file operations to avoid Windows permission issues
66+
mock_open = mock.mock_open()
67+
with mock.patch("builtins.open", mock_open):
68+
# Should not raise
69+
classifier = models.JailbreakClassifier("fake_model_path.pkl")
70+
assert classifier is not None
71+
# Should be callable
72+
result = classifier("test")
73+
assert isinstance(result, tuple)
74+
75+
76+
# Test 3: Error if dependencies missing when instantiating model-based classifier
77+
78+
79+
def test_model_based_classifier_missing_deps(monkeypatch):
80+
"""
81+
If sklearn is missing, instantiating JailbreakClassifier should raise ImportError.
82+
"""
83+
monkeypatch.setitem(sys.modules, "sklearn.ensemble", None)
84+
85+
import nemoguardrails.library.jailbreak_detection.model_based.models as models
86+
87+
# to avoid Windows permission issues
88+
mock_open = mock.mock_open()
89+
with mock.patch("builtins.open", mock_open):
90+
with pytest.raises(ImportError):
91+
models.JailbreakClassifier("fake_model_path.pkl")
92+
93+
94+
# Test 4: Error when classifier_path is None
95+
96+
97+
def test_initialize_model_with_none_classifier_path():
98+
"""
99+
initialize_model should raise EnvironmentError when classifier_path is None.
100+
"""
101+
import nemoguardrails.library.jailbreak_detection.model_based.checks as checks
102+
103+
with pytest.raises(EnvironmentError) as exc_info:
104+
checks.initialize_model(classifier_path=None)
105+
106+
assert "Please set the EMBEDDING_CLASSIFIER_PATH environment variable" in str(
107+
exc_info.value
108+
)
109+
110+
111+
# Test 5: SnowflakeEmbed initialization and call with torch imports
112+
113+
114+
def test_snowflake_embed_torch_imports(monkeypatch):
115+
"""
116+
Test that SnowflakeEmbed properly imports torch and transformers when needed.
117+
"""
118+
# Mock torch and transformers
119+
mock_torch = mock.MagicMock()
120+
mock_torch.cuda.is_available.return_value = False
121+
mock_transformers = mock.MagicMock()
122+
123+
mock_tokenizer = mock.MagicMock()
124+
mock_model = mock.MagicMock()
125+
mock_transformers.AutoTokenizer.from_pretrained.return_value = mock_tokenizer
126+
mock_transformers.AutoModel.from_pretrained.return_value = mock_model
127+
128+
monkeypatch.setitem(sys.modules, "torch", mock_torch)
129+
monkeypatch.setitem(sys.modules, "transformers", mock_transformers)
130+
131+
import nemoguardrails.library.jailbreak_detection.model_based.models as models
132+
133+
embed = models.SnowflakeEmbed()
134+
assert embed.device == "cpu" # as we mocked cuda.is_available() = False
135+
136+
mock_tokens = mock.MagicMock()
137+
mock_tokens.to.return_value = mock_tokens
138+
mock_tokenizer.return_value = mock_tokens
139+
140+
import numpy as np
141+
142+
fake_embedding = np.array([1.0, 2.0, 3.0])
143+
144+
# the code does self.model(**tokens)[0][:, 0]
145+
# so we need to mock this properly
146+
mock_tensor_output = mock.MagicMock()
147+
mock_tensor_output.detach.return_value.cpu.return_value.squeeze.return_value.numpy.return_value = (
148+
fake_embedding
149+
)
150+
151+
mock_first_index = mock.MagicMock()
152+
mock_first_index.__getitem__.return_value = mock_tensor_output # for [:, 0]
153+
154+
mock_model_output = mock.MagicMock()
155+
mock_model_output.__getitem__.return_value = mock_first_index # for [0]
156+
157+
mock_model.return_value = mock_model_output
158+
159+
result = embed("test text")
160+
assert isinstance(result, np.ndarray)
161+
assert np.array_equal(result, fake_embedding)
162+
163+
164+
# Test 6: Check jailbreak function with classifier parameter
165+
166+
167+
def test_check_jailbreak_with_classifier():
168+
"""
169+
Test check_jailbreak function when classifier is provided.
170+
"""
171+
import nemoguardrails.library.jailbreak_detection.model_based.checks as checks
172+
173+
mock_classifier = mock.MagicMock()
174+
# jailbreak detected with score 0.9
175+
mock_classifier.return_value = (True, 0.9)
176+
177+
result = checks.check_jailbreak("test prompt", classifier=mock_classifier)
178+
179+
assert result == {"jailbreak": True, "score": 0.9}
180+
mock_classifier.assert_called_once_with("test prompt")
181+
182+
183+
# Test 7: Check jailbreak function without classifier parameter (uses initialize_model)
184+
185+
186+
def test_check_jailbreak_without_classifier(monkeypatch):
187+
"""
188+
Test check_jailbreak function when no classifier is provided, it should call initialize_model.
189+
"""
190+
import nemoguardrails.library.jailbreak_detection.model_based.checks as checks
191+
192+
# mock initialize_model to return a mock classifier
193+
mock_classifier = mock.MagicMock()
194+
# no jailbreak
195+
mock_classifier.return_value = (False, -0.5)
196+
mock_initialize_model = mock.MagicMock(return_value=mock_classifier)
197+
198+
monkeypatch.setattr(checks, "initialize_model", mock_initialize_model)
199+
200+
result = checks.check_jailbreak("safe prompt")
201+
202+
assert result == {"jailbreak": False, "score": -0.5}
203+
mock_initialize_model.assert_called_once()
204+
mock_classifier.assert_called_once_with("safe prompt")

0 commit comments

Comments
 (0)