Skip to content

Commit a7ab4f9

Browse files
committed
Do not force users to implement _invoke/_ainvoke to prevent breaking changes + add deprecation warning
1 parent 7b4e2d3 commit a7ab4f9

File tree

3 files changed

+39
-18
lines changed

3 files changed

+39
-18
lines changed

docs/source/user_guide_rag.rst

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -265,21 +265,17 @@ Here's an example using the Python Ollama client:
265265
266266
import ollama
267267
from neo4j_graphrag.llm import LLMInterface, LLMResponse
268+
from neo4j_graphrag.types import LLMMessage
268269
269270
class OllamaLLM(LLMInterface):
270271
271-
def invoke(self, input: str) -> LLMResponse:
272-
response = ollama.chat(model=self.model_name, messages=[
273-
{
274-
'role': 'user',
275-
'content': input,
276-
},
277-
])
272+
def _invoke(self, input: list[LLMMessage]) -> LLMResponse:
273+
response = ollama.chat(model=self.model_name, messages=input)
278274
return LLMResponse(
279275
content=response["message"]["content"]
280276
)
281277
282-
async def ainvoke(self, input: str) -> LLMResponse:
278+
async def _ainvoke(self, input: list[LLMMessage]) -> LLMResponse:
283279
return self.invoke(input) # TODO: implement async with ollama.AsyncClient
284280
285281

src/neo4j_graphrag/llm/base.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
from abc import ABC, abstractmethod
17+
import warnings
1818
from typing import Any, List, Optional, Sequence, Union
1919

2020
from pydantic import ValidationError
@@ -36,7 +36,7 @@
3636
from ..exceptions import LLMGenerationError
3737

3838

39-
class LLMInterface(ABC):
39+
class LLMInterface:
4040
"""Interface for large language models.
4141
4242
Args:
@@ -68,6 +68,16 @@ def invoke(
6868
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
6969
system_instruction: Optional[str] = None,
7070
) -> LLMResponse:
71+
if message_history:
72+
warnings.warn(
73+
"Using 'message_history' in the llm.invoke method is deprecated. Please use invoke(list[LLMMessage]) instead.",
74+
DeprecationWarning,
75+
)
76+
if system_instruction:
77+
warnings.warn(
78+
"Using 'system_instruction' in the llm.invoke method is deprecated. Please use invoke(list[LLMMessage]) instead.",
79+
DeprecationWarning,
80+
)
7181
try:
7282
messages = legacy_inputs_to_messages(
7383
input, message_history, system_instruction
@@ -76,7 +86,6 @@ def invoke(
7686
raise LLMGenerationError("Input validation failed") from e
7787
return self._invoke(messages)
7888

79-
@abstractmethod
8089
def _invoke(
8190
self,
8291
input: list[LLMMessage],
@@ -92,6 +101,7 @@ def _invoke(
92101
Raises:
93102
LLMGenerationError: If anything goes wrong.
94103
"""
104+
raise NotImplementedError()
95105

96106
@async_rate_limit_handler
97107
async def ainvoke(
@@ -100,10 +110,19 @@ async def ainvoke(
100110
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
101111
system_instruction: Optional[str] = None,
102112
) -> LLMResponse:
113+
if message_history:
114+
warnings.warn(
115+
"Using 'message_history' in the llm.ainvoke method is deprecated. Please use invoke(list[LLMMessage]) instead.",
116+
DeprecationWarning,
117+
)
118+
if system_instruction:
119+
warnings.warn(
120+
"Using 'system_instruction' in the llm.ainvoke method is deprecated. Please use invoke(list[LLMMessage]) instead.",
121+
DeprecationWarning,
122+
)
103123
messages = legacy_inputs_to_messages(input, message_history, system_instruction)
104124
return await self._ainvoke(messages)
105125

106-
@abstractmethod
107126
async def _ainvoke(
108127
self,
109128
input: list[LLMMessage],
@@ -119,6 +138,7 @@ async def _ainvoke(
119138
Raises:
120139
LLMGenerationError: If anything goes wrong.
121140
"""
141+
raise NotImplementedError()
122142

123143
@rate_limit_handler
124144
def invoke_with_tools(

tests/unit/llm/test_base.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,11 @@
2020

2121
@fixture(scope="module") # type: ignore[misc]
2222
def llm_interface() -> Generator[Type[LLMInterface], None, None]:
23-
real_abstract_methods = LLMInterface.__abstractmethods__
24-
LLMInterface.__abstractmethods__ = frozenset()
25-
2623
class CustomLLMInterface(LLMInterface):
2724
pass
2825

2926
yield CustomLLMInterface
3027

31-
LLMInterface.__abstractmethods__ = real_abstract_methods
32-
3328

3429
@patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages")
3530
def test_base_llm_interface_invoke_with_input_as_str(
@@ -52,7 +47,8 @@ def test_base_llm_interface_invoke_with_input_as_str(
5247
system_instruction = "You are a genius."
5348

5449
with patch.object(llm, "_invoke") as mock_invoke:
55-
llm.invoke(question, message_history, system_instruction)
50+
with pytest.warns(DeprecationWarning) as record:
51+
llm.invoke(question, message_history, system_instruction)
5652
mock_invoke.assert_called_once_with(
5753
[
5854
LLMMessage(
@@ -66,6 +62,15 @@ def test_base_llm_interface_invoke_with_input_as_str(
6662
message_history,
6763
system_instruction,
6864
)
65+
assert len(record) == 2
66+
assert (
67+
"Using 'message_history' in the llm.invoke method is deprecated"
68+
in record[0].message.args[0] # type: ignore[union-attr]
69+
)
70+
assert (
71+
"Using 'system_instruction' in the llm.invoke method is deprecated"
72+
in record[1].message.args[0] # type: ignore[union-attr]
73+
)
6974

7075

7176
@patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages")

0 commit comments

Comments
 (0)