Skip to content

Commit

Permalink
feat: Make goose annotations case-insensitive (#704)
Browse files Browse the repository at this point in the history
  • Loading branch information
obalunenko authored Mar 4, 2024
1 parent 76946cc commit 48100ea
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 13 deletions.
95 changes: 84 additions & 11 deletions internal/sqlparser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,19 @@ func ParseSQLMigration(r io.Reader, direction Direction, debug bool) (stmts []st
if stateMachine.get() == start && strings.TrimSpace(line) == "" {
continue
}
// TODO(mf): validate annotations to avoid common user errors:
// https://github.com/pressly/goose/issues/163#issuecomment-501736725
if strings.HasPrefix(line, "--") {
cmd := strings.TrimSpace(strings.TrimPrefix(line, "--"))

// Check for annotations.
// All annotations must be in format: "-- +goose [annotation]"
if strings.HasPrefix(strings.TrimSpace(line), "--") && strings.Contains(line, "+goose") {
var cmd annotation

cmd, err = extractAnnotation(line)
if err != nil {
return nil, false, fmt.Errorf("failed to parse annotation line %q: %w", line, err)
}

switch cmd {
case "+goose Up":
case annotationUp:
switch stateMachine.get() {
case start:
stateMachine.set(gooseUp)
Expand All @@ -136,7 +142,7 @@ func ParseSQLMigration(r io.Reader, direction Direction, debug bool) (stmts []st
}
continue

case "+goose Down":
case annotationDown:
switch stateMachine.get() {
case gooseUp, gooseStatementEndUp:
// If we hit a down annotation, but the buffer is not empty, we have an unfinished SQL query from a
Expand All @@ -151,7 +157,7 @@ func ParseSQLMigration(r io.Reader, direction Direction, debug bool) (stmts []st
}
continue

case "+goose StatementBegin":
case annotationStatementBegin:
switch stateMachine.get() {
case gooseUp, gooseStatementEndUp:
stateMachine.set(gooseStatementBeginUp)
Expand All @@ -162,7 +168,7 @@ func ParseSQLMigration(r io.Reader, direction Direction, debug bool) (stmts []st
}
continue

case "+goose StatementEnd":
case annotationStatementEnd:
switch stateMachine.get() {
case gooseStatementBeginUp:
stateMachine.set(gooseStatementEndUp)
Expand All @@ -172,17 +178,20 @@ func ParseSQLMigration(r io.Reader, direction Direction, debug bool) (stmts []st
return nil, false, errors.New("'-- +goose StatementEnd' must be defined after '-- +goose StatementBegin', see https://github.com/pressly/goose#sql-migrations")
}

case "+goose NO TRANSACTION":
case annotationNoTransaction:
useTx = false
continue

case "+goose ENVSUB ON":
case annotationEnvsubOn:
useEnvsub = true
continue

case "+goose ENVSUB OFF":
case annotationEnvsubOff:
useEnvsub = false
continue

default:
return nil, false, fmt.Errorf("unknown annotation: %q", cmd)
}
}
// Once we've started parsing a statement the buffer is no longer empty,
Expand Down Expand Up @@ -277,6 +286,70 @@ func ParseSQLMigration(r io.Reader, direction Direction, debug bool) (stmts []st
return stmts, useTx, nil
}

type annotation string

const (
annotationUp annotation = "Up"
annotationDown annotation = "Down"
annotationStatementBegin annotation = "StatementBegin"
annotationStatementEnd annotation = "StatementEnd"
annotationNoTransaction annotation = "NO TRANSACTION"
annotationEnvsubOn annotation = "ENVSUB ON"
annotationEnvsubOff annotation = "ENVSUB OFF"
)

var supportedAnnotations = map[annotation]struct{}{
annotationUp: {},
annotationDown: {},
annotationStatementBegin: {},
annotationStatementEnd: {},
annotationNoTransaction: {},
annotationEnvsubOn: {},
annotationEnvsubOff: {},
}

var (
errEmptyAnnotation = errors.New("empty annotation")
errInvalidAnnotation = errors.New("invalid annotation")
)

// extractAnnotation extracts the annotation from the line.
// All annotations must be in format: "-- +goose [annotation]"
// Allowed annotations: Up, Down, StatementBegin, StatementEnd, NO TRANSACTION, ENVSUB ON, ENVSUB OFF
func extractAnnotation(line string) (annotation, error) {
// If line contains leading whitespace - return error.
if strings.HasPrefix(line, " ") || strings.HasPrefix(line, "\t") {
return "", fmt.Errorf("%q contains leading whitespace: %w", line, errInvalidAnnotation)
}

// Extract the annotation from the line, by removing the leading "--"
cmd := strings.ReplaceAll(line, "--", "")

// Extract the annotation from the line, by removing the leading "+goose"
cmd = strings.Replace(cmd, "+goose", "", 1)

if strings.Contains(cmd, "+goose") {
return "", fmt.Errorf("%q contains multiple '+goose' annotations: %w", cmd, errInvalidAnnotation)
}

// Remove leading and trailing whitespace from the annotation command.
cmd = strings.TrimSpace(cmd)

if cmd == "" {
return "", errEmptyAnnotation
}

a := annotation(cmd)

for s := range supportedAnnotations {
if strings.EqualFold(string(s), string(a)) {
return s, nil
}
}

return "", fmt.Errorf("%q not supported: %w", cmd, errInvalidAnnotation)
}

func missingSemicolonError(state parserState, direction Direction, s string) error {
return fmt.Errorf("failed to parse migration: state %d, direction: %v: unexpected unfinished SQL query: %q: missing semicolon?",
state,
Expand Down
79 changes: 79 additions & 0 deletions internal/sqlparser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -507,3 +507,82 @@ CREATE TABLE post (
check.HasError(t, err)
check.Contains(t, err.Error(), "variable substitution failed: $SOME_UNSET_VAR: required env var not set:")
}

func Test_extractAnnotation(t *testing.T) {
tests := []struct {
name string
input string
want annotation
wantErr func(t *testing.T, err error)
}{
{
name: "Up",
input: "-- +goose Up",
want: annotationUp,
wantErr: check.NoError,
},
{
name: "Down",
input: "-- +goose Down",
want: annotationDown,
wantErr: check.NoError,
},
{
name: "StmtBegin",
input: "-- +goose StatementBegin",
want: annotationStatementBegin,
wantErr: check.NoError,
},
{
name: "NoTransact",
input: "-- +goose NO TRANSACTION",
want: annotationNoTransaction,
wantErr: check.NoError,
},
{
name: "Unsupported",
input: "-- +goose unsupported",
want: "",
wantErr: check.HasError,
},
{
name: "Empty",
input: "-- +goose",
want: "",
wantErr: check.HasError,
},
{
name: "statement with spaces and Uppercase",
input: "-- +goose UP ",
want: annotationUp,
wantErr: check.NoError,
},
{
name: "statement with leading whitespace - error",
input: " -- +goose UP ",
want: "",
wantErr: check.HasError,
},
{
name: "statement with leading \t - error",
input: "\t-- +goose UP ",
want: "",
wantErr: check.HasError,
},
{
name: "multiple +goose annotations - error",
input: "-- +goose +goose Up",
want: "",
wantErr: check.HasError,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := extractAnnotation(tt.input)
tt.wantErr(t, err)

check.Equal(t, got, tt.want)
})
}
}
4 changes: 2 additions & 2 deletions internal/sqlparser/testdata/valid-up/test01/input.sql
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
-- +goose Up
-- +goose UP
CREATE TABLE emp (
empname text,
salary integer,
last_date timestamp,
last_user text
);

-- +goose StatementBegin
-- +goose statementBegin
CREATE FUNCTION emp_stamp() RETURNS trigger AS $emp_stamp$
BEGIN
-- Check that empname and salary are given
Expand Down

0 comments on commit 48100ea

Please sign in to comment.