diff --git a/encode.go b/encode.go index ee459765..bfb76744 100644 --- a/encode.go +++ b/encode.go @@ -318,7 +318,7 @@ func (enc *Encoder) eArrayOrSliceElement(rv reflect.Value) { length := rv.Len() enc.wf("[") for i := 0; i < length; i++ { - elem := rv.Index(i) + elem := eindirect(rv.Index(i)) enc.eElement(elem) if i != length-1 { enc.wf(", ") @@ -332,7 +332,7 @@ func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) { encPanic(errNoKey) } for i := 0; i < rv.Len(); i++ { - trv := rv.Index(i) + trv := eindirect(rv.Index(i)) if isNil(trv) { continue } @@ -357,7 +357,7 @@ func (enc *Encoder) eTable(key Key, rv reflect.Value) { } func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value, inline bool) { - switch rv := eindirect(rv); rv.Kind() { + switch rv.Kind() { case reflect.Map: enc.eMap(key, rv, inline) case reflect.Struct: @@ -379,7 +379,7 @@ func (enc *Encoder) eMap(key Key, rv reflect.Value, inline bool) { var mapKeysDirect, mapKeysSub []string for _, mapKey := range rv.MapKeys() { k := mapKey.String() - if typeIsTable(tomlTypeOfGo(rv.MapIndex(mapKey))) { + if typeIsTable(tomlTypeOfGo(eindirect(rv.MapIndex(mapKey)))) { mapKeysSub = append(mapKeysSub, k) } else { mapKeysDirect = append(mapKeysDirect, k) @@ -389,7 +389,7 @@ func (enc *Encoder) eMap(key Key, rv reflect.Value, inline bool) { var writeMapKeys = func(mapKeys []string, trailC bool) { sort.Strings(mapKeys) for i, mapKey := range mapKeys { - val := rv.MapIndex(reflect.ValueOf(mapKey)) + val := eindirect(rv.MapIndex(reflect.ValueOf(mapKey))) if isNil(val) { continue } @@ -441,27 +441,16 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) { continue } - frv := rv.Field(i) + frv := eindirect(rv.Field(i)) // Treat anonymous struct fields with tag names as though they are // not anonymous, like encoding/json does. // // Non-struct anonymous fields use the normal encoding logic. if f.Anonymous { - t := f.Type - switch t.Kind() { - case reflect.Struct: - if getOptions(f.Tag).name == "" { - addFields(t, frv, append(start, f.Index...)) - continue - } - case reflect.Ptr: - if t.Elem().Kind() == reflect.Struct && getOptions(f.Tag).name == "" { - if !frv.IsNil() { - addFields(t.Elem(), frv.Elem(), append(start, f.Index...)) - } - continue - } + if getOptions(f.Tag).name == "" && frv.Kind() == reflect.Struct { + addFields(frv.Type(), frv, append(start, f.Index...)) + continue } } @@ -487,7 +476,7 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) { writeFields := func(fields [][]int) { for _, fieldIndex := range fields { fieldType := rt.FieldByIndex(fieldIndex) - fieldVal := rv.FieldByIndex(fieldIndex) + fieldVal := eindirect(rv.FieldByIndex(fieldIndex)) if isNil(fieldVal) { /// Don't write anything for nil fields. continue @@ -540,6 +529,21 @@ func tomlTypeOfGo(rv reflect.Value) tomlType { if isNil(rv) || !rv.IsValid() { return nil } + + if rv.Kind() == reflect.Struct { + if _, ok := rv.Interface().(time.Time); ok { + return tomlDatetime + } + if isMarshaler(rv) { + return tomlString + } + return tomlHash + } + + if isMarshaler(rv) { + return tomlString + } + switch rv.Kind() { case reflect.Bool: return tomlBool @@ -561,19 +565,7 @@ func tomlTypeOfGo(rv reflect.Value) tomlType { return tomlString case reflect.Map: return tomlHash - case reflect.Struct: - if _, ok := rv.Interface().(time.Time); ok { - return tomlDatetime - } - if isMarshaler(rv) { - return tomlString - } - return tomlHash default: - if isMarshaler(rv) { - return tomlString - } - encPanic(errors.New("unsupported type: " + rv.Kind().String())) panic("unreachable") } @@ -586,16 +578,6 @@ func isMarshaler(rv reflect.Value) bool { case Marshaler: return true } - - // Someone used a pointer receiver: we can make it work for pointer values. - if rv.CanAddr() { - if _, ok := rv.Addr().Interface().(encoding.TextMarshaler); ok { - return true - } - if _, ok := rv.Addr().Interface().(Marshaler); ok { - return true - } - } return false } @@ -605,19 +587,19 @@ func isTableArray(arr reflect.Value) bool { return false } - /// Don't allow nil. + ret := true for i := 0; i < arr.Len(); i++ { - if tomlTypeOfGo(arr.Index(i)) == nil { + tt := tomlTypeOfGo(eindirect(arr.Index(i))) + // Don't allow nil. + if tt == nil { encPanic(errArrayNilElement) } - } - for i := 0; i < arr.Len(); i++ { - if !typeEqual(tomlHash, tomlTypeOfGo(arr.Index(i))) { - return false + if ret && !typeEqual(tomlHash, tt) { + ret = false } } - return true + return ret } type tagOptions struct { @@ -716,12 +698,23 @@ func encPanic(err error) { } func eindirect(v reflect.Value) reflect.Value { - switch v.Kind() { - case reflect.Ptr, reflect.Interface: - return eindirect(v.Elem()) - default: + if v.Kind() != reflect.Ptr && v.Kind() != reflect.Interface { + if isMarshaler(v) { + return v + } + if v.CanAddr() { + if pv := v.Addr(); isMarshaler(pv) { + return pv + } + } return v } + + if v.IsNil() { + return v + } + + return eindirect(v.Elem()) } func isNil(rv reflect.Value) bool {