Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jsonpb: change Marshal/Unmarshal to return error if any required field is not set #472

Merged
merged 7 commits into from
Jan 5, 2018
136 changes: 135 additions & 1 deletion jsonpb/jsonpb.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ func (m *Marshaler) Marshal(out io.Writer, pb proto.Message) error {
if pb == nil || (v.Kind() == reflect.Ptr && v.IsNil()) {
return errors.New("Marshal called with nil")
}
// Check for unset required fields first.
if err := checkRequiredFields(pb); err != nil {
return err
}
writer := &errWriter{writer: out}
return m.marshalObject(writer, pb, "", "")
}
Expand Down Expand Up @@ -636,7 +640,10 @@ func (u *Unmarshaler) UnmarshalNext(dec *json.Decoder, pb proto.Message) error {
if err := dec.Decode(&inputValue); err != nil {
return err
}
return u.unmarshalValue(reflect.ValueOf(pb).Elem(), inputValue, nil)
if err := u.unmarshalValue(reflect.ValueOf(pb).Elem(), inputValue, nil); err != nil {
return err
}
return checkRequiredFields(pb)
}

// Unmarshal unmarshals a JSON object stream into a protocol
Expand Down Expand Up @@ -1080,3 +1087,130 @@ func (s mapKeys) Less(i, j int) bool {
}
return fmt.Sprint(s[i].Interface()) < fmt.Sprint(s[j].Interface())
}

// checkRequiredFields returns an error if any required field in the given proto message is not set.
// This function is used by both Marshal and Unmarshal. While required fields only exist in a
// proto2 message, a proto3 message can contain proto2 message(s).
func checkRequiredFields(pb proto.Message) error {
// Most well-known type messages do not contain required fields. The "Any" type may contain
// a message that has required fields.
//
// When an Any message is being marshaled, the code will invoked proto.Unmarshal on Any.Value
// field in order to transform that into JSON, and that should have returned an error if a
// required field is not set in the embedded message.
//
// When an Any message is being unmarshaled, the code will have invoked proto.Marshal on the
// embedded message to store the serialized message in Any.Value field, and that should have
// returned an error if a required field is not set.
if _, ok := pb.(wkt); ok {
return nil
}

v := reflect.ValueOf(pb)
// Skip message if it is not a struct pointer.
if v.Kind() != reflect.Ptr {
return nil
}
v = v.Elem()
if v.Kind() != reflect.Struct {
return nil
}

for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
sfield := v.Type().Field(i)
if strings.HasPrefix(sfield.Name, "XXX_") {
continue
}

// Oneof field is an interface implemented by wrapper structs containing the actual oneof
// field, i.e. an interface containing &T{real_value}.
if sfield.Tag.Get("protobuf_oneof") != "" {
if field.Kind() != reflect.Interface {
continue
}
v := field.Elem()
if v.Kind() != reflect.Ptr || v.IsNil() {
continue
}
v = v.Elem()
if v.Kind() != reflect.Struct || v.NumField() < 1 {
continue
}
field = v.Field(0)
sfield = v.Type().Field(0)
}

var prop proto.Properties
prop.Init(sfield.Type, sfield.Name, sfield.Tag.Get("protobuf"), &sfield)

switch field.Kind() {
case reflect.Map:
if field.IsNil() {
continue
}
// Check each map value.
keys := field.MapKeys()
for _, k := range keys {
v := field.MapIndex(k)
if err := checkRequiredFieldsInValue(v); err != nil {
return err
}
}
case reflect.Slice:
// Handle non-repeated type, e.g. bytes.
if !prop.Repeated {
if prop.Required && field.IsNil() {
return fmt.Errorf("required field %q is not set", prop.Name)
}
continue
}

// Handle repeated type.
if field.IsNil() {
continue
}
// Check each slice item.
for i := 0; i < field.Len(); i++ {
v := field.Index(i)
if err := checkRequiredFieldsInValue(v); err != nil {
return err
}
}
case reflect.Ptr:
if field.IsNil() {
if prop.Required {
return fmt.Errorf("required field %q is not set", prop.Name)
}
continue
}
if err := checkRequiredFieldsInValue(field); err != nil {
return err
}
}
}

