Skip to content

Commit

Permalink
Fix quoted comments (#370)
Browse files Browse the repository at this point in the history
* Make path filtering work with quotes
  • Loading branch information
WillAbides authored Sep 15, 2023
1 parent 1c0fdf0 commit 4df8923
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 10 deletions.
7 changes: 0 additions & 7 deletions parser/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ type context struct {
idx int
size int
tokens token.Tokens
mode Mode
path string
}

Expand Down Expand Up @@ -56,7 +55,6 @@ func (c *context) copy() *context {
idx: c.idx,
size: c.size,
tokens: append(token.Tokens{}, c.tokens...),
mode: c.mode,
path: c.path,
}
}
Expand Down Expand Up @@ -145,10 +143,6 @@ func (c *context) afterNextNotCommentToken() *token.Token {
return nil
}

func (c *context) enabledComment() bool {
return c.mode&ParseComments != 0
}

func (c *context) isCurrentCommentToken() bool {
tk := c.currentToken()
if tk == nil {
Expand Down Expand Up @@ -193,7 +187,6 @@ func newContext(tokens token.Tokens, mode Mode) *context {
idx: 0,
size: len(filteredTokens),
tokens: token.Tokens(filteredTokens),
mode: mode,
path: "$",
}
}
4 changes: 4 additions & 0 deletions parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,8 @@ a: # commentA
i: fuga # commentI
j: piyo # commentJ
k.l.m.n: moge # commentKLMN
o#p: hogera # commentOP
q#.r: hogehoge # commentQR
`
f, err := parser.ParseBytes([]byte(yml), parser.ParseComments)
if err != nil {
Expand Down Expand Up @@ -922,6 +924,8 @@ k.l.m.n: moge # commentKLMN
"$.a.i",
"$.j",
"$.'k.l.m.n'",
"$.o#p",
"$.'q#.r'",
}
if !reflect.DeepEqual(expectedPaths, commentPaths) {
t.Fatalf("failed to get YAMLPath to the comment node:\nexpected[%s]\ngot [%s]", expectedPaths, commentPaths)
Expand Down
26 changes: 23 additions & 3 deletions path.go
Original file line number Diff line number Diff line change
Expand Up @@ -500,11 +500,29 @@ func newSelectorNode(selector string) *selectorNode {
}

func (n *selectorNode) filter(node ast.Node) (ast.Node, error) {
selector := n.selector
if len(selector) > 1 && selector[0] == '\'' && selector[len(selector)-1] == '\'' {
selector = selector[1 : len(selector)-1]
}
switch node.Type() {
case ast.MappingType:
for _, value := range node.(*ast.MappingNode).Values {
key := value.Key.GetToken().Value
if key == n.selector {
if len(key) > 0 {
switch key[0] {
case '"':
var err error
key, err = strconv.Unquote(key)
if err != nil {
return nil, errors.Wrapf(err, "failed to unquote")
}
case '\'':
if len(key) > 1 && key[len(key)-1] == '\'' {
key = key[1 : len(key)-1]
}
}
}
if key == selector {
if n.child == nil {
return value.Value, nil
}
Expand All @@ -518,7 +536,7 @@ func (n *selectorNode) filter(node ast.Node) (ast.Node, error) {
case ast.MappingValueType:
value := node.(*ast.MappingValueNode)
key := value.Key.GetToken().Value
if key == n.selector {
if key == selector {
if n.child == nil {
return value.Value, nil
}
Expand Down Expand Up @@ -571,7 +589,9 @@ func (n *selectorNode) replace(node ast.Node, target ast.Node) error {
}

func (n *selectorNode) String() string {
s := fmt.Sprintf(".%s", n.selector)
var builder PathBuilder
selector := builder.normalizeSelectorName(n.selector)
s := fmt.Sprintf(".%s", selector)
if n.child != nil {
s += n.child.String()
}
Expand Down
7 changes: 7 additions & 0 deletions path_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ store:
bicycle:
color: red
price: 19.95
bicycle*unicycle:
price: 20.25
`
tests := []struct {
name string
Expand Down Expand Up @@ -97,6 +99,11 @@ store:
path: builder().Root().Child("store").Child("bicycle").Child("price").Build(),
expected: float64(19.95),
},
{
name: `$.store.'bicycle*unicycle'.price`,
path: builder().Root().Child("store").Child(`bicycle*unicycle`).Child("price").Build(),
expected: float64(20.25),
},
}
t.Run("PathString", func(t *testing.T) {
for _, test := range tests {
Expand Down
89 changes: 89 additions & 0 deletions yaml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package yaml_test

import (
"bytes"
"strings"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -1161,6 +1162,94 @@ hoge:
})
}

func TestCommentMapRoundTrip(t *testing.T) {
// test that an unmarshal and marshal round trip retains comments.
// if expect is empty, the test will use the input as the expected result.
tests := []struct {
name string
source string
expect string
encodeOptions []yaml.EncodeOption
}{
{
name: "simple map",
source: `
# head
a: 1 # line
# foot
`,
},
{
name: "nesting",
source: `
- 1 # one
- foo:
a: b
# c comment
c: d # d comment
"e#f": g # g comment
h.i: j # j comment
"k.#l": m # m comment
`,
},
{
name: "single quotes",
source: `'a#b': c # c comment`,
encodeOptions: []yaml.EncodeOption{yaml.UseSingleQuote(true)},
},
{
name: "single quotes added in encode",
source: `a#b: c # c comment`,
encodeOptions: []yaml.EncodeOption{yaml.UseSingleQuote(true)},
expect: `'a#b': c # c comment`,
},
{
name: "double quotes quotes transformed to single quotes",
source: `"a#b": c # c comment`,
encodeOptions: []yaml.EncodeOption{yaml.UseSingleQuote(true)},
expect: `'a#b': c # c comment`,
},
{
name: "single quotes quotes transformed to double quotes",
source: `'a#b': c # c comment`,
expect: `"a#b": c # c comment`,
},
{
name: "single quotes removed",
source: `'a': b # b comment`,
expect: `a: b # b comment`,
},
{
name: "double quotes removed",
source: `"a": b # b comment`,
expect: `a: b # b comment`,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var val any
cm := yaml.CommentMap{}
source := strings.TrimSpace(test.source)
if err := yaml.UnmarshalWithOptions([]byte(source), &val, yaml.CommentToMap(cm)); err != nil {
t.Fatalf("%+v", err)
}
marshaled, err := yaml.MarshalWithOptions(val, append(test.encodeOptions, yaml.WithComment(cm))...)
if err != nil {
t.Fatalf("%+v", err)
}
got := strings.TrimSpace(string(marshaled))
expect := strings.TrimSpace(test.expect)
if expect == "" {
expect = source
}
if got != expect {
t.Fatalf("expected:\n%s\ngot:\n%s\n", expect, got)
}
})

}
}

func TestRegisterCustomMarshaler(t *testing.T) {
type T struct {
Foo []byte `yaml:"foo"`
Expand Down

0 comments on commit 4df8923

Please sign in to comment.