From 739d2eee46aab952542d3c10ff49ce2c32c45589 Mon Sep 17 00:00:00 2001 From: Lyrian Date: Thu, 24 Oct 2024 14:43:20 -0700 Subject: [PATCH] feat: use StreamErrorHandler to send back invalid argument error in bidirectional streaming (#4864) * feat: use StreamErrorHandler to send back invalid argument error in bidirectional streaming * add unit tests and fix casing * try integration tests * give up on async request streaming in integration test * stablized local runs --------- Co-authored-by: Ian Lee --- examples/internal/integration/BUILD.bazel | 1 + .../internal/integration/integration_test.go | 94 +++++++++++++++++++ .../proto/examplepb/flow_combination.pb.gw.go | 2 +- .../internal/proto/examplepb/stream.pb.go | 68 ++++++++------ .../internal/proto/examplepb/stream.pb.gw.go | 91 +++++++++++++++++- .../internal/proto/examplepb/stream.proto | 7 ++ .../proto/examplepb/stream.swagger.json | 40 ++++++++ .../proto/examplepb/stream_grpc.pb.go | 43 ++++++++- .../internal/server/a_bit_of_everything.go | 52 ++++++++++ .../internal/gengateway/template.go | 2 +- runtime/errors.go | 15 +++ runtime/errors_test.go | 50 ++++++++++ 12 files changed, 430 insertions(+), 35 deletions(-) diff --git a/examples/internal/integration/BUILD.bazel b/examples/internal/integration/BUILD.bazel index 08d1a021385..9ae6ba31a23 100644 --- a/examples/internal/integration/BUILD.bazel +++ b/examples/internal/integration/BUILD.bazel @@ -25,6 +25,7 @@ go_test( "@org_golang_google_protobuf//encoding/protojson", "@org_golang_google_protobuf//proto", "@org_golang_google_protobuf//testing/protocmp", + "@org_golang_google_protobuf//types/known/durationpb", "@org_golang_google_protobuf//types/known/emptypb", "@org_golang_google_protobuf//types/known/fieldmaskpb", "@org_golang_google_protobuf//types/known/structpb", diff --git a/examples/internal/integration/integration_test.go b/examples/internal/integration/integration_test.go index d04b5f43484..f7caa5fdf43 100644 --- a/examples/internal/integration/integration_test.go +++ b/examples/internal/integration/integration_test.go @@ -27,6 +27,7 @@ import ( "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/emptypb" fieldmaskpb "google.golang.org/protobuf/types/known/fieldmaskpb" "google.golang.org/protobuf/types/known/structpb" @@ -521,6 +522,7 @@ func TestABE(t *testing.T) { testABEDownload(t, 8088) testABEBulkEcho(t, 8088) testABEBulkEchoZeroLength(t, 8088) + testABEBulkEchoDurationError(t, 8088) testAdditionalBindings(t, 8088) testABERepeated(t, 8088) testABEExists(t, 8088) @@ -1448,6 +1450,98 @@ func testABEBulkEchoZeroLength(t *testing.T, port int) { } } +func testABEBulkEchoDurationError(t *testing.T, port int) { + reqr, reqw := io.Pipe() + var wg sync.WaitGroup + var want []*durationpb.Duration + wg.Add(1) + go func() { + defer wg.Done() + defer reqw.Close() + for i := 0; i < 10; i++ { + s := fmt.Sprintf("%d.123s", i) + if i == 5 { + s = "invalidDurationFormat" + } + buf, err := marshaler.Marshal(s) + if err != nil { + t.Errorf("marshaler.Marshal(%v) failed with %v; want success", s, err) + return + } + if _, err = reqw.Write(buf); err != nil { + t.Errorf("reqw.Write(%q) failed with %v; want success", string(buf), err) + return + } + want = append(want, &durationpb.Duration{Seconds: int64(i), Nanos: int32(0.123 * 1e9)}) + } + }() + apiURL := fmt.Sprintf("http://localhost:%d/v1/example/a_bit_of_everything/echo_duration", port) + req, err := http.NewRequest("POST", apiURL, reqr) + if err != nil { + t.Errorf("http.NewRequest(%q, %q, reqr) failed with %v; want success", "POST", apiURL, err) + return + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Transfer-Encoding", "chunked") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Errorf("http.Post(%q, %q, req) failed with %v; want success", apiURL, "application/json", err) + return + } + defer resp.Body.Close() + if got, want := resp.StatusCode, http.StatusOK; got != want { + t.Errorf("resp.StatusCode = %d; want %d", got, want) + } + + var got []*durationpb.Duration + var invalidArgumentCount int + wg.Add(1) + go func() { + defer wg.Done() + + dec := marshaler.NewDecoder(resp.Body) + for i := 0; ; i++ { + var item struct { + Result json.RawMessage `json:"result"` + Error map[string]interface{} `json:"error"` + } + err := dec.Decode(&item) + if err == io.EOF { + break + } + if err != nil { + t.Errorf("dec.Decode(&item) failed with %v; want success; i = %d", err, i) + } + if len(item.Error) != 0 { + code, ok := item.Error["code"].(float64) + if !ok { + t.Errorf("item.Error[code] not found or not a number: %#v; i = %d", item.Error, i) + } else if int32(code) == 3 { + invalidArgumentCount++ + } else { + t.Errorf("item.Error[code] = %v; want 3; i = %d", code, i) + } + continue + } + + msg := new(durationpb.Duration) + if err := marshaler.Unmarshal(item.Result, msg); err != nil { + t.Errorf("marshaler.Unmarshal(%q, msg) failed with %v; want success", item.Result, err) + } + got = append(got, msg) + } + + if invalidArgumentCount != 1 { + t.Errorf("got %d errors with code 3; want exactly 1", invalidArgumentCount) + } + }() + + wg.Wait() + if diff := cmp.Diff(got, want[:5], protocmp.Transform()); diff != "" { + t.Error(diff) + } +} + func testAdditionalBindings(t *testing.T, port int) { for i, f := range []func() *http.Response{ func() *http.Response { diff --git a/examples/internal/proto/examplepb/flow_combination.pb.gw.go b/examples/internal/proto/examplepb/flow_combination.pb.gw.go index 98af744ae42..17db93e3001 100644 --- a/examples/internal/proto/examplepb/flow_combination.pb.gw.go +++ b/examples/internal/proto/examplepb/flow_combination.pb.gw.go @@ -1907,7 +1907,7 @@ func RegisterFlowCombinationHandlerClient(ctx context.Context, mux *runtime.Serv go func() { for err := range reqErrChan { if err != nil && err != io.EOF { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + runtime.HTTPStreamError(annotatedContext, mux, outboundMarshaler, w, req, err) } } }() diff --git a/examples/internal/proto/examplepb/stream.pb.go b/examples/internal/proto/examplepb/stream.pb.go index f8b8dfcf79f..1e2cb3001bc 100644 --- a/examples/internal/proto/examplepb/stream.pb.go +++ b/examples/internal/proto/examplepb/stream.pb.go @@ -12,6 +12,7 @@ import ( httpbody "google.golang.org/genproto/googleapis/api/httpbody" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" + durationpb "google.golang.org/protobuf/types/known/durationpb" emptypb "google.golang.org/protobuf/types/known/emptypb" reflect "reflect" sync "sync" @@ -88,11 +89,13 @@ var file_examples_internal_proto_examplepb_stream_proto_rawDesc = []byte{ 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x61, 0x6e, 0x6e, 0x6f, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x19, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x68, 0x74, 0x74, 0x70, 0x62, 0x6f, 0x64, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x1a, 0x1e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, + 0x75, 0x66, 0x2f, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1b, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x65, 0x6d, 0x70, 0x74, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x1f, 0x0a, 0x07, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x32, - 0x89, 0x05, 0x0a, 0x0d, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, + 0x92, 0x06, 0x0a, 0x0d, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x99, 0x01, 0x0a, 0x0a, 0x42, 0x75, 0x6c, 0x6b, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x12, 0x40, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x67, 0x61, 0x74, 0x65, 0x77, 0x61, 0x79, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x73, 0x2e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, @@ -125,20 +128,28 @@ var file_examples_internal_proto_examplepb_stream_proto_rawDesc = []byte{ 0x93, 0x02, 0x29, 0x3a, 0x01, 0x2a, 0x22, 0x24, 0x2f, 0x76, 0x31, 0x2f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2f, 0x61, 0x5f, 0x62, 0x69, 0x74, 0x5f, 0x6f, 0x66, 0x5f, 0x65, 0x76, 0x65, 0x72, 0x79, 0x74, 0x68, 0x69, 0x6e, 0x67, 0x2f, 0x65, 0x63, 0x68, 0x6f, 0x28, 0x01, 0x30, 0x01, - 0x12, 0x79, 0x0a, 0x08, 0x44, 0x6f, 0x77, 0x6e, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x37, 0x2e, 0x67, - 0x72, 0x70, 0x63, 0x2e, 0x67, 0x61, 0x74, 0x65, 0x77, 0x61, 0x79, 0x2e, 0x65, 0x78, 0x61, 0x6d, - 0x70, 0x6c, 0x65, 0x73, 0x2e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2e, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x70, 0x62, 0x2e, 0x4f, 0x70, - 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x14, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x61, - 0x70, 0x69, 0x2e, 0x48, 0x74, 0x74, 0x70, 0x42, 0x6f, 0x64, 0x79, 0x22, 0x1c, 0x82, 0xd3, 0xe4, - 0x93, 0x02, 0x16, 0x12, 0x14, 0x2f, 0x76, 0x31, 0x2f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, - 0x2f, 0x64, 0x6f, 0x77, 0x6e, 0x6c, 0x6f, 0x61, 0x64, 0x30, 0x01, 0x42, 0x4d, 0x5a, 0x4b, 0x67, - 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2d, 0x65, - 0x63, 0x6f, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2d, 0x67, 0x61, - 0x74, 0x65, 0x77, 0x61, 0x79, 0x2f, 0x76, 0x32, 0x2f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, - 0x73, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x2f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x33, + 0x12, 0x86, 0x01, 0x0a, 0x10, 0x42, 0x75, 0x6c, 0x6b, 0x45, 0x63, 0x68, 0x6f, 0x44, 0x75, 0x72, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x1a, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, + 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x38, 0x82, 0xd3, 0xe4, + 0x93, 0x02, 0x32, 0x3a, 0x01, 0x2a, 0x22, 0x2d, 0x2f, 0x76, 0x31, 0x2f, 0x65, 0x78, 0x61, 0x6d, + 0x70, 0x6c, 0x65, 0x2f, 0x61, 0x5f, 0x62, 0x69, 0x74, 0x5f, 0x6f, 0x66, 0x5f, 0x65, 0x76, 0x65, + 0x72, 0x79, 0x74, 0x68, 0x69, 0x6e, 0x67, 0x2f, 0x65, 0x63, 0x68, 0x6f, 0x5f, 0x64, 0x75, 0x72, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x28, 0x01, 0x30, 0x01, 0x12, 0x79, 0x0a, 0x08, 0x44, 0x6f, 0x77, + 0x6e, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x37, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x2e, 0x67, 0x61, 0x74, + 0x65, 0x77, 0x61, 0x79, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x73, 0x2e, 0x69, 0x6e, + 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x65, 0x78, 0x61, + 0x6d, 0x70, 0x6c, 0x65, 0x70, 0x62, 0x2e, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x14, + 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x48, 0x74, 0x74, 0x70, + 0x42, 0x6f, 0x64, 0x79, 0x22, 0x1c, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x16, 0x12, 0x14, 0x2f, 0x76, + 0x31, 0x2f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2f, 0x64, 0x6f, 0x77, 0x6e, 0x6c, 0x6f, + 0x61, 0x64, 0x30, 0x01, 0x42, 0x4d, 0x5a, 0x4b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, + 0x6f, 0x6d, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2d, 0x65, 0x63, 0x6f, 0x73, 0x79, 0x73, 0x74, 0x65, + 0x6d, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2d, 0x67, 0x61, 0x74, 0x65, 0x77, 0x61, 0x79, 0x2f, 0x76, + 0x32, 0x2f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x73, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, + 0x6e, 0x61, 0x6c, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, + 0x65, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -155,23 +166,26 @@ func file_examples_internal_proto_examplepb_stream_proto_rawDescGZIP() []byte { var file_examples_internal_proto_examplepb_stream_proto_msgTypes = make([]protoimpl.MessageInfo, 1) var file_examples_internal_proto_examplepb_stream_proto_goTypes = []any{ - (*Options)(nil), // 0: grpc.gateway.examples.internal.proto.examplepb.Options - (*ABitOfEverything)(nil), // 1: grpc.gateway.examples.internal.proto.examplepb.ABitOfEverything - (*sub.StringMessage)(nil), // 2: grpc.gateway.examples.internal.proto.sub.StringMessage - (*emptypb.Empty)(nil), // 3: google.protobuf.Empty - (*httpbody.HttpBody)(nil), // 4: google.api.HttpBody + (*Options)(nil), // 0: grpc.gateway.examples.internal.proto.examplepb.Options + (*ABitOfEverything)(nil), // 1: grpc.gateway.examples.internal.proto.examplepb.ABitOfEverything + (*sub.StringMessage)(nil), // 2: grpc.gateway.examples.internal.proto.sub.StringMessage + (*durationpb.Duration)(nil), // 3: google.protobuf.Duration + (*emptypb.Empty)(nil), // 4: google.protobuf.Empty + (*httpbody.HttpBody)(nil), // 5: google.api.HttpBody } var file_examples_internal_proto_examplepb_stream_proto_depIdxs = []int32{ 1, // 0: grpc.gateway.examples.internal.proto.examplepb.StreamService.BulkCreate:input_type -> grpc.gateway.examples.internal.proto.examplepb.ABitOfEverything 0, // 1: grpc.gateway.examples.internal.proto.examplepb.StreamService.List:input_type -> grpc.gateway.examples.internal.proto.examplepb.Options 2, // 2: grpc.gateway.examples.internal.proto.examplepb.StreamService.BulkEcho:input_type -> grpc.gateway.examples.internal.proto.sub.StringMessage - 0, // 3: grpc.gateway.examples.internal.proto.examplepb.StreamService.Download:input_type -> grpc.gateway.examples.internal.proto.examplepb.Options - 3, // 4: grpc.gateway.examples.internal.proto.examplepb.StreamService.BulkCreate:output_type -> google.protobuf.Empty - 1, // 5: grpc.gateway.examples.internal.proto.examplepb.StreamService.List:output_type -> grpc.gateway.examples.internal.proto.examplepb.ABitOfEverything - 2, // 6: grpc.gateway.examples.internal.proto.examplepb.StreamService.BulkEcho:output_type -> grpc.gateway.examples.internal.proto.sub.StringMessage - 4, // 7: grpc.gateway.examples.internal.proto.examplepb.StreamService.Download:output_type -> google.api.HttpBody - 4, // [4:8] is the sub-list for method output_type - 0, // [0:4] is the sub-list for method input_type + 3, // 3: grpc.gateway.examples.internal.proto.examplepb.StreamService.BulkEchoDuration:input_type -> google.protobuf.Duration + 0, // 4: grpc.gateway.examples.internal.proto.examplepb.StreamService.Download:input_type -> grpc.gateway.examples.internal.proto.examplepb.Options + 4, // 5: grpc.gateway.examples.internal.proto.examplepb.StreamService.BulkCreate:output_type -> google.protobuf.Empty + 1, // 6: grpc.gateway.examples.internal.proto.examplepb.StreamService.List:output_type -> grpc.gateway.examples.internal.proto.examplepb.ABitOfEverything + 2, // 7: grpc.gateway.examples.internal.proto.examplepb.StreamService.BulkEcho:output_type -> grpc.gateway.examples.internal.proto.sub.StringMessage + 3, // 8: grpc.gateway.examples.internal.proto.examplepb.StreamService.BulkEchoDuration:output_type -> google.protobuf.Duration + 5, // 9: grpc.gateway.examples.internal.proto.examplepb.StreamService.Download:output_type -> google.api.HttpBody + 5, // [5:10] is the sub-list for method output_type + 0, // [0:5] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name diff --git a/examples/internal/proto/examplepb/stream.pb.gw.go b/examples/internal/proto/examplepb/stream.pb.gw.go index 9a419c81512..b86d0902d82 100644 --- a/examples/internal/proto/examplepb/stream.pb.gw.go +++ b/examples/internal/proto/examplepb/stream.pb.gw.go @@ -22,6 +22,7 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/durationpb" ) // Suppress "imported and not used" errors @@ -151,6 +152,53 @@ func request_StreamService_BulkEcho_0(ctx context.Context, marshaler runtime.Mar return stream, metadata, errChan, nil } +func request_StreamService_BulkEchoDuration_0(ctx context.Context, marshaler runtime.Marshaler, client StreamServiceClient, req *http.Request, pathParams map[string]string) (StreamService_BulkEchoDurationClient, runtime.ServerMetadata, chan error, error) { + var metadata runtime.ServerMetadata + errChan := make(chan error, 1) + stream, err := client.BulkEchoDuration(ctx) + if err != nil { + grpclog.Errorf("Failed to start streaming: %v", err) + close(errChan) + return nil, metadata, errChan, err + } + dec := marshaler.NewDecoder(req.Body) + handleSend := func() error { + var protoReq durationpb.Duration + err := dec.Decode(&protoReq) + if err == io.EOF { + return err + } + if err != nil { + grpclog.Errorf("Failed to decode request: %v", err) + return status.Errorf(codes.InvalidArgument, "Failed to decode request: %v", err) + } + if err := stream.Send(&protoReq); err != nil { + grpclog.Errorf("Failed to send request: %v", err) + return err + } + return nil + } + go func() { + defer close(errChan) + for { + if err := handleSend(); err != nil { + errChan <- err + break + } + } + if err := stream.CloseSend(); err != nil { + grpclog.Errorf("Failed to terminate client stream: %v", err) + } + }() + header, err := stream.Header() + if err != nil { + grpclog.Errorf("Failed to get header from client: %v", err) + return nil, metadata, errChan, err + } + metadata.HeaderMD = header + return stream, metadata, errChan, nil +} + var ( filter_StreamService_Download_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)} ) @@ -207,6 +255,13 @@ func RegisterStreamServiceHandlerServer(ctx context.Context, mux *runtime.ServeM return }) + mux.Handle("POST", pattern_StreamService_BulkEchoDuration_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + err := status.Error(codes.Unimplemented, "streaming calls are not yet supported in the in-process transport") + _, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + }) + mux.Handle("GET", pattern_StreamService_Download_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { err := status.Error(codes.Unimplemented, "streaming calls are not yet supported in the in-process transport") _, outboundMarshaler := runtime.MarshalerForRequest(mux, req) @@ -320,7 +375,7 @@ func RegisterStreamServiceHandlerClient(ctx context.Context, mux *runtime.ServeM go func() { for err := range reqErrChan { if err != nil && err != io.EOF { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + runtime.HTTPStreamError(annotatedContext, mux, outboundMarshaler, w, req, err) } } }() @@ -329,6 +384,36 @@ func RegisterStreamServiceHandlerClient(ctx context.Context, mux *runtime.ServeM }) + mux.Handle("POST", pattern_StreamService_BulkEchoDuration_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + var err error + var annotatedContext context.Context + annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/grpc.gateway.examples.internal.proto.examplepb.StreamService/BulkEchoDuration", runtime.WithHTTPPathPattern("/v1/example/a_bit_of_everything/echo_duration")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + resp, md, reqErrChan, err := request_StreamService_BulkEchoDuration_0(annotatedContext, inboundMarshaler, client, req, pathParams) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + go func() { + for err := range reqErrChan { + if err != nil && err != io.EOF { + runtime.HTTPStreamError(annotatedContext, mux, outboundMarshaler, w, req, err) + } + } + }() + + forward_StreamService_BulkEchoDuration_0(annotatedContext, mux, outboundMarshaler, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...) + + }) + mux.Handle("GET", pattern_StreamService_Download_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() @@ -361,6 +446,8 @@ var ( pattern_StreamService_BulkEcho_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"v1", "example", "a_bit_of_everything", "echo"}, "")) + pattern_StreamService_BulkEchoDuration_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"v1", "example", "a_bit_of_everything", "echo_duration"}, "")) + pattern_StreamService_Download_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v1", "example", "download"}, "")) ) @@ -371,5 +458,7 @@ var ( forward_StreamService_BulkEcho_0 = runtime.ForwardResponseStream + forward_StreamService_BulkEchoDuration_0 = runtime.ForwardResponseStream + forward_StreamService_Download_0 = runtime.ForwardResponseStream ) diff --git a/examples/internal/proto/examplepb/stream.proto b/examples/internal/proto/examplepb/stream.proto index 8d1a58db8bd..5add38d0492 100644 --- a/examples/internal/proto/examplepb/stream.proto +++ b/examples/internal/proto/examplepb/stream.proto @@ -6,6 +6,7 @@ import "examples/internal/proto/examplepb/a_bit_of_everything.proto"; import "examples/internal/proto/sub/message.proto"; import "google/api/annotations.proto"; import "google/api/httpbody.proto"; +import "google/protobuf/duration.proto"; import "google/protobuf/empty.proto"; option go_package = "github.com/grpc-ecosystem/grpc-gateway/v2/examples/internal/proto/examplepb"; @@ -27,6 +28,12 @@ service StreamService { body: "*" }; } + rpc BulkEchoDuration(stream google.protobuf.Duration) returns (stream google.protobuf.Duration) { + option (google.api.http) = { + post: "/v1/example/a_bit_of_everything/echo_duration" + body: "*" + }; + } rpc Download(Options) returns (stream google.api.HttpBody) { option (google.api.http) = {get: "/v1/example/download"}; diff --git a/examples/internal/proto/examplepb/stream.swagger.json b/examples/internal/proto/examplepb/stream.swagger.json index 0b521e59415..42cf40cf197 100644 --- a/examples/internal/proto/examplepb/stream.swagger.json +++ b/examples/internal/proto/examplepb/stream.swagger.json @@ -131,6 +131,46 @@ ] } }, + "/v1/example/a_bit_of_everything/echo_duration": { + "post": { + "operationId": "StreamService_BulkEchoDuration", + "responses": { + "200": { + "description": "A successful response.(streaming responses)", + "schema": { + "type": "object", + "properties": { + "result": {}, + "error": { + "$ref": "#/definitions/rpcStatus" + } + }, + "title": "Stream result of protobufDuration" + } + }, + "default": { + "description": "An unexpected error response.", + "schema": { + "$ref": "#/definitions/rpcStatus" + } + } + }, + "parameters": [ + { + "name": "body", + "description": " (streaming inputs)", + "in": "body", + "required": true, + "schema": { + "type": "string" + } + } + ], + "tags": [ + "StreamService" + ] + } + }, "/v1/example/download": { "get": { "operationId": "StreamService_Download", diff --git a/examples/internal/proto/examplepb/stream_grpc.pb.go b/examples/internal/proto/examplepb/stream_grpc.pb.go index 8bfc29c7f64..72b3509d6e4 100644 --- a/examples/internal/proto/examplepb/stream_grpc.pb.go +++ b/examples/internal/proto/examplepb/stream_grpc.pb.go @@ -13,6 +13,7 @@ import ( grpc "google.golang.org/grpc" codes "google.golang.org/grpc/codes" status "google.golang.org/grpc/status" + durationpb "google.golang.org/protobuf/types/known/durationpb" emptypb "google.golang.org/protobuf/types/known/emptypb" ) @@ -22,10 +23,11 @@ import ( const _ = grpc.SupportPackageIsVersion9 const ( - StreamService_BulkCreate_FullMethodName = "/grpc.gateway.examples.internal.proto.examplepb.StreamService/BulkCreate" - StreamService_List_FullMethodName = "/grpc.gateway.examples.internal.proto.examplepb.StreamService/List" - StreamService_BulkEcho_FullMethodName = "/grpc.gateway.examples.internal.proto.examplepb.StreamService/BulkEcho" - StreamService_Download_FullMethodName = "/grpc.gateway.examples.internal.proto.examplepb.StreamService/Download" + StreamService_BulkCreate_FullMethodName = "/grpc.gateway.examples.internal.proto.examplepb.StreamService/BulkCreate" + StreamService_List_FullMethodName = "/grpc.gateway.examples.internal.proto.examplepb.StreamService/List" + StreamService_BulkEcho_FullMethodName = "/grpc.gateway.examples.internal.proto.examplepb.StreamService/BulkEcho" + StreamService_BulkEchoDuration_FullMethodName = "/grpc.gateway.examples.internal.proto.examplepb.StreamService/BulkEchoDuration" + StreamService_Download_FullMethodName = "/grpc.gateway.examples.internal.proto.examplepb.StreamService/Download" ) // StreamServiceClient is the client API for StreamService service. @@ -37,6 +39,7 @@ type StreamServiceClient interface { BulkCreate(ctx context.Context, opts ...grpc.CallOption) (grpc.ClientStreamingClient[ABitOfEverything, emptypb.Empty], error) List(ctx context.Context, in *Options, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ABitOfEverything], error) BulkEcho(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[sub.StringMessage, sub.StringMessage], error) + BulkEchoDuration(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[durationpb.Duration, durationpb.Duration], error) Download(ctx context.Context, in *Options, opts ...grpc.CallOption) (grpc.ServerStreamingClient[httpbody.HttpBody], error) } @@ -93,9 +96,22 @@ func (c *streamServiceClient) BulkEcho(ctx context.Context, opts ...grpc.CallOpt // This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. type StreamService_BulkEchoClient = grpc.BidiStreamingClient[sub.StringMessage, sub.StringMessage] +func (c *streamServiceClient) BulkEchoDuration(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[durationpb.Duration, durationpb.Duration], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &StreamService_ServiceDesc.Streams[3], StreamService_BulkEchoDuration_FullMethodName, cOpts...) + if err != nil { + return nil, err + } + x := &grpc.GenericClientStream[durationpb.Duration, durationpb.Duration]{ClientStream: stream} + return x, nil +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type StreamService_BulkEchoDurationClient = grpc.BidiStreamingClient[durationpb.Duration, durationpb.Duration] + func (c *streamServiceClient) Download(ctx context.Context, in *Options, opts ...grpc.CallOption) (grpc.ServerStreamingClient[httpbody.HttpBody], error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) - stream, err := c.cc.NewStream(ctx, &StreamService_ServiceDesc.Streams[3], StreamService_Download_FullMethodName, cOpts...) + stream, err := c.cc.NewStream(ctx, &StreamService_ServiceDesc.Streams[4], StreamService_Download_FullMethodName, cOpts...) if err != nil { return nil, err } @@ -121,6 +137,7 @@ type StreamServiceServer interface { BulkCreate(grpc.ClientStreamingServer[ABitOfEverything, emptypb.Empty]) error List(*Options, grpc.ServerStreamingServer[ABitOfEverything]) error BulkEcho(grpc.BidiStreamingServer[sub.StringMessage, sub.StringMessage]) error + BulkEchoDuration(grpc.BidiStreamingServer[durationpb.Duration, durationpb.Duration]) error Download(*Options, grpc.ServerStreamingServer[httpbody.HttpBody]) error } @@ -140,6 +157,9 @@ func (UnimplementedStreamServiceServer) List(*Options, grpc.ServerStreamingServe func (UnimplementedStreamServiceServer) BulkEcho(grpc.BidiStreamingServer[sub.StringMessage, sub.StringMessage]) error { return status.Errorf(codes.Unimplemented, "method BulkEcho not implemented") } +func (UnimplementedStreamServiceServer) BulkEchoDuration(grpc.BidiStreamingServer[durationpb.Duration, durationpb.Duration]) error { + return status.Errorf(codes.Unimplemented, "method BulkEchoDuration not implemented") +} func (UnimplementedStreamServiceServer) Download(*Options, grpc.ServerStreamingServer[httpbody.HttpBody]) error { return status.Errorf(codes.Unimplemented, "method Download not implemented") } @@ -188,6 +208,13 @@ func _StreamService_BulkEcho_Handler(srv interface{}, stream grpc.ServerStream) // This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. type StreamService_BulkEchoServer = grpc.BidiStreamingServer[sub.StringMessage, sub.StringMessage] +func _StreamService_BulkEchoDuration_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(StreamServiceServer).BulkEchoDuration(&grpc.GenericServerStream[durationpb.Duration, durationpb.Duration]{ServerStream: stream}) +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type StreamService_BulkEchoDurationServer = grpc.BidiStreamingServer[durationpb.Duration, durationpb.Duration] + func _StreamService_Download_Handler(srv interface{}, stream grpc.ServerStream) error { m := new(Options) if err := stream.RecvMsg(m); err != nil { @@ -223,6 +250,12 @@ var StreamService_ServiceDesc = grpc.ServiceDesc{ ServerStreams: true, ClientStreams: true, }, + { + StreamName: "BulkEchoDuration", + Handler: _StreamService_BulkEchoDuration_Handler, + ServerStreams: true, + ClientStreams: true, + }, { StreamName: "Download", Handler: _StreamService_Download_Handler, diff --git a/examples/internal/server/a_bit_of_everything.go b/examples/internal/server/a_bit_of_everything.go index b2583baccee..fbf5847b146 100644 --- a/examples/internal/server/a_bit_of_everything.go +++ b/examples/internal/server/a_bit_of_everything.go @@ -325,6 +325,58 @@ func (s *_ABitOfEverythingServer) BulkEcho(stream examples.StreamService_BulkEch return nil } +func (s *_ABitOfEverythingServer) BulkEchoDuration(stream examples.StreamService_BulkEchoDurationServer) error { + hmd := metadata.New(map[string]string{ + "foo": "foo1", + "bar": "bar1", + }) + if err := stream.SendHeader(hmd); err != nil { + return err + } + + // Channel to coordinate between read and write goroutines + msgChan := make(chan *durationpb.Duration) + errChan := make(chan error) + + go func() { + defer close(msgChan) + for { + msg, err := stream.Recv() + if err == io.EOF { + return + } + if err != nil { + errChan <- err + return + } + msgChan <- msg + } + }() + + go func() { + for msg := range msgChan { + grpclog.Info(msg) + if err := stream.Send(msg); err != nil { + errChan <- err + return + } + } + // Sleep to mock the delay in receiving the request close. + // Accommodates the integration test client which is not a true + // bidirectional streaming client that supports request streaming. + time.Sleep(1 * time.Second) + close(errChan) + }() + + err := <-errChan + + stream.SetTrailer(metadata.New(map[string]string{ + "foo": "foo2", + "bar": "bar2", + })) + return err +} + func (s *_ABitOfEverythingServer) DeepPathEcho(ctx context.Context, msg *examples.ABitOfEverything) (*examples.ABitOfEverything, error) { s.m.Lock() defer s.m.Unlock() diff --git a/protoc-gen-grpc-gateway/internal/gengateway/template.go b/protoc-gen-grpc-gateway/internal/gengateway/template.go index 7f2924e13b1..937ce682bdb 100644 --- a/protoc-gen-grpc-gateway/internal/gengateway/template.go +++ b/protoc-gen-grpc-gateway/internal/gengateway/template.go @@ -747,7 +747,7 @@ func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client(ctx context.Context, go func() { for err := range reqErrChan { if err != nil && err != io.EOF { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + runtime.HTTPStreamError(annotatedContext, mux, outboundMarshaler, w, req, err) } } }() diff --git a/runtime/errors.go b/runtime/errors.go index 01f57341918..41cd4f5030e 100644 --- a/runtime/errors.go +++ b/runtime/errors.go @@ -81,6 +81,21 @@ func HTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.R mux.errorHandler(ctx, mux, marshaler, w, r, err) } +// HTTPStreamError uses the mux-configured stream error handler to notify error to the client without closing the connection. +func HTTPStreamError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, err error) { + st := mux.streamErrorHandler(ctx, err) + msg := errorChunk(st) + buf, err := marshaler.Marshal(msg) + if err != nil { + grpclog.Errorf("Failed to marshal an error: %v", err) + return + } + if _, err := w.Write(buf); err != nil { + grpclog.Errorf("Failed to notify error to client: %v", err) + return + } +} + // DefaultHTTPErrorHandler is the default error handler. // If "err" is a gRPC Status, the function replies with the status code mapped by HTTPStatusFromCode. // If "err" is a HTTPStatusError, the function replies with the status code provide by that struct. This is diff --git a/runtime/errors_test.go b/runtime/errors_test.go index 39bcc46ffee..c5eba00e21b 100644 --- a/runtime/errors_test.go +++ b/runtime/errors_test.go @@ -1,6 +1,7 @@ package runtime_test import ( + "bytes" "context" "errors" "net/http" @@ -132,3 +133,52 @@ func TestDefaultHTTPError(t *testing.T) { }) } } + +func TestHTTPStreamError(t *testing.T) { + ctx := context.Background() + + for _, tc := range []struct { + name string + err error + expectedStatus *status.Status + expectedResponse []byte + }{ + { + name: "Simple error", + err: errors.New("simple error"), + expectedStatus: status.New(codes.Unknown, "simple error"), + expectedResponse: []byte(`{"error":{"code":2,"message":"simple error"}}`), + }, + { + name: "Invalid request error", + err: status.Error(codes.InvalidArgument, "invalid request"), + expectedStatus: status.New(codes.InvalidArgument, "invalid request"), + expectedResponse: []byte(`{"error":{"code":3,"message":"invalid request"}}`), + }, + } { + t.Run(tc.name, func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + + mux := runtime.NewServeMux(runtime.WithStreamErrorHandler( + runtime.DefaultStreamErrorHandler, + )) + + marshaler := &runtime.JSONPb{} + + runtime.HTTPStreamError(ctx, mux, marshaler, w, r, tc.err) + + if w.Code != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code) + } + + if !proto.Equal(status.Convert(tc.err).Proto(), tc.expectedStatus.Proto()) { + t.Errorf("Expected status %v, got %v", tc.expectedStatus, status.Convert(tc.err)) + } + + if !bytes.Equal(w.Body.Bytes(), tc.expectedResponse) { + t.Errorf("Expected response %s, got %s", tc.expectedResponse, w.Body.Bytes()) + } + }) + } +}