Skip to content

Commit 1c72e4a

Browse files
committed
feat(llm): add support for langchain partner and community chat models
Introduces a new abstraction layer for initializing LLM models with: - refactor providers module - implement model initialization logic - clear separation between text completion and chat models - proper error handling with dedicated ModelInitializationError - consistent provider name handling and discovery - type safety improvements for LangChain models
1 parent 0536619 commit 1c72e4a

File tree

11 files changed

+533
-304
lines changed

11 files changed

+533
-304
lines changed

examples/configs/content_safety/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
models:
22
- type: main
33
engine: nim
4-
model_name: meta/llama-3.3-70b-instruct
4+
model: meta/llama-3.3-70b-instruct
55

66
- type: content_safety
77
engine: nim

nemoguardrails/actions/llm/generation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323
import threading
2424
from functools import lru_cache
2525
from time import time
26-
from typing import Callable, List, Optional, cast
26+
from typing import Callable, List, Optional, Union, cast
2727

2828
from jinja2 import meta
2929
from jinja2.sandbox import SandboxedEnvironment
30+
from langchain_core.language_models import BaseChatModel
3031
from langchain_core.language_models.llms import BaseLLM
3132

3233
from nemoguardrails.actions.actions import ActionResult, action
@@ -81,7 +82,7 @@ class LLMGenerationActions:
8182
def __init__(
8283
self,
8384
config: RailsConfig,
84-
llm: BaseLLM,
85+
llm: Union[BaseLLM, BaseChatModel],
8586
llm_task_manager: LLMTaskManager,
8687
get_embedding_search_provider_instance: Callable[
8788
[Optional[EmbeddingSearchProvider]], EmbeddingsIndex
@@ -417,7 +418,7 @@ async def generate_user_intent(
417418
)
418419
# We add these in reverse order so the most relevant is towards the end.
419420
for result in reversed(results):
420-
examples += f"user \"{result.text}\"\n {result.meta['intent']}\n\n"
421+
examples += f'user "{result.text}"\n {result.meta["intent"]}\n\n'
421422
if result.meta["intent"] not in potential_user_intents:
422423
potential_user_intents.append(result.meta["intent"])
423424

nemoguardrails/evaluate/utils.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,18 @@
1515

1616
import json
1717

18-
from nemoguardrails.llm.providers import get_llm_provider, get_llm_provider_names
18+
from nemoguardrails.llm.models.initializer import init_llm_model
1919
from nemoguardrails.rails.llm.config import Model
2020

2121

2222
def initialize_llm(model_config: Model):
2323
"""Initializes the model from LLM provider."""
24-
if model_config.engine not in get_llm_provider_names():
25-
raise Exception(f"Unknown LLM engine: {model_config.engine}")
26-
provider_cls = get_llm_provider(model_config)
27-
kwargs = {"temperature": 0, "max_tokens": 10}
28-
if model_config.engine in [
29-
"azure",
30-
"openai",
31-
"gooseai",
32-
"nlpcloud",
33-
"petals",
34-
]:
35-
kwargs["model_name"] = model_config.model
36-
else:
37-
kwargs["model"] = model_config.model
38-
return provider_cls(**kwargs)
24+
25+
return init_llm_model(
26+
model_name=model_config.model,
27+
provider_name=model_config.engine,
28+
kwargs=model_config.parameters,
29+
)
3930

4031

4132
def load_dataset(dataset_path: str):
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Module for initializing LLM models with proper error handling and type checking."""
17+
18+
from typing import Any, Dict, Optional, Union
19+
20+
from langchain_core.language_models import BaseChatModel
21+
from langchain_core.language_models.llms import BaseLLM
22+
23+
from .langchain_initializer import ModelInitializationError, init_langchain_model
24+
25+
26+
# later we can easily conver it to a class
27+
def init_llm_model(
28+
model_name: Optional[str], provider_name: str, kwargs: Dict[str, Any]
29+
) -> Union[BaseChatModel, BaseLLM]:
30+
"""Initialize an LLM model with proper error handling.
31+
32+
Currently, this function only supports LangChain models.
33+
In the future, it may support other model backends.
34+
35+
Args:
36+
model_name: Name of the model to initialize
37+
provider_name: Name of the provider to use
38+
kwargs: Additional arguments to pass to the model initialization
39+
40+
Returns:
41+
An initialized LLM model
42+
43+
Raises:
44+
ModelInitializationError: If model initialization fails
45+
"""
46+
# currently we only support LangChain models
47+
return init_langchain_model(
48+
model_name=model_name, provider_name=provider_name, kwargs=kwargs
49+
)
50+
51+
52+
__all__ = ["init_llm_model", "ModelInitializationError"]

0 commit comments

Comments
 (0)