Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: while using gzip and call with MaxCallRecvMsgSize(math.MaxInt64) #2

Closed
wants to merge 11 commits into from
38 changes: 20 additions & 18 deletions examples/helloworld/greeter_client/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright 2015 gRPC authors.
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,45 +16,47 @@
*
*/

// Package main implements a client for Greeter service.
// Binary client is an example client.
package main

import (
"context"
"flag"
"fmt"
"log"
"math"
"time"

"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
pb "google.golang.org/grpc/examples/helloworld/helloworld"
"google.golang.org/grpc/encoding/gzip" // Install the gzip compressor
pb "google.golang.org/grpc/examples/features/proto/echo"
)

const (
defaultName = "world"
)

var (
addr = flag.String("addr", "localhost:50051", "the address to connect to")
name = flag.String("name", defaultName, "Name to greet")
)
var addr = flag.String("addr", "localhost:50051", "the address to connect to")

func main() {
flag.Parse()

// Set up a connection to the server.
conn, err := grpc.NewClient(*addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
log.Fatalf("did not connect: %v", err)
}
defer conn.Close()
c := pb.NewGreeterClient(conn)

// Contact the server and print out its response.
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
c := pb.NewEchoClient(conn)

// Send the RPC compressed. If all RPCs on a client should be sent this
// way, use the DialOption:
// grpc.WithDefaultCallOptions(grpc.UseCompressor(gzip.Name))
const msg = "compress"
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
r, err := c.SayHello(ctx, &pb.HelloRequest{Name: *name})
if err != nil {
log.Fatalf("could not greet: %v", err)
res, err := c.UnaryEcho(ctx, &pb.EchoRequest{Message: msg}, grpc.UseCompressor(gzip.Name), grpc.MaxCallRecvMsgSize(math.MaxInt64))
fmt.Printf("UnaryEcho call returned %q, %v\n", res.GetMessage(), err)
if err != nil || res.GetMessage() != msg {
log.Fatalf("Message=%q, err=%v; want Message=%q, err=<nil>", res.GetMessage(), err, msg)
}
log.Printf("Greeting: %s", r.GetMessage())

}
32 changes: 15 additions & 17 deletions examples/helloworld/greeter_server/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright 2015 gRPC authors.
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,7 +16,7 @@
*
*/

// Package main implements a server for Greeter service.
// Binary server is an example server.
package main

import (
Expand All @@ -27,34 +27,32 @@ import (
"net"

"google.golang.org/grpc"
pb "google.golang.org/grpc/examples/helloworld/helloworld"
)
_ "google.golang.org/grpc/encoding/gzip" // Install the gzip compressor

var (
port = flag.Int("port", 50051, "The server port")
pb "google.golang.org/grpc/examples/features/proto/echo"
)

// server is used to implement helloworld.GreeterServer.
var port = flag.Int("port", 50051, "the port to serve on")

type server struct {
pb.UnimplementedGreeterServer
pb.UnimplementedEchoServer
}

// SayHello implements helloworld.GreeterServer
func (s *server) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) {
log.Printf("Received: %v", in.GetName())
return &pb.HelloReply{Message: "Hello " + in.GetName()}, nil
func (s *server) UnaryEcho(ctx context.Context, in *pb.EchoRequest) (*pb.EchoResponse, error) {
fmt.Printf("UnaryEcho called with message %q\n", in.GetMessage())
return &pb.EchoResponse{Message: in.Message}, nil
}

func main() {
flag.Parse()

lis, err := net.Listen("tcp", fmt.Sprintf(":%d", *port))
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
fmt.Printf("server listening at %v\n", lis.Addr())

s := grpc.NewServer()
pb.RegisterGreeterServer(s, &server{})
log.Printf("server listening at %v", lis.Addr())
if err := s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
pb.RegisterEchoServer(s, &server{})
s.Serve(lis)
}
50 changes: 38 additions & 12 deletions rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"compress/gzip"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
Expand Down Expand Up @@ -809,25 +810,50 @@ func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize
}); ok {
if size := sizer.DecompressedSize(d); size >= 0 {
if size > maxReceiveMessageSize {
return nil, size, nil
return nil, size, errors.New("message size exceeds maximum allowed")
}
bufferSize := uint64(size) + bytes.MinRead
if bufferSize > math.MaxInt {
bufferSize = math.MaxInt
}
buf := bytes.NewBuffer(make([]byte, 0, int(bufferSize)))
bytesRead, err := buf.ReadFrom(io.LimitReader(dcReader, int64(maxReceiveMessageSize)))
if err != nil {
return nil, int(bytesRead), err
}
if err = checkReceiveMessageOverflow(bytesRead, int64(maxReceiveMessageSize), dcReader); err != nil {
return nil, size + 1, err
}
// size is used as an estimate to size the buffer, but we
// will read more data if available.
// +MinRead so ReadFrom will not reallocate if size is correct.
//
// TODO: If we ensure that the buffer size is the same as the DecompressedSize,
// we can also utilize the recv buffer pool here.
buf := bytes.NewBuffer(make([]byte, 0, size+bytes.MinRead))
bytesRead, err := buf.ReadFrom(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
return buf.Bytes(), int(bytesRead), err
}
}
// Read from LimitReader with limit max+1. So if the underlying
// reader is over limit, the result will be bigger than max.
d, err = io.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
d, err = io.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)))
if err != nil {
return nil, len(d), err
}
if err = checkReceiveMessageOverflow(int64(len(d)), int64(maxReceiveMessageSize), dcReader); err != nil {
return nil, len(d) + 1, err
}
return d, len(d), err
}

