diff --git a/examples/internal/proto/examplepb/flow_combination.pb.gw.go b/examples/internal/proto/examplepb/flow_combination.pb.gw.go index 94b63fa25c8..98af744ae42 100644 --- a/examples/internal/proto/examplepb/flow_combination.pb.gw.go +++ b/examples/internal/proto/examplepb/flow_combination.pb.gw.go @@ -110,12 +110,14 @@ func request_FlowCombination_StreamEmptyRpc_0(ctx context.Context, marshaler run } -func request_FlowCombination_StreamEmptyStream_0(ctx context.Context, marshaler runtime.Marshaler, client FlowCombinationClient, req *http.Request, pathParams map[string]string) (FlowCombination_StreamEmptyStreamClient, runtime.ServerMetadata, error) { +func request_FlowCombination_StreamEmptyStream_0(ctx context.Context, marshaler runtime.Marshaler, client FlowCombinationClient, req *http.Request, pathParams map[string]string) (FlowCombination_StreamEmptyStreamClient, runtime.ServerMetadata, chan error, error) { var metadata runtime.ServerMetadata + errChan := make(chan error, 1) stream, err := client.StreamEmptyStream(ctx) if err != nil { grpclog.Errorf("Failed to start streaming: %v", err) - return nil, metadata, err + close(errChan) + return nil, metadata, errChan, err } dec := marshaler.NewDecoder(req.Body) handleSend := func() error { @@ -126,7 +128,7 @@ func request_FlowCombination_StreamEmptyStream_0(ctx context.Context, marshaler } if err != nil { grpclog.Errorf("Failed to decode request: %v", err) - return err + return status.Errorf(codes.InvalidArgument, "Failed to decode request: %v", err) } if err := stream.Send(&protoReq); err != nil { grpclog.Errorf("Failed to send request: %v", err) @@ -135,8 +137,10 @@ func request_FlowCombination_StreamEmptyStream_0(ctx context.Context, marshaler return nil } go func() { + defer close(errChan) for { if err := handleSend(); err != nil { + errChan <- err break } } @@ -147,10 +151,10 @@ func request_FlowCombination_StreamEmptyStream_0(ctx context.Context, marshaler header, err := stream.Header() if err != nil { grpclog.Errorf("Failed to get header from client: %v", err) - return nil, metadata, err + return nil, metadata, errChan, err } metadata.HeaderMD = header - return stream, metadata, nil + return stream, metadata, errChan, nil } func request_FlowCombination_RpcBodyRpc_0(ctx context.Context, marshaler runtime.Marshaler, client FlowCombinationClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { @@ -1893,12 +1897,20 @@ func RegisterFlowCombinationHandlerClient(ctx context.Context, mux *runtime.Serv runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return } - resp, md, err := request_FlowCombination_StreamEmptyStream_0(annotatedContext, inboundMarshaler, client, req, pathParams) + + resp, md, reqErrChan, err := request_FlowCombination_StreamEmptyStream_0(annotatedContext, inboundMarshaler, client, req, pathParams) annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) if err != nil { runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } + go func() { + for err := range reqErrChan { + if err != nil && err != io.EOF { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + } + } + }() forward_FlowCombination_StreamEmptyStream_0(annotatedContext, mux, outboundMarshaler, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) diff --git a/examples/internal/proto/examplepb/stream.pb.gw.go b/examples/internal/proto/examplepb/stream.pb.gw.go index c3d44eed522..9a419c81512 100644 --- a/examples/internal/proto/examplepb/stream.pb.gw.go +++ b/examples/internal/proto/examplepb/stream.pb.gw.go @@ -104,12 +104,14 @@ func request_StreamService_List_0(ctx context.Context, marshaler runtime.Marshal } -func request_StreamService_BulkEcho_0(ctx context.Context, marshaler runtime.Marshaler, client StreamServiceClient, req *http.Request, pathParams map[string]string) (StreamService_BulkEchoClient, runtime.ServerMetadata, error) { +func request_StreamService_BulkEcho_0(ctx context.Context, marshaler runtime.Marshaler, client StreamServiceClient, req *http.Request, pathParams map[string]string) (StreamService_BulkEchoClient, runtime.ServerMetadata, chan error, error) { var metadata runtime.ServerMetadata + errChan := make(chan error, 1) stream, err := client.BulkEcho(ctx) if err != nil { grpclog.Errorf("Failed to start streaming: %v", err) - return nil, metadata, err + close(errChan) + return nil, metadata, errChan, err } dec := marshaler.NewDecoder(req.Body) handleSend := func() error { @@ -120,7 +122,7 @@ func request_StreamService_BulkEcho_0(ctx context.Context, marshaler runtime.Mar } if err != nil { grpclog.Errorf("Failed to decode request: %v", err) - return err + return status.Errorf(codes.InvalidArgument, "Failed to decode request: %v", err) } if err := stream.Send(&protoReq); err != nil { grpclog.Errorf("Failed to send request: %v", err) @@ -129,8 +131,10 @@ func request_StreamService_BulkEcho_0(ctx context.Context, marshaler runtime.Mar return nil } go func() { + defer close(errChan) for { if err := handleSend(); err != nil { + errChan <- err break } } @@ -141,10 +145,10 @@ func request_StreamService_BulkEcho_0(ctx context.Context, marshaler runtime.Mar header, err := stream.Header() if err != nil { grpclog.Errorf("Failed to get header from client: %v", err) - return nil, metadata, err + return nil, metadata, errChan, err } metadata.HeaderMD = header - return stream, metadata, nil + return stream, metadata, errChan, nil } var ( @@ -306,12 +310,20 @@ func RegisterStreamServiceHandlerClient(ctx context.Context, mux *runtime.ServeM runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return } - resp, md, err := request_StreamService_BulkEcho_0(annotatedContext, inboundMarshaler, client, req, pathParams) + + resp, md, reqErrChan, err := request_StreamService_BulkEcho_0(annotatedContext, inboundMarshaler, client, req, pathParams) annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) if err != nil { runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } + go func() { + for err := range reqErrChan { + if err != nil && err != io.EOF { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + } + } + }() forward_StreamService_BulkEcho_0(annotatedContext, mux, outboundMarshaler, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) diff --git a/protoc-gen-grpc-gateway/internal/gengateway/template.go b/protoc-gen-grpc-gateway/internal/gengateway/template.go index 373ed2798f4..7f2924e13b1 100644 --- a/protoc-gen-grpc-gateway/internal/gengateway/template.go +++ b/protoc-gen-grpc-gateway/internal/gengateway/template.go @@ -256,7 +256,9 @@ var _ = metadata.Join `)) _ = template.Must(handlerTemplate.New("request-func-signature").Parse(strings.ReplaceAll(` -{{if .Method.GetServerStreaming}} +{{if and .Method.GetClientStreaming .Method.GetServerStreaming}} +func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, marshaler runtime.Marshaler, client {{.Method.Service.InstanceName}}Client, req *http.Request, pathParams map[string]string) ({{.Method.Service.InstanceName}}_{{.Method.GetName}}Client, runtime.ServerMetadata, chan error, error) +{{else if .Method.GetServerStreaming}} func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, marshaler runtime.Marshaler, client {{.Method.Service.InstanceName}}Client, req *http.Request, pathParams map[string]string) ({{.Method.Service.InstanceName}}_{{.Method.GetName}}Client, runtime.ServerMetadata, error) {{else}} func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, marshaler runtime.Marshaler, client {{.Method.Service.InstanceName}}Client, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) @@ -439,10 +441,12 @@ var ( _ = template.Must(handlerTemplate.New("bidi-streaming-request-func").Parse(` {{template "request-func-signature" .}} { var metadata runtime.ServerMetadata + errChan := make(chan error, 1) stream, err := client.{{.Method.GetName}}(ctx) if err != nil { grpclog.Errorf("Failed to start streaming: %v", err) - return nil, metadata, err + close(errChan) + return nil, metadata, errChan, err } dec := marshaler.NewDecoder(req.Body) handleSend := func() error { @@ -453,7 +457,7 @@ var ( } if err != nil { grpclog.Errorf("Failed to decode request: %v", err) - return err + return status.Errorf(codes.InvalidArgument, "Failed to decode request: %v", err) } if err := stream.Send(&protoReq); err != nil { grpclog.Errorf("Failed to send request: %v", err) @@ -462,8 +466,10 @@ var ( return nil } go func() { + defer close(errChan) for { if err := handleSend(); err != nil { + errChan <- err break } } @@ -474,10 +480,10 @@ var ( header, err := stream.Header() if err != nil { grpclog.Errorf("Failed to get header from client: %v", err) - return nil, metadata, err + return nil, metadata, errChan, err } metadata.HeaderMD = header - return stream, metadata, nil + return stream, metadata, errChan, nil } `)) @@ -727,12 +733,25 @@ func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client(ctx context.Context, runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return } + {{if and $m.GetClientStreaming $m.GetServerStreaming }} + resp, md, reqErrChan, err := request_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, inboundMarshaler, client, req, pathParams) + {{- else -}} resp, md, err := request_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, inboundMarshaler, client, req, pathParams) + {{- end }} annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) if err != nil { runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } + {{- if and $m.GetClientStreaming $m.GetServerStreaming }} + go func() { + for err := range reqErrChan { + if err != nil && err != io.EOF { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + } + } + }() + {{- end }} {{if $m.GetServerStreaming}} {{ if $b.ResponseBody }} forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, func() (proto.Message, error) { diff --git a/protoc-gen-grpc-gateway/internal/gengateway/template_test.go b/protoc-gen-grpc-gateway/internal/gengateway/template_test.go index cabc284319f..f2dc868487e 100644 --- a/protoc-gen-grpc-gateway/internal/gengateway/template_test.go +++ b/protoc-gen-grpc-gateway/internal/gengateway/template_test.go @@ -316,7 +316,7 @@ func TestApplyTemplateRequestWithClientStreaming(t *testing.T) { }, { serverStreaming: true, - sigWant: `func request_ExampleService_Echo_0(ctx context.Context, marshaler runtime.Marshaler, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, runtime.ServerMetadata, error) {`, + sigWant: `func request_ExampleService_Echo_0(ctx context.Context, marshaler runtime.Marshaler, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, runtime.ServerMetadata, chan error, error) {`, }, } { meth.ServerStreaming = proto.Bool(spec.serverStreaming)