diff --git a/pkg/chaos/docker/cmd/restart.go b/pkg/chaos/docker/cmd/restart.go index 90112efe..99e38246 100644 --- a/pkg/chaos/docker/cmd/restart.go +++ b/pkg/chaos/docker/cmd/restart.go @@ -3,6 +3,7 @@ package cmd import ( "context" "fmt" + "time" "github.com/urfave/cli" @@ -20,10 +21,10 @@ func NewRestartCLICommand(ctx context.Context) *cli.Command { return &cli.Command{ Name: "restart", Flags: []cli.Flag{ - cli.StringFlag{ - Name: "command, s", - Usage: "shell command, that will be sent by Pumba to the target container(s)", - Value: "kill 1", + cli.IntFlag{ + Name: "timeout, s", + Usage: "restart timeout for target container(s)", + Value: 1000, }, cli.IntFlag{ Name: "limit, l", @@ -53,11 +54,11 @@ func (cmd *restartContext) restart(c *cli.Context) error { // get names or pattern names, pattern := chaos.GetNamesOrPattern(c) // get command - command := c.String("command") + timeout := time.Duration(c.Int("timeout")) * time.Millisecond // get limit for number of containers to restart limit := c.Int("limit") // init restart command - restartCommand, err := docker.NewRestartCommand(chaos.DockerClient, names, pattern, labels, command, limit, dryRun) + restartCommand, err := docker.NewRestartCommand(chaos.DockerClient, names, pattern, labels, timeout, limit, dryRun) if err != nil { return err } diff --git a/pkg/chaos/docker/restart.go b/pkg/chaos/docker/restart.go index 57817df8..f9d69de9 100644 --- a/pkg/chaos/docker/restart.go +++ b/pkg/chaos/docker/restart.go @@ -2,6 +2,7 @@ package docker import ( "context" + "time" "github.com/alexei-led/pumba/pkg/chaos" "github.com/alexei-led/pumba/pkg/container" @@ -15,17 +16,14 @@ type RestartCommand struct { names []string pattern string labels []string - command string + timeout time.Duration limit int dryRun bool } // NewRestartCommand create new Restart Command instance -func NewRestartCommand(client container.Client, names []string, pattern string, labels []string, command string, limit int, dryRun bool) (chaos.Command, error) { - restart := &RestartCommand{client, names, pattern, labels, command, limit, dryRun} - if restart.command == "" { - restart.command = "kill 1" - } +func NewRestartCommand(client container.Client, names []string, pattern string, labels []string, timeout time.Duration, limit int, dryRun bool) (chaos.Command, error) { + restart := &RestartCommand{client, names, pattern, labels, timeout, limit, dryRun} return restart, nil } @@ -58,10 +56,10 @@ func (k *RestartCommand) Run(ctx context.Context, random bool) error { for _, container := range containers { log.WithFields(log.Fields{ "container": container, - "command": k.command, + "timeout": k.timeout, }).Debug("restarting container") c := container - err = k.client.RestartContainer(ctx, c, k.command, k.dryRun) + err = k.client.RestartContainer(ctx, c, k.timeout, k.dryRun) if err != nil { return errors.Wrap(err, "failed to restart container") } diff --git a/pkg/chaos/docker/restart_test.go b/pkg/chaos/docker/restart_test.go index 0aeb6369..9d028e25 100644 --- a/pkg/chaos/docker/restart_test.go +++ b/pkg/chaos/docker/restart_test.go @@ -5,6 +5,7 @@ import ( "errors" "reflect" "testing" + "time" "github.com/alexei-led/pumba/pkg/chaos" "github.com/alexei-led/pumba/pkg/container" @@ -14,14 +15,14 @@ import ( //nolint:funlen func TestRestartCommand_Run(t *testing.T) { type wantErrors struct { - listError bool + listError bool restartError bool } type fields struct { names []string pattern string labels []string - command string + timeout time.Duration limit int dryRun bool } @@ -41,7 +42,7 @@ func TestRestartCommand_Run(t *testing.T) { name: "restart matching containers by names", fields: fields{ names: []string{"c1", "c2", "c3"}, - command: "kill 1", + timeout: 1 * time.Second, }, args: args{ ctx: context.TODO(), @@ -53,7 +54,7 @@ func TestRestartCommand_Run(t *testing.T) { fields: fields{ names: []string{"c1", "c2", "c3"}, labels: []string{"key=value"}, - command: "kill 1", + timeout: 1 * time.Second, }, args: args{ ctx: context.TODO(), @@ -64,7 +65,7 @@ func TestRestartCommand_Run(t *testing.T) { name: "restart matching containers by filter with limit", fields: fields{ pattern: "^c?", - command: "kill -STOP 1", + timeout: 1 * time.Second, limit: 2, }, args: args{ @@ -76,7 +77,7 @@ func TestRestartCommand_Run(t *testing.T) { name: "restart random matching container by names", fields: fields{ names: []string{"c1", "c2", "c3"}, - command: "kill 1", + timeout: 1 * time.Second, }, args: args{ ctx: context.TODO(), @@ -88,7 +89,7 @@ func TestRestartCommand_Run(t *testing.T) { name: "no matching containers by names", fields: fields{ names: []string{"c1", "c2", "c3"}, - command: "kill 1", + timeout: 1 * time.Second, }, args: args{ ctx: context.TODO(), @@ -98,7 +99,7 @@ func TestRestartCommand_Run(t *testing.T) { name: "error listing containers", fields: fields{ names: []string{"c1", "c2", "c3"}, - command: "kill 1", + timeout: 1 * time.Second, }, args: args{ ctx: context.TODO(), @@ -110,7 +111,7 @@ func TestRestartCommand_Run(t *testing.T) { name: "error restarting container", fields: fields{ names: []string{"c1", "c2", "c3"}, - command: "kill 1", + timeout: 1 * time.Second, }, args: args{ ctx: context.TODO(), @@ -128,7 +129,7 @@ func TestRestartCommand_Run(t *testing.T) { names: tt.fields.names, pattern: tt.fields.pattern, labels: tt.fields.labels, - command: tt.fields.command, + timeout: 1 * time.Second, limit: tt.fields.limit, dryRun: tt.fields.dryRun, } @@ -144,11 +145,11 @@ func TestRestartCommand_Run(t *testing.T) { } } if tt.args.random { - mockClient.On("RestartContainer", tt.args.ctx, mock.AnythingOfType("*container.Container"), tt.fields.command, tt.fields.dryRun).Return(nil) + mockClient.On("RestartContainer", tt.args.ctx, mock.AnythingOfType("*container.Container"), tt.fields.timeout, tt.fields.dryRun).Return(nil) } else { for i := range tt.expected { if tt.fields.limit == 0 || i < tt.fields.limit { - call = mockClient.On("RestartContainer", tt.args.ctx, mock.AnythingOfType("*container.Container"), tt.fields.command, tt.fields.dryRun) + call = mockClient.On("RestartContainer", tt.args.ctx, mock.AnythingOfType("*container.Container"), tt.fields.timeout, tt.fields.dryRun) if tt.errs.restartError { call.Return(errors.New("ERROR")) goto Invoke @@ -173,7 +174,7 @@ func TestNewRestartCommand(t *testing.T) { names []string pattern string labels []string - command string + timeout time.Duration limit int dryRun bool } @@ -187,12 +188,12 @@ func TestNewRestartCommand(t *testing.T) { name: "create new restart command", args: args{ names: []string{"c1", "c2"}, - command: "kill -TERM 1", + timeout: 1 * time.Second, limit: 10, }, want: &RestartCommand{ names: []string{"c1", "c2"}, - command: "kill -TERM 1", + timeout: 1 * time.Second, limit: 10, }, }, @@ -200,17 +201,17 @@ func TestNewRestartCommand(t *testing.T) { name: "empty command", args: args{ names: []string{"c1", "c2"}, - command: "", + timeout: 1 * time.Second, }, want: &RestartCommand{ names: []string{"c1", "c2"}, - command: "kill 1", + timeout: 1 * time.Second, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := NewRestartCommand(tt.args.client, tt.args.names, tt.args.pattern, tt.args.labels, tt.args.command, tt.args.limit, tt.args.dryRun) + got, err := NewRestartCommand(tt.args.client, tt.args.names, tt.args.pattern, tt.args.labels, tt.args.timeout, tt.args.limit, tt.args.dryRun) if (err != nil) != tt.wantErr { t.Errorf("NewRestartCommand() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/pkg/container/client.go b/pkg/container/client.go index 065f311a..e6c92b2d 100644 --- a/pkg/container/client.go +++ b/pkg/container/client.go @@ -38,7 +38,7 @@ type Client interface { StopContainer(context.Context, *Container, int, bool) error KillContainer(context.Context, *Container, string, bool) error ExecContainer(context.Context, *Container, string, bool) error - RestartContainer(context.Context, *Container, string, bool) error + RestartContainer(context.Context, *Container, time.Duration, bool) error RemoveContainer(context.Context, *Container, bool, bool, bool, bool) error NetemContainer(context.Context, *Container, string, []string, []*net.IPNet, []string, []string, time.Duration, string, bool, bool) error StopNetemContainer(context.Context, *Container, string, []*net.IPNet, []string, []string, string, bool, bool) error @@ -186,56 +186,19 @@ func (client dockerClient) ExecContainer(ctx context.Context, c *Container, comm return nil } -func (client dockerClient) RestartContainer(ctx context.Context, c *Container, command string, dryrun bool) error { +func (client dockerClient) RestartContainer(ctx context.Context, c *Container, timeout time.Duration, dryrun bool) error { log.WithFields(log.Fields{ "name": c.Name(), "id": c.ID(), - "command": command, + "timeout": timeout, "dryrun": dryrun, }).Info("restart container") if !dryrun { - createRes, err := client.containerAPI.ContainerRestartCreate( - ctx, c.ID(), types.RestartConfig{ - User: "root", - AttachStdout: true, - AttachStderr: true, - Cmd: strings.Split(command, " "), - }, - ) - if err != nil { - return errors.Wrap(err, "restart create failed") - } - - attachRes, err := client.containerAPI.ContainerAttach( - ctx, createRes.ID, types.ContainerAttachOptions{}, + err := client.containerAPI.ContainerRestart( + ctx, c.ID(), &timeout, ) if err != nil { - return errors.Wrap(err, "restart attach failed") - } - - if err := client.containerAPI.ContainerRestartStart( - ctx, createRes.ID, types.RestartStartCheck{}, - ); err != nil { - return errors.Wrap(err, "restart start failed") - } - - output, err := ioutil.ReadAll(attachRes.Reader) - if err != nil { - return errors.Wrap(err, "reading output from restart reader failed") - } - log.WithFields(log.Fields{ - "name": c.Name(), - "id": c.ID(), - "command": command, - "dryrun": dryrun, - }).Info(string(output)) - - res, err := client.containerAPI.ContainerRestartInspect(ctx, createRes.ID) - if err != nil { - return errors.Wrap(err, "restart inspect failed") - } - if res.ExitCode != 0 { - return errors.New("restart failed " + command + fmt.Sprintf(" %d", res.ExitCode)) + return errors.Wrap(err, "restart failed") } } return nil diff --git a/pkg/container/mock_Client.go b/pkg/container/mock_Client.go index be0ba79f..3f07d737 100644 --- a/pkg/container/mock_Client.go +++ b/pkg/container/mock_Client.go @@ -45,11 +45,11 @@ func (_m *MockClient) ExecContainer(_a0 context.Context, _a1 *Container, _a2 str } // RestartContainer provides a mock function with given fields: _a0, _a1, _a2, _a3 -func (_m *MockClient) RestartContainer(_a0 context.Context, _a1 *Container, _a2 string, _a3 bool) error { +func (_m *MockClient) RestartContainer(_a0 context.Context, _a1 *Container, _a2 time.Duration, _a3 bool) error { ret := _m.Called(_a0, _a1, _a2, _a3) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, *Container, string, bool) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, *Container, time.Duration, bool) error); ok { r0 = rf(_a0, _a1, _a2, _a3) } else { r0 = ret.Error(0)