Skip to content

Commit

Permalink
ChatSession interface (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
sunshineplan authored Mar 13, 2024
1 parent 38ac197 commit a568231
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 28 deletions.
12 changes: 11 additions & 1 deletion ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type AI interface {
Model

Chatbot
ChatSession() Chatbot
ChatSession() ChatSession

Close() error
}
Expand All @@ -33,6 +33,16 @@ type Chatbot interface {
ChatStream(context.Context, ...string) (ChatStream, error)
}

type Message struct {
Content string
Role string
}

type ChatSession interface {
Chatbot
History() []Message
}

type ChatStream interface {
Next() (ChatResponse, error)
Close() error
Expand Down
41 changes: 30 additions & 11 deletions chatgpt/chatgpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@ func (resp *ChatResponse[Response]) String() string {
}

func (ai *ChatGPT) createRequest(
stream bool,
one bool,
history []openai.ChatCompletionMessage,
messages ...string,
) (req openai.ChatCompletionRequest) {
req.Model = ai.model
if ai.maxTokens != nil {
req.MaxTokens = int(*ai.maxTokens)
}
if !stream && ai.count != nil {
if !one && ai.count != nil {
req.N = int(*ai.count)
}
if ai.temperature != nil {
Expand All @@ -114,6 +114,7 @@ func (ai *ChatGPT) createRequest(

func (chatgpt *ChatGPT) chat(
ctx context.Context,
session bool,
history []openai.ChatCompletionMessage,
messages ...string,
) (resp openai.ChatCompletionResponse, err error) {
Expand All @@ -124,11 +125,11 @@ func (chatgpt *ChatGPT) chat(
if err = chatgpt.wait(ctx); err != nil {
return
}
return chatgpt.c.CreateChatCompletion(ctx, chatgpt.createRequest(false, history, messages...))
return chatgpt.c.CreateChatCompletion(ctx, chatgpt.createRequest(session, history, messages...))
}

func (ai *ChatGPT) Chat(ctx context.Context, messages ...string) (ai.ChatResponse, error) {
resp, err := ai.chat(ctx, nil, messages...)
resp, err := ai.chat(ctx, false, nil, messages...)
if err != nil {
return nil, err
}
Expand All @@ -148,7 +149,7 @@ func (stream *ChatStream) Next() (ai.ChatResponse, error) {
if err != nil {
if err == io.EOF {
if stream.cs != nil {
stream.cs.History = append(stream.cs.History, openai.ChatCompletionMessage{
stream.cs.history = append(stream.cs.history, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant, Content: stream.merged})
}
}
Expand Down Expand Up @@ -189,31 +190,49 @@ func (ai *ChatGPT) ChatStream(ctx context.Context, messages ...string) (ai.ChatS
return &ChatStream{stream, nil, ""}, nil
}

var _ ai.Chatbot = new(ChatSession)
var _ ai.ChatSession = new(ChatSession)

type ChatSession struct {
ai *ChatGPT
History []openai.ChatCompletionMessage
history []openai.ChatCompletionMessage
}

func addToHistory(history *[]openai.ChatCompletionMessage, role string, messages ...string) {
for _, i := range messages {
*history = append(
*history,
openai.ChatCompletionMessage{Role: role, Content: i},
)
}
}

func (session *ChatSession) Chat(ctx context.Context, messages ...string) (ai.ChatResponse, error) {
resp, err := session.ai.chat(ctx, session.History, messages...)
resp, err := session.ai.chat(ctx, true, session.history, messages...)
if err != nil {
return nil, err
}
session.History = append(session.History, resp.Choices[0].Message)
addToHistory(&session.history, openai.ChatMessageRoleUser, messages...)
session.history = append(session.history, resp.Choices[0].Message)
return &ChatResponse[openai.ChatCompletionResponse]{resp}, nil
}

func (session *ChatSession) ChatStream(ctx context.Context, messages ...string) (ai.ChatStream, error) {
stream, err := session.ai.chatStream(ctx, session.History, messages...)
stream, err := session.ai.chatStream(ctx, session.history, messages...)
if err != nil {
return nil, err
}
addToHistory(&session.history, openai.ChatMessageRoleUser, messages...)
return &ChatStream{stream, session, ""}, nil
}

func (ai *ChatGPT) ChatSession() ai.Chatbot {
func (session *ChatSession) History() (history []ai.Message) {
for _, i := range session.history {
history = append(history, ai.Message{Content: i.Content, Role: i.Role})
}
return
}

func (ai *ChatGPT) ChatSession() ai.ChatSession {
return &ChatSession{ai: ai}
}

Expand Down
78 changes: 78 additions & 0 deletions chatgpt/chatgpt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package chatgpt

import (
"context"
"fmt"
"io"
"os"
"testing"
"time"
)

func TestChatGPT(t *testing.T) {
apiKey := os.Getenv("CHATGPT_API_KEY")
if apiKey == "" {
return
}
chatgpt := New(apiKey)
defer chatgpt.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
fmt.Println("Who are you?")
resp, err := chatgpt.Chat(ctx, "Who are you?")
if err != nil {
t.Fatal(err)
}
fmt.Println(resp.Results())
fmt.Println("---")
fmt.Println("Who am I?")
ctx, cancel = context.WithTimeout(context.Background(), time.Minute)
defer cancel()
stream, err := chatgpt.ChatStream(ctx, "Who am I?")
if err != nil {
t.Fatal(err)
}
defer stream.Close()
for {
resp, err := stream.Next()
if err != nil {
if err == io.EOF {
break
}
t.Fatal(err)
}
fmt.Println(resp.Results())
}
fmt.Println("---")
s := chatgpt.ChatSession()
ctx, cancel = context.WithTimeout(context.Background(), time.Minute)
defer cancel()
fmt.Println("Hello, I have 2 dogs in my house.")
resp, err = s.Chat(ctx, "Hello, I have 2 dogs in my house.")
if err != nil {
t.Fatal(err)
}
fmt.Println(resp.Results())
ctx, cancel = context.WithTimeout(context.Background(), time.Minute)
defer cancel()
fmt.Println("How many paws are in my house?")
stream, err = s.ChatStream(ctx, "How many paws are in my house?")
if err != nil {
t.Fatal(err)
}
defer stream.Close()
for {
resp, err := stream.Next()
if err != nil {
if err == io.EOF {
break
}
t.Fatal(err)
}
fmt.Println(resp.Results())
}
fmt.Println("---")
for _, i := range s.History() {
fmt.Println(i.Role, ":", i.Content)
}
}
42 changes: 26 additions & 16 deletions gemini/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@ func (ai *Gemini) SetMaxTokens(i int32) { ai.config.SetMaxOutputTokens(i) }
func (ai *Gemini) SetTemperature(f float32) { ai.config.SetTemperature(f) }
func (ai *Gemini) SetTopP(f float32) { ai.config.SetTopP(f) }

func texts2parts(texts []string) (parts []genai.Part) {
for _, i := range texts {
parts = append(parts, genai.Text(i))
}
return
}

func parts2texts(parts []genai.Part) (texts []string) {
for _, i := range parts {
if text, ok := i.(genai.Text); ok {
texts = append(texts, string(text))
}
}
return
}

var _ ai.ChatResponse = new(ChatResponse)

type ChatResponse struct {
Expand All @@ -68,13 +84,7 @@ type ChatResponse struct {
func (resp *ChatResponse) Results() (res []string) {
for _, i := range resp.Candidates {
if i.Content != nil {
var parts []string
for _, part := range i.Content.Parts {
if text, ok := part.(genai.Text); ok {
parts = append(parts, string(text))
}
}
res = append(res, strings.Join(parts, "\n"))
res = append(res, strings.Join(parts2texts(i.Content.Parts), "\n"))
}
}
return
Expand All @@ -87,13 +97,6 @@ func (resp *ChatResponse) String() string {
return ""
}

func texts2parts(texts []string) (parts []genai.Part) {
for _, i := range texts {
parts = append(parts, genai.Text(i))
}
return
}

func (ai *Gemini) Chat(ctx context.Context, parts ...string) (ai.ChatResponse, error) {
if err := ai.wait(ctx); err != nil {
return nil, err
Expand Down Expand Up @@ -137,7 +140,7 @@ func (ai *Gemini) ChatStream(ctx context.Context, parts ...string) (ai.ChatStrea
return &ChatStream{ai.model.GenerateContentStream(ctx, texts2parts(parts)...)}, nil
}

var _ ai.Chatbot = new(ChatSession)
var _ ai.ChatSession = new(ChatSession)

type ChatSession struct {
ai *Gemini
Expand All @@ -162,7 +165,14 @@ func (session *ChatSession) ChatStream(ctx context.Context, parts ...string) (ai
return &ChatStream{session.cs.SendMessageStream(ctx, texts2parts(parts)...)}, nil
}

func (ai *Gemini) ChatSession() ai.Chatbot {
func (session *ChatSession) History() (history []ai.Message) {
for _, i := range session.cs.History {
history = append(history, ai.Message{Content: strings.Join(parts2texts(i.Parts), "\n"), Role: i.Role})
}
return
}

func (ai *Gemini) ChatSession() ai.ChatSession {
return &ChatSession{ai, ai.model.StartChat()}
}

Expand Down
4 changes: 4 additions & 0 deletions gemini/gemini_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,8 @@ func TestGemini(t *testing.T) {
}
fmt.Println(resp.Results())
}
fmt.Println("---")
for _, i := range s.History() {
fmt.Println(i.Role, ":", i.Content)
}
}

0 comments on commit a568231

Please sign in to comment.