Skip to content

Commit

Permalink
Add StringMap flag
Browse files Browse the repository at this point in the history
StringMap wraps a `map[string]string` and parses inputs in the form of
"key=value".
  • Loading branch information
avorima committed Dec 2, 2022
1 parent e0bfec9 commit ef880bd
Show file tree
Hide file tree
Showing 8 changed files with 416 additions and 3 deletions.
48 changes: 48 additions & 0 deletions app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,34 @@ func ExampleApp_Run_sliceValues() {
// error: <nil>
}

func ExampleApp_Run_mapValues() {
// set args for examples sake
os.Args = []string{
"multi_values",
"--stringMap", "parsed1=parsed two", "--stringMap", "parsed3=",
}
app := NewApp()
app.Name = "multi_values"
app.Flags = []Flag{
&StringMapFlag{Name: "stringMap"},
}
app.Action = func(ctx *Context) error {
for i, v := range ctx.FlagNames() {
fmt.Printf("%d-%s %#v\n", i, v, ctx.StringMap(v))
}
fmt.Printf("notfound %#v\n", ctx.StringMap("notfound"))
err := ctx.Err()
fmt.Println("error:", err)
return err
}

_ = app.Run(os.Args)
// Output:
// 0-stringMap map[string]string{"parsed1":"parsed two", "parsed3":""}
// notfound map[string]string(nil)
// error: <nil>
}

func TestApp_Run(t *testing.T) {
s := ""

Expand Down Expand Up @@ -2874,6 +2902,16 @@ func TestFlagAction(t *testing.T) {
return nil
},
},
&StringMapFlag{
Name: "f_string_map",
Action: func(c *Context, v map[string]string) error {
if _, ok := v["err"]; ok {
return fmt.Errorf("error string map")
}
c.App.Writer.Write([]byte(fmt.Sprintf("%v", v)))
return nil
},
},
},
Action: func(ctx *Context) error { return nil },
}
Expand Down Expand Up @@ -3034,6 +3072,16 @@ func TestFlagAction(t *testing.T) {
args: []string{"app", "--f_string=app", "--f_uint=1", "--f_int_slice=1,2,3", "--f_duration=1h30m20s", "c1", "--f_string=c1", "sub1", "--f_string=sub1"},
exp: "app 1h30m20s [1 2 3] 1 c1 sub1 ",
},
{
name: "flag_string_map",
args: []string{"app", "--f_string_map=s1=s2,s3="},
exp: "map[s1:s2 s3:]",
},
{
name: "flag_string_map_error",
args: []string{"app", "--f_string_map=err="},
err: fmt.Errorf("error string map"),
},
}

