Skip to content

Commit

Permalink
Merge pull request #151 from pressly/87
Browse files Browse the repository at this point in the history
Refactor SQL parser
  • Loading branch information
VojtechVitek authored Mar 5, 2019
2 parents e4b9895 + 02bb13b commit f54a6e4
Show file tree
Hide file tree
Showing 7 changed files with 565 additions and 441 deletions.
2 changes: 1 addition & 1 deletion goose.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"sync"
)

const VERSION = "v2.6.0"
const VERSION = "v2.7.0-rc1"

var (
duplicateCheckOnce sync.Once
Expand Down
2 changes: 0 additions & 2 deletions migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ func newMigration(v int64, src string) *Migration {
}

func TestMigrationSort(t *testing.T) {

ms := Migrations{}

// insert in any order
Expand All @@ -26,7 +25,6 @@ func TestMigrationSort(t *testing.T) {
}

func validateMigrationSort(t *testing.T, ms Migrations, sorted []int64) {

for i, m := range ms {
if sorted[i] != m.Version {
t.Error("incorrect sorted version")
Expand Down
44 changes: 34 additions & 10 deletions migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package goose
import (
"database/sql"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
Expand Down Expand Up @@ -38,7 +39,6 @@ func (m *Migration) Up(db *sql.DB) error {
if err := m.run(db, true); err != nil {
return err
}
log.Println("OK ", filepath.Base(m.Source))
return nil
}

Expand All @@ -47,51 +47,75 @@ func (m *Migration) Down(db *sql.DB) error {
if err := m.run(db, false); err != nil {
return err
}
log.Println("OK ", filepath.Base(m.Source))
return nil
}

func (m *Migration) run(db *sql.DB, direction bool) error {
switch filepath.Ext(m.Source) {
case ".sql":
if err := runSQLMigration(db, m.Source, m.Version, direction); err != nil {
return errors.Wrapf(err, "failed to run SQL migration %q", filepath.Base(m.Source))
f, err := os.Open(m.Source)
if err != nil {
return errors.Wrapf(err, "ERROR %v: failed to open SQL migration file", filepath.Base(m.Source))
}
defer f.Close()

statements, useTx, err := parseSQLMigration(f, direction)
if err != nil {
return errors.Wrapf(err, "ERROR %v: failed to parse SQL migration file", filepath.Base(m.Source))
}

if err := runSQLMigration(db, statements, useTx, m.Version, direction); err != nil {
return errors.Wrapf(err, "ERROR %v: failed to run SQL migration", filepath.Base(m.Source))
}

if len(statements) > 0 {
log.Println("OK ", filepath.Base(m.Source))
} else {
log.Println("EMPTY", filepath.Base(m.Source))
}

case ".go":
if !m.Registered {
return errors.Errorf("failed to run Go migration %q: Go functions must be registered and built into a custom binary (see https://github.com/pressly/goose/tree/master/examples/go-migrations)", m.Source)
return errors.Errorf("ERROR %v: failed to run Go migration: Go functions must be registered and built into a custom binary (see https://github.com/pressly/goose/tree/master/examples/go-migrations)", m.Source)
}
tx, err := db.Begin()
if err != nil {
return errors.Wrap(err, "failed to begin transaction")
return errors.Wrap(err, "ERROR failed to begin transaction")
}

fn := m.UpFn
if !direction {
fn = m.DownFn
}

if fn != nil {
// Run Go migration function.
if err := fn(tx); err != nil {
tx.Rollback()
return errors.Wrapf(err, "failed to run Go migration %q", filepath.Base(m.Source))
return errors.Wrapf(err, "ERROR %v: failed to run Go migration function %T", filepath.Base(m.Source), fn)
}
}

if direction {
if _, err := tx.Exec(GetDialect().insertVersionSQL(), m.Version, direction); err != nil {
tx.Rollback()
return errors.Wrap(err, "failed to execute transaction")
return errors.Wrap(err, "ERROR failed to execute transaction")
}
} else {
if _, err := tx.Exec(GetDialect().deleteVersionSQL(), m.Version); err != nil {
tx.Rollback()
return errors.Wrap(err, "failed to execute transaction")
return errors.Wrap(err, "ERROR failed to execute transaction")
}
}

if err := tx.Commit(); err != nil {
return errors.Wrap(err, "failed to commit transaction")
return errors.Wrap(err, "ERROR failed to commit transaction")
}

if fn != nil {
log.Println("OK ", filepath.Base(m.Source))
} else {
log.Println("EMPTY", filepath.Base(m.Source))
}

return nil
Expand Down
181 changes: 17 additions & 164 deletions migration_sql.go
Original file line number Diff line number Diff line change
@@ -1,153 +1,12 @@
package goose

import (
"bufio"
"bytes"
"database/sql"
"fmt"
"io"
"os"
"regexp"
"strings"
"sync"

"github.com/pkg/errors"
)

const sqlCmdPrefix = "-- +goose "
const scanBufSize = 4 * 1024 * 1024

var bufferPool = sync.Pool{
New: func() interface{} {
return make([]byte, scanBufSize)
},
}

// Checks the line to see if the line has a statement-ending semicolon
// or if the line contains a double-dash comment.
func endsWithSemicolon(line string) bool {
scanBuf := bufferPool.Get().([]byte)
defer bufferPool.Put(scanBuf)

prev := ""
scanner := bufio.NewScanner(strings.NewReader(line))
scanner.Buffer(scanBuf, scanBufSize)
scanner.Split(bufio.ScanWords)

for scanner.Scan() {
word := scanner.Text()
if strings.HasPrefix(word, "--") {
break
}
prev = word
}

return strings.HasSuffix(prev, ";")
}

// Split the given sql script into individual statements.
//
// The base case is to simply split on semicolons, as these
// naturally terminate a statement.
//
// However, more complex cases like pl/pgsql can have semicolons
// within a statement. For these cases, we provide the explicit annotations
// 'StatementBegin' and 'StatementEnd' to allow the script to
// tell us to ignore semicolons.
func getSQLStatements(r io.Reader, direction bool) ([]string, bool, error) {
var buf bytes.Buffer
scanBuf := bufferPool.Get().([]byte)
defer bufferPool.Put(scanBuf)

scanner := bufio.NewScanner(r)
scanner.Buffer(scanBuf, scanBufSize)

// track the count of each section
// so we can diagnose scripts with no annotations
upSections := 0
downSections := 0

statementEnded := false
ignoreSemicolons := false
directionIsActive := false
tx := true
stmts := []string{}

for scanner.Scan() {

line := scanner.Text()

// handle any goose-specific commands
if strings.HasPrefix(line, sqlCmdPrefix) {
cmd := strings.TrimSpace(line[len(sqlCmdPrefix):])
switch cmd {
case "Up":
directionIsActive = (direction == true)
upSections++
break

case "Down":
directionIsActive = (direction == false)
downSections++
break

case "StatementBegin":
if directionIsActive {
ignoreSemicolons = true
}
break

case "StatementEnd":
if directionIsActive {
statementEnded = (ignoreSemicolons == true)
ignoreSemicolons = false
}
break

case "NO TRANSACTION":
tx = false
break
}
}

if !directionIsActive {
continue
}

if _, err := buf.WriteString(line + "\n"); err != nil {
return nil, false, fmt.Errorf("io err: %v", err)
}

// Wrap up the two supported cases: 1) basic with semicolon; 2) psql statement
// Lines that end with semicolon that are in a statement block
// do not conclude statement.
if (!ignoreSemicolons && endsWithSemicolon(line)) || statementEnded {
statementEnded = false
stmts = append(stmts, buf.String())
buf.Reset()
}
}

if err := scanner.Err(); err != nil {
return nil, false, fmt.Errorf("scanning migration: %v", err)
}

// diagnose likely migration script errors
if ignoreSemicolons {
return nil, false, fmt.Errorf("parsing migration: saw '-- +goose StatementBegin' with no matching '-- +goose StatementEnd'")
}

if bufferRemaining := strings.TrimSpace(buf.String()); len(bufferRemaining) > 0 {
return nil, false, fmt.Errorf("parsing migration: unexpected unfinished SQL query: %s. potential missing semicolon", bufferRemaining)
}

if upSections == 0 && downSections == 0 {
return nil, false, fmt.Errorf("parsing migration: no Up/Down annotations found, so no statements were executed. See https://bitbucket.org/liamstask/goose/overview for details")
}

return stmts, tx, nil
}

// Run a migration specified in raw SQL.
//
// Sections of the script can be annotated with a special comment,
Expand All @@ -156,52 +15,41 @@ func getSQLStatements(r io.Reader, direction bool) ([]string, bool, error) {
//
// All statements following an Up or Down directive are grouped together
// until another direction directive is found.
func runSQLMigration(db *sql.DB, sqlFile string, v int64, direction bool) error {
f, err := os.Open(sqlFile)
if err != nil {
return errors.Wrap(err, "failed to open SQL migration file")
}
defer f.Close()

statements, useTx, err := getSQLStatements(f, direction)
if err != nil {
return err
}

func runSQLMigration(db *sql.DB, statements []string, useTx bool, v int64, direction bool) error {
if useTx {
// TRANSACTION.

printInfo("Begin transaction\n")
verboseInfo("Begin transaction")

tx, err := db.Begin()
if err != nil {
errors.Wrap(err, "failed to begin transaction")
}

for _, query := range statements {
printInfo("Executing statement: %s\n", clearStatement(query))
verboseInfo("Executing statement: %s\n", clearStatement(query))
if _, err = tx.Exec(query); err != nil {
printInfo("Rollback transaction\n")
verboseInfo("Rollback transaction")
tx.Rollback()
return errors.Wrapf(err, "failed to execute SQL query %q", clearStatement(query))
}
}

if direction {
if _, err := tx.Exec(GetDialect().insertVersionSQL(), v, direction); err != nil {
printInfo("Rollback transaction\n")
verboseInfo("Rollback transaction")
tx.Rollback()
return errors.Wrap(err, "failed to insert new goose version")
}
} else {
if _, err := tx.Exec(GetDialect().deleteVersionSQL(), v); err != nil {
printInfo("Rollback transaction\n")
verboseInfo("Rollback transaction")
tx.Rollback()
return errors.Wrap(err, "failed to delete goose version")
}
}

printInfo("Commit transaction\n")
verboseInfo("Commit transaction")
if err := tx.Commit(); err != nil {
return errors.Wrap(err, "failed to commit transaction")
}
Expand All @@ -211,7 +59,7 @@ func runSQLMigration(db *sql.DB, sqlFile string, v int64, direction bool) error

// NO TRANSACTION.
for _, query := range statements {
printInfo("Executing statement: %s\n", clearStatement(query))
verboseInfo("Executing statement: %s", clearStatement(query))
if _, err := db.Exec(query); err != nil {
return errors.Wrapf(err, "failed to execute SQL query %q", clearStatement(query))
}
Expand All @@ -223,18 +71,23 @@ func runSQLMigration(db *sql.DB, sqlFile string, v int64, direction bool) error
return nil
}

func printInfo(s string, args ...interface{}) {
const (
grayColor = "\033[90m"
resetColor = "\033[00m"
)

func verboseInfo(s string, args ...interface{}) {
if verbose {
log.Printf(s, args...)
log.Printf(grayColor+s+resetColor, args...)
}
}

var (
matchSQLComments = regexp.MustCompile(`(?m)^--.*$[\r\n]*`)
matchEmptyLines = regexp.MustCompile(`(?m)^$[\r\n]*`)
matchEmptyEOL = regexp.MustCompile(`(?m)^$[\r\n]*`) // TODO: Duplicate
)

func clearStatement(s string) string {
s = matchSQLComments.ReplaceAllString(s, ``)
return matchEmptyLines.ReplaceAllString(s, ``)
return matchEmptyEOL.ReplaceAllString(s, ``)
}
Loading

0 comments on commit f54a6e4

Please sign in to comment.