Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add streaming #11

Merged
merged 22 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion agency.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func NewOperation(handler OperationHandler) *Operation {
func (p *Operation) Execute(ctx context.Context, input Message) (Message, error) {
output, err := p.handler(ctx, input, p.config)
if err != nil {
return Message{}, err
return nil, err
}
return output, nil
}
Expand Down
2 changes: 1 addition & 1 deletion examples/chat/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func main() {
panic(err)
}

input := agency.UserMessage(text)
input := agency.NewMessage(agency.UserRole, agency.TextKind, []byte(text))
answer, err := assistant.SetMessages(messages).Execute(ctx, input)
if err != nil {
panic(err)
Expand Down
2 changes: 1 addition & 1 deletion examples/cli/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func main() {
MaxTokens: *maxTokens,
}).
SetPrompt(*prompt).
Execute(context.Background(), agency.UserMessage(content))
Execute(context.Background(), agency.NewMessage(agency.UserRole, agency.TextKind, []byte(content)))

if err != nil {
fmt.Println(err)
Expand Down
31 changes: 0 additions & 31 deletions examples/completion_streaming/main.go

This file was deleted.

8 changes: 4 additions & 4 deletions examples/custom_operation/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func main() {

msg, err := agency.NewProcess(
increment, increment, increment,
).Execute(context.Background(), agency.UserMessage("0"))
).Execute(context.Background(), agency.NewMessage(agency.UserRole, agency.TextKind, []byte("0")))

if err != nil {
panic(err)
Expand All @@ -23,10 +23,10 @@ func main() {
}

func incrementFunc(ctx context.Context, msg agency.Message, _ *agency.OperationConfig) (agency.Message, error) {
i, err := strconv.ParseInt(string(msg.Content), 10, 10)
i, err := strconv.ParseInt(string(msg.Content()), 10, 10)
if err != nil {
return agency.Message{}, err
return nil, err
}
inc := strconv.Itoa(int(i) + 1)
return agency.SystemMessage(inc), nil
return agency.NewMessage(agency.ToolRole, agency.TextKind, []byte(inc)), nil
}
27 changes: 17 additions & 10 deletions examples/func_call/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ func main() {
{
Name: "GetMeaningOfLife",
Description: "Answer questions about meaning of life",
Body: func(ctx context.Context, _ []byte) (any, error) {
return 42, nil
Body: func(ctx context.Context, _ []byte) (agency.Message, error) {
return agency.NewTextMessage(agency.ToolRole, "42"), nil
},
},
// function with parameters
Expand All @@ -38,12 +38,15 @@ func main() {
"b": {Type: "integer"},
},
},
Body: func(ctx context.Context, params []byte) (any, error) {
Body: func(ctx context.Context, params []byte) (agency.Message, error) {
var pp struct{ A, B int }
if err := json.Unmarshal(params, &pp); err != nil {
return nil, err
}
return (pp.A + pp.B) * 10, nil // *10 is just to distinguish from normal response
return agency.NewTextMessage(
agency.ToolRole,
fmt.Sprintf("%d", (pp.A+pp.B)*10),
), nil // *10 is just to distinguish from normal response
},
},
},
Expand All @@ -64,30 +67,34 @@ Examples:
// test for first function call
answer, err := t2tOp.Execute(
ctx,
agency.UserMessage("what is the meaning of life?"),
agency.NewMessage(agency.UserRole, agency.TextKind, []byte("what is the meaning of life?")),
)
if err != nil {
panic(err)
}
fmt.Println(answer)
printAnswer(answer)

// test for second function call
answer, err = t2tOp.Execute(
ctx,
agency.UserMessage("1+1?"),
agency.NewMessage(agency.UserRole, agency.TextKind, []byte("1+1?")),
)
if err != nil {
panic(err)
}
fmt.Println(answer)
printAnswer(answer)

// test for both function calls at the same time
answer, err = t2tOp.Execute(
ctx,
agency.UserMessage("1+1 and what is the meaning of life?"),
agency.NewMessage(agency.UserRole, agency.TextKind, []byte("1+1 and what is the meaning of life?")),
)
if err != nil {
panic(err)
}
fmt.Println(answer)
printAnswer(answer)
}

func printAnswer(message agency.Message) {
fmt.Printf("Role: %s; Type: %s; Data: %s\n", message.Role(), message.Kind(), agency.GetStringContent(message))
}
35 changes: 35 additions & 0 deletions examples/image_to_stream/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package main

import (
"context"
"fmt"
"os"

_ "github.com/joho/godotenv/autoload"
"github.com/neurocult/agency"

providers "github.com/neurocult/agency/providers/openai"
)

func main() {
imgBytes, err := os.ReadFile("assets/test.jpg")
if err != nil {
panic(err)
}

_, err = providers.New(providers.Params{Key: os.Getenv("OPENAI_API_KEY")}).
TextToStream(providers.TextToStreamParams{
TextToTextParams: providers.TextToTextParams{MaxTokens: 300, Model: "gpt-4o"},
StreamHandler: func(delta, total string, isFirst, isLast bool) error {
fmt.Println(delta)
return nil
}}).
SetPrompt("describe what you see").
Execute(
context.Background(),
agency.NewMessage(agency.UserRole, agency.ImageKind, imgBytes),
)
if err != nil {
panic(err)
}
}
5 changes: 2 additions & 3 deletions examples/image_to_text/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@ import (
"os"

_ "github.com/joho/godotenv/autoload"

"github.com/neurocult/agency"
openAIProvider "github.com/neurocult/agency/providers/openai"
"github.com/sashabaranov/go-openai"
)

func main() {
imgBytes, err := os.ReadFile("example.png")
imgBytes, err := os.ReadFile("../example.png")
if err != nil {
panic(err)
}
Expand All @@ -23,7 +22,7 @@ func main() {
SetPrompt("describe what you see").
Execute(
context.Background(),
agency.Message{Content: imgBytes},
agency.NewMessage(agency.UserRole, agency.ImageKind, imgBytes),
)
if err != nil {
panic(err)
Expand Down
2 changes: 1 addition & 1 deletion examples/logging/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func main() {
).
Execute(
context.Background(),
agency.UserMessage("Kazakhstan alga!"),
agency.NewMessage(agency.UserRole, agency.TextKind, []byte("Kazakhstan alga!")),
Logger,
)

Expand Down
2 changes: 1 addition & 1 deletion examples/prompt_template/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func main() {
).
Execute(
context.Background(),
agency.UserMessage("%s", "I love programming."),
agency.NewMessage(agency.UserRole, agency.TextKind, []byte("I love programming.")),
)

if err != nil {
Expand Down
19 changes: 10 additions & 9 deletions examples/rag_vector_database/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ import (
"os"

_ "github.com/joho/godotenv/autoload"
"github.com/neurocult/agency"
"github.com/weaviate/weaviate-go-client/v4/weaviate"
"github.com/weaviate/weaviate-go-client/v4/weaviate/graphql"
"github.com/weaviate/weaviate/entities/models"

"github.com/neurocult/agency"
"github.com/neurocult/agency/providers/openai"
)

Expand All @@ -36,7 +36,7 @@ func main() {
retrieve,
summarize,
voice,
).Execute(ctx, agency.UserMessage("programming"))
).Execute(ctx, agency.NewMessage(agency.UserRole, agency.TextKind, []byte("programming")))
if err != nil {
panic(err)
}
Expand All @@ -49,7 +49,7 @@ func main() {
// RAGoperation retrieves relevant objects from vector store and builds a text message to pass further to the process
func RAGoperation(client *weaviate.Client) *agency.Operation {
return agency.NewOperation(func(ctx context.Context, msg agency.Message, po *agency.OperationConfig) (agency.Message, error) {
input := msg.String()
input := string(msg.Content())

result, err := client.GraphQL().Get().
WithClassName("Records").
Expand All @@ -71,15 +71,16 @@ func RAGoperation(client *weaviate.Client) *agency.Operation {
for _, obj := range result.Data {
bb, err := json.Marshal(&obj)
if err != nil {
return agency.Message{}, err
return nil, err
}
content += string(bb)
}

return agency.Message{
Role: agency.AssistantRole,
Content: []byte(content),
}, nil
return agency.NewMessage(
agency.AssistantRole,
agency.TextKind,
[]byte(content),
), nil
})
}

Expand Down Expand Up @@ -125,7 +126,7 @@ func saveToDisk(msg agency.Message) error {
}
defer file.Close()

_, err = file.Write(msg.Content)
_, err = file.Write(msg.Content())
if err != nil {
return err
}
Expand Down
4 changes: 1 addition & 3 deletions examples/speech_to_text/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ func main() {
Model: goopenai.Whisper1,
}).Execute(
context.Background(),
agency.Message{
Content: data,
},
agency.NewMessage(agency.UserRole, agency.VoiceKind, data),
)

if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions examples/speech_to_text_multi_model/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func main() {
}

ctx := context.Background()
speechMsg := agency.Message{Content: sound}
speechMsg := agency.NewMessage(agency.UserRole, agency.VoiceKind, sound)

_, err = agency.NewProcess(
hear,
Expand All @@ -64,6 +64,6 @@ func main() {
}

for _, msg := range saver {
fmt.Println(msg.String())
fmt.Println(string(msg.Content()))
}
}
4 changes: 2 additions & 2 deletions examples/speech_to_text_to_image/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func main() {
Model: goopenai.CreateImageModelDallE2,
ImageSize: goopenai.CreateImageSize256x256,
}),
).Execute(context.Background(), agency.Message{Content: data})
).Execute(context.Background(), agency.NewMessage(agency.UserRole, agency.VoiceKind, data))
if err != nil {
panic(err)
}
Expand All @@ -39,7 +39,7 @@ func main() {
}

func saveImgToDisk(msg agency.Message) error {
r := bytes.NewReader(msg.Content)
r := bytes.NewReader(msg.Content())

imgData, err := png.Decode(r)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions examples/text_to_image_dalle2/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func main() {
Style: "vivid",
}).Execute(
context.Background(),
agency.UserMessage("Halloween night at a haunted museum"),
agency.NewMessage(agency.UserRole, agency.TextKind, []byte("Halloween night at a haunted museum")),
)
if err != nil {
panic(err)
Expand All @@ -40,7 +40,7 @@ func main() {
}

func saveToDisk(msg agency.Message) error {
r := bytes.NewReader(msg.Content)
r := bytes.NewReader(msg.Content())

// for dall-e-3 use third party libraries due to lack of webp support in go stdlib
imgData, format, err := image.Decode(r)
Expand Down
10 changes: 7 additions & 3 deletions examples/text_to_speech/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@ import (
)

func main() {
input := agency.UserMessage(`
input := agency.NewMessage(

agency.UserRole,
agency.TextKind,
[]byte(`
One does not simply walk into Mordor.
Its black gates are guarded by more than just Orcs.
There is evil there that does not sleep, and the Great Eye is ever watchful.
`)
`))

msg, err := openai.New(openai.Params{Key: os.Getenv("OPENAI_API_KEY")}).
TextToSpeech(openai.TextToSpeechParams{
Expand All @@ -42,7 +46,7 @@ func saveToDisk(msg agency.Message) error {
}
defer file.Close()

_, err = file.Write(msg.Content)
_, err = file.Write(msg.Content())
if err != nil {
return err
}
Expand Down
Loading