Skip to content

Commit

Permalink
[#343] pool: Provide dial context to clients
Browse files Browse the repository at this point in the history
After recent changes `client.Client` accepts dial context. There is a
need to forward the context passed into `Pool.Dial` to the underlying
`Client` instances.

Define type aliases of different client constructors: context-based and
non-context. Use context-based constructor in `Pool`. Pass `ctx`
parameter of `Pool.Dial` method to the client builder.

Signed-off-by: Leonard Lyubich <ctulhurider@gmail.com>
  • Loading branch information
cthulhu-rider authored and fyrchik committed Oct 7, 2022
1 parent 452a50e commit 8c68264
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 24 deletions.
46 changes: 40 additions & 6 deletions pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ type wrapperPrm struct {
timeout time.Duration
errorThreshold uint32
responseInfoCallback func(sdkClient.ResponseMetaInfo) error
dialCtx context.Context
}

// setAddress sets endpoint to connect in NeoFS network.
Expand Down Expand Up @@ -247,6 +248,11 @@ func (x *wrapperPrm) setResponseInfoCallback(f func(sdkClient.ResponseMetaInfo)
x.responseInfoCallback = f
}

// setDialContext specifies context for client dial.
func (x *wrapperPrm) setDialContext(ctx context.Context) {
x.dialCtx = ctx
}

// newWrapper creates a clientWrapper that implements the client interface.
func newWrapper(prm wrapperPrm) (*clientWrapper, error) {
var prmInit sdkClient.PrmInit
Expand All @@ -263,6 +269,7 @@ func newWrapper(prm wrapperPrm) (*clientWrapper, error) {
var prmDial sdkClient.PrmDial
prmDial.SetServerURI(prm.address)
prmDial.SetTimeout(prm.timeout)
prmDial.SetContext(prm.dialCtx)

err := res.client.Dial(prmDial)
if err != nil {
Expand Down Expand Up @@ -818,6 +825,14 @@ func (c *clientStatusMonitor) handleError(st apistatus.Status, err error) error
return err
}

// clientBuilder is a type alias of client constructors which open connection
// to the given endpoint.
type clientBuilder = func(endpoint string) (client, error)

// clientBuilderContext is a type alias of client constructors which open
// connection to the given endpoint using provided context.
type clientBuilderContext = func(ctx context.Context, endpoint string) (client, error)

// InitParameters contains values used to initialize connection Pool.
type InitParameters struct {
key *ecdsa.PrivateKey
Expand All @@ -829,7 +844,7 @@ type InitParameters struct {
errorThreshold uint32
nodeParams []NodeParam

clientBuilder func(endpoint string) (client, error)
clientBuilder clientBuilderContext
}

// SetKey specifies default key to be used for the protocol communication by default.
Expand Down Expand Up @@ -876,6 +891,24 @@ func (x *InitParameters) AddNode(nodeParam NodeParam) {
x.nodeParams = append(x.nodeParams, nodeParam)
}

// setClientBuilder sets clientBuilder used for client construction.
// Wraps setClientBuilderContext without a context.
func (x *InitParameters) setClientBuilder(builder clientBuilder) {
x.setClientBuilderContext(func(_ context.Context, endpoint string) (client, error) {
return builder(endpoint)
})
}

// setClientBuilderContext sets clientBuilderContext used for client construction.
func (x *InitParameters) setClientBuilderContext(builder clientBuilderContext) {
x.clientBuilder = builder
}

// isMissingClientBuilder checks if client constructor was not specified.
func (x *InitParameters) isMissingClientBuilder() bool {
return x.clientBuilder == nil
}

type rebalanceParameters struct {
nodesParams []*nodesParam
nodeRequestTimeout time.Duration
Expand Down Expand Up @@ -1303,7 +1336,7 @@ type Pool struct {
cache *sessionCache
stokenDuration uint64
rebalanceParams rebalanceParameters
clientBuilder func(endpoint string) (client, error)
clientBuilder clientBuilderContext
logger *zap.Logger
}

Expand Down Expand Up @@ -1371,7 +1404,7 @@ func (p *Pool) Dial(ctx context.Context) error {
for i, params := range p.rebalanceParams.nodesParams {
clients := make([]client, len(params.weights))
for j, addr := range params.addresses {
c, err := p.clientBuilder(addr)
c, err := p.clientBuilder(ctx, addr)
if err != nil {
return err
}
Expand Down Expand Up @@ -1428,8 +1461,8 @@ func fillDefaultInitParams(params *InitParameters, cache *sessionCache) {
params.healthcheckTimeout = defaultRequestTimeout
}

if params.clientBuilder == nil {
params.clientBuilder = func(addr string) (client, error) {
if params.isMissingClientBuilder() {
params.setClientBuilderContext(func(ctx context.Context, addr string) (client, error) {
var prm wrapperPrm
prm.setAddress(addr)
prm.setKey(*params.key)
Expand All @@ -1439,8 +1472,9 @@ func fillDefaultInitParams(params *InitParameters, cache *sessionCache) {
cache.updateEpoch(info.Epoch())
return nil
})
prm.setDialContext(ctx)
return newWrapper(prm)
}
})
}
}

Expand Down
36 changes: 18 additions & 18 deletions pool/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ func TestBuildPoolClientFailed(t *testing.T) {
}

opts := InitParameters{
key: newPrivateKey(t),
nodeParams: []NodeParam{{1, "peer0", 1}},
clientBuilder: clientBuilder,
key: newPrivateKey(t),
nodeParams: []NodeParam{{1, "peer0", 1}},
}
opts.setClientBuilder(clientBuilder)

pool, err := NewPool(opts)
require.NoError(t, err)
Expand All @@ -46,10 +46,10 @@ func TestBuildPoolCreateSessionFailed(t *testing.T) {
}

opts := InitParameters{
key: newPrivateKey(t),
nodeParams: []NodeParam{{1, "peer0", 1}},
clientBuilder: clientBuilder,
key: newPrivateKey(t),
nodeParams: []NodeParam{{1, "peer0", 1}},
}
opts.setClientBuilder(clientBuilder)

pool, err := NewPool(opts)
require.NoError(t, err)
Expand Down Expand Up @@ -87,11 +87,11 @@ func TestBuildPoolOneNodeFailed(t *testing.T) {
require.NoError(t, err)
opts := InitParameters{
key: newPrivateKey(t),
clientBuilder: clientBuilder,
clientRebalanceInterval: 1000 * time.Millisecond,
logger: log,
nodeParams: nodes,
}
opts.setClientBuilder(clientBuilder)

clientPool, err := NewPool(opts)
require.NoError(t, err)
Expand Down Expand Up @@ -127,10 +127,10 @@ func TestOneNode(t *testing.T) {
}

opts := InitParameters{
key: newPrivateKey(t),
nodeParams: []NodeParam{{1, "peer0", 1}},
clientBuilder: clientBuilder,
key: newPrivateKey(t),
nodeParams: []NodeParam{{1, "peer0", 1}},
}
opts.setClientBuilder(clientBuilder)

pool, err := NewPool(opts)
require.NoError(t, err)
Expand Down Expand Up @@ -159,8 +159,8 @@ func TestTwoNodes(t *testing.T) {
{1, "peer0", 1},
{1, "peer1", 1},
},
clientBuilder: clientBuilder,
}
opts.setClientBuilder(clientBuilder)

pool, err := NewPool(opts)
require.NoError(t, err)
Expand Down Expand Up @@ -209,8 +209,8 @@ func TestOneOfTwoFailed(t *testing.T) {
key: newPrivateKey(t),
nodeParams: nodes,
clientRebalanceInterval: 200 * time.Millisecond,
clientBuilder: clientBuilder,
}
opts.setClientBuilder(clientBuilder)

pool, err := NewPool(opts)
require.NoError(t, err)
Expand Down Expand Up @@ -247,8 +247,8 @@ func TestTwoFailed(t *testing.T) {
{1, "peer1", 1},
},
clientRebalanceInterval: 200 * time.Millisecond,
clientBuilder: clientBuilder,
}
opts.setClientBuilder(clientBuilder)

pool, err := NewPool(opts)
require.NoError(t, err)
Expand Down Expand Up @@ -280,8 +280,8 @@ func TestSessionCache(t *testing.T) {
{1, "peer0", 1},
},
clientRebalanceInterval: 30 * time.Second,
clientBuilder: clientBuilder,
}
opts.setClientBuilder(clientBuilder)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down Expand Up @@ -348,8 +348,8 @@ func TestPriority(t *testing.T) {
key: newPrivateKey(t),
nodeParams: nodes,
clientRebalanceInterval: 1500 * time.Millisecond,
clientBuilder: clientBuilder,
}
opts.setClientBuilder(clientBuilder)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down Expand Up @@ -395,8 +395,8 @@ func TestSessionCacheWithKey(t *testing.T) {
{1, "peer0", 1},
},
clientRebalanceInterval: 30 * time.Second,
clientBuilder: clientBuilder,
}
opts.setClientBuilder(clientBuilder)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down Expand Up @@ -434,8 +434,8 @@ func TestSessionTokenOwner(t *testing.T) {
nodeParams: []NodeParam{
{1, "peer0", 1},
},
clientBuilder: clientBuilder,
}
opts.setClientBuilder(clientBuilder)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down Expand Up @@ -638,8 +638,8 @@ func TestSwitchAfterErrorThreshold(t *testing.T) {
key: newPrivateKey(t),
nodeParams: nodes,
clientRebalanceInterval: 30 * time.Second,
clientBuilder: clientBuilder,
}
opts.setClientBuilder(clientBuilder)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down

0 comments on commit 8c68264

Please sign in to comment.