From ce07e331e1bdfa453fbea2b0837a36bc2e59c111 Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Wed, 13 Mar 2024 16:00:39 +0800 Subject: [PATCH] Rewrite test (#14) --- ai_test.go | 132 ++++++++++++++++++++++++++++++++++++++++ chatgpt/chatgpt_test.go | 78 ------------------------ gemini/gemini_test.go | 86 -------------------------- 3 files changed, 132 insertions(+), 164 deletions(-) create mode 100644 ai_test.go delete mode 100644 chatgpt/chatgpt_test.go delete mode 100644 gemini/gemini_test.go diff --git a/ai_test.go b/ai_test.go new file mode 100644 index 0000000..2b5e88b --- /dev/null +++ b/ai_test.go @@ -0,0 +1,132 @@ +package ai_test + +import ( + "context" + "fmt" + "io" + "os" + "testing" + "time" + + "github.com/sunshineplan/ai" + "github.com/sunshineplan/ai/chatgpt" + "github.com/sunshineplan/ai/gemini" +) + +func init() { + if proxy := os.Getenv("AI_PROXY"); proxy != "" { + ai.SetProxy(proxy) + } +} + +func testChat(ai ai.AI, prompt string) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + fmt.Println(prompt) + resp, err := ai.Chat(ctx, prompt) + if err != nil { + return err + } + fmt.Println(resp.Results()) + fmt.Println("---") + return nil +} + +func testChatStream(ai ai.AI, prompt string) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + fmt.Println(prompt) + stream, err := ai.ChatStream(ctx, prompt) + if err != nil { + return err + } + defer stream.Close() + for { + resp, err := stream.Next() + if err != nil { + if err == io.EOF { + break + } + return err + } + fmt.Println(resp.Results()) + } + fmt.Println("---") + return nil +} + +func testChatSession(ai ai.AI) error { + s := ai.ChatSession() + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + fmt.Println("Hello, I have 2 dogs in my house.") + resp, err := s.Chat(ctx, "Hello, I have 2 dogs in my house.") + if err != nil { + return err + } + fmt.Println(resp.Results()) + ctx, cancel = context.WithTimeout(context.Background(), time.Minute) + defer cancel() + fmt.Println("How many paws are in my house?") + stream, err := s.ChatStream(ctx, "How many paws are in my house?") + if err != nil { + return err + } + defer stream.Close() + for { + resp, err := stream.Next() + if err != nil { + if err == io.EOF { + break + } + return err + } + fmt.Println(resp.Results()) + } + fmt.Println("---") + fmt.Println("History") + for _, i := range s.History() { + fmt.Println(i.Role, ":", i.Content) + } + fmt.Println("---") + return nil +} + +func TestGemini(t *testing.T) { + apiKey := os.Getenv("GEMINI_API_KEY") + if apiKey == "" { + return + } + gemini, err := gemini.New(apiKey) + if err != nil { + t.Fatal(err) + } + defer gemini.Close() + if err := testChat(gemini, "Who are you?"); err != nil { + t.Error(err) + } + if err := testChatStream(gemini, "Who am I?"); err != nil { + t.Error(err) + } + if err := testChatSession(gemini); err != nil { + t.Error(err) + } +} + +func TestChatGPT(t *testing.T) { + apiKey := os.Getenv("CHATGPT_API_KEY") + if apiKey == "" { + return + } + chatgpt := chatgpt.New(apiKey) + defer chatgpt.Close() + if err := testChat(chatgpt, "Who are you?"); err != nil { + t.Error(err) + } + if err := testChatStream(chatgpt, "Who am I?"); err != nil { + t.Error(err) + } + if err := testChatSession(chatgpt); err != nil { + t.Error(err) + } +} diff --git a/chatgpt/chatgpt_test.go b/chatgpt/chatgpt_test.go deleted file mode 100644 index 22e9464..0000000 --- a/chatgpt/chatgpt_test.go +++ /dev/null @@ -1,78 +0,0 @@ -package chatgpt - -import ( - "context" - "fmt" - "io" - "os" - "testing" - "time" -) - -func TestChatGPT(t *testing.T) { - apiKey := os.Getenv("CHATGPT_API_KEY") - if apiKey == "" { - return - } - chatgpt := New(apiKey) - defer chatgpt.Close() - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - fmt.Println("Who are you?") - resp, err := chatgpt.Chat(ctx, "Who are you?") - if err != nil { - t.Fatal(err) - } - fmt.Println(resp.Results()) - fmt.Println("---") - fmt.Println("Who am I?") - ctx, cancel = context.WithTimeout(context.Background(), time.Minute) - defer cancel() - stream, err := chatgpt.ChatStream(ctx, "Who am I?") - if err != nil { - t.Fatal(err) - } - defer stream.Close() - for { - resp, err := stream.Next() - if err != nil { - if err == io.EOF { - break - } - t.Fatal(err) - } - fmt.Println(resp.Results()) - } - fmt.Println("---") - s := chatgpt.ChatSession() - ctx, cancel = context.WithTimeout(context.Background(), time.Minute) - defer cancel() - fmt.Println("Hello, I have 2 dogs in my house.") - resp, err = s.Chat(ctx, "Hello, I have 2 dogs in my house.") - if err != nil { - t.Fatal(err) - } - fmt.Println(resp.Results()) - ctx, cancel = context.WithTimeout(context.Background(), time.Minute) - defer cancel() - fmt.Println("How many paws are in my house?") - stream, err = s.ChatStream(ctx, "How many paws are in my house?") - if err != nil { - t.Fatal(err) - } - defer stream.Close() - for { - resp, err := stream.Next() - if err != nil { - if err == io.EOF { - break - } - t.Fatal(err) - } - fmt.Println(resp.Results()) - } - fmt.Println("---") - for _, i := range s.History() { - fmt.Println(i.Role, ":", i.Content) - } -} diff --git a/gemini/gemini_test.go b/gemini/gemini_test.go deleted file mode 100644 index b519f54..0000000 --- a/gemini/gemini_test.go +++ /dev/null @@ -1,86 +0,0 @@ -package gemini - -import ( - "context" - "fmt" - "io" - "os" - "testing" - "time" - - "github.com/sunshineplan/ai" -) - -func TestGemini(t *testing.T) { - if proxy := os.Getenv("GEMINI_PROXY"); proxy != "" { - ai.SetProxy(proxy) - } - apiKey := os.Getenv("GEMINI_API_KEY") - if apiKey == "" { - return - } - gemini, err := New(apiKey) - if err != nil { - t.Fatal(err) - } - defer gemini.Close() - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - fmt.Println("Who are you?") - resp, err := gemini.Chat(ctx, "Who are you?") - if err != nil { - t.Fatal(err) - } - fmt.Println(resp.Results()) - fmt.Println("---") - fmt.Println("Who am I?") - ctx, cancel = context.WithTimeout(context.Background(), time.Minute) - defer cancel() - stream, err := gemini.ChatStream(ctx, "Who am I?") - if err != nil { - t.Fatal(err) - } - defer stream.Close() - for { - resp, err := stream.Next() - if err != nil { - if err == io.EOF { - break - } - t.Fatal(err) - } - fmt.Println(resp.Results()) - } - fmt.Println("---") - s := gemini.ChatSession() - ctx, cancel = context.WithTimeout(context.Background(), time.Minute) - defer cancel() - fmt.Println("Hello, I have 2 dogs in my house.") - resp, err = s.Chat(ctx, "Hello, I have 2 dogs in my house.") - if err != nil { - t.Fatal(err) - } - fmt.Println(resp.Results()) - ctx, cancel = context.WithTimeout(context.Background(), time.Minute) - defer cancel() - fmt.Println("How many paws are in my house?") - stream, err = s.ChatStream(ctx, "How many paws are in my house?") - if err != nil { - t.Fatal(err) - } - defer stream.Close() - for { - resp, err := stream.Next() - if err != nil { - if err == io.EOF { - break - } - t.Fatal(err) - } - fmt.Println(resp.Results()) - } - fmt.Println("---") - for _, i := range s.History() { - fmt.Println(i.Role, ":", i.Content) - } -}