Skip to content

Commit 5af9758

Browse files
committed
Clean up failing tests, add new ones with api key checks in kwargs
1 parent 3a432c7 commit 5af9758

File tree

6 files changed

+136
-73
lines changed

6 files changed

+136
-73
lines changed

nemoguardrails/llm/models/initializer.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,11 @@
2323
from .langchain_initializer import ModelInitializationError, init_langchain_model
2424

2525

26-
# later we can easily convert it to a class
26+
# later we can easily conver it to a class
2727
def init_llm_model(
2828
model_name: Optional[str],
2929
provider_name: str,
3030
mode: Literal["chat", "text"],
31-
api_key: str,
3231
kwargs: Dict[str, Any],
3332
) -> Union[BaseChatModel, BaseLLM]:
3433
"""Initialize an LLM model with proper error handling.
@@ -40,7 +39,6 @@ def init_llm_model(
4039
model_name: Name of the model to initialize
4140
provider_name: Name of the provider to use
4241
mode: Literal taking either "chat" or "text" values
43-
api_key: String with LLM API Key to use by client
4442
kwargs: Additional arguments to pass to the model initialization
4543
4644
Returns:
@@ -54,7 +52,6 @@ def init_llm_model(
5452
model_name=model_name,
5553
provider_name=provider_name,
5654
mode=mode,
57-
api_key=api_key,
5855
kwargs=kwargs,
5956
)
6057

nemoguardrails/llm/models/langchain_initializer.py

Lines changed: 14 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class ModelInitializationError(Exception):
4747

4848

4949
ModelInitMethod = Callable[
50-
[str, str, str, Dict[str, Any]], Optional[Union[BaseChatModel, BaseLLM]]
50+
[str, str, Dict[str, Any]], Optional[Union[BaseChatModel, BaseLLM]]
5151
]
5252

5353

@@ -67,10 +67,10 @@ def supports_mode(self, mode: Literal["chat", "text"]) -> bool:
6767
return mode in self.supported_modes
6868

6969
def execute(
70-
self, model_name: str, provider_name: str, api_key: str, kwargs: Dict[str, Any]
70+
self, model_name: str, provider_name: str, kwargs: Dict[str, Any]
7171
) -> Optional[Union[BaseChatModel, BaseLLM]]:
7272
"""Execute this initializer to initialize a model."""
73-
return self.init_method(model_name, provider_name, api_key, kwargs)
73+
return self.init_method(model_name, provider_name, kwargs)
7474

7575
def __str__(self) -> str:
7676
return f"{self.init_method.__name__}(modes={self.supported_modes})"
@@ -81,7 +81,6 @@ def try_initialization_method(
8181
model_name: str,
8282
provider_name: str,
8383
mode: Literal["chat", "text"],
84-
api_key: str,
8584
kwargs: Dict[str, Any],
8685
):
8786
"""Wrap an initialization method execution with a try/except to capture errors.
@@ -106,7 +105,6 @@ def try_initialization_method(
106105
result = initializer.execute(
107106
model_name=model_name,
108107
provider_name=provider_name,
109-
api_key=api_key,
110108
kwargs=kwargs,
111109
)
112110
log.debug(f"Initializer {initializer.init_method.__name__} returned: {result}")
@@ -124,7 +122,6 @@ def init_langchain_model(
124122
model_name: str,
125123
provider_name: str,
126124
mode: Literal["chat", "text"],
127-
api_key: str,
128125
kwargs: Dict[str, Any],
129126
) -> Union[BaseChatModel, BaseLLM]:
130127
"""Initialize a LangChain model using a series of initialization methods.
@@ -165,7 +162,6 @@ def init_langchain_model(
165162
model_name=model_name,
166163
provider_name=provider_name,
167164
mode=mode,
168-
api_key=api_key,
169165
kwargs=kwargs,
170166
)
171167
if result is not None:
@@ -201,14 +197,13 @@ def init_langchain_model(
201197

202198

203199
def _init_chat_completion_model(
204-
model_name: str, provider_name: str, api_key: str, kwargs: Dict[str, Any]
200+
model_name: str, provider_name: str, kwargs: Dict[str, Any]
205201
) -> BaseChatModel: # noqa #type: ignore
206202
"""Initialize a chat completion model.
207203
208204
Args:
209205
model_name: Name of the model to initialize
210206
provider_name: Name of the provider to use
211-
api_key: LLM API key to initialize client with
212207
kwargs: Additional arguments to pass to the model initialization
213208
214209
Returns:
@@ -223,10 +218,6 @@ def _init_chat_completion_model(
223218
# line with our pyproject.toml
224219
package_version = version("langchain-core")
225220

226-
# Langchain's `init_chat_model()` doesn't have an argument for api_key, so
227-
# copy kwargs and include api_key there instead
228-
kwargs["api_key"] = api_key
229-
230221
if _parse_version(package_version) < (0, 2, 7):
231222
raise RuntimeError(
232223
"this feature is supported from v0.2.7 of langchain-core."
@@ -243,14 +234,13 @@ def _init_chat_completion_model(
243234

244235

245236
def _init_text_completion_model(
246-
model_name: str, provider_name: str, api_key: str, kwargs: Dict[str, Any]
237+
model_name: str, provider_name: str, kwargs: Dict[str, Any]
247238
) -> BaseLLM:
248239
"""Initialize a text completion model.
249240
250241
Args:
251242
model_name: Name of the model to initialize
252243
provider_name: Name of the provider to use
253-
api_key: API Key to use for LLM call
254244
kwargs: Additional arguments to pass to the model initialization
255245
256246
Returns:
@@ -262,20 +252,18 @@ def _init_text_completion_model(
262252
provider_cls = _get_text_completion_provider(provider_name)
263253
if provider_cls is None:
264254
raise ValueError()
265-
kwargs = _update_model_kwargs(provider_cls, model_name, api_key, kwargs)
266-
kwargs["api_key"] = api_key
255+
kwargs = _update_model_kwargs(provider_cls, model_name, kwargs)
267256
return provider_cls(**kwargs)
268257

269258

270259
def _init_community_chat_models(
271-
model_name: str, provider_name: str, api_key: str, kwargs: Dict[str, Any]
260+
model_name: str, provider_name: str, kwargs: Dict[str, Any]
272261
) -> BaseChatModel:
273262
"""Initialize community chat models.
274263
275264
Args:
276265
provider_name: Name of the provider to use
277266
model_name: Name of the model to initialize
278-
api_key: API Key to use for LLM call
279267
kwargs: Additional arguments to pass to the model initialization
280268
281269
Returns:
@@ -288,12 +276,12 @@ def _init_community_chat_models(
288276
provider_cls = _get_chat_completion_provider(provider_name)
289277
if provider_cls is None:
290278
raise ValueError()
291-
kwargs = _update_model_kwargs(provider_cls, model_name, api_key, kwargs)
279+
kwargs = _update_model_kwargs(provider_cls, model_name, kwargs)
292280
return provider_cls(**kwargs)
293281

294282

295283
def _init_gpt35_turbo_instruct(
296-
model_name: str, provider_name: str, api_key: str, kwargs: Dict[str, Any]
284+
model_name: str, provider_name: str, kwargs: Dict[str, Any]
297285
) -> BaseLLM:
298286
"""Initialize GPT-3.5 Turbo Instruct model.
299287
@@ -305,7 +293,6 @@ def _init_gpt35_turbo_instruct(
305293
Args:
306294
model_name: Name of the model to initialize
307295
provider_name: Name of the provider to use
308-
api_key: API key value for LLM call
309296
kwargs: Additional arguments to pass to the model initialization
310297
311298
Returns:
@@ -318,7 +305,6 @@ def _init_gpt35_turbo_instruct(
318305
return _init_text_completion_model(
319306
model_name=model_name,
320307
provider_name=provider_name,
321-
api_key=api_key,
322308
kwargs=kwargs,
323309
)
324310
except Exception as e:
@@ -328,14 +314,13 @@ def _init_gpt35_turbo_instruct(
328314

329315

330316
def _init_nvidia_model(
331-
model_name: str, provider_name: str, api_key: str, kwargs: Dict[str, Any]
317+
model_name: str, provider_name: str, kwargs: Dict[str, Any]
332318
) -> BaseChatModel:
333319
"""Initialize NVIDIA AI Endpoints model.
334320
335321
Args:
336322
model_name: Name of the model to initialize
337323
provider_name: Name of the provider to use
338-
api_key: API key
339324
**kwargs: Additional arguments to pass to the model initialization
340325
341326
Returns:
@@ -358,7 +343,7 @@ def _init_nvidia_model(
358343
" Please upgrade it with `pip install langchain-nvidia-ai-endpoints --upgrade`."
359344
)
360345

361-
return ChatNVIDIA(model=model_name, api_key=api_key, **kwargs)
346+
return ChatNVIDIA(model=model_name, **kwargs)
362347
except ImportError as e:
363348
raise ImportError(
364349
"Could not import langchain_nvidia_ai_endpoints, please install it with "
@@ -379,7 +364,7 @@ def _init_nvidia_model(
379364

380365

381366
def _handle_model_special_cases(
382-
model_name: str, provider_name: str, api_key: str, kwargs: Dict[str, Any]
367+
model_name: str, provider_name: str, kwargs: Dict[str, Any]
383368
) -> Optional[Union[BaseChatModel, BaseLLM]]:
384369
"""Handle model initialization for special cases that need custom logic.
385370
@@ -408,15 +393,13 @@ def _handle_model_special_cases(
408393
if initializer is None:
409394
return None
410395

411-
result = initializer(model_name, provider_name, api_key, kwargs)
396+
result = initializer(model_name, provider_name, kwargs)
412397
if not isinstance(result, (BaseChatModel, BaseLLM)):
413398
raise TypeError("Initializer returned an invalid type")
414399
return result
415400

416401

417-
def _update_model_kwargs(
418-
provider_cls: type, model_name: str, api_key: str, kwargs: dict
419-
) -> Dict:
402+
def _update_model_kwargs(provider_cls: type, model_name: str, kwargs: dict) -> Dict:
420403
"""Update kwargs with the model name based on the provider's expected fields.
421404
422405
If provider_cls.model_fields contains 'model' or 'model_name',
@@ -425,5 +408,4 @@ def _update_model_kwargs(
425408
for key in ("model", "model_name"):
426409
if key in getattr(provider_cls, "model_fields", {}):
427410
kwargs[key] = model_name
428-
kwargs["api_key"] = api_key
429411
return kwargs

nemoguardrails/llm/providers/providers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def _patch_acall_method_to(llm_providers: Dict[str, Type[BaseLLM]]):
133133
setattr(provider_cls, "_acall", _acall)
134134

135135

136-
# Initialize the providers with the default onesBTW
136+
# Initialize the providers with the default ones
137137
_llm_providers: Dict[str, Type[BaseLLM]] = {
138138
"trt_llm": TRTLLM,
139139
}

nemoguardrails/rails/llm/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,9 @@ class Model(BaseModel):
102102
default=None,
103103
description="The name of the model. If not specified, it should be specified through the parameters attribute.",
104104
)
105-
api_key_env_var: str = Field(
105+
api_key_env_var: Optional[str] = Field(
106106
default=None,
107-
description='The environment variable containing the model\'s API Key. Do not include "$".',
107+
description='Optional environment variable with model\'s API Key. Do not include "$".',
108108
)
109109
reasoning_config: Optional[ReasoningModelConfig] = Field(
110110
default_factory=ReasoningModelConfig,
@@ -1245,7 +1245,7 @@ def validate_models_api_key_env_var(cls, models):
12451245
"""Model API Key Env var must be set to make LLM calls"""
12461246
api_keys = [m.api_key_env_var for m in models]
12471247
for api_key in api_keys:
1248-
if not os.environ.get(api_key):
1248+
if api_key and not os.environ.get(api_key):
12491249
raise ValueError(
12501250
f"Model API Key environment variable '{api_key}' not set."
12511251
)

nemoguardrails/rails/llm/llmrails.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -376,18 +376,18 @@ def _init_llms(self):
376376
provider_name = llm_config.engine
377377
kwargs = llm_config.parameters or {}
378378
mode = llm_config.mode
379-
api_key = os.environ.get(llm_config.api_key_env_var)
380379

381-
# Add the api_key to kwargs if it's set
382-
api_key = os.environ.get(llm_config.api_key_env_var)
383-
if api_key:
384-
kwargs["api_key"] = api_key
380+
# If the optional API Key Environment Variable is set, store
381+
# this in the `kwargs` for the current model
382+
if llm_config.api_key_env_var:
383+
api_key = os.environ.get(llm_config.api_key_env_var)
384+
if api_key:
385+
kwargs["api_key"] = api_key
385386

386387
llm_model = init_llm_model(
387388
model_name=model_name,
388389
provider_name=provider_name,
389390
mode=mode,
390-
api_key=api_key,
391391
kwargs=kwargs,
392392
)
393393

0 commit comments

Comments
 (0)