Skip to content

Commit

Permalink
Add StringMap flag
Browse files Browse the repository at this point in the history
StringMap wraps a `map[string]string` and parses inputs in the form of
"key=value".
  • Loading branch information
avorima committed Nov 18, 2022
1 parent eb6cc14 commit d32d928
Show file tree
Hide file tree
Showing 8 changed files with 584 additions and 22 deletions.
20 changes: 20 additions & 0 deletions app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2901,6 +2901,16 @@ func TestFlagAction(t *testing.T) {
return nil
},
},
&StringMapFlag{
Name: "f_string_map",
Action: func(c *Context, v map[string]string) error {
if _, ok := v["err"]; ok {
return fmt.Errorf("error string map")
}
c.App.Writer.Write([]byte(fmt.Sprintf("%v", v)))
return nil
},
},
},
Action: func(ctx *Context) error { return nil },
}
Expand Down Expand Up @@ -3081,6 +3091,16 @@ func TestFlagAction(t *testing.T) {
args: []string{"app", "--f_string=app", "--f_uint=1", "--f_int_slice=1,2,3", "--f_duration=1h30m20s", "c1", "--f_string=c1", "sub1", "--f_string=sub1"},
exp: "app 1h30m20s [1 2 3] 1 c1 sub1 ",
},
{
name: "flag_string_map",
args: []string{"app", "--f_string_map=s1=s2,s3="},
exp: "map[s1:s2 s3:]",
},
{
name: "flag_string_map_error",
args: []string{"app", "--f_string_map=err="},
err: fmt.Errorf("error string map"),
},
}

for _, test := range tests {
Expand Down
9 changes: 9 additions & 0 deletions flag-spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,12 @@ flag_types:
type: bool
- name: Action
type: "func(*Context, Path) error"
StringMap:
value_pointer: true
skip_interfaces:
- fmt.Stringer
struct_fields:
- name: TakesFile
type: bool
- name: Action
type: "func(*Context, map[string]string) error"
5 changes: 3 additions & 2 deletions flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ import (
const defaultPlaceholder = "value"

var (
defaultSliceFlagSeparator = ","
disableSliceFlagSeparator = false
defaultSliceFlagSeparator = ","
defaultMapFlagKeyValueSeparator = "="
disableSliceFlagSeparator = false
)

var (
Expand Down
195 changes: 195 additions & 0 deletions flag_string_map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
package cli

import (
"encoding/json"
"flag"
"fmt"
"strconv"
"strings"
)

// StringMap wraps a map[string]string to satisfy flag.Value
type StringMap struct {
dict map[string]string
hasBeenSet bool
}

// NewStringMap creates a *StringMap with default values
func NewStringMap(defaults map[string]string) *StringMap {
return &StringMap{dict: copyStringDict(defaults)}
}

// clone creates a deep copy the object
func (m *StringMap) clone() *StringMap {
return &StringMap{
dict: copyStringDict(m.dict),
hasBeenSet: m.hasBeenSet,
}
}

func copyStringDict(in map[string]string) map[string]string {
out := make(map[string]string, len(in))
for k, v := range in {
out[k] = v
}
return out
}

// Set appends a key and value to the map
func (m *StringMap) Set(value string) error {
if !m.hasBeenSet {
m.dict = make(map[string]string)
m.hasBeenSet = true
}

if strings.HasPrefix(value, slPfx) {
// Deserializing assumes overwrite
_ = json.Unmarshal([]byte(strings.Replace(value, slPfx, "", 1)), &m.dict)
m.hasBeenSet = true
return nil
}

for _, item := range flagSplitMultiValues(value) {
key, value, ok := strings.Cut(item, defaultMapFlagKeyValueSeparator)
if !ok {
return fmt.Errorf("item %q is missing separator %q", item, defaultMapFlagKeyValueSeparator)
}
if key == "" {
return fmt.Errorf("must provide key for StringMap")
}
m.dict[key] = value
}

return nil
}

// String returns a readable representation of this value (for usage defaults)
func (m *StringMap) String() string {
if len(m.dict) == 0 {
return ""
}

var sb strings.Builder
for k, v := range m.dict {
sb.WriteString(k + defaultMapFlagKeyValueSeparator + v + ",")
}
s := sb.String()
return s[:len(s)-1]
}

// Serialize allows StringMap to fulfill Serializer
func (m *StringMap) Serialize() string {
jsonBytes, _ := json.Marshal(m.dict)
return fmt.Sprintf("%s%s", slPfx, string(jsonBytes))
}

// Value returns the map of strings set by this flag
func (m *StringMap) Value() map[string]string {
return m.dict
}

// Get returns the flag structure
func (m *StringMap) Get() interface{} {
return *m
}

// String returns a readable representation of this value
// (for usage defaults)
func (f *StringMapFlag) String() string {
return FlagStringer(f)
}

// GetValue returns the flags value as string representation and an empty
// string if the flag takes no value at all.
func (f *StringMapFlag) GetValue() string {
var defaultVals []string
if f.Value != nil && len(f.Value.Value()) > 0 {
for k, v := range f.Value.Value() {
kv := strconv.Quote(k) + defaultMapFlagKeyValueSeparator + strconv.Quote(v)
defaultVals = append(defaultVals, kv)
}
}

return strings.Join(defaultVals, ", ")
}

// GetDefaultText returns the default text for this flag
func (f *StringMapFlag) GetDefaultText() string {
if f.DefaultText != "" {
return f.DefaultText
}
return f.GetValue()
}

// IsSliceFlag implements DocGenerationSliceFlag.
func (f *StringMapFlag) IsSliceFlag() bool {
return true
}

// Apply populates the flag given the flag set and environment
func (f *StringMapFlag) Apply(set *flag.FlagSet) error {
if f.Destination != nil && f.Value != nil {
f.Destination.dict = copyStringDict(f.Value.dict)
}

// resolve setValue (what we will assign to the set)
var setValue *StringMap
switch {
case f.Destination != nil:
setValue = f.Destination
case f.Value != nil:
setValue = f.Value.clone()
default:
setValue = new(StringMap)
}

if val, source, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found {
if err := setValue.Set(strings.TrimSpace(val)); err != nil {
return fmt.Errorf("could not parse %q as string value from %s for flag %s: %s", val, source, f.Name, err)
}

// Set this to false so that we reset the map if we then set values from
// flags that have already been set by the environment.
setValue.hasBeenSet = true
f.HasBeenSet = true
}

for _, name := range f.Names() {
set.Var(setValue, name, f.Usage)
}

return nil
}

// Get returns the flag’s value in the given Context.
func (f *StringMapFlag) Get(ctx *Context) map[string]string {
return ctx.StringMap(f.Name)
}

// RunAction executes flag action if set
func (f *StringMapFlag) RunAction(c *Context) error {
if f.Action != nil {
return f.Action(c, c.StringMap(f.Name))
}

return nil
}

// StringMap looks up the value of a local StringMapFlag, returns
// nil if not found
func (cCtx *Context) StringMap(name string) map[string]string {
if fs := cCtx.lookupFlagSet(name); fs != nil {
return lookupStringMap(name, fs)
}
return nil
}

func lookupStringMap(name string, set *flag.FlagSet) map[string]string {
f := set.Lookup(name)
if f != nil {
if mapping, ok := unwrapFlagValue(f.Value).(*StringMap); ok {
return mapping.Value()
}
}
return nil
}
Loading

0 comments on commit d32d928

Please sign in to comment.