Skip to content

Commit

Permalink
Extend Persistent*Run to allow multiple hooks
Browse files Browse the repository at this point in the history
Supersedes: PR #1123
Resolves: #252, #219
Related PR's: #714, #220

- Introduces On*Run() methods to register hooks
- Introduces OnPersistent*Run() methods to register hooks that are
  always executed (childs not overriding parents).
- Allows to register multiple *Run hooks
- Keeps current Persistent*Run behavior unaffected (childs still
  override parents when set as a property)
- Introduces EnablePersistentRunOverride option to control if child
  Persistent*Run hooks should override their parent's hooks

Merge spf13/cobra#1142
  • Loading branch information
bartdeboer authored and hoshsadiq committed Feb 10, 2022
1 parent da6f963 commit a21be02
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 58 deletions.
3 changes: 3 additions & 0 deletions cobra.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ var templateFuncs = template.FuncMap{

var initializers []func()

// EnablePersistentRunOverride ensures Persistent*Run* functions in childs override their parents
var EnablePersistentRunOverride = true

// EnablePrefixMatching allows to set automatic prefix matching. Automatic prefix matching can be a dangerous thing
// to automatically enable in CLI tools.
// Set this to true to enable it.
Expand Down
150 changes: 115 additions & 35 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,17 @@ type Command struct {
// PersistentPostRunE: PersistentPostRun but returns an error.
PersistentPostRunE func(cmd *Command, args []string) error

// persistentPreRunHooks are executed before the command or one of its children are executed
persistentPreRunHooks []func(cmd *Command, args []string) error
// preRunHooks are executed before the command is executed
preRunHooks []func(cmd *Command, args []string) error
// runHooks are executed when the command is executed
runHooks []func(cmd *Command, args []string) error
// postRunHooks are executed after the command has executed
postRunHooks []func(cmd *Command, args []string) error
// persistentPostRunHooks are executed after the command or one of its children have executed
persistentPostRunHooks []func(cmd *Command, args []string) error

// groups for commands
commandgroups []*Group

Expand Down Expand Up @@ -878,59 +889,128 @@ func (c *Command) execute(a []string) (err error) {
return err
}

for p := c; p != nil; p = p.Parent() {
if p.PersistentPreRunE != nil {
if err := p.PersistentPreRunE(c, argWoFlags); err != nil {
return err
}
break
} else if p.PersistentPreRun != nil {
p.PersistentPreRun(c, argWoFlags)
break
}
}
// Allocate the hooks execution chain for the current command
var hooks []func(cmd *Command, args []string) error

// First append the PreRun* hooks
hooks = append(hooks, c.preRunHooks...)
if c.PreRunE != nil {
if err := c.PreRunE(c, argWoFlags); err != nil {
return err
}
hooks = append(hooks, c.PreRunE)
} else if c.PreRun != nil {
c.PreRun(c, argWoFlags)
hooks = append(hooks, wrapVoidHook(c.PreRun))
}

if err := c.validateRequiredFlags(); err != nil {
return c.FlagErrorFunc()(c, err)
}
if c.RunE != nil {
if err := c.RunE(c, argWoFlags); err != nil {
// Include the validateRequiredFlags() logic as a hook
// to be executed before running the main Run hooks.
hooks = append(hooks, func(cmd *Command, args []string) error {
if err := cmd.validateRequiredFlags(); err != nil {
return err
}
} else {
if c.Run != nil {
c.Run(c, argWoFlags)
}
return nil
})

// Append the main Run* hooks
hooks = append(hooks, c.runHooks...)
if c.RunE != nil {
hooks = append(hooks, c.RunE)
} else if c.Run != nil {
hooks = append(hooks, wrapVoidHook(c.Run))
}

// Append the PostRun* hooks
hooks = append(hooks, c.postRunHooks...)
if c.PostRunE != nil {
if err := c.PostRunE(c, argWoFlags); err != nil {
return err
}
hooks = append(hooks, c.PostRunE)
} else if c.PostRun != nil {
c.PostRun(c, argWoFlags)
hooks = append(hooks, wrapVoidHook(c.PostRun))
}

// Lastly find and append/prepend the Persistent*Run hooks.
// Setting EnablePersistentRunOverride to true (default) preserves
// the previous behavior/concern where childs should override their parents.
// Any hooks registered through OnPersistent*Run will always
// be executed and cannot be overriden.
hasPersistentPreRunFromStruct := false
hasPersistentPostRunFromStruct := false
for p := c; p != nil; p = p.Parent() {
if p.PersistentPostRunE != nil {
if err := p.PersistentPostRunE(c, argWoFlags); err != nil {
return err
// Find and prepend the PersistentPreRun* hooks as defined on the commands
if !hasPersistentPreRunFromStruct || !EnablePersistentRunOverride {
if p.PersistentPreRunE != nil {
hooks = prependHook(&hooks, p.PersistentPreRunE)
hasPersistentPreRunFromStruct = true
} else if p.PersistentPreRun != nil {
hooks = prependHook(&hooks, wrapVoidHook(p.PersistentPreRun))
hasPersistentPreRunFromStruct = true
}
}
// Find and append the PersistentPostRun* hooks as defined on the commands
if !hasPersistentPostRunFromStruct || !EnablePersistentRunOverride {
if p.PersistentPostRunE != nil {
hooks = append(hooks, p.PersistentPostRunE)
hasPersistentPostRunFromStruct = true
} else if p.PersistentPostRun != nil {
hooks = append(hooks, wrapVoidHook(p.PersistentPostRun))
hasPersistentPostRunFromStruct = true
}
break
} else if p.PersistentPostRun != nil {
p.PersistentPostRun(c, argWoFlags)
break
}

// Hooks registered through OnPersistent*Run should always be executed
// Prepend the PersistentPreRun* hooks
hooks = append(p.persistentPreRunHooks, hooks...)
// Append the PersistentPostRun* hooks
hooks = append(hooks, p.persistentPostRunHooks...)
}

// Execute the hooks execution chain:
for _, x := range hooks {
if err := x(c, argWoFlags); err != nil {
return err
}
}

return nil
}

// prependHook prepends a hook onto the array of hooks
func prependHook(hooks *[]func(cmd *Command, args []string) error, hook ...func(cmd *Command, args []string) error) []func(cmd *Command, args []string) error {
return append(hook, *hooks...)
}

// wrapVoidHook wraps a void hook into a function having the return error signature
func wrapVoidHook(hook func(cmd *Command, args []string)) func(cmd *Command, args []string) error {
return func(cmd *Command, args []string) error {
hook(cmd, args)
return nil
}
}

// OnPersistentPreRun registers one or more hooks on the command to be executed
// before the command or one of its children are executed
func (c *Command) OnPersistentPreRun(f ...func(cmd *Command, args []string) error) {
c.persistentPreRunHooks = append(c.persistentPreRunHooks, f...)
}

// OnPreRun registers one or more hooks on the command to be executed before the command is executed
func (c *Command) OnPreRun(f ...func(cmd *Command, args []string) error) {
c.preRunHooks = append(c.preRunHooks, f...)
}

// OnRun registers one or more hooks on the command to be executed when the command is executed
func (c *Command) OnRun(f ...func(cmd *Command, args []string) error) {
c.runHooks = append(c.runHooks, f...)
}

// OnPostRun registers one or more hooks on the command to be executed after the command has executed
func (c *Command) OnPostRun(f ...func(cmd *Command, args []string) error) {
c.postRunHooks = append(c.postRunHooks, f...)
}

// OnPersistentPostRun register one or more hooks on the command to be executed
// after the command or one of its children have executed
func (c *Command) OnPersistentPostRun(f ...func(cmd *Command, args []string) error) {
c.persistentPostRunHooks = append(c.persistentPostRunHooks, f...)
}

func (c *Command) preRun() {
for _, x := range initializers {
x()
Expand Down
146 changes: 123 additions & 23 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1385,6 +1385,23 @@ func TestPersistentHooks(t *testing.T) {
childPersPostArgs string
)

var (
persParentPersPreArgs string
persParentPreArgs string
persParentRunArgs string
persParentPostArgs string
persParentPersPostArgs string
)

var (
persChildPersPreArgs string
persChildPreArgs string
persChildPreArgs2 string
persChildRunArgs string
persChildPostArgs string
persChildPersPostArgs string
)

parentCmd := &Command{
Use: "parent",
PersistentPreRun: func(_ *Command, args []string) {
Expand Down Expand Up @@ -1424,6 +1441,52 @@ func TestPersistentHooks(t *testing.T) {
}
parentCmd.AddCommand(childCmd)

parentCmd.OnPersistentPreRun(func(_ *Command, args []string) error {
persParentPersPreArgs = strings.Join(args, " ")
return nil
})
parentCmd.OnPreRun(func(_ *Command, args []string) error {
persParentPreArgs = strings.Join(args, " ")
return nil
})
parentCmd.OnRun(func(_ *Command, args []string) error {
persParentRunArgs = strings.Join(args, " ")
return nil
})
parentCmd.OnPostRun(func(_ *Command, args []string) error {
persParentPostArgs = strings.Join(args, " ")
return nil
})
parentCmd.OnPersistentPostRun(func(_ *Command, args []string) error {
persParentPersPostArgs = strings.Join(args, " ")
return nil
})

childCmd.OnPersistentPreRun(func(_ *Command, args []string) error {
persChildPersPreArgs = strings.Join(args, " ")
return nil
})
childCmd.OnPreRun(func(_ *Command, args []string) error {
persChildPreArgs = strings.Join(args, " ")
return nil
})
childCmd.OnPreRun(func(_ *Command, args []string) error {
persChildPreArgs2 = strings.Join(args, " ") + " three"
return nil
})
childCmd.OnRun(func(_ *Command, args []string) error {
persChildRunArgs = strings.Join(args, " ")
return nil
})
childCmd.OnPostRun(func(_ *Command, args []string) error {
persChildPostArgs = strings.Join(args, " ")
return nil
})
childCmd.OnPersistentPostRun(func(_ *Command, args []string) error {
persChildPersPostArgs = strings.Join(args, " ")
return nil
})

output, err := executeCommand(parentCmd, "child", "one", "two")
if output != "" {
t.Errorf("Unexpected output: %v", output)
Expand All @@ -1433,42 +1496,79 @@ func TestPersistentHooks(t *testing.T) {
}

for _, v := range []struct {
name string
got string
name string
got string
doCheck bool
}{
// TODO: currently PersistentPreRun* defined in parent does not
// run if the matching child subcommand has PersistentPreRun.
// If the behavior changes (https://github.com/spf13/cobra/issues/252)
// this test must be fixed.
{"parentPersPreArgs", parentPersPreArgs},
{"parentPreArgs", parentPreArgs},
{"parentRunArgs", parentRunArgs},
{"parentPostArgs", parentPostArgs},
// TODO: currently PersistentPostRun* defined in parent does not
// run if the matching child subcommand has PersistentPostRun.
// If the behavior changes (https://github.com/spf13/cobra/issues/252)
// this test must be fixed.
{"parentPersPostArgs", parentPersPostArgs},
{"parentPersPreArgs", parentPersPreArgs, EnablePersistentRunOverride},
{"parentPreArgs", parentPreArgs, true},
{"parentRunArgs", parentRunArgs, true},
{"parentPostArgs", parentPostArgs, true},
{"parentPersPostArgs", parentPersPostArgs, !EnablePersistentRunOverride},
} {
if v.got != "" {
if v.doCheck && v.got != "" {
t.Errorf("Expected blank %s, got %q", v.name, v.got)
}
}

for _, v := range []struct {
name string
got string
name string
got string
doCheck bool
}{
{"childPersPreArgs", childPersPreArgs},
{"childPreArgs", childPreArgs},
{"childRunArgs", childRunArgs},
{"childPostArgs", childPostArgs},
{"childPersPostArgs", childPersPostArgs},
{"childPersPreArgs", childPersPreArgs, EnablePersistentRunOverride},
{"childPreArgs", childPreArgs, true},
{"childRunArgs", childRunArgs, true},
{"childPostArgs", childPostArgs, true},
{"childPersPostArgs", childPersPostArgs, EnablePersistentRunOverride},
} {
if v.got != onetwo {
t.Errorf("Expected %s %q, got %q", v.name, onetwo, v.got)
}
}

// Test On*Run hooks

if persParentPersPreArgs != "one two" {
t.Errorf("Expected persParentPersPreArgs %q, got %q", "one two", persParentPersPreArgs)
}
if persParentPreArgs != "" {
t.Errorf("Expected blank persParentPreArgs, got %q", persParentPreArgs)
}
if persParentRunArgs != "" {
t.Errorf("Expected blank persParentRunArgs, got %q", persParentRunArgs)
}
if persParentPostArgs != "" {
t.Errorf("Expected blank persParentPostArg, got %q", persParentPostArgs)
}
if persParentPersPostArgs != "one two" {
t.Errorf("Expected persParentPersPostArgs %q, got %q", "one two", persParentPersPostArgs)
}

if persChildPersPreArgs != "one two" {
t.Errorf("Expected persChildPersPreArgs %q, got %q", "one two", persChildPersPreArgs)
}
if persChildPreArgs != "one two" {
t.Errorf("Expected persChildPreArgs %q, got %q", "one two", persChildPreArgs)
}
if persChildPreArgs2 != "one two three" {
t.Errorf("Expected persChildPreArgs %q, got %q", "one two three", persChildPreArgs2)
}
if persChildRunArgs != "one two" {
t.Errorf("Expected persChildRunArgs %q, got %q", "one two", persChildRunArgs)
}
if persChildPostArgs != "one two" {
t.Errorf("Expected persChildPostArgs %q, got %q", "one two", persChildPostArgs)
}
if persChildPersPostArgs != "one two" {
t.Errorf("Expected persChildPersPostArgs %q, got %q", "one two", persChildPersPostArgs)
}
}

func TestPersistentHooksWoOverride(t *testing.T) {
EnablePersistentRunOverride = false
TestPersistentHooks(t)
EnablePersistentRunOverride = true
}

// Related to https://github.com/spf13/cobra/issues/521.
Expand Down

0 comments on commit a21be02

Please sign in to comment.