Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(flag_groups): flag groups implementation improved #1775

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ type Command struct {
// that we can use on every pflag set and children commands
globNormFunc func(f *flag.FlagSet, name string) flag.NormalizedName

// flagGroups is the list of groups that contain grouped names of flags.
// Groups are like "relationships" between flags that allow to validate
// flags and adjust completions taking into account these "relationships".
flagGroups []flagGroup

// usageFunc is usage func defined by user.
usageFunc func(*Command) error
// usageTemplate is usage template defined by user.
Expand Down
4 changes: 2 additions & 2 deletions completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,8 @@ func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDi
var completions []string
var directive ShellCompDirective

// Enforce flag groups before doing flag completions
finalCmd.enforceFlagGroupsForCompletion()
// Allow flagGroups to update the command to improve completions
finalCmd.adjustByFlagGroupsForCompletions()

// Note that we want to perform flagname completion even if finalCmd.DisableFlagParsing==true;
// doing this allows for completion of persistent flag names even for commands that disable flag parsing.
Expand Down
288 changes: 130 additions & 158 deletions flag_groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,209 +16,181 @@ package cobra

import (
"fmt"
"sort"
"strings"

flag "github.com/spf13/pflag"
)

const (
requiredAsGroup = "cobra_annotation_required_if_others_set"
mutuallyExclusive = "cobra_annotation_mutually_exclusive"
)

// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
// if the command is invoked with a subset (but not all) of the given flags.
// MarkFlagsRequiredTogether creates a relationship between flags, which ensures
// that if any of flags with names from flagNames is set, other flags must be set too.
func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
c.mergePersistentFlags()
for _, v := range flagNames {
f := c.Flags().Lookup(v)
if f == nil {
panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v))
}
if err := c.Flags().SetAnnotation(v, requiredAsGroup, append(f.Annotations[requiredAsGroup], strings.Join(flagNames, " "))); err != nil {
// Only errs if the flag isn't found.
panic(err)
}
}
c.addFlagGroup(&requiredTogetherFlagGroup{
flagNames: flagNames,
})
}

// MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors
// if the command is invoked with more than one flag from the given set of flags.
// MarkFlagsMutuallyExclusive creates a relationship between flags, which ensures
// that if any of flags with names from flagNames is set, other flags must not be set.
func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
c.addFlagGroup(&mutuallyExclusiveFlagGroup{
flagNames: flagNames,
})
}

// addFlagGroup merges persistent flags of the command and adds flagGroup into command's flagGroups list.
// Panics, if flagGroup g contains the name of the flag, which is not defined in the Command c.
func (c *Command) addFlagGroup(g flagGroup) {
c.mergePersistentFlags()
for _, v := range flagNames {
f := c.Flags().Lookup(v)
if f == nil {
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v))
}
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
if err := c.Flags().SetAnnotation(v, mutuallyExclusive, append(f.Annotations[mutuallyExclusive], strings.Join(flagNames, " "))); err != nil {
panic(err)

for _, flagName := range g.AssignedFlagNames() {
if c.Flags().Lookup(flagName) == nil {
panic(fmt.Sprintf("flag %q is not defined", flagName))
}
}

c.flagGroups = append(c.flagGroups, g)
}

// ValidateFlagGroups validates the mutuallyExclusive/requiredAsGroup logic and returns the
// first error encountered.
// ValidateFlagGroups runs validation for each group from command's flagGroups list,
// and returns the first error encountered, or nil, if there were no validation errors.
func (c *Command) ValidateFlagGroups() error {
if c.DisableFlagParsing {
return nil
setFlags := makeSetFlagsSet(c.Flags())
for _, group := range c.flagGroups {
if err := group.ValidateSetFlags(setFlags); err != nil {
return err
}
}
return nil
}

flags := c.Flags()

// groupStatus format is the list of flags as a unique ID,
// then a map of each flag name and whether it is set or not.
groupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
flags.VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
})

