Skip to content

Commit

Permalink
refactor(flag_groups): flag groups implementation changed
Browse files Browse the repository at this point in the history
This commit changes the flag groups feature logic. New implementation is more clean, readable and extendable (hope it won't be just my opinion).

The following changes have been made:

1. Main change:
Flags annotating by "cobra_annotation_required_if_others_set" and "cobra_annotation_mutually_exclusive" annotations was removed as well as all related and hard-to-understand "hacks" to combine flags back into groups on validation process.
Instead, `flagGroups` field was added to the `Command` struct. `flagGroups` field is a list of (new) structs `flagGroup`, which represents the "relationships" between flags within the command.

2. "Required together" and "mutually exclusive" groups logic was updated by implementing `requiredTogetherFlagGroup` and `mutuallyExclusiveFlagGroup` `flagGroup`s.

3. `enforceFlagGroupsForCompletion` `Command`'s method was renamed to `adjustByFlagGroupsForCompletions`.

4. Groups failed validation error messages were changed:
  - `"if any flags in the group [...] are set they must all be set; missing [...]"` to `"flags [...] must be set together, but [...] were not set"`
  - `"if any flags in the group [...] are set none of the others can be; [...] were all set"` to `"exactly one of the flags [...] can be set, but [...] were set"`

5. Not found flag on group marking error messages were updated from "Failed to find flag %q and mark it as being required in a flag group" and "Failed to find flag %q and mark it as being in a mutually exclusive flag group" to "flag %q is not defined"

6. `TestValidateFlagGroups` test was updated in `flag_groups_test.go`.
  - `getCmd` function was updated and test flag names were changed to improve readability
  - 2 testcases (`Validation of required groups occurs on groups in sorted order` and `Validation of exclusive groups occurs on groups in sorted order`) were removed, because groups validation now occur in the same order those groups were registered
  - other 16 testcases are preserved with updated descriptions, error messages

The completions generation tests that contain flag groups related testcases and updated flag groups tests, as well as all other tests, have been passed.

API was not changed: `MarkFlagsRequiredTogether` and `MarkFlagsMutuallyExclusive` functions have the same signatures.
  • Loading branch information
evermake committed Oct 7, 2022
1 parent 212ea40 commit 2bec783
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 249 deletions.
5 changes: 5 additions & 0 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ type Command struct {
// that we can use on every pflag set and children commands
globNormFunc func(f *flag.FlagSet, name string) flag.NormalizedName

// flagGroups is the list of groups that contain grouped names of flags.
// Groups are like "relationships" between flags that allow to validate
// flags and adjust completions taking into account these "relationships".
flagGroups []flagGroup

// usageFunc is usage func defined by user.
usageFunc func(*Command) error
// usageTemplate is usage template defined by user.
Expand Down
4 changes: 2 additions & 2 deletions completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,8 @@ func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDi
var completions []string
var directive ShellCompDirective

// Enforce flag groups before doing flag completions
finalCmd.enforceFlagGroupsForCompletion()
// Allow flagGroups to update the command to improve completions
finalCmd.adjustByFlagGroupsForCompletions()

// Note that we want to perform flagname completion even if finalCmd.DisableFlagParsing==true;
// doing this allows for completion of persistent flag names even for commands that disable flag parsing.
Expand Down
288 changes: 130 additions & 158 deletions flag_groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,209 +16,181 @@ package cobra

import (
"fmt"
"sort"
"strings"

flag "github.com/spf13/pflag"
)

const (
requiredAsGroup = "cobra_annotation_required_if_others_set"
mutuallyExclusive = "cobra_annotation_mutually_exclusive"
)

// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
// if the command is invoked with a subset (but not all) of the given flags.
// MarkFlagsRequiredTogether creates a relationship between flags, which ensures
// that if any of flags with names from flagNames is set, other flags must be set too.
func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
c.mergePersistentFlags()
for _, v := range flagNames {
f := c.Flags().Lookup(v)
if f == nil {
panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v))
}
if err := c.Flags().SetAnnotation(v, requiredAsGroup, append(f.Annotations[requiredAsGroup], strings.Join(flagNames, " "))); err != nil {
// Only errs if the flag isn't found.
panic(err)
}
}
c.addFlagGroup(&requiredTogetherFlagGroup{
flagNames: flagNames,
})
}

// MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors
// if the command is invoked with more than one flag from the given set of flags.
// MarkFlagsMutuallyExclusive creates a relationship between flags, which ensures
// that if any of flags with names from flagNames is set, other flags must not be set.
func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
c.addFlagGroup(&mutuallyExclusiveFlagGroup{
flagNames: flagNames,
})
}

// addFlagGroup merges persistent flags of the command and adds flagGroup into command's flagGroups list.
// Panics, if flagGroup g contains the name of the flag, which is not defined in the Command c.
func (c *Command) addFlagGroup(g flagGroup) {
c.mergePersistentFlags()
for _, v := range flagNames {
f := c.Flags().Lookup(v)
if f == nil {
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v))

for _, flagName := range g.AssignedFlagNames() {
if c.Flags().Lookup(flagName) == nil {
panic(fmt.Sprintf("flag %q is not defined", flagName))
}
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
if err := c.Flags().SetAnnotation(v, mutuallyExclusive, append(f.Annotations[mutuallyExclusive], strings.Join(flagNames, " "))); err != nil {
panic(err)
}

c.flagGroups = append(c.flagGroups, g)
}

// validateFlagGroups runs validation for each group from command's flagGroups list,
// and returns the first error encountered, or nil, if there were no validation errors.
func (c *Command) validateFlagGroups() error {
setFlags := makeSetFlagsSet(c.Flags())
for _, group := range c.flagGroups {
if err := group.ValidateSetFlags(setFlags); err != nil {
return err
}
}
return nil
}

// ValidateFlagGroups validates the mutuallyExclusive/requiredAsGroup logic and returns the
// first error encountered.
func (c *Command) ValidateFlagGroups() error {
// adjustByFlagGroupsForCompletions changes the command by each flagGroup from command's flagGroups list
// to make the further command completions generation more convenient.
// Does nothing, if Command.DisableFlagParsing is true.
func (c *Command) adjustByFlagGroupsForCompletions() {
if c.DisableFlagParsing {
return nil
return
}

flags := c.Flags()

// groupStatus format is the list of flags as a unique ID,
// then a map of each flag name and whether it is set or not.
groupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
flags.VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
})

if err := validateRequiredFlagGroups(groupStatus); err != nil {
return err
}
if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
return err
for _, group := range c.flagGroups {
group.AdjustCommandForCompletions(c)
}
return nil
}

