10
10
from langchain_community .chat_models .vertexai import ChatVertexAI
11
11
from langchain_community .llms import VertexAI
12
12
from langchain_core .language_models import BaseLanguageModel
13
+ from langchain_core .messages import HumanMessage
13
14
from langchain_core .outputs import LLMResult
15
+ from langchain_core .prompt_values import ChatPromptValue , StringPromptValue
16
+ from langchain_core .prompts import HumanMessagePromptTemplate , ChatPromptTemplate
14
17
from langchain_openai .chat_models import AzureChatOpenAI , ChatOpenAI
15
18
from langchain_openai .llms import AzureOpenAI , OpenAI
16
19
from langchain_openai .llms .base import BaseOpenAI
17
20
18
21
from ragas .run_config import RunConfig , add_async_retry , add_retry
19
22
import re
23
+ import hashlib
20
24
import traceback
21
25
26
+
22
27
if t .TYPE_CHECKING :
23
28
from langchain_core .callbacks import Callbacks
24
29
@@ -110,6 +115,17 @@ async def generate(
110
115
)
111
116
return await loop .run_in_executor (None , generate_text )
112
117
118
+ @dataclass
119
+ class LLMConfig :
120
+ stop : t .Optional [t .List [str ]] = None
121
+ params : t .Optional [t .Dict [str , t .Any ]] = None
122
+ prompt_callback : t .Optional [t .Callable [[PromptValue ], t .Tuple [t .List [PromptValue ], t .Dict [str , t .Any ]]]] = None
123
+ result_callback : t .Optional [t .Callable [[LLMResult ], t .Tuple [t .List [LLMResult ]]]] = None
124
+
125
+ def __init__ (self , stop : t .Optional [t .List [str ]] = None , prompt_callback : t .Optional [t .Callable [[PromptValue ], t .Tuple [t .List [PromptValue ], t .Dict [str , t .Any ]]]] = None , ** kwargs ):
126
+ self .stop = stop
127
+ self .params = kwargs
128
+ self .prompt_callback = prompt_callback
113
129
114
130
class LangchainLLMWrapper (BaseRagasLLM ):
115
131
"""
@@ -120,12 +136,18 @@ class LangchainLLMWrapper(BaseRagasLLM):
120
136
"""
121
137
122
138
def __init__ (
123
- self , langchain_llm : BaseLanguageModel , run_config : t .Optional [RunConfig ] = None
139
+ self ,
140
+ langchain_llm : BaseLanguageModel ,
141
+ run_config : t .Optional [RunConfig ] = None ,
142
+ llm_config : LLMConfig = None ,
124
143
):
125
144
self .langchain_llm = langchain_llm
126
145
if run_config is None :
127
146
run_config = RunConfig ()
128
147
self .set_run_config (run_config )
148
+ if llm_config is None :
149
+ llm_config = LLMConfig ()
150
+ self .llm_config = llm_config
129
151
130
152
def generate_text (
131
153
self ,
@@ -136,21 +158,38 @@ def generate_text(
136
158
callbacks : Callbacks = None ,
137
159
) -> LLMResult :
138
160
temperature = self .get_temperature (n = n )
161
+ stop = stop or self .llm_config .stop
162
+
163
+ if self .llm_config .prompt_callback :
164
+ prompts , extra_params = self .llm_config .prompt_callback (prompt )
165
+ else :
166
+ prompts = [prompt ]
167
+ extra_params = {}
168
+
139
169
if is_multiple_completion_supported (self .langchain_llm ):
140
- return self .langchain_llm .generate_prompt (
141
- prompts = [ prompt ] ,
170
+ result = self .langchain_llm .generate_prompt (
171
+ prompts = prompts ,
142
172
n = n ,
143
173
temperature = temperature ,
144
- stop = stop ,
145
174
callbacks = callbacks ,
175
+ stop = stop ,
176
+ ** self .llm_config .params ,
177
+ ** extra_params ,
146
178
)
179
+ if self .llm_config .result_callback :
180
+ return self .llm_config .result_callback (result )
181
+ return result
147
182
else :
148
183
result = self .langchain_llm .generate_prompt (
149
184
prompts = [prompt ] * n ,
150
185
temperature = temperature ,
151
186
stop = stop ,
152
187
callbacks = callbacks ,
188
+ ** self .llm_config .params ,
189
+ ** extra_params ,
153
190
)
191
+ if self .llm_config .result_callback :
192
+ result = self .llm_config .result_callback (result )
154
193
# make LLMResult.generation appear as if it was n_completions
155
194
# note that LLMResult.runs is still a list that represents each run
156
195
generations = [[g [0 ] for g in result .generations ]]
@@ -162,43 +201,56 @@ async def agenerate_text(
162
201
prompt : PromptValue ,
163
202
n : int = 1 ,
164
203
temperature : float = 1e-8 ,
165
- stop : t .Optional [t .List [str ]] = None , #["<|eot_id|>"], #None,
204
+ stop : t .Optional [t .List [str ]] = None ,
166
205
callbacks : Callbacks = None ,
167
206
) -> LLMResult :
168
- # traceback.print_stack()
169
- logger . debug ( f"Generating text with prompt: { str (prompt ).encode ('utf-8' ). decode ( 'unicode_escape' ) } ..." )
170
- stop = [ "<|eot_id|>" ]
171
- # ["</s>", "[/INST]"] #
172
- prompt . prompt_str = f"<human>: { prompt . prompt_str } \n <bot>:"
207
+ # to trace request/response for multi-threaded execution
208
+ gen_id = hashlib . md5 ( str (prompt ).encode ('utf-8' )). hexdigest ()[: 4 ]
209
+ stop = stop or self . llm_config . stop
210
+ prompt_str = prompt . prompt_str
211
+ logger . debug ( f"Generating text for [ { gen_id } ] with prompt: { prompt_str } " )
173
212
temperature = self .get_temperature (n = n )
213
+ if self .llm_config .prompt_callback :
214
+ prompts , extra_params = self .llm_config .prompt_callback (prompt )
215
+ else :
216
+ prompts = [prompt ] * n
217
+ extra_params = {}
174
218
if is_multiple_completion_supported (self .langchain_llm ):
175
- response = await self .langchain_llm .agenerate_prompt (
176
- prompts = [ prompt ] ,
219
+ result = await self .langchain_llm .agenerate_prompt (
220
+ prompts = prompts ,
177
221
n = n ,
178
222
temperature = temperature ,
179
223
stop = stop ,
180
224
callbacks = callbacks ,
225
+ ** self .llm_config .params ,
226
+ ** extra_params ,
181
227
)
182
- logger .debug (f"got result (m): { response .generations [0 ][0 ].text } " )
183
- return response
228
+ if self .llm_config .result_callback :
229
+ result = self .llm_config .result_callback (result )
230
+ logger .debug (f"got result (m): { result .generations [0 ][0 ].text } " )
231
+ return result
184
232
else :
185
233
result = await self .langchain_llm .agenerate_prompt (
186
- prompts = [ prompt ] * n ,
234
+ prompts = prompts ,
187
235
temperature = temperature ,
188
- stop = stop ,
189
236
callbacks = callbacks ,
237
+ ** self .llm_config .params ,
238
+ ** extra_params ,
190
239
)
240
+ if self .llm_config .result_callback :
241
+ result = self .llm_config .result_callback (result )
191
242
# make LLMResult.generation appear as if it was n_completions
192
243
# note that LLMResult.runs is still a list that represents each run
193
244
generations = [[g [0 ] for g in result .generations ]]
194
245
result .generations = generations
246
+
247
+ # this part should go to LLMConfig.result_callback
195
248
if len (result .generations [0 ][0 ].text ) > 0 :
196
- # while the <human>/<bot> tags improves answer quality, I observed sometimes the </bit> to leak into the response
197
249
result .generations [0 ][0 ].text = re .sub (r"</?bot>" , '' , result .generations [0 ][0 ].text )
198
- logger .debug (f"got result: { result .generations [0 ][0 ].text } " )
250
+ logger .debug (f"got result [ { gen_id } ] : { result .generations [0 ][0 ].text } " )
199
251
# todo configure on question?
200
252
if len (result .generations [0 ][0 ].text ) < 24 :
201
- logger .warn (f"truncated response?: { result .generations } " )
253
+ logger .warning (f"truncated response?: { result .generations } " )
202
254
return result
203
255
204
256
def set_run_config (self , run_config : RunConfig ):
0 commit comments