diff --git a/gapis/server/grpc.go b/gapis/server/grpc.go index 565f4defff..2c81f38964 100644 --- a/gapis/server/grpc.go +++ b/gapis/server/grpc.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "net" + "sync/atomic" "time" "github.com/google/gapid/core/app/auth" @@ -46,17 +47,10 @@ func Listen(ctx context.Context, addr string, cfg Config) error { // NewWithListener starts a new GRPC server listening on l. // This is a blocking call. func NewWithListener(ctx context.Context, l net.Listener, cfg Config, srvChan chan<- *grpc.Server) error { - keepAlive := make(chan struct{}, 1) s := &grpcServer{ - handler: New(ctx, cfg), - bindCtx: func(c context.Context) context.Context { - // Write to keepAlive if it has no pending signal. - select { - case keepAlive <- struct{}{}: - default: - } - return keys.Clone(c, ctx) - }, + handler: New(ctx, cfg), + bindCtx: func(c context.Context) context.Context { return keys.Clone(c, ctx) }, + keepAlive: make(chan struct{}, 1), } return grpcutil.ServeWithListener(ctx, l, func(ctx context.Context, listener net.Listener, server *grpc.Server) error { if addr, ok := listener.Addr().(*net.TCPAddr); ok { @@ -68,21 +62,40 @@ func NewWithListener(ctx context.Context, l net.Listener, cfg Config, srvChan ch srvChan <- server } if cfg.IdleTimeout != 0 { - go s.stopIfIdle(ctx, server, keepAlive, cfg.IdleTimeout) + go s.stopIfIdle(ctx, server, cfg.IdleTimeout) } return nil }, grpc.UnaryInterceptor(auth.ServerInterceptor(cfg.AuthToken))) } type grpcServer struct { - handler Server - bindCtx func(context.Context) context.Context + handler Server + bindCtx func(context.Context) context.Context + keepAlive chan struct{} + inFlightRPCs uint32 +} + +// inRPC should be called at the start of an RPC call. The returned function +// should be called when the RPC call finishes. +func (s *grpcServer) inRPC() func() { + atomic.AddUint32(&s.inFlightRPCs, 1) + select { + case s.keepAlive <- struct{}{}: + default: + } + return func() { + select { + case s.keepAlive <- struct{}{}: + default: + } + atomic.AddUint32(&s.inFlightRPCs, ^uint32(0)) + } } // stopIfIdle calls GracefulStop on server if there are no writes the the // keepAlive chan within idleTimeout. // This function blocks until there's an idle timeout, or ctx is cancelled. -func (s *grpcServer) stopIfIdle(ctx context.Context, server *grpc.Server, keepAlive <-chan struct{}, idleTimeout time.Duration) { +func (s *grpcServer) stopIfIdle(ctx context.Context, server *grpc.Server, idleTimeout time.Duration) { // Split the idleTimeout into N smaller chunks, and check that there was // no activity from the client in a contiguous N chunks of time. // This avoids killing the server if the machine is suspended (where the @@ -96,21 +109,24 @@ func (s *grpcServer) stopIfIdle(ctx context.Context, server *grpc.Server, keepAl case <-task.ShouldStop(ctx): return case <-time.After(waitTime): + if rpcs := atomic.LoadUint32(&s.inFlightRPCs); rpcs != 0 { + continue + } idleTime += waitTime if idleTime >= idleTimeout { - log.E(ctx, fmt.Sprintf("Stopping GAPIS server: No communication with the client for %v (--idle-timeout %v)", idleTime, idleTimeout)) + log.E(ctx, "Stopping GAPIS server: No communication with the client for %v (--idle-timeout %v)", idleTime, idleTimeout) time.Sleep(time.Second * 3) // Wait a little in the hope this message makes its way to the client(s). return - } else { - log.W(ctx, fmt.Sprintf("No communication with the client for %v (--idle-timeout %v)", idleTime, idleTimeout)) } - case <-keepAlive: + log.W(ctx, "No communication with the client for %v (--idle-timeout %v)", idleTime, idleTimeout) + case <-s.keepAlive: idleTime = 0 } } } func (s *grpcServer) Ping(ctx xctx.Context, req *service.PingRequest) (*service.PingResponse, error) { + defer s.inRPC()() err := s.handler.Ping(s.bindCtx(ctx)) if err := service.NewError(err); err != nil { return &service.PingResponse{}, nil @@ -119,6 +135,7 @@ func (s *grpcServer) Ping(ctx xctx.Context, req *service.PingRequest) (*service. } func (s *grpcServer) GetServerInfo(ctx xctx.Context, req *service.GetServerInfoRequest) (*service.GetServerInfoResponse, error) { + defer s.inRPC()() info, err := s.handler.GetServerInfo(s.bindCtx(ctx)) if err := service.NewError(err); err != nil { return &service.GetServerInfoResponse{Res: &service.GetServerInfoResponse_Error{Error: err}}, nil @@ -127,6 +144,7 @@ func (s *grpcServer) GetServerInfo(ctx xctx.Context, req *service.GetServerInfoR } func (s *grpcServer) Get(ctx xctx.Context, req *service.GetRequest) (*service.GetResponse, error) { + defer s.inRPC()() res, err := s.handler.Get(s.bindCtx(ctx), req.Path) if err := service.NewError(err); err != nil { return &service.GetResponse{Res: &service.GetResponse_Error{Error: err}}, nil @@ -136,6 +154,7 @@ func (s *grpcServer) Get(ctx xctx.Context, req *service.GetRequest) (*service.Ge } func (s *grpcServer) Set(ctx xctx.Context, req *service.SetRequest) (*service.SetResponse, error) { + defer s.inRPC()() res, err := s.handler.Set(s.bindCtx(ctx), req.Path, req.Value.Get()) if err := service.NewError(err); err != nil { return &service.SetResponse{Res: &service.SetResponse_Error{Error: err}}, nil @@ -144,6 +163,7 @@ func (s *grpcServer) Set(ctx xctx.Context, req *service.SetRequest) (*service.Se } func (s *grpcServer) Follow(ctx xctx.Context, req *service.FollowRequest) (*service.FollowResponse, error) { + defer s.inRPC()() res, err := s.handler.Follow(s.bindCtx(ctx), req.Path) if err := service.NewError(err); err != nil { return &service.FollowResponse{Res: &service.FollowResponse_Error{Error: err}}, nil @@ -152,6 +172,7 @@ func (s *grpcServer) Follow(ctx xctx.Context, req *service.FollowRequest) (*serv } func (s *grpcServer) BeginCPUProfile(ctx xctx.Context, req *service.BeginCPUProfileRequest) (*service.BeginCPUProfileResponse, error) { + defer s.inRPC()() err := s.handler.BeginCPUProfile(s.bindCtx(ctx)) if err := service.NewError(err); err != nil { return &service.BeginCPUProfileResponse{Error: err}, nil @@ -160,6 +181,7 @@ func (s *grpcServer) BeginCPUProfile(ctx xctx.Context, req *service.BeginCPUProf } func (s *grpcServer) EndCPUProfile(ctx xctx.Context, req *service.EndCPUProfileRequest) (*service.EndCPUProfileResponse, error) { + defer s.inRPC()() data, err := s.handler.EndCPUProfile(s.bindCtx(ctx)) if err := service.NewError(err); err != nil { return &service.EndCPUProfileResponse{Res: &service.EndCPUProfileResponse_Error{Error: err}}, nil @@ -168,6 +190,7 @@ func (s *grpcServer) EndCPUProfile(ctx xctx.Context, req *service.EndCPUProfileR } func (s *grpcServer) GetPerformanceCounters(ctx xctx.Context, req *service.GetPerformanceCountersRequest) (*service.GetPerformanceCountersResponse, error) { + defer s.inRPC()() data, err := s.handler.GetPerformanceCounters(s.bindCtx(ctx)) if err := service.NewError(err); err != nil { return &service.GetPerformanceCountersResponse{Res: &service.GetPerformanceCountersResponse_Error{Error: err}}, nil @@ -176,6 +199,7 @@ func (s *grpcServer) GetPerformanceCounters(ctx xctx.Context, req *service.GetPe } func (s *grpcServer) GetProfile(ctx xctx.Context, req *service.GetProfileRequest) (*service.GetProfileResponse, error) { + defer s.inRPC()() data, err := s.handler.GetProfile(s.bindCtx(ctx), req.Name, req.Debug) if err := service.NewError(err); err != nil { return &service.GetProfileResponse{Res: &service.GetProfileResponse_Error{Error: err}}, nil @@ -184,6 +208,7 @@ func (s *grpcServer) GetProfile(ctx xctx.Context, req *service.GetProfileRequest } func (s *grpcServer) GetAvailableStringTables(ctx xctx.Context, req *service.GetAvailableStringTablesRequest) (*service.GetAvailableStringTablesResponse, error) { + defer s.inRPC()() tables, err := s.handler.GetAvailableStringTables(s.bindCtx(ctx)) if err := service.NewError(err); err != nil { return &service.GetAvailableStringTablesResponse{Res: &service.GetAvailableStringTablesResponse_Error{Error: err}}, nil @@ -196,6 +221,7 @@ func (s *grpcServer) GetAvailableStringTables(ctx xctx.Context, req *service.Get } func (s *grpcServer) GetStringTable(ctx xctx.Context, req *service.GetStringTableRequest) (*service.GetStringTableResponse, error) { + defer s.inRPC()() table, err := s.handler.GetStringTable(s.bindCtx(ctx), req.Table) if err := service.NewError(err); err != nil { return &service.GetStringTableResponse{Res: &service.GetStringTableResponse_Error{Error: err}}, nil @@ -204,6 +230,7 @@ func (s *grpcServer) GetStringTable(ctx xctx.Context, req *service.GetStringTabl } func (s *grpcServer) ImportCapture(ctx xctx.Context, req *service.ImportCaptureRequest) (*service.ImportCaptureResponse, error) { + defer s.inRPC()() capture, err := s.handler.ImportCapture(s.bindCtx(ctx), req.Name, req.Data) if err := service.NewError(err); err != nil { return &service.ImportCaptureResponse{Res: &service.ImportCaptureResponse_Error{Error: err}}, nil @@ -212,6 +239,7 @@ func (s *grpcServer) ImportCapture(ctx xctx.Context, req *service.ImportCaptureR } func (s *grpcServer) ExportCapture(ctx xctx.Context, req *service.ExportCaptureRequest) (*service.ExportCaptureResponse, error) { + defer s.inRPC()() data, err := s.handler.ExportCapture(s.bindCtx(ctx), req.Capture) if err := service.NewError(err); err != nil { return &service.ExportCaptureResponse{Res: &service.ExportCaptureResponse_Error{Error: err}}, nil @@ -220,6 +248,7 @@ func (s *grpcServer) ExportCapture(ctx xctx.Context, req *service.ExportCaptureR } func (s *grpcServer) LoadCapture(ctx xctx.Context, req *service.LoadCaptureRequest) (*service.LoadCaptureResponse, error) { + defer s.inRPC()() capture, err := s.handler.LoadCapture(s.bindCtx(ctx), req.Path) if err := service.NewError(err); err != nil { return &service.LoadCaptureResponse{Res: &service.LoadCaptureResponse_Error{Error: err}}, nil @@ -228,6 +257,7 @@ func (s *grpcServer) LoadCapture(ctx xctx.Context, req *service.LoadCaptureReque } func (s *grpcServer) GetDevices(ctx xctx.Context, req *service.GetDevicesRequest) (*service.GetDevicesResponse, error) { + defer s.inRPC()() devices, err := s.handler.GetDevices(s.bindCtx(ctx)) if err := service.NewError(err); err != nil { return &service.GetDevicesResponse{Res: &service.GetDevicesResponse_Error{Error: err}}, nil @@ -240,6 +270,7 @@ func (s *grpcServer) GetDevices(ctx xctx.Context, req *service.GetDevicesRequest } func (s *grpcServer) GetDevicesForReplay(ctx xctx.Context, req *service.GetDevicesForReplayRequest) (*service.GetDevicesForReplayResponse, error) { + defer s.inRPC()() devices, err := s.handler.GetDevicesForReplay(s.bindCtx(ctx), req.Capture) if err := service.NewError(err); err != nil { return &service.GetDevicesForReplayResponse{Res: &service.GetDevicesForReplayResponse_Error{Error: err}}, nil @@ -252,6 +283,7 @@ func (s *grpcServer) GetDevicesForReplay(ctx xctx.Context, req *service.GetDevic } func (s *grpcServer) GetFramebufferAttachment(ctx xctx.Context, req *service.GetFramebufferAttachmentRequest) (*service.GetFramebufferAttachmentResponse, error) { + defer s.inRPC()() image, err := s.handler.GetFramebufferAttachment( s.bindCtx(ctx), req.Device, @@ -267,12 +299,14 @@ func (s *grpcServer) GetFramebufferAttachment(ctx xctx.Context, req *service.Get } func (s *grpcServer) GetLogStream(req *service.GetLogStreamRequest, server service.Gapid_GetLogStreamServer) error { + defer s.inRPC()() ctx := server.Context() h := log.NewHandler(func(m *log.Message) { server.Send(log_pb.From(m)) }, nil) return s.handler.GetLogStream(s.bindCtx(ctx), h) } func (s *grpcServer) Find(req *service.FindRequest, server service.Gapid_FindServer) error { + defer s.inRPC()() ctx := server.Context() return s.handler.Find(s.bindCtx(ctx), req, server.Send) }