From d5c2270f40e452fea045e06d871f6a15158e5f26 Mon Sep 17 00:00:00 2001 From: Jason Del Ponte <961963+jasdel@users.noreply.github.com> Date: Tue, 22 Feb 2022 11:37:28 -0800 Subject: [PATCH] Fix AttributeValue marshaling and names in expressions (#1590) Updates the `attributevalue` and `expression` package's handling of AttributeValue marshaling fixing several bugs in the packages. * Fixes #1569 `Inconsistent struct field name marshaled`. Fields will now be consistent with the EncoderOptions or DecoderOptions the Go struct was used with. Previously the Go struct fields would be cached with the first options used for the type. Causes subsequent usages to have the wrong field names if the encoding options used different TagKeys. * Fixes #645, #411 `Support more than string types for map keys`. Updates (un)marshaler to support number, bool, and types that implement encoding.Text(Un)Marshaler interfaces. * Fixes Support for expression Names with literal dots in name. Adds new function NameNoDotSplit to expression package. This function allows you to provide a literal expression Name containing dots. Also adds a new method to NameBuilder, AppendName, for joining multiple name path components together. Helpful for joining names with literal dots with subsequent object path fields. * Fixes bug with AttributeValue marshaler struct struct tag usage that caused TagKey to be ignored if the member had a struct tag with `dynamodbav` struct tag. Now both tags will be read as documented, with the TagKey struct tag options taking precedence. --- .../2096a6beb82d44bea4a469c197f6de40.json | 8 + .../98a8c469e1d64c9aa06b0e467ac61dd9.json | 9 + .../db81731fa3ab450e9ea3535a0d4aaedd.json | 9 + feature/dynamodb/attributevalue/decode.go | 149 +++++++++++-- .../dynamodb/attributevalue/decode_test.go | 199 ++++++++++++++++- feature/dynamodb/attributevalue/encode.go | 42 +++- .../dynamodb/attributevalue/encode_test.go | 132 +++++++++++- feature/dynamodb/attributevalue/field.go | 13 +- .../dynamodb/attributevalue/field_cache.go | 16 +- feature/dynamodb/attributevalue/field_test.go | 201 ++++++++++++++---- .../dynamodb/attributevalue/marshaler_test.go | 8 +- .../dynamodb/attributevalue/shared_test.go | 23 +- feature/dynamodb/attributevalue/tag.go | 2 +- .../dynamodb/expression/expression_test.go | 38 ++-- feature/dynamodb/expression/go.mod | 1 + feature/dynamodb/expression/operand.go | 133 +++++++++++- feature/dynamodb/expression/operand_test.go | 108 ++++++++-- feature/dynamodb/expression/update_test.go | 8 +- .../dynamodbstreams/attributevalue/decode.go | 149 +++++++++++-- .../attributevalue/decode_test.go | 199 ++++++++++++++++- .../dynamodbstreams/attributevalue/encode.go | 42 +++- .../attributevalue/encode_test.go | 132 +++++++++++- .../dynamodbstreams/attributevalue/field.go | 13 +- .../attributevalue/field_cache.go | 16 +- .../attributevalue/field_test.go | 201 ++++++++++++++---- .../attributevalue/marshaler_test.go | 8 +- .../attributevalue/shared_test.go | 23 +- feature/dynamodbstreams/attributevalue/tag.go | 2 +- 28 files changed, 1671 insertions(+), 213 deletions(-) create mode 100644 .changelog/2096a6beb82d44bea4a469c197f6de40.json create mode 100644 .changelog/98a8c469e1d64c9aa06b0e467ac61dd9.json create mode 100644 .changelog/db81731fa3ab450e9ea3535a0d4aaedd.json diff --git a/.changelog/2096a6beb82d44bea4a469c197f6de40.json b/.changelog/2096a6beb82d44bea4a469c197f6de40.json new file mode 100644 index 00000000000..e580502da74 --- /dev/null +++ b/.changelog/2096a6beb82d44bea4a469c197f6de40.json @@ -0,0 +1,8 @@ +{ + "id": "2096a6be-b82d-44be-a4a4-69c197f6de40", + "type": "feature", + "description": "Add support for expression names with dots via new NameBuilder function NameNoDotSplit, related to [aws/aws-sdk-go#2570](https://github.com/aws/aws-sdk-go/issues/2570)", + "modules": [ + "feature/dynamodb/expression" + ] +} \ No newline at end of file diff --git a/.changelog/98a8c469e1d64c9aa06b0e467ac61dd9.json b/.changelog/98a8c469e1d64c9aa06b0e467ac61dd9.json new file mode 100644 index 00000000000..4eb187140df --- /dev/null +++ b/.changelog/98a8c469e1d64c9aa06b0e467ac61dd9.json @@ -0,0 +1,9 @@ +{ + "id": "98a8c469-e1d6-4c9a-a06b-0e467ac61dd9", + "type": "bugfix", + "description": "Fixes [#1569](https://github.com/aws/aws-sdk-go-v2/issues/1569) inconsistent serialization of Go struct field names", + "modules": [ + "feature/dynamodb/attributevalue", + "feature/dynamodbstreams/attributevalue" + ] +} diff --git a/.changelog/db81731fa3ab450e9ea3535a0d4aaedd.json b/.changelog/db81731fa3ab450e9ea3535a0d4aaedd.json new file mode 100644 index 00000000000..36c874b163c --- /dev/null +++ b/.changelog/db81731fa3ab450e9ea3535a0d4aaedd.json @@ -0,0 +1,9 @@ +{ + "id": "db81731f-a3ab-450e-9ea3-535a0d4aaedd", + "type": "feature", + "description": "Fixes [#645](https://github.com/aws/aws-sdk-go-v2/issues/645), [#411](https://github.com/aws/aws-sdk-go-v2/issues/411) by adding support for (un)marshaling AttributeValue maps to Go maps key types of string, number, bool, and types implementing encoding.Text(un)Marshaler interface", + "modules": [ + "feature/dynamodb/attributevalue", + "feature/dynamodbstreams/attributevalue" + ] +} \ No newline at end of file diff --git a/feature/dynamodb/attributevalue/decode.go b/feature/dynamodb/attributevalue/decode.go index 5a02853dc95..fc3f322dd01 100644 --- a/feature/dynamodb/attributevalue/decode.go +++ b/feature/dynamodb/attributevalue/decode.go @@ -1,6 +1,7 @@ package attributevalue import ( + "encoding" "fmt" "reflect" "strconv" @@ -197,7 +198,7 @@ func UnmarshalListOfMapsWithOptions(l []map[string]types.AttributeValue, out int } // DecoderOptions is a collection of options to configure how the decoder -// unmarshalls the value. +// unmarshals the value. type DecoderOptions struct { // Support other custom struct tag keys, such as `yaml`, `json`, or `toml`. // Note that values provided with a custom TagKey must also be supported @@ -221,7 +222,7 @@ type Decoder struct { // NewDecoder creates a new Decoder with default configuration. Use // the `opts` functional options to override the default configuration. func NewDecoder(optFns ...func(*DecoderOptions)) *Decoder { - var options DecoderOptions + options := DecoderOptions{TagKey: defaultTagKey} for _, fn := range optFns { fn(&options) } @@ -254,14 +255,14 @@ func (d *Decoder) decode(av types.AttributeValue, v reflect.Value, fieldTag tag) var u Unmarshaler _, isNull := av.(*types.AttributeValueMemberNULL) if av == nil || isNull { - u, v = indirect(v, true) + u, v = indirect(v, indirectOptions{decodeNull: true}) if u != nil { return u.UnmarshalDynamoDBAttributeValue(av) } return d.decodeNull(v) } - u, v = indirect(v, false) + u, v = indirect(v, indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBAttributeValue(av) } @@ -386,7 +387,7 @@ func (d *Decoder) decodeBinarySet(bs [][]byte, v reflect.Value) error { if !isArray { v.SetLen(i + 1) } - u, elem := indirect(v.Index(i), false) + u, elem := indirect(v.Index(i), indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberBS{Value: bs}) } @@ -513,7 +514,7 @@ func (d *Decoder) decodeNumberSet(ns []string, v reflect.Value) error { if !isArray { v.SetLen(i + 1) } - u, elem := indirect(v.Index(i), false) + u, elem := indirect(v.Index(i), indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberNS{Value: ns}) } @@ -564,32 +565,48 @@ func (d *Decoder) decodeList(avList []types.AttributeValue, v reflect.Value) err return nil } -func (d *Decoder) decodeMap(avMap map[string]types.AttributeValue, v reflect.Value) error { +func (d *Decoder) decodeMap(avMap map[string]types.AttributeValue, v reflect.Value) (err error) { + var decodeMapKey func(v string, key reflect.Value, fieldTag tag) error + switch v.Kind() { case reflect.Map: - t := v.Type() - if t.Key().Kind() != reflect.String { - return &UnmarshalTypeError{Value: "map string key", Type: t.Key()} + decodeMapKey, err = d.getMapKeyDecoder(v.Type().Key()) + if err != nil { + return err } + if v.IsNil() { - v.Set(reflect.MakeMap(t)) + v.Set(reflect.MakeMap(v.Type())) } case reflect.Struct: case reflect.Interface: v.Set(reflect.MakeMap(stringInterfaceMapType)) + decodeMapKey = d.decodeString v = v.Elem() default: return &UnmarshalTypeError{Value: "map", Type: v.Type()} } if v.Kind() == reflect.Map { + keyType := v.Type().Key() + valueType := v.Type().Elem() for k, av := range avMap { - key := reflect.New(v.Type().Key()).Elem() - key.SetString(k) - elem := reflect.New(v.Type().Elem()).Elem() + key := reflect.New(keyType).Elem() + // handle pointer keys + _, indirectKey := indirect(key, indirectOptions{skipUnmarshaler: true}) + if err := decodeMapKey(k, indirectKey, tag{}); err != nil { + return &UnmarshalTypeError{ + Value: fmt.Sprintf("map key %q", k), + Type: keyType, + Err: err, + } + } + + elem := reflect.New(valueType).Elem() if err := d.decode(av, elem, tag{}); err != nil { return err } + v.SetMapIndex(key, elem) } } else if v.Kind() == reflect.Struct { @@ -609,6 +626,50 @@ func (d *Decoder) decodeMap(avMap map[string]types.AttributeValue, v reflect.Val return nil } +var numberType = reflect.TypeOf(Number("")) +var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + +func (d *Decoder) getMapKeyDecoder(keyType reflect.Type) (func(string, reflect.Value, tag) error, error) { + // Test the key type to determine if it implements the TextUnmarshaler interface. + if reflect.PtrTo(keyType).Implements(textUnmarshalerType) || keyType.Implements(textUnmarshalerType) { + return func(v string, k reflect.Value, _ tag) error { + if !k.CanAddr() { + return fmt.Errorf("cannot take address of map key, %v", k.Type()) + } + return k.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(v)) + }, nil + } + + var decodeMapKey func(v string, key reflect.Value, fieldTag tag) error + + switch keyType.Kind() { + case reflect.Bool: + decodeMapKey = func(v string, key reflect.Value, fieldTag tag) error { + b, err := strconv.ParseBool(v) + if err != nil { + return err + } + return d.decodeBool(b, key) + } + case reflect.String: + // Number type handled as a string + decodeMapKey = d.decodeString + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + decodeMapKey = d.decodeNumber + + default: + return nil, &UnmarshalTypeError{ + Value: "map key must be string, number, bool, or TextUnmarshaler", + Type: keyType, + } + } + + return decodeMapKey, nil +} + func (d *Decoder) decodeNull(v reflect.Value) error { if v.IsValid() && v.CanSet() { v.Set(reflect.Zero(v.Type())) @@ -675,7 +736,7 @@ func (d *Decoder) decodeStringSet(ss []string, v reflect.Value) error { if !isArray { v.SetLen(i + 1) } - u, elem := indirect(v.Index(i), false) + u, elem := indirect(v.Index(i), indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberSS{Value: ss}) } @@ -713,38 +774,82 @@ func decoderFieldByIndex(v reflect.Value, index []int) reflect.Value { return v } +type indirectOptions struct { + decodeNull bool + skipUnmarshaler bool +} + // indirect will walk a value's interface or pointer value types. Returning // the final value or the value a unmarshaler is defined on. // // Based on the enoding/json type reflect value type indirection in Go Stdlib // https://golang.org/src/encoding/json/decode.go indirect func. -func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, reflect.Value) { +func indirect(v reflect.Value, opts indirectOptions) (Unmarshaler, reflect.Value) { + // Issue #24153 indicates that it is generally not a guaranteed property + // that you may round-trip a reflect.Value by calling Value.Addr().Elem() + // and expect the value to still be settable for values derived from + // unexported embedded struct fields. + // + // The logic below effectively does this when it first addresses the value + // (to satisfy possible pointer methods) and continues to dereference + // subsequent pointers as necessary. + // + // After the first round-trip, we set v back to the original value to + // preserve the original RW flags contained in reflect.Value. + v0 := v + haveAddr := false + + // If v is a named type and is addressable, + // start with its address, so that if the type has pointer methods, + // we find them. if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() { + haveAddr = true v = v.Addr() } + for { + // Load value from interface, but only if the result will be + // usefully addressable. if v.Kind() == reflect.Interface && !v.IsNil() { e := v.Elem() - if e.Kind() == reflect.Ptr && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Ptr) { + if e.Kind() == reflect.Ptr && !e.IsNil() && (!opts.decodeNull || e.Elem().Kind() == reflect.Ptr) { + haveAddr = false v = e continue } + if e.Kind() != reflect.Ptr && e.IsValid() { + return nil, e + } } if v.Kind() != reflect.Ptr { break } - if v.Elem().Kind() != reflect.Ptr && decodingNull && v.CanSet() { + if opts.decodeNull && v.CanSet() { + break + } + + // Prevent infinite loop if v is an interface pointing to its own address: + // var v interface{} + // v = &v + if v.Elem().Kind() == reflect.Interface && v.Elem().Elem() == v { + v = v.Elem() break } if v.IsNil() { v.Set(reflect.New(v.Type().Elem())) } - if v.Type().NumMethod() > 0 { + if !opts.skipUnmarshaler && v.Type().NumMethod() > 0 && v.CanInterface() { if u, ok := v.Interface().(Unmarshaler); ok { return u, reflect.Value{} } } - v = v.Elem() + + if haveAddr { + v = v0 // restore original value after round-trip Value.Addr().Elem() + haveAddr = false + } else { + v = v.Elem() + } } return nil, v @@ -782,8 +887,12 @@ func (n Number) String() string { type UnmarshalTypeError struct { Value string Type reflect.Type + Err error } +// Unwrap returns the underlying error if any. +func (e *UnmarshalTypeError) Unwrap() error { return e.Err } + // Error returns the string representation of the error. // satisfying the error interface func (e *UnmarshalTypeError) Error() string { diff --git a/feature/dynamodb/attributevalue/decode_test.go b/feature/dynamodb/attributevalue/decode_test.go index feafd697c83..10e279de89e 100644 --- a/feature/dynamodb/attributevalue/decode_test.go +++ b/feature/dynamodb/attributevalue/decode_test.go @@ -335,7 +335,10 @@ func TestUnmarshalMapError(t *testing.T) { }, actual: &map[int]interface{}{}, expected: nil, - err: &UnmarshalTypeError{Value: "map string key", Type: reflect.TypeOf(int(0))}, + err: &UnmarshalTypeError{ + Value: `map key "BOOL"`, + Type: reflect.TypeOf(int(0)), + }, }, } @@ -765,3 +768,197 @@ func TestDecodeAliasType(t *testing.T) { t.Errorf("expect:\n%v\nactual:\n%v", expect, actual) } } + +type testUnmarshalMapKeyComplex struct { + Foo string +} + +func (t *testUnmarshalMapKeyComplex) UnmarshalText(b []byte) error { + t.Foo = string(b) + return nil +} +func (t *testUnmarshalMapKeyComplex) UnmarshalDynamoDBAttributeValue(av types.AttributeValue) error { + avM, ok := av.(*types.AttributeValueMemberM) + if !ok { + return fmt.Errorf("unexpected AttributeValue type %T, %v", av, av) + } + avFoo, ok := avM.Value["foo"] + if !ok { + return nil + } + + avS, ok := avFoo.(*types.AttributeValueMemberS) + if !ok { + return fmt.Errorf("unexpected Foo AttributeValue type, %T, %v", avM, avM) + } + + t.Foo = avS.Value + + return nil +} + +func TestUnmarshalMap_keyTypes(t *testing.T) { + type StrAlias string + type IntAlias int + type BoolAlias bool + + cases := map[string]struct { + input map[string]types.AttributeValue + expectVal interface{} + expectType func() interface{} + }{ + "string key": { + input: map[string]types.AttributeValue{ + "a": &types.AttributeValueMemberN{Value: "123"}, + "b": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[string]interface{}{} }, + expectVal: map[string]interface{}{ + "a": 123., + "b": "efg", + }, + }, + "string alias key": { + input: map[string]types.AttributeValue{ + "a": &types.AttributeValueMemberN{Value: "123"}, + "b": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[StrAlias]interface{}{} }, + expectVal: map[StrAlias]interface{}{ + "a": 123., + "b": "efg", + }, + }, + "Number key": { + input: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[Number]interface{}{} }, + expectVal: map[Number]interface{}{ + Number("1"): 123., + Number("2"): "efg", + }, + }, + "int key": { + input: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[int]interface{}{} }, + expectVal: map[int]interface{}{ + 1: 123., + 2: "efg", + }, + }, + "int alias key": { + input: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[IntAlias]interface{}{} }, + expectVal: map[IntAlias]interface{}{ + 1: 123., + 2: "efg", + }, + }, + "bool key": { + input: map[string]types.AttributeValue{ + "true": &types.AttributeValueMemberN{Value: "123"}, + "false": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[bool]interface{}{} }, + expectVal: map[bool]interface{}{ + true: 123., + false: "efg", + }, + }, + "bool alias key": { + input: map[string]types.AttributeValue{ + "true": &types.AttributeValueMemberN{Value: "123"}, + "false": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[BoolAlias]interface{}{} }, + expectVal: map[BoolAlias]interface{}{ + true: 123., + false: "efg", + }, + }, + "textMarshaler key": { + input: map[string]types.AttributeValue{ + "Foo:1": &types.AttributeValueMemberN{Value: "123"}, + "Foo:2": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[testTextMarshaler]interface{}{} }, + expectVal: map[testTextMarshaler]interface{}{ + {Foo: "1"}: 123., + {Foo: "2"}: "efg", + }, + }, + "textMarshaler DDBAvMarshaler key": { + input: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[testUnmarshalMapKeyComplex]interface{}{} }, + expectVal: map[testUnmarshalMapKeyComplex]interface{}{ + {Foo: "1"}: 123., + {Foo: "2"}: "efg", + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + actualVal := c.expectType() + err := UnmarshalMap(c.input, &actualVal) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + t.Logf("expectType, %T", actualVal) + + if diff := cmp.Diff(c.expectVal, actualVal); diff != "" { + t.Errorf("expect value match\n%s", diff) + } + }) + } +} + +func TestUnmarshalMap_keyPtrTypes(t *testing.T) { + input := map[string]types.AttributeValue{ + "Foo:1": &types.AttributeValueMemberN{Value: "123"}, + "Foo:2": &types.AttributeValueMemberS{Value: "efg"}, + } + + expectVal := map[*testTextMarshaler]interface{}{ + {Foo: "1"}: 123., + {Foo: "2"}: "efg", + } + + actualVal := map[*testTextMarshaler]interface{}{} + err := UnmarshalMap(input, &actualVal) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + t.Logf("expectType, %T", actualVal) + + if e, a := len(expectVal), len(actualVal); e != a { + t.Errorf("expect %v values, got %v", e, a) + } + + for k, v := range expectVal { + var found bool + for ak, av := range actualVal { + if *k == *ak { + found = true + if diff := cmp.Diff(v, av); diff != "" { + t.Errorf("expect value match\n%s", diff) + } + } + } + if !found { + t.Errorf("expect %v key not found", *k) + } + } + +} diff --git a/feature/dynamodb/attributevalue/encode.go b/feature/dynamodb/attributevalue/encode.go index c8dcf94736a..4a96c08d31c 100644 --- a/feature/dynamodb/attributevalue/encode.go +++ b/feature/dynamodb/attributevalue/encode.go @@ -1,6 +1,7 @@ package attributevalue import ( + "encoding" "fmt" "reflect" "strconv" @@ -380,6 +381,7 @@ type Encoder struct { // the `opts` functional options to override the default configuration. func NewEncoder(optFns ...func(*EncoderOptions)) *Encoder { options := EncoderOptions{ + TagKey: defaultTagKey, NullEmptySets: true, } for _, fn := range optFns { @@ -497,9 +499,9 @@ func (e *Encoder) encodeStruct(v reflect.Value, fieldTag tag) (types.AttributeVa func (e *Encoder) encodeMap(v reflect.Value, fieldTag tag) (types.AttributeValue, error) { m := &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{}} for _, key := range v.MapKeys() { - keyName := fmt.Sprint(key.Interface()) - if keyName == "" { - return nil, &InvalidMarshalError{msg: "map key cannot be empty"} + keyName, err := mapKeyAsString(key, fieldTag) + if err != nil { + return nil, err } elemVal := v.MapIndex(key) @@ -519,6 +521,40 @@ func (e *Encoder) encodeMap(v reflect.Value, fieldTag tag) (types.AttributeValue return m, nil } +func mapKeyAsString(keyVal reflect.Value, fieldTag tag) (keyStr string, err error) { + defer func() { + if err != nil { + return + } + if keyStr == "" { + err = &InvalidMarshalError{msg: "map key cannot be empty"} + } + }() + + if k, ok := keyVal.Interface().(encoding.TextMarshaler); ok { + b, err := k.MarshalText() + if err != nil { + return "", fmt.Errorf("failed to marshal text, %w", err) + } + return string(b), err + } + + switch keyVal.Kind() { + case reflect.Bool, + reflect.String, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + + return fmt.Sprint(keyVal.Interface()), nil + + default: + return "", &InvalidMarshalError{ + msg: "map key type not supported, must be string, number, bool, or TextMarshaler", + } + } +} + func (e *Encoder) encodeSlice(v reflect.Value, fieldTag tag) (types.AttributeValue, error) { if v.Type().Elem().Kind() == reflect.Uint8 { slice := reflect.MakeSlice(byteSliceType, v.Len(), v.Len()) diff --git a/feature/dynamodb/attributevalue/encode_test.go b/feature/dynamodb/attributevalue/encode_test.go index 58793aee34f..23c2fbfc01a 100644 --- a/feature/dynamodb/attributevalue/encode_test.go +++ b/feature/dynamodb/attributevalue/encode_test.go @@ -1,13 +1,14 @@ package attributevalue import ( - smithydocument "github.com/aws/smithy-go/document" - "github.com/google/go-cmp/cmp/cmpopts" "reflect" "strconv" "testing" "time" + smithydocument "github.com/aws/smithy-go/document" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/google/go-cmp/cmp" @@ -366,3 +367,130 @@ func TestEncoderFieldByIndex(t *testing.T) { t.Error("expected f to be of kind Int with value equal to outer.Inner") } } + +func TestMarshalMap_keyTypes(t *testing.T) { + type StrAlias string + type IntAlias int + type BoolAlias bool + + cases := map[string]struct { + input interface{} + expectAV map[string]types.AttributeValue + }{ + "string key": { + input: map[string]interface{}{ + "a": 123, + "b": "efg", + }, + expectAV: map[string]types.AttributeValue{ + "a": &types.AttributeValueMemberN{Value: "123"}, + "b": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "string alias key": { + input: map[StrAlias]interface{}{ + "a": 123, + "b": "efg", + }, + expectAV: map[string]types.AttributeValue{ + "a": &types.AttributeValueMemberN{Value: "123"}, + "b": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "Number key": { + input: map[Number]interface{}{ + Number("1"): 123, + Number("2"): "efg", + }, + expectAV: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "int key": { + input: map[int]interface{}{ + 1: 123, + 2: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "int alias key": { + input: map[IntAlias]interface{}{ + 1: 123, + 2: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "bool key": { + input: map[bool]interface{}{ + true: 123, + false: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "true": &types.AttributeValueMemberN{Value: "123"}, + "false": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "bool alias key": { + input: map[BoolAlias]interface{}{ + true: 123, + false: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "true": &types.AttributeValueMemberN{Value: "123"}, + "false": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "textMarshaler key": { + input: map[testTextMarshaler]interface{}{ + {Foo: "1"}: 123, + {Foo: "2"}: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "Foo:1": &types.AttributeValueMemberN{Value: "123"}, + "Foo:2": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "textMarshaler ptr key": { + input: map[*testTextMarshaler]interface{}{ + {Foo: "1"}: 123, + {Foo: "2"}: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "Foo:1": &types.AttributeValueMemberN{Value: "123"}, + "Foo:2": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + av, err := MarshalMap(c.input) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + cmpOptions := cmp.Options{ + cmpopts.IgnoreUnexported(types.AttributeValueMemberM{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberN{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberNS{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberBOOL{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberB{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberBS{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberL{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberS{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberSS{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberNULL{}), + } + if diff := cmp.Diff(c.expectAV, av, cmpOptions...); diff != "" { + t.Errorf("expect attribute value match\n%s", diff) + } + }) + } +} diff --git a/feature/dynamodb/attributevalue/field.go b/feature/dynamodb/attributevalue/field.go index 7abd3479a96..4f63bc7df99 100644 --- a/feature/dynamodb/attributevalue/field.go +++ b/feature/dynamodb/attributevalue/field.go @@ -5,6 +5,8 @@ import ( "sort" ) +const defaultTagKey = "dynamodbav" + type field struct { tag @@ -46,7 +48,12 @@ type structFieldOptions struct { // unionStructFields returns a list of fields for the given type. Type info is cached // to avoid repeated calls into the reflect package func unionStructFields(t reflect.Type, opts structFieldOptions) *cachedFields { - if cached, ok := fieldCache.Load(t); ok { + key := fieldCacheKey{ + typ: t, + opts: opts, + } + + if cached, ok := fieldCache.Load(key); ok { return cached } @@ -62,7 +69,7 @@ func unionStructFields(t reflect.Type, opts structFieldOptions) *cachedFields { fs.fieldsByName[f.Name] = i } - cached, _ := fieldCache.LoadOrStore(t, fs) + cached, _ := fieldCache.LoadOrStore(key, fs) return cached } @@ -105,7 +112,7 @@ func enumFields(t reflect.Type, opts structFieldOptions) []field { fieldTag := tag{} fieldTag.parseAVTag(sf.Tag) // Because MarshalOptions.TagKey must be explicitly set. - if opts.TagKey != "" && fieldTag == (tag{}) { + if opts.TagKey != "" && opts.TagKey != defaultTagKey { fieldTag.parseStructTag(opts.TagKey, sf.Tag) } diff --git a/feature/dynamodb/attributevalue/field_cache.go b/feature/dynamodb/attributevalue/field_cache.go index 60a9d9c7499..c0fc4679a86 100644 --- a/feature/dynamodb/attributevalue/field_cache.go +++ b/feature/dynamodb/attributevalue/field_cache.go @@ -1,25 +1,31 @@ package attributevalue import ( + "reflect" "strings" "sync" ) -var fieldCache fieldCacher +var fieldCache = &fieldCacher{} + +type fieldCacheKey struct { + typ reflect.Type + opts structFieldOptions +} type fieldCacher struct { cache sync.Map } -func (c *fieldCacher) Load(t interface{}) (*cachedFields, bool) { - if v, ok := c.cache.Load(t); ok { +func (c *fieldCacher) Load(key fieldCacheKey) (*cachedFields, bool) { + if v, ok := c.cache.Load(key); ok { return v.(*cachedFields), true } return nil, false } -func (c *fieldCacher) LoadOrStore(t interface{}, fs *cachedFields) (*cachedFields, bool) { - v, ok := c.cache.LoadOrStore(t, fs) +func (c *fieldCacher) LoadOrStore(key fieldCacheKey, fs *cachedFields) (*cachedFields, bool) { + v, ok := c.cache.LoadOrStore(key, fs) return v.(*cachedFields), ok } diff --git a/feature/dynamodb/attributevalue/field_test.go b/feature/dynamodb/attributevalue/field_test.go index 82a09d6cf99..9c2b291703d 100644 --- a/feature/dynamodb/attributevalue/field_test.go +++ b/feature/dynamodb/attributevalue/field_test.go @@ -1,6 +1,7 @@ package attributevalue import ( + "fmt" "reflect" "testing" ) @@ -22,7 +23,7 @@ type unionComplex struct { } type unionTagged struct { - A int `json:"A"` + A int `dynamodbav:"ddbav" json:"A" taga:"TagA" tagb:"TagB"` } type unionTaggedComplex struct { @@ -32,97 +33,211 @@ type unionTaggedComplex struct { } func TestUnionStructFields(t *testing.T) { - var cases = []struct { + origFieldCache := fieldCache + defer func() { fieldCache = origFieldCache }() + + fieldCache = &fieldCacher{} + + var cases = map[string]struct { in interface{} + opts structFieldOptions expect []testUnionValues }{ - { - in: unionSimple{1, "2", []string{"abc"}}, + "simple input": { + in: unionSimple{1, "2", []string{"abc"}}, + opts: structFieldOptions{TagKey: "json"}, expect: []testUnionValues{ {"A", 1}, {"B", "2"}, {"C", []string{"abc"}}, }, }, - { + "nested struct": { in: unionComplex{ unionSimple: unionSimple{1, "2", []string{"abc"}}, A: 2, }, + opts: structFieldOptions{TagKey: "json"}, expect: []testUnionValues{ {"B", "2"}, {"C", []string{"abc"}}, {"A", 2}, }, }, - { + "with TagKey unset": { + in: unionTaggedComplex{ + unionSimple: unionSimple{1, "2", []string{"abc"}}, + unionTagged: unionTagged{3}, + B: "3", + }, + expect: []testUnionValues{ + {"A", 1}, + {"C", []string{"abc"}}, + {"ddbav", 3}, + {"B", "3"}, + }, + }, + "with TagKey json": { in: unionTaggedComplex{ unionSimple: unionSimple{1, "2", []string{"abc"}}, unionTagged: unionTagged{3}, B: "3", }, + opts: structFieldOptions{TagKey: "json"}, expect: []testUnionValues{ {"C", []string{"abc"}}, {"A", 3}, {"B", "3"}, }, }, + "with TagKey taga": { + in: unionTaggedComplex{ + unionSimple: unionSimple{1, "2", []string{"abc"}}, + unionTagged: unionTagged{3}, + B: "3", + }, + opts: structFieldOptions{TagKey: "taga"}, + expect: []testUnionValues{ + {"A", 1}, + {"C", []string{"abc"}}, + {"TagA", 3}, + {"B", "3"}, + }, + }, + "with TagKey tagb": { + in: unionTaggedComplex{ + unionSimple: unionSimple{1, "2", []string{"abc"}}, + unionTagged: unionTagged{3}, + B: "3", + }, + opts: structFieldOptions{TagKey: "tagb"}, + expect: []testUnionValues{ + {"A", 1}, + {"C", []string{"abc"}}, + {"TagB", 3}, + {"B", "3"}, + }, + }, } - for i, c := range cases { - v := reflect.ValueOf(c.in) + for name, c := range cases { + t.Run(name, func(t *testing.T) { + v := reflect.ValueOf(c.in) - fields := unionStructFields(v.Type(), structFieldOptions{TagKey: "json"}) - for j, f := range fields.All() { - expected := c.expect[j] - if e, a := expected.Name, f.Name; e != a { - t.Errorf("%d:%d expect %v, got %v", i, j, e, f) - } - actual := v.FieldByIndex(f.Index).Interface() - if e, a := expected.Value, actual; !reflect.DeepEqual(e, a) { - t.Errorf("%d:%d expect %v, got %v", i, j, e, f) + fields := unionStructFields(v.Type(), c.opts) + for i, f := range fields.All() { + expected := c.expect[i] + if e, a := expected.Name, f.Name; e != a { + t.Errorf("%d expect %v, got %v, %v", i, e, a, f) + } + actual := v.FieldByIndex(f.Index).Interface() + if e, a := expected.Value, actual; !reflect.DeepEqual(e, a) { + t.Errorf("%d expect %v, got %v, %v", i, e, a, f) + } } - } + }) } } func TestCachedFields(t *testing.T) { type myStruct struct { - Dog int + Dog int `tag1:"rabbit" tag2:"cow" tag3:"horse"` CAT string bird bool } - fields := unionStructFields(reflect.TypeOf(myStruct{}), structFieldOptions{}) - - const expectedNumFields = 2 - if numFields := len(fields.All()); numFields != expectedNumFields { - t.Errorf("expected number of fields to be %d but got %d", expectedNumFields, numFields) - } - - cases := []struct { + cases := map[string][]struct { Name string FieldName string Found bool }{ - {"Dog", "Dog", true}, - {"dog", "Dog", true}, - {"DOG", "Dog", true}, - {"Yorkie", "", false}, - {"Cat", "CAT", true}, - {"cat", "CAT", true}, - {"CAT", "CAT", true}, - {"tiger", "", false}, - {"bird", "", false}, + "": { + {"Dog", "Dog", true}, + {"dog", "Dog", true}, + {"DOG", "Dog", true}, + {"Yorkie", "", false}, + {"Cat", "CAT", true}, + {"cat", "CAT", true}, + {"CAT", "CAT", true}, + {"tiger", "", false}, + {"bird", "", false}, + }, + "tag1": { + {"rabbit", "rabbit", true}, + {"Rabbit", "rabbit", true}, + {"cow", "", false}, + {"Cow", "", false}, + {"horse", "", false}, + {"Horse", "", false}, + {"Dog", "", false}, + {"dog", "", false}, + {"DOG", "", false}, + {"Cat", "CAT", true}, + {"cat", "CAT", true}, + {"CAT", "CAT", true}, + {"tiger", "", false}, + {"bird", "", false}, + }, + "tag2": { + {"rabbit", "", false}, + {"Rabbit", "", false}, + {"cow", "cow", true}, + {"Cow", "cow", true}, + {"horse", "", false}, + {"Horse", "", false}, + {"Dog", "", false}, + {"dog", "", false}, + {"DOG", "", false}, + {"Cat", "CAT", true}, + {"cat", "CAT", true}, + {"CAT", "CAT", true}, + {"tiger", "", false}, + {"bird", "", false}, + }, + "tag3": { + {"rabbit", "", false}, + {"Rabbit", "", false}, + {"cow", "", false}, + {"Cow", "", false}, + {"horse", "horse", true}, + {"Horse", "horse", true}, + {"Dog", "", false}, + {"dog", "", false}, + {"DOG", "", false}, + {"Cat", "CAT", true}, + {"cat", "CAT", true}, + {"CAT", "CAT", true}, + {"tiger", "", false}, + {"bird", "", false}, + }, } - for _, c := range cases { - f, found := fields.FieldByName(c.Name) - if found != c.Found { - t.Errorf("expected found to be %v but got %v", c.Found, found) - } - if found && f.Name != c.FieldName { - t.Errorf("expected field name to be %s but got %s", c.FieldName, f.Name) + for tagKey, cs := range cases { + for _, c := range cs { + name := tagKey + if name == "" { + name = "none" + } + t.Run(fmt.Sprintf("%s/%s", name, c.Name), func(t *testing.T) { + t.Parallel() + + fields := unionStructFields(reflect.TypeOf(myStruct{}), structFieldOptions{ + TagKey: tagKey, + }) + + const expectedNumFields = 2 + if numFields := len(fields.All()); numFields != expectedNumFields { + t.Errorf("expect %v fields, got %d", expectedNumFields, numFields) + } + + f, found := fields.FieldByName(c.Name) + if found != c.Found { + t.Errorf("expect %v found, got %v", c.Found, found) + } + if found && f.Name != c.FieldName { + t.Errorf("expect %v field name, got %s", c.FieldName, f.Name) + } + }) } } } diff --git a/feature/dynamodb/attributevalue/marshaler_test.go b/feature/dynamodb/attributevalue/marshaler_test.go index 6d84c8adac5..6d895e89b4a 100644 --- a/feature/dynamodb/attributevalue/marshaler_test.go +++ b/feature/dynamodb/attributevalue/marshaler_test.go @@ -520,7 +520,7 @@ func compareObjects(t *testing.T, expected interface{}, actual interface{}) { } func BenchmarkMarshalOneMember(b *testing.B) { - fieldCache = fieldCacher{} + fieldCache = &fieldCacher{} simple := simpleMarshalStruct{ String: "abc", @@ -547,7 +547,7 @@ func BenchmarkMarshalOneMember(b *testing.B) { } func BenchmarkMarshalTwoMembers(b *testing.B) { - fieldCache = fieldCacher{} + fieldCache = &fieldCacher{} simple := simpleMarshalStruct{ String: "abc", @@ -576,7 +576,7 @@ func BenchmarkMarshalTwoMembers(b *testing.B) { } func BenchmarkUnmarshalOneMember(b *testing.B) { - fieldCache = fieldCacher{} + fieldCache = &fieldCacher{} myStructAVMap, _ := Marshal(simpleMarshalStruct{ String: "abc", @@ -605,7 +605,7 @@ func BenchmarkUnmarshalOneMember(b *testing.B) { } func BenchmarkUnmarshalTwoMembers(b *testing.B) { - fieldCache = fieldCacher{} + fieldCache = &fieldCacher{} myStructAVMap, _ := Marshal(simpleMarshalStruct{ String: "abc", diff --git a/feature/dynamodb/attributevalue/shared_test.go b/feature/dynamodb/attributevalue/shared_test.go index 63a09bbf640..b2c249ae4ef 100644 --- a/feature/dynamodb/attributevalue/shared_test.go +++ b/feature/dynamodb/attributevalue/shared_test.go @@ -1,18 +1,37 @@ package attributevalue import ( - smithydocument "github.com/aws/smithy-go/document" - "github.com/google/go-cmp/cmp/cmpopts" + "fmt" "reflect" "strings" "testing" "time" + smithydocument "github.com/aws/smithy-go/document" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/google/go-cmp/cmp" ) +type testTextMarshaler struct { + Foo string +} + +func (t *testTextMarshaler) UnmarshalText(b []byte) error { + if !strings.HasPrefix(string(b), "Foo:") { + return fmt.Errorf(`missing "Foo:" prefix`) + } + + t.Foo = string(b)[len("Foo:"):] + return nil +} + +func (t testTextMarshaler) MarshalText() ([]byte, error) { + return []byte("Foo:" + t.Foo), nil +} + type testBinarySetStruct struct { Binarys [][]byte `dynamodbav:",binaryset"` } diff --git a/feature/dynamodb/attributevalue/tag.go b/feature/dynamodb/attributevalue/tag.go index 6eb901706fb..f01c432e6ea 100644 --- a/feature/dynamodb/attributevalue/tag.go +++ b/feature/dynamodb/attributevalue/tag.go @@ -18,7 +18,7 @@ type tag struct { } func (t *tag) parseAVTag(structTag reflect.StructTag) { - tagStr := structTag.Get("dynamodbav") + tagStr := structTag.Get(defaultTagKey) if len(tagStr) == 0 { return } diff --git a/feature/dynamodb/expression/expression_test.go b/feature/dynamodb/expression/expression_test.go index e748277d281..c3c3565da71 100644 --- a/feature/dynamodb/expression/expression_test.go +++ b/feature/dynamodb/expression/expression_test.go @@ -384,7 +384,7 @@ func TestUpdate(t *testing.T) { setOperation: { { name: NameBuilder{ - name: "foo", + names: []string{"foo"}, }, value: ValueBuilder{ value: 5, @@ -407,7 +407,7 @@ func TestUpdate(t *testing.T) { setOperation: { { name: NameBuilder{ - name: "foo", + names: []string{"foo"}, }, value: ValueBuilder{ value: 5, @@ -416,7 +416,7 @@ func TestUpdate(t *testing.T) { }, { name: NameBuilder{ - name: "bar", + names: []string{"bar"}, }, value: ValueBuilder{ value: 6, @@ -425,7 +425,7 @@ func TestUpdate(t *testing.T) { }, { name: NameBuilder{ - name: "baz", + names: []string{"baz"}, }, value: ValueBuilder{ value: 7, @@ -496,7 +496,7 @@ func TestNames(t *testing.T) { condition: ConditionBuilder{ operandList: []OperandBuilder{ NameBuilder{ - name: "foo", + names: []string{"foo"}, }, ValueBuilder{ value: 5, @@ -507,7 +507,7 @@ func TestNames(t *testing.T) { filter: ConditionBuilder{ operandList: []OperandBuilder{ NameBuilder{ - name: "bar", + names: []string{"bar"}, }, ValueBuilder{ value: 6, @@ -518,13 +518,13 @@ func TestNames(t *testing.T) { projection: ProjectionBuilder{ names: []NameBuilder{ { - name: "foo", + names: []string{"foo"}, }, { - name: "bar", + names: []string{"bar"}, }, { - name: "baz", + names: []string{"baz"}, }, }, }, @@ -618,7 +618,7 @@ func TestValues(t *testing.T) { condition: ConditionBuilder{ operandList: []OperandBuilder{ NameBuilder{ - name: "foo", + names: []string{"foo"}, }, ValueBuilder{ value: 5, @@ -629,7 +629,7 @@ func TestValues(t *testing.T) { filter: ConditionBuilder{ operandList: []OperandBuilder{ NameBuilder{ - name: "bar", + names: []string{"bar"}, }, ValueBuilder{ value: 6, @@ -640,13 +640,13 @@ func TestValues(t *testing.T) { projection: ProjectionBuilder{ names: []NameBuilder{ { - name: "foo", + names: []string{"foo"}, }, { - name: "bar", + names: []string{"bar"}, }, { - name: "baz", + names: []string{"baz"}, }, }, }, @@ -702,7 +702,7 @@ func TestBuildChildTrees(t *testing.T) { condition: ConditionBuilder{ operandList: []OperandBuilder{ NameBuilder{ - name: "foo", + names: []string{"foo"}, }, ValueBuilder{ value: 5, @@ -713,7 +713,7 @@ func TestBuildChildTrees(t *testing.T) { filter: ConditionBuilder{ operandList: []OperandBuilder{ NameBuilder{ - name: "bar", + names: []string{"bar"}, }, ValueBuilder{ value: 6, @@ -724,13 +724,13 @@ func TestBuildChildTrees(t *testing.T) { projection: ProjectionBuilder{ names: []NameBuilder{ { - name: "foo", + names: []string{"foo"}, }, { - name: "bar", + names: []string{"bar"}, }, { - name: "baz", + names: []string{"baz"}, }, }, }, diff --git a/feature/dynamodb/expression/go.mod b/feature/dynamodb/expression/go.mod index 0d6eb5bbf0c..6f1311f5580 100644 --- a/feature/dynamodb/expression/go.mod +++ b/feature/dynamodb/expression/go.mod @@ -6,6 +6,7 @@ require ( github.com/aws/aws-sdk-go-v2 v1.13.0 github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.6.0 github.com/aws/aws-sdk-go-v2/service/dynamodb v1.13.0 + github.com/google/go-cmp v0.5.6 ) replace github.com/aws/aws-sdk-go-v2 => ../../../ diff --git a/feature/dynamodb/expression/operand.go b/feature/dynamodb/expression/operand.go index 77caed0615e..7eabc00a663 100644 --- a/feature/dynamodb/expression/operand.go +++ b/feature/dynamodb/expression/operand.go @@ -18,7 +18,16 @@ import ( // // Create a ValueBuilder representing the string "aValue" // valueBuilder := expression.Value("aValue") type ValueBuilder struct { - value interface{} + value interface{} + options ValueBuilderOptions +} + +// ValueBuilderOptions provides the options for how a value is built, and +// encoded in the expression. +type ValueBuilderOptions struct { + // Use functional options to specify how the value will be encoded. If the + // value is already an AttributeValue, the EncoderOptions will be ignored. + EncoderOptions []func(*attributevalue.EncoderOptions) } // NameBuilder represents a name of a top level item attribute or a nested @@ -32,7 +41,7 @@ type ValueBuilder struct { // // Create a NameBuilder representing the item attribute "aName" // nameBuilder := expression.Name("aName") type NameBuilder struct { - name string + names []string } // SizeBuilder represents the output of the function size ("someName"), which @@ -132,8 +141,75 @@ type OperandBuilder interface { // // Use Name() to create a condition expression // condition := expression.Name("foo").Equal(expression.Name("bar")) func Name(name string) NameBuilder { + if len(name) == 0 { + return NameBuilder{} + } + + return NameBuilder{ + names: strings.Split(name, "."), + } +} + +// AppendName to adds additional name fields, returning a new NameBuilder. Can +// be used to append list indexes and map fields to the Expression attribute +// name. +// +// Leading or trailing dots(`.`) for Names that are not created with +// NameNoDotSplit will result in an error when the expression is built. The +// dot(`.`) will be added automatically as needed. +func (nb NameBuilder) AppendName(field NameBuilder) NameBuilder { + names := make([]string, 0, len(nb.names)+len(field.names)) + names = append(names, nb.names...) + names = append(names, field.names...) + + // If the name being append starts with a list index it to the name being + // appended to. This allows list indexes to be appended to names. If there + // is a syntax error in the name, it will be caught when the expression is + // built via BuildOperand method. + if len(nb.names) != 0 && len(field.names) != 0 { + lastLeftName := len(nb.names) - 1 + firstRightName := lastLeftName + 1 + if v := names[firstRightName]; len(v) > 0 && v[0] == '[' { + if end := strings.Index(v, "]"); end != -1 { + names[lastLeftName] += v[0 : end+1] + names[firstRightName] = v[end+1:] + // Remove the name if it is empty after moving the index. + if len(names[firstRightName]) == 0 { + copy(names[firstRightName:], names[firstRightName+1:]) + names[len(names)-1] = "" + names = names[:len(names)-1] + } + } + } + } + + return NameBuilder{ + names: names, + } +} + +// NameNoDotSplit returns a NameBuilder. The argument should represent the +// desired item attribute. The name will not be split on dots. The name may end +// with square brackets for list indexes. Square brackets will not be +// considered a part of the NameLiteral. +// +// Use NameBuilder.WithField method to add subsequent map field names. +// Use NameBuilder.WithListIndex method to add list index to the name. +// +// See: http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.Attributes.html +// +// Example: +// +// // Specify a name containing dots, and should not be split. +// name := expression.NameLiteral("Top.Level") +// +// // Specify a nested attribute +// nested := expression.Name("Record[6].SongList") +// // Use Name() to create a condition expression +// condition := expression.Name("foo").Equal(expression.Name("bar")) +func NameNoDotSplit(name string) NameBuilder { return NameBuilder{ - name: name, + names: []string{name}, } } @@ -157,6 +233,37 @@ func Value(value interface{}) ValueBuilder { } } +// ValueWithOptions creates a ValueBuilder and sets its value to the argument. The value +// will be marshalled using the attributevalue package, unless it is of +// type types.AttributeValue, where it will be used directly. +// +// The ValueBuilderOptions functional options parameter allows you to specify +// how the value will be encoded. Including options like AttributeValue +// encoding struct tag. If value is already an DynamoDB AttributeValue, +// encoding options will have not effect. +// +// Empty slices and maps will be encoded as their respective empty types.AttributeValue +// types. If a NULL value is required, pass a dynamodb.AttributeValue, e.g.: +// emptyList := &types.AttributeValueMemberNULL{Value: true} +// +// Example: +// +// // Use Value() to create a condition expression +// condition := expression.Name("foo").Equal(expression.Value(10)) +// // Use Value() to set the value of a set expression. +// update := Set(expression.Name("greets"), expression.Value(&types.AttributeValueMemberS{Value: "hello"})) +func ValueWithOptions(value interface{}, optFns ...func(*ValueBuilderOptions)) ValueBuilder { + var options ValueBuilderOptions + for _, fn := range optFns { + fn(&options) + } + + return ValueBuilder{ + value: value, + options: options, + } +} + // Size creates a SizeBuilder representing the size of the item attribute // specified by the argument NameBuilder. Size() is only valid for certain types // of item attributes. For documentation, @@ -455,7 +562,10 @@ func IfNotExists(name NameBuilder, setValue OperandBuilder) SetValueBuilder { // // // Use IfNotExists() to set item attribute "someName" to value 5 if // // "someName" does not exist yet. (Prevents overwrite) -// update, err := expression.Set(expression.Name("someName"), expression.Name("someName").IfNotExists(expression.Value(5))) +// update, err := expression.Set( +// expression.Name("someName"), +// expression.Name("someName").IfNotExists(expression.Value(5)), +// ) // // Expression Equivalent: // @@ -473,9 +583,10 @@ func (nb NameBuilder) IfNotExists(rightOperand OperandBuilder) SetValueBuilder { // Builder is called. BuildOperand() should never be called externally. // BuildOperand() aliases all strings to avoid stepping over DynamoDB's reserved // words. +// // More information on reserved words at http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/ReservedWords.html func (nb NameBuilder) BuildOperand() (Operand, error) { - if nb.name == "" { + if len(nb.names) == 0 { return Operand{}, newUnsetParameterError("BuildOperand", "NameBuilder") } @@ -483,15 +594,17 @@ func (nb NameBuilder) BuildOperand() (Operand, error) { names: []string{}, } - nameSplit := strings.Split(nb.name, ".") - fmtNames := make([]string, 0, len(nameSplit)) - - for _, word := range nameSplit { + fmtNames := make([]string, 0, len(nb.names)) + for _, word := range nb.names { var substr string if word == "" { return Operand{}, newInvalidParameterError("BuildOperand", "NameBuilder") } + if idx := strings.Index(word, "]"); idx != -1 && idx != len(word)-1 { + return Operand{}, newInvalidParameterError("BuildOperand", "NameBuilder") + } + if word[len(word)-1] == ']' { for j, char := range word { if char == '[' { @@ -554,7 +667,7 @@ func (vb ValueBuilder) BuildOperand() (Operand, error) { case types.AttributeValue: expr = v default: - expr, err = attributevalue.Marshal(vb.value) + expr, err = attributevalue.MarshalWithOptions(vb.value, vb.options.EncoderOptions...) if err != nil { return Operand{}, newInvalidParameterError("BuildOperand", "ValueBuilder") } diff --git a/feature/dynamodb/expression/operand_test.go b/feature/dynamodb/expression/operand_test.go index 5921b0dfec0..108d49916ce 100644 --- a/feature/dynamodb/expression/operand_test.go +++ b/feature/dynamodb/expression/operand_test.go @@ -1,11 +1,13 @@ package expression import ( - "reflect" "strings" "testing" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" ) // opeErrorMode will help with error cases and checking error types @@ -23,6 +25,11 @@ const ( ) func TestBuildOperand(t *testing.T) { + type mockStructValue struct { + A string `dynamodbav:"ddbA" tagb:"TagB"` + B string + } + cases := []struct { name string input OperandBuilder @@ -45,6 +52,37 @@ func TestBuildOperand(t *testing.T) { fmtExpr: "$n.$n", }, }, + { + name: "struct value", + input: ValueWithOptions(mockStructValue{A: "abc123", B: "efg456"}), + expected: exprNode{ + values: []types.AttributeValue{ + &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "ddbA": &types.AttributeValueMemberS{Value: "abc123"}, + "B": &types.AttributeValueMemberS{Value: "efg456"}, + }}, + }, + fmtExpr: "$v", + }, + }, + { + name: "struct value with TagKey", + input: ValueWithOptions(mockStructValue{A: "abc123", B: "efg456"}, + func(o *ValueBuilderOptions) { + o.EncoderOptions = append(o.EncoderOptions, func(o *attributevalue.EncoderOptions) { + o.TagKey = "tagb" + }) + }), + expected: exprNode{ + values: []types.AttributeValue{ + &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "TagB": &types.AttributeValueMemberS{Value: "abc123"}, + "B": &types.AttributeValueMemberS{Value: "efg456"}, + }}, + }, + fmtExpr: "$v", + }, + }, { name: "basic value", input: Value(5), @@ -181,6 +219,14 @@ func TestBuildOperand(t *testing.T) { fmtExpr: "$n.$n", }, }, + { + name: "no split name", + input: NameNoDotSplit("foo.bar"), + expected: exprNode{ + names: []string{"foo.bar"}, + fmtExpr: "$n", + }, + }, { name: "nested name with index", input: Name("foo.bar[0].baz"), @@ -189,6 +235,33 @@ func TestBuildOperand(t *testing.T) { fmtExpr: "$n.$n[0].$n", }, }, + { + name: "no split name with index", + input: NameNoDotSplit("foo.bar[0]"), + expected: exprNode{ + names: []string{"foo.bar"}, + fmtExpr: "$n[0]", + }, + }, + { + name: "no split name append name", + input: NameNoDotSplit("foo.bar").AppendName(Name("foo.bar")), + expected: exprNode{ + names: []string{"foo.bar", "foo", "bar"}, + fmtExpr: "$n.$n.$n", + }, + }, + { + name: "no split name append name with list index", + input: NameNoDotSplit("foo.bar"). + AppendName(Name("foo.bar")). + AppendName(Name("[0]")). + AppendName(Name("abc123")), + expected: exprNode{ + names: []string{"foo.bar", "foo", "bar", "abc123"}, + fmtExpr: "$n.$n.$n[0].$n", + }, + }, { name: "basic size", input: Name("foo").Size(), @@ -238,19 +311,30 @@ func TestBuildOperand(t *testing.T) { if c.err != noOperandError { if err == nil { t.Errorf("expect error %q, got no error", c.err) - } else { - if e, a := string(c.err), err.Error(); !strings.Contains(a, e) { - t.Errorf("expect %q error message to be in %q", e, a) - } - } - } else { - if err != nil { - t.Errorf("expect no error, got unexpected Error %q", err) + } else if e, a := string(c.err), err.Error(); !strings.Contains(a, e) { + t.Errorf("expect %q error message to be in %q", e, a) } + return + } + if err != nil { + t.Fatalf("expect no error, got unexpected Error %q", err) + } - if e, a := c.expected, operand.exprNode; !reflect.DeepEqual(a, e) { - t.Errorf("expect %v, got %v", e, a) - } + cmpOptions := cmp.Options{ + cmp.AllowUnexported(exprNode{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberM{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberN{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberNS{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberBOOL{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberB{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberBS{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberL{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberS{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberSS{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberNULL{}), + } + if diff := cmp.Diff(c.expected, operand.exprNode, cmpOptions...); diff != "" { + t.Errorf("expect operand match\n%s", diff) } }) } diff --git a/feature/dynamodb/expression/update_test.go b/feature/dynamodb/expression/update_test.go index 917e63d2276..4c0a4108ce6 100644 --- a/feature/dynamodb/expression/update_test.go +++ b/feature/dynamodb/expression/update_test.go @@ -599,7 +599,7 @@ func TestUpdateBuildChildNodes(t *testing.T) { { mode: setOperation, name: NameBuilder{ - name: "foo", + names: []string{"foo"}, }, value: ValueBuilder{ value: 5, @@ -608,7 +608,7 @@ func TestUpdateBuildChildNodes(t *testing.T) { { mode: setOperation, name: NameBuilder{ - name: "bar", + names: []string{"bar"}, }, value: ValueBuilder{ value: 6, @@ -617,7 +617,7 @@ func TestUpdateBuildChildNodes(t *testing.T) { { mode: setOperation, name: NameBuilder{ - name: "baz", + names: []string{"baz"}, }, value: ValueBuilder{ value: 7, @@ -626,7 +626,7 @@ func TestUpdateBuildChildNodes(t *testing.T) { { mode: setOperation, name: NameBuilder{ - name: "qux", + names: []string{"qux"}, }, value: ValueBuilder{ value: 8, diff --git a/feature/dynamodbstreams/attributevalue/decode.go b/feature/dynamodbstreams/attributevalue/decode.go index 368c51936ec..2d6ec05e2d7 100644 --- a/feature/dynamodbstreams/attributevalue/decode.go +++ b/feature/dynamodbstreams/attributevalue/decode.go @@ -1,6 +1,7 @@ package attributevalue import ( + "encoding" "fmt" "reflect" "strconv" @@ -197,7 +198,7 @@ func UnmarshalListOfMapsWithOptions(l []map[string]types.AttributeValue, out int } // DecoderOptions is a collection of options to configure how the decoder -// unmarshalls the value. +// unmarshals the value. type DecoderOptions struct { // Support other custom struct tag keys, such as `yaml`, `json`, or `toml`. // Note that values provided with a custom TagKey must also be supported @@ -221,7 +222,7 @@ type Decoder struct { // NewDecoder creates a new Decoder with default configuration. Use // the `opts` functional options to override the default configuration. func NewDecoder(optFns ...func(*DecoderOptions)) *Decoder { - var options DecoderOptions + options := DecoderOptions{TagKey: defaultTagKey} for _, fn := range optFns { fn(&options) } @@ -254,14 +255,14 @@ func (d *Decoder) decode(av types.AttributeValue, v reflect.Value, fieldTag tag) var u Unmarshaler _, isNull := av.(*types.AttributeValueMemberNULL) if av == nil || isNull { - u, v = indirect(v, true) + u, v = indirect(v, indirectOptions{decodeNull: true}) if u != nil { return u.UnmarshalDynamoDBStreamsAttributeValue(av) } return d.decodeNull(v) } - u, v = indirect(v, false) + u, v = indirect(v, indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBStreamsAttributeValue(av) } @@ -386,7 +387,7 @@ func (d *Decoder) decodeBinarySet(bs [][]byte, v reflect.Value) error { if !isArray { v.SetLen(i + 1) } - u, elem := indirect(v.Index(i), false) + u, elem := indirect(v.Index(i), indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBStreamsAttributeValue(&types.AttributeValueMemberBS{Value: bs}) } @@ -513,7 +514,7 @@ func (d *Decoder) decodeNumberSet(ns []string, v reflect.Value) error { if !isArray { v.SetLen(i + 1) } - u, elem := indirect(v.Index(i), false) + u, elem := indirect(v.Index(i), indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBStreamsAttributeValue(&types.AttributeValueMemberNS{Value: ns}) } @@ -564,32 +565,48 @@ func (d *Decoder) decodeList(avList []types.AttributeValue, v reflect.Value) err return nil } -func (d *Decoder) decodeMap(avMap map[string]types.AttributeValue, v reflect.Value) error { +func (d *Decoder) decodeMap(avMap map[string]types.AttributeValue, v reflect.Value) (err error) { + var decodeMapKey func(v string, key reflect.Value, fieldTag tag) error + switch v.Kind() { case reflect.Map: - t := v.Type() - if t.Key().Kind() != reflect.String { - return &UnmarshalTypeError{Value: "map string key", Type: t.Key()} + decodeMapKey, err = d.getMapKeyDecoder(v.Type().Key()) + if err != nil { + return err } + if v.IsNil() { - v.Set(reflect.MakeMap(t)) + v.Set(reflect.MakeMap(v.Type())) } case reflect.Struct: case reflect.Interface: v.Set(reflect.MakeMap(stringInterfaceMapType)) + decodeMapKey = d.decodeString v = v.Elem() default: return &UnmarshalTypeError{Value: "map", Type: v.Type()} } if v.Kind() == reflect.Map { + keyType := v.Type().Key() + valueType := v.Type().Elem() for k, av := range avMap { - key := reflect.New(v.Type().Key()).Elem() - key.SetString(k) - elem := reflect.New(v.Type().Elem()).Elem() + key := reflect.New(keyType).Elem() + // handle pointer keys + _, indirectKey := indirect(key, indirectOptions{skipUnmarshaler: true}) + if err := decodeMapKey(k, indirectKey, tag{}); err != nil { + return &UnmarshalTypeError{ + Value: fmt.Sprintf("map key %q", k), + Type: keyType, + Err: err, + } + } + + elem := reflect.New(valueType).Elem() if err := d.decode(av, elem, tag{}); err != nil { return err } + v.SetMapIndex(key, elem) } } else if v.Kind() == reflect.Struct { @@ -609,6 +626,50 @@ func (d *Decoder) decodeMap(avMap map[string]types.AttributeValue, v reflect.Val return nil } +var numberType = reflect.TypeOf(Number("")) +var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + +func (d *Decoder) getMapKeyDecoder(keyType reflect.Type) (func(string, reflect.Value, tag) error, error) { + // Test the key type to determine if it implements the TextUnmarshaler interface. + if reflect.PtrTo(keyType).Implements(textUnmarshalerType) || keyType.Implements(textUnmarshalerType) { + return func(v string, k reflect.Value, _ tag) error { + if !k.CanAddr() { + return fmt.Errorf("cannot take address of map key, %v", k.Type()) + } + return k.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(v)) + }, nil + } + + var decodeMapKey func(v string, key reflect.Value, fieldTag tag) error + + switch keyType.Kind() { + case reflect.Bool: + decodeMapKey = func(v string, key reflect.Value, fieldTag tag) error { + b, err := strconv.ParseBool(v) + if err != nil { + return err + } + return d.decodeBool(b, key) + } + case reflect.String: + // Number type handled as a string + decodeMapKey = d.decodeString + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + decodeMapKey = d.decodeNumber + + default: + return nil, &UnmarshalTypeError{ + Value: "map key must be string, number, bool, or TextUnmarshaler", + Type: keyType, + } + } + + return decodeMapKey, nil +} + func (d *Decoder) decodeNull(v reflect.Value) error { if v.IsValid() && v.CanSet() { v.Set(reflect.Zero(v.Type())) @@ -675,7 +736,7 @@ func (d *Decoder) decodeStringSet(ss []string, v reflect.Value) error { if !isArray { v.SetLen(i + 1) } - u, elem := indirect(v.Index(i), false) + u, elem := indirect(v.Index(i), indirectOptions{}) if u != nil { return u.UnmarshalDynamoDBStreamsAttributeValue(&types.AttributeValueMemberSS{Value: ss}) } @@ -713,38 +774,82 @@ func decoderFieldByIndex(v reflect.Value, index []int) reflect.Value { return v } +type indirectOptions struct { + decodeNull bool + skipUnmarshaler bool +} + // indirect will walk a value's interface or pointer value types. Returning // the final value or the value a unmarshaler is defined on. // // Based on the enoding/json type reflect value type indirection in Go Stdlib // https://golang.org/src/encoding/json/decode.go indirect func. -func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, reflect.Value) { +func indirect(v reflect.Value, opts indirectOptions) (Unmarshaler, reflect.Value) { + // Issue #24153 indicates that it is generally not a guaranteed property + // that you may round-trip a reflect.Value by calling Value.Addr().Elem() + // and expect the value to still be settable for values derived from + // unexported embedded struct fields. + // + // The logic below effectively does this when it first addresses the value + // (to satisfy possible pointer methods) and continues to dereference + // subsequent pointers as necessary. + // + // After the first round-trip, we set v back to the original value to + // preserve the original RW flags contained in reflect.Value. + v0 := v + haveAddr := false + + // If v is a named type and is addressable, + // start with its address, so that if the type has pointer methods, + // we find them. if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() { + haveAddr = true v = v.Addr() } + for { + // Load value from interface, but only if the result will be + // usefully addressable. if v.Kind() == reflect.Interface && !v.IsNil() { e := v.Elem() - if e.Kind() == reflect.Ptr && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Ptr) { + if e.Kind() == reflect.Ptr && !e.IsNil() && (!opts.decodeNull || e.Elem().Kind() == reflect.Ptr) { + haveAddr = false v = e continue } + if e.Kind() != reflect.Ptr && e.IsValid() { + return nil, e + } } if v.Kind() != reflect.Ptr { break } - if v.Elem().Kind() != reflect.Ptr && decodingNull && v.CanSet() { + if opts.decodeNull && v.CanSet() { + break + } + + // Prevent infinite loop if v is an interface pointing to its own address: + // var v interface{} + // v = &v + if v.Elem().Kind() == reflect.Interface && v.Elem().Elem() == v { + v = v.Elem() break } if v.IsNil() { v.Set(reflect.New(v.Type().Elem())) } - if v.Type().NumMethod() > 0 { + if !opts.skipUnmarshaler && v.Type().NumMethod() > 0 && v.CanInterface() { if u, ok := v.Interface().(Unmarshaler); ok { return u, reflect.Value{} } } - v = v.Elem() + + if haveAddr { + v = v0 // restore original value after round-trip Value.Addr().Elem() + haveAddr = false + } else { + v = v.Elem() + } } return nil, v @@ -782,8 +887,12 @@ func (n Number) String() string { type UnmarshalTypeError struct { Value string Type reflect.Type + Err error } +// Unwrap returns the underlying error if any. +func (e *UnmarshalTypeError) Unwrap() error { return e.Err } + // Error returns the string representation of the error. // satisfying the error interface func (e *UnmarshalTypeError) Error() string { diff --git a/feature/dynamodbstreams/attributevalue/decode_test.go b/feature/dynamodbstreams/attributevalue/decode_test.go index e64c36e5e4a..0f83592a558 100644 --- a/feature/dynamodbstreams/attributevalue/decode_test.go +++ b/feature/dynamodbstreams/attributevalue/decode_test.go @@ -335,7 +335,10 @@ func TestUnmarshalMapError(t *testing.T) { }, actual: &map[int]interface{}{}, expected: nil, - err: &UnmarshalTypeError{Value: "map string key", Type: reflect.TypeOf(int(0))}, + err: &UnmarshalTypeError{ + Value: `map key "BOOL"`, + Type: reflect.TypeOf(int(0)), + }, }, } @@ -765,3 +768,197 @@ func TestDecodeAliasType(t *testing.T) { t.Errorf("expect:\n%v\nactual:\n%v", expect, actual) } } + +type testUnmarshalMapKeyComplex struct { + Foo string +} + +func (t *testUnmarshalMapKeyComplex) UnmarshalText(b []byte) error { + t.Foo = string(b) + return nil +} +func (t *testUnmarshalMapKeyComplex) UnmarshalDynamoDBStreamsAttributeValue(av types.AttributeValue) error { + avM, ok := av.(*types.AttributeValueMemberM) + if !ok { + return fmt.Errorf("unexpected AttributeValue type %T, %v", av, av) + } + avFoo, ok := avM.Value["foo"] + if !ok { + return nil + } + + avS, ok := avFoo.(*types.AttributeValueMemberS) + if !ok { + return fmt.Errorf("unexpected Foo AttributeValue type, %T, %v", avM, avM) + } + + t.Foo = avS.Value + + return nil +} + +func TestUnmarshalMap_keyTypes(t *testing.T) { + type StrAlias string + type IntAlias int + type BoolAlias bool + + cases := map[string]struct { + input map[string]types.AttributeValue + expectVal interface{} + expectType func() interface{} + }{ + "string key": { + input: map[string]types.AttributeValue{ + "a": &types.AttributeValueMemberN{Value: "123"}, + "b": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[string]interface{}{} }, + expectVal: map[string]interface{}{ + "a": 123., + "b": "efg", + }, + }, + "string alias key": { + input: map[string]types.AttributeValue{ + "a": &types.AttributeValueMemberN{Value: "123"}, + "b": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[StrAlias]interface{}{} }, + expectVal: map[StrAlias]interface{}{ + "a": 123., + "b": "efg", + }, + }, + "Number key": { + input: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[Number]interface{}{} }, + expectVal: map[Number]interface{}{ + Number("1"): 123., + Number("2"): "efg", + }, + }, + "int key": { + input: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[int]interface{}{} }, + expectVal: map[int]interface{}{ + 1: 123., + 2: "efg", + }, + }, + "int alias key": { + input: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[IntAlias]interface{}{} }, + expectVal: map[IntAlias]interface{}{ + 1: 123., + 2: "efg", + }, + }, + "bool key": { + input: map[string]types.AttributeValue{ + "true": &types.AttributeValueMemberN{Value: "123"}, + "false": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[bool]interface{}{} }, + expectVal: map[bool]interface{}{ + true: 123., + false: "efg", + }, + }, + "bool alias key": { + input: map[string]types.AttributeValue{ + "true": &types.AttributeValueMemberN{Value: "123"}, + "false": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[BoolAlias]interface{}{} }, + expectVal: map[BoolAlias]interface{}{ + true: 123., + false: "efg", + }, + }, + "textMarshaler key": { + input: map[string]types.AttributeValue{ + "Foo:1": &types.AttributeValueMemberN{Value: "123"}, + "Foo:2": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[testTextMarshaler]interface{}{} }, + expectVal: map[testTextMarshaler]interface{}{ + {Foo: "1"}: 123., + {Foo: "2"}: "efg", + }, + }, + "textMarshaler DDBAvMarshaler key": { + input: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[testUnmarshalMapKeyComplex]interface{}{} }, + expectVal: map[testUnmarshalMapKeyComplex]interface{}{ + {Foo: "1"}: 123., + {Foo: "2"}: "efg", + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + actualVal := c.expectType() + err := UnmarshalMap(c.input, &actualVal) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + t.Logf("expectType, %T", actualVal) + + if diff := cmp.Diff(c.expectVal, actualVal); diff != "" { + t.Errorf("expect value match\n%s", diff) + } + }) + } +} + +func TestUnmarshalMap_keyPtrTypes(t *testing.T) { + input := map[string]types.AttributeValue{ + "Foo:1": &types.AttributeValueMemberN{Value: "123"}, + "Foo:2": &types.AttributeValueMemberS{Value: "efg"}, + } + + expectVal := map[*testTextMarshaler]interface{}{ + {Foo: "1"}: 123., + {Foo: "2"}: "efg", + } + + actualVal := map[*testTextMarshaler]interface{}{} + err := UnmarshalMap(input, &actualVal) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + t.Logf("expectType, %T", actualVal) + + if e, a := len(expectVal), len(actualVal); e != a { + t.Errorf("expect %v values, got %v", e, a) + } + + for k, v := range expectVal { + var found bool + for ak, av := range actualVal { + if *k == *ak { + found = true + if diff := cmp.Diff(v, av); diff != "" { + t.Errorf("expect value match\n%s", diff) + } + } + } + if !found { + t.Errorf("expect %v key not found", *k) + } + } + +} diff --git a/feature/dynamodbstreams/attributevalue/encode.go b/feature/dynamodbstreams/attributevalue/encode.go index 66f649f05d4..89f15c1ee5c 100644 --- a/feature/dynamodbstreams/attributevalue/encode.go +++ b/feature/dynamodbstreams/attributevalue/encode.go @@ -1,6 +1,7 @@ package attributevalue import ( + "encoding" "fmt" "reflect" "strconv" @@ -380,6 +381,7 @@ type Encoder struct { // the `opts` functional options to override the default configuration. func NewEncoder(optFns ...func(*EncoderOptions)) *Encoder { options := EncoderOptions{ + TagKey: defaultTagKey, NullEmptySets: true, } for _, fn := range optFns { @@ -497,9 +499,9 @@ func (e *Encoder) encodeStruct(v reflect.Value, fieldTag tag) (types.AttributeVa func (e *Encoder) encodeMap(v reflect.Value, fieldTag tag) (types.AttributeValue, error) { m := &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{}} for _, key := range v.MapKeys() { - keyName := fmt.Sprint(key.Interface()) - if keyName == "" { - return nil, &InvalidMarshalError{msg: "map key cannot be empty"} + keyName, err := mapKeyAsString(key, fieldTag) + if err != nil { + return nil, err } elemVal := v.MapIndex(key) @@ -519,6 +521,40 @@ func (e *Encoder) encodeMap(v reflect.Value, fieldTag tag) (types.AttributeValue return m, nil } +func mapKeyAsString(keyVal reflect.Value, fieldTag tag) (keyStr string, err error) { + defer func() { + if err != nil { + return + } + if keyStr == "" { + err = &InvalidMarshalError{msg: "map key cannot be empty"} + } + }() + + if k, ok := keyVal.Interface().(encoding.TextMarshaler); ok { + b, err := k.MarshalText() + if err != nil { + return "", fmt.Errorf("failed to marshal text, %w", err) + } + return string(b), err + } + + switch keyVal.Kind() { + case reflect.Bool, + reflect.String, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + + return fmt.Sprint(keyVal.Interface()), nil + + default: + return "", &InvalidMarshalError{ + msg: "map key type not supported, must be string, number, bool, or TextMarshaler", + } + } +} + func (e *Encoder) encodeSlice(v reflect.Value, fieldTag tag) (types.AttributeValue, error) { if v.Type().Elem().Kind() == reflect.Uint8 { slice := reflect.MakeSlice(byteSliceType, v.Len(), v.Len()) diff --git a/feature/dynamodbstreams/attributevalue/encode_test.go b/feature/dynamodbstreams/attributevalue/encode_test.go index 0079eea8f76..af64000ddaa 100644 --- a/feature/dynamodbstreams/attributevalue/encode_test.go +++ b/feature/dynamodbstreams/attributevalue/encode_test.go @@ -1,13 +1,14 @@ package attributevalue import ( - smithydocument "github.com/aws/smithy-go/document" - "github.com/google/go-cmp/cmp/cmpopts" "reflect" "strconv" "testing" "time" + smithydocument "github.com/aws/smithy-go/document" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/dynamodbstreams/types" "github.com/google/go-cmp/cmp" @@ -366,3 +367,130 @@ func TestEncoderFieldByIndex(t *testing.T) { t.Error("expected f to be of kind Int with value equal to outer.Inner") } } + +func TestMarshalMap_keyTypes(t *testing.T) { + type StrAlias string + type IntAlias int + type BoolAlias bool + + cases := map[string]struct { + input interface{} + expectAV map[string]types.AttributeValue + }{ + "string key": { + input: map[string]interface{}{ + "a": 123, + "b": "efg", + }, + expectAV: map[string]types.AttributeValue{ + "a": &types.AttributeValueMemberN{Value: "123"}, + "b": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "string alias key": { + input: map[StrAlias]interface{}{ + "a": 123, + "b": "efg", + }, + expectAV: map[string]types.AttributeValue{ + "a": &types.AttributeValueMemberN{Value: "123"}, + "b": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "Number key": { + input: map[Number]interface{}{ + Number("1"): 123, + Number("2"): "efg", + }, + expectAV: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "int key": { + input: map[int]interface{}{ + 1: 123, + 2: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "int alias key": { + input: map[IntAlias]interface{}{ + 1: 123, + 2: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "bool key": { + input: map[bool]interface{}{ + true: 123, + false: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "true": &types.AttributeValueMemberN{Value: "123"}, + "false": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "bool alias key": { + input: map[BoolAlias]interface{}{ + true: 123, + false: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "true": &types.AttributeValueMemberN{Value: "123"}, + "false": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "textMarshaler key": { + input: map[testTextMarshaler]interface{}{ + {Foo: "1"}: 123, + {Foo: "2"}: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "Foo:1": &types.AttributeValueMemberN{Value: "123"}, + "Foo:2": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "textMarshaler ptr key": { + input: map[*testTextMarshaler]interface{}{ + {Foo: "1"}: 123, + {Foo: "2"}: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "Foo:1": &types.AttributeValueMemberN{Value: "123"}, + "Foo:2": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + av, err := MarshalMap(c.input) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + cmpOptions := cmp.Options{ + cmpopts.IgnoreUnexported(types.AttributeValueMemberM{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberN{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberNS{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberBOOL{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberB{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberBS{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberL{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberS{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberSS{}), + cmpopts.IgnoreUnexported(types.AttributeValueMemberNULL{}), + } + if diff := cmp.Diff(c.expectAV, av, cmpOptions...); diff != "" { + t.Errorf("expect attribute value match\n%s", diff) + } + }) + } +} diff --git a/feature/dynamodbstreams/attributevalue/field.go b/feature/dynamodbstreams/attributevalue/field.go index 7abd3479a96..4f63bc7df99 100644 --- a/feature/dynamodbstreams/attributevalue/field.go +++ b/feature/dynamodbstreams/attributevalue/field.go @@ -5,6 +5,8 @@ import ( "sort" ) +const defaultTagKey = "dynamodbav" + type field struct { tag @@ -46,7 +48,12 @@ type structFieldOptions struct { // unionStructFields returns a list of fields for the given type. Type info is cached // to avoid repeated calls into the reflect package func unionStructFields(t reflect.Type, opts structFieldOptions) *cachedFields { - if cached, ok := fieldCache.Load(t); ok { + key := fieldCacheKey{ + typ: t, + opts: opts, + } + + if cached, ok := fieldCache.Load(key); ok { return cached } @@ -62,7 +69,7 @@ func unionStructFields(t reflect.Type, opts structFieldOptions) *cachedFields { fs.fieldsByName[f.Name] = i } - cached, _ := fieldCache.LoadOrStore(t, fs) + cached, _ := fieldCache.LoadOrStore(key, fs) return cached } @@ -105,7 +112,7 @@ func enumFields(t reflect.Type, opts structFieldOptions) []field { fieldTag := tag{} fieldTag.parseAVTag(sf.Tag) // Because MarshalOptions.TagKey must be explicitly set. - if opts.TagKey != "" && fieldTag == (tag{}) { + if opts.TagKey != "" && opts.TagKey != defaultTagKey { fieldTag.parseStructTag(opts.TagKey, sf.Tag) } diff --git a/feature/dynamodbstreams/attributevalue/field_cache.go b/feature/dynamodbstreams/attributevalue/field_cache.go index 60a9d9c7499..c0fc4679a86 100644 --- a/feature/dynamodbstreams/attributevalue/field_cache.go +++ b/feature/dynamodbstreams/attributevalue/field_cache.go @@ -1,25 +1,31 @@ package attributevalue import ( + "reflect" "strings" "sync" ) -var fieldCache fieldCacher +var fieldCache = &fieldCacher{} + +type fieldCacheKey struct { + typ reflect.Type + opts structFieldOptions +} type fieldCacher struct { cache sync.Map } -func (c *fieldCacher) Load(t interface{}) (*cachedFields, bool) { - if v, ok := c.cache.Load(t); ok { +func (c *fieldCacher) Load(key fieldCacheKey) (*cachedFields, bool) { + if v, ok := c.cache.Load(key); ok { return v.(*cachedFields), true } return nil, false } -func (c *fieldCacher) LoadOrStore(t interface{}, fs *cachedFields) (*cachedFields, bool) { - v, ok := c.cache.LoadOrStore(t, fs) +func (c *fieldCacher) LoadOrStore(key fieldCacheKey, fs *cachedFields) (*cachedFields, bool) { + v, ok := c.cache.LoadOrStore(key, fs) return v.(*cachedFields), ok } diff --git a/feature/dynamodbstreams/attributevalue/field_test.go b/feature/dynamodbstreams/attributevalue/field_test.go index 82a09d6cf99..9c2b291703d 100644 --- a/feature/dynamodbstreams/attributevalue/field_test.go +++ b/feature/dynamodbstreams/attributevalue/field_test.go @@ -1,6 +1,7 @@ package attributevalue import ( + "fmt" "reflect" "testing" ) @@ -22,7 +23,7 @@ type unionComplex struct { } type unionTagged struct { - A int `json:"A"` + A int `dynamodbav:"ddbav" json:"A" taga:"TagA" tagb:"TagB"` } type unionTaggedComplex struct { @@ -32,97 +33,211 @@ type unionTaggedComplex struct { } func TestUnionStructFields(t *testing.T) { - var cases = []struct { + origFieldCache := fieldCache + defer func() { fieldCache = origFieldCache }() + + fieldCache = &fieldCacher{} + + var cases = map[string]struct { in interface{} + opts structFieldOptions expect []testUnionValues }{ - { - in: unionSimple{1, "2", []string{"abc"}}, + "simple input": { + in: unionSimple{1, "2", []string{"abc"}}, + opts: structFieldOptions{TagKey: "json"}, expect: []testUnionValues{ {"A", 1}, {"B", "2"}, {"C", []string{"abc"}}, }, }, - { + "nested struct": { in: unionComplex{ unionSimple: unionSimple{1, "2", []string{"abc"}}, A: 2, }, + opts: structFieldOptions{TagKey: "json"}, expect: []testUnionValues{ {"B", "2"}, {"C", []string{"abc"}}, {"A", 2}, }, }, - { + "with TagKey unset": { + in: unionTaggedComplex{ + unionSimple: unionSimple{1, "2", []string{"abc"}}, + unionTagged: unionTagged{3}, + B: "3", + }, + expect: []testUnionValues{ + {"A", 1}, + {"C", []string{"abc"}}, + {"ddbav", 3}, + {"B", "3"}, + }, + }, + "with TagKey json": { in: unionTaggedComplex{ unionSimple: unionSimple{1, "2", []string{"abc"}}, unionTagged: unionTagged{3}, B: "3", }, + opts: structFieldOptions{TagKey: "json"}, expect: []testUnionValues{ {"C", []string{"abc"}}, {"A", 3}, {"B", "3"}, }, }, + "with TagKey taga": { + in: unionTaggedComplex{ + unionSimple: unionSimple{1, "2", []string{"abc"}}, + unionTagged: unionTagged{3}, + B: "3", + }, + opts: structFieldOptions{TagKey: "taga"}, + expect: []testUnionValues{ + {"A", 1}, + {"C", []string{"abc"}}, + {"TagA", 3}, + {"B", "3"}, + }, + }, + "with TagKey tagb": { + in: unionTaggedComplex{ + unionSimple: unionSimple{1, "2", []string{"abc"}}, + unionTagged: unionTagged{3}, + B: "3", + }, + opts: structFieldOptions{TagKey: "tagb"}, + expect: []testUnionValues{ + {"A", 1}, + {"C", []string{"abc"}}, + {"TagB", 3}, + {"B", "3"}, + }, + }, } - for i, c := range cases { - v := reflect.ValueOf(c.in) + for name, c := range cases { + t.Run(name, func(t *testing.T) { + v := reflect.ValueOf(c.in) - fields := unionStructFields(v.Type(), structFieldOptions{TagKey: "json"}) - for j, f := range fields.All() { - expected := c.expect[j] - if e, a := expected.Name, f.Name; e != a { - t.Errorf("%d:%d expect %v, got %v", i, j, e, f) - } - actual := v.FieldByIndex(f.Index).Interface() - if e, a := expected.Value, actual; !reflect.DeepEqual(e, a) { - t.Errorf("%d:%d expect %v, got %v", i, j, e, f) + fields := unionStructFields(v.Type(), c.opts) + for i, f := range fields.All() { + expected := c.expect[i] + if e, a := expected.Name, f.Name; e != a { + t.Errorf("%d expect %v, got %v, %v", i, e, a, f) + } + actual := v.FieldByIndex(f.Index).Interface() + if e, a := expected.Value, actual; !reflect.DeepEqual(e, a) { + t.Errorf("%d expect %v, got %v, %v", i, e, a, f) + } } - } + }) } } func TestCachedFields(t *testing.T) { type myStruct struct { - Dog int + Dog int `tag1:"rabbit" tag2:"cow" tag3:"horse"` CAT string bird bool } - fields := unionStructFields(reflect.TypeOf(myStruct{}), structFieldOptions{}) - - const expectedNumFields = 2 - if numFields := len(fields.All()); numFields != expectedNumFields { - t.Errorf("expected number of fields to be %d but got %d", expectedNumFields, numFields) - } - - cases := []struct { + cases := map[string][]struct { Name string FieldName string Found bool }{ - {"Dog", "Dog", true}, - {"dog", "Dog", true}, - {"DOG", "Dog", true}, - {"Yorkie", "", false}, - {"Cat", "CAT", true}, - {"cat", "CAT", true}, - {"CAT", "CAT", true}, - {"tiger", "", false}, - {"bird", "", false}, + "": { + {"Dog", "Dog", true}, + {"dog", "Dog", true}, + {"DOG", "Dog", true}, + {"Yorkie", "", false}, + {"Cat", "CAT", true}, + {"cat", "CAT", true}, + {"CAT", "CAT", true}, + {"tiger", "", false}, + {"bird", "", false}, + }, + "tag1": { + {"rabbit", "rabbit", true}, + {"Rabbit", "rabbit", true}, + {"cow", "", false}, + {"Cow", "", false}, + {"horse", "", false}, + {"Horse", "", false}, + {"Dog", "", false}, + {"dog", "", false}, + {"DOG", "", false}, + {"Cat", "CAT", true}, + {"cat", "CAT", true}, + {"CAT", "CAT", true}, + {"tiger", "", false}, + {"bird", "", false}, + }, + "tag2": { + {"rabbit", "", false}, + {"Rabbit", "", false}, + {"cow", "cow", true}, + {"Cow", "cow", true}, + {"horse", "", false}, + {"Horse", "", false}, + {"Dog", "", false}, + {"dog", "", false}, + {"DOG", "", false}, + {"Cat", "CAT", true}, + {"cat", "CAT", true}, + {"CAT", "CAT", true}, + {"tiger", "", false}, + {"bird", "", false}, + }, + "tag3": { + {"rabbit", "", false}, + {"Rabbit", "", false}, + {"cow", "", false}, + {"Cow", "", false}, + {"horse", "horse", true}, + {"Horse", "horse", true}, + {"Dog", "", false}, + {"dog", "", false}, + {"DOG", "", false}, + {"Cat", "CAT", true}, + {"cat", "CAT", true}, + {"CAT", "CAT", true}, + {"tiger", "", false}, + {"bird", "", false}, + }, } - for _, c := range cases { - f, found := fields.FieldByName(c.Name) - if found != c.Found { - t.Errorf("expected found to be %v but got %v", c.Found, found) - } - if found && f.Name != c.FieldName { - t.Errorf("expected field name to be %s but got %s", c.FieldName, f.Name) + for tagKey, cs := range cases { + for _, c := range cs { + name := tagKey + if name == "" { + name = "none" + } + t.Run(fmt.Sprintf("%s/%s", name, c.Name), func(t *testing.T) { + t.Parallel() + + fields := unionStructFields(reflect.TypeOf(myStruct{}), structFieldOptions{ + TagKey: tagKey, + }) + + const expectedNumFields = 2 + if numFields := len(fields.All()); numFields != expectedNumFields { + t.Errorf("expect %v fields, got %d", expectedNumFields, numFields) + } + + f, found := fields.FieldByName(c.Name) + if found != c.Found { + t.Errorf("expect %v found, got %v", c.Found, found) + } + if found && f.Name != c.FieldName { + t.Errorf("expect %v field name, got %s", c.FieldName, f.Name) + } + }) } } } diff --git a/feature/dynamodbstreams/attributevalue/marshaler_test.go b/feature/dynamodbstreams/attributevalue/marshaler_test.go index 26d4d91a5c7..ac0969111aa 100644 --- a/feature/dynamodbstreams/attributevalue/marshaler_test.go +++ b/feature/dynamodbstreams/attributevalue/marshaler_test.go @@ -520,7 +520,7 @@ func compareObjects(t *testing.T, expected interface{}, actual interface{}) { } func BenchmarkMarshalOneMember(b *testing.B) { - fieldCache = fieldCacher{} + fieldCache = &fieldCacher{} simple := simpleMarshalStruct{ String: "abc", @@ -547,7 +547,7 @@ func BenchmarkMarshalOneMember(b *testing.B) { } func BenchmarkMarshalTwoMembers(b *testing.B) { - fieldCache = fieldCacher{} + fieldCache = &fieldCacher{} simple := simpleMarshalStruct{ String: "abc", @@ -576,7 +576,7 @@ func BenchmarkMarshalTwoMembers(b *testing.B) { } func BenchmarkUnmarshalOneMember(b *testing.B) { - fieldCache = fieldCacher{} + fieldCache = &fieldCacher{} myStructAVMap, _ := Marshal(simpleMarshalStruct{ String: "abc", @@ -605,7 +605,7 @@ func BenchmarkUnmarshalOneMember(b *testing.B) { } func BenchmarkUnmarshalTwoMembers(b *testing.B) { - fieldCache = fieldCacher{} + fieldCache = &fieldCacher{} myStructAVMap, _ := Marshal(simpleMarshalStruct{ String: "abc", diff --git a/feature/dynamodbstreams/attributevalue/shared_test.go b/feature/dynamodbstreams/attributevalue/shared_test.go index 62071554878..41a747147a4 100644 --- a/feature/dynamodbstreams/attributevalue/shared_test.go +++ b/feature/dynamodbstreams/attributevalue/shared_test.go @@ -1,18 +1,37 @@ package attributevalue import ( - smithydocument "github.com/aws/smithy-go/document" - "github.com/google/go-cmp/cmp/cmpopts" + "fmt" "reflect" "strings" "testing" "time" + smithydocument "github.com/aws/smithy-go/document" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/dynamodbstreams/types" "github.com/google/go-cmp/cmp" ) +type testTextMarshaler struct { + Foo string +} + +func (t *testTextMarshaler) UnmarshalText(b []byte) error { + if !strings.HasPrefix(string(b), "Foo:") { + return fmt.Errorf(`missing "Foo:" prefix`) + } + + t.Foo = string(b)[len("Foo:"):] + return nil +} + +func (t testTextMarshaler) MarshalText() ([]byte, error) { + return []byte("Foo:" + t.Foo), nil +} + type testBinarySetStruct struct { Binarys [][]byte `dynamodbav:",binaryset"` } diff --git a/feature/dynamodbstreams/attributevalue/tag.go b/feature/dynamodbstreams/attributevalue/tag.go index 6eb901706fb..f01c432e6ea 100644 --- a/feature/dynamodbstreams/attributevalue/tag.go +++ b/feature/dynamodbstreams/attributevalue/tag.go @@ -18,7 +18,7 @@ type tag struct { } func (t *tag) parseAVTag(structTag reflect.StructTag) { - tagStr := structTag.Get("dynamodbav") + tagStr := structTag.Get(defaultTagKey) if len(tagStr) == 0 { return }