diff --git a/app_test.go b/app_test.go index 00e980d29c..ba423d5326 100644 --- a/app_test.go +++ b/app_test.go @@ -3234,3 +3234,64 @@ func TestPersistentFlag(t *testing.T) { } } + +func TestFlagDuplicates(t *testing.T) { + + a := &App{ + Flags: []Flag{ + &StringFlag{ + Name: "sflag", + OnlyOnce: true, + }, + &Int64SliceFlag{ + Name: "isflag", + }, + &Float64SliceFlag{ + Name: "fsflag", + OnlyOnce: true, + }, + &IntFlag{ + Name: "iflag", + }, + }, + Action: func(ctx *Context) error { + return nil + }, + } + + tests := []struct { + name string + args []string + errExpected bool + }{ + { + name: "all args present once", + args: []string{"foo", "--sflag", "hello", "--isflag", "1", "--isflag", "2", "--fsflag", "2.0", "--iflag", "10"}, + }, + { + name: "duplicate non slice flag(duplicatable)", + args: []string{"foo", "--sflag", "hello", "--isflag", "1", "--isflag", "2", "--fsflag", "2.0", "--iflag", "10", "--iflag", "20"}, + }, + { + name: "duplicate non slice flag(non duplicatable)", + args: []string{"foo", "--sflag", "hello", "--isflag", "1", "--isflag", "2", "--fsflag", "2.0", "--iflag", "10", "--sflag", "trip"}, + errExpected: true, + }, + { + name: "duplicate slice flag(non duplicatable)", + args: []string{"foo", "--sflag", "hello", "--isflag", "1", "--isflag", "2", "--fsflag", "2.0", "--fsflag", "3.0", "--iflag", "10"}, + errExpected: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := a.Run(test.args) + if test.errExpected && err == nil { + t.Error("expected error") + } else if !test.errExpected && err != nil { + t.Error(err) + } + }) + } +} diff --git a/flag_float64_slice.go b/flag_float64_slice.go index fb8b56c33d..44d7132091 100644 --- a/flag_float64_slice.go +++ b/flag_float64_slice.go @@ -21,8 +21,8 @@ func (cCtx *Context) Float64Slice(name string) []float64 { func lookupFloat64Slice(name string, set *flag.FlagSet) []float64 { f := set.Lookup(name) if f != nil { - if slice, ok := f.Value.(*Float64Slice); ok { - return slice.Value() + if slice, ok := f.Value.(flag.Getter).Get().([]float64); ok { + return slice } } return nil diff --git a/flag_impl.go b/flag_impl.go index 855e40ee92..a8530d88e8 100644 --- a/flag_impl.go +++ b/flag_impl.go @@ -13,6 +13,52 @@ type Value interface { flag.Getter } +// simple wrapper to intercept Value operations +// to check for duplicates +type valueWrapper struct { + value Value + count int + onlyOnce bool +} + +func (v *valueWrapper) String() string { + if v.value == nil { + return "" + } + return v.value.String() +} + +func (v *valueWrapper) Set(s string) error { + if v.count == 1 && v.onlyOnce { + return fmt.Errorf("cant duplicate this flag") + } + v.count++ + return v.value.Set(s) +} + +func (v *valueWrapper) Get() any { + return v.value.Get() +} + +func (v *valueWrapper) IsBoolFlag() bool { + _, ok := v.value.(*boolValue) + return ok +} + +func (v *valueWrapper) Serialize() string { + if s, ok := v.value.(Serializer); ok { + return s.Serialize() + } + return v.value.String() +} + +func (v *valueWrapper) Count() int { + if s, ok := v.value.(Countable); ok { + return s.Count() + } + return 0 +} + // ValueCreator is responsible for creating a flag.Value emulation // as well as custom formatting // @@ -57,6 +103,8 @@ type FlagBase[T any, C any, VC ValueCreator[T, C]] struct { Config C // Additional/Custom configuration associated with this flag type + OnlyOnce bool // whether this flag can be duplicated on the command line + // unexported fields for internal use hasBeenSet bool // whether the flag has been set from env or file applied bool // whether the flag has been applied to a flag set already @@ -108,8 +156,13 @@ func (f *FlagBase[T, C, V]) Apply(set *flag.FlagSet) error { } } + vw := &valueWrapper{ + value: f.value, + onlyOnce: f.OnlyOnce, + } + for _, name := range f.Names() { - set.Var(f.value, name, f.Usage) + set.Var(vw, name, f.Usage) } f.applied = true diff --git a/flag_int64_slice.go b/flag_int64_slice.go index 51db8e8c1b..152b9980a0 100644 --- a/flag_int64_slice.go +++ b/flag_int64_slice.go @@ -21,8 +21,8 @@ func (cCtx *Context) Int64Slice(name string) []int64 { func lookupInt64Slice(name string, set *flag.FlagSet) []int64 { f := set.Lookup(name) if f != nil { - if slice, ok := f.Value.(*Int64Slice); ok { - return slice.Value() + if slice, ok := f.Value.(flag.Getter).Get().([]int64); ok { + return slice } } return nil diff --git a/flag_int_slice.go b/flag_int_slice.go index 3737f0071c..95451ad720 100644 --- a/flag_int_slice.go +++ b/flag_int_slice.go @@ -19,8 +19,8 @@ func (cCtx *Context) IntSlice(name string) []int { func lookupIntSlice(name string, set *flag.FlagSet) []int { f := set.Lookup(name) if f != nil { - if slice, ok := f.Value.(*IntSlice); ok { - return slice.Value() + if slice, ok := f.Value.(flag.Getter).Get().([]int); ok { + return slice } } return nil diff --git a/flag_string_map.go b/flag_string_map.go index b4fbb3deb7..f75ed37201 100644 --- a/flag_string_map.go +++ b/flag_string_map.go @@ -19,8 +19,8 @@ func (cCtx *Context) StringMap(name string) map[string]string { func lookupStringMap(name string, set *flag.FlagSet) map[string]string { f := set.Lookup(name) if f != nil { - if mapping, ok := f.Value.(*StringMap); ok { - return mapping.Value() + if mapping, ok := f.Value.(flag.Getter).Get().(map[string]string); ok { + return mapping } } return nil diff --git a/flag_string_slice.go b/flag_string_slice.go index 2e379a382f..19e159320b 100644 --- a/flag_string_slice.go +++ b/flag_string_slice.go @@ -21,8 +21,8 @@ func (cCtx *Context) StringSlice(name string) []string { func lookupStringSlice(name string, set *flag.FlagSet) []string { f := set.Lookup(name) if f != nil { - if slice, ok := f.Value.(*StringSlice); ok { - return slice.Value() + if slice, ok := f.Value.(flag.Getter).Get().([]string); ok { + return slice } } return nil diff --git a/flag_test.go b/flag_test.go index a377bdf30d..b4eec1cbc4 100644 --- a/flag_test.go +++ b/flag_test.go @@ -602,7 +602,7 @@ func TestStringSliceFlagApply_UsesEnvValues_noDefault(t *testing.T) { err := set.Parse(nil) expect(t, err, nil) - expect(t, set.Lookup("goat").Value.(*StringSlice).Value(), []string{"vincent van goat", "scape goat"}) + expect(t, set.Lookup("goat").Value.(flag.Getter).Get(), []string{"vincent van goat", "scape goat"}) } func TestStringSliceFlagApply_UsesEnvValues_withDefault(t *testing.T) { @@ -615,7 +615,7 @@ func TestStringSliceFlagApply_UsesEnvValues_withDefault(t *testing.T) { _ = fl.Apply(set) err := set.Parse(nil) expect(t, err, nil) - expect(t, set.Lookup("goat").Value.(*StringSlice).Value(), []string{"vincent van goat", "scape goat"}) + expect(t, set.Lookup("goat").Value.(flag.Getter).Get(), []string{"vincent van goat", "scape goat"}) } func TestStringSliceFlagApply_DefaultValueWithDestination(t *testing.T) { @@ -959,7 +959,7 @@ func TestIntSliceFlagApply_UsesEnvValues_noDefault(t *testing.T) { err := set.Parse(nil) expect(t, err, nil) - expect(t, set.Lookup("goat").Value.(*IntSlice).Value(), []int{1, 2}) + expect(t, set.Lookup("goat").Value.(flag.Getter).Get(), []int{1, 2}) } func TestIntSliceFlagApply_UsesEnvValues_withDefault(t *testing.T) { @@ -973,7 +973,7 @@ func TestIntSliceFlagApply_UsesEnvValues_withDefault(t *testing.T) { err := set.Parse(nil) expect(t, err, nil) expect(t, val, []int{3, 4}) - expect(t, set.Lookup("goat").Value.(*IntSlice).Value(), []int{1, 2}) + expect(t, set.Lookup("goat").Value.(flag.Getter).Get(), []int{1, 2}) } func TestIntSliceFlagApply_DefaultValueWithDestination(t *testing.T) { @@ -1098,7 +1098,7 @@ func TestInt64SliceFlagApply_UsesEnvValues_noDefault(t *testing.T) { err := set.Parse(nil) expect(t, err, nil) - expect(t, set.Lookup("goat").Value.(*Int64Slice).Value(), []int64{1, 2}) + expect(t, set.Lookup("goat").Value.(flag.Getter).Get(), []int64{1, 2}) } func TestInt64SliceFlagApply_UsesEnvValues_withDefault(t *testing.T) { @@ -1112,7 +1112,7 @@ func TestInt64SliceFlagApply_UsesEnvValues_withDefault(t *testing.T) { err := set.Parse(nil) expect(t, err, nil) expect(t, val.Value(), []int64{3, 4}) - expect(t, set.Lookup("goat").Value.(*Int64Slice).Value(), []int64{1, 2}) + expect(t, set.Lookup("goat").Value.(flag.Getter).Get().([]int64), []int64{1, 2}) } func TestInt64SliceFlagApply_DefaultValueWithDestination(t *testing.T) { @@ -1254,7 +1254,7 @@ func TestUintSliceFlagApply_UsesEnvValues_noDefault(t *testing.T) { err := set.Parse(nil) expect(t, err, nil) - expect(t, set.Lookup("goat").Value.(*UintSlice).Value(), []uint{1, 2}) + expect(t, set.Lookup("goat").Value.(flag.Getter).Get().([]uint), []uint{1, 2}) } func TestUintSliceFlagApply_UsesEnvValues_withDefault(t *testing.T) { @@ -1268,7 +1268,7 @@ func TestUintSliceFlagApply_UsesEnvValues_withDefault(t *testing.T) { err := set.Parse(nil) expect(t, err, nil) expect(t, val.Value(), []uint{3, 4}) - expect(t, set.Lookup("goat").Value.(*UintSlice).Value(), []uint{1, 2}) + expect(t, set.Lookup("goat").Value.(flag.Getter).Get().([]uint), []uint{1, 2}) } func TestUintSliceFlagApply_DefaultValueWithDestination(t *testing.T) { @@ -1401,7 +1401,7 @@ func TestUint64SliceFlagApply_UsesEnvValues_noDefault(t *testing.T) { err := set.Parse(nil) expect(t, err, nil) - expect(t, set.Lookup("goat").Value.(*Uint64Slice).Value(), []uint64{1, 2}) + expect(t, set.Lookup("goat").Value.(flag.Getter).Get().([]uint64), []uint64{1, 2}) } func TestUint64SliceFlagApply_UsesEnvValues_withDefault(t *testing.T) { @@ -1414,7 +1414,7 @@ func TestUint64SliceFlagApply_UsesEnvValues_withDefault(t *testing.T) { _ = fl.Apply(set) err := set.Parse(nil) expect(t, err, nil) - expect(t, set.Lookup("goat").Value.(*Uint64Slice).Value(), []uint64{1, 2}) + expect(t, set.Lookup("goat").Value.(flag.Getter).Get().([]uint64), []uint64{1, 2}) } func TestUint64SliceFlagApply_DefaultValueWithDestination(t *testing.T) { @@ -1601,7 +1601,7 @@ func TestFloat64SliceFlagApply_UsesEnvValues_noDefault(t *testing.T) { err := set.Parse(nil) expect(t, err, nil) - expect(t, set.Lookup("goat").Value.(*Float64Slice).Value(), []float64{1, 2}) + expect(t, set.Lookup("goat").Value.(flag.Getter).Get().([]float64), []float64{1, 2}) } func TestFloat64SliceFlagApply_UsesEnvValues_withDefault(t *testing.T) { @@ -1614,7 +1614,7 @@ func TestFloat64SliceFlagApply_UsesEnvValues_withDefault(t *testing.T) { _ = fl.Apply(set) err := set.Parse(nil) expect(t, err, nil) - expect(t, set.Lookup("goat").Value.(*Float64Slice).Value(), []float64{1, 2}) + expect(t, set.Lookup("goat").Value.(flag.Getter).Get().([]float64), []float64{1, 2}) } func TestFloat64SliceFlagApply_DefaultValueWithDestination(t *testing.T) { @@ -2708,7 +2708,7 @@ func TestTimestampFlagApply(t *testing.T) { err := set.Parse([]string{"--time", "2006-01-02T15:04:05Z"}) expect(t, err, nil) - expect(t, *set.Lookup("time").Value.(*timestampValue).timestamp, expectedResult) + expect(t, set.Lookup("time").Value.(flag.Getter).Get(), expectedResult) } func TestTimestampFlagApplyValue(t *testing.T) { @@ -2719,7 +2719,7 @@ func TestTimestampFlagApplyValue(t *testing.T) { err := set.Parse([]string{""}) expect(t, err, nil) - expect(t, *set.Lookup("time").Value.(*timestampValue).timestamp, expectedResult) + expect(t, set.Lookup("time").Value.(flag.Getter).Get(), expectedResult) } func TestTimestampFlagApply_Fail_Parse_Wrong_Layout(t *testing.T) { @@ -2751,7 +2751,7 @@ func TestTimestampFlagApply_Timezoned(t *testing.T) { err := set.Parse([]string{"--time", "Mon Jan 2 08:04:05 2006"}) expect(t, err, nil) - expect(t, *set.Lookup("time").Value.(*timestampValue).timestamp, expectedResult.In(pdt)) + expect(t, set.Lookup("time").Value.(flag.Getter).Get(), expectedResult.In(pdt)) } func TestTimestampFlagValueFromContext(t *testing.T) { @@ -3230,7 +3230,7 @@ func TestStringMapFlagApply_UsesEnvValues_noDefault(t *testing.T) { err := set.Parse(nil) expect(t, err, nil) expect(t, val, map[string]string(nil)) - expect(t, set.Lookup("goat").Value.(*StringMap).Value(), map[string]string{"vincent van goat": "scape goat"}) + expect(t, set.Lookup("goat").Value.(flag.Getter).Get(), map[string]string{"vincent van goat": "scape goat"}) } func TestStringMapFlagApply_UsesEnvValues_withDefault(t *testing.T) { @@ -3244,7 +3244,7 @@ func TestStringMapFlagApply_UsesEnvValues_withDefault(t *testing.T) { err := set.Parse(nil) expect(t, err, nil) expect(t, val, map[string]string{`some default`: `values here`}) - expect(t, set.Lookup("goat").Value.(*StringMap).Value(), map[string]string{"vincent van goat": "scape goat"}) + expect(t, set.Lookup("goat").Value.(flag.Getter).Get(), map[string]string{"vincent van goat": "scape goat"}) } func TestStringMapFlagApply_DefaultValueWithDestination(t *testing.T) { diff --git a/flag_uint64_slice.go b/flag_uint64_slice.go index 4a6453fdc4..eb30dbff93 100644 --- a/flag_uint64_slice.go +++ b/flag_uint64_slice.go @@ -21,8 +21,8 @@ func (cCtx *Context) Uint64Slice(name string) []uint64 { func lookupUint64Slice(name string, set *flag.FlagSet) []uint64 { f := set.Lookup(name) if f != nil { - if slice, ok := f.Value.(*Uint64Slice); ok { - return slice.Value() + if slice, ok := f.Value.(flag.Getter).Get().([]uint64); ok { + return slice } } return nil diff --git a/flag_uint_slice.go b/flag_uint_slice.go index 03f6ad4cfa..0d64f9c0dc 100644 --- a/flag_uint_slice.go +++ b/flag_uint_slice.go @@ -21,8 +21,8 @@ func (cCtx *Context) UintSlice(name string) []uint { func lookupUintSlice(name string, set *flag.FlagSet) []uint { f := set.Lookup(name) if f != nil { - if slice, ok := f.Value.(*UintSlice); ok { - return slice.Value() + if slice, ok := f.Value.(flag.Getter).Get().([]uint); ok { + return slice } } return nil diff --git a/godoc-current.txt b/godoc-current.txt index 9d7be183a0..1aebcd6822 100644 --- a/godoc-current.txt +++ b/godoc-current.txt @@ -762,6 +762,8 @@ type FlagBase[T any, C any, VC ValueCreator[T, C]] struct { Config C // Additional/Custom configuration associated with this flag type + OnlyOnce bool // whether this flag can be duplicated on the command line + // Has unexported fields. } FlagBase[T,C,VC] is a generic flag base which can be used as a boilerplate diff --git a/testdata/godoc-v3.x.txt b/testdata/godoc-v3.x.txt index 9d7be183a0..1aebcd6822 100644 --- a/testdata/godoc-v3.x.txt +++ b/testdata/godoc-v3.x.txt @@ -762,6 +762,8 @@ type FlagBase[T any, C any, VC ValueCreator[T, C]] struct { Config C // Additional/Custom configuration associated with this flag type + OnlyOnce bool // whether this flag can be duplicated on the command line + // Has unexported fields. } FlagBase[T,C,VC] is a generic flag base which can be used as a boilerplate