diff --git a/db/collection_update.go b/db/collection_update.go index 05428a5450..047bcd5d0c 100644 --- a/db/collection_update.go +++ b/db/collection_update.go @@ -359,13 +359,13 @@ func (c *collection) applyMerge( return errors.New("invalid field in Patch") } - var err error - mergeCBOR[mfield], err = validateFieldSchema(mval, fd) + cborVal, err := validateFieldSchema(mval, fd) if err != nil { return err } + mergeCBOR[mfield] = cborVal - val := client.NewCBORValue(fd.Typ, mergeCBOR[mfield]) + val := client.NewCBORValue(fd.Typ, cborVal) fieldKey, fieldExists := c.tryGetFieldKey(key, mfield) if !fieldExists { return client.ErrFieldNotExist @@ -434,7 +434,7 @@ func validateFieldSchema(val *fastjson.Value, field client.FieldDescription) (in return getArray(val, getString, "") case client.FieldKind_NILLABLE_STRING_ARRAY: - return getArray(val, getString, nil) + return getNillableArray(val, getString) case client.FieldKind_BOOL: return getBool(val) @@ -443,7 +443,7 @@ func validateFieldSchema(val *fastjson.Value, field client.FieldDescription) (in return getArray(val, getBool, false) case client.FieldKind_NILLABLE_BOOL_ARRAY: - return getArray(val, getBool, nil) + return getNillableArray(val, getBool) case client.FieldKind_FLOAT, client.FieldKind_DECIMAL: return getFloat64(val) @@ -452,7 +452,7 @@ func validateFieldSchema(val *fastjson.Value, field client.FieldDescription) (in return getArray(val, getFloat64, 0) case client.FieldKind_NILLABLE_FLOAT_ARRAY: - return getArray(val, getFloat64, nil) + return getNillableArray(val, getFloat64) case client.FieldKind_DATE: return getDate(val) @@ -464,7 +464,7 @@ func validateFieldSchema(val *fastjson.Value, field client.FieldDescription) (in return getArray(val, getInt64, 0) case client.FieldKind_NILLABLE_INT_ARRAY: - return getArray(val, getInt64, nil) + return getNillableArray(val, getInt64) case client.FieldKind_OBJECT, client.FieldKind_OBJECT_ARRAY, client.FieldKind_FOREIGN_OBJECT, client.FieldKind_FOREIGN_OBJECT_ARRAY: @@ -502,8 +502,8 @@ func getDate(v *fastjson.Value) (time.Time, error) { func getArray[T any]( val *fastjson.Value, typeGetter func(*fastjson.Value) (T, error), - zeroValue any, -) (any, error) { + zeroValue T, +) ([]T, error) { if val.Type() == fastjson.TypeNull { return nil, nil } @@ -513,37 +513,46 @@ func getArray[T any]( return nil, err } - if zeroValue == nil { - arr := make([]*T, len(valArray)) - for i, arrItem := range valArray { - if arrItem.Type() == fastjson.TypeNull { - arr[i] = nil - continue - } - v, err := typeGetter(arrItem) - if err != nil { - return nil, err - } - arr[i] = &v + arr := make([]T, len(valArray)) + for i, arrItem := range valArray { + if arrItem.Type() == fastjson.TypeNull { + arr[i] = zeroValue + continue + } + arr[i], err = typeGetter(arrItem) + if err != nil { + return nil, err } - return arr, nil } - arr := make([]T, len(valArray)) + return arr, nil +} + +func getNillableArray[T any]( + val *fastjson.Value, + typeGetter func(*fastjson.Value) (T, error), +) ([]*T, error) { + if val.Type() == fastjson.TypeNull { + return nil, nil + } + + valArray, err := val.Array() + if err != nil { + return nil, err + } + + arr := make([]*T, len(valArray)) for i, arrItem := range valArray { if arrItem.Type() == fastjson.TypeNull { - var ok bool - arr[i], ok = zeroValue.(T) - if !ok { - return nil, errors.New("zeroValue should be of the same type as the array items type") - } continue } - arr[i], err = typeGetter(arrItem) + v, err := typeGetter(arrItem) if err != nil { return nil, err } + arr[i] = &v } + return arr, nil }