From 2e42faba0f888f4ac0450d38a2ac3101b1d51157 Mon Sep 17 00:00:00 2001 From: Gavin Cabbage Date: Sun, 6 Oct 2019 20:01:14 -0400 Subject: [PATCH] implement and test Extensioner --- chiv.go | 6 ++++ chiv_formatters.go | 17 +++++++++- chiv_integration_test.go | 30 +++++++---------- chiv_test.go | 71 ++++++++++++++++++++++++++++++++-------- 4 files changed, 91 insertions(+), 33 deletions(-) diff --git a/chiv.go b/chiv.go index 17d173d..d471a2d 100644 --- a/chiv.go +++ b/chiv.go @@ -129,6 +129,9 @@ func (a *Archiver) archive(ctx context.Context, rows Rows, table, bucket string) formatter = a.format(w, columns) g, gctx = errgroup.WithContext(ctx) ) + if extensioner, ok := formatter.(Extensioner); ok && a.extension == "" { + a.extension = extensioner.Extension() + } g.Go(func() error { return a.download(gctx, rows, columns, formatter, w) }) @@ -218,6 +221,9 @@ func (a *Archiver) upload(ctx context.Context, r io.ReadCloser, table string, bu } }() + if table == "" { + table = "table" + } if a.key == "" { if a.extension != "" { a.key = fmt.Sprintf("%s.%s", table, a.extension) diff --git a/chiv_formatters.go b/chiv_formatters.go index 92d2743..9154027 100644 --- a/chiv_formatters.go +++ b/chiv_formatters.go @@ -10,7 +10,7 @@ import ( "regexp" "strconv" - "gopkg.in/yaml.v2" + yaml "gopkg.in/yaml.v2" ) // Column reports its name, database type name and scan type. @@ -90,6 +90,11 @@ func (f *csvFormatter) Close() error { return nil } +// Extension returns the default CSV formatter extension. +func (f *csvFormatter) Extension() string { + return "csv" +} + type yamlFormatter struct { w io.Writer columns []Column @@ -132,6 +137,11 @@ func (f *yamlFormatter) Close() error { return nil } +// Extension returns the default YAML formatter extension. +func (f *yamlFormatter) Extension() string { + return "yaml" +} + const ( openBracket = byte('[') closeBracket = byte(']') @@ -196,6 +206,11 @@ func (f *jsonFormatter) Close() error { return nil } +// Extension returns the default JSON formatter extension. +func (f *jsonFormatter) Extension() string { + return "json" +} + func (f *jsonFormatter) writeByte(b byte) error { _, err := f.w.Write([]byte{b}) if err != nil { diff --git a/chiv_integration_test.go b/chiv_integration_test.go index ac6faa0..59b00b6 100644 --- a/chiv_integration_test.go +++ b/chiv_integration_test.go @@ -16,7 +16,6 @@ import ( "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3" - "github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/stretchr/testify/require" @@ -56,7 +55,7 @@ func TestArchiver_Archive(t *testing.T) { { expected: "./testdata/postgres/postgres.csv", table: "postgres_table", - key: "postgres_table", + key: "postgres_table.csv", options: []chiv.Option{}, }, }, @@ -69,13 +68,13 @@ func TestArchiver_Archive(t *testing.T) { teardown: "./testdata/postgres/postgres_teardown.sql", bucket: "postgres_bucket", options: []chiv.Option{ - chiv.WithKey("postgres_table.csv"), + chiv.WithKey("archive.csv"), }, calls: []call{ { expected: "./testdata/postgres/postgres.csv", table: "postgres_table", - key: "postgres_table.csv", + key: "archive.csv", options: []chiv.Option{}, }, }, @@ -94,7 +93,7 @@ func TestArchiver_Archive(t *testing.T) { { expected: "./testdata/postgres/postgres_with_null.csv", table: "postgres_table", - key: "postgres_table", + key: "postgres_table.csv", options: []chiv.Option{}, }, }, @@ -108,7 +107,6 @@ func TestArchiver_Archive(t *testing.T) { bucket: "postgres_bucket", options: []chiv.Option{ chiv.WithFormat(chiv.JSON), - chiv.WithKey("postgres_table.json"), }, calls: []call{ { @@ -128,7 +126,6 @@ func TestArchiver_Archive(t *testing.T) { bucket: "postgres_bucket", options: []chiv.Option{ chiv.WithFormat(chiv.YAML), - chiv.WithKey("postgres_table.yaml"), }, calls: []call{ { @@ -156,16 +153,13 @@ func TestArchiver_Archive(t *testing.T) { key: "postgres_table.json", options: []chiv.Option{ chiv.WithFormat(chiv.JSON), - chiv.WithKey("postgres_table.json"), }, }, { expected: "./testdata/postgres/postgres.yaml", table: "postgres_table", key: "postgres_table.yaml", - options: []chiv.Option{ - chiv.WithKey("postgres_table.yaml"), - }, + options: []chiv.Option{}, }, }, }, @@ -191,7 +185,7 @@ func TestArchiver_Archive(t *testing.T) { { expected: "./testdata/postgres/postgres.yaml", table: "postgres_table", - key: "postgres_table", + key: "postgres_table.yaml", options: []chiv.Option{}, }, }, @@ -205,7 +199,6 @@ func TestArchiver_Archive(t *testing.T) { bucket: "postgres_bucket", options: []chiv.Option{ chiv.WithFormat(chiv.CSV), - chiv.WithExtension("csv"), }, calls: []call{ { @@ -234,7 +227,7 @@ func TestArchiver_Archive(t *testing.T) { { expected: "./testdata/postgres/postgres_subset.csv", table: "postgres_table", - key: "postgres_table", + key: "postgres_table.csv", options: []chiv.Option{ chiv.WithColumns("id", "text_column", "int_column"), }, @@ -253,7 +246,7 @@ func TestArchiver_Archive(t *testing.T) { { expected: "./testdata/mariadb/happy.csv", table: "test_table", - key: "test_table", + key: "test_table.csv", options: []chiv.Option{}, }, }, @@ -267,13 +260,12 @@ func TestArchiver_Archive(t *testing.T) { bucket: "mariadb_bucket", options: []chiv.Option{ chiv.WithFormat(chiv.YAML), - chiv.WithKey("mariadb_table.yaml"), }, calls: []call{ { expected: "./testdata/mariadb/happy.yaml", table: "test_table", - key: "mariadb_table.yaml", + key: "test_table.yaml", options: []chiv.Option{}, }, }, @@ -287,7 +279,6 @@ func TestArchiver_Archive(t *testing.T) { bucket: "mariadb_bucket", options: []chiv.Option{ chiv.WithFormat(chiv.JSON), - chiv.WithKey("test_table.json"), }, calls: []call{ { @@ -336,6 +327,7 @@ func TestArchiveWithContext(t *testing.T) { driver = "postgres" bucket = "postgres_bucket" table = "postgres_table" + key = "postgres_table.csv" setup = "./testdata/postgres/postgres_setup.sql" teardown = "./testdata/postgres/postgres_teardown.sql" expected = "./testdata/postgres/postgres.csv" @@ -354,7 +346,7 @@ func TestArchiveWithContext(t *testing.T) { require.NoError(t, chiv.ArchiveWithContext(context.Background(), db, uploader, table, bucket)) - actual := download(t, downloader, bucket, table) + actual := download(t, downloader, bucket, key) require.Equal(t, readFile(t, expected), actual) } diff --git a/chiv_test.go b/chiv_test.go index 6a0815c..9fdeca0 100644 --- a/chiv_test.go +++ b/chiv_test.go @@ -22,15 +22,17 @@ func TestArchiveRows(t *testing.T) { rows *rows uploader *uploader bucket string - formatter *formatter + formatter chiv.Formatter options []chiv.Option expectedErr string + expectedKey string }{ { - name: "base case", - rows: &rows{}, - uploader: &uploader{}, - formatter: &formatter{}, + name: "base case", + rows: &rows{}, + expectedKey: "table", + uploader: &uploader{}, + formatter: &formatter{}, }, { name: "happy path one row", @@ -38,8 +40,9 @@ func TestArchiveRows(t *testing.T) { columns: []string{"first_column", "second_column"}, scan: [][]string{{"first", "second"}}, }, - uploader: &uploader{}, - formatter: &formatter{}, + expectedKey: "table", + uploader: &uploader{}, + formatter: &formatter{}, }, { name: "happy path multiple rows", @@ -51,8 +54,9 @@ func TestArchiveRows(t *testing.T) { {"seventh", "eighth", "ninth"}, }, }, - uploader: &uploader{}, - formatter: &formatter{}, + expectedKey: "table", + uploader: &uploader{}, + formatter: &formatter{}, }, { name: "column types error", @@ -62,6 +66,7 @@ func TestArchiveRows(t *testing.T) { columnTypesErr: errors.New("column types"), }, expectedErr: "chiv: getting column types from rows: column types", + expectedKey: "table", uploader: &uploader{}, formatter: &formatter{}, }, @@ -72,6 +77,7 @@ func TestArchiveRows(t *testing.T) { scan: [][]string{{"first", "second"}}, }, expectedErr: "chiv: downloading: opening formatter: opening formatter", + expectedKey: "table", uploader: &uploader{}, formatter: &formatter{ openErr: errors.New("opening formatter"), @@ -85,6 +91,7 @@ func TestArchiveRows(t *testing.T) { scanErr: errors.New("scanning"), }, expectedErr: "chiv: downloading: scanning row: scanning", + expectedKey: "table", uploader: &uploader{}, formatter: &formatter{}, }, @@ -95,6 +102,7 @@ func TestArchiveRows(t *testing.T) { scan: [][]string{{"first", "second"}}, }, expectedErr: "chiv: downloading: formatting row: formatting", + expectedKey: "table", uploader: &uploader{}, formatter: &formatter{ formatErr: errors.New("formatting"), @@ -108,6 +116,7 @@ func TestArchiveRows(t *testing.T) { errErr: errors.New("database"), }, expectedErr: "chiv: downloading: scanning rows: database", + expectedKey: "table", uploader: &uploader{}, formatter: &formatter{}, }, @@ -118,6 +127,7 @@ func TestArchiveRows(t *testing.T) { scan: [][]string{{"first", "second"}}, }, expectedErr: "chiv: downloading: closing formatter: closing formatter", + expectedKey: "table", uploader: &uploader{}, formatter: &formatter{ closeErr: errors.New("closing formatter"), @@ -130,11 +140,24 @@ func TestArchiveRows(t *testing.T) { scan: [][]string{{"first", "second"}}, }, expectedErr: "chiv: uploading: uploading", + expectedKey: "table", uploader: &uploader{ uploadErr: errors.New("uploading"), }, formatter: &formatter{}, }, + { + name: "extension formatter", + rows: &rows{ + columns: []string{"first_column", "second_column"}, + scan: [][]string{{"first", "second"}}, + }, + expectedKey: "NOTEQUALWTFtable.ext", + uploader: &uploader{ + uploadErr: errors.New("uploading"), + }, + formatter: &extensionFormatter{&formatter{}}, + }, } for _, test := range cases { @@ -149,14 +172,25 @@ func TestArchiveRows(t *testing.T) { return } + var f *formatter + switch v := test.formatter.(type) { + case *extensionFormatter: + f = v.formatter + case *formatter: + f = v + default: + t.Fatal("unrecognized formatter type") + } + require.NoError(t, err) - require.True(t, test.formatter.closed) + require.True(t, f.closed) + require.Equal(t, test.expectedKey, test.uploader.uploadKey) for i := range test.rows.scan { for j := range test.rows.scan[i] { - require.True(t, i < len(test.formatter.written) && j < len(test.formatter.written[i]), "formatter written record count") + require.True(t, i < len(f.written) && j < len(f.written[i]), "formatter written record count") expected := test.rows.scan[i][j] - actual := test.formatter.written[i][j] + actual := f.written[i][j] require.Equal(t, expected, actual) } } @@ -173,7 +207,7 @@ type rows struct { } func (r *rows) ColumnTypes() ([]*sql.ColumnType, error) { - return make([](*sql.ColumnType), len(r.columns)), r.columnTypesErr + return make([]*sql.ColumnType, len(r.columns)), r.columnTypesErr } func (r *rows) Next() bool { @@ -201,6 +235,7 @@ func (r *rows) Err() error { } type uploader struct { + uploadKey string uploadErr error } @@ -212,6 +247,7 @@ func (u *uploader) UploadWithContext(ctx aws.Context, input *s3manager.UploadInp } } + u.uploadKey = *input.Key return nil, u.uploadErr } @@ -245,7 +281,16 @@ func (f *formatter) Format(record [][]byte) error { return nil } + func (f *formatter) Close() error { f.closed = true return f.closeErr } + +type extensionFormatter struct { + *formatter +} + +func (f *extensionFormatter) Extension() string { + return "ext" +}