Skip to content

Commit

Permalink
Add StringMap flag
Browse files Browse the repository at this point in the history
StringMap wraps a `map[string]string` and parses inputs in the form of
"key=value".
  • Loading branch information
avorima committed Nov 28, 2022
1 parent 26b286a commit 4b5281d
Show file tree
Hide file tree
Showing 7 changed files with 367 additions and 18 deletions.
20 changes: 20 additions & 0 deletions app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2874,6 +2874,16 @@ func TestFlagAction(t *testing.T) {
return nil
},
},
&StringMapFlag{
Name: "f_string_map",
Action: func(c *Context, v map[string]string) error {
if _, ok := v["err"]; ok {
return fmt.Errorf("error string map")
}
c.App.Writer.Write([]byte(fmt.Sprintf("%v", v)))
return nil
},
},
},
Action: func(ctx *Context) error { return nil },
}
Expand Down Expand Up @@ -3034,6 +3044,16 @@ func TestFlagAction(t *testing.T) {
args: []string{"app", "--f_string=app", "--f_uint=1", "--f_int_slice=1,2,3", "--f_duration=1h30m20s", "c1", "--f_string=c1", "sub1", "--f_string=sub1"},
exp: "app 1h30m20s [1 2 3] 1 c1 sub1 ",
},
{
name: "flag_string_map",
args: []string{"app", "--f_string_map=s1=s2,s3="},
exp: "map[s1:s2 s3:]",
},
{
name: "flag_string_map_error",
args: []string{"app", "--f_string_map=err="},
err: fmt.Errorf("error string map"),
},
}

