diff --git a/internal/riverinternaltest/riverinternaltest.go b/internal/riverinternaltest/riverinternaltest.go index 0cf89122..818a8c40 100644 --- a/internal/riverinternaltest/riverinternaltest.go +++ b/internal/riverinternaltest/riverinternaltest.go @@ -4,7 +4,6 @@ package riverinternaltest import ( "context" - "errors" "fmt" "log" "net/url" @@ -219,44 +218,7 @@ func TestTx(ctx context.Context, tb testing.TB) pgx.Tx { return dbPool } - tx, err := getPool().Begin(ctx) - require.NoError(tb, err) - - tb.Cleanup(func() { - err := tx.Rollback(ctx) - - if err == nil { - return - } - - // Try to look for an error on rollback because it does occasionally - // reveal a real problem in the way a test is written. However, allow - // tests to roll back their transaction early if they like, so ignore - // `ErrTxClosed`. - if errors.Is(err, pgx.ErrTxClosed) { - return - } - - // In case of a cancelled context during a database operation, which - // happens in many tests, pgx seems to not only roll back the - // transaction, but closes the connection, and returns this error on - // rollback. Allow this error since it's hard to prevent it in our flows - // that use contexts heavily. - if err.Error() == "conn closed" { - return - } - - // Similar to the above, but a newly appeared error that wraps the - // above. As far as I can tell, no error variables are available to use - // with `errors.Is`. - if err.Error() == "failed to deallocate cached statement(s): conn closed" { - return - } - - require.NoError(tb, err) - }) - - return tx + return riversharedtest.TestTxPool(ctx, tb, getPool()) } // TruncateRiverTables truncates River tables in the target database. This is diff --git a/rivershared/go.mod b/rivershared/go.mod index 68873f3c..0c98c612 100644 --- a/rivershared/go.mod +++ b/rivershared/go.mod @@ -15,6 +15,7 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/rivershared/go.sum b/rivershared/go.sum index 0a87d5d6..8bf1bbac 100644 --- a/rivershared/go.sum +++ b/rivershared/go.sum @@ -1,5 +1,7 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 h1:Dj0L5fhJ9F82ZJyVOmBx6msDp/kfd1t9GRfny/mfJA0= +github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds= github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -8,10 +10,13 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/riverqueue/river v0.18.0 h1:sGHeTOL9MR8+pMIVHRm59fzet8Ron/xjF3Yq/PSGb78= github.com/riverqueue/river v0.18.0/go.mod h1:oapX5xb/L2YnkE801QubDZ0COHxVxEGVY37icPzghhU= +github.com/riverqueue/river v0.19.0/go.mod h1:YJ7LA2uBdqFHQJzKyYc+X6S04KJeiwsS1yU5a1rynlk= github.com/riverqueue/river/riverdriver v0.18.0 h1:a2haR5I0MQLHjLCSVFpUEeJALCLemRl5zCztucysm1E= github.com/riverqueue/river/riverdriver v0.18.0/go.mod h1:Mj45PbHabEnBv/nSah0J1/tg6hrX/SNeXtcYcSqMzxQ= +github.com/riverqueue/river/riverdriver v0.19.0/go.mod h1:Soxi08hHkEvopExAp6ADG2437r4coSiB4QpuIL5E28k= github.com/riverqueue/river/rivertype v0.18.0 h1:YsXR5NbLAzniurGO0+zcISWMKq7Y71xkIe2oi86OAsE= github.com/riverqueue/river/rivertype v0.18.0/go.mod h1:DETcejveWlq6bAb8tHkbgJqmXWVLiFhTiEm8j7co1bE= +github.com/riverqueue/river/rivertype v0.19.0/go.mod h1:DETcejveWlq6bAb8tHkbgJqmXWVLiFhTiEm8j7co1bE= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= diff --git a/rivershared/riversharedtest/riversharedtest.go b/rivershared/riversharedtest/riversharedtest.go index 6b4ec2e6..3c43a372 100644 --- a/rivershared/riversharedtest/riversharedtest.go +++ b/rivershared/riversharedtest/riversharedtest.go @@ -1,6 +1,9 @@ package riversharedtest import ( + "cmp" + "context" + "errors" "fmt" "log/slog" "os" @@ -8,6 +11,8 @@ import ( "testing" "time" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/require" "go.uber.org/goleak" @@ -27,6 +32,47 @@ func BaseServiceArchetype(tb testing.TB) *baseservice.Archetype { } } +// A pool and mutex to protect it, lazily initialized by TestTx. Once open, this +// pool is never explicitly closed, instead closing implicitly as the package +// tests finish. +var ( + dbPool *pgxpool.Pool //nolint:gochecknoglobals + dbPoolMu sync.RWMutex //nolint:gochecknoglobals +) + +// DBPool gets a lazily initialized database pool for `TEST_DATABASE_URL` or +// `river_test` if the former isn't specified. +func DBPool(ctx context.Context, tb testing.TB) *pgxpool.Pool { + tb.Helper() + + tryPool := func() *pgxpool.Pool { + dbPoolMu.RLock() + defer dbPoolMu.RUnlock() + return dbPool + } + + if dbPool := tryPool(); dbPool != nil { + return dbPool + } + + dbPoolMu.Lock() + defer dbPoolMu.Unlock() + + // Multiple goroutines may have passed the initial `nil` check on start + // up, so check once more to make sure pool hasn't been set yet. + if dbPool != nil { + return dbPool + } + + dbPool, err := pgxpool.New(ctx, cmp.Or( + os.Getenv("TEST_DATABASE_URL"), + "postgres://localhost:5432/river_test", + )) + require.NoError(tb, err) + + return dbPool +} + // Logger returns a logger suitable for use in tests. // // Defaults to informational verbosity. If env is set with `RIVER_DEBUG=true`, @@ -48,6 +94,69 @@ func LoggerWarn(tb testing.TB) *slog.Logger { return slogtest.NewLogger(tb, &slog.HandlerOptions{Level: slog.LevelWarn}) } +// TestTx starts a test transaction that's rolled back automatically as the test +// case is cleaning itself up. +// +// This variant uses the default database pool from DBPool that points to +// `TEST_DATABASE_URL` or `river_test` if the former wasn't specified. +func TestTx(ctx context.Context, tb testing.TB) pgx.Tx { + tb.Helper() + return TestTxPool(ctx, tb, DBPool(ctx, tb)) +} + +// TestTxPool starts a test transaction that's rolled back automatically as the +// test case is cleaning itself up. +// +// This variant starts the test transaction on the specified database pool. +func TestTxPool(ctx context.Context, tb testing.TB, dbPool *pgxpool.Pool) pgx.Tx { + tb.Helper() + + tx, err := dbPool.Begin(ctx) + require.NoError(tb, err) + + tb.Cleanup(func() { + // Tests may inerit context from `t.Context()` which is cancelled after + // tests run and before calling clean up. We need a non-cancelled + // context to issue rollback here, so use a bit of a bludgeon to do so + // with `context.WithoutCancel()`. + ctx := context.WithoutCancel(ctx) + + err := tx.Rollback(ctx) + + if err == nil { + return + } + + // Try to look for an error on rollback because it does occasionally + // reveal a real problem in the way a test is written. However, allow + // tests to roll back their transaction early if they like, so ignore + // `ErrTxClosed`. + if errors.Is(err, pgx.ErrTxClosed) { + return + } + + // In case of a cancelled context during a database operation, which + // happens in many tests, pgx seems to not only roll back the + // transaction, but closes the connection, and returns this error on + // rollback. Allow this error since it's hard to prevent it in our flows + // that use contexts heavily. + if err.Error() == "conn closed" { + return + } + + // Similar to the above, but a newly appeared error that wraps the + // above. As far as I can tell, no error variables are available to use + // with `errors.Is`. + if err.Error() == "failed to deallocate cached statement(s): conn closed" { + return + } + + require.NoError(tb, err) + }) + + return tx +} + // TimeStub implements baseservice.TimeGeneratorWithStub to allow time to be // stubbed in tests. // diff --git a/rivershared/riversharedtest/riversharedtest_test.go b/rivershared/riversharedtest/riversharedtest_test.go index de7b5214..51aee978 100644 --- a/rivershared/riversharedtest/riversharedtest_test.go +++ b/rivershared/riversharedtest/riversharedtest_test.go @@ -1,12 +1,67 @@ package riversharedtest import ( + "context" "testing" "time" + "github.com/jackc/pgerrcode" + "github.com/jackc/pgx/v5/pgconn" "github.com/stretchr/testify/require" ) +func TestDBPool(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + pool := DBPool(ctx, t) + _, err := pool.Exec(ctx, "SELECT 1") + require.NoError(t, err) +} + +func TestTestTx(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + type PoolOrTx interface { + Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) + } + + checkTestTable := func(ctx context.Context, poolOrTx PoolOrTx) error { + _, err := poolOrTx.Exec(ctx, "SELECT * FROM river_shared_test_tx_table") + return err + } + + // Test cleanups are invoked in the order of last added, first called. When + // TestTx is called below it adds a cleanup, so we want to make sure that + // this cleanup, which checks that the database remains pristine, is invoked + // after the TestTx cleanup, so we add it first. + t.Cleanup(func() { + // Tests may inherit context from `t.Context()` which is cancelled after + // tests run and before calling clean up. We need a non-cancelled + // context to issue rollback here, so use a bit of a bludgeon to do so + // with `context.WithoutCancel()`. + ctx := context.WithoutCancel(ctx) + + err := checkTestTable(ctx, DBPool(ctx, t)) + require.Error(t, err) + + var pgErr *pgconn.PgError + require.ErrorAs(t, err, &pgErr) + require.Equal(t, pgerrcode.UndefinedTable, pgErr.Code) + }) + + tx := TestTx(ctx, t) + + _, err := tx.Exec(ctx, "CREATE TABLE river_shared_test_tx_table (id bigint)") + require.NoError(t, err) + + err = checkTestTable(ctx, tx) + require.NoError(t, err) +} + func TestWaitOrTimeout(t *testing.T) { t.Parallel()