diff --git a/cmd/crictl/logs.go b/cmd/crictl/logs.go index 8b454ec897..8a662e3e40 100644 --- a/cmd/crictl/logs.go +++ b/cmd/crictl/logs.go @@ -67,7 +67,7 @@ var logsCommand = &cli.Command{ Usage: "Show timestamps", }, }, - Action: func(ctx *cli.Context) error { + Action: func(ctx *cli.Context) (retErr error) { runtimeService, err := getRuntimeService(ctx) if err != nil { return err @@ -108,7 +108,26 @@ var logsCommand = &cli.Command{ logPath = fmt.Sprintf("%s%s%s", logPath[:strings.LastIndex(logPath, "/")+1], fmt.Sprint(containerAttempt-1), logPath[strings.LastIndex(logPath, "."):]) } - return logs.ReadLogs(context.Background(), logPath, status.GetId(), logOptions, runtimeService, os.Stdout, os.Stderr) + // build a WithCancel context based on cli.context + readLogCtx, cancelFn := context.WithCancel(ctx.Context) + go func() { + <-SetupInterruptSignalHandler() + // cancel readLogCtx when Interrupt signal received + cancelFn() + }() + defer func() { + // We can not use the typed error "context.Canceled" here + // because the upstream K8S dependency explicitly returns a fmt.Errorf("context cancelled"). + // So we need to compare the error in string. + if retErr != nil && retErr.Error() == "context cancelled" { + // Silent the "context cancelled" error. + // In order to prevent the error msg when user hit Ctrl+C. + retErr = nil + } + // Ensure no context leak + cancelFn() + }() + return logs.ReadLogs(readLogCtx, logPath, status.GetId(), logOptions, runtimeService, os.Stdout, os.Stderr) }, } diff --git a/cmd/crictl/main_unix.go b/cmd/crictl/main_unix.go index fe55934e12..64575a9c0e 100644 --- a/cmd/crictl/main_unix.go +++ b/cmd/crictl/main_unix.go @@ -18,8 +18,15 @@ limitations under the License. package main +import ( + "os" + "syscall" +) + const ( defaultConfigPath = "/etc/crictl.yaml" ) var defaultRuntimeEndpoints = []string{"unix:///var/run/dockershim.sock", "unix:///run/containerd/containerd.sock", "unix:///run/crio/crio.sock"} + +var shutdownSignals = []os.Signal{os.Interrupt, syscall.SIGTERM} diff --git a/cmd/crictl/main_windows.go b/cmd/crictl/main_windows.go index 788161aaa7..0cefd61f26 100644 --- a/cmd/crictl/main_windows.go +++ b/cmd/crictl/main_windows.go @@ -26,6 +26,8 @@ import ( var defaultRuntimeEndpoints = []string{"npipe:////./pipe/dockershim", "npipe:////./pipe/containerd", "npipe:////./pipe/crio"} var defaultConfigPath string +var shutdownSignals = []os.Signal{os.Interrupt} + func init() { defaultConfigPath = filepath.Join(os.Getenv("USERPROFILE"), ".crictl", "crictl.yaml") } diff --git a/cmd/crictl/portforward.go b/cmd/crictl/portforward.go index 883893f0f3..1edf5685fb 100644 --- a/cmd/crictl/portforward.go +++ b/cmd/crictl/portforward.go @@ -21,14 +21,13 @@ import ( "net/http" "net/url" "os" - "os/signal" "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/urfave/cli/v2" "golang.org/x/net/context" restclient "k8s.io/client-go/rest" - portforward "k8s.io/client-go/tools/portforward" + "k8s.io/client-go/tools/portforward" "k8s.io/client-go/transport/spdy" pb "k8s.io/cri-api/pkg/apis/runtime/v1alpha2" ) @@ -100,21 +99,10 @@ func PortForward(client pb.RuntimeServiceClient, opts portforwardOptions) error } dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", URL) - stopChan := make(chan struct{}, 1) readyChan := make(chan struct{}) - signals := make(chan os.Signal, 1) - signal.Notify(signals, os.Interrupt) - defer signal.Stop(signals) - - go func() { - <-signals - if stopChan != nil { - close(stopChan) - } - }() logrus.Debugf("Ports to forword: %v", opts.ports) - pf, err := portforward.New(dialer, opts.ports, stopChan, readyChan, os.Stdout, os.Stderr) + pf, err := portforward.New(dialer, opts.ports, SetupInterruptSignalHandler(), readyChan, os.Stdout, os.Stderr) if err != nil { return err } diff --git a/cmd/crictl/stats.go b/cmd/crictl/stats.go index 320da08b6e..0af9f91e97 100644 --- a/cmd/crictl/stats.go +++ b/cmd/crictl/stats.go @@ -18,8 +18,6 @@ package main import ( "fmt" - "os" - "os/signal" "sort" "time" @@ -150,29 +148,42 @@ func ContainerStats(client pb.RuntimeServiceClient, opts statsOptions) error { display := newTableDisplay(20, 1, 3, ' ', 0) if !opts.watch { - if err := displayStats(client, request, display, opts); err != nil { + if err := displayStats(context.TODO(), client, request, display, opts); err != nil { return err } } else { - s := make(chan os.Signal, 1) - signal.Notify(s, os.Interrupt) + displayErrCh := make(chan error, 1) + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + watchCtx, cancelFn := context.WithCancel(context.Background()) + defer cancelFn() + // Put the displayStats in another goroutine. + // because it might be time consuming with lots of containers. + // and we want to cancel it ASAP when user hit CtrlC go func() { - <-s - os.Exit(0) - }() - for range time.Tick(500 * time.Millisecond) { - if err := displayStats(client, request, display, opts); err != nil { - return err + for range ticker.C { + if err := displayStats(watchCtx, client, request, display, opts); err != nil { + displayErrCh <- err + break + } } + }() + // listen for CtrlC or error + select { + case <-SetupInterruptSignalHandler(): + cancelFn() + return nil + case err := <-displayErrCh: + return err } } return nil } -func getContainerStats(client pb.RuntimeServiceClient, request *pb.ListContainerStatsRequest) (*pb.ListContainerStatsResponse, error) { +func getContainerStats(ctx context.Context, client pb.RuntimeServiceClient, request *pb.ListContainerStatsRequest) (*pb.ListContainerStatsResponse, error) { logrus.Debugf("ListContainerStatsRequest: %v", request) - r, err := client.ListContainerStats(context.Background(), request) + r, err := client.ListContainerStats(ctx, request) logrus.Debugf("ListContainerResponse: %v", r) if err != nil { return nil, err @@ -181,8 +192,8 @@ func getContainerStats(client pb.RuntimeServiceClient, request *pb.ListContainer return r, nil } -func displayStats(client pb.RuntimeServiceClient, request *pb.ListContainerStatsRequest, display *display, opts statsOptions) error { - r, err := getContainerStats(client, request) +func displayStats(ctx context.Context, client pb.RuntimeServiceClient, request *pb.ListContainerStatsRequest, display *display, opts statsOptions) error { + r, err := getContainerStats(ctx, client, request) if err != nil { return err } @@ -194,18 +205,24 @@ func displayStats(client pb.RuntimeServiceClient, request *pb.ListContainerStats } oldStats := make(map[string]*pb.ContainerStats) for _, s := range r.GetStats() { + if ctx.Err() != nil { + return ctx.Err() + } oldStats[s.Attributes.Id] = s } time.Sleep(opts.sample) - r, err = getContainerStats(client, request) + r, err = getContainerStats(ctx, client, request) if err != nil { return err } display.AddRow([]string{columnContainer, columnCPU, columnMemory, columnDisk, columnInodes}) for _, s := range r.GetStats() { + if ctx.Err() != nil { + return ctx.Err() + } id := getTruncatedID(s.Attributes.Id, "") cpu := s.GetCpu().GetUsageCoreNanoSeconds().GetValue() mem := s.GetMemory().GetWorkingSetBytes().GetValue() diff --git a/cmd/crictl/util.go b/cmd/crictl/util.go index a999fa7b8b..a1b0c8655e 100644 --- a/cmd/crictl/util.go +++ b/cmd/crictl/util.go @@ -22,10 +22,12 @@ import ( "encoding/json" "fmt" "os" + "os/signal" "reflect" "regexp" "sort" "strings" + "sync" "time" "github.com/golang/protobuf/jsonpb" @@ -43,6 +45,32 @@ const ( truncatedIDLen = 13 ) +var ( + // The global stopCh for monitoring Interrupt signal. + // DO NOT use it directly. Use SetupInterruptSignalHandler() to get it. + signalIntStopCh chan struct{} + // only setup stopCh once + signalIntSetupOnce = &sync.Once{} +) + +// SetupInterruptSignalHandler setup a global signal handler monitoring Interrupt signal. e.g: Ctrl+C. +// The returned read-only channel will be closed on receiving Interrupt signals. +// It will directly call os.Exit(1) on receiving Interrupt signal twice. +func SetupInterruptSignalHandler() <-chan struct{} { + signalIntSetupOnce.Do(func() { + signalIntStopCh = make(chan struct{}) + c := make(chan os.Signal, 2) + signal.Notify(c, shutdownSignals...) + go func() { + <-c + close(signalIntStopCh) + <-c + os.Exit(1) // Exit immediately on second signal + }() + }) + return signalIntStopCh +} + type listOptions struct { // id of container or sandbox id string