for _, test := range tests {
Expand Down
5 changes: 3 additions & 2 deletions flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ import (
const defaultPlaceholder = "value"

var (
defaultSliceFlagSeparator = ","
disableSliceFlagSeparator = false
defaultSliceFlagSeparator = ","
defaultMapFlagKeyValueSeparator = "="
disableSliceFlagSeparator = false
)

var (
Expand Down
3 changes: 2 additions & 1 deletion flag_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ func (f *FlagBase[T, C, V]) RunAction(ctx *Context) error {
// IsSliceFlag returns true if the value type T is of kind slice
func (f *FlagBase[T, C, VC]) IsSliceFlag() bool {
// TBD how to specify
return reflect.TypeOf(f.Value).Kind() == reflect.Slice
kind := reflect.TypeOf(f.Value).Kind()
return kind == reflect.Slice || kind == reflect.Map
}

// IsPersistent returns true if flag needs to be persistent across subcommands
Expand Down
116 changes: 116 additions & 0 deletions flag_map_impl.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package cli

import (
"encoding/json"
"fmt"
"reflect"
"sort"
"strings"
)

// MapBase wraps map[string]T to satisfy flag.Value
type MapBase[T any, C any, VC ValueCreator[T, C]] struct {
dict *map[string]T
hasBeenSet bool
value Value
}

func (i MapBase[T, C, VC]) Create(val map[string]T, p *map[string]T, c C) Value {
*p = map[string]T{}
for k, v := range val {
(*p)[k] = v
}
var t T
np := new(T)
var vc VC
return &MapBase[T, C, VC]{
dict: p,
value: vc.Create(t, np, c),
}
}

// NewMapBase makes a *MapBase with default values
func NewMapBase[T any, C any, VC ValueCreator[T, C]](defaults map[string]T) *MapBase[T, C, VC] {
return &MapBase[T, C, VC]{
dict: &defaults,
}
}

// Set parses the value and appends it to the list of values
func (i *MapBase[T, C, VC]) Set(value string) error {
if !i.hasBeenSet {
*i.dict = map[string]T{}
i.hasBeenSet = true
}

if strings.HasPrefix(value, slPfx) {
// Deserializing assumes overwrite
_ = json.Unmarshal([]byte(strings.Replace(value, slPfx, "", 1)), &i.dict)
i.hasBeenSet = true
return nil
}

for _, item := range flagSplitMultiValues(value) {
key, value, ok := strings.Cut(item, defaultMapFlagKeyValueSeparator)
if !ok {
return fmt.Errorf("item %q is missing separator %q", item, defaultMapFlagKeyValueSeparator)
}
if err := i.value.Set(strings.TrimSpace(value)); err != nil {
return err
}
tmp, ok := i.value.Get().(T)
if !ok {
return fmt.Errorf("unable to cast %v", i.value)
}
(*i.dict)[key] = tmp
}

return nil
}

// String returns a readable representation of this value (for usage defaults)
func (i *MapBase[T, C, VC]) String() string {
v := i.Value()
var t T
if reflect.TypeOf(t).Kind() == reflect.String {
return fmt.Sprintf("%v", v)
}
return fmt.Sprintf("%T{%s}", v, i.ToString(v))
}

// Serialize allows MapBase to fulfill Serializer
func (i *MapBase[T, C, VC]) Serialize() string {
jsonBytes, _ := json.Marshal(i.dict)
return fmt.Sprintf("%s%s", slPfx, string(jsonBytes))
}

// Value returns the mapping of values set by this flag
func (i *MapBase[T, C, VC]) Value() map[string]T {
if i.dict == nil {
return map[string]T{}
}
return *i.dict
}

// Get returns the mapping of values set by this flag
func (i *MapBase[T, C, VC]) Get() interface{} {
return *i.dict
}

func (i MapBase[T, C, VC]) ToString(t map[string]T) string {
var defaultVals []string
var vc VC
for _, k := range sortedKeys(t) {
defaultVals = append(defaultVals, k+defaultMapFlagKeyValueSeparator+vc.ToString(t[k]))
}
return strings.Join(defaultVals, ", ")
}

func sortedKeys[T any](dict map[string]T) []string {
keys := make([]string, 0, len(dict))
for k := range dict {
keys = append(keys, k)
}
sort.Strings(keys)
return keys
}
27 changes: 27 additions & 0 deletions flag_string_map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package cli

import "flag"

type StringMap = MapBase[string, NoConfig, stringValue]
type StringMapFlag = FlagBase[map[string]string, NoConfig, StringMap]

var NewStringMap = NewMapBase[string, NoConfig, stringValue]

// StringMap looks up the value of a local StringMapFlag, returns
// nil if not found
func (cCtx *Context) StringMap(name string) map[string]string {
if fs := cCtx.lookupFlagSet(name); fs != nil {
return lookupStringMap(name, fs)
}
return nil
}

func lookupStringMap(name string, set *flag.FlagSet) map[string]string {
f := set.Lookup(name)
if f != nil {
if mapping, ok := f.Value.(*StringMap); ok {
return mapping.Value()
}
}
return nil
}
148 changes: 148 additions & 0 deletions flag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ func TestFlagsFromEnv(t *testing.T) {
{"08", 0, &Uint64Flag{Name: "seconds", EnvVars: []string{"SECONDS"}, Config: IntegerConfig{Base: 0}}, `could not parse "08" as uint64 value from environment variable "SECONDS" for flag seconds: .*`},
{"1.2", 0, &Uint64Flag{Name: "seconds", EnvVars: []string{"SECONDS"}}, `could not parse "1.2" as uint64 value from environment variable "SECONDS" for flag seconds: .*`},
{"foobar", 0, &Uint64Flag{Name: "seconds", EnvVars: []string{"SECONDS"}}, `could not parse "foobar" as uint64 value from environment variable "SECONDS" for flag seconds: .*`},

{"foo=bar,empty=", map[string]string{"foo": "bar", "empty": ""}, &StringMapFlag{Name: "names", EnvVars: []string{"NAMES"}}, ""},
}

for i, test := range flagTests {
Expand Down Expand Up @@ -2652,6 +2654,22 @@ func TestUint64Slice_Serialized_Set(t *testing.T) {
}
}

func TestStringMap_Serialized_Set(t *testing.T) {
m0 := NewStringMap(map[string]string{"a": "b"})
ser0 := m0.Serialize()

if len(ser0) < len(slPfx) {
t.Fatalf("serialized shorter than expected: %q", ser0)
}

m1 := NewStringMap(map[string]string{"c": "d"})
_ = m1.Set(ser0)

if m0.String() != m1.String() {
t.Fatalf("pre and post serialization do not match: %v != %v", m0, m1)
}
}

func TestTimestamp_set(t *testing.T) {
ts := timestampValue{
timestamp: nil,
Expand Down Expand Up @@ -2804,6 +2822,12 @@ func TestFlagDefaultValue(t *testing.T) {
toParse: []string{"--flag", "13"},
expect: `--flag value (default: 1)`,
},
{
name: "stringMap",
flag: &StringMapFlag{Name: "flag", Value: map[string]string{"default1": "default2"}},
toParse: []string{"--flag", "parsed="},
expect: `--flag value [ --flag value ] (default: default1="default2")`,
},
}
for i, v := range cases {
set := flag.NewFlagSet("test", 0)
Expand Down Expand Up @@ -2961,6 +2985,15 @@ func TestFlagDefaultValueWithEnv(t *testing.T) {
"tflag": "2010-01-02T15:04:05Z",
},
},
{
name: "stringMap",
flag: &StringMapFlag{Name: "flag", Value: map[string]string{"default1": "default2"}, EnvVars: []string{"ssflag"}},
toParse: []string{"--flag", "parsed="},
expect: `--flag value [ --flag value ] (default: default1="default2")` + withEnvHint([]string{"ssflag"}, ""),
environ: map[string]string{
"ssflag": "some-other-env_value=",
},
},
}
for i, v := range cases {
for key, val := range v.environ {
Expand Down Expand Up @@ -3025,6 +3058,12 @@ func TestFlagValue(t *testing.T) {
toParse: []string{"--flag", "13,14", "--flag", "15,16"},
expect: `[]uint{13, 14, 15, 16}`,
},
{
name: "stringMap",
flag: &StringMapFlag{Name: "flag", Value: map[string]string{"default1": "default2"}},
toParse: []string{"--flag", "parsed=parsed2", "--flag", "parsed3=parsed4"},
expect: `map[parsed:parsed2 parsed3:parsed4]`,
},
}
for i, v := range cases {
set := flag.NewFlagSet("test", 0)
Expand Down Expand Up @@ -3125,3 +3164,112 @@ func TestFlagSplitMultiValues_Disabled(t *testing.T) {
t.Fatalf("failed to disable split slice flag, want: %s, but got: %s", strings.Join(opts, defaultSliceFlagSeparator), ret[0])
}
}

var stringMapFlagTests = []struct {
name string
aliases []string
value map[string]string
expected string
}{
{"foo", nil, nil, "--foo value [ --foo value ]\t"},
{"f", nil, nil, "-f value [ -f value ]\t"},
{"f", nil, map[string]string{"Lipstick": ""}, "-f value [ -f value ]\t(default: Lipstick=)"},
{"test", nil, map[string]string{"Something": ""}, "--test value [ --test value ]\t(default: Something=)"},
{"dee", []string{"d"}, map[string]string{"Inka": "Dinka", "dooo": ""}, "--dee value, -d value [ --dee value, -d value ]\t(default: Inka=\"Dinka\", dooo=)"},
}

func TestStringMapFlagHelpOutput(t *testing.T) {
for _, test := range stringMapFlagTests {
f := &StringMapFlag{Name: test.name, Aliases: test.aliases, Value: test.value}
output := f.String()

if output != test.expected {
t.Errorf("%q does not match %q", output, test.expected)
}
}
}

func TestStringMapFlagWithEnvVarHelpOutput(t *testing.T) {
defer resetEnv(os.Environ())
os.Clearenv()
_ = os.Setenv("APP_QWWX", "11,4")

for _, test := range stringMapFlagTests {
fl := &StringMapFlag{Name: test.name, Aliases: test.aliases, Value: test.value, EnvVars: []string{"APP_QWWX"}}
output := fl.String()

expectedSuffix := withEnvHint([]string{"APP_QWWX"}, "")
if !strings.HasSuffix(output, expectedSuffix) {
t.Errorf("%q does not end with"+expectedSuffix, output)
}
}
}

func TestStringMapFlagApply_SetsAllNames(t *testing.T) {
fl := StringMapFlag{Name: "goat", Aliases: []string{"G", "gooots"}}
set := flag.NewFlagSet("test", 0)
_ = fl.Apply(set)

err := set.Parse([]string{"--goat", "aaa=", "-G", "bbb=", "--gooots", "eeeee="})
expect(t, err, nil)
}

func TestStringMapFlagApply_UsesEnvValues_noDefault(t *testing.T) {
defer resetEnv(os.Environ())
os.Clearenv()
_ = os.Setenv("MY_GOAT", "vincent van goat=scape goat")
var val map[string]string
fl := StringMapFlag{Name: "goat", EnvVars: []string{"MY_GOAT"}, Value: val}
set := flag.NewFlagSet("test", 0)
_ = fl.Apply(set)

err := set.Parse(nil)
expect(t, err, nil)
expect(t, val, map[string]string(nil))
expect(t, set.Lookup("goat").Value.(*StringMap).Value(), map[string]string{"vincent van goat": "scape goat"})
}

func TestStringMapFlagApply_UsesEnvValues_withDefault(t *testing.T) {
defer resetEnv(os.Environ())
os.Clearenv()
_ = os.Setenv("MY_GOAT", "vincent van goat=scape goat")
val := map[string]string{`some default`: `values here`}
fl := StringMapFlag{Name: "goat", EnvVars: []string{"MY_GOAT"}, Value: val}
set := flag.NewFlagSet("test", 0)
_ = fl.Apply(set)
err := set.Parse(nil)
expect(t, err, nil)
expect(t, val, map[string]string{`some default`: `values here`})
expect(t, set.Lookup("goat").Value.(*StringMap).Value(), map[string]string{"vincent van goat": "scape goat"})
}

func TestStringMapFlagApply_DefaultValueWithDestination(t *testing.T) {
defValue := map[string]string{"UA": "US"}

fl := StringMapFlag{Name: "country", Value: defValue, Destination: &map[string]string{"CA": ""}}
set := flag.NewFlagSet("test", 0)
_ = fl.Apply(set)

err := set.Parse([]string{})
expect(t, err, nil)
expect(t, defValue, *fl.Destination)
}

func TestStringMapFlagValueFromContext(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Var(NewStringMap(map[string]string{"a": "b", "c": ""}), "myflag", "doc")
ctx := NewContext(nil, set, nil)
f := &StringMapFlag{Name: "myflag"}
expect(t, f.Get(ctx), map[string]string{"a": "b", "c": ""})
}

func TestStringMapFlagApply_Error(t *testing.T) {
fl := StringMapFlag{Name: "goat"}
set := flag.NewFlagSet("test", 0)
_ = fl.Apply(set)

err := set.Parse([]string{"--goat", "aaa", "bbb="})
if err == nil {
t.Errorf("expected error, but got none")
}
}
Loading

0 comments on commit 4b5281d

Please sign in to comment.