Skip to content

Commit

Permalink
chore: split embedders in individual files (#315)
Browse files Browse the repository at this point in the history
  • Loading branch information
smoya authored Dec 18, 2024
1 parent 17e3d28 commit 77673ee
Show file tree
Hide file tree
Showing 8 changed files with 452 additions and 396 deletions.
10 changes: 8 additions & 2 deletions docs/adding-embedding-integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,17 @@ integration. Update the tests to account for the new function.
The vectorizer worker reads the database's vectorizer configuration at runtime
and turns it into a `pgai.vectorizer.Config`.

To add a new integration, add a new embedding class with fields corresponding
to the database's jsonb configuration to `pgai/vectorizer/embeddings.py`. See
To add a new integration, add a new file containing the embedding class
with fields corresponding to the database's jsonb configuration into the
[embedders directory] directory. See
the existing implementations for examples of how to do this. Implement the
`Embedder` class' abstract methods. Use first-party python libraries for the
integration, if available. If no first-party python libraries are available,
use direct HTTP requests.

Remember to include the import line of your recently created class into the
[embedders \_\_init\_\_.py].

Add tests which perform end-to-end testing of the new integration. There are
two options for handling API calls to the integration API:

Expand All @@ -49,6 +53,8 @@ used conservatively. We will determine on a case-by-case basis what level of
testing we would like.

[vcr.py]:https://vcrpy.readthedocs.io/en/latest/
[embedders directory]:/projects/pgai/pgai/vectorizer/embedders
[embedders \_\_init\_\_.py]:/projects/pgai/pgai/vectorizer/embedders/__init__.py

## Documentation

Expand Down
4 changes: 4 additions & 0 deletions projects/pgai/justfile
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ test:
lint:
@uv run ruff check ./

# Run ruff linter checks and fix all auto-fixable issues
lint-fix:
@uv run ruff check ./ --fix

# Run pyright type checking
type-check:
@uv run pyright ./
Expand Down
3 changes: 3 additions & 0 deletions projects/pgai/pgai/vectorizer/embedders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .ollama import Ollama as Ollama
from .openai import OpenAI as OpenAI
from .voyageai import VoyageAI as VoyageAI
158 changes: 158 additions & 0 deletions projects/pgai/pgai/vectorizer/embedders/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import os
from collections.abc import Mapping, Sequence
from functools import cached_property
from typing import (
Any,
Literal,
)

import ollama
from pydantic import BaseModel
from typing_extensions import TypedDict, override

from ..embeddings import (
BatchApiCaller,
Embedder,
EmbeddingResponse,
EmbeddingVector,
StringDocument,
Usage,
logger,
)


# Note: this is a re-declaration of ollama.Options, which we are forced to do
# otherwise pydantic complains (ollama.Options subclasses typing.TypedDict):
# pydantic.errors.PydanticUserError: Please use `typing_extensions.TypedDict` instead of `typing.TypedDict` on Python < 3.12. # noqa
class OllamaOptions(TypedDict, total=False):
# load time options
numa: bool
num_ctx: int
num_batch: int
num_gpu: int
main_gpu: int
low_vram: bool
f16_kv: bool
logits_all: bool
vocab_only: bool
use_mmap: bool
use_mlock: bool
embedding_only: bool
num_thread: int

# runtime options
num_keep: int
seed: int
num_predict: int
top_k: int
top_p: float
tfs_z: float
typical_p: float
repeat_last_n: int
temperature: float
repeat_penalty: float
presence_penalty: float
frequency_penalty: float
mirostat: int
mirostat_tau: float
mirostat_eta: float
penalize_newline: bool
stop: Sequence[str]


class Ollama(BaseModel, Embedder):
"""
Embedder that uses Ollama to embed documents into vector representations.
Attributes:
implementation (Literal["ollama"]): The literal identifier for this
implementation.
model (str): The name of the Ollama model used for embeddings.
base_url (str): The base url used to access the Ollama API.
options (dict): Additional ollama-specific runtime options
keep_alive (str): How long to keep the model loaded after the request
"""

implementation: Literal["ollama"]
model: str
base_url: str | None = None
options: OllamaOptions | None = None
keep_alive: str | None = None # this is only `str` because of the SQL API

@override
async def embed(self, documents: list[str]) -> Sequence[EmbeddingVector]:
"""
Embeds a list of documents into vectors using Ollama's embeddings API.
Args:
documents (list[str]): A list of documents to be embedded.
Returns:
Sequence[EmbeddingVector | ChunkEmbeddingError]: The embeddings or
errors for each document.
"""
await logger.adebug(f"Chunks produced: {len(documents)}")
return await self._batcher.batch_chunks_and_embed(documents)

@cached_property
def _batcher(self) -> BatchApiCaller[StringDocument]:
return BatchApiCaller(self._max_chunks_per_batch(), self.call_embed_api)

@override
def _max_chunks_per_batch(self) -> int:
# Note: the chosen default is arbitrary - Ollama doesn't place a limit
return int(
os.getenv("PGAI_VECTORIZER_OLLAMA_MAX_CHUNKS_PER_BATCH", default="2048")
)

@override
async def setup(self):
client = ollama.AsyncClient(host=self.base_url)
try:
await client.show(self.model)
except ollama.ResponseError as e:
if f"model '{self.model}' not found" in e.error:
logger.warn(
f"pulling ollama model '{self.model}', this may take a while"
)
await client.pull(self.model)

async def call_embed_api(self, documents: str | list[str]) -> EmbeddingResponse:
response = await ollama.AsyncClient(host=self.base_url).embed(
model=self.model,
input=documents,
options=self.options,
keep_alive=self.keep_alive,
)
usage = Usage(
prompt_tokens=response["prompt_eval_count"],
total_tokens=response["prompt_eval_count"],
)
return EmbeddingResponse(embeddings=response["embeddings"], usage=usage)

async def _model(self) -> Mapping[str, Any]:
"""
Gets the model details from the Ollama API
:return:
"""
return await ollama.AsyncClient(host=self.base_url).show(self.model)

async def _context_length(self) -> int | None:
"""
Gets the context_length of the configured model, if available
"""
model = await self._model()
architecture = model["model_info"].get("general.architecture", None)
if architecture is None:
logger.warn(f"unable to determine architecture for model '{self.model}'")
return None
context_key = f"{architecture}.context_length"
# see https://github.com/ollama/ollama/blob/712d63c3f06f297e22b1ae32678349187dccd2e4/llm/ggml.go#L116-L118 # noqa
model_context_length = model["model_info"][context_key]
# the context window can be configured, so pull the value from the config
num_ctx = (
float("inf")
if self.options is None
else self.options.get("num_ctx", float("inf"))
)
return min(model_context_length, num_ctx)
Loading

0 comments on commit 77673ee

Please sign in to comment.