Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check api key on client init #7

Merged
merged 3 commits into from
Jan 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 62 additions & 6 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cohere
import (
"bytes"
"encoding/json"
"errors"
"io"
"net/http"
"path"
Expand All @@ -16,21 +17,41 @@ type Client struct {
}

const (
endpointGenerate = "/generate"
endpointChooseBest = "/choose-best"
endpointEmbed = "/embed"
endpointLikelihood = "/likelihood"
endpointGenerate = "generate"
endpointChooseBest = "choose-best"
endpointEmbed = "embed"
endpointLikelihood = "likelihood"

endpointCheckAPIKey = "check-api-key"
)

type CheckAPIKeyResponse struct {
Valid bool
}

// Public functions

func CreateClient(apiKey string) *Client {
return &Client{
func CreateClient(apiKey string) (*Client, error) {
client := &Client{
APIKey: apiKey,
BaseURL: "https://api.cohere.ai/",
Client: *http.DefaultClient,
Version: "2021-11-08",
}

res, err := client.CheckAPIKey()
if err != nil {
return nil, err
}

ret := &CheckAPIKeyResponse{}
if err := json.Unmarshal(res, ret); err != nil {
return nil, err
}
if !ret.Valid {
return nil, errors.New("invalid api key")
}
return client, nil
}

// Client methods
Expand All @@ -49,6 +70,7 @@ func (c *Client) post(model string, endpoint string, body interface{}) ([]byte,

req.Header.Set("Authorization", "BEARER "+c.APIKey)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Request-Source", "go-sdk")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch!

if len(c.Version) > 0 {
req.Header.Set("Cohere-Version", c.Version)
}
Expand All @@ -74,6 +96,40 @@ func (c *Client) post(model string, endpoint string, body interface{}) ([]byte,
return buf, nil
}

func (c *Client) CheckAPIKey() ([]byte, error) {
url := c.BaseURL + endpointCheckAPIKey
req, err := http.NewRequest("POST", url, nil)
if err != nil {
return nil, err
}

req.Header.Set("Authorization", "BEARER "+c.APIKey)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Request-Source", "go-sdk")
if len(c.Version) > 0 {
req.Header.Set("Cohere-Version", c.Version)
}
res, err := c.Client.Do(req)
if err != nil {
return nil, err
}

defer res.Body.Close()
buf, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
if res.StatusCode != 200 {
e := &APIError{}
if err := json.Unmarshal(buf, e); err != nil {
return nil, err
}
e.StatusCode = res.StatusCode
return nil, e
}
return buf, nil
}

// Generates realistic text conditioned on a given input.
// See: https://docs.cohere.ai/generate-reference
// Returns a GenerateResponse object.
Expand Down
38 changes: 23 additions & 15 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package cohere

import (
"errors"
"os"
"reflect"
"testing"
)

Expand All @@ -17,22 +15,23 @@ func init() {

func TestErrors(t *testing.T) {
t.Run("Invalid api key", func(t *testing.T) {
co := CreateClient("")
_, err := co.Generate("small", GenerateOptions{
Prompt: "",
MaxTokens: 10,
Temperature: 0.75,
})
co, err := CreateClient("")
if co != nil {
t.Error("expected nil client, got client")
}
if err == nil {
t.Errorf("expected error, got nil")
} else if !errors.Is(err, &APIError{}) {
t.Errorf("expected ApiError, got %s", reflect.TypeOf(err))
t.Error("expected error, got nil")
} else if err.Error() != "invalid api key" {
t.Errorf("expected invalid api key, got %s", err.Error())
}
})
}

func TestGenerate(t *testing.T) {
co := CreateClient(apiKey)
co, err := CreateClient(apiKey)
if err != nil {
t.Error(err)
}

t.Run("Generate basic", func(t *testing.T) {
_, err := co.Generate("medium", GenerateOptions{
Expand Down Expand Up @@ -62,7 +61,10 @@ func TestGenerate(t *testing.T) {
}

func TestChooseBest(t *testing.T) {
co := CreateClient(apiKey)
co, err := CreateClient(apiKey)
if err != nil {
t.Error(err)
}

t.Run("ChooseBest", func(t *testing.T) {
_, err := co.ChooseBest("small", ChooseBestOptions{
Expand All @@ -78,7 +80,10 @@ func TestChooseBest(t *testing.T) {
}

func TestEmbed(t *testing.T) {
co := CreateClient(apiKey)
co, err := CreateClient(apiKey)
if err != nil {
t.Error(err)
}

t.Run("Embed", func(t *testing.T) {
texts := []string{"hello", "goodbye"}
Expand All @@ -94,7 +99,10 @@ func TestEmbed(t *testing.T) {
}

func TestLikelihood(t *testing.T) {
co := CreateClient(apiKey)
co, err := CreateClient(apiKey)
if err != nil {
t.Error(err)
}

t.Run("Likelihood", func(t *testing.T) {
text := "so I crept up the basement stairs and BOOOO!"
Expand Down
6 changes: 5 additions & 1 deletion example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ func main() {
os.Exit(1)
}

co := cohere.CreateClient(apiKey)
co, err := cohere.CreateClient(apiKey)
if err != nil {
fmt.Println(err)
return
}

prompt := "What is your"
res, err := co.Generate("medium", cohere.GenerateOptions{
Expand Down