-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ability to mark flags as required or exclusive as a group
This change adds two features for dealing with flags: - requiring flags be provided as a group (or not at all) - requiring flags be mutually exclusive of each other By utilizing the flag annotations we can mark which flag groups a flag is a part of and during the parsing process we track which ones we have seen or not. A flag may be a part of multiple groups. The list of flags and the type of group (required together or exclusive) make it a unique group. Signed-off-by: John Schnake <jschnake@vmware.com>
- Loading branch information
1 parent
bf6cb58
commit c437f2c
Showing
5 changed files
with
354 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
// Copyright © 2022 Steve Francia <spf@spf13.com>. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package cobra | ||
|
||
import ( | ||
"fmt" | ||
"sort" | ||
"strings" | ||
|
||
flag "github.com/spf13/pflag" | ||
) | ||
|
||
const ( | ||
requiredAsGroup = "cobra_annotation_required_if_others_set" | ||
mutuallyExclusive = "cobra_annotation_mutually_exclusive" | ||
) | ||
|
||
// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors | ||
// if the command is invoked with a subset (but not all) of the given flags. | ||
func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) { | ||
c.mergePersistentFlags() | ||
for _, v := range flagNames { | ||
f := c.Flags().Lookup(v) | ||
if f == nil { | ||
panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v)) | ||
} | ||
if err := c.Flags().SetAnnotation(v, requiredAsGroup, append(f.Annotations[requiredAsGroup], strings.Join(flagNames, " "))); err != nil { | ||
// Only errs if the flag isn't found. | ||
panic(err) | ||
} | ||
} | ||
} | ||
|
||
// MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors | ||
// if the command is invoked with more than one flag from the given set of flags. | ||
func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) { | ||
c.mergePersistentFlags() | ||
for _, v := range flagNames { | ||
f := c.Flags().Lookup(v) | ||
if f == nil { | ||
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v)) | ||
} | ||
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed. | ||
if err := c.Flags().SetAnnotation(v, mutuallyExclusive, append(f.Annotations[mutuallyExclusive], strings.Join(flagNames, " "))); err != nil { | ||
panic(err) | ||
} | ||
} | ||
} | ||
|
||
// validateFlagGroups validates the mutuallyExclusive/requiredAsGroup logic and returns the | ||
// first error encountered. | ||
func (c *Command) validateFlagGroups() error { | ||
if c.DisableFlagParsing { | ||
return nil | ||
} | ||
|
||
flags := c.Flags() | ||
|
||
// groupStatus format is the list of flags as a unique ID, | ||
// then a map of each flag name and whether it is set or not. | ||
groupStatus := map[string]map[string]bool{} | ||
mutuallyExclusiveGroupStatus := map[string]map[string]bool{} | ||
flags.VisitAll(func(pflag *flag.Flag) { | ||
processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus) | ||
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus) | ||
}) | ||
|
||
if err := validateRequiredFlagGroups(groupStatus); err != nil { | ||
return err | ||
} | ||
if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil { | ||
return err | ||
} | ||
return nil | ||
} | ||
|
||
func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool { | ||
for _, fname := range flagnames { | ||
f := fs.Lookup(fname) | ||
if f == nil { | ||
return false | ||
} | ||
} | ||
return true | ||
} | ||
|
||
func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) { | ||
groupInfo, found := pflag.Annotations[annotation] | ||
if found { | ||
for _, group := range groupInfo { | ||
if groupStatus[group] == nil { | ||
flagnames := strings.Split(group, " ") | ||
|
||
// Only consider this flag group at all if all the flags are defined. | ||
if !hasAllFlags(flags, flagnames...) { | ||
continue | ||
} | ||
|
||
groupStatus[group] = map[string]bool{} | ||
for _, name := range flagnames { | ||
groupStatus[group][name] = false | ||
} | ||
} | ||
|
||
groupStatus[group][pflag.Name] = pflag.Changed | ||
} | ||
} | ||
} | ||
|
||
func validateRequiredFlagGroups(data map[string]map[string]bool) error { | ||
keys := sortedKeys(data) | ||
for _, flagList := range keys { | ||
flagnameAndStatus := data[flagList] | ||
|
||
unset := []string{} | ||
for flagname, isSet := range flagnameAndStatus { | ||
if !isSet { | ||
unset = append(unset, flagname) | ||
} | ||
} | ||
if len(unset) == len(flagnameAndStatus) || len(unset) == 0 { | ||
continue | ||
} | ||
|
||
// Sort values, so they can be tested/scripted against consistently. | ||
sort.Strings(unset) | ||
return fmt.Errorf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func validateExclusiveFlagGroups(data map[string]map[string]bool) error { | ||
keys := sortedKeys(data) | ||
for _, flagList := range keys { | ||
flagnameAndStatus := data[flagList] | ||
var set []string | ||
for flagname, isSet := range flagnameAndStatus { | ||
if isSet { | ||
set = append(set, flagname) | ||
} | ||
} | ||
if len(set) == 0 || len(set) == 1 { | ||
continue | ||
} | ||
|
||
// Sort values, so they can be tested/scripted against consistently. | ||
sort.Strings(set) | ||
return fmt.Errorf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set) | ||
} | ||
return nil | ||
} | ||
|
||
func sortedKeys(m map[string]map[string]bool) []string { | ||
keys := make([]string, len(m)) | ||
i := 0 | ||
for k := range m { | ||
keys[i] = k | ||
i++ | ||
} | ||
sort.Strings(keys) | ||
return keys | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
// Copyright © 2022 Steve Francia <spf@spf13.com>. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package cobra | ||
|
||
import ( | ||
"strings" | ||
"testing" | ||
) | ||
|
||
func TestValidateFlagGroups(t *testing.T) { | ||
getCmd := func() *Command { | ||
c := &Command{ | ||
Use: "testcmd", | ||
Run: func(cmd *Command, args []string) { | ||
}} | ||
// Define lots of flags to utilize for testing. | ||
for _, v := range []string{"a", "b", "c", "d"} { | ||
c.Flags().String(v, "", "") | ||
} | ||
for _, v := range []string{"e", "f", "g"} { | ||
c.PersistentFlags().String(v, "", "") | ||
} | ||
subC := &Command{ | ||
Use: "subcmd", | ||
Run: func(cmd *Command, args []string) { | ||
}} | ||
subC.Flags().String("subonly", "", "") | ||
c.AddCommand(subC) | ||
return c | ||
} | ||
|
||
// Each test case uses a unique command from the function above. | ||
testcases := []struct { | ||
desc string | ||
flagGroupsRequired []string | ||
flagGroupsExclusive []string | ||
subCmdFlagGroupsRequired []string | ||
subCmdFlagGroupsExclusive []string | ||
args []string | ||
expectErr string | ||
}{ | ||
{ | ||
desc: "No flags no problem", | ||
}, { | ||
desc: "No flags no problem even with conflicting groups", | ||
flagGroupsRequired: []string{"a b"}, | ||
flagGroupsExclusive: []string{"a b"}, | ||
}, { | ||
desc: "Required flag group not satisfied", | ||
flagGroupsRequired: []string{"a b c"}, | ||
args: []string{"--a=foo"}, | ||
expectErr: "if any flags in the group [a b c] are set they must all be set; missing [b c]", | ||
}, { | ||
desc: "Exclusive flag group not satisfied", | ||
flagGroupsExclusive: []string{"a b c"}, | ||
args: []string{"--a=foo", "--b=foo"}, | ||
expectErr: "if any flags in the group [a b c] are set none of the others can be; [a b] were all set", | ||
}, { | ||
desc: "Multiple required flag group not satisfied returns first error", | ||
flagGroupsRequired: []string{"a b c", "a d"}, | ||
args: []string{"--c=foo", "--d=foo"}, | ||
expectErr: `if any flags in the group [a b c] are set they must all be set; missing [a b]`, | ||
}, { | ||
desc: "Multiple exclusive flag group not satisfied returns first error", | ||
flagGroupsExclusive: []string{"a b c", "a d"}, | ||
args: []string{"--a=foo", "--c=foo", "--d=foo"}, | ||
expectErr: `if any flags in the group [a b c] are set none of the others can be; [a c] were all set`, | ||
}, { | ||
desc: "Validation of required groups occurs on groups in sorted order", | ||
flagGroupsRequired: []string{"a d", "a b", "a c"}, | ||
args: []string{"--a=foo"}, | ||
expectErr: `if any flags in the group [a b] are set they must all be set; missing [b]`, | ||
}, { | ||
desc: "Validation of exclusive groups occurs on groups in sorted order", | ||
flagGroupsExclusive: []string{"a d", "a b", "a c"}, | ||
args: []string{"--a=foo", "--b=foo", "--c=foo"}, | ||
expectErr: `if any flags in the group [a b] are set none of the others can be; [a b] were all set`, | ||
}, { | ||
desc: "Persistent flags utilize both features and can fail required groups", | ||
flagGroupsRequired: []string{"a e", "e f"}, | ||
flagGroupsExclusive: []string{"f g"}, | ||
args: []string{"--a=foo", "--f=foo", "--g=foo"}, | ||
expectErr: `if any flags in the group [a e] are set they must all be set; missing [e]`, | ||
}, { | ||
desc: "Persistent flags utilize both features and can fail mutually exclusive groups", | ||
flagGroupsRequired: []string{"a e", "e f"}, | ||
flagGroupsExclusive: []string{"f g"}, | ||
args: []string{"--a=foo", "--e=foo", "--f=foo", "--g=foo"}, | ||
expectErr: `if any flags in the group [f g] are set none of the others can be; [f g] were all set`, | ||
}, { | ||
desc: "Persistent flags utilize both features and can pass", | ||
flagGroupsRequired: []string{"a e", "e f"}, | ||
flagGroupsExclusive: []string{"f g"}, | ||
args: []string{"--a=foo", "--e=foo", "--f=foo"}, | ||
}, { | ||
desc: "Subcmds can use required groups using inherited flags", | ||
subCmdFlagGroupsRequired: []string{"e subonly"}, | ||
args: []string{"subcmd", "--e=foo", "--subonly=foo"}, | ||
}, { | ||
desc: "Subcmds can use exclusive groups using inherited flags", | ||
subCmdFlagGroupsExclusive: []string{"e subonly"}, | ||
args: []string{"subcmd", "--e=foo", "--subonly=foo"}, | ||
expectErr: "if any flags in the group [e subonly] are set none of the others can be; [e subonly] were all set", | ||
}, { | ||
desc: "Subcmds can use exclusive groups using inherited flags and pass", | ||
subCmdFlagGroupsExclusive: []string{"e subonly"}, | ||
args: []string{"subcmd", "--e=foo"}, | ||
}, { | ||
desc: "Flag groups not applied if not found on invoked command", | ||
subCmdFlagGroupsRequired: []string{"e subonly"}, | ||
args: []string{"--e=foo"}, | ||
}, | ||
} | ||
for _, tc := range testcases { | ||
t.Run(tc.desc, func(t *testing.T) { | ||
c := getCmd() | ||
sub := c.Commands()[0] | ||
for _, flagGroup := range tc.flagGroupsRequired { | ||
c.MarkFlagsRequiredTogether(strings.Split(flagGroup, " ")...) | ||
} | ||
for _, flagGroup := range tc.flagGroupsExclusive { | ||
c.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...) | ||
} | ||
for _, flagGroup := range tc.subCmdFlagGroupsRequired { | ||
sub.MarkFlagsRequiredTogether(strings.Split(flagGroup, " ")...) | ||
} | ||
for _, flagGroup := range tc.subCmdFlagGroupsExclusive { | ||
sub.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...) | ||
} | ||
c.SetArgs(tc.args) | ||
err := c.Execute() | ||
switch { | ||
case err == nil && len(tc.expectErr) > 0: | ||
t.Errorf("Expected error %q but got nil", tc.expectErr) | ||
case err != nil && err.Error() != tc.expectErr: | ||
t.Errorf("Expected error %q but got %q", tc.expectErr, err) | ||
} | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters