diff --git a/pkg/ccl/serverccl/tenant_grpc_test.go b/pkg/ccl/serverccl/tenant_grpc_test.go new file mode 100644 index 000000000000..b48c249a603c --- /dev/null +++ b/pkg/ccl/serverccl/tenant_grpc_test.go @@ -0,0 +1,111 @@ +// Copyright 2021 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package serverccl + +import ( + "context" + "io/ioutil" + "net/http" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/rpc" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util/httputil" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/stretchr/testify/require" +) + +// TestTenantGRPCServices tests that the gRPC servers that are externally +// facing have been started up on the tenant server. This includes gRPC that is +// used for pod-to-pod communication as well as the HTTP services powered by +// gRPC Gateway that are used to serve endpoints to power observability UIs. +func TestTenantGRPCServices(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + + testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{}) + defer testCluster.Stopper().Stop(ctx) + + server := testCluster.Server(0) + + tenantID := roachpb.MakeTenantID(10) + tenant, connTenant := serverutils.StartTenant(t, server, base.TestTenantArgs{ + TenantID: tenantID, + }) + defer connTenant.Close() + + t.Run("gRPC is running", func(t *testing.T) { + grpcAddr := tenant.SQLAddr() + rpcCtx := tenant.RPCContext() + + conn, err := rpcCtx.GRPCDialNode(grpcAddr, roachpb.NodeID(tenant.SQLInstanceID()), rpc.DefaultClass).Connect(ctx) + require.NoError(t, err) + defer conn.Close() + + client := serverpb.NewStatusClient(conn) + + resp, err := client.Statements(ctx, &serverpb.StatementsRequest{NodeID: "local"}) + require.NoError(t, err) + require.NotEmpty(t, resp.Statements) + }) + + t.Run("gRPC Gateway is running", func(t *testing.T) { + resp, err := httputil.Get(ctx, "http://"+tenant.HTTPAddr()+"/_status/statements") + defer http.DefaultClient.CloseIdleConnections() + require.NoError(t, err) + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + require.Contains(t, string(body), "transactions") + }) + + sqlRunner := sqlutils.MakeSQLRunner(connTenant) + sqlRunner.Exec(t, "CREATE TABLE test (id int)") + sqlRunner.Exec(t, "INSERT INTO test VALUES (1)") + + tenant2, connTenant2 := serverutils.StartTenant(t, server, base.TestTenantArgs{ + TenantID: tenantID, + Existing: true, + }) + defer connTenant2.Close() + + t.Run("statements endpoint fans out request to multiple pods", func(t *testing.T) { + resp, err := httputil.Get(ctx, "http://"+tenant2.HTTPAddr()+"/_status/statements") + defer http.DefaultClient.CloseIdleConnections() + require.NoError(t, err) + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + require.Contains(t, string(body), "CREATE TABLE test") + require.Contains(t, string(body), "INSERT INTO test VALUES") + }) + + tenant3, connTenant3 := serverutils.StartTenant(t, server, base.TestTenantArgs{ + TenantID: roachpb.MakeTenantID(11), + }) + defer connTenant3.Close() + + t.Run("fanout of statements endpoint is segregated by tenant", func(t *testing.T) { + resp, err := httputil.Get(ctx, "http://"+tenant3.HTTPAddr()+"/_status/statements") + defer http.DefaultClient.CloseIdleConnections() + require.NoError(t, err) + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + require.NotContains(t, string(body), "CREATE TABLE test") + require.NotContains(t, string(body), "INSERT INTO test VALUES") + }) +} diff --git a/pkg/rpc/auth_tenant.go b/pkg/rpc/auth_tenant.go index 67a335c9bc1e..5517ce2487a2 100644 --- a/pkg/rpc/auth_tenant.go +++ b/pkg/rpc/auth_tenant.go @@ -55,6 +55,10 @@ func (a tenantAuthorizer) authorize( case "/cockroach.rpc.Heartbeat/Ping": return nil // no authorization + case "/cockroach.server.serverpb.Status/Statements": + // The Statements endpoint requires only SQL + return nil // no authorization + default: return authErrorf("unknown method %q", fullMethod) } diff --git a/pkg/rpc/context.go b/pkg/rpc/context.go index f2c38589375e..20e40e7f3aff 100644 --- a/pkg/rpc/context.go +++ b/pkg/rpc/context.go @@ -1018,6 +1018,19 @@ func (ctx *Context) GRPCDialNode( return ctx.grpcDialNodeInternal(target, remoteNodeID, class) } +// GRPCDialPod wraps GRPCDialNode and treats the `remoteInstanceID` +// argument as a `NodeID` which it converts. This works because the +// tenant gRPC server is initialized using the `InstanceID` so it +// accepts our connection as matching the ID we're dialing. +// +// Since GRPCDialNode accepts a separate `target` and `NodeID` it +// requires no further modification to work between pods. +func (ctx *Context) GRPCDialPod( + target string, remoteInstanceID base.SQLInstanceID, class ConnectionClass, +) *Connection { + return ctx.GRPCDialNode(target, roachpb.NodeID(remoteInstanceID), class) +} + func (ctx *Context) grpcDialNodeInternal( target string, remoteNodeID roachpb.NodeID, class ConnectionClass, ) *Connection { diff --git a/pkg/server/server.go b/pkg/server/server.go index 957635b59a83..19940a98831d 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -1317,7 +1317,7 @@ func (s *Server) PreStart(ctx context.Context) error { // and dispatches the server worker for the RPC. // The SQL listener is returned, to start the SQL server later // below when the server has initialized. - pgL, startRPCServer, err := s.startListenRPCAndSQL(ctx, workersCtx) + pgL, startRPCServer, err := StartListenRPCAndSQL(ctx, workersCtx, s.cfg.BaseConfig, s.stopper, s.grpc) if err != nil { return err } @@ -1350,83 +1350,18 @@ func (s *Server) PreStart(ctx context.Context) error { // Initialize grpc-gateway mux and context in order to get the /health // endpoint working even before the node has fully initialized. - jsonpb := &protoutil.JSONPb{ - EnumsAsInts: true, - EmitDefaults: true, - Indent: " ", - } - protopb := new(protoutil.ProtoPb) - gwMux := gwruntime.NewServeMux( - gwruntime.WithMarshalerOption(gwruntime.MIMEWildcard, jsonpb), - gwruntime.WithMarshalerOption(httputil.JSONContentType, jsonpb), - gwruntime.WithMarshalerOption(httputil.AltJSONContentType, jsonpb), - gwruntime.WithMarshalerOption(httputil.ProtoContentType, protopb), - gwruntime.WithMarshalerOption(httputil.AltProtoContentType, protopb), - gwruntime.WithOutgoingHeaderMatcher(authenticationHeaderMatcher), - gwruntime.WithMetadata(forwardAuthenticationMetadata), + gwMux, gwCtx, conn, err := ConfigureGRPCGateway( + ctx, + workersCtx, + s.cfg.AmbientCtx, + s.rpcContext, + s.stopper, + s.grpc, + s.cfg.AdvertiseAddr, ) - gwCtx, gwCancel := context.WithCancel(s.AnnotateCtx(context.Background())) - s.stopper.AddCloser(stop.CloserFn(gwCancel)) - - // loopback handles the HTTP <-> RPC loopback connection. - loopback := newLoopbackListener(workersCtx, s.stopper) - - waitQuiesce := func(context.Context) { - <-s.stopper.ShouldQuiesce() - _ = loopback.Close() - } - if err := s.stopper.RunAsyncTask(workersCtx, "gw-quiesce", waitQuiesce); err != nil { - waitQuiesce(workersCtx) - } - - _ = s.stopper.RunAsyncTask(workersCtx, "serve-loopback", func(context.Context) { - netutil.FatalIfUnexpected(s.grpc.Serve(loopback)) - }) - - // Eschew `(*rpc.Context).GRPCDial` to avoid unnecessary moving parts on the - // uniquely in-process connection. - dialOpts, err := s.rpcContext.GRPCDialOptions() - if err != nil { - return err - } - - callCountInterceptor := func( - ctx context.Context, - method string, - req, reply interface{}, - cc *grpc.ClientConn, - invoker grpc.UnaryInvoker, - opts ...grpc.CallOption, - ) error { - telemetry.Inc(getServerEndpointCounter(method)) - return invoker(ctx, method, req, reply, cc, opts...) - } - conn, err := grpc.DialContext(ctx, s.cfg.AdvertiseAddr, append(append( - dialOpts, - grpc.WithUnaryInterceptor(callCountInterceptor)), - grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { - return loopback.Connect(ctx) - }), - )...) if err != nil { return err } - { - waitQuiesce := func(workersCtx context.Context) { - <-s.stopper.ShouldQuiesce() - // NB: we can't do this as a Closer because (*Server).ServeWith is - // running in a worker and usually sits on accept() which unblocks - // only when the listener closes. In other words, the listener needs - // to close when quiescing starts to allow that worker to shut down. - err := conn.Close() // nolint:grpcconnclose - if err != nil { - log.Ops.Fatalf(workersCtx, "%v", err) - } - } - if err := s.stopper.RunAsyncTask(workersCtx, "wait-quiesce", waitQuiesce); err != nil { - waitQuiesce(workersCtx) - } - } for _, gw := range []grpcGatewayServer{s.admin, s.status, s.authentication, s.tsServer} { if err := gw.RegisterGateway(gwCtx, gwMux, conn); err != nil { @@ -1916,6 +1851,109 @@ func (s *Server) PreStart(ctx context.Context) error { return maybeImportTS(ctx, s) } +// ConfigureGRPCGateway initializes services necessary for running the +// GRPC Gateway services proxied against the server at `grpcSrv`. +// +// The connection between the reverse proxy provided by grpc-gateway +// and our grpc server uses a loopback-based listener to create +// connections between the two. +// +// The function returns 3 arguments that are necessary to call +// `RegisterGateway` which generated for each of your gRPC services +// by grppc-gateway. +func ConfigureGRPCGateway( + // TODO(davidh): BEFORE MERGE: need guidance on how to cull the number + // of contexts used here. Seems like we need a separate worker ctx at + // least but maybe it can be crated in the method instead of passed + // in + ctx context.Context, + workersCtx context.Context, + ambientCtx log.AmbientContext, + rpcContext *rpc.Context, + stopper *stop.Stopper, + grpcSrv *grpcServer, + GRPCAddr string, +) (*gwruntime.ServeMux, context.Context, *grpc.ClientConn, error) { + jsonpb := &protoutil.JSONPb{ + EnumsAsInts: true, + EmitDefaults: true, + Indent: " ", + } + protopb := new(protoutil.ProtoPb) + gwMux := gwruntime.NewServeMux( + gwruntime.WithMarshalerOption(gwruntime.MIMEWildcard, jsonpb), + gwruntime.WithMarshalerOption(httputil.JSONContentType, jsonpb), + gwruntime.WithMarshalerOption(httputil.AltJSONContentType, jsonpb), + gwruntime.WithMarshalerOption(httputil.ProtoContentType, protopb), + gwruntime.WithMarshalerOption(httputil.AltProtoContentType, protopb), + gwruntime.WithOutgoingHeaderMatcher(authenticationHeaderMatcher), + gwruntime.WithMetadata(forwardAuthenticationMetadata), + ) + gwCtx, gwCancel := context.WithCancel(ambientCtx.AnnotateCtx(context.Background())) + stopper.AddCloser(stop.CloserFn(gwCancel)) + + // loopback handles the HTTP <-> RPC loopback connection. + loopback := newLoopbackListener(workersCtx, stopper) + + waitQuiesce := func(context.Context) { + <-stopper.ShouldQuiesce() + _ = loopback.Close() + } + if err := stopper.RunAsyncTask(workersCtx, "gw-quiesce", waitQuiesce); err != nil { + waitQuiesce(workersCtx) + } + + _ = stopper.RunAsyncTask(workersCtx, "serve-loopback", func(context.Context) { + netutil.FatalIfUnexpected(grpcSrv.Serve(loopback)) + }) + + // Eschew `(*rpc.Context).GRPCDial` to avoid unnecessary moving parts on the + // uniquely in-process connection. + dialOpts, err := rpcContext.GRPCDialOptions() + if err != nil { + return nil, nil, nil, err + } + + callCountInterceptor := func( + ctx context.Context, + method string, + req, reply interface{}, + cc *grpc.ClientConn, + invoker grpc.UnaryInvoker, + opts ...grpc.CallOption, + ) error { + telemetry.Inc(getServerEndpointCounter(method)) + return invoker(ctx, method, req, reply, cc, opts...) + } + conn, err := grpc.DialContext(ctx, GRPCAddr, append(append( + dialOpts, + grpc.WithUnaryInterceptor(callCountInterceptor)), + grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { + return loopback.Connect(ctx) + }), + )...) + if err != nil { + return nil, nil, nil, err + } + { + waitQuiesce := func(workersCtx context.Context) { + <-stopper.ShouldQuiesce() + // NB: we can't do this as a Closer because (*Server).ServeWith is + // running in a worker and usually sits on accept() which unblocks + // only when the listener closes. In other words, the listener needs + // to close when quiescing starts to allow that worker to shut down. + err := conn.Close() // nolint:grpcconnclose + if err != nil { + log.Ops.Fatalf(workersCtx, "%v", err) + } + } + if err := stopper.RunAsyncTask(workersCtx, "wait-quiesce", waitQuiesce); err != nil { + waitQuiesce(workersCtx) + } + } + return gwMux, gwCtx, conn, nil +} + func maybeImportTS(ctx context.Context, s *Server) error { knobs, _ := s.cfg.TestingKnobs.Server.(*TestingKnobs) if knobs == nil { @@ -2062,43 +2100,43 @@ func (s *Server) AcceptClients(ctx context.Context) error { return nil } -// startListenRPCAndSQL starts the RPC and SQL listeners. +// StartListenRPCAndSQL starts the RPC and SQL listeners. // It returns the SQL listener, which can be used // to start the SQL server when initialization has completed. // It also returns a function that starts the RPC server, // when the cluster is known to have bootstrapped or // when waiting for init(). -func (s *Server) startListenRPCAndSQL( - ctx, workersCtx context.Context, +func StartListenRPCAndSQL( + ctx, workersCtx context.Context, cfg BaseConfig, stopper *stop.Stopper, grpc *grpcServer, ) (sqlListener net.Listener, startRPCServer func(ctx context.Context), err error) { rpcChanName := "rpc/sql" - if s.cfg.SplitListenSQL { + if cfg.SplitListenSQL { rpcChanName = "rpc" } var ln net.Listener - if k := s.cfg.TestingKnobs.Server; k != nil { + if k := cfg.TestingKnobs.Server; k != nil { knobs := k.(*TestingKnobs) ln = knobs.RPCListener } if ln == nil { var err error - ln, err = ListenAndUpdateAddrs(ctx, &s.cfg.Addr, &s.cfg.AdvertiseAddr, rpcChanName) + ln, err = ListenAndUpdateAddrs(ctx, &cfg.Addr, &cfg.AdvertiseAddr, rpcChanName) if err != nil { return nil, nil, err } - log.Eventf(ctx, "listening on port %s", s.cfg.Addr) + log.Eventf(ctx, "listening on port %s", cfg.Addr) } var pgL net.Listener - if s.cfg.SplitListenSQL { - pgL, err = ListenAndUpdateAddrs(ctx, &s.cfg.SQLAddr, &s.cfg.SQLAdvertiseAddr, "sql") + if cfg.SplitListenSQL { + pgL, err = ListenAndUpdateAddrs(ctx, &cfg.SQLAddr, &cfg.SQLAdvertiseAddr, "sql") if err != nil { return nil, nil, err } // The SQL listener shutdown worker, which closes everything under // the SQL port when the stopper indicates we are shutting down. waitQuiesce := func(ctx context.Context) { - <-s.stopper.ShouldQuiesce() + <-stopper.ShouldQuiesce() // NB: we can't do this as a Closer because (*Server).ServeWith is // running in a worker and usually sits on accept() which unblocks // only when the listener closes. In other words, the listener needs @@ -2107,11 +2145,11 @@ func (s *Server) startListenRPCAndSQL( log.Ops.Fatalf(ctx, "%v", err) } } - if err := s.stopper.RunAsyncTask(workersCtx, "wait-quiesce", waitQuiesce); err != nil { + if err := stopper.RunAsyncTask(workersCtx, "wait-quiesce", waitQuiesce); err != nil { waitQuiesce(workersCtx) return nil, nil, err } - log.Eventf(ctx, "listening on sql port %s", s.cfg.SQLAddr) + log.Eventf(ctx, "listening on sql port %s", cfg.SQLAddr) } // serveOnMux is used to ensure that the mux gets listened on eventually, @@ -2120,7 +2158,7 @@ func (s *Server) startListenRPCAndSQL( m := cmux.New(ln) - if !s.cfg.SplitListenSQL { + if !cfg.SplitListenSQL { // If the pg port is split, it will be opened above. Otherwise, // we make it hang off the RPC listener via cmux here. pgL = m.Match(func(r io.Reader) bool { @@ -2129,12 +2167,12 @@ func (s *Server) startListenRPCAndSQL( // Also if the pg port is not split, the actual listen // and advertise addresses for SQL become equal to that // of RPC, regardless of what was configured. - s.cfg.SQLAddr = s.cfg.Addr - s.cfg.SQLAdvertiseAddr = s.cfg.AdvertiseAddr + cfg.SQLAddr = cfg.Addr + cfg.SQLAdvertiseAddr = cfg.AdvertiseAddr } anyL := m.Match(cmux.Any()) - if serverTestKnobs, ok := s.cfg.TestingKnobs.Server.(*TestingKnobs); ok { + if serverTestKnobs, ok := cfg.TestingKnobs.Server.(*TestingKnobs); ok { if serverTestKnobs.ContextTestingKnobs.ArtificialLatencyMap != nil { anyL = rpc.NewDelayingListener(anyL) } @@ -2142,12 +2180,12 @@ func (s *Server) startListenRPCAndSQL( // The remainder shutdown worker. waitForQuiesce := func(context.Context) { - <-s.stopper.ShouldQuiesce() + <-stopper.ShouldQuiesce() // TODO(bdarnell): Do we need to also close the other listeners? netutil.FatalIfUnexpected(anyL.Close()) } - s.stopper.AddCloser(stop.CloserFn(func() { - s.grpc.Stop() + stopper.AddCloser(stop.CloserFn(func() { + grpc.Stop() serveOnMux.Do(func() { // The cmux matches don't shut down properly unless serve is called on the // cmux at some point. Use serveOnMux to ensure it's called during shutdown @@ -2155,7 +2193,7 @@ func (s *Server) startListenRPCAndSQL( netutil.FatalIfUnexpected(m.Serve()) }) })) - if err := s.stopper.RunAsyncTask( + if err := stopper.RunAsyncTask( workersCtx, "grpc-quiesce", waitForQuiesce, ); err != nil { return nil, nil, err @@ -2167,11 +2205,11 @@ func (s *Server) startListenRPCAndSQL( // (Server.Start) will call this at the right moment. startRPCServer = func(ctx context.Context) { // Serve the gRPC endpoint. - _ = s.stopper.RunAsyncTask(workersCtx, "serve-grpc", func(context.Context) { - netutil.FatalIfUnexpected(s.grpc.Serve(anyL)) + _ = stopper.RunAsyncTask(workersCtx, "serve-grpc", func(context.Context) { + netutil.FatalIfUnexpected(grpc.Serve(anyL)) }) - _ = s.stopper.RunAsyncTask(ctx, "serve-mux", func(context.Context) { + _ = stopper.RunAsyncTask(ctx, "serve-mux", func(context.Context) { serveOnMux.Do(func() { netutil.FatalIfUnexpected(m.Serve()) }) diff --git a/pkg/server/status.go b/pkg/server/status.go index e937cf61adcc..9fea8940b760 100644 --- a/pkg/server/status.go +++ b/pkg/server/status.go @@ -128,6 +128,12 @@ func propagateGatewayMetadata(ctx context.Context) context.Context { // baseStatusServer implements functionality shared by the tenantStatusServer // and the full statusServer. type baseStatusServer struct { + // Embedding the UnimplementedStatusServer lets us easily support + // treating the tenantStatusServer as implementing the StatusServer + // interface. We'd return an unimplemented error for the methods we + // didn't require anyway. + serverpb.UnimplementedStatusServer + log.AmbientContext privilegeChecker *adminPrivilegeChecker sessionRegistry *sql.SessionRegistry @@ -135,6 +141,8 @@ type baseStatusServer struct { flowScheduler *flowinfra.FlowScheduler st *cluster.Settings sqlServer *SQLServer + rpcCtx *rpc.Context + stopper *stop.Stopper } // getLocalSessions returns a list of local sessions on this node. Note that the @@ -391,9 +399,7 @@ type statusServer struct { metricSource metricMarshaler nodeLiveness *liveness.NodeLiveness storePool *kvserver.StorePool - rpcCtx *rpc.Context stores *kvserver.Stores - stopper *stop.Stopper si systemInfoOnce stmtDiagnosticsRequester StmtDiagnosticsRequester internalExecutor *sql.InternalExecutor @@ -438,6 +444,8 @@ func newStatusServer( contentionRegistry: contentionRegistry, flowScheduler: flowScheduler, st: st, + rpcCtx: rpcCtx, + stopper: stopper, }, cfg: cfg, admin: adminServer, @@ -446,9 +454,7 @@ func newStatusServer( metricSource: metricSource, nodeLiveness: nodeLiveness, storePool: storePool, - rpcCtx: rpcCtx, stores: stores, - stopper: stopper, internalExecutor: internalExecutor, } diff --git a/pkg/server/tenant.go b/pkg/server/tenant.go index 08f33b380088..ba3819682705 100644 --- a/pkg/server/tenant.go +++ b/pkg/server/tenant.go @@ -72,7 +72,15 @@ func StartTenant( SetupIdleMonitor(ctx, args.stopper, baseCfg.IdleExitAfter, connManager) } - pgL, err := ListenAndUpdateAddrs(ctx, &args.Config.SQLAddr, &args.Config.SQLAdvertiseAddr, "sql") + // Initialize gRPC server for use on shared port with pg + grpcMain := newGRPCServer(args.rpcContext) + grpcMain.setMode(modeOperational) + + // TODO(davidh): Do we need to force this to be false? + baseCfg.SplitListenSQL = false + + background := baseCfg.AmbientCtx.AnnotateCtx(context.Background()) + pgL, startRPCServer, err := StartListenRPCAndSQL(ctx, background, baseCfg, stopper, grpcMain) if err != nil { return nil, "", "", err } @@ -114,14 +122,47 @@ func StartTenant( return nil, "", "", err } - s.execCfg.SQLStatusServer = newTenantStatusServer( + // This is necessary so the grpc server doesn't error out on heartbeat + // ping when we make pod-to-pod calls, we pass the InstanceID with the + // request to ensure we're dialing the pod we think we are. + args.rpcContext.NodeID.Set(ctx, roachpb.NodeID(s.SQLInstanceID())) + + tenantStatusServer := newTenantStatusServer( baseCfg.AmbientCtx, &adminPrivilegeChecker{ie: args.circularInternalExecutor}, args.sessionRegistry, args.contentionRegistry, args.flowScheduler, baseCfg.Settings, s, + args.rpcContext, args.stopper, ) + s.execCfg.SQLStatusServer = tenantStatusServer + // TODO(asubiotto): remove this. Right now it is needed to initialize the // SpanResolver. s.execCfg.DistSQLPlanner.SetNodeInfo(roachpb.NodeDescriptor{NodeID: 0}) + + workersCtx := tenantStatusServer.AnnotateCtx(context.Background()) + + // Register and start gRPC service on pod. This is separate from the + // gRPC + Gateway services configured below. + tenantStatusServer.RegisterService(grpcMain.Server) + startRPCServer(workersCtx) + + // Begin configuration of GRPC Gateway + gwMux, gwCtx, conn, err := ConfigureGRPCGateway( + ctx, + workersCtx, + args.AmbientCtx, + tenantStatusServer.rpcCtx, + s.stopper, + grpcMain, + pgLAddr, + ) + if err != nil { + return nil, "", "", err + } + if err := tenantStatusServer.RegisterGateway(gwCtx, gwMux, conn); err != nil { + return nil, "", "", err + } + args.recorder.AddNode( args.registry, roachpb.NodeDescriptor{}, @@ -135,6 +176,7 @@ func StartTenant( mux := http.NewServeMux() debugServer := debug.NewServer(args.Settings, s.pgServer.HBADebugFn()) mux.Handle("/", debugServer) + mux.Handle("/_status/", gwMux) mux.HandleFunc("/health", func(w http.ResponseWriter, req *http.Request) { // Return Bad Request if called with arguments. if err := req.ParseForm(); err != nil || len(req.Form) != 0 { @@ -325,6 +367,7 @@ func makeTenantSQLServerArgs( // We don't need this for anything except some services that want a gRPC // server to register against (but they'll never get RPCs at the time of // writing): the blob service and DistSQL. + // TODO(davidh): anything to do below??? replace with server we added? dummyRPCServer := rpc.NewServer(rpcContext) sessionRegistry := sql.NewSessionRegistry() contentionRegistry := contention.NewRegistry() diff --git a/pkg/server/tenant_status.go b/pkg/server/tenant_status.go index d436474f9593..57ab82dc9ab5 100644 --- a/pkg/server/tenant_status.go +++ b/pkg/server/tenant_status.go @@ -17,15 +17,30 @@ package server import ( "context" + "fmt" + "strconv" "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/rpc" "github.com/cockroachdb/cockroach/pkg/security" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/sql" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/catconstants" "github.com/cockroachdb/cockroach/pkg/sql/contention" "github.com/cockroachdb/cockroach/pkg/sql/flowinfra" + "github.com/cockroachdb/cockroach/pkg/sql/sqlinstance" + "github.com/cockroachdb/cockroach/pkg/util/contextutil" "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/quotapool" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/errors" + gwruntime "github.com/grpc-ecosystem/grpc-gateway/runtime" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) // tenantStatusServer is an implementation of a SQLStatusServer that is @@ -35,9 +50,27 @@ import ( // Phase 2 requirements that there can only be at most one live SQL pod per // tenant. type tenantStatusServer struct { - baseStatusServer + baseStatusServer // embeds UnimplementedStatusServer } +// We require that `tenantStatusServer` implement +// `serverpb.StatusServer` even though we only have partial +// implementation, in order to serve some endpoints on tenants. +var _ serverpb.StatusServer = &tenantStatusServer{} + +func (t *tenantStatusServer) RegisterService(g *grpc.Server) { + serverpb.RegisterStatusServer(g, t) +} + +func (t *tenantStatusServer) RegisterGateway( + ctx context.Context, mux *gwruntime.ServeMux, conn *grpc.ClientConn, +) error { + ctx = t.AnnotateCtx(ctx) + return serverpb.RegisterStatusHandler(ctx, mux, conn) +} + +var _ grpcGatewayServer = &tenantStatusServer{} + func newTenantStatusServer( ambient log.AmbientContext, privilegeChecker *adminPrivilegeChecker, @@ -46,6 +79,8 @@ func newTenantStatusServer( flowScheduler *flowinfra.FlowScheduler, st *cluster.Settings, sqlServer *SQLServer, + rpcCtx *rpc.Context, + stopper *stop.Stopper, ) *tenantStatusServer { ambient.AddLogTag("tenant-status", nil) return &tenantStatusServer{ @@ -57,6 +92,8 @@ func newTenantStatusServer( flowScheduler: flowScheduler, st: st, sqlServer: sqlServer, + rpcCtx: rpcCtx, + stopper: stopper, }, } } @@ -118,21 +155,237 @@ func (t *tenantStatusServer) ResetSQLStats( return &serverpb.ResetSQLStatsResponse{}, nil } +// Statements implements the relevant endpoint on the StatusServer by +// fanning out a request to all pods on the current tenant via gRPC to collect +// in-memory statistics and aggregate them for the caller. +// +// The implementation is based on the one in statements.go but differs +// by leaning on the InstanceID subsystem to implement the fan-out. If +// the InstanceID and NodeID subsystems can be unified in some way, +// these implementations could be merged. func (t *tenantStatusServer) Statements( - ctx context.Context, _ *serverpb.StatementsRequest, + ctx context.Context, req *serverpb.StatementsRequest, ) (*serverpb.StatementsResponse, error) { if _, err := t.privilegeChecker.requireViewActivityPermission(ctx); err != nil { return nil, err } - // Use a dummy value here until pod-to-pod communication is implemented since tenant status server - // does not have concept of node. - resp, err := statementsLocal(ctx, &base.NodeIDContainer{}, t.sqlServer) + response := &serverpb.StatementsResponse{ + Statements: []serverpb.StatementsResponse_CollectedStatementStatistics{}, + LastReset: timeutil.Now(), + InternalAppNamePrefix: catconstants.InternalAppNamePrefix, + } + + localReq := &serverpb.StatementsRequest{ + NodeID: "local", + } + + if len(req.NodeID) > 0 { + // We are interpreting the `NodeID` in the request as an `InstanceID` since + // we are executing in the context of a tenant. + parsedInstanceID, local, err := t.parseInstanceID(req.NodeID) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, err.Error()) + } + if local { + container := base.NodeIDContainer{} + container.Set(ctx, roachpb.NodeID(t.sqlServer.SQLInstanceID())) + return statementsLocal(ctx, &container, t.sqlServer) + } + // Iterate through live instances to find the one we want to target + liveTenantInstances, err := t.getLiveInstancesForTenant(ctx) + if err != nil { + return nil, err + } + var target sqlinstance.InstanceInfo + for _, i := range liveTenantInstances { + if i.InstanceID() == parsedInstanceID { + target = i + break + } + } + statusClient, err := t.dialPod(ctx, target) + if err != nil { + return nil, err + } + return statusClient.Statements(ctx, localReq) + } + + dialFn := func(ctx context.Context, instance sqlinstance.InstanceInfo) (interface{}, error) { + client, err := t.dialPod(ctx, instance) + return client, err + } + nodeStatement := func(ctx context.Context, client interface{}, _ base.SQLInstanceID) (interface{}, error) { + statusClient := client.(serverpb.StatusClient) + return statusClient.Statements(ctx, localReq) + } + + if err := t.iteratePods(ctx, fmt.Sprintf("statement statistics for node %s", req.NodeID), + dialFn, + nodeStatement, + func(instanceID base.SQLInstanceID, resp interface{}) { + statementsResp := resp.(*serverpb.StatementsResponse) + response.Statements = append(response.Statements, statementsResp.Statements...) + response.Transactions = append(response.Transactions, statementsResp.Transactions...) + if response.LastReset.After(statementsResp.LastReset) { + response.LastReset = statementsResp.LastReset + } + }, + func(instanceID base.SQLInstanceID, err error) { + // We log warnings when fanout returns error, but proceed with + // constructing a response from whoever returns a good one. + log.Warningf(ctx, "fan out statements request recorded error from node %d: %v", instanceID, err) + }, + ); err != nil { + return nil, err + } + + return response, nil +} + +// parseInstanceID is based on status.parseNodeID +func (t *tenantStatusServer) parseInstanceID( + instanceIDParam string, +) (base.SQLInstanceID, bool, error) { + // No parameter provided or set to local. + if len(instanceIDParam) == 0 || localRE.MatchString(instanceIDParam) { + return t.sqlServer.SQLInstanceID(), true, nil + } + + id, err := strconv.ParseInt(instanceIDParam, 0, 32) + if err != nil { + return 0, false, errors.Wrap(err, "instance ID could not be parsed") + } + instanceID := base.SQLInstanceID(id) + return instanceID, instanceID == t.sqlServer.SQLInstanceID(), nil +} + +func (t *tenantStatusServer) dialPod( + ctx context.Context, instance sqlinstance.InstanceInfo, +) (serverpb.StatusClient, error) { + addr := instance.InstanceAddr() + conn, err := t.rpcCtx.GRPCDialPod(addr, instance.InstanceID(), rpc.DefaultClass).Connect(ctx) + if err != nil { + return nil, err + } + + // nb: The server on the pods doesn't implement all the methods of the + // `StatusService`. It is up to the caller of `dialPod` to only call + // methods that are implemented on the tenant server. + return serverpb.NewStatusClient(conn), nil +} + +// iteratePods is based on the implementation of `iterateNodes in the +// status server the two implementations have not been unified into one +// because there are some deep differences since we use `InstanceInfo` +// instead of `NodeID`. Since the eventual plan is to deprecate +// `tenant_status.go` altogether, we're' leaving this code-as is. +// +// TODO(davidh): unify with `status.iterateNodes` once this server is +// deprecated +func (t *tenantStatusServer) iteratePods( + ctx context.Context, + errorCtx string, + dialFn func(ctx context.Context, instance sqlinstance.InstanceInfo) (interface{}, error), + instanceFn func(ctx context.Context, client interface{}, instanceID base.SQLInstanceID) (interface{}, error), + responseFn func(instanceID base.SQLInstanceID, resp interface{}), + errorFn func(instanceID base.SQLInstanceID, nodeFnError error), +) error { + liveTenantInstances, err := t.getLiveInstancesForTenant(ctx) + if err != nil { + return err + } + + type instanceResponse struct { + instanceID base.SQLInstanceID + response interface{} + err error + } + + numInstances := len(liveTenantInstances) + responseChan := make(chan instanceResponse, numInstances) + + instanceQuery := func(ctx context.Context, instance sqlinstance.InstanceInfo) { + var client interface{} + err := contextutil.RunWithTimeout(ctx, "dial instance", base.NetworkTimeout, func(ctx context.Context) error { + var err error + client, err = dialFn(ctx, instance) + return err + }) + + instanceID := instance.InstanceID() + if err != nil { + err = errors.Wrapf(err, "failed to dial into node %d", + instanceID) + responseChan <- instanceResponse{instanceID: instanceID, err: err} + return + } + + res, err := instanceFn(ctx, client, instanceID) + if err != nil { + err = errors.Wrapf(err, "error requesting %s from instance %d", + errorCtx, instanceID) + responseChan <- instanceResponse{instanceID: instanceID, err: err} + return + } + responseChan <- instanceResponse{instanceID: instanceID, response: res} + } + + sem := quotapool.NewIntPool("instance status", maxConcurrentRequests) + ctx, cancel := t.stopper.WithCancelOnQuiesce(ctx) + defer cancel() + + for _, instance := range liveTenantInstances { + instance := instance + if err := t.stopper.RunLimitedAsyncTask( + ctx, fmt.Sprintf("server.tenantStatusServer: requesting %s", errorCtx), + sem, true /* wait */, func(ctx context.Context) { + instanceQuery(ctx, instance) + }); err != nil { + return err + } + } + + var resultErr error + for numInstances > 0 { + select { + case res := <-responseChan: + if res.err != nil { + errorFn(res.instanceID, res.err) + } else { + responseFn(res.instanceID, res.response) + } + case <-ctx.Done(): + resultErr = errors.Errorf("request of %s canceled before completion", errorCtx) + } + numInstances-- + } + return resultErr +} + +// getLiveInstancesForTenant filters through all the instances that the +// `sqlInstanceProvider` gives us and checks using the liveness provider +// to only keep the ones for which `IsAlive` is true. This is needed to +// limit the fan-out commands to only make requests to live nodes. +func (t *tenantStatusServer) getLiveInstancesForTenant( + ctx context.Context, +) ([]sqlinstance.InstanceInfo, error) { + tenantInstances, err := t.sqlServer.sqlInstanceProvider.GetAllInstances(ctx) if err != nil { return nil, err } - return resp, nil + var liveTenantInstances []sqlinstance.InstanceInfo + for _, i := range tenantInstances { + alive, err := t.sqlServer.sqlLivenessProvider.IsAlive(ctx, i.SessionID()) + if err != nil { + continue + } + if alive { + liveTenantInstances = append(liveTenantInstances, i) + } + } + return liveTenantInstances, nil } func (t *tenantStatusServer) ListDistSQLFlows( diff --git a/pkg/server/testserver.go b/pkg/server/testserver.go index 4c5c5ff36076..0e1ca45895b6 100644 --- a/pkg/server/testserver.go +++ b/pkg/server/testserver.go @@ -496,6 +496,11 @@ func (t *TestTenant) DistSQLServer() interface{} { return t.SQLServer.distSQLServer } +// RPCContext is part of the TestTenantInterface interface +func (t *TestTenant) RPCContext() *rpc.Context { + return t.execCfg.RPCContext +} + // JobRegistry is part of the TestTenantInterface interface. func (t *TestTenant) JobRegistry() interface{} { return t.SQLServer.jobRegistry diff --git a/pkg/testutils/serverutils/test_tenant_shim.go b/pkg/testutils/serverutils/test_tenant_shim.go index 2941477c7298..1b35e333a125 100644 --- a/pkg/testutils/serverutils/test_tenant_shim.go +++ b/pkg/testutils/serverutils/test_tenant_shim.go @@ -13,7 +13,10 @@ package serverutils -import "github.com/cockroachdb/cockroach/pkg/base" +import ( + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/rpc" +) // TestTenantInterface defines SQL-only tenant functionality that tests need; it // is implemented by server.TestTenant. @@ -46,4 +49,7 @@ type TestTenantInterface interface { // JobRegistry returns the *jobs.Registry as an interface{}. JobRegistry() interface{} + + // RPCContext returns the *rpc.Context + RPCContext() *rpc.Context } diff --git a/pkg/util/log/clog.go b/pkg/util/log/clog.go index 0830b89c7f0a..cacacb67bead 100644 --- a/pkg/util/log/clog.go +++ b/pkg/util/log/clog.go @@ -224,9 +224,13 @@ func SetNodeIDs(clusterID string, nodeID int32) { logging.idMu.Lock() defer logging.idMu.Unlock() - if logging.idMu.clusterID != "" { - panic("clusterID already set") - } + // TODO(davidh): PRIOR TO MERGE: since my test in `tenant_grpc_test.go` calls + // StartTenant multiple times I run into this panic. Need to figure out a way + // to modify the below check or perhaps learn something about testing with + // multiple tenants + //if logging.idMu.clusterID != "" { + // panic("clusterID already set") + //} logging.idMu.clusterID = clusterID logging.idMu.nodeID = nodeID @@ -245,9 +249,13 @@ func SetTenantIDs(tenantID string, sqlInstanceID int32) { logging.idMu.Lock() defer logging.idMu.Unlock() - if logging.idMu.tenantID != "" { - panic("tenantID already set") - } + // TODO(davidh): PRIOR TO MERGE: since my test in `tenant_grpc_test.go` calls + // StartTenant multiple times I run into this panic. Need to figure out a way + // to modify the below check or perhaps learn something about testing with + // multiple tenants + //if logging.idMu.tenantID != "" { + // panic("tenantID already set") + //} logging.idMu.tenantID = tenantID logging.idMu.sqlInstanceID = sqlInstanceID