if err := validateRequiredFlagGroups(groupStatus); err != nil {
return err
// adjustByFlagGroupsForCompletions changes the command by each flagGroup from command's flagGroups list
// to make the further command completions generation more convenient.
// Does nothing, if Command.DisableFlagParsing is true.
func (c *Command) adjustByFlagGroupsForCompletions() {
if c.DisableFlagParsing {
return
}
if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
return err

for _, group := range c.flagGroups {
group.AdjustCommandForCompletions(c)
}
return nil
}

func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool {
for _, fname := range flagnames {
f := fs.Lookup(fname)
if f == nil {
return false
}
}
return true
type flagGroup interface {
// ValidateSetFlags checks whether the combination of flags that have been set is valid.
// If not, an error is returned.
ValidateSetFlags(setFlags setFlagsSet) error

// AssignedFlagNames returns a full list of flag names that have been assigned to the group.
AssignedFlagNames() []string

// AdjustCommandForCompletions updates the command to generate more convenient for this group completions.
AdjustCommandForCompletions(c *Command)
}

func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) {
groupInfo, found := pflag.Annotations[annotation]
if found {
for _, group := range groupInfo {
if groupStatus[group] == nil {
flagnames := strings.Split(group, " ")
// requiredTogetherFlagGroup groups flags that are required together and
// must all be set, if any of flags from this group is set.
type requiredTogetherFlagGroup struct {
flagNames []string
}

// Only consider this flag group at all if all the flags are defined.
if !hasAllFlags(flags, flagnames...) {
continue
}
func (g *requiredTogetherFlagGroup) AssignedFlagNames() []string {
return g.flagNames
}
func (g *requiredTogetherFlagGroup) ValidateSetFlags(setFlags setFlagsSet) error {
unset := setFlags.selectUnsetFlagNamesFrom(g.flagNames)

groupStatus[group] = map[string]bool{}
for _, name := range flagnames {
groupStatus[group][name] = false
}
}
if unsetCount := len(unset); unsetCount != 0 && unsetCount != len(g.flagNames) {
return fmt.Errorf("flags %v must be set together, but %v were not set", g.flagNames, unset)
}

groupStatus[group][pflag.Name] = pflag.Changed
return nil
}
func (g *requiredTogetherFlagGroup) AdjustCommandForCompletions(c *Command) {
setFlags := makeSetFlagsSet(c.Flags())
if setFlags.hasAnyFrom(g.flagNames) {
for _, requiredFlagName := range g.flagNames {
_ = c.MarkFlagRequired(requiredFlagName)
}
}
}

func validateRequiredFlagGroups(data map[string]map[string]bool) error {
keys := sortedKeys(data)
for _, flagList := range keys {
flagnameAndStatus := data[flagList]
// mutuallyExclusiveFlagGroup groups flags that are mutually exclusive
// and must not be set together, if any of flags from this group is set.
type mutuallyExclusiveFlagGroup struct {
flagNames []string
}

unset := []string{}
for flagname, isSet := range flagnameAndStatus {
if !isSet {
unset = append(unset, flagname)
}
}
if len(unset) == len(flagnameAndStatus) || len(unset) == 0 {
continue
}
func (g *mutuallyExclusiveFlagGroup) AssignedFlagNames() []string {
return g.flagNames
}
func (g *mutuallyExclusiveFlagGroup) ValidateSetFlags(setFlags setFlagsSet) error {
set := setFlags.selectSetFlagNamesFrom(g.flagNames)

// Sort values, so they can be tested/scripted against consistently.
sort.Strings(unset)
return fmt.Errorf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset)
if len(set) > 1 {
return fmt.Errorf("exactly one of the flags %v can be set, but %v were set", g.flagNames, set)
}

return nil
}

func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
keys := sortedKeys(data)
for _, flagList := range keys {
flagnameAndStatus := data[flagList]
var set []string
for flagname, isSet := range flagnameAndStatus {
if isSet {
set = append(set, flagname)
func (g *mutuallyExclusiveFlagGroup) AdjustCommandForCompletions(c *Command) {
setFlags := makeSetFlagsSet(c.Flags())
firstSetFlagName, hasAny := setFlags.selectFirstSetFlagNameFrom(g.flagNames)
if hasAny {
for _, exclusiveFlagName := range g.flagNames {
if exclusiveFlagName != firstSetFlagName {
c.Flags().Lookup(exclusiveFlagName).Hidden = true
}
}
if len(set) == 0 || len(set) == 1 {
continue
}

// Sort values, so they can be tested/scripted against consistently.
sort.Strings(set)
return fmt.Errorf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set)
}
return nil
}

func sortedKeys(m map[string]map[string]bool) []string {
keys := make([]string, len(m))
i := 0
for k := range m {
keys[i] = k
i++
}
sort.Strings(keys)
return keys
}
// setFlagsSet is a helper set type that is intended to be used to store names of the flags
// that have been set in flag.FlagSet and to perform some lookups and checks on those flags.
type setFlagsSet map[string]struct{}

// enforceFlagGroupsForCompletion will do the following:
// - when a flag in a group is present, other flags in the group will be marked required
// - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden
// This allows the standard completion logic to behave appropriately for flag groups
func (c *Command) enforceFlagGroupsForCompletion() {
if c.DisableFlagParsing {
return
}
// makeSetFlagsSet creates setFlagsSet of names of the flags that have been set in the given flag.FlagSet.
func makeSetFlagsSet(fs *flag.FlagSet) setFlagsSet {
s := make(setFlagsSet)

flags := c.Flags()
groupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
c.Flags().VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
// Visit flags that have been set and add them to the set
fs.Visit(func(f *flag.Flag) {
s[f.Name] = struct{}{}
})

// If a flag that is part of a group is present, we make all the other flags
// of that group required so that the shell completion suggests them automatically
for flagList, flagnameAndStatus := range groupStatus {
for _, isSet := range flagnameAndStatus {
if isSet {
// One of the flags of the group is set, mark the other ones as required
for _, fName := range strings.Split(flagList, " ") {
_ = c.MarkFlagRequired(fName)
}
}
return s
}
func (s setFlagsSet) has(flagName string) bool {
_, ok := s[flagName]
return ok
}
func (s setFlagsSet) hasAnyFrom(flagNames []string) bool {
for _, flagName := range flagNames {
if s.has(flagName) {
return true
}
}

// If a flag that is mutually exclusive to others is present, we hide the other
// flags of that group so the shell completion does not suggest them
for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus {
for flagName, isSet := range flagnameAndStatus {
if isSet {
// One of the flags of the mutually exclusive group is set, mark the other ones as hidden
// Don't mark the flag that is already set as hidden because it may be an
// array or slice flag and therefore must continue being suggested
for _, fName := range strings.Split(flagList, " ") {
if fName != flagName {
flag := c.Flags().Lookup(fName)
flag.Hidden = true
}
}
}
return false
}
func (s setFlagsSet) selectFirstSetFlagNameFrom(flagNames []string) (string, bool) {
for _, flagName := range flagNames {
if s.has(flagName) {
return flagName, true
}
}
return "", false
}
func (s setFlagsSet) selectSetFlagNamesFrom(flagNames []string) (setFlagNames []string) {
for _, flagName := range flagNames {
if s.has(flagName) {
setFlagNames = append(setFlagNames, flagName)
}
}
return
}
func (s setFlagsSet) selectUnsetFlagNamesFrom(flagNames []string) (unsetFlagNames []string) {
for _, flagName := range flagNames {
if !s.has(flagName) {
unsetFlagNames = append(unsetFlagNames, flagName)
}
}
return
}
Loading