diff --git a/internal/config/config.go b/internal/config/config.go index 6b324a0..171af60 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,9 +1,14 @@ package config -import "github.com/rs/zerolog" +import ( + "github.com/clevyr/yampl/internal/config/flag" + "github.com/rs/zerolog" +) type Config struct { - Values Values + valuesStringToString *flag.StringToString + Values Values + Inplace bool Recursive bool Prefix string @@ -23,7 +28,9 @@ type Config struct { func New() *Config { return &Config{ - Values: make(Values), + valuesStringToString: &flag.StringToString{}, + Values: make(Values), + Prefix: "#yampl", LeftDelim: "{{", RightDelim: "}}", diff --git a/internal/config/flag/string_to_string.go b/internal/config/flag/string_to_string.go new file mode 100644 index 0000000..d525e10 --- /dev/null +++ b/internal/config/flag/string_to_string.go @@ -0,0 +1,99 @@ +package flag + +import ( + "bytes" + "encoding/csv" + "errors" + "fmt" + "io" + "maps" + "strings" +) + +var ErrStringToStringFormat = errors.New("must be formatted as key=value") + +type StringToString struct { + value map[string]string + changed bool +} + +// Set Format: a=1,b=2 +func (s *StringToString) Set(val string) error { + val = strings.TrimSpace(val) + count := strings.Count(val, "=") + records := make([]string, 0, count) + switch count { + case 0: + return fmt.Errorf("%s %w", val, ErrStringToStringFormat) + case 1: + records = append(records, val) + default: + r := csv.NewReader(strings.NewReader(val)) + r.TrimLeadingSpace = true + for { + line, err := r.Read() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return err + } + + r.FieldsPerRecord = 0 // Prevent wrong number of fields error + + for _, v := range line { + switch { + case strings.Contains(v, "="): + records = append(records, v) + case len(records) != 0: + records[len(records)-1] += "\n" + v + default: + return fmt.Errorf("%s %w", v, ErrStringToStringFormat) + } + } + } + } + + result := make(map[string]string, len(records)) + for _, pair := range records { + kv := strings.SplitN(pair, "=", 2) + if len(kv) != 2 { + return fmt.Errorf("%s %w", pair, ErrStringToStringFormat) + } + result[kv[0]] = kv[1] + } + + if s.changed { + for k, v := range result { + s.value[k] = v + } + } else { + s.changed = true + s.value = result + } + + return nil +} + +func (s *StringToString) Type() string { + return "stringToString" +} + +func (s *StringToString) String() string { + records := make([]string, 0, len(s.value)) + for k, v := range s.value { + records = append(records, k+"="+v) + } + + var buf bytes.Buffer + w := csv.NewWriter(&buf) + if err := w.Write(records); err != nil { + panic(err) + } + w.Flush() + return "[" + strings.TrimSpace(buf.String()) + "]" +} + +func (s *StringToString) Values() map[string]string { + return maps.Clone(s.value) +} diff --git a/internal/config/flag/string_to_string_test.go b/internal/config/flag/string_to_string_test.go new file mode 100644 index 0000000..80aa0c3 --- /dev/null +++ b/internal/config/flag/string_to_string_test.go @@ -0,0 +1,111 @@ +package flag + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStringToString_Set(t *testing.T) { + type args struct { + val string + } + tests := []struct { + name string + args args + want *StringToString + wantErr require.ErrorAssertionFunc + }{ + { + "one value", + args{"a=b"}, + &StringToString{value: map[string]string{"a": "b"}, changed: true}, + require.NoError, + }, + { + "two values", + args{"a=b,c=d"}, + &StringToString{value: map[string]string{"a": "b", "c": "d"}, changed: true}, + require.NoError, + }, + { + "multiline value", + args{"a=b\nc,d=e"}, + &StringToString{value: map[string]string{"a": "b\nc", "d": "e"}, changed: true}, + require.NoError, + }, + { + "multiline values", + args{"a=b\nc=d"}, + &StringToString{value: map[string]string{"a": "b", "c": "d"}, changed: true}, + require.NoError, + }, + { + "multiple newlines", + args{"a=b\n\nc=d"}, + &StringToString{value: map[string]string{"a": "b", "c": "d"}, changed: true}, + require.NoError, + }, + { + "trim spaces", + args{"a=b\n c=d"}, + &StringToString{value: map[string]string{"a": "b", "c": "d"}, changed: true}, + require.NoError, + }, + { + "newline around values", + args{"\na=b\nc=d\n"}, + &StringToString{value: map[string]string{"a": "b", "c": "d"}, changed: true}, + require.NoError, + }, + { + "json value", + args{"a=[1]"}, + &StringToString{value: map[string]string{"a": "[1]"}, changed: true}, + require.NoError, + }, + { + "json values", + args{"a=[1, 2, 3]"}, + &StringToString{value: map[string]string{"a": "[1, 2, 3]"}, changed: true}, + require.NoError, + }, + {"error empty", args{""}, &StringToString{}, require.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &StringToString{} + tt.wantErr(t, s.Set(tt.args.val)) + assert.Equal(t, tt.want, s) + }) + } + + t.Run("consecutive", func(t *testing.T) { + s := &StringToString{} + require.NoError(t, s.Set("a=b")) + assert.True(t, s.changed) + assert.Equal(t, map[string]string{"a": "b"}, s.value) + require.NoError(t, s.Set("c=d")) + assert.True(t, s.changed) + assert.Equal(t, map[string]string{"a": "b", "c": "d"}, s.value) + }) +} + +func TestStringToString_String(t *testing.T) { + tests := []struct { + name string + value *StringToString + want string + }{ + {"empty", &StringToString{}, "[]"}, + {"simple value", &StringToString{value: map[string]string{"a": "b"}, changed: true}, "[a=b]"}, + {"value with comma", &StringToString{value: map[string]string{"a": "b,c"}, changed: true}, `["a=b,c"]`}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.value.String() + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/internal/config/flags.go b/internal/config/flags.go index 223979a..007dc73 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -27,7 +27,7 @@ const ( func (c *Config) RegisterFlags(cmd *cobra.Command) { cmd.Flags().BoolVarP(&c.Inplace, InplaceFlag, "i", c.Inplace, "Edit files in place") - cmd.Flags().StringToStringP(ValueFlag, ValueFlagShort, map[string]string{}, "Define a template variable. Can be used more than once.") + cmd.Flags().VarP(c.valuesStringToString, ValueFlag, ValueFlagShort, "Define a template variable. Can be used more than once.") cmd.Flags().BoolVarP(&c.Recursive, RecursiveFlag, "r", c.Recursive, "Recursively update yaml files in the given directory") cmd.Flags().StringVarP(&c.Prefix, PrefixFlag, "p", c.Prefix, "Template comments must begin with this prefix. The beginning '#' is implied.") cmd.Flags().StringVar(&c.LeftDelim, LeftDelimFlag, c.LeftDelim, "Override template left delimiter") diff --git a/internal/config/load.go b/internal/config/load.go index 6932ecb..cca4418 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -39,11 +39,7 @@ func (c *Config) Load(cmd *cobra.Command) error { c.Prefix = "#" + c.Prefix } - rawValues, err := cmd.Flags().GetStringToString(ValueFlag) - if err != nil { - return err - } - c.Values.Fill(rawValues) + c.Values.Fill(c.valuesStringToString.Values()) if f := cmd.Flags().Lookup(FailFlag); f.Changed { val, err := cmd.Flags().GetBool(FailFlag) diff --git a/internal/config/values_hack.go b/internal/config/values_hack.go deleted file mode 100644 index d3928c6..0000000 --- a/internal/config/values_hack.go +++ /dev/null @@ -1,64 +0,0 @@ -package config - -import ( - "bufio" - "os" - "strings" -) - -func FixStringToStringNewlines(s []string) []string { - var prevValueFlag string - result := make([]string, 0, len(s)) - for i, arg := range s { - switch { - case arg == "--": - if prevValueFlag != "" { - result = append(result, prevValueFlag) - } - result = append(result, s[i:]...) - return result - case prevValueFlag != "": - result = append(result, fixArgNewlines(prevValueFlag+"="+arg)...) - prevValueFlag = "" - case hasValueFlag(arg): - if strings.ContainsRune(arg, '=') { - result = append(result, fixArgNewlines(arg)...) - } else { - prevValueFlag = arg - } - default: - result = append(result, arg) - } - } - envName := EnvPrefix + strings.ToUpper(ValueFlag) - if env := os.Getenv(envName); env != "" { - _ = os.Setenv(envName, strings.ReplaceAll(env, "\n", ",")) - } - return result -} - -func hasValueFlag(s string) bool { - return s == "-"+ValueFlagShort || - s == "--"+ValueFlag || - strings.HasPrefix(s, "-"+ValueFlagShort+"=") || - strings.HasPrefix(s, "--"+ValueFlag+"=") -} - -func fixArgNewlines(arg string) []string { - if strings.ContainsRune(arg, '\n') { - prefix, arg, found := strings.Cut(arg, "=") - if !found { - return []string{prefix} - } - - result := make([]string, 0, 2) - s := bufio.NewScanner(strings.NewReader(arg)) - for s.Scan() { - if len(s.Bytes()) > 0 { - result = append(result, prefix+"="+strings.TrimSpace(s.Text())) - } - } - return result - } - return []string{arg} -} diff --git a/internal/config/values_hack_test.go b/internal/config/values_hack_test.go deleted file mode 100644 index 61a2009..0000000 --- a/internal/config/values_hack_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package config - -import ( - "os" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestFixStringToStringNewlines(t *testing.T) { - t.Run("env", func(t *testing.T) { - t.Setenv("YAMPL_VALUE", "a=a\nb=b") - FixStringToStringNewlines([]string{}) - assert.Equal(t, "a=a,b=b", os.Getenv("YAMPL_VALUE")) - }) - - type args struct { - s []string - } - tests := []struct { - name string - args args - want []string - }{ - {"no value flag", args{[]string{"yampl"}}, []string{"yampl"}}, - {"no newline", args{[]string{"yampl", "--value=a=a"}}, []string{"yampl", "--value=a=a"}}, - {"newline with equal", args{[]string{"yampl", "--value=a=a\nb=b"}}, []string{"yampl", "--value=a=a", "--value=b=b"}}, - {"newline with space", args{[]string{"yampl", "--value", "a=a\nb=b"}}, []string{"yampl", "--value=a=a", "--value=b=b"}}, - {"newline in file", args{[]string{"yampl", "test\nfile.yaml"}}, []string{"yampl", "test\nfile.yaml"}}, - {"newline after end of options", args{[]string{"yampl", "-v=a=a", "---", "-v\nfile.yaml"}}, []string{"yampl", "-v=a=a", "---", "-v\nfile.yaml"}}, - {"trim newline", args{[]string{"yampl", "-v=\na=a\n"}}, []string{"yampl", "-v=a=a"}}, - {"collapse newlines", args{[]string{"yampl", "-v=a=a\n\nb=b"}}, []string{"yampl", "-v=a=a", "-v=b=b"}}, - {"trim spaces", args{[]string{"yampl", "-v=a=a\n b=b"}}, []string{"yampl", "-v=a=a", "-v=b=b"}}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := FixStringToStringNewlines(tt.args.s) - assert.Equal(t, tt.want, got) - }) - } -} - -func Test_hasValueFlag(t *testing.T) { - type args struct { - s string - } - tests := []struct { - name string - args args - want bool - }{ - {"no flag", args{"yampl"}, false}, - {"normal", args{"--value"}, true}, - {"normal with value", args{"--value=test"}, true}, - {"shorthand", args{"-v"}, true}, - {"shorthand with value", args{"-v=test"}, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := hasValueFlag(tt.args.s) - assert.Equal(t, tt.want, got) - }) - } -} diff --git a/main.go b/main.go index c74d344..badae98 100644 --- a/main.go +++ b/main.go @@ -4,13 +4,10 @@ import ( "os" "github.com/clevyr/yampl/cmd" - "github.com/clevyr/yampl/internal/config" ) func main() { - os.Args = config.FixStringToStringNewlines(os.Args) - rootCmd := cmd.NewCommand() - if err := rootCmd.Execute(); err != nil { + if err := cmd.NewCommand().Execute(); err != nil { os.Exit(1) } }