diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0b08e855..60bbd0b5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -72,6 +72,10 @@ jobs: env: TEST_DATABASE_URL: postgres://postgres:postgres@127.0.0.1:5432/river_testdb?sslmode=disable + - name: Test riverdatabasesql + working-directory: ./riverdriver/riverdatabasesql + run: go test -race ./... + - name: Test riverpgxv5 working-directory: ./riverdriver/riverpgxv5 run: go test -race ./... @@ -121,6 +125,15 @@ jobs: contents: read # allow read access to pull request. Use with `only-new-issues` option. pull-requests: read + + strategy: + matrix: + submodule: + - . + - riverdriver + - riverdriver/riverdatabasesql + - riverdriver/riverpgxv5 + steps: - uses: actions/setup-go@v4 with: @@ -137,6 +150,7 @@ jobs: only-new-issues: true version: v1.55.2 + working-directory: ${{ matrix.submodule }} producer_sample: runs-on: ubuntu-latest @@ -204,7 +218,6 @@ jobs: sqlc-version: "1.22.0" - name: Run sqlc diff - working-directory: ./internal/dbsqlc run: | echo "Please make sure that all sqlc changes are checked in!" - sqlc diff + make verify diff --git a/.golangci.yml b/.golangci.yml index f6d0c697..c1bf07e9 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -64,6 +64,9 @@ linters-settings: - Default - Prefix(github.com/riverqueue) + gomoddirectives: + replace-local: true + gosec: excludes: - G404 # use of non-crypto random; overly broad for our use case diff --git a/CHANGELOG.md b/CHANGELOG.md index af3feb7c..e74a5bcf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- Added `rivermigrate/riverdatabasesql` driver to enable River Go migrations through Go's built in `database/sql` package. + ## [0.0.12] - 2023-12-02 ### Added diff --git a/Makefile b/Makefile index f06ffe8b..ab10a7a4 100644 --- a/Makefile +++ b/Makefile @@ -4,4 +4,16 @@ generate: generate/sqlc .PHONY: generate/sqlc generate/sqlc: - cd internal/dbsqlc && sqlc generate \ No newline at end of file + cd internal/dbsqlc && sqlc generate + cd riverdriver/riverdatabasesql/internal/dbsqlc && sqlc generate + cd riverdriver/riverpgxv5/internal/dbsqlc && sqlc generate + +.PHONY: verify +verify: +verify: verify/sqlc + +.PHONY: verify/sqlc +verify/sqlc: + cd internal/dbsqlc && sqlc diff + cd riverdriver/riverdatabasesql/internal/dbsqlc && sqlc diff + cd riverdriver/riverpgxv5/internal/dbsqlc && sqlc diff \ No newline at end of file diff --git a/docs/development.md b/docs/development.md index aa5ed01d..a53fa46a 100644 --- a/docs/development.md +++ b/docs/development.md @@ -32,7 +32,9 @@ queries. After changing an sqlc `.sql` file, generate Go with: ```shell git checkout master && git pull --rebase VERSION=v0.0.x +git tag riverdriver/VERSION -m "release riverdriver/VERSION" git tag riverdriver/riverpgxv5/$VERSION -m "release riverdriver/riverpgxv5/$VERSION" +git tag riverdriver/riverdatabasesql/$VERSION -m "release riverdriver/riverdatabasesql/$VERSION" git tag $VERSION git push --tags ``` diff --git a/go.mod b/go.mod index 0c4db7ab..d9b88a5a 100644 --- a/go.mod +++ b/go.mod @@ -1,15 +1,21 @@ module github.com/riverqueue/river -go 1.21.0 +go 1.21.4 -// replace github.com/riverqueue/river/riverdriver/riverpgxv5 => ./riverdriver/riverpgxv5 +replace github.com/riverqueue/river/riverdriver => ./riverdriver + +replace github.com/riverqueue/river/riverdriver/riverpgxv5 => ./riverdriver/riverpgxv5 + +replace github.com/riverqueue/river/riverdriver/riverdatabasesql => ./riverdriver/riverdatabasesql require ( github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa github.com/jackc/pgx/v5 v5.5.0 github.com/jackc/puddle/v2 v2.2.1 github.com/oklog/ulid/v2 v2.1.0 + github.com/riverqueue/river/riverdriver v0.0.0-00010101000000-000000000000 github.com/riverqueue/river/riverdriver/riverpgxv5 v0.0.12 + github.com/riverqueue/river/riverdriver/riverdatabasesql v0.0.0-00010101000000-000000000000 github.com/robfig/cron/v3 v3.0.1 github.com/spf13/cobra v1.8.0 github.com/stretchr/testify v1.8.4 @@ -23,6 +29,7 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/lib/pq v1.10.9 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect golang.org/x/crypto v0.15.0 // indirect diff --git a/go.sum b/go.sum index 874971f7..21e9a72d 100644 --- a/go.sum +++ b/go.sum @@ -18,13 +18,13 @@ github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/oklog/ulid/v2 v2.1.0 h1:+9lhoxAP56we25tyYETBBY1YLA2SaoLvUFgrP2miPJU= github.com/oklog/ulid/v2 v2.1.0/go.mod h1:rcEKHmBBKfef9DhnvX7y1HZBYxjXb0cP5ExxNsTT1QQ= github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/riverqueue/river/riverdriver/riverpgxv5 v0.0.12 h1:mcDBnqwzEXY9WDOwbkd8xmFdSr/H6oHb1F3NCNCmLDY= -github.com/riverqueue/river/riverdriver/riverpgxv5 v0.0.12/go.mod h1:k6hsPkW9Fl3qURzyLHbvxUCqWDpit0WrZ3oEaKezD3E= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= diff --git a/internal/dbsqlc/models.go b/internal/dbsqlc/models.go index 701196e5..0a371bcf 100644 --- a/internal/dbsqlc/models.go +++ b/internal/dbsqlc/models.go @@ -82,9 +82,3 @@ type RiverLeader struct { LeaderID string Name string } - -type RiverMigration struct { - ID int64 - CreatedAt time.Time - Version int64 -} diff --git a/internal/dbsqlc/sqlc.yaml b/internal/dbsqlc/sqlc.yaml index a2820edb..04556ee7 100644 --- a/internal/dbsqlc/sqlc.yaml +++ b/internal/dbsqlc/sqlc.yaml @@ -4,11 +4,9 @@ sql: queries: - river_job.sql - river_leader.sql - - river_migration.sql schema: - river_job.sql - river_leader.sql - - river_migration.sql gen: go: package: "dbsqlc" diff --git a/internal/riverinternaltest/riverinternaltest.go b/internal/riverinternaltest/riverinternaltest.go index cc1edda4..fa51e835 100644 --- a/internal/riverinternaltest/riverinternaltest.go +++ b/internal/riverinternaltest/riverinternaltest.go @@ -8,6 +8,7 @@ import ( "fmt" "log" "log/slog" + "net/url" "os" "runtime" "sync" @@ -52,19 +53,39 @@ func BaseServiceArchetype(tb testing.TB) *baseservice.Archetype { } func DatabaseConfig(databaseName string) *pgxpool.Config { - databaseURL := valutil.ValOrDefault(os.Getenv("TEST_DATABASE_URL"), "postgres:///river_testdb?sslmode=disable") - - config, err := pgxpool.ParseConfig(databaseURL) + config, err := pgxpool.ParseConfig(DatabaseURL(databaseName)) if err != nil { panic(fmt.Sprintf("error parsing database URL: %v", err)) } config.MaxConns = dbPoolMaxConns config.ConnConfig.ConnectTimeout = 10 * time.Second - config.ConnConfig.Database = databaseName config.ConnConfig.RuntimeParams["timezone"] = "UTC" return config } +// DatabaseURL gets a test database URL from TEST_DATABASE_URL or falls back on +// a default pointing to `river_testdb`. If databaseName is set, it replaces the +// database in the URL, although the host and other parameters are preserved. +// +// Most of the time DatabaseConfig should be used instead of this function, but +// it may be useful in non-pgx situations like for examples showing the use of +// `database/sql`. +func DatabaseURL(databaseName string) string { + u, err := url.Parse(valutil.ValOrDefault( + os.Getenv("TEST_DATABASE_URL"), + "postgres://localhost/river_testdb?sslmode=disable"), + ) + if err != nil { + panic(err) + } + + if databaseName != "" { + u.Path = databaseName + } + + return u.String() +} + // DiscardContinuously drains continuously out of the given channel and discards // anything that comes out of it. Returns a stop function that should be invoked // to stop draining. Stop must be invoked before tests finish to stop an diff --git a/internal/util/dbutil/db_util.go b/internal/util/dbutil/db_util.go index 9a31c601..bc50d4b4 100644 --- a/internal/util/dbutil/db_util.go +++ b/internal/util/dbutil/db_util.go @@ -7,6 +7,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/riverqueue/river/internal/dbsqlc" + "github.com/riverqueue/river/riverdriver" ) // Executor is an interface for a type that can begin a transaction and also @@ -56,3 +57,35 @@ func WithTxV[T any](ctx context.Context, txBeginner TxBeginner, innerFunc func(c return res, nil } + +// WithExecutorTx starts and commits a transaction on a driver executor around +// the given function, allowing the return of a generic value. +func WithExecutorTx(ctx context.Context, exec riverdriver.Executor, innerFunc func(ctx context.Context, tx riverdriver.ExecutorTx) error) error { + _, err := WithExecutorTxV(ctx, exec, func(ctx context.Context, tx riverdriver.ExecutorTx) (struct{}, error) { + return struct{}{}, innerFunc(ctx, tx) + }) + return err +} + +// WithExecutorTxV starts and commits a transaction on a driver executor around +// the given function, allowing the return of a generic value. +func WithExecutorTxV[T any](ctx context.Context, exec riverdriver.Executor, innerFunc func(ctx context.Context, tx riverdriver.ExecutorTx) (T, error)) (T, error) { + var defaultRes T + + tx, err := exec.Begin(ctx) + if err != nil { + return defaultRes, fmt.Errorf("error beginning transaction: %w", err) + } + defer tx.Rollback(ctx) + + res, err := innerFunc(ctx, tx) + if err != nil { + return defaultRes, err + } + + if err := tx.Commit(ctx); err != nil { + return defaultRes, fmt.Errorf("error committing transaction: %w", err) + } + + return res, nil +} diff --git a/internal/util/dbutil/db_util_test.go b/internal/util/dbutil/db_util_test.go index 88b6eb99..c0f45cf7 100644 --- a/internal/util/dbutil/db_util_test.go +++ b/internal/util/dbutil/db_util_test.go @@ -8,6 +8,8 @@ import ( "github.com/stretchr/testify/require" "github.com/riverqueue/river/internal/riverinternaltest" + "github.com/riverqueue/river/riverdriver" + "github.com/riverqueue/river/riverdriver/riverpgxv5" ) func TestWithTx(t *testing.T) { @@ -40,3 +42,36 @@ func TestWithTxV(t *testing.T) { require.NoError(t, err) require.Equal(t, 7, ret) } + +func TestWithExecutorTx(t *testing.T) { + t.Parallel() + + ctx := context.Background() + dbPool := riverinternaltest.TestDB(ctx, t) + driver := riverpgxv5.New(dbPool) + + err := WithExecutorTx(ctx, driver.GetExecutor(), func(ctx context.Context, tx riverdriver.ExecutorTx) error { + _, err := tx.Exec(ctx, "SELECT 1") + require.NoError(t, err) + + return nil + }) + require.NoError(t, err) +} + +func TestWithExecutorTxV(t *testing.T) { + t.Parallel() + + ctx := context.Background() + dbPool := riverinternaltest.TestDB(ctx, t) + driver := riverpgxv5.New(dbPool) + + ret, err := WithExecutorTxV(ctx, driver.GetExecutor(), func(ctx context.Context, tx riverdriver.ExecutorTx) (int, error) { + _, err := tx.Exec(ctx, "SELECT 1") + require.NoError(t, err) + + return 7, nil + }) + require.NoError(t, err) + require.Equal(t, 7, ret) +} diff --git a/riverdriver/go.mod b/riverdriver/go.mod new file mode 100644 index 00000000..def94ac2 --- /dev/null +++ b/riverdriver/go.mod @@ -0,0 +1,14 @@ +module github.com/riverqueue/river/riverdriver + +go 1.21.4 + +require github.com/jackc/pgx/v5 v5.5.0 + +require ( + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect + golang.org/x/crypto v0.15.0 // indirect + golang.org/x/sync v0.5.0 // indirect + golang.org/x/text v0.14.0 // indirect +) diff --git a/riverdriver/go.sum b/riverdriver/go.sum new file mode 100644 index 00000000..b9c08498 --- /dev/null +++ b/riverdriver/go.sum @@ -0,0 +1,28 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.5.0 h1:NxstgwndsTRy7eq9/kqYc/BZh5w2hHJV86wjvO+1xPw= +github.com/jackc/pgx/v5 v5.5.0/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +golang.org/x/crypto v0.15.0 h1:frVn1TEaCEaZcn3Tmd7Y2b5KKPaZ+I32Q2OA3kYp5TA= +golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g= +golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= +golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/riverdriver/river_driver_interface.go b/riverdriver/river_driver_interface.go index c47c08aa..b944ba6b 100644 --- a/riverdriver/river_driver_interface.go +++ b/riverdriver/river_driver_interface.go @@ -13,10 +13,20 @@ package riverdriver import ( + "context" + "errors" + "time" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" ) +var ( + ErrNotImplemented = errors.New("driver does not implement this functionality") + ErrNoRows = errors.New("no rows found") + ErrSubTxNotSupported = errors.New("subtransactions not supported for this driver") +) + // Driver provides a database driver for use with river.Client. // // Its purpose is to wrap the interface of a third party database package, with @@ -32,10 +42,95 @@ import ( type Driver[TTx any] interface { // GetDBPool returns a database pool.This doesn't make sense in a world // where multiple drivers are supported and is subject to change. + // + // API is not stable. DO NOT USE. GetDBPool() *pgxpool.Pool + // GetExecutor gets an executor for the driver. + // + // API is not stable. DO NOT USE. + GetExecutor() Executor + + // UnwrapExecutor gets unwraps executor from a driver transaction. + // + // API is not stable. DO NOT USE. + UnwrapExecutor(tx TTx) Executor + // UnwrapTx turns a generically typed transaction into a pgx.Tx for use with // internal infrastructure. This doesn't make sense in a world where // multiple drivers are supported and is subject to change. + // + // API is not stable. DO NOT USE. UnwrapTx(tx TTx) pgx.Tx } + +// Executor provides River operations against a database. It may be a database +// pool or transaction. +type Executor interface { + // Begin begins a new subtransaction. ErrSubTxNotSupported may be returned + // if the executor is a transaction and the driver doesn't support + // subtransactions (like riverdriver/riverdatabasesql for database/sql). + // + // API is not stable. DO NOT USE. + Begin(ctx context.Context) (ExecutorTx, error) + + // Exec executes raw SQL. Used for migrations. + // + // API is not stable. DO NOT USE. + Exec(ctx context.Context, sql string) (struct{}, error) + + // MigrationDeleteByVersionMany deletes many migration versions. + // + // API is not stable. DO NOT USE. + MigrationDeleteByVersionMany(ctx context.Context, versions []int) ([]*Migration, error) + + // MigrationGetAll gets all currently applied migrations. + // + // API is not stable. DO NOT USE. + MigrationGetAll(ctx context.Context) ([]*Migration, error) + + // MigrationInsertMany inserts many migration versions. + // + // API is not stable. DO NOT USE. + MigrationInsertMany(ctx context.Context, versions []int) ([]*Migration, error) + + // TableExists checks whether a table exists for the schema in the current + // search schema. + // + // API is not stable. DO NOT USE. + TableExists(ctx context.Context, tableName string) (bool, error) +} + +// ExecutorTx is an executor which is a transaction. In addition to standard +// Executor operations, it may be committed or rolled back. +type ExecutorTx interface { + Executor + + // Commit commits the transaction. + // + // API is not stable. DO NOT USE. + Commit(ctx context.Context) error + + // Rollback rolls back the transaction. + // + // API is not stable. DO NOT USE. + Rollback(ctx context.Context) error +} + +// Migration represents a River migration. +type Migration struct { + // ID is an automatically generated primary key for the migration. + // + // API is not stable. DO NOT USE. + ID int + + // CreatedAt is when the migration was initially created. + // + // API is not stable. DO NOT USE. + CreatedAt time.Time + + // Version is the version of the migration. + // + // API is not stable. DO NOT USE. + Version int +} diff --git a/riverdriver/river_driver_interface_test.go b/riverdriver/river_driver_interface_test.go index 221ace9d..ac16273f 100644 --- a/riverdriver/river_driver_interface_test.go +++ b/riverdriver/river_driver_interface_test.go @@ -1,10 +1 @@ package riverdriver - -import ( - "github.com/jackc/pgx/v5" - - "github.com/riverqueue/river/riverdriver/riverpgxv5" -) - -// Verify interface compliance. -var _ Driver[pgx.Tx] = &riverpgxv5.Driver{} diff --git a/riverdriver/riverdatabasesql/go.mod b/riverdriver/riverdatabasesql/go.mod new file mode 100644 index 00000000..e4663345 --- /dev/null +++ b/riverdriver/riverdatabasesql/go.mod @@ -0,0 +1,25 @@ +module github.com/riverqueue/river/riverdriver/riverdatabasesql + +go 1.21.4 + +replace github.com/riverqueue/river/riverdriver => ../ + +require ( + github.com/jackc/pgx/v5 v5.5.0 + github.com/lib/pq v1.10.9 + github.com/riverqueue/river/riverdriver v0.0.0-00010101000000-000000000000 + github.com/stretchr/testify v1.8.1 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/riverqueue/river/riverdriver/riverpgxv5 v0.0.12 // indirect + golang.org/x/crypto v0.15.0 // indirect + golang.org/x/sync v0.5.0 // indirect + golang.org/x/text v0.14.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/riverdriver/riverdatabasesql/go.sum b/riverdriver/riverdatabasesql/go.sum new file mode 100644 index 00000000..18cd465b --- /dev/null +++ b/riverdriver/riverdatabasesql/go.sum @@ -0,0 +1,44 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.5.0 h1:NxstgwndsTRy7eq9/kqYc/BZh5w2hHJV86wjvO+1xPw= +github.com/jackc/pgx/v5 v5.5.0/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/riverqueue/river/riverdriver/riverpgxv5 v0.0.12 h1:mcDBnqwzEXY9WDOwbkd8xmFdSr/H6oHb1F3NCNCmLDY= +github.com/riverqueue/river/riverdriver/riverpgxv5 v0.0.12/go.mod h1:k6hsPkW9Fl3qURzyLHbvxUCqWDpit0WrZ3oEaKezD3E= +github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +golang.org/x/crypto v0.15.0 h1:frVn1TEaCEaZcn3Tmd7Y2b5KKPaZ+I32Q2OA3kYp5TA= +golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g= +golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= +golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/riverdriver/riverdatabasesql/internal/dbsqlc/db.go b/riverdriver/riverdatabasesql/internal/dbsqlc/db.go new file mode 100644 index 00000000..57d06d7d --- /dev/null +++ b/riverdriver/riverdatabasesql/internal/dbsqlc/db.go @@ -0,0 +1,24 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 + +package dbsqlc + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New() *Queries { + return &Queries{} +} + +type Queries struct { +} diff --git a/riverdriver/riverdatabasesql/internal/dbsqlc/models.go b/riverdriver/riverdatabasesql/internal/dbsqlc/models.go new file mode 100644 index 00000000..b42c6819 --- /dev/null +++ b/riverdriver/riverdatabasesql/internal/dbsqlc/models.go @@ -0,0 +1,15 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 + +package dbsqlc + +import ( + "time" +) + +type RiverMigration struct { + ID int64 + CreatedAt time.Time + Version int64 +} diff --git a/riverdriver/riverdatabasesql/internal/dbsqlc/river_migration.sql.go b/riverdriver/riverdatabasesql/internal/dbsqlc/river_migration.sql.go new file mode 100644 index 00000000..66048d8f --- /dev/null +++ b/riverdriver/riverdatabasesql/internal/dbsqlc/river_migration.sql.go @@ -0,0 +1,129 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 +// source: river_migration.sql + +package dbsqlc + +import ( + "context" + + "github.com/lib/pq" +) + +const riverMigrationDeleteByVersionMany = `-- name: RiverMigrationDeleteByVersionMany :many +DELETE FROM river_migration +WHERE version = any($1::bigint[]) +RETURNING id, created_at, version +` + +func (q *Queries) RiverMigrationDeleteByVersionMany(ctx context.Context, db DBTX, version []int64) ([]*RiverMigration, error) { + rows, err := db.QueryContext(ctx, riverMigrationDeleteByVersionMany, pq.Array(version)) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverMigration + for rows.Next() { + var i RiverMigration + if err := rows.Scan(&i.ID, &i.CreatedAt, &i.Version); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const riverMigrationGetAll = `-- name: RiverMigrationGetAll :many +SELECT id, created_at, version +FROM river_migration +ORDER BY version +` + +func (q *Queries) RiverMigrationGetAll(ctx context.Context, db DBTX) ([]*RiverMigration, error) { + rows, err := db.QueryContext(ctx, riverMigrationGetAll) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverMigration + for rows.Next() { + var i RiverMigration + if err := rows.Scan(&i.ID, &i.CreatedAt, &i.Version); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const riverMigrationInsert = `-- name: RiverMigrationInsert :one +INSERT INTO river_migration ( + version +) VALUES ( + $1 +) RETURNING id, created_at, version +` + +func (q *Queries) RiverMigrationInsert(ctx context.Context, db DBTX, version int64) (*RiverMigration, error) { + row := db.QueryRowContext(ctx, riverMigrationInsert, version) + var i RiverMigration + err := row.Scan(&i.ID, &i.CreatedAt, &i.Version) + return &i, err +} + +const riverMigrationInsertMany = `-- name: RiverMigrationInsertMany :many +INSERT INTO river_migration ( + version +) +SELECT + unnest($1::bigint[]) +RETURNING id, created_at, version +` + +func (q *Queries) RiverMigrationInsertMany(ctx context.Context, db DBTX, version []int64) ([]*RiverMigration, error) { + rows, err := db.QueryContext(ctx, riverMigrationInsertMany, pq.Array(version)) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverMigration + for rows.Next() { + var i RiverMigration + if err := rows.Scan(&i.ID, &i.CreatedAt, &i.Version); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const tableExists = `-- name: TableExists :one +SELECT CASE WHEN to_regclass($1) IS NULL THEN false + ELSE true END +` + +func (q *Queries) TableExists(ctx context.Context, db DBTX, tableName string) (bool, error) { + row := db.QueryRowContext(ctx, tableExists, tableName) + var column_1 bool + err := row.Scan(&column_1) + return column_1, err +} diff --git a/riverdriver/riverdatabasesql/internal/dbsqlc/sqlc.yaml b/riverdriver/riverdatabasesql/internal/dbsqlc/sqlc.yaml new file mode 100644 index 00000000..ec379388 --- /dev/null +++ b/riverdriver/riverdatabasesql/internal/dbsqlc/sqlc.yaml @@ -0,0 +1,24 @@ +version: "2" +sql: + - engine: "postgresql" + queries: + - ../../../riverpgxv5/internal/dbsqlc/river_migration.sql + schema: + - ../../../riverpgxv5/internal/dbsqlc/river_migration.sql + gen: + go: + package: "dbsqlc" + sql_package: "database/sql" + out: "." + emit_exact_table_names: true + emit_methods_with_db_argument: true + emit_result_struct_pointers: true + + overrides: + - db_type: "timestamptz" + go_type: "time.Time" + - db_type: "timestamptz" + go_type: + type: "time.Time" + pointer: true + nullable: true diff --git a/riverdriver/riverdatabasesql/river_database_sql.go b/riverdriver/riverdatabasesql/river_database_sql.go new file mode 100644 index 00000000..8b1136f6 --- /dev/null +++ b/riverdriver/riverdatabasesql/river_database_sql.go @@ -0,0 +1,139 @@ +// Package riverdatabasesql bundles a River driver for Go's built in database/sql. +// +// This is _not_ a fully functional driver, and only supports use through +// rivermigrate for purposes of interacting with migration frameworks like +// Goose. Using it with a River client will panic. +package riverdatabasesql + +import ( + "context" + "database/sql" + "errors" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/riverqueue/river/riverdriver" + "github.com/riverqueue/river/riverdriver/riverdatabasesql/internal/dbsqlc" +) + +// Driver is an implementation of riverdriver.Driver for database/sql. +type Driver struct { + dbPool *sql.DB + queries *dbsqlc.Queries +} + +// New returns a new database/sql River driver for use with River. +// +// It takes an sql.DB to use for use with River. The pool should already be +// configured to use the schema specified in the client's Schema field. The pool +// must not be closed while associated River objects are running. +// +// This is _not_ a fully functional driver, and only supports use through +// rivermigrate for purposes of interacting with migration frameworks like +// Goose. Using it with a River client will panic. +func New(dbPool *sql.DB) *Driver { + return &Driver{dbPool: dbPool, queries: dbsqlc.New()} +} + +func (d *Driver) GetDBPool() *pgxpool.Pool { panic(riverdriver.ErrNotImplemented) } +func (d *Driver) GetExecutor() riverdriver.Executor { + return &Executor{d.dbPool, d.dbPool, dbsqlc.New()} +} + +func (d *Driver) UnwrapExecutor(tx *sql.Tx) riverdriver.Executor { + return &Executor{nil, tx, dbsqlc.New()} +} +func (d *Driver) UnwrapTx(tx *sql.Tx) pgx.Tx { panic(riverdriver.ErrNotImplemented) } + +type Executor struct { + dbPool *sql.DB + dbtx dbsqlc.DBTX + queries *dbsqlc.Queries +} + +func (e *Executor) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) { + if e.dbPool == nil { + return nil, riverdriver.ErrSubTxNotSupported + } + + tx, err := e.dbPool.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + return &ExecutorTx{Executor: Executor{nil, tx, e.queries}, tx: tx}, nil +} + +func (e *Executor) Exec(ctx context.Context, sql string) (struct{}, error) { + _, err := e.dbtx.ExecContext(ctx, sql) + return struct{}{}, interpretError(err) +} + +func (e *Executor) MigrationDeleteByVersionMany(ctx context.Context, versions []int) ([]*riverdriver.Migration, error) { + migrations, err := e.queries.RiverMigrationDeleteByVersionMany(ctx, e.dbtx, + mapSlice(versions, func(v int) int64 { return int64(v) })) + return mapMigrations(migrations), interpretError(err) +} + +func (e *Executor) MigrationGetAll(ctx context.Context) ([]*riverdriver.Migration, error) { + migrations, err := e.queries.RiverMigrationGetAll(ctx, e.dbtx) + return mapMigrations(migrations), interpretError(err) +} + +func (e *Executor) MigrationInsertMany(ctx context.Context, versions []int) ([]*riverdriver.Migration, error) { + migrations, err := e.queries.RiverMigrationInsertMany(ctx, e.dbtx, + mapSlice(versions, func(v int) int64 { return int64(v) })) + return mapMigrations(migrations), interpretError(err) +} + +func (e *Executor) TableExists(ctx context.Context, tableName string) (bool, error) { + exists, err := e.queries.TableExists(ctx, e.dbtx, tableName) + return exists, interpretError(err) +} + +type ExecutorTx struct { + Executor + tx *sql.Tx +} + +func (t *ExecutorTx) Commit(ctx context.Context) error { + // unfortunately, `database/sql` does not take a context ... + return t.tx.Commit() +} + +func (t *ExecutorTx) Rollback(ctx context.Context) error { + // unfortunately, `database/sql` does not take a context ... + return t.tx.Rollback() +} + +func interpretError(err error) error { + if errors.Is(err, sql.ErrNoRows) { + return riverdriver.ErrNoRows + } + return err +} + +func mapMigrations(migrations []*dbsqlc.RiverMigration) []*riverdriver.Migration { + if migrations == nil { + return nil + } + + return mapSlice(migrations, func(m *dbsqlc.RiverMigration) *riverdriver.Migration { + return &riverdriver.Migration{ + ID: int(m.ID), + CreatedAt: m.CreatedAt, + Version: int(m.Version), + } + }) +} + +// mapSlice manipulates a slice and transforms it to a slice of another type. +func mapSlice[T any, R any](collection []T, mapFunc func(T) R) []R { + result := make([]R, len(collection)) + + for i, item := range collection { + result[i] = mapFunc(item) + } + + return result +} diff --git a/riverdriver/riverdatabasesql/river_database_sql_test.go b/riverdriver/riverdatabasesql/river_database_sql_test.go new file mode 100644 index 00000000..f97be426 --- /dev/null +++ b/riverdriver/riverdatabasesql/river_database_sql_test.go @@ -0,0 +1,41 @@ +package riverdatabasesql + +import ( + "database/sql" + "errors" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/riverqueue/river/riverdriver" +) + +// Verify interface compliance. +var _ riverdriver.Driver[*sql.Tx] = New(nil) + +func TestNew(t *testing.T) { + t.Parallel() + + t.Run("AllowsNilDatabasePool", func(t *testing.T) { + t.Parallel() + + dbPool := &sql.DB{} + driver := New(dbPool) + require.Equal(t, dbPool, driver.dbPool) + }) + + t.Run("AllowsNilDatabasePool", func(t *testing.T) { + t.Parallel() + + driver := New(nil) + require.Nil(t, driver.dbPool) + }) +} + +func TestInterpretError(t *testing.T) { + t.Parallel() + + require.EqualError(t, interpretError(errors.New("an error")), "an error") + require.ErrorIs(t, interpretError(sql.ErrNoRows), riverdriver.ErrNoRows) + require.NoError(t, interpretError(nil)) +} diff --git a/riverdriver/riverpgxv5/go.mod b/riverdriver/riverpgxv5/go.mod index e6027196..c782fb86 100644 --- a/riverdriver/riverpgxv5/go.mod +++ b/riverdriver/riverpgxv5/go.mod @@ -1,9 +1,12 @@ module github.com/riverqueue/river/riverdriver/riverpgxv5 -go 1.21.0 +go 1.21.4 + +replace github.com/riverqueue/river/riverdriver => ../ require ( github.com/jackc/pgx/v5 v5.5.0 + github.com/riverqueue/river/riverdriver v0.0.0-00010101000000-000000000000 github.com/stretchr/testify v1.8.1 ) diff --git a/riverdriver/riverpgxv5/internal/dbsqlc/db.go b/riverdriver/riverpgxv5/internal/dbsqlc/db.go new file mode 100644 index 00000000..27ee12e2 --- /dev/null +++ b/riverdriver/riverpgxv5/internal/dbsqlc/db.go @@ -0,0 +1,25 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 + +package dbsqlc + +import ( + "context" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +type DBTX interface { + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row +} + +func New() *Queries { + return &Queries{} +} + +type Queries struct { +} diff --git a/riverdriver/riverpgxv5/internal/dbsqlc/models.go b/riverdriver/riverpgxv5/internal/dbsqlc/models.go new file mode 100644 index 00000000..b42c6819 --- /dev/null +++ b/riverdriver/riverpgxv5/internal/dbsqlc/models.go @@ -0,0 +1,15 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 + +package dbsqlc + +import ( + "time" +) + +type RiverMigration struct { + ID int64 + CreatedAt time.Time + Version int64 +} diff --git a/internal/dbsqlc/river_migration.sql b/riverdriver/riverpgxv5/internal/dbsqlc/river_migration.sql similarity index 76% rename from internal/dbsqlc/river_migration.sql rename to riverdriver/riverpgxv5/internal/dbsqlc/river_migration.sql index a9ce74f8..fa7a65ba 100644 --- a/internal/dbsqlc/river_migration.sql +++ b/riverdriver/riverpgxv5/internal/dbsqlc/river_migration.sql @@ -5,7 +5,7 @@ CREATE TABLE river_migration( CONSTRAINT version CHECK (version >= 1) ); --- name: RiverMigrationDeleteByVersionMany :one +-- name: RiverMigrationDeleteByVersionMany :many DELETE FROM river_migration WHERE version = any(@version::bigint[]) RETURNING *; @@ -28,4 +28,8 @@ INSERT INTO river_migration ( ) SELECT unnest(@version::bigint[]) -RETURNING *; \ No newline at end of file +RETURNING *; + +-- name: TableExists :one +SELECT CASE WHEN to_regclass(@table_name) IS NULL THEN false + ELSE true END; \ No newline at end of file diff --git a/internal/dbsqlc/river_migration.sql.go b/riverdriver/riverpgxv5/internal/dbsqlc/river_migration.sql.go similarity index 71% rename from internal/dbsqlc/river_migration.sql.go rename to riverdriver/riverpgxv5/internal/dbsqlc/river_migration.sql.go index 2bffffe8..583b8d59 100644 --- a/internal/dbsqlc/river_migration.sql.go +++ b/riverdriver/riverpgxv5/internal/dbsqlc/river_migration.sql.go @@ -9,17 +9,30 @@ import ( "context" ) -const riverMigrationDeleteByVersionMany = `-- name: RiverMigrationDeleteByVersionMany :one +const riverMigrationDeleteByVersionMany = `-- name: RiverMigrationDeleteByVersionMany :many DELETE FROM river_migration WHERE version = any($1::bigint[]) RETURNING id, created_at, version ` -func (q *Queries) RiverMigrationDeleteByVersionMany(ctx context.Context, db DBTX, version []int64) (*RiverMigration, error) { - row := db.QueryRow(ctx, riverMigrationDeleteByVersionMany, version) - var i RiverMigration - err := row.Scan(&i.ID, &i.CreatedAt, &i.Version) - return &i, err +func (q *Queries) RiverMigrationDeleteByVersionMany(ctx context.Context, db DBTX, version []int64) ([]*RiverMigration, error) { + rows, err := db.Query(ctx, riverMigrationDeleteByVersionMany, version) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverMigration + for rows.Next() { + var i RiverMigration + if err := rows.Scan(&i.ID, &i.CreatedAt, &i.Version); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil } const riverMigrationGetAll = `-- name: RiverMigrationGetAll :many @@ -91,3 +104,15 @@ func (q *Queries) RiverMigrationInsertMany(ctx context.Context, db DBTX, version } return items, nil } + +const tableExists = `-- name: TableExists :one +SELECT CASE WHEN to_regclass($1) IS NULL THEN false + ELSE true END +` + +func (q *Queries) TableExists(ctx context.Context, db DBTX, tableName string) (bool, error) { + row := db.QueryRow(ctx, tableExists, tableName) + var column_1 bool + err := row.Scan(&column_1) + return column_1, err +} diff --git a/riverdriver/riverpgxv5/internal/dbsqlc/sqlc.yaml b/riverdriver/riverpgxv5/internal/dbsqlc/sqlc.yaml new file mode 100644 index 00000000..ace0bb62 --- /dev/null +++ b/riverdriver/riverpgxv5/internal/dbsqlc/sqlc.yaml @@ -0,0 +1,24 @@ +version: "2" +sql: + - engine: "postgresql" + queries: + - river_migration.sql + schema: + - river_migration.sql + gen: + go: + package: "dbsqlc" + sql_package: "pgx/v5" + out: "." + emit_exact_table_names: true + emit_methods_with_db_argument: true + emit_result_struct_pointers: true + + overrides: + - db_type: "timestamptz" + go_type: "time.Time" + - db_type: "timestamptz" + go_type: + type: "time.Time" + pointer: true + nullable: true diff --git a/riverdriver/riverpgxv5/river_pgx_v5_driver.go b/riverdriver/riverpgxv5/river_pgx_v5_driver.go index 9c803088..0dc8cc3a 100644 --- a/riverdriver/riverpgxv5/river_pgx_v5_driver.go +++ b/riverdriver/riverpgxv5/river_pgx_v5_driver.go @@ -6,21 +6,27 @@ package riverpgxv5 import ( + "context" + "errors" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" + + "github.com/riverqueue/river/riverdriver" + "github.com/riverqueue/river/riverdriver/riverpgxv5/internal/dbsqlc" ) // Driver is an implementation of riverdriver.Driver for Pgx v5. type Driver struct { - dbPool *pgxpool.Pool + dbPool *pgxpool.Pool + queries *dbsqlc.Queries } -// New returns a new Pgx v5 River driver for use with river.Client. +// New returns a new Pgx v5 River driver for use with River. // // It takes a pgxpool.Pool to use for use with River. The pool should already be // configured to use the schema specified in the client's Schema field. The pool -// must not be closed while the associated client is running (not until graceful -// shutdown has completed). +// must not be closed while associated River objects are running. // // The database pool may be nil. If it is, a client that it's sent into will not // be able to start up (calls to Start will error) and the Insert and InsertMany @@ -29,8 +35,98 @@ type Driver struct { // in testing so that inserts can be performed and verified on a test // transaction that will be rolled back. func New(dbPool *pgxpool.Pool) *Driver { - return &Driver{dbPool: dbPool} + return &Driver{dbPool: dbPool, queries: dbsqlc.New()} +} + +func (d *Driver) GetDBPool() *pgxpool.Pool { return d.dbPool } +func (d *Driver) GetExecutor() riverdriver.Executor { return &Executor{d.dbPool, dbsqlc.New()} } +func (d *Driver) UnwrapExecutor(tx pgx.Tx) riverdriver.Executor { return &Executor{tx, dbsqlc.New()} } +func (d *Driver) UnwrapTx(tx pgx.Tx) pgx.Tx { return tx } + +type Executor struct { + dbtx interface { + dbsqlc.DBTX + Begin(ctx context.Context) (pgx.Tx, error) + } + queries *dbsqlc.Queries +} + +func (e *Executor) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) { + tx, err := e.dbtx.Begin(ctx) + if err != nil { + return nil, err + } + return &ExecutorTx{Executor: Executor{tx, e.queries}, tx: tx}, nil +} + +func (e *Executor) Exec(ctx context.Context, sql string) (struct{}, error) { + _, err := e.dbtx.Exec(ctx, sql) + return struct{}{}, interpretError(err) +} + +func (e *Executor) MigrationDeleteByVersionMany(ctx context.Context, versions []int) ([]*riverdriver.Migration, error) { + migrations, err := e.queries.RiverMigrationDeleteByVersionMany(ctx, e.dbtx, + mapSlice(versions, func(v int) int64 { return int64(v) })) + return mapMigrations(migrations), interpretError(err) +} + +func (e *Executor) MigrationGetAll(ctx context.Context) ([]*riverdriver.Migration, error) { + migrations, err := e.queries.RiverMigrationGetAll(ctx, e.dbtx) + return mapMigrations(migrations), interpretError(err) +} + +func (e *Executor) MigrationInsertMany(ctx context.Context, versions []int) ([]*riverdriver.Migration, error) { + migrations, err := e.queries.RiverMigrationInsertMany(ctx, e.dbtx, + mapSlice(versions, func(v int) int64 { return int64(v) })) + return mapMigrations(migrations), interpretError(err) } -func (d *Driver) GetDBPool() *pgxpool.Pool { return d.dbPool } -func (d *Driver) UnwrapTx(tx pgx.Tx) pgx.Tx { return tx } +func (e *Executor) TableExists(ctx context.Context, tableName string) (bool, error) { + exists, err := e.queries.TableExists(ctx, e.dbtx, tableName) + return exists, interpretError(err) +} + +type ExecutorTx struct { + Executor + tx pgx.Tx +} + +func (t *ExecutorTx) Commit(ctx context.Context) error { + return t.tx.Commit(ctx) +} + +func (t *ExecutorTx) Rollback(ctx context.Context) error { + return t.tx.Rollback(ctx) +} + +func interpretError(err error) error { + if errors.Is(err, pgx.ErrNoRows) { + return riverdriver.ErrNoRows + } + return err +} + +func mapMigrations(migrations []*dbsqlc.RiverMigration) []*riverdriver.Migration { + if migrations == nil { + return nil + } + + return mapSlice(migrations, func(m *dbsqlc.RiverMigration) *riverdriver.Migration { + return &riverdriver.Migration{ + ID: int(m.ID), + CreatedAt: m.CreatedAt, + Version: int(m.Version), + } + }) +} + +// mapSlice manipulates a slice and transforms it to a slice of another type. +func mapSlice[T any, R any](collection []T, mapFunc func(T) R) []R { + result := make([]R, len(collection)) + + for i, item := range collection { + result[i] = mapFunc(item) + } + + return result +} diff --git a/riverdriver/riverpgxv5/river_pgx_v5_driver_test.go b/riverdriver/riverpgxv5/river_pgx_v5_driver_test.go index 16e5556c..229b735c 100644 --- a/riverdriver/riverpgxv5/river_pgx_v5_driver_test.go +++ b/riverdriver/riverpgxv5/river_pgx_v5_driver_test.go @@ -1,21 +1,42 @@ package riverpgxv5 import ( + "errors" "testing" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/require" + + "github.com/riverqueue/river/riverdriver" ) +// Verify interface compliance. +var _ riverdriver.Driver[pgx.Tx] = New(nil) + func TestNew(t *testing.T) { + t.Parallel() + t.Run("AllowsNilDatabasePool", func(t *testing.T) { + t.Parallel() + dbPool := &pgxpool.Pool{} driver := New(dbPool) require.Equal(t, dbPool, driver.dbPool) }) t.Run("AllowsNilDatabasePool", func(t *testing.T) { + t.Parallel() + driver := New(nil) require.Nil(t, driver.dbPool) }) } + +func TestInterpretError(t *testing.T) { + t.Parallel() + + require.EqualError(t, interpretError(errors.New("an error")), "an error") + require.ErrorIs(t, interpretError(pgx.ErrNoRows), riverdriver.ErrNoRows) + require.NoError(t, interpretError(nil)) +} diff --git a/rivermigrate/example_migrate_database_sql_test.go b/rivermigrate/example_migrate_database_sql_test.go new file mode 100644 index 00000000..20e26960 --- /dev/null +++ b/rivermigrate/example_migrate_database_sql_test.go @@ -0,0 +1,72 @@ +package rivermigrate_test + +import ( + "context" + "database/sql" + "fmt" + "strings" + + _ "github.com/jackc/pgx/v5/stdlib" + + "github.com/riverqueue/river/internal/riverinternaltest" + "github.com/riverqueue/river/riverdriver/riverdatabasesql" + "github.com/riverqueue/river/rivermigrate" +) + +// Example_migrateDatabaseSQL demonstrates the use of River's Go migration API +// through Go's built-in database/sql package. +func Example_migrateDatabaseSQL() { + ctx := context.Background() + + dbPool, err := sql.Open("pgx", riverinternaltest.DatabaseURL("river_testdb_example")) + if err != nil { + panic(err) + } + defer dbPool.Close() + + tx, err := dbPool.BeginTx(ctx, nil) + if err != nil { + panic(err) + } + defer tx.Rollback() + + migrator := rivermigrate.New(riverdatabasesql.New(dbPool), nil) + + // Our test database starts with a full River schema. Drop it so that we can + // demonstrate working migrations. This isn't necessary outside this test. + dropRiverSchema(ctx, migrator, tx) + + printVersions := func(res *rivermigrate.MigrateResult) { + for _, version := range res.Versions { + fmt.Printf("Migrated [%s] version %d\n", strings.ToUpper(string(res.Direction)), version.Version) + } + } + + // Migrate to version 3. An actual call may want to omit all MigrateOpts, + // which will default to applying all available up migrations. + res, err := migrator.MigrateTx(ctx, tx, rivermigrate.DirectionUp, &rivermigrate.MigrateOpts{ + TargetVersion: 3, + }) + if err != nil { + panic(err) + } + printVersions(res) + + // Migrate down by three steps. Down migrating defaults to running only one + // step unless overridden by an option like MaxSteps or TargetVersion. + res, err = migrator.MigrateTx(ctx, tx, rivermigrate.DirectionDown, &rivermigrate.MigrateOpts{ + MaxSteps: 3, + }) + if err != nil { + panic(err) + } + printVersions(res) + + // Output: + // Migrated [UP] version 1 + // Migrated [UP] version 2 + // Migrated [UP] version 3 + // Migrated [DOWN] version 3 + // Migrated [DOWN] version 2 + // Migrated [DOWN] version 1 +} diff --git a/rivermigrate/example_migrate_test.go b/rivermigrate/example_migrate_test.go index 8b8bee0e..9bd26c3b 100644 --- a/rivermigrate/example_migrate_test.go +++ b/rivermigrate/example_migrate_test.go @@ -3,35 +3,15 @@ package rivermigrate_test import ( "context" "fmt" - "sort" "strings" - "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" - "github.com/riverqueue/river" "github.com/riverqueue/river/internal/riverinternaltest" "github.com/riverqueue/river/riverdriver/riverpgxv5" "github.com/riverqueue/river/rivermigrate" ) -type SortArgs struct { - // Strings is a slice of strings to sort. - Strings []string `json:"strings"` -} - -func (SortArgs) Kind() string { return "sort" } - -type SortWorker struct { - river.WorkerDefaults[SortArgs] -} - -func (w *SortWorker) Work(ctx context.Context, job *river.Job[SortArgs]) error { - sort.Strings(job.Args.Strings) - fmt.Printf("Sorted strings: %+v\n", job.Args.Strings) - return nil -} - // Example_migrate demonstrates the use of River's Go migration API by migrating // up and down. func Example_migrate() { @@ -81,11 +61,6 @@ func Example_migrate() { } printVersions(res) - // Roll back all changes applied so our test database is left unaffected. - if err := tx.Rollback(ctx); err != nil { - panic(err) - } - // Output: // Migrated [UP] version 1 // Migrated [UP] version 2 @@ -95,7 +70,7 @@ func Example_migrate() { // Migrated [DOWN] version 1 } -func dropRiverSchema(ctx context.Context, migrator *rivermigrate.Migrator[pgx.Tx], tx pgx.Tx) { +func dropRiverSchema[TTx any](ctx context.Context, migrator *rivermigrate.Migrator[TTx], tx TTx) { _, err := migrator.MigrateTx(ctx, tx, rivermigrate.DirectionDown, &rivermigrate.MigrateOpts{ TargetVersion: -1, }) diff --git a/rivermigrate/river_migrate.go b/rivermigrate/river_migrate.go index d3d54151..c14b54f7 100644 --- a/rivermigrate/river_migrate.go +++ b/rivermigrate/river_migrate.go @@ -5,7 +5,6 @@ package rivermigrate import ( "context" "embed" - "errors" "fmt" "io" "io/fs" @@ -17,10 +16,6 @@ import ( "strings" "time" - "github.com/jackc/pgerrcode" - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgconn" - "github.com/riverqueue/river/internal/baseservice" "github.com/riverqueue/river/internal/dbsqlc" "github.com/riverqueue/river/internal/util/dbutil" @@ -67,8 +62,9 @@ type Migrator[TTx any] struct { // New returns a new migrator with the given database driver and configuration. // The config parameter may be omitted as nil. // -// Currently only one driver is supported, which is Pgx v5. See package -// riverpgxv5. +// Two drivers are supported for migrations, one for Pgx v5 and one for the +// built-in database/sql package for use with migration frameworks like Goose. +// See packages riverpgxv5 and riverdatabasesql respectively. // // The function takes a generic parameter TTx representing a transaction type, // but it can be omitted because it'll generally always be inferred from the @@ -155,8 +151,7 @@ type MigrateVersion struct { Version int } -func migrateVersionToInt(version MigrateVersion) int { return version.Version } -func migrateVersionToInt64(version MigrateVersion) int64 { return int64(version.Version) } +func migrateVersionToInt(version MigrateVersion) int { return version.Version } type Direction string @@ -180,12 +175,12 @@ const ( // // handle error // } func (m *Migrator[TTx]) Migrate(ctx context.Context, direction Direction, opts *MigrateOpts) (*MigrateResult, error) { - return dbutil.WithTxV(ctx, m.driver.GetDBPool(), func(ctx context.Context, tx pgx.Tx) (*MigrateResult, error) { + return dbutil.WithExecutorTxV(ctx, m.driver.GetExecutor(), func(ctx context.Context, tx riverdriver.ExecutorTx) (*MigrateResult, error) { switch direction { case DirectionDown: - return m.migrateDownTx(ctx, tx, direction, opts) + return m.migrateDown(ctx, tx, direction, opts) case DirectionUp: - return m.migrateUpTx(ctx, tx, direction, opts) + return m.migrateUp(ctx, tx, direction, opts) } panic("invalid direction: " + direction) @@ -213,22 +208,22 @@ func (m *Migrator[TTx]) Migrate(ctx context.Context, direction Direction, opts * func (m *Migrator[TTx]) MigrateTx(ctx context.Context, tx TTx, direction Direction, opts *MigrateOpts) (*MigrateResult, error) { switch direction { case DirectionDown: - return m.migrateDownTx(ctx, m.driver.UnwrapTx(tx), direction, opts) + return m.migrateDown(ctx, m.driver.UnwrapExecutor(tx), direction, opts) case DirectionUp: - return m.migrateUpTx(ctx, m.driver.UnwrapTx(tx), direction, opts) + return m.migrateUp(ctx, m.driver.UnwrapExecutor(tx), direction, opts) } panic("invalid direction: " + direction) } -// migrateDownTx runs down migrations. -func (m *Migrator[TTx]) migrateDownTx(ctx context.Context, tx pgx.Tx, direction Direction, opts *MigrateOpts) (*MigrateResult, error) { - existingMigrations, err := m.existingMigrations(ctx, tx) +// migrateDown runs down migrations. +func (m *Migrator[TTx]) migrateDown(ctx context.Context, exec riverdriver.Executor, direction Direction, opts *MigrateOpts) (*MigrateResult, error) { + existingMigrations, err := m.existingMigrations(ctx, exec) if err != nil { return nil, err } existingMigrationsMap := sliceutil.KeyBy(existingMigrations, - func(m *dbsqlc.RiverMigration) (int, struct{}) { return int(m.Version), struct{}{} }) + func(m *riverdriver.Migration) (int, struct{}) { return m.Version, struct{}{} }) targetMigrations := maps.Clone(m.migrations) for version := range targetMigrations { @@ -240,7 +235,7 @@ func (m *Migrator[TTx]) migrateDownTx(ctx context.Context, tx pgx.Tx, direction sortedTargetMigrations := maputil.Values(targetMigrations) slices.SortFunc(sortedTargetMigrations, func(a, b *migrationBundle) int { return b.Version - a.Version }) // reverse order - res, err := m.applyMigrations(ctx, tx, direction, opts, sortedTargetMigrations) + res, err := m.applyMigrations(ctx, exec, direction, opts, sortedTargetMigrations) if err != nil { return nil, err } @@ -259,34 +254,34 @@ func (m *Migrator[TTx]) migrateDownTx(ctx context.Context, tx pgx.Tx, direction return res, nil } - if _, err := m.queries.RiverMigrationDeleteByVersionMany(ctx, tx, sliceutil.Map(res.Versions, migrateVersionToInt64)); err != nil { + if _, err := exec.MigrationDeleteByVersionMany(ctx, sliceutil.Map(res.Versions, migrateVersionToInt)); err != nil { return nil, fmt.Errorf("error deleting migration rows for versions %+v: %w", res.Versions, err) } return res, nil } -// migrateUpTx runs up migrations. -func (m *Migrator[TTx]) migrateUpTx(ctx context.Context, tx pgx.Tx, direction Direction, opts *MigrateOpts) (*MigrateResult, error) { - existingMigrations, err := m.existingMigrations(ctx, tx) +// migrateUp runs up migrations. +func (m *Migrator[TTx]) migrateUp(ctx context.Context, exec riverdriver.Executor, direction Direction, opts *MigrateOpts) (*MigrateResult, error) { + existingMigrations, err := m.existingMigrations(ctx, exec) if err != nil { return nil, err } targetMigrations := maps.Clone(m.migrations) for _, migrateRow := range existingMigrations { - delete(targetMigrations, int(migrateRow.Version)) + delete(targetMigrations, migrateRow.Version) } sortedTargetMigrations := maputil.Values(targetMigrations) slices.SortFunc(sortedTargetMigrations, func(a, b *migrationBundle) int { return a.Version - b.Version }) - res, err := m.applyMigrations(ctx, tx, direction, opts, sortedTargetMigrations) + res, err := m.applyMigrations(ctx, exec, direction, opts, sortedTargetMigrations) if err != nil { return nil, err } - if _, err := m.queries.RiverMigrationInsertMany(ctx, tx, sliceutil.Map(res.Versions, migrateVersionToInt64)); err != nil { + if _, err := exec.MigrationInsertMany(ctx, sliceutil.Map(res.Versions, migrateVersionToInt)); err != nil { return nil, fmt.Errorf("error inserting migration rows for versions %+v: %w", res.Versions, err) } @@ -295,7 +290,7 @@ func (m *Migrator[TTx]) migrateUpTx(ctx context.Context, tx pgx.Tx, direction Di // Common code shared between the up and down migration directions that walks // through each target migration and applies it, logging appropriately. -func (m *Migrator[TTx]) applyMigrations(ctx context.Context, tx pgx.Tx, direction Direction, opts *MigrateOpts, sortedTargetMigrations []*migrationBundle) (*MigrateResult, error) { +func (m *Migrator[TTx]) applyMigrations(ctx context.Context, exec riverdriver.Executor, direction Direction, opts *MigrateOpts, sortedTargetMigrations []*migrationBundle) (*MigrateResult, error) { if opts == nil { opts = &MigrateOpts{} } @@ -355,7 +350,7 @@ func (m *Migrator[TTx]) applyMigrations(ctx context.Context, tx pgx.Tx, directio slog.Int("version", versionBundle.Version), ) - _, err := tx.Exec(ctx, sql) + _, err := exec.Exec(ctx, sql) if err != nil { return nil, fmt.Errorf("error applying version %03d [%s]: %w", versionBundle.Version, strings.ToUpper(string(direction)), err) @@ -377,26 +372,18 @@ func (m *Migrator[TTx]) applyMigrations(ctx context.Context, tx pgx.Tx, directio // the `river_migration` table not existing yet. (The subtransaction is needed // because otherwise the existing transaction would become aborted on an // unsuccessful `river_migration` check.) -func (m *Migrator[TTx]) existingMigrations(ctx context.Context, tx pgx.Tx) ([]*dbsqlc.RiverMigration, error) { - // We start another inner transaction here because in case this is the first - // ever migration run, the transaction may become aborted if `river_migration` - // doesn't exist, a condition which we must handle gracefully. - migrations, err := dbutil.WithTxV(ctx, tx, func(ctx context.Context, tx pgx.Tx) ([]*dbsqlc.RiverMigration, error) { - migrations, err := m.queries.RiverMigrationGetAll(ctx, tx) - if err != nil { - return nil, fmt.Errorf("error getting current migrate rows: %w", err) - } - return migrations, nil - }) +func (m *Migrator[TTx]) existingMigrations(ctx context.Context, exec riverdriver.Executor) ([]*riverdriver.Migration, error) { + exists, err := exec.TableExists(ctx, "river_migration") if err != nil { - var pgErr *pgconn.PgError - if errors.As(err, &pgErr) { - if pgErr.Code == pgerrcode.UndefinedTable && strings.Contains(pgErr.Message, "river_migration") { - return nil, nil - } - } + return nil, fmt.Errorf("error checking if `%s` exists: %w", "river_migration", err) + } + if !exists { + return nil, nil + } - return nil, err + migrations, err := exec.MigrationGetAll(ctx) + if err != nil { + return nil, fmt.Errorf("error getting existing migrations: %w", err) } return migrations, nil diff --git a/rivermigrate/river_migrate_test.go b/rivermigrate/river_migrate_test.go index 7cdb1ac9..f5633834 100644 --- a/rivermigrate/river_migrate_test.go +++ b/rivermigrate/river_migrate_test.go @@ -2,18 +2,23 @@ package rivermigrate import ( "context" + "database/sql" + "log/slog" "slices" "testing" "github.com/jackc/pgerrcode" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/jackc/pgx/v5/stdlib" "github.com/stretchr/testify/require" - "github.com/riverqueue/river/internal/dbsqlc" "github.com/riverqueue/river/internal/riverinternaltest" "github.com/riverqueue/river/internal/util/dbutil" "github.com/riverqueue/river/internal/util/sliceutil" + "github.com/riverqueue/river/riverdriver" + "github.com/riverqueue/river/riverdriver/riverdatabasesql" "github.com/riverqueue/river/riverdriver/riverpgxv5" ) @@ -44,13 +49,13 @@ var ( func TestMigrator(t *testing.T) { t.Parallel() - var ( - ctx = context.Background() - queries = dbsqlc.New() - ) + ctx := context.Background() type testBundle struct { - tx pgx.Tx + dbPool *pgxpool.Pool + driver *riverpgxv5.Driver + logger *slog.Logger + tx pgx.Tx } setup := func(t *testing.T) (*Migrator[pgx.Tx], *testBundle) { @@ -62,24 +67,45 @@ func TestMigrator(t *testing.T) { // we use test DBs instead of test transactions, but this could be // changed to test transactions as long as test cases were made to run // non-parallel. - testDB := riverinternaltest.TestDB(ctx, t) + dbPool := riverinternaltest.TestDB(ctx, t) // Despite being in an isolated database, we still start a transaction // because we don't want schema changes we make to persist. - tx, err := testDB.Begin(ctx) + tx, err := dbPool.Begin(ctx) require.NoError(t, err) t.Cleanup(func() { _ = tx.Rollback(ctx) }) bundle := &testBundle{ - tx: tx, + dbPool: dbPool, + driver: riverpgxv5.New(dbPool), + logger: riverinternaltest.Logger(t), + tx: tx, } - migrator := New(riverpgxv5.New(testDB), nil) + migrator := New(bundle.driver, &Config{Logger: bundle.logger}) migrator.migrations = riverMigrationsWithtestVersionsMap return migrator, bundle } + // Gets a migrator using the driver for `database/sql`. + setupDatabaseSQLMigrator := func(t *testing.T, bundle *testBundle) (*Migrator[*sql.Tx], *sql.Tx) { + t.Helper() + + stdPool := stdlib.OpenDBFromPool(bundle.dbPool) + t.Cleanup(func() { require.NoError(t, stdPool.Close()) }) + + tx, err := stdPool.BeginTx(ctx, nil) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, tx.Rollback()) }) + + driver := riverdatabasesql.New(stdPool) + migrator := New(driver, &Config{Logger: bundle.logger}) + migrator.migrations = riverMigrationsWithtestVersionsMap + + return migrator, tx + } + t.Run("MigrateDownDefault", func(t *testing.T) { t.Parallel() @@ -135,10 +161,10 @@ func TestMigrator(t *testing.T) { require.Equal(t, []int{riverMigrationsWithTestVersionsMaxVersion, riverMigrationsWithTestVersionsMaxVersion - 1}, sliceutil.Map(res.Versions, migrateVersionToInt)) - migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) + migrations, err := bundle.driver.UnwrapExecutor(bundle.tx).MigrationGetAll(ctx) require.NoError(t, err) require.Equal(t, seqOneTo(riverMigrationsWithTestVersionsMaxVersion-2), - sliceutil.Map(migrations, riverMigrationToInt)) + sliceutil.Map(migrations, migrationToInt)) err = dbExecError(ctx, bundle.tx, "SELECT name FROM test_table") require.Error(t, err) @@ -156,10 +182,26 @@ func TestMigrator(t *testing.T) { require.NoError(t, err) require.Equal(t, []int{}, sliceutil.Map(res.Versions, migrateVersionToInt)) - migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) + migrations, err := bundle.driver.UnwrapExecutor(bundle.tx).MigrationGetAll(ctx) require.NoError(t, err) require.Equal(t, seqOneTo(3), - sliceutil.Map(migrations, riverMigrationToInt)) + sliceutil.Map(migrations, migrationToInt)) + }) + + t.Run("MigrateDownWithDatabaseSQLDriver", func(t *testing.T) { + t.Parallel() + + _, bundle := setup(t) + migrator, tx := setupDatabaseSQLMigrator(t, bundle) + + res, err := migrator.MigrateTx(ctx, tx, DirectionDown, &MigrateOpts{MaxSteps: 1}) + require.NoError(t, err) + require.Equal(t, []int{3}, sliceutil.Map(res.Versions, migrateVersionToInt)) + + migrations, err := migrator.driver.UnwrapExecutor(tx).MigrationGetAll(ctx) + require.NoError(t, err) + require.Equal(t, seqOneTo(2), + sliceutil.Map(migrations, migrationToInt)) }) t.Run("MigrateDownWithTargetVersion", func(t *testing.T) { @@ -175,10 +217,10 @@ func TestMigrator(t *testing.T) { require.Equal(t, []int{5, 4}, sliceutil.Map(res.Versions, migrateVersionToInt)) - migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) + migrations, err := bundle.driver.UnwrapExecutor(bundle.tx).MigrationGetAll(ctx) require.NoError(t, err) require.Equal(t, seqOneTo(3), - sliceutil.Map(migrations, riverMigrationToInt)) + sliceutil.Map(migrations, migrationToInt)) err = dbExecError(ctx, bundle.tx, "SELECT name FROM test_table") require.Error(t, err) @@ -242,10 +284,10 @@ func TestMigrator(t *testing.T) { require.Equal(t, []int{riverMigrationsWithTestVersionsMaxVersion - 1, riverMigrationsWithTestVersionsMaxVersion}, sliceutil.Map(res.Versions, migrateVersionToInt)) - migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) + migrations, err := bundle.driver.UnwrapExecutor(bundle.tx).MigrationGetAll(ctx) require.NoError(t, err) require.Equal(t, seqOneTo(riverMigrationsWithTestVersionsMaxVersion), - sliceutil.Map(migrations, riverMigrationToInt)) + sliceutil.Map(migrations, migrationToInt)) _, err = bundle.tx.Exec(ctx, "SELECT * FROM test_table") require.NoError(t, err) @@ -258,10 +300,10 @@ func TestMigrator(t *testing.T) { require.Equal(t, DirectionUp, res.Direction) require.Equal(t, []int{}, sliceutil.Map(res.Versions, migrateVersionToInt)) - migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) + migrations, err := bundle.driver.UnwrapExecutor(bundle.tx).MigrationGetAll(ctx) require.NoError(t, err) require.Equal(t, seqOneTo(riverMigrationsWithTestVersionsMaxVersion), - sliceutil.Map(migrations, riverMigrationToInt)) + sliceutil.Map(migrations, migrationToInt)) _, err = bundle.tx.Exec(ctx, "SELECT * FROM test_table") require.NoError(t, err) @@ -278,10 +320,10 @@ func TestMigrator(t *testing.T) { require.Equal(t, []int{riverMigrationsWithTestVersionsMaxVersion - 1}, sliceutil.Map(res.Versions, migrateVersionToInt)) - migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) + migrations, err := bundle.driver.UnwrapExecutor(bundle.tx).MigrationGetAll(ctx) require.NoError(t, err) require.Equal(t, seqOneTo(riverMigrationsWithTestVersionsMaxVersion-1), - sliceutil.Map(migrations, riverMigrationToInt)) + sliceutil.Map(migrations, migrationToInt)) // Column `name` is only added in the second test version. err = dbExecError(ctx, bundle.tx, "SELECT name FROM test_table") @@ -304,10 +346,26 @@ func TestMigrator(t *testing.T) { require.NoError(t, err) require.Equal(t, []int{}, sliceutil.Map(res.Versions, migrateVersionToInt)) - migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) + migrations, err := bundle.driver.UnwrapExecutor(bundle.tx).MigrationGetAll(ctx) require.NoError(t, err) require.Equal(t, seqOneTo(3), - sliceutil.Map(migrations, riverMigrationToInt)) + sliceutil.Map(migrations, migrationToInt)) + }) + + t.Run("MigrateUpWithDatabaseSQLDriver", func(t *testing.T) { + t.Parallel() + + _, bundle := setup(t) + migrator, tx := setupDatabaseSQLMigrator(t, bundle) + + res, err := migrator.MigrateTx(ctx, tx, DirectionUp, &MigrateOpts{MaxSteps: 1}) + require.NoError(t, err) + require.Equal(t, []int{riverMigrationsMaxVersion + 1}, sliceutil.Map(res.Versions, migrateVersionToInt)) + + migrations, err := migrator.driver.UnwrapExecutor(tx).MigrationGetAll(ctx) + require.NoError(t, err) + require.Equal(t, seqOneTo(riverMigrationsMaxVersion+1), + sliceutil.Map(migrations, migrationToInt)) }) t.Run("MigrateUpWithTargetVersion", func(t *testing.T) { @@ -320,9 +378,9 @@ func TestMigrator(t *testing.T) { require.Equal(t, []int{4, 5}, sliceutil.Map(res.Versions, migrateVersionToInt)) - migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) + migrations, err := bundle.driver.UnwrapExecutor(bundle.tx).MigrationGetAll(ctx) require.NoError(t, err) - require.Equal(t, seqOneTo(5), sliceutil.Map(migrations, riverMigrationToInt)) + require.Equal(t, seqOneTo(5), sliceutil.Map(migrations, migrationToInt)) }) t.Run("MigrateUpWithTargetVersionInvalid", func(t *testing.T) { @@ -354,7 +412,7 @@ func dbExecError(ctx context.Context, executor dbutil.Executor, sql string) erro }) } -func riverMigrationToInt(r *dbsqlc.RiverMigration) int { return int(r.Version) } +func migrationToInt(r *riverdriver.Migration) int { return r.Version } func seqOneTo(max int) []int { seq := make([]int, max)