Skip to content

Commit

Permalink
Merge branch 'wit/inbounds_nowire_handler' into wit/inbounds_nowire_o…
Browse files Browse the repository at this point in the history
…ptions
  • Loading branch information
abhinav committed Aug 18, 2021
2 parents 96a3051 + 41afa2e commit 4388d13
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 263 deletions.
92 changes: 4 additions & 88 deletions encoding/thrift/inbound_nowire.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"context"
"io"

"go.uber.org/thriftrw/protocol/binary"
"go.uber.org/thriftrw/protocol/stream"
"go.uber.org/thriftrw/wire"
encodingapi "go.uber.org/yarpc/api/encoding"
Expand Down Expand Up @@ -62,8 +61,7 @@ type NoWireHandler interface {
// request, given the ThriftRW primitives for reading the raw representation.
type thriftNoWireHandler struct {
NoWireHandler NoWireHandler
Protocol stream.Protocol
Enveloping bool
RequestReader stream.RequestReader
}

var (
Expand Down Expand Up @@ -137,92 +135,10 @@ func (t thriftNoWireHandler) decodeAndHandle(
}

nwc := NoWireCall{
Reader: treq.Body,
EnvelopeType: reqEnvelopeType,
}

// If the underlying nowire transport supports envelope-agnostic
// request decoding, use that. Otherwise, build a RequestReader based
// on envelope configuration.
if reqReader, ok := t.Protocol.(stream.RequestReader); ok {
nwc.RequestReader = reqReader
} else if t.Enveloping {
nwc.RequestReader = &envelopeRequestReader{
proto: t.Protocol,
treq: treq,
}
} else {
nwc.RequestReader = &noEnvelopeReader{
proto: t.Protocol,
treq: treq,
}
Reader: treq.Body,
EnvelopeType: reqEnvelopeType,
RequestReader: t.RequestReader,
}

return t.NoWireHandler.Handle(ctx, &nwc)
}

// envelopeRequestReader implements ThriftRW's stream.RequestReader, decoding
// requests with enveloping support.
type envelopeRequestReader struct {
proto stream.Protocol
treq *transport.Request
}

var _ stream.RequestReader = (*envelopeRequestReader)(nil)

func (p *envelopeRequestReader) ReadRequest(
ctx context.Context,
et wire.EnvelopeType,
r io.Reader,
body stream.BodyReader,
) (stream.ResponseWriter, error) {
sr := p.proto.Reader(r)
defer sr.Close()

eh, err := sr.ReadEnvelopeBegin()
if err != nil {
return nil, errors.RequestBodyDecodeError(p.treq, err)
}

if eh.Type != et {
return nil, errors.RequestBodyDecodeError(p.treq, errUnexpectedEnvelopeType(eh.Type))
}

if err := body.Decode(sr); err != nil {
return nil, errors.RequestBodyDecodeError(p.treq, err)
}

if err := sr.ReadEnvelopeEnd(); err != nil {
return nil, errors.RequestBodyDecodeError(p.treq, err)
}

return binary.EnvelopeV1Responder{
Name: eh.Name,
SeqID: eh.SeqID,
}, nil
}

// noEnvelopeReader implements ThriftRW's stream.RequestReader, decoding
// requests without enveloping support.
type noEnvelopeReader struct {
proto stream.Protocol
treq *transport.Request
}

var _ stream.RequestReader = (*noEnvelopeReader)(nil)

func (p *noEnvelopeReader) ReadRequest(
ctx context.Context,
et wire.EnvelopeType,
r io.Reader,
body stream.BodyReader,
) (stream.ResponseWriter, error) {
sr := p.proto.Reader(r)
defer sr.Close()

if err := body.Decode(sr); err != nil {
return nil, errors.RequestBodyDecodeError(p.treq, err)
}

return binary.NoEnvelopeResponder, nil
}
204 changes: 29 additions & 175 deletions encoding/thrift/inbound_nowire_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ type responseHandler struct {
t *testing.T
nwc *NoWireCall

reqBody stream.BodyReader
body stream.Enveloper
reqBody stream.BodyReader
body stream.Enveloper
appError bool
}

var _ NoWireHandler = (*responseHandler)(nil)
Expand All @@ -90,8 +91,9 @@ func (rh *responseHandler) Handle(ctx context.Context, nwc *NoWireCall) (NoWireR
rh.nwc = nwc
rw, err := nwc.RequestReader.ReadRequest(ctx, nwc.EnvelopeType, nwc.Reader, rh.reqBody)
return NoWireResponse{
Body: rh.body,
ResponseWriter: rw,
Body: rh.body,
ResponseWriter: rw,
IsApplicationError: rh.appError,
}, err
}

Expand All @@ -104,7 +106,7 @@ func TestDecodeNoWireRequestUnary(t *testing.T) {
proto := binary.Default
h := thriftNoWireHandler{
NoWireHandler: &rh,
Protocol: proto,
RequestReader: proto,
}

ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
Expand All @@ -127,7 +129,7 @@ func TestDecodeNoWireRequestOneway(t *testing.T) {
proto := binary.Default
h := thriftNoWireHandler{
NoWireHandler: &rh,
Protocol: proto,
RequestReader: proto,
}

ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
Expand All @@ -140,76 +142,6 @@ func TestDecodeNoWireRequestOneway(t *testing.T) {
assert.Equal(t, wire.OneWay, rh.nwc.EnvelopeType) // OneWay call
}

func TestDecodeNoWireRequestEnveloping(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

sr := streamtest.NewMockReader(mockCtrl)
sr.EXPECT().ReadEnvelopeBegin().Return(stream.EnvelopeHeader{Type: wire.Call}, nil)
sr.EXPECT().ReadEnvelopeEnd().Return(nil)
sr.EXPECT().Close().Return(nil)

proto := streamtest.NewMockProtocol(mockCtrl)
proto.EXPECT().Reader(gomock.Any()).Return(sr)

br := &bodyReader{}
rh := responseHandler{
t: t,
reqBody: br,
body: responseEnveloper{name: "caller", envelopeType: wire.Reply},
}
h := thriftNoWireHandler{
NoWireHandler: &rh,
Protocol: proto,
Enveloping: true,
}

ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
defer cancel()

rw := new(transporttest.FakeResponseWriter)
require.NoError(t, h.Handle(ctx, request(), rw))
assert.Equal(t, _body, br.body, "request body expected to be decoded")

rrp, ok := rh.nwc.RequestReader.(*envelopeRequestReader)
require.True(t, ok)
assert.Equal(t, proto, rrp.proto)
}

func TestDecodeNoWireRequestEnvelopingFalse(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

sr := streamtest.NewMockReader(mockCtrl)
sr.EXPECT().Close().Return(nil)

proto := streamtest.NewMockProtocol(mockCtrl)
proto.EXPECT().Reader(gomock.Any()).Return(sr)

br := &bodyReader{}
rh := responseHandler{
t: t,
reqBody: br,
body: responseEnveloper{name: "caller", envelopeType: wire.Reply},
}
h := thriftNoWireHandler{
NoWireHandler: &rh,
Protocol: proto,
Enveloping: false,
}

ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
defer cancel()

rw := new(transporttest.FakeResponseWriter)
require.NoError(t, h.Handle(ctx, request(), rw))
assert.Equal(t, _body, br.body, "request body expected to be decoded")

rrp, ok := rh.nwc.RequestReader.(*noEnvelopeReader)
require.True(t, ok)
assert.Equal(t, proto, rrp.proto)
}

func TestNoWireHandleIncorrectResponseEnvelope(t *testing.T) {
br := &bodyReader{}
rh := responseHandler{
Expand All @@ -220,7 +152,7 @@ func TestNoWireHandleIncorrectResponseEnvelope(t *testing.T) {
proto := binary.Default
h := thriftNoWireHandler{
NoWireHandler: &rh,
Protocol: proto,
RequestReader: proto,
}

ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
Expand All @@ -238,14 +170,15 @@ func TestNoWireHandleWriteResponseError(t *testing.T) {
defer mockCtrl.Finish()

proto := streamtest.NewMockRequestReader(mockCtrl)
proto.EXPECT().ReadRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(responseWriter{err: fmt.Errorf("write response error")}, nil)
proto.EXPECT().ReadRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(responseWriter{err: fmt.Errorf("write response error")}, nil)

re := responseEnveloper{name: "caller", envelopeType: wire.Reply}
br := &bodyReader{}
rh := responseHandler{t: t, reqBody: br, body: re}
h := thriftNoWireHandler{
NoWireHandler: &rh,
Protocol: proto,
RequestReader: proto,
}

ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
Expand All @@ -262,7 +195,7 @@ func TestDecodeNoWireRequestExpectEncodingsError(t *testing.T) {
re := responseEnveloper{name: "caller", envelopeType: wire.Reply}
h := thriftNoWireHandler{
NoWireHandler: &responseHandler{t: t, body: re},
Protocol: binary.Default,
RequestReader: binary.Default,
}

ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
Expand All @@ -277,104 +210,25 @@ func TestDecodeNoWireRequestExpectEncodingsError(t *testing.T) {
assert.Contains(t, err.Error(), `expected encoding "thrift" but got "grpc"`)
}

func TestReqReaderEnvelopingEnvelopeBeginError(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

sr := streamtest.NewMockReader(mockCtrl)
sr.EXPECT().ReadEnvelopeBegin().Return(stream.EnvelopeHeader{}, fmt.Errorf("read envelope begin error"))
sr.EXPECT().Close().Return(nil)

proto := streamtest.NewMockProtocol(mockCtrl)
proto.EXPECT().Reader(gomock.Any()).Return(sr)

_, err := testEnvelopedReadRequest(t, proto, &bodyReader{}, true /* enveloping */)
require.Error(t, err)
assert.Contains(t, err.Error(), "read envelope begin error")
}

func TestReqReaderEnvelopingBadEnvelopeType(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

sr := streamtest.NewMockReader(mockCtrl)
sr.EXPECT().ReadEnvelopeBegin().Return(stream.EnvelopeHeader{Type: wire.Exception}, nil)
sr.EXPECT().Close().Return(nil)

proto := streamtest.NewMockProtocol(mockCtrl)
proto.EXPECT().Reader(gomock.Any()).Return(sr)

_, err := testEnvelopedReadRequest(t, proto, &bodyReader{}, true /* enveloping */)
require.Error(t, err)
assert.Contains(t, err.Error(), "unexpected envelope type")
}

func TestReqReaderEnvelopingDecodeError(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

sr := streamtest.NewMockReader(mockCtrl)
sr.EXPECT().ReadEnvelopeBegin().Return(stream.EnvelopeHeader{Type: wire.OneWay}, nil)
sr.EXPECT().Close().Return(nil)

proto := streamtest.NewMockProtocol(mockCtrl)
proto.EXPECT().Reader(gomock.Any()).Return(sr)

_, err := testEnvelopedReadRequest(t, proto, &bodyReader{err: fmt.Errorf("decode error")}, true /* enveloping */)
require.Error(t, err)
assert.Contains(t, err.Error(), "decode error")
}

func TestReqReaderEnvelopingEnvelopeEndError(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

sr := streamtest.NewMockReader(mockCtrl)
sr.EXPECT().ReadEnvelopeBegin().Return(stream.EnvelopeHeader{Type: wire.OneWay}, nil)
sr.EXPECT().ReadEnvelopeEnd().Return(fmt.Errorf("read envelope end error"))
sr.EXPECT().Close().Return(nil)

proto := streamtest.NewMockProtocol(mockCtrl)
proto.EXPECT().Reader(gomock.Any()).Return(sr)

_, err := testEnvelopedReadRequest(t, proto, &bodyReader{}, true /* enveloping */)
require.Error(t, err)
assert.Contains(t, err.Error(), "read envelope end error")
}

func TestReqReaderNotEnvelopingDecodeError(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

sr := streamtest.NewMockReader(mockCtrl)
sr.EXPECT().Close().Return(nil)

proto := streamtest.NewMockProtocol(mockCtrl)
proto.EXPECT().Reader(gomock.Any()).Return(sr)

_, err := testEnvelopedReadRequest(t, proto, &bodyReader{err: fmt.Errorf("another decode error")}, false /* enveloping */)
require.Error(t, err)
assert.Contains(t, err.Error(), "another decode error")
}

func testEnvelopedReadRequest(
t *testing.T,
proto stream.Protocol,
body stream.BodyReader,
enveloping bool,
) (stream.ResponseWriter, error) {
t.Helper()

req := request()
var rrp stream.RequestReader
if enveloping {
rrp = &envelopeRequestReader{proto: proto, treq: req}
} else {
rrp = &noEnvelopeReader{proto: proto, treq: req}
func TestDecodeNoWireAppliationError(t *testing.T) {
br := &bodyReader{}
re := responseEnveloper{name: "caller", envelopeType: wire.Reply}
h := thriftNoWireHandler{
NoWireHandler: &responseHandler{
t: t,
reqBody: br,
body: re,
appError: true,
},
RequestReader: binary.Default,
}

ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
defer cancel()

return rrp.ReadRequest(ctx, wire.OneWay, req.Body, body)
req := request()

rw := new(transporttest.FakeResponseWriter)
require.NoError(t, h.Handle(ctx, req, rw))
assert.True(t, rw.IsApplicationError)
}

0 comments on commit 4388d13

Please sign in to comment.