diff --git a/cmd/crictl/image.go b/cmd/crictl/image.go index 3c4410ed75..869b26b2c4 100644 --- a/cmd/crictl/image.go +++ b/cmd/crictl/image.go @@ -417,7 +417,9 @@ var removeImageCommand = &cli.Command{ } // Container images - containers, err := runtimeClient.ListContainers(context.TODO(), nil) + containers, err := InterruptableRPC(nil, func(ctx context.Context) ([]*pb.Container, error) { + return runtimeClient.ListContainers(ctx, nil) + }) if err != nil { return err } @@ -669,7 +671,9 @@ func PullImageWithSandbox(client internalapi.ImageManagerService, image string, defer cancel() } - res, err := client.PullImage(ctx, request.Image, request.Auth, request.SandboxConfig) + res, err := InterruptableRPC(ctx, func(ctx context.Context) (string, error) { + return client.PullImage(ctx, request.Image, request.Auth, request.SandboxConfig) + }) if err != nil { return nil, err } @@ -683,7 +687,9 @@ func PullImageWithSandbox(client internalapi.ImageManagerService, image string, func ListImages(client internalapi.ImageManagerService, image string) (*pb.ListImagesResponse, error) { request := &pb.ListImagesRequest{Filter: &pb.ImageFilter{Image: &pb.ImageSpec{Image: image}}} logrus.Debugf("ListImagesRequest: %v", request) - res, err := client.ListImages(context.TODO(), request.Filter) + res, err := InterruptableRPC(nil, func(ctx context.Context) ([]*pb.Image, error) { + return client.ListImages(ctx, request.Filter) + }) if err != nil { return nil, err } @@ -790,7 +796,9 @@ func ImageStatus(client internalapi.ImageManagerService, image string, verbose b Verbose: verbose, } logrus.Debugf("ImageStatusRequest: %v", request) - res, err := client.ImageStatus(context.TODO(), request.Image, request.Verbose) + res, err := InterruptableRPC(nil, func(ctx context.Context) (*pb.ImageStatusResponse, error) { + return client.ImageStatus(ctx, request.Image, request.Verbose) + }) if err != nil { return nil, err } @@ -806,16 +814,18 @@ func RemoveImage(client internalapi.ImageManagerService, image string) error { } request := &pb.RemoveImageRequest{Image: &pb.ImageSpec{Image: image}} logrus.Debugf("RemoveImageRequest: %v", request) - if err := client.RemoveImage(context.TODO(), request.Image); err != nil { - return err - } - return nil + _, err := InterruptableRPC(nil, func(ctx context.Context) (*pb.RemoveImageResponse, error) { + return nil, client.RemoveImage(ctx, request.Image) + }) + return err } // ImageFsInfo sends an ImageStatusRequest to the server, and parses // the returned ImageFsInfoResponse. func ImageFsInfo(client internalapi.ImageManagerService) (*pb.ImageFsInfoResponse, error) { - res, err := client.ImageFsInfo(context.TODO()) + res, err := InterruptableRPC(nil, func(ctx context.Context) (*pb.ImageFsInfoResponse, error) { + return client.ImageFsInfo(ctx) + }) if err != nil { return nil, err } diff --git a/cmd/crictl/util.go b/cmd/crictl/util.go index 559debe991..18398117d1 100644 --- a/cmd/crictl/util.go +++ b/cmd/crictl/util.go @@ -70,6 +70,39 @@ func SetupInterruptSignalHandler() <-chan struct{} { return signalIntStopCh } +func InterruptableRPC[T any]( + ctx context.Context, + rpcFunc func(context.Context) (T, error), +) (res T, err error) { + if ctx == nil { + ctx = context.Background() + } + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + resCh := make(chan T, 1) + errCh := make(chan error, 1) + + go func() { + res, err := rpcFunc(ctx) + if err != nil { + errCh <- err + return + } + resCh <- res + }() + + select { + case <-SetupInterruptSignalHandler(): + cancel() + return res, fmt.Errorf("interrupted: %w", ctx.Err()) + case err := <-errCh: + return res, err + case res := <-resCh: + return res, nil + } +} + type listOptions struct { // id of container or sandbox id string