// checkReceiveMessageOverflow checks if the number of bytes read from the stream exceeds
// the maximum receive message size allowed by the client. If the `readBytes` equals
// `maxReceiveMessageSize`, the function attempts to read one more byte from the `dcReader`
// to detect if there's an overflow.
//
// If additional data is read, or an error other than `io.EOF` is encountered, the function
// returns an error indicating that the message size has exceeded the permissible limit.
func checkReceiveMessageOverflow(readBytes, maxReceiveMessageSize int64, dcReader io.Reader) error {
if readBytes == maxReceiveMessageSize {
b := make([]byte, 1)
if n, err := dcReader.Read(b); n > 0 || err != io.EOF {
return fmt.Errorf("overflow: message larger than max size receivable by client (%d bytes)", maxReceiveMessageSize)
}
}
return nil
}

// For the two compressor parameters, both should not be set, but if they are,
// dc takes precedence over compressor.
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
Expand Down
131 changes: 131 additions & 0 deletions rpc_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package grpc
import (
"bytes"
"compress/gzip"
"errors"
"io"
"math"
"reflect"
Expand Down Expand Up @@ -266,3 +267,133 @@ func BenchmarkGZIPCompressor512KiB(b *testing.B) {
func BenchmarkGZIPCompressor1MiB(b *testing.B) {
bmCompressor(b, 1024*1024, NewGZIPCompressor())
}
func TestCheckReceiveMessageOverflow(t *testing.T) {
tests := []struct {
name string
readBytes int64
maxReceiveMessageSize int64
dcReader io.Reader
wantErr error
}{
{
name: "No overflow",
readBytes: 5,
maxReceiveMessageSize: 10,
dcReader: bytes.NewReader([]byte{}),
wantErr: nil,
},
{
name: "Overflow with additional data",
readBytes: 10,
maxReceiveMessageSize: 10,
dcReader: bytes.NewReader([]byte{1}),
wantErr: errors.New("overflow: message larger than max size receivable by client (10 bytes)"),
},
{
name: "No overflow with EOF",
maxReceiveMessageSize: 10,
dcReader: bytes.NewReader([]byte{}),
wantErr: nil,
},
{
name: "Overflow condition with error handling",
readBytes: 15,
maxReceiveMessageSize: 15,
dcReader: bytes.NewReader([]byte{1, 2, 3}),
wantErr: errors.New("overflow: message larger than max size receivable by client (15 bytes)"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := checkReceiveMessageOverflow(tt.readBytes, tt.maxReceiveMessageSize, tt.dcReader)
if (err != nil) != (tt.wantErr != nil) {
t.Errorf("unexpected error state: got err=%v, want err=%v", err, tt.wantErr)
} else if err != nil && err.Error() != tt.wantErr.Error() {
t.Errorf("unexpected error message: got err=%v, want err=%v", err, tt.wantErr)
}

})
}
}

type testCompressor struct {
triggerDecompressError bool
}

func (c *testCompressor) Name() string {
// Return a name for the compressor.
return "testCompressor"
}

func (c *testCompressor) Compress(w io.Writer) (io.WriteCloser, error) {
return nil, errors.New("Compress not implemented")
}

func (c *testCompressor) Decompress(r io.Reader) (io.Reader, error) {
if c.triggerDecompressError {
return nil, errors.New("decompression failed")
}
return r, nil
}

func (c *testCompressor) DecompressedSize(compressedBytes []byte) int {
return len(compressedBytes) * 2 // Assume decompressed size is double for testing
}

// TestDecompress tests the decompress function.
func TestDecompress(t *testing.T) {
tests := []struct {
name string
compressor *testCompressor
input []byte
maxReceiveMessageSize int
wantOutput []byte
wantSize int
wantErr bool
}{
{
name: "Successful decompression",
compressor: &testCompressor{},
input: []byte{0x01, 0x02, 0x03, 0x04},
maxReceiveMessageSize: 10,
wantOutput: []byte{0x01, 0x02, 0x03, 0x04},
wantSize: 4,
wantErr: false,
},
{
name: "Message size overflow",
compressor: &testCompressor{},
input: []byte{0x01, 0x02, 0x03, 0x04},
maxReceiveMessageSize: 2,
wantOutput: nil,
wantSize: 8,
wantErr: true,
},
{
name: "Error during decompression",
compressor: &testCompressor{triggerDecompressError: true},
input: []byte{0x01, 0x02, 0x03, 0x04},
maxReceiveMessageSize: 10,
wantOutput: nil,
wantSize: 0,
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
output, size, err := decompress(tt.compressor, tt.input, tt.maxReceiveMessageSize)

if (err != nil) != tt.wantErr {
t.Errorf("decompress() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !bytes.Equal(output, tt.wantOutput) {
t.Errorf("decompress() got = %v, want %v", output, tt.wantOutput)
}
if size != tt.wantSize {
t.Errorf("decompress() size = %d, want %d", size, tt.wantSize)
}
})
}
}