Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

grpc: fix data race in balancer registration #16229

Merged
merged 3 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -1598,10 +1598,7 @@ func (a *Agent) ShutdownAgent() error {

a.stopLicenseManager()

// this would be cancelled anyways (by the closing of the shutdown ch) but
// this should help them to be stopped more quickly
a.baseDeps.AutoConfig.Stop()
a.baseDeps.MetricsConfig.Cancel()
a.baseDeps.Close()

a.stateLock.Lock()
defer a.stateLock.Unlock()
Expand Down
5 changes: 4 additions & 1 deletion agent/consul/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -522,9 +522,13 @@ func newDefaultDeps(t *testing.T, c *Config) Deps {

resolverBuilder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, c.NodeName+"-"+c.Datacenter))
resolver.Register(resolverBuilder)
t.Cleanup(func() {
resolver.Deregister(resolverBuilder.Authority())
})

balancerBuilder := balancer.NewBuilder(resolverBuilder.Authority(), testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)

r := router.NewRouter(
logger,
Expand Down Expand Up @@ -559,7 +563,6 @@ func newDefaultDeps(t *testing.T, c *Config) Deps {
UseTLSForDC: tls.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: c.Datacenter,
BalancerBuilder: balancerBuilder,
}),
LeaderForwarder: resolverBuilder,
NewRequestRecorderFunc: middleware.NewRequestRecorder,
Expand Down
3 changes: 1 addition & 2 deletions agent/consul/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1165,7 +1165,7 @@ func TestRPC_LocalTokenStrippedOnForward_GRPC(t *testing.T) {

var conn *grpc.ClientConn
{
client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t, func(c *Config) {
client, resolverBuilder := newClientWithGRPCPlumbing(t, func(c *Config) {
c.Datacenter = "dc2"
c.PrimaryDatacenter = "dc1"
c.RPCConfig.EnableStreaming = true
Expand All @@ -1177,7 +1177,6 @@ func TestRPC_LocalTokenStrippedOnForward_GRPC(t *testing.T) {
Servers: resolverBuilder,
DialingFromServer: false,
DialingFromDatacenter: "dc2",
BalancerBuilder: balancerBuilder,
})

conn, err = pool.ClientConn("dc2")
Expand Down
15 changes: 6 additions & 9 deletions agent/consul/subscribe_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
require.NoError(t, err)
defer server.Shutdown()

client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t, configureTLS, clientConfigVerifyOutgoing)
client, resolverBuilder := newClientWithGRPCPlumbing(t, configureTLS, clientConfigVerifyOutgoing)

// Try to join
testrpc.WaitForLeader(t, server.RPC, "dc1")
Expand Down Expand Up @@ -71,7 +71,6 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder,
})
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
Expand Down Expand Up @@ -116,7 +115,6 @@ func TestSubscribeBackend_IntegrationWithServer_TLSEnabled(t *testing.T) {
UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder,
})
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
Expand Down Expand Up @@ -191,7 +189,7 @@ func TestSubscribeBackend_IntegrationWithServer_TLSReload(t *testing.T) {
defer server.Shutdown()

// Set up a client with valid certs and verify_outgoing = true
client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t, configureTLS, clientConfigVerifyOutgoing)
client, resolverBuilder := newClientWithGRPCPlumbing(t, configureTLS, clientConfigVerifyOutgoing)

testrpc.WaitForLeader(t, server.RPC, "dc1")

Expand All @@ -204,7 +202,6 @@ func TestSubscribeBackend_IntegrationWithServer_TLSReload(t *testing.T) {
UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder,
})
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
Expand Down Expand Up @@ -284,7 +281,7 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T
codec := rpcClient(t, server)
defer codec.Close()

client, resolverBuilder, balancerBuilder := newClientWithGRPCPlumbing(t)
client, resolverBuilder := newClientWithGRPCPlumbing(t)

// Try to join
testrpc.WaitForLeader(t, server.RPC, "dc1")
Expand Down Expand Up @@ -346,7 +343,6 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T
UseTLSForDC: client.tlsConfigurator.UseTLS,
DialingFromServer: true,
DialingFromDatacenter: "dc1",
BalancerBuilder: balancerBuilder,
})
conn, err := pool.ClientConn("dc1")
require.NoError(t, err)
Expand Down Expand Up @@ -376,7 +372,7 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T
"at least some of the subscribers should have received non-snapshot updates")
}

func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *resolver.ServerResolverBuilder, *balancer.Builder) {
func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *resolver.ServerResolverBuilder) {
_, config := testClientConfig(t)
for _, op := range ops {
op(config)
Expand All @@ -392,6 +388,7 @@ func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *re

balancerBuilder := balancer.NewBuilder(resolverBuilder.Authority(), testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)

deps := newDefaultDeps(t, config)
deps.Router = router.NewRouter(
Expand All @@ -406,7 +403,7 @@ func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *re
t.Cleanup(func() {
client.Shutdown()
})
return client, resolverBuilder, balancerBuilder
return client, resolverBuilder
}

type testLogger interface {
Expand Down
40 changes: 20 additions & 20 deletions agent/grpc-internal/balancer/balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,25 @@ import (
"google.golang.org/grpc/status"
)

// NewBuilder constructs a new Builder with the given name.
func NewBuilder(name string, logger hclog.Logger) *Builder {
// NewBuilder constructs a new Builder. Calling Register will add the Builder
// to our global registry under the given "authority" such that it will be used
// when dialing targets in the form "consul-internal://<authority>/...", this
// allows us to add and remove balancers for different in-memory agents during
// tests.
func NewBuilder(authority string, logger hclog.Logger) *Builder {
return &Builder{
name: name,
logger: logger,
byTarget: make(map[string]*list.List),
shuffler: randomShuffler(),
authority: authority,
logger: logger,
byTarget: make(map[string]*list.List),
shuffler: randomShuffler(),
}
}

// Builder implements gRPC's balancer.Builder interface to construct balancers.
type Builder struct {
name string
logger hclog.Logger
shuffler shuffler
authority string
logger hclog.Logger
shuffler shuffler

mu sync.Mutex
byTarget map[string]*list.List
Expand Down Expand Up @@ -129,19 +133,15 @@ func (b *Builder) removeBalancer(targetURL string, elem *list.Element) {
}
}

// Name implements the gRPC Balancer interface by returning its given name.
func (b *Builder) Name() string { return b.name }

// gRPC's balancer.Register method is not thread-safe, so we guard our calls
// with a global lock (as it may be called from parallel tests).
var registerLock sync.Mutex

// Register the Builder in gRPC's global registry using its given name.
// Register the Builder in our global registry. Users should call Deregister
// when finished using the Builder to clean-up global state.
func (b *Builder) Register() {
registerLock.Lock()
defer registerLock.Unlock()
globalRegistry.register(b.authority, b)
}

gbalancer.Register(b)
// Deregister the Builder from our global registry to clean up state.
func (b *Builder) Deregister() {
globalRegistry.deregister(b.authority)
}

// Rebalance randomizes the priority order of servers for the given target to
Expand Down
41 changes: 25 additions & 16 deletions agent/grpc-internal/balancer/balancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"

"github.com/hashicorp/go-uuid"

"github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/consul/sdk/testutil/retry"
Expand All @@ -34,12 +36,13 @@ func TestBalancer(t *testing.T) {
server1 := runServer(t, "server1")
server2 := runServer(t, "server2")

target, _ := stubResolver(t, server1, server2)
target, authority, _ := stubResolver(t, server1, server2)

balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
balancerBuilder := NewBuilder(authority, testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)

conn := dial(t, target, balancerBuilder)
conn := dial(t, target)
client := testservice.NewSimpleClient(conn)

var serverName string
Expand Down Expand Up @@ -78,12 +81,13 @@ func TestBalancer(t *testing.T) {
server1 := runServer(t, "server1")
server2 := runServer(t, "server2")

target, _ := stubResolver(t, server1, server2)
target, authority, _ := stubResolver(t, server1, server2)

balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
balancerBuilder := NewBuilder(authority, testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)

conn := dial(t, target, balancerBuilder)
conn := dial(t, target)
client := testservice.NewSimpleClient(conn)

// Figure out which server we're talking to now, and which we should switch to.
Expand Down Expand Up @@ -123,10 +127,11 @@ func TestBalancer(t *testing.T) {
server1 := runServer(t, "server1")
server2 := runServer(t, "server2")

target, _ := stubResolver(t, server1, server2)
target, authority, _ := stubResolver(t, server1, server2)

balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
balancerBuilder := NewBuilder(authority, testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)

// Provide a custom prioritizer that causes Rebalance to choose whichever
// server didn't get our first request.
Expand All @@ -137,7 +142,7 @@ func TestBalancer(t *testing.T) {
})
}

conn := dial(t, target, balancerBuilder)
conn := dial(t, target)
client := testservice.NewSimpleClient(conn)

// Figure out which server we're talking to now.
Expand Down Expand Up @@ -177,12 +182,13 @@ func TestBalancer(t *testing.T) {
server1 := runServer(t, "server1")
server2 := runServer(t, "server2")

target, res := stubResolver(t, server1, server2)
target, authority, res := stubResolver(t, server1, server2)

balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
balancerBuilder := NewBuilder(authority, testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)

conn := dial(t, target, balancerBuilder)
conn := dial(t, target)
client := testservice.NewSimpleClient(conn)

// Figure out which server we're talking to now.
Expand Down Expand Up @@ -233,7 +239,7 @@ func TestBalancer(t *testing.T) {
})
}

func stubResolver(t *testing.T, servers ...*server) (string, *manual.Resolver) {
func stubResolver(t *testing.T, servers ...*server) (string, string, *manual.Resolver) {
t.Helper()

addresses := make([]resolver.Address, len(servers))
Expand All @@ -249,7 +255,10 @@ func stubResolver(t *testing.T, servers ...*server) (string, *manual.Resolver) {
resolver.Register(r)
t.Cleanup(func() { resolver.UnregisterForTesting(scheme) })

return fmt.Sprintf("%s://", scheme), r
authority, err := uuid.GenerateUUID()
require.NoError(t, err)

return fmt.Sprintf("%s://%s", scheme, authority), authority, r
}

func runServer(t *testing.T, name string) *server {
Expand Down Expand Up @@ -309,12 +318,12 @@ func (s *server) Something(context.Context, *testservice.Req) (*testservice.Resp
return &testservice.Resp{ServerName: s.name}, nil
}

func dial(t *testing.T, target string, builder *Builder) *grpc.ClientConn {
func dial(t *testing.T, target string) *grpc.ClientConn {
conn, err := grpc.Dial(
target,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultServiceConfig(
fmt.Sprintf(`{"loadBalancingConfig":[{"%s":{}}]}`, builder.Name()),
fmt.Sprintf(`{"loadBalancingConfig":[{"%s":{}}]}`, BuilderName),
),
)
t.Cleanup(func() {
Expand Down
69 changes: 69 additions & 0 deletions agent/grpc-internal/balancer/registry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package balancer

import (
"fmt"
"sync"

gbalancer "google.golang.org/grpc/balancer"
)

// BuilderName should be given in gRPC service configuration to enable our
// custom balancer. It refers to this package's global registry, rather than
// an instance of Builder to enable us to add and remove builders at runtime,
// specifically during tests.
const BuilderName = "consul-internal"

// gRPC's balancer.Register method is thread-unsafe because it mutates a global
// map without holding a lock. As such, it's expected that you register custom
// balancers once at the start of your program (e.g. a package init function).
//
// In production, this is fine. Agents register a single instance of our builder
// and use it for the duration. Tests are where this becomes problematic, as we
// spin up several agents in-memory and register/deregister a builder for each,
// with its own agent-specific state, logger, etc.
//
// To avoid data races, we call gRPC's Register method once, on-package init,
// with a global registry struct that implements the Builder interface but
// delegates the building to N instances of our Builder that are registered and
// deregistered at runtime. We the dial target's host (aka "authority") which
// is unique per-agent to pick the correct builder.
func init() {
gbalancer.Register(globalRegistry)
}

var globalRegistry = &registry{
byAuthority: make(map[string]*Builder),
}

type registry struct {
mu sync.RWMutex
byAuthority map[string]*Builder
}

func (r *registry) Build(cc gbalancer.ClientConn, opts gbalancer.BuildOptions) gbalancer.Balancer {
r.mu.RLock()
defer r.mu.RUnlock()

auth := opts.Target.URL.Host
builder, ok := r.byAuthority[auth]
if !ok {
panic(fmt.Sprintf("no gRPC balancer builder registered for authority: %q", auth))
}
return builder.Build(cc, opts)
}

func (r *registry) Name() string { return BuilderName }

func (r *registry) register(auth string, builder *Builder) {
r.mu.Lock()
defer r.mu.Unlock()

r.byAuthority[auth] = builder
}

func (r *registry) deregister(auth string) {
r.mu.Lock()
defer r.mu.Unlock()

delete(r.byAuthority, auth)
}
Loading