From 6438bfc745a204f6a6a5490dc08b43680d32cb7d Mon Sep 17 00:00:00 2001 From: Brandur Date: Tue, 5 May 2026 07:22:20 -0400 Subject: [PATCH] A few tweaks to `QueueBundle.Remove` implementation This builds on #1235 to bring in a few tweaks: * Removing a queue is a blocking operation because it needs to wait for the producer to finish up its jobs and shut down. It'd be better to provide a way for this not to block forever, so here we add a context parameter to `QueueBundle.Remove` similar to the one taken by `Client.Stop`. If the client becomes done before the producer resolves, `QueueBundle.Remove` falls through with the error. * Add a "stress" test case for `QueueBundle.Remove`. It's meant to detect a deadlock or other concurrency bug in case there is one and gives us a little more confidence that what we have here is right. * Renamed `addProducer` and `removeProducer` to `producerAdd` and `producerRemove` so they sort more nicely against each other. * Add changelogentry. --- CHANGELOG.md | 4 +++ client.go | 44 +++++++++++++++----------- client_test.go | 86 ++++++++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 109 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 048c1aa5..a9327db1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- Add `QeueueBundle.Remove` to remove an already added queue/producer. [PR #1235](https://github.com/riverqueue/river/pull/1235) and [PR #1240](https://github.com/riverqueue/river/pull/1240). + ### Fixed - Fix unsafe concurrent producer map access in client. [PR #1236](https://github.com/riverqueue/river/pull/1236). diff --git a/client.go b/client.go index c62959c1..fda06a5f 100644 --- a/client.go +++ b/client.go @@ -767,11 +767,11 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client } client.queues = &QueueBundle{ - addProducer: client.addProducer, - removeProducer: client.removeProducer, clientFetchCooldown: config.FetchCooldown, clientFetchPollInterval: config.FetchPollInterval, clientWillExecuteJobs: config.willExecuteJobs(), + producerAdd: client.producerAdd, + producerRemove: client.producerRemove, } baseservice.Init(archetype, &client.baseService) @@ -879,7 +879,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client client.services = append(client.services, client.elector) for queue, queueConfig := range config.Queues { - if _, err := client.addProducer(queue, queueConfig); err != nil { + if _, err := client.producerAdd(queue, queueConfig); err != nil { return nil, err } } @@ -2177,7 +2177,7 @@ func (c *Client[TTx]) validateJobArgs(args JobArgs) error { return nil } -func (c *Client[TTx]) addProducer(queueName string, queueConfig QueueConfig) (*producer, error) { +func (c *Client[TTx]) producerAdd(queueName string, queueConfig QueueConfig) (*producer, error) { c.producersMu.Lock() defer c.producersMu.Unlock() @@ -2210,7 +2210,7 @@ func (c *Client[TTx]) addProducer(queueName string, queueConfig QueueConfig) (*p return producer, nil } -func (c *Client[TTx]) removeProducer(queueName string) error { +func (c *Client[TTx]) producerRemove(ctx context.Context, queueName string) error { c.producersMu.Lock() defer c.producersMu.Unlock() @@ -2219,7 +2219,17 @@ func (c *Client[TTx]) removeProducer(queueName string) error { return &QueueNotFoundError{Name: queueName} } - producer.Stop() + shouldStop, stopped, finalizeStop := producer.StopInit() + if shouldStop { + select { + case <-ctx.Done(): + finalizeStop(false) + return ctx.Err() + case <-stopped: + finalizeStop(true) + } + } + delete(c.producersByQueueName, queueName) return nil @@ -2812,17 +2822,14 @@ func (c *Client[TTx]) Schema() string { return c.config.Schema } // QueueBundle is a bundle for adding additional queues. It's made accessible // through Client.Queues. type QueueBundle struct { - // Function that adds a producer to the associated client. - addProducer func(queueName string, queueConfig QueueConfig) (*producer, error) - - removeProducer func(queueName string) error - clientFetchCooldown time.Duration clientFetchPollInterval time.Duration clientWillExecuteJobs bool - fetchCtx context.Context //nolint:containedctx + fetchCtx context.Context //nolint:containedctx + producerAdd func(queueName string, queueConfig QueueConfig) (*producer, error) // add producer to associated client + producerRemove func(ctx context.Context, queueName string) error // remove producer from associated client // Mutex that's acquired when client is starting and stopping and when a // queue is being added so that we can be sure that a client is fully @@ -2847,7 +2854,7 @@ func (b *QueueBundle) Add(queueName string, queueConfig QueueConfig) error { b.startStopMu.Lock() defer b.startStopMu.Unlock() - producer, err := b.addProducer(queueName, queueConfig) + producer, err := b.producerAdd(queueName, queueConfig) if err != nil { return err } @@ -2863,13 +2870,14 @@ func (b *QueueBundle) Add(queueName string, queueConfig QueueConfig) error { } // Remove removes a queue from the client, stopping the producer if the client -// is running. The function will block until all jobs currently being worked in -// the queue have completed. This blocking behavior may affect other operations, -// including shutdown timing. +// is running. It waits for any jobs currently being worked in the queue to +// complete before returning. If the provided context is done before the +// producer has stopped, Remove returns the context's error and does not remove +// the queue. // // Returns an error if the client is not configured to execute jobs or if the // specified queue does not exist. -func (b *QueueBundle) Remove(queueName string) error { +func (b *QueueBundle) Remove(ctx context.Context, queueName string) error { if !b.clientWillExecuteJobs { return errors.New("client is not configured to execute jobs, cannot remove queue") } @@ -2877,7 +2885,7 @@ func (b *QueueBundle) Remove(queueName string) error { b.startStopMu.Lock() defer b.startStopMu.Unlock() - return b.removeProducer(queueName) + return b.producerRemove(ctx, queueName) } // Generates a default client ID using the current hostname and time. diff --git a/client_test.go b/client_test.go index 551e6160..350d8bfb 100644 --- a/client_test.go +++ b/client_test.go @@ -406,6 +406,34 @@ func Test_Client_Common(t *testing.T) { wg.Wait() }) + t.Run("Queues_Remove_Stress", func(t *testing.T) { + t.Parallel() + + client, _ := setup(t) + + startClient(ctx, t, client) + riversharedtest.WaitOrTimeout(t, client.baseStartStop.Started()) + + var wg sync.WaitGroup + + for i := range 5 { + wg.Add(1) + workerNum := i + go func() { + defer wg.Done() + + for j := range 5 { + queueName := fmt.Sprintf("stress_queue_%d_%d", workerNum, j) + + require.NoError(t, client.Queues().Add(queueName, QueueConfig{MaxWorkers: 1})) + require.NoError(t, client.Queues().Remove(ctx, queueName)) + } + }() + } + + wg.Wait() + }) + t.Run("Queues_Remove_BeforeStart", func(t *testing.T) { t.Parallel() @@ -427,7 +455,7 @@ func Test_Client_Common(t *testing.T) { }) require.NoError(t, err) - err = client.Queues().Remove(queueName) + err = client.Queues().Remove(ctx, queueName) require.NoError(t, err) startClient(ctx, t, client) @@ -481,7 +509,7 @@ func Test_Client_Common(t *testing.T) { event := riversharedtest.WaitOrTimeout(t, subscribeChan) require.Equal(t, EventKindJobCompleted, event.Kind) - err = client.Queues().Remove(queueName) + err = client.Queues().Remove(ctx, queueName) require.NoError(t, err) insertRes, err := client.Insert(ctx, &JobArgs{}, &InsertOpts{ @@ -502,12 +530,56 @@ func Test_Client_Common(t *testing.T) { require.Equal(t, rivertype.JobStateAvailable, job.State) }) + t.Run("Queues_Remove_ContextDone", func(t *testing.T) { + t.Parallel() + + client, _ := setup(t) + + type JobArgs struct { + testutil.JobArgsReflectKind[JobArgs] + } + + jobStartedChan := make(chan struct{}) + AddWorker(client.config.Workers, WorkFunc(func(ctx context.Context, job *Job[JobArgs]) error { + close(jobStartedChan) + <-ctx.Done() + return nil + })) + + queueName := "remove_context_done_queue" + require.NoError(t, client.Queues().Add(queueName, QueueConfig{MaxWorkers: 2})) + + startClient(ctx, t, client) + riversharedtest.WaitOrTimeout(t, client.baseStartStop.Started()) + + _, err := client.Insert(ctx, &JobArgs{}, &InsertOpts{Queue: queueName}) + require.NoError(t, err) + + riversharedtest.WaitOrTimeout(t, jobStartedChan) + + // Remove with an already-cancelled context should return immediately + // without removing the queue. + cancelledCtx, cancel := context.WithCancel(ctx) + cancel() + + err = client.Queues().Remove(cancelledCtx, queueName) + require.ErrorIs(t, err, context.Canceled) + + // Queue should still exist and be functional since Remove bailed out. + // Verify by successfully removing it with a valid context after + // cancelling the job via StopAndCancel. + require.NoError(t, client.StopAndCancel(ctx)) + + // Re-start so startClient's cleanup Stop doesn't fail. + require.NoError(t, client.Start(ctx)) + }) + t.Run("Queues_Remove_NonExistentQueue", func(t *testing.T) { t.Parallel() client, _ := setup(t) - err := client.Queues().Remove("non_existent_queue") + err := client.Queues().Remove(ctx, "non_existent_queue") require.Error(t, err) var queueNotFoundErr *QueueNotFoundError require.ErrorAs(t, err, &queueNotFoundErr) @@ -522,7 +594,7 @@ func Test_Client_Common(t *testing.T) { config.Workers = nil client := newTestClient(t, bundle.dbPool, config) - err := client.Queues().Remove("any_queue") + err := client.Queues().Remove(ctx, "any_queue") require.Error(t, err) require.Contains(t, err.Error(), "client is not configured to execute jobs, cannot remove queue") }) @@ -551,7 +623,7 @@ func Test_Client_Common(t *testing.T) { event := riversharedtest.WaitOrTimeout(t, subscribeChan) require.Equal(t, EventKindJobCompleted, event.Kind) - err = client.Queues().Remove(QueueDefault) + err = client.Queues().Remove(ctx, QueueDefault) require.NoError(t, err) insertRes, err := client.Insert(ctx, &JobArgs{}, nil) @@ -601,7 +673,7 @@ func Test_Client_Common(t *testing.T) { event := riversharedtest.WaitOrTimeout(t, subscribeChan) require.Equal(t, EventKindJobCompleted, event.Kind) - err = client.Queues().Remove(queueName) + err = client.Queues().Remove(ctx, queueName) require.NoError(t, err) err = client.Queues().Add(queueName, QueueConfig{ @@ -634,7 +706,7 @@ func Test_Client_Common(t *testing.T) { require.Equal(t, EventKindJobCompleted, event.Kind) require.Equal(t, insertRes1.Job.ID, event.Job.ID) - require.NoError(t, client.Queues().Remove("test_queue")) + require.NoError(t, client.Queues().Remove(ctx, "test_queue")) insertRes2, err := client.Insert(ctx, &noOpArgs{}, &InsertOpts{Queue: "test_queue"}) require.NoError(t, err)