Skip to content

Commit

Permalink
Split reqReaderProto into envelopeRequestReader and noEnvelopeReader
Browse files Browse the repository at this point in the history
  • Loading branch information
abhinav authored and witriew committed Aug 17, 2021
1 parent cd8fb79 commit b26373b
Showing 1 changed file with 57 additions and 33 deletions.
90 changes: 57 additions & 33 deletions encoding/thrift/inbound_nowire.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,64 +141,88 @@ func (t thriftNoWireHandler) decodeAndHandle(
EnvelopeType: reqEnvelopeType,
}

// If the underlying nowire transport supports envelope-agnostic
// request decoding, use that. Otherwise, build a RequestReader based
// on envelope configuration.
if reqReader, ok := t.Protocol.(stream.RequestReader); ok {
nwc.RequestReader = reqReader
} else if t.Enveloping {
nwc.RequestReader = &envelopeRequestReader{
proto: t.Protocol,
treq: treq,
}
} else {
nwc.RequestReader = &reqReaderProto{
Protocol: t.Protocol,
treq: treq,
enveloping: t.Enveloping,
nwc.RequestReader = &noEnvelopeReader{
proto: t.Protocol,
treq: treq,
}
}

return t.NoWireHandler.Handle(ctx, &nwc)
}

// reqReaderProto is an implementation of ThriftRW's stream.RequestReader in
// case the provided stream.Protocol does not implement the necessary
// `ReadRequest` to discover the correct enveloping.
type reqReaderProto struct {
stream.Protocol

treq *transport.Request
enveloping bool
// envelopeRequestReader implements ThriftRW's stream.RequestReader, decoding
// requests with enveloping support.
type envelopeRequestReader struct {
proto stream.Protocol
treq *transport.Request
}

var _ stream.RequestReader = (*reqReaderProto)(nil)
var _ stream.RequestReader = (*envelopeRequestReader)(nil)

func (p *reqReaderProto) ReadRequest(
func (p *envelopeRequestReader) ReadRequest(
ctx context.Context,
et wire.EnvelopeType,
r io.Reader,
body stream.BodyReader,
) (stream.ResponseWriter, error) {
sr := p.Protocol.Reader(r)
sr := p.proto.Reader(r)
defer sr.Close()

if p.enveloping {
eh, err := sr.ReadEnvelopeBegin()
if err != nil {
return nil, errors.RequestBodyDecodeError(p.treq, err)
}
eh, err := sr.ReadEnvelopeBegin()
if err != nil {
return nil, errors.RequestBodyDecodeError(p.treq, err)
}

if eh.Type != et {
return nil, errors.RequestBodyDecodeError(p.treq, errUnexpectedEnvelopeType(eh.Type))
}
if eh.Type != et {
return nil, errors.RequestBodyDecodeError(p.treq, errUnexpectedEnvelopeType(eh.Type))
}

if err := body.Decode(sr); err != nil {
return nil, errors.RequestBodyDecodeError(p.treq, err)
}
if err := body.Decode(sr); err != nil {
return nil, errors.RequestBodyDecodeError(p.treq, err)
}

if err := sr.ReadEnvelopeEnd(); err != nil {
return nil, errors.RequestBodyDecodeError(p.treq, err)
}
return binary.EnvelopeV1Responder{
Name: eh.Name,
SeqID: eh.SeqID,
}, nil
if err := sr.ReadEnvelopeEnd(); err != nil {
return nil, errors.RequestBodyDecodeError(p.treq, err)
}

return binary.EnvelopeV1Responder{
Name: eh.Name,
SeqID: eh.SeqID,
}, nil
}

// noEnvelopeReader implements ThriftRW's stream.RequestReader, decoding
// requests without enveloping support.
type noEnvelopeReader struct {
proto stream.Protocol
treq *transport.Request
}

var _ stream.RequestReader = (*noEnvelopeReader)(nil)

func (p *noEnvelopeReader) ReadRequest(
ctx context.Context,
et wire.EnvelopeType,
r io.Reader,
body stream.BodyReader,
) (stream.ResponseWriter, error) {
sr := p.proto.Reader(r)
defer sr.Close()

if err := body.Decode(sr); err != nil {
return nil, errors.RequestBodyDecodeError(p.treq, err)
}

return binary.NoEnvelopeResponder, nil
}

0 comments on commit b26373b

Please sign in to comment.