From 41afa2e58ca5eb1c3a9074b976d1e9dd8a87658e Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Wed, 18 Aug 2021 14:21:58 -0700 Subject: [PATCH] encoding/thrift/nowire: Only support RequestReader Since this is a completely green path, we don't need to make stream.RequestReader an optional upgrade. --- encoding/thrift/inbound_nowire.go | 92 +---------- encoding/thrift/inbound_nowire_test.go | 204 ++++--------------------- 2 files changed, 33 insertions(+), 263 deletions(-) diff --git a/encoding/thrift/inbound_nowire.go b/encoding/thrift/inbound_nowire.go index a2e852d1c5..f392bce2d9 100644 --- a/encoding/thrift/inbound_nowire.go +++ b/encoding/thrift/inbound_nowire.go @@ -24,7 +24,6 @@ import ( "context" "io" - "go.uber.org/thriftrw/protocol/binary" "go.uber.org/thriftrw/protocol/stream" "go.uber.org/thriftrw/wire" encodingapi "go.uber.org/yarpc/api/encoding" @@ -62,8 +61,7 @@ type NoWireHandler interface { // request, given the ThriftRW primitives for reading the raw representation. type thriftNoWireHandler struct { NoWireHandler NoWireHandler - Protocol stream.Protocol - Enveloping bool + RequestReader stream.RequestReader } var ( @@ -137,92 +135,10 @@ func (t thriftNoWireHandler) decodeAndHandle( } nwc := NoWireCall{ - Reader: treq.Body, - EnvelopeType: reqEnvelopeType, - } - - // If the underlying nowire transport supports envelope-agnostic - // request decoding, use that. Otherwise, build a RequestReader based - // on envelope configuration. - if reqReader, ok := t.Protocol.(stream.RequestReader); ok { - nwc.RequestReader = reqReader - } else if t.Enveloping { - nwc.RequestReader = &envelopeRequestReader{ - proto: t.Protocol, - treq: treq, - } - } else { - nwc.RequestReader = &noEnvelopeReader{ - proto: t.Protocol, - treq: treq, - } + Reader: treq.Body, + EnvelopeType: reqEnvelopeType, + RequestReader: t.RequestReader, } return t.NoWireHandler.Handle(ctx, &nwc) } - -// envelopeRequestReader implements ThriftRW's stream.RequestReader, decoding -// requests with enveloping support. -type envelopeRequestReader struct { - proto stream.Protocol - treq *transport.Request -} - -var _ stream.RequestReader = (*envelopeRequestReader)(nil) - -func (p *envelopeRequestReader) ReadRequest( - ctx context.Context, - et wire.EnvelopeType, - r io.Reader, - body stream.BodyReader, -) (stream.ResponseWriter, error) { - sr := p.proto.Reader(r) - defer sr.Close() - - eh, err := sr.ReadEnvelopeBegin() - if err != nil { - return nil, errors.RequestBodyDecodeError(p.treq, err) - } - - if eh.Type != et { - return nil, errors.RequestBodyDecodeError(p.treq, errUnexpectedEnvelopeType(eh.Type)) - } - - if err := body.Decode(sr); err != nil { - return nil, errors.RequestBodyDecodeError(p.treq, err) - } - - if err := sr.ReadEnvelopeEnd(); err != nil { - return nil, errors.RequestBodyDecodeError(p.treq, err) - } - - return binary.EnvelopeV1Responder{ - Name: eh.Name, - SeqID: eh.SeqID, - }, nil -} - -// noEnvelopeReader implements ThriftRW's stream.RequestReader, decoding -// requests without enveloping support. -type noEnvelopeReader struct { - proto stream.Protocol - treq *transport.Request -} - -var _ stream.RequestReader = (*noEnvelopeReader)(nil) - -func (p *noEnvelopeReader) ReadRequest( - ctx context.Context, - et wire.EnvelopeType, - r io.Reader, - body stream.BodyReader, -) (stream.ResponseWriter, error) { - sr := p.proto.Reader(r) - defer sr.Close() - - if err := body.Decode(sr); err != nil { - return nil, errors.RequestBodyDecodeError(p.treq, err) - } - - return binary.NoEnvelopeResponder, nil -} diff --git a/encoding/thrift/inbound_nowire_test.go b/encoding/thrift/inbound_nowire_test.go index 1d06d80204..50e111239b 100644 --- a/encoding/thrift/inbound_nowire_test.go +++ b/encoding/thrift/inbound_nowire_test.go @@ -72,8 +72,9 @@ type responseHandler struct { t *testing.T nwc *NoWireCall - reqBody stream.BodyReader - body stream.Enveloper + reqBody stream.BodyReader + body stream.Enveloper + appError bool } var _ NoWireHandler = (*responseHandler)(nil) @@ -90,8 +91,9 @@ func (rh *responseHandler) Handle(ctx context.Context, nwc *NoWireCall) (NoWireR rh.nwc = nwc rw, err := nwc.RequestReader.ReadRequest(ctx, nwc.EnvelopeType, nwc.Reader, rh.reqBody) return NoWireResponse{ - Body: rh.body, - ResponseWriter: rw, + Body: rh.body, + ResponseWriter: rw, + IsApplicationError: rh.appError, }, err } @@ -104,7 +106,7 @@ func TestDecodeNoWireRequestUnary(t *testing.T) { proto := binary.Default h := thriftNoWireHandler{ NoWireHandler: &rh, - Protocol: proto, + RequestReader: proto, } ctx, cancel := context.WithTimeout(context.Background(), testtime.Second) @@ -127,7 +129,7 @@ func TestDecodeNoWireRequestOneway(t *testing.T) { proto := binary.Default h := thriftNoWireHandler{ NoWireHandler: &rh, - Protocol: proto, + RequestReader: proto, } ctx, cancel := context.WithTimeout(context.Background(), testtime.Second) @@ -140,76 +142,6 @@ func TestDecodeNoWireRequestOneway(t *testing.T) { assert.Equal(t, wire.OneWay, rh.nwc.EnvelopeType) // OneWay call } -func TestDecodeNoWireRequestEnveloping(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - sr := streamtest.NewMockReader(mockCtrl) - sr.EXPECT().ReadEnvelopeBegin().Return(stream.EnvelopeHeader{Type: wire.Call}, nil) - sr.EXPECT().ReadEnvelopeEnd().Return(nil) - sr.EXPECT().Close().Return(nil) - - proto := streamtest.NewMockProtocol(mockCtrl) - proto.EXPECT().Reader(gomock.Any()).Return(sr) - - br := &bodyReader{} - rh := responseHandler{ - t: t, - reqBody: br, - body: responseEnveloper{name: "caller", envelopeType: wire.Reply}, - } - h := thriftNoWireHandler{ - NoWireHandler: &rh, - Protocol: proto, - Enveloping: true, - } - - ctx, cancel := context.WithTimeout(context.Background(), testtime.Second) - defer cancel() - - rw := new(transporttest.FakeResponseWriter) - require.NoError(t, h.Handle(ctx, request(), rw)) - assert.Equal(t, _body, br.body, "request body expected to be decoded") - - rrp, ok := rh.nwc.RequestReader.(*envelopeRequestReader) - require.True(t, ok) - assert.Equal(t, proto, rrp.proto) -} - -func TestDecodeNoWireRequestEnvelopingFalse(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - sr := streamtest.NewMockReader(mockCtrl) - sr.EXPECT().Close().Return(nil) - - proto := streamtest.NewMockProtocol(mockCtrl) - proto.EXPECT().Reader(gomock.Any()).Return(sr) - - br := &bodyReader{} - rh := responseHandler{ - t: t, - reqBody: br, - body: responseEnveloper{name: "caller", envelopeType: wire.Reply}, - } - h := thriftNoWireHandler{ - NoWireHandler: &rh, - Protocol: proto, - Enveloping: false, - } - - ctx, cancel := context.WithTimeout(context.Background(), testtime.Second) - defer cancel() - - rw := new(transporttest.FakeResponseWriter) - require.NoError(t, h.Handle(ctx, request(), rw)) - assert.Equal(t, _body, br.body, "request body expected to be decoded") - - rrp, ok := rh.nwc.RequestReader.(*noEnvelopeReader) - require.True(t, ok) - assert.Equal(t, proto, rrp.proto) -} - func TestNoWireHandleIncorrectResponseEnvelope(t *testing.T) { br := &bodyReader{} rh := responseHandler{ @@ -220,7 +152,7 @@ func TestNoWireHandleIncorrectResponseEnvelope(t *testing.T) { proto := binary.Default h := thriftNoWireHandler{ NoWireHandler: &rh, - Protocol: proto, + RequestReader: proto, } ctx, cancel := context.WithTimeout(context.Background(), testtime.Second) @@ -238,14 +170,15 @@ func TestNoWireHandleWriteResponseError(t *testing.T) { defer mockCtrl.Finish() proto := streamtest.NewMockRequestReader(mockCtrl) - proto.EXPECT().ReadRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(responseWriter{err: fmt.Errorf("write response error")}, nil) + proto.EXPECT().ReadRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(responseWriter{err: fmt.Errorf("write response error")}, nil) re := responseEnveloper{name: "caller", envelopeType: wire.Reply} br := &bodyReader{} rh := responseHandler{t: t, reqBody: br, body: re} h := thriftNoWireHandler{ NoWireHandler: &rh, - Protocol: proto, + RequestReader: proto, } ctx, cancel := context.WithTimeout(context.Background(), testtime.Second) @@ -262,7 +195,7 @@ func TestDecodeNoWireRequestExpectEncodingsError(t *testing.T) { re := responseEnveloper{name: "caller", envelopeType: wire.Reply} h := thriftNoWireHandler{ NoWireHandler: &responseHandler{t: t, body: re}, - Protocol: binary.Default, + RequestReader: binary.Default, } ctx, cancel := context.WithTimeout(context.Background(), testtime.Second) @@ -277,104 +210,25 @@ func TestDecodeNoWireRequestExpectEncodingsError(t *testing.T) { assert.Contains(t, err.Error(), `expected encoding "thrift" but got "grpc"`) } -func TestReqReaderEnvelopingEnvelopeBeginError(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - sr := streamtest.NewMockReader(mockCtrl) - sr.EXPECT().ReadEnvelopeBegin().Return(stream.EnvelopeHeader{}, fmt.Errorf("read envelope begin error")) - sr.EXPECT().Close().Return(nil) - - proto := streamtest.NewMockProtocol(mockCtrl) - proto.EXPECT().Reader(gomock.Any()).Return(sr) - - _, err := testEnvelopedReadRequest(t, proto, &bodyReader{}, true /* enveloping */) - require.Error(t, err) - assert.Contains(t, err.Error(), "read envelope begin error") -} - -func TestReqReaderEnvelopingBadEnvelopeType(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - sr := streamtest.NewMockReader(mockCtrl) - sr.EXPECT().ReadEnvelopeBegin().Return(stream.EnvelopeHeader{Type: wire.Exception}, nil) - sr.EXPECT().Close().Return(nil) - - proto := streamtest.NewMockProtocol(mockCtrl) - proto.EXPECT().Reader(gomock.Any()).Return(sr) - - _, err := testEnvelopedReadRequest(t, proto, &bodyReader{}, true /* enveloping */) - require.Error(t, err) - assert.Contains(t, err.Error(), "unexpected envelope type") -} - -func TestReqReaderEnvelopingDecodeError(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - sr := streamtest.NewMockReader(mockCtrl) - sr.EXPECT().ReadEnvelopeBegin().Return(stream.EnvelopeHeader{Type: wire.OneWay}, nil) - sr.EXPECT().Close().Return(nil) - - proto := streamtest.NewMockProtocol(mockCtrl) - proto.EXPECT().Reader(gomock.Any()).Return(sr) - - _, err := testEnvelopedReadRequest(t, proto, &bodyReader{err: fmt.Errorf("decode error")}, true /* enveloping */) - require.Error(t, err) - assert.Contains(t, err.Error(), "decode error") -} - -func TestReqReaderEnvelopingEnvelopeEndError(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - sr := streamtest.NewMockReader(mockCtrl) - sr.EXPECT().ReadEnvelopeBegin().Return(stream.EnvelopeHeader{Type: wire.OneWay}, nil) - sr.EXPECT().ReadEnvelopeEnd().Return(fmt.Errorf("read envelope end error")) - sr.EXPECT().Close().Return(nil) - - proto := streamtest.NewMockProtocol(mockCtrl) - proto.EXPECT().Reader(gomock.Any()).Return(sr) - - _, err := testEnvelopedReadRequest(t, proto, &bodyReader{}, true /* enveloping */) - require.Error(t, err) - assert.Contains(t, err.Error(), "read envelope end error") -} - -func TestReqReaderNotEnvelopingDecodeError(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - sr := streamtest.NewMockReader(mockCtrl) - sr.EXPECT().Close().Return(nil) - - proto := streamtest.NewMockProtocol(mockCtrl) - proto.EXPECT().Reader(gomock.Any()).Return(sr) - - _, err := testEnvelopedReadRequest(t, proto, &bodyReader{err: fmt.Errorf("another decode error")}, false /* enveloping */) - require.Error(t, err) - assert.Contains(t, err.Error(), "another decode error") -} - -func testEnvelopedReadRequest( - t *testing.T, - proto stream.Protocol, - body stream.BodyReader, - enveloping bool, -) (stream.ResponseWriter, error) { - t.Helper() - - req := request() - var rrp stream.RequestReader - if enveloping { - rrp = &envelopeRequestReader{proto: proto, treq: req} - } else { - rrp = &noEnvelopeReader{proto: proto, treq: req} +func TestDecodeNoWireAppliationError(t *testing.T) { + br := &bodyReader{} + re := responseEnveloper{name: "caller", envelopeType: wire.Reply} + h := thriftNoWireHandler{ + NoWireHandler: &responseHandler{ + t: t, + reqBody: br, + body: re, + appError: true, + }, + RequestReader: binary.Default, } ctx, cancel := context.WithTimeout(context.Background(), testtime.Second) defer cancel() - return rrp.ReadRequest(ctx, wire.OneWay, req.Body, body) + req := request() + + rw := new(transporttest.FakeResponseWriter) + require.NoError(t, h.Handle(ctx, req, rw)) + assert.True(t, rw.IsApplicationError) }