From 049330cd045f372b05aa39db469cd8ae32df7fde Mon Sep 17 00:00:00 2001 From: "A. Stoewer" Date: Mon, 6 May 2024 16:32:44 +1000 Subject: [PATCH] Improve schema sub-command --- cmd/parquet-cli/cmd_schema.go | 131 +------------------------ pkg/inspect/inspect.go | 12 +++ pkg/inspect/schema.go | 165 +++++++++++++++++++++++++++++++ pkg/output/format.go | 26 +++-- pkg/output/output.go | 90 +---------------- pkg/output/print.go | 178 ++++++++++++++++++++++++++++++++++ 6 files changed, 381 insertions(+), 221 deletions(-) create mode 100644 pkg/inspect/schema.go create mode 100644 pkg/output/print.go diff --git a/cmd/parquet-cli/cmd_schema.go b/cmd/parquet-cli/cmd_schema.go index e14e1c3..8821178 100644 --- a/cmd/parquet-cli/cmd_schema.go +++ b/cmd/parquet-cli/cmd_schema.go @@ -1,14 +1,10 @@ package main import ( - "encoding/binary" - "errors" - "fmt" - "io" "os" - "github.com/parquet-go/parquet-go/format" - "github.com/segmentio/encoding/thrift" + "github.com/stoewer/parquet-cli/pkg/inspect" + "github.com/stoewer/parquet-cli/pkg/output" ) @@ -18,128 +14,11 @@ type schema struct { } func (s *schema) Run() error { - f, err := os.Open(s.File) - if err != nil { - return err - } - - fi, err := f.Stat() - if err != nil { - return err - } - - metadata, err := readMetadata(f, fi.Size()) + pf, err := openParquetFile(s.File) if err != nil { return err - - } - - return output.PrintTable(os.Stdout, s.Output, newMetadataTable(metadata)) -} - -// borrowed with love from github.com/segmentio/parquet-go/file.go:OpenFile() -func readMetadata(r io.ReaderAt, size int64) (*format.FileMetaData, error) { - b := make([]byte, 8) - - if _, err := r.ReadAt(b[:4], 0); err != nil { - return nil, fmt.Errorf("reading magic header of parquet file: %w", err) - } - if string(b[:4]) != "PAR1" { - return nil, fmt.Errorf("invalid magic header of parquet file: %q", b[:4]) - } - if n, err := r.ReadAt(b[:8], size-8); n != 8 { - return nil, fmt.Errorf("reading magic footer of parquet file: %w", err) - } - if string(b[4:8]) != "PAR1" { - return nil, fmt.Errorf("invalid magic footer of parquet file: %q", b[4:8]) - } - - footerSize := int64(binary.LittleEndian.Uint32(b[:4])) - footerData := make([]byte, footerSize) - if _, err := r.ReadAt(footerData, size-(footerSize+8)); err != nil { - return nil, fmt.Errorf("reading footer of parquet file: %w", err) - } - - protocol := thrift.CompactProtocol{} - metadata := &format.FileMetaData{} - if err := thrift.Unmarshal(&protocol, footerData, metadata); err != nil { - return nil, fmt.Errorf("reading parquet file metadata: %w", err) - } - if len(metadata.Schema) == 0 { - return nil, errors.New("missing root column") - } - - return metadata, nil -} - -type metadataTable struct { - schema []format.SchemaElement - row int -} - -func newMetadataTable(m *format.FileMetaData) *metadataTable { - return &metadataTable{ - schema: m.Schema, - } -} - -func (t *metadataTable) Header() []any { - return []any{ - "Type", - "TypeLength", - "RepetitionType", - "Name", - "NumChildren", - "ConvertedType", - "Scale", - "Precision", - "FieldID", - "LogicalType", - } -} - -func (t *metadataTable) NextRow() (output.TableRow, error) { - if t.row >= len(t.schema) { - return nil, io.EOF - } - - r := newMetadataRow(0, &t.schema[t.row]) - t.row++ - - return r, nil -} - -type metadataRow struct { - n int - s *format.SchemaElement -} - -func newMetadataRow(n int, s *format.SchemaElement) *metadataRow { - return &metadataRow{ - n: n, - s: s, } -} - -func (r *metadataRow) Row() int { - return r.n -} - -func (r *metadataRow) Cells() []any { - return []any{ - r.s.Type, - r.s.TypeLength, - r.s.RepetitionType, - r.s.Name, - r.s.NumChildren, - r.s.ConvertedType, - r.s.Scale, - r.s.Precision, - r.s.FieldID, - r.s.LogicalType, - } -} -func (r *metadataRow) SerializableData() any { - return r.s + sch := inspect.NewSchema(pf) + return output.Print(os.Stdout, sch, &output.PrintOptions{Format: s.Output}) } diff --git a/pkg/inspect/inspect.go b/pkg/inspect/inspect.go index c9d2f2a..af7efe7 100644 --- a/pkg/inspect/inspect.go +++ b/pkg/inspect/inspect.go @@ -29,3 +29,15 @@ func LeafColumns(file *parquet.File) []*parquet.Column { sort.SliceStable(leafs, func(i, j int) bool { return leafs[i].Index() < leafs[j].Index() }) return leafs } + +func PathToDisplayName(path []string) string { + l := len(path) + if l > 3 { + if path[l-2] == "list" && path[l-1] == "element" { + return path[l-3] + } else if path[l-2] == "key_value" && (path[l-1] == "key" || path[l-1] == "value") { + return path[l-3] + "." + path[l-1] + } + } + return path[l-1] +} diff --git a/pkg/inspect/schema.go b/pkg/inspect/schema.go new file mode 100644 index 0000000..7eeb7ba --- /dev/null +++ b/pkg/inspect/schema.go @@ -0,0 +1,165 @@ +package inspect + +import ( + "fmt" + "io" + "strings" + + "github.com/stoewer/parquet-cli/pkg/output" + + "github.com/parquet-go/parquet-go" +) + +var headers = []any{ + "Index", + "Name", + "Optional", + "Repeated", + "Required", + "Is Leaf", + "Type", + "Go Type", + "Encoding", + "Compression", + "Path", +} + +type Schema struct { + pf *parquet.File + + fields []fieldWithPath + next int +} + +func NewSchema(pf *parquet.File) *Schema { + return &Schema{pf: pf} +} + +func (s *Schema) Text() (string, error) { + textRaw := s.pf.Schema().String() + + var text strings.Builder + for _, r := range textRaw { + if r == '\t' { + text.WriteString(" ") + } else { + text.WriteRune(r) + } + } + + return text.String(), nil +} + +func (s *Schema) Header() []any { + return headers +} + +func (s *Schema) NextRow() (output.TableRow, error) { + if s.fields == nil { + s.fields = fieldsFromSchema(s.pf.Schema()) + } + if s.next >= len(s.fields) { + return nil, fmt.Errorf("no more fields: %w", io.EOF) + } + + nextField := s.fields[s.next] + s.next++ + return toSchemaNode(&nextField), nil +} + +func (s *Schema) NextSerializable() (any, error) { + return s.NextRow() +} + +func toSchemaNode(n *fieldWithPath) *schemaNode { + sn := &schemaNode{ + Index: n.Index, + Name: n.Name(), + Optional: n.Optional(), + Repeated: n.Repeated(), + Required: n.Required(), + IsLeaf: n.Leaf(), + } + + if n.Leaf() { + sn.IsLeaf = true + sn.Type = n.Type().String() + sn.GoType = n.GoType().String() + if n.Encoding() != nil { + sn.Encoding = n.Encoding().String() + } + if n.Compression() != nil { + sn.Compression = n.Compression().String() + } + } + + if len(n.Path) > 0 { + sn.Path = strings.Join(n.Path, ".") + sn.Name = PathToDisplayName(n.Path) + } + + return sn +} + +type schemaNode struct { + Index int `json:"index,omitempty"` + Name string `json:"name"` + Optional bool `json:"optional"` + Repeated bool `json:"repeated"` + Required bool `json:"required"` + IsLeaf bool `json:"is_leaf"` + Type string `json:"type,omitempty"` + GoType string `json:"go_type,omitempty"` + Encoding string `json:"encoding,omitempty"` + Compression string `json:"compression,omitempty"` + Path string `json:"path,omitempty"` +} + +func (sn *schemaNode) Cells() []any { + return []any{ + sn.Index, + sn.Name, + sn.Optional, + sn.Repeated, + sn.Required, + sn.IsLeaf, + sn.Type, + sn.GoType, + sn.Encoding, + sn.Compression, + sn.Path, + } +} + +type fieldWithPath struct { + parquet.Field + Path []string + Index int +} + +func fieldsFromSchema(schema *parquet.Schema) []fieldWithPath { + result := make([]fieldWithPath, 0) + for _, field := range schema.Fields() { + result = fieldsFromPathRecursive(field, []string{}, result) + } + return result +} + +func fieldsFromPathRecursive(field parquet.Field, path []string, result []fieldWithPath) []fieldWithPath { + path = append(path, field.Name()) + + result = append(result, fieldWithPath{Field: field, Path: path}) + for _, child := range field.Fields() { + result = fieldsFromPathRecursive(child, path, result) + } + + colIndex := 0 + for i := range result { + if result[i].Leaf() { + result[i].Index = colIndex + colIndex++ + } + } + + return result +} diff --git a/pkg/output/format.go b/pkg/output/format.go index e0ccfe8..feef17e 100644 --- a/pkg/output/format.go +++ b/pkg/output/format.go @@ -1,6 +1,9 @@ package output -import "errors" +import ( + "errors" + "fmt" +) // Format describes a printable data representation. type Format string @@ -21,7 +24,7 @@ func (f *Format) Validate() error { } } -func formatsFor(data any) []Format { +func supportedFormats(data any) []Format { var formats []Format switch data.(type) { case Serializable, SerializableIterator: @@ -34,11 +37,20 @@ func formatsFor(data any) []Format { return formats } -func supportsFormat(data any, f Format) bool { - for _, format := range formatsFor(data) { - if format == f { - return true +func errUnsupportedFormat(data any, f Format) error { + supported := supportedFormats(data) + + var supportedPretty string + for i, format := range supportedFormats(data) { + if i > 0 { + if i == len(supported)-1 { + supportedPretty += " or " + } else { + supportedPretty += ", " + } } + supportedPretty += "'" + string(format) + "'" } - return false + + return fmt.Errorf("format '%s' is not supported must be %s", f, supportedPretty) } diff --git a/pkg/output/output.go b/pkg/output/output.go index 33b05e6..d668c3d 100644 --- a/pkg/output/output.go +++ b/pkg/output/output.go @@ -2,20 +2,17 @@ package output import ( "bytes" - "encoding/csv" "encoding/json" "errors" "fmt" "io" - "strings" - "text/tabwriter" ) // PrintTable writes the TableIterator data to w using the provided format. func PrintTable(w io.Writer, f Format, data TableIterator) error { switch f { case FormatJSON: - return printJSON(w, data) + return printTableToJSON(w, data) case FormatTab: return printTab(w, data) case FormatCSV: @@ -25,69 +22,7 @@ func PrintTable(w io.Writer, f Format, data TableIterator) error { } } -func printTab(w io.Writer, data TableIterator) error { - tw := tabwriter.NewWriter(w, 0, 0, 2, ' ', 0) - - formatBuilder := strings.Builder{} - for range data.Header() { - formatBuilder.WriteString("%v\t") - } - formatBuilder.WriteRune('\n') - format := formatBuilder.String() - - _, err := fmt.Fprintf(tw, format, data.Header()...) - if err != nil { - return err - } - - row, err := data.NextRow() - for err == nil { - _, err = fmt.Fprintf(tw, format, row.Cells()...) - if err != nil { - return err - } - - row, err = data.NextRow() - } - if err != nil && !errors.Is(err, io.EOF) { - return err - } - - return tw.Flush() -} - -func printCSV(w io.Writer, data TableIterator) error { - cw := csv.NewWriter(w) - cw.Comma = ';' - - header := data.Header() - lineBuffer := make([]string, len(header)) - - line := toStringSlice(header, lineBuffer) - err := cw.Write(line) - if err != nil { - return err - } - - row, err := data.NextRow() - for err == nil { - line = toStringSlice(row.Cells(), lineBuffer) - err = cw.Write(line) - if err != nil { - return err - } - - row, err = data.NextRow() - } - if err != nil && !errors.Is(err, io.EOF) { - return err - } - - cw.Flush() - return cw.Error() -} - -func printJSON(w io.Writer, data TableIterator) error { +func printTableToJSON(w io.Writer, data TableIterator) error { if serializable, ok := data.(Serializable); ok { enc := json.NewEncoder(w) enc.SetIndent("", " ") @@ -139,24 +74,3 @@ func printJSON(w io.Writer, data TableIterator) error { _, err = fmt.Println("\n]") return err } - -func toStringSlice(in []any, buf []string) []string { - for i, v := range in { - var s string - switch v := v.(type) { - case string: - s = v - case fmt.Stringer: - s = v.String() - default: - s = fmt.Sprint(v) - } - - if i < len(buf) { - buf[i] = s - } else { - buf = append(buf, s) - } - } - return buf[0:len(in)] -} diff --git a/pkg/output/print.go b/pkg/output/print.go new file mode 100644 index 0000000..f902475 --- /dev/null +++ b/pkg/output/print.go @@ -0,0 +1,178 @@ +package output + +import ( + "bytes" + "encoding/csv" + "encoding/json" + "errors" + "fmt" + "io" + "strings" + "text/tabwriter" + "unsafe" +) + +type PrintOptions struct { + Format Format + Color bool +} + +func Print(out io.Writer, data any, opts *PrintOptions) error { + switch opts.Format { + case FormatText: + if text, ok := data.(Text); ok { + return printText(out, text) + } + case FormatTab: + if table, ok := data.(TableIterator); ok { + return printTab(out, table) + } + case FormatCSV: + if table, ok := data.(TableIterator); ok { + return printCSV(out, table) + } + case FormatJSON: + if ser, ok := data.(SerializableIterator); ok { + return printJSON(out, ser) + } + } + return errUnsupportedFormat(data, opts.Format) +} + +func printJSON(w io.Writer, data SerializableIterator) error { + _, err := fmt.Fprintln(w, "[") + if err != nil { + return err + } + + var count int + buf := bytes.NewBuffer(make([]byte, 10240)) + next, err := data.NextSerializable() + + for err == nil { + if count > 0 { + _, err = fmt.Fprint(w, ",\n ") + } else { + _, err = fmt.Fprint(w, " ") + } + if err != nil { + return err + } + + buf.Reset() + err = json.NewEncoder(buf).Encode(next) + if err != nil { + return err + } + buf.Truncate(buf.Len() - 1) // remove the newline + + _, err = fmt.Fprint(w, buf) + if err != nil { + return err + } + + count++ + next, err = data.NextSerializable() + } + if !errors.Is(err, io.EOF) { + return err + } + + _, err = fmt.Println("\n]") + return err +} + +func printTab(w io.Writer, data TableIterator) error { + tw := tabwriter.NewWriter(w, 0, 0, 2, ' ', 0) + + formatBuilder := strings.Builder{} + for range data.Header() { + formatBuilder.WriteString("%v\t") + } + formatBuilder.WriteRune('\n') + format := formatBuilder.String() + + _, err := fmt.Fprintf(tw, format, data.Header()...) + if err != nil { + return err + } + + row, err := data.NextRow() + for err == nil { + _, err = fmt.Fprintf(tw, format, row.Cells()...) + if err != nil { + return err + } + + row, err = data.NextRow() + } + if !errors.Is(err, io.EOF) { + return err + } + + return tw.Flush() +} + +func printCSV(w io.Writer, data TableIterator) error { + cw := csv.NewWriter(w) + cw.Comma = ';' + + header := data.Header() + lineBuffer := make([]string, len(header)) + + line := toStringSlice(header, lineBuffer) + err := cw.Write(line) + if err != nil { + return err + } + + row, err := data.NextRow() + for err == nil { + line = toStringSlice(row.Cells(), lineBuffer) + err = cw.Write(line) + if err != nil { + return err + } + + row, err = data.NextRow() + } + if !errors.Is(err, io.EOF) { + return err + } + + cw.Flush() + return cw.Error() +} + +func toStringSlice(in []any, buf []string) []string { + for i, v := range in { + var s string + switch v := v.(type) { + case string: + s = v + case fmt.Stringer: + s = v.String() + default: + s = fmt.Sprint(v) + } + + if i < len(buf) { + buf[i] = s + } else { + buf = append(buf, s) + } + } + return buf[0:len(in)] +} + +func printText(out io.Writer, data Text) error { + s, err := data.Text() + if err != nil { + return fmt.Errorf("unable to print text: %w", err) + } + + b := unsafe.Slice(unsafe.StringData(s), len(s)) + + _, err = out.Write(b) + return err +}