Skip to content

Commit

Permalink
Feature/cmd mode (#13)
Browse files Browse the repository at this point in the history
Changes:
* Cmd Mode - it postprocess and attempts to execute llm output in bash
* Cleaning and improved error handling
  • Loading branch information
baalimago authored Jun 17, 2024
1 parent 27630ce commit 4044c6c
Show file tree
Hide file tree
Showing 13 changed files with 191 additions and 43 deletions.
15 changes: 15 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,21 @@
"env": {
"NO_COLOR": "true"
}
},
{
"name": "ChatGPT - Chat - Cmd",
"type": "go",
"request": "launch",
"program": "${workspaceFolder}",
"args": [
"-cm",
"gpt-4o",
"cmd",
"give me a command to show my current directory"
],
"env": {
"NO_COLOR": "true"
}
}
]
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ module github.com/baalimago/clai

go 1.22

require github.com/baalimago/go_away_boilerplate v1.3.10
require github.com/baalimago/go_away_boilerplate v1.3.11

require golang.org/x/net v0.24.0
6 changes: 2 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
github.com/baalimago/go_away_boilerplate v1.3.9 h1:oLtCiTNZAU2OM528QKxTsXyWXjTuK1j2YYVLXpyN7JQ=
github.com/baalimago/go_away_boilerplate v1.3.9/go.mod h1:2O+zQ0Zm8vPD5SeccFFlgyf3AnYWQSHAut/ecPMmRdU=
github.com/baalimago/go_away_boilerplate v1.3.10 h1:zID0+yZPimRZxw1XM8KcdBQ+IT9fbAqlA8pfYCe1Qrc=
github.com/baalimago/go_away_boilerplate v1.3.10/go.mod h1:2O+zQ0Zm8vPD5SeccFFlgyf3AnYWQSHAut/ecPMmRdU=
github.com/baalimago/go_away_boilerplate v1.3.11 h1:dM3Hckn55zSb6D9N03+TTSUxFJ6207GsCiS6n8vVZCs=
github.com/baalimago/go_away_boilerplate v1.3.11/go.mod h1:2O+zQ0Zm8vPD5SeccFFlgyf3AnYWQSHAut/ecPMmRdU=
golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
8 changes: 4 additions & 4 deletions internal/glob/glob.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"

"github.com/baalimago/clai/internal/models"
"github.com/baalimago/clai/internal/utils"
"github.com/baalimago/go_away_boilerplate/pkg/ancli"
"github.com/baalimago/go_away_boilerplate/pkg/misc"
)
Expand Down Expand Up @@ -64,11 +65,10 @@ func constructGlobMessages(globMessages []models.Message) []models.Message {
}

