Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update text embedding component #532

Merged
merged 5 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
FROM --platform=linux/amd64 python:3.8-slim as base
FROM --platform=linux/amd64 pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime as base

# System dependencies
RUN apt-get update && \
apt-get upgrade -y && \
apt-get install git -y

# Install requirements
COPY requirements.txt /
COPY requirements.txt ./
RUN pip3 install --no-cache-dir -r requirements.txt

# Install Fondant
# This is split from other requirements to leverage caching
ARG FONDANT_VERSION=main
RUN pip3 install fondant[aws,azure,gcp]@git+https://github.com/ml6team/fondant@${FONDANT_VERSION}

# Set the working directory to the component folder
WORKDIR /component/src
WORKDIR /component
COPY src/ src/
ENV PYTHONPATH "${PYTHONPATH}:./src"

# Copy over src-files
COPY src/ .
FROM base as test
COPY test_requirements.txt .
RUN pip3 install --no-cache-dir -r test_requirements.txt
COPY tests/ tests/
RUN python -m pytest tests

FROM base
WORKDIR /component/src
ENTRYPOINT ["fondant", "execute", "main"]
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Generate embeddings
# Embed text

### Description
Component that generates embeddings of text passages.
Expand All @@ -22,9 +22,10 @@ The component takes the following arguments to alter its behavior:

| argument | type | description | default |
| -------- | ---- | ----------- | ------- |
| model_provider | str | The provider of the model - corresponding to langchain embedding classes. Currently the following providers are supported: aleph_alpha, cohere, huggingface, openai. | huggingface |
| model | str | The model to generate embeddings from. Choose an available model name to pass to the model provider's langchain embedding class. | all-MiniLM-L6-v2 |
| model_provider | str | The provider of the model - corresponding to langchain embedding classes. Currently the following providers are supported: aleph_alpha, cohere, huggingface, openai, vertexai. | huggingface |
| model | str | The model to generate embeddings from. Choose an available model name to pass to the model provider's langchain embedding class. | / |
| api_keys | dict | The API keys to use for the model provider that are written to environment variables.Pass only the keys required by the model provider or conveniently pass all keys you will ever need. Pay attention how to name the dictionary keys so that they can be used by the model provider. | / |
| auth_kwargs | dict | Additional keyword arguments required for api initialization/authentication. | / |

### Usage

Expand All @@ -34,15 +35,21 @@ You can add this component to your pipeline using the following code:
from fondant.pipeline import ComponentOp


