Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

embeddings: use NewEmbedder for OpenAI embeddings #385

Merged
merged 1 commit into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions embeddings/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,21 @@ import (
"strings"
)

// NewEmbedder creates a new Embedder from the given EmbedderClient, with
// some options that affect how embedding will be done.
func NewEmbedder(client EmbedderClient, opts ...Option) (*EmbedderImpl, error) {
e := &EmbedderImpl{
client: client,
StripNewLines: defaultStripNewLines,
BatchSize: defaultBatchSize,
}

for _, opt := range opts {
opt(e)
}
return e, nil
}

// Embedder is the interface for creating vector embeddings from texts.
type Embedder interface {
// EmbedDocuments returns a vector for each text.
Expand All @@ -25,19 +40,6 @@ type EmbedderImpl struct {
BatchSize int
}

func NewEmbedder(client EmbedderClient, opts ...Option) (*EmbedderImpl, error) {
e := &EmbedderImpl{
client: client,
StripNewLines: defaultStripNewLines,
BatchSize: defaultBatchSize,
}

for _, opt := range opts {
opt(e)
}
return e, nil
}

// EmbedQuery embeds a single text.
func (ei *EmbedderImpl) EmbedQuery(ctx context.Context, text string) ([]float32, error) {
if ei.StripNewLines {
Expand Down
49 changes: 0 additions & 49 deletions embeddings/openai/openai.go

This file was deleted.

55 changes: 0 additions & 55 deletions embeddings/openai/options.go

This file was deleted.

58 changes: 24 additions & 34 deletions embeddings/openai/openai_test.go → embeddings/openai_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package openai
package embeddings

import (
"context"
Expand All @@ -7,20 +7,30 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tmc/langchaingo/embeddings"
"github.com/tmc/langchaingo/llms/openai"
)

func TestOpenaiEmbeddings(t *testing.T) {
t.Parallel()

func newOpenAIEmbedder(t *testing.T, opts ...Option) *EmbedderImpl {
t.Helper()
if openaiKey := os.Getenv("OPENAI_API_KEY"); openaiKey == "" {
t.Skip("OPENAI_API_KEY not set")
return nil
}
e, err := NewOpenAI()

llm, err := openai.New()
require.NoError(t, err)

_, err = e.EmbedQuery(context.Background(), "Hello world!")
embedder, err := NewEmbedder(llm, opts...)
require.NoError(t, err)

return embedder
}

func TestOpenaiEmbeddings(t *testing.T) {
t.Parallel()

e := newOpenAIEmbedder(t)
_, err := e.EmbedQuery(context.Background(), "Hello world!")
require.NoError(t, err)

embeddings, err := e.EmbedDocuments(context.Background(), []string{"Hello world", "The world is ending", "good bye"})
Expand All @@ -33,14 +43,8 @@ func TestOpenaiEmbeddingsQueryVsDocuments(t *testing.T) {
// of which method we call.
t.Parallel()

if openaiKey := os.Getenv("OPENAI_API_KEY"); openaiKey == "" {
t.Skip("OPENAI_API_KEY not set")
}
e, err := NewOpenAI()
require.NoError(t, err)

e := newOpenAIEmbedder(t)
text := "hi there"

eq, err := e.EmbedQuery(context.Background(), text)
require.NoError(t, err)

Expand All @@ -55,17 +59,9 @@ func TestOpenaiEmbeddingsQueryVsDocuments(t *testing.T) {
func TestOpenaiEmbeddingsWithOptions(t *testing.T) {
t.Parallel()

if openaiKey := os.Getenv("OPENAI_API_KEY"); openaiKey == "" {
t.Skip("OPENAI_API_KEY not set")
}

client, err := openai.New()
require.NoError(t, err)

e, err := NewOpenAI(WithClient(*client), WithBatchSize(1), WithStripNewLines(false))
require.NoError(t, err)
e := newOpenAIEmbedder(t, WithBatchSize(1), WithStripNewLines(false))

_, err = e.EmbedQuery(context.Background(), "Hello world!")
_, err := e.EmbedQuery(context.Background(), "Hello world!")
require.NoError(t, err)

embeddings, err := e.EmbedDocuments(context.Background(), []string{"Hello world"})
Expand Down Expand Up @@ -94,7 +90,7 @@ func TestOpenaiEmbeddingsWithAzureAPI(t *testing.T) {
)
require.NoError(t, err)

e, err := NewOpenAI(WithClient(*client), WithBatchSize(1), WithStripNewLines(false))
e, err := NewEmbedder(client, WithBatchSize(1), WithStripNewLines(false))
require.NoError(t, err)

_, err = e.EmbedQuery(context.Background(), "Hello world!")
Expand All @@ -112,17 +108,11 @@ func TestUseLLMAndChatAsEmbedderClient(t *testing.T) {
t.Skip("OPENAI_API_KEY not set")
}

llm, err := openai.New()
require.NoError(t, err)

embedderFromLLM, err := embeddings.NewEmbedder(llm)
require.NoError(t, err)
var _ embeddings.Embedder = embedderFromLLM

// Shows that we can pass an openai chat value to NewEmbedder
chat, err := openai.NewChat()
require.NoError(t, err)

embedderFromChat, err := embeddings.NewEmbedder(chat)
embedderFromChat, err := NewEmbedder(chat)
require.NoError(t, err)
var _ embeddings.Embedder = embedderFromChat
var _ Embedder = embedderFromChat
}
2 changes: 1 addition & 1 deletion examples/anthropic-completion-example/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module github.com/tmc/langchaingo/examples/anthropic-completion-example

go 1.19

require github.com/tmc/langchaingo v0.0.0-20231130160443-fc423fab7b81
require github.com/tmc/langchaingo v0.0.0-20231130223434-98fa24d3e7d2

require (
github.com/dlclark/regexp2 v1.8.1 // indirect
Expand Down
4 changes: 2 additions & 2 deletions examples/anthropic-completion-example/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ github.com/pkoukk/tiktoken-go v0.1.2 h1:u7PCSBiWJ3nJYoTGShyM9iHXz4dNyYkurwwp+GHt
github.com/pkoukk/tiktoken-go v0.1.2/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tmc/langchaingo v0.0.0-20231130160443-fc423fab7b81 h1:WRrvtNwd7S1etCMnYjEaem5cizL5TP7q8MTXum/OSUA=
github.com/tmc/langchaingo v0.0.0-20231130160443-fc423fab7b81/go.mod h1:WgJkGMb5Ac/WpD6YLo3zRAiHtALrgGnH42Hcu5Rs4/A=
github.com/tmc/langchaingo v0.0.0-20231130223434-98fa24d3e7d2 h1:3arY5l84Sp5SRx+9xY8vXiTpin932qv0BNAOSQbtlHY=
github.com/tmc/langchaingo v0.0.0-20231130223434-98fa24d3e7d2/go.mod h1:WgJkGMb5Ac/WpD6YLo3zRAiHtALrgGnH42Hcu5Rs4/A=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
2 changes: 1 addition & 1 deletion examples/chroma-vectorstore-example/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ go 1.20
require (
github.com/amikos-tech/chroma-go v0.0.0-20230901221218-d0087270239e
github.com/google/uuid v1.3.1
github.com/tmc/langchaingo v0.0.0-20231130160443-fc423fab7b81
github.com/tmc/langchaingo v0.0.0-20231130223434-98fa24d3e7d2
)

require (
Expand Down
4 changes: 2 additions & 2 deletions examples/chroma-vectorstore-example/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ=
github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tmc/langchaingo v0.0.0-20231130160443-fc423fab7b81 h1:WRrvtNwd7S1etCMnYjEaem5cizL5TP7q8MTXum/OSUA=
github.com/tmc/langchaingo v0.0.0-20231130160443-fc423fab7b81/go.mod h1:WgJkGMb5Ac/WpD6YLo3zRAiHtALrgGnH42Hcu5Rs4/A=
github.com/tmc/langchaingo v0.0.0-20231130223434-98fa24d3e7d2 h1:3arY5l84Sp5SRx+9xY8vXiTpin932qv0BNAOSQbtlHY=
github.com/tmc/langchaingo v0.0.0-20231130223434-98fa24d3e7d2/go.mod h1:WgJkGMb5Ac/WpD6YLo3zRAiHtALrgGnH42Hcu5Rs4/A=
go.starlark.net v0.0.0-20230302034142-4b1e35fe2254 h1:Ss6D3hLXTM0KobyBYEAygXzFfGcjnmfEJOBgSbemCtg=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/exp v0.0.0-20230510235704-dd950f8aeaea h1:vLCWI/yYrdEHyN2JzIzPO3aaQJHQdp89IZBA/+azVC4=
Expand Down
2 changes: 1 addition & 1 deletion examples/cohere-llm-example/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module github.com/tmc/langchaingo/examples/basic-llm-example

go 1.19

require github.com/tmc/langchaingo v0.0.0-20231130160443-fc423fab7b81
require github.com/tmc/langchaingo v0.0.0-20231130223434-98fa24d3e7d2

require (
github.com/cohere-ai/tokenizer v1.1.2 // indirect
Expand Down
4 changes: 2 additions & 2 deletions examples/cohere-llm-example/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ github.com/pkoukk/tiktoken-go v0.1.2 h1:u7PCSBiWJ3nJYoTGShyM9iHXz4dNyYkurwwp+GHt
github.com/pkoukk/tiktoken-go v0.1.2/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tmc/langchaingo v0.0.0-20231130160443-fc423fab7b81 h1:WRrvtNwd7S1etCMnYjEaem5cizL5TP7q8MTXum/OSUA=
github.com/tmc/langchaingo v0.0.0-20231130160443-fc423fab7b81/go.mod h1:WgJkGMb5Ac/WpD6YLo3zRAiHtALrgGnH42Hcu5Rs4/A=
github.com/tmc/langchaingo v0.0.0-20231130223434-98fa24d3e7d2 h1:3arY5l84Sp5SRx+9xY8vXiTpin932qv0BNAOSQbtlHY=
github.com/tmc/langchaingo v0.0.0-20231130223434-98fa24d3e7d2/go.mod h1:WgJkGMb5Ac/WpD6YLo3zRAiHtALrgGnH42Hcu5Rs4/A=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
2 changes: 1 addition & 1 deletion examples/document-qa-example/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module github.com/tmc/langchaingo/examples/document-qa-example

go 1.19

require github.com/tmc/langchaingo v0.0.0-20231130160443-fc423fab7b81
require github.com/tmc/langchaingo v0.0.0-20231130223434-98fa24d3e7d2

require (
github.com/Masterminds/goutils v1.1.1 // indirect
Expand Down
4 changes: 2 additions & 2 deletions examples/document-qa-example/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tmc/langchaingo v0.0.0-20231130160443-fc423fab7b81 h1:WRrvtNwd7S1etCMnYjEaem5cizL5TP7q8MTXum/OSUA=
github.com/tmc/langchaingo v0.0.0-20231130160443-fc423fab7b81/go.mod h1:WgJkGMb5Ac/WpD6YLo3zRAiHtALrgGnH42Hcu5Rs4/A=
github.com/tmc/langchaingo v0.0.0-20231130223434-98fa24d3e7d2 h1:3arY5l84Sp5SRx+9xY8vXiTpin932qv0BNAOSQbtlHY=
github.com/tmc/langchaingo v0.0.0-20231130223434-98fa24d3e7d2/go.mod h1:WgJkGMb5Ac/WpD6YLo3zRAiHtALrgGnH42Hcu5Rs4/A=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
gitlab.com/golang-commonmark/html v0.0.0-20191124015941-a22733972181 h1:K+bMSIx9A7mLES1rtG+qKduLIXq40DAzYHtb0XuCukA=
gitlab.com/golang-commonmark/linkify v0.0.0-20191026162114-a0c2df6c8f82 h1:oYrL81N608MLZhma3ruL8qTM4xcpYECGut8KSxRY59g=
Expand Down
2 changes: 1 addition & 1 deletion examples/ernie-chat-example/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module github.com/tmc/langchaingo/examples/ernie-chat-example

go 1.19

require github.com/tmc/langchaingo v0.0.0-20231130160443-fc423fab7b81
require github.com/tmc/langchaingo v0.0.0-20231130223434-98fa24d3e7d2

require (
github.com/dlclark/regexp2 v1.8.1 // indirect
Expand Down
4 changes: 2 additions & 2 deletions examples/ernie-chat-example/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ github.com/pkoukk/tiktoken-go v0.1.2 h1:u7PCSBiWJ3nJYoTGShyM9iHXz4dNyYkurwwp+GHt
github.com/pkoukk/tiktoken-go v0.1.2/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tmc/langchaingo v0.0.0-20231130160443-fc423fab7b81 h1:WRrvtNwd7S1etCMnYjEaem5cizL5TP7q8MTXum/OSUA=
github.com/tmc/langchaingo v0.0.0-20231130160443-fc423fab7b81/go.mod h1:WgJkGMb5Ac/WpD6YLo3zRAiHtALrgGnH42Hcu5Rs4/A=
github.com/tmc/langchaingo v0.0.0-20231130223434-98fa24d3e7d2 h1:3arY5l84Sp5SRx+9xY8vXiTpin932qv0BNAOSQbtlHY=
github.com/tmc/langchaingo v0.0.0-20231130223434-98fa24d3e7d2/go.mod h1:WgJkGMb5Ac/WpD6YLo3zRAiHtALrgGnH42Hcu5Rs4/A=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
2 changes: 1 addition & 1 deletion examples/ernie-completion-example/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module github.com/tmc/langchaingo/examples/ernie-completion-example

go 1.19

require github.com/tmc/langchaingo v0.0.0-20231130160443-fc423fab7b81
require github.com/tmc/langchaingo v0.0.0-20231130223434-98fa24d3e7d2

require (
github.com/dlclark/regexp2 v1.8.1 // indirect
Expand Down
4 changes: 2 additions & 2 deletions examples/ernie-completion-example/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ github.com/pkoukk/tiktoken-go v0.1.2 h1:u7PCSBiWJ3nJYoTGShyM9iHXz4dNyYkurwwp+GHt
github.com/pkoukk/tiktoken-go v0.1.2/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tmc/langchaingo v0.0.0-20231130160443-fc423fab7b81 h1:WRrvtNwd7S1etCMnYjEaem5cizL5TP7q8MTXum/OSUA=
github.com/tmc/langchaingo v0.0.0-20231130160443-fc423fab7b81/go.mod h1:WgJkGMb5Ac/WpD6YLo3zRAiHtALrgGnH42Hcu5Rs4/A=
github.com/tmc/langchaingo v0.0.0-20231130223434-98fa24d3e7d2 h1:3arY5l84Sp5SRx+9xY8vXiTpin932qv0BNAOSQbtlHY=
github.com/tmc/langchaingo v0.0.0-20231130223434-98fa24d3e7d2/go.mod h1:WgJkGMb5Ac/WpD6YLo3zRAiHtALrgGnH42Hcu5Rs4/A=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
2 changes: 1 addition & 1 deletion examples/ernie-function-call-example/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module github.com/tmc/langchaingo/examples/ernie-function-call-example

go 1.19

require github.com/tmc/langchaingo v0.0.0-20231130160443-fc423fab7b81
require github.com/tmc/langchaingo v0.0.0-20231130223434-98fa24d3e7d2

require (
github.com/dlclark/regexp2 v1.8.1 // indirect
Expand Down
4 changes: 2 additions & 2 deletions examples/ernie-function-call-example/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ github.com/pkoukk/tiktoken-go v0.1.2 h1:u7PCSBiWJ3nJYoTGShyM9iHXz4dNyYkurwwp+GHt
github.com/pkoukk/tiktoken-go v0.1.2/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tmc/langchaingo v0.0.0-20231130160443-fc423fab7b81 h1:WRrvtNwd7S1etCMnYjEaem5cizL5TP7q8MTXum/OSUA=
github.com/tmc/langchaingo v0.0.0-20231130160443-fc423fab7b81/go.mod h1:WgJkGMb5Ac/WpD6YLo3zRAiHtALrgGnH42Hcu5Rs4/A=
github.com/tmc/langchaingo v0.0.0-20231130223434-98fa24d3e7d2 h1:3arY5l84Sp5SRx+9xY8vXiTpin932qv0BNAOSQbtlHY=
github.com/tmc/langchaingo v0.0.0-20231130223434-98fa24d3e7d2/go.mod h1:WgJkGMb5Ac/WpD6YLo3zRAiHtALrgGnH42Hcu5Rs4/A=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
2 changes: 1 addition & 1 deletion examples/ernie-function-call-streaming-example/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module github.com/tmc/langchaingo/examples/ernie-function-call-streaming-example

go 1.19

require github.com/tmc/langchaingo v0.0.0-20231130160443-fc423fab7b81
require github.com/tmc/langchaingo v0.0.0-20231130223434-98fa24d3e7d2

require (
github.com/dlclark/regexp2 v1.8.1 // indirect
Expand Down
Loading
Loading