diff --git a/cmd/gpu_plugin/rm/gpu_plugin_resource_manager_test.go b/cmd/gpu_plugin/rm/gpu_plugin_resource_manager_test.go index 09a5c68b2..d2c1a7a65 100644 --- a/cmd/gpu_plugin/rm/gpu_plugin_resource_manager_test.go +++ b/cmd/gpu_plugin/rm/gpu_plugin_resource_manager_test.go @@ -92,7 +92,7 @@ func (w *mockPodResources) Get(ctx context.Context, } func newMockResourceManager(pods []v1.Pod) ResourceManager { - client, err := grpc.Dial("fake", grpc.WithTransportCredentials(insecure.NewCredentials())) + client, err := grpc.NewClient("fake", grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { fmt.Fprintf(os.Stderr, "failed to create client: %v\n", err) diff --git a/pkg/deviceplugin/server.go b/pkg/deviceplugin/server.go index 7f02b79e8..2636b8fda 100644 --- a/pkg/deviceplugin/server.go +++ b/pkg/deviceplugin/server.go @@ -26,6 +26,7 @@ import ( "github.com/fsnotify/fsnotify" "github.com/pkg/errors" "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials/insecure" "k8s.io/klog/v2" @@ -326,15 +327,9 @@ func watchFile(file string) error { } func (srv *server) registerWithKubelet(kubeletSocket, pluginEndPoint, resourceName string) error { - ctx := context.Background() - - conn, err := grpc.DialContext(ctx, kubeletSocket, - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { - return (&net.Dialer{}).DialContext(ctx, "unix", addr) - })) + conn, err := grpc.NewClient(filepath.Join("unix://", kubeletSocket), grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { - return errors.Wrap(err, "Cannot connect to kubelet service") + return errors.Wrap(err, "Cannot create a gRPC client") } defer conn.Close() @@ -347,7 +342,7 @@ func (srv *server) registerWithKubelet(kubeletSocket, pluginEndPoint, resourceNa Options: srv.getDevicePluginOptions(), } - _, err = client.Register(ctx, reqt) + _, err = client.Register(context.Background(), reqt) if err != nil { return errors.Wrap(err, "Cannot register to kubelet service") } @@ -358,20 +353,33 @@ func (srv *server) registerWithKubelet(kubeletSocket, pluginEndPoint, resourceNa // waitForServer checks if grpc server is alive // by making grpc blocking connection to the server socket. func waitForServer(socket string, timeout time.Duration) error { + conn, err := grpc.NewClient(filepath.Join("unix://", socket), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return errors.Wrap(err, "Cannot create a gRPC client") + } + + defer conn.Close() + ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - conn, err := grpc.DialContext(ctx, socket, - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithBlock(), - grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { - return (&net.Dialer{}).DialContext(ctx, "unix", addr) - }), - ) - if conn != nil { - _ = conn.Close() - } + // A blocking dial blocks until the clientConn is ready. Based + // on grpc-go's DialContext() that moved to use NewClient() but + // marked DialContext() deprecated. + for { + state := conn.GetState() + if state == connectivity.Idle { + conn.Connect() + } + + if state == connectivity.Ready { + return nil + } - return errors.Wrapf(err, "Failed dial context at %s", socket) + if !conn.WaitForStateChange(ctx, state) { + // ctx got timeout or canceled. + return errors.Wrapf(ctx.Err(), "Failed dial context at %s", socket) + } + } } diff --git a/pkg/deviceplugin/server_test.go b/pkg/deviceplugin/server_test.go index 1efea061f..c1849c1b3 100644 --- a/pkg/deviceplugin/server_test.go +++ b/pkg/deviceplugin/server_test.go @@ -21,6 +21,7 @@ import ( "net" "os" "path" + "path/filepath" "reflect" "sync" "testing" @@ -111,7 +112,7 @@ func (k *kubeletStub) start() error { return waitForServer(k.socket, 10*time.Second) } -func TestRegisterWithKublet(t *testing.T) { +func TestRegisterWithKubelet(t *testing.T) { pluginSocket := path.Join(devicePluginPath, pluginEndpoint) srv := newTestServer() @@ -180,11 +181,8 @@ func TestSetupAndServe(t *testing.T) { ctx := context.Background() - conn, err := grpc.DialContext(ctx, pluginSocket, - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { - return (&net.Dialer{}).DialContext(ctx, "unix", addr) - })) + conn, err := grpc.NewClient(filepath.Join("unix://", pluginSocket), + grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { t.Fatalf("Failed to get connection: %+v", err) } @@ -231,12 +229,8 @@ func TestSetupAndServe(t *testing.T) { time.Sleep(1 * time.Second) } - conn, err = grpc.DialContext(ctx, pluginSocket, - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { - return (&net.Dialer{}).DialContext(ctx, "unix", addr) - })) - + conn, err = grpc.NewClient(filepath.Join("unix://", pluginSocket), + grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { t.Fatalf("Failed to get connection: %+v", err) }