diff --git a/chiv.go b/chiv.go index 017eaea..d642f9e 100644 --- a/chiv.go +++ b/chiv.go @@ -4,7 +4,7 @@ package chiv import ( "context" "database/sql" - "encoding/csv" + "errors" "fmt" "io" @@ -12,31 +12,39 @@ import ( "github.com/aws/aws-sdk-go/service/s3/s3manager" ) +var ( + // DefaultFormat is CSV. + DefaultFormatFunc = CSV + // ErrRecordLength does not match the number of columns. + ErrRecordLength = errors.New("record length does not match number of columns") +) + // Archiver archives arbitrarily large relational database tables to Amazon S3. It contains a database connection // and upload client. Options set on creation apply to all calls to Archive unless overridden. type Archiver struct { db *sql.DB s3 *s3manager.Uploader - format Format + config config +} + +type config struct { + format FormatterFunc key string null []byte } -const ( - // DefaultFormat is CSV. - DefaultFormat = CSV -) - // NewArchiver constructs an Archiver with the given database connection, S3 uploader and options. func NewArchiver(db *sql.DB, s3 *s3manager.Uploader, options ...Option) *Archiver { a := Archiver{ - db: db, - s3: s3, - format: DefaultFormat, + db: db, + s3: s3, + config: config{ + format: DefaultFormatFunc, + }, } for _, option := range options { - option(&a) + option(&a.config) } return &a @@ -53,11 +61,11 @@ func (a *Archiver) ArchiveWithContext(ctx context.Context, table, bucket string, db: a.db, s3: a.s3, ctx: ctx, - config: a, + config: a.config, } for _, option := range options { - option(archiver.config) + option(&archiver.config) } return archiver.archive(table, bucket) @@ -67,7 +75,7 @@ type archiver struct { db *sql.DB s3 *s3manager.Uploader ctx context.Context - config *Archiver + config config } func (a *archiver) archive(table string, bucket string) error { @@ -87,23 +95,7 @@ func (a *archiver) archive(table string, bucket string) error { } } -type formatter interface { - Begin([]*sql.ColumnType) error - Write([][]byte) error - End() error -} - func (a *archiver) download(wc io.WriteCloser, table string, errs chan error) { - var w formatter - switch a.config.format { - case YAML: - w = &yamlFormatter{} - case JSON: - w = &jsonFormatter{w: wc} - default: - w = &csvFormatter{w: csv.NewWriter(wc)} - } - selectAll := fmt.Sprintf(`select * from "%s";`, table) rows, err := a.db.QueryContext(a.ctx, selectAll) if err != nil { @@ -118,7 +110,8 @@ func (a *archiver) download(wc io.WriteCloser, table string, errs chan error) { return } - if err := w.Begin(columns); err != nil { + f, err := a.config.format(wc, columns) + if err != nil { errs <- err return } @@ -147,7 +140,7 @@ func (a *archiver) download(wc io.WriteCloser, table string, errs chan error) { } } - if err := w.Write(record); err != nil { + if err := f.Format(record); err != nil { errs <- err return } @@ -158,7 +151,7 @@ func (a *archiver) download(wc io.WriteCloser, table string, errs chan error) { return } - if err := w.End(); err != nil { + if err := f.Close(); err != nil { errs <- err return } @@ -171,14 +164,8 @@ func (a *archiver) download(wc io.WriteCloser, table string, errs chan error) { func (a *archiver) upload(r io.Reader, table string, bucket string, errs chan error) { if a.config.key == "" { - switch a.config.format { - case YAML: - a.config.key = fmt.Sprintf("%s.yml", table) - case JSON: - a.config.key = fmt.Sprintf("%s.json", table) - default: - a.config.key = fmt.Sprintf("%s.csv", table) - } + // TODO if a.config.extension or something? can pass in '.json'? wish i could connect to formatter hm + a.config.key = table } if _, err := a.s3.UploadWithContext(a.ctx, &s3manager.UploadInput{ diff --git a/chiv_test.go b/chiv_test.go index 9e5e78f..60d37f9 100644 --- a/chiv_test.go +++ b/chiv_test.go @@ -43,7 +43,7 @@ func TestArchiver_Archive(t *testing.T) { expected: "./testdata/postgres.csv", bucket: "postgres_bucket", table: "postgres_table", - key: "postgres_table.csv", + key: "postgres_table", options: []chiv.Option{}, }, { @@ -55,9 +55,9 @@ func TestArchiver_Archive(t *testing.T) { expected: "./testdata/postgres.csv", bucket: "postgres_bucket", table: "postgres_table", - key: "postgres_custom_key", + key: "postgres_table.csv", options: []chiv.Option{ - chiv.WithKey("postgres_custom_key"), + chiv.WithKey("postgres_table.csv"), }, }, { @@ -69,7 +69,7 @@ func TestArchiver_Archive(t *testing.T) { expected: "./testdata/postgres_with_null.csv", bucket: "postgres_bucket", table: "postgres_table", - key: "postgres_table.csv", + key: "postgres_table", options: []chiv.Option{ chiv.WithNull("custom_null"), }, @@ -86,6 +86,7 @@ func TestArchiver_Archive(t *testing.T) { key: "postgres_table.json", options: []chiv.Option{ chiv.WithFormat(chiv.JSON), + chiv.WithKey("postgres_table.json"), }, }, } diff --git a/formatters.go b/formatters.go index b8154a2..a13caa9 100644 --- a/formatters.go +++ b/formatters.go @@ -4,56 +4,64 @@ import ( "database/sql" "encoding/csv" "encoding/json" - "errors" "io" "regexp" "strconv" ) -const ( - openBracket = byte('[') - closeBracket = byte(']') - comma = byte(',') -) +// FormatterFunc returns an initialized Formatter. +type FormatterFunc func(io.Writer, []*sql.ColumnType) (Formatter, error) -var ( - // ErrRecordLength does not match the number of columns. - ErrRecordLength = errors.New("record length does not match number of columns") -) - -// csvFormatter formats columns in CSV format. -type csvFormatter struct { - w *csv.Writer - count int +// Formatter formats and writes records. +type Formatter interface { + Format([][]byte) error + Close() error } -func (c *csvFormatter) Begin(columns []*sql.ColumnType) error { - c.count = len(columns) +func CSV(w io.Writer, columns []*sql.ColumnType) (Formatter, error) { + f := &csvFormatter{ + w: csv.NewWriter(w), + count: len(columns), + } - header := make([]string, c.count) + header := make([]string, f.count) for i, column := range columns { header[i] = column.Name() } - return c.w.Write(header) + if err := f.w.Write(header); err != nil { + return nil, err + } + + return f, nil +} + +// csvFormatter formats columns in CSV format. +type csvFormatter struct { + w *csv.Writer + count int } -func (c *csvFormatter) Write(record [][]byte) error { - if c.count != len(record) { +func (f *csvFormatter) Format(record [][]byte) error { + if f.count != len(record) { return ErrRecordLength } - strings := make([]string, c.count) + strings := make([]string, f.count) for i, item := range record { strings[i] = string(item) } - return c.w.Write(strings) + return f.w.Write(strings) } -func (c *csvFormatter) End() error { - c.w.Flush() - return c.w.Error() +func (f *csvFormatter) Close() error { + f.w.Flush() + return f.w.Error() +} + +func YAML(w io.Writer, columns []*sql.ColumnType) (Formatter, error) { + return &yamlFormatter{columns: columns}, nil } // yamlFormatter formats columns in YAML format. @@ -61,11 +69,7 @@ type yamlFormatter struct { columns []*sql.ColumnType } -func (c *yamlFormatter) Begin(columns []*sql.ColumnType) error { - return nil -} - -func (c *yamlFormatter) Write(record [][]byte) error { +func (c *yamlFormatter) Format(record [][]byte) error { if len(c.columns) != len(record) { return ErrRecordLength } @@ -73,10 +77,29 @@ func (c *yamlFormatter) Write(record [][]byte) error { return nil } -func (c *yamlFormatter) End() error { +func (c *yamlFormatter) Close() error { return nil } +const ( + openBracket = byte('[') + closeBracket = byte(']') + comma = byte(',') +) + +func JSON(w io.Writer, columns []*sql.ColumnType) (Formatter, error) { + f := &jsonFormatter{ + w: w, + columns: columns, + } + + if err := f.writeByte(openBracket); err != nil { + return nil, err + } + + return f, nil +} + // jsonFormatter formats columns in JSON format. type jsonFormatter struct { w io.Writer @@ -84,19 +107,14 @@ type jsonFormatter struct { notFirst bool } -func (c *jsonFormatter) Begin(columns []*sql.ColumnType) error { - c.columns = columns - return writeByte(c.w, openBracket) -} - -func (c *jsonFormatter) Write(record [][]byte) error { - if len(c.columns) != len(record) { +func (f *jsonFormatter) Format(record [][]byte) error { + if len(f.columns) != len(record) { return ErrRecordLength } m := make(map[string]interface{}) - for i, column := range c.columns { - r, err := parse(record[i], c.columns[i].DatabaseTypeName()) + for i, column := range f.columns { + r, err := parse(record[i], f.columns[i].DatabaseTypeName()) if err != nil { return err } @@ -108,30 +126,30 @@ func (c *jsonFormatter) Write(record [][]byte) error { return err } - if c.notFirst { - err := writeByte(c.w, comma) + if f.notFirst { + err := f.writeByte(comma) if err != nil { return err } } - n, err := c.w.Write(b) + n, err := f.w.Write(b) if err != nil { return err } else if n != len(b) { return io.ErrShortWrite } - c.notFirst = true + f.notFirst = true return nil } -func (c *jsonFormatter) End() error { - return writeByte(c.w, closeBracket) +func (f *jsonFormatter) Close() error { + return f.writeByte(closeBracket) } -func writeByte(w io.Writer, b byte) error { - n, err := w.Write([]byte{b}) +func (f *jsonFormatter) writeByte(b byte) error { + n, err := f.w.Write([]byte{b}) if err != nil { return err } else if n != 1 { diff --git a/options.go b/options.go index 86935d4..47259ce 100644 --- a/options.go +++ b/options.go @@ -1,37 +1,25 @@ package chiv // Option configures the Archiver. Options can be provided when creating an Archiver or on each call to Archive. -type Option func(*Archiver) - -// Format uploaded to S3. -type Format int - -const ( - // CSV file format. - CSV Format = iota - // YAML file format. - YAML - // JSON file format. - JSON -) +type Option func(*config) // WithFormat configures the upload format. -func WithFormat(f Format) Option { - return func(a *Archiver) { - a.format = f +func WithFormat(f FormatterFunc) Option { + return func(c *config) { + c.format = f } } // WithKey configures the upload object key in S3. func WithKey(s string) Option { - return func(a *Archiver) { - a.key = s + return func(c *config) { + c.key = s } } // WithNull configures a custom null string. func WithNull(s string) Option { - return func(a *Archiver) { - a.null = []byte(s) + return func(c *config) { + c.null = []byte(s) } }