Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llms/bedrock: Add AWS Bedrock LLMs #666

Merged
merged 7 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions llms/bedrock/bedrockllm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package bedrock

import (
"context"
"errors"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/bedrock/internal/bedrockclient"
)

const defaultModel = ModelAmazonTitanTextLiteV1

// LLM is a Bedrock LLM implementation.
type LLM struct {
modelID string
client *bedrockclient.Client
CallbacksHandler callbacks.Handler
}

// New creates a new Bedrock LLM implementation.
func New(opts ...Option) (*LLM, error) {
o, c, err := newClient(opts...)
if err != nil {
return nil, err
}
return &LLM{
client: c,
modelID: o.modelID,
CallbacksHandler: o.callbackHandler,
}, nil
}

func newClient(opts ...Option) (*options, *bedrockclient.Client, error) {
options := &options{
modelID: defaultModel,
}

for _, opt := range opts {
opt(options)
}

if options.client == nil {
cfg, err := config.LoadDefaultConfig(context.Background())
if err != nil {
return options, nil, err
}
options.client = bedrockruntime.NewFromConfig(cfg)
}

return options, bedrockclient.NewClient(options.client), nil
}

// Call implements llms.Model.
func (l *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) {
return llms.GenerateFromSinglePrompt(ctx, l, prompt, options...)
}

// GenerateContent implements llms.Model.
func (l *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) {
if l.CallbacksHandler != nil {
l.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages)
}

opts := llms.CallOptions{
Model: l.modelID,
}
for _, opt := range options {
opt(&opts)
}

m, err := processMessages(messages)
if err != nil {
return nil, err
}

res, err := l.client.CreateCompletion(ctx, opts.Model, m, opts)
if err != nil {
if l.CallbacksHandler != nil {
l.CallbacksHandler.HandleLLMError(ctx, err)
}
return nil, err
}

if l.CallbacksHandler != nil {
l.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, res)
}

return res, nil
}

func processMessages(messages []llms.MessageContent) ([]bedrockclient.Message, error) {
bedrockMsgs := make([]bedrockclient.Message, 0, len(messages))

for _, m := range messages {
for _, part := range m.Parts {
switch part := part.(type) {
case llms.TextContent:
bedrockMsgs = append(bedrockMsgs, bedrockclient.Message{
Role: m.Role,
Content: part.Text,
Type: "text",
})
case llms.BinaryContent:
bedrockMsgs = append(bedrockMsgs, bedrockclient.Message{
Role: m.Role,
Content: string(part.Data),
MimeType: part.MIMEType,
Type: "image",
})
default:
return nil, errors.New("unsupported message type")
}
}
}
return bedrockMsgs, nil
}

var _ llms.Model = (*LLM)(nil)
45 changes: 45 additions & 0 deletions llms/bedrock/bedrockllm_option.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package bedrock

import (
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/tmc/langchaingo/callbacks"
)

// Option is an option for the Bedrock LLM.
type Option func(*options)

type options struct {
modelID string
client *bedrockruntime.Client
callbackHandler callbacks.Handler
}

// WithModel allows setting a custom modelId.
//
// If not set, the default model is used
// i.e. "amazon.titan-text-lite-v1".
func WithModel(modelID string) Option {
return func(o *options) {
o.modelID = modelID
}
}

// WithClient allows setting a custom bedrockruntime.Client.
//
// You may use this to pass a custom bedrockruntime.Client
// with custom configuration options
// such as setting custom credentials, region, endpoint, etc.
//
// By default, a new client will be created using the default credentials chain.
func WithClient(client *bedrockruntime.Client) Option {
return func(o *options) {
o.client = client
}
}

// WithCallback allows setting a custom Callback Handler.
func WithCallback(callbackHandler callbacks.Handler) Option {
return func(o *options) {
o.callbackHandler = callbackHandler
}
}
85 changes: 85 additions & 0 deletions llms/bedrock/bedrockllm_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package bedrock_test

import (
"context"
"os"
"testing"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/bedrock"
"github.com/tmc/langchaingo/schema"
)

func setUpTest() (*bedrockruntime.Client, error) {
cfg, err := config.LoadDefaultConfig(context.Background())
if err != nil {
return nil, err
}
client := bedrockruntime.NewFromConfig(cfg)
return client, nil
}

func TestAmazonOutput(t *testing.T) {
t.Parallel()

if os.Getenv("TEST_AWS") != "true" {
t.Skip("Skipping test, requires AWS access")
}

client, err := setUpTest()
if err != nil {
t.Fatal(err)
}
llm, err := bedrock.New(bedrock.WithClient(client))
if err != nil {
t.Fatal(err)
}

msgs := []llms.MessageContent{
{
Role: schema.ChatMessageTypeSystem,
Parts: []llms.ContentPart{
llms.TextPart("You know all about AI."),
},
},
{
Role: schema.ChatMessageTypeHuman,
Parts: []llms.ContentPart{
llms.TextPart("Explain AI in 10 words or less."),
},
},
}

// All the test models.
models := []string{
bedrock.ModelAi21J2MidV1,
bedrock.ModelAi21J2UltraV1,
bedrock.ModelAmazonTitanTextLiteV1,
bedrock.ModelAmazonTitanTextExpressV1,
bedrock.ModelAnthropicClaudeV3Sonnet,
bedrock.ModelAnthropicClaudeV3Haiku,
bedrock.ModelAnthropicClaudeV21,
bedrock.ModelAnthropicClaudeV2,
bedrock.ModelAnthropicClaudeInstantV1,
bedrock.ModelCohereCommandTextV14,
bedrock.ModelCohereCommandLightTextV14,
bedrock.ModelMetaLlama213bChatV1,
bedrock.ModelMetaLlama270bChatV1,
}

ctx := context.Background()

for _, model := range models {
t.Logf("Model output for %s:-", model)

resp, err := llm.GenerateContent(ctx, msgs, llms.WithModel(model), llms.WithMaxTokens(512))
if err != nil {
t.Fatal(err)
}
for i, choice := range resp.Choices {
t.Logf("Choice %d: %s", i, choice.Content)
}
}
}
88 changes: 88 additions & 0 deletions llms/bedrock/internal/bedrockclient/bedrockclient.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package bedrockclient

import (
"context"
"errors"
"strings"

"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/schema"
)

// Client is a Bedrock client.
type Client struct {
client *bedrockruntime.Client
}

// Message is a chunk of text or an data
// that will be sent to the provider.
//
// The provider may then transform the message to its own
// format before sending it to the LLM model API.
type Message struct {
Role schema.ChatMessageType
Content string
// Type may be "text" or "image"
Type string
// MimeType is the MIME type
MimeType string
}

func getProvider(modelID string) string {
return strings.Split(modelID, ".")[0]
}

// NewClient creates a new Bedrock client.
func NewClient(client *bedrockruntime.Client) *Client {
return &Client{
client: client,
}
}

// CreateCompletion creates a new completion response from the provider
// after sending the messages to the provider.
func (c *Client) CreateCompletion(ctx context.Context,
modelID string,
messages []Message,
options llms.CallOptions,
) (*llms.ContentResponse, error) {
provider := getProvider(modelID)
switch provider {
case "ai21":
return createAi21Completion(ctx, c.client, modelID, messages, options)
case "amazon":
return createAmazonCompletion(ctx, c.client, modelID, messages, options)
case "anthropic":
return createAnthropicCompletion(ctx, c.client, modelID, messages, options)
case "cohere":
return createCohereCompletion(ctx, c.client, modelID, messages, options)
case "meta":
return createMetaCompletion(ctx, c.client, modelID, messages, options)
default:
return nil, errors.New("unsupported provider")
}
}

// Helper function to process input text chat
// messages as a single string.
func processInputMessagesGeneric(messages []Message) string {
var sb strings.Builder
var hasRole bool
for _, message := range messages {
if message.Role != "" {
hasRole = true
sb.WriteString("\n")
sb.WriteString(string(message.Role))
sb.WriteString(": ")
}
if message.Type == "text" {
sb.WriteString(message.Content)
}
}
if hasRole {
sb.WriteString("\n")
sb.WriteString("AI: ")
}
return sb.String()
}
Loading
Loading