diff --git a/cmd/shfmt/json_test.go b/cmd/shfmt/json_test.go deleted file mode 100644 index 35811868e..000000000 --- a/cmd/shfmt/json_test.go +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) 2017, Daniel Martí -// See LICENSE for licensing information - -package main - -import ( - "bytes" - "os" - "strings" - "testing" - - qt "github.com/frankban/quicktest" - - "mvdan.cc/sh/v3/syntax" -) - -func TestRoundtripJSON(t *testing.T) { - t.Parallel() - - // Read testdata files. - inputShell, err := os.ReadFile("testdata/json.sh") - qt.Assert(t, err, qt.IsNil) - inputJSON, err := os.ReadFile("testdata/json.json") - if !*update { // allow it to not exist - qt.Assert(t, err, qt.IsNil) - } - sb := new(strings.Builder) - - // Parse the shell source and check that it is well formatted. - parser := syntax.NewParser(syntax.KeepComments(true)) - node, err := parser.Parse(bytes.NewReader(inputShell), "") - qt.Assert(t, err, qt.IsNil) - - printer := syntax.NewPrinter() - sb.Reset() - err = printer.Print(sb, node) - qt.Assert(t, err, qt.IsNil) - qt.Assert(t, sb.String(), qt.Equals, string(inputShell)) - - // Validate writing the pretty JSON. - sb.Reset() - err = writeJSON(sb, node, true) - qt.Assert(t, err, qt.IsNil) - got := sb.String() - if *update { - err := os.WriteFile("testdata/json.json", []byte(got), 0o666) - qt.Assert(t, err, qt.IsNil) - } else { - qt.Assert(t, got, qt.Equals, string(inputJSON)) - } - - // Ensure we don't use the originally parsed node again. - node = nil - - // Validate reading the pretty JSON and check that it formats the same. - node2, err := readJSON(bytes.NewReader(inputJSON)) - qt.Assert(t, err, qt.IsNil) - - sb.Reset() - err = printer.Print(sb, node2) - qt.Assert(t, err, qt.IsNil) - qt.Assert(t, sb.String(), qt.Equals, string(inputShell)) - - // Validate that emitting the JSON again produces the same result. - sb.Reset() - err = writeJSON(sb, node2, true) - qt.Assert(t, err, qt.IsNil) - got = sb.String() - qt.Assert(t, got, qt.Equals, string(inputJSON)) -} diff --git a/cmd/shfmt/main.go b/cmd/shfmt/main.go index c9224b64a..c41e5804f 100644 --- a/cmd/shfmt/main.go +++ b/cmd/shfmt/main.go @@ -24,6 +24,7 @@ import ( "mvdan.cc/sh/v3/fileutil" "mvdan.cc/sh/v3/syntax" + "mvdan.cc/sh/v3/syntax/typedjson" ) // TODO: this flag business screams generics. try again with Go 1.18+. @@ -444,7 +445,7 @@ func formatBytes(src []byte, path string, fileLang syntax.LangVariant) error { var node syntax.Node var err error if fromJSON.val { - node, err = readJSON(bytes.NewReader(src)) + node, err = typedjson.Decode(bytes.NewReader(src)) if err != nil { return err } @@ -462,7 +463,9 @@ func formatBytes(src []byte, path string, fileLang syntax.LangVariant) error { } if toJSON.val { // must be standard input; fine to return - return writeJSON(out, node, true) + // TODO: change the default behavior to be compact, + // and allow using --to-json=pretty or --to-json=indent. + return typedjson.EncodeOptions{Indent: "\t"}.Encode(out, node) } writeBuf.Reset() printer.Print(&writeBuf, node) diff --git a/cmd/shfmt/testdata/scripts/tojson.txt b/cmd/shfmt/testdata/scripts/tojson.txt index b3a96a9c7..9906a85f8 100644 --- a/cmd/shfmt/testdata/scripts/tojson.txt +++ b/cmd/shfmt/testdata/scripts/tojson.txt @@ -17,11 +17,14 @@ cmp stdout comment.sh.json -- empty.sh -- -- empty.sh.json -- -{} +{ + "Type": "File" +} -- simple.sh -- foo -- simple.sh.json -- { + "Type": "File", "Pos": { "Offset": 0, "Line": 1, @@ -109,6 +112,7 @@ foo ((2)) -- arithmetic.sh.json -- { + "Type": "File", "Pos": { "Offset": 0, "Line": 1, @@ -205,6 +209,7 @@ foo # -- comment.sh.json -- { + "Type": "File", "Pos": { "Offset": 0, "Line": 1, diff --git a/cmd/shfmt/json.go b/syntax/typedjson/json.go similarity index 74% rename from cmd/shfmt/json.go rename to syntax/typedjson/json.go index 089ce8491..9e8852c45 100644 --- a/cmd/shfmt/json.go +++ b/syntax/typedjson/json.go @@ -1,7 +1,20 @@ // Copyright (c) 2017, Daniel Martí // See LICENSE for licensing information -package main +// Package typedjson allows encoding and decoding shell syntax trees as JSON. +// The decoding process needs to know what syntax node types to decode into, +// so the "typed JSON" requires "Type" keys in some syntax tree node objects: +// +// - The root node +// - Any node represented as an interface field in the parent Go type +// +// The types of all other nodes can be inferred from context alone. +// +// For the sake of efficiency and simplicity, the "Type" key +// described above must be first in each JSON object. +package typedjson + +// TODO: encoding and decoding nodes other than File is untested. import ( "encoding/json" @@ -12,32 +25,50 @@ import ( "mvdan.cc/sh/v3/syntax" ) -func writeJSON(w io.Writer, node syntax.Node, pretty bool) error { +// Encode is a shortcut for EncodeOptions.Encode, with the default options. +func Encode(w io.Writer, node syntax.Node) error { + return EncodeOptions{}.Encode(w, node) +} + +// EncodeOptions allows configuring how syntax nodes are encoded. +type EncodeOptions struct { + Indent string // e.g. "\t" + + // Allows us to add options later. +} + +// Encode writes node to w in its typed JSON form, +// as described in the package documentation. +func (opts EncodeOptions) Encode(w io.Writer, node syntax.Node) error { val := reflect.ValueOf(node) - encVal, _ := encode(val) + encVal, tname := encodeValue(val) + if tname == "" { + panic("node did not contain a named type?") + } + encVal.Elem().Field(0).SetString(tname) enc := json.NewEncoder(w) - if pretty { - enc.SetIndent("", "\t") + if opts.Indent != "" { + enc.SetIndent("", opts.Indent) } return enc.Encode(encVal.Interface()) } -func encode(val reflect.Value) (reflect.Value, string) { +func encodeValue(val reflect.Value) (reflect.Value, string) { switch val.Kind() { case reflect.Ptr: - elem := val.Elem() - if !elem.IsValid() { + if val.IsNil() { break } - return encode(elem) + return encodeValue(val.Elem()) case reflect.Interface: if val.IsNil() { break } - enc, tname := encode(val.Elem()) - if tname != "" { - enc.Elem().Field(0).SetString(tname) + enc, tname := encodeValue(val.Elem()) + if tname == "" { + panic("interface did not contain a named type?") } + enc.Elem().Field(0).SetString(tname) return enc, "" case reflect.Struct: // Construct a new struct with an optional Type, Pos and End, @@ -71,7 +102,7 @@ func encode(val reflect.Value) (reflect.Value, string) { if ftyp.Type == exportedPosType { encodePos(enc.Field(i), fval) } else { - encElem, _ := encode(fval) + encElem, _ := encodeValue(fval) if encElem.IsValid() { enc.Field(i).Set(encElem) } @@ -88,7 +119,7 @@ func encode(val reflect.Value) (reflect.Value, string) { enc := reflect.MakeSlice(anySliceType, n, n) for i := 0; i < n; i++ { elem := val.Index(i) - encElem, _ := encode(elem) + encElem, _ := encodeValue(elem) enc.Index(i).Set(encElem) } return enc, "" @@ -161,19 +192,32 @@ func decodePos(val reflect.Value, enc map[string]interface{}) { val.Set(reflect.ValueOf(syntax.NewPos(offset, line, column))) } -func readJSON(r io.Reader) (syntax.Node, error) { +// Decode is a shortcut for DecodeOptions.Decode, with the default options. +func Decode(r io.Reader) (syntax.Node, error) { + return DecodeOptions{}.Decode(r) +} + +// DecodeOptions allows configuring how syntax nodes are encoded. +type DecodeOptions struct { + // Empty for now; allows us to add options later. +} + +// Decode writes node to w in its typed JSON form, +// as described in the package documentation. +func (opts DecodeOptions) Decode(r io.Reader) (syntax.Node, error) { var enc interface{} if err := json.NewDecoder(r).Decode(&enc); err != nil { return nil, err } - node := &syntax.File{} - if err := decode(reflect.ValueOf(node), enc); err != nil { + node := new(syntax.Node) + if err := decodeValue(reflect.ValueOf(node).Elem(), enc); err != nil { return nil, err } - return node, nil + return *node, nil } var nodeByName = map[string]reflect.Type{ + "File": reflect.TypeOf((*syntax.File)(nil)).Elem(), "Word": reflect.TypeOf((*syntax.Word)(nil)).Elem(), "Lit": reflect.TypeOf((*syntax.Lit)(nil)).Elem(), @@ -215,7 +259,7 @@ var nodeByName = map[string]reflect.Type{ "CStyleLoop": reflect.TypeOf((*syntax.CStyleLoop)(nil)).Elem(), } -func decode(val reflect.Value, enc interface{}) error { +func decodeValue(val reflect.Value, enc interface{}) error { switch enc := enc.(type) { case map[string]interface{}: if val.Kind() == reflect.Ptr && val.IsNil() { @@ -246,14 +290,14 @@ func decode(val reflect.Value, enc interface{}) error { decodePos(fval, fv.(map[string]interface{})) continue } - if err := decode(fval, fv); err != nil { + if err := decodeValue(fval, fv); err != nil { return err } } case []interface{}: for _, encElem := range enc { elem := reflect.New(val.Type().Elem()).Elem() - if err := decode(elem, encElem); err != nil { + if err := decodeValue(elem, encElem); err != nil { return err } val.Set(reflect.Append(val, elem)) diff --git a/syntax/typedjson/json_test.go b/syntax/typedjson/json_test.go new file mode 100644 index 000000000..5292fa7cb --- /dev/null +++ b/syntax/typedjson/json_test.go @@ -0,0 +1,88 @@ +// Copyright (c) 2017, Daniel Martí +// See LICENSE for licensing information + +package typedjson_test + +import ( + "bytes" + "flag" + "os" + "path/filepath" + "strings" + "testing" + + qt "github.com/frankban/quicktest" + + "mvdan.cc/sh/v3/syntax" + "mvdan.cc/sh/v3/syntax/typedjson" +) + +var update = flag.Bool("u", false, "update output files") + +func TestRoundtrip(t *testing.T) { + t.Parallel() + + dir := filepath.Join("testdata", "roundtrip") + shellPaths, err := filepath.Glob(filepath.Join(dir, "*.sh")) + qt.Assert(t, err, qt.IsNil) + for _, shellPath := range shellPaths { + + shellPath := shellPath // do not reuse the range var + name := strings.TrimSuffix(filepath.Base(shellPath), ".sh") + jsonPath := filepath.Join(dir, name+".json") + t.Run(name, func(t *testing.T) { + t.Parallel() + + shellInput, err := os.ReadFile(shellPath) + qt.Assert(t, err, qt.IsNil) + jsonInput, err := os.ReadFile(jsonPath) + if !*update { // allow it to not exist + qt.Assert(t, err, qt.IsNil) + } + sb := new(strings.Builder) + + // Parse the shell source and check that it is well formatted. + parser := syntax.NewParser(syntax.KeepComments(true)) + node, err := parser.Parse(bytes.NewReader(shellInput), "") + qt.Assert(t, err, qt.IsNil) + + printer := syntax.NewPrinter() + sb.Reset() + err = printer.Print(sb, node) + qt.Assert(t, err, qt.IsNil) + qt.Assert(t, sb.String(), qt.Equals, string(shellInput)) + + // Validate writing the pretty JSON. + sb.Reset() + encOpts := typedjson.EncodeOptions{Indent: "\t"} + err = encOpts.Encode(sb, node) + qt.Assert(t, err, qt.IsNil) + got := sb.String() + if *update { + err := os.WriteFile(jsonPath, []byte(got), 0o666) + qt.Assert(t, err, qt.IsNil) + } else { + qt.Assert(t, got, qt.Equals, string(jsonInput)) + } + + // Ensure we don't use the originally parsed node again. + node = nil + + // Validate reading the pretty JSON and check that it formats the same. + node2, err := typedjson.Decode(bytes.NewReader(jsonInput)) + qt.Assert(t, err, qt.IsNil) + + sb.Reset() + err = printer.Print(sb, node2) + qt.Assert(t, err, qt.IsNil) + qt.Assert(t, sb.String(), qt.Equals, string(shellInput)) + + // Validate that emitting the JSON again produces the same result. + sb.Reset() + err = encOpts.Encode(sb, node2) + qt.Assert(t, err, qt.IsNil) + got = sb.String() + qt.Assert(t, got, qt.Equals, string(jsonInput)) + }) + } +} diff --git a/cmd/shfmt/testdata/json.json b/syntax/typedjson/testdata/roundtrip/file.json similarity index 99% rename from cmd/shfmt/testdata/json.json rename to syntax/typedjson/testdata/roundtrip/file.json index a637f1f5f..0fde505e2 100644 --- a/cmd/shfmt/testdata/json.json +++ b/syntax/typedjson/testdata/roundtrip/file.json @@ -1,4 +1,5 @@ { + "Type": "File", "Pos": { "Offset": 0, "Line": 1, diff --git a/cmd/shfmt/testdata/json.sh b/syntax/typedjson/testdata/roundtrip/file.sh similarity index 100% rename from cmd/shfmt/testdata/json.sh rename to syntax/typedjson/testdata/roundtrip/file.sh