From 75008d0e3a8395d8bb05893115dd0dad0dd9fb57 Mon Sep 17 00:00:00 2001 From: Max Date: Mon, 1 Jan 2024 13:16:27 -0700 Subject: [PATCH 1/4] #54: Refactor UnifiedChatResponse struct and add ProviderResponse and TokenCount structs --- pkg/api/schemas/language.go | 26 +++++++++++++++----- pkg/providers/openai/chat.go | 47 ++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 6 deletions(-) diff --git a/pkg/api/schemas/language.go b/pkg/api/schemas/language.go index 29aad057..f692a2f2 100644 --- a/pkg/api/schemas/language.go +++ b/pkg/api/schemas/language.go @@ -8,12 +8,26 @@ type UnifiedChatRequest struct { // UnifiedChatResponse defines Glide's Chat Response Schema unified across all language models type UnifiedChatResponse struct { - ID string `json:"id,omitempty"` - Created float64 `json:"created,omitempty"` - Choices []*ChatChoice `json:"choices,omitempty"` - Model string `json:"model,omitempty"` - Object string `json:"object,omitempty"` // TODO: what does this mean "Object"? - Usage Usage `json:"usage,omitempty"` + ID string `json:"id,omitempty"` + Created float64 `json:"created,omitempty"` + Provider string `json:"provider,omitempty"` + Router string `json:"router,omitempty"` + Model string `json:"model,omitempty"` + Cached bool `json:"cached,omitempty"` + ProviderResponse ProviderResponse `json:"provider_response,omitempty"` +} + +// ProviderResponse contains data from the chosen provider +type ProviderResponse struct { + ResponseId map[string]string `json:"response_id,omitempty"` + Message ChatMessage `json:"message"` + TokenCount TokenCount `json:"token_count"` +} + +type TokenCount struct { + PromptTokens int `json:"prompt_tokens"` + ResponseTokens int `json:"response_tokens"` + TotalTokens int `json:"total_tokens"` } // ChatMessage is a message in a chat request. diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 0c2844f7..720555db 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "time" "glide/pkg/providers/errs" @@ -147,8 +148,54 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche return nil, errs.ErrProviderUnavailable } + // Read the response body into a byte slice + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + c.telemetry.Logger.Error("failed to read openai chat response", zap.Error(err)) + return nil, err + } + + // Parse the response JSON + var responseJSON map[string]interface{} + err = json.Unmarshal(bodyBytes, &responseJSON) + if err != nil { + c.telemetry.Logger.Error("failed to parse openai chat response", zap.Error(err)) + return nil, err + } + // Parse response var response schemas.UnifiedChatResponse + var responsePayload schemas.ProviderResponse + var tokenCount schemas.TokenCount + + message := responseJSON["choices"].([]interface{})[0].(map[string]interface{})["message"].(map[string]interface{}) + messageStruct := schemas.ChatMessage{ + Role: message["role"].(string), + Content: message["content"].(string), + } + + tokenCount = schemas.TokenCount{ + PromptTokens: responseJSON["usage"].(map[string]interface{})["prompt_tokens"].(int), + ResponseTokens: responseJSON["usage"].(map[string]interface{})["completion_tokens"].(int), + TotalTokens: responseJSON["usage"].(map[string]interface{})["total_tokens"].(int), + } + + responsePayload = schemas.ProviderResponse{ + ResponseId: map[string]string{"system_fingerprint": responseJSON["system_fingerprint"].(string)}, + Message: messageStruct, + TokenCount: tokenCount, + } + + + response = schemas.UnifiedChatResponse{ + ID: responseJSON["id"].(string), + Created: float64(time.Now().Unix()), + Provider: "openai", + Router: "chat", + Model: payload.Model, + Cached: false, + ProviderResponse: responsePayload, + } return &response, json.NewDecoder(resp.Body).Decode(&response) } From eaf61ee8ce88fa84ad2fa7beb42bbab4bbfa585c Mon Sep 17 00:00:00 2001 From: Max Date: Mon, 1 Jan 2024 13:38:48 -0700 Subject: [PATCH 2/4] #54: Unified response created and tested - passing --- pkg/api/schemas/language.go | 6 +++--- pkg/providers/openai/chat.go | 19 ++++++++++--------- pkg/providers/openai/client_test.go | 3 ++- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/pkg/api/schemas/language.go b/pkg/api/schemas/language.go index f692a2f2..b0335b02 100644 --- a/pkg/api/schemas/language.go +++ b/pkg/api/schemas/language.go @@ -25,9 +25,9 @@ type ProviderResponse struct { } type TokenCount struct { - PromptTokens int `json:"prompt_tokens"` - ResponseTokens int `json:"response_tokens"` - TotalTokens int `json:"total_tokens"` + PromptTokens float64`json:"prompt_tokens"` + ResponseTokens float64 `json:"response_tokens"` + TotalTokens float64 `json:"total_tokens"` } // ChatMessage is a message in a chat request. diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 720555db..1681a7e3 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -77,16 +77,15 @@ func NewChatMessagesFromUnifiedRequest(request *schemas.UnifiedChatRequest) []Ch // Chat sends a chat request to the specified OpenAI model. func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { // Create a new chat request + chatRequest := c.createChatRequestSchema(request) - - // TODO: this is suspicious we do zero remapping of OpenAI response and send it back as is. - // Does it really work well across providers? + chatResponse, err := c.doChatRequest(ctx, chatRequest) if err != nil { return nil, err } - if len(chatResponse.Choices) == 0 { + if len(chatResponse.ProviderResponse.Message.Content) == 0 { return nil, ErrEmptyResponse } @@ -163,6 +162,8 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche return nil, err } + + // Parse response var response schemas.UnifiedChatResponse var responsePayload schemas.ProviderResponse @@ -175,9 +176,9 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } tokenCount = schemas.TokenCount{ - PromptTokens: responseJSON["usage"].(map[string]interface{})["prompt_tokens"].(int), - ResponseTokens: responseJSON["usage"].(map[string]interface{})["completion_tokens"].(int), - TotalTokens: responseJSON["usage"].(map[string]interface{})["total_tokens"].(int), + PromptTokens: responseJSON["usage"].(map[string]interface{})["prompt_tokens"].(float64), + ResponseTokens: responseJSON["usage"].(map[string]interface{})["completion_tokens"].(float64), + TotalTokens: responseJSON["usage"].(map[string]interface{})["total_tokens"].(float64), } responsePayload = schemas.ProviderResponse{ @@ -192,10 +193,10 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche Created: float64(time.Now().Unix()), Provider: "openai", Router: "chat", - Model: payload.Model, + Model: responseJSON["model"].(string), Cached: false, ProviderResponse: responsePayload, } - return &response, json.NewDecoder(resp.Body).Decode(&response) + return &response, nil } diff --git a/pkg/providers/openai/client_test.go b/pkg/providers/openai/client_test.go index 394e7403..c81d3a3a 100644 --- a/pkg/providers/openai/client_test.go +++ b/pkg/providers/openai/client_test.go @@ -12,8 +12,9 @@ import ( "glide/pkg/api/schemas" - "github.com/stretchr/testify/require" "glide/pkg/telemetry" + + "github.com/stretchr/testify/require" ) func TestOpenAIClient_ChatRequest(t *testing.T) { From 956c404b7739c44b8b7e52a97588c3b9c52ef530 Mon Sep 17 00:00:00 2001 From: Max Date: Mon, 1 Jan 2024 13:41:43 -0700 Subject: [PATCH 3/4] #54: lint --- pkg/api/schemas/language.go | 2 +- pkg/providers/openai/chat.go | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/pkg/api/schemas/language.go b/pkg/api/schemas/language.go index b0335b02..cf86f7f1 100644 --- a/pkg/api/schemas/language.go +++ b/pkg/api/schemas/language.go @@ -25,7 +25,7 @@ type ProviderResponse struct { } type TokenCount struct { - PromptTokens float64`json:"prompt_tokens"` + PromptTokens float64 `json:"prompt_tokens"` ResponseTokens float64 `json:"response_tokens"` TotalTokens float64 `json:"total_tokens"` } diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 1681a7e3..6992eb83 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -77,9 +77,8 @@ func NewChatMessagesFromUnifiedRequest(request *schemas.UnifiedChatRequest) []Ch // Chat sends a chat request to the specified OpenAI model. func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) { // Create a new chat request - chatRequest := c.createChatRequestSchema(request) - + chatResponse, err := c.doChatRequest(ctx, chatRequest) if err != nil { return nil, err @@ -156,19 +155,20 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche // Parse the response JSON var responseJSON map[string]interface{} + err = json.Unmarshal(bodyBytes, &responseJSON) if err != nil { c.telemetry.Logger.Error("failed to parse openai chat response", zap.Error(err)) return nil, err } - - // Parse response var response schemas.UnifiedChatResponse + var responsePayload schemas.ProviderResponse + var tokenCount schemas.TokenCount - + message := responseJSON["choices"].([]interface{})[0].(map[string]interface{})["message"].(map[string]interface{}) messageStruct := schemas.ChatMessage{ Role: message["role"].(string), @@ -187,7 +187,6 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche TokenCount: tokenCount, } - response = schemas.UnifiedChatResponse{ ID: responseJSON["id"].(string), Created: float64(time.Now().Unix()), From e879a8d320be11b6e5cc8e02a8ce60dcbcd545a2 Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 2 Jan 2024 08:59:29 -0700 Subject: [PATCH 4/4] #54: lint --- pkg/providers/openai/chat.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 6992eb83..70de85fa 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -155,7 +155,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche // Parse the response JSON var responseJSON map[string]interface{} - + err = json.Unmarshal(bodyBytes, &responseJSON) if err != nil { c.telemetry.Logger.Error("failed to parse openai chat response", zap.Error(err))