Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ package main

import (
"context"
"errors"
"fmt"
"log/slog"
"os"
"os/signal"
"path/filepath"
"syscall"
"time"

_ "github.com/jackc/pgx/v5/stdlib"
Expand All @@ -23,7 +26,12 @@ var (
)

func main() {
if err := rootCmd.Execute(); err != nil {
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
if err := rootCmd.ExecuteContext(ctx); err != nil {
if errors.Is(err, context.Canceled) {
os.Exit(130)
}
os.Exit(1)
}
}
Expand All @@ -48,7 +56,7 @@ func init() {
rootCmd.AddCommand(runCmd)
}

func runBenchmark(_ *cobra.Command, _ []string) error {
func runBenchmark(cmd *cobra.Command, _ []string) error {
logLevel := slog.LevelWarn
if flagVerbose {
logLevel = slog.LevelInfo
Expand Down Expand Up @@ -91,9 +99,11 @@ func runBenchmark(_ *cobra.Command, _ []string) error {
return fmt.Errorf("create docker comparator: %w", err)
}

ctx := context.Background()
ctx := cmd.Context()
defer func() {
if err := docker.Cleanup(ctx); err != nil {
cleanupCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := docker.Cleanup(cleanupCtx); err != nil {
log.Error("final cleanup failed", "err", err)
}
}()
Expand Down
222 changes: 222 additions & 0 deletions tests/e2e_run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package tests

import (
"bytes"
"fmt"
"math"
"net"
Expand All @@ -13,7 +14,9 @@ import (
"runtime"
"strconv"
"strings"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -93,6 +96,78 @@ func TestRunCommandE2E(t *testing.T) {
assert.Contains(t, diff.AfterText, "Index")
}

func TestRunCommandSIGINTTriggersCleanup(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("integration test uses POSIX shell scripts")
}

repoRoot := repoRoot(t)
composeCmd := requireDockerCompose(t)
requireDockerDaemon(t, composeCmd)

binaryPath := buildBinary(t, repoRoot)
projectDir := t.TempDir()
port := reservePort(t)
reportPath := filepath.Join(projectDir, "report.html")
composeProject := "pgcompareshutdown" + port

writeExecutable(t, filepath.Join(projectDir, "setup.sh"), setupScript())
writeFile(t, filepath.Join(projectDir, ".env"), testEnvFile(port))
writeFile(t, filepath.Join(projectDir, "pgcompare.yaml"), interruptConfigYAML())
writeFile(t, filepath.Join(projectDir, "docker-compose.yml"), dockerComposeWithNamedVolumeYAML())
writeFile(t, filepath.Join(projectDir, "queries_before.sql"), slowQueriesSQL())
writeFile(t, filepath.Join(projectDir, "queries_after.sql"), slowQueriesSQL())
writeFile(t, filepath.Join(projectDir, "schema.sql"), testSchemaSQL())

cmdEnv := append(os.Environ(), "COMPOSE_PROJECT_NAME="+composeProject)
t.Cleanup(func() {
cleanupDockerProject(projectDir, cmdEnv, composeCmd)
})

cmd := exec.Command(
binaryPath,
"run",
"--config", filepath.Join(projectDir, "pgcompare.yaml"),
"--out", reportPath,
"--verbose",
)
cmd.Dir = projectDir
cmd.Env = cmdEnv

var stdout lockedBuffer
var stderr lockedBuffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr

require.NoError(t, cmd.Start())

waitCh := make(chan error, 1)
go func() {
waitCh <- cmd.Wait()
}()

waitForBenchmarkStart(t, composeProject, waitCh, &stdout, &stderr)
require.NoError(t, cmd.Process.Signal(os.Interrupt))

err := waitForCommandExit(t, waitCh, &stdout, &stderr)
var exitErr *exec.ExitError
require.ErrorAs(t, err, &exitErr, combinedOutput(&stdout, &stderr))
assert.Equal(t, 130, exitErr.ExitCode(), combinedOutput(&stdout, &stderr))
assert.NoFileExists(t, reportPath)

assert.Eventually(t, func() bool {
containers, err := dockerProjectContainers(composeProject)
if err != nil {
return false
}
volumes, err := dockerProjectVolumes(composeProject)
if err != nil {
return false
}
return len(containers) == 0 && len(volumes) == 0
}, 10*time.Second, 200*time.Millisecond, combinedOutput(&stdout, &stderr))
}

func repoRoot(t *testing.T) string {
t.Helper()

Expand Down Expand Up @@ -175,6 +250,23 @@ type planDiff struct {
AfterText string
}

type lockedBuffer struct {
mu sync.Mutex
buf bytes.Buffer
}

func (b *lockedBuffer) Write(p []byte) (int, error) {
b.mu.Lock()
defer b.mu.Unlock()
return b.buf.Write(p)
}

func (b *lockedBuffer) String() string {
b.mu.Lock()
defer b.mu.Unlock()
return b.buf.String()
}

func extractReportStat(t *testing.T, html, phase, query string) reportStat {
t.Helper()

Expand Down Expand Up @@ -383,6 +475,26 @@ report:
` + "\n")
}

func interruptConfigYAML() string {
return strings.TrimSpace(`
migration:
env_var: MIGRATION_VERSION
before_version: "1"
after_version: "2"

setup:
command: "./setup.sh"

benchmark:
before_queries: queries_before.sql
after_queries: queries_after.sql
warmup_iterations: 0
iterations: 1
concurrency: 1
repeats: 1
` + "\n")
}

func dockerComposeYAML() string {
return strings.TrimSpace(`
services:
Expand All @@ -402,6 +514,30 @@ services:
` + "\n")
}

func dockerComposeWithNamedVolumeYAML() string {
return strings.TrimSpace(`
services:
postgres:
image: postgres:17-alpine
ports:
- "127.0.0.1:${POSTGRES_PORT}:5432"
environment:
POSTGRES_USER: ${POSTGRES_USER}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
POSTGRES_DB: ${POSTGRES_DB}
volumes:
- pgdata:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U $$POSTGRES_USER -d $$POSTGRES_DB"]
interval: 1s
timeout: 5s
retries: 30

volumes:
pgdata:
` + "\n")
}

func testQueriesSQL() string {
return strings.TrimSpace(`
-- name: find_active_users
Expand All @@ -413,6 +549,13 @@ LIMIT 5;
` + "\n")
}

func slowQueriesSQL() string {
return strings.TrimSpace(`
-- name: wait_for_interrupt
SELECT pg_sleep(10);
` + "\n")
}

func testSchemaSQL() string {
return strings.TrimSpace(`
CREATE TABLE users (
Expand All @@ -430,3 +573,82 @@ SELECT
FROM generate_series(1, 100000) AS gs;
` + "\n")
}

func waitForBenchmarkStart(t *testing.T, composeProject string, waitCh <-chan error, stdout, stderr *lockedBuffer) {
t.Helper()

deadline := time.Now().Add(60 * time.Second)
for time.Now().Before(deadline) {
select {
case err := <-waitCh:
t.Fatalf("pgcompare exited before benchmark started: %v\n%s", err, combinedOutput(stdout, stderr))
default:
}

output := combinedOutput(stdout, stderr)
containers, err := dockerProjectContainers(composeProject)
require.NoError(t, err, output)
volumes, err := dockerProjectVolumes(composeProject)
require.NoError(t, err, output)

if len(containers) > 0 && len(volumes) > 0 && strings.Contains(output, "Benchmarking 'before'") {
time.Sleep(500 * time.Millisecond)
return
}

time.Sleep(200 * time.Millisecond)
}

t.Fatalf("timed out waiting for benchmark start\n%s", combinedOutput(stdout, stderr))
}

func waitForCommandExit(t *testing.T, waitCh <-chan error, stdout, stderr *lockedBuffer) error {
t.Helper()

select {
case err := <-waitCh:
return err
case <-time.After(45 * time.Second):
t.Fatalf("pgcompare did not exit after SIGINT\n%s", combinedOutput(stdout, stderr))
return nil
}
}

func dockerProjectContainers(composeProject string) ([]string, error) {
return dockerLines("docker", "ps", "-aq", "--filter", "label=com.docker.compose.project="+composeProject)
}

func dockerProjectVolumes(composeProject string) ([]string, error) {
return dockerLines("docker", "volume", "ls", "-q", "--filter", "label=com.docker.compose.project="+composeProject)
}

func dockerLines(args ...string) ([]string, error) {
cmd := exec.Command(args[0], args[1:]...)
output, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("%s: %w: %s", strings.Join(args, " "), err, strings.TrimSpace(string(output)))
}

text := strings.TrimSpace(string(output))
if text == "" {
return nil, nil
}

return strings.Fields(text), nil
}

func combinedOutput(stdout, stderr *lockedBuffer) string {
out := strings.TrimSpace(stdout.String())
err := strings.TrimSpace(stderr.String())

switch {
case out == "" && err == "":
return ""
case out == "":
return "stderr:\n" + err
case err == "":
return "stdout:\n" + out
default:
return "stdout:\n" + out + "\n\nstderr:\n" + err
}
}