From 3a8cd742c1ad311a59ed6ce5ded6ddbdddfc22c6 Mon Sep 17 00:00:00 2001 From: Alex Guo Date: Tue, 4 Jan 2022 18:05:45 -0500 Subject: [PATCH 1/2] added check api key function + tests --- client.go | 70 ++++++++++++++++++++++++++++++++++++++++++++----- client_test.go | 43 +++++++++++++++++++----------- example/main.go | 6 ++++- 3 files changed, 95 insertions(+), 24 deletions(-) diff --git a/client.go b/client.go index cfc0997..a03cbfa 100644 --- a/client.go +++ b/client.go @@ -3,6 +3,7 @@ package cohere import ( "bytes" "encoding/json" + "errors" "io" "net/http" "path" @@ -16,22 +17,42 @@ type Client struct { } const ( - endpointGenerate = "/generate" - endpointSimilarity = "/similarity" - endpointChooseBest = "/choose-best" - endpointEmbed = "/embed" - endpointLikelihood = "/likelihood" + endpointGenerate = "generate" + endpointSimilarity = "similarity" + endpointChooseBest = "choose-best" + endpointEmbed = "embed" + endpointLikelihood = "likelihood" + + endpointCheckAPIKey = "check-api-key" ) +type CheckAPIKeyResponse struct { + Valid bool +} + // Public functions -func CreateClient(apiKey string) *Client { - return &Client{ +func CreateClient(apiKey string) (*Client, error) { + client := &Client{ APIKey: apiKey, BaseURL: "https://api.cohere.ai/", Client: *http.DefaultClient, Version: "2021-11-08", } + + res, err := client.CheckAPIKey() + if err != nil { + return nil, err + } + + ret := &CheckAPIKeyResponse{} + if err := json.Unmarshal(res, ret); err != nil { + return nil, err + } + if !ret.Valid { + return nil, errors.New("invalid api key") + } + return client, nil } // Client methods @@ -50,6 +71,7 @@ func (c *Client) post(model string, endpoint string, body interface{}) ([]byte, req.Header.Set("Authorization", "BEARER "+c.APIKey) req.Header.Set("Content-Type", "application/json") + req.Header.Set("Request-Source", "go-sdk") if len(c.Version) > 0 { req.Header.Set("Cohere-Version", c.Version) } @@ -75,6 +97,40 @@ func (c *Client) post(model string, endpoint string, body interface{}) ([]byte, return buf, nil } +func (c *Client) CheckAPIKey() ([]byte, error) { + url := c.BaseURL + endpointCheckAPIKey + req, err := http.NewRequest("POST", url, nil) + if err != nil { + return nil, err + } + + req.Header.Set("Authorization", "BEARER "+c.APIKey) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Request-Source", "go-sdk") + if len(c.Version) > 0 { + req.Header.Set("Cohere-Version", c.Version) + } + res, err := c.Client.Do(req) + if err != nil { + return nil, err + } + + defer res.Body.Close() + buf, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + if res.StatusCode != 200 { + e := &APIError{} + if err := json.Unmarshal(buf, e); err != nil { + return nil, err + } + e.StatusCode = res.StatusCode + return nil, e + } + return buf, nil +} + // Generates realistic text conditioned on a given input. // See: https://docs.cohere.ai/generate-reference // Returns a GenerateResponse object. diff --git a/client_test.go b/client_test.go index 4945e59..bacee3e 100644 --- a/client_test.go +++ b/client_test.go @@ -1,9 +1,7 @@ package cohere import ( - "errors" "os" - "reflect" "testing" ) @@ -17,22 +15,23 @@ func init() { func TestErrors(t *testing.T) { t.Run("Invalid api key", func(t *testing.T) { - co := CreateClient("") - _, err := co.Generate("small", GenerateOptions{ - Prompt: "", - MaxTokens: 10, - Temperature: 0.75, - }) + co, err := CreateClient("") + if co != nil { + t.Error("expected nil client, got client") + } if err == nil { - t.Errorf("expected error, got nil") - } else if !errors.Is(err, &APIError{}) { - t.Errorf("expected ApiError, got %s", reflect.TypeOf(err)) + t.Error("expected error, got nil") + } else if err.Error() != "invalid api key" { + t.Errorf("expected invalid api key, got %s", err.Error()) } }) } func TestGenerate(t *testing.T) { - co := CreateClient(apiKey) + co, err := CreateClient(apiKey) + if err != nil { + t.Error(err) + } t.Run("Generate basic", func(t *testing.T) { _, err := co.Generate("medium", GenerateOptions{ @@ -62,7 +61,10 @@ func TestGenerate(t *testing.T) { } func TestSimilarity(t *testing.T) { - co := CreateClient(apiKey) + co, err := CreateClient(apiKey) + if err != nil { + t.Error(err) + } t.Run("Similarity", func(t *testing.T) { _, err := co.Similarity("small", SimilarityOptions{ @@ -76,7 +78,10 @@ func TestSimilarity(t *testing.T) { } func TestChooseBest(t *testing.T) { - co := CreateClient(apiKey) + co, err := CreateClient(apiKey) + if err != nil { + t.Error(err) + } t.Run("ChooseBest", func(t *testing.T) { _, err := co.ChooseBest("small", ChooseBestOptions{ @@ -92,7 +97,10 @@ func TestChooseBest(t *testing.T) { } func TestEmbed(t *testing.T) { - co := CreateClient(apiKey) + co, err := CreateClient(apiKey) + if err != nil { + t.Error(err) + } t.Run("Embed", func(t *testing.T) { texts := []string{"hello", "goodbye"} @@ -108,7 +116,10 @@ func TestEmbed(t *testing.T) { } func TestLikelihood(t *testing.T) { - co := CreateClient(apiKey) + co, err := CreateClient(apiKey) + if err != nil { + t.Error(err) + } t.Run("Likelihood", func(t *testing.T) { text := "so I crept up the basement stairs and BOOOO!" diff --git a/example/main.go b/example/main.go index 8e37c69..f99cdb1 100644 --- a/example/main.go +++ b/example/main.go @@ -14,7 +14,11 @@ func main() { os.Exit(1) } - co := cohere.CreateClient(apiKey) + co, err := cohere.CreateClient(apiKey) + if err != nil { + fmt.Println("invalid api key") + return + } prompt := "What is your" res, err := co.Generate("medium", cohere.GenerateOptions{ From cad28ef104fc01990ad506cae3c558d6c903d03c Mon Sep 17 00:00:00 2001 From: Alex Guo Date: Wed, 5 Jan 2022 10:30:18 -0500 Subject: [PATCH 2/2] print error --- example/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/main.go b/example/main.go index f99cdb1..46b2c97 100644 --- a/example/main.go +++ b/example/main.go @@ -16,7 +16,7 @@ func main() { co, err := cohere.CreateClient(apiKey) if err != nil { - fmt.Println("invalid api key") + fmt.Println(err) return }