func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool {
for _, fname := range flagnames {
f := fs.Lookup(fname)
if f == nil {
return false
}
}
return true
type flagGroup interface {
// ValidateSetFlags checks whether the combination of flags that have been set is valid.
// If not, an error is returned.
ValidateSetFlags(setFlags setFlagsSet) error

// AssignedFlagNames returns a full list of flag names that have been assigned to the group.
AssignedFlagNames() []string

// AdjustCommandForCompletions updates the command to generate more convenient for this group completions.
AdjustCommandForCompletions(c *Command)
}

func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) {
groupInfo, found := pflag.Annotations[annotation]
if found {
for _, group := range groupInfo {
if groupStatus[group] == nil {
flagnames := strings.Split(group, " ")
// requiredTogetherFlagGroup groups flags that are required together and
// must all be set, if any of flags from this group is set.
type requiredTogetherFlagGroup struct {
flagNames []string
}

// Only consider this flag group at all if all the flags are defined.
if !hasAllFlags(flags, flagnames...) {
continue
}
func (g *requiredTogetherFlagGroup) AssignedFlagNames() []string {
return g.flagNames
}
func (g *requiredTogetherFlagGroup) ValidateSetFlags(setFlags setFlagsSet) error {
unset := setFlags.selectUnsetFlagNamesFrom(g.flagNames)

groupStatus[group] = map[string]bool{}
for _, name := range flagnames {
groupStatus[group][name] = false
}
}
if unsetCount := len(unset); unsetCount != 0 && unsetCount != len(g.flagNames) {
return fmt.Errorf("flags %v must be set together, but %v were not set", g.flagNames, unset)
}

groupStatus[group][pflag.Name] = pflag.Changed
return nil
}
func (g *requiredTogetherFlagGroup) AdjustCommandForCompletions(c *Command) {
setFlags := makeSetFlagsSet(c.Flags())
if setFlags.hasAnyFrom(g.flagNames) {
for _, requiredFlagName := range g.flagNames {
_ = c.MarkFlagRequired(requiredFlagName)
}
}
}

func validateRequiredFlagGroups(data map[string]map[string]bool) error {
keys := sortedKeys(data)
for _, flagList := range keys {
flagnameAndStatus := data[flagList]
// mutuallyExclusiveFlagGroup groups flags that are mutually exclusive
// and must not be set together, if any of flags from this group is set.
type mutuallyExclusiveFlagGroup struct {
flagNames []string
}

unset := []string{}
for flagname, isSet := range flagnameAndStatus {
if !isSet {
unset = append(unset, flagname)
}
}
if len(unset) == len(flagnameAndStatus) || len(unset) == 0 {
continue
}
func (g *mutuallyExclusiveFlagGroup) AssignedFlagNames() []string {
return g.flagNames
}
func (g *mutuallyExclusiveFlagGroup) ValidateSetFlags(setFlags setFlagsSet) error {
set := setFlags.selectSetFlagNamesFrom(g.flagNames)

// Sort values, so they can be tested/scripted against consistently.
sort.Strings(unset)
return fmt.Errorf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset)
if len(set) > 1 {
return fmt.Errorf("exactly one of the flags %v can be set, but %v were set", g.flagNames, set)
}

return nil
}

func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
keys := sortedKeys(data)
for _, flagList := range keys {
flagnameAndStatus := data[flagList]
var set []string
for flagname, isSet := range flagnameAndStatus {
if isSet {
set = append(set, flagname)
func (g *mutuallyExclusiveFlagGroup) AdjustCommandForCompletions(c *Command) {
setFlags := makeSetFlagsSet(c.Flags())
firstSetFlagName, hasAny := setFlags.selectFirstSetFlagNameFrom(g.flagNames)
if hasAny {
for _, exclusiveFlagName := range g.flagNames {
if exclusiveFlagName != firstSetFlagName {
c.Flags().Lookup(exclusiveFlagName).Hidden = true
}
}
if len(set) == 0 || len(set) == 1 {
continue
}

// Sort values, so they can be tested/scripted against consistently.
sort.Strings(set)
return fmt.Errorf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set)
}
return nil
}

func sortedKeys(m map[string]map[string]bool) []string {
keys := make([]string, len(m))
i := 0
for k := range m {
keys[i] = k
i++
}
sort.Strings(keys)
return keys
}
// setFlagsSet is a helper set type that is intended to be used to store names of the flags
// that have been set in flag.FlagSet and to perform some lookups and checks on those flags.
type setFlagsSet map[string]struct{}

// enforceFlagGroupsForCompletion will do the following:
// - when a flag in a group is present, other flags in the group will be marked required
// - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden
// This allows the standard completion logic to behave appropriately for flag groups
func (c *Command) enforceFlagGroupsForCompletion() {
if c.DisableFlagParsing {
return
}
// makeSetFlagsSet creates setFlagsSet of names of the flags that have been set in the given flag.FlagSet.
func makeSetFlagsSet(fs *flag.FlagSet) setFlagsSet {
s := make(setFlagsSet)

flags := c.Flags()
groupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
c.Flags().VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
// Visit flags that have been set and add them to the set
fs.Visit(func(f *flag.Flag) {
s[f.Name] = struct{}{}
})

// If a flag that is part of a group is present, we make all the other flags
// of that group required so that the shell completion suggests them automatically
for flagList, flagnameAndStatus := range groupStatus {
for _, isSet := range flagnameAndStatus {
if isSet {
// One of the flags of the group is set, mark the other ones as required
for _, fName := range strings.Split(flagList, " ") {
_ = c.MarkFlagRequired(fName)
}
}
return s
}
func (s setFlagsSet) has(flagName string) bool {
_, ok := s[flagName]
return ok
}
func (s setFlagsSet) hasAnyFrom(flagNames []string) bool {
for _, flagName := range flagNames {
if s.has(flagName) {
return true
}
}

// If a flag that is mutually exclusive to others is present, we hide the other
// flags of that group so the shell completion does not suggest them
for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus {
for flagName, isSet := range flagnameAndStatus {
if isSet {
// One of the flags of the mutually exclusive group is set, mark the other ones as hidden
// Don't mark the flag that is already set as hidden because it may be an
// array or slice flag and therefore must continue being suggested
for _, fName := range strings.Split(flagList, " ") {
if fName != flagName {
flag := c.Flags().Lookup(fName)
flag.Hidden = true
}
}
}
return false
}
func (s setFlagsSet) selectFirstSetFlagNameFrom(flagNames []string) (string, bool) {
for _, flagName := range flagNames {
if s.has(flagName) {
return flagName, true
}
}
return "", false
}
func (s setFlagsSet) selectSetFlagNamesFrom(flagNames []string) (setFlagNames []string) {
for _, flagName := range flagNames {
if s.has(flagName) {
setFlagNames = append(setFlagNames, flagName)
}
}
return
}
func (s setFlagsSet) selectUnsetFlagNamesFrom(flagNames []string) (unsetFlagNames []string) {
for _, flagName := range flagNames {
if !s.has(flagName) {
unsetFlagNames = append(unsetFlagNames, flagName)
}
}
return
}
Loading

0 comments on commit 2bec783

Please sign in to comment.