Skip to content

Commit

Permalink
Fix format with alias (#648)
Browse files Browse the repository at this point in the history
* fix format with alias
  • Loading branch information
goccy authored Feb 11, 2025
1 parent c331468 commit 8f17441
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 7 deletions.
7 changes: 1 addition & 6 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -758,12 +758,7 @@ func (d *Decoder) deleteStructKeys(structType reflect.Type, unknownFields map[st
}

func (d *Decoder) unmarshalableDocument(node ast.Node) ([]byte, error) {
var err error
node, err = d.resolveAlias(node)
if err != nil {
return nil, err
}
doc := format.FormatNode(node)
doc := format.FormatNodeWithResolvedAlias(node, d.anchorNodeMap)
return []byte(doc), nil
}

Expand Down
35 changes: 35 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3151,6 +3151,41 @@ func TestMapKeyCustomUnmarshaler(t *testing.T) {
}
}

type bytesUnmershalerWithMapAlias struct{}

func (*bytesUnmershalerWithMapAlias) UnmarshalYAML(b []byte) error {
expected := strings.TrimPrefix(`
stuff:
bar:
- one
- two
`, "\n")
if string(b) != expected {
return fmt.Errorf("failed to decode: expected:\n[%s]\nbut got:\n[%s]\n", expected, string(b))
}
return nil
}

func TestBytesUnmarshalerWithMapAlias(t *testing.T) {
yml := `
x-foo: &data
bar:
- one
- two
foo:
stuff: *data
`
type T struct {
Foo bytesUnmershalerWithMapAlias `yaml:"foo"`
}
var v T
if err := yaml.Unmarshal([]byte(yml), &v); err != nil {
t.Fatal(err)
}
}

func TestDecoderPreservesDefaultValues(t *testing.T) {
type nested struct {
Val string `yaml:"val"`
Expand Down
43 changes: 42 additions & 1 deletion internal/format/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ import (
"github.com/goccy/go-yaml/token"
)

func FormatNodeWithResolvedAlias(n ast.Node, anchorNodeMap map[string]ast.Node) string {
tk := n.GetToken()
if tk == nil {
return ""
}
formatter := newFormatter(tk, hasComment(n))
formatter.anchorNodeMap = anchorNodeMap
return formatter.format(n)
}

func FormatNode(n ast.Node) string {
tk := n.GetToken()
if tk == nil {
Expand Down Expand Up @@ -124,6 +134,7 @@ func hasComment(n ast.Node) bool {
type Formatter struct {
existsComment bool
tokenToOriginMap map[*token.Token]string
anchorNodeMap map[string]ast.Node
}

func newFormatter(tk *token.Token, existsComment bool) *Formatter {
Expand Down Expand Up @@ -294,6 +305,19 @@ func (f *Formatter) formatAnchor(n *ast.AnchorNode) string {
}

func (f *Formatter) formatAlias(n *ast.AliasNode) string {
if f.anchorNodeMap != nil {
node := f.anchorNodeMap[n.Value.GetToken().Value]
if node != nil {
formatted := f.formatNode(node)
// If formatted text contains newline characters, indentation needs to be considered.
if strings.Contains(formatted, "\n") {
// If the first character is not a newline, the first line should be output without indentation.
isIgnoredFirstLine := !strings.HasPrefix(formatted, "\n")
formatted = f.addIndentSpace(n.GetToken().Position.IndentNum, formatted, isIgnoredFirstLine)
}
return formatted
}
}
return f.origin(n.Start) + f.formatNode(n.Value)
}

Expand Down Expand Up @@ -385,7 +409,7 @@ func (f *Formatter) trimIndentSpace(trimIndentNum int, v string) string {
}
lines := strings.Split(normalizeNewLineChars(v), "\n")
out := make([]string, 0, len(lines))
for _, line := range strings.Split(v, "\n") {
for _, line := range lines {
var cnt int
out = append(out, strings.TrimLeftFunc(line, func(r rune) bool {
cnt++
Expand All @@ -395,6 +419,23 @@ func (f *Formatter) trimIndentSpace(trimIndentNum int, v string) string {
return strings.Join(out, "\n")
}

func (f *Formatter) addIndentSpace(indentNum int, v string, isIgnoredFirstLine bool) string {
if indentNum == 0 {
return v
}
indent := strings.Repeat(" ", indentNum)
lines := strings.Split(normalizeNewLineChars(v), "\n")
out := make([]string, 0, len(lines))
for idx, line := range lines {
if line == "" || (isIgnoredFirstLine && idx == 0) {
out = append(out, line)
continue
}
out = append(out, indent+line)
}
return strings.Join(out, "\n")
}

// normalizeNewLineChars normalize CRLF and CR to LF.
func normalizeNewLineChars(v string) string {
return strings.ReplaceAll(strings.ReplaceAll(v, "\r\n", "\n"), "\r", "\n")
Expand Down

0 comments on commit 8f17441

Please sign in to comment.