Skip to content

Commit

Permalink
feat: expose invalid argument error to clients in bidirectional strea…
Browse files Browse the repository at this point in the history
…ming (#4795) (#4819)

* feat: expose invalid argument error to clients in bidirectional streaming (#4795)

* fix for loop

* remove for loop

* revert to for loop

* move stream close to not block stream.header()

---------

Co-authored-by: Ian Lee <ianlee@Ians-MacBook-Pro-2.local>
  • Loading branch information
ianbbqzy and Ian Lee authored Oct 23, 2024
1 parent 54f5a01 commit 830ba27
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 18 deletions.
24 changes: 18 additions & 6 deletions examples/internal/proto/examplepb/flow_combination.pb.gw.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 18 additions & 6 deletions examples/internal/proto/examplepb/stream.pb.gw.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 24 additions & 5 deletions protoc-gen-grpc-gateway/internal/gengateway/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,9 @@ var _ = metadata.Join
`))

_ = template.Must(handlerTemplate.New("request-func-signature").Parse(strings.ReplaceAll(`
{{if .Method.GetServerStreaming}}
{{if and .Method.GetClientStreaming .Method.GetServerStreaming}}
func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, marshaler runtime.Marshaler, client {{.Method.Service.InstanceName}}Client, req *http.Request, pathParams map[string]string) ({{.Method.Service.InstanceName}}_{{.Method.GetName}}Client, runtime.ServerMetadata, chan error, error)
{{else if .Method.GetServerStreaming}}
func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, marshaler runtime.Marshaler, client {{.Method.Service.InstanceName}}Client, req *http.Request, pathParams map[string]string) ({{.Method.Service.InstanceName}}_{{.Method.GetName}}Client, runtime.ServerMetadata, error)
{{else}}
func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, marshaler runtime.Marshaler, client {{.Method.Service.InstanceName}}Client, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error)
Expand Down Expand Up @@ -439,10 +441,12 @@ var (
_ = template.Must(handlerTemplate.New("bidi-streaming-request-func").Parse(`
{{template "request-func-signature" .}} {
var metadata runtime.ServerMetadata
errChan := make(chan error, 1)
stream, err := client.{{.Method.GetName}}(ctx)
if err != nil {
grpclog.Errorf("Failed to start streaming: %v", err)
return nil, metadata, err
close(errChan)
return nil, metadata, errChan, err
}
dec := marshaler.NewDecoder(req.Body)
handleSend := func() error {
Expand All @@ -453,7 +457,7 @@ var (
}
if err != nil {
grpclog.Errorf("Failed to decode request: %v", err)
return err
return status.Errorf(codes.InvalidArgument, "Failed to decode request: %v", err)
}
if err := stream.Send(&protoReq); err != nil {
grpclog.Errorf("Failed to send request: %v", err)
Expand All @@ -462,8 +466,10 @@ var (
return nil
}
go func() {
defer close(errChan)
for {
if err := handleSend(); err != nil {
errChan <- err
break
}
}
Expand All @@ -474,10 +480,10 @@ var (
header, err := stream.Header()
if err != nil {
grpclog.Errorf("Failed to get header from client: %v", err)
return nil, metadata, err
return nil, metadata, errChan, err
}
metadata.HeaderMD = header
return stream, metadata, nil
return stream, metadata, errChan, nil
}
`))

Expand Down Expand Up @@ -727,12 +733,25 @@ func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client(ctx context.Context,
runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
return
}
{{if and $m.GetClientStreaming $m.GetServerStreaming }}
resp, md, reqErrChan, err := request_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, inboundMarshaler, client, req, pathParams)
{{- else -}}
resp, md, err := request_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, inboundMarshaler, client, req, pathParams)
{{- end }}
annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md)
if err != nil {
runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err)
return
}
{{- if and $m.GetClientStreaming $m.GetServerStreaming }}
go func() {
for err := range reqErrChan {
if err != nil && err != io.EOF {
runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err)
}
}
}()
{{- end }}
{{if $m.GetServerStreaming}}
{{ if $b.ResponseBody }}
forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, func() (proto.Message, error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ func TestApplyTemplateRequestWithClientStreaming(t *testing.T) {
},
{
serverStreaming: true,
sigWant: `func request_ExampleService_Echo_0(ctx context.Context, marshaler runtime.Marshaler, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, runtime.ServerMetadata, error) {`,
sigWant: `func request_ExampleService_Echo_0(ctx context.Context, marshaler runtime.Marshaler, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, runtime.ServerMetadata, chan error, error) {`,
},
} {
meth.ServerStreaming = proto.Bool(spec.serverStreaming)
Expand Down

0 comments on commit 830ba27

Please sign in to comment.