// Handle proto2 extensions.
for _, ext := range proto.RegisteredExtensions(pb) {
if !proto.HasExtension(pb, ext) {
continue
}
ep, err := proto.GetExtension(pb, ext)
if err != nil {
return err
}
err = checkRequiredFieldsInValue(reflect.ValueOf(ep))
if err != nil {
return err
}
}

return nil
}

func checkRequiredFieldsInValue(v reflect.Value) error {
if pm, ok := v.Interface().(proto.Message); ok {
return checkRequiredFields(pm)
}
return nil
}
197 changes: 197 additions & 0 deletions jsonpb/jsonpb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,9 @@ var marshalingTests = []struct {
{"BoolValue", marshaler, &pb.KnownTypes{Bool: &wpb.BoolValue{Value: true}}, `{"bool":true}`},
{"StringValue", marshaler, &pb.KnownTypes{Str: &wpb.StringValue{Value: "plush"}}, `{"str":"plush"}`},
{"BytesValue", marshaler, &pb.KnownTypes{Bytes: &wpb.BytesValue{Value: []byte("wow")}}, `{"bytes":"d293"}`},

{"required", marshaler, &pb.MsgWithRequired{Str: proto.String("hello")}, `{"str":"hello"}`},
{"required bytes", marshaler, &pb.MsgWithRequiredBytes{Byts: []byte{}}, `{"byts":""}`},
}

func TestMarshaling(t *testing.T) {
Expand Down Expand Up @@ -500,6 +503,91 @@ func TestMarshalAnyJSONPBMarshaler(t *testing.T) {
}
}

// Test marshaling message containing unset required fields should produce error.
func TestMarshalUnsetRequiredFields(t *testing.T) {
msgExt := &pb.Real{}
proto.SetExtension(msgExt, pb.E_Extm, &pb.MsgWithRequired{})

tests := []struct {
desc string
marshaler *Marshaler
pb proto.Message
}{
{
desc: "direct required field",
marshaler: &Marshaler{},
pb: &pb.MsgWithRequired{},
},
{
desc: "direct required field + emit defaults",
marshaler: &Marshaler{EmitDefaults: true},
pb: &pb.MsgWithRequired{},
},
{
desc: "indirect required field",
marshaler: &Marshaler{},
pb: &pb.MsgWithIndirectRequired{Subm: &pb.MsgWithRequired{}},
},
{
desc: "indirect required field + emit defaults",
marshaler: &Marshaler{EmitDefaults: true},
pb: &pb.MsgWithIndirectRequired{Subm: &pb.MsgWithRequired{}},
},
{
desc: "direct required wkt field",
marshaler: &Marshaler{},
pb: &pb.MsgWithRequiredWKT{},
},
{
desc: "direct required wkt field + emit defaults",
marshaler: &Marshaler{EmitDefaults: true},
pb: &pb.MsgWithRequiredWKT{},
},
{
desc: "direct required bytes field",
marshaler: &Marshaler{},
pb: &pb.MsgWithRequiredBytes{},
},
{
desc: "required in map value",
marshaler: &Marshaler{},
pb: &pb.MsgWithIndirectRequired{
MapField: map[string]*pb.MsgWithRequired{
"key": {},
},
},
},
{
desc: "required in repeated item",
marshaler: &Marshaler{},
pb: &pb.MsgWithIndirectRequired{
SliceField: []*pb.MsgWithRequired{
{Str: proto.String("hello")},
{},
},
},
},
{
desc: "required inside oneof",
marshaler: &Marshaler{},
pb: &pb.MsgWithOneof{
Union: &pb.MsgWithOneof_MsgWithRequired{&pb.MsgWithRequired{}},
},
},
{
desc: "required inside extension",
marshaler: &Marshaler{},
pb: msgExt,
},
}

for _, tc := range tests {
if _, err := tc.marshaler.MarshalToString(tc.pb); err == nil {
t.Errorf("%s: expecting error in marshaling with unset required fields %+v", tc.desc, tc.pb)
}
}
}

