Skip to content

Commit 3cf9054

Browse files
Fixed the review changes
1 parent 26bf731 commit 3cf9054

File tree

2 files changed

+26
-28
lines changed

2 files changed

+26
-28
lines changed

rpc_util.go

+6-7
Original file line numberDiff line numberDiff line change
@@ -845,17 +845,16 @@ func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveM
845845
}
846846
} else {
847847
out, err = decompress(compressor, compressed, maxReceiveMessageSize, p.bufferPool)
848+
if err == errMaxMessageSizeExceeded {
849+
out.Free()
850+
// TODO: Revisit the error code. Currently keep it consistent with java
851+
// implementation.
852+
return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max %d", maxReceiveMessageSize)
853+
}
848854
}
849855
if err != nil {
850856
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
851857
}
852-
853-
if err == errMaxMessageSizeExceeded {
854-
out.Free()
855-
// TODO: Revisit the error code. Currently keep it consistent with java
856-
// implementation.
857-
return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max %d", maxReceiveMessageSize)
858-
}
859858
} else {
860859
out = compressed
861860
}

rpc_util_test.go

+20-21
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ package grpc
2121
import (
2222
"bytes"
2323
"compress/gzip"
24-
"errors"
2524
"io"
2625
"math"
2726
"reflect"
2827
"testing"
2928

3029
"github.com/google/go-cmp/cmp"
30+
"github.com/google/go-cmp/cmp/cmpopts"
3131
"google.golang.org/grpc/codes"
3232
"google.golang.org/grpc/encoding"
3333
_ "google.golang.org/grpc/encoding/gzip"
@@ -300,27 +300,33 @@ func BenchmarkGZIPCompressor1MiB(b *testing.B) {
300300
}
301301

302302
// compressData compresses data using gzip and returns the compressed bytes.
303-
func compressData(data []byte) []byte {
303+
// It now accepts *testing.T to handle errors during compression.
304+
func compressData(t *testing.T, data []byte) []byte {
304305
var buf bytes.Buffer
305306
gz := gzip.NewWriter(&buf)
306-
_, _ = gz.Write(data)
307-
_ = gz.Close()
307+
if _, err := gz.Write(data); err != nil {
308+
t.Fatalf("compressData() failed to write data: %v", err)
309+
}
310+
311+
if err := gz.Close(); err != nil {
312+
t.Fatalf("compressData() failed to close gzip writer: %v", err)
313+
}
308314
return buf.Bytes()
309315
}
310316

317+
// compressInput compresses input data and returns a BufferSlice.
318+
func compressInput(input []byte) mem.BufferSlice {
319+
compressedData := compressData(nil, input)
320+
return mem.BufferSlice{mem.NewBuffer(&compressedData, nil)}
321+
}
322+
311323
// TestDecompress tests the decompress function with various scenarios, including
312324
// successful decompression, error handling, and edge cases like overflow or
313325
// premature data end. It ensures that the function behaves correctly with different
314326
// inputs, buffer sizes, and error conditions, using the "gzip" compressor for testing.
315-
316327
func TestDecompress(t *testing.T) {
317328
c := encoding.GetCompressor("gzip")
318329

319-
compressInput := func(input []byte) mem.BufferSlice {
320-
compressedData := compressData(input)
321-
return mem.BufferSlice{mem.NewBuffer(&compressedData, nil)}
322-
}
323-
324330
tests := []struct {
325331
name string
326332
compressor encoding.Compressor
@@ -354,10 +360,10 @@ func TestDecompress(t *testing.T) {
354360
wantErr: nil,
355361
},
356362
{
357-
name: "Handles maxReceiveMessageSize as MaxInt64",
363+
name: "Handles maxReceiveMessageSize as MaxInt",
358364
compressor: c,
359365
input: []byte("small message"),
360-
maxReceiveMessageSize: math.MaxInt64,
366+
maxReceiveMessageSize: math.MaxInt,
361367
want: []byte("small message"),
362368
wantErr: nil,
363369
},
@@ -368,15 +374,8 @@ func TestDecompress(t *testing.T) {
368374
compressedMsg := compressInput(tt.input)
369375
output, err := decompress(tt.compressor, compressedMsg, tt.maxReceiveMessageSize, mem.DefaultBufferPool())
370376

371-
if tt.wantErr != nil {
372-
if !errors.Is(err, tt.wantErr) {
373-
t.Fatalf("decompress() error = %v, wantErr = %v", err, tt.wantErr)
374-
}
375-
return
376-
}
377-
378-
if err != nil {
379-
t.Fatalf("decompress() unexpected error = %v", err)
377+
if !cmp.Equal(err, tt.wantErr, cmpopts.EquateErrors()) {
378+
t.Fatalf("decompress() error = %v, wantErr = %v", err, tt.wantErr)
380379
}
381380

382381
if diff := cmp.Diff(tt.want, output.Materialize()); diff != "" {

0 commit comments

Comments
 (0)