From 10f6e39787500f2dd7ae04f64ac1821c570fc3cc Mon Sep 17 00:00:00 2001 From: Ian Lee Date: Wed, 9 Oct 2024 16:22:43 -0700 Subject: [PATCH] feat: expose invalid argument error to clients in bidirectional streaming (#4795) --- .../internal/gengateway/template.go | 28 +++++++++++++------ .../internal/gengateway/template_test.go | 2 +- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/protoc-gen-grpc-gateway/internal/gengateway/template.go b/protoc-gen-grpc-gateway/internal/gengateway/template.go index 373ed2798f4..e1c34d3bc60 100644 --- a/protoc-gen-grpc-gateway/internal/gengateway/template.go +++ b/protoc-gen-grpc-gateway/internal/gengateway/template.go @@ -257,7 +257,7 @@ var _ = metadata.Join _ = template.Must(handlerTemplate.New("request-func-signature").Parse(strings.ReplaceAll(` {{if .Method.GetServerStreaming}} -func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, marshaler runtime.Marshaler, client {{.Method.Service.InstanceName}}Client, req *http.Request, pathParams map[string]string) ({{.Method.Service.InstanceName}}_{{.Method.GetName}}Client, runtime.ServerMetadata, error) +func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, marshaler runtime.Marshaler, client {{.Method.Service.InstanceName}}Client, req *http.Request, pathParams map[string]string) ({{.Method.Service.InstanceName}}_{{.Method.GetName}}Client, runtime.ServerMetadata, chan error, error) {{else}} func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, marshaler runtime.Marshaler, client {{.Method.Service.InstanceName}}Client, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {{end}}`, "\n", ""))) @@ -439,10 +439,11 @@ var ( _ = template.Must(handlerTemplate.New("bidi-streaming-request-func").Parse(` {{template "request-func-signature" .}} { var metadata runtime.ServerMetadata + errChan := make(chan error) stream, err := client.{{.Method.GetName}}(ctx) if err != nil { grpclog.Errorf("Failed to start streaming: %v", err) - return nil, metadata, err + return nil, metadata, errChan, err } dec := marshaler.NewDecoder(req.Body) handleSend := func() error { @@ -453,7 +454,7 @@ var ( } if err != nil { grpclog.Errorf("Failed to decode request: %v", err) - return err + return status.Errorf(codes.InvalidArgument, "Failed to decode request: %v", err) } if err := stream.Send(&protoReq); err != nil { grpclog.Errorf("Failed to send request: %v", err) @@ -464,20 +465,18 @@ var ( go func() { for { if err := handleSend(); err != nil { + errChan <- err break } } - if err := stream.CloseSend(); err != nil { - grpclog.Errorf("Failed to terminate client stream: %v", err) - } }() header, err := stream.Header() if err != nil { grpclog.Errorf("Failed to get header from client: %v", err) - return nil, metadata, err + return nil, metadata, errChan, err } metadata.HeaderMD = header - return stream, metadata, nil + return stream, metadata, errChan, nil } `)) @@ -727,12 +726,23 @@ func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client(ctx context.Context, runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return } - resp, md, err := request_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, inboundMarshaler, client, req, pathParams) + resp, md, reqErrChan, err := request_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, inboundMarshaler, client, req, pathParams) annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) if err != nil { runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } + go func() { + for { + err := <-reqErrChan + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + if err := resp.CloseSend(); err != nil { + grpclog.Errorf("Failed to terminate client stream: %v", err) + } + } + } + }() {{if $m.GetServerStreaming}} {{ if $b.ResponseBody }} forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, func() (proto.Message, error) { diff --git a/protoc-gen-grpc-gateway/internal/gengateway/template_test.go b/protoc-gen-grpc-gateway/internal/gengateway/template_test.go index cabc284319f..f2dc868487e 100644 --- a/protoc-gen-grpc-gateway/internal/gengateway/template_test.go +++ b/protoc-gen-grpc-gateway/internal/gengateway/template_test.go @@ -316,7 +316,7 @@ func TestApplyTemplateRequestWithClientStreaming(t *testing.T) { }, { serverStreaming: true, - sigWant: `func request_ExampleService_Echo_0(ctx context.Context, marshaler runtime.Marshaler, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, runtime.ServerMetadata, error) {`, + sigWant: `func request_ExampleService_Echo_0(ctx context.Context, marshaler runtime.Marshaler, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, runtime.ServerMetadata, chan error, error) {`, }, } { meth.ServerStreaming = proto.Bool(spec.serverStreaming)