Skip to content

Commit

Permalink
#24: Update DefinePayload func
Browse files Browse the repository at this point in the history
  • Loading branch information
mkrueger12 committed Dec 15, 2023
1 parent 5a4b019 commit c7a9329
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 151 deletions.
146 changes: 85 additions & 61 deletions pkg/buildApiReq.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ package pkg

import (
//"errors"
//"github.com/go-playground/validator/v10"
"github.com/go-playground/validator/v10"
"fmt"
"glide/pkg/providers"
//"glide/pkg/providers"
"glide/pkg/providers/openai"
"encoding/json"
//"net/http"
"reflect"
)


Expand All @@ -23,70 +24,93 @@ var configList = map[string]interface{}{
"openai": openai.OpenAIConfig,
}

// Sample JSON
//var jsonStr = `{"provider": "openai", "params": {"model": "gpt-3.5-turbo", "messages": "Hello, how are you?"}}`


func DefinePayload(payload []byte) (interface{}, error) {
// API route api.glide.com/v1/chat
// {"provider": "openai", "params": {"model": "gpt-3.5-turbo", "messages": "Hello, how are you?"}}

// Define a map to hold the JSON data
var payload_data map[string]interface{}

// Parse the JSON
err := json.Unmarshal([]byte(payload), &payload_data)
if err != nil {
// Handle error
fmt.Println(err)
}

endpoints, ok := payload_data["endpoints"].([]interface{})
if !ok {
// Handle error
fmt.Println("Endpoints not found")
}

providerList := make([]string, len(endpoints))
for i, endpoint := range endpoints {
endpointMap, ok := endpoint.(map[string]interface{})
if !ok {
// Handle error
fmt.Println("Endpoint is not a map")
}

provider, ok := endpointMap["provider"].(string)
if !ok {
// Handle error
fmt.Println("Provider not found")
}

providerList[i] = provider
}

// TODO: use mode and providerList to determine which provider to use
//modeList := payload_data["mode"].([]interface{})

provider := "openai"

// select the predefined config for the provider
var providerConfig map[string]interface{}
if config, ok := configList[provider].(pkg.ProviderConfigs); ok { // this pulls the config in index.go
if modeConfig, ok := config["chat"].(map[string]interface{}); ok { // this pulls the specific config for the endpoint
providerConfig = modeConfig
// Define a map to hold the JSON data
var payload_data map[string]interface{}

// Parse the JSON
err := json.Unmarshal([]byte(payload), &payload_data)
if err != nil {
// Handle error
fmt.Println(err)
}

endpoints, ok := payload_data["endpoints"].([]interface{})
if !ok {
// Handle error
fmt.Println("Endpoints not found")
}
}

// Build the providerConfig map by iterating over the keys in the providerConfig map and checking if the key exists in the params map
providerList := make([]string, len(endpoints))
for i, endpoint := range endpoints {
endpointMap, ok := endpoint.(map[string]interface{})
if !ok {
// Handle error
fmt.Println("Endpoint is not a map")
}

provider, ok := endpointMap["provider"].(string)
if !ok {
// Handle error
fmt.Println("Provider not found")
}

providerList[i] = provider
}

for key := range providerConfig {
if value, exists := payload_data[key]; exists {
providerConfig[key] = value
}
}
// TODO: use mode and providerList to determine which provider to use
//modeList := payload_data["mode"].([]interface{})

// If everything is fine, return the providerConfig and nil error
println(providerConfig)
return providerConfig, nil
}
provider := "openai"

// TODO: the following is inefficient. Needs updating.
endpointsMap := payload_data["endpoints"].([]map[string]interface{})

var params map[string]interface{}

for _, endpoint := range endpointsMap {
if endpoint["provider"] == provider {
params := endpoint["params"].(map[string]interface{})
fmt.Println(params)
break
}
}

var defaultConfig interface{} // Assuming defaultConfig is a struct

if provider == "openai" {
defaultConfig = openai.OpenAiChatDefaultConfig() // this is a struct
} else if provider == "cohere" {
defaultConfig = openai.OpenAiChatDefaultConfig() //TODO: change this to cohere
}

// Use reflect to set the value in defaultConfig
v := reflect.ValueOf(defaultConfig).Elem()
for key, value := range params {
field := v.FieldByName(key)
if field.IsValid() && field.CanSet() {
switch field.Kind() {
case reflect.Int:
if val, ok := value.(int); ok {
field.SetInt(int64(val))
}
case reflect.String:
if val, ok := value.(string); ok {
field.SetString(val)
}
}
}
}

// Validate the struct
validate := validator.New()
err = validate.Struct(defaultConfig)
if err != nil {
fmt.Printf("Validation failed: %v\n", err)
return nil, err
}

return defaultConfig, nil
}
123 changes: 33 additions & 90 deletions pkg/providers/openai/chat.go
Original file line number Diff line number Diff line change
@@ -1,28 +1,22 @@
package openai

type OpenAiProviderConfig struct {
Model string `json:"model" validate:"required,lowercase"`
Messages ConfigItem `json:"messages" validate:"required"`
MaxTokens ConfigItem `json:"max_tokens" validate:"omitempty,gte=0"`
Temperature ConfigItem `json:"temperature" validate:"omitempty,gte=0,lte=2"`
TopP ConfigItem `json:"top_p" validate:"omitempty,gte=0,lte=1"`
N ConfigItem `json:"n" validate:"omitempty,gte=1"`
Stream ConfigItem `json:"stream" validate:"omitempty, boolean"`
Stop ConfigItem `json:"stop"`
PresencePenalty ConfigItem `json:"presence_penalty" validate:"omitempty,gte=-2,lte=2"`
FrequencyPenalty ConfigItem `json:"frequency_penalty" validate:"omitempty,gte=-2,lte=2"`
LogitBias ConfigItem `json:"logit_bias" validate:"omitempty"`
User ConfigItem `json:"user"`
Seed ConfigItem `json:"seed" validate:"omitempty,gte=0"`
Tools ConfigItem `json:"tools"`
ToolChoice ConfigItem `json:"tool_choice"`
ResponseFormat ConfigItem `json:"response_format"`
}

type ConfigItem struct {
Param string `json:"param" validate:"required"`
Required bool `json:"required" validate:"omitempty,boolean"`
Default interface{} `json:"default"`
Model string `json:"model" validate:"required,lowercase"`
Messages string `json:"messages" validate:"required"` // does this need to be updated to []string?
MaxTokens int `json:"max_tokens" validate:"omitempty,gte=0"`
Temperature int `json:"temperature" validate:"omitempty,gte=0,lte=2"`
TopP int `json:"top_p" validate:"omitempty,gte=0,lte=1"`
N int `json:"n" validate:"omitempty,gte=1"`
Stream bool `json:"stream" validate:"omitempty, boolean"`
Stop interface{} `json:"stop"`
PresencePenalty int `json:"presence_penalty" validate:"omitempty,gte=-2,lte=2"`
FrequencyPenalty int `json:"frequency_penalty" validate:"omitempty,gte=-2,lte=2"`
LogitBias *map[int]float64 `json:"logit_bias" validate:"omitempty"`
User interface{} `json:"user"`
Seed interface{} `json:"seed" validate:"omitempty,gte=0"`
Tools []string `json:"tools"`
ToolChoice interface{} `json:"tool_choice"`
ResponseFormat interface{} `json:"response_format"`
}

var defaultMessage = `[
Expand All @@ -38,73 +32,22 @@ var defaultMessage = `[

// Provide the request body for OpenAI's ChatCompletion API
func OpenAiChatDefaultConfig() OpenAiProviderConfig {
return OpenAiProviderConfig{
Model: "gpt-3.5-turbo",
Messages: ConfigItem{
Param: "messages",
Required: true,
Default: defaultMessage,
},
MaxTokens: ConfigItem{
Param: "max_tokens",
Required: false,
Default: 100,
},
Temperature: ConfigItem{
Param: "temperature",
Required: false,
Default: 1,
},
TopP: ConfigItem{
Param: "top_p",
Required: false,
Default: 1,
},
N: ConfigItem{
Param: "n",
Required: false,
Default: 1,
},
Stream: ConfigItem{
Param: "stream",
Required: false,
Default: false,
},
Stop: ConfigItem{
Param: "stop",
Required: false,
},
PresencePenalty: ConfigItem{
Param: "presence_penalty",
Required: false,
},
FrequencyPenalty: ConfigItem{
Param: "frequency_penalty",
Required: false,
},
LogitBias: ConfigItem{
Param: "logit_bias",
Required: false,
},
User: ConfigItem{
Param: "user",
Required: false,
},
Seed: ConfigItem{
Param: "seed",
Required: false,
},
Tools: ConfigItem{
Param: "tools",
Required: false,
},
ToolChoice: ConfigItem{
Param: "tool_choice",
Required: false,
},
ResponseFormat: ConfigItem{
Param: "response_format",
Required: false,
},
return OpenAiProviderConfig{
Model: "gpt-3.5-turbo",
Messages: defaultMessage,
MaxTokens: 100,
Temperature: 1,
TopP: 1,
N: 1,
Stream: false,
Stop: nil,
PresencePenalty: 0,
FrequencyPenalty: 0,
LogitBias: nil,
User: nil,
Seed: nil,
Tools: nil,
ToolChoice: nil,
ResponseFormat: nil,
}
}

0 comments on commit c7a9329

Please sign in to comment.