Skip to content

Commit

Permalink
works with postgres
Browse files Browse the repository at this point in the history
  • Loading branch information
gavincabbage committed Mar 26, 2019
1 parent 924f651 commit 12522b3
Show file tree
Hide file tree
Showing 7 changed files with 257 additions and 96 deletions.
132 changes: 107 additions & 25 deletions chiv.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@ package chiv
import (
"context"
"database/sql"
"encoding/csv"
"fmt"
"io"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
)

type Archiver struct {
db *sql.DB
s3 *s3manager.Uploader
key string
format Format
}

Expand All @@ -35,12 +39,12 @@ func NewArchiver(db *sql.DB, s3 *s3manager.Uploader, options ...Option) *Archive
}

// Archive a database table to S3.
func (a *Archiver) Archive(table string, bucket string, options ...Option) error {
func (a *Archiver) Archive(table, bucket string, options ...Option) error {
return a.ArchiveWithContext(context.Background(), table, bucket, options...)
}

// Archive a database table to S3 with context.
func (a *Archiver) ArchiveWithContext(ctx context.Context, table string, bucket string, options ...Option) error {
func (a *Archiver) ArchiveWithContext(ctx context.Context, table, bucket string, options ...Option) error {
archiver := archiver{
db: a.db,
s3: a.s3,
Expand All @@ -52,7 +56,7 @@ func (a *Archiver) ArchiveWithContext(ctx context.Context, table string, bucket
option(archiver.config)
}

return archiver.archive(table)
return archiver.archive(table, bucket)
}

type archiver struct {
Expand All @@ -62,28 +66,106 @@ type archiver struct {
config *Archiver
}

func (a *archiver) archive(table string) error {
const selectAll = "SELECT * FROM $1"

rows, err := a.db.QueryContext(a.ctx, selectAll, table)
if err != nil {
return err
}
defer rows.Close()

r, w := io.Pipe() // TODO figuring this all out...

for rows.Next() {

}

if err := rows.Err(); err != nil {
func (a *archiver) archive(table, bucket string) error {
errs := make(chan error)
r, w := io.Pipe()
defer r.Close()
defer w.Close()

go func() {
cw := csv.NewWriter(w)

selectAll := fmt.Sprintf(`select * from "%s";`, table)
rows, err := a.db.QueryContext(a.ctx, selectAll)
if err != nil {
errs <- err
return
}
defer rows.Close()

columns, err := rows.Columns()
if err != nil {
errs <- err
return
}

if err := cw.Write(columns); err != nil {
errs <- err
return
}

var (
rawBytes = make([]sql.RawBytes, len(columns))
record = make([]string, len(columns))
dest = make([]interface{}, len(columns))
)
for i := range rawBytes {
dest[i] = &rawBytes[i]
}

for rows.Next() {
err = rows.Scan(dest...)
if err != nil {
errs <- err
return
}

for i, raw := range rawBytes {
if raw == nil {
record[i] = "\\N"
} else {
record[i] = string(raw)
}
}

if err := cw.Write(record); err != nil {
errs <- err
return
}
}

if err := rows.Err(); err != nil {
errs <- err
return
}

cw.Flush()
if err := cw.Error(); err != nil {
errs <- err
return
}

if err := w.Close(); err != nil {
errs <- err
return
}
}()

go func() {
if a.config.key == "" {
switch a.config.format {
case CSV:
a.config.key = fmt.Sprintf("%s.csv", table)
case JSON:
a.config.key = fmt.Sprintf("%s.json", table)
}
}

if _, err := a.s3.UploadWithContext(a.ctx, &s3manager.UploadInput{
Body: r,
Bucket: aws.String(bucket),
Key: aws.String(a.config.key),
}); err != nil {
errs <- err
}

errs <- nil
}()

select {
case err := <-errs:
return err
case <-a.ctx.Done():
return nil
}

// TODO the work
// db cursor selecting: ???
// s3 streaming: https://docs.aws.amazon.com/code-samples/latest/catalog/go-s3-upload_arbitrary_sized_stream.go.html

return nil // TODO return size or some other info along w/ error?
}
182 changes: 112 additions & 70 deletions chiv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,111 +4,153 @@ package chiv_test

import (
"database/sql"
"io/ioutil"
"os"
"strings"
"testing"

"github.com/aws/aws-sdk-go/service/s3/s3manager"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
_ "github.com/lib/pq"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/gavincabbage/chiv"
)

func TestArchive(t *testing.T) {
var (
db = newDB(t)
s3client = newS3(t)
uploader = s3manager.NewUploaderWithClient(s3client)
downloader = s3manager.NewDownloaderWithClient(s3client)
)

mustExec(t, db, `
CREATE TABLE IF NOT EXISTS "test_table" (
id UUID PRIMARY KEY,
text_column TEXT,
char_column VARCHAR(50),
int_column INTEGER,
bool_column BOOLEAN,
ts_column TIMESTAMP
);`)
defer mustExec(t, db, `DROP TABLE "test_table";`)

mustExec(t, db, `
INSERT INTO "test_table" VALUES (
'ea09d13c-f441-4550-9492-115f8b409c96',
'some text',
'some chars',
42,
true,
'2018-01-04'::timestamp
);`)

mustExec(t, db, `
INSERT INTO "test_table" VALUES (
'7530a381-526a-42aa-a9ba-97fb2bca283f',
'some more text',
'some more chars',
101,
false,
'2018-02-05'::timestamp
);`)

expected := `id,text_column,char_column,int_column,bool_column,ts_column
ea09d13c-f441-4550-9492-115f8b409c96,some text,some chars,42,true,SOMETIMESTAMP
7530a381-526a-42aa-a9ba-97fb2bca283f,some more text,some more chars,101,false,OTHERTIMESTAMP`

if _, err := s3client.CreateBucket(&s3.CreateBucketInput{
Bucket: aws.String("test_bucket"),
}); err != nil {
t.Error(err)
func TestArchiver_Archive(t *testing.T) {
cases := []struct {
name string
driver string
database string
setup string
teardown string
expected string
bucket string
table string
key string
options []chiv.Option
}{
{
name: "postgres to csv",
driver: "postgres",
database: os.Getenv("POSTGRES_URL"),
setup: "./testdata/postgres_to_csv_setup.sql",
teardown: "./testdata/postgres_to_csv_teardown.sql",
expected: "./testdata/postgres_to_csv.csv",
bucket: "postgres_to_csv_bucket",
table: "postgres_to_csv_table",
key: "postgres_to_csv_table.csv",
options: []chiv.Option{},
},
{
name: "postgres to csv key override",
driver: "postgres",
database: os.Getenv("POSTGRES_URL"),
setup: "./testdata/postgres_to_csv_setup.sql",
teardown: "./testdata/postgres_to_csv_teardown.sql",
expected: "./testdata/postgres_to_csv.csv",
bucket: "postgres_to_csv_bucket",
table: "postgres_to_csv_table",
key: "postgres_to_csv_custom_key",
options: []chiv.Option{
chiv.WithKey("postgres_to_csv_custom_key"),
},
},
}

subject := chiv.NewArchiver(db, uploader)
assert.NotNil(t, subject)
for _, test := range cases {
t.Run(test.name, func(t *testing.T) {
var (
db = newDB(t, test.driver, test.database)
s3client = newS3Client(t, os.Getenv("AWS_REGION"), os.Getenv("AWS_ENDPOINT"))
uploader = s3manager.NewUploaderWithClient(s3client)
downloader = s3manager.NewDownloaderWithClient(s3client)
)

err := subject.Archive("test_table", "test_bucket")
require.NoError(t, err)
exec(t, db, test.setup)
defer exec(t, db, test.teardown)

b := &aws.WriteAtBuffer{}
n, err := downloader.Download(b, &s3.GetObjectInput{
Bucket: aws.String("test_bucket"),
Key: aws.String("test_table.csv"),
})
require.NoError(t, err)
require.Equal(t, len([]byte(expected)), n)
require.Equal(t, expected, string(b.Bytes()))
createBucket(t, s3client, test.bucket)
expected := readFile(t, test.expected)

subject := chiv.NewArchiver(db, uploader)
assert.NotNil(t, subject)

require.NoError(t, subject.Archive(test.table, test.bucket, test.options...))

n, actual := download(t, downloader, test.bucket, test.key)
require.Equal(t, len([]byte(expected)), n)
require.Equal(t, expected, actual)
})
}
}

func newDB(t *testing.T) *sql.DB {
db, err := sql.Open("postgres", os.Getenv("DATABASE_URL"))
require.NoError(t, err)
func newDB(t *testing.T, driver string, url string) *sql.DB {
db, err := sql.Open(driver, url)
if err != nil {
t.Error(err)
}

return db
}

func newS3(t *testing.T) *s3.S3 {
func newS3Client(t *testing.T, region string, endpoint string) *s3.S3 {
awsConfig := aws.NewConfig().
WithRegion(os.Getenv("AWS_REGION")).
WithRegion(region).
WithDisableSSL(true).
WithCredentials(credentials.NewEnvCredentials())

awsSession, err := session.NewSession(awsConfig)
require.NoError(t, err)
if err != nil {
t.Error(err)
}

client := s3.New(awsSession)
client.Endpoint = os.Getenv("AWS_ENDPOINT")
client.Endpoint = endpoint

return client
}

func mustExec(t *testing.T, db *sql.DB, query string) {
if _, err := db.Exec(query); err != nil {
func exec(t *testing.T, db *sql.DB, path string) {
file := readFile(t, path)
statements := strings.Split(string(file), ";\n")
for _, statement := range statements {
if _, err := db.Exec(statement); err != nil {
t.Error(err)
}
}
}

func createBucket(t *testing.T, client *s3.S3, name string) {
if _, err := client.CreateBucket(&s3.CreateBucketInput{
Bucket: aws.String(name),
}); err != nil {
t.Error(err)
}
}

func readFile(t *testing.T, path string) string {
contents, err := ioutil.ReadFile(path)
if err != nil {
t.Error(err)
}

return string(contents)
}

func download(t *testing.T, downloader *s3manager.Downloader, bucket string, key string) (int, string) {
b := &aws.WriteAtBuffer{}
n, err := downloader.Download(b, &s3.GetObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
})
if err != nil {
t.Error(err)
}

return int(n), string(b.Bytes())
}
Loading

0 comments on commit 12522b3

Please sign in to comment.