Skip to content

Commit 9ae8a0f

Browse files
committed
feat(llm): Add automatic provider inference for LangChain LLMs
Add automatic provider name detection from LLM module paths to eliminate manual provider specification. The implementation extracts provider names from LangChain package naming conventions (e.g.,langchain_openai → openai) and handles edge cases including community packages, wrapped classes, and multiple inheritance through MRO traversal.
1 parent 861ec38 commit 9ae8a0f

File tree

2 files changed

+190
-1
lines changed

2 files changed

+190
-1
lines changed

nemoguardrails/actions/llm/utils.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,70 @@ def __init__(self, inner_exception: Any):
4646
self.inner_exception = inner_exception
4747

4848

49+
def _infer_provider_from_module(llm: BaseLanguageModel) -> Optional[str]:
50+
"""Infer provider name from the LLM's module path.
51+
52+
This function extracts the provider name from LangChain package naming conventions:
53+
- langchain_openai -> openai
54+
- langchain_anthropic -> anthropic
55+
- langchain_google_genai -> google_genai
56+
- langchain_nvidia_ai_endpoints -> nvidia_ai_endpoints
57+
- langchain_community.chat_models.ollama -> ollama
58+
59+
For patched/wrapped classes, checks base classes as well.
60+
61+
Args:
62+
llm: The LLM instance
63+
64+
Returns:
65+
The inferred provider name, or None if it cannot be determined
66+
"""
67+
module = type(llm).__module__
68+
69+
if module.startswith("langchain_"):
70+
package = module.split(".")[0]
71+
provider = package.replace("langchain_", "")
72+
73+
if provider == "community":
74+
parts = module.split(".")
75+
if len(parts) >= 3:
76+
provider = parts[-1]
77+
return provider
78+
else:
79+
return provider
80+
81+
for base_class in type(llm).__mro__[1:]:
82+
base_module = base_class.__module__
83+
if base_module.startswith("langchain_"):
84+
package = base_module.split(".")[0]
85+
provider = package.replace("langchain_", "")
86+
87+
if provider == "community":
88+
parts = base_module.split(".")
89+
if len(parts) >= 3:
90+
provider = parts[-1]
91+
return provider
92+
else:
93+
return provider
94+
95+
return None
96+
97+
98+
def get_llm_provider(llm: BaseLanguageModel) -> Optional[str]:
99+
"""Get the provider name for an LLM instance by inferring from module path.
100+
101+
This function extracts the provider name from LangChain package naming conventions.
102+
See _infer_provider_from_module for details on the inference logic.
103+
104+
Args:
105+
llm: The LLM instance
106+
107+
Returns:
108+
The provider name if it can be inferred, None otherwise
109+
"""
110+
return _infer_provider_from_module(llm)
111+
112+
49113
def _infer_model_name(llm: BaseLanguageModel):
50114
"""Helper to infer the model name based from an LLM instance.
51115
@@ -126,7 +190,7 @@ def _setup_llm_call_info(
126190
llm_call_info_var.set(llm_call_info)
127191

128192
llm_call_info.llm_model_name = model_name or _infer_model_name(llm)
129-
llm_call_info.llm_provider_name = model_provider
193+
llm_call_info.llm_provider_name = model_provider or _infer_provider_from_module(llm)
130194

131195

132196
def _prepare_callbacks(

tests/test_actions_llm_utils.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from nemoguardrails.actions.llm.utils import _infer_provider_from_module
17+
18+
19+
class MockOpenAILLM:
20+
__module__ = "langchain_openai.chat_models"
21+
22+
23+
class MockAnthropicLLM:
24+
__module__ = "langchain_anthropic.chat_models"
25+
26+
27+
class MockNVIDIALLM:
28+
__module__ = "langchain_nvidia_ai_endpoints.chat_models"
29+
30+
31+
class MockCommunityOllama:
32+
__module__ = "langchain_community.chat_models.ollama"
33+
34+
35+
class MockUnknownLLM:
36+
__module__ = "some_custom_package.models"
37+
38+
39+
class MockNVIDIAOriginal:
40+
__module__ = "langchain_nvidia_ai_endpoints.chat_models"
41+
42+
43+
class MockPatchedNVIDIA(MockNVIDIAOriginal):
44+
__module__ = "nemoguardrails.llm.providers._langchain_nvidia_ai_endpoints_patch"
45+
46+
47+
def test_infer_provider_openai():
48+
llm = MockOpenAILLM()
49+
provider = _infer_provider_from_module(llm)
50+
assert provider == "openai"
51+
52+
53+
def test_infer_provider_anthropic():
54+
llm = MockAnthropicLLM()
55+
provider = _infer_provider_from_module(llm)
56+
assert provider == "anthropic"
57+
58+
59+
def test_infer_provider_nvidia_ai_endpoints():
60+
llm = MockNVIDIALLM()
61+
provider = _infer_provider_from_module(llm)
62+
assert provider == "nvidia_ai_endpoints"
63+
64+
65+
def test_infer_provider_community_ollama():
66+
llm = MockCommunityOllama()
67+
provider = _infer_provider_from_module(llm)
68+
assert provider == "ollama"
69+
70+
71+
def test_infer_provider_unknown():
72+
llm = MockUnknownLLM()
73+
provider = _infer_provider_from_module(llm)
74+
assert provider is None
75+
76+
77+
def test_infer_provider_from_patched_class():
78+
llm = MockPatchedNVIDIA()
79+
provider = _infer_provider_from_module(llm)
80+
assert provider == "nvidia_ai_endpoints"
81+
82+
83+
def test_infer_provider_checks_base_classes():
84+
class BaseOpenAI:
85+
__module__ = "langchain_openai.chat_models"
86+
87+
class CustomWrapper(BaseOpenAI):
88+
__module__ = "my_custom_wrapper.llms"
89+
90+
llm = CustomWrapper()
91+
provider = _infer_provider_from_module(llm)
92+
assert provider == "openai"
93+
94+
95+
def test_infer_provider_multiple_inheritance():
96+
class BaseNVIDIA:
97+
__module__ = "langchain_nvidia_ai_endpoints.chat_models"
98+
99+
class Mixin:
100+
__module__ = "some_mixin.utils"
101+
102+
class MultipleInheritance(Mixin, BaseNVIDIA):
103+
__module__ = "custom_package.models"
104+
105+
llm = MultipleInheritance()
106+
provider = _infer_provider_from_module(llm)
107+
assert provider == "nvidia_ai_endpoints"
108+
109+
110+
def test_infer_provider_deeply_nested_inheritance():
111+
class Original:
112+
__module__ = "langchain_anthropic.chat_models"
113+
114+
class Wrapper1(Original):
115+
__module__ = "wrapper1.models"
116+
117+
class Wrapper2(Wrapper1):
118+
__module__ = "wrapper2.models"
119+
120+
class Wrapper3(Wrapper2):
121+
__module__ = "wrapper3.models"
122+
123+
llm = Wrapper3()
124+
provider = _infer_provider_from_module(llm)
125+
assert provider == "anthropic"

0 commit comments

Comments
 (0)