diff --git a/client/v3/client.go b/client/v3/client.go index 2990379ab9f..4dfae89c610 100644 --- a/client/v3/client.go +++ b/client/v3/client.go @@ -286,8 +286,7 @@ func (c *Client) dial(creds grpccredentials.TransportCredentials, dopts ...grpc. if err != nil { return nil, fmt.Errorf("failed to configure dialer: %v", err) } - if c.Username != "" && c.Password != "" { - c.authTokenBundle = credentials.NewBundle(credentials.Config{}) + if c.authTokenBundle != nil { opts = append(opts, grpc.WithPerRPCCredentials(c.authTokenBundle.PerRPCCredentials())) } @@ -383,6 +382,7 @@ func newClient(cfg *Config) (*Client, error) { if cfg.Username != "" && cfg.Password != "" { client.Username = cfg.Username client.Password = cfg.Password + client.authTokenBundle = credentials.NewBundle(credentials.Config{}) } if cfg.MaxCallSendMsgSize > 0 || cfg.MaxCallRecvMsgSize > 0 { if cfg.MaxCallRecvMsgSize > 0 && cfg.MaxCallSendMsgSize > cfg.MaxCallRecvMsgSize { diff --git a/client/v3/client_test.go b/client/v3/client_test.go index 8e2c03e280b..e441476374b 100644 --- a/client/v3/client_test.go +++ b/client/v3/client_test.go @@ -200,6 +200,55 @@ func TestZapWithLogger(t *testing.T) { } } +func TestAuthTokenBundleNoOverwrite(t *testing.T) { + // Create a mock AuthServer to handle Authenticate RPCs. + lis, err := net.Listen("unix", "etcd-auth-test:0") + if err != nil { + t.Fatal(err) + } + defer lis.Close() + addr := "unix:" + lis.Addr().String() + srv := grpc.NewServer() + etcdserverpb.RegisterAuthServer(srv, mockAuthServer{}) + go srv.Serve(lis) + defer srv.Stop() + + // Create a client, which should call Authenticate on the mock server to + // exchange username/password for an auth token. + c, err := NewClient(t, Config{ + DialTimeout: 5 * time.Second, + Endpoints: []string{addr}, + Username: "foo", + Password: "bar", + }) + if err != nil { + t.Fatal(err) + } + defer c.Close() + oldTokenBundle := c.authTokenBundle + + // Call the public Dial again, which should preserve the original + // authTokenBundle. + gc, err := c.Dial(addr) + if err != nil { + t.Fatal(err) + } + defer gc.Close() + newTokenBundle := c.authTokenBundle + + if oldTokenBundle != newTokenBundle { + t.Error("Client.authTokenBundle has been overwritten during Client.Dial") + } +} + +type mockAuthServer struct { + *etcdserverpb.UnimplementedAuthServer +} + +func (mockAuthServer) Authenticate(context.Context, *etcdserverpb.AuthenticateRequest) (*etcdserverpb.AuthenticateResponse, error) { + return &etcdserverpb.AuthenticateResponse{Token: "mock-token"}, nil +} + func TestSyncFiltersMembers(t *testing.T) { c, _ := NewClient(t, Config{Endpoints: []string{"http://254.0.0.1:12345"}}) defer c.Close()