func parseGlob(glob string) ([]models.Message, error) {
home, err := os.UserHomeDir()
if err != nil && strings.Contains(glob, "~/") { // only fail if glob contains ~/ and home dir is not found
return nil, fmt.Errorf("failed to get home dir: %w", err)
glob, err := utils.ReplaceTildeWithHome(glob)
if err != nil {
return nil, fmt.Errorf("parseGlob, ReplaceTildeWithHome: %w", err)
}
glob = strings.Replace(glob, "~", home, 1)
files, err := filepath.Glob(glob)
ret := make([]models.Message, 0, len(files))
if err != nil {
Expand Down
17 changes: 14 additions & 3 deletions internal/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
"github.com/baalimago/clai/internal/text"
"github.com/baalimago/clai/internal/utils"
"github.com/baalimago/go_away_boilerplate/pkg/ancli"
imagodebug "github.com/baalimago/go_away_boilerplate/pkg/debug"

"github.com/baalimago/go_away_boilerplate/pkg/misc"
)

Expand All @@ -33,6 +35,7 @@ const (
PHOTO
VERSION
SETUP
CMD
)

var defaultFlags = Configurations{
Expand Down Expand Up @@ -66,6 +69,8 @@ func getModeFromArgs(cmd string) (Mode, error) {
return SETUP, nil
case "version", "v":
return VERSION, nil
case "cmd":
return CMD, nil
default:
return HELP, fmt.Errorf("unknown command: '%s'", os.Args[1])
}
Expand All @@ -81,13 +86,19 @@ func setupTextQuerier(mode Mode, confDir string, flagSet Configurations) (models
if mode == CHAT {
tConf.ChatMode = true
}

if mode == CMD {
tConf.CmdMode = true
tConf.SystemPrompt = tConf.CmdModePrompt
}

// At the moment, the configurations are based on the config file. But
// the configuration presecende is flags > file > default. So, we need
// to re-apply the flag overrides to the configuration
applyFlagOverridesForText(&tConf, flagSet, defaultFlags)

if misc.Truthy(os.Getenv("DEBUG")) {
ancli.PrintOK(fmt.Sprintf("config post flag override: %+v\n", tConf))
ancli.PrintOK(fmt.Sprintf("config post flag override: %+v\n", imagodebug.IndentedJsonFmt(tConf)))
}
args := flag.Args()
if mode == GLOB || flagSet.Glob != "" {
Expand Down Expand Up @@ -133,7 +144,7 @@ func Setup(usage string) (models.Querier, error) {
}

switch mode {
case CHAT, QUERY, GLOB:
case CHAT, QUERY, GLOB, CMD:
return setupTextQuerier(mode, confDir, flagSet)
case PHOTO:
pConf, err := utils.LoadConfigFromFile(confDir, "photoConfig.json", migrateOldPhotoConfig, &photo.DEFAULT)
Expand All @@ -153,7 +164,7 @@ func Setup(usage string) (models.Querier, error) {
}
pq, err := NewPhotoQuerier(pConf)
if misc.Truthy(os.Getenv("DEBUG")) {
ancli.PrintOK(fmt.Sprintf("photo querier: %+v\n", pq))
ancli.PrintOK(fmt.Sprintf("photo querier: %+v\n", imagodebug.IndentedJsonFmt(pq)))
}
if err != nil {
return nil, fmt.Errorf("failed to create photo querier: %v", err)
Expand Down
22 changes: 17 additions & 5 deletions internal/text/conf.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ import (
"github.com/baalimago/clai/internal/reply"
"github.com/baalimago/clai/internal/utils"
"github.com/baalimago/go_away_boilerplate/pkg/ancli"
"github.com/baalimago/go_away_boilerplate/pkg/debug"
"github.com/baalimago/go_away_boilerplate/pkg/misc"
)

// Configurations used to setup the requirements of text models
type Configurations struct {
Model string `json:"model"`
SystemPrompt string `json:"system-prompt"`
CmdModePrompt string `json:"cmd-mode-prompt"`
Raw bool `json:"raw"`
UseTools bool `json:"use-tools"`
TokenWarnLimit int `json:"token-warn-limit"`
Expand All @@ -25,17 +27,19 @@ type Configurations struct {
Stream bool `json:"-"`
ReplyMode bool `json:"-"`
ChatMode bool `json:"-"`
CmdMode bool `json:"-"`
Glob string `json:"-"`
InitialPrompt models.Chat `json:"-"`
// PostProccessedPrompt which has had it's strings replaced etc
PostProccessedPrompt string `json:"-"`
}

var DEFAULT = Configurations{
Model: "gpt-4o",
SystemPrompt: "You are an assistant for a CLI interface. Answer concisely and informatively. Prefer markdown if possible.",
Raw: false,
UseTools: false,
Model: "gpt-4o",
SystemPrompt: "You are an assistant for a CLI tool. Answer concisely and informatively. Prefer markdown if possible.",
CmdModePrompt: "You are an assistant for a CLI tool aiding with cli tool suggestions. Write ONLY the command and nothing else.",
Raw: false,
UseTools: false,
// Aproximately $1 for the worst input rates as of 2024-05
TokenWarnLimit: 17000,
}
Expand All @@ -45,6 +49,8 @@ func (c *Configurations) SetupPrompts(args []string) error {
ancli.PrintWarn("Using glob + reply modes together might yield strange results. The prevQuery will be appended after the glob messages.\n")
}

// Allways replace system prompt on cmd mode. This somewhat corrupts the chat since it always will
// be the command prompt. But it's better than not having it
if !c.ReplyMode {
c.InitialPrompt = models.Chat{
Messages: []models.Message{
Expand All @@ -69,6 +75,12 @@ func (c *Configurations) SetupPrompts(args []string) error {
return fmt.Errorf("failed to load previous query: %w", err)
}
c.InitialPrompt.Messages = append(c.InitialPrompt.Messages, iP.Messages...)

if c.CmdMode {
// Replace the initial message with the cmd prompt. This sort of
// destroys the history, but since the conversation might be long it's fine
c.InitialPrompt.Messages[0].Content = c.SystemPrompt
}
}

prompt, err := utils.Prompt(c.StdinReplace, args)
Expand All @@ -84,7 +96,7 @@ func (c *Configurations) SetupPrompts(args []string) error {
}

if misc.Truthy(os.Getenv("DEBUG")) {
ancli.PrintOK(fmt.Sprintf("InitialPrompt: %+v\n", c.InitialPrompt))
ancli.PrintOK(fmt.Sprintf("InitialPrompt: %v\n", debug.IndentedJsonFmt(c.InitialPrompt)))
}
c.PostProccessedPrompt = prompt
if c.InitialPrompt.ID == "" {
Expand Down
3 changes: 2 additions & 1 deletion internal/text/generic/stream_completer.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/baalimago/clai/internal/models"
"github.com/baalimago/clai/internal/tools"
"github.com/baalimago/go_away_boilerplate/pkg/ancli"
"github.com/baalimago/go_away_boilerplate/pkg/debug"
"github.com/baalimago/go_away_boilerplate/pkg/misc"
)

Expand Down Expand Up @@ -64,7 +65,7 @@ func (s *StreamCompleter) createRequest(ctx context.Context, chat models.Chat) (
reqData.ToolChoice = s.ToolChoice
}
if s.debug {
ancli.PrintOK(fmt.Sprintf("generic streamcompleter request: %+v\n", reqData))
ancli.PrintOK(fmt.Sprintf("generic streamcompleter request: %v\n", debug.IndentedJsonFmt(reqData)))
}
jsonData, err := json.Marshal(reqData)
if err != nil {
Expand Down
75 changes: 52 additions & 23 deletions internal/text/querier.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/baalimago/clai/internal/tools"
"github.com/baalimago/clai/internal/utils"
"github.com/baalimago/go_away_boilerplate/pkg/ancli"
"github.com/baalimago/go_away_boilerplate/pkg/debug"
)

const (
Expand All @@ -38,32 +39,15 @@ type Querier[C models.StreamCompleter] struct {
hasPrinted bool
Model C
tokenWarnLimit int
cmdMode bool
}

// Query using the underlying model to stream completions and then print the output
// from the model to stdout. Blocking operation.
func (q *Querier[C]) Query(ctx context.Context) error {
amTokens := q.countTokens()
if q.tokenWarnLimit > 0 && amTokens > q.tokenWarnLimit {
ancli.PrintWarn(
fmt.Sprintf("You're about to send: ~%v tokens to the model, which may amount to: ~$%.3f (applying worst input rates as of 2024-05). This limit may be changed in: '%v'. Do you wish to continue? [yY]: ",
amTokens,
// Worst rates found at 2024-05 were gpt-4-32k at $60 per 1M tokens
float64(amTokens)*(float64(60)/float64(1000000)),
path.Join(q.configDir, "textConfig.json"),
))
var userInput string
reader := bufio.NewReader(os.Stdin)
userInput, err := reader.ReadString('\n')
if err != nil {
return fmt.Errorf("failed to read user input: %w", err)
}
switch userInput {
case "y\n", "Y\n":
// Continue on y or Y
default:
return errors.New("query canceled due to token amount check")
}
err := q.tokenLengthWarning()
if err != nil {
return fmt.Errorf("Querier.Query: %w", err)
}
completionsChan, err := q.Model.StreamCompletions(ctx, q.chat)
if err != nil {
Expand Down Expand Up @@ -106,6 +90,32 @@ func (q *Querier[C]) Query(ctx context.Context) error {
}
}

func (q *Querier[C]) tokenLengthWarning() error {
amTokens := q.countTokens()
if q.tokenWarnLimit > 0 && amTokens > q.tokenWarnLimit {
ancli.PrintWarn(
fmt.Sprintf("You're about to send: ~%v tokens to the model, which may amount to: ~$%.3f (applying worst input rates as of 2024-05). This limit may be changed in: '%v'. Do you wish to continue? [yY]: ",
amTokens,
// Worst rates found at 2024-05 were gpt-4-32k at $60 per 1M tokens
float64(amTokens)*(float64(60)/float64(1000000)),
path.Join(q.configDir, "textConfig.json"),
))
var userInput string
reader := bufio.NewReader(os.Stdin)
userInput, err := reader.ReadString('\n')
if err != nil {
return fmt.Errorf("failed to read user input: %w", err)
}
switch userInput {
case "y\n", "Y\n":
// Continue on y or Y
default:
return errors.New("Querier.tokenLengthWarning: query canceled due to token amount check")
}
}
return nil
}

// countTokens by simply counting the amount of strings which are delimited by whitespace
// and multiply by some factor. This factor is somewhat arbritrary, and adjusted to be good enough
// for all the different models
Expand Down Expand Up @@ -139,6 +149,23 @@ func (q *Querier[C]) postProcess() {
}
}

if q.debug {
ancli.PrintOK(fmt.Sprintf("Querier.postProcess:\n%v\n", debug.IndentedJsonFmt(q)))
}

// Cmd mode is a bit of a hack, it will handle all output
if q.cmdMode {
err := q.handleCmdMode()
if err != nil {
ancli.PrintErr(fmt.Sprintf("Querier.postProcess: %v\n", err))
}
return
}

q.postProcessOutput(newSysMsg)
}

func (q *Querier[C]) postProcessOutput(newSysMsg models.Message) {
// The token should already have been printed while streamed
if q.Raw {
return
Expand All @@ -163,13 +190,15 @@ func (q *Querier[C]) reset() {
func (q *Querier[C]) TextQuery(ctx context.Context, chat models.Chat) (models.Chat, error) {
q.reset()
q.chat = chat
// Query will update the chat with the latest system message
err := q.Query(ctx)
if err != nil {
return models.Chat{}, fmt.Errorf("failed to query: %w", err)
return models.Chat{}, fmt.Errorf("TextQuery: %w", err)
}
if q.debug {
ancli.PrintOK(fmt.Sprintf("chat: %v", q.chat))
ancli.PrintOK(fmt.Sprintf("Querier.TextQuery:\n%v", debug.IndentedJsonFmt(q)))
}

return q.chat, nil
}

Expand Down
Loading

0 comments on commit 4044c6c

Please sign in to comment.