Skip to content

Commit eaaa58c

Browse files
authored
refactor(llm): reorganize HuggingFace provider structure (#1083)
1 parent d05fd8d commit eaaa58c

File tree

8 files changed

+138
-24
lines changed

8 files changed

+138
-24
lines changed

examples/configs/llm/hf_pipeline_dolly/config.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@
1717
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, pipeline
1818

1919
from nemoguardrails.llm.helpers import get_llm_instance_wrapper
20-
from nemoguardrails.llm.providers import (
21-
HuggingFacePipelineCompatible,
22-
register_llm_provider,
23-
)
20+
from nemoguardrails.llm.providers import register_llm_provider
21+
from nemoguardrails.llm.providers.huggingface import HuggingFacePipelineCompatible
2422

2523

2624
@lru_cache

examples/configs/llm/hf_pipeline_falcon/config.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@
1717
from torch import bfloat16
1818

1919
from nemoguardrails.llm.helpers import get_llm_instance_wrapper
20-
from nemoguardrails.llm.providers import (
21-
HuggingFacePipelineCompatible,
22-
register_llm_provider,
23-
)
20+
from nemoguardrails.llm.providers import register_llm_provider
21+
from nemoguardrails.llm.providers.huggingface import HuggingFacePipelineCompatible
2422

2523

2624
@lru_cache

examples/configs/llm/hf_pipeline_llama2/config.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,8 @@
2020

2121
from nemoguardrails import LLMRails, RailsConfig
2222
from nemoguardrails.llm.helpers import get_llm_instance_wrapper
23-
from nemoguardrails.llm.providers import (
24-
HuggingFacePipelineCompatible,
25-
register_llm_provider,
26-
)
23+
from nemoguardrails.llm.providers import register_llm_provider
24+
from nemoguardrails.llm.providers.huggingface import HuggingFacePipelineCompatible
2725

2826

2927
def _get_model_config(config: RailsConfig, type: str):

examples/configs/llm/hf_pipeline_mosaic/config.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,8 @@
1818
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, pipeline
1919

2020
from nemoguardrails.llm.helpers import get_llm_instance_wrapper
21-
from nemoguardrails.llm.providers import (
22-
HuggingFacePipelineCompatible,
23-
register_llm_provider,
24-
)
21+
from nemoguardrails.llm.providers import register_llm_provider
22+
from nemoguardrails.llm.providers.huggingface import HuggingFacePipelineCompatible
2523

2624

2725
@lru_cache

examples/configs/llm/hf_pipeline_vicuna/config.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,8 @@
1818
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, pipeline
1919

2020
from nemoguardrails.llm.helpers import get_llm_instance_wrapper
21-
from nemoguardrails.llm.providers import (
22-
HuggingFacePipelineCompatible,
23-
register_llm_provider,
24-
)
21+
from nemoguardrails.llm.providers import register_llm_provider
22+
from nemoguardrails.llm.providers.huggingface import HuggingFacePipelineCompatible
2523

2624

2725
@lru_cache

examples/configs/rag/multi_kb/config.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,8 @@
3232
from nemoguardrails.actions import action
3333
from nemoguardrails.actions.actions import ActionResult
3434
from nemoguardrails.llm.helpers import get_llm_instance_wrapper
35-
from nemoguardrails.llm.providers import (
36-
HuggingFacePipelineCompatible,
37-
register_llm_provider,
38-
)
35+
from nemoguardrails.llm.providers import register_llm_provider
36+
from nemoguardrails.llm.providers.huggingface import HuggingFacePipelineCompatible
3937

4038
from .tabular_llm import TabularLLM
4139

nemoguardrails/llm/providers/huggingface/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,10 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
from .pipeline import HuggingFacePipelineCompatible
1617
from .streamers import AsyncTextIteratorStreamer
18+
19+
__all__ = [
20+
"HuggingFacePipelineCompatible",
21+
"AsyncTextIteratorStreamer",
22+
]
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from typing import Any, List, Optional
17+
18+
from langchain.callbacks.manager import (
19+
AsyncCallbackManagerForLLMRun,
20+
CallbackManagerForLLMRun,
21+
)
22+
from langchain.schema.output import GenerationChunk
23+
from langchain_community.llms import HuggingFacePipeline
24+
25+
26+
class HuggingFacePipelineCompatible(HuggingFacePipeline):
27+
"""
28+
Hackish way to add backward-compatibility functions to the Langchain class.
29+
TODO: Planning to add this fix directly to Langchain repo.
30+
"""
31+
32+
def _call(
33+
self,
34+
prompt: str,
35+
stop: Optional[List[str]] = None,
36+
run_manager: Optional[CallbackManagerForLLMRun] = None,
37+
**kwargs: Any,
38+
) -> str:
39+
"""
40+
Hackish way to perform a single llm call since Langchain dropped support
41+
"""
42+
if not isinstance(prompt, str):
43+
raise ValueError(
44+
"Argument `prompt` is expected to be a string. Instead found "
45+
f"{type(prompt)}. If you want to run the LLM on multiple prompts, use "
46+
"`generate` instead."
47+
)
48+
49+
# Streaming for NeMo Guardrails is not supported in sync calls.
50+
if self.model_kwargs and self.model_kwargs.get("streaming"):
51+
raise Exception(
52+
"Streaming mode not supported for HuggingFacePipeline in NeMo Guardrails!"
53+
)
54+
55+
llm_result = self._generate(
56+
[prompt],
57+
stop=stop,
58+
run_manager=run_manager,
59+
**kwargs,
60+
)
61+
return llm_result.generations[0][0].text
62+
63+
async def _acall(
64+
self,
65+
prompt: str,
66+
stop: Optional[List[str]] = None,
67+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
68+
**kwargs: Any,
69+
) -> str:
70+
"""
71+
Hackish way to add async support
72+
"""
73+
if not isinstance(prompt, str):
74+
raise ValueError(
75+
"Argument `prompt` is expected to be a string. Instead found "
76+
f"{type(prompt)}. If you want to run the LLM on multiple prompts, use "
77+
"`generate` instead."
78+
)
79+
80+
# Handle streaming, if the flag is set
81+
if self.model_kwargs and self.model_kwargs.get("streaming"):
82+
# Retrieve the streamer object, needs to be set in model_kwargs
83+
streamer = self.model_kwargs.get("streamer")
84+
if not streamer:
85+
raise Exception(
86+
"Cannot stream, please add HuggingFace streamer object to model_kwargs!"
87+
)
88+
89+
loop = asyncio.get_running_loop()
90+
91+
# Pass the asyncio loop to the stream so that it can send back
92+
# the chunks in the queue.
93+
streamer.loop = loop
94+
95+
# Launch the generation in a separate task.
96+
generation_kwargs = dict(
97+
prompts=[prompt],
98+
stop=stop,
99+
run_manager=run_manager,
100+
**kwargs,
101+
)
102+
loop.create_task(self._agenerate(**generation_kwargs))
103+
104+
# And start waiting for the chunks to come in.
105+
completion = ""
106+
async for item in streamer:
107+
completion += item
108+
chunk = GenerationChunk(text=item)
109+
if run_manager:
110+
await run_manager.on_llm_new_token(item, chunk=chunk)
111+
112+
return completion
113+
114+
llm_result = await self._agenerate(
115+
[prompt],
116+
stop=stop,
117+
run_manager=run_manager,
118+
**kwargs,
119+
)
120+
return llm_result.generations[0][0].text

0 commit comments

Comments
 (0)