diff --git a/Gopkg.lock b/Gopkg.lock index 25cf65ea15..260abfcaaf 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -144,12 +144,11 @@ revision = "7d11b49dc0769f6dbb0d1b19f3d48524d1bad9ad" [[projects]] - branch = "master" - digest = "1:25435262330720ca0cade25af7ee7fb96d0cb70cc1ea0c0961694681c12a90e6" + digest = "1:c710f7b09759fe5ef0122a55a14afa8fab25a51a56e0826c24499563ef9e0e92" name = "github.com/containerd/ttrpc" packages = ["."] pruneopts = "NUT" - revision = "69144327078caa5a2f1d5eda8bea6110bf16eeb3" + revision = "92c8520ef9f86600c650dd540266a007bf03670f" [[projects]] branch = "master" diff --git a/Gopkg.toml b/Gopkg.toml index 83366412eb..a4a4663e84 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -70,6 +70,10 @@ name = "github.com/gogo/protobuf" revision = "4cbf7e384e768b4e01799441fdf2a706a5635ae7" +[[override]] + name = "github.com/containerd/ttrpc" + revision = "92c8520ef9f86600c650dd540266a007bf03670f" + [[override]] branch = "master" name = "github.com/hashicorp/yamux" diff --git a/vendor/github.com/containerd/ttrpc/channel.go b/vendor/github.com/containerd/ttrpc/channel.go index 22f5496b4b..aa8c9541cf 100644 --- a/vendor/github.com/containerd/ttrpc/channel.go +++ b/vendor/github.com/containerd/ttrpc/channel.go @@ -18,7 +18,6 @@ package ttrpc import ( "bufio" - "context" "encoding/binary" "io" "net" @@ -98,7 +97,7 @@ func newChannel(conn net.Conn) *channel { // returned will be valid and caller should send that along to // the correct consumer. The bytes on the underlying channel // will be discarded. -func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) { +func (ch *channel) recv() (messageHeader, []byte, error) { mh, err := readMessageHeader(ch.hrbuf[:], ch.br) if err != nil { return messageHeader{}, nil, err @@ -120,7 +119,7 @@ func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) { return mh, p, nil } -func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error { +func (ch *channel) send(streamID uint32, t messageType, p []byte) error { if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil { return err } diff --git a/vendor/github.com/containerd/ttrpc/client.go b/vendor/github.com/containerd/ttrpc/client.go index 41e83c2d00..bdd1d12e7a 100644 --- a/vendor/github.com/containerd/ttrpc/client.go +++ b/vendor/github.com/containerd/ttrpc/client.go @@ -29,6 +29,7 @@ import ( "github.com/gogo/protobuf/proto" "github.com/pkg/errors" "github.com/sirupsen/logrus" + "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -36,28 +37,56 @@ import ( // closed. var ErrClosed = errors.New("ttrpc: closed") +// Client for a ttrpc server type Client struct { codec codec conn net.Conn channel *channel calls chan *callRequest - closed chan struct{} - closeOnce sync.Once - closeFunc func() - done chan struct{} - err error + ctx context.Context + closed func() + + closeOnce sync.Once + userCloseFunc func() + + errOnce sync.Once + err error + interceptor UnaryClientInterceptor +} + +// ClientOpts configures a client +type ClientOpts func(c *Client) + +// WithOnClose sets the close func whenever the client's Close() method is called +func WithOnClose(onClose func()) ClientOpts { + return func(c *Client) { + c.userCloseFunc = onClose + } +} + +// WithUnaryClientInterceptor sets the provided client interceptor +func WithUnaryClientInterceptor(i UnaryClientInterceptor) ClientOpts { + return func(c *Client) { + c.interceptor = i + } } -func NewClient(conn net.Conn) *Client { +func NewClient(conn net.Conn, opts ...ClientOpts) *Client { + ctx, cancel := context.WithCancel(context.Background()) c := &Client{ - codec: codec{}, - conn: conn, - channel: newChannel(conn), - calls: make(chan *callRequest), - closed: make(chan struct{}), - done: make(chan struct{}), - closeFunc: func() {}, + codec: codec{}, + conn: conn, + channel: newChannel(conn), + calls: make(chan *callRequest), + closed: cancel, + ctx: ctx, + userCloseFunc: func() {}, + interceptor: defaultClientInterceptor, + } + + for _, o := range opts { + o(c) } go c.run() @@ -87,11 +116,18 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int cresp = &Response{} ) + if metadata, ok := GetMetadata(ctx); ok { + metadata.setRequest(creq) + } + if dl, ok := ctx.Deadline(); ok { creq.TimeoutNano = dl.Sub(time.Now()).Nanoseconds() } - if err := c.dispatch(ctx, creq, cresp); err != nil { + info := &UnaryClientInfo{ + FullMethod: fullPath(service, method), + } + if err := c.interceptor(ctx, creq, cresp, info, c.dispatch); err != nil { return err } @@ -99,11 +135,10 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int return err } - if cresp.Status == nil { - return errors.New("no status provided on response") + if cresp.Status != nil && cresp.Status.Code != int32(codes.OK) { + return status.ErrorProto(cresp.Status) } - - return status.ErrorProto(cresp.Status) + return nil } func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error { @@ -119,8 +154,8 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err case <-ctx.Done(): return ctx.Err() case c.calls <- call: - case <-c.done: - return c.err + case <-c.ctx.Done(): + return c.error() } select { @@ -128,75 +163,100 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err return ctx.Err() case err := <-errs: return filterCloseErr(err) - case <-c.done: - return c.err + case <-c.ctx.Done(): + return c.error() } } func (c *Client) Close() error { c.closeOnce.Do(func() { - close(c.closed) + c.closed() }) - return nil } -// OnClose allows a close func to be called when the server is closed -func (c *Client) OnClose(closer func()) { - c.closeFunc = closer -} - type message struct { messageHeader p []byte err error } -func (c *Client) run() { - var ( - streamID uint32 = 1 - waiters = make(map[uint32]*callRequest) - calls = c.calls - incoming = make(chan *message) - shutdown = make(chan struct{}) - shutdownErr error - ) +type receiver struct { + wg *sync.WaitGroup + messages chan *message + err error +} - go func() { - defer close(shutdown) +func (r *receiver) run(ctx context.Context, c *channel) { + defer r.wg.Done() - // start one more goroutine to recv messages without blocking. - for { - mh, p, err := c.channel.recv(context.TODO()) + for { + select { + case <-ctx.Done(): + r.err = ctx.Err() + return + default: + mh, p, err := c.recv() if err != nil { _, ok := status.FromError(err) if !ok { // treat all errors that are not an rpc status as terminal. // all others poison the connection. - shutdownErr = err + r.err = filterCloseErr(err) return } } select { - case incoming <- &message{ + case r.messages <- &message{ messageHeader: mh, p: p[:mh.Length], err: err, }: - case <-c.done: + case <-ctx.Done(): + r.err = ctx.Err() return } } + } +} + +func (c *Client) run() { + var ( + streamID uint32 = 1 + waiters = make(map[uint32]*callRequest) + calls = c.calls + incoming = make(chan *message) + receiversDone = make(chan struct{}) + wg sync.WaitGroup + ) + + // broadcast the shutdown error to the remaining waiters. + abortWaiters := func(wErr error) { + for _, waiter := range waiters { + waiter.errs <- wErr + } + } + recv := &receiver{ + wg: &wg, + messages: incoming, + } + wg.Add(1) + + go func() { + wg.Wait() + close(receiversDone) }() + go recv.run(c.ctx, c.channel) - defer c.conn.Close() - defer close(c.done) - defer c.closeFunc() + defer func() { + c.conn.Close() + c.userCloseFunc() + }() for { select { case call := <-calls: - if err := c.send(call.ctx, streamID, messageTypeRequest, call.req); err != nil { + if err := c.send(streamID, messageTypeRequest, call.req); err != nil { call.errs <- err continue } @@ -212,41 +272,42 @@ func (c *Client) run() { call.errs <- c.recv(call.resp, msg) delete(waiters, msg.StreamID) - case <-shutdown: - if shutdownErr != nil { - shutdownErr = filterCloseErr(shutdownErr) - } else { - shutdownErr = ErrClosed - } - - shutdownErr = errors.Wrapf(shutdownErr, "ttrpc: client shutting down") - - c.err = shutdownErr - for _, waiter := range waiters { - waiter.errs <- shutdownErr + case <-receiversDone: + // all the receivers have exited + if recv.err != nil { + c.setError(recv.err) } + // don't return out, let the close of the context trigger the abort of waiters c.Close() - return - case <-c.closed: - if c.err == nil { - c.err = ErrClosed - } - // broadcast the shutdown error to the remaining waiters. - for _, waiter := range waiters { - waiter.errs <- c.err - } + case <-c.ctx.Done(): + abortWaiters(c.error()) return } } } -func (c *Client) send(ctx context.Context, streamID uint32, mtype messageType, msg interface{}) error { +func (c *Client) error() error { + c.errOnce.Do(func() { + if c.err == nil { + c.err = ErrClosed + } + }) + return c.err +} + +func (c *Client) setError(err error) { + c.errOnce.Do(func() { + c.err = err + }) +} + +func (c *Client) send(streamID uint32, mtype messageType, msg interface{}) error { p, err := c.codec.Marshal(msg) if err != nil { return err } - return c.channel.send(ctx, streamID, mtype, p) + return c.channel.send(streamID, mtype, p) } func (c *Client) recv(resp *Response, msg *message) error { @@ -267,22 +328,21 @@ func (c *Client) recv(resp *Response, msg *message) error { // // This purposely ignores errors with a wrapped cause. func filterCloseErr(err error) error { - if err == nil { + switch { + case err == nil: return nil - } - - if err == io.EOF { + case err == io.EOF: return ErrClosed - } - - if strings.Contains(err.Error(), "use of closed network connection") { + case errors.Cause(err) == io.EOF: return ErrClosed - } - - // if we have an epipe on a write, we cast to errclosed - if oerr, ok := err.(*net.OpError); ok && oerr.Op == "write" { - if serr, ok := oerr.Err.(*os.SyscallError); ok && serr.Err == syscall.EPIPE { - return ErrClosed + case strings.Contains(err.Error(), "use of closed network connection"): + return ErrClosed + default: + // if we have an epipe on a write, we cast to errclosed + if oerr, ok := err.(*net.OpError); ok && oerr.Op == "write" { + if serr, ok := oerr.Err.(*os.SyscallError); ok && serr.Err == syscall.EPIPE { + return ErrClosed + } } } diff --git a/vendor/github.com/containerd/ttrpc/config.go b/vendor/github.com/containerd/ttrpc/config.go index 019b7a09dd..6a53c112b7 100644 --- a/vendor/github.com/containerd/ttrpc/config.go +++ b/vendor/github.com/containerd/ttrpc/config.go @@ -19,9 +19,11 @@ package ttrpc import "github.com/pkg/errors" type serverConfig struct { - handshaker Handshaker + handshaker Handshaker + interceptor UnaryServerInterceptor } +// ServerOpt for configuring a ttrpc server type ServerOpt func(*serverConfig) error // WithServerHandshaker can be passed to NewServer to ensure that the @@ -37,3 +39,14 @@ func WithServerHandshaker(handshaker Handshaker) ServerOpt { return nil } } + +// WithUnaryServerInterceptor sets the provided interceptor on the server +func WithUnaryServerInterceptor(i UnaryServerInterceptor) ServerOpt { + return func(c *serverConfig) error { + if c.interceptor != nil { + return errors.New("only one interceptor allowed per server") + } + c.interceptor = i + return nil + } +} diff --git a/vendor/github.com/containerd/ttrpc/interceptor.go b/vendor/github.com/containerd/ttrpc/interceptor.go new file mode 100644 index 0000000000..c1219dac65 --- /dev/null +++ b/vendor/github.com/containerd/ttrpc/interceptor.go @@ -0,0 +1,50 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package ttrpc + +import "context" + +// UnaryServerInfo provides information about the server request +type UnaryServerInfo struct { + FullMethod string +} + +// UnaryClientInfo provides information about the client request +type UnaryClientInfo struct { + FullMethod string +} + +// Unmarshaler contains the server request data and allows it to be unmarshaled +// into a concrete type +type Unmarshaler func(interface{}) error + +// Invoker invokes the client's request and response from the ttrpc server +type Invoker func(context.Context, *Request, *Response) error + +// UnaryServerInterceptor specifies the interceptor function for server request/response +type UnaryServerInterceptor func(context.Context, Unmarshaler, *UnaryServerInfo, Method) (interface{}, error) + +// UnaryClientInterceptor specifies the interceptor function for client request/response +type UnaryClientInterceptor func(context.Context, *Request, *Response, *UnaryClientInfo, Invoker) error + +func defaultServerInterceptor(ctx context.Context, unmarshal Unmarshaler, info *UnaryServerInfo, method Method) (interface{}, error) { + return method(ctx, unmarshal) +} + +func defaultClientInterceptor(ctx context.Context, req *Request, resp *Response, _ *UnaryClientInfo, invoker Invoker) error { + return invoker(ctx, req, resp) +} diff --git a/vendor/github.com/containerd/ttrpc/metadata.go b/vendor/github.com/containerd/ttrpc/metadata.go new file mode 100644 index 0000000000..ce8c0d13c4 --- /dev/null +++ b/vendor/github.com/containerd/ttrpc/metadata.go @@ -0,0 +1,107 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package ttrpc + +import ( + "context" + "strings" +) + +// MD is the user type for ttrpc metadata +type MD map[string][]string + +// Get returns the metadata for a given key when they exist. +// If there is no metadata, a nil slice and false are returned. +func (m MD) Get(key string) ([]string, bool) { + key = strings.ToLower(key) + list, ok := m[key] + if !ok || len(list) == 0 { + return nil, false + } + + return list, true +} + +// Set sets the provided values for a given key. +// The values will overwrite any existing values. +// If no values provided, a key will be deleted. +func (m MD) Set(key string, values ...string) { + key = strings.ToLower(key) + if len(values) == 0 { + delete(m, key) + return + } + m[key] = values +} + +// Append appends additional values to the given key. +func (m MD) Append(key string, values ...string) { + key = strings.ToLower(key) + if len(values) == 0 { + return + } + current, ok := m[key] + if ok { + m.Set(key, append(current, values...)...) + } else { + m.Set(key, values...) + } +} + +func (m MD) setRequest(r *Request) { + for k, values := range m { + for _, v := range values { + r.Metadata = append(r.Metadata, &KeyValue{ + Key: k, + Value: v, + }) + } + } +} + +func (m MD) fromRequest(r *Request) { + for _, kv := range r.Metadata { + m[kv.Key] = append(m[kv.Key], kv.Value) + } +} + +type metadataKey struct{} + +// GetMetadata retrieves metadata from context.Context (previously attached with WithMetadata) +func GetMetadata(ctx context.Context) (MD, bool) { + metadata, ok := ctx.Value(metadataKey{}).(MD) + return metadata, ok +} + +// GetMetadataValue gets a specific metadata value by name from context.Context +func GetMetadataValue(ctx context.Context, name string) (string, bool) { + metadata, ok := GetMetadata(ctx) + if !ok { + return "", false + } + + if list, ok := metadata.Get(name); ok { + return list[0], true + } + + return "", false +} + +// WithMetadata attaches metadata map to a context.Context +func WithMetadata(ctx context.Context, md MD) context.Context { + return context.WithValue(ctx, metadataKey{}, md) +} diff --git a/vendor/github.com/containerd/ttrpc/server.go b/vendor/github.com/containerd/ttrpc/server.go index 40804eac0d..1d4f1df653 100644 --- a/vendor/github.com/containerd/ttrpc/server.go +++ b/vendor/github.com/containerd/ttrpc/server.go @@ -53,10 +53,13 @@ func NewServer(opts ...ServerOpt) (*Server, error) { return nil, err } } + if config.interceptor == nil { + config.interceptor = defaultServerInterceptor + } return &Server{ config: config, - services: newServiceSet(), + services: newServiceSet(config.interceptor), done: make(chan struct{}), listeners: make(map[net.Listener]struct{}), connections: make(map[*serverConn]struct{}), @@ -341,7 +344,7 @@ func (c *serverConn) run(sctx context.Context) { default: // proceed } - mh, p, err := ch.recv(ctx) + mh, p, err := ch.recv() if err != nil { status, ok := status.FromError(err) if !ok { @@ -438,7 +441,7 @@ func (c *serverConn) run(sctx context.Context) { return } - if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil { + if err := ch.send(response.id, messageTypeResponse, p); err != nil { logrus.WithError(err).Error("failed sending message on channel") return } @@ -449,7 +452,12 @@ func (c *serverConn) run(sctx context.Context) { // branch. Basically, it means that we are no longer receiving // requests due to a terminal error. recvErr = nil // connection is now "closing" - if err != nil && err != io.EOF { + if err == io.EOF || err == io.ErrUnexpectedEOF { + // The client went away and we should stop processing + // requests, so that the client connection is closed + return + } + if err != nil { logrus.WithError(err).Error("error receiving message") } case <-shutdown: @@ -461,6 +469,12 @@ func (c *serverConn) run(sctx context.Context) { var noopFunc = func() {} func getRequestContext(ctx context.Context, req *Request) (retCtx context.Context, cancel func()) { + if len(req.Metadata) > 0 { + md := MD{} + md.fromRequest(req) + ctx = WithMetadata(ctx, md) + } + cancel = noopFunc if req.TimeoutNano == 0 { return ctx, cancel diff --git a/vendor/github.com/containerd/ttrpc/services.go b/vendor/github.com/containerd/ttrpc/services.go index fe1cade5ad..0eacfd79aa 100644 --- a/vendor/github.com/containerd/ttrpc/services.go +++ b/vendor/github.com/containerd/ttrpc/services.go @@ -37,12 +37,14 @@ type ServiceDesc struct { } type serviceSet struct { - services map[string]ServiceDesc + services map[string]ServiceDesc + interceptor UnaryServerInterceptor } -func newServiceSet() *serviceSet { +func newServiceSet(interceptor UnaryServerInterceptor) *serviceSet { return &serviceSet{ - services: make(map[string]ServiceDesc), + services: make(map[string]ServiceDesc), + interceptor: interceptor, } } @@ -84,7 +86,11 @@ func (s *serviceSet) dispatch(ctx context.Context, serviceName, methodName strin return nil } - resp, err := method(ctx, unmarshal) + info := &UnaryServerInfo{ + FullMethod: fullPath(serviceName, methodName), + } + + resp, err := s.interceptor(ctx, unmarshal, info, method) if err != nil { return nil, err } @@ -146,5 +152,5 @@ func convertCode(err error) codes.Code { } func fullPath(service, method string) string { - return "/" + path.Join("/", service, method) + return "/" + path.Join(service, method) } diff --git a/vendor/github.com/containerd/ttrpc/types.go b/vendor/github.com/containerd/ttrpc/types.go index a6b3b818e0..9a1c19a723 100644 --- a/vendor/github.com/containerd/ttrpc/types.go +++ b/vendor/github.com/containerd/ttrpc/types.go @@ -23,10 +23,11 @@ import ( ) type Request struct { - Service string `protobuf:"bytes,1,opt,name=service,proto3"` - Method string `protobuf:"bytes,2,opt,name=method,proto3"` - Payload []byte `protobuf:"bytes,3,opt,name=payload,proto3"` - TimeoutNano int64 `protobuf:"varint,4,opt,name=timeout_nano,proto3"` + Service string `protobuf:"bytes,1,opt,name=service,proto3"` + Method string `protobuf:"bytes,2,opt,name=method,proto3"` + Payload []byte `protobuf:"bytes,3,opt,name=payload,proto3"` + TimeoutNano int64 `protobuf:"varint,4,opt,name=timeout_nano,proto3"` + Metadata []*KeyValue `protobuf:"bytes,5,rep,name=metadata,proto3"` } func (r *Request) Reset() { *r = Request{} } @@ -41,3 +42,22 @@ type Response struct { func (r *Response) Reset() { *r = Response{} } func (r *Response) String() string { return fmt.Sprintf("%+#v", r) } func (r *Response) ProtoMessage() {} + +type StringList struct { + List []string `protobuf:"bytes,1,rep,name=list,proto3"` +} + +func (r *StringList) Reset() { *r = StringList{} } +func (r *StringList) String() string { return fmt.Sprintf("%+#v", r) } +func (r *StringList) ProtoMessage() {} + +func makeStringList(item ...string) StringList { return StringList{List: item} } + +type KeyValue struct { + Key string `protobuf:"bytes,1,opt,name=key,proto3"` + Value string `protobuf:"bytes,2,opt,name=value,proto3"` +} + +func (m *KeyValue) Reset() { *m = KeyValue{} } +func (*KeyValue) ProtoMessage() {} +func (m *KeyValue) String() string { return fmt.Sprintf("%+#v", m) }