From c06cfb02b448f4f5752eb77ad8576839704c1a42 Mon Sep 17 00:00:00 2001 From: pacificbelt30 <57101176+pacificbelt30@users.noreply.github.com> Date: Thu, 25 Jul 2024 11:42:00 +0900 Subject: [PATCH 1/9] VertexAI prompts now include images --- cmd/main.go | 40 ++++++++++++++++++++--------- internal/ai/ai.go | 6 +++++ internal/ai/vertexai.go | 10 +++++++- internal/utils/utils.go | 56 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 99 insertions(+), 13 deletions(-) diff --git a/cmd/main.go b/cmd/main.go index 1c6c1f3..f38a589 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -4,6 +4,7 @@ import ( "flag" "log" "os" + "regexp" "github.com/3-shake/alert-menta/internal/ai" "github.com/3-shake/alert-menta/internal/github" @@ -22,9 +23,9 @@ func main() { analyze: Perform a root cause analysis based on the contents of the Issue. suggest: Provide suggestions for improvement based on the contents of the Issue. ask: Answer free-text questions.`) - configFile = flag.String("config", "./internal/config/config.yaml", "Configuration file") - gh_token = flag.String("github-token", "", "GitHub token") - oai_key = flag.String("api-key", "", "OpenAI api key") + configFile = flag.String("config", "./internal/config/config.yaml", "Configuration file") + gh_token = flag.String("github-token", "", "GitHub token") + oai_key = flag.String("api-key", "", "OpenAI api key") ) flag.Parse() @@ -68,6 +69,7 @@ func main() { user_prompt += "Body:" + *body + "\n" // Get comments under the Issue and add them to the user prompt except for comments by Actions. + images := []ai.Image{} comments, _ := issue.GetComments() if err != nil || comments == nil { logger.Fatalf("Error getting comments: %v", err) @@ -80,23 +82,37 @@ func main() { logger.Printf("%s: %s", *v.User.Login, *v.Body) } user_prompt += *v.User.Login + ":" + *v.Body + "\n" + + // Get image + imageRegex := regexp.MustCompile(`!\[(.*?)\]\((.*?)\)`) // Get URL from ![alt](url) + matches := imageRegex.FindAllStringSubmatch(*v.Body, -1) + for _, match := range matches { + logger.Println(match[2]) // 画像の URL を出力 + image_url, ext, err := utils.DownloadImage(match[2], *gh_token) + if err != nil { + logger.Fatalf("Error downloading image: %s", err) + return + } + + images = append(images, ai.Image{Data: image_url, Extension: ext}) + } } // Set system prompt var system_prompt string - if *command == "ask" { - if *intent == "" { + if *command == "ask" { + if *intent == "" { log.SetOutput(os.Stdout) logger.Println("Error: intent is required for 'ask' command") flag.PrintDefaults() os.Exit(1) - } - system_prompt = cfg.Ai.Commands[*command].System_prompt + *intent - } else { - system_prompt = cfg.Ai.Commands[*command].System_prompt - } - prompt := ai.Prompt{UserPrompt: user_prompt, SystemPrompt: system_prompt} - logger.Println("\x1b[34mPrompt: |\n", prompt.SystemPrompt, prompt.UserPrompt, "\x1b[0m") + } + system_prompt = cfg.Ai.Commands[*command].System_prompt + *intent + } else { + system_prompt = cfg.Ai.Commands[*command].System_prompt + } + prompt := ai.Prompt{UserPrompt: user_prompt, SystemPrompt: system_prompt, Images: images} + logger.Println("\x1b[34mPrompt: |\n", prompt.SystemPrompt, prompt.UserPrompt, "\x1b[0m + ", len(prompt.Images), "images") // Get response from OpenAI or VertexAI var aic ai.Ai diff --git a/internal/ai/ai.go b/internal/ai/ai.go index ce0b506..21002cb 100644 --- a/internal/ai/ai.go +++ b/internal/ai/ai.go @@ -4,7 +4,13 @@ type Ai interface { GetResponse(prompt Prompt) (string, error) } +type Image struct { + Data []byte + Extension string +} + type Prompt struct { UserPrompt string SystemPrompt string + Images []Image } diff --git a/internal/ai/vertexai.go b/internal/ai/vertexai.go index 6f28ef6..d31801b 100644 --- a/internal/ai/vertexai.go +++ b/internal/ai/vertexai.go @@ -18,7 +18,14 @@ func (ai *VertexAI) GetResponse(prompt Prompt) (string, error) { model := ai.client.GenerativeModel(ai.model) model.SetTemperature(0.9) - resp, err := model.GenerateContent(ai.context, genai.Text(prompt.SystemPrompt+prompt.UserPrompt)) + integrated_prompt := []genai.Part{} // image + text prompt + for _, image := range prompt.Images { + integrated_prompt = append(integrated_prompt, genai.ImageData(image.Extension, image.Data)) + } + integrated_prompt = append(integrated_prompt, genai.Text(prompt.SystemPrompt+prompt.UserPrompt)) + + // Generate AI response + resp, err := model.GenerateContent(ai.context, integrated_prompt...) if err != nil { log.Fatal(err) return "", err @@ -41,6 +48,7 @@ func getResponseText(resp *genai.GenerateContentResponse) string { func NewVertexAIClient(projectID, localtion, modelName string) *VertexAI { // Secret is provided in json and PATH is specified in the environment variable `GOOGLE_APPLICATION_CREDENTIALS`. + // If you are using gcloud cli authentication or workload identity federation, you do not need to specify the secret json file. ctx := context.Background() client, err := genai.NewClient(ctx, projectID, localtion) if err != nil { diff --git a/internal/utils/utils.go b/internal/utils/utils.go index f3bbe5b..c9bf3bb 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -1,9 +1,14 @@ package utils import ( + "fmt" + "io" "log" + "net/http" "os" "path/filepath" + "regexp" + "time" "github.com/spf13/viper" ) @@ -89,3 +94,54 @@ func NewConfig(filename string) (*Config, error) { logger.Println("Config:", cfg) return cfg, nil } + +func DownloadImage(url string, token string) ([]byte, string, error) { + // Create a new HTTP client + client := &http.Client{ + Timeout: 15 * time.Second, + } + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return []byte{}, "", fmt.Errorf("failed to create a new request: %w", err) + } + + // Download the image with the token + req.Header.Set("Authorization", "Bearer "+token) // set token to header + resp, err := client.Do(req) + if err != nil { + return []byte{}, "", fmt.Errorf("failed to get a response: %w", err) + } + defer resp.Body.Close() + + // Write the response body to the temporary file + file, err := os.CreateTemp("", "downloaded-image-*") + if err != nil { + return []byte{}, "", fmt.Errorf("failed to create a temporary file: %w", err) + } + defer func() { + log.Println("remove", file.Name(), "Content-Type:", resp.Header.Get("Content-Type")) + file.Close() + os.Remove(file.Name()) + }() + _, err = io.Copy(file, resp.Body) + if err != nil { + return []byte{}, "", fmt.Errorf("failed to write the response body to the temporary file: %w", err) + } + + // Read image data from the temporary file + data, err := os.ReadFile(file.Name()) + if err != nil { + return []byte{}, "", fmt.Errorf("failed to read the file: %w", err) + } + + // Get the extension of the image + contentType := resp.Header.Get("Content-Type") + imageRegex := regexp.MustCompile(`.+/(.*)`) + matches := imageRegex.FindAllStringSubmatch(contentType, -1) + if len(matches) == 0 { + return []byte{}, "", fmt.Errorf("failed to get the extension of the image") + } + ext := matches[0][1] + + return data, ext, nil +} From 956fb768ab14adc08bd55d1250c8ec6538deefdf Mon Sep 17 00:00:00 2001 From: pacificbelt30 <57101176+pacificbelt30@users.noreply.github.com> Date: Thu, 25 Jul 2024 13:04:29 +0900 Subject: [PATCH 2/9] Replaced Japanese comments with English. --- cmd/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/main.go b/cmd/main.go index d51dea6..e24cd58 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -93,7 +93,7 @@ func main() { imageRegex := regexp.MustCompile(`!\[(.*?)\]\((.*?)\)`) // Get URL from ![alt](url) matches := imageRegex.FindAllStringSubmatch(*v.Body, -1) for _, match := range matches { - logger.Println(match[2]) // 画像の URL を出力 + logger.Println(match[2]) // Output the URL of the image image_url, ext, err := utils.DownloadImage(match[2], *gh_token) if err != nil { logger.Fatalf("Error downloading image: %s", err) From e9e49d963b9d03cdc825d98e905164159a61dd56 Mon Sep 17 00:00:00 2001 From: pacificbelt30 <57101176+pacificbelt30@users.noreply.github.com> Date: Fri, 9 Aug 2024 11:41:53 +0900 Subject: [PATCH 3/9] Added the function to perform image recognition when OpenAI is used. --- cmd/main.go | 4 ++-- internal/ai/openai.go | 23 ++++++++++++++++++++++- internal/utils/utils.go | 6 ++++++ 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/cmd/main.go b/cmd/main.go index f47e26b..d2ba135 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -116,7 +116,7 @@ func main() { } else { system_prompt = cfg.Ai.Commands[*command].System_prompt } - prompt := ai.Prompt{UserPrompt: user_prompt, SystemPrompt: system_prompt} + prompt := ai.Prompt{UserPrompt: user_prompt, SystemPrompt: system_prompt, Images: images} logger.Println("\x1b[34mPrompt: |\n", prompt.SystemPrompt, prompt.UserPrompt, "\x1b[0m") // Get response from OpenAI or VertexAI @@ -137,7 +137,7 @@ func main() { } comment, _ := aic.GetResponse(prompt) - logger.Println("Response:", comment) + logger.Println("\x1b[32mResponse:", comment, "\x1b[0m") // Post a comment on the Issue err = issue.PostComment(comment) diff --git a/internal/ai/openai.go b/internal/ai/openai.go index fdb78e2..fd70d58 100644 --- a/internal/ai/openai.go +++ b/internal/ai/openai.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "github.com/3-shake/alert-menta/internal/utils" "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" "github.com/Azure/azure-sdk-for-go/sdk/azcore" ) @@ -18,10 +19,30 @@ func (ai *OpenAI) GetResponse(prompt Prompt) (string, error) { keyCredential := azcore.NewKeyCredential(ai.apiKey) client, _ := azopenai.NewClientForOpenAI("https://api.openai.com/v1/", keyCredential, nil) + // Convert images to base64 + base64Images := func(images []Image) []string { + var base64Images []string + for _, image := range images { + base64Images = append(base64Images, utils.ImageToBase64(image.Data, image.Extension)) + } + return base64Images + }(prompt.Images) + + // create a user prompt with text and images + user_prompt := []azopenai.ChatCompletionRequestMessageContentPartClassification{ + &azopenai.ChatCompletionRequestMessageContentPartText{Text: &prompt.UserPrompt}, + } + for _, image := range base64Images { + user_prompt = append(user_prompt, &azopenai.ChatCompletionRequestMessageContentPartImage{ImageURL: &azopenai.ChatCompletionRequestMessageContentPartImageURL{URL: &image}}) + } + // Create a chat request with the prompt messages := []azopenai.ChatRequestMessageClassification{ + &azopenai.ChatRequestSystemMessage{ + Content: &prompt.SystemPrompt, + }, &azopenai.ChatRequestUserMessage{ - Content: azopenai.NewChatRequestUserMessageContent(prompt.SystemPrompt + prompt.UserPrompt), + Content: azopenai.NewChatRequestUserMessageContent(user_prompt), }, } diff --git a/internal/utils/utils.go b/internal/utils/utils.go index c9bf3bb..85ff43e 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -1,6 +1,7 @@ package utils import ( + "encoding/base64" "fmt" "io" "log" @@ -145,3 +146,8 @@ func DownloadImage(url string, token string) ([]byte, string, error) { return data, ext, nil } + +func ImageToBase64(data []byte, ext string) string { + base64img := base64.StdEncoding.EncodeToString(data) + return "data:image/" + ext + ";base64," + base64img +} \ No newline at end of file From 1df2952e96b61df71c320b7deb46e25a185eea37 Mon Sep 17 00:00:00 2001 From: pacificbelt30 <57101176+pacificbelt30@users.noreply.github.com> Date: Tue, 24 Sep 2024 14:18:48 +0900 Subject: [PATCH 4/9] Change the name of the variable that holds the image to be downloaded from GitHub image_url -> img_data --- cmd/main.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/main.go b/cmd/main.go index d2ba135..f0b1de6 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -94,13 +94,13 @@ func main() { matches := imageRegex.FindAllStringSubmatch(*v.Body, -1) for _, match := range matches { logger.Println(match[2]) // Output the URL of the image - image_url, ext, err := utils.DownloadImage(match[2], *gh_token) + img_data, ext, err := utils.DownloadImage(match[2], *gh_token) if err != nil { logger.Fatalf("Error downloading image: %s", err) return } - images = append(images, ai.Image{Data: image_url, Extension: ext}) + images = append(images, ai.Image{Data: img_data, Extension: ext}) } } From 848408343214d7d50bc11b46fbe5bdc8bcf59545 Mon Sep 17 00:00:00 2001 From: vagrant Date: Fri, 15 Nov 2024 05:58:48 +0000 Subject: [PATCH 5/9] update main.go vertexai.go --- cmd/main.go | 10 +++++----- internal/ai/vertexai.go | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/cmd/main.go b/cmd/main.go index e37e6fb..c0a39aa 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -59,7 +59,7 @@ func main() { issue := github.NewIssue(cfg.owner, cfg.repo, cfg.issueNumber, cfg.ghToken) - userPrompt, imgs, err := constructUserPrompt(issue, loadedcfg, logger) + userPrompt, imgs, err := constructUserPrompt(cfg.ghToken, issue, loadedcfg, logger) if err != nil { logger.Fatalf("Erro constructing userPrompt: %v", err) } @@ -98,7 +98,7 @@ func validateCommand(command string, cfg *utils.Config) error { } // Construct user prompt from issue -func constructUserPrompt(issue *github.GitHubIssue, cfg *utils.Config, logger *log.Logger) (string, []ai.Image, error) { +func constructUserPrompt(ghToken string, issue *github.GitHubIssue, cfg *utils.Config, logger *log.Logger) (string, []ai.Image, error) { title, err := issue.GetTitle() if err != nil { return "", nil, fmt.Errorf("Error getting Title: %w", err) @@ -130,10 +130,10 @@ func constructUserPrompt(issue *github.GitHubIssue, cfg *utils.Config, logger *l } userPrompt.WriteString(*v.User.Login + ":" + *v.Body + "\n") - matches := imageRegex.FindAllStringSubmatch(*comment.Body, -1) + matches := imageRegex.FindAllStringSubmatch(*v.Body, -1) for _, match := range matches { logger.Println("Image URL:", match[2]) // Log the URL of the image - imgData, ext, err := utils.DownloadImage(match[2], *ghToken) + imgData, ext, err := utils.DownloadImage(match[2], ghToken) if err != nil { return "", nil, fmt.Errorf("Error downloading image: %w", err) } @@ -156,7 +156,7 @@ func constructPrompt(command, intent, userPrompt string, imgs []ai.Image, cfg *u systemPrompt = cfg.Ai.Commands[command].System_prompt } logger.Println("\x1b[34mPrompt: |\n", systemPrompt, userPrompt, "\x1b[0m") - return &ai.Prompt{UserPrompt: userPrompt, SystemPrompt: systemPrompt, Images: images}, nil + return &ai.Prompt{UserPrompt: userPrompt, SystemPrompt: systemPrompt, Images: imgs}, nil } // Initialize AI client diff --git a/internal/ai/vertexai.go b/internal/ai/vertexai.go index 5a48816..f513d27 100644 --- a/internal/ai/vertexai.go +++ b/internal/ai/vertexai.go @@ -24,7 +24,7 @@ func (ai *VertexAI) GetResponse(prompt *Prompt) (string, error) { for _, image := range prompt.Images { integrated_prompt = append(integrated_prompt, genai.ImageData(image.Extension, image.Data)) } - integrated_prompt = append(integrated_prompt, genai.Text(prompt.SystemPrompt + prompt.UserPrompt)) + integrated_prompt = append(integrated_prompt, genai.Text(prompt.SystemPrompt + prompt.UserPrompt)) // Generate AI response resp, err := model.GenerateContent(ai.context, integrated_prompt...) From e041ff46f695bc2263010218d5945a97727fd1a8 Mon Sep 17 00:00:00 2001 From: vagrant Date: Fri, 15 Nov 2024 06:43:21 +0000 Subject: [PATCH 6/9] update openai.go --- internal/ai/openai.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/ai/openai.go b/internal/ai/openai.go index a40a949..fa03ab1 100644 --- a/internal/ai/openai.go +++ b/internal/ai/openai.go @@ -39,7 +39,7 @@ func (ai *OpenAI) GetResponse(prompt *Prompt) (string, error) { // Create a chat request with the prompt messages := []azopenai.ChatRequestMessageClassification{ &azopenai.ChatRequestSystemMessage{ - Content: &prompt.SystemPrompt, + Content: azopenai.NewChatRequestSystemMessageContent(prompt.SystemPrompt), }, &azopenai.ChatRequestUserMessage{ Content: azopenai.NewChatRequestUserMessageContent(user_prompt), From a1992f1cc198c3e892fe703808d0b63f974ab23d Mon Sep 17 00:00:00 2001 From: vagrant Date: Fri, 15 Nov 2024 08:46:58 +0000 Subject: [PATCH 7/9] upadate .alert-menta.user.yaml --- .alert-menta.user.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.alert-menta.user.yaml b/.alert-menta.user.yaml index c03cb38..74664f3 100644 --- a/.alert-menta.user.yaml +++ b/.alert-menta.user.yaml @@ -5,7 +5,7 @@ system: ai: provider: "openai" # "openai" or "vertexai" openai: - model: "gpt-3.5-turbo" # Check the list of available models by `curl https://api.openai.com/v1/models -H "Authorization: Bearer $OPENAI_API_KEY"` + model: "gpt-4o-mini-2024-07-18" # Check the list of available models by `curl https://api.openai.com/v1/models -H "Authorization: Bearer $OPENAI_API_KEY"` vertexai: project: "" From 724a14dead308baf4007e844b3be2106693bd0a9 Mon Sep 17 00:00:00 2001 From: Keigo Kurita Date: Mon, 6 Jan 2025 13:21:07 +0900 Subject: [PATCH 8/9] update main_test.go --- cmd/main_test.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/cmd/main_test.go b/cmd/main_test.go index fa066da..7ccdad5 100644 --- a/cmd/main_test.go +++ b/cmd/main_test.go @@ -6,6 +6,7 @@ import ( "log" "os" + "github.com/3-shake/alert-menta/internal/ai" "github.com/3-shake/alert-menta/internal/utils" ) @@ -54,17 +55,18 @@ func TestConstructPrompt(t *testing.T) { command string intent string userPrompt string + imgs []ai.Image expectErr bool expectedSystemPrompt string }{ - {"Valid Ask Command with Intent", "ask", "What is the first thing to work on in suggestions?", "userPrompt", false, "Ask system prompt: What is the first thing to work on in suggestions?\n"}, - {"Ask Command without Intent", "ask", "", "userPrompt", true, ""}, - {"Valid Other Command", "other", "", "userPrompt", false, "Other system prompt: "}, + {"Valid Ask Command with Intent", "ask", "What is the first thing to work on in suggestions?", "userPrompt", []ai.Image{}, false, "Ask system prompt: What is the first thing to work on in suggestions?\n"}, + {"Ask Command without Intent", "ask", "", "userPrompt", []ai.Image{}, true, ""}, + {"Valid Other Command", "other", "", "userPrompt", []ai.Image{}, false, "Other system prompt: "}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - prompt, err := constructPrompt(tt.command, tt.intent, tt.userPrompt, mockCfg, logger) + prompt, err := constructPrompt(tt.command, tt.intent, tt.userPrompt, tt.imgs, mockCfg, logger) if (err != nil) != tt.expectErr { t.Errorf("expected error: %v, got error %v", tt.expectErr, err) } From 7e6c7700dc0fb5515bbb25474f158399601df611 Mon Sep 17 00:00:00 2001 From: Keigo Kurita Date: Mon, 6 Jan 2025 13:26:48 +0900 Subject: [PATCH 9/9] go fmt --- cmd/main.go | 294 +++++++++++++++++------------------ cmd/main_test.go | 174 ++++++++++----------- internal/ai/vertexai.go | 4 +- internal/utils/utils.go | 2 +- internal/utils/utils_test.go | 74 ++++----- 5 files changed, 274 insertions(+), 274 deletions(-) diff --git a/cmd/main.go b/cmd/main.go index c0a39aa..010d151 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -2,11 +2,11 @@ package main import ( "flag" + "fmt" "log" "os" "regexp" "strings" - "fmt" "github.com/3-shake/alert-menta/internal/ai" "github.com/3-shake/alert-menta/internal/github" @@ -15,169 +15,169 @@ import ( // Struct to hold the command-line arguments type Config struct { - repo string - owner string - issueNumber int - intent string - command string - configFile string - ghToken string - oaiKey string + repo string + owner string + issueNumber int + intent string + command string + configFile string + ghToken string + oaiKey string } func main() { - cfg := &Config{} - flag.StringVar(&cfg.repo, "repo", "", "Repository name") - flag.StringVar(&cfg.owner, "owner", "", "Repository owner") - flag.IntVar(&cfg.issueNumber, "issue", 0, "Issue number") - flag.StringVar(&cfg.intent, "intent", "", "Question or intent for the 'ask' command") - flag.StringVar(&cfg.command, "command", "", "Commands to be executed by AI. Commands defined in the configuration file are available.") - flag.StringVar(&cfg.configFile, "config", "", "Configuration file") - flag.StringVar(&cfg.ghToken, "github-token", "", "GitHub token") - flag.StringVar(&cfg.oaiKey, "api-key", "", "OpenAI api key") - flag.Parse() - - if cfg.repo == "" || cfg.owner == "" || cfg.issueNumber == 0 || cfg.ghToken == "" || cfg.command == "" || cfg.configFile == "" { - flag.PrintDefaults() - os.Exit(1) - } - - logger := log.New( - os.Stdout, "[alert-menta main] ", - log.Ldate|log.Ltime|log.Llongfile|log.Lmsgprefix, - ) - - loadedcfg, err := utils.NewConfig(cfg.configFile) - if err != nil { - logger.Fatalf("Error loading config: %v", err) - } - - err = validateCommand(cfg.command, loadedcfg) - if err != nil { - logger.Fatalf("Error validating command: %v", err) - } - - issue := github.NewIssue(cfg.owner, cfg.repo, cfg.issueNumber, cfg.ghToken) - - userPrompt, imgs, err := constructUserPrompt(cfg.ghToken, issue, loadedcfg, logger) - if err != nil { - logger.Fatalf("Erro constructing userPrompt: %v", err) - } - - prompt, err := constructPrompt(cfg.command, cfg.intent, userPrompt, imgs, loadedcfg, logger) - if err != nil { - logger.Fatalf("Error constructing prompt: %v", err) - } - - aic, err := getAIClient(cfg.oaiKey, loadedcfg, logger) - if err != nil { - logger.Fatalf("Error geting AI client: %v", err) - } - - comment, err := aic.GetResponse(prompt) - if err != nil { - logger.Fatalf("Error getting Response: %v", err) - } - logger.Println("Response:", comment) - - if err := issue.PostComment(comment); err != nil { - logger.Fatalf("Error creating comment: %v", err) - } + cfg := &Config{} + flag.StringVar(&cfg.repo, "repo", "", "Repository name") + flag.StringVar(&cfg.owner, "owner", "", "Repository owner") + flag.IntVar(&cfg.issueNumber, "issue", 0, "Issue number") + flag.StringVar(&cfg.intent, "intent", "", "Question or intent for the 'ask' command") + flag.StringVar(&cfg.command, "command", "", "Commands to be executed by AI. Commands defined in the configuration file are available.") + flag.StringVar(&cfg.configFile, "config", "", "Configuration file") + flag.StringVar(&cfg.ghToken, "github-token", "", "GitHub token") + flag.StringVar(&cfg.oaiKey, "api-key", "", "OpenAI api key") + flag.Parse() + + if cfg.repo == "" || cfg.owner == "" || cfg.issueNumber == 0 || cfg.ghToken == "" || cfg.command == "" || cfg.configFile == "" { + flag.PrintDefaults() + os.Exit(1) + } + + logger := log.New( + os.Stdout, "[alert-menta main] ", + log.Ldate|log.Ltime|log.Llongfile|log.Lmsgprefix, + ) + + loadedcfg, err := utils.NewConfig(cfg.configFile) + if err != nil { + logger.Fatalf("Error loading config: %v", err) + } + + err = validateCommand(cfg.command, loadedcfg) + if err != nil { + logger.Fatalf("Error validating command: %v", err) + } + + issue := github.NewIssue(cfg.owner, cfg.repo, cfg.issueNumber, cfg.ghToken) + + userPrompt, imgs, err := constructUserPrompt(cfg.ghToken, issue, loadedcfg, logger) + if err != nil { + logger.Fatalf("Erro constructing userPrompt: %v", err) + } + + prompt, err := constructPrompt(cfg.command, cfg.intent, userPrompt, imgs, loadedcfg, logger) + if err != nil { + logger.Fatalf("Error constructing prompt: %v", err) + } + + aic, err := getAIClient(cfg.oaiKey, loadedcfg, logger) + if err != nil { + logger.Fatalf("Error geting AI client: %v", err) + } + + comment, err := aic.GetResponse(prompt) + if err != nil { + logger.Fatalf("Error getting Response: %v", err) + } + logger.Println("Response:", comment) + + if err := issue.PostComment(comment); err != nil { + logger.Fatalf("Error creating comment: %v", err) + } } // Validate the provided command func validateCommand(command string, cfg *utils.Config) error { - if _, ok := cfg.Ai.Commands[command]; !ok { - allowedCommands := make([]string, 0, len(cfg.Ai.Commands)) - for cmd := range cfg.Ai.Commands { - allowedCommands = append(allowedCommands, cmd) - } - return fmt.Errorf("Invalid command: %s. Allowed commands are %s", command, strings.Join(allowedCommands, ", ")) - } - return nil + if _, ok := cfg.Ai.Commands[command]; !ok { + allowedCommands := make([]string, 0, len(cfg.Ai.Commands)) + for cmd := range cfg.Ai.Commands { + allowedCommands = append(allowedCommands, cmd) + } + return fmt.Errorf("Invalid command: %s. Allowed commands are %s", command, strings.Join(allowedCommands, ", ")) + } + return nil } // Construct user prompt from issue func constructUserPrompt(ghToken string, issue *github.GitHubIssue, cfg *utils.Config, logger *log.Logger) (string, []ai.Image, error) { - title, err := issue.GetTitle() - if err != nil { - return "", nil, fmt.Errorf("Error getting Title: %w", err) - } - - body, err := issue.GetBody() - if err != nil { - return "", nil, fmt.Errorf("Error getting Body: %w", err) - } - - var userPrompt strings.Builder - userPrompt.WriteString("Title:" + *title + "\n") - userPrompt.WriteString("Body:" + *body + "\n") - - comments, err := issue.GetComments() - if err != nil { - return "", nil, fmt.Errorf("Error getting comments: %w", err) - } - - var images []ai.Image - imageRegex := regexp.MustCompile(`!\[(.*?)\]\((.*?)\)`) - - for _, v := range comments { - if *v.User.Login == "github-actions[bot]" { - continue - } - if cfg.System.Debug.Log_level == "debug" { - logger.Printf("%s: %s", *v.User.Login, *v.Body) - } - userPrompt.WriteString(*v.User.Login + ":" + *v.Body + "\n") + title, err := issue.GetTitle() + if err != nil { + return "", nil, fmt.Errorf("Error getting Title: %w", err) + } + + body, err := issue.GetBody() + if err != nil { + return "", nil, fmt.Errorf("Error getting Body: %w", err) + } + + var userPrompt strings.Builder + userPrompt.WriteString("Title:" + *title + "\n") + userPrompt.WriteString("Body:" + *body + "\n") + + comments, err := issue.GetComments() + if err != nil { + return "", nil, fmt.Errorf("Error getting comments: %w", err) + } + + var images []ai.Image + imageRegex := regexp.MustCompile(`!\[(.*?)\]\((.*?)\)`) + + for _, v := range comments { + if *v.User.Login == "github-actions[bot]" { + continue + } + if cfg.System.Debug.Log_level == "debug" { + logger.Printf("%s: %s", *v.User.Login, *v.Body) + } + userPrompt.WriteString(*v.User.Login + ":" + *v.Body + "\n") matches := imageRegex.FindAllStringSubmatch(*v.Body, -1) - for _, match := range matches { - logger.Println("Image URL:", match[2]) // Log the URL of the image - imgData, ext, err := utils.DownloadImage(match[2], ghToken) - if err != nil { - return "", nil, fmt.Errorf("Error downloading image: %w", err) - } - - images = append(images, ai.Image{Data: imgData, Extension: ext}) - } - } - return userPrompt.String(), images, nil + for _, match := range matches { + logger.Println("Image URL:", match[2]) // Log the URL of the image + imgData, ext, err := utils.DownloadImage(match[2], ghToken) + if err != nil { + return "", nil, fmt.Errorf("Error downloading image: %w", err) + } + + images = append(images, ai.Image{Data: imgData, Extension: ext}) + } + } + return userPrompt.String(), images, nil } // Construct AI prompt -func constructPrompt(command, intent, userPrompt string, imgs []ai.Image, cfg *utils.Config, logger *log.Logger) (*ai.Prompt, error){ - var systemPrompt string - if command == "ask" { - if intent == "" { - return nil, fmt.Errorf("Error: intent is required for 'ask' command") - } - systemPrompt = cfg.Ai.Commands[command].System_prompt + intent + "\n" - } else { - systemPrompt = cfg.Ai.Commands[command].System_prompt - } - logger.Println("\x1b[34mPrompt: |\n", systemPrompt, userPrompt, "\x1b[0m") - return &ai.Prompt{UserPrompt: userPrompt, SystemPrompt: systemPrompt, Images: imgs}, nil +func constructPrompt(command, intent, userPrompt string, imgs []ai.Image, cfg *utils.Config, logger *log.Logger) (*ai.Prompt, error) { + var systemPrompt string + if command == "ask" { + if intent == "" { + return nil, fmt.Errorf("Error: intent is required for 'ask' command") + } + systemPrompt = cfg.Ai.Commands[command].System_prompt + intent + "\n" + } else { + systemPrompt = cfg.Ai.Commands[command].System_prompt + } + logger.Println("\x1b[34mPrompt: |\n", systemPrompt, userPrompt, "\x1b[0m") + return &ai.Prompt{UserPrompt: userPrompt, SystemPrompt: systemPrompt, Images: imgs}, nil } // Initialize AI client func getAIClient(oaiKey string, cfg *utils.Config, logger *log.Logger) (ai.Ai, error) { - switch cfg.Ai.Provider { - case "openai": - if oaiKey == "" { - return nil, fmt.Errorf("Error: Please provide your Open AI API key") - } - logger.Println("Using OpenAI API") - logger.Println("OpenAI model:", cfg.Ai.OpenAI.Model) - return ai.NewOpenAIClient(oaiKey, cfg.Ai.OpenAI.Model), nil - case "vertexai": - logger.Println("Using VertexAI API") - logger.Println("VertexAI model:", cfg.Ai.VertexAI.Model) - aic, err := ai.NewVertexAIClient(cfg.Ai.VertexAI.Project, cfg.Ai.VertexAI.Region, cfg.Ai.VertexAI.Model) - if err != nil { - return nil, fmt.Errorf("Error: new Vertex AI client: %w", err) - } - return aic, nil - default: - return nil, fmt.Errorf("Error: Invalid provider") - } -} \ No newline at end of file + switch cfg.Ai.Provider { + case "openai": + if oaiKey == "" { + return nil, fmt.Errorf("Error: Please provide your Open AI API key") + } + logger.Println("Using OpenAI API") + logger.Println("OpenAI model:", cfg.Ai.OpenAI.Model) + return ai.NewOpenAIClient(oaiKey, cfg.Ai.OpenAI.Model), nil + case "vertexai": + logger.Println("Using VertexAI API") + logger.Println("VertexAI model:", cfg.Ai.VertexAI.Model) + aic, err := ai.NewVertexAIClient(cfg.Ai.VertexAI.Project, cfg.Ai.VertexAI.Region, cfg.Ai.VertexAI.Model) + if err != nil { + return nil, fmt.Errorf("Error: new Vertex AI client: %w", err) + } + return aic, nil + default: + return nil, fmt.Errorf("Error: Invalid provider") + } +} diff --git a/cmd/main_test.go b/cmd/main_test.go index 7ccdad5..61ffa2b 100644 --- a/cmd/main_test.go +++ b/cmd/main_test.go @@ -1,113 +1,113 @@ package main import ( - "errors" - "testing" + "errors" "log" "os" + "testing" - "github.com/3-shake/alert-menta/internal/ai" - "github.com/3-shake/alert-menta/internal/utils" + "github.com/3-shake/alert-menta/internal/ai" + "github.com/3-shake/alert-menta/internal/utils" ) // Test for validateCommand func TestValidateCommand(t *testing.T) { - mockCfg := &utils.Config{ - Ai: utils.Ai{ - Commands: map[string]utils.Command{ - "valid": {System_prompt: "hoge"}, - }, - }, - } + mockCfg := &utils.Config{ + Ai: utils.Ai{ + Commands: map[string]utils.Command{ + "valid": {System_prompt: "hoge"}, + }, + }, + } - tests := []struct { - command string - expected error - }{ - {"valid", nil}, - {"invalid", errors.New("Invalid command: invalid. Allowed commands are valid")}, - } + tests := []struct { + command string + expected error + }{ + {"valid", nil}, + {"invalid", errors.New("Invalid command: invalid. Allowed commands are valid")}, + } - for _, tt := range tests { - err := validateCommand(tt.command, mockCfg) - if err != nil && err.Error() != tt.expected.Error() { - t.Errorf("expected %v, got %v", tt.expected, err) - } - } + for _, tt := range tests { + err := validateCommand(tt.command, mockCfg) + if err != nil && err.Error() != tt.expected.Error() { + t.Errorf("expected %v, got %v", tt.expected, err) + } + } } // Test for constructPrompt func TestConstructPrompt(t *testing.T) { - mockCfg := &utils.Config{ - Ai: utils.Ai{ - Commands: map[string]utils.Command{ - "ask": {System_prompt: "Ask system prompt: "}, - "other": {System_prompt: "Other system prompt: "}, - }, - }, - } + mockCfg := &utils.Config{ + Ai: utils.Ai{ + Commands: map[string]utils.Command{ + "ask": {System_prompt: "Ask system prompt: "}, + "other": {System_prompt: "Other system prompt: "}, + }, + }, + } - // Logger setup for testing - logger := log.New(os.Stdout, "", 0) + // Logger setup for testing + logger := log.New(os.Stdout, "", 0) - tests := []struct { - name string - command string - intent string - userPrompt string - imgs []ai.Image - expectErr bool - expectedSystemPrompt string - }{ - {"Valid Ask Command with Intent", "ask", "What is the first thing to work on in suggestions?", "userPrompt", []ai.Image{}, false, "Ask system prompt: What is the first thing to work on in suggestions?\n"}, - {"Ask Command without Intent", "ask", "", "userPrompt", []ai.Image{}, true, ""}, - {"Valid Other Command", "other", "", "userPrompt", []ai.Image{}, false, "Other system prompt: "}, - } + tests := []struct { + name string + command string + intent string + userPrompt string + imgs []ai.Image + expectErr bool + expectedSystemPrompt string + }{ + {"Valid Ask Command with Intent", "ask", "What is the first thing to work on in suggestions?", "userPrompt", []ai.Image{}, false, "Ask system prompt: What is the first thing to work on in suggestions?\n"}, + {"Ask Command without Intent", "ask", "", "userPrompt", []ai.Image{}, true, ""}, + {"Valid Other Command", "other", "", "userPrompt", []ai.Image{}, false, "Other system prompt: "}, + } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - prompt, err := constructPrompt(tt.command, tt.intent, tt.userPrompt, tt.imgs, mockCfg, logger) - if (err != nil) != tt.expectErr { - t.Errorf("expected error: %v, got error %v", tt.expectErr, err) - } - if err == nil { - if prompt.SystemPrompt != tt.expectedSystemPrompt { - t.Errorf("expected system prompt: %s, got %s", tt.expectedSystemPrompt, prompt.SystemPrompt) - } - if prompt.UserPrompt != tt.userPrompt { - t.Errorf("expected user prompt: %s, got %s", tt.userPrompt, prompt.UserPrompt) - } - } - }) - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + prompt, err := constructPrompt(tt.command, tt.intent, tt.userPrompt, tt.imgs, mockCfg, logger) + if (err != nil) != tt.expectErr { + t.Errorf("expected error: %v, got error %v", tt.expectErr, err) + } + if err == nil { + if prompt.SystemPrompt != tt.expectedSystemPrompt { + t.Errorf("expected system prompt: %s, got %s", tt.expectedSystemPrompt, prompt.SystemPrompt) + } + if prompt.UserPrompt != tt.userPrompt { + t.Errorf("expected user prompt: %s, got %s", tt.userPrompt, prompt.UserPrompt) + } + } + }) + } } // Test for getAIClient func TestGetAIClient(t *testing.T) { - mockCfg := &utils.Config{ - Ai: utils.Ai{ - Provider: "openai", - OpenAI: utils.OpenAI{ - Model: "test-model", - }, - }, - } + mockCfg := &utils.Config{ + Ai: utils.Ai{ + Provider: "openai", + OpenAI: utils.OpenAI{ + Model: "test-model", + }, + }, + } - tests := []struct { - oaiKey string - expectErr bool - provider string - }{ - {"valid-key", false, "openai"}, - {"", true, "openai"}, + tests := []struct { + oaiKey string + expectErr bool + provider string + }{ + {"valid-key", false, "openai"}, + {"", true, "openai"}, {"", true, "invalid"}, - } + } - for _, tt := range tests { - mockCfg.Ai.Provider = tt.provider - _, err := getAIClient(tt.oaiKey, mockCfg, log.New(os.Stdout, "", 0)) - if (err != nil) != tt.expectErr { - t.Errorf("expected error: %v, got %v", tt.expectErr, err) - } - } -} \ No newline at end of file + for _, tt := range tests { + mockCfg.Ai.Provider = tt.provider + _, err := getAIClient(tt.oaiKey, mockCfg, log.New(os.Stdout, "", 0)) + if (err != nil) != tt.expectErr { + t.Errorf("expected error: %v, got %v", tt.expectErr, err) + } + } +} diff --git a/internal/ai/vertexai.go b/internal/ai/vertexai.go index f513d27..c8f886c 100644 --- a/internal/ai/vertexai.go +++ b/internal/ai/vertexai.go @@ -2,9 +2,9 @@ package ai import ( "context" + "fmt" "log" "reflect" - "fmt" "cloud.google.com/go/vertexai/genai" ) @@ -24,7 +24,7 @@ func (ai *VertexAI) GetResponse(prompt *Prompt) (string, error) { for _, image := range prompt.Images { integrated_prompt = append(integrated_prompt, genai.ImageData(image.Extension, image.Data)) } - integrated_prompt = append(integrated_prompt, genai.Text(prompt.SystemPrompt + prompt.UserPrompt)) + integrated_prompt = append(integrated_prompt, genai.Text(prompt.SystemPrompt+prompt.UserPrompt)) // Generate AI response resp, err := model.GenerateContent(ai.context, integrated_prompt...) diff --git a/internal/utils/utils.go b/internal/utils/utils.go index e1edc59..e4c5231 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -138,4 +138,4 @@ func DownloadImage(url string, token string) ([]byte, string, error) { func ImageToBase64(data []byte, ext string) string { base64img := base64.StdEncoding.EncodeToString(data) return "data:image/" + ext + ";base64," + base64img -} \ No newline at end of file +} diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go index 3d32b06..4f0c3c4 100644 --- a/internal/utils/utils_test.go +++ b/internal/utils/utils_test.go @@ -1,14 +1,14 @@ package utils import ( - "os" - "testing" + "os" + "testing" ) // TestNewConfig tests the NewConfig function func TestNewConfig(t *testing.T) { - // Setup: Create a temporary config file - configContent := ` + // Setup: Create a temporary config file + configContent := ` system: debug: log_level: "debug" @@ -21,39 +21,39 @@ ai: description: "Test command" system_prompt: "Prompt" ` - tempFile, err := os.CreateTemp("", "testconfig*.yaml") - if err != nil { - t.Fatalf("Error creating temporary config file: %v", err) - } - defer os.Remove(tempFile.Name()) // Clean up after the test + tempFile, err := os.CreateTemp("", "testconfig*.yaml") + if err != nil { + t.Fatalf("Error creating temporary config file: %v", err) + } + defer os.Remove(tempFile.Name()) // Clean up after the test - if _, err := tempFile.Write([]byte(configContent)); err != nil { - t.Fatalf("Error writing to temporary config file: %v", err) - } - if err := tempFile.Close(); err != nil { - t.Fatalf("Error closing temporary config file: %v", err) - } + if _, err := tempFile.Write([]byte(configContent)); err != nil { + t.Fatalf("Error writing to temporary config file: %v", err) + } + if err := tempFile.Close(); err != nil { + t.Fatalf("Error closing temporary config file: %v", err) + } - // Test: Call NewConfig - cfg, err := NewConfig(tempFile.Name()) - if err != nil { - t.Fatalf("NewConfig returned an error: %v", err) - } + // Test: Call NewConfig + cfg, err := NewConfig(tempFile.Name()) + if err != nil { + t.Fatalf("NewConfig returned an error: %v", err) + } - // Validate: Check if the values are correctly parsed - if cfg.System.Debug.Log_level != "debug" { - t.Errorf("Expected log_level 'debug', got '%s'", cfg.System.Debug.Log_level) - } - if cfg.Ai.Provider != "openai" { - t.Errorf("Expected provider 'openai', got '%s'", cfg.Ai.Provider) - } - if cfg.Ai.OpenAI.Model != "text-davinci-003" { - t.Errorf("Expected model 'text-davinci-003', got '%s'", cfg.Ai.OpenAI.Model) - } - if cfg.Ai.Commands["command1"].Description != "Test command" { - t.Errorf("Expected command description 'Test command', got '%s'", cfg.Ai.Commands["command1"].Description) - } - if cfg.Ai.Commands["command1"].System_prompt != "Prompt" { - t.Errorf("Expected system_prompt 'Prompt', got '%s'", cfg.Ai.Commands["command1"].System_prompt) - } -} \ No newline at end of file + // Validate: Check if the values are correctly parsed + if cfg.System.Debug.Log_level != "debug" { + t.Errorf("Expected log_level 'debug', got '%s'", cfg.System.Debug.Log_level) + } + if cfg.Ai.Provider != "openai" { + t.Errorf("Expected provider 'openai', got '%s'", cfg.Ai.Provider) + } + if cfg.Ai.OpenAI.Model != "text-davinci-003" { + t.Errorf("Expected model 'text-davinci-003', got '%s'", cfg.Ai.OpenAI.Model) + } + if cfg.Ai.Commands["command1"].Description != "Test command" { + t.Errorf("Expected command description 'Test command', got '%s'", cfg.Ai.Commands["command1"].Description) + } + if cfg.Ai.Commands["command1"].System_prompt != "Prompt" { + t.Errorf("Expected system_prompt 'Prompt', got '%s'", cfg.Ai.Commands["command1"].System_prompt) + } +}