Skip to content

Commit

Permalink
Merge pull request #784 from NVIDIA/feat/bump-langchain-version
Browse files Browse the repository at this point in the history
Feat: Upgrade LangChain to Version 0.3
  • Loading branch information
Pouyanpi authored Oct 14, 2024
2 parents eef798d + b988b5a commit 3865845
Show file tree
Hide file tree
Showing 23 changed files with 44 additions and 37 deletions.
2 changes: 1 addition & 1 deletion examples/configs/rag/custom_rag_output_rails/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from langchain.llms.base import BaseLLM
from langchain.prompts import PromptTemplate
from langchain_core.language_models.llms import BaseLLM
from langchain_core.output_parsers import StrOutputParser

from nemoguardrails import LLMRails
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/rag/multi_kb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
from gpt4pandas import GPT4Pandas
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import BaseLLM
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain_core.language_models.llms import BaseLLM
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

from nemoguardrails import LLMRails, RailsConfig
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/rag/pinecone/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from langchain.chains import RetrievalQA
from langchain.docstore.document import Document
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms import BaseLLM
from langchain.vectorstores import Pinecone
from langchain_core.language_models.llms import BaseLLM

from nemoguardrails import LLMRails
from nemoguardrails.actions import action
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/demo_llama_index_guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from typing import Any, Callable, Coroutine

from langchain.llms.base import BaseLLM
from langchain_core.language_models.llms import BaseLLM

from nemoguardrails import LLMRails, RailsConfig

Expand Down
2 changes: 1 addition & 1 deletion nemoguardrails/actions/llm/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from typing import Callable, List, Optional, cast

from jinja2 import Environment, meta
from langchain.llms import BaseLLM
from langchain_core.language_models.llms import BaseLLM

from nemoguardrails.actions.actions import ActionResult, action
from nemoguardrails.actions.llm.utils import (
Expand Down
2 changes: 1 addition & 1 deletion nemoguardrails/actions/v2_x/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ast import literal_eval
from typing import Any, List, Optional, Tuple

from langchain.llms import BaseLLM
from langchain_core.language_models.llms import BaseLLM
from rich.text import Text

from nemoguardrails.actions.actions import action
Expand Down
2 changes: 1 addition & 1 deletion nemoguardrails/library/content_safety/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
from typing import Dict, Optional

from langchain.llms.base import BaseLLM
from langchain_core.language_models.llms import BaseLLM

from nemoguardrails.actions.actions import action
from nemoguardrails.actions.llm.utils import llm_call
Expand Down
2 changes: 1 addition & 1 deletion nemoguardrails/library/factchecking/align_score/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
from typing import Optional

from langchain.llms import BaseLLM
from langchain_core.language_models.llms import BaseLLM

from nemoguardrails import RailsConfig
from nemoguardrails.actions import action
Expand Down
2 changes: 1 addition & 1 deletion nemoguardrails/library/hallucination/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from typing import Optional

from langchain.chains import LLMChain
from langchain.llms.base import BaseLLM
from langchain.prompts import PromptTemplate
from langchain_core.language_models.llms import BaseLLM

from nemoguardrails import RailsConfig
from nemoguardrails.actions import action
Expand Down
2 changes: 1 addition & 1 deletion nemoguardrails/library/llama_guard/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
from typing import List, Optional, Tuple

from langchain.llms import BaseLLM
from langchain_core.language_models.llms import BaseLLM

from nemoguardrails.actions import action
from nemoguardrails.actions.llm.utils import llm_call
Expand Down
2 changes: 1 addition & 1 deletion nemoguardrails/library/patronusai/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import re
from typing import List, Optional, Tuple, Union

from langchain.llms.base import BaseLLM
from langchain_core.language_models.llms import BaseLLM

from nemoguardrails.actions import action
from nemoguardrails.actions.llm.utils import llm_call
Expand Down
2 changes: 1 addition & 1 deletion nemoguardrails/library/self_check/facts/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
from typing import Optional

from langchain.llms import BaseLLM
from langchain_core.language_models.llms import BaseLLM

from nemoguardrails import RailsConfig
from nemoguardrails.actions import action
Expand Down
2 changes: 1 addition & 1 deletion nemoguardrails/library/self_check/input_check/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
from typing import Optional

from langchain.llms.base import BaseLLM
from langchain_core.language_models.llms import BaseLLM

from nemoguardrails import RailsConfig
from nemoguardrails.actions.actions import ActionResult, action
Expand Down
2 changes: 1 addition & 1 deletion nemoguardrails/library/self_check/output_check/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
from typing import Optional

from langchain.llms.base import BaseLLM
from langchain_core.language_models.llms import BaseLLM

from nemoguardrails import RailsConfig
from nemoguardrails.actions import action
Expand Down
2 changes: 1 addition & 1 deletion nemoguardrails/llm/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM, BaseLLM
from langchain_core.language_models.llms import LLM, BaseLLM


def get_llm_instance_wrapper(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from langchain_core.language_models.chat_models import generate_from_stream
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult
from langchain_core.pydantic_v1 import Field
from langchain_nvidia_ai_endpoints import ChatNVIDIA as ChatNVIDIAOriginal
from pydantic.v1 import Field

log = logging.getLogger(__name__)

Expand Down
6 changes: 3 additions & 3 deletions nemoguardrails/llm/providers/nemollm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import BaseLLM
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.schema import Generation
from langchain.schema.output import GenerationChunk, LLMResult
from langchain_core.language_models.llms import BaseLLM
from pydantic.v1 import root_validator

log = logging.getLogger(__name__)


class NeMoLLM(BaseLLM, BaseModel):
class NeMoLLM(BaseLLM):
"""Wrapper around NeMo LLM large language models.
If NGC_API_HOST, NGC_API_KEY and NGC_ORGANIZATION_ID environment variables are set,
Expand Down
25 changes: 16 additions & 9 deletions nemoguardrails/llm/providers/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,37 @@

import asyncio
import logging
from importlib.metadata import PackageNotFoundError, version
import warnings
from importlib.metadata import version
from typing import Any, Dict, List, Optional, Type

from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.schema.output import GenerationChunk
from langchain_community import llms
from langchain_community.llms import HuggingFacePipeline
from packaging import version as pkg_version
from langchain_core.language_models.llms import BaseLLM

from nemoguardrails.rails.llm.config import Model

from .nemollm import NeMoLLM
from .trtllm.llm import TRTLLM

# NOTE: this is temp
# Suppress specific warnings related to protected namespaces in Pydantic models, they must update their code.
warnings.filterwarnings(
"ignore",
message=r'Field "model_.*" in .* has conflict with protected namespace "model_"',
category=UserWarning,
module=r"pydantic\._internal\._fields",
)
log = logging.getLogger(__name__)

# Initialize the providers with the default ones, for now only NeMo LLM.
# Initialize the providers with the default ones
# We set nvidia_ai_endpoints provider to None because it's only supported if `langchain_nvidia_ai_endpoints` is installed.
_providers: Dict[str, Type[BaseLanguageModel]] = {
_providers: Dict[str, Optional[Type[BaseLLM]]] = {
"nemollm": NeMoLLM,
"trt_llm": TRTLLM,
"nvidia_ai_endpoints": None,
Expand Down Expand Up @@ -193,7 +200,7 @@ def discover_langchain_providers():
# If the "_acall" method is not defined, we add it.
if (
provider_cls
and issubclass(provider_cls, LLM)
and issubclass(provider_cls, BaseLLM)
and "_acall" not in provider_cls.__dict__
):
log.debug("Adding async support to %s", provider_cls.__name__)
Expand All @@ -204,12 +211,12 @@ def discover_langchain_providers():
discover_langchain_providers()


def register_llm_provider(name: str, provider_cls: Type[BaseLanguageModel]):
def register_llm_provider(name: str, provider_cls: Type[BaseLLM]):
"""Register an additional LLM provider."""
_providers[name] = provider_cls


def get_llm_provider(model_config: Model) -> Type[BaseLanguageModel]:
def get_llm_provider(model_config: Model) -> Type[BaseLLM]:
if model_config.engine not in _providers:
raise RuntimeError(f"Could not find LLM provider '{model_config.engine}'")

Expand Down
6 changes: 3 additions & 3 deletions nemoguardrails/llm/providers/trtllm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from typing import Any, Dict, List, Optional

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.pydantic_v1 import Field, root_validator
from langchain_core.language_models.llms import BaseLLM
from pydantic.v1 import Field, root_validator

from nemoguardrails.llm.providers.trtllm.client import TritonClient

Expand All @@ -31,7 +31,7 @@
RANDOM_SEED = 0


class TRTLLM(LLM):
class TRTLLM(BaseLLM):
"""A custom Langchain LLM class that integrates with TRTLLM triton models.
Arguments:
Expand Down
2 changes: 1 addition & 1 deletion nemoguardrails/rails/llm/llmrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
import warnings
from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, Type, Union, cast

from langchain.llms.base import BaseLLM
from langchain_core.language_models import BaseLanguageModel
from langchain_core.language_models.llms import BaseLLM

from nemoguardrails.actions.llm.generation import LLMGenerationActions
from nemoguardrails.actions.llm.utils import get_colang_history
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ dependencies = [
"httpx>=0.24.1",
"jinja2>=3.1.4",
# The 0.1.9 has a bug related to SparkLLM which breaks everything.
"langchain>=0.2.14,<0.3.0,!=0.1.9",
"langchain-core>=0.2.14,<0.3.0,!=0.1.26",
"langchain-community>=0.0.16,<0.3.0",
"langchain>=0.2.14,<0.4.0,!=0.1.9",
"langchain-core>=0.2.14,<0.4.0,!=0.1.26",
"langchain-community>=0.0.16,<0.4.0",
"lark~=1.1.7",
"nest-asyncio>=1.5.6",
"prompt-toolkit>=3.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ast import literal_eval
from typing import Optional

from langchain.llms import BaseLLM
from langchain_core.language_models.llms import BaseLLM

from nemoguardrails.actions import action
from nemoguardrails.actions.llm.utils import llm_call
Expand Down
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain_core.language_models.llms import LLM

from nemoguardrails import LLMRails, RailsConfig
from nemoguardrails.colang import parse_colang_file
Expand Down

0 comments on commit 3865845

Please sign in to comment.