Skip to content

Commit

Permalink
Improve close (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
sunshineplan authored Mar 13, 2024
1 parent f3dd064 commit 38ac197
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 27 deletions.
3 changes: 3 additions & 0 deletions ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ package ai

import (
"context"
"errors"
"net/http"
"net/url"
)

var ErrAIClosed = errors.New("AI client is nil or already closed")

type AI interface {
Limiter

Expand Down
63 changes: 44 additions & 19 deletions chatgpt/chatgpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ const defaultModel = openai.GPT3Dot5Turbo
var _ ai.AI = new(ChatGPT)

type ChatGPT struct {
*openai.Client
c *openai.Client
model string
maxTokens *int32
temperature *float32
Expand All @@ -30,7 +30,10 @@ func New(authToken string) ai.AI {
}

func NewWithClient(client *openai.Client) ai.AI {
return &ChatGPT{Client: client, model: defaultModel}
if client == nil {
panic("cannot create AI from nil client")
}
return &ChatGPT{c: client, model: defaultModel}
}

func (chatgpt *ChatGPT) SetLimit(limit rate.Limit) {
Expand Down Expand Up @@ -74,12 +77,23 @@ func (resp *ChatResponse[Response]) Results() (res []string) {
return
}

func (ai *ChatGPT) createRequest(history []openai.ChatCompletionMessage, messages ...string) (req openai.ChatCompletionRequest) {
func (resp *ChatResponse[Response]) String() string {
if res := resp.Results(); len(res) > 0 {
return res[0]
}
return ""
}

func (ai *ChatGPT) createRequest(
stream bool,
history []openai.ChatCompletionMessage,
messages ...string,
) (req openai.ChatCompletionRequest) {
req.Model = ai.model
if ai.maxTokens != nil {
req.MaxTokens = int(*ai.maxTokens)
}
if ai.count != nil {
if !stream && ai.count != nil {
req.N = int(*ai.count)
}
if ai.temperature != nil {
Expand All @@ -98,15 +112,19 @@ func (ai *ChatGPT) createRequest(history []openai.ChatCompletionMessage, message
return
}

func (ai *ChatGPT) chat(
func (chatgpt *ChatGPT) chat(
ctx context.Context,
history []openai.ChatCompletionMessage,
messages ...string,
) (resp openai.ChatCompletionResponse, err error) {
if err = ai.wait(ctx); err != nil {
if chatgpt.c == nil {
err = ai.ErrAIClosed
return
}
return ai.CreateChatCompletion(ctx, ai.createRequest(history, messages...))
if err = chatgpt.wait(ctx); err != nil {
return
}
return chatgpt.c.CreateChatCompletion(ctx, chatgpt.createRequest(false, history, messages...))
}

func (ai *ChatGPT) Chat(ctx context.Context, messages ...string) (ai.ChatResponse, error) {
Expand All @@ -120,40 +138,47 @@ func (ai *ChatGPT) Chat(ctx context.Context, messages ...string) (ai.ChatRespons
var _ ai.ChatStream = new(ChatStream)

type ChatStream struct {
*openai.ChatCompletionStream
cs *ChatSession
content string
sr *openai.ChatCompletionStream
cs *ChatSession
merged string
}

func (stream *ChatStream) Next() (ai.ChatResponse, error) {
resp, err := stream.Recv()
resp, err := stream.sr.Recv()
if err != nil {
if err == io.EOF {
if stream.cs != nil {
stream.cs.History = append(stream.cs.History, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant, Content: stream.content})
Role: openai.ChatMessageRoleAssistant, Content: stream.merged})
}
}
stream.content = ""
stream.merged = ""
return nil, err
}
if stream.cs != nil {
stream.content += resp.Choices[0].Delta.Content
stream.merged += resp.Choices[0].Delta.Content
}
return &ChatResponse[openai.ChatCompletionStreamResponse]{resp}, nil
}

func (ai *ChatGPT) chatStream(
func (stream *ChatStream) Close() error {
return stream.sr.Close()
}

func (chatgpt *ChatGPT) chatStream(
ctx context.Context,
history []openai.ChatCompletionMessage,
messages ...string,
) (*openai.ChatCompletionStream, error) {
if err := ai.wait(ctx); err != nil {
if chatgpt.c == nil {
return nil, ai.ErrAIClosed
}
if err := chatgpt.wait(ctx); err != nil {
return nil, err
}
req := ai.createRequest(history, messages...)
req := chatgpt.createRequest(true, history, messages...)
req.Stream = true
return ai.CreateChatCompletionStream(ctx, req)
return chatgpt.c.CreateChatCompletionStream(ctx, req)
}

func (ai *ChatGPT) ChatStream(ctx context.Context, messages ...string) (ai.ChatStream, error) {
Expand Down Expand Up @@ -189,10 +214,10 @@ func (session *ChatSession) ChatStream(ctx context.Context, messages ...string)
}

func (ai *ChatGPT) ChatSession() ai.Chatbot {
ai.count = nil
return &ChatSession{ai: ai}
}

func (ai *ChatGPT) Close() error {
ai.c = nil
return nil
}
32 changes: 24 additions & 8 deletions gemini/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gemini

import (
"context"
"errors"
"io"
"strings"

Expand All @@ -18,7 +19,7 @@ const defaultModel = "gemini-1.0-pro"
var _ ai.AI = new(Gemini)

type Gemini struct {
*genai.Client
c *genai.Client
model *genai.GenerativeModel
config genai.GenerationConfig

Expand All @@ -34,7 +35,7 @@ func New(apiKey string) (ai.AI, error) {
}

func NewWithClient(client *genai.Client) ai.AI {
return &Gemini{Client: client, model: client.GenerativeModel(defaultModel)}
return &Gemini{c: client, model: client.GenerativeModel(defaultModel)}
}

func (gemini *Gemini) SetLimit(limit rate.Limit) {
Expand All @@ -49,7 +50,7 @@ func (ai *Gemini) wait(ctx context.Context) error {
}

func (ai *Gemini) SetModel(model string) {
ai.model = ai.GenerativeModel(model)
ai.model = ai.c.GenerativeModel(model)
ai.model.GenerationConfig = ai.config
}

Expand Down Expand Up @@ -79,6 +80,13 @@ func (resp *ChatResponse) Results() (res []string) {
return
}

func (resp *ChatResponse) String() string {
if res := resp.Results(); len(res) > 0 {
return res[0]
}
return ""
}

func texts2parts(texts []string) (parts []genai.Part) {
for _, i := range texts {
parts = append(parts, genai.Text(i))
Expand All @@ -100,11 +108,14 @@ func (ai *Gemini) Chat(ctx context.Context, parts ...string) (ai.ChatResponse, e
var _ ai.ChatStream = new(ChatStream)

type ChatStream struct {
*genai.GenerateContentResponseIterator
iter *genai.GenerateContentResponseIterator
}

func (stream *ChatStream) Next() (ai.ChatResponse, error) {
resp, err := stream.GenerateContentResponseIterator.Next()
if stream.iter == nil {
return nil, errors.New("stream iterator is nil or already closed")
}
resp, err := stream.iter.Next()
if err != nil {
if err == iterator.Done {
return nil, io.EOF
Expand All @@ -115,6 +126,7 @@ func (stream *ChatStream) Next() (ai.ChatResponse, error) {
}

func (stream *ChatStream) Close() error {
stream.iter = nil
return nil
}

Expand All @@ -129,14 +141,14 @@ var _ ai.Chatbot = new(ChatSession)

type ChatSession struct {
ai *Gemini
*genai.ChatSession
cs *genai.ChatSession
}

func (session *ChatSession) Chat(ctx context.Context, parts ...string) (ai.ChatResponse, error) {
if err := session.ai.wait(ctx); err != nil {
return nil, err
}
resp, err := session.SendMessage(ctx, texts2parts(parts)...)
resp, err := session.cs.SendMessage(ctx, texts2parts(parts)...)
if err != nil {
return nil, err
}
Expand All @@ -147,9 +159,13 @@ func (session *ChatSession) ChatStream(ctx context.Context, parts ...string) (ai
if err := session.ai.wait(ctx); err != nil {
return nil, err
}
return &ChatStream{session.SendMessageStream(ctx, texts2parts(parts)...)}, nil
return &ChatStream{session.cs.SendMessageStream(ctx, texts2parts(parts)...)}, nil
}

func (ai *Gemini) ChatSession() ai.Chatbot {
return &ChatSession{ai, ai.model.StartChat()}
}

func (ai *Gemini) Close() error {
return ai.c.Close()
}
3 changes: 3 additions & 0 deletions gemini/gemini_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ func TestGemini(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer gemini.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
fmt.Println("Who are you?")
Expand All @@ -39,6 +40,7 @@ func TestGemini(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer stream.Close()
for {
resp, err := stream.Next()
if err != nil {
Expand Down Expand Up @@ -66,6 +68,7 @@ func TestGemini(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer stream.Close()
for {
resp, err := stream.Next()
if err != nil {
Expand Down

0 comments on commit 38ac197

Please sign in to comment.