generate_embeddings_op = ComponentOp.from_registry(
name="generate_embeddings",
embed_text_op = ComponentOp.from_registry(
name="embed_text",
arguments={
# Add arguments
# "model_provider": "huggingface",
# "model": "all-MiniLM-L6-v2",
# "model": ,
# "api_keys": {},
# "auth_kwargs": {},
}
)
pipeline.add_op(generate_embeddings_op, dependencies=[...]) #Add previous component as dependency
pipeline.add_op(embed_text_op, dependencies=[...]) #Add previous component as dependency
```

### Testing

You can run the tests using docker with BuildKit. From this directory, run:
```
docker build . --target test
```
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: Generate embeddings
name: Embed text
description: Component that generates embeddings of text passages.
image: generate_embeddings:latest
image: embed_text:latest

consumes:
text:
Expand All @@ -22,20 +22,28 @@ args:
model_provider:
description: |
The provider of the model - corresponding to langchain embedding classes.
Currently the following providers are supported: aleph_alpha, cohere, huggingface, openai.
Currently the following providers are supported: aleph_alpha, cohere, huggingface, openai,
vertexai.
type: str
default: huggingface
model:
description: |
The model to generate embeddings from.
Choose an available model name to pass to the model provider's langchain embedding class.
type: str
default: all-MiniLM-L6-v2
default: None
api_keys:
description: |
The API keys to use for the model provider that are written to environment variables.
Pass only the keys required by the model provider or conveniently pass all keys you will ever need.
Pay attention how to name the dictionary keys so that they can be used by the model provider.
type: dict
default: {}
auth_kwargs:
description: |
Additional keyword arguments required for api initialization/authentication.
type: dict
default: {}



Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
aleph_alpha_client==3.5.1
cohere==4.27
google-cloud-aiplatform==1.34.0
langchain==0.0.313
openai==0.28.1
pandas==1.5.0
Expand Down
Original file line number Diff line number Diff line change
@@ -1,33 +1,53 @@
import logging
import os

import google.cloud.aiplatform as aip
import pandas as pd
from fondant.component import PandasTransformComponent
from langchain.embeddings import (
AlephAlphaAsymmetricSemanticEmbedding,
CohereEmbeddings,
HuggingFaceEmbeddings,
OpenAIEmbeddings,
VertexAIEmbeddings,
)
from langchain.schema.embeddings import Embeddings
from retry import retry
from utils import to_env_vars

logger = logging.getLogger(__name__)


class GenerateEmbeddingsComponent(PandasTransformComponent):
def to_env_vars(api_keys: dict):
for key, value in api_keys.items():
os.environ[key] = value


class EmbedTextComponent(PandasTransformComponent):
def __init__(
self,
*_,
model_provider: str,
model: str,
api_keys: dict,
auth_kwargs: dict,
):
self.model_provider = model_provider
self.model = model
self.embedding_model = self.get_embedding_model(
model_provider,
model,
auth_kwargs,
)

to_env_vars(api_keys)

def get_embedding_model(self, model_provider, model: str):
@staticmethod
def get_embedding_model(
model_provider,
model: str,
auth_kwargs: dict,
) -> Embeddings:
if model_provider == "vertexai":
aip.init(**auth_kwargs)
return VertexAIEmbeddings(model=model)
# contains a first selection of embedding models
if model_provider == "aleph_alpha":
return AlephAlphaAsymmetricSemanticEmbedding(model=model)
Expand All @@ -41,13 +61,11 @@ def get_embedding_model(self, model_provider, model: str):
raise ValueError(msg)

@retry() # make sure to keep trying even when api call limit is reached
def get_embeddings_vectors(self, embedding_model, texts):
return embedding_model.embed_documents(texts.tolist())
def get_embeddings_vectors(self, texts):
return self.embedding_model.embed_documents(texts.tolist())

def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
embedding_model = self.get_embedding_model(self.model_provider, self.model)
dataframe[("text", "embedding")] = self.get_embeddings_vectors(
embedding_model,
dataframe[("text", "data")],
)
return dataframe
1 change: 1 addition & 0 deletions components/embed_text/test_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pytest==7.4.2
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pandas as pd

from components.generate_embeddings.src.main import GenerateEmbeddingsComponent
from src.main import EmbedTextComponent


def embeddings_close(a, b):
Expand All @@ -13,9 +13,9 @@ def embeddings_close(a, b):

def test_run_component_test():
"""Test generate embeddings component."""
with open("lorem_300.txt", encoding="utf-8") as f:
with open("tests/lorem_300.txt", encoding="utf-8") as f:
lorem_300 = f.read()
with open("lorem_400.txt", encoding="utf-8") as f:
with open("tests/lorem_400.txt", encoding="utf-8") as f:
lorem_400 = f.read()

# Given: Dataframe with text
Expand All @@ -29,15 +29,16 @@ def test_run_component_test():

dataframe = pd.concat({"text": pd.DataFrame(data)}, axis=1, names=["text", "data"])

component = GenerateEmbeddingsComponent(
component = EmbedTextComponent(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not ideal that we're loading an actual model (which makes the test dependent on an internet connection), although I don't mind as much in the component tests since we currently run them manually

model_provider="huggingface",
model="all-MiniLM-L6-v2",
api_keys={},
auth_kwargs={},
)

dataframe = component.transform(dataframe=dataframe)

with open("hello_world_embedding.txt", encoding="utf-8") as f:
with open("tests/hello_world_embedding.txt", encoding="utf-8") as f:
hello_world_embedding = f.read()
hello_world_embedding = json.loads(hello_world_embedding)

Expand Down
6 changes: 0 additions & 6 deletions components/generate_embeddings/src/utils.py

This file was deleted.

Loading