diff --git a/context.go b/context.go index 5d43de016b..312efb5b38 100644 --- a/context.go +++ b/context.go @@ -105,13 +105,11 @@ func (cCtx *Context) Lineage() []*Context { return lineage } -// NumOccurrences returns the num of occurences of this flag +// Count returns the num of occurences of this flag func (cCtx *Context) Count(name string) int { if fs := cCtx.lookupFlagSet(name); fs != nil { - if bf, ok := fs.Lookup(name).Value.(*boolValue); ok { - if bf.count != nil { - return *bf.count - } + if cf, ok := fs.Lookup(name).Value.(Countable); ok { + return cf.Count() } } return 0 diff --git a/flag.go b/flag.go index 4f0871d332..1618b4da86 100644 --- a/flag.go +++ b/flag.go @@ -124,6 +124,12 @@ type Flag interface { GetValue() string } +// Countable is an interface to enable detection of flag values which support +// repetitive flags +type Countable interface { + Count() int +} + func flagSet(name string, flags []Flag) (*flag.FlagSet, error) { set := flag.NewFlagSet(name, flag.ContinueOnError) diff --git a/flag_bool.go b/flag_bool.go index 08705766f4..e5fa700e1a 100644 --- a/flag_bool.go +++ b/flag_bool.go @@ -51,6 +51,13 @@ func (b *boolValue) String() string { func (b *boolValue) IsBoolFlag() bool { return true } +func (b *boolValue) Count() int { + if b.count != nil { + return *b.count + } + return 0 +} + // GetValue returns the flags value as string representation and an empty // string if the flag takes no value at all. func (f *BoolFlag) GetValue() string {