Skip to content

Commit

Permalink
embeddings: Add Amazon Bedrock embeddings (#643)
Browse files Browse the repository at this point in the history
* Add bedrock embedding provider

* Add bedrock tests

---------

Co-authored-by: Travis Cline <travis.cline@gmail.com>
  • Loading branch information
sansmoraxz and tmc authored Mar 6, 2024
1 parent 8cbd678 commit a51062f
Show file tree
Hide file tree
Showing 7 changed files with 382 additions and 0 deletions.
86 changes: 86 additions & 0 deletions embeddings/bedrock/bedrock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package bedrock

import (
"context"
"errors"
"strings"

"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/tmc/langchaingo/embeddings"
)

// Bedrock is the embedder used generate text embeddings through Amazon Bedrock.
type Bedrock struct {
ModelID string
client *bedrockruntime.Client
StripNewLines bool
BatchSize int
}

// NewBedrock returns a new embeddings.Embedder that uses Amazon Bedrock to generate embeddings.
func NewBedrock(opts ...Option) (*Bedrock, error) {
v, err := applyOptions(opts...)
if err != nil {
return nil, err
}

return v, nil
}

func getProvider(modelID string) string {
return strings.Split(modelID, ".")[0]
}

// EmbedDocuments implements embeddings.Embedder
// and generates embeddings for the supplied texts.
func (b *Bedrock) EmbedDocuments(ctx context.Context, texts []string) ([][]float32, error) {
batchedTexts := embeddings.BatchTexts(
embeddings.MaybeRemoveNewLines(texts, b.StripNewLines),
b.BatchSize,
)
provider := getProvider(b.ModelID)

allEmbeds := make([][]float32, 0, len(texts))
var embeddings [][]float32
var err error

for _, batch := range batchedTexts {
switch provider {
case "amazon":
embeddings, err = FetchAmazonTextEmbeddings(ctx, b.client, b.ModelID, batch)
case "cohere":
embeddings, err = FetchCohereTextEmbeddings(ctx, b.client, b.ModelID, batch, CohereInputTypeText)
default:
err = errors.New("unsupported text embedding provider: " + provider)
}

if err != nil {
return nil, err
}
allEmbeds = append(allEmbeds, embeddings...)
}
return allEmbeds, nil
}

// EmbedQuery implements embeddings.Embedder
// and generates an embedding for the supplied text.
func (b *Bedrock) EmbedQuery(ctx context.Context, text string) ([]float32, error) {
var embeddings [][]float32
var err error

switch provider := getProvider(b.ModelID); provider {
case "amazon":
embeddings, err = FetchAmazonTextEmbeddings(ctx, b.client, b.ModelID, []string{text})
case "cohere":
embeddings, err = FetchCohereTextEmbeddings(ctx, b.client, b.ModelID, []string{text}, CohereInputTypeQuery)
default:
err = errors.New("unsupported text embedding provider: " + provider)
}

if err != nil {
return nil, err
}
return embeddings[0], nil
}

var _ embeddings.Embedder = &Bedrock{}
36 changes: 36 additions & 0 deletions embeddings/bedrock/bedrock_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package bedrock_test

import (
"context"
"os"
"testing"

"github.com/stretchr/testify/require"
"github.com/tmc/langchaingo/embeddings/bedrock"
)

func TestEmbedQuery(t *testing.T) {
t.Parallel()
if os.Getenv("TEST_AWS") != "true" {
t.Skip("Skipping test, requires AWS access")
}
model, err := bedrock.NewBedrock(bedrock.WithModel(bedrock.ModelTitanEmbedG1))
require.NoError(t, err)
_, err = model.EmbedQuery(context.Background(), "hello world")

require.NoError(t, err)
}

func TestEmbedDocuments(t *testing.T) {
t.Parallel()
if os.Getenv("TEST_AWS") != "true" {
t.Skip("Skipping test, requires AWS access")
}
model, err := bedrock.NewBedrock(bedrock.WithModel(bedrock.ModelCohereEn))
require.NoError(t, err)

embeddings, err := model.EmbedDocuments(context.Background(), []string{"hello world", "goodbye world"})

require.NoError(t, err)
require.Len(t, embeddings, 2)
}
66 changes: 66 additions & 0 deletions embeddings/bedrock/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package bedrock

import (
"context"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
)

const (
_defaultBatchSize = 512
_defaultStripNewLines = true
_defaultModel = ModelTitanEmbedG1
)

// Option is a function type that can be used to modify the client.
type Option func(p *Bedrock)

// WithStripNewLines is an option for specifying the should it strip new lines.
func WithStripNewLines(stripNewLines bool) Option {
return func(p *Bedrock) {
p.StripNewLines = stripNewLines
}
}

// WithBatchSize is an option for specifying the batch size.
// Only applicable to Cohere provider.
func WithBatchSize(batchSize int) Option {
return func(p *Bedrock) {
p.BatchSize = batchSize
}
}

// WithModel is an option for providing the model name to use.
func WithModel(model string) Option {
return func(p *Bedrock) {
p.ModelID = model
}
}

// WithClient is an option for providing the Bedrock client.
func WithClient(client *bedrockruntime.Client) Option {
return func(p *Bedrock) {
p.client = client
}
}

func applyOptions(opts ...Option) (*Bedrock, error) {
o := &Bedrock{
StripNewLines: _defaultStripNewLines,
BatchSize: _defaultBatchSize,
}

for _, opt := range opts {
opt(o)
}

if o.client == nil {
cfg, err := config.LoadDefaultConfig(context.Background())
if err != nil {
return nil, err
}
o.client = bedrockruntime.NewFromConfig(cfg)
}
return o, nil
}
66 changes: 66 additions & 0 deletions embeddings/bedrock/provider_amazon.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package bedrock

import (
"context"
"encoding/json"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
)

const (
/*
ModelTitanEmbedG1 is the model id for the amazon text embeddings.
MaxTokens := 8000
ModelDimensions := 1536
Languages := []string{"English", "Arabic", "Chinese (Simplified)", "French", "German", "Hindi", "Japanese", "Spanish", "Czech", "Filipino", "Hebrew", "Italian", "Korean", "Portuguese", "Russian", "Swedish", "Turkish", "Chinese (Traditional)", "Dutch", "Kannada", "Malayalam", "Marathi", "Polish", "Tamil", "Telugu", ...}
*/
ModelTitanEmbedG1 = "amazon.titan-embed-text-v1"
)

type amazonEmbeddingsInput struct {
InputText string `json:"inputText"`
}

type amazonEmbeddingsOutput struct {
Embedding []float32 `json:"embedding"`
}

func FetchAmazonTextEmbeddings(ctx context.Context,
client *bedrockruntime.Client,
modelID string,
texts []string,
) ([][]float32, error) {
embeddings := make([][]float32, 0, len(texts))

for _, text := range texts {
bodyStruct := amazonEmbeddingsInput{
InputText: text,
}
body, err := json.Marshal(bodyStruct)
if err != nil {
return nil, err
}
modelInput := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(modelID),
Accept: aws.String("*/*"),
ContentType: aws.String("application/json"),
Body: body,
}

result, err := client.InvokeModel(ctx, modelInput)
if err != nil {
return nil, err
}

var response amazonEmbeddingsOutput
err = json.Unmarshal(result.Body, &response)
if err != nil {
return nil, err
}
embeddings = append(embeddings, response.Embedding)
}

return embeddings, nil
}
83 changes: 83 additions & 0 deletions embeddings/bedrock/provider_cohere.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package bedrock

import (
"context"
"encoding/json"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
)

const (
/*
ModelCohereEn is the model id for the cohere english embeddings.
ModelDimensions := 1024
MaxTokens := 512
Languages := []string{"English"}
*/
ModelCohereEn = "cohere.embed-english-v3"

/*
ModelCohereMulti is the model id for the cohere multilingual embeddings.
ModelDimensions := 1024
MaxTokens:= 512
Languages := [108]string
*/
ModelCohereMulti = "cohere.embed-multilingual-v3"
)

const (
// CohereInputTypeText is the input type for text embeddings.
CohereInputTypeText = "search_document"
// CohereInputTypeQuery is the input type for query embeddings.
CohereInputTypeQuery = "search_query"
)

type cohereTextEmbeddingsInput struct {
Texts []string `json:"texts"`
InputType string `json:"input_type"`
}

type cohereTextEmbeddingsOutput struct {
ResponseType string `json:"response_type"`
Embeddings [][]float32 `json:"embeddings"`
}

func FetchCohereTextEmbeddings(
ctx context.Context,
client *bedrockruntime.Client,
modelID string,
inputs []string,
inputType string,
) ([][]float32, error) {
var err error

bodyStruct := cohereTextEmbeddingsInput{
Texts: inputs,
InputType: inputType,
}
body, err := json.Marshal(bodyStruct)
if err != nil {
return nil, err
}
modelInput := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(modelID),
Accept: aws.String("*/*"),
ContentType: aws.String("application/json"),
Body: body,
}

result, err := client.InvokeModel(ctx, modelInput)
if err != nil {
return nil, err
}
var response cohereTextEmbeddingsOutput
err = json.Unmarshal(result.Body, &response)
if err != nil {
return nil, err
}

return response.Embeddings, nil
}
15 changes: 15 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ require (
github.com/antchfx/xmlquery v1.3.17 // indirect
github.com/antchfx/xpath v1.2.4 // indirect
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d // indirect
github.com/aws/aws-sdk-go-v2 v1.25.2 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.1 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.4 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.15.2 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.2 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.2 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.1 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.2 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.20.1 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.1 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.28.1 // indirect
github.com/aws/smithy-go v1.20.1 // indirect
github.com/aymerick/douceur v0.2.0 // indirect
github.com/cenkalti/backoff/v4 v4.2.1 // indirect
github.com/cockroachdb/errors v1.9.1 // indirect
Expand Down Expand Up @@ -162,6 +175,8 @@ require (
github.com/Masterminds/sprig/v3 v3.2.3
github.com/PuerkitoBio/goquery v1.8.1
github.com/amikos-tech/chroma-go v0.0.1
github.com/aws/aws-sdk-go-v2/config v1.27.4
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.1
github.com/cohere-ai/tokenizer v1.1.2
github.com/go-openapi/strfmt v0.21.3
github.com/go-sql-driver/mysql v1.7.1
Expand Down
Loading

0 comments on commit a51062f

Please sign in to comment.