diff --git a/CHANGELOG.md b/CHANGELOG.md index dfb4e414..43a4e5b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- Add transaction variants for queue-related client functions: `QueueGetTx`, `QueueListTx`, `QueuePauseTx`, and `QueueResumeTx`. [PR #402](https://github.com/riverqueue/river/pull/402). + ### Fixed - Fix possible Client shutdown panics if the user-provided context is cancelled while jobs are still running. [PR #401](https://github.com/riverqueue/river/pull/401). diff --git a/client.go b/client.go index b9cdfd8d..88efec97 100644 --- a/client.go +++ b/client.go @@ -1607,6 +1607,15 @@ func (c *Client[TTx]) QueueGet(ctx context.Context, name string) (*rivertype.Que return c.driver.GetExecutor().QueueGet(ctx, name) } +// QueueGetTx returns the queue with the given name. If the queue has not recently +// been active or does not exist, returns ErrNotFound. +// +// The provided context is used for the underlying Postgres query and can be +// used to cancel the operation or apply a timeout. +func (c *Client[TTx]) QueueGetTx(ctx context.Context, tx TTx, name string) (*rivertype.Queue, error) { + return c.driver.UnwrapExecutor(tx).QueueGet(ctx, name) +} + // QueueListResult is the result of a job list operation. It contains a list of // jobs and leaves room for future cursor functionality. type QueueListResult struct { @@ -1635,8 +1644,31 @@ func (c *Client[TTx]) QueueList(ctx context.Context, params *QueueListParams) (* return nil, err } - listRes := &QueueListResult{Queues: queues} - return listRes, nil + return &QueueListResult{Queues: queues}, nil +} + +// QueueListTx returns a list of all queues that are currently active or were +// recently active. Limit and offset can be used to paginate the results. +// +// The provided context is used for the underlying Postgres query and can be +// used to cancel the operation or apply a timeout. +// +// params := river.NewQueueListParams().First(10) +// queueRows, err := client.QueueListTx(ctx, tx, params) +// if err != nil { +// // handle error +// } +func (c *Client[TTx]) QueueListTx(ctx context.Context, tx TTx, params *QueueListParams) (*QueueListResult, error) { + if params == nil { + params = NewQueueListParams() + } + + queues, err := c.driver.UnwrapExecutor(tx).QueueList(ctx, int(params.paginationCount)) + if err != nil { + return nil, err + } + + return &QueueListResult{Queues: queues}, nil } // QueuePause pauses the queue with the given name. When a queue is paused, @@ -1668,6 +1700,31 @@ func (c *Client[TTx]) QueuePause(ctx context.Context, name string, opts *QueuePa return tx.Commit(ctx) } +// QueuePauseTx pauses the queue with the given name. When a queue is paused, +// clients will not fetch any more jobs for that particular queue. To pause all +// queues at once, use the special queue name "*". +// +// Clients with a configured notifier should receive a notification about the +// paused queue(s) within a few milliseconds of the transaction commit. Clients +// in poll-only mode will pause after their next poll for queue configuration. +// +// The provided context is used for the underlying Postgres update and can be +// used to cancel the operation or apply a timeout. The opts are reserved for +// future functionality. +func (c *Client[TTx]) QueuePauseTx(ctx context.Context, tx TTx, name string, opts *QueuePauseOpts) error { + executorTx := c.driver.UnwrapExecutor(tx) + + if err := executorTx.QueuePause(ctx, name); err != nil { + return err + } + + if err := c.notifyQueuePauseOrResume(ctx, executorTx, controlActionPause, name, opts); err != nil { + return err + } + + return nil +} + // QueueResume resumes the queue with the given name. If the queue was // previously paused, any clients configured to work that queue will resume // fetching additional jobs. To resume all queues at once, use the special queue @@ -1698,6 +1755,32 @@ func (c *Client[TTx]) QueueResume(ctx context.Context, name string, opts *QueueP return tx.Commit(ctx) } +// QueueResume resumes the queue with the given name. If the queue was +// previously paused, any clients configured to work that queue will resume +// fetching additional jobs. To resume all queues at once, use the special queue +// name "*". +// +// Clients with a configured notifier should receive a notification about the +// resumed queue(s) within a few milliseconds of the transaction commit. Clients +// in poll-only mode will resume after their next poll for queue configuration. +// +// The provided context is used for the underlying Postgres update and can be +// used to cancel the operation or apply a timeout. The opts are reserved for +// future functionality. +func (c *Client[TTx]) QueueResumeTx(ctx context.Context, tx TTx, name string, opts *QueuePauseOpts) error { + executorTx := c.driver.UnwrapExecutor(tx) + + if err := executorTx.QueueResume(ctx, name); err != nil { + return err + } + + if err := c.notifyQueuePauseOrResume(ctx, executorTx, controlActionResume, name, opts); err != nil { + return err + } + + return nil +} + // Generates a default client ID using the current hostname and time. func defaultClientID(startedAt time.Time) string { host, _ := os.Hostname() diff --git a/client_test.go b/client_test.go index b7815f26..009d1a27 100644 --- a/client_test.go +++ b/client_test.go @@ -572,6 +572,36 @@ func Test_Client(t *testing.T) { require.Equal(t, insertRes2.Job.ID, event.Job.ID) }) + t.Run("PauseAndResumeSingleQueueTx", func(t *testing.T) { + t.Parallel() + + config, bundle := setupConfig(t) + client := newTestClient(t, bundle.dbPool, config) + + queue := testfactory.Queue(ctx, t, client.driver.GetExecutor(), nil) + + tx, err := bundle.dbPool.Begin(ctx) + require.NoError(t, err) + t.Cleanup(func() { tx.Rollback(ctx) }) + + require.NoError(t, client.QueuePauseTx(ctx, tx, queue.Name, nil)) + + queueRes, err := client.QueueGetTx(ctx, tx, queue.Name) + require.NoError(t, err) + require.WithinDuration(t, time.Now(), *queueRes.PausedAt, 2*time.Second) + + // Not paused outside transaction. + queueRes, err = client.QueueGet(ctx, queue.Name) + require.NoError(t, err) + require.Nil(t, queueRes.PausedAt) + + require.NoError(t, client.QueueResumeTx(ctx, tx, queue.Name, nil)) + + queueRes, err = client.QueueGetTx(ctx, tx, queue.Name) + require.NoError(t, err) + require.Nil(t, queueRes.PausedAt) + }) + t.Run("PausedBeforeStart", func(t *testing.T) { t.Parallel() @@ -2691,18 +2721,68 @@ func Test_Client_QueueGet(t *testing.T) { client, _ := setup(t) - now := time.Now().UTC() - insertedQueue := testfactory.Queue(ctx, t, client.driver.GetExecutor(), nil) + queue := testfactory.Queue(ctx, t, client.driver.GetExecutor(), nil) - queue, err := client.QueueGet(ctx, insertedQueue.Name) + queueRes, err := client.QueueGet(ctx, queue.Name) require.NoError(t, err) - require.NotNil(t, queue) + require.WithinDuration(t, time.Now(), queueRes.CreatedAt, 2*time.Second) + require.WithinDuration(t, queue.CreatedAt, queueRes.CreatedAt, time.Millisecond) + require.Equal(t, []byte("{}"), queueRes.Metadata) + require.Equal(t, queue.Name, queueRes.Name) + require.Nil(t, queueRes.PausedAt) + }) + + t.Run("ReturnsErrNotFoundIfQueueDoesNotExist", func(t *testing.T) { + t.Parallel() - require.WithinDuration(t, now, queue.CreatedAt, 2*time.Second) - require.WithinDuration(t, insertedQueue.CreatedAt, queue.CreatedAt, time.Millisecond) - require.Equal(t, []byte("{}"), queue.Metadata) - require.Equal(t, insertedQueue.Name, queue.Name) - require.Nil(t, queue.PausedAt) + client, _ := setup(t) + + queueRes, err := client.QueueGet(ctx, "a_queue_that_does_not_exist") + require.Error(t, err) + require.ErrorIs(t, err, ErrNotFound) + require.Nil(t, queueRes) + }) +} + +func Test_Client_QueueGetTx(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + type testBundle struct { + executorTx riverdriver.ExecutorTx + tx pgx.Tx + } + + setup := func(t *testing.T) (*Client[pgx.Tx], *testBundle) { + t.Helper() + + dbPool := riverinternaltest.TestDB(ctx, t) + config := newTestConfig(t, nil) + client := newTestClient(t, dbPool, config) + + tx, err := dbPool.Begin(ctx) + require.NoError(t, err) + t.Cleanup(func() { tx.Rollback(ctx) }) + + return client, &testBundle{executorTx: client.driver.UnwrapExecutor(tx), tx: tx} + } + + t.Run("FetchesAnExistingQueue", func(t *testing.T) { + t.Parallel() + + client, bundle := setup(t) + + queue := testfactory.Queue(ctx, t, bundle.executorTx, nil) + + queueRes, err := client.QueueGetTx(ctx, bundle.tx, queue.Name) + require.NoError(t, err) + require.Equal(t, queue.Name, queueRes.Name) + + // Not visible outside of transaction. + _, err = client.QueueGet(ctx, queue.Name) + require.Error(t, err) + require.ErrorIs(t, err, ErrNotFound) }) t.Run("ReturnsErrNotFoundIfQueueDoesNotExist", func(t *testing.T) { @@ -2710,10 +2790,10 @@ func Test_Client_QueueGet(t *testing.T) { client, _ := setup(t) - queue, err := client.QueueGet(ctx, "a_queue_that_does_not_exist") + queueRes, err := client.QueueGet(ctx, "a_queue_that_does_not_exist") require.Error(t, err) require.ErrorIs(t, err, ErrNotFound) - require.Nil(t, queue) + require.Nil(t, queueRes) }) } @@ -2782,6 +2862,53 @@ func Test_Client_QueueList(t *testing.T) { }) } +func Test_Client_QueueListTx(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + type testBundle struct { + executorTx riverdriver.ExecutorTx + tx pgx.Tx + } + + setup := func(t *testing.T) (*Client[pgx.Tx], *testBundle) { + t.Helper() + + dbPool := riverinternaltest.TestDB(ctx, t) + config := newTestConfig(t, nil) + client := newTestClient(t, dbPool, config) + + tx, err := dbPool.Begin(ctx) + require.NoError(t, err) + t.Cleanup(func() { tx.Rollback(ctx) }) + + return client, &testBundle{executorTx: client.driver.UnwrapExecutor(tx), tx: tx} + } + + t.Run("ListsQueues", func(t *testing.T) { + t.Parallel() + + client, bundle := setup(t) + + listRes, err := client.QueueListTx(ctx, bundle.tx, NewQueueListParams()) + require.NoError(t, err) + require.Empty(t, listRes.Queues) + + queue := testfactory.Queue(ctx, t, bundle.executorTx, nil) + + listRes, err = client.QueueListTx(ctx, bundle.tx, NewQueueListParams()) + require.NoError(t, err) + require.Len(t, listRes.Queues, 1) + require.Equal(t, queue.Name, listRes.Queues[0].Name) + + // Not visible outside of transaction. + listRes, err = client.QueueList(ctx, NewQueueListParams()) + require.NoError(t, err) + require.Empty(t, listRes.Queues) + }) +} + func Test_Client_RetryPolicy(t *testing.T) { t.Parallel()