Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ccl/sqlproxyccl: validate cluster name before establishing connection #103479

Merged
merged 1 commit into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/ccl/sqlproxyccl/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ func (c *connector) lookupAddr(ctx context.Context) (string, error) {

// Lookup tenant in the directory cache. Once we have retrieve the list of
// pods, use the Balancer for load balancing.
pods, err := c.DirectoryCache.LookupTenantPods(ctx, c.TenantID, c.ClusterName)
pods, err := c.DirectoryCache.LookupTenantPods(ctx, c.TenantID)
switch {
case err == nil:
runningPods := make([]*tenant.Pod, 0, len(pods))
Expand Down
20 changes: 8 additions & 12 deletions pkg/ccl/sqlproxyccl/connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ func TestConnector_dialTenantCluster(t *testing.T) {
tenantID := roachpb.MustMakeTenantID(42)
directoryCache := &testTenantDirectoryCache{
lookupTenantPodsFn: func(
fnCtx context.Context, tenantID roachpb.TenantID, clusterName string,
fnCtx context.Context, tenantID roachpb.TenantID,
) ([]*tenant.Pod, error) {
mu.Lock()
defer mu.Unlock()
Expand Down Expand Up @@ -677,12 +677,11 @@ func TestConnector_lookupAddr(t *testing.T) {
}
c.DirectoryCache = &testTenantDirectoryCache{
lookupTenantPodsFn: func(
fnCtx context.Context, tenantID roachpb.TenantID, clusterName string,
fnCtx context.Context, tenantID roachpb.TenantID,
) ([]*tenant.Pod, error) {
lookupTenantPodsFnCount++
require.Equal(t, ctx, fnCtx)
require.Equal(t, c.TenantID, tenantID)
require.Equal(t, c.ClusterName, clusterName)
return []*tenant.Pod{
{TenantID: c.TenantID.ToUint64(), Addr: "127.0.0.10:70", State: tenant.DRAINING},
{TenantID: c.TenantID.ToUint64(), Addr: "127.0.0.10:80", State: tenant.RUNNING},
Expand All @@ -705,12 +704,11 @@ func TestConnector_lookupAddr(t *testing.T) {
}
c.DirectoryCache = &testTenantDirectoryCache{
lookupTenantPodsFn: func(
fnCtx context.Context, tenantID roachpb.TenantID, clusterName string,
fnCtx context.Context, tenantID roachpb.TenantID,
) ([]*tenant.Pod, error) {
lookupTenantPodsFnCount++
require.Equal(t, ctx, fnCtx)
require.Equal(t, c.TenantID, tenantID)
require.Equal(t, c.ClusterName, clusterName)
return nil, status.Errorf(codes.FailedPrecondition, "foo")
},
}
Expand All @@ -730,12 +728,11 @@ func TestConnector_lookupAddr(t *testing.T) {
}
c.DirectoryCache = &testTenantDirectoryCache{
lookupTenantPodsFn: func(
fnCtx context.Context, tenantID roachpb.TenantID, clusterName string,
fnCtx context.Context, tenantID roachpb.TenantID,
) ([]*tenant.Pod, error) {
lookupTenantPodsFnCount++
require.Equal(t, ctx, fnCtx)
require.Equal(t, c.TenantID, tenantID)
require.Equal(t, c.ClusterName, clusterName)
return nil, status.Errorf(codes.NotFound, "foo")
},
}
Expand All @@ -755,12 +752,11 @@ func TestConnector_lookupAddr(t *testing.T) {
}
c.DirectoryCache = &testTenantDirectoryCache{
lookupTenantPodsFn: func(
fnCtx context.Context, tenantID roachpb.TenantID, clusterName string,
fnCtx context.Context, tenantID roachpb.TenantID,
) ([]*tenant.Pod, error) {
lookupTenantPodsFnCount++
require.Equal(t, ctx, fnCtx)
require.Equal(t, c.TenantID, tenantID)
require.Equal(t, c.ClusterName, clusterName)
return nil, errors.New("foo")
},
}
Expand Down Expand Up @@ -982,7 +978,7 @@ var _ tenant.DirectoryCache = &testTenantDirectoryCache{}
// cache.
type testTenantDirectoryCache struct {
lookupTenantFn func(ctx context.Context, tenantID roachpb.TenantID) (*tenant.Tenant, error)
lookupTenantPodsFn func(ctx context.Context, tenantID roachpb.TenantID, clusterName string) ([]*tenant.Pod, error)
lookupTenantPodsFn func(ctx context.Context, tenantID roachpb.TenantID) ([]*tenant.Pod, error)
trylookupTenantPodsFn func(ctx context.Context, tenantID roachpb.TenantID) ([]*tenant.Pod, error)
reportFailureFn func(ctx context.Context, tenantID roachpb.TenantID, addr string) error
}
Expand All @@ -996,9 +992,9 @@ func (r *testTenantDirectoryCache) LookupTenant(

// LookupTenantPods implements the tenant.DirectoryCache interface.
func (r *testTenantDirectoryCache) LookupTenantPods(
ctx context.Context, tenantID roachpb.TenantID, clusterName string,
ctx context.Context, tenantID roachpb.TenantID,
) ([]*tenant.Pod, error) {
return r.lookupTenantPodsFn(ctx, tenantID, clusterName)
return r.lookupTenantPodsFn(ctx, tenantID)
}

// TryLookupTenantPods implements the tenant.DirectoryCache interface.
Expand Down
51 changes: 42 additions & 9 deletions pkg/ccl/sqlproxyccl/proxy_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,15 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn net.Conn)
return clientErr
}

// Validate the incoming connection and ensure that the cluster name
// matches the tenant's. This avoids malicious actors from attempting to
// connect to the cluster using just the tenant ID.
if err := handler.validateConnection(ctx, tenID, clusterName); err != nil {
// We do not need to log here as validateConnection already logs.
updateMetricsAndSendErrToClient(err, fe.Conn, handler.metrics)
return err
}

errConnection := make(chan error, 1)
removeListener, err := handler.aclWatcher.ListenForDenied(
ctx,
Expand All @@ -408,15 +417,12 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn net.Conn)
},
)
if err != nil {
if status.Code(err) == codes.NotFound {
err = withCode(
errors.Newf("cluster %s-%d not found", clusterName, tenID.ToUint64()),
codeParamsRoutingFailed,
)
} else {
log.Errorf(ctx, "connection blocked by access control list: %v", err)
err = withCode(errors.New("connection refused"), codeProxyRefusedConnection)
}
// It is possible that we get a NotFound error here because of a race
// with a deleting tenant. This case is rare, and we'll just return a
// "connection refused" error. The next time they connect, they will
// get a "not found" error.
log.Errorf(ctx, "connection blocked by access control list: %v", err)
err = withCode(errors.New("connection refused"), codeProxyRefusedConnection)
updateMetricsAndSendErrToClient(err, fe.Conn, handler.metrics)
return err
}
Expand Down Expand Up @@ -539,6 +545,33 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn net.Conn)
}
}

// validateRequest validates the incoming connection by ensuring that the SQL
// connection knows some additional information about the tenant (i.e. the
// cluster name) before being allowed to connect.
func (handler *proxyHandler) validateConnection(
ctx context.Context, tenantID roachpb.TenantID, clusterName string,
) error {
tenant, err := handler.directoryCache.LookupTenant(ctx, tenantID)
if err != nil && status.Code(err) != codes.NotFound {
return err
}
if err == nil {
if tenant.ClusterName == "" || tenant.ClusterName == clusterName {
return nil
}
log.Errorf(
ctx,
"could not validate connection: cluster name '%s' doesn't match expected '%s'",
clusterName,
tenant.ClusterName,
)
}
return withCode(
errors.Newf("cluster %s-%d not found", clusterName, tenantID.ToUint64()),
codeParamsRoutingFailed,
)
}

// handleCancelRequest handles a pgwire query cancel request by either
// forwarding it to a SQL node or to another proxy.
func (handler *proxyHandler) handleCancelRequest(
Expand Down
62 changes: 62 additions & 0 deletions pkg/ccl/sqlproxyccl/proxy_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,68 @@ const backendError = "Backend error!"
// the test directory server.
const notFoundTenantID = 99

func TestProxyHandler_ValidateConnection(t *testing.T) {
defer leaktest.AfterTest(t)()
defer log.Scope(t).Close(t)

ctx := context.Background()
stop := stop.NewStopper()
defer stop.Stop(ctx)

// Create the directory server.
tds := tenantdirsvr.NewTestStaticDirectoryServer(stop, nil /* timeSource */)
invalidTenantID := roachpb.MustMakeTenantID(99)
tenantID := roachpb.MustMakeTenantID(10)
tds.CreateTenant(tenantID, &tenant.Tenant{
TenantID: tenantID.ToUint64(),
ClusterName: "my-tenant",
})
tenantWithoutNameID := roachpb.MustMakeTenantID(20)
tds.CreateTenant(tenantWithoutNameID, &tenant.Tenant{
TenantID: tenantWithoutNameID.ToUint64(),
})
require.NoError(t, tds.Start(ctx))

options := &ProxyOptions{}
options.testingKnobs.directoryServer = tds
s, _, _ := newSecureProxyServer(ctx, t, stop, options)

t.Run("not found/no cluster name", func(t *testing.T) {
err := s.handler.validateConnection(ctx, invalidTenantID, "")
require.Regexp(t, "codeParamsRoutingFailed: cluster -99 not found", err.Error())
})
t.Run("not found", func(t *testing.T) {
err := s.handler.validateConnection(ctx, invalidTenantID, "foo-bar")
require.Regexp(t, "codeParamsRoutingFailed: cluster foo-bar-99 not found", err.Error())
})
t.Run("found/tenant without name", func(t *testing.T) {
err := s.handler.validateConnection(ctx, tenantWithoutNameID, "foo-bar")
require.NoError(t, err)
})
t.Run("found/tenant name matches", func(t *testing.T) {
err := s.handler.validateConnection(ctx, tenantID, "my-tenant")
require.NoError(t, err)
})
t.Run("found/connection without name", func(t *testing.T) {
err := s.handler.validateConnection(ctx, tenantID, "")
require.Regexp(t, "codeParamsRoutingFailed: cluster -10 not found", err.Error())
})
t.Run("found/tenant name mismatch", func(t *testing.T) {
err := s.handler.validateConnection(ctx, tenantID, "foo-bar")
require.Regexp(t, "codeParamsRoutingFailed: cluster foo-bar-10 not found", err.Error())
})

// Stop the directory server.
tds.Stop(ctx)

// Directory hasn't started
t.Run("directory error", func(t *testing.T) {
// Use a new tenant ID here to force GetTenant.
err := s.handler.validateConnection(ctx, roachpb.MustMakeTenantID(100), "")
require.Regexp(t, "directory server has not been started", err.Error())
})
}

func TestProxyProtocol(t *testing.T) {
defer leaktest.AfterTest(t)()
defer log.Scope(t).Close(t)
Expand Down
48 changes: 24 additions & 24 deletions pkg/ccl/sqlproxyccl/tenant/directory_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,9 @@ type DirectoryCache interface {

// LookupTenantPods returns a list of SQL pods in the RUNNING and DRAINING
// states for the given tenant. This blocks until there is at least one
// running SQL pod.
//
// If no matching pods are found (e.g. cluster name mismatch, or tenant was
// deleted), this will return a GRPC NotFound error.
LookupTenantPods(ctx context.Context, tenantID roachpb.TenantID, clusterName string) ([]*Pod, error)
// running SQL pod. If the tenant cannot be found, this will return a GRPC
// NotFound error.
LookupTenantPods(ctx context.Context, tenantID roachpb.TenantID) ([]*Pod, error)

// TryLookupTenantPods returns a list of SQL pods in the RUNNING and
// DRAINING states for the given tenant. It returns a GRPC NotFound error
Expand Down Expand Up @@ -176,6 +174,10 @@ func NewDirectoryCache(
// LookupTenant returns the tenant entry associated to the requested tenant
// ID. If the tenant cannot be found, this will return a GRPC NotFound error.
//
// WARNING: Callers should never attempt to modify values returned by this
// method, or else they may be a race. Other instances may be reading from the
// same object.
//
// LookupTenant implements the DirectoryCache interface.
func (d *directoryCache) LookupTenant(
ctx context.Context, tenantID roachpb.TenantID,
Expand All @@ -193,41 +195,26 @@ func (d *directoryCache) LookupTenant(
// states for the given tenant. If the tenant was just created or is suspended,
// such that there are no available RUNNING processes, then LookupTenantPods
// will trigger resumption of a new instance (or a conversion of a DRAINING pod
// to a RUNNING one) and block until that happens.
//
// If clusterName is non-empty, then a GRPC NotFound error is returned if no
// pods match the cluster name. This can be used to ensure that the incoming SQL
// connection "knows" some additional information about the tenant, such as the
// name of the cluster, before being allowed to connect. Similarly, if the
// tenant does not exist (e.g. because it was deleted), LookupTenantPods returns
// a GRPC NotFound error.
// to a RUNNING one) and block until that happens. If the tenant cannot be
// found, this will return a GRPC NotFound error.
//
// WARNING: Callers should never attempt to modify values returned by this
// method, or else they may be a race. Other instances may be reading from the
// same slice.
//
// LookupTenantPods implements the DirectoryCache interface.
func (d *directoryCache) LookupTenantPods(
ctx context.Context, tenantID roachpb.TenantID, clusterName string,
ctx context.Context, tenantID roachpb.TenantID,
) ([]*Pod, error) {
// Ensure that a directory entry has been created for this tenant.
entry, err := d.getEntry(ctx, tenantID, true /* allowCreate */)
if err != nil {
return nil, err
}

// Check if the cluster name matches. This can be skipped if clusterName
// is empty, or the ClusterName returned by the directory server is empty.
tenant := entry.ToProto()
if clusterName != "" && tenant.ClusterName != "" && clusterName != tenant.ClusterName {
// Return a GRPC NotFound error.
log.Errorf(ctx, "cluster name %s doesn't match expected %s", clusterName, tenant.ClusterName)
return nil, status.Errorf(codes.NotFound,
"cluster name %s doesn't match expected %s", clusterName, tenant.ClusterName)
}

ctx, cancel := d.stopper.WithCancelOnQuiesce(ctx)
defer cancel()

tenantPods := entry.GetPods()

// Trigger resumption if there are no RUNNING pods.
Expand Down Expand Up @@ -560,6 +547,19 @@ func (d *directoryCache) watchTenants(ctx context.Context, stopper *stop.Stopper
// (for a long time, or until the watcher catches up). Marking
// them as stale allows LookupTenant to fetch a new right away
// if needed.
//
// TODO(jaylim-crl): One optimization that could be done here is
// to build a new cache, while allowing the old one to work.
// Once the cache has been populated, we will swap the new and
// old caches. We can tell that the cache has been populated
// when events switch from ADDED to MODIFIED. Though, if we use
// this approach, it is possible that there aren't any MODIFIED
// events, and we're stuck waiting to switch the cache over.
// Perhaps a better idea would be to invoke GetTenant on the
// list of tenants which were previously valid individually.
// Note that it's unlikely for us to hit multiple cache misses
// during this short period unless we're getting thousands of
// connections with unique tenant IDs for the first time.
d.markAllEntriesInvalid()
// If stream ends, immediately try to establish a new one.
// Otherwise, wait for a second to avoid slamming server.
Expand Down
Loading