Skip to content

Commit

Permalink
restore NewClientFromConnection (bazelbuild#558)
Browse files Browse the repository at this point in the history
The function is useful for creating other gRPC clients that share the same
underlying connection. Exporting an interface allows users to reuse the connection
pool as well as provide their own pre-dialed connection.
  • Loading branch information
mrahs authored May 3, 2024
1 parent faba519 commit aaaa08f
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 21 deletions.
6 changes: 3 additions & 3 deletions go/pkg/cas/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
//
// All fields are considered immutable, and should not be changed.
type Client struct {
conn *grpc.ClientConn
conn grpc.ClientConnInterface
// InstanceName is the full name of the RBE instance.
InstanceName string

Expand Down Expand Up @@ -234,12 +234,12 @@ func (c *RPCConfig) validate() error {

// NewClient creates a new client with the default configuration.
// Use client.Dial to create a connection.
func NewClient(ctx context.Context, conn *grpc.ClientConn, instanceName string) (*Client, error) {
func NewClient(ctx context.Context, conn grpc.ClientConnInterface, instanceName string) (*Client, error) {
return NewClientWithConfig(ctx, conn, instanceName, DefaultClientConfig())
}

// NewClientWithConfig creates a new client and accepts a configuration.
func NewClientWithConfig(ctx context.Context, conn *grpc.ClientConn, instanceName string, config ClientConfig) (*Client, error) {
func NewClientWithConfig(ctx context.Context, conn grpc.ClientConnInterface, instanceName string, config ClientConfig) (*Client, error) {
switch err := config.Validate(); {
case err != nil:
return nil, errors.Wrap(err, "invalid config")
Expand Down
50 changes: 32 additions & 18 deletions go/pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,11 @@ func (ce *InitError) Error() string {
return fmt.Sprintf("%v, authentication type (identity) used=%q", ce.Err.Error(), ce.AuthUsed)
}

// Temporary interface definition until the gcp balancer is removed in favour of
// the round-robin balancer.
type grpcClientConn interface {
// GrpcClientConn allows accepting pre-created connections to be used when creating clients.
// It is only intended to be used by methods in this package.
// It is not intended for SDK users to depend on.
// It might be removed in future update.
type GrpcClientConn interface {
grpc.ClientConnInterface
io.Closer
}
Expand All @@ -136,8 +138,8 @@ type Client struct {
//
// These fields are logically "protected" and are intended for use by extensions of Client.
Retrier *Retrier
connection grpcClientConn
casConnection grpcClientConn
connection GrpcClientConn
casConnection GrpcClientConn
// StartupCapabilities denotes whether to load ServerCapabilities on startup.
StartupCapabilities StartupCapabilities
// LegacyExecRootRelativeOutputs denotes whether outputs are relative to the exec root.
Expand Down Expand Up @@ -216,18 +218,16 @@ const (
DefaultRegularMode = 0644
)

func (c *Client) Connection() *grpc.ClientConn {
if conn, ok := c.connection.(*grpc.ClientConn); ok {
return conn
}
return c.connection.(*balancer.RRConnPool).Conn()
// Connection is meant to be used with generated methods that accept
// grpc.ClientConnInterface
func (c *Client) Connection() grpc.ClientConnInterface {
return c.connection
}

func (c *Client) CASConnection() *grpc.ClientConn {
if conn, ok := c.casConnection.(*grpc.ClientConn); ok {
return conn
}
return c.casConnection.(*balancer.RRConnPool).Conn()
// CASConnection is meant to be used with generated methods that accept
// grpc.ClientConnInterface
func (c *Client) CASConnection() grpc.ClientConnInterface {
return c.casConnection
}

// Close closes the underlying gRPC connection(s).
Expand Down Expand Up @@ -715,7 +715,7 @@ func NewClient(ctx context.Context, instanceName string, params DialParams, opts
return nil, fmt.Errorf("failed to prepare gRPC dial options: %v", err)
}

var conn, casConn grpcClientConn
var conn, casConn GrpcClientConn
if params.RoundRobinBalancer {
dial := func(ctx context.Context) (*grpc.ClientConn, error) {
return grpc.DialContext(ctx, params.Service, dialOpts...)
Expand Down Expand Up @@ -744,6 +744,21 @@ func NewClient(ctx context.Context, instanceName string, params DialParams, opts
return nil, &InitError{Err: statusWrap(err), AuthUsed: authUsed}
}

client, err := NewClientFromConnection(ctx, instanceName, conn, casConn, opts...)
if err != nil {
return nil, &InitError{Err: err, AuthUsed: authUsed}
}
return client, nil
}

// NewClientFromConnection creates a client from gRPC connections to a remote execution service and a cas service.
func NewClientFromConnection(ctx context.Context, instanceName string, conn, casConn GrpcClientConn, opts ...Opt) (*Client, error) {
if conn == nil {
return nil, fmt.Errorf("connection to remote execution service may not be nil")
}
if casConn == nil {
return nil, fmt.Errorf("connection to CAS service may not be nil")
}
client := &Client{
InstanceName: instanceName,
actionCache: regrpc.NewActionCacheClient(casConn),
Expand Down Expand Up @@ -787,7 +802,6 @@ func NewClient(ctx context.Context, instanceName string, params DialParams, opts
return nil, fmt.Errorf("CASConcurrency should be at least 1")
}
client.RunBackgroundTasks(ctx)

return client, nil
}

Expand Down Expand Up @@ -1048,7 +1062,7 @@ func (c *Client) WaitExecution(ctx context.Context, req *repb.WaitExecutionReque

// GetBackendCapabilities returns the capabilities for a specific server connection
// (either the main connection or the CAS connection).
func (c *Client) GetBackendCapabilities(ctx context.Context, conn *grpc.ClientConn, req *repb.GetCapabilitiesRequest) (res *repb.ServerCapabilities, err error) {
func (c *Client) GetBackendCapabilities(ctx context.Context, conn grpc.ClientConnInterface, req *repb.GetCapabilitiesRequest) (res *repb.ServerCapabilities, err error) {
opts := c.RPCOpts()
err = c.Retrier.Do(ctx, func() (e error) {
return c.CallWithTimeout(ctx, "GetCapabilities", func(ctx context.Context) (e error) {
Expand Down

0 comments on commit aaaa08f

Please sign in to comment.