Skip to content

Commit 2c4f4e5

Browse files
committed
Update VertexAILLM
1 parent 5ff611b commit 2c4f4e5

File tree

2 files changed

+47
-60
lines changed

2 files changed

+47
-60
lines changed
Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,20 @@
11
from neo4j_graphrag.llm import LLMResponse, VertexAILLM
22
from vertexai.generative_models import GenerationConfig
33

4+
from neo4j_graphrag.types import LLMMessage
5+
6+
messages: list[LLMMessage] = [
7+
{
8+
"role": "system",
9+
"content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.",
10+
},
11+
{
12+
"role": "user",
13+
"content": "say something",
14+
},
15+
]
16+
17+
418
generation_config = GenerationConfig(temperature=1.0)
519
llm = VertexAILLM(
620
model_name="gemini-2.0-flash-001",
@@ -9,7 +23,6 @@
923
# vertexai.generative_models.GenerativeModel client
1024
)
1125
res: LLMResponse = llm.invoke(
12-
"say something",
13-
system_instruction="You are living in 3000 where AI rules the world",
26+
input=messages,
1427
)
1528
print(res.content)

src/neo4j_graphrag/llm/vertexai_llm.py

Lines changed: 32 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,16 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16-
from typing import Any, List, Optional, Union, cast, Sequence
16+
from typing import Any, List, Optional, Union, Sequence
1717

18-
from pydantic import ValidationError
1918

2019
from neo4j_graphrag.exceptions import LLMGenerationError
2120
from neo4j_graphrag.llm.base import LLMInterface
2221
from neo4j_graphrag.llm.rate_limit import (
2322
RateLimitHandler,
24-
rate_limit_handler,
25-
async_rate_limit_handler,
2623
)
2724
from neo4j_graphrag.llm.types import (
28-
BaseMessage,
2925
LLMResponse,
30-
MessageList,
3126
ToolCall,
3227
ToolCallResponse,
3328
)
@@ -98,92 +93,73 @@ def __init__(
9893

9994
def get_messages(
10095
self,
101-
input: str,
102-
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
103-
) -> list[Content]:
96+
input: list[LLMMessage],
97+
) -> tuple[str | None, list[Content]]:
10498
messages = []
105-
if message_history:
106-
if isinstance(message_history, MessageHistory):
107-
message_history = message_history.messages
108-
try:
109-
MessageList(messages=cast(list[BaseMessage], message_history))
110-
except ValidationError as e:
111-
raise LLMGenerationError(e.errors()) from e
112-
113-
for message in message_history:
114-
if message.get("role") == "user":
115-
messages.append(
116-
Content(
117-
role="user",
118-
parts=[Part.from_text(message.get("content", ""))],
119-
)
99+
system_instruction = self.system_instruction
100+
for message in input:
101+
if message.get("role") == "system":
102+
system_instruction = message.get("content")
103+
continue
104+
if message.get("role") == "user":
105+
messages.append(
106+
Content(
107+
role="user",
108+
parts=[Part.from_text(message.get("content", ""))],
120109
)
121-
elif message.get("role") == "assistant":
122-
messages.append(
123-
Content(
124-
role="model",
125-
parts=[Part.from_text(message.get("content", ""))],
126-
)
110+
)
111+
continue
112+
if message.get("role") == "assistant":
113+
messages.append(
114+
Content(
115+
role="model",
116+
parts=[Part.from_text(message.get("content", ""))],
127117
)
118+
)
119+
continue
120+
return system_instruction, messages
128121

129-
messages.append(Content(role="user", parts=[Part.from_text(input)]))
130-
return messages
131-
132-
@rate_limit_handler
133-
def invoke(
122+
def _invoke(
134123
self,
135-
input: str,
136-
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
137-
system_instruction: Optional[str] = None,
124+
input: list[LLMMessage],
138125
) -> LLMResponse:
139126
"""Sends text to the LLM and returns a response.
140127
141128
Args:
142129
input (str): The text to send to the LLM.
143-
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
144-
with each message having a specific role assigned.
145-
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
146130
147131
Returns:
148132
LLMResponse: The response from the LLM.
149133
"""
134+
system_instruction, messages = self.get_messages(input)
150135
model = self._get_model(
151136
system_instruction=system_instruction,
152137
)
153138
try:
154-
if isinstance(message_history, MessageHistory):
155-
message_history = message_history.messages
156-
options = self._get_call_params(input, message_history, tools=None)
139+
options = self._get_call_params(messages, tools=None)
157140
response = model.generate_content(**options)
158141
return self._parse_content_response(response)
159142
except ResponseValidationError as e:
160143
raise LLMGenerationError("Error calling VertexAILLM") from e
161144

162-
@async_rate_limit_handler
163-
async def ainvoke(
145+
async def _ainvoke(
164146
self,
165-
input: str,
166-
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
167-
system_instruction: Optional[str] = None,
147+
input: list[LLMMessage],
168148
) -> LLMResponse:
169149
"""Asynchronously sends text to the LLM and returns a response.
170150
171151
Args:
172152
input (str): The text to send to the LLM.
173-
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
174-
with each message having a specific role assigned.
175-
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
176153
177154
Returns:
178155
LLMResponse: The response from the LLM.
179156
"""
180157
try:
181-
if isinstance(message_history, MessageHistory):
182-
message_history = message_history.messages
158+
system_instruction, messages = self.get_messages(input)
183159
model = self._get_model(
184160
system_instruction=system_instruction,
185161
)
186-
options = self._get_call_params(input, message_history, tools=None)
162+
options = self._get_call_params(messages, tools=None)
187163
response = await model.generate_content_async(**options)
188164
return self._parse_content_response(response)
189165
except ResponseValidationError as e:
@@ -222,8 +198,7 @@ def _get_model(
222198

223199
def _get_call_params(
224200
self,
225-
input: str,
226-
message_history: Optional[Union[List[LLMMessage], MessageHistory]],
201+
messages: list[Content],
227202
tools: Optional[Sequence[Tool]],
228203
) -> dict[str, Any]:
229204
options = dict(self.options)
@@ -241,7 +216,6 @@ def _get_call_params(
241216
# no tools, remove tool_config if defined
242217
options.pop("tool_config", None)
243218

244-
messages = self.get_messages(input, message_history)
245219
options["contents"] = messages
246220
return options
247221

0 commit comments

Comments
 (0)