Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llms/bedrock: Add AWS Bedrock LLMs #666

Merged
merged 7 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions llms/bedrock/bedrockllm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package bedrock

import (
"context"
"errors"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/bedrock/internal/bedrockclient"
)

const defaultModel = "amazon.titan-text-lite-v1"

type LLM struct {
modelID string
client *bedrockclient.Client
CallbacksHandler callbacks.Handler
}

func New(opts ...Option) (*LLM, error) {
o, c, err := newClient(opts...)
if err != nil {
return nil, err
}
return &LLM{
client: c,
modelID: o.modelID,
CallbacksHandler: o.callbackHandler,
}, nil
}

func newClient(opts ...Option) (*options, *bedrockclient.Client, error) {
options := &options{
modelID: defaultModel,
}

for _, opt := range opts {
opt(options)
}

if options.client == nil {
cfg, err := config.LoadDefaultConfig(context.Background())
if err != nil {
return options, nil, err
}
options.client = bedrockruntime.NewFromConfig(cfg)
}

return options, bedrockclient.NewClient(options.client), nil
}

// Call implements llms.Model.
func (l *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) {
return llms.GenerateFromSinglePrompt(ctx, l, prompt, options...)
}

// GenerateContent implements llms.Model.
func (l *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) {
if l.CallbacksHandler != nil {
l.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages)
}

opts := llms.CallOptions{
Model: l.modelID,
}
for _, opt := range options {
opt(&opts)
}

m, err := processMessages(messages)
if err != nil {
return nil, err
}

res, err := l.client.CreateCompletion(ctx, opts.Model, m, opts)
if err != nil {
if l.CallbacksHandler != nil {
l.CallbacksHandler.HandleLLMError(ctx, err)
}
return nil, err
}

if l.CallbacksHandler != nil {
l.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, res)
}

return res, nil
}

func processMessages(messages []llms.MessageContent) ([]bedrockclient.Message, error) {
bedrockMsgs := make([]bedrockclient.Message, 0, len(messages))

for _, m := range messages {
for _, part := range m.Parts {
switch part := part.(type) {
case llms.TextContent:
bedrockMsgs = append(bedrockMsgs, bedrockclient.Message{
Role: m.Role,
Content: part.Text,
Type: "text",
})
case llms.BinaryContent:
bedrockMsgs = append(bedrockMsgs, bedrockclient.Message{
Role: m.Role,
Content: string(part.Data),
MimeType: part.MIMEType,
Type: "image",
})
default:
return nil, errors.New("unsupported message type")
}
}
}
return bedrockMsgs, nil
}

var _ llms.Model = (*LLM)(nil)
35 changes: 35 additions & 0 deletions llms/bedrock/bedrockllm_option.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package bedrock

import (
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/tmc/langchaingo/callbacks"
)

type Option func(*options)

type options struct {
modelID string
client *bedrockruntime.Client
callbackHandler callbacks.Handler
}

// WithModel allows setting a custom modelId.
func WithModel(modelID string) Option {
return func(o *options) {
o.modelID = modelID
}
}

// WithClient allows setting a custom bedrockruntime.Client.
func WithClient(client *bedrockruntime.Client) Option {
return func(o *options) {
o.client = client
}
}

// WithCallback allows setting a custom Callback Handler.
func WithCallback(callbackHandler callbacks.Handler) Option {
return func(o *options) {
o.callbackHandler = callbackHandler
}
}
84 changes: 84 additions & 0 deletions llms/bedrock/bedrockllm_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package bedrock_test

import (
"context"
"os"
"testing"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/bedrock"
"github.com/tmc/langchaingo/schema"
)

func setUpTest() (*bedrockruntime.Client, error) {
cfg, err := config.LoadDefaultConfig(context.Background())
if err != nil {
return nil, err
}
client := bedrockruntime.NewFromConfig(cfg)
return client, nil
}

func TestAmazonOutput(t *testing.T) {
t.Parallel()

if os.Getenv("TEST_AWS") != "true" {
t.Skip("Skipping test, requires AWS access")
}

client, err := setUpTest()
if err != nil {
t.Fatal(err)
}
llm, err := bedrock.New(bedrock.WithClient(client))
if err != nil {
t.Fatal(err)
}

msgs := []llms.MessageContent{
{
Role: schema.ChatMessageTypeSystem,
Parts: []llms.ContentPart{
llms.TextPart("You know all about AI."),
},
},
{
Role: schema.ChatMessageTypeHuman,
Parts: []llms.ContentPart{
llms.TextPart("Explain AI in 10 words or less."),
},
},
}

// All the test models.
models := []string{
bedrock.ModelAi21J2MidV1,
bedrock.ModelAi21J2UltraV1,
bedrock.ModelAmazonTitanTextLiteV1,
bedrock.ModelAmazonTitanTextExpressV1,
bedrock.ModelAnthropicClaude3Sonnet20240229V10,
bedrock.ModelAnthropicClaudeV21,
bedrock.ModelAnthropicClaudeV2,
bedrock.ModelAnthropicClaudeInstantV1,
bedrock.ModelCohereCommandTextV14,
bedrock.ModelCohereCommandLightTextV14,
bedrock.ModelMetaLlama213bChatV1,
bedrock.ModelMetaLlama270bChatV1,
}

ctx := context.Background()

for _, model := range models {
t.Logf("Model output for %s:-", model)

resp, err := llm.GenerateContent(ctx, msgs, llms.WithModel(model), llms.WithMaxTokens(512))
if err != nil {
t.Fatal(err)
}
for i, choice := range resp.Choices {
t.Logf("Choice %d: %s", i, choice.Content)
}
}
}
77 changes: 77 additions & 0 deletions llms/bedrock/internal/bedrockclient/bedrockclient.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package bedrockclient

import (
"context"
"errors"
"strings"

"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/schema"
)

type Client struct {
client *bedrockruntime.Client
}

type Message struct {
Role schema.ChatMessageType
Content string
// Type may be "text" or "image"
Type string
// MimeType is the MIME type
MimeType string
}

func getProvider(modelID string) string {
return strings.Split(modelID, ".")[0]
}

func NewClient(client *bedrockruntime.Client) *Client {
return &Client{
client: client,
}
}

func (c *Client) CreateCompletion(ctx context.Context,
modelID string,
messages []Message,
options llms.CallOptions,
) (*llms.ContentResponse, error) {
provider := getProvider(modelID)
switch provider {
case "ai21":
return createAi21Completion(ctx, c.client, modelID, messages, options)
case "amazon":
return createAmazonCompletion(ctx, c.client, modelID, messages, options)
case "anthropic":
return createAnthropicCompletion(ctx, c.client, modelID, messages, options)
case "cohere":
return createCohereCompletion(ctx, c.client, modelID, messages, options)
case "meta":
return createMetaCompletion(ctx, c.client, modelID, messages, options)
default:
return nil, errors.New("unsupported provider")
}
}

func processInputMessagesGeneric(messages []Message) string {
var sb strings.Builder
var hasRole bool
for _, message := range messages {
if message.Role != "" {
hasRole = true
sb.WriteString("\n")
sb.WriteString(string(message.Role))
sb.WriteString(": ")
}
if message.Type == "text" {
sb.WriteString(message.Content)
}
}
if hasRole {
sb.WriteString("\n")
sb.WriteString("AI: ")
}
return sb.String()
}
Loading
Loading