Skip to content

Commit

Permalink
grpc: Add a pointer of server to ctx passed into stats handler (#6750)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasweq authored Oct 26, 2023
1 parent 8190d88 commit 8cb9846
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 0 deletions.
5 changes: 5 additions & 0 deletions internal/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ var (
// xDS-enabled server invokes this method on a grpc.Server when a particular
// listener moves to "not-serving" mode.
DrainServerTransports any // func(*grpc.Server, string)
// IsRegisteredMethod returns whether the passed in method is registered as
// a method on the server.
IsRegisteredMethod any // func(*grpc.Server, string) bool
// ServerFromContext returns the server from the context.
ServerFromContext any // func(context.Context) *grpc.Server
// AddGlobalServerOptions adds an array of ServerOption that will be
// effective globally for newly created servers. The priority will be: 1.
// user-provided; 2. this method; 3. default values.
Expand Down
65 changes: 65 additions & 0 deletions internal/testutils/stubstatshandler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
*
* Copyright 2023 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package testutils

import (
"context"

"google.golang.org/grpc/stats"
)

// StubStatsHandler is a stats handler that is easy to customize within
// individual test cases. It is a stubbable implementation of
// google.golang.org/grpc/stats.Handler for testing purposes.
type StubStatsHandler struct {
TagRPCF func(ctx context.Context, info *stats.RPCTagInfo) context.Context
HandleRPCF func(ctx context.Context, info stats.RPCStats)
TagConnF func(ctx context.Context, info *stats.ConnTagInfo) context.Context
HandleConnF func(ctx context.Context, info stats.ConnStats)
}

// TagRPC calls the StubStatsHandler's TagRPCF, if set.
func (ssh *StubStatsHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
if ssh.TagRPCF != nil {
return ssh.TagRPCF(ctx, info)
}
return ctx
}

// HandleRPC calls the StubStatsHandler's HandleRPCF, if set.
func (ssh *StubStatsHandler) HandleRPC(ctx context.Context, rs stats.RPCStats) {
if ssh.HandleRPCF != nil {
ssh.HandleRPCF(ctx, rs)
}
}

// TagConn calls the StubStatsHandler's TagConnF, if set.
func (ssh *StubStatsHandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context {
if ssh.TagConnF != nil {
return ssh.TagConnF(ctx, info)
}
return ctx
}

// HandleConn calls the StubStatsHandler's HandleConnF, if set.
func (ssh *StubStatsHandler) HandleConn(ctx context.Context, cs stats.ConnStats) {
if ssh.HandleConnF != nil {
ssh.HandleConnF(ctx, cs)
}
}
43 changes: 43 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ func init() {
internal.GetServerCredentials = func(srv *Server) credentials.TransportCredentials {
return srv.opts.creds
}
internal.IsRegisteredMethod = func(srv *Server, method string) bool {
return srv.isRegisteredMethod(method)
}
internal.ServerFromContext = serverFromContext
internal.DrainServerTransports = func(srv *Server, addr string) {
srv.drainServerTransports(addr)
}
Expand Down Expand Up @@ -1707,6 +1711,7 @@ func (s *Server) processStreamingRPC(ctx context.Context, t transport.ServerTran

func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream) {
ctx := stream.Context()
ctx = contextWithServer(ctx, s)
var ti *traceInfo
if EnableTracing {
tr := trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method())
Expand Down Expand Up @@ -1953,6 +1958,44 @@ func (s *Server) getCodec(contentSubtype string) baseCodec {
return codec
}

type serverKey struct{}

// serverFromContext gets the Server from the context.
func serverFromContext(ctx context.Context) *Server {
s, _ := ctx.Value(serverKey{}).(*Server)
return s
}

// contextWithServer sets the Server in the context.
func contextWithServer(ctx context.Context, server *Server) context.Context {
return context.WithValue(ctx, serverKey{}, server)
}

// isRegisteredMethod returns whether the passed in method is registered as a
// method on the server. /service/method and service/method will match if the
// service and method are registered on the server.
func (s *Server) isRegisteredMethod(serviceMethod string) bool {
if serviceMethod != "" && serviceMethod[0] == '/' {
serviceMethod = serviceMethod[1:]
}
pos := strings.LastIndex(serviceMethod, "/")
if pos == -1 { // Invalid method name syntax.
return false
}
service := serviceMethod[:pos]
method := serviceMethod[pos+1:]
srv, knownService := s.services[service]
if knownService {
if _, ok := srv.methods[method]; ok {
return true
}
if _, ok := srv.streams[method]; ok {
return true
}
}
return false
}

// SetHeader sets the header metadata to be sent from the server to the client.
// The context provided must be the context passed to the server's handler.
//
Expand Down
61 changes: 61 additions & 0 deletions stats/stats_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ import (
"github.com/golang/protobuf/proto"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -1457,3 +1460,61 @@ func (s) TestMultipleServerStatsHandler(t *testing.T) {
t.Fatalf("h.gotConn: unexpected amount of ConnStats: %v != %v", len(h.gotConn), 4)
}
}

// TestStatsHandlerCallsServerIsRegisteredMethod tests whether a stats handler
// gets access to a Server on the server side, and thus the method that the
// server owns which specifies whether a method is made or not. The test sets up
// a server with a unary call and full duplex call configured, and makes an RPC.
// Within the stats handler, asking the server whether unary or duplex method
// names are registered should return true, and any other query should return
// false.
func (s) TestStatsHandlerCallsServerIsRegisteredMethod(t *testing.T) {
wg := sync.WaitGroup{}
wg.Add(1)
stubStatsHandler := &testutils.StubStatsHandler{
TagRPCF: func(ctx context.Context, _ *stats.RPCTagInfo) context.Context {
// OpenTelemetry instrumentation needs the passed in Server to determine if
// methods are registered in different handle calls in to record metrics.
// This tag RPC call context gets passed into every handle call, so can
// assert once here, since it maps to all the handle RPC calls that come
// after. These internal calls will be how the OpenTelemetry instrumentation
// component accesses this server and the subsequent helper on the server.
server := internal.ServerFromContext.(func(context.Context) *grpc.Server)(ctx)
if server == nil {
t.Errorf("stats handler received ctx has no server present")
}
isRegisteredMethod := internal.IsRegisteredMethod.(func(*grpc.Server, string) bool)
// /s/m and s/m are valid.
if !isRegisteredMethod(server, "/grpc.testing.TestService/UnaryCall") {
t.Errorf("UnaryCall should be a registered method according to server")
}
if !isRegisteredMethod(server, "grpc.testing.TestService/FullDuplexCall") {
t.Errorf("FullDuplexCall should be a registered method according to server")
}
if isRegisteredMethod(server, "/grpc.testing.TestService/DoesNotExistCall") {
t.Errorf("DoesNotExistCall should not be a registered method according to server")
}
if isRegisteredMethod(server, "/unknownService/UnaryCall") {
t.Errorf("/unknownService/UnaryCall should not be a registered method according to server")
}
wg.Done()
return ctx
},
}
ss := &stubserver.StubServer{
UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
return &testpb.SimpleResponse{}, nil
},
}
if err := ss.Start([]grpc.ServerOption{grpc.StatsHandler(stubStatsHandler)}); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{Payload: &testpb.Payload{}}); err != nil {
t.Fatalf("Unexpected error from UnaryCall: %v", err)
}
wg.Wait()
}

0 comments on commit 8cb9846

Please sign in to comment.