-
-
Notifications
You must be signed in to change notification settings - Fork 734
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
embeddings: Add Amazon Bedrock embeddings (#643)
* Add bedrock embedding provider * Add bedrock tests --------- Co-authored-by: Travis Cline <travis.cline@gmail.com>
- Loading branch information
1 parent
8cbd678
commit a51062f
Showing
7 changed files
with
382 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
package bedrock | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"strings" | ||
|
||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime" | ||
"github.com/tmc/langchaingo/embeddings" | ||
) | ||
|
||
// Bedrock is the embedder used generate text embeddings through Amazon Bedrock. | ||
type Bedrock struct { | ||
ModelID string | ||
client *bedrockruntime.Client | ||
StripNewLines bool | ||
BatchSize int | ||
} | ||
|
||
// NewBedrock returns a new embeddings.Embedder that uses Amazon Bedrock to generate embeddings. | ||
func NewBedrock(opts ...Option) (*Bedrock, error) { | ||
v, err := applyOptions(opts...) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return v, nil | ||
} | ||
|
||
func getProvider(modelID string) string { | ||
return strings.Split(modelID, ".")[0] | ||
} | ||
|
||
// EmbedDocuments implements embeddings.Embedder | ||
// and generates embeddings for the supplied texts. | ||
func (b *Bedrock) EmbedDocuments(ctx context.Context, texts []string) ([][]float32, error) { | ||
batchedTexts := embeddings.BatchTexts( | ||
embeddings.MaybeRemoveNewLines(texts, b.StripNewLines), | ||
b.BatchSize, | ||
) | ||
provider := getProvider(b.ModelID) | ||
|
||
allEmbeds := make([][]float32, 0, len(texts)) | ||
var embeddings [][]float32 | ||
var err error | ||
|
||
for _, batch := range batchedTexts { | ||
switch provider { | ||
case "amazon": | ||
embeddings, err = FetchAmazonTextEmbeddings(ctx, b.client, b.ModelID, batch) | ||
case "cohere": | ||
embeddings, err = FetchCohereTextEmbeddings(ctx, b.client, b.ModelID, batch, CohereInputTypeText) | ||
default: | ||
err = errors.New("unsupported text embedding provider: " + provider) | ||
} | ||
|
||
if err != nil { | ||
return nil, err | ||
} | ||
allEmbeds = append(allEmbeds, embeddings...) | ||
} | ||
return allEmbeds, nil | ||
} | ||
|
||
// EmbedQuery implements embeddings.Embedder | ||
// and generates an embedding for the supplied text. | ||
func (b *Bedrock) EmbedQuery(ctx context.Context, text string) ([]float32, error) { | ||
var embeddings [][]float32 | ||
var err error | ||
|
||
switch provider := getProvider(b.ModelID); provider { | ||
case "amazon": | ||
embeddings, err = FetchAmazonTextEmbeddings(ctx, b.client, b.ModelID, []string{text}) | ||
case "cohere": | ||
embeddings, err = FetchCohereTextEmbeddings(ctx, b.client, b.ModelID, []string{text}, CohereInputTypeQuery) | ||
default: | ||
err = errors.New("unsupported text embedding provider: " + provider) | ||
} | ||
|
||
if err != nil { | ||
return nil, err | ||
} | ||
return embeddings[0], nil | ||
} | ||
|
||
var _ embeddings.Embedder = &Bedrock{} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
package bedrock_test | ||
|
||
import ( | ||
"context" | ||
"os" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
"github.com/tmc/langchaingo/embeddings/bedrock" | ||
) | ||
|
||
func TestEmbedQuery(t *testing.T) { | ||
t.Parallel() | ||
if os.Getenv("TEST_AWS") != "true" { | ||
t.Skip("Skipping test, requires AWS access") | ||
} | ||
model, err := bedrock.NewBedrock(bedrock.WithModel(bedrock.ModelTitanEmbedG1)) | ||
require.NoError(t, err) | ||
_, err = model.EmbedQuery(context.Background(), "hello world") | ||
|
||
require.NoError(t, err) | ||
} | ||
|
||
func TestEmbedDocuments(t *testing.T) { | ||
t.Parallel() | ||
if os.Getenv("TEST_AWS") != "true" { | ||
t.Skip("Skipping test, requires AWS access") | ||
} | ||
model, err := bedrock.NewBedrock(bedrock.WithModel(bedrock.ModelCohereEn)) | ||
require.NoError(t, err) | ||
|
||
embeddings, err := model.EmbedDocuments(context.Background(), []string{"hello world", "goodbye world"}) | ||
|
||
require.NoError(t, err) | ||
require.Len(t, embeddings, 2) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
package bedrock | ||
|
||
import ( | ||
"context" | ||
|
||
"github.com/aws/aws-sdk-go-v2/config" | ||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime" | ||
) | ||
|
||
const ( | ||
_defaultBatchSize = 512 | ||
_defaultStripNewLines = true | ||
_defaultModel = ModelTitanEmbedG1 | ||
) | ||
|
||
// Option is a function type that can be used to modify the client. | ||
type Option func(p *Bedrock) | ||
|
||
// WithStripNewLines is an option for specifying the should it strip new lines. | ||
func WithStripNewLines(stripNewLines bool) Option { | ||
return func(p *Bedrock) { | ||
p.StripNewLines = stripNewLines | ||
} | ||
} | ||
|
||
// WithBatchSize is an option for specifying the batch size. | ||
// Only applicable to Cohere provider. | ||
func WithBatchSize(batchSize int) Option { | ||
return func(p *Bedrock) { | ||
p.BatchSize = batchSize | ||
} | ||
} | ||
|
||
// WithModel is an option for providing the model name to use. | ||
func WithModel(model string) Option { | ||
return func(p *Bedrock) { | ||
p.ModelID = model | ||
} | ||
} | ||
|
||
// WithClient is an option for providing the Bedrock client. | ||
func WithClient(client *bedrockruntime.Client) Option { | ||
return func(p *Bedrock) { | ||
p.client = client | ||
} | ||
} | ||
|
||
func applyOptions(opts ...Option) (*Bedrock, error) { | ||
o := &Bedrock{ | ||
StripNewLines: _defaultStripNewLines, | ||
BatchSize: _defaultBatchSize, | ||
} | ||
|
||
for _, opt := range opts { | ||
opt(o) | ||
} | ||
|
||
if o.client == nil { | ||
cfg, err := config.LoadDefaultConfig(context.Background()) | ||
if err != nil { | ||
return nil, err | ||
} | ||
o.client = bedrockruntime.NewFromConfig(cfg) | ||
} | ||
return o, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
package bedrock | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
|
||
"github.com/aws/aws-sdk-go-v2/aws" | ||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime" | ||
) | ||
|
||
const ( | ||
/* | ||
ModelTitanEmbedG1 is the model id for the amazon text embeddings. | ||
MaxTokens := 8000 | ||
ModelDimensions := 1536 | ||
Languages := []string{"English", "Arabic", "Chinese (Simplified)", "French", "German", "Hindi", "Japanese", "Spanish", "Czech", "Filipino", "Hebrew", "Italian", "Korean", "Portuguese", "Russian", "Swedish", "Turkish", "Chinese (Traditional)", "Dutch", "Kannada", "Malayalam", "Marathi", "Polish", "Tamil", "Telugu", ...} | ||
*/ | ||
ModelTitanEmbedG1 = "amazon.titan-embed-text-v1" | ||
) | ||
|
||
type amazonEmbeddingsInput struct { | ||
InputText string `json:"inputText"` | ||
} | ||
|
||
type amazonEmbeddingsOutput struct { | ||
Embedding []float32 `json:"embedding"` | ||
} | ||
|
||
func FetchAmazonTextEmbeddings(ctx context.Context, | ||
client *bedrockruntime.Client, | ||
modelID string, | ||
texts []string, | ||
) ([][]float32, error) { | ||
embeddings := make([][]float32, 0, len(texts)) | ||
|
||
for _, text := range texts { | ||
bodyStruct := amazonEmbeddingsInput{ | ||
InputText: text, | ||
} | ||
body, err := json.Marshal(bodyStruct) | ||
if err != nil { | ||
return nil, err | ||
} | ||
modelInput := &bedrockruntime.InvokeModelInput{ | ||
ModelId: aws.String(modelID), | ||
Accept: aws.String("*/*"), | ||
ContentType: aws.String("application/json"), | ||
Body: body, | ||
} | ||
|
||
result, err := client.InvokeModel(ctx, modelInput) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
var response amazonEmbeddingsOutput | ||
err = json.Unmarshal(result.Body, &response) | ||
if err != nil { | ||
return nil, err | ||
} | ||
embeddings = append(embeddings, response.Embedding) | ||
} | ||
|
||
return embeddings, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
package bedrock | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
|
||
"github.com/aws/aws-sdk-go-v2/aws" | ||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime" | ||
) | ||
|
||
const ( | ||
/* | ||
ModelCohereEn is the model id for the cohere english embeddings. | ||
ModelDimensions := 1024 | ||
MaxTokens := 512 | ||
Languages := []string{"English"} | ||
*/ | ||
ModelCohereEn = "cohere.embed-english-v3" | ||
|
||
/* | ||
ModelCohereMulti is the model id for the cohere multilingual embeddings. | ||
ModelDimensions := 1024 | ||
MaxTokens:= 512 | ||
Languages := [108]string | ||
*/ | ||
ModelCohereMulti = "cohere.embed-multilingual-v3" | ||
) | ||
|
||
const ( | ||
// CohereInputTypeText is the input type for text embeddings. | ||
CohereInputTypeText = "search_document" | ||
// CohereInputTypeQuery is the input type for query embeddings. | ||
CohereInputTypeQuery = "search_query" | ||
) | ||
|
||
type cohereTextEmbeddingsInput struct { | ||
Texts []string `json:"texts"` | ||
InputType string `json:"input_type"` | ||
} | ||
|
||
type cohereTextEmbeddingsOutput struct { | ||
ResponseType string `json:"response_type"` | ||
Embeddings [][]float32 `json:"embeddings"` | ||
} | ||
|
||
func FetchCohereTextEmbeddings( | ||
ctx context.Context, | ||
client *bedrockruntime.Client, | ||
modelID string, | ||
inputs []string, | ||
inputType string, | ||
) ([][]float32, error) { | ||
var err error | ||
|
||
bodyStruct := cohereTextEmbeddingsInput{ | ||
Texts: inputs, | ||
InputType: inputType, | ||
} | ||
body, err := json.Marshal(bodyStruct) | ||
if err != nil { | ||
return nil, err | ||
} | ||
modelInput := &bedrockruntime.InvokeModelInput{ | ||
ModelId: aws.String(modelID), | ||
Accept: aws.String("*/*"), | ||
ContentType: aws.String("application/json"), | ||
Body: body, | ||
} | ||
|
||
result, err := client.InvokeModel(ctx, modelInput) | ||
if err != nil { | ||
return nil, err | ||
} | ||
var response cohereTextEmbeddingsOutput | ||
err = json.Unmarshal(result.Body, &response) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return response.Embeddings, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.