diff --git a/bson/bsonrw/extjson_reader.go b/bson/bsonrw/extjson_reader.go index dd560c96f6..b83012b21c 100644 --- a/bson/bsonrw/extjson_reader.go +++ b/bson/bsonrw/extjson_reader.go @@ -159,29 +159,18 @@ func (ejvr *extJSONValueReader) pop() { } } -func (ejvr *extJSONValueReader) skipDocument() error { - // read entire document until ErrEOD (using readKey and readValue) - _, typ, err := ejvr.p.readKey() - for err == nil { - _, err = ejvr.p.readValue(typ) - if err != nil { - break +func (ejvr *extJSONValueReader) skipObject() { + // read entire object until depth returns to 0 (last ending } or ] seen) + depth := 1 + for depth > 0 { + ejvr.p.advanceState() + switch ejvr.p.s { + case jpsSawBeginObject, jpsSawBeginArray: + depth++ + case jpsSawEndObject, jpsSawEndArray: + depth-- } - - _, typ, err = ejvr.p.readKey() } - - return err -} - -func (ejvr *extJSONValueReader) skipArray() error { - // read entire array until ErrEOA (using peekType) - _, err := ejvr.p.peekType() - for err == nil { - _, err = ejvr.p.peekType() - } - - return err } func (ejvr *extJSONValueReader) invalidTransitionErr(destination mode, name string, modes []mode) error { @@ -234,30 +223,9 @@ func (ejvr *extJSONValueReader) Skip() error { t := ejvr.stack[ejvr.frame].vType switch t { - case bsontype.Array: - // read entire array until ErrEOA - err := ejvr.skipArray() - if err != ErrEOA { - return err - } - case bsontype.EmbeddedDocument: - // read entire doc until ErrEOD - err := ejvr.skipDocument() - if err != ErrEOD { - return err - } - case bsontype.CodeWithScope: - // read the code portion and set up parser in document mode - _, err := ejvr.p.readValue(t) - if err != nil { - return err - } - - // read until ErrEOD - err = ejvr.skipDocument() - if err != ErrEOD { - return err - } + case bsontype.Array, bsontype.EmbeddedDocument, bsontype.CodeWithScope: + // read entire array, doc or CodeWithScope + ejvr.skipObject() default: _, err := ejvr.p.readValue(t) if err != nil { diff --git a/bson/unmarshal_test.go b/bson/unmarshal_test.go index 283ff5d450..d98f527737 100644 --- a/bson/unmarshal_test.go +++ b/bson/unmarshal_test.go @@ -166,3 +166,115 @@ func TestCachingDecodersNotSharedAcrossRegistries(t *testing.T) { assert.Equal(t, int32(-1), *second.X, "expected X value to be -1, got %v", *second.X) }) } + +func TestUnmarshalExtJSONWithUndefinedField(t *testing.T) { + // When unmarshalling, fields that are undefined in the destination struct are skipped. + // This process must not skip other, defined fields and must not raise errors. + type expectedResponse struct { + DefinedField string + } + + unmarshalExpectedResponse := func(t *testing.T, extJSON string) *expectedResponse { + t.Helper() + responseDoc := expectedResponse{} + err := UnmarshalExtJSON([]byte(extJSON), false, &responseDoc) + assert.Nil(t, err, "UnmarshalExtJSON error: %v", err) + return &responseDoc + } + + testCases := []struct { + name string + testJSON string + }{ + { + "no array", + `{ + "UndefinedField": {"key": 1}, + "DefinedField": "value" + }`, + }, + { + "outer array", + `{ + "UndefinedField": [{"key": 1}], + "DefinedField": "value" + }`, + }, + { + "embedded array", + `{ + "UndefinedField": {"keys": [2]}, + "DefinedField": "value" + }`, + }, + { + "outer array and embedded array", + `{ + "UndefinedField": [{"keys": [2]}], + "DefinedField": "value" + }`, + }, + { + "embedded document", + `{ + "UndefinedField": {"key": {"one": "two"}}, + "DefinedField": "value" + }`, + }, + { + "doubly embedded document", + `{ + "UndefinedField": {"key": {"one": {"two": "three"}}}, + "DefinedField": "value" + }`, + }, + { + "embedded document and embedded array", + `{ + "UndefinedField": {"key": {"one": {"two": [3]}}}, + "DefinedField": "value" + }`, + }, + { + "embedded document and embedded array in outer array", + `{ + "UndefinedField": [{"key": {"one": [3]}}], + "DefinedField": "value" + }`, + }, + { + "code with scope", + `{ + "UndefinedField": {"logic": {"$code": "foo", "$scope": {"bar": 1}}}, + "DefinedField": "value" + }`, + }, + { + "embedded array of code with scope", + `{ + "UndefinedField": {"logic": [{"$code": "foo", "$scope": {"bar": 1}}]}, + "DefinedField": "value" + }`, + }, + { + "type definition embedded document", + `{ + "UndefinedField": {"myDouble": {"$numberDouble": "1.24"}}, + "DefinedField": "value" + }`, + }, + { + "empty embedded document", + `{ + "UndefinedField": {"empty": {}}, + "DefinedField": "value" + }`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + responseDoc := unmarshalExpectedResponse(t, tc.testJSON) + assert.Equal(t, "value", responseDoc.DefinedField, "expected DefinedField to be 'value', got %q", responseDoc.DefinedField) + }) + } +}