From 19cf1349ba85f2f3aa2c03eb203548408de9f0bb Mon Sep 17 00:00:00 2001 From: Ian Lee Date: Wed, 9 Oct 2024 16:22:43 -0700 Subject: [PATCH 1/5] feat: expose invalid argument error to clients in bidirectional streaming (#4795) --- .../proto/examplepb/flow_combination.pb.gw.go | 30 +++++++++++----- .../internal/proto/examplepb/stream.pb.gw.go | 30 +++++++++++----- .../internal/gengateway/template.go | 35 ++++++++++++++----- .../internal/gengateway/template_test.go | 2 +- 4 files changed, 70 insertions(+), 27 deletions(-) diff --git a/examples/internal/proto/examplepb/flow_combination.pb.gw.go b/examples/internal/proto/examplepb/flow_combination.pb.gw.go index 94b63fa25c8..9922def36c7 100644 --- a/examples/internal/proto/examplepb/flow_combination.pb.gw.go +++ b/examples/internal/proto/examplepb/flow_combination.pb.gw.go @@ -110,12 +110,13 @@ func request_FlowCombination_StreamEmptyRpc_0(ctx context.Context, marshaler run } -func request_FlowCombination_StreamEmptyStream_0(ctx context.Context, marshaler runtime.Marshaler, client FlowCombinationClient, req *http.Request, pathParams map[string]string) (FlowCombination_StreamEmptyStreamClient, runtime.ServerMetadata, error) { +func request_FlowCombination_StreamEmptyStream_0(ctx context.Context, marshaler runtime.Marshaler, client FlowCombinationClient, req *http.Request, pathParams map[string]string) (FlowCombination_StreamEmptyStreamClient, runtime.ServerMetadata, chan error, error) { var metadata runtime.ServerMetadata + errChan := make(chan error, 1) stream, err := client.StreamEmptyStream(ctx) if err != nil { grpclog.Errorf("Failed to start streaming: %v", err) - return nil, metadata, err + return nil, metadata, errChan, err } dec := marshaler.NewDecoder(req.Body) handleSend := func() error { @@ -126,7 +127,7 @@ func request_FlowCombination_StreamEmptyStream_0(ctx context.Context, marshaler } if err != nil { grpclog.Errorf("Failed to decode request: %v", err) - return err + return status.Errorf(codes.InvalidArgument, "Failed to decode request: %v", err) } if err := stream.Send(&protoReq); err != nil { grpclog.Errorf("Failed to send request: %v", err) @@ -137,20 +138,18 @@ func request_FlowCombination_StreamEmptyStream_0(ctx context.Context, marshaler go func() { for { if err := handleSend(); err != nil { + errChan <- err break } } - if err := stream.CloseSend(); err != nil { - grpclog.Errorf("Failed to terminate client stream: %v", err) - } }() header, err := stream.Header() if err != nil { grpclog.Errorf("Failed to get header from client: %v", err) - return nil, metadata, err + return nil, metadata, errChan, err } metadata.HeaderMD = header - return stream, metadata, nil + return stream, metadata, errChan, nil } func request_FlowCombination_RpcBodyRpc_0(ctx context.Context, marshaler runtime.Marshaler, client FlowCombinationClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { @@ -1893,12 +1892,25 @@ func RegisterFlowCombinationHandlerClient(ctx context.Context, mux *runtime.Serv runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return } - resp, md, err := request_FlowCombination_StreamEmptyStream_0(annotatedContext, inboundMarshaler, client, req, pathParams) + + resp, md, reqErrChan, err := request_FlowCombination_StreamEmptyStream_0(annotatedContext, inboundMarshaler, client, req, pathParams) annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) if err != nil { runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } + go func() { + for { + err := <-reqErrChan + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + if err := resp.CloseSend(); err != nil { + grpclog.Errorf("Failed to terminate client stream: %v", err) + } + return + } + } + }() forward_FlowCombination_StreamEmptyStream_0(annotatedContext, mux, outboundMarshaler, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) diff --git a/examples/internal/proto/examplepb/stream.pb.gw.go b/examples/internal/proto/examplepb/stream.pb.gw.go index c3d44eed522..c9308f1631f 100644 --- a/examples/internal/proto/examplepb/stream.pb.gw.go +++ b/examples/internal/proto/examplepb/stream.pb.gw.go @@ -104,12 +104,13 @@ func request_StreamService_List_0(ctx context.Context, marshaler runtime.Marshal } -func request_StreamService_BulkEcho_0(ctx context.Context, marshaler runtime.Marshaler, client StreamServiceClient, req *http.Request, pathParams map[string]string) (StreamService_BulkEchoClient, runtime.ServerMetadata, error) { +func request_StreamService_BulkEcho_0(ctx context.Context, marshaler runtime.Marshaler, client StreamServiceClient, req *http.Request, pathParams map[string]string) (StreamService_BulkEchoClient, runtime.ServerMetadata, chan error, error) { var metadata runtime.ServerMetadata + errChan := make(chan error, 1) stream, err := client.BulkEcho(ctx) if err != nil { grpclog.Errorf("Failed to start streaming: %v", err) - return nil, metadata, err + return nil, metadata, errChan, err } dec := marshaler.NewDecoder(req.Body) handleSend := func() error { @@ -120,7 +121,7 @@ func request_StreamService_BulkEcho_0(ctx context.Context, marshaler runtime.Mar } if err != nil { grpclog.Errorf("Failed to decode request: %v", err) - return err + return status.Errorf(codes.InvalidArgument, "Failed to decode request: %v", err) } if err := stream.Send(&protoReq); err != nil { grpclog.Errorf("Failed to send request: %v", err) @@ -131,20 +132,18 @@ func request_StreamService_BulkEcho_0(ctx context.Context, marshaler runtime.Mar go func() { for { if err := handleSend(); err != nil { + errChan <- err break } } - if err := stream.CloseSend(); err != nil { - grpclog.Errorf("Failed to terminate client stream: %v", err) - } }() header, err := stream.Header() if err != nil { grpclog.Errorf("Failed to get header from client: %v", err) - return nil, metadata, err + return nil, metadata, errChan, err } metadata.HeaderMD = header - return stream, metadata, nil + return stream, metadata, errChan, nil } var ( @@ -306,12 +305,25 @@ func RegisterStreamServiceHandlerClient(ctx context.Context, mux *runtime.ServeM runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return } - resp, md, err := request_StreamService_BulkEcho_0(annotatedContext, inboundMarshaler, client, req, pathParams) + + resp, md, reqErrChan, err := request_StreamService_BulkEcho_0(annotatedContext, inboundMarshaler, client, req, pathParams) annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) if err != nil { runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } + go func() { + for { + err := <-reqErrChan + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + if err := resp.CloseSend(); err != nil { + grpclog.Errorf("Failed to terminate client stream: %v", err) + } + return + } + } + }() forward_StreamService_BulkEcho_0(annotatedContext, mux, outboundMarshaler, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) diff --git a/protoc-gen-grpc-gateway/internal/gengateway/template.go b/protoc-gen-grpc-gateway/internal/gengateway/template.go index 373ed2798f4..2aeec622021 100644 --- a/protoc-gen-grpc-gateway/internal/gengateway/template.go +++ b/protoc-gen-grpc-gateway/internal/gengateway/template.go @@ -256,7 +256,9 @@ var _ = metadata.Join `)) _ = template.Must(handlerTemplate.New("request-func-signature").Parse(strings.ReplaceAll(` -{{if .Method.GetServerStreaming}} +{{if and .Method.GetClientStreaming .Method.GetServerStreaming}} +func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, marshaler runtime.Marshaler, client {{.Method.Service.InstanceName}}Client, req *http.Request, pathParams map[string]string) ({{.Method.Service.InstanceName}}_{{.Method.GetName}}Client, runtime.ServerMetadata, chan error, error) +{{else if .Method.GetServerStreaming}} func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, marshaler runtime.Marshaler, client {{.Method.Service.InstanceName}}Client, req *http.Request, pathParams map[string]string) ({{.Method.Service.InstanceName}}_{{.Method.GetName}}Client, runtime.ServerMetadata, error) {{else}} func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, marshaler runtime.Marshaler, client {{.Method.Service.InstanceName}}Client, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) @@ -439,10 +441,11 @@ var ( _ = template.Must(handlerTemplate.New("bidi-streaming-request-func").Parse(` {{template "request-func-signature" .}} { var metadata runtime.ServerMetadata + errChan := make(chan error, 1) stream, err := client.{{.Method.GetName}}(ctx) if err != nil { grpclog.Errorf("Failed to start streaming: %v", err) - return nil, metadata, err + return nil, metadata, errChan, err } dec := marshaler.NewDecoder(req.Body) handleSend := func() error { @@ -453,7 +456,7 @@ var ( } if err != nil { grpclog.Errorf("Failed to decode request: %v", err) - return err + return status.Errorf(codes.InvalidArgument, "Failed to decode request: %v", err) } if err := stream.Send(&protoReq); err != nil { grpclog.Errorf("Failed to send request: %v", err) @@ -464,20 +467,18 @@ var ( go func() { for { if err := handleSend(); err != nil { + errChan <- err break } } - if err := stream.CloseSend(); err != nil { - grpclog.Errorf("Failed to terminate client stream: %v", err) - } }() header, err := stream.Header() if err != nil { grpclog.Errorf("Failed to get header from client: %v", err) - return nil, metadata, err + return nil, metadata, errChan, err } metadata.HeaderMD = header - return stream, metadata, nil + return stream, metadata, errChan, nil } `)) @@ -727,12 +728,30 @@ func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client(ctx context.Context, runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return } + {{if and $m.GetClientStreaming $m.GetServerStreaming }} + resp, md, reqErrChan, err := request_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, inboundMarshaler, client, req, pathParams) + {{- else -}} resp, md, err := request_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, inboundMarshaler, client, req, pathParams) + {{- end }} annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) if err != nil { runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } + {{- if and $m.GetClientStreaming $m.GetServerStreaming }} + go func() { + for { + err := <-reqErrChan + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + if err := resp.CloseSend(); err != nil { + grpclog.Errorf("Failed to terminate client stream: %v", err) + } + return + } + } + }() + {{- end }} {{if $m.GetServerStreaming}} {{ if $b.ResponseBody }} forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, func() (proto.Message, error) { diff --git a/protoc-gen-grpc-gateway/internal/gengateway/template_test.go b/protoc-gen-grpc-gateway/internal/gengateway/template_test.go index cabc284319f..f2dc868487e 100644 --- a/protoc-gen-grpc-gateway/internal/gengateway/template_test.go +++ b/protoc-gen-grpc-gateway/internal/gengateway/template_test.go @@ -316,7 +316,7 @@ func TestApplyTemplateRequestWithClientStreaming(t *testing.T) { }, { serverStreaming: true, - sigWant: `func request_ExampleService_Echo_0(ctx context.Context, marshaler runtime.Marshaler, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, runtime.ServerMetadata, error) {`, + sigWant: `func request_ExampleService_Echo_0(ctx context.Context, marshaler runtime.Marshaler, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, runtime.ServerMetadata, chan error, error) {`, }, } { meth.ServerStreaming = proto.Bool(spec.serverStreaming) From 8015368c32cd6f5097cd6240541231bec18a837f Mon Sep 17 00:00:00 2001 From: Ian Lee Date: Tue, 15 Oct 2024 17:00:14 -0700 Subject: [PATCH 2/5] fix for loop --- .../proto/examplepb/flow_combination.pb.gw.go | 3 +-- examples/internal/proto/examplepb/stream.pb.gw.go | 3 +-- .../internal/gengateway/template.go | 13 ++++++------- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/examples/internal/proto/examplepb/flow_combination.pb.gw.go b/examples/internal/proto/examplepb/flow_combination.pb.gw.go index 9922def36c7..7a5b3735c3b 100644 --- a/examples/internal/proto/examplepb/flow_combination.pb.gw.go +++ b/examples/internal/proto/examplepb/flow_combination.pb.gw.go @@ -1900,8 +1900,7 @@ func RegisterFlowCombinationHandlerClient(ctx context.Context, mux *runtime.Serv return } go func() { - for { - err := <-reqErrChan + for err := range reqErrChan { if err != nil { runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) if err := resp.CloseSend(); err != nil { diff --git a/examples/internal/proto/examplepb/stream.pb.gw.go b/examples/internal/proto/examplepb/stream.pb.gw.go index c9308f1631f..b1e7bc6bea2 100644 --- a/examples/internal/proto/examplepb/stream.pb.gw.go +++ b/examples/internal/proto/examplepb/stream.pb.gw.go @@ -313,8 +313,7 @@ func RegisterStreamServiceHandlerClient(ctx context.Context, mux *runtime.ServeM return } go func() { - for { - err := <-reqErrChan + for err := range reqErrChan { if err != nil { runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) if err := resp.CloseSend(); err != nil { diff --git a/protoc-gen-grpc-gateway/internal/gengateway/template.go b/protoc-gen-grpc-gateway/internal/gengateway/template.go index 2aeec622021..8fe7412d6ef 100644 --- a/protoc-gen-grpc-gateway/internal/gengateway/template.go +++ b/protoc-gen-grpc-gateway/internal/gengateway/template.go @@ -471,6 +471,7 @@ var ( break } } + close(errChan) }() header, err := stream.Header() if err != nil { @@ -740,14 +741,12 @@ func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client(ctx context.Context, } {{- if and $m.GetClientStreaming $m.GetServerStreaming }} go func() { - for { - err := <-reqErrChan - if err != nil { + for err := range reqErrChan { + if err != io.EOF { runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - if err := resp.CloseSend(); err != nil { - grpclog.Errorf("Failed to terminate client stream: %v", err) - } - return + } + if err := resp.CloseSend(); err != nil { + grpclog.Errorf("Failed to terminate client stream: %v", err) } } }() From 690ed314b2ef058838e8acf0107eda0e78f98d0d Mon Sep 17 00:00:00 2001 From: Ian Lee Date: Tue, 15 Oct 2024 17:04:33 -0700 Subject: [PATCH 3/5] remove for loop --- .../proto/examplepb/flow_combination.pb.gw.go | 15 +++++++-------- examples/internal/proto/examplepb/stream.pb.gw.go | 15 +++++++-------- .../internal/gengateway/template.go | 13 ++++++------- 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/examples/internal/proto/examplepb/flow_combination.pb.gw.go b/examples/internal/proto/examplepb/flow_combination.pb.gw.go index 7a5b3735c3b..48f803afcf9 100644 --- a/examples/internal/proto/examplepb/flow_combination.pb.gw.go +++ b/examples/internal/proto/examplepb/flow_combination.pb.gw.go @@ -142,6 +142,7 @@ func request_FlowCombination_StreamEmptyStream_0(ctx context.Context, marshaler break } } + close(errChan) }() header, err := stream.Header() if err != nil { @@ -1900,14 +1901,12 @@ func RegisterFlowCombinationHandlerClient(ctx context.Context, mux *runtime.Serv return } go func() { - for err := range reqErrChan { - if err != nil { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - if err := resp.CloseSend(); err != nil { - grpclog.Errorf("Failed to terminate client stream: %v", err) - } - return - } + err := <-reqErrChan + if err != io.EOF { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + } + if err := resp.CloseSend(); err != nil { + grpclog.Errorf("Failed to terminate client stream: %v", err) } }() diff --git a/examples/internal/proto/examplepb/stream.pb.gw.go b/examples/internal/proto/examplepb/stream.pb.gw.go index b1e7bc6bea2..1b46a3ca584 100644 --- a/examples/internal/proto/examplepb/stream.pb.gw.go +++ b/examples/internal/proto/examplepb/stream.pb.gw.go @@ -136,6 +136,7 @@ func request_StreamService_BulkEcho_0(ctx context.Context, marshaler runtime.Mar break } } + close(errChan) }() header, err := stream.Header() if err != nil { @@ -313,14 +314,12 @@ func RegisterStreamServiceHandlerClient(ctx context.Context, mux *runtime.ServeM return } go func() { - for err := range reqErrChan { - if err != nil { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - if err := resp.CloseSend(); err != nil { - grpclog.Errorf("Failed to terminate client stream: %v", err) - } - return - } + err := <-reqErrChan + if err != io.EOF { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + } + if err := resp.CloseSend(); err != nil { + grpclog.Errorf("Failed to terminate client stream: %v", err) } }() diff --git a/protoc-gen-grpc-gateway/internal/gengateway/template.go b/protoc-gen-grpc-gateway/internal/gengateway/template.go index 8fe7412d6ef..32a01577eca 100644 --- a/protoc-gen-grpc-gateway/internal/gengateway/template.go +++ b/protoc-gen-grpc-gateway/internal/gengateway/template.go @@ -741,13 +741,12 @@ func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client(ctx context.Context, } {{- if and $m.GetClientStreaming $m.GetServerStreaming }} go func() { - for err := range reqErrChan { - if err != io.EOF { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - } - if err := resp.CloseSend(); err != nil { - grpclog.Errorf("Failed to terminate client stream: %v", err) - } + err := <-reqErrChan + if err != io.EOF { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + } + if err := resp.CloseSend(); err != nil { + grpclog.Errorf("Failed to terminate client stream: %v", err) } }() {{- end }} From 8184dce9acf75f63902fd2758d3257245e9fd796 Mon Sep 17 00:00:00 2001 From: Ian Lee Date: Thu, 17 Oct 2024 11:36:18 -0700 Subject: [PATCH 4/5] revert to for loop --- .../internal/proto/examplepb/flow_combination.pb.gw.go | 9 +++++---- examples/internal/proto/examplepb/stream.pb.gw.go | 9 +++++---- protoc-gen-grpc-gateway/internal/gengateway/template.go | 9 +++++---- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/examples/internal/proto/examplepb/flow_combination.pb.gw.go b/examples/internal/proto/examplepb/flow_combination.pb.gw.go index 48f803afcf9..a441d646289 100644 --- a/examples/internal/proto/examplepb/flow_combination.pb.gw.go +++ b/examples/internal/proto/examplepb/flow_combination.pb.gw.go @@ -136,13 +136,13 @@ func request_FlowCombination_StreamEmptyStream_0(ctx context.Context, marshaler return nil } go func() { + defer close(errChan) for { if err := handleSend(); err != nil { errChan <- err break } } - close(errChan) }() header, err := stream.Header() if err != nil { @@ -1901,9 +1901,10 @@ func RegisterFlowCombinationHandlerClient(ctx context.Context, mux *runtime.Serv return } go func() { - err := <-reqErrChan - if err != io.EOF { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + for err := range reqErrChan { + if err != nil && err != io.EOF { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + } } if err := resp.CloseSend(); err != nil { grpclog.Errorf("Failed to terminate client stream: %v", err) diff --git a/examples/internal/proto/examplepb/stream.pb.gw.go b/examples/internal/proto/examplepb/stream.pb.gw.go index 1b46a3ca584..8414013c14e 100644 --- a/examples/internal/proto/examplepb/stream.pb.gw.go +++ b/examples/internal/proto/examplepb/stream.pb.gw.go @@ -130,13 +130,13 @@ func request_StreamService_BulkEcho_0(ctx context.Context, marshaler runtime.Mar return nil } go func() { + defer close(errChan) for { if err := handleSend(); err != nil { errChan <- err break } } - close(errChan) }() header, err := stream.Header() if err != nil { @@ -314,9 +314,10 @@ func RegisterStreamServiceHandlerClient(ctx context.Context, mux *runtime.ServeM return } go func() { - err := <-reqErrChan - if err != io.EOF { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + for err := range reqErrChan { + if err != nil && err != io.EOF { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + } } if err := resp.CloseSend(); err != nil { grpclog.Errorf("Failed to terminate client stream: %v", err) diff --git a/protoc-gen-grpc-gateway/internal/gengateway/template.go b/protoc-gen-grpc-gateway/internal/gengateway/template.go index 32a01577eca..c615a54cfde 100644 --- a/protoc-gen-grpc-gateway/internal/gengateway/template.go +++ b/protoc-gen-grpc-gateway/internal/gengateway/template.go @@ -465,13 +465,13 @@ var ( return nil } go func() { + defer close(errChan) for { if err := handleSend(); err != nil { errChan <- err break } } - close(errChan) }() header, err := stream.Header() if err != nil { @@ -741,9 +741,10 @@ func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client(ctx context.Context, } {{- if and $m.GetClientStreaming $m.GetServerStreaming }} go func() { - err := <-reqErrChan - if err != io.EOF { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + for err := range reqErrChan { + if err != nil && err != io.EOF { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + } } if err := resp.CloseSend(); err != nil { grpclog.Errorf("Failed to terminate client stream: %v", err) From a99930615d5a56acc29010dd048ad207ddb1f3b4 Mon Sep 17 00:00:00 2001 From: Ian Lee Date: Sat, 19 Oct 2024 00:45:16 -0700 Subject: [PATCH 5/5] move stream close to not block stream.header() --- .../internal/proto/examplepb/flow_combination.pb.gw.go | 7 ++++--- examples/internal/proto/examplepb/stream.pb.gw.go | 7 ++++--- protoc-gen-grpc-gateway/internal/gengateway/template.go | 7 ++++--- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/examples/internal/proto/examplepb/flow_combination.pb.gw.go b/examples/internal/proto/examplepb/flow_combination.pb.gw.go index a441d646289..98af744ae42 100644 --- a/examples/internal/proto/examplepb/flow_combination.pb.gw.go +++ b/examples/internal/proto/examplepb/flow_combination.pb.gw.go @@ -116,6 +116,7 @@ func request_FlowCombination_StreamEmptyStream_0(ctx context.Context, marshaler stream, err := client.StreamEmptyStream(ctx) if err != nil { grpclog.Errorf("Failed to start streaming: %v", err) + close(errChan) return nil, metadata, errChan, err } dec := marshaler.NewDecoder(req.Body) @@ -143,6 +144,9 @@ func request_FlowCombination_StreamEmptyStream_0(ctx context.Context, marshaler break } } + if err := stream.CloseSend(); err != nil { + grpclog.Errorf("Failed to terminate client stream: %v", err) + } }() header, err := stream.Header() if err != nil { @@ -1906,9 +1910,6 @@ func RegisterFlowCombinationHandlerClient(ctx context.Context, mux *runtime.Serv runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) } } - if err := resp.CloseSend(); err != nil { - grpclog.Errorf("Failed to terminate client stream: %v", err) - } }() forward_FlowCombination_StreamEmptyStream_0(annotatedContext, mux, outboundMarshaler, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) diff --git a/examples/internal/proto/examplepb/stream.pb.gw.go b/examples/internal/proto/examplepb/stream.pb.gw.go index 8414013c14e..9a419c81512 100644 --- a/examples/internal/proto/examplepb/stream.pb.gw.go +++ b/examples/internal/proto/examplepb/stream.pb.gw.go @@ -110,6 +110,7 @@ func request_StreamService_BulkEcho_0(ctx context.Context, marshaler runtime.Mar stream, err := client.BulkEcho(ctx) if err != nil { grpclog.Errorf("Failed to start streaming: %v", err) + close(errChan) return nil, metadata, errChan, err } dec := marshaler.NewDecoder(req.Body) @@ -137,6 +138,9 @@ func request_StreamService_BulkEcho_0(ctx context.Context, marshaler runtime.Mar break } } + if err := stream.CloseSend(); err != nil { + grpclog.Errorf("Failed to terminate client stream: %v", err) + } }() header, err := stream.Header() if err != nil { @@ -319,9 +323,6 @@ func RegisterStreamServiceHandlerClient(ctx context.Context, mux *runtime.ServeM runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) } } - if err := resp.CloseSend(); err != nil { - grpclog.Errorf("Failed to terminate client stream: %v", err) - } }() forward_StreamService_BulkEcho_0(annotatedContext, mux, outboundMarshaler, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) diff --git a/protoc-gen-grpc-gateway/internal/gengateway/template.go b/protoc-gen-grpc-gateway/internal/gengateway/template.go index c615a54cfde..7f2924e13b1 100644 --- a/protoc-gen-grpc-gateway/internal/gengateway/template.go +++ b/protoc-gen-grpc-gateway/internal/gengateway/template.go @@ -445,6 +445,7 @@ var ( stream, err := client.{{.Method.GetName}}(ctx) if err != nil { grpclog.Errorf("Failed to start streaming: %v", err) + close(errChan) return nil, metadata, errChan, err } dec := marshaler.NewDecoder(req.Body) @@ -472,6 +473,9 @@ var ( break } } + if err := stream.CloseSend(); err != nil { + grpclog.Errorf("Failed to terminate client stream: %v", err) + } }() header, err := stream.Header() if err != nil { @@ -746,9 +750,6 @@ func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client(ctx context.Context, runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) } } - if err := resp.CloseSend(); err != nil { - grpclog.Errorf("Failed to terminate client stream: %v", err) - } }() {{- end }} {{if $m.GetServerStreaming}}