diff --git a/go/vt/vtctl/localvtctldclient/bidi_stream.go b/go/vt/vtctl/localvtctldclient/bidi_stream.go new file mode 100644 index 00000000000..541bd0046e1 --- /dev/null +++ b/go/vt/vtctl/localvtctldclient/bidi_stream.go @@ -0,0 +1,75 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package localvtctldclient + +import ( + "context" + "io" + "sync" + + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +type bidiStream struct { + m sync.RWMutex + ctx context.Context + errch chan error + sendClosed bool +} + +func newBidiStream(ctx context.Context) *bidiStream { // nolint (TODO:@ajm188) this will be used in a future PR, and the codegen will produce invalid code for streaming rpcs without this + return &bidiStream{ + ctx: ctx, + errch: make(chan error, 1), + } +} + +func (bs *bidiStream) close(err error) { + if err == nil { + err = io.EOF + } + + bs.m.Lock() + defer bs.m.Unlock() + + bs.sendClosed = true + bs.errch <- err +} + +var ( + _ grpc.ClientStream = (*bidiStream)(nil) + _ grpc.ServerStream = (*bidiStream)(nil) +) + +// client and server methods + +func (bs *bidiStream) Context() context.Context { return bs.ctx } +func (bs *bidiStream) RecvMsg(m interface{}) error { return nil } +func (bs *bidiStream) SendMsg(m interface{}) error { return nil } + +// client methods + +func (bs *bidiStream) Header() (metadata.MD, error) { return nil, nil } +func (bs *bidiStream) Trailer() metadata.MD { return nil } +func (bs *bidiStream) CloseSend() error { return nil } + +// server methods + +func (bs *bidiStream) SendHeader(md metadata.MD) error { return nil } +func (bs *bidiStream) SetHeader(md metadata.MD) error { return nil } +func (bs *bidiStream) SetTrailer(md metadata.MD) {} diff --git a/go/vt/vtctl/localvtctldclient/client.go b/go/vt/vtctl/localvtctldclient/client.go index 2580346e54c..c7db19e8804 100644 --- a/go/vt/vtctl/localvtctldclient/client.go +++ b/go/vt/vtctl/localvtctldclient/client.go @@ -26,8 +26,9 @@ import ( ) var ( - m sync.RWMutex - server vtctlservicepb.VtctldServer + m sync.RWMutex + server vtctlservicepb.VtctldServer + errStreamClosed = errors.New("stream is closed for sending") // nolint (TODO:@ajm188) this will be used in a future PR, and the codegen will produce invalid code for streaming rpcs without this ) type localVtctldClient struct { diff --git a/go/vt/vtctl/vtctldclient/codegen/main.go b/go/vt/vtctl/vtctldclient/codegen/main.go index f7001fab7e2..c7d8ef60bad 100644 --- a/go/vt/vtctl/vtctldclient/codegen/main.go +++ b/go/vt/vtctl/vtctldclient/codegen/main.go @@ -130,22 +130,43 @@ func main() { // nolint:funlen // github.com/golang/protobuf/grpc, although in vitess we currently // always use the former. + // In the case of unary RPCs, the first result is a Pointer. In the case + // of streaming RPCs, it is a Named type whose underlying type is an + // Interface. + // // The second result is always error. result := sig.Results().At(0) + switch result.Type().(type) { + case *types.Pointer: + localType, localImport, pkgPath, err = extractLocalPointerType(result) + case *types.Named: + switch result.Type().Underlying().(type) { + case *types.Interface: + f.IsStreaming = true + localType, localImport, pkgPath, err = extractLocalNamedType(result) + if err == nil && *local { + // We need to get the pointer type returned by `stream.Recv()` + // in the local case for the stream adapter. + var recvType, recvImport, recvPkgPath string + recvType, recvImport, recvPkgPath, err = extractRecvType(result) + if err == nil { + f.StreamMessage = buildParam("stream", recvImport, recvType, true) + importNames = addImport(recvImport, recvPkgPath, importNames, imports) + } + } + default: + err = fmt.Errorf("expected either pointer (for unary) or named interface (for streaming) rpc result type, got %T", result.Type().Underlying()) + } + default: + err = fmt.Errorf("expected either pointer (for unary) or named interface (for streaming) rpc result type, got %T", result.Type()) + } - localType, localImport, pkgPath, err = extractLocalPointerType(result) // (TODO|@amason): does not work for streaming rpcs if err != nil { panic(err) } - f.Result.Name = result.Name() - f.Result.Type = "*" + localImport + "." + localType - - if _, ok := imports[localImport]; !ok { - importNames = append(importNames, localImport) - } - - imports[localImport] = pkgPath + f.Result = buildParam(result.Name(), localImport, localType, !f.IsStreaming) + importNames = addImport(localImport, pkgPath, importNames, imports) } sort.Strings(importNames) @@ -203,9 +224,11 @@ type Import struct { // Func is the variable part of a gRPC client interface method (i.e. not the // context or dialopts arguments, or the error part of the result tuple). type Func struct { - Name string - Param Param - Result Param + Name string + Param Param + Result Param + IsStreaming bool + StreamMessage Param } // Param represents an element of either a parameter list or result list. It @@ -218,6 +241,28 @@ type Param struct { Type string } +func buildParam(name string, localImport string, localType string, isPointer bool) Param { + p := Param{ + Name: name, + Type: fmt.Sprintf("%s.%s", localImport, localType), + } + + if isPointer { + p.Type = "*" + p.Type + } + + return p +} + +func addImport(localImport string, pkgPath string, importNames []string, imports map[string]string) []string { + if _, ok := imports[localImport]; !ok { + importNames = append(importNames, localImport) + } + + imports[localImport] = pkgPath + return importNames +} + func loadPackage(source string) (*packages.Package, error) { pkgs, err := packages.Load(&packages.Config{ Mode: packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesInfo, @@ -280,6 +325,19 @@ func rewriteProtoImports(pkg *types.Package) string { return pkg.Name() } +func extractLocalNamedType(v *types.Var) (name string, localImport string, pkgPath string, err error) { + named, ok := v.Type().(*types.Named) + if !ok { + return "", "", "", fmt.Errorf("expected a named type for %s, got %v", v.Name(), v.Type()) + } + + name = named.Obj().Name() + localImport = rewriteProtoImports(named.Obj().Pkg()) + pkgPath = named.Obj().Pkg().Path() + + return name, localImport, pkgPath, nil +} + func extractLocalPointerType(v *types.Var) (name string, localImport string, pkgPath string, err error) { ptr, ok := v.Type().(*types.Pointer) if !ok { @@ -297,3 +355,35 @@ func extractLocalPointerType(v *types.Var) (name string, localImport string, pkg return name, localImport, pkgPath, nil } + +func extractRecvType(v *types.Var) (name string, localImport string, pkgPath string, err error) { + named, ok := v.Type().(*types.Named) + if !ok { + return "", "", "", fmt.Errorf("expected a named type for %s, got %v", v.Name(), v.Type()) + } + + iface, ok := named.Underlying().(*types.Interface) + if !ok { + return "", "", "", fmt.Errorf("expected %s to name an interface type, got %v", v.Name(), named.Underlying()) + } + + for i := 0; i < iface.NumExplicitMethods(); i++ { + m := iface.ExplicitMethod(i) + if m.Name() != "Recv" { + continue + } + + sig, ok := m.Type().(*types.Signature) + if !ok { + return "", "", "", fmt.Errorf("%s.Recv should have type Signature; got %v", v.Name(), m.Type()) + } + + if sig.Results().Len() != 2 { + return "", "", "", fmt.Errorf("%s.Recv should return two values, not %d", v.Name(), sig.Results().Len()) + } + + return extractLocalPointerType(sig.Results().At(0)) + } + + return "", "", "", fmt.Errorf("interface %s has no explicit method named Recv", named.Obj().Name()) +} diff --git a/go/vt/vtctl/vtctldclient/codegen/template.go b/go/vt/vtctl/vtctldclient/codegen/template.go index fbfd6ef6e7d..2bdb5b9ba9f 100644 --- a/go/vt/vtctl/vtctldclient/codegen/template.go +++ b/go/vt/vtctl/vtctldclient/codegen/template.go @@ -16,7 +16,10 @@ limitations under the License. package main -import "text/template" +import ( + "strings" + "text/template" +) const tmplStr = `// Code generated by {{ .ClientName }}-generator. DO NOT EDIT. @@ -52,16 +55,74 @@ import ( {{ end -}} ) {{ range .Methods }} +{{ if and $.Local .IsStreaming -}} +type {{ streamAdapterName .Name }} struct { + *bidiStream + ch chan {{ .StreamMessage.Type }} +} + +func (stream *{{ streamAdapterName .Name }}) Recv() ({{ .StreamMessage.Type }}, error) { + select { + case <-stream.ctx.Done(): + return nil, stream.ctx.Err() + case err := <-stream.errch: + return nil, err + case msg := <-stream.ch: + return msg, nil + } +} + +func (stream *{{ streamAdapterName .Name }}) Send(msg {{ .StreamMessage.Type }}) error { + stream.m.RLock() + defer stream.m.RUnlock() + + if stream.sendClosed { + return errStreamClosed + } + + select { + case <-stream.ctx.Done(): + return stream.ctx.Err() + case stream.ch <- msg: + return nil + } +} +{{ end -}} // {{ .Name }} is part of the vtctlservicepb.VtctldClient interface. func (client *{{ $.Type }}) {{ .Name }}(ctx context.Context, {{ .Param.Name }} {{ .Param.Type }}, opts ...grpc.CallOption) ({{ .Result.Type }}, error) { - {{- if not $.Local -}} + {{ if not $.Local -}} if client.c == nil { return nil, status.Error(codes.Unavailable, connClosedMsg) } - {{ end -}} - return client.{{ if $.Local }}s{{ else }}c{{ end }}.{{ .Name }}(ctx, in{{ if not $.Local }}, opts...{{ end }}) + return client.c.{{ .Name }}(ctx, in, opts...) + {{- else -}} + {{- if .IsStreaming -}} + stream := &{{ streamAdapterName .Name }}{ + bidiStream: newBidiStream(ctx), + ch: make(chan {{ .StreamMessage.Type }}, 1), + } + go func() { + err := client.s.{{ .Name }}(in, stream) + stream.close(err) + }() + + return stream, nil + {{- else -}} + return client.s.{{ .Name }}(ctx, in) + {{- end -}} + {{- end }} } {{ end }}` -var tmpl = template.Must(template.New("vtctldclient-generator").Parse(tmplStr)) +var tmpl = template.Must(template.New("vtctldclient-generator").Funcs(map[string]interface{}{ + "streamAdapterName": func(s string) string { + if len(s) == 0 { + return s + } + + head := s[:1] + tail := s[1:] + return strings.ToLower(head) + tail + "StreamAdapter" + }, +}).Parse(tmplStr))