Skip to content

Commit d9fe689

Browse files
committed
better support for other langchain llm clients
1 parent be87de6 commit d9fe689

File tree

6 files changed

+91
-33
lines changed

6 files changed

+91
-33
lines changed

Diff for: src/ragas/embeddings/base.py

-3
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,9 @@
1414
from ragas.run_config import RunConfig, add_async_retry, add_retry
1515
import logging
1616

17-
# logging.basicConfig(level=logging.DEBUG)
18-
1917
DEFAULT_MODEL_NAME = "BAAI/bge-small-en-v1.5"
2018

2119

22-
2320
class BaseRagasEmbeddings(Embeddings, ABC):
2421
run_config: RunConfig
2522

Diff for: src/ragas/evaluation.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from ragas.exceptions import ExceptionInRunner
1919
from ragas.executor import Executor
2020
from ragas.llms import llm_factory
21-
from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper
21+
from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper, LLMConfig
2222
from ragas.metrics._answer_correctness import AnswerCorrectness
2323
from ragas.metrics.base import Metric, MetricWithEmbeddings, MetricWithLLM
2424
from ragas.metrics.critique import AspectCritique
@@ -41,6 +41,7 @@ def evaluate(
4141
dataset: Dataset,
4242
metrics: list[Metric] | None = None,
4343
llm: t.Optional[BaseRagasLLM | LangchainLLM] = None,
44+
llm_config: t.Optional[LLMConfig] = None,
4445
embeddings: t.Optional[BaseRagasEmbeddings | LangchainEmbeddings] = None,
4546
callbacks: Callbacks = None,
4647
is_async: bool = True,
@@ -148,7 +149,7 @@ def evaluate(
148149

149150
# set the llm and embeddings
150151
if isinstance(llm, LangchainLLM):
151-
llm = LangchainLLMWrapper(llm, run_config=run_config)
152+
llm = LangchainLLMWrapper(llm, llm_config=llm_config, run_config=run_config)
152153
if isinstance(embeddings, LangchainEmbeddings):
153154
embeddings = LangchainEmbeddingsWrapper(embeddings)
154155

Diff for: src/ragas/llms/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper, llm_factory
1+
from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper, LLMConfig, llm_factory
22

33
__all__ = [
44
"BaseRagasLLM",
55
"LangchainLLMWrapper",
6+
"LLMConfig",
67
"llm_factory",
78
]

Diff for: src/ragas/llms/base.py

+71-19
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,20 @@
1010
from langchain_community.chat_models.vertexai import ChatVertexAI
1111
from langchain_community.llms import VertexAI
1212
from langchain_core.language_models import BaseLanguageModel
13+
from langchain_core.messages import HumanMessage
1314
from langchain_core.outputs import LLMResult
15+
from langchain_core.prompt_values import ChatPromptValue, StringPromptValue
16+
from langchain_core.prompts import HumanMessagePromptTemplate, ChatPromptTemplate
1417
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI
1518
from langchain_openai.llms import AzureOpenAI, OpenAI
1619
from langchain_openai.llms.base import BaseOpenAI
1720

1821
from ragas.run_config import RunConfig, add_async_retry, add_retry
1922
import re
23+
import hashlib
2024
import traceback
2125

26+
2227
if t.TYPE_CHECKING:
2328
from langchain_core.callbacks import Callbacks
2429

@@ -110,6 +115,17 @@ async def generate(
110115
)
111116
return await loop.run_in_executor(None, generate_text)
112117

118+
@dataclass
119+
class LLMConfig:
120+
stop: t.Optional[t.List[str]] = None
121+
params: t.Optional[t.Dict[str, t.Any]] = None
122+
prompt_callback: t.Optional[t.Callable[[PromptValue], t.Tuple[t.List[PromptValue], t.Dict[str, t.Any]]]] = None
123+
result_callback: t.Optional[t.Callable[[LLMResult], t.Tuple[t.List[LLMResult]]]] = None
124+
125+
def __init__(self, stop: t.Optional[t.List[str]] = None, prompt_callback: t.Optional[t.Callable[[PromptValue], t.Tuple[t.List[PromptValue], t.Dict[str, t.Any]]]] = None, **kwargs):
126+
self.stop = stop
127+
self.params = kwargs
128+
self.prompt_callback = prompt_callback
113129

114130
class LangchainLLMWrapper(BaseRagasLLM):
115131
"""
@@ -120,12 +136,18 @@ class LangchainLLMWrapper(BaseRagasLLM):
120136
"""
121137

122138
def __init__(
123-
self, langchain_llm: BaseLanguageModel, run_config: t.Optional[RunConfig] = None
139+
self,
140+
langchain_llm: BaseLanguageModel,
141+
run_config: t.Optional[RunConfig] = None,
142+
llm_config: LLMConfig = None,
124143
):
125144
self.langchain_llm = langchain_llm
126145
if run_config is None:
127146
run_config = RunConfig()
128147
self.set_run_config(run_config)
148+
if llm_config is None:
149+
llm_config = LLMConfig()
150+
self.llm_config = llm_config
129151

130152
def generate_text(
131153
self,
@@ -136,21 +158,38 @@ def generate_text(
136158
callbacks: Callbacks = None,
137159
) -> LLMResult:
138160
temperature = self.get_temperature(n=n)
161+
stop = stop or self.llm_config.stop
162+
163+
if self.llm_config.prompt_callback:
164+
prompts, extra_params = self.llm_config.prompt_callback(prompt)
165+
else:
166+
prompts = [prompt]
167+
extra_params = {}
168+
139169
if is_multiple_completion_supported(self.langchain_llm):
140-
return self.langchain_llm.generate_prompt(
141-
prompts=[prompt],
170+
result = self.langchain_llm.generate_prompt(
171+
prompts=prompts,
142172
n=n,
143173
temperature=temperature,
144-
stop=stop,
145174
callbacks=callbacks,
175+
stop=stop,
176+
**self.llm_config.params,
177+
**extra_params,
146178
)
179+
if self.llm_config.result_callback:
180+
return self.llm_config.result_callback(result)
181+
return result
147182
else:
148183
result = self.langchain_llm.generate_prompt(
149184
prompts=[prompt] * n,
150185
temperature=temperature,
151186
stop=stop,
152187
callbacks=callbacks,
188+
**self.llm_config.params,
189+
**extra_params,
153190
)
191+
if self.llm_config.result_callback:
192+
result = self.llm_config.result_callback(result)
154193
# make LLMResult.generation appear as if it was n_completions
155194
# note that LLMResult.runs is still a list that represents each run
156195
generations = [[g[0] for g in result.generations]]
@@ -162,43 +201,56 @@ async def agenerate_text(
162201
prompt: PromptValue,
163202
n: int = 1,
164203
temperature: float = 1e-8,
165-
stop: t.Optional[t.List[str]] = None, #["<|eot_id|>"], #None,
204+
stop: t.Optional[t.List[str]] = None,
166205
callbacks: Callbacks = None,
167206
) -> LLMResult:
168-
# traceback.print_stack()
169-
logger.debug(f"Generating text with prompt: {str(prompt).encode('utf-8').decode('unicode_escape')}...")
170-
stop = ["<|eot_id|>"]
171-
# ["</s>", "[/INST]"] #
172-
prompt.prompt_str =f"<human>: {prompt.prompt_str}\n<bot>:"
207+
# to trace request/response for multi-threaded execution
208+
gen_id = hashlib.md5(str(prompt).encode('utf-8')).hexdigest()[:4]
209+
stop = stop or self.llm_config.stop
210+
prompt_str = prompt.prompt_str
211+
logger.debug(f"Generating text for [{gen_id}] with prompt: {prompt_str}")
173212
temperature = self.get_temperature(n=n)
213+
if self.llm_config.prompt_callback:
214+
prompts, extra_params = self.llm_config.prompt_callback(prompt)
215+
else:
216+
prompts = [prompt] * n
217+
extra_params = {}
174218
if is_multiple_completion_supported(self.langchain_llm):
175-
response = await self.langchain_llm.agenerate_prompt(
176-
prompts=[prompt],
219+
result = await self.langchain_llm.agenerate_prompt(
220+
prompts=prompts,
177221
n=n,
178222
temperature=temperature,
179223
stop=stop,
180224
callbacks=callbacks,
225+
**self.llm_config.params,
226+
**extra_params,
181227
)
182-
logger.debug(f"got result (m): {response.generations[0][0].text}")
183-
return response
228+
if self.llm_config.result_callback:
229+
result = self.llm_config.result_callback(result)
230+
logger.debug(f"got result (m): {result.generations[0][0].text}")
231+
return result
184232
else:
185233
result = await self.langchain_llm.agenerate_prompt(
186-
prompts=[prompt] * n,
234+
prompts=prompts,
187235
temperature=temperature,
188-
stop=stop,
189236
callbacks=callbacks,
237+
**self.llm_config.params,
238+
**extra_params,
190239
)
240+
if self.llm_config.result_callback:
241+
result = self.llm_config.result_callback(result)
191242
# make LLMResult.generation appear as if it was n_completions
192243
# note that LLMResult.runs is still a list that represents each run
193244
generations = [[g[0] for g in result.generations]]
194245
result.generations = generations
246+
247+
# this part should go to LLMConfig.result_callback
195248
if len(result.generations[0][0].text) > 0:
196-
# while the <human>/<bot> tags improves answer quality, I observed sometimes the </bit> to leak into the response
197249
result.generations[0][0].text = re.sub(r"</?bot>", '', result.generations[0][0].text)
198-
logger.debug(f"got result: {result.generations[0][0].text}")
250+
logger.debug(f"got result [{gen_id}]: {result.generations[0][0].text}")
199251
# todo configure on question?
200252
if len(result.generations[0][0].text) < 24:
201-
logger.warn(f"truncated response?: {result.generations}")
253+
logger.warning(f"truncated response?: {result.generations}")
202254
return result
203255

204256
def set_run_config(self, run_config: RunConfig):

Diff for: src/ragas/testset/evolutions.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,13 @@ async def generate_datarow(
186186
):
187187
assert self.generator_llm is not None, "generator_llm cannot be None"
188188

189-
node_content = [
190-
f"{i+1}\t{n.page_content}" for i, n in enumerate(current_nodes.nodes)
191-
]
189+
# clear index distinction helps in getting it more clear for LLM - especially for long, complex contexts
190+
node_content = {
191+
str(i + 1): n.page_content for i, n in enumerate(current_nodes.nodes)
192+
}
193+
# node_content = [
194+
# f"{i+1}\t{n.page_content}" for i, n in enumerate(current_nodes.nodes)
195+
# ]
192196
results = await self.generator_llm.generate(
193197
prompt=self.find_relevant_context_prompt.format(
194198
question=question, contexts=node_content
@@ -208,9 +212,9 @@ async def generate_datarow(
208212
)
209213
else:
210214
selected_nodes = [
211-
current_nodes.nodes[i - 1]
215+
current_nodes.nodes[int(i) - 1]
212216
for i in relevant_context_indices
213-
if i - 1 < len(current_nodes.nodes)
217+
if int(i) - 1 < len(current_nodes.nodes)
214218
]
215219
relevant_context = (
216220
CurrentNodes(root_node=selected_nodes[0], nodes=selected_nodes)

Diff for: src/ragas/testset/generator.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from ragas.embeddings.base import BaseRagasEmbeddings, LangchainEmbeddingsWrapper
1717
from ragas.exceptions import ExceptionInRunner
1818
from ragas.executor import Executor
19-
from ragas.llms import BaseRagasLLM, LangchainLLMWrapper
19+
from ragas.llms import BaseRagasLLM, LangchainLLMWrapper, LLMConfig
2020
from ragas.run_config import RunConfig
2121
from ragas.testset.docstore import Document, DocumentStore, InMemoryDocumentStore
2222
from ragas.testset.evolutions import (
@@ -32,6 +32,7 @@
3232
from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter
3333
from ragas.utils import check_if_sum_is_close, deprecated, get_feature_language, is_nan
3434

35+
3536
if t.TYPE_CHECKING:
3637
from langchain_core.documents import Document as LCDocument
3738
from llama_index.core.schema import Document as LlamaindexDocument
@@ -81,9 +82,11 @@ def from_langchain(
8182
docstore: t.Optional[DocumentStore] = None,
8283
run_config: t.Optional[RunConfig] = None,
8384
chunk_size: int = 1024,
85+
generator_llm_config: t.Optional[LLMConfig] = None,
86+
critic_llm_config: t.Optional[LLMConfig] = None,
8487
) -> "TestsetGenerator":
85-
generator_llm_model = LangchainLLMWrapper(generator_llm)
86-
critic_llm_model = LangchainLLMWrapper(critic_llm)
88+
generator_llm_model = LangchainLLMWrapper(generator_llm, llm_config=generator_llm_config)
89+
critic_llm_model = LangchainLLMWrapper(critic_llm, llm_config=critic_llm_config)
8790
embeddings_model = LangchainEmbeddingsWrapper(embeddings)
8891

8992
keyphrase_extractor = KeyphraseExtractor(llm=generator_llm_model)

0 commit comments

Comments
 (0)