Skip to content

Commit

Permalink
Add OpenAI embeddings
Browse files Browse the repository at this point in the history
Fixes #7
  • Loading branch information
KennethEnevoldsen committed Sep 24, 2023
1 parent ac8551f commit 0ab8418
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/seb/seb_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .fairseq_models import *
from .hf_models import *
from .openai_models import *
40 changes: 40 additions & 0 deletions src/seb/seb_models/openai_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
The openai embedding api's evaluated on the SEB benchmark.
"""

from functools import partial
from typing import List

import torch

from seb.model_interface import ModelInterface, ModelMeta, SebModel
from seb.registries import models


class OpenaiTextEmbeddingModel(ModelInterface):
def __init__(self, api_name: str):
self.api_name = api_name

def encode(self, sentences: List[str], batch_size: int = 32) -> torch.Tensor:
import openai

sentences = [t.replace("\n", " ") for t in sentences]
emb = openai.Embedding.create(input=sentences, model=self.api_name)
data = emb["data"]
vectors = [embedding.embedding for embedding in data]
return torch.tensor(vectors)


@models.register("text-embedding-ada-002")
def create_openai_ada_002() -> SebModel:
api_name = "text-embedding-ada-002"
meta = ModelMeta(
name=api_name,
huggingface_name=None,
reference=f"https://openai.com/blog/new-and-improved-embedding-model",
languages=[],
)
return SebModel(
loader=partial(OpenaiTextEmbeddingModel, api_name=api_name),
meta=meta,
)
13 changes: 13 additions & 0 deletions tests/test_all_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import seb

all_models = seb.get_all_models()
openai_models = []


@pytest.mark.skip(
Expand All @@ -20,3 +21,15 @@ def test_model(model: seb.SebModel, task: seb.Task):
Test if the models encodes as expected
"""
task.evaluate(model)


# @pytest.mark.skip(
# reason="This test loads in all models. It is too heavy to have running as a CI"
# )
@pytest.mark.parametrize("model", [seb.get_model("text-embedding-ada-002")])
@pytest.mark.parametrize("task", [seb.get_task("test encode task")])
def test_openai_model(model: seb.SebModel, task: seb.Task):
"""
Test if the models encodes as expected
"""
task.evaluate(model)

0 comments on commit 0ab8418

Please sign in to comment.