1313# limitations under the License.
1414from __future__ import annotations
1515
16- from typing import Any , List , Optional , Union , cast , Sequence
16+ from typing import Any , List , Optional , Union , Sequence
1717
18- from pydantic import ValidationError
1918
2019from neo4j_graphrag .exceptions import LLMGenerationError
2120from neo4j_graphrag .llm .base import LLMInterface
2221from neo4j_graphrag .llm .rate_limit import (
2322 RateLimitHandler ,
24- rate_limit_handler ,
25- async_rate_limit_handler ,
2623)
2724from neo4j_graphrag .llm .types import (
28- BaseMessage ,
2925 LLMResponse ,
30- MessageList ,
3126 ToolCall ,
3227 ToolCallResponse ,
3328)
@@ -98,92 +93,73 @@ def __init__(
9893
9994 def get_messages (
10095 self ,
101- input : str ,
102- message_history : Optional [Union [List [LLMMessage ], MessageHistory ]] = None ,
103- ) -> list [Content ]:
96+ input : list [LLMMessage ],
97+ ) -> tuple [str | None , list [Content ]]:
10498 messages = []
105- if message_history :
106- if isinstance (message_history , MessageHistory ):
107- message_history = message_history .messages
108- try :
109- MessageList (messages = cast (list [BaseMessage ], message_history ))
110- except ValidationError as e :
111- raise LLMGenerationError (e .errors ()) from e
112-
113- for message in message_history :
114- if message .get ("role" ) == "user" :
115- messages .append (
116- Content (
117- role = "user" ,
118- parts = [Part .from_text (message .get ("content" , "" ))],
119- )
99+ system_instruction = self .system_instruction
100+ for message in input :
101+ if message .get ("role" ) == "system" :
102+ system_instruction = message .get ("content" )
103+ continue
104+ if message .get ("role" ) == "user" :
105+ messages .append (
106+ Content (
107+ role = "user" ,
108+ parts = [Part .from_text (message .get ("content" , "" ))],
120109 )
121- elif message .get ("role" ) == "assistant" :
122- messages .append (
123- Content (
124- role = "model" ,
125- parts = [Part .from_text (message .get ("content" , "" ))],
126- )
110+ )
111+ continue
112+ if message .get ("role" ) == "assistant" :
113+ messages .append (
114+ Content (
115+ role = "model" ,
116+ parts = [Part .from_text (message .get ("content" , "" ))],
127117 )
118+ )
119+ continue
120+ return system_instruction , messages
128121
129- messages .append (Content (role = "user" , parts = [Part .from_text (input )]))
130- return messages
131-
132- @rate_limit_handler
133- def invoke (
122+ def _invoke (
134123 self ,
135- input : str ,
136- message_history : Optional [Union [List [LLMMessage ], MessageHistory ]] = None ,
137- system_instruction : Optional [str ] = None ,
124+ input : list [LLMMessage ],
138125 ) -> LLMResponse :
139126 """Sends text to the LLM and returns a response.
140127
141128 Args:
142129 input (str): The text to send to the LLM.
143- message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
144- with each message having a specific role assigned.
145- system_instruction (Optional[str]): An option to override the llm system message for this invocation.
146130
147131 Returns:
148132 LLMResponse: The response from the LLM.
149133 """
134+ system_instruction , messages = self .get_messages (input )
150135 model = self ._get_model (
151136 system_instruction = system_instruction ,
152137 )
153138 try :
154- if isinstance (message_history , MessageHistory ):
155- message_history = message_history .messages
156- options = self ._get_call_params (input , message_history , tools = None )
139+ options = self ._get_call_params (messages , tools = None )
157140 response = model .generate_content (** options )
158141 return self ._parse_content_response (response )
159142 except ResponseValidationError as e :
160143 raise LLMGenerationError ("Error calling VertexAILLM" ) from e
161144
162- @async_rate_limit_handler
163- async def ainvoke (
145+ async def _ainvoke (
164146 self ,
165- input : str ,
166- message_history : Optional [Union [List [LLMMessage ], MessageHistory ]] = None ,
167- system_instruction : Optional [str ] = None ,
147+ input : list [LLMMessage ],
168148 ) -> LLMResponse :
169149 """Asynchronously sends text to the LLM and returns a response.
170150
171151 Args:
172152 input (str): The text to send to the LLM.
173- message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
174- with each message having a specific role assigned.
175- system_instruction (Optional[str]): An option to override the llm system message for this invocation.
176153
177154 Returns:
178155 LLMResponse: The response from the LLM.
179156 """
180157 try :
181- if isinstance (message_history , MessageHistory ):
182- message_history = message_history .messages
158+ system_instruction , messages = self .get_messages (input )
183159 model = self ._get_model (
184160 system_instruction = system_instruction ,
185161 )
186- options = self ._get_call_params (input , message_history , tools = None )
162+ options = self ._get_call_params (messages , tools = None )
187163 response = await model .generate_content_async (** options )
188164 return self ._parse_content_response (response )
189165 except ResponseValidationError as e :
@@ -222,8 +198,7 @@ def _get_model(
222198
223199 def _get_call_params (
224200 self ,
225- input : str ,
226- message_history : Optional [Union [List [LLMMessage ], MessageHistory ]],
201+ messages : list [Content ],
227202 tools : Optional [Sequence [Tool ]],
228203 ) -> dict [str , Any ]:
229204 options = dict (self .options )
@@ -241,7 +216,6 @@ def _get_call_params(
241216 # no tools, remove tool_config if defined
242217 options .pop ("tool_config" , None )
243218
244- messages = self .get_messages (input , message_history )
245219 options ["contents" ] = messages
246220 return options
247221
0 commit comments