Skip to content

Commit c766f30

Browse files
authored
Fix embedders (#435)
* Fix embedders * Remove test for rate limit handler for SentenceTransformerEmbeddings
1 parent 4f563e2 commit c766f30

File tree

12 files changed

+23
-79
lines changed

12 files changed

+23
-79
lines changed

examples/customize/embeddings/custom_embeddings.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66

77
class CustomEmbeddings(Embedder):
88
def __init__(self, dimension: int = 10, **kwargs: Any):
9+
super().__init__(**kwargs)
910
self.dimension = dimension
1011

11-
def _embed_query(self, input: str) -> list[float]:
12+
def embed_query(self, input: str) -> list[float]:
1213
return [random.random() for _ in range(self.dimension)]
1314

1415

examples/customize/retrievers/hybrid_retrievers/hybrid_cypher_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
# Create Embedder object
2222
class CustomEmbedder(Embedder):
23-
def _embed_query(self, text: str) -> list[float]:
23+
def embed_query(self, text: str) -> list[float]:
2424
return [random() for _ in range(DIMENSION)]
2525

2626

examples/customize/retrievers/hybrid_retrievers/hybrid_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
# Create Embedder object
2222
class CustomEmbedder(Embedder):
23-
def _embed_query(self, text: str) -> list[float]:
23+
def embed_query(self, text: str) -> list[float]:
2424
return [random() for _ in range(DIMENSION)]
2525

2626

src/neo4j_graphrag/embeddings/base.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from neo4j_graphrag.utils.rate_limit import (
2121
DEFAULT_RATE_LIMIT_HANDLER,
2222
RateLimitHandler,
23-
rate_limit_handler,
2423
)
2524

2625

@@ -39,20 +38,8 @@ def __init__(self, rate_limit_handler: Optional[RateLimitHandler] = None):
3938
else:
4039
self._rate_limit_handler = DEFAULT_RATE_LIMIT_HANDLER
4140

42-
@rate_limit_handler
43-
def embed_query(self, text: str) -> list[float]:
44-
"""Embed query text.
45-
46-
Args:
47-
text (str): Text to convert to vector embedding
48-
49-
Returns:
50-
list[float]: A vector embedding.
51-
"""
52-
return self._embed_query(text)
53-
5441
@abstractmethod
55-
def _embed_query(self, text: str) -> list[float]:
42+
def embed_query(self, text: str) -> list[float]:
5643
"""Embed query text.
5744
5845
Args:

src/neo4j_graphrag/embeddings/cohere.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from neo4j_graphrag.embeddings.base import Embedder
2020
from neo4j_graphrag.exceptions import EmbeddingsGenerationError
21-
from neo4j_graphrag.utils.rate_limit import RateLimitHandler
21+
from neo4j_graphrag.utils.rate_limit import RateLimitHandler, rate_limit_handler
2222

2323
try:
2424
import cohere
@@ -42,7 +42,8 @@ def __init__(
4242
self.model = model
4343
self.client = cohere.Client(**kwargs)
4444

45-
def _embed_query(self, text: str, **kwargs: Any) -> list[float]:
45+
@rate_limit_handler
46+
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
4647
try:
4748
response = self.client.embed(
4849
texts=[text],

src/neo4j_graphrag/embeddings/mistral.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from neo4j_graphrag.embeddings.base import Embedder
2222
from neo4j_graphrag.exceptions import EmbeddingsGenerationError
23-
from neo4j_graphrag.utils.rate_limit import RateLimitHandler
23+
from neo4j_graphrag.utils.rate_limit import RateLimitHandler, rate_limit_handler
2424

2525
try:
2626
from mistralai import Mistral
@@ -55,7 +55,8 @@ def __init__(
5555
self.model = model
5656
self.mistral_client = Mistral(api_key=api_key, **kwargs)
5757

58-
def _embed_query(self, text: str, **kwargs: Any) -> list[float]:
58+
@rate_limit_handler
59+
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
5960
"""
6061
Generate embeddings for a given query using a Mistral AI text embedding model.
6162

src/neo4j_graphrag/embeddings/ollama.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from neo4j_graphrag.embeddings.base import Embedder
2121
from neo4j_graphrag.exceptions import EmbeddingsGenerationError
22-
from neo4j_graphrag.utils.rate_limit import RateLimitHandler
22+
from neo4j_graphrag.utils.rate_limit import RateLimitHandler, rate_limit_handler
2323

2424

2525
class OllamaEmbeddings(Embedder):
@@ -48,7 +48,8 @@ def __init__(
4848
self.model = model
4949
self.client = ollama.Client(**kwargs)
5050

51-
def _embed_query(self, text: str, **kwargs: Any) -> list[float]:
51+
@rate_limit_handler
52+
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
5253
"""
5354
Generate embeddings for a given query using an Ollama text embedding model.
5455

src/neo4j_graphrag/embeddings/openai.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from neo4j_graphrag.embeddings.base import Embedder
2222
from neo4j_graphrag.exceptions import EmbeddingsGenerationError
23-
from neo4j_graphrag.utils.rate_limit import RateLimitHandler
23+
from neo4j_graphrag.utils.rate_limit import RateLimitHandler, rate_limit_handler
2424

2525
if TYPE_CHECKING:
2626
import openai
@@ -59,7 +59,8 @@ def _initialize_client(self, **kwargs: Any) -> Any:
5959
"""
6060
pass
6161

62-
def _embed_query(self, text: str, **kwargs: Any) -> list[float]:
62+
@rate_limit_handler
63+
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
6364
"""
6465
Generate embeddings for a given query using an OpenAI text embedding model.
6566

src/neo4j_graphrag/embeddings/sentence_transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(
4242
self.np = np
4343
self.model = sentence_transformers.SentenceTransformer(model, *args, **kwargs)
4444

45-
def _embed_query(self, text: str) -> Any:
45+
def embed_query(self, text: str) -> Any:
4646
try:
4747
result = self.model.encode([text])
4848

src/neo4j_graphrag/embeddings/vertexai.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from neo4j_graphrag.embeddings.base import Embedder
2020
from neo4j_graphrag.exceptions import EmbeddingsGenerationError
21-
from neo4j_graphrag.utils.rate_limit import RateLimitHandler
21+
from neo4j_graphrag.utils.rate_limit import RateLimitHandler, rate_limit_handler
2222

2323
try:
2424
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
@@ -52,7 +52,8 @@ def __init__(
5252
super().__init__(rate_limit_handler)
5353
self.model = TextEmbeddingModel.from_pretrained(model)
5454

55-
def _embed_query(
55+
@rate_limit_handler
56+
def embed_query(
5657
self, text: str, task_type: str = "RETRIEVAL_QUERY", **kwargs: Any
5758
) -> list[float]:
5859
"""

0 commit comments

Comments
 (0)