Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[vtctldclient-codegen] Add support in codegen for streaming RPCs #9064

Merged
merged 3 commits into from
Oct 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions go/vt/vtctl/localvtctldclient/bidi_stream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
Copyright 2021 The Vitess Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package localvtctldclient

import (
"context"
"io"
"sync"

"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)

type bidiStream struct {
m sync.RWMutex
ctx context.Context
errch chan error
sendClosed bool
}

func newBidiStream(ctx context.Context) *bidiStream { // nolint (TODO:@ajm188) this will be used in a future PR, and the codegen will produce invalid code for streaming rpcs without this
return &bidiStream{
ctx: ctx,
errch: make(chan error, 1),
}
}

func (bs *bidiStream) close(err error) {
if err == nil {
err = io.EOF
}

bs.m.Lock()
defer bs.m.Unlock()

bs.sendClosed = true
bs.errch <- err
}

var (
_ grpc.ClientStream = (*bidiStream)(nil)
_ grpc.ServerStream = (*bidiStream)(nil)
)

// client and server methods

func (bs *bidiStream) Context() context.Context { return bs.ctx }
func (bs *bidiStream) RecvMsg(m interface{}) error { return nil }
func (bs *bidiStream) SendMsg(m interface{}) error { return nil }

// client methods

func (bs *bidiStream) Header() (metadata.MD, error) { return nil, nil }
func (bs *bidiStream) Trailer() metadata.MD { return nil }
func (bs *bidiStream) CloseSend() error { return nil }

// server methods

func (bs *bidiStream) SendHeader(md metadata.MD) error { return nil }
func (bs *bidiStream) SetHeader(md metadata.MD) error { return nil }
func (bs *bidiStream) SetTrailer(md metadata.MD) {}
5 changes: 3 additions & 2 deletions go/vt/vtctl/localvtctldclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ import (
)

var (
m sync.RWMutex
server vtctlservicepb.VtctldServer
m sync.RWMutex
server vtctlservicepb.VtctldServer
errStreamClosed = errors.New("stream is closed for sending") // nolint (TODO:@ajm188) this will be used in a future PR, and the codegen will produce invalid code for streaming rpcs without this
)

type localVtctldClient struct {
Expand Down
114 changes: 102 additions & 12 deletions go/vt/vtctl/vtctldclient/codegen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,22 +130,43 @@ func main() { // nolint:funlen
// github.com/golang/protobuf/grpc, although in vitess we currently
// always use the former.

// In the case of unary RPCs, the first result is a Pointer. In the case
// of streaming RPCs, it is a Named type whose underlying type is an
// Interface.
//
// The second result is always error.
result := sig.Results().At(0)
switch result.Type().(type) {
case *types.Pointer:
localType, localImport, pkgPath, err = extractLocalPointerType(result)
case *types.Named:
switch result.Type().Underlying().(type) {
case *types.Interface:
f.IsStreaming = true
localType, localImport, pkgPath, err = extractLocalNamedType(result)
if err == nil && *local {
// We need to get the pointer type returned by `stream.Recv()`
// in the local case for the stream adapter.
var recvType, recvImport, recvPkgPath string
recvType, recvImport, recvPkgPath, err = extractRecvType(result)
if err == nil {
f.StreamMessage = buildParam("stream", recvImport, recvType, true)
importNames = addImport(recvImport, recvPkgPath, importNames, imports)
}
}
default:
err = fmt.Errorf("expected either pointer (for unary) or named interface (for streaming) rpc result type, got %T", result.Type().Underlying())
}
default:
err = fmt.Errorf("expected either pointer (for unary) or named interface (for streaming) rpc result type, got %T", result.Type())
}

localType, localImport, pkgPath, err = extractLocalPointerType(result) // (TODO|@amason): does not work for streaming rpcs
if err != nil {
panic(err)
}

f.Result.Name = result.Name()
f.Result.Type = "*" + localImport + "." + localType

if _, ok := imports[localImport]; !ok {
importNames = append(importNames, localImport)
}

imports[localImport] = pkgPath
f.Result = buildParam(result.Name(), localImport, localType, !f.IsStreaming)
importNames = addImport(localImport, pkgPath, importNames, imports)
}

sort.Strings(importNames)
Expand Down Expand Up @@ -203,9 +224,11 @@ type Import struct {
// Func is the variable part of a gRPC client interface method (i.e. not the
// context or dialopts arguments, or the error part of the result tuple).
type Func struct {
Name string
Param Param
Result Param
Name string
Param Param
Result Param
IsStreaming bool
StreamMessage Param
}

// Param represents an element of either a parameter list or result list. It
Expand All @@ -218,6 +241,28 @@ type Param struct {
Type string
}

func buildParam(name string, localImport string, localType string, isPointer bool) Param {
p := Param{
Name: name,
Type: fmt.Sprintf("%s.%s", localImport, localType),
}

if isPointer {
p.Type = "*" + p.Type
}

return p
}

func addImport(localImport string, pkgPath string, importNames []string, imports map[string]string) []string {
if _, ok := imports[localImport]; !ok {
importNames = append(importNames, localImport)
}

imports[localImport] = pkgPath
return importNames
}

func loadPackage(source string) (*packages.Package, error) {
pkgs, err := packages.Load(&packages.Config{
Mode: packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesInfo,
Expand Down Expand Up @@ -280,6 +325,19 @@ func rewriteProtoImports(pkg *types.Package) string {
return pkg.Name()
}

func extractLocalNamedType(v *types.Var) (name string, localImport string, pkgPath string, err error) {
named, ok := v.Type().(*types.Named)
if !ok {
return "", "", "", fmt.Errorf("expected a named type for %s, got %v", v.Name(), v.Type())
}

name = named.Obj().Name()
localImport = rewriteProtoImports(named.Obj().Pkg())
pkgPath = named.Obj().Pkg().Path()

return name, localImport, pkgPath, nil
}

func extractLocalPointerType(v *types.Var) (name string, localImport string, pkgPath string, err error) {
ptr, ok := v.Type().(*types.Pointer)
if !ok {
Expand All @@ -297,3 +355,35 @@ func extractLocalPointerType(v *types.Var) (name string, localImport string, pkg

return name, localImport, pkgPath, nil
}

func extractRecvType(v *types.Var) (name string, localImport string, pkgPath string, err error) {
named, ok := v.Type().(*types.Named)
if !ok {
return "", "", "", fmt.Errorf("expected a named type for %s, got %v", v.Name(), v.Type())
}

iface, ok := named.Underlying().(*types.Interface)
if !ok {
return "", "", "", fmt.Errorf("expected %s to name an interface type, got %v", v.Name(), named.Underlying())
}

for i := 0; i < iface.NumExplicitMethods(); i++ {
m := iface.ExplicitMethod(i)
if m.Name() != "Recv" {
continue
}

sig, ok := m.Type().(*types.Signature)
if !ok {
return "", "", "", fmt.Errorf("%s.Recv should have type Signature; got %v", v.Name(), m.Type())
}

if sig.Results().Len() != 2 {
return "", "", "", fmt.Errorf("%s.Recv should return two values, not %d", v.Name(), sig.Results().Len())
}

return extractLocalPointerType(sig.Results().At(0))
}

return "", "", "", fmt.Errorf("interface %s has no explicit method named Recv", named.Obj().Name())
}
71 changes: 66 additions & 5 deletions go/vt/vtctl/vtctldclient/codegen/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ limitations under the License.

package main

import "text/template"
import (
"strings"
"text/template"
)

const tmplStr = `// Code generated by {{ .ClientName }}-generator. DO NOT EDIT.

Expand Down Expand Up @@ -52,16 +55,74 @@ import (
{{ end -}}
)
{{ range .Methods }}
{{ if and $.Local .IsStreaming -}}
type {{ streamAdapterName .Name }} struct {
*bidiStream
ch chan {{ .StreamMessage.Type }}
}

func (stream *{{ streamAdapterName .Name }}) Recv() ({{ .StreamMessage.Type }}, error) {
select {
case <-stream.ctx.Done():
return nil, stream.ctx.Err()
case err := <-stream.errch:
return nil, err
case msg := <-stream.ch:
return msg, nil
}
}

func (stream *{{ streamAdapterName .Name }}) Send(msg {{ .StreamMessage.Type }}) error {
stream.m.RLock()
defer stream.m.RUnlock()

if stream.sendClosed {
return errStreamClosed
}

select {
case <-stream.ctx.Done():
return stream.ctx.Err()
case stream.ch <- msg:
return nil
}
}
{{ end -}}
// {{ .Name }} is part of the vtctlservicepb.VtctldClient interface.
func (client *{{ $.Type }}) {{ .Name }}(ctx context.Context, {{ .Param.Name }} {{ .Param.Type }}, opts ...grpc.CallOption) ({{ .Result.Type }}, error) {
{{- if not $.Local -}}
{{ if not $.Local -}}
if client.c == nil {
return nil, status.Error(codes.Unavailable, connClosedMsg)
}

{{ end -}}
return client.{{ if $.Local }}s{{ else }}c{{ end }}.{{ .Name }}(ctx, in{{ if not $.Local }}, opts...{{ end }})
return client.c.{{ .Name }}(ctx, in, opts...)
{{- else -}}
{{- if .IsStreaming -}}
stream := &{{ streamAdapterName .Name }}{
bidiStream: newBidiStream(ctx),
ch: make(chan {{ .StreamMessage.Type }}, 1),
}
go func() {
err := client.s.{{ .Name }}(in, stream)
stream.close(err)
}()

return stream, nil
{{- else -}}
return client.s.{{ .Name }}(ctx, in)
{{- end -}}
{{- end }}
}
{{ end }}`

var tmpl = template.Must(template.New("vtctldclient-generator").Parse(tmplStr))
var tmpl = template.Must(template.New("vtctldclient-generator").Funcs(map[string]interface{}{
"streamAdapterName": func(s string) string {
if len(s) == 0 {
return s
}

head := s[:1]
tail := s[1:]
return strings.ToLower(head) + tail + "StreamAdapter"
},
}).Parse(tmplStr))