Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions experimental/experimental.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,12 @@ func WithBufferPool(bufferPool mem.BufferPool) grpc.DialOption {
func BufferPool(bufferPool mem.BufferPool) grpc.ServerOption {
return internal.BufferPool.(func(mem.BufferPool) grpc.ServerOption)(bufferPool)
}

// AcceptCompressors returns a CallOption that limits the values
// advertised in the grpc-accept-encoding header for the provided RPC. The
// supplied names must correspond to compressors registered via
// encoding.RegisterCompressor. Passing no names advertises "identity" (no
// compression) only.
func AcceptCompressors(names ...string) grpc.CallOption {
return internal.AcceptCompressors.(func(...string) grpc.CallOption)(names...)
}
4 changes: 4 additions & 0 deletions internal/experimental.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,8 @@ var (
// BufferPool is implemented by the grpc package and returns a server
// option to configure a shared buffer pool for a grpc.Server.
BufferPool any // func (grpc.SharedBufferPool) grpc.ServerOption

// AcceptCompressors is implemented by the grpc package and returns
// a call option that restricts the grpc-accept-encoding header for a call.
AcceptCompressors any // func(...string) grpc.CallOption
)
3 changes: 3 additions & 0 deletions internal/transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,9 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
hfLen := 7 // :method, :scheme, :path, :authority, content-type, user-agent, te
hfLen += len(authData) + len(callAuthData)
registeredCompressors := t.registeredCompressors
if callHdr.AcceptedCompressors != nil {
registeredCompressors = *callHdr.AcceptedCompressors
}
if callHdr.PreviousAttempts > 0 {
hfLen++
}
Expand Down
6 changes: 6 additions & 0 deletions internal/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,12 @@ type CallHdr struct {
// outbound message.
SendCompress string

// AcceptedCompressors overrides the grpc-accept-encoding header for this
// call. When nil, the transport advertises the default set of registered
// compressors. A non-nil pointer overrides that value (including the empty
// string to advertise none).
AcceptedCompressors *string

// Creds specifies credentials.PerRPCCredentials for a call.
Creds credentials.PerRPCCredentials

Expand Down
93 changes: 81 additions & 12 deletions rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/encoding/proto"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata"
Expand All @@ -41,6 +43,10 @@ import (
"google.golang.org/grpc/status"
)

func init() {
internal.AcceptCompressors = AcceptCompressors
}

// Compressor defines the interface gRPC uses to compress a message.
//
// Deprecated: use package encoding.
Expand Down Expand Up @@ -151,16 +157,32 @@ func (d *gzipDecompressor) Type() string {

// callInfo contains all related configuration and information about an RPC.
type callInfo struct {
compressorName string
failFast bool
maxReceiveMessageSize *int
maxSendMessageSize *int
creds credentials.PerRPCCredentials
contentSubtype string
codec baseCodec
maxRetryRPCBufferSize int
onFinish []func(err error)
authority string
compressorName string
failFast bool
maxReceiveMessageSize *int
maxSendMessageSize *int
creds credentials.PerRPCCredentials
contentSubtype string
codec baseCodec
maxRetryRPCBufferSize int
onFinish []func(err error)
authority string
acceptedResponseCompressors []string
}

func acceptedCompressorAllows(allowed []string, name string) bool {
if allowed == nil {
return true
}
if name == "" || name == encoding.Identity {
return true
}
for _, a := range allowed {
if a == name {
return true
}
}
return false
}

func defaultCallInfo() *callInfo {
Expand All @@ -170,6 +192,29 @@ func defaultCallInfo() *callInfo {
}
}

func newAcceptedCompressionConfig(names []string) ([]string, error) {
if len(names) == 0 {
return nil, nil
}
var allowed []string
seen := make(map[string]struct{}, len(names))
for _, name := range names {
name = strings.TrimSpace(name)
if name == "" || name == encoding.Identity {
continue
}
if !grpcutil.IsCompressorNameRegistered(name) {
return nil, status.Errorf(codes.InvalidArgument, "grpc: compressor %q is not registered", name)
}
if _, dup := seen[name]; dup {
continue
}
seen[name] = struct{}{}
allowed = append(allowed, name)
}
return allowed, nil
}

// CallOption configures a Call before it starts or extracts information from
// a Call after it completes.
type CallOption interface {
Expand Down Expand Up @@ -471,6 +516,31 @@ func (o CompressorCallOption) before(c *callInfo) error {
}
func (o CompressorCallOption) after(*callInfo, *csAttempt) {}

// AcceptCompressors returns a CallOption that limits the compression algorithms
// advertised in the grpc-accept-encoding header for response messages.
// Compression algorithms not in the provided list will not be advertised, and
// responses compressed with non-listed algorithms will be rejected.
func AcceptCompressors(names ...string) CallOption {
cp := append([]string(nil), names...)
return AcceptCompressorsCallOption{names: cp}
}

// AcceptCompressorsCallOption is a CallOption that limits response compression.
type AcceptCompressorsCallOption struct {
names []string
}

func (o AcceptCompressorsCallOption) before(c *callInfo) error {
allowed, err := newAcceptedCompressionConfig(o.names)
if err != nil {
return err
}
c.acceptedResponseCompressors = allowed
return nil
}

func (AcceptCompressorsCallOption) after(*callInfo, *csAttempt) {}

// CallContentSubtype returns a CallOption that will set the content-subtype
// for a call. For example, if content-subtype is "json", the Content-Type over
// the wire will be "application/grpc+json". The content-subtype is converted
Expand Down Expand Up @@ -857,8 +927,7 @@ func (p *payloadInfo) free() {
// the buffer is no longer needed.
// TODO: Refactor this function to reduce the number of arguments.
// See: https://google.github.io/styleguide/go/best-practices.html#function-argument-lists
func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool,
) (out mem.BufferSlice, err error) {
func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) (out mem.BufferSlice, err error) {
pf, compressed, err := p.recvMsg(maxReceiveMessageSize)
if err != nil {
return nil, err
Expand Down
112 changes: 112 additions & 0 deletions rpc_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,118 @@ const (
decompressionErrorMsg = "invalid compression format"
)

type testCompressorForRegistry struct {
name string
}

func (c *testCompressorForRegistry) Compress(w io.Writer) (io.WriteCloser, error) {
return &testWriteCloser{w}, nil
}

func (c *testCompressorForRegistry) Decompress(r io.Reader) (io.Reader, error) {
return r, nil
}

func (c *testCompressorForRegistry) Name() string {
return c.name
}

type testWriteCloser struct {
io.Writer
}

func (w *testWriteCloser) Close() error {
return nil
}

func (s) TestNewAcceptedCompressionConfig(t *testing.T) {
// Register a test compressor for multi-compressor tests
testCompressor := &testCompressorForRegistry{name: "test-compressor"}
encoding.RegisterCompressor(testCompressor)
defer func() {
// Unregister the test compressor
encoding.RegisterCompressor(&testCompressorForRegistry{name: "test-compressor"})
}()

tests := []struct {
name string
input []string
wantAllowed []string
wantErr bool
}{
{
name: "identity-only",
input: nil,
wantAllowed: nil,
},
{
name: "single valid",
input: []string{"gzip"},
wantAllowed: []string{"gzip"},
},
{
name: "dedupe and trim",
input: []string{" gzip ", "gzip"},
wantAllowed: []string{"gzip"},
},
{
name: "ignores identity",
input: []string{"identity", "gzip"},
wantAllowed: []string{"gzip"},
},
{
name: "explicit identity only",
input: []string{"identity"},
wantAllowed: nil,
},
{
name: "invalid compressor",
input: []string{"does-not-exist"},
wantErr: true,
},
{
name: "only whitespace",
input: []string{" ", "\t"},
wantAllowed: nil,
},
{
name: "multiple valid compressors",
input: []string{"gzip", "test-compressor"},
wantAllowed: []string{"gzip", "test-compressor"},
},
{
name: "multiple with identity and whitespace",
input: []string{"gzip", "identity", " test-compressor ", " "},
wantAllowed: []string{"gzip", "test-compressor"},
},
{
name: "empty string in list",
input: []string{"gzip", "", "test-compressor"},
wantAllowed: []string{"gzip", "test-compressor"},
},
{
name: "mixed valid and invalid",
input: []string{"gzip", "invalid-comp"},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
allowed, err := newAcceptedCompressionConfig(tt.input)
if (err != nil) != tt.wantErr {
t.Fatalf("newAcceptedCompressionConfig(%v) error = %v, wantErr %v", tt.input, err, tt.wantErr)
}
if tt.wantErr {
return
}
if diff := cmp.Diff(tt.wantAllowed, allowed); diff != "" {
t.Fatalf("allowed diff (-want +got): %v", diff)
}
})
}
}

type fullReader struct {
data []byte
}
Expand Down
13 changes: 13 additions & 0 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"math"
rand "math/rand/v2"
"strconv"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -301,6 +302,10 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client
DoneFunc: doneFunc,
Authority: callInfo.authority,
}
if allowed := callInfo.acceptedResponseCompressors; len(allowed) > 0 {
headerValue := strings.Join(allowed, ",")
callHdr.AcceptedCompressors = &headerValue
}

// Set our outgoing compression according to the UseCompressor CallOption, if
// set. In that case, also find the compressor from the encoding package.
Expand Down Expand Up @@ -1134,6 +1139,10 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
a.decompressorV0 = nil
a.decompressorV1 = encoding.GetCompressor(ct)
}
// Validate that the compression method is acceptable for this call.
if !acceptedCompressorAllows(cs.callInfo.acceptedResponseCompressors, ct) {
return status.Errorf(codes.Internal, "grpc: peer compressed the response with %q which is not allowed by AcceptCompressors", ct)
}
} else {
// No compression is used; disable our decompressor.
a.decompressorV0 = nil
Expand Down Expand Up @@ -1479,6 +1488,10 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {
as.decompressorV0 = nil
as.decompressorV1 = encoding.GetCompressor(ct)
}
// Validate that the compression method is acceptable for this call.
if !acceptedCompressorAllows(as.callInfo.acceptedResponseCompressors, ct) {
return status.Errorf(codes.Internal, "grpc: peer compressed the response with %q which is not allowed by AcceptCompressors", ct)
}
} else {
// No compression is used; disable our decompressor.
as.decompressorV0 = nil
Expand Down
Loading