Skip to content

Commit

Permalink
fix: config init writes unnecessary values to config file (#153)
Browse files Browse the repository at this point in the history
  • Loading branch information
tbckr authored Oct 9, 2023
1 parent fa01a83 commit 1acacb6
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 44 deletions.
3 changes: 2 additions & 1 deletion Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ tasks:
ci:
desc: Run all CI steps
cmds:
- task: setup
- task: clean
- task: prepare
- task: build
- task: test

Expand Down
11 changes: 5 additions & 6 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ import (
"os"
"strings"

"github.com/tbckr/sgpt/v2/chat"

"github.com/sashabaranov/go-openai"
"github.com/spf13/viper"
"github.com/tbckr/sgpt/v2/chat"
"github.com/tbckr/sgpt/v2/modifiers"
)

Expand Down Expand Up @@ -85,7 +86,7 @@ func CreateClient() (*OpenAIClient, error) {
return client, nil
}

func (c *OpenAIClient) GetChatCompletion(ctx context.Context, config *viper.Viper, prompt, modifier string) (string, error) {
func (c *OpenAIClient) GetChatCompletion(ctx context.Context, config *viper.Viper, chatID, prompt, modifier string) (string, error) {
var err error
var chatSessionManager chat.SessionManager
var messages []openai.ChatCompletionMessage
Expand All @@ -95,10 +96,8 @@ func (c *OpenAIClient) GetChatCompletion(ctx context.Context, config *viper.Vipe
return "", err
}

var chatID string
var isChat bool
if config.IsSet("chat") {
chatID = config.GetString("chat")
isChat := false
if chatID != "" {
isChat = true
}
chatExists := false
Expand Down
12 changes: 4 additions & 8 deletions api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func TestSimplePrompt(t *testing.T) {
config := createTestConfig(t)

var result string
result, err = client.GetChatCompletion(context.Background(), config, prompt, "txt")
result, err = client.GetChatCompletion(context.Background(), config, "", prompt, "txt")
require.NoError(t, err)
require.Equal(t, expected, result)

Expand All @@ -117,10 +117,8 @@ func TestPromptSaveAsChat(t *testing.T) {
require.NoError(t, err)
config := createTestConfig(t)

config.Set("chat", "test_chat")

var result string
result, err = client.GetChatCompletion(context.Background(), config, prompt, "txt")
result, err = client.GetChatCompletion(context.Background(), config, "test_chat", prompt, "txt")
require.NoError(t, err)
require.Equal(t, expected, result)

Expand Down Expand Up @@ -152,8 +150,6 @@ func TestPromptLoadChat(t *testing.T) {
require.NoError(t, err)
config := createTestConfig(t)

config.Set("chat", "test_chat")

var manager chat.SessionManager
manager, err = chat.NewFilesystemChatSessionManager(config)
require.NoError(t, err)
Expand All @@ -171,7 +167,7 @@ func TestPromptLoadChat(t *testing.T) {
require.NoError(t, err)

var result string
result, err = client.GetChatCompletion(context.Background(), config, prompt, "txt")
result, err = client.GetChatCompletion(context.Background(), config, "test_chat", prompt, "txt")
require.NoError(t, err)
require.Equal(t, expected, result)

Expand Down Expand Up @@ -206,7 +202,7 @@ func TestPromptWithModifier(t *testing.T) {
config.Set("chat", "test_chat")

var result string
result, err = client.GetChatCompletion(context.Background(), config, prompt, "sh")
result, err = client.GetChatCompletion(context.Background(), config, "test_chat", prompt, "sh")
require.NoError(t, err)
require.Equal(t, expected, result)

Expand Down
7 changes: 7 additions & 0 deletions cli/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ func TestConfigCmdInit(t *testing.T) {
require.Equal(t, 0, mem.code)

require.FileExists(t, filepath.Join(configDir, "config.yaml"))
// config must only contain values for model, maxtokens, temperature, topp
require.NoError(t, config.ReadInConfig())
// TESTING may be in the config, because this is a test
require.Equal(t, 5, len(config.AllSettings()))
for _, key := range []string{"model", "maxtokens", "temperature", "topp", "testing"} {
require.Contains(t, config.AllSettings(), key)
}
}

func TestConfigCmdInitAlreadyExists(t *testing.T) {
Expand Down
49 changes: 20 additions & 29 deletions cli/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,13 @@ import (
)

type rootCmd struct {
cmd *cobra.Command
exit func(int)
cmd *cobra.Command
exit func(int)

chat string
execute bool
copyToClipboard bool

verbose bool
}

Expand Down Expand Up @@ -209,7 +214,7 @@ ls | sort
}

var response string
response, err = client.GetChatCompletion(cmd.Context(), config, prompt, mode)
response, err = client.GetChatCompletion(cmd.Context(), config, root.chat, prompt, mode)
if err != nil {
return err
}
Expand All @@ -218,7 +223,7 @@ ls | sort
return err
}

if config.GetBool("clipboard") {
if root.copyToClipboard {
slog.Debug("Sending client response to clipboard")
err = clipboard.WriteAll(response)
if err != nil {
Expand All @@ -227,19 +232,26 @@ ls | sort
}
}

if config.GetBool("execute") {
if root.execute {
slog.Debug("Trying to execute response in shell")
return shell.ExecuteCommandWithConfirmation(cmd.Context(), cmd.InOrStdin(), cmd.OutOrStdout(), response)
}
return nil
},
}

// flags
cmd.Flags().BoolVarP(&root.execute, "execute", "e", false, "execute a response in the shell")
cmd.Flags().BoolVarP(&root.copyToClipboard, "clipboard", "b", false, "send client response to clipboard")
cmd.Flags().StringVarP(&root.chat, "chat", "c", "", "use an existing chat session or create a new one")

// flags with config binding
createFlagsWithConfigBinding(cmd, config)

// verbose persistent flag
cmd.PersistentFlags().BoolVarP(&root.verbose, "verbose", "v", false,
"enable more verbose output for debugging")

createFlags(cmd, config)

cmd.AddCommand(
newChatCmd(config).cmd,
newCheckCmd(config, createClientFn).cmd,
Expand All @@ -253,7 +265,7 @@ ls | sort
return root
}

func createFlags(cmd *cobra.Command, config *viper.Viper) {
func createFlagsWithConfigBinding(cmd *cobra.Command, config *viper.Viper) {
var bindErrors []error
var err error
// text based commands
Expand Down Expand Up @@ -281,27 +293,6 @@ func createFlags(cmd *cobra.Command, config *viper.Viper) {
bindErrors = append(bindErrors, err)
}

// shell command
cmd.Flags().BoolP("execute", "e", false, "execute a response in the shell")
err = config.BindPFlag("execute", cmd.Flags().Lookup("execute"))
if err != nil {
bindErrors = append(bindErrors, err)
}

// clipboard flags
cmd.Flags().BoolP("clipboard", "b", false, "send client response to clipboard")
err = config.BindPFlag("clipboard", cmd.Flags().Lookup("clipboard"))
if err != nil {
bindErrors = append(bindErrors, err)
}

// chat flags
cmd.Flags().StringP("chat", "c", "", "use an existing chat session or create a new one")
err = config.BindPFlag("chat", cmd.Flags().Lookup("chat"))
if err != nil {
bindErrors = append(bindErrors, err)
}

if len(bindErrors) > 0 {
for _, err = range bindErrors {
slog.Error("Failed to bind flag to viper", "error", err)
Expand Down

0 comments on commit 1acacb6

Please sign in to comment.