Skip to content

Commit

Permalink
apply feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
fredcarle committed Sep 2, 2022
1 parent fa1128b commit 43f8e0b
Showing 1 changed file with 38 additions and 29 deletions.
67 changes: 38 additions & 29 deletions db/collection_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,13 +359,13 @@ func (c *collection) applyMerge(
return errors.New("invalid field in Patch")
}

var err error
mergeCBOR[mfield], err = validateFieldSchema(mval, fd)
cborVal, err := validateFieldSchema(mval, fd)
if err != nil {
return err
}
mergeCBOR[mfield] = cborVal

val := client.NewCBORValue(fd.Typ, mergeCBOR[mfield])
val := client.NewCBORValue(fd.Typ, cborVal)
fieldKey, fieldExists := c.tryGetFieldKey(key, mfield)
if !fieldExists {
return client.ErrFieldNotExist
Expand Down Expand Up @@ -434,7 +434,7 @@ func validateFieldSchema(val *fastjson.Value, field client.FieldDescription) (in
return getArray(val, getString, "")

case client.FieldKind_NILLABLE_STRING_ARRAY:
return getArray(val, getString, nil)
return getNillableArray(val, getString)

case client.FieldKind_BOOL:
return getBool(val)
Expand All @@ -443,7 +443,7 @@ func validateFieldSchema(val *fastjson.Value, field client.FieldDescription) (in
return getArray(val, getBool, false)

case client.FieldKind_NILLABLE_BOOL_ARRAY:
return getArray(val, getBool, nil)
return getNillableArray(val, getBool)

case client.FieldKind_FLOAT, client.FieldKind_DECIMAL:
return getFloat64(val)
Expand All @@ -452,7 +452,7 @@ func validateFieldSchema(val *fastjson.Value, field client.FieldDescription) (in
return getArray(val, getFloat64, 0)

case client.FieldKind_NILLABLE_FLOAT_ARRAY:
return getArray(val, getFloat64, nil)
return getNillableArray(val, getFloat64)

case client.FieldKind_DATE:
return getDate(val)
Expand All @@ -464,7 +464,7 @@ func validateFieldSchema(val *fastjson.Value, field client.FieldDescription) (in
return getArray(val, getInt64, 0)

case client.FieldKind_NILLABLE_INT_ARRAY:
return getArray(val, getInt64, nil)
return getNillableArray(val, getInt64)

case client.FieldKind_OBJECT, client.FieldKind_OBJECT_ARRAY,
client.FieldKind_FOREIGN_OBJECT, client.FieldKind_FOREIGN_OBJECT_ARRAY:
Expand Down Expand Up @@ -502,8 +502,8 @@ func getDate(v *fastjson.Value) (time.Time, error) {
func getArray[T any](
val *fastjson.Value,
typeGetter func(*fastjson.Value) (T, error),
zeroValue any,
) (any, error) {
zeroValue T,
) ([]T, error) {
if val.Type() == fastjson.TypeNull {
return nil, nil
}
Expand All @@ -513,37 +513,46 @@ func getArray[T any](
return nil, err
}

if zeroValue == nil {
arr := make([]*T, len(valArray))
for i, arrItem := range valArray {
if arrItem.Type() == fastjson.TypeNull {
arr[i] = nil
continue
}
v, err := typeGetter(arrItem)
if err != nil {
return nil, err
}
arr[i] = &v
arr := make([]T, len(valArray))
for i, arrItem := range valArray {
if arrItem.Type() == fastjson.TypeNull {
arr[i] = zeroValue
continue
}
arr[i], err = typeGetter(arrItem)
if err != nil {
return nil, err
}
return arr, nil
}

arr := make([]T, len(valArray))
return arr, nil
}

func getNillableArray[T any](
val *fastjson.Value,
typeGetter func(*fastjson.Value) (T, error),
) ([]*T, error) {
if val.Type() == fastjson.TypeNull {
return nil, nil
}

valArray, err := val.Array()
if err != nil {
return nil, err
}

arr := make([]*T, len(valArray))
for i, arrItem := range valArray {
if arrItem.Type() == fastjson.TypeNull {
var ok bool
arr[i], ok = zeroValue.(T)
if !ok {
return nil, errors.New("zeroValue should be of the same type as the array items type")
}
continue
}
arr[i], err = typeGetter(arrItem)
v, err := typeGetter(arrItem)
if err != nil {
return nil, err
}
arr[i] = &v
}

return arr, nil
}

Expand Down

0 comments on commit 43f8e0b

Please sign in to comment.