for _, test := range tests {
Expand Down
5 changes: 3 additions & 2 deletions flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ import (
const defaultPlaceholder = "value"

var (
defaultSliceFlagSeparator = ","
disableSliceFlagSeparator = false
defaultSliceFlagSeparator = ","
defaultMapFlagKeyValueSeparator = "="
disableSliceFlagSeparator = false
)

var (
Expand Down
3 changes: 2 additions & 1 deletion flag_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ func (f *FlagBase[T, C, V]) RunAction(ctx *Context) error {
// IsSliceFlag returns true if the value type T is of kind slice
func (f *FlagBase[T, C, VC]) IsSliceFlag() bool {
// TBD how to specify
return reflect.TypeOf(f.Value).Kind() == reflect.Slice
kind := reflect.TypeOf(f.Value).Kind()
return kind == reflect.Slice || kind == reflect.Map
}

// IsPersistent returns true if flag needs to be persistent across subcommands
Expand Down
116 changes: 116 additions & 0 deletions flag_map_impl.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package cli

import (
"encoding/json"
"fmt"
"reflect"
"sort"
"strings"
)

// MapBase wraps map[string]T to satisfy flag.Value
type MapBase[T any, C any, VC ValueCreator[T, C]] struct {
dict *map[string]T
hasBeenSet bool
value Value
}

func (i MapBase[T, C, VC]) Create(val map[string]T, p *map[string]T, c C) Value {
*p = map[string]T{}
for k, v := range val {
(*p)[k] = v
}
var t T
np := new(T)
var vc VC
return &MapBase[T, C, VC]{
dict: p,
value: vc.Create(t, np, c),
}
}

// NewMapBase makes a *MapBase with default values
func NewMapBase[T any, C any, VC ValueCreator[T, C]](defaults map[string]T) *MapBase[T, C, VC] {
return &MapBase[T, C, VC]{
dict: &defaults,
}
}

// Set parses the value and appends it to the list of values
func (i *MapBase[T, C, VC]) Set(value string) error {
if !i.hasBeenSet {
*i.dict = map[string]T{}
i.hasBeenSet = true
}

if strings.HasPrefix(value, slPfx) {
// Deserializing assumes overwrite
_ = json.Unmarshal([]byte(strings.Replace(value, slPfx, "", 1)), &i.dict)
i.hasBeenSet = true
return nil
}

for _, item := range flagSplitMultiValues(value) {
key, value, ok := strings.Cut(item, defaultMapFlagKeyValueSeparator)
if !ok {
return fmt.Errorf("item %q is missing separator %q", item, defaultMapFlagKeyValueSeparator)
}
if err := i.value.Set(strings.TrimSpace(value)); err != nil {
return err
}
tmp, ok := i.value.Get().(T)
if !ok {
return fmt.Errorf("unable to cast %v", i.value)
}
(*i.dict)[key] = tmp
}

return nil
}

// String returns a readable representation of this value (for usage defaults)
func (i *MapBase[T, C, VC]) String() string {
v := i.Value()
var t T
if reflect.TypeOf(t).Kind() == reflect.String {
return fmt.Sprintf("%v", v)
}
return fmt.Sprintf("%T{%s}", v, i.ToString(v))
}

// Serialize allows MapBase to fulfill Serializer
func (i *MapBase[T, C, VC]) Serialize() string {
jsonBytes, _ := json.Marshal(i.dict)
return fmt.Sprintf("%s%s", slPfx, string(jsonBytes))
}

// Value returns the mapping of values set by this flag
func (i *MapBase[T, C, VC]) Value() map[string]T {
if i.dict == nil {
return map[string]T{}
}
return *i.dict
}

// Get returns the mapping of values set by this flag
func (i *MapBase[T, C, VC]) Get() interface{} {
return *i.dict
}

func (i MapBase[T, C, VC]) ToString(t map[string]T) string {
var defaultVals []string
var vc VC
for _, k := range sortedKeys(t) {
defaultVals = append(defaultVals, k+defaultMapFlagKeyValueSeparator+vc.ToString(t[k]))
}
return strings.Join(defaultVals, ", ")
}

func sortedKeys[T any](dict map[string]T) []string {
keys := make([]string, 0, len(dict))
for k := range dict {
keys = append(keys, k)
}
sort.Strings(keys)
return keys
}
27 changes: 27 additions & 0 deletions flag_string_map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package cli

import "flag"

type StringMap = MapBase[string, NoConfig, stringValue]
type StringMapFlag = FlagBase[map[string]string, NoConfig, StringMap]

var NewStringMap = NewMapBase[string, NoConfig, stringValue]

// StringMap looks up the value of a local StringMapFlag, returns
// nil if not found
func (cCtx *Context) StringMap(name string) map[string]string {
if fs := cCtx.lookupFlagSet(name); fs != nil {
return lookupStringMap(name, fs)
}
return nil
}

func lookupStringMap(name string, set *flag.FlagSet) map[string]string {
f := set.Lookup(name)
if f != nil {
if mapping, ok := f.Value.(*StringMap); ok {
return mapping.Value()
}
}
return nil
}
Loading

0 comments on commit ef880bd

Please sign in to comment.