From b324da5ffbb2a87e4a611b1703bd03bffbe7ffbf Mon Sep 17 00:00:00 2001 From: Martin Tournoij Date: Thu, 8 Jun 2023 08:12:39 +0200 Subject: [PATCH] Store BURNTSUSHI_TOML_110 in parser and lexer Setting a global is racy when multiple decodes are run in parallel. Fixes #395 --- .github/workflows/test.yml | 2 +- decode_test.go | 21 ++++++++++++++++++++ lex.go | 40 ++++++++++++++++++++------------------ meta.go | 2 +- parse.go | 15 +++++++------- 5 files changed, 51 insertions(+), 29 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0055d054..ac0a3512 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,7 +20,7 @@ "uses": "actions/checkout@v3" }, { "name": "Test", - "run": "go test ./..." + "run": "go test -race ./..." }, { "name": "Test on 32bit", "if": "runner.os == 'Linux'", diff --git a/decode_test.go b/decode_test.go index 45d0cf20..6f08d3ae 100644 --- a/decode_test.go +++ b/decode_test.go @@ -11,6 +11,7 @@ import ( "reflect" "strconv" "strings" + "sync" "testing" "time" @@ -1201,6 +1202,26 @@ func TestMetaKeys(t *testing.T) { } } +func TestDecodeParallel(t *testing.T) { + doc, err := os.ReadFile("testdata/ja-JP.toml") + if err != nil { + t.Fatal(err) + } + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + err := Unmarshal(doc, new(map[string]interface{})) + if err != nil { + t.Fatal(err) + } + }() + } + wg.Wait() +} + // errorContains checks if the error message in have contains the text in // want. // diff --git a/lex.go b/lex.go index a2545302..3545a6ad 100644 --- a/lex.go +++ b/lex.go @@ -46,12 +46,13 @@ func (p Position) String() string { } type lexer struct { - input string - start int - pos int - line int - state stateFn - items chan item + input string + start int + pos int + line int + state stateFn + items chan item + tomlNext bool // Allow for backing up up to 4 runes. This is necessary because TOML // contains 3-rune tokens (""" and '''). @@ -87,13 +88,14 @@ func (lx *lexer) nextItem() item { } } -func lex(input string) *lexer { +func lex(input string, tomlNext bool) *lexer { lx := &lexer{ - input: input, - state: lexTop, - items: make(chan item, 10), - stack: make([]stateFn, 0, 10), - line: 1, + input: input, + state: lexTop, + items: make(chan item, 10), + stack: make([]stateFn, 0, 10), + line: 1, + tomlNext: tomlNext, } return lx } @@ -408,7 +410,7 @@ func lexTableNameEnd(lx *lexer) stateFn { // Lexes only one part, e.g. only 'a' inside 'a.b'. func lexBareName(lx *lexer) stateFn { r := lx.next() - if isBareKeyChar(r) { + if isBareKeyChar(r, lx.tomlNext) { return lexBareName } lx.backup() @@ -618,7 +620,7 @@ func lexInlineTableValue(lx *lexer) stateFn { case isWhitespace(r): return lexSkip(lx, lexInlineTableValue) case isNL(r): - if tomlNext { + if lx.tomlNext { return lexSkip(lx, lexInlineTableValue) } return lx.errorPrevLine(errLexInlineTableNL{}) @@ -643,7 +645,7 @@ func lexInlineTableValueEnd(lx *lexer) stateFn { case isWhitespace(r): return lexSkip(lx, lexInlineTableValueEnd) case isNL(r): - if tomlNext { + if lx.tomlNext { return lexSkip(lx, lexInlineTableValueEnd) } return lx.errorPrevLine(errLexInlineTableNL{}) @@ -654,7 +656,7 @@ func lexInlineTableValueEnd(lx *lexer) stateFn { lx.ignore() lx.skip(isWhitespace) if lx.peek() == '}' { - if tomlNext { + if lx.tomlNext { return lexInlineTableValueEnd } return lx.errorf("trailing comma not allowed in inline tables") @@ -838,7 +840,7 @@ func lexStringEscape(lx *lexer) stateFn { r := lx.next() switch r { case 'e': - if !tomlNext { + if !lx.tomlNext { return lx.error(errLexEscape{r}) } fallthrough @@ -861,7 +863,7 @@ func lexStringEscape(lx *lexer) stateFn { case '\\': return lx.pop() case 'x': - if !tomlNext { + if !lx.tomlNext { return lx.error(errLexEscape{r}) } return lexHexEscape @@ -1258,7 +1260,7 @@ func isHexadecimal(r rune) bool { return (r >= '0' && r <= '9') || (r >= 'a' && r <= 'f') || (r >= 'A' && r <= 'F') } -func isBareKeyChar(r rune) bool { +func isBareKeyChar(r rune, tomlNext bool) bool { if tomlNext { return (r >= 'A' && r <= 'Z') || (r >= 'a' && r <= 'z') || diff --git a/meta.go b/meta.go index 71847a04..2e78b24e 100644 --- a/meta.go +++ b/meta.go @@ -106,7 +106,7 @@ func (k Key) maybeQuoted(i int) string { return `""` } for _, c := range k[i] { - if !isBareKeyChar(c) { + if !isBareKeyChar(c, false) { return `"` + dblQuotedReplacer.Replace(k[i]) + `"` } } diff --git a/parse.go b/parse.go index ceceb1d1..9c191536 100644 --- a/parse.go +++ b/parse.go @@ -11,13 +11,12 @@ import ( "github.com/BurntSushi/toml/internal" ) -var tomlNext bool - type parser struct { lx *lexer context Key // Full key for the current hash in scope. currentKey string // Base key name for everything except hashes. pos Position // Current position in the TOML file. + tomlNext bool ordered []Key // List of keys in the order that they appear in the TOML data. @@ -32,8 +31,7 @@ type keyInfo struct { } func parse(data string) (p *parser, err error) { - _, ok := os.LookupEnv("BURNTSUSHI_TOML_110") - tomlNext = ok + _, tomlNext := os.LookupEnv("BURNTSUSHI_TOML_110") defer func() { if r := recover(); r != nil { @@ -74,9 +72,10 @@ func parse(data string) (p *parser, err error) { p = &parser{ keyInfo: make(map[string]keyInfo), mapping: make(map[string]interface{}), - lx: lex(data), + lx: lex(data, tomlNext), ordered: make([]Key, 0), implicits: make(map[string]struct{}), + tomlNext: tomlNext, } for { item := p.next() @@ -361,7 +360,7 @@ func (p *parser) valueDatetime(it item) (interface{}, tomlType) { err error ) for _, dt := range dtTypes { - if dt.next && !tomlNext { + if dt.next && !p.tomlNext { continue } t, err = time.ParseInLocation(dt.fmt, it.val, dt.zone) @@ -764,7 +763,7 @@ func (p *parser) replaceEscapes(it item, str string) string { replaced = append(replaced, rune(0x000D)) r += 1 case 'e': - if tomlNext { + if p.tomlNext { replaced = append(replaced, rune(0x001B)) r += 1 } @@ -775,7 +774,7 @@ func (p *parser) replaceEscapes(it item, str string) string { replaced = append(replaced, rune(0x005C)) r += 1 case 'x': - if tomlNext { + if p.tomlNext { escaped := p.asciiEscapeToUnicode(it, s[r+1:r+3]) replaced = append(replaced, escaped) r += 3