Skip to content

Commit 5ff611b

Browse files
committed
Mypy / ruff
1 parent 3836143 commit 5ff611b

File tree

6 files changed

+21
-36
lines changed

6 files changed

+21
-36
lines changed

examples/customize/llms/custom_llm.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import random
22
import string
3-
from typing import Any, Awaitable, Callable, List, Optional, TypeVar, Union
3+
from typing import Any, Awaitable, Callable, Optional, TypeVar
44

55
from neo4j_graphrag.llm import LLMInterface, LLMResponse
66
from neo4j_graphrag.llm.rate_limit import (
77
RateLimitHandler,
88
# rate_limit_handler,
99
# async_rate_limit_handler,
1010
)
11-
from neo4j_graphrag.message_history import MessageHistory
1211
from neo4j_graphrag.types import LLMMessage
1312

1413

@@ -18,38 +17,26 @@ def __init__(
1817
):
1918
super().__init__(model_name, **kwargs)
2019

21-
# Optional: Apply rate limit handling to synchronous invoke method
22-
# @rate_limit_handler
23-
def invoke(
20+
def _invoke(
2421
self,
25-
input: str,
26-
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
27-
system_instruction: Optional[str] = None,
22+
input: list[LLMMessage],
2823
) -> LLMResponse:
2924
content: str = (
3025
self.model_name + ": " + "".join(random.choices(string.ascii_letters, k=30))
3126
)
3227
return LLMResponse(content=content)
3328

34-
# Optional: Apply rate limit handling to asynchronous ainvoke method
35-
# @async_rate_limit_handler
36-
async def ainvoke(
29+
async def _ainvoke(
3730
self,
38-
input: str,
39-
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
40-
system_instruction: Optional[str] = None,
31+
input: list[LLMMessage],
4132
) -> LLMResponse:
4233
raise NotImplementedError()
4334

4435

45-
llm = CustomLLM(
46-
""
47-
) # if rate_limit_handler and async_rate_limit_handler decorators are used, the default rate limit handler will be applied automatically (retry with exponential backoff)
36+
llm = CustomLLM("")
4837
res: LLMResponse = llm.invoke("text")
4938
print(res.content)
5039

51-
# If rate_limit_handler and async_rate_limit_handler decorators are used and you want to use a custom rate limit handler
52-
# Type variables for function signatures used in rate limit handlers
5340
F = TypeVar("F", bound=Callable[..., Any])
5441
AF = TypeVar("AF", bound=Callable[..., Awaitable[Any]])
5542

examples/customize/llms/ollama_llm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
]
1818

1919

20-
2120
llm = OllamaLLM(
2221
model_name="orca-mini:latest",
2322
# model_params={"options": {"temperature": 0}, "format": "json"},

src/neo4j_graphrag/llm/anthropic_llm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16-
from typing import TYPE_CHECKING, Any, Iterable, Optional
17-
16+
from typing import TYPE_CHECKING, Any, Iterable, Optional, Union
1817

1918
from neo4j_graphrag.exceptions import LLMGenerationError
2019
from neo4j_graphrag.llm.base import LLMInterface
@@ -28,6 +27,7 @@
2827

2928
if TYPE_CHECKING:
3029
from anthropic.types.message_param import MessageParam
30+
from anthropic import NotGiven
3131

3232

3333
class AnthropicLLM(LLMInterface):
@@ -78,16 +78,16 @@ def __init__(
7878
def get_messages(
7979
self,
8080
input: list[LLMMessage],
81-
) -> tuple[str, Iterable[MessageParam]]:
81+
) -> tuple[Union[str, NotGiven], Iterable[MessageParam]]:
8282
messages: list[MessageParam] = []
83-
system_instruction = self.anthropic.NOT_GIVEN
83+
system_instruction: Union[str, NotGiven] = self.anthropic.NOT_GIVEN
8484
for i in input:
8585
if i["role"] == "system":
8686
system_instruction = i["content"]
8787
else:
8888
messages.append(
8989
self.anthropic.types.MessageParam(
90-
role=i["role"], # type: ignore
90+
role=i["role"],
9191
content=i["content"],
9292
)
9393
)

src/neo4j_graphrag/llm/mistralai_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
except ImportError:
4040
Mistral = None # type: ignore
4141
SDKError = None # type: ignore
42-
Messages = Any
42+
Messages = None # type: ignore
4343

4444

4545
class MistralAILLM(LLMInterface):

src/neo4j_graphrag/llm/ollama_llm.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,7 @@ def get_messages(
7171
self,
7272
input: list[LLMMessage],
7373
) -> Sequence[Message]:
74-
return [
75-
self.ollama.Message(**i)
76-
for i in input
77-
]
74+
return [self.ollama.Message(**i) for i in input]
7875

7976
def _invoke(
8077
self,

src/neo4j_graphrag/llm/openai_llm.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
Iterable,
2626
Sequence,
2727
Union,
28-
cast,
28+
cast, Type,
2929
)
3030

3131
from neo4j_graphrag.message_history import MessageHistory
@@ -44,9 +44,11 @@
4444
if TYPE_CHECKING:
4545
from openai.types.chat import (
4646
ChatCompletionMessageParam,
47-
ChatCompletionToolParam, ChatCompletionUserMessageParam,
48-
ChatCompletionSystemMessageParam, ChatCompletionAssistantMessageParam,
49-
)
47+
ChatCompletionToolParam,
48+
ChatCompletionUserMessageParam,
49+
ChatCompletionSystemMessageParam,
50+
ChatCompletionAssistantMessageParam,
51+
)
5052
from openai import OpenAI, AsyncOpenAI
5153
from .rate_limit import RateLimitHandler
5254
else:
@@ -93,7 +95,7 @@ def get_messages(
9395
) -> Iterable[ChatCompletionMessageParam]:
9496
chat_messages = []
9597
for m in messages:
96-
message_type: ChatCompletionMessageParam
98+
message_type: Type[ChatCompletionMessageParam]
9799
if m["role"] == "system":
98100
message_type = ChatCompletionSystemMessageParam
99101
elif m["role"] == "user":
@@ -104,7 +106,7 @@ def get_messages(
104106
raise ValueError(f"Unknown message type: {m['role']}")
105107
chat_messages.append(
106108
message_type(
107-
role=m["role"],
109+
role=m["role"], # type: ignore
108110
content=m["content"],
109111
)
110112
)

0 commit comments

Comments
 (0)