Skip to content

Commit

Permalink
Merge pull request #7 from cohere-ai/check-api-key
Browse files Browse the repository at this point in the history
check api key on client init
  • Loading branch information
alexguo247 authored Jan 5, 2022
2 parents 0447882 + cad28ef commit fcd44a7
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 22 deletions.
68 changes: 62 additions & 6 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cohere
import (
"bytes"
"encoding/json"
"errors"
"io"
"net/http"
"path"
Expand All @@ -16,21 +17,41 @@ type Client struct {
}

const (
endpointGenerate = "/generate"
endpointChooseBest = "/choose-best"
endpointEmbed = "/embed"
endpointLikelihood = "/likelihood"
endpointGenerate = "generate"
endpointChooseBest = "choose-best"
endpointEmbed = "embed"
endpointLikelihood = "likelihood"

endpointCheckAPIKey = "check-api-key"
)

type CheckAPIKeyResponse struct {
Valid bool
}

// Public functions

func CreateClient(apiKey string) *Client {
return &Client{
func CreateClient(apiKey string) (*Client, error) {
client := &Client{
APIKey: apiKey,
BaseURL: "https://api.cohere.ai/",
Client: *http.DefaultClient,
Version: "2021-11-08",
}

res, err := client.CheckAPIKey()
if err != nil {
return nil, err
}

ret := &CheckAPIKeyResponse{}
if err := json.Unmarshal(res, ret); err != nil {
return nil, err
}
if !ret.Valid {
return nil, errors.New("invalid api key")
}
return client, nil
}

// Client methods
Expand All @@ -49,6 +70,7 @@ func (c *Client) post(model string, endpoint string, body interface{}) ([]byte,

req.Header.Set("Authorization", "BEARER "+c.APIKey)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Request-Source", "go-sdk")
if len(c.Version) > 0 {
req.Header.Set("Cohere-Version", c.Version)
}
Expand All @@ -74,6 +96,40 @@ func (c *Client) post(model string, endpoint string, body interface{}) ([]byte,
return buf, nil
}

func (c *Client) CheckAPIKey() ([]byte, error) {
url := c.BaseURL + endpointCheckAPIKey
req, err := http.NewRequest("POST", url, nil)
if err != nil {
return nil, err
}

req.Header.Set("Authorization", "BEARER "+c.APIKey)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Request-Source", "go-sdk")
if len(c.Version) > 0 {
req.Header.Set("Cohere-Version", c.Version)
}
res, err := c.Client.Do(req)
if err != nil {
return nil, err
}

defer res.Body.Close()
buf, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
if res.StatusCode != 200 {
e := &APIError{}
if err := json.Unmarshal(buf, e); err != nil {
return nil, err
}
e.StatusCode = res.StatusCode
return nil, e
}
return buf, nil
}

// Generates realistic text conditioned on a given input.
// See: https://docs.cohere.ai/generate-reference
// Returns a GenerateResponse object.
Expand Down
38 changes: 23 additions & 15 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package cohere

import (
"errors"
"os"
"reflect"
"testing"
)

Expand All @@ -17,22 +15,23 @@ func init() {

func TestErrors(t *testing.T) {
t.Run("Invalid api key", func(t *testing.T) {
co := CreateClient("")
_, err := co.Generate("small", GenerateOptions{
Prompt: "",
MaxTokens: 10,
Temperature: 0.75,
})
co, err := CreateClient("")
if co != nil {
t.Error("expected nil client, got client")
}
if err == nil {
t.Errorf("expected error, got nil")
} else if !errors.Is(err, &APIError{}) {
t.Errorf("expected ApiError, got %s", reflect.TypeOf(err))
t.Error("expected error, got nil")
} else if err.Error() != "invalid api key" {
t.Errorf("expected invalid api key, got %s", err.Error())
}
})
}

func TestGenerate(t *testing.T) {
co := CreateClient(apiKey)
co, err := CreateClient(apiKey)
if err != nil {
t.Error(err)
}

t.Run("Generate basic", func(t *testing.T) {
_, err := co.Generate("medium", GenerateOptions{
Expand Down Expand Up @@ -62,7 +61,10 @@ func TestGenerate(t *testing.T) {
}

func TestChooseBest(t *testing.T) {
co := CreateClient(apiKey)
co, err := CreateClient(apiKey)
if err != nil {
t.Error(err)
}

t.Run("ChooseBest", func(t *testing.T) {
_, err := co.ChooseBest("small", ChooseBestOptions{
Expand All @@ -78,7 +80,10 @@ func TestChooseBest(t *testing.T) {
}

func TestEmbed(t *testing.T) {
co := CreateClient(apiKey)
co, err := CreateClient(apiKey)
if err != nil {
t.Error(err)
}

t.Run("Embed", func(t *testing.T) {
texts := []string{"hello", "goodbye"}
Expand All @@ -94,7 +99,10 @@ func TestEmbed(t *testing.T) {
}

func TestLikelihood(t *testing.T) {
co := CreateClient(apiKey)
co, err := CreateClient(apiKey)
if err != nil {
t.Error(err)
}

t.Run("Likelihood", func(t *testing.T) {
text := "so I crept up the basement stairs and BOOOO!"
Expand Down
6 changes: 5 additions & 1 deletion example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ func main() {
os.Exit(1)
}

co := cohere.CreateClient(apiKey)
co, err := cohere.CreateClient(apiKey)
if err != nil {
fmt.Println(err)
return
}

prompt := "What is your"
res, err := co.Generate("medium", cohere.GenerateOptions{
Expand Down

0 comments on commit fcd44a7

Please sign in to comment.