package main import ( "bufio" "bytes" "context" "errors" "fmt" "io" "os" "strconv" "strings" "github.com/alecthomas/chroma" "github.com/alecthomas/chroma/formatters" "github.com/alecthomas/chroma/lexers" "github.com/alecthomas/chroma/styles" openai "github.com/sashabaranov/go-openai" ) func formatCode(code, lang string) { lexer := lexers.Get(lang) if lexer == nil { lexer = lexers.Fallback } lexer = chroma.Coalesce(lexer) style := styles.Get("solarized-dark") if style == nil { style = styles.Fallback } formatter := formatters.Get("terminal256") if formatter == nil { formatter = formatters.Fallback } iterator, err := lexer.Tokenise(nil, code) if err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) return } err = formatter.Format(os.Stdout, style, iterator) if err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) } } func processChar(char rune, state *int, buffer *string, code *string, lang *string) { switch *state { case 0: if char == '`' { *state = 1 } else { fmt.Print(string(char)) } case 1: if char == '`' { *state = 2 } else { *state = 0 fmt.Print("`", string(char)) } case 2: if char == '`' { *state = 3 *buffer = "" } else { *state = 0 fmt.Print("``", string(char)) } case 3: if char == '\n' { *state = 4 *lang = *buffer *buffer = "" } else { *buffer += string(char) } case 4: if char == '`' { *state = 5 } else { *code += string(char) } case 5: if char == '`' { *state = 6 } else { *state = 4 *code += "`" + string(char) } case 6: if char == '`' { formatCode(*code, *lang) *code = "" *lang = "" *state = 0 } else { *state = 4 *code += "``" + string(char) } } } // contains checks if a given string is in the array func contains(array []string, target string) bool { for _, str := range array { if str == target { return true } } return false } // Define an array of strings as a global variable var validModels = []string{ "gpt-4-32k-0314", "gpt-4-32k", "gpt-4-0314", "gpt-4", "gpt-3.5-turbo-0301", "gpt-3.5-turbo", "mistralai/Mixtral-8x7B-Instruct-v0.1", } func main() { // Configure the API client bearer := os.Getenv("OPEN_AI_TOKEN") if bearer == "" { fmt.Println("You need to set OPEN_AI_TOKEN environment variable.") } model := os.Getenv("OPEN_AI_MODEL") if model == "" { fmt.Println("You need to set OPEN_AI_MODEL environment variable.") } else if !contains(validModels, model) { fmt.Println("You have specified an invalid model... Please select from: gpt-4, gpt-4-0314, gpt-4-32k, gpt-4-32k-0314, gpt-3.5-turbo, gpt-3.5-turbo-0301") } preTemp := os.Getenv("OPEN_AI_TEMP") // Pull env var, convert and error check if preTemp == "" { fmt.Println("You need to set OPEN_AI_TEMP environment variable.") } temperature, err := strconv.ParseFloat(preTemp, 32) if err != nil { fmt.Println(err) } preMaxTokens := os.Getenv("OPEN_AI_MAX_TOKENS") // Pull env var, convert and error check if preMaxTokens == "" { fmt.Println("You need to set OPEN_AI_MAX_TOKENS evironment variable.") } maxTokens, err := strconv.Atoi(preMaxTokens) if err != nil { fmt.Println(err) } customUrl := os.Getenv("OPEN_AI_CUSTOM_URL") if customUrl == "" { fmt.Println("If you want to use a personal endpoint, set OPEN_AI_CUSTOM_URL") } if len(os.Args) < 2 { fmt.Println() fmt.Println("No command line argument supplied!") fmt.Println("Please use: --chat or -c for chat mode. Type \"exit\" to exit the chat.") fmt.Println("Please use: \"Double Quotes\" or 'Single Quotes' for a one-off command.") return } else if os.Args[1] == "--chat" || os.Args[1] == "-c" { var client *openai.Client if customUrl != "" { config := openai.DefaultConfig(bearer) config.BaseURL = customUrl client = openai.NewClientWithConfig(config) } else { client = openai.NewClient(bearer) } ctx := context.Background() messages := make([]openai.ChatCompletionMessage, 0) reader := bufio.NewReader(os.Stdin) fmt.Println() fmt.Println("Begin Conversation") fmt.Println("------------------") var fullMessage string for { fmt.Print("-> ") // This complex block is for handling code blocks in input var input_buffer bytes.Buffer inCodeBlock := false backtickCount := 0 for { ch, _, err := reader.ReadRune() if err != nil { if err == io.EOF { break } fmt.Println("Error reading from stdin:", err) return } if ch == '`' { backtickCount++ if backtickCount == 3 { inCodeBlock = !inCodeBlock backtickCount = 0 } } else { backtickCount = 0 } input_buffer.WriteRune(ch) if !inCodeBlock && ch == '\n' { break } } text := input_buffer.String() // convert CRLF to LF text = strings.Replace(text, "\n", "", -1) if text == "exit" { return } messages = append(messages, openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleUser, Content: text, }) stream, err := client.CreateChatCompletionStream( ctx, openai.ChatCompletionRequest{ Model: model, MaxTokens: maxTokens, Temperature: float32(temperature), Messages: messages, Stream: true, }, ) if err != nil { fmt.Printf("ChatCompletion error: %v\n", err) continue } defer stream.Close() // Used for processing code blocks state := 0 buffer := "" code := "" lang := "" for { response, err := stream.Recv() if errors.Is(err, io.EOF) { fmt.Println() // For spacing of the respsonse fmt.Println() break } if err != nil { fmt.Printf("\nStream error: %v\n", err) return } for _, char := range response.Choices[0].Delta.Content { processChar(char, &state, &buffer, &code, &lang) } fullMessage = fullMessage + response.Choices[0].Delta.Content // Save full response to save back to chat context } // Put response back into chat context messages = append(messages, openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleAssistant, Content: fullMessage, }) } } else { // Get prompt and validate system_prompt := "You are a super powerful AI assistant. Answer all queries as concisely as possible and try to think through each response step-by-step." prompt := os.Args[1] // Create GPT stream chat request var c *openai.Client if customUrl != "" { config := openai.DefaultConfig(bearer) config.BaseURL = customUrl c = openai.NewClientWithConfig(config) } else { c = openai.NewClient(bearer) } ctx := context.Background() req := openai.ChatCompletionRequest{ Model: model, MaxTokens: maxTokens, Temperature: float32(temperature), Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleSystem, Content: system_prompt, }, { Role: openai.ChatMessageRoleUser, Content: prompt, }, }, Stream: true, } stream, err := c.CreateChatCompletionStream(ctx, req) if err != nil { fmt.Printf("ChatCompletionStream error: %v\n", err) return } defer stream.Close() state := 0 buffer := "" code := "" lang := "" fmt.Println() for { response, err := stream.Recv() if errors.Is(err, io.EOF) { fmt.Println() // For spacing of the response return // Finished displaying stream to terminal } if err != nil { fmt.Printf("\nStream error: %v\n", err) return } for _, char := range response.Choices[0].Delta.Content { processChar(char, &state, &buffer, &code, &lang) } } } }