From 8eb1d0f0d6826c96871697a7e68b50746edd275d Mon Sep 17 00:00:00 2001 From: Will Plano <51954103+william-plano-oxb@users.noreply.github.com> Date: Mon, 30 Sep 2019 10:34:02 +0100 Subject: [PATCH] Support json names in field mask generation (#1050) * Support json names in field mask generation Added tests for using jsonpb with OrigName: false Added benchmarks for FieldMaskFromRequestBody called with a message descriptor * Add descriptor dependency to go_proto_library --- docs/_docs/patch.md | 2 - examples/integration/BUILD.bazel | 2 + examples/integration/fieldmask_test.go | 163 ++++++++++++++++++ examples/integration/integration_test.go | 25 +++ examples/proto/examplepb/BUILD.bazel | 2 + .../examplepb/a_bit_of_everything.pb.gw.go | 9 +- .../proto/examplepb/echo_service.pb.gw.go | 3 + .../proto/examplepb/flow_combination.pb.gw.go | 3 + .../examplepb/non_standard_names.pb.gw.go | 15 +- .../examplepb/response_body_service.pb.gw.go | 3 + examples/proto/examplepb/stream.pb.gw.go | 3 + .../unannotated_echo_service.pb.gw.go | 3 + examples/proto/examplepb/wrappers.pb.gw.go | 3 + .../gengateway/generator.go | 1 + .../gengateway/template.go | 17 +- runtime/BUILD.bazel | 2 + runtime/fieldmask.go | 39 ++++- runtime/fieldmask_test.go | 6 +- 18 files changed, 285 insertions(+), 16 deletions(-) create mode 100644 examples/integration/fieldmask_test.go diff --git a/docs/_docs/patch.md b/docs/_docs/patch.md index 42b96eab25f..2eaac3b0166 100644 --- a/docs/_docs/patch.md +++ b/docs/_docs/patch.md @@ -10,8 +10,6 @@ There are two scenarios: - The FieldMask is hidden from the REST request as per the [Google API design guide](https://cloud.google.com/apis/design/standard_methods#update) (as in the first additional binding in the [UpdateV2](https://github.com/grpc-ecosystem/grpc-gateway/blob/master/examples/proto/examplepb/a_bit_of_everything.proto#L366) example). In this case, the FieldMask is updated from the request body and set in the gRPC request message. - The FieldMask is exposed to the REST request (as in the second additional binding in the [UpdateV2](https://github.com/grpc-ecosystem/grpc-gateway/blob/master/examples/proto/examplepb/a_bit_of_everything.proto#L370) example). For this case, the field mask is left untouched by the gateway. -Currently this feature is not compatible with using the JSON marshaller option to use json names in the http request/response described [here](https://grpc-ecosystem.github.io/grpc-gateway/docs/customizingyourgateway.html#using-camelcase-for-json). If you want to use json names then you may with to disable this feature with the command line option `-allow_patch_feature=false` - ## Example Usage 1. Create PATCH request. diff --git a/examples/integration/BUILD.bazel b/examples/integration/BUILD.bazel index b37ad78ec78..cee9f33f7f8 100644 --- a/examples/integration/BUILD.bazel +++ b/examples/integration/BUILD.bazel @@ -4,6 +4,7 @@ go_test( name = "go_default_test", srcs = [ "client_test.go", + "fieldmask_test.go", "integration_test.go", "main_test.go", "proto_error_test.go", @@ -19,6 +20,7 @@ go_test( "//examples/server:go_default_library", "//runtime:go_default_library", "@com_github_golang_glog//:go_default_library", + "@com_github_golang_protobuf//descriptor:go_default_library_gen", "@com_github_golang_protobuf//jsonpb:go_default_library_gen", "@com_github_golang_protobuf//proto:go_default_library", "@go_googleapis//google/rpc:status_go_proto", diff --git a/examples/integration/fieldmask_test.go b/examples/integration/fieldmask_test.go new file mode 100644 index 00000000000..2ebe44ff05c --- /dev/null +++ b/examples/integration/fieldmask_test.go @@ -0,0 +1,163 @@ +package integration_test + +import ( + "bytes" + "fmt" + "testing" + + "github.com/golang/protobuf/descriptor" + "github.com/grpc-ecosystem/grpc-gateway/examples/proto/examplepb" + "github.com/grpc-ecosystem/grpc-gateway/runtime" + "google.golang.org/genproto/protobuf/field_mask" +) + +func fieldMasksEqual(fm1, fm2 *field_mask.FieldMask) bool { + if fm1 == nil && fm2 == nil { + return true + } + if fm1 == nil || fm2 == nil { + return false + } + if len(fm1.GetPaths()) != len(fm2.GetPaths()) { + return false + } + + paths := make(map[string]bool) + for _, path := range fm1.GetPaths() { + paths[path] = true + } + for _, path := range fm2.GetPaths() { + if _, ok := paths[path]; !ok { + return false + } + } + + return true +} + +func newFieldMask(paths ...string) *field_mask.FieldMask { + return &field_mask.FieldMask{Paths: paths} +} + +func fieldMaskString(fm *field_mask.FieldMask) string { + if fm == nil { + return "" + } + return fmt.Sprintf("%v", fm.GetPaths()) +} + +// N.B. These tests are here rather than in the runtime package because they need +// to import examplepb for the descriptor, which would result in a circular +// dependency since examplepb imports runtime from the pb.gw.go files +func TestFieldMaskFromRequestBodyWithDescriptor(t *testing.T) { + _, md := descriptor.ForMessage(new(examplepb.NonStandardMessage)) + jsonInput := `{"id":"foo", "thing":{"subThing":{"sub_value":"bar"}}}` + expected := newFieldMask("id", "thing.subThing.sub_value") + + actual, err := runtime.FieldMaskFromRequestBody(bytes.NewReader([]byte(jsonInput)), md) + if !fieldMasksEqual(actual, expected) { + t.Errorf("want %v; got %v", fieldMaskString(expected), fieldMaskString(actual)) + } + if err != nil { + t.Errorf("err %v", err) + } +} + +func TestFieldMaskFromRequestBodyWithJsonNames(t *testing.T) { + _, md := descriptor.ForMessage(new(examplepb.NonStandardMessageWithJSONNames)) + jsonInput := `{"ID":"foo", "Thingy":{"SubThing":{"sub_Value":"bar"}}}` + expected := newFieldMask("id", "thing.subThing.sub_value") + + actual, err := runtime.FieldMaskFromRequestBody(bytes.NewReader([]byte(jsonInput)), md) + if !fieldMasksEqual(actual, expected) { + t.Errorf("want %v; got %v", fieldMaskString(expected), fieldMaskString(actual)) + } + if err != nil { + t.Errorf("err %v", err) + } +} + +// avoid compiler optimising benchmark away +var result *field_mask.FieldMask + +func BenchmarkABEFieldMaskFromRequestBodyWithDescriptor(b *testing.B) { + _, md := descriptor.ForMessage(new(examplepb.ABitOfEverything)) + input := `{` + + `"single_nested": {"name": "bar",` + + ` "amount": 10,` + + ` "ok": "TRUE"},` + + `"uuid": "6EC2446F-7E89-4127-B3E6-5C05E6BECBA7",` + + `"nested": [{"name": "bar",` + + ` "amount": 10},` + + ` {"name": "baz",` + + ` "amount": 20}],` + + `"float_value": 1.5,` + + `"double_value": 2.5,` + + `"int64_value": 4294967296,` + + `"uint64_value": 9223372036854775807,` + + `"int32_value": -2147483648,` + + `"fixed64_value": 9223372036854775807,` + + `"fixed32_value": 4294967295,` + + `"bool_value": true,` + + `"string_value": "strprefix/foo",` + + `"bytes_value": "132456",` + + `"uint32_value": 4294967295,` + + `"enum_value": "ONE",` + + `"path_enum_value": "DEF",` + + `"nested_path_enum_value": "JKL",` + + `"sfixed32_value": 2147483647,` + + `"sfixed64_value": -4611686018427387904,` + + `"sint32_value": 2147483647,` + + `"sint64_value": 4611686018427387903,` + + `"repeated_string_value": ["a", "b", "c"],` + + `"oneof_value": {"oneof_string":"x"},` + + `"map_value": {"a": "ONE",` + + ` "b": "ZERO"},` + + `"mapped_string_value": {"a": "x",` + + ` "b": "y"},` + + `"mapped_nested_value": {"a": {"name": "x", "amount": 1},` + + ` "b": {"name": "y", "amount": 2}},` + + `"nonConventionalNameValue": "camelCase",` + + `"timestamp_value": "2016-05-10T10:19:13.123Z",` + + `"repeated_enum_value": ["ONE", "ZERO"],` + + `"repeated_enum_annotation": ["ONE", "ZERO"],` + + `"enum_value_annotation": "ONE",` + + `"repeated_string_annotation": ["a", "b"],` + + `"repeated_nested_annotation": [{"name": "hoge",` + + ` "amount": 10},` + + ` {"name": "fuga",` + + ` "amount": 20}],` + + `"nested_annotation": {"name": "hoge",` + + ` "amount": 10},` + + `"int64_override_type": 12345` + + `}` + var r *field_mask.FieldMask + var err error + for i := 0; i < b.N; i++ { + r, err = runtime.FieldMaskFromRequestBody(bytes.NewReader([]byte(input)), md) + } + if err != nil { + b.Error(err) + } + result = r +} + +func BenchmarkNonStandardFieldMaskFromRequestBodyWithDescriptor(b *testing.B) { + _, md := descriptor.ForMessage(new(examplepb.NonStandardMessage)) + input := `{` + + `"id": "foo",` + + `"Num": 2,` + + `"line_num": 3,` + + `"langIdent": "bar",` + + `"STATUS": "baz"` + + `}` + var r *field_mask.FieldMask + var err error + for i := 0; i < b.N; i++ { + r, err = runtime.FieldMaskFromRequestBody(bytes.NewReader([]byte(input)), md) + } + if err != nil { + b.Error(err) + } + result = r +} diff --git a/examples/integration/integration_test.go b/examples/integration/integration_test.go index 939fecde84a..febbb75b023 100644 --- a/examples/integration/integration_test.go +++ b/examples/integration/integration_test.go @@ -1643,10 +1643,23 @@ func TestNonStandardNames(t *testing.T) { return } }() + go func() { + if err := runGateway( + ctx, + ":8082", + runtime.WithMarshalerOption(runtime.MIMEWildcard, &runtime.JSONPb{OrigName: false, EmitDefaults: true}), + ); err != nil { + t.Errorf("runGateway() failed with %v; want success", err) + return + } + }() if err := waitForGateway(ctx, 8081); err != nil { t.Errorf("waitForGateway(ctx, 8081) failed with %v; want success", err) } + if err := waitForGateway(ctx, 8082); err != nil { + t.Errorf("waitForGateway(ctx, 8082) failed with %v; want success", err) + } for _, tc := range []struct { name string @@ -1667,6 +1680,18 @@ func TestNonStandardNames(t *testing.T) { // N.B. json_names have no effect if not using OrigName: false `{"id":"foo","Num":"1","line_num":"42","langIdent":"English","STATUS":"good","en_GB":"1","no":"yes","thing":{"subThing":{"sub_value":"hi"}}}`, }, + { + "Test standard update method with OrigName: false marshaller option", + 8082, + "update", + `{"id":"foo","Num":"1","lineNum":"42","langIdent":"English","STATUS":"good","enGB":"1","no":"yes","thing":{"subThing":{"subValue":"hi"}}}`, + }, + { + "Test update method using json_names in message with OrigName: false marshaller option", + 8082, + "update_with_json_names", + `{"ID":"foo","Num":"1","LineNum":"42","langIdent":"English","status":"good","En_GB":"1","yes":"no","Thingy":{"SubThing":{"sub_Value":"hi"}}}`, + }, } { t.Run(tc.name, func(t *testing.T) { testNonStandardNames(t, tc.port, tc.method, tc.jsonBody) diff --git a/examples/proto/examplepb/BUILD.bazel b/examples/proto/examplepb/BUILD.bazel index 4888a802ff9..62a3a147115 100644 --- a/examples/proto/examplepb/BUILD.bazel +++ b/examples/proto/examplepb/BUILD.bazel @@ -51,6 +51,7 @@ go_proto_library( "//examples/proto/sub:go_default_library", "//examples/proto/sub2:go_default_library", "//protoc-gen-swagger/options:go_default_library", + "@com_github_golang_protobuf//descriptor:go_default_library_gen", # keep "@go_googleapis//google/api:annotations_go_proto", ], ) @@ -63,6 +64,7 @@ go_library( deps = [ "//runtime:go_default_library", "//utilities:go_default_library", + "@com_github_golang_protobuf//descriptor:go_default_library_gen", "@com_github_golang_protobuf//proto:go_default_library", "@org_golang_google_grpc//:go_default_library", "@org_golang_google_grpc//codes:go_default_library", diff --git a/examples/proto/examplepb/a_bit_of_everything.pb.gw.go b/examples/proto/examplepb/a_bit_of_everything.pb.gw.go index c6aa89f505c..eaa39346919 100644 --- a/examples/proto/examplepb/a_bit_of_everything.pb.gw.go +++ b/examples/proto/examplepb/a_bit_of_everything.pb.gw.go @@ -13,6 +13,7 @@ import ( "io" "net/http" + "github.com/golang/protobuf/descriptor" "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes/empty" "github.com/grpc-ecosystem/grpc-gateway/examples/proto/pathenum" @@ -26,11 +27,13 @@ import ( "google.golang.org/grpc/status" ) +// Suppress "imported and not used" errors var _ codes.Code var _ io.Reader var _ status.Status var _ = runtime.String var _ = utilities.NewDoubleArray +var _ = descriptor.ForMessage var ( filter_ABitOfEverythingService_Create_0 = &utilities.DoubleArray{Encoding: map[string]int{"float_value": 0, "double_value": 1, "int64_value": 2, "uint64_value": 3, "int32_value": 4, "fixed64_value": 5, "fixed32_value": 6, "bool_value": 7, "string_value": 8, "uint32_value": 9, "sfixed32_value": 10, "sfixed64_value": 11, "sint32_value": 12, "sint64_value": 13, "nonConventionalNameValue": 14, "enum_value": 15, "path_enum_value": 16, "nested_path_enum_value": 17, "enum_value_annotation": 18}, Base: []int{1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, Check: []int{0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}} @@ -774,7 +777,8 @@ func request_ABitOfEverythingService_UpdateV2_1(ctx context.Context, marshaler r return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } if protoReq.UpdateMask == nil || len(protoReq.UpdateMask.GetPaths()) == 0 { - if fieldMask, err := runtime.FieldMaskFromRequestBody(newReader()); err != nil { + _, md := descriptor.ForMessage(protoReq.Abe) + if fieldMask, err := runtime.FieldMaskFromRequestBody(newReader(), md); err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } else { protoReq.UpdateMask = fieldMask @@ -823,7 +827,8 @@ func local_request_ABitOfEverythingService_UpdateV2_1(ctx context.Context, marsh return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } if protoReq.UpdateMask == nil || len(protoReq.UpdateMask.GetPaths()) == 0 { - if fieldMask, err := runtime.FieldMaskFromRequestBody(newReader()); err != nil { + _, md := descriptor.ForMessage(protoReq.Abe) + if fieldMask, err := runtime.FieldMaskFromRequestBody(newReader(), md); err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } else { protoReq.UpdateMask = fieldMask diff --git a/examples/proto/examplepb/echo_service.pb.gw.go b/examples/proto/examplepb/echo_service.pb.gw.go index a0445e04929..001f9661119 100644 --- a/examples/proto/examplepb/echo_service.pb.gw.go +++ b/examples/proto/examplepb/echo_service.pb.gw.go @@ -13,6 +13,7 @@ import ( "io" "net/http" + "github.com/golang/protobuf/descriptor" "github.com/golang/protobuf/proto" "github.com/grpc-ecosystem/grpc-gateway/runtime" "github.com/grpc-ecosystem/grpc-gateway/utilities" @@ -22,11 +23,13 @@ import ( "google.golang.org/grpc/status" ) +// Suppress "imported and not used" errors var _ codes.Code var _ io.Reader var _ status.Status var _ = runtime.String var _ = utilities.NewDoubleArray +var _ = descriptor.ForMessage var ( filter_EchoService_Echo_0 = &utilities.DoubleArray{Encoding: map[string]int{"id": 0}, Base: []int{1, 1, 0}, Check: []int{0, 1, 2}} diff --git a/examples/proto/examplepb/flow_combination.pb.gw.go b/examples/proto/examplepb/flow_combination.pb.gw.go index c69b4f1daae..14599e7a759 100644 --- a/examples/proto/examplepb/flow_combination.pb.gw.go +++ b/examples/proto/examplepb/flow_combination.pb.gw.go @@ -13,6 +13,7 @@ import ( "io" "net/http" + "github.com/golang/protobuf/descriptor" "github.com/golang/protobuf/proto" "github.com/grpc-ecosystem/grpc-gateway/runtime" "github.com/grpc-ecosystem/grpc-gateway/utilities" @@ -22,11 +23,13 @@ import ( "google.golang.org/grpc/status" ) +// Suppress "imported and not used" errors var _ codes.Code var _ io.Reader var _ status.Status var _ = runtime.String var _ = utilities.NewDoubleArray +var _ = descriptor.ForMessage func request_FlowCombination_RpcEmptyRpc_0(ctx context.Context, marshaler runtime.Marshaler, client FlowCombinationClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { var protoReq EmptyProto diff --git a/examples/proto/examplepb/non_standard_names.pb.gw.go b/examples/proto/examplepb/non_standard_names.pb.gw.go index 606554372cf..11cb0b98adc 100644 --- a/examples/proto/examplepb/non_standard_names.pb.gw.go +++ b/examples/proto/examplepb/non_standard_names.pb.gw.go @@ -13,6 +13,7 @@ import ( "io" "net/http" + "github.com/golang/protobuf/descriptor" "github.com/golang/protobuf/proto" "github.com/grpc-ecosystem/grpc-gateway/runtime" "github.com/grpc-ecosystem/grpc-gateway/utilities" @@ -22,11 +23,13 @@ import ( "google.golang.org/grpc/status" ) +// Suppress "imported and not used" errors var _ codes.Code var _ io.Reader var _ status.Status var _ = runtime.String var _ = utilities.NewDoubleArray +var _ = descriptor.ForMessage var ( filter_NonStandardService_Update_0 = &utilities.DoubleArray{Encoding: map[string]int{"body": 0}, Base: []int{1, 1, 0}, Check: []int{0, 1, 2}} @@ -44,7 +47,8 @@ func request_NonStandardService_Update_0(ctx context.Context, marshaler runtime. return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } if protoReq.UpdateMask == nil || len(protoReq.UpdateMask.GetPaths()) == 0 { - if fieldMask, err := runtime.FieldMaskFromRequestBody(newReader()); err != nil { + _, md := descriptor.ForMessage(protoReq.Body) + if fieldMask, err := runtime.FieldMaskFromRequestBody(newReader(), md); err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } else { protoReq.UpdateMask = fieldMask @@ -75,7 +79,8 @@ func local_request_NonStandardService_Update_0(ctx context.Context, marshaler ru return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } if protoReq.UpdateMask == nil || len(protoReq.UpdateMask.GetPaths()) == 0 { - if fieldMask, err := runtime.FieldMaskFromRequestBody(newReader()); err != nil { + _, md := descriptor.ForMessage(protoReq.Body) + if fieldMask, err := runtime.FieldMaskFromRequestBody(newReader(), md); err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } else { protoReq.UpdateMask = fieldMask @@ -107,7 +112,8 @@ func request_NonStandardService_UpdateWithJSONNames_0(ctx context.Context, marsh return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } if protoReq.UpdateMask == nil || len(protoReq.UpdateMask.GetPaths()) == 0 { - if fieldMask, err := runtime.FieldMaskFromRequestBody(newReader()); err != nil { + _, md := descriptor.ForMessage(protoReq.Body) + if fieldMask, err := runtime.FieldMaskFromRequestBody(newReader(), md); err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } else { protoReq.UpdateMask = fieldMask @@ -138,7 +144,8 @@ func local_request_NonStandardService_UpdateWithJSONNames_0(ctx context.Context, return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } if protoReq.UpdateMask == nil || len(protoReq.UpdateMask.GetPaths()) == 0 { - if fieldMask, err := runtime.FieldMaskFromRequestBody(newReader()); err != nil { + _, md := descriptor.ForMessage(protoReq.Body) + if fieldMask, err := runtime.FieldMaskFromRequestBody(newReader(), md); err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } else { protoReq.UpdateMask = fieldMask diff --git a/examples/proto/examplepb/response_body_service.pb.gw.go b/examples/proto/examplepb/response_body_service.pb.gw.go index e90bba3f44c..86d9c4ed9be 100644 --- a/examples/proto/examplepb/response_body_service.pb.gw.go +++ b/examples/proto/examplepb/response_body_service.pb.gw.go @@ -13,6 +13,7 @@ import ( "io" "net/http" + "github.com/golang/protobuf/descriptor" "github.com/golang/protobuf/proto" "github.com/grpc-ecosystem/grpc-gateway/runtime" "github.com/grpc-ecosystem/grpc-gateway/utilities" @@ -22,11 +23,13 @@ import ( "google.golang.org/grpc/status" ) +// Suppress "imported and not used" errors var _ codes.Code var _ io.Reader var _ status.Status var _ = runtime.String var _ = utilities.NewDoubleArray +var _ = descriptor.ForMessage func request_ResponseBodyService_GetResponseBody_0(ctx context.Context, marshaler runtime.Marshaler, client ResponseBodyServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { var protoReq ResponseBodyIn diff --git a/examples/proto/examplepb/stream.pb.gw.go b/examples/proto/examplepb/stream.pb.gw.go index b2629f9a564..8b33ff5c40a 100644 --- a/examples/proto/examplepb/stream.pb.gw.go +++ b/examples/proto/examplepb/stream.pb.gw.go @@ -13,6 +13,7 @@ import ( "io" "net/http" + "github.com/golang/protobuf/descriptor" "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes/empty" "github.com/grpc-ecosystem/grpc-gateway/examples/proto/sub" @@ -24,11 +25,13 @@ import ( "google.golang.org/grpc/status" ) +// Suppress "imported and not used" errors var _ codes.Code var _ io.Reader var _ status.Status var _ = runtime.String var _ = utilities.NewDoubleArray +var _ = descriptor.ForMessage func request_StreamService_BulkCreate_0(ctx context.Context, marshaler runtime.Marshaler, client StreamServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { var metadata runtime.ServerMetadata diff --git a/examples/proto/examplepb/unannotated_echo_service.pb.gw.go b/examples/proto/examplepb/unannotated_echo_service.pb.gw.go index e92b4af943c..27258d252ad 100644 --- a/examples/proto/examplepb/unannotated_echo_service.pb.gw.go +++ b/examples/proto/examplepb/unannotated_echo_service.pb.gw.go @@ -13,6 +13,7 @@ import ( "io" "net/http" + "github.com/golang/protobuf/descriptor" "github.com/golang/protobuf/proto" "github.com/grpc-ecosystem/grpc-gateway/runtime" "github.com/grpc-ecosystem/grpc-gateway/utilities" @@ -22,11 +23,13 @@ import ( "google.golang.org/grpc/status" ) +// Suppress "imported and not used" errors var _ codes.Code var _ io.Reader var _ status.Status var _ = runtime.String var _ = utilities.NewDoubleArray +var _ = descriptor.ForMessage var ( filter_UnannotatedEchoService_Echo_0 = &utilities.DoubleArray{Encoding: map[string]int{"id": 0}, Base: []int{1, 1, 0}, Check: []int{0, 1, 2}} diff --git a/examples/proto/examplepb/wrappers.pb.gw.go b/examples/proto/examplepb/wrappers.pb.gw.go index 57ff5f69962..edb9b10a020 100644 --- a/examples/proto/examplepb/wrappers.pb.gw.go +++ b/examples/proto/examplepb/wrappers.pb.gw.go @@ -13,6 +13,7 @@ import ( "io" "net/http" + "github.com/golang/protobuf/descriptor" "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes/empty" "github.com/golang/protobuf/ptypes/wrappers" @@ -24,11 +25,13 @@ import ( "google.golang.org/grpc/status" ) +// Suppress "imported and not used" errors var _ codes.Code var _ io.Reader var _ status.Status var _ = runtime.String var _ = utilities.NewDoubleArray +var _ = descriptor.ForMessage func request_WrappersService_Create_0(ctx context.Context, marshaler runtime.Marshaler, client WrappersServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { var protoReq Wrappers diff --git a/protoc-gen-grpc-gateway/gengateway/generator.go b/protoc-gen-grpc-gateway/gengateway/generator.go index 43943f95643..0b6bfbd2b93 100644 --- a/protoc-gen-grpc-gateway/gengateway/generator.go +++ b/protoc-gen-grpc-gateway/gengateway/generator.go @@ -44,6 +44,7 @@ func New(reg *descriptor.Registry, useRequestContext bool, registerFuncSuffix, p "net/http", "github.com/grpc-ecosystem/grpc-gateway/runtime", "github.com/grpc-ecosystem/grpc-gateway/utilities", + "github.com/golang/protobuf/descriptor", "github.com/golang/protobuf/proto", "google.golang.org/grpc", "google.golang.org/grpc/codes", diff --git a/protoc-gen-grpc-gateway/gengateway/template.go b/protoc-gen-grpc-gateway/gengateway/template.go index cd6b1f9e850..1d3d3ca8f19 100644 --- a/protoc-gen-grpc-gateway/gengateway/template.go +++ b/protoc-gen-grpc-gateway/gengateway/template.go @@ -2,6 +2,7 @@ package gengateway import ( "bytes" + "errors" "fmt" "strings" "text/template" @@ -34,6 +35,14 @@ func (b binding) GetBodyFieldPath() string { return "*" } +// GetBodyFieldPath returns the binding body's struct field name. +func (b binding) GetBodyFieldStructName() (string, error) { + if b.Body != nil && len(b.Body.FieldPath) != 0 { + return generator2.CamelCase(b.Body.FieldPath.String()), nil + } + return "", errors.New("No body field found") +} + // HasQueryParam determines if the binding needs parameters in query string. // // It sometimes returns true even though actually the binding does not need. @@ -224,11 +233,13 @@ import ( {{range $i := .Imports}}{{if not $i.Standard}}{{$i | printf "%s\n"}}{{end}}{{end}} ) +// Suppress "imported and not used" errors var _ codes.Code var _ io.Reader var _ status.Status var _ = runtime.String var _ = utilities.NewDoubleArray +var _ = descriptor.ForMessage `)) handlerTemplate = template.Must(template.New("handler").Parse(` @@ -316,7 +327,8 @@ var ( } {{- if and $AllowPatchFeature (eq (.HTTPMethod) "PATCH") (.FieldMaskField) (not (eq "*" .GetBodyFieldPath)) }} if protoReq.{{.FieldMaskField}} == nil || len(protoReq.{{.FieldMaskField}}.GetPaths()) == 0 { - if fieldMask, err := runtime.FieldMaskFromRequestBody(newReader()); err != nil { + _, md := descriptor.ForMessage(protoReq.{{.GetBodyFieldStructName}}) + if fieldMask, err := runtime.FieldMaskFromRequestBody(newReader(), md); err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } else { protoReq.{{.FieldMaskField}} = fieldMask @@ -477,7 +489,8 @@ func local_request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ct } {{- if and $AllowPatchFeature (eq (.HTTPMethod) "PATCH") (.FieldMaskField) (not (eq "*" .GetBodyFieldPath)) }} if protoReq.{{.FieldMaskField}} == nil || len(protoReq.{{.FieldMaskField}}.GetPaths()) == 0 { - if fieldMask, err := runtime.FieldMaskFromRequestBody(newReader()); err != nil { + _, md := descriptor.ForMessage(protoReq.{{.GetBodyFieldStructName}}) + if fieldMask, err := runtime.FieldMaskFromRequestBody(newReader(), md); err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } else { protoReq.{{.FieldMaskField}} = fieldMask diff --git a/runtime/BUILD.bazel b/runtime/BUILD.bazel index 2462dc053f7..819c45a7657 100644 --- a/runtime/BUILD.bazel +++ b/runtime/BUILD.bazel @@ -27,10 +27,12 @@ go_library( deps = [ "//internal:go_default_library", "//utilities:go_default_library", + "@com_github_golang_protobuf//descriptor:go_default_library_gen", "@com_github_golang_protobuf//jsonpb:go_default_library_gen", "@com_github_golang_protobuf//proto:go_default_library", "@go_googleapis//google/api:httpbody_go_proto", "@io_bazel_rules_go//proto/wkt:any_go_proto", + "@io_bazel_rules_go//proto/wkt:descriptor_go_proto", "@io_bazel_rules_go//proto/wkt:duration_go_proto", "@io_bazel_rules_go//proto/wkt:field_mask_go_proto", "@io_bazel_rules_go//proto/wkt:timestamp_go_proto", diff --git a/runtime/fieldmask.go b/runtime/fieldmask.go index d3eb5782b70..341aad5a3ea 100644 --- a/runtime/fieldmask.go +++ b/runtime/fieldmask.go @@ -5,11 +5,37 @@ import ( "io" "strings" + descriptor2 "github.com/golang/protobuf/descriptor" + "github.com/golang/protobuf/protoc-gen-go/descriptor" "google.golang.org/genproto/protobuf/field_mask" ) +func translateName(name string, md *descriptor.DescriptorProto) (string, *descriptor.DescriptorProto) { + // TODO - should really gate this with a test that the marshaller has used json names + if md != nil { + for _, f := range md.Field { + if f.JsonName != nil && f.Name != nil && *f.JsonName == name { + var subType *descriptor.DescriptorProto + + // If the field has a TypeName then we retrieve the nested type for translating the embedded message names. + if f.TypeName != nil { + typeSplit := strings.Split(*f.TypeName, ".") + typeName := typeSplit[len(typeSplit)-1] + for _, t := range md.NestedType { + if typeName == *t.Name { + subType = t + } + } + } + return *f.Name, subType + } + } + } + return name, nil +} + // FieldMaskFromRequestBody creates a FieldMask printing all complete paths from the JSON body. -func FieldMaskFromRequestBody(r io.Reader) (*field_mask.FieldMask, error) { +func FieldMaskFromRequestBody(r io.Reader, md *descriptor.DescriptorProto) (*field_mask.FieldMask, error) { fm := &field_mask.FieldMask{} var root interface{} if err := json.NewDecoder(r).Decode(&root); err != nil { @@ -19,7 +45,7 @@ func FieldMaskFromRequestBody(r io.Reader) (*field_mask.FieldMask, error) { return nil, err } - queue := []fieldMaskPathItem{{node: root}} + queue := []fieldMaskPathItem{{node: root, md: md}} for len(queue) > 0 { // dequeue an item item := queue[0] @@ -28,7 +54,11 @@ func FieldMaskFromRequestBody(r io.Reader) (*field_mask.FieldMask, error) { if m, ok := item.node.(map[string]interface{}); ok { // if the item is an object, then enqueue all of its children for k, v := range m { - queue = append(queue, fieldMaskPathItem{path: append(item.path, k), node: v}) + protoName, subMd := translateName(k, item.md) + if subMsg, ok := v.(descriptor2.Message); ok { + _, subMd = descriptor2.ForMessage(subMsg) + } + queue = append(queue, fieldMaskPathItem{path: append(item.path, protoName), node: v, md: subMd}) } } else if len(item.path) > 0 { // otherwise, it's a leaf node so print its path @@ -46,4 +76,7 @@ type fieldMaskPathItem struct { // a generic decoded json object the current item to inspect for further path extraction node interface{} + + // descriptor for parent message + md *descriptor.DescriptorProto } diff --git a/runtime/fieldmask_test.go b/runtime/fieldmask_test.go index dae968b8284..7a5ddaae957 100644 --- a/runtime/fieldmask_test.go +++ b/runtime/fieldmask_test.go @@ -56,7 +56,7 @@ func TestFieldMaskFromRequestBody(t *testing.T) { {name: "canonical", input: `{"f": {"b": {"d": 1, "x": 2}, "c": 1}}`, expected: newFieldMask("f.b.d", "f.b.x", "f.c")}, } { t.Run(tc.name, func(t *testing.T) { - actual, err := FieldMaskFromRequestBody(bytes.NewReader([]byte(tc.input))) + actual, err := FieldMaskFromRequestBody(bytes.NewReader([]byte(tc.input)), nil) if !fieldMasksEqual(actual, tc.expected) { t.Errorf("want %v; got %v", fieldMaskString(tc.expected), fieldMaskString(actual)) } @@ -123,7 +123,7 @@ func BenchmarkABEFieldMaskFromRequestBody(b *testing.B) { var r *field_mask.FieldMask var err error for i := 0; i < b.N; i++ { - r, err = FieldMaskFromRequestBody(bytes.NewReader([]byte(input))) + r, err = FieldMaskFromRequestBody(bytes.NewReader([]byte(input)), nil) } if err != nil { b.Error(err) @@ -142,7 +142,7 @@ func BenchmarkNonStandardFieldMaskFromRequestBody(b *testing.B) { var r *field_mask.FieldMask var err error for i := 0; i < b.N; i++ { - r, err = FieldMaskFromRequestBody(bytes.NewReader([]byte(input))) + r, err = FieldMaskFromRequestBody(bytes.NewReader([]byte(input)), nil) } if err != nil { b.Error(err)