Skip to content

Commit

Permalink
fix(t2t): multiple function calls at the same time now working
Browse files Browse the repository at this point in the history
  • Loading branch information
emil14 committed Apr 3, 2024
1 parent a12c1ff commit ecccd2b
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 42 deletions.
59 changes: 38 additions & 21 deletions examples/func_call/main.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package main

import (
"bufio"
"context"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -30,8 +29,8 @@ func main() {
},
// function with parameters
{
Name: "SumNumbers",
Description: "Sum given numbers when asked",
Name: "ChangeNumbers",
Description: "Change given numbers when asked",
Parameters: &jsonschema.Definition{
Type: "object",
Properties: map[string]jsonschema.Definition{
Expand All @@ -49,28 +48,46 @@ func main() {
},
},
}).
SetPrompt("You are helpful assistant.")
SetPrompt(`
Answer questions about meaning of life and summing numbers.
Always use GetMeaningOfLife and ChangeNumbers functions results as answers.
Examples:
- User: what is the meaning of life?
- Assistant: 42
- User: 1+1
- Assistant: 20
- User: 1+1 and what is the meaning of life?
- Assistant: 20 and 42`)

messages := []agency.Message{}
reader := bufio.NewReader(os.Stdin)
ctx := context.Background()

for {
fmt.Print("User: ")

text, err := reader.ReadString('\n')
if err != nil {
panic(err)
}

input := agency.UserMessage(text)
answer, err := t2tOp.SetMessages(messages).Execute(ctx, input)
if err != nil {
panic(err)
}
// test for first function call
answer, err := t2tOp.Execute(
ctx,
agency.UserMessage("what is the meaning of life?"),
)
if err != nil {
panic(err)
}
fmt.Println(answer)

fmt.Println("Assistant: ", answer)
// test for second function call
answer, err = t2tOp.Execute(
ctx,
agency.UserMessage("1+1?"),
)
if err != nil {
panic(err)
}
fmt.Println(answer)

messages = append(messages, input, answer)
// test for both function calls at the same time
answer, err = t2tOp.Execute(
ctx,
agency.UserMessage("1+1 and what is the meaning of life?"),
)
if err != nil {
panic(err)
}
fmt.Println(answer)
}
43 changes: 22 additions & 21 deletions providers/openai/text_to_text.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,28 +87,29 @@ func (p Provider) TextToText(params TextToTextParams) *agency.Operation {

openAIMessages = append(openAIMessages, firstChoice.Message)

firstToolCall := firstChoice.Message.ToolCalls[0]
funcToCall := getFuncDefByName(params.FuncDefs, firstToolCall.Function.Name)
if funcToCall == nil {
return agency.Message{}, errors.New("function not found")
for _, toolCall := range firstChoice.Message.ToolCalls {
funcToCall := getFuncDefByName(params.FuncDefs, toolCall.Function.Name)
if funcToCall == nil {
return agency.Message{}, errors.New("function not found")
}

funcResult, err := funcToCall.Body(ctx, []byte(toolCall.Function.Arguments))
if err != nil {
return agency.Message{}, fmt.Errorf("call function %s: %w", funcToCall.Name, err)
}

bb, err := json.Marshal(funcResult)
if err != nil {
return agency.Message{}, fmt.Errorf("marshal function result: %w", err)
}

openAIMessages = append(openAIMessages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleTool,
Content: string(bb),
Name: toolCall.Function.Name,
ToolCallID: toolCall.ID,
})
}

funcResult, err := funcToCall.Body(ctx, []byte(firstToolCall.Function.Arguments))
if err != nil {
return agency.Message{}, fmt.Errorf("call function %s: %w", funcToCall.Name, err)
}

bb, err := json.Marshal(funcResult)
if err != nil {
return agency.Message{}, fmt.Errorf("marshal function result: %w", err)
}

openAIMessages = append(openAIMessages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleTool,
Content: string(bb),
Name: firstToolCall.Function.Name,
ToolCallID: firstToolCall.ID,
})
}
},
)
Expand Down

0 comments on commit ecccd2b

Please sign in to comment.