-
Notifications
You must be signed in to change notification settings - Fork 103
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
thrift: Implement client "StreamClient" to support thriftrw streaming…
… protocol for outbound requests
- Loading branch information
Showing
2 changed files
with
217 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,231 @@ | ||
package thrift | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"io" | ||
|
||
"go.uber.org/thriftrw/envelope/stream" | ||
pstream "go.uber.org/thriftrw/protocol/stream" | ||
envelope "go.uber.org/thriftrw/envelope/stream" | ||
"go.uber.org/thriftrw/protocol/binary" | ||
"go.uber.org/thriftrw/protocol/stream" | ||
"go.uber.org/thriftrw/wire" | ||
"go.uber.org/yarpc" | ||
encodingapi "go.uber.org/yarpc/api/encoding" | ||
"go.uber.org/yarpc/api/transport" | ||
"go.uber.org/yarpc/encoding/thrift/internal" | ||
"go.uber.org/yarpc/pkg/encoding" | ||
"go.uber.org/yarpc/pkg/errors" | ||
"go.uber.org/yarpc/pkg/procedure" | ||
) | ||
|
||
// StreamClient is a generic Thrift client. It speaks in raw Thrift payloads. | ||
// StreamClient is a generic Thrift client for stream encoding/decoding. | ||
// It speaks in raw Thrift payloads. | ||
// | ||
// Users should use the client generated by the code generator rather than | ||
// using this directly. | ||
type StreamClient interface { | ||
// Call the given Thrift method. | ||
Call(ctx context.Context, reqBody stream.Enveloper, opts ...yarpc.CallOption) (pstream.Reader, error) | ||
CallOneway(ctx context.Context, reqBody stream.Enveloper, opts ...yarpc.CallOption) (transport.Ack, error) | ||
Call(ctx context.Context, reqBody envelope.Enveloper, opts ...yarpc.CallOption) (stream.Reader, error) | ||
CallOneway(ctx context.Context, reqBody envelope.Enveloper, opts ...yarpc.CallOption) (transport.Ack, error) | ||
} | ||
|
||
// NewStreamClient creates a new Thrift client. | ||
func NewStreamClient(c Config, opts ...ClientOption) StreamClient { | ||
// Code generated for Thrift client instantiation will probably be something | ||
// like this: | ||
// | ||
// func New(cc transport.ClientConfig, opts ...ClientOption) *MyServiceClient { | ||
// c := thrift.NewStreamClient(thrift.Config{ | ||
// Service: "MyService", | ||
// ClientConfig: cc, | ||
// Protocol: binary.Default, | ||
// }, opts...) | ||
// return &MyServiceClient{client: c} | ||
// } | ||
// | ||
// So Config is really the internal config as far as consumers of the | ||
// generated client are concerned. | ||
|
||
var cc clientConfig | ||
for _, opt := range opts { | ||
opt.applyClientOption(&cc) | ||
} | ||
|
||
p := stream.Protocol(binary.Default) | ||
if cc.Protocol != nil { | ||
val, ok := cc.Protocol.(stream.Protocol) | ||
if !ok { | ||
panic("yarpc.NewStreamClient expects a Protocol of type stream.Protocol") | ||
} | ||
p = val | ||
} | ||
|
||
svc := c.Service | ||
if cc.ServiceName != "" { | ||
svc = cc.ServiceName | ||
} | ||
|
||
return streamThriftClient{ | ||
p: p, | ||
cc: c.ClientConfig, | ||
thriftService: svc, | ||
Enveloping: cc.Enveloping, | ||
Multiplexed: cc.Multiplexed, | ||
} | ||
} | ||
|
||
type streamThriftClient struct { | ||
cc transport.ClientConfig | ||
p stream.Protocol | ||
|
||
// name of the Thrift service | ||
thriftService string | ||
Enveloping bool | ||
Multiplexed bool | ||
} | ||
|
||
func (c streamThriftClient) Call(ctx context.Context, reqBody envelope.Enveloper, opts ...yarpc.CallOption) (stream.Reader, error) { | ||
// Code generated for Thrift client calls will probably be something like | ||
// this: | ||
// | ||
// func (c *MyServiceClient) someMethod(ctx context.Context, arg1 Arg1Type, arg2 arg2Type, opts ...yarpc.CallOption) (returnValue, error) { | ||
// args := myservice.SomeMethodHelper.Args(arg1, arg2) | ||
// resBody, err := c.client.Call(ctx, args, opts...) | ||
// var result myservice.SomeMethodResult | ||
// if err = result.Decode(resBody); err != nil { | ||
// return nil, err | ||
// } | ||
// success, err := myservice.SomeMethodHelper.UnwrapResponse(&result) | ||
// return success, err | ||
// } | ||
|
||
out := c.cc.GetUnaryOutbound() | ||
|
||
treq, proto, err := c.buildTransportRequest(reqBody) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
call := encodingapi.NewOutboundCall(encoding.FromOptions(opts)...) | ||
ctx, err = call.WriteToRequest(ctx, treq) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
tres, err := out.Call(ctx, treq) | ||
if err != nil { | ||
return nil, err | ||
} | ||
defer tres.Body.Close() | ||
|
||
if _, err = call.ReadFromResponse(ctx, tres); err != nil { | ||
return nil, err | ||
} | ||
|
||
var r io.Reader | ||
// optimization for avoiding additional buffer copy as tchannel outbound | ||
// already decodes the body into io.Reader compatible type | ||
// thrift deserializer reads sets, maps, and lists lazily which makes | ||
// buffer pool unusable as response handling is out of scope of this method | ||
if body, ok := tres.Body.(io.Reader); ok { | ||
r = body | ||
} else { | ||
buf := bytes.NewBuffer(make([]byte, 0, _defaultBufferSize)) | ||
if _, err = buf.ReadFrom(tres.Body); err != nil { | ||
return nil, err | ||
} | ||
r = bytes.NewReader(buf.Bytes()) | ||
} | ||
|
||
sr := proto.Reader(r) | ||
|
||
env, err := sr.ReadEnvelopeBegin() | ||
if err != nil { | ||
return nil, errors.ResponseBodyDecodeError(treq, err) | ||
} | ||
|
||
switch env.Type { | ||
case wire.Reply: | ||
return sr, nil | ||
case wire.Exception: | ||
defer sr.Close() | ||
var exc internal.TApplicationException | ||
if err = exc.Decode(sr); err != nil { | ||
return nil, errors.ResponseBodyDecodeError(treq, err) | ||
} | ||
return nil, thriftException{ | ||
Service: treq.Service, | ||
Procedure: treq.Procedure, | ||
Reason: &exc, | ||
} | ||
default: | ||
sr.Close() | ||
return nil, errors.ResponseBodyDecodeError( | ||
treq, errUnexpectedEnvelopeType(env.Type)) | ||
} | ||
|
||
} | ||
|
||
func (c streamThriftClient) CallOneway(ctx context.Context, reqBody envelope.Enveloper, opts ...yarpc.CallOption) (transport.Ack, error) { | ||
out := c.cc.GetOnewayOutbound() | ||
|
||
treq, _, err := c.buildTransportRequest(reqBody) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
call := encodingapi.NewOutboundCall(encoding.FromOptions(opts)...) | ||
ctx, err = call.WriteToRequest(ctx, treq) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return out.CallOneway(ctx, treq) | ||
} | ||
|
||
func (c streamThriftClient) buildTransportRequest(reqBody envelope.Enveloper) (*transport.Request, stream.Protocol, error) { | ||
proto := c.p | ||
treq := transport.Request{ | ||
Caller: c.cc.Caller(), | ||
Service: c.cc.Service(), | ||
Encoding: Encoding, | ||
Procedure: procedure.ToName(c.thriftService, reqBody.MethodName()), | ||
} | ||
|
||
envType := reqBody.EnvelopeType() | ||
if envType != wire.Call && envType != wire.OneWay { | ||
return nil, nil, errors.RequestBodyEncodeError( | ||
&treq, errUnexpectedEnvelopeType(envType), | ||
) | ||
} | ||
|
||
var buffer bytes.Buffer | ||
sw := proto.Writer(&buffer) | ||
defer sw.Close() | ||
|
||
if c.Enveloping { | ||
if err := sw.WriteEnvelopeBegin(stream.EnvelopeHeader{ | ||
Name: reqBody.MethodName(), | ||
Type: reqBody.EnvelopeType(), | ||
SeqID: 1, // don't care | ||
}); err != nil { | ||
return nil, nil, errors.RequestBodyEncodeError(&treq, err) | ||
} | ||
|
||
if err := reqBody.Encode(sw); err != nil { | ||
return nil, nil, errors.RequestBodyEncodeError(&treq, err) | ||
} | ||
|
||
if err := sw.WriteEnvelopeEnd(); err != nil { | ||
return nil, nil, errors.RequestBodyEncodeError(&treq, err) | ||
} | ||
} else { | ||
if err := reqBody.Encode(sw); err != nil { | ||
return nil, nil, errors.RequestBodyEncodeError(&treq, err) | ||
} | ||
} | ||
|
||
treq.Body = &buffer | ||
treq.BodySize = buffer.Len() | ||
return &treq, proto, nil | ||
} |