-
Notifications
You must be signed in to change notification settings - Fork 746
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add azure openai provider (#309)
* feat: add azure openai provider Signed-off-by: Aris Boutselis <aris.boutselis@senseon.io> * feat: validate backend name Signed-off-by: Aris Boutselis <aris.boutselis@senseon.io> * fix: remove BaseURL from the mandatory env variables Signed-off-by: Aris Boutselis <arisboutselis08@gmail.com> * fix: conflicts Signed-off-by: Aris Boutselis <aris.boutselis@senseon.io> * chore: updated logo (#365) Signed-off-by: Alex Jones <alexsimonjones@gmail.com> * chore: added changing banners (#367) Signed-off-by: Alex Jones <alexsimonjones@gmail.com> * feat: add additionalLabels to Service Monitor (#366) * feat: add additionalLabels to Service Monitor Signed-off-by: Brad McCoy <bradmccoydev@gmail.com> * feat: update additionalLabels Signed-off-by: Brad McCoy <bradmccoydev@gmail.com> --------- Signed-off-by: Brad McCoy <bradmccoydev@gmail.com> * fix: update README file's ai provider section. Signed-off-by: Aris Boutselis <aris.boutselis@senseon.io> --------- Signed-off-by: Aris Boutselis <aris.boutselis@senseon.io> Signed-off-by: Aris Boutselis <arisboutselis08@gmail.com> Signed-off-by: Alex Jones <alexsimonjones@gmail.com> Signed-off-by: Brad McCoy <bradmccoydev@gmail.com> Co-authored-by: Aris Boutselis <arisboutselis08@gmail.com> Co-authored-by: Alex Jones <alexsimonjones@gmail.com> Co-authored-by: Brad McCoy <bradmccoydev@gmail.com>
- Loading branch information
1 parent
a89a5cf
commit d8357ce
Showing
5 changed files
with
195 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
package ai | ||
|
||
import ( | ||
"context" | ||
"encoding/base64" | ||
"errors" | ||
"fmt" | ||
"strings" | ||
|
||
"github.com/k8sgpt-ai/k8sgpt/pkg/cache" | ||
"github.com/k8sgpt-ai/k8sgpt/pkg/util" | ||
|
||
"github.com/fatih/color" | ||
|
||
"github.com/sashabaranov/go-openai" | ||
) | ||
|
||
type AzureAIClient struct { | ||
client *openai.Client | ||
language string | ||
model string | ||
} | ||
|
||
func (c *AzureAIClient) Configure(config IAIConfig, lang string) error { | ||
token := config.GetPassword() | ||
baseURL := config.GetBaseURL() | ||
engine := config.GetEngine() | ||
defaultConfig := openai.DefaultAzureConfig(token, baseURL, engine) | ||
client := openai.NewClientWithConfig(defaultConfig) | ||
if client == nil { | ||
return errors.New("error creating Azure OpenAI client") | ||
} | ||
c.language = lang | ||
c.client = client | ||
c.model = config.GetModel() | ||
return nil | ||
} | ||
|
||
func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) { | ||
// Create a completion request | ||
resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ | ||
Model: c.model, | ||
Messages: []openai.ChatCompletionMessage{ | ||
{ | ||
Role: "user", | ||
Content: fmt.Sprintf(default_prompt, c.language, prompt), | ||
}, | ||
}, | ||
}) | ||
if err != nil { | ||
return "", err | ||
} | ||
return resp.Choices[0].Message.Content, nil | ||
} | ||
|
||
func (a *AzureAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache) (string, error) { | ||
inputKey := strings.Join(prompt, " ") | ||
// Check for cached data | ||
cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey) | ||
|
||
if !cache.IsCacheDisabled() && cache.Exists(cacheKey) { | ||
response, err := cache.Load(cacheKey) | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
if response != "" { | ||
output, err := base64.StdEncoding.DecodeString(response) | ||
if err != nil { | ||
color.Red("error decoding cached data: %v", err) | ||
return "", nil | ||
} | ||
return string(output), nil | ||
} | ||
} | ||
|
||
response, err := a.GetCompletion(ctx, inputKey) | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response))) | ||
|
||
if err != nil { | ||
color.Red("error storing value to cache: %v", err) | ||
return "", nil | ||
} | ||
|
||
return response, nil | ||
} | ||
|
||
func (a *AzureAIClient) GetName() string { | ||
return "azureopenai" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters