diff --git a/encode.go b/encode.go index de9e72a3..d01fcf1f 100644 --- a/encode.go +++ b/encode.go @@ -21,7 +21,6 @@ import ( "io" "reflect" "regexp" - "sort" "strconv" "strings" "time" @@ -185,9 +184,7 @@ func (e *encoder) marshal(tag string, in reflect.Value) { func (e *encoder) mapv(tag string, in reflect.Value) { e.mappingv(tag, func() { - keys := keyList(in.MapKeys()) - sort.Sort(keys) - for _, k := range keys { + for _, k := range NewSortedKeys(in) { e.marshal("", k) e.marshal("", in.MapIndex(k)) } @@ -238,9 +235,7 @@ func (e *encoder) structv(tag string, in reflect.Value) { m := in.Field(sinfo.InlineMap) if m.Len() > 0 { e.flow = false - keys := keyList(m.MapKeys()) - sort.Sort(keys) - for _, k := range keys { + for _, k := range NewSortedKeys(m) { if _, found := sinfo.FieldsMap[k.String()]; found { panic(fmt.Sprintf("cannot have key %q in inlined map: conflicts with struct field", k.String())) } diff --git a/encode_test.go b/encode_test.go index 4a8bf2e2..38a4626c 100644 --- a/encode_test.go +++ b/encode_test.go @@ -494,6 +494,31 @@ var marshalTests = []struct { }, "value: !!seq []\n", }, + {map[struct { + A int + B string + }]int{{A: 10, B: "a"}: 1, + {A: 11, B: "b"}: 3, + {A: 11, B: "a"}: 2, + {A: 11, B: "c"}: 4, + {A: 20, B: "a"}: 5}, + "" + + "? a: 10\n" + + " b: a\n" + + ": 1\n" + + "? a: 11\n" + + " b: a\n" + + ": 2\n" + + "? a: 11\n" + + " b: b\n" + + ": 3\n" + + "? a: 11\n" + + " b: c\n" + + ": 4\n" + + "? a: 20\n" + + " b: a\n" + + ": 5\n", + }, } func (s *S) TestMarshal(c *C) { diff --git a/sorter.go b/sorter.go index 9210ece7..8566e5d6 100644 --- a/sorter.go +++ b/sorter.go @@ -17,16 +17,23 @@ package yaml import ( "reflect" + "sort" "unicode" ) -type keyList []reflect.Value +type reflectSorter []reflect.Value -func (l keyList) Len() int { return len(l) } -func (l keyList) Swap(i, j int) { l[i], l[j] = l[j], l[i] } -func (l keyList) Less(i, j int) bool { - a := l[i] - b := l[j] +func NewSortedKeys(mv reflect.Value) []reflect.Value { + keys := mv.MapKeys() + sort.Slice(keys, reflectSorter(keys).Less) + return keys +} + +func (l reflectSorter) Less(i, j int) bool { + return lessByValues(l[i], l[j]) +} + +func lessByValues(a, b reflect.Value) bool { ak := a.Kind() bk := b.Kind() for (ak == reflect.Interface || ak == reflect.Ptr) && !a.IsNil() { @@ -48,6 +55,25 @@ func (l keyList) Less(i, j int) bool { } return numLess(a, b) } + if ak == reflect.Struct && bk == ak { + tp := a.Type() + isEqual := false + for fi, fc := 0, tp.NumField(); fi < fc; fi++ { //compare struct fields in declaration order + if !tp.Field(fi).IsExported() { + continue + } + if lessByValues(a.Field(fi), b.Field(fi)) { + return true + } + if lessByValues(b.Field(fi), a.Field(fi)) { + return false + } + isEqual = true + } + if isEqual { + return false + } + } if ak != reflect.String || bk != reflect.String { return ak < bk }