-
-
Notifications
You must be signed in to change notification settings - Fork 735
/
Copy pathcompletions.go
89 lines (79 loc) · 2.7 KB
/
completions.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
package openaiclient
import (
"context"
)
// CompletionRequest is a request to complete a completion.
type CompletionRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Temperature float64 `json:"temperature"`
MaxTokens int `json:"max_tokens,omitempty"`
N int `json:"n,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
TopP float64 `json:"top_p,omitempty"`
StopWords []string `json:"stop,omitempty"`
// StreamingFunc is a function to be called for each chunk of a streaming response.
// Return an error to stop streaming early.
StreamingFunc func(ctx context.Context, chunk []byte) error `json:"-"`
}
type CompletionResponse struct {
ID string `json:"id,omitempty"`
Created float64 `json:"created,omitempty"`
Choices []struct {
FinishReason string `json:"finish_reason,omitempty"`
Index float64 `json:"index,omitempty"`
Logprobs interface{} `json:"logprobs,omitempty"`
Text string `json:"text,omitempty"`
} `json:"choices,omitempty"`
Model string `json:"model,omitempty"`
Object string `json:"object,omitempty"`
Usage struct {
CompletionTokens float64 `json:"completion_tokens,omitempty"`
PromptTokens float64 `json:"prompt_tokens,omitempty"`
TotalTokens float64 `json:"total_tokens,omitempty"`
} `json:"usage,omitempty"`
}
type errorMessage struct {
Error struct {
Message string `json:"message"`
Type string `json:"type"`
} `json:"error"`
}
func (c *Client) setCompletionDefaults(payload *CompletionRequest) {
// Set defaults
if payload.MaxTokens == 0 {
payload.MaxTokens = 256
}
if len(payload.StopWords) == 0 {
payload.StopWords = nil
}
switch {
// Prefer the model specified in the payload.
case payload.Model != "":
// If no model is set in the payload, take the one specified in the client.
case c.Model != "":
payload.Model = c.Model
// Fallback: use the default model
default:
payload.Model = defaultChatModel
}
}
// nolint:lll
func (c *Client) createCompletion(ctx context.Context, payload *CompletionRequest) (*ChatResponse, error) {
c.setCompletionDefaults(payload)
return c.createChat(ctx, &ChatRequest{
Model: payload.Model,
Messages: []*ChatMessage{
{Role: "user", Content: payload.Prompt},
},
Temperature: payload.Temperature,
TopP: payload.TopP,
MaxTokens: payload.MaxTokens,
N: payload.N,
StopWords: payload.StopWords,
FrequencyPenalty: payload.FrequencyPenalty,
PresencePenalty: payload.PresencePenalty,
StreamingFunc: payload.StreamingFunc,
})
}