Skip to content

Commit 01cd760

Browse files
Support NIM API key for NemoGuard JailbreakDetect (#1214)
* Update jailbreak detection compatibility for NIM to allow providing an API key. * Allow configurable classification path. * Clean up unused dependencies. Update `JailbreakDetectionConfig` object to use base_url and endpoints. Refactor checks to align with base_uri and api_key_env_var approaches. Add additional error handling and logging. Fix tests to reflect changes. Signed-off-by: Erick Galinkin <egalinkin@nvidia.com> * apply black Signed-off-by: Erick Galinkin <egalinkin@nvidia.com> * style: apply pre-commits * Support deprecated `nim_url` and `nim_port` fields. Signed-off-by: Erick Galinkin <egalinkin@nvidia.com> * Push test update for deprecated parameters Signed-off-by: Erick Galinkin <egalinkin@nvidia.com> * fix: improve error handling in check_jailbreak function - Fix TypeError when classifier is None by adding defensive programming - Replace silent failure with clear RuntimeError and descriptive message - Simplify calling code by removing redundant null checks from actions.py and server.py - Update tests to match new function signature and behavior - Add test coverage for new RuntimeError path This resolves the critical bug where check_jailbreak(prompt) would crash with "TypeError: 'NoneType' object is not callable" when EMBEDDING_CLASSIFIER_PATH is not set. Now it raises a clear RuntimeError with guidance on how to fix it. * fix fix * fix(request): make nim_auth_token optional in request * test: add more tests * fix model path mocking and assertion for windows --------- Signed-off-by: Erick Galinkin <egalinkin@nvidia.com> Co-authored-by: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com>
1 parent 05ef8e9 commit 01cd760

File tree

12 files changed

+927
-71
lines changed

12 files changed

+927
-71
lines changed

nemoguardrails/library/jailbreak_detection/actions.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
# limitations under the License.
3030

3131
import logging
32+
import os
3233
from typing import Optional
3334

3435
from nemoguardrails.actions import action
@@ -94,13 +95,22 @@ async def jailbreak_detection_model(
9495
jailbreak_config = llm_task_manager.config.rails.config.jailbreak_detection
9596

9697
jailbreak_api_url = jailbreak_config.server_endpoint
97-
nim_url = jailbreak_config.nim_url
98-
nim_port = jailbreak_config.nim_port
98+
nim_base_url = jailbreak_config.nim_base_url
99+
nim_classification_path = jailbreak_config.nim_server_endpoint
100+
if jailbreak_config.api_key_env_var is not None:
101+
nim_auth_token = os.getenv(jailbreak_config.api_key_env_var)
102+
if nim_auth_token is None:
103+
log.warning(
104+
"Specified a value for jailbreak config api_key_env var at %s but the environment variable was not set!"
105+
% jailbreak_config.api_key_env_var
106+
)
107+
else:
108+
nim_auth_token = None
99109

100110
if context is not None:
101111
prompt = context.get("user_message", "")
102112

103-
if not jailbreak_api_url and not nim_url:
113+
if not jailbreak_api_url and not nim_base_url:
104114
from nemoguardrails.library.jailbreak_detection.model_based.checks import (
105115
check_jailbreak,
106116
initialize_model,
@@ -109,14 +119,26 @@ async def jailbreak_detection_model(
109119
log.warning(
110120
"No jailbreak detection endpoint set. Running in-process, NOT RECOMMENDED FOR PRODUCTION."
111121
)
112-
classifier = initialize_model()
113-
jailbreak = check_jailbreak(prompt=prompt, classifier=classifier)
114-
115-
return jailbreak["jailbreak"]
116-
117-
if nim_url:
122+
try:
123+
jailbreak = check_jailbreak(prompt=prompt)
124+
log.info(f"Local model jailbreak detection result: {jailbreak}")
125+
return jailbreak["jailbreak"]
126+
except RuntimeError as e:
127+
log.error(f"Jailbreak detection model not available: {e}")
128+
return False
129+
except ImportError as e:
130+
log.error(
131+
f"Failed to import required dependencies for local model. Install scikit-learn and torch, or use NIM-based approach",
132+
exc_info=e,
133+
)
134+
return False
135+
136+
if nim_base_url:
118137
jailbreak = await jailbreak_nim_request(
119-
prompt=prompt, nim_url=nim_url, nim_port=nim_port
138+
prompt=prompt,
139+
nim_url=nim_base_url,
140+
nim_auth_token=nim_auth_token,
141+
nim_classification_path=nim_classification_path,
120142
)
121143
elif jailbreak_api_url:
122144
jailbreak = await jailbreak_detection_model_request(

nemoguardrails/library/jailbreak_detection/model_based/checks.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,29 +13,33 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import logging
1617
import os
1718
from functools import lru_cache
1819
from pathlib import Path
19-
from typing import Tuple, Union
20+
from typing import Union
2021

21-
import numpy as np
22-
23-
models_path = os.environ.get("EMBEDDING_CLASSIFIER_PATH")
22+
logger = logging.getLogger(__name__)
2423

2524

2625
@lru_cache()
27-
def initialize_model(classifier_path: str = models_path) -> "JailbreakClassifier":
26+
def initialize_model() -> Union[None, "JailbreakClassifier"]:
2827
"""
2928
Initialize the global classifier model according to the configuration provided.
3029
Args
3130
classifier_path: Path to the classifier model
3231
Returns
3332
jailbreak_classifier: JailbreakClassifier object combining embedding model and NemoGuard JailbreakDetect RF
3433
"""
34+
35+
classifier_path = os.environ.get("EMBEDDING_CLASSIFIER_PATH")
36+
3537
if classifier_path is None:
36-
raise EnvironmentError(
37-
"Please set the EMBEDDING_CLASSIFIER_PATH environment variable to point to the Classifier model_based folder"
38+
# Log a warning, but do not throw an exception
39+
logger.warning(
40+
"No embedding classifier path set. Server /model endpoint will not work."
3841
)
42+
return None
3943

4044
from nemoguardrails.library.jailbreak_detection.model_based.models import (
4145
JailbreakClassifier,
@@ -57,10 +61,19 @@ def check_jailbreak(
5761
Args:
5862
prompt: User utterance to classify
5963
classifier: Instantiated JailbreakClassifier object
64+
65+
Raises:
66+
RuntimeError: If no classifier is available and EMBEDDING_CLASSIFIER_PATH is not set
6067
"""
6168
if classifier is None:
6269
classifier = initialize_model()
6370

71+
if classifier is None:
72+
raise RuntimeError(
73+
"No jailbreak classifier available. Please set the EMBEDDING_CLASSIFIER_PATH "
74+
"environment variable to point to the classifier model directory."
75+
)
76+
6477
classification, score = classifier(prompt)
6578
# classification will be 1 or 0 -- cast to boolean.
6679
return {"jailbreak": classification, "score": score}

nemoguardrails/library/jailbreak_detection/model_based/models.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import os
1716
from typing import Tuple
1817

1918
import numpy as np
@@ -46,29 +45,6 @@ def __call__(self, text: str):
4645
return embeddings.detach().cpu().squeeze(0).numpy()
4746

4847

49-
class NvEmbedE5:
50-
def __init__(self):
51-
self.api_key = os.environ.get("NVIDIA_API_KEY", None)
52-
if self.api_key is None:
53-
raise ValueError("No NVIDIA API key set!")
54-
55-
from openai import OpenAI
56-
57-
self.client = OpenAI(
58-
api_key=self.api_key,
59-
base_url="https://integrate.api.nvidia.com/v1",
60-
)
61-
62-
def __call__(self, text: str):
63-
response = self.client.embeddings.create(
64-
input=[text],
65-
model="nvidia/nv-embedqa-e5-v5",
66-
encoding_format="float",
67-
extra_body={"input_type": "query", "truncate": "END"},
68-
)
69-
return np.array(response.data[0].embedding, dtype="float32")
70-
71-
7248
class JailbreakClassifier:
7349
def __init__(self, random_forest_path: str):
7450
import pickle

nemoguardrails/library/jailbreak_detection/request.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,17 +98,25 @@ async def jailbreak_detection_model_request(
9898
async def jailbreak_nim_request(
9999
prompt: str,
100100
nim_url: str,
101-
nim_port: int,
101+
nim_auth_token: Optional[str],
102+
nim_classification_path: str,
102103
):
104+
from urllib.parse import urljoin
105+
106+
headers = {"Content-Type": "application/json", "Accept": "application/json"}
103107
payload = {
104108
"input": prompt,
105109
}
106110

107-
endpoint = f"http://{nim_url}:{nim_port}/v1/classify"
111+
endpoint = urljoin(nim_url, nim_classification_path)
108112
try:
109113
async with aiohttp.ClientSession() as session:
110114
try:
111-
async with session.post(endpoint, json=payload, timeout=30) as resp:
115+
if nim_auth_token is not None:
116+
headers["Authorization"] = f"Bearer {nim_auth_token}"
117+
async with session.post(
118+
endpoint, json=payload, headers=headers, timeout=30
119+
) as resp:
112120
if resp.status != 200:
113121
log.error(
114122
f"NemoGuard JailbreakDetect NIM request failed with status {resp.status}"

nemoguardrails/library/jailbreak_detection/server.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,7 @@ def run_all_heuristics(request: JailbreakHeuristicRequest):
111111

112112
@app.post("/model")
113113
def run_model_check(request: JailbreakModelRequest):
114-
classifier = mc.initialize_model()
115-
result = mc.check_jailbreak(request.prompt, classifier=classifier)
114+
result = mc.check_jailbreak(request.prompt)
116115
jailbreak = result["jailbreak"]
117116
score = result["score"]
118117
model_checks = {"jailbreak": jailbreak, "score": score}

nemoguardrails/rails/llm/config.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -556,28 +556,50 @@ class JailbreakDetectionConfig(BaseModel):
556556

557557
server_endpoint: Optional[str] = Field(
558558
default=None,
559-
description="The endpoint for the jailbreak detection heuristics server.",
559+
description="The endpoint for the jailbreak detection heuristics/model container.",
560560
)
561561
length_per_perplexity_threshold: float = Field(
562562
default=89.79, description="The length/perplexity threshold."
563563
)
564564
prefix_suffix_perplexity_threshold: float = Field(
565565
default=1845.65, description="The prefix/suffix perplexity threshold."
566566
)
567+
nim_base_url: Optional[str] = Field(
568+
default=None,
569+
description="Base URL for jailbreak detection model. Example: http://localhost:8000/v1",
570+
)
571+
nim_server_endpoint: Optional[str] = Field(
572+
default="classify",
573+
description="Classification path uri. Defaults to 'classify' for NemoGuard JailbreakDetect.",
574+
)
575+
api_key_env_var: Optional[str] = Field(
576+
default=None,
577+
description="Environment variable containing API key for jailbreak detection model",
578+
)
579+
# legacy fields, keep for backward comp with deprecation warnings
567580
nim_url: Optional[str] = Field(
568581
default=None,
569-
description="Location of the NemoGuard JailbreakDetect NIM.",
582+
deprecated="Use 'nim_base_url' instead. This field will be removed in a future version.",
583+
description="DEPRECATED: Use nim_base_url instead",
570584
)
571-
nim_port: int = Field(
572-
default=8000,
573-
description="Port the NemoGuard JailbreakDetect NIM is listening on.",
585+
nim_port: Optional[int] = Field(
586+
default=None,
587+
deprecated="Include port in 'nim_base_url' instead. This field will be removed in a future version.",
588+
description="DEPRECATED: Include port in nim_base_url instead",
574589
)
575590
embedding: Optional[str] = Field(
576-
default="nvidia/nv-embedqa-e5-v5",
577-
description="DEPRECATED: Model to use for embedding-based detections. Use NIM instead.",
578-
deprecated=True,
591+
default=None,
592+
deprecated="This field is no longer used.",
579593
)
580594

595+
@model_validator(mode="after")
596+
def migrate_deprecated_fields(self) -> "JailbreakDetectionConfig":
597+
"""Migrate deprecated nim_url/nim_port fields to nim_base_url format."""
598+
if self.nim_url and not self.nim_base_url:
599+
port = self.nim_port or 8000
600+
self.nim_base_url = f"http://{self.nim_url}:{port}/v1"
601+
return self
602+
581603

582604
class AutoAlignOptions(BaseModel):
583605
"""List of guardrails that are activated"""

tests/test_configs/jailbreak_nim/config.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ rails:
22
config:
33
jailbreak_detection:
44
server_endpoint: ""
5-
nim_url: "0.0.0.0"
6-
nim_port: 8000
5+
nim_base_url: "http://0.0.0.0:8000/v1"
6+
nim_server_endpoint: "classify"
7+
api_key_env_var: "JB_NIM_TEST"
78

89
input:
910
flows:

0 commit comments

Comments
 (0)