var unmarshalingTests = []struct {
desc string
unmarshaler Unmarshaler
Expand Down Expand Up @@ -631,6 +719,9 @@ var unmarshalingTests = []struct {
{"null BoolValue", Unmarshaler{}, `{"bool":null}`, &pb.KnownTypes{Bool: nil}},
{"null StringValue", Unmarshaler{}, `{"str":null}`, &pb.KnownTypes{Str: nil}},
{"null BytesValue", Unmarshaler{}, `{"bytes":null}`, &pb.KnownTypes{Bytes: nil}},

{"required", Unmarshaler{}, `{"str":"hello"}`, &pb.MsgWithRequired{Str: proto.String("hello")}},
{"required bytes", Unmarshaler{}, `{"byts": []}`, &pb.MsgWithRequiredBytes{Byts: []byte{}}},
}

func TestUnmarshaling(t *testing.T) {
Expand Down Expand Up @@ -902,3 +993,109 @@ func (m *dynamicMessage) UnmarshalJSONPB(jum *Unmarshaler, js []byte) error {
m.rawJson = string(js)
return nil
}

// Test unmarshaling message containing unset required fields should produce error.
func TestUnmarshalUnsetRequiredFields(t *testing.T) {
tests := []struct {
desc string
pb proto.Message
json string
}{
{
desc: "direct required field missing",
pb: &pb.MsgWithRequired{},
json: `{}`,
},
{
desc: "direct required field set to null",
pb: &pb.MsgWithRequired{},
json: `{"str": null}`,
},
{
desc: "indirect required field missing",
pb: &pb.MsgWithIndirectRequired{},
json: `{"subm": {}}`,
},
{
desc: "indirect required field set to null",
pb: &pb.MsgWithIndirectRequired{},
json: `{"subm": {"str": null}}`,
},
{
desc: "direct required bytes field missing",
pb: &pb.MsgWithRequiredBytes{},
json: `{}`,
},
{
desc: "direct required bytes field set to null",
pb: &pb.MsgWithRequiredBytes{},
json: `{"byts": null}`,
},
{
desc: "direct required wkt field missing",
pb: &pb.MsgWithRequiredWKT{},
json: `{}`,
},
{
desc: "direct required wkt field set to null",
pb: &pb.MsgWithRequiredWKT{},
json: `{"str": null}`,
},
{
desc: "any containing message with required field set to null",
pb: &pb.KnownTypes{},
json: `{"an": {"@type": "example.com/jsonpb.MsgWithRequired", "str": null}}`,
},
{
desc: "any containing message with missing required field",
pb: &pb.KnownTypes{},
json: `{"an": {"@type": "example.com/jsonpb.MsgWithRequired"}}`,
},
{
desc: "missing required in map value",
pb: &pb.MsgWithIndirectRequired{},
json: `{"map_field": {"a": {}, "b": {"str": "hi"}}}`,
},
{
desc: "required in map value set to null",
pb: &pb.MsgWithIndirectRequired{},
json: `{"map_field": {"a": {"str": "hello"}, "b": {"str": null}}}`,
},
{
desc: "missing required in slice item",
pb: &pb.MsgWithIndirectRequired{},
json: `{"slice_field": [{}, {"str": "hi"}]}`,
},
{
desc: "required in slice item set to null",
pb: &pb.MsgWithIndirectRequired{},
json: `{"slice_field": [{"str": "hello"}, {"str": null}]}`,
},
{
desc: "required inside oneof missing",
pb: &pb.MsgWithOneof{},
json: `{"msgWithRequired": {}}`,
},
{
desc: "required inside oneof set to null",
pb: &pb.MsgWithOneof{},
json: `{"msgWithRequired": {"str": null}}`,
},
{
desc: "required field in extension missing",
pb: &pb.Real{},
json: `{"[jsonpb.extm]":{}}`,
},
{
desc: "required field in extension set to null",
pb: &pb.Real{},
json: `{"[jsonpb.extm]":{"str": null}}`,
},
}

for _, tc := range tests {
if err := UnmarshalString(tc.json, tc.pb); err == nil {
t.Errorf("%s: expecting error in unmarshaling with unset required fields %s", tc.desc, tc.json)
}
}
}
4 changes: 4 additions & 0 deletions jsonpb/jsonpb_test_proto/more_test_objects.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading