diff --git a/call.go b/call.go index 0843865244de..a2b89ac6a36e 100644 --- a/call.go +++ b/call.go @@ -73,7 +73,10 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran } } for { - if err = recv(p, dopts.codec, stream, dopts.dc, reply, dopts.maxMsgSize, inPayload); err != nil { + if c.maxReceiveMessageSize == nil { + return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") + } + if err = recv(p, dopts.codec, stream, dopts.dc, reply, *c.maxReceiveMessageSize, inPayload); err != nil { if err == io.EOF { break } @@ -93,7 +96,7 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran } // sendRequest writes out various information of an RPC such as Context and Message. -func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, callHdr *transport.CallHdr, stream *transport.Stream, t transport.ClientTransport, args interface{}, opts *transport.Options) (err error) { +func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, c *callInfo, callHdr *transport.CallHdr, stream *transport.Stream, t transport.ClientTransport, args interface{}, opts *transport.Options) (err error) { defer func() { if err != nil { // If err is connection error, t will be closed, no need to close stream here. @@ -118,6 +121,12 @@ func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, if err != nil { return Errorf(codes.Internal, "grpc: %v", err) } + if c.maxSendMessageSize == nil { + return Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)") + } + if len(outBuf) > *c.maxSendMessageSize { + return Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(outBuf), *c.maxSendMessageSize) + } err = t.Write(stream, outBuf, opts) if err == nil && outPayload != nil { outPayload.SentTime = time.Now() @@ -145,14 +154,18 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (e error) { c := defaultCallInfo - if mc, ok := cc.getMethodConfig(method); ok { - c.failFast = !mc.WaitForReady - if mc.Timeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, mc.Timeout) - defer cancel() - } + mc := cc.GetMethodConfig(method) + if mc.WaitForReady != nil { + c.failFast = !*mc.WaitForReady } + + if mc.Timeout != nil && *mc.Timeout >= 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, *mc.Timeout) + defer cancel() + } + + opts = append(cc.dopts.callOptions, opts...) for _, o := range opts { if err := o.before(&c); err != nil { return toRPCErr(err) @@ -163,6 +176,10 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli o.after(&c) } }() + + c.maxSendMessageSize = getMaxSize(mc.MaxReqSize, c.maxSendMessageSize, defaultClientMaxSendMessageSize) + c.maxReceiveMessageSize = getMaxSize(mc.MaxRespSize, c.maxReceiveMessageSize, defaultClientMaxReceiveMessageSize) + if EnableTracing { c.traceInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method) defer c.traceInfo.tr.Finish() @@ -260,7 +277,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli } return toRPCErr(err) } - err = sendRequest(ctx, cc.dopts, cc.dopts.cp, callHdr, stream, t, args, topts) + err = sendRequest(ctx, cc.dopts, cc.dopts.cp, &c, callHdr, stream, t, args, topts) if err != nil { if put != nil { updateRPCInfoInContext(ctx, rpcInfo{ diff --git a/clientconn.go b/clientconn.go index be511f939c3f..9dca29c37b52 100644 --- a/clientconn.go +++ b/clientconn.go @@ -36,8 +36,8 @@ package grpc import ( "errors" "fmt" - "math" "net" + "strings" "sync" "time" @@ -87,22 +87,25 @@ var ( // dialOptions configure a Dial call. dialOptions are set by the DialOption // values passed to Dial. type dialOptions struct { - unaryInt UnaryClientInterceptor - streamInt StreamClientInterceptor - codec Codec - cp Compressor - dc Decompressor - bs backoffStrategy - balancer Balancer - block bool - insecure bool - timeout time.Duration - scChan <-chan ServiceConfig - copts transport.ConnectOptions - maxMsgSize int + unaryInt UnaryClientInterceptor + streamInt StreamClientInterceptor + codec Codec + cp Compressor + dc Decompressor + bs backoffStrategy + balancer Balancer + block bool + insecure bool + timeout time.Duration + scChan <-chan ServiceConfig + copts transport.ConnectOptions + callOptions []CallOption } -const defaultClientMaxMsgSize = math.MaxInt32 +const ( + defaultClientMaxReceiveMessageSize = 1024 * 1024 * 4 + defaultClientMaxSendMessageSize = 1024 * 1024 * 4 +) // DialOption configures how we set up the connection. type DialOption func(*dialOptions) @@ -123,10 +126,15 @@ func WithInitialConnWindowSize(s int32) DialOption { } } -// WithMaxMsgSize returns a DialOption which sets the maximum message size the client can receive. +// WithMaxMsgSize returns a DialOption which sets the maximum message size the client can receive. Deprecated: use WithDefaultCallOptions(MaxCallRecvMsgSize(s)) instead. func WithMaxMsgSize(s int) DialOption { + return WithDefaultCallOptions(MaxCallRecvMsgSize(s)) +} + +// WithDefaultCallOptions returns a DialOption which sets the default CallOptions for calls over the connection. +func WithDefaultCallOptions(cos ...CallOption) DialOption { return func(o *dialOptions) { - o.maxMsgSize = s + o.callOptions = append(o.callOptions, cos...) } } @@ -321,7 +329,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * conns: make(map[Address]*addrConn), } cc.ctx, cc.cancel = context.WithCancel(context.Background()) - cc.dopts.maxMsgSize = defaultClientMaxMsgSize + for _, opt := range opts { opt(&cc.dopts) } @@ -359,15 +367,16 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * } }() + scSet := false if cc.dopts.scChan != nil { - // Wait for the initial service config. + // Try to get an initial service config. select { case sc, ok := <-cc.dopts.scChan: if ok { cc.sc = sc + scSet = true } - case <-ctx.Done(): - return nil, ctx.Err() + default: } } // Set defaults. @@ -430,7 +439,17 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * return nil, err } } - + if cc.dopts.scChan != nil && !scSet { + // Blocking wait for the initial service config. + select { + case sc, ok := <-cc.dopts.scChan: + if ok { + cc.sc = sc + } + case <-ctx.Done(): + return nil, ctx.Err() + } + } if cc.dopts.scChan != nil { go cc.scWatcher() } @@ -640,12 +659,23 @@ func (cc *ClientConn) resetAddrConn(addr Address, block bool, tearDownErr error) return nil } -// TODO: Avoid the locking here. -func (cc *ClientConn) getMethodConfig(method string) (m MethodConfig, ok bool) { +// GetMethodConfig gets the method config of the input method. +// If there's an exact match for input method (i.e. /service/method), we return +// the corresponding MethodConfig. +// If there isn't an exact match for the input method, we look for the default config +// under the service (i.e /service/). If there is a default MethodConfig for +// the serivce, we return it. +// Otherwise, we return an empty MethodConfig. +func (cc *ClientConn) GetMethodConfig(method string) MethodConfig { + // TODO: Avoid the locking here. cc.mu.RLock() defer cc.mu.RUnlock() - m, ok = cc.sc.Methods[method] - return + m, ok := cc.sc.Methods[method] + if !ok { + i := strings.LastIndex(method, "/") + m, _ = cc.sc.Methods[method[:i+1]] + } + return m } func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) { diff --git a/rpc_util.go b/rpc_util.go index 446d9fbee0c0..11558d701721 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -137,12 +137,14 @@ func (d *gzipDecompressor) Type() string { // callInfo contains all related configuration and information about an RPC. type callInfo struct { - failFast bool - headerMD metadata.MD - trailerMD metadata.MD - peer *peer.Peer - traceInfo traceInfo // in trace.go - creds credentials.PerRPCCredentials + failFast bool + headerMD metadata.MD + trailerMD metadata.MD + peer *peer.Peer + traceInfo traceInfo // in trace.go + maxReceiveMessageSize *int + maxSendMessageSize *int + creds credentials.PerRPCCredentials } var defaultCallInfo = callInfo{failFast: true} @@ -217,6 +219,22 @@ func FailFast(failFast bool) CallOption { }) } +// MaxCallRecvMsgSize returns a CallOption which sets the maximum message size the client can receive. +func MaxCallRecvMsgSize(s int) CallOption { + return beforeCall(func(o *callInfo) error { + o.maxReceiveMessageSize = &s + return nil + }) +} + +// MaxCallSendMsgSize returns a CallOption which sets the maximum message size the client can send. +func MaxCallSendMsgSize(s int) CallOption { + return beforeCall(func(o *callInfo) error { + o.maxSendMessageSize = &s + return nil + }) +} + // PerRPCCredentials returns a CallOption that sets credentials.PerRPCCredentials // for a call. func PerRPCCredentials(creds credentials.PerRPCCredentials) CallOption { @@ -259,7 +277,7 @@ type parser struct { // No other error values or types must be returned, which also means // that the underlying io.Reader must not return an incompatible // error. -func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err error) { +func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byte, err error) { if _, err := io.ReadFull(p.r, p.header[:]); err != nil { return 0, nil, err } @@ -270,8 +288,8 @@ func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err erro if length == 0 { return pf, nil, nil } - if length > uint32(maxMsgSize) { - return 0, nil, Errorf(codes.Internal, "grpc: received message length %d exceeding the max size %d", length, maxMsgSize) + if length > uint32(maxReceiveMessageSize) { + return 0, nil, Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", length, maxReceiveMessageSize) } // TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead // of making it for each message: @@ -314,7 +332,7 @@ func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer, outPayl length = uint(len(b)) } if length > math.MaxUint32 { - return nil, Errorf(codes.InvalidArgument, "grpc: message too large (%d bytes)", length) + return nil, Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", length) } const ( @@ -355,8 +373,8 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er return nil } -func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxMsgSize int, inPayload *stats.InPayload) error { - pf, d, err := p.recvMsg(maxMsgSize) +func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, inPayload *stats.InPayload) error { + pf, d, err := p.recvMsg(maxReceiveMessageSize) if err != nil { return err } @@ -372,10 +390,10 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{ return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) } } - if len(d) > maxMsgSize { + if len(d) > maxReceiveMessageSize { // TODO: Revisit the error code. Currently keep it consistent with java // implementation. - return Errorf(codes.Internal, "grpc: received a message of %d bytes exceeding %d limit", len(d), maxMsgSize) + return Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(d), maxReceiveMessageSize) } if err := c.Unmarshal(d, m); err != nil { return Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err) @@ -501,24 +519,22 @@ type MethodConfig struct { // WaitForReady indicates whether RPCs sent to this method should wait until // the connection is ready by default (!failfast). The value specified via the // gRPC client API will override the value set here. - WaitForReady bool + WaitForReady *bool // Timeout is the default timeout for RPCs sent to this method. The actual // deadline used will be the minimum of the value specified here and the value // set by the application via the gRPC client API. If either one is not set, // then the other will be used. If neither is set, then the RPC has no deadline. - Timeout time.Duration + Timeout *time.Duration // MaxReqSize is the maximum allowed payload size for an individual request in a // stream (client->server) in bytes. The size which is measured is the serialized // payload after per-message compression (but before stream compression) in bytes. // The actual value used is the minumum of the value specified here and the value set // by the application via the gRPC client API. If either one is not set, then the other // will be used. If neither is set, then the built-in default is used. - // TODO: support this. - MaxReqSize uint32 + MaxReqSize *int // MaxRespSize is the maximum allowed payload size for an individual response in a // stream (server->client) in bytes. - // TODO: support this. - MaxRespSize uint32 + MaxRespSize *int } // ServiceConfig is provided by the service provider and contains parameters for how @@ -529,9 +545,32 @@ type ServiceConfig struct { // via grpc.WithBalancer will override this. LB Balancer // Methods contains a map for the methods in this service. + // If there is an exact match for a method (i.e. /service/method) in the map, use the corresponding MethodConfig. + // If there's no exact match, look for the default config for the service (/service/) and use the corresponding MethodConfig if it exists. + // Otherwise, the method has no MethodConfig to use. Methods map[string]MethodConfig } +func min(a, b *int) *int { + if *a < *b { + return a + } + return b +} + +func getMaxSize(mcMax, doptMax *int, defaultVal int) *int { + if mcMax == nil && doptMax == nil { + return &defaultVal + } + if mcMax != nil && doptMax != nil { + return min(mcMax, doptMax) + } + if mcMax != nil { + return mcMax + } + return doptMax +} + // SupportPackageIsVersion4 is referenced from generated protocol buffer files // to assert that that code is compatible with this version of the grpc package. // diff --git a/server.go b/server.go index 5a1d4ea1b4aa..62f4149a0c1e 100644 --- a/server.go +++ b/server.go @@ -61,6 +61,11 @@ import ( "google.golang.org/grpc/transport" ) +const ( + defaultServerMaxReceiveMessageSize = 1024 * 1024 * 4 + defaultServerMaxSendMessageSize = 1024 * 1024 * 4 +) + type methodHandler func(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor UnaryServerInterceptor) (interface{}, error) // MethodDesc represents an RPC service's method specification. @@ -111,12 +116,13 @@ type options struct { codec Codec cp Compressor dc Decompressor - maxMsgSize int unaryInt UnaryServerInterceptor streamInt StreamServerInterceptor inTapHandle tap.ServerInHandle statsHandler stats.Handler maxConcurrentStreams uint32 + maxReceiveMessageSize int + maxSendMessageSize int useHandlerImpl bool // use http.Handler-based server unknownStreamDesc *StreamDesc keepaliveParams keepalive.ServerParameters @@ -125,7 +131,10 @@ type options struct { initialConnWindowSize int32 } -var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit +var defaultServerOptions = options{ + maxReceiveMessageSize: defaultServerMaxReceiveMessageSize, + maxSendMessageSize: defaultServerMaxSendMessageSize, +} // A ServerOption sets options such as credentials, codec and keepalive parameters, etc. type ServerOption func(*options) @@ -181,11 +190,25 @@ func RPCDecompressor(dc Decompressor) ServerOption { } } -// MaxMsgSize returns a ServerOption to set the max message size in bytes for inbound mesages. -// If this is not set, gRPC uses the default 4MB. +// MaxMsgSize returns a ServerOption to set the max message size in bytes the server can receive. +// If this is not set, gRPC uses the default limit. Deprecated: use MaxRecvMsgSize instead. func MaxMsgSize(m int) ServerOption { + return MaxRecvMsgSize(m) +} + +// MaxRecvMsgSize returns a ServerOption to set the max message size in bytes the server can receive. +// If this is not set, gRPC uses the default 4MB. +func MaxRecvMsgSize(m int) ServerOption { return func(o *options) { - o.maxMsgSize = m + o.maxReceiveMessageSize = m + } +} + +// MaxSendMsgSize returns a ServerOption to set the max message size in bytes the server can send. +// If this is not set, gRPC uses the default 4MB. +func MaxSendMsgSize(m int) ServerOption { + return func(o *options) { + o.maxSendMessageSize = m } } @@ -266,8 +289,7 @@ func UnknownServiceHandler(streamHandler StreamHandler) ServerOption { // NewServer creates a gRPC server which has no service registered and has not // started to accept requests yet. func NewServer(opt ...ServerOption) *Server { - var opts options - opts.maxMsgSize = defaultMaxMsgSize + opts := defaultServerOptions for _, o := range opt { o(&opts) } @@ -649,6 +671,9 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str // the optimal option. grpclog.Fatalf("grpc: Server failed to encode response %v", err) } + if len(p) > s.opts.maxSendMessageSize { + return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(p), s.opts.maxSendMessageSize) + } err = t.Write(stream, p, opts) if err == nil && outPayload != nil { outPayload.SentTime = time.Now() @@ -690,7 +715,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. stream.SetSendCompress(s.opts.cp.Type()) } p := &parser{r: stream} - pf, req, err := p.recvMsg(s.opts.maxMsgSize) + pf, req, err := p.recvMsg(s.opts.maxReceiveMessageSize) if err == io.EOF { // The entire stream is done (for unary RPC only). return err @@ -748,10 +773,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. return Errorf(codes.Internal, err.Error()) } } - if len(req) > s.opts.maxMsgSize { + if len(req) > s.opts.maxReceiveMessageSize { // TODO: Revisit the error code. Currently keep it consistent with // java implementation. - return status.Errorf(codes.Internal, "grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize) + return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(req), s.opts.maxReceiveMessageSize) } if err := s.opts.codec.Unmarshal(req, v); err != nil { return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err) @@ -844,15 +869,16 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp stream.SetSendCompress(s.opts.cp.Type()) } ss := &serverStream{ - t: t, - s: stream, - p: &parser{r: stream}, - codec: s.opts.codec, - cp: s.opts.cp, - dc: s.opts.dc, - maxMsgSize: s.opts.maxMsgSize, - trInfo: trInfo, - statsHandler: sh, + t: t, + s: stream, + p: &parser{r: stream}, + codec: s.opts.codec, + cp: s.opts.cp, + dc: s.opts.dc, + maxReceiveMessageSize: s.opts.maxReceiveMessageSize, + maxSendMessageSize: s.opts.maxSendMessageSize, + trInfo: trInfo, + statsHandler: sh, } if ss.cp != nil { ss.cbuf = new(bytes.Buffer) @@ -927,7 +953,7 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str trInfo.tr.SetError() } errDesc := fmt.Sprintf("malformed method name: %q", stream.Method()) - if err := t.WriteStatus(stream, status.New(codes.InvalidArgument, errDesc)); err != nil { + if err := t.WriteStatus(stream, status.New(codes.ResourceExhausted, errDesc)); err != nil { if trInfo != nil { trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) trInfo.tr.SetError() diff --git a/stream.go b/stream.go index ec534a017b1d..ed0ebe7b73ec 100644 --- a/stream.go +++ b/stream.go @@ -113,17 +113,24 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth cancel context.CancelFunc ) c := defaultCallInfo - if mc, ok := cc.getMethodConfig(method); ok { - c.failFast = !mc.WaitForReady - if mc.Timeout > 0 { - ctx, cancel = context.WithTimeout(ctx, mc.Timeout) - } + mc := cc.GetMethodConfig(method) + if mc.WaitForReady != nil { + c.failFast = !*mc.WaitForReady + } + + if mc.Timeout != nil { + ctx, cancel = context.WithTimeout(ctx, *mc.Timeout) } + + opts = append(cc.dopts.callOptions, opts...) for _, o := range opts { if err := o.before(&c); err != nil { return nil, toRPCErr(err) } } + c.maxSendMessageSize = getMaxSize(mc.MaxReqSize, c.maxSendMessageSize, defaultClientMaxSendMessageSize) + c.maxReceiveMessageSize = getMaxSize(mc.MaxRespSize, c.maxReceiveMessageSize, defaultClientMaxReceiveMessageSize) + callHdr := &transport.CallHdr{ Host: cc.authority, Method: method, @@ -215,14 +222,13 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth break } cs := &clientStream{ - opts: opts, - c: c, - desc: desc, - codec: cc.dopts.codec, - cp: cc.dopts.cp, - dc: cc.dopts.dc, - maxMsgSize: cc.dopts.maxMsgSize, - cancel: cancel, + opts: opts, + c: c, + desc: desc, + codec: cc.dopts.codec, + cp: cc.dopts.cp, + dc: cc.dopts.dc, + cancel: cancel, put: put, t: t, @@ -266,18 +272,17 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth // clientStream implements a client side Stream. type clientStream struct { - opts []CallOption - c callInfo - t transport.ClientTransport - s *transport.Stream - p *parser - desc *StreamDesc - codec Codec - cp Compressor - cbuf *bytes.Buffer - dc Decompressor - maxMsgSize int - cancel context.CancelFunc + opts []CallOption + c callInfo + t transport.ClientTransport + s *transport.Stream + p *parser + desc *StreamDesc + codec Codec + cp Compressor + cbuf *bytes.Buffer + dc Decompressor + cancel context.CancelFunc tracing bool // set to EnableTracing when the clientStream is created. @@ -361,6 +366,12 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { if err != nil { return Errorf(codes.Internal, "grpc: %v", err) } + if cs.c.maxSendMessageSize == nil { + return Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)") + } + if len(out) > *cs.c.maxSendMessageSize { + return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(out), *cs.c.maxSendMessageSize) + } err = cs.t.Write(cs.s, out, &transport.Options{Last: false}) if err == nil && outPayload != nil { outPayload.SentTime = time.Now() @@ -376,7 +387,10 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { Client: true, } } - err = recv(cs.p, cs.codec, cs.s, cs.dc, m, cs.maxMsgSize, inPayload) + if cs.c.maxReceiveMessageSize == nil { + return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") + } + err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, inPayload) defer func() { // err != nil indicates the termination of the stream. if err != nil { @@ -399,7 +413,10 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { } // Special handling for client streaming rpc. // This recv expects EOF or errors, so we don't collect inPayload. - err = recv(cs.p, cs.codec, cs.s, cs.dc, m, cs.maxMsgSize, nil) + if cs.c.maxReceiveMessageSize == nil { + return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") + } + err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil) cs.closeTransportStream(err) if err == nil { return toRPCErr(errors.New("grpc: client streaming protocol violation: get , want ")) @@ -524,15 +541,16 @@ type ServerStream interface { // serverStream implements a server side Stream. type serverStream struct { - t transport.ServerTransport - s *transport.Stream - p *parser - codec Codec - cp Compressor - dc Decompressor - cbuf *bytes.Buffer - maxMsgSize int - trInfo *traceInfo + t transport.ServerTransport + s *transport.Stream + p *parser + codec Codec + cp Compressor + dc Decompressor + cbuf *bytes.Buffer + maxReceiveMessageSize int + maxSendMessageSize int + trInfo *traceInfo statsHandler stats.Handler @@ -591,6 +609,9 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { err = Errorf(codes.Internal, "grpc: %v", err) return err } + if len(out) > ss.maxSendMessageSize { + return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(out), ss.maxSendMessageSize) + } if err := ss.t.Write(ss.s, out, &transport.Options{Last: false}); err != nil { return toRPCErr(err) } @@ -620,7 +641,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { if ss.statsHandler != nil { inPayload = &stats.InPayload{} } - if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize, inPayload); err != nil { + if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload); err != nil { if err == io.EOF { return err } diff --git a/test/end2end_test.go b/test/end2end_test.go index 0eee77d01341..b028e83eb9bc 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -435,7 +435,11 @@ type test struct { healthServer *health.Server // nil means disabled maxStream uint32 tapHandle tap.ServerInHandle - maxMsgSize int + maxMsgSize *int + maxClientReceiveMsgSize *int + maxClientSendMsgSize *int + maxServerReceiveMsgSize *int + maxServerSendMsgSize *int userAgent string clientCompression bool serverCompression bool @@ -496,8 +500,14 @@ func (te *test) startServer(ts testpb.TestServiceServer) { te.testServer = ts te.t.Logf("Running test in %s environment...", te.e.name) sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(te.maxStream)} - if te.maxMsgSize > 0 { - sopts = append(sopts, grpc.MaxMsgSize(te.maxMsgSize)) + if te.maxMsgSize != nil { + sopts = append(sopts, grpc.MaxMsgSize(*te.maxMsgSize)) + } + if te.maxServerReceiveMsgSize != nil { + sopts = append(sopts, grpc.MaxRecvMsgSize(*te.maxServerReceiveMsgSize)) + } + if te.maxServerSendMsgSize != nil { + sopts = append(sopts, grpc.MaxSendMsgSize(*te.maxServerSendMsgSize)) } if te.tapHandle != nil { sopts = append(sopts, grpc.InTapHandle(te.tapHandle)) @@ -596,8 +606,14 @@ func (te *test) clientConn() *grpc.ClientConn { if te.streamClientInt != nil { opts = append(opts, grpc.WithStreamInterceptor(te.streamClientInt)) } - if te.maxMsgSize > 0 { - opts = append(opts, grpc.WithMaxMsgSize(te.maxMsgSize)) + if te.maxMsgSize != nil { + opts = append(opts, grpc.WithMaxMsgSize(*te.maxMsgSize)) + } + if te.maxClientReceiveMsgSize != nil { + opts = append(opts, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(*te.maxClientReceiveMsgSize))) + } + if te.maxClientSendMsgSize != nil { + opts = append(opts, grpc.WithDefaultCallOptions(grpc.MaxCallSendMsgSize(*te.maxClientSendMsgSize))) } switch te.e.security { case "tls": @@ -1097,16 +1113,10 @@ func testFailFast(t *testing.T, e env) { awaitNewConnLogOutput() } -func TestServiceConfig(t *testing.T) { - defer leakCheck(t)() - for _, e := range listTestEnv() { - testServiceConfig(t, e) - } -} - -func testServiceConfig(t *testing.T, e env) { +func testServiceConfigSetup(t *testing.T, e env) (*test, chan grpc.ServiceConfig) { te := newTest(t, e) - ch := make(chan grpc.ServiceConfig) + // We write before read. + ch := make(chan grpc.ServiceConfig, 1) te.sc = ch te.userAgent = testAppUA te.declareLogNoise( @@ -1115,37 +1125,152 @@ func testServiceConfig(t *testing.T, e env) { "grpc: addrConn.resetTransport failed to create client transport: connection error", "Failed to dial : context canceled; please retry.", ) + return te, ch +} + +func newBool(b bool) (a *bool) { + return &b +} + +func newInt(b int) (a *int) { + return &b +} + +func newDuration(b time.Duration) (a *time.Duration) { + a = new(time.Duration) + *a = b + return +} + +func TestServiceConfigGetMethodConfig(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testGetMethodConfig(t, e) + } +} + +func testGetMethodConfig(t *testing.T, e env) { + te, ch := testServiceConfigSetup(t, e) defer te.tearDown() - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - mc := grpc.MethodConfig{ - WaitForReady: true, - Timeout: time.Millisecond, - } - m := make(map[string]grpc.MethodConfig) - m["/grpc.testing.TestService/EmptyCall"] = mc - m["/grpc.testing.TestService/FullDuplexCall"] = mc - sc := grpc.ServiceConfig{ - Methods: m, + mc1 := grpc.MethodConfig{ + WaitForReady: newBool(true), + Timeout: newDuration(time.Millisecond), + } + mc2 := grpc.MethodConfig{WaitForReady: newBool(false)} + m := make(map[string]grpc.MethodConfig) + m["/grpc.testing.TestService/EmptyCall"] = mc1 + m["/grpc.testing.TestService/"] = mc2 + sc := grpc.ServiceConfig{ + Methods: m, + } + ch <- sc + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + // The following RPCs are expected to become non-fail-fast ones with 1ms deadline. + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.DeadlineExceeded { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded) + } + + m = make(map[string]grpc.MethodConfig) + m["/grpc.testing.TestService/UnaryCall"] = mc1 + m["/grpc.testing.TestService/"] = mc2 + sc = grpc.ServiceConfig{ + Methods: m, + } + ch <- sc + // Wait for the new service config to propagate. + for { + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) == codes.DeadlineExceeded { + continue } - ch <- sc - }() + break + } + // The following RPCs are expected to become fail-fast. + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s", err, codes.Unavailable) + } +} + +func TestServiceConfigWaitForReady(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testServiceConfigWaitForReady(t, e) + } +} + +func testServiceConfigWaitForReady(t *testing.T, e env) { + te, ch := testServiceConfigSetup(t, e) + defer te.tearDown() + + // Case1: Client API set failfast to be false, and service config set wait_for_ready to be false, Client API should win, and the rpc will wait until deadline exceeds. + mc := grpc.MethodConfig{ + WaitForReady: newBool(false), + Timeout: newDuration(time.Millisecond), + } + m := make(map[string]grpc.MethodConfig) + m["/grpc.testing.TestService/EmptyCall"] = mc + m["/grpc.testing.TestService/FullDuplexCall"] = mc + sc := grpc.ServiceConfig{ + Methods: m, + } + ch <- sc + cc := te.clientConn() tc := testpb.NewTestServiceClient(cc) // The following RPCs are expected to become non-fail-fast ones with 1ms deadline. + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded) + } + if _, err := tc.FullDuplexCall(context.Background(), grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded { + t.Fatalf("TestService/FullDuplexCall(_) = _, %v, want %s", err, codes.DeadlineExceeded) + } + + // Generate a service config update. + // Case2: Client API does not set failfast, and service config set wait_for_ready to be true, and the rpc will wait until deadline exceeds. + mc.WaitForReady = newBool(true) + m = make(map[string]grpc.MethodConfig) + m["/grpc.testing.TestService/EmptyCall"] = mc + m["/grpc.testing.TestService/FullDuplexCall"] = mc + sc = grpc.ServiceConfig{ + Methods: m, + } + ch <- sc + + // Wait for the new service config to take effect. + mc = cc.GetMethodConfig("/grpc.testing.TestService/EmptyCall") + for { + if !*mc.WaitForReady { + time.Sleep(100 * time.Millisecond) + mc = cc.GetMethodConfig("/grpc.testing.TestService/EmptyCall") + continue + } + break + } + // The following RPCs are expected to become non-fail-fast ones with 1ms deadline. if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.DeadlineExceeded { t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded) } if _, err := tc.FullDuplexCall(context.Background()); grpc.Code(err) != codes.DeadlineExceeded { t.Fatalf("TestService/FullDuplexCall(_) = _, %v, want %s", err, codes.DeadlineExceeded) } - wg.Wait() - // Generate a service config update. +} + +func TestServiceConfigTimeout(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testServiceConfigTimeout(t, e) + } +} + +func testServiceConfigTimeout(t *testing.T, e env) { + te, ch := testServiceConfigSetup(t, e) + defer te.tearDown() + + // Case1: Client API sets timeout to be 1ns and ServiceConfig sets timeout to be 1hr. Timeout should be 1ns (min of 1ns and 1hr) and the rpc will wait until deadline exceeds. mc := grpc.MethodConfig{ - WaitForReady: false, + Timeout: newDuration(time.Hour), } m := make(map[string]grpc.MethodConfig) m["/grpc.testing.TestService/EmptyCall"] = mc @@ -1154,19 +1279,530 @@ func testServiceConfig(t *testing.T, e env) { Methods: m, } ch <- sc - // Loop until the new update becomes effective. + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + // The following RPCs are expected to become non-fail-fast ones with 1ns deadline. + ctx, _ := context.WithTimeout(context.Background(), time.Nanosecond) + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded) + } + ctx, _ = context.WithTimeout(context.Background(), time.Nanosecond) + if _, err := tc.FullDuplexCall(ctx, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded { + t.Fatalf("TestService/FullDuplexCall(_) = _, %v, want %s", err, codes.DeadlineExceeded) + } + + // Generate a service config update. + // Case2: Client API sets timeout to be 1hr and ServiceConfig sets timeout to be 1ns. Timeout should be 1ns (min of 1ns and 1hr) and the rpc will wait until deadline exceeds. + mc.Timeout = newDuration(time.Nanosecond) + m = make(map[string]grpc.MethodConfig) + m["/grpc.testing.TestService/EmptyCall"] = mc + m["/grpc.testing.TestService/FullDuplexCall"] = mc + sc = grpc.ServiceConfig{ + Methods: m, + } + ch <- sc + + // Wait for the new service config to take effect. + mc = cc.GetMethodConfig("/grpc.testing.TestService/FullDuplexCall") for { - if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable { + if *mc.Timeout != time.Nanosecond { + time.Sleep(100 * time.Millisecond) + mc = cc.GetMethodConfig("/grpc.testing.TestService/FullDuplexCall") continue } break } - // The following RPCs are expected to become fail-fast. - if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable { - t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s", err, codes.Unavailable) + + ctx, _ = context.WithTimeout(context.Background(), time.Hour) + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded) + } + + ctx, _ = context.WithTimeout(context.Background(), time.Hour) + if _, err := tc.FullDuplexCall(ctx, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded { + t.Fatalf("TestService/FullDuplexCall(_) = _, %v, want %s", err, codes.DeadlineExceeded) + } +} + +func TestServiceConfigMaxMsgSize(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testServiceConfigMaxMsgSize(t, e) + } +} + +func testServiceConfigMaxMsgSize(t *testing.T, e env) { + // Setting up values and objects shared across all test cases. + const smallSize = 1 + const largeSize = 1024 + const extraLargeSize = 2048 + + smallPayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, smallSize) + if err != nil { + t.Fatal(err) + } + largePayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, largeSize) + if err != nil { + t.Fatal(err) + } + extraLargePayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, extraLargeSize) + if err != nil { + t.Fatal(err) + } + + mc := grpc.MethodConfig{ + MaxReqSize: newInt(extraLargeSize), + MaxRespSize: newInt(extraLargeSize), + } + + m := make(map[string]grpc.MethodConfig) + m["/grpc.testing.TestService/UnaryCall"] = mc + m["/grpc.testing.TestService/FullDuplexCall"] = mc + sc := grpc.ServiceConfig{ + Methods: m, + } + // Case1: sc set maxReqSize to 2048 (send), maxRespSize to 2048 (recv). + te1, ch1 := testServiceConfigSetup(t, e) + te1.startServer(&testServer{security: e.security}) + defer te1.tearDown() + + ch1 <- sc + tc := testpb.NewTestServiceClient(te1.clientConn()) + + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(int32(extraLargeSize)), + Payload: smallPayload, + } + // Test for unary RPC recv. + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.ResourceExhausted) + } + + // Test for unary RPC send. + req.Payload = extraLargePayload + req.ResponseSize = proto.Int32(int32(smallSize)) + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.ResourceExhausted) + } + + // Test for streaming RPC recv. + respParam := []*testpb.ResponseParameters{ + { + Size: proto.Int32(int32(extraLargeSize)), + }, + } + sreq := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + Payload: smallPayload, + } + stream, err := tc.FullDuplexCall(te1.ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) + } + if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.ResourceExhausted) + } + + // Test for streaming RPC send. + respParam[0].Size = proto.Int32(int32(smallSize)) + sreq.Payload = extraLargePayload + stream, err = tc.FullDuplexCall(te1.ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + if err := stream.Send(sreq); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("%v.Send(%v) = %v, want _, error code: %s", stream, sreq, err, codes.ResourceExhausted) + } + + // Case2: Client API set maxReqSize to 1024 (send), maxRespSize to 1024 (recv). Sc sets maxReqSize to 2048 (send), maxRespSize to 2048 (recv). + te2, ch2 := testServiceConfigSetup(t, e) + te2.maxClientReceiveMsgSize = newInt(1024) + te2.maxClientSendMsgSize = newInt(1024) + te2.startServer(&testServer{security: e.security}) + defer te2.tearDown() + ch2 <- sc + tc = testpb.NewTestServiceClient(te2.clientConn()) + + // Test for unary RPC recv. + req.Payload = smallPayload + req.ResponseSize = proto.Int32(int32(largeSize)) + + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.ResourceExhausted) + } + + // Test for unary RPC send. + req.Payload = largePayload + req.ResponseSize = proto.Int32(int32(smallSize)) + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.ResourceExhausted) + } + + // Test for streaming RPC recv. + stream, err = tc.FullDuplexCall(te2.ctx) + respParam[0].Size = proto.Int32(int32(largeSize)) + sreq.Payload = smallPayload + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) + } + if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.ResourceExhausted) + } + + // Test for streaming RPC send. + respParam[0].Size = proto.Int32(int32(smallSize)) + sreq.Payload = largePayload + stream, err = tc.FullDuplexCall(te2.ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + if err := stream.Send(sreq); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("%v.Send(%v) = %v, want _, error code: %s", stream, sreq, err, codes.ResourceExhausted) + } + + // Case3: Client API set maxReqSize to 4096 (send), maxRespSize to 4096 (recv). Sc sets maxReqSize to 2048 (send), maxRespSize to 2048 (recv). + te3, ch3 := testServiceConfigSetup(t, e) + te3.maxClientReceiveMsgSize = newInt(4096) + te3.maxClientSendMsgSize = newInt(4096) + te3.startServer(&testServer{security: e.security}) + defer te3.tearDown() + ch3 <- sc + tc = testpb.NewTestServiceClient(te3.clientConn()) + + // Test for unary RPC recv. + req.Payload = smallPayload + req.ResponseSize = proto.Int32(int32(largeSize)) + + if _, err := tc.UnaryCall(context.Background(), req); err != nil { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want ", err) + } + + req.ResponseSize = proto.Int32(int32(extraLargeSize)) + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.ResourceExhausted) + } + + // Test for unary RPC send. + req.Payload = largePayload + req.ResponseSize = proto.Int32(int32(smallSize)) + if _, err := tc.UnaryCall(context.Background(), req); err != nil { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want ", err) + } + + req.Payload = extraLargePayload + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.ResourceExhausted) + } + + // Test for streaming RPC recv. + stream, err = tc.FullDuplexCall(te3.ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + respParam[0].Size = proto.Int32(int32(largeSize)) + sreq.Payload = smallPayload + + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) + } + if _, err := stream.Recv(); err != nil { + t.Fatalf("%v.Recv() = _, %v, want ", stream, err) + } + + respParam[0].Size = proto.Int32(int32(extraLargeSize)) + + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) + } + if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.ResourceExhausted) + } + + // Test for streaming RPC send. + respParam[0].Size = proto.Int32(int32(smallSize)) + sreq.Payload = largePayload + stream, err = tc.FullDuplexCall(te3.ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) + } + sreq.Payload = extraLargePayload + if err := stream.Send(sreq); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("%v.Send(%v) = %v, want _, error code: %s", stream, sreq, err, codes.ResourceExhausted) + } +} + +func TestMaxMsgSizeClientDefault(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testMaxMsgSizeClientDefault(t, e) + } +} + +func testMaxMsgSizeClientDefault(t *testing.T, e env) { + te := newTest(t, e) + te.userAgent = testAppUA + // To avoid error on server side. + te.maxServerSendMsgSize = newInt(5 * 1024 * 1024) + te.declareLogNoise( + "transport: http2Client.notifyError got notified that the client transport was broken EOF", + "grpc: addrConn.transportMonitor exits due to: grpc: the connection is closing", + "grpc: addrConn.resetTransport failed to create client transport: connection error", + "Failed to dial : context canceled; please retry.", + ) + te.startServer(&testServer{security: e.security}) + + defer te.tearDown() + tc := testpb.NewTestServiceClient(te.clientConn()) + + const smallSize = 1 + const largeSize = 4 * 1024 * 1024 + smallPayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, smallSize) + if err != nil { + t.Fatal(err) + } + + largePayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, largeSize) + if err != nil { + t.Fatal(err) + } + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(int32(largeSize)), + Payload: smallPayload, + } + // Test for unary RPC recv. + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.ResourceExhausted) + } + + // Test for unary RPC send. + req.Payload = largePayload + req.ResponseSize = proto.Int32(int32(smallSize)) + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.ResourceExhausted) + } + + respParam := []*testpb.ResponseParameters{ + { + Size: proto.Int32(int32(largeSize)), + }, + } + sreq := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + Payload: smallPayload, + } + + // Test for streaming RPC recv. + stream, err := tc.FullDuplexCall(te.ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) + } + if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.ResourceExhausted) + } + + // Test for streaming RPC send. + respParam[0].Size = proto.Int32(int32(smallSize)) + sreq.Payload = largePayload + stream, err = tc.FullDuplexCall(te.ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + if err := stream.Send(sreq); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("%v.Send(%v) = %v, want _, error codes: %s", stream, sreq, err, codes.ResourceExhausted) + } +} + +func TestMaxMsgSizeClientAPI(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testMaxMsgSizeClientAPI(t, e) + } +} + +func testMaxMsgSizeClientAPI(t *testing.T, e env) { + te := newTest(t, e) + te.userAgent = testAppUA + // To avoid error on server side. + te.maxServerSendMsgSize = newInt(5 * 1024 * 1024) + te.maxClientReceiveMsgSize = newInt(1024) + te.maxClientSendMsgSize = newInt(1024) + te.declareLogNoise( + "transport: http2Client.notifyError got notified that the client transport was broken EOF", + "grpc: addrConn.transportMonitor exits due to: grpc: the connection is closing", + "grpc: addrConn.resetTransport failed to create client transport: connection error", + "Failed to dial : context canceled; please retry.", + ) + te.startServer(&testServer{security: e.security}) + + defer te.tearDown() + tc := testpb.NewTestServiceClient(te.clientConn()) + + const smallSize = 1 + const largeSize = 1024 + smallPayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, smallSize) + if err != nil { + t.Fatal(err) + } + + largePayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, largeSize) + if err != nil { + t.Fatal(err) + } + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(int32(largeSize)), + Payload: smallPayload, + } + // Test for unary RPC recv. + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.ResourceExhausted) + } + + // Test for unary RPC send. + req.Payload = largePayload + req.ResponseSize = proto.Int32(int32(smallSize)) + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.ResourceExhausted) + } + + respParam := []*testpb.ResponseParameters{ + { + Size: proto.Int32(int32(largeSize)), + }, + } + sreq := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + Payload: smallPayload, + } + + // Test for streaming RPC recv. + stream, err := tc.FullDuplexCall(te.ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) + } + if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.ResourceExhausted) + } + + // Test for streaming RPC send. + respParam[0].Size = proto.Int32(int32(smallSize)) + sreq.Payload = largePayload + stream, err = tc.FullDuplexCall(te.ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + if err := stream.Send(sreq); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("%v.Send(%v) = %v, want _, error code: %s", stream, sreq, err, codes.ResourceExhausted) + } +} + +func TestMaxMsgSizeServerAPI(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testMaxMsgSizeServerAPI(t, e) + } +} + +func testMaxMsgSizeServerAPI(t *testing.T, e env) { + te := newTest(t, e) + te.userAgent = testAppUA + te.maxServerReceiveMsgSize = newInt(1024) + te.maxServerSendMsgSize = newInt(1024) + te.declareLogNoise( + "transport: http2Client.notifyError got notified that the client transport was broken EOF", + "grpc: addrConn.transportMonitor exits due to: grpc: the connection is closing", + "grpc: addrConn.resetTransport failed to create client transport: connection error", + "Failed to dial : context canceled; please retry.", + ) + te.startServer(&testServer{security: e.security}) + + defer te.tearDown() + tc := testpb.NewTestServiceClient(te.clientConn()) + + const smallSize = 1 + const largeSize = 1024 + smallPayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, smallSize) + if err != nil { + t.Fatal(err) + } + + largePayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, largeSize) + if err != nil { + t.Fatal(err) + } + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(int32(largeSize)), + Payload: smallPayload, + } + // Test for unary RPC send. + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.ResourceExhausted) + } + + // Test for unary RPC recv. + req.Payload = largePayload + req.ResponseSize = proto.Int32(int32(smallSize)) + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.ResourceExhausted) + } + + respParam := []*testpb.ResponseParameters{ + { + Size: proto.Int32(int32(largeSize)), + }, + } + sreq := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + Payload: smallPayload, + } + + // Test for streaming RPC send. + stream, err := tc.FullDuplexCall(te.ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) + } + if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.ResourceExhausted) + } + + // Test for streaming RPC recv. + respParam[0].Size = proto.Int32(int32(smallSize)) + sreq.Payload = largePayload + stream, err = tc.FullDuplexCall(te.ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) } - if _, err := tc.FullDuplexCall(context.Background()); grpc.Code(err) != codes.Unavailable { - t.Fatalf("TestService/FullDuplexCall(_) = _, %v, want %s", err, codes.Unavailable) + if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.ResourceExhausted) } } @@ -1486,6 +2122,7 @@ func testLargeUnary(t *testing.T, e env) { } } +// Test backward-compatability API for setting msg size limit. func TestExceedMsgLimit(t *testing.T) { defer leakCheck(t)() for _, e := range listTestEnv() { @@ -1495,12 +2132,12 @@ func TestExceedMsgLimit(t *testing.T) { func testExceedMsgLimit(t *testing.T, e env) { te := newTest(t, e) - te.maxMsgSize = 1024 + te.maxMsgSize = newInt(1024) te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) - argSize := int32(te.maxMsgSize + 1) + argSize := int32(*te.maxMsgSize + 1) const smallSize = 1 payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize) @@ -1512,23 +2149,23 @@ func testExceedMsgLimit(t *testing.T, e env) { t.Fatal(err) } - // test on server side for unary RPC + // Test on server side for unary RPC. req := &testpb.SimpleRequest{ ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), ResponseSize: proto.Int32(smallSize), Payload: payload, } - if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Internal { - t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.Internal) + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.ResourceExhausted) } - // test on client side for unary RPC - req.ResponseSize = proto.Int32(int32(te.maxMsgSize) + 1) + // Test on client side for unary RPC. + req.ResponseSize = proto.Int32(int32(*te.maxMsgSize) + 1) req.Payload = smallPayload - if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Internal { - t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.Internal) + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.ResourceExhausted) } - // test on server side for streaming RPC + // Test on server side for streaming RPC. stream, err := tc.FullDuplexCall(te.ctx) if err != nil { t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) @@ -1539,7 +2176,7 @@ func testExceedMsgLimit(t *testing.T, e env) { }, } - spayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(te.maxMsgSize+1)) + spayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(*te.maxMsgSize+1)) if err != nil { t.Fatal(err) } @@ -1552,22 +2189,22 @@ func testExceedMsgLimit(t *testing.T, e env) { if err := stream.Send(sreq); err != nil { t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) } - if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.Internal { - t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.Internal) + if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.ResourceExhausted) } - // test on client side for streaming RPC + // Test on client side for streaming RPC. stream, err = tc.FullDuplexCall(te.ctx) if err != nil { t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) } - respParam[0].Size = proto.Int32(int32(te.maxMsgSize) + 1) + respParam[0].Size = proto.Int32(int32(*te.maxMsgSize) + 1) sreq.Payload = smallPayload if err := stream.Send(sreq); err != nil { t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) } - if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.Internal { - t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.Internal) + if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.ResourceExhausted { + t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.ResourceExhausted) } }