Skip to content

Commit

Permalink
format func first pass
Browse files Browse the repository at this point in the history
  • Loading branch information
gavincabbage committed Apr 12, 2019
1 parent 145bceb commit 18fb791
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 114 deletions.
69 changes: 28 additions & 41 deletions chiv.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,47 @@ package chiv
import (
"context"
"database/sql"
"encoding/csv"
"errors"
"fmt"
"io"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
)

var (
// DefaultFormat is CSV.
DefaultFormatFunc = CSV
// ErrRecordLength does not match the number of columns.
ErrRecordLength = errors.New("record length does not match number of columns")
)

// Archiver archives arbitrarily large relational database tables to Amazon S3. It contains a database connection
// and upload client. Options set on creation apply to all calls to Archive unless overridden.
type Archiver struct {
db *sql.DB
s3 *s3manager.Uploader
format Format
config config
}

type config struct {
format FormatterFunc
key string
null []byte
}

const (
// DefaultFormat is CSV.
DefaultFormat = CSV
)

// NewArchiver constructs an Archiver with the given database connection, S3 uploader and options.
func NewArchiver(db *sql.DB, s3 *s3manager.Uploader, options ...Option) *Archiver {
a := Archiver{
db: db,
s3: s3,
format: DefaultFormat,
db: db,
s3: s3,
config: config{
format: DefaultFormatFunc,
},
}

for _, option := range options {
option(&a)
option(&a.config)
}

return &a
Expand All @@ -53,11 +61,11 @@ func (a *Archiver) ArchiveWithContext(ctx context.Context, table, bucket string,
db: a.db,
s3: a.s3,
ctx: ctx,
config: a,
config: a.config,
}

for _, option := range options {
option(archiver.config)
option(&archiver.config)
}

return archiver.archive(table, bucket)
Expand All @@ -67,7 +75,7 @@ type archiver struct {
db *sql.DB
s3 *s3manager.Uploader
ctx context.Context
config *Archiver
config config
}

func (a *archiver) archive(table string, bucket string) error {
Expand All @@ -87,23 +95,7 @@ func (a *archiver) archive(table string, bucket string) error {
}
}

type formatter interface {
Begin([]*sql.ColumnType) error
Write([][]byte) error
End() error
}

func (a *archiver) download(wc io.WriteCloser, table string, errs chan error) {
var w formatter
switch a.config.format {
case YAML:
w = &yamlFormatter{}
case JSON:
w = &jsonFormatter{w: wc}
default:
w = &csvFormatter{w: csv.NewWriter(wc)}
}

selectAll := fmt.Sprintf(`select * from "%s";`, table)
rows, err := a.db.QueryContext(a.ctx, selectAll)
if err != nil {
Expand All @@ -118,7 +110,8 @@ func (a *archiver) download(wc io.WriteCloser, table string, errs chan error) {
return
}

if err := w.Begin(columns); err != nil {
f, err := a.config.format(wc, columns)
if err != nil {
errs <- err
return
}
Expand Down Expand Up @@ -147,7 +140,7 @@ func (a *archiver) download(wc io.WriteCloser, table string, errs chan error) {
}
}

if err := w.Write(record); err != nil {
if err := f.Format(record); err != nil {
errs <- err
return
}
Expand All @@ -158,7 +151,7 @@ func (a *archiver) download(wc io.WriteCloser, table string, errs chan error) {
return
}

if err := w.End(); err != nil {
if err := f.Close(); err != nil {
errs <- err
return
}
Expand All @@ -171,14 +164,8 @@ func (a *archiver) download(wc io.WriteCloser, table string, errs chan error) {

func (a *archiver) upload(r io.Reader, table string, bucket string, errs chan error) {
if a.config.key == "" {
switch a.config.format {
case YAML:
a.config.key = fmt.Sprintf("%s.yml", table)
case JSON:
a.config.key = fmt.Sprintf("%s.json", table)
default:
a.config.key = fmt.Sprintf("%s.csv", table)
}
// TODO if a.config.extension or something? can pass in '.json'? wish i could connect to formatter hm
a.config.key = table
}

if _, err := a.s3.UploadWithContext(a.ctx, &s3manager.UploadInput{
Expand Down
9 changes: 5 additions & 4 deletions chiv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func TestArchiver_Archive(t *testing.T) {
expected: "./testdata/postgres.csv",
bucket: "postgres_bucket",
table: "postgres_table",
key: "postgres_table.csv",
key: "postgres_table",
options: []chiv.Option{},
},
{
Expand All @@ -55,9 +55,9 @@ func TestArchiver_Archive(t *testing.T) {
expected: "./testdata/postgres.csv",
bucket: "postgres_bucket",
table: "postgres_table",
key: "postgres_custom_key",
key: "postgres_table.csv",
options: []chiv.Option{
chiv.WithKey("postgres_custom_key"),
chiv.WithKey("postgres_table.csv"),
},
},
{
Expand All @@ -69,7 +69,7 @@ func TestArchiver_Archive(t *testing.T) {
expected: "./testdata/postgres_with_null.csv",
bucket: "postgres_bucket",
table: "postgres_table",
key: "postgres_table.csv",
key: "postgres_table",
options: []chiv.Option{
chiv.WithNull("custom_null"),
},
Expand All @@ -86,6 +86,7 @@ func TestArchiver_Archive(t *testing.T) {
key: "postgres_table.json",
options: []chiv.Option{
chiv.WithFormat(chiv.JSON),
chiv.WithKey("postgres_table.json"),
},
},
}
Expand Down
116 changes: 67 additions & 49 deletions formatters.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,99 +4,117 @@ import (
"database/sql"
"encoding/csv"
"encoding/json"
"errors"
"io"
"regexp"
"strconv"
)

const (
openBracket = byte('[')
closeBracket = byte(']')
comma = byte(',')
)
// FormatterFunc returns an initialized Formatter.
type FormatterFunc func(io.Writer, []*sql.ColumnType) (Formatter, error)

var (
// ErrRecordLength does not match the number of columns.
ErrRecordLength = errors.New("record length does not match number of columns")
)

// csvFormatter formats columns in CSV format.
type csvFormatter struct {
w *csv.Writer
count int
// Formatter formats and writes records.
type Formatter interface {
Format([][]byte) error
Close() error
}

func (c *csvFormatter) Begin(columns []*sql.ColumnType) error {
c.count = len(columns)
func CSV(w io.Writer, columns []*sql.ColumnType) (Formatter, error) {
f := &csvFormatter{
w: csv.NewWriter(w),
count: len(columns),
}

header := make([]string, c.count)
header := make([]string, f.count)
for i, column := range columns {
header[i] = column.Name()
}

return c.w.Write(header)
if err := f.w.Write(header); err != nil {
return nil, err
}

return f, nil
}

// csvFormatter formats columns in CSV format.
type csvFormatter struct {
w *csv.Writer
count int
}

func (c *csvFormatter) Write(record [][]byte) error {
if c.count != len(record) {
func (f *csvFormatter) Format(record [][]byte) error {
if f.count != len(record) {
return ErrRecordLength
}

strings := make([]string, c.count)
strings := make([]string, f.count)
for i, item := range record {
strings[i] = string(item)
}

return c.w.Write(strings)
return f.w.Write(strings)
}

func (c *csvFormatter) End() error {
c.w.Flush()
return c.w.Error()
func (f *csvFormatter) Close() error {
f.w.Flush()
return f.w.Error()
}

func YAML(w io.Writer, columns []*sql.ColumnType) (Formatter, error) {
return &yamlFormatter{columns: columns}, nil
}

// yamlFormatter formats columns in YAML format.
type yamlFormatter struct {
columns []*sql.ColumnType
}

func (c *yamlFormatter) Begin(columns []*sql.ColumnType) error {
return nil
}

func (c *yamlFormatter) Write(record [][]byte) error {
func (c *yamlFormatter) Format(record [][]byte) error {
if len(c.columns) != len(record) {
return ErrRecordLength
}

return nil
}

func (c *yamlFormatter) End() error {
func (c *yamlFormatter) Close() error {
return nil
}

const (
openBracket = byte('[')
closeBracket = byte(']')
comma = byte(',')
)

func JSON(w io.Writer, columns []*sql.ColumnType) (Formatter, error) {
f := &jsonFormatter{
w: w,
columns: columns,
}

if err := f.writeByte(openBracket); err != nil {
return nil, err
}

return f, nil
}

// jsonFormatter formats columns in JSON format.
type jsonFormatter struct {
w io.Writer
columns []*sql.ColumnType
notFirst bool
}

func (c *jsonFormatter) Begin(columns []*sql.ColumnType) error {
c.columns = columns
return writeByte(c.w, openBracket)
}

func (c *jsonFormatter) Write(record [][]byte) error {
if len(c.columns) != len(record) {
func (f *jsonFormatter) Format(record [][]byte) error {
if len(f.columns) != len(record) {
return ErrRecordLength
}

m := make(map[string]interface{})
for i, column := range c.columns {
r, err := parse(record[i], c.columns[i].DatabaseTypeName())
for i, column := range f.columns {
r, err := parse(record[i], f.columns[i].DatabaseTypeName())
if err != nil {
return err
}
Expand All @@ -108,30 +126,30 @@ func (c *jsonFormatter) Write(record [][]byte) error {
return err
}

if c.notFirst {
err := writeByte(c.w, comma)
if f.notFirst {
err := f.writeByte(comma)
if err != nil {
return err
}
}

n, err := c.w.Write(b)
n, err := f.w.Write(b)
if err != nil {
return err
} else if n != len(b) {
return io.ErrShortWrite
}

c.notFirst = true
f.notFirst = true
return nil
}

func (c *jsonFormatter) End() error {
return writeByte(c.w, closeBracket)
func (f *jsonFormatter) Close() error {
return f.writeByte(closeBracket)
}

func writeByte(w io.Writer, b byte) error {
n, err := w.Write([]byte{b})
func (f *jsonFormatter) writeByte(b byte) error {
n, err := f.w.Write([]byte{b})
if err != nil {
return err
} else if n != 1 {
Expand Down
Loading

0 comments on commit 18fb791

Please sign in to comment.