From a27cf5dbe45fc3cffc8081aff755d8d5a258910a Mon Sep 17 00:00:00 2001 From: Brandur Date: Sat, 2 Dec 2023 18:46:32 -0800 Subject: [PATCH] Support for `database/sql` in migrations + framework for multi-driver River Here, add a new minimal driver called `riverdriver/riversql` that supports Go's built-in `database/sql` package, but only for purposes of migrations. The idea here is to fully complete #57 by providing a way of making `rivermigrate` interoperable with Go migration frameworks that support Go-based migrations like Goose, which provides hooks for `*sql.Tx` [1] rather than pgx. `riverdriver/riversql` is not a full driver and is only meant to be used with `rivermigrate`. We document this clearly in a number of places. To make a multi-driver world possible with River, we have to start the work of building a platform that does more than `riverpgxv5`'s "cheat" workaround. This works by having each driver implement specific database operations like `MigrationGetAll`, which target their wrapped database package of choice. This is accomplished by having each driver bundle in its own sqlc that targets its package. So `riverpgxv5` has an `sqlc.yaml` that targets `pgx/v5`, while `riversql` has one that targets `database/sql`. There's some `sqlc.yaml` duplication involved here, but luckily both drivers can share a `river_migration.sql` file that contains all queries involved, so you only need to change one place. `river_migration.sql` also migrates entirely out of the main `./internal/dbsqlc`. The idea here is that eventually `./internal/dbsqlc` will disappear completely, usurped entirely by driver-specific versions. As this is done, all references to `pgx` will disappear from the top-level package. There are some complications here to figure out like `LISTEN`/`NOTIFY` though, and I'm not clear whether `database/sql` could ever become a fully functional driver as it might be missing some needed functionality (e.g. subtransactions are still not supported after talking about them for ten f*ing years [2]. However, even if it's not, the system would let us support other fully functional packages or future major versions of pgx (or even past ones like `pgx/v4` if there's demand). `river/riverdriver` becomes a package as it now has types in it that need to be referenced by driver implementations, and this would otherwise not be possible without introducing a circular dependency. Notably, this development branch has to use some `go.mod` `replace` directives to demonstrate that it works correctly. If we go this direction, we'll need to break it into chunks to release it without them: 1. Break out changes to `river/riverdriver`. Tag and release it. 2. Break out changes to `riverdriver/river*` drivers. Have them target the release in (1), comment out `replace`s, then tag and release them. 3. Target the remaining River changes to the releases in (1) and (2), comment out `replace`s, then tag and release the top-level driver. Unfortunately future deep incisions to drivers will require similar gymnastics, but I don't think there's a way around it (we already have this process except it's currently two steps instead of three). The hope is that these will change relatively rarely, so it won't be too painful. [1] https://github.com/pressly/goose#go-migrations [2] https://github.com/golang/go/issues/7898 --- .github/workflows/ci.yml | 44 +++++- .golangci.yml | 3 + CHANGELOG.md | 8 + Makefile | 14 +- docs/development.md | 2 + go.mod | 11 +- go.sum | 4 +- internal/dbsqlc/models.go | 6 - internal/dbsqlc/sqlc.yaml | 2 - .../riverinternaltest/riverinternaltest.go | 29 +++- internal/util/dbutil/db_util.go | 33 +++++ internal/util/dbutil/db_util_test.go | 35 +++++ riverdriver/go.mod | 14 ++ riverdriver/go.sum | 28 ++++ riverdriver/river_driver_interface.go | 95 ++++++++++++ riverdriver/river_driver_interface_test.go | 9 -- riverdriver/riverdatabasesql/go.mod | 25 ++++ riverdriver/riverdatabasesql/go.sum | 44 ++++++ .../riverdatabasesql/internal/dbsqlc/db.go | 24 +++ .../internal/dbsqlc/models.go | 15 ++ .../internal/dbsqlc/river_migration.sql.go | 129 ++++++++++++++++ .../internal/dbsqlc/sqlc.yaml | 24 +++ .../riverdatabasesql/river_database_sql.go | 139 ++++++++++++++++++ .../river_database_sql_test.go | 41 ++++++ riverdriver/riverpgxv5/go.mod | 5 +- riverdriver/riverpgxv5/internal/dbsqlc/db.go | 25 ++++ .../riverpgxv5/internal/dbsqlc/models.go | 15 ++ .../internal}/dbsqlc/river_migration.sql | 8 +- .../internal}/dbsqlc/river_migration.sql.go | 37 ++++- .../riverpgxv5/internal/dbsqlc/sqlc.yaml | 24 +++ riverdriver/riverpgxv5/river_pgx_v5_driver.go | 110 +++++++++++++- .../riverpgxv5/river_pgx_v5_driver_test.go | 21 +++ .../example_migrate_database_sql_test.go | 72 +++++++++ rivermigrate/example_migrate_test.go | 27 +--- rivermigrate/river_migrate.go | 79 +++++----- rivermigrate/river_migrate_test.go | 112 ++++++++++---- 36 files changed, 1165 insertions(+), 148 deletions(-) create mode 100644 riverdriver/go.mod create mode 100644 riverdriver/go.sum create mode 100644 riverdriver/riverdatabasesql/go.mod create mode 100644 riverdriver/riverdatabasesql/go.sum create mode 100644 riverdriver/riverdatabasesql/internal/dbsqlc/db.go create mode 100644 riverdriver/riverdatabasesql/internal/dbsqlc/models.go create mode 100644 riverdriver/riverdatabasesql/internal/dbsqlc/river_migration.sql.go create mode 100644 riverdriver/riverdatabasesql/internal/dbsqlc/sqlc.yaml create mode 100644 riverdriver/riverdatabasesql/river_database_sql.go create mode 100644 riverdriver/riverdatabasesql/river_database_sql_test.go create mode 100644 riverdriver/riverpgxv5/internal/dbsqlc/db.go create mode 100644 riverdriver/riverpgxv5/internal/dbsqlc/models.go rename {internal => riverdriver/riverpgxv5/internal}/dbsqlc/river_migration.sql (76%) rename {internal => riverdriver/riverpgxv5/internal}/dbsqlc/river_migration.sql.go (71%) create mode 100644 riverdriver/riverpgxv5/internal/dbsqlc/sqlc.yaml create mode 100644 rivermigrate/example_migrate_database_sql_test.go diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0b08e855..46bc6165 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -72,7 +72,15 @@ jobs: env: TEST_DATABASE_URL: postgres://postgres:postgres@127.0.0.1:5432/river_testdb?sslmode=disable - - name: Test riverpgxv5 + - name: Test riverdriver + working-directory: ./riverdriver + run: go test -race ./... + + - name: Test riverdriver/riverdatabasesql + working-directory: ./riverdriver/riverdatabasesql + run: go test -race ./... + + - name: Test riverdriver/riverpgxv5 working-directory: ./riverdriver/riverpgxv5 run: go test -race ./... @@ -117,10 +125,13 @@ jobs: golangci: name: lint runs-on: ubuntu-latest + env: + GOLANGCI_LINT_VERSION: v1.55.2 permissions: contents: read # allow read access to pull request. Use with `only-new-issues` option. pull-requests: read + steps: - uses: actions/setup-go@v4 with: @@ -130,13 +141,33 @@ jobs: - name: Checkout uses: actions/checkout@v3 - - name: golangci-lint + - name: Lint + uses: golangci/golangci-lint-action@v3 + with: + only-new-issues: true # Optional: show only new issues if it's a pull request. The default value is `false`. + version: ${{ env.GOLANGCI_LINT_VERSION }} + working-directory: . + + - name: Lint riverdriver + uses: golangci/golangci-lint-action@v3 + with: + only-new-issues: true # Optional: show only new issues if it's a pull request. The default value is `false`. + version: ${{ env.GOLANGCI_LINT_VERSION }} + working-directory: ./riverdriver + + - name: Lint riverdriver/riverdatabasesql uses: golangci/golangci-lint-action@v3 with: - # Optional: show only new issues if it's a pull request. The default value is `false`. - only-new-issues: true + only-new-issues: true # Optional: show only new issues if it's a pull request. The default value is `false`. + version: ${{ env.GOLANGCI_LINT_VERSION }} + working-directory: ./riverdriver/riverdatabasesql - version: v1.55.2 + - name: Lint riverdriver/riverpgxv5 + uses: golangci/golangci-lint-action@v3 + with: + only-new-issues: true # Optional: show only new issues if it's a pull request. The default value is `false`. + version: ${{ env.GOLANGCI_LINT_VERSION }} + working-directory: ./riverdriver/riverpgxv5 producer_sample: runs-on: ubuntu-latest @@ -204,7 +235,6 @@ jobs: sqlc-version: "1.22.0" - name: Run sqlc diff - working-directory: ./internal/dbsqlc run: | echo "Please make sure that all sqlc changes are checked in!" - sqlc diff + make verify diff --git a/.golangci.yml b/.golangci.yml index f6d0c697..c1bf07e9 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -64,6 +64,9 @@ linters-settings: - Default - Prefix(github.com/riverqueue) + gomoddirectives: + replace-local: true + gosec: excludes: - G404 # use of non-crypto random; overly broad for our use case diff --git a/CHANGELOG.md b/CHANGELOG.md index af3feb7c..e103c1eb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- Added `riverdriver/riverdatabasesql` driver to enable River Go migrations through Go's built in `database/sql` package. [PR #98](https://github.com/riverqueue/river/pull/98). + +### Changed + +- `riverdriver` becomes its own submodule. It contains types that `riverdriver/riverdatabasesql` and `riverdriver/riverpgxv5` need to reference. [PR #98](https://github.com/riverqueue/river/pull/98). + ## [0.0.12] - 2023-12-02 ### Added diff --git a/Makefile b/Makefile index f06ffe8b..ab10a7a4 100644 --- a/Makefile +++ b/Makefile @@ -4,4 +4,16 @@ generate: generate/sqlc .PHONY: generate/sqlc generate/sqlc: - cd internal/dbsqlc && sqlc generate \ No newline at end of file + cd internal/dbsqlc && sqlc generate + cd riverdriver/riverdatabasesql/internal/dbsqlc && sqlc generate + cd riverdriver/riverpgxv5/internal/dbsqlc && sqlc generate + +.PHONY: verify +verify: +verify: verify/sqlc + +.PHONY: verify/sqlc +verify/sqlc: + cd internal/dbsqlc && sqlc diff + cd riverdriver/riverdatabasesql/internal/dbsqlc && sqlc diff + cd riverdriver/riverpgxv5/internal/dbsqlc && sqlc diff \ No newline at end of file diff --git a/docs/development.md b/docs/development.md index aa5ed01d..a53fa46a 100644 --- a/docs/development.md +++ b/docs/development.md @@ -32,7 +32,9 @@ queries. After changing an sqlc `.sql` file, generate Go with: ```shell git checkout master && git pull --rebase VERSION=v0.0.x +git tag riverdriver/VERSION -m "release riverdriver/VERSION" git tag riverdriver/riverpgxv5/$VERSION -m "release riverdriver/riverpgxv5/$VERSION" +git tag riverdriver/riverdatabasesql/$VERSION -m "release riverdriver/riverdatabasesql/$VERSION" git tag $VERSION git push --tags ``` diff --git a/go.mod b/go.mod index 0c4db7ab..d9b88a5a 100644 --- a/go.mod +++ b/go.mod @@ -1,15 +1,21 @@ module github.com/riverqueue/river -go 1.21.0 +go 1.21.4 -// replace github.com/riverqueue/river/riverdriver/riverpgxv5 => ./riverdriver/riverpgxv5 +replace github.com/riverqueue/river/riverdriver => ./riverdriver + +replace github.com/riverqueue/river/riverdriver/riverpgxv5 => ./riverdriver/riverpgxv5 + +replace github.com/riverqueue/river/riverdriver/riverdatabasesql => ./riverdriver/riverdatabasesql require ( github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa github.com/jackc/pgx/v5 v5.5.0 github.com/jackc/puddle/v2 v2.2.1 github.com/oklog/ulid/v2 v2.1.0 + github.com/riverqueue/river/riverdriver v0.0.0-00010101000000-000000000000 github.com/riverqueue/river/riverdriver/riverpgxv5 v0.0.12 + github.com/riverqueue/river/riverdriver/riverdatabasesql v0.0.0-00010101000000-000000000000 github.com/robfig/cron/v3 v3.0.1 github.com/spf13/cobra v1.8.0 github.com/stretchr/testify v1.8.4 @@ -23,6 +29,7 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/lib/pq v1.10.9 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect golang.org/x/crypto v0.15.0 // indirect diff --git a/go.sum b/go.sum index 874971f7..21e9a72d 100644 --- a/go.sum +++ b/go.sum @@ -18,13 +18,13 @@ github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/oklog/ulid/v2 v2.1.0 h1:+9lhoxAP56we25tyYETBBY1YLA2SaoLvUFgrP2miPJU= github.com/oklog/ulid/v2 v2.1.0/go.mod h1:rcEKHmBBKfef9DhnvX7y1HZBYxjXb0cP5ExxNsTT1QQ= github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/riverqueue/river/riverdriver/riverpgxv5 v0.0.12 h1:mcDBnqwzEXY9WDOwbkd8xmFdSr/H6oHb1F3NCNCmLDY= -github.com/riverqueue/river/riverdriver/riverpgxv5 v0.0.12/go.mod h1:k6hsPkW9Fl3qURzyLHbvxUCqWDpit0WrZ3oEaKezD3E= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= diff --git a/internal/dbsqlc/models.go b/internal/dbsqlc/models.go index 701196e5..0a371bcf 100644 --- a/internal/dbsqlc/models.go +++ b/internal/dbsqlc/models.go @@ -82,9 +82,3 @@ type RiverLeader struct { LeaderID string Name string } - -type RiverMigration struct { - ID int64 - CreatedAt time.Time - Version int64 -} diff --git a/internal/dbsqlc/sqlc.yaml b/internal/dbsqlc/sqlc.yaml index a2820edb..04556ee7 100644 --- a/internal/dbsqlc/sqlc.yaml +++ b/internal/dbsqlc/sqlc.yaml @@ -4,11 +4,9 @@ sql: queries: - river_job.sql - river_leader.sql - - river_migration.sql schema: - river_job.sql - river_leader.sql - - river_migration.sql gen: go: package: "dbsqlc" diff --git a/internal/riverinternaltest/riverinternaltest.go b/internal/riverinternaltest/riverinternaltest.go index cc1edda4..fa51e835 100644 --- a/internal/riverinternaltest/riverinternaltest.go +++ b/internal/riverinternaltest/riverinternaltest.go @@ -8,6 +8,7 @@ import ( "fmt" "log" "log/slog" + "net/url" "os" "runtime" "sync" @@ -52,19 +53,39 @@ func BaseServiceArchetype(tb testing.TB) *baseservice.Archetype { } func DatabaseConfig(databaseName string) *pgxpool.Config { - databaseURL := valutil.ValOrDefault(os.Getenv("TEST_DATABASE_URL"), "postgres:///river_testdb?sslmode=disable") - - config, err := pgxpool.ParseConfig(databaseURL) + config, err := pgxpool.ParseConfig(DatabaseURL(databaseName)) if err != nil { panic(fmt.Sprintf("error parsing database URL: %v", err)) } config.MaxConns = dbPoolMaxConns config.ConnConfig.ConnectTimeout = 10 * time.Second - config.ConnConfig.Database = databaseName config.ConnConfig.RuntimeParams["timezone"] = "UTC" return config } +// DatabaseURL gets a test database URL from TEST_DATABASE_URL or falls back on +// a default pointing to `river_testdb`. If databaseName is set, it replaces the +// database in the URL, although the host and other parameters are preserved. +// +// Most of the time DatabaseConfig should be used instead of this function, but +// it may be useful in non-pgx situations like for examples showing the use of +// `database/sql`. +func DatabaseURL(databaseName string) string { + u, err := url.Parse(valutil.ValOrDefault( + os.Getenv("TEST_DATABASE_URL"), + "postgres://localhost/river_testdb?sslmode=disable"), + ) + if err != nil { + panic(err) + } + + if databaseName != "" { + u.Path = databaseName + } + + return u.String() +} + // DiscardContinuously drains continuously out of the given channel and discards // anything that comes out of it. Returns a stop function that should be invoked // to stop draining. Stop must be invoked before tests finish to stop an diff --git a/internal/util/dbutil/db_util.go b/internal/util/dbutil/db_util.go index 9a31c601..bc50d4b4 100644 --- a/internal/util/dbutil/db_util.go +++ b/internal/util/dbutil/db_util.go @@ -7,6 +7,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/riverqueue/river/internal/dbsqlc" + "github.com/riverqueue/river/riverdriver" ) // Executor is an interface for a type that can begin a transaction and also @@ -56,3 +57,35 @@ func WithTxV[T any](ctx context.Context, txBeginner TxBeginner, innerFunc func(c return res, nil } + +// WithExecutorTx starts and commits a transaction on a driver executor around +// the given function, allowing the return of a generic value. +func WithExecutorTx(ctx context.Context, exec riverdriver.Executor, innerFunc func(ctx context.Context, tx riverdriver.ExecutorTx) error) error { + _, err := WithExecutorTxV(ctx, exec, func(ctx context.Context, tx riverdriver.ExecutorTx) (struct{}, error) { + return struct{}{}, innerFunc(ctx, tx) + }) + return err +} + +// WithExecutorTxV starts and commits a transaction on a driver executor around +// the given function, allowing the return of a generic value. +func WithExecutorTxV[T any](ctx context.Context, exec riverdriver.Executor, innerFunc func(ctx context.Context, tx riverdriver.ExecutorTx) (T, error)) (T, error) { + var defaultRes T + + tx, err := exec.Begin(ctx) + if err != nil { + return defaultRes, fmt.Errorf("error beginning transaction: %w", err) + } + defer tx.Rollback(ctx) + + res, err := innerFunc(ctx, tx) + if err != nil { + return defaultRes, err + } + + if err := tx.Commit(ctx); err != nil { + return defaultRes, fmt.Errorf("error committing transaction: %w", err) + } + + return res, nil +} diff --git a/internal/util/dbutil/db_util_test.go b/internal/util/dbutil/db_util_test.go index 88b6eb99..c0f45cf7 100644 --- a/internal/util/dbutil/db_util_test.go +++ b/internal/util/dbutil/db_util_test.go @@ -8,6 +8,8 @@ import ( "github.com/stretchr/testify/require" "github.com/riverqueue/river/internal/riverinternaltest" + "github.com/riverqueue/river/riverdriver" + "github.com/riverqueue/river/riverdriver/riverpgxv5" ) func TestWithTx(t *testing.T) { @@ -40,3 +42,36 @@ func TestWithTxV(t *testing.T) { require.NoError(t, err) require.Equal(t, 7, ret) } + +func TestWithExecutorTx(t *testing.T) { + t.Parallel() + + ctx := context.Background() + dbPool := riverinternaltest.TestDB(ctx, t) + driver := riverpgxv5.New(dbPool) + + err := WithExecutorTx(ctx, driver.GetExecutor(), func(ctx context.Context, tx riverdriver.ExecutorTx) error { + _, err := tx.Exec(ctx, "SELECT 1") + require.NoError(t, err) + + return nil + }) + require.NoError(t, err) +} + +func TestWithExecutorTxV(t *testing.T) { + t.Parallel() + + ctx := context.Background() + dbPool := riverinternaltest.TestDB(ctx, t) + driver := riverpgxv5.New(dbPool) + + ret, err := WithExecutorTxV(ctx, driver.GetExecutor(), func(ctx context.Context, tx riverdriver.ExecutorTx) (int, error) { + _, err := tx.Exec(ctx, "SELECT 1") + require.NoError(t, err) + + return 7, nil + }) + require.NoError(t, err) + require.Equal(t, 7, ret) +} diff --git a/riverdriver/go.mod b/riverdriver/go.mod new file mode 100644 index 00000000..def94ac2 --- /dev/null +++ b/riverdriver/go.mod @@ -0,0 +1,14 @@ +module github.com/riverqueue/river/riverdriver + +go 1.21.4 + +require github.com/jackc/pgx/v5 v5.5.0 + +require ( + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect + golang.org/x/crypto v0.15.0 // indirect + golang.org/x/sync v0.5.0 // indirect + golang.org/x/text v0.14.0 // indirect +) diff --git a/riverdriver/go.sum b/riverdriver/go.sum new file mode 100644 index 00000000..b9c08498 --- /dev/null +++ b/riverdriver/go.sum @@ -0,0 +1,28 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.5.0 h1:NxstgwndsTRy7eq9/kqYc/BZh5w2hHJV86wjvO+1xPw= +github.com/jackc/pgx/v5 v5.5.0/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +golang.org/x/crypto v0.15.0 h1:frVn1TEaCEaZcn3Tmd7Y2b5KKPaZ+I32Q2OA3kYp5TA= +golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g= +golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= +golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/riverdriver/river_driver_interface.go b/riverdriver/river_driver_interface.go index c47c08aa..b944ba6b 100644 --- a/riverdriver/river_driver_interface.go +++ b/riverdriver/river_driver_interface.go @@ -13,10 +13,20 @@ package riverdriver import ( + "context" + "errors" + "time" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" ) +var ( + ErrNotImplemented = errors.New("driver does not implement this functionality") + ErrNoRows = errors.New("no rows found") + ErrSubTxNotSupported = errors.New("subtransactions not supported for this driver") +) + // Driver provides a database driver for use with river.Client. // // Its purpose is to wrap the interface of a third party database package, with @@ -32,10 +42,95 @@ import ( type Driver[TTx any] interface { // GetDBPool returns a database pool.This doesn't make sense in a world // where multiple drivers are supported and is subject to change. + // + // API is not stable. DO NOT USE. GetDBPool() *pgxpool.Pool + // GetExecutor gets an executor for the driver. + // + // API is not stable. DO NOT USE. + GetExecutor() Executor + + // UnwrapExecutor gets unwraps executor from a driver transaction. + // + // API is not stable. DO NOT USE. + UnwrapExecutor(tx TTx) Executor + // UnwrapTx turns a generically typed transaction into a pgx.Tx for use with // internal infrastructure. This doesn't make sense in a world where // multiple drivers are supported and is subject to change. + // + // API is not stable. DO NOT USE. UnwrapTx(tx TTx) pgx.Tx } + +// Executor provides River operations against a database. It may be a database +// pool or transaction. +type Executor interface { + // Begin begins a new subtransaction. ErrSubTxNotSupported may be returned + // if the executor is a transaction and the driver doesn't support + // subtransactions (like riverdriver/riverdatabasesql for database/sql). + // + // API is not stable. DO NOT USE. + Begin(ctx context.Context) (ExecutorTx, error) + + // Exec executes raw SQL. Used for migrations. + // + // API is not stable. DO NOT USE. + Exec(ctx context.Context, sql string) (struct{}, error) + + // MigrationDeleteByVersionMany deletes many migration versions. + // + // API is not stable. DO NOT USE. + MigrationDeleteByVersionMany(ctx context.Context, versions []int) ([]*Migration, error) + + // MigrationGetAll gets all currently applied migrations. + // + // API is not stable. DO NOT USE. + MigrationGetAll(ctx context.Context) ([]*Migration, error) + + // MigrationInsertMany inserts many migration versions. + // + // API is not stable. DO NOT USE. + MigrationInsertMany(ctx context.Context, versions []int) ([]*Migration, error) + + // TableExists checks whether a table exists for the schema in the current + // search schema. + // + // API is not stable. DO NOT USE. + TableExists(ctx context.Context, tableName string) (bool, error) +} + +// ExecutorTx is an executor which is a transaction. In addition to standard +// Executor operations, it may be committed or rolled back. +type ExecutorTx interface { + Executor + + // Commit commits the transaction. + // + // API is not stable. DO NOT USE. + Commit(ctx context.Context) error + + // Rollback rolls back the transaction. + // + // API is not stable. DO NOT USE. + Rollback(ctx context.Context) error +} + +// Migration represents a River migration. +type Migration struct { + // ID is an automatically generated primary key for the migration. + // + // API is not stable. DO NOT USE. + ID int + + // CreatedAt is when the migration was initially created. + // + // API is not stable. DO NOT USE. + CreatedAt time.Time + + // Version is the version of the migration. + // + // API is not stable. DO NOT USE. + Version int +} diff --git a/riverdriver/river_driver_interface_test.go b/riverdriver/river_driver_interface_test.go index 221ace9d..ac16273f 100644 --- a/riverdriver/river_driver_interface_test.go +++ b/riverdriver/river_driver_interface_test.go @@ -1,10 +1 @@ package riverdriver - -import ( - "github.com/jackc/pgx/v5" - - "github.com/riverqueue/river/riverdriver/riverpgxv5" -) - -// Verify interface compliance. -var _ Driver[pgx.Tx] = &riverpgxv5.Driver{} diff --git a/riverdriver/riverdatabasesql/go.mod b/riverdriver/riverdatabasesql/go.mod new file mode 100644 index 00000000..e4663345 --- /dev/null +++ b/riverdriver/riverdatabasesql/go.mod @@ -0,0 +1,25 @@ +module github.com/riverqueue/river/riverdriver/riverdatabasesql + +go 1.21.4 + +replace github.com/riverqueue/river/riverdriver => ../ + +require ( + github.com/jackc/pgx/v5 v5.5.0 + github.com/lib/pq v1.10.9 + github.com/riverqueue/river/riverdriver v0.0.0-00010101000000-000000000000 + github.com/stretchr/testify v1.8.1 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/riverqueue/river/riverdriver/riverpgxv5 v0.0.12 // indirect + golang.org/x/crypto v0.15.0 // indirect + golang.org/x/sync v0.5.0 // indirect + golang.org/x/text v0.14.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/riverdriver/riverdatabasesql/go.sum b/riverdriver/riverdatabasesql/go.sum new file mode 100644 index 00000000..18cd465b --- /dev/null +++ b/riverdriver/riverdatabasesql/go.sum @@ -0,0 +1,44 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.5.0 h1:NxstgwndsTRy7eq9/kqYc/BZh5w2hHJV86wjvO+1xPw= +github.com/jackc/pgx/v5 v5.5.0/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/riverqueue/river/riverdriver/riverpgxv5 v0.0.12 h1:mcDBnqwzEXY9WDOwbkd8xmFdSr/H6oHb1F3NCNCmLDY= +github.com/riverqueue/river/riverdriver/riverpgxv5 v0.0.12/go.mod h1:k6hsPkW9Fl3qURzyLHbvxUCqWDpit0WrZ3oEaKezD3E= +github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +golang.org/x/crypto v0.15.0 h1:frVn1TEaCEaZcn3Tmd7Y2b5KKPaZ+I32Q2OA3kYp5TA= +golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g= +golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= +golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/riverdriver/riverdatabasesql/internal/dbsqlc/db.go b/riverdriver/riverdatabasesql/internal/dbsqlc/db.go new file mode 100644 index 00000000..57d06d7d --- /dev/null +++ b/riverdriver/riverdatabasesql/internal/dbsqlc/db.go @@ -0,0 +1,24 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 + +package dbsqlc + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New() *Queries { + return &Queries{} +} + +type Queries struct { +} diff --git a/riverdriver/riverdatabasesql/internal/dbsqlc/models.go b/riverdriver/riverdatabasesql/internal/dbsqlc/models.go new file mode 100644 index 00000000..b42c6819 --- /dev/null +++ b/riverdriver/riverdatabasesql/internal/dbsqlc/models.go @@ -0,0 +1,15 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 + +package dbsqlc + +import ( + "time" +) + +type RiverMigration struct { + ID int64 + CreatedAt time.Time + Version int64 +} diff --git a/riverdriver/riverdatabasesql/internal/dbsqlc/river_migration.sql.go b/riverdriver/riverdatabasesql/internal/dbsqlc/river_migration.sql.go new file mode 100644 index 00000000..66048d8f --- /dev/null +++ b/riverdriver/riverdatabasesql/internal/dbsqlc/river_migration.sql.go @@ -0,0 +1,129 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 +// source: river_migration.sql + +package dbsqlc + +import ( + "context" + + "github.com/lib/pq" +) + +const riverMigrationDeleteByVersionMany = `-- name: RiverMigrationDeleteByVersionMany :many +DELETE FROM river_migration +WHERE version = any($1::bigint[]) +RETURNING id, created_at, version +` + +func (q *Queries) RiverMigrationDeleteByVersionMany(ctx context.Context, db DBTX, version []int64) ([]*RiverMigration, error) { + rows, err := db.QueryContext(ctx, riverMigrationDeleteByVersionMany, pq.Array(version)) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverMigration + for rows.Next() { + var i RiverMigration + if err := rows.Scan(&i.ID, &i.CreatedAt, &i.Version); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const riverMigrationGetAll = `-- name: RiverMigrationGetAll :many +SELECT id, created_at, version +FROM river_migration +ORDER BY version +` + +func (q *Queries) RiverMigrationGetAll(ctx context.Context, db DBTX) ([]*RiverMigration, error) { + rows, err := db.QueryContext(ctx, riverMigrationGetAll) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverMigration + for rows.Next() { + var i RiverMigration + if err := rows.Scan(&i.ID, &i.CreatedAt, &i.Version); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const riverMigrationInsert = `-- name: RiverMigrationInsert :one +INSERT INTO river_migration ( + version +) VALUES ( + $1 +) RETURNING id, created_at, version +` + +func (q *Queries) RiverMigrationInsert(ctx context.Context, db DBTX, version int64) (*RiverMigration, error) { + row := db.QueryRowContext(ctx, riverMigrationInsert, version) + var i RiverMigration + err := row.Scan(&i.ID, &i.CreatedAt, &i.Version) + return &i, err +} + +const riverMigrationInsertMany = `-- name: RiverMigrationInsertMany :many +INSERT INTO river_migration ( + version +) +SELECT + unnest($1::bigint[]) +RETURNING id, created_at, version +` + +func (q *Queries) RiverMigrationInsertMany(ctx context.Context, db DBTX, version []int64) ([]*RiverMigration, error) { + rows, err := db.QueryContext(ctx, riverMigrationInsertMany, pq.Array(version)) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverMigration + for rows.Next() { + var i RiverMigration + if err := rows.Scan(&i.ID, &i.CreatedAt, &i.Version); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const tableExists = `-- name: TableExists :one +SELECT CASE WHEN to_regclass($1) IS NULL THEN false + ELSE true END +` + +func (q *Queries) TableExists(ctx context.Context, db DBTX, tableName string) (bool, error) { + row := db.QueryRowContext(ctx, tableExists, tableName) + var column_1 bool + err := row.Scan(&column_1) + return column_1, err +} diff --git a/riverdriver/riverdatabasesql/internal/dbsqlc/sqlc.yaml b/riverdriver/riverdatabasesql/internal/dbsqlc/sqlc.yaml new file mode 100644 index 00000000..ec379388 --- /dev/null +++ b/riverdriver/riverdatabasesql/internal/dbsqlc/sqlc.yaml @@ -0,0 +1,24 @@ +version: "2" +sql: + - engine: "postgresql" + queries: + - ../../../riverpgxv5/internal/dbsqlc/river_migration.sql + schema: + - ../../../riverpgxv5/internal/dbsqlc/river_migration.sql + gen: + go: + package: "dbsqlc" + sql_package: "database/sql" + out: "." + emit_exact_table_names: true + emit_methods_with_db_argument: true + emit_result_struct_pointers: true + + overrides: + - db_type: "timestamptz" + go_type: "time.Time" + - db_type: "timestamptz" + go_type: + type: "time.Time" + pointer: true + nullable: true diff --git a/riverdriver/riverdatabasesql/river_database_sql.go b/riverdriver/riverdatabasesql/river_database_sql.go new file mode 100644 index 00000000..8b1136f6 --- /dev/null +++ b/riverdriver/riverdatabasesql/river_database_sql.go @@ -0,0 +1,139 @@ +// Package riverdatabasesql bundles a River driver for Go's built in database/sql. +// +// This is _not_ a fully functional driver, and only supports use through +// rivermigrate for purposes of interacting with migration frameworks like +// Goose. Using it with a River client will panic. +package riverdatabasesql + +import ( + "context" + "database/sql" + "errors" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/riverqueue/river/riverdriver" + "github.com/riverqueue/river/riverdriver/riverdatabasesql/internal/dbsqlc" +) + +// Driver is an implementation of riverdriver.Driver for database/sql. +type Driver struct { + dbPool *sql.DB + queries *dbsqlc.Queries +} + +// New returns a new database/sql River driver for use with River. +// +// It takes an sql.DB to use for use with River. The pool should already be +// configured to use the schema specified in the client's Schema field. The pool +// must not be closed while associated River objects are running. +// +// This is _not_ a fully functional driver, and only supports use through +// rivermigrate for purposes of interacting with migration frameworks like +// Goose. Using it with a River client will panic. +func New(dbPool *sql.DB) *Driver { + return &Driver{dbPool: dbPool, queries: dbsqlc.New()} +} + +func (d *Driver) GetDBPool() *pgxpool.Pool { panic(riverdriver.ErrNotImplemented) } +func (d *Driver) GetExecutor() riverdriver.Executor { + return &Executor{d.dbPool, d.dbPool, dbsqlc.New()} +} + +func (d *Driver) UnwrapExecutor(tx *sql.Tx) riverdriver.Executor { + return &Executor{nil, tx, dbsqlc.New()} +} +func (d *Driver) UnwrapTx(tx *sql.Tx) pgx.Tx { panic(riverdriver.ErrNotImplemented) } + +type Executor struct { + dbPool *sql.DB + dbtx dbsqlc.DBTX + queries *dbsqlc.Queries +} + +func (e *Executor) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) { + if e.dbPool == nil { + return nil, riverdriver.ErrSubTxNotSupported + } + + tx, err := e.dbPool.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + return &ExecutorTx{Executor: Executor{nil, tx, e.queries}, tx: tx}, nil +} + +func (e *Executor) Exec(ctx context.Context, sql string) (struct{}, error) { + _, err := e.dbtx.ExecContext(ctx, sql) + return struct{}{}, interpretError(err) +} + +func (e *Executor) MigrationDeleteByVersionMany(ctx context.Context, versions []int) ([]*riverdriver.Migration, error) { + migrations, err := e.queries.RiverMigrationDeleteByVersionMany(ctx, e.dbtx, + mapSlice(versions, func(v int) int64 { return int64(v) })) + return mapMigrations(migrations), interpretError(err) +} + +func (e *Executor) MigrationGetAll(ctx context.Context) ([]*riverdriver.Migration, error) { + migrations, err := e.queries.RiverMigrationGetAll(ctx, e.dbtx) + return mapMigrations(migrations), interpretError(err) +} + +func (e *Executor) MigrationInsertMany(ctx context.Context, versions []int) ([]*riverdriver.Migration, error) { + migrations, err := e.queries.RiverMigrationInsertMany(ctx, e.dbtx, + mapSlice(versions, func(v int) int64 { return int64(v) })) + return mapMigrations(migrations), interpretError(err) +} + +func (e *Executor) TableExists(ctx context.Context, tableName string) (bool, error) { + exists, err := e.queries.TableExists(ctx, e.dbtx, tableName) + return exists, interpretError(err) +} + +type ExecutorTx struct { + Executor + tx *sql.Tx +} + +func (t *ExecutorTx) Commit(ctx context.Context) error { + // unfortunately, `database/sql` does not take a context ... + return t.tx.Commit() +} + +func (t *ExecutorTx) Rollback(ctx context.Context) error { + // unfortunately, `database/sql` does not take a context ... + return t.tx.Rollback() +} + +func interpretError(err error) error { + if errors.Is(err, sql.ErrNoRows) { + return riverdriver.ErrNoRows + } + return err +} + +func mapMigrations(migrations []*dbsqlc.RiverMigration) []*riverdriver.Migration { + if migrations == nil { + return nil + } + + return mapSlice(migrations, func(m *dbsqlc.RiverMigration) *riverdriver.Migration { + return &riverdriver.Migration{ + ID: int(m.ID), + CreatedAt: m.CreatedAt, + Version: int(m.Version), + } + }) +} + +// mapSlice manipulates a slice and transforms it to a slice of another type. +func mapSlice[T any, R any](collection []T, mapFunc func(T) R) []R { + result := make([]R, len(collection)) + + for i, item := range collection { + result[i] = mapFunc(item) + } + + return result +} diff --git a/riverdriver/riverdatabasesql/river_database_sql_test.go b/riverdriver/riverdatabasesql/river_database_sql_test.go new file mode 100644 index 00000000..f97be426 --- /dev/null +++ b/riverdriver/riverdatabasesql/river_database_sql_test.go @@ -0,0 +1,41 @@ +package riverdatabasesql + +import ( + "database/sql" + "errors" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/riverqueue/river/riverdriver" +) + +// Verify interface compliance. +var _ riverdriver.Driver[*sql.Tx] = New(nil) + +func TestNew(t *testing.T) { + t.Parallel() + + t.Run("AllowsNilDatabasePool", func(t *testing.T) { + t.Parallel() + + dbPool := &sql.DB{} + driver := New(dbPool) + require.Equal(t, dbPool, driver.dbPool) + }) + + t.Run("AllowsNilDatabasePool", func(t *testing.T) { + t.Parallel() + + driver := New(nil) + require.Nil(t, driver.dbPool) + }) +} + +func TestInterpretError(t *testing.T) { + t.Parallel() + + require.EqualError(t, interpretError(errors.New("an error")), "an error") + require.ErrorIs(t, interpretError(sql.ErrNoRows), riverdriver.ErrNoRows) + require.NoError(t, interpretError(nil)) +} diff --git a/riverdriver/riverpgxv5/go.mod b/riverdriver/riverpgxv5/go.mod index e6027196..c782fb86 100644 --- a/riverdriver/riverpgxv5/go.mod +++ b/riverdriver/riverpgxv5/go.mod @@ -1,9 +1,12 @@ module github.com/riverqueue/river/riverdriver/riverpgxv5 -go 1.21.0 +go 1.21.4 + +replace github.com/riverqueue/river/riverdriver => ../ require ( github.com/jackc/pgx/v5 v5.5.0 + github.com/riverqueue/river/riverdriver v0.0.0-00010101000000-000000000000 github.com/stretchr/testify v1.8.1 ) diff --git a/riverdriver/riverpgxv5/internal/dbsqlc/db.go b/riverdriver/riverpgxv5/internal/dbsqlc/db.go new file mode 100644 index 00000000..27ee12e2 --- /dev/null +++ b/riverdriver/riverpgxv5/internal/dbsqlc/db.go @@ -0,0 +1,25 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 + +package dbsqlc + +import ( + "context" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +type DBTX interface { + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row +} + +func New() *Queries { + return &Queries{} +} + +type Queries struct { +} diff --git a/riverdriver/riverpgxv5/internal/dbsqlc/models.go b/riverdriver/riverpgxv5/internal/dbsqlc/models.go new file mode 100644 index 00000000..b42c6819 --- /dev/null +++ b/riverdriver/riverpgxv5/internal/dbsqlc/models.go @@ -0,0 +1,15 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.22.0 + +package dbsqlc + +import ( + "time" +) + +type RiverMigration struct { + ID int64 + CreatedAt time.Time + Version int64 +} diff --git a/internal/dbsqlc/river_migration.sql b/riverdriver/riverpgxv5/internal/dbsqlc/river_migration.sql similarity index 76% rename from internal/dbsqlc/river_migration.sql rename to riverdriver/riverpgxv5/internal/dbsqlc/river_migration.sql index a9ce74f8..fa7a65ba 100644 --- a/internal/dbsqlc/river_migration.sql +++ b/riverdriver/riverpgxv5/internal/dbsqlc/river_migration.sql @@ -5,7 +5,7 @@ CREATE TABLE river_migration( CONSTRAINT version CHECK (version >= 1) ); --- name: RiverMigrationDeleteByVersionMany :one +-- name: RiverMigrationDeleteByVersionMany :many DELETE FROM river_migration WHERE version = any(@version::bigint[]) RETURNING *; @@ -28,4 +28,8 @@ INSERT INTO river_migration ( ) SELECT unnest(@version::bigint[]) -RETURNING *; \ No newline at end of file +RETURNING *; + +-- name: TableExists :one +SELECT CASE WHEN to_regclass(@table_name) IS NULL THEN false + ELSE true END; \ No newline at end of file diff --git a/internal/dbsqlc/river_migration.sql.go b/riverdriver/riverpgxv5/internal/dbsqlc/river_migration.sql.go similarity index 71% rename from internal/dbsqlc/river_migration.sql.go rename to riverdriver/riverpgxv5/internal/dbsqlc/river_migration.sql.go index 2bffffe8..583b8d59 100644 --- a/internal/dbsqlc/river_migration.sql.go +++ b/riverdriver/riverpgxv5/internal/dbsqlc/river_migration.sql.go @@ -9,17 +9,30 @@ import ( "context" ) -const riverMigrationDeleteByVersionMany = `-- name: RiverMigrationDeleteByVersionMany :one +const riverMigrationDeleteByVersionMany = `-- name: RiverMigrationDeleteByVersionMany :many DELETE FROM river_migration WHERE version = any($1::bigint[]) RETURNING id, created_at, version ` -func (q *Queries) RiverMigrationDeleteByVersionMany(ctx context.Context, db DBTX, version []int64) (*RiverMigration, error) { - row := db.QueryRow(ctx, riverMigrationDeleteByVersionMany, version) - var i RiverMigration - err := row.Scan(&i.ID, &i.CreatedAt, &i.Version) - return &i, err +func (q *Queries) RiverMigrationDeleteByVersionMany(ctx context.Context, db DBTX, version []int64) ([]*RiverMigration, error) { + rows, err := db.Query(ctx, riverMigrationDeleteByVersionMany, version) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*RiverMigration + for rows.Next() { + var i RiverMigration + if err := rows.Scan(&i.ID, &i.CreatedAt, &i.Version); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil } const riverMigrationGetAll = `-- name: RiverMigrationGetAll :many @@ -91,3 +104,15 @@ func (q *Queries) RiverMigrationInsertMany(ctx context.Context, db DBTX, version } return items, nil } + +const tableExists = `-- name: TableExists :one +SELECT CASE WHEN to_regclass($1) IS NULL THEN false + ELSE true END +` + +func (q *Queries) TableExists(ctx context.Context, db DBTX, tableName string) (bool, error) { + row := db.QueryRow(ctx, tableExists, tableName) + var column_1 bool + err := row.Scan(&column_1) + return column_1, err +} diff --git a/riverdriver/riverpgxv5/internal/dbsqlc/sqlc.yaml b/riverdriver/riverpgxv5/internal/dbsqlc/sqlc.yaml new file mode 100644 index 00000000..ace0bb62 --- /dev/null +++ b/riverdriver/riverpgxv5/internal/dbsqlc/sqlc.yaml @@ -0,0 +1,24 @@ +version: "2" +sql: + - engine: "postgresql" + queries: + - river_migration.sql + schema: + - river_migration.sql + gen: + go: + package: "dbsqlc" + sql_package: "pgx/v5" + out: "." + emit_exact_table_names: true + emit_methods_with_db_argument: true + emit_result_struct_pointers: true + + overrides: + - db_type: "timestamptz" + go_type: "time.Time" + - db_type: "timestamptz" + go_type: + type: "time.Time" + pointer: true + nullable: true diff --git a/riverdriver/riverpgxv5/river_pgx_v5_driver.go b/riverdriver/riverpgxv5/river_pgx_v5_driver.go index 9c803088..0dc8cc3a 100644 --- a/riverdriver/riverpgxv5/river_pgx_v5_driver.go +++ b/riverdriver/riverpgxv5/river_pgx_v5_driver.go @@ -6,21 +6,27 @@ package riverpgxv5 import ( + "context" + "errors" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" + + "github.com/riverqueue/river/riverdriver" + "github.com/riverqueue/river/riverdriver/riverpgxv5/internal/dbsqlc" ) // Driver is an implementation of riverdriver.Driver for Pgx v5. type Driver struct { - dbPool *pgxpool.Pool + dbPool *pgxpool.Pool + queries *dbsqlc.Queries } -// New returns a new Pgx v5 River driver for use with river.Client. +// New returns a new Pgx v5 River driver for use with River. // // It takes a pgxpool.Pool to use for use with River. The pool should already be // configured to use the schema specified in the client's Schema field. The pool -// must not be closed while the associated client is running (not until graceful -// shutdown has completed). +// must not be closed while associated River objects are running. // // The database pool may be nil. If it is, a client that it's sent into will not // be able to start up (calls to Start will error) and the Insert and InsertMany @@ -29,8 +35,98 @@ type Driver struct { // in testing so that inserts can be performed and verified on a test // transaction that will be rolled back. func New(dbPool *pgxpool.Pool) *Driver { - return &Driver{dbPool: dbPool} + return &Driver{dbPool: dbPool, queries: dbsqlc.New()} +} + +func (d *Driver) GetDBPool() *pgxpool.Pool { return d.dbPool } +func (d *Driver) GetExecutor() riverdriver.Executor { return &Executor{d.dbPool, dbsqlc.New()} } +func (d *Driver) UnwrapExecutor(tx pgx.Tx) riverdriver.Executor { return &Executor{tx, dbsqlc.New()} } +func (d *Driver) UnwrapTx(tx pgx.Tx) pgx.Tx { return tx } + +type Executor struct { + dbtx interface { + dbsqlc.DBTX + Begin(ctx context.Context) (pgx.Tx, error) + } + queries *dbsqlc.Queries +} + +func (e *Executor) Begin(ctx context.Context) (riverdriver.ExecutorTx, error) { + tx, err := e.dbtx.Begin(ctx) + if err != nil { + return nil, err + } + return &ExecutorTx{Executor: Executor{tx, e.queries}, tx: tx}, nil +} + +func (e *Executor) Exec(ctx context.Context, sql string) (struct{}, error) { + _, err := e.dbtx.Exec(ctx, sql) + return struct{}{}, interpretError(err) +} + +func (e *Executor) MigrationDeleteByVersionMany(ctx context.Context, versions []int) ([]*riverdriver.Migration, error) { + migrations, err := e.queries.RiverMigrationDeleteByVersionMany(ctx, e.dbtx, + mapSlice(versions, func(v int) int64 { return int64(v) })) + return mapMigrations(migrations), interpretError(err) +} + +func (e *Executor) MigrationGetAll(ctx context.Context) ([]*riverdriver.Migration, error) { + migrations, err := e.queries.RiverMigrationGetAll(ctx, e.dbtx) + return mapMigrations(migrations), interpretError(err) +} + +func (e *Executor) MigrationInsertMany(ctx context.Context, versions []int) ([]*riverdriver.Migration, error) { + migrations, err := e.queries.RiverMigrationInsertMany(ctx, e.dbtx, + mapSlice(versions, func(v int) int64 { return int64(v) })) + return mapMigrations(migrations), interpretError(err) } -func (d *Driver) GetDBPool() *pgxpool.Pool { return d.dbPool } -func (d *Driver) UnwrapTx(tx pgx.Tx) pgx.Tx { return tx } +func (e *Executor) TableExists(ctx context.Context, tableName string) (bool, error) { + exists, err := e.queries.TableExists(ctx, e.dbtx, tableName) + return exists, interpretError(err) +} + +type ExecutorTx struct { + Executor + tx pgx.Tx +} + +func (t *ExecutorTx) Commit(ctx context.Context) error { + return t.tx.Commit(ctx) +} + +func (t *ExecutorTx) Rollback(ctx context.Context) error { + return t.tx.Rollback(ctx) +} + +func interpretError(err error) error { + if errors.Is(err, pgx.ErrNoRows) { + return riverdriver.ErrNoRows + } + return err +} + +func mapMigrations(migrations []*dbsqlc.RiverMigration) []*riverdriver.Migration { + if migrations == nil { + return nil + } + + return mapSlice(migrations, func(m *dbsqlc.RiverMigration) *riverdriver.Migration { + return &riverdriver.Migration{ + ID: int(m.ID), + CreatedAt: m.CreatedAt, + Version: int(m.Version), + } + }) +} + +// mapSlice manipulates a slice and transforms it to a slice of another type. +func mapSlice[T any, R any](collection []T, mapFunc func(T) R) []R { + result := make([]R, len(collection)) + + for i, item := range collection { + result[i] = mapFunc(item) + } + + return result +} diff --git a/riverdriver/riverpgxv5/river_pgx_v5_driver_test.go b/riverdriver/riverpgxv5/river_pgx_v5_driver_test.go index 16e5556c..229b735c 100644 --- a/riverdriver/riverpgxv5/river_pgx_v5_driver_test.go +++ b/riverdriver/riverpgxv5/river_pgx_v5_driver_test.go @@ -1,21 +1,42 @@ package riverpgxv5 import ( + "errors" "testing" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/require" + + "github.com/riverqueue/river/riverdriver" ) +// Verify interface compliance. +var _ riverdriver.Driver[pgx.Tx] = New(nil) + func TestNew(t *testing.T) { + t.Parallel() + t.Run("AllowsNilDatabasePool", func(t *testing.T) { + t.Parallel() + dbPool := &pgxpool.Pool{} driver := New(dbPool) require.Equal(t, dbPool, driver.dbPool) }) t.Run("AllowsNilDatabasePool", func(t *testing.T) { + t.Parallel() + driver := New(nil) require.Nil(t, driver.dbPool) }) } + +func TestInterpretError(t *testing.T) { + t.Parallel() + + require.EqualError(t, interpretError(errors.New("an error")), "an error") + require.ErrorIs(t, interpretError(pgx.ErrNoRows), riverdriver.ErrNoRows) + require.NoError(t, interpretError(nil)) +} diff --git a/rivermigrate/example_migrate_database_sql_test.go b/rivermigrate/example_migrate_database_sql_test.go new file mode 100644 index 00000000..20e26960 --- /dev/null +++ b/rivermigrate/example_migrate_database_sql_test.go @@ -0,0 +1,72 @@ +package rivermigrate_test + +import ( + "context" + "database/sql" + "fmt" + "strings" + + _ "github.com/jackc/pgx/v5/stdlib" + + "github.com/riverqueue/river/internal/riverinternaltest" + "github.com/riverqueue/river/riverdriver/riverdatabasesql" + "github.com/riverqueue/river/rivermigrate" +) + +// Example_migrateDatabaseSQL demonstrates the use of River's Go migration API +// through Go's built-in database/sql package. +func Example_migrateDatabaseSQL() { + ctx := context.Background() + + dbPool, err := sql.Open("pgx", riverinternaltest.DatabaseURL("river_testdb_example")) + if err != nil { + panic(err) + } + defer dbPool.Close() + + tx, err := dbPool.BeginTx(ctx, nil) + if err != nil { + panic(err) + } + defer tx.Rollback() + + migrator := rivermigrate.New(riverdatabasesql.New(dbPool), nil) + + // Our test database starts with a full River schema. Drop it so that we can + // demonstrate working migrations. This isn't necessary outside this test. + dropRiverSchema(ctx, migrator, tx) + + printVersions := func(res *rivermigrate.MigrateResult) { + for _, version := range res.Versions { + fmt.Printf("Migrated [%s] version %d\n", strings.ToUpper(string(res.Direction)), version.Version) + } + } + + // Migrate to version 3. An actual call may want to omit all MigrateOpts, + // which will default to applying all available up migrations. + res, err := migrator.MigrateTx(ctx, tx, rivermigrate.DirectionUp, &rivermigrate.MigrateOpts{ + TargetVersion: 3, + }) + if err != nil { + panic(err) + } + printVersions(res) + + // Migrate down by three steps. Down migrating defaults to running only one + // step unless overridden by an option like MaxSteps or TargetVersion. + res, err = migrator.MigrateTx(ctx, tx, rivermigrate.DirectionDown, &rivermigrate.MigrateOpts{ + MaxSteps: 3, + }) + if err != nil { + panic(err) + } + printVersions(res) + + // Output: + // Migrated [UP] version 1 + // Migrated [UP] version 2 + // Migrated [UP] version 3 + // Migrated [DOWN] version 3 + // Migrated [DOWN] version 2 + // Migrated [DOWN] version 1 +} diff --git a/rivermigrate/example_migrate_test.go b/rivermigrate/example_migrate_test.go index 8b8bee0e..9bd26c3b 100644 --- a/rivermigrate/example_migrate_test.go +++ b/rivermigrate/example_migrate_test.go @@ -3,35 +3,15 @@ package rivermigrate_test import ( "context" "fmt" - "sort" "strings" - "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" - "github.com/riverqueue/river" "github.com/riverqueue/river/internal/riverinternaltest" "github.com/riverqueue/river/riverdriver/riverpgxv5" "github.com/riverqueue/river/rivermigrate" ) -type SortArgs struct { - // Strings is a slice of strings to sort. - Strings []string `json:"strings"` -} - -func (SortArgs) Kind() string { return "sort" } - -type SortWorker struct { - river.WorkerDefaults[SortArgs] -} - -func (w *SortWorker) Work(ctx context.Context, job *river.Job[SortArgs]) error { - sort.Strings(job.Args.Strings) - fmt.Printf("Sorted strings: %+v\n", job.Args.Strings) - return nil -} - // Example_migrate demonstrates the use of River's Go migration API by migrating // up and down. func Example_migrate() { @@ -81,11 +61,6 @@ func Example_migrate() { } printVersions(res) - // Roll back all changes applied so our test database is left unaffected. - if err := tx.Rollback(ctx); err != nil { - panic(err) - } - // Output: // Migrated [UP] version 1 // Migrated [UP] version 2 @@ -95,7 +70,7 @@ func Example_migrate() { // Migrated [DOWN] version 1 } -func dropRiverSchema(ctx context.Context, migrator *rivermigrate.Migrator[pgx.Tx], tx pgx.Tx) { +func dropRiverSchema[TTx any](ctx context.Context, migrator *rivermigrate.Migrator[TTx], tx TTx) { _, err := migrator.MigrateTx(ctx, tx, rivermigrate.DirectionDown, &rivermigrate.MigrateOpts{ TargetVersion: -1, }) diff --git a/rivermigrate/river_migrate.go b/rivermigrate/river_migrate.go index d3d54151..c14b54f7 100644 --- a/rivermigrate/river_migrate.go +++ b/rivermigrate/river_migrate.go @@ -5,7 +5,6 @@ package rivermigrate import ( "context" "embed" - "errors" "fmt" "io" "io/fs" @@ -17,10 +16,6 @@ import ( "strings" "time" - "github.com/jackc/pgerrcode" - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgconn" - "github.com/riverqueue/river/internal/baseservice" "github.com/riverqueue/river/internal/dbsqlc" "github.com/riverqueue/river/internal/util/dbutil" @@ -67,8 +62,9 @@ type Migrator[TTx any] struct { // New returns a new migrator with the given database driver and configuration. // The config parameter may be omitted as nil. // -// Currently only one driver is supported, which is Pgx v5. See package -// riverpgxv5. +// Two drivers are supported for migrations, one for Pgx v5 and one for the +// built-in database/sql package for use with migration frameworks like Goose. +// See packages riverpgxv5 and riverdatabasesql respectively. // // The function takes a generic parameter TTx representing a transaction type, // but it can be omitted because it'll generally always be inferred from the @@ -155,8 +151,7 @@ type MigrateVersion struct { Version int } -func migrateVersionToInt(version MigrateVersion) int { return version.Version } -func migrateVersionToInt64(version MigrateVersion) int64 { return int64(version.Version) } +func migrateVersionToInt(version MigrateVersion) int { return version.Version } type Direction string @@ -180,12 +175,12 @@ const ( // // handle error // } func (m *Migrator[TTx]) Migrate(ctx context.Context, direction Direction, opts *MigrateOpts) (*MigrateResult, error) { - return dbutil.WithTxV(ctx, m.driver.GetDBPool(), func(ctx context.Context, tx pgx.Tx) (*MigrateResult, error) { + return dbutil.WithExecutorTxV(ctx, m.driver.GetExecutor(), func(ctx context.Context, tx riverdriver.ExecutorTx) (*MigrateResult, error) { switch direction { case DirectionDown: - return m.migrateDownTx(ctx, tx, direction, opts) + return m.migrateDown(ctx, tx, direction, opts) case DirectionUp: - return m.migrateUpTx(ctx, tx, direction, opts) + return m.migrateUp(ctx, tx, direction, opts) } panic("invalid direction: " + direction) @@ -213,22 +208,22 @@ func (m *Migrator[TTx]) Migrate(ctx context.Context, direction Direction, opts * func (m *Migrator[TTx]) MigrateTx(ctx context.Context, tx TTx, direction Direction, opts *MigrateOpts) (*MigrateResult, error) { switch direction { case DirectionDown: - return m.migrateDownTx(ctx, m.driver.UnwrapTx(tx), direction, opts) + return m.migrateDown(ctx, m.driver.UnwrapExecutor(tx), direction, opts) case DirectionUp: - return m.migrateUpTx(ctx, m.driver.UnwrapTx(tx), direction, opts) + return m.migrateUp(ctx, m.driver.UnwrapExecutor(tx), direction, opts) } panic("invalid direction: " + direction) } -// migrateDownTx runs down migrations. -func (m *Migrator[TTx]) migrateDownTx(ctx context.Context, tx pgx.Tx, direction Direction, opts *MigrateOpts) (*MigrateResult, error) { - existingMigrations, err := m.existingMigrations(ctx, tx) +// migrateDown runs down migrations. +func (m *Migrator[TTx]) migrateDown(ctx context.Context, exec riverdriver.Executor, direction Direction, opts *MigrateOpts) (*MigrateResult, error) { + existingMigrations, err := m.existingMigrations(ctx, exec) if err != nil { return nil, err } existingMigrationsMap := sliceutil.KeyBy(existingMigrations, - func(m *dbsqlc.RiverMigration) (int, struct{}) { return int(m.Version), struct{}{} }) + func(m *riverdriver.Migration) (int, struct{}) { return m.Version, struct{}{} }) targetMigrations := maps.Clone(m.migrations) for version := range targetMigrations { @@ -240,7 +235,7 @@ func (m *Migrator[TTx]) migrateDownTx(ctx context.Context, tx pgx.Tx, direction sortedTargetMigrations := maputil.Values(targetMigrations) slices.SortFunc(sortedTargetMigrations, func(a, b *migrationBundle) int { return b.Version - a.Version }) // reverse order - res, err := m.applyMigrations(ctx, tx, direction, opts, sortedTargetMigrations) + res, err := m.applyMigrations(ctx, exec, direction, opts, sortedTargetMigrations) if err != nil { return nil, err } @@ -259,34 +254,34 @@ func (m *Migrator[TTx]) migrateDownTx(ctx context.Context, tx pgx.Tx, direction return res, nil } - if _, err := m.queries.RiverMigrationDeleteByVersionMany(ctx, tx, sliceutil.Map(res.Versions, migrateVersionToInt64)); err != nil { + if _, err := exec.MigrationDeleteByVersionMany(ctx, sliceutil.Map(res.Versions, migrateVersionToInt)); err != nil { return nil, fmt.Errorf("error deleting migration rows for versions %+v: %w", res.Versions, err) } return res, nil } -// migrateUpTx runs up migrations. -func (m *Migrator[TTx]) migrateUpTx(ctx context.Context, tx pgx.Tx, direction Direction, opts *MigrateOpts) (*MigrateResult, error) { - existingMigrations, err := m.existingMigrations(ctx, tx) +// migrateUp runs up migrations. +func (m *Migrator[TTx]) migrateUp(ctx context.Context, exec riverdriver.Executor, direction Direction, opts *MigrateOpts) (*MigrateResult, error) { + existingMigrations, err := m.existingMigrations(ctx, exec) if err != nil { return nil, err } targetMigrations := maps.Clone(m.migrations) for _, migrateRow := range existingMigrations { - delete(targetMigrations, int(migrateRow.Version)) + delete(targetMigrations, migrateRow.Version) } sortedTargetMigrations := maputil.Values(targetMigrations) slices.SortFunc(sortedTargetMigrations, func(a, b *migrationBundle) int { return a.Version - b.Version }) - res, err := m.applyMigrations(ctx, tx, direction, opts, sortedTargetMigrations) + res, err := m.applyMigrations(ctx, exec, direction, opts, sortedTargetMigrations) if err != nil { return nil, err } - if _, err := m.queries.RiverMigrationInsertMany(ctx, tx, sliceutil.Map(res.Versions, migrateVersionToInt64)); err != nil { + if _, err := exec.MigrationInsertMany(ctx, sliceutil.Map(res.Versions, migrateVersionToInt)); err != nil { return nil, fmt.Errorf("error inserting migration rows for versions %+v: %w", res.Versions, err) } @@ -295,7 +290,7 @@ func (m *Migrator[TTx]) migrateUpTx(ctx context.Context, tx pgx.Tx, direction Di // Common code shared between the up and down migration directions that walks // through each target migration and applies it, logging appropriately. -func (m *Migrator[TTx]) applyMigrations(ctx context.Context, tx pgx.Tx, direction Direction, opts *MigrateOpts, sortedTargetMigrations []*migrationBundle) (*MigrateResult, error) { +func (m *Migrator[TTx]) applyMigrations(ctx context.Context, exec riverdriver.Executor, direction Direction, opts *MigrateOpts, sortedTargetMigrations []*migrationBundle) (*MigrateResult, error) { if opts == nil { opts = &MigrateOpts{} } @@ -355,7 +350,7 @@ func (m *Migrator[TTx]) applyMigrations(ctx context.Context, tx pgx.Tx, directio slog.Int("version", versionBundle.Version), ) - _, err := tx.Exec(ctx, sql) + _, err := exec.Exec(ctx, sql) if err != nil { return nil, fmt.Errorf("error applying version %03d [%s]: %w", versionBundle.Version, strings.ToUpper(string(direction)), err) @@ -377,26 +372,18 @@ func (m *Migrator[TTx]) applyMigrations(ctx context.Context, tx pgx.Tx, directio // the `river_migration` table not existing yet. (The subtransaction is needed // because otherwise the existing transaction would become aborted on an // unsuccessful `river_migration` check.) -func (m *Migrator[TTx]) existingMigrations(ctx context.Context, tx pgx.Tx) ([]*dbsqlc.RiverMigration, error) { - // We start another inner transaction here because in case this is the first - // ever migration run, the transaction may become aborted if `river_migration` - // doesn't exist, a condition which we must handle gracefully. - migrations, err := dbutil.WithTxV(ctx, tx, func(ctx context.Context, tx pgx.Tx) ([]*dbsqlc.RiverMigration, error) { - migrations, err := m.queries.RiverMigrationGetAll(ctx, tx) - if err != nil { - return nil, fmt.Errorf("error getting current migrate rows: %w", err) - } - return migrations, nil - }) +func (m *Migrator[TTx]) existingMigrations(ctx context.Context, exec riverdriver.Executor) ([]*riverdriver.Migration, error) { + exists, err := exec.TableExists(ctx, "river_migration") if err != nil { - var pgErr *pgconn.PgError - if errors.As(err, &pgErr) { - if pgErr.Code == pgerrcode.UndefinedTable && strings.Contains(pgErr.Message, "river_migration") { - return nil, nil - } - } + return nil, fmt.Errorf("error checking if `%s` exists: %w", "river_migration", err) + } + if !exists { + return nil, nil + } - return nil, err + migrations, err := exec.MigrationGetAll(ctx) + if err != nil { + return nil, fmt.Errorf("error getting existing migrations: %w", err) } return migrations, nil diff --git a/rivermigrate/river_migrate_test.go b/rivermigrate/river_migrate_test.go index 7cdb1ac9..f5633834 100644 --- a/rivermigrate/river_migrate_test.go +++ b/rivermigrate/river_migrate_test.go @@ -2,18 +2,23 @@ package rivermigrate import ( "context" + "database/sql" + "log/slog" "slices" "testing" "github.com/jackc/pgerrcode" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/jackc/pgx/v5/stdlib" "github.com/stretchr/testify/require" - "github.com/riverqueue/river/internal/dbsqlc" "github.com/riverqueue/river/internal/riverinternaltest" "github.com/riverqueue/river/internal/util/dbutil" "github.com/riverqueue/river/internal/util/sliceutil" + "github.com/riverqueue/river/riverdriver" + "github.com/riverqueue/river/riverdriver/riverdatabasesql" "github.com/riverqueue/river/riverdriver/riverpgxv5" ) @@ -44,13 +49,13 @@ var ( func TestMigrator(t *testing.T) { t.Parallel() - var ( - ctx = context.Background() - queries = dbsqlc.New() - ) + ctx := context.Background() type testBundle struct { - tx pgx.Tx + dbPool *pgxpool.Pool + driver *riverpgxv5.Driver + logger *slog.Logger + tx pgx.Tx } setup := func(t *testing.T) (*Migrator[pgx.Tx], *testBundle) { @@ -62,24 +67,45 @@ func TestMigrator(t *testing.T) { // we use test DBs instead of test transactions, but this could be // changed to test transactions as long as test cases were made to run // non-parallel. - testDB := riverinternaltest.TestDB(ctx, t) + dbPool := riverinternaltest.TestDB(ctx, t) // Despite being in an isolated database, we still start a transaction // because we don't want schema changes we make to persist. - tx, err := testDB.Begin(ctx) + tx, err := dbPool.Begin(ctx) require.NoError(t, err) t.Cleanup(func() { _ = tx.Rollback(ctx) }) bundle := &testBundle{ - tx: tx, + dbPool: dbPool, + driver: riverpgxv5.New(dbPool), + logger: riverinternaltest.Logger(t), + tx: tx, } - migrator := New(riverpgxv5.New(testDB), nil) + migrator := New(bundle.driver, &Config{Logger: bundle.logger}) migrator.migrations = riverMigrationsWithtestVersionsMap return migrator, bundle } + // Gets a migrator using the driver for `database/sql`. + setupDatabaseSQLMigrator := func(t *testing.T, bundle *testBundle) (*Migrator[*sql.Tx], *sql.Tx) { + t.Helper() + + stdPool := stdlib.OpenDBFromPool(bundle.dbPool) + t.Cleanup(func() { require.NoError(t, stdPool.Close()) }) + + tx, err := stdPool.BeginTx(ctx, nil) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, tx.Rollback()) }) + + driver := riverdatabasesql.New(stdPool) + migrator := New(driver, &Config{Logger: bundle.logger}) + migrator.migrations = riverMigrationsWithtestVersionsMap + + return migrator, tx + } + t.Run("MigrateDownDefault", func(t *testing.T) { t.Parallel() @@ -135,10 +161,10 @@ func TestMigrator(t *testing.T) { require.Equal(t, []int{riverMigrationsWithTestVersionsMaxVersion, riverMigrationsWithTestVersionsMaxVersion - 1}, sliceutil.Map(res.Versions, migrateVersionToInt)) - migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) + migrations, err := bundle.driver.UnwrapExecutor(bundle.tx).MigrationGetAll(ctx) require.NoError(t, err) require.Equal(t, seqOneTo(riverMigrationsWithTestVersionsMaxVersion-2), - sliceutil.Map(migrations, riverMigrationToInt)) + sliceutil.Map(migrations, migrationToInt)) err = dbExecError(ctx, bundle.tx, "SELECT name FROM test_table") require.Error(t, err) @@ -156,10 +182,26 @@ func TestMigrator(t *testing.T) { require.NoError(t, err) require.Equal(t, []int{}, sliceutil.Map(res.Versions, migrateVersionToInt)) - migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) + migrations, err := bundle.driver.UnwrapExecutor(bundle.tx).MigrationGetAll(ctx) require.NoError(t, err) require.Equal(t, seqOneTo(3), - sliceutil.Map(migrations, riverMigrationToInt)) + sliceutil.Map(migrations, migrationToInt)) + }) + + t.Run("MigrateDownWithDatabaseSQLDriver", func(t *testing.T) { + t.Parallel() + + _, bundle := setup(t) + migrator, tx := setupDatabaseSQLMigrator(t, bundle) + + res, err := migrator.MigrateTx(ctx, tx, DirectionDown, &MigrateOpts{MaxSteps: 1}) + require.NoError(t, err) + require.Equal(t, []int{3}, sliceutil.Map(res.Versions, migrateVersionToInt)) + + migrations, err := migrator.driver.UnwrapExecutor(tx).MigrationGetAll(ctx) + require.NoError(t, err) + require.Equal(t, seqOneTo(2), + sliceutil.Map(migrations, migrationToInt)) }) t.Run("MigrateDownWithTargetVersion", func(t *testing.T) { @@ -175,10 +217,10 @@ func TestMigrator(t *testing.T) { require.Equal(t, []int{5, 4}, sliceutil.Map(res.Versions, migrateVersionToInt)) - migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) + migrations, err := bundle.driver.UnwrapExecutor(bundle.tx).MigrationGetAll(ctx) require.NoError(t, err) require.Equal(t, seqOneTo(3), - sliceutil.Map(migrations, riverMigrationToInt)) + sliceutil.Map(migrations, migrationToInt)) err = dbExecError(ctx, bundle.tx, "SELECT name FROM test_table") require.Error(t, err) @@ -242,10 +284,10 @@ func TestMigrator(t *testing.T) { require.Equal(t, []int{riverMigrationsWithTestVersionsMaxVersion - 1, riverMigrationsWithTestVersionsMaxVersion}, sliceutil.Map(res.Versions, migrateVersionToInt)) - migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) + migrations, err := bundle.driver.UnwrapExecutor(bundle.tx).MigrationGetAll(ctx) require.NoError(t, err) require.Equal(t, seqOneTo(riverMigrationsWithTestVersionsMaxVersion), - sliceutil.Map(migrations, riverMigrationToInt)) + sliceutil.Map(migrations, migrationToInt)) _, err = bundle.tx.Exec(ctx, "SELECT * FROM test_table") require.NoError(t, err) @@ -258,10 +300,10 @@ func TestMigrator(t *testing.T) { require.Equal(t, DirectionUp, res.Direction) require.Equal(t, []int{}, sliceutil.Map(res.Versions, migrateVersionToInt)) - migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) + migrations, err := bundle.driver.UnwrapExecutor(bundle.tx).MigrationGetAll(ctx) require.NoError(t, err) require.Equal(t, seqOneTo(riverMigrationsWithTestVersionsMaxVersion), - sliceutil.Map(migrations, riverMigrationToInt)) + sliceutil.Map(migrations, migrationToInt)) _, err = bundle.tx.Exec(ctx, "SELECT * FROM test_table") require.NoError(t, err) @@ -278,10 +320,10 @@ func TestMigrator(t *testing.T) { require.Equal(t, []int{riverMigrationsWithTestVersionsMaxVersion - 1}, sliceutil.Map(res.Versions, migrateVersionToInt)) - migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) + migrations, err := bundle.driver.UnwrapExecutor(bundle.tx).MigrationGetAll(ctx) require.NoError(t, err) require.Equal(t, seqOneTo(riverMigrationsWithTestVersionsMaxVersion-1), - sliceutil.Map(migrations, riverMigrationToInt)) + sliceutil.Map(migrations, migrationToInt)) // Column `name` is only added in the second test version. err = dbExecError(ctx, bundle.tx, "SELECT name FROM test_table") @@ -304,10 +346,26 @@ func TestMigrator(t *testing.T) { require.NoError(t, err) require.Equal(t, []int{}, sliceutil.Map(res.Versions, migrateVersionToInt)) - migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) + migrations, err := bundle.driver.UnwrapExecutor(bundle.tx).MigrationGetAll(ctx) require.NoError(t, err) require.Equal(t, seqOneTo(3), - sliceutil.Map(migrations, riverMigrationToInt)) + sliceutil.Map(migrations, migrationToInt)) + }) + + t.Run("MigrateUpWithDatabaseSQLDriver", func(t *testing.T) { + t.Parallel() + + _, bundle := setup(t) + migrator, tx := setupDatabaseSQLMigrator(t, bundle) + + res, err := migrator.MigrateTx(ctx, tx, DirectionUp, &MigrateOpts{MaxSteps: 1}) + require.NoError(t, err) + require.Equal(t, []int{riverMigrationsMaxVersion + 1}, sliceutil.Map(res.Versions, migrateVersionToInt)) + + migrations, err := migrator.driver.UnwrapExecutor(tx).MigrationGetAll(ctx) + require.NoError(t, err) + require.Equal(t, seqOneTo(riverMigrationsMaxVersion+1), + sliceutil.Map(migrations, migrationToInt)) }) t.Run("MigrateUpWithTargetVersion", func(t *testing.T) { @@ -320,9 +378,9 @@ func TestMigrator(t *testing.T) { require.Equal(t, []int{4, 5}, sliceutil.Map(res.Versions, migrateVersionToInt)) - migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) + migrations, err := bundle.driver.UnwrapExecutor(bundle.tx).MigrationGetAll(ctx) require.NoError(t, err) - require.Equal(t, seqOneTo(5), sliceutil.Map(migrations, riverMigrationToInt)) + require.Equal(t, seqOneTo(5), sliceutil.Map(migrations, migrationToInt)) }) t.Run("MigrateUpWithTargetVersionInvalid", func(t *testing.T) { @@ -354,7 +412,7 @@ func dbExecError(ctx context.Context, executor dbutil.Executor, sql string) erro }) } -func riverMigrationToInt(r *dbsqlc.RiverMigration) int { return int(r.Version) } +func migrationToInt(r *riverdriver.Migration) int { return r.Version } func seqOneTo(max int) []int { seq := make([]int, max)