From 52fd81b5472f4d57c9673b1cd9aa32afc511e8ac Mon Sep 17 00:00:00 2001 From: redkarasik Date: Sun, 19 Apr 2026 21:03:53 +0300 Subject: [PATCH] fix: add graceful shutdown for docker cleanup. --- main.go | 18 +++- tests/e2e_run_test.go | 222 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 236 insertions(+), 4 deletions(-) diff --git a/main.go b/main.go index 54d6826..b115ec2 100644 --- a/main.go +++ b/main.go @@ -2,10 +2,13 @@ package main import ( "context" + "errors" "fmt" "log/slog" "os" + "os/signal" "path/filepath" + "syscall" "time" _ "github.com/jackc/pgx/v5/stdlib" @@ -23,7 +26,12 @@ var ( ) func main() { - if err := rootCmd.Execute(); err != nil { + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + if err := rootCmd.ExecuteContext(ctx); err != nil { + if errors.Is(err, context.Canceled) { + os.Exit(130) + } os.Exit(1) } } @@ -48,7 +56,7 @@ func init() { rootCmd.AddCommand(runCmd) } -func runBenchmark(_ *cobra.Command, _ []string) error { +func runBenchmark(cmd *cobra.Command, _ []string) error { logLevel := slog.LevelWarn if flagVerbose { logLevel = slog.LevelInfo @@ -91,9 +99,11 @@ func runBenchmark(_ *cobra.Command, _ []string) error { return fmt.Errorf("create docker comparator: %w", err) } - ctx := context.Background() + ctx := cmd.Context() defer func() { - if err := docker.Cleanup(ctx); err != nil { + cleanupCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if err := docker.Cleanup(cleanupCtx); err != nil { log.Error("final cleanup failed", "err", err) } }() diff --git a/tests/e2e_run_test.go b/tests/e2e_run_test.go index 148c64c..d35fbb3 100644 --- a/tests/e2e_run_test.go +++ b/tests/e2e_run_test.go @@ -3,6 +3,7 @@ package tests import ( + "bytes" "fmt" "math" "net" @@ -13,7 +14,9 @@ import ( "runtime" "strconv" "strings" + "sync" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -93,6 +96,78 @@ func TestRunCommandE2E(t *testing.T) { assert.Contains(t, diff.AfterText, "Index") } +func TestRunCommandSIGINTTriggersCleanup(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("integration test uses POSIX shell scripts") + } + + repoRoot := repoRoot(t) + composeCmd := requireDockerCompose(t) + requireDockerDaemon(t, composeCmd) + + binaryPath := buildBinary(t, repoRoot) + projectDir := t.TempDir() + port := reservePort(t) + reportPath := filepath.Join(projectDir, "report.html") + composeProject := "pgcompareshutdown" + port + + writeExecutable(t, filepath.Join(projectDir, "setup.sh"), setupScript()) + writeFile(t, filepath.Join(projectDir, ".env"), testEnvFile(port)) + writeFile(t, filepath.Join(projectDir, "pgcompare.yaml"), interruptConfigYAML()) + writeFile(t, filepath.Join(projectDir, "docker-compose.yml"), dockerComposeWithNamedVolumeYAML()) + writeFile(t, filepath.Join(projectDir, "queries_before.sql"), slowQueriesSQL()) + writeFile(t, filepath.Join(projectDir, "queries_after.sql"), slowQueriesSQL()) + writeFile(t, filepath.Join(projectDir, "schema.sql"), testSchemaSQL()) + + cmdEnv := append(os.Environ(), "COMPOSE_PROJECT_NAME="+composeProject) + t.Cleanup(func() { + cleanupDockerProject(projectDir, cmdEnv, composeCmd) + }) + + cmd := exec.Command( + binaryPath, + "run", + "--config", filepath.Join(projectDir, "pgcompare.yaml"), + "--out", reportPath, + "--verbose", + ) + cmd.Dir = projectDir + cmd.Env = cmdEnv + + var stdout lockedBuffer + var stderr lockedBuffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + require.NoError(t, cmd.Start()) + + waitCh := make(chan error, 1) + go func() { + waitCh <- cmd.Wait() + }() + + waitForBenchmarkStart(t, composeProject, waitCh, &stdout, &stderr) + require.NoError(t, cmd.Process.Signal(os.Interrupt)) + + err := waitForCommandExit(t, waitCh, &stdout, &stderr) + var exitErr *exec.ExitError + require.ErrorAs(t, err, &exitErr, combinedOutput(&stdout, &stderr)) + assert.Equal(t, 130, exitErr.ExitCode(), combinedOutput(&stdout, &stderr)) + assert.NoFileExists(t, reportPath) + + assert.Eventually(t, func() bool { + containers, err := dockerProjectContainers(composeProject) + if err != nil { + return false + } + volumes, err := dockerProjectVolumes(composeProject) + if err != nil { + return false + } + return len(containers) == 0 && len(volumes) == 0 + }, 10*time.Second, 200*time.Millisecond, combinedOutput(&stdout, &stderr)) +} + func repoRoot(t *testing.T) string { t.Helper() @@ -175,6 +250,23 @@ type planDiff struct { AfterText string } +type lockedBuffer struct { + mu sync.Mutex + buf bytes.Buffer +} + +func (b *lockedBuffer) Write(p []byte) (int, error) { + b.mu.Lock() + defer b.mu.Unlock() + return b.buf.Write(p) +} + +func (b *lockedBuffer) String() string { + b.mu.Lock() + defer b.mu.Unlock() + return b.buf.String() +} + func extractReportStat(t *testing.T, html, phase, query string) reportStat { t.Helper() @@ -383,6 +475,26 @@ report: ` + "\n") } +func interruptConfigYAML() string { + return strings.TrimSpace(` +migration: + env_var: MIGRATION_VERSION + before_version: "1" + after_version: "2" + +setup: + command: "./setup.sh" + +benchmark: + before_queries: queries_before.sql + after_queries: queries_after.sql + warmup_iterations: 0 + iterations: 1 + concurrency: 1 + repeats: 1 +` + "\n") +} + func dockerComposeYAML() string { return strings.TrimSpace(` services: @@ -402,6 +514,30 @@ services: ` + "\n") } +func dockerComposeWithNamedVolumeYAML() string { + return strings.TrimSpace(` +services: + postgres: + image: postgres:17-alpine + ports: + - "127.0.0.1:${POSTGRES_PORT}:5432" + environment: + POSTGRES_USER: ${POSTGRES_USER} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD} + POSTGRES_DB: ${POSTGRES_DB} + volumes: + - pgdata:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U $$POSTGRES_USER -d $$POSTGRES_DB"] + interval: 1s + timeout: 5s + retries: 30 + +volumes: + pgdata: +` + "\n") +} + func testQueriesSQL() string { return strings.TrimSpace(` -- name: find_active_users @@ -413,6 +549,13 @@ LIMIT 5; ` + "\n") } +func slowQueriesSQL() string { + return strings.TrimSpace(` +-- name: wait_for_interrupt +SELECT pg_sleep(10); +` + "\n") +} + func testSchemaSQL() string { return strings.TrimSpace(` CREATE TABLE users ( @@ -430,3 +573,82 @@ SELECT FROM generate_series(1, 100000) AS gs; ` + "\n") } + +func waitForBenchmarkStart(t *testing.T, composeProject string, waitCh <-chan error, stdout, stderr *lockedBuffer) { + t.Helper() + + deadline := time.Now().Add(60 * time.Second) + for time.Now().Before(deadline) { + select { + case err := <-waitCh: + t.Fatalf("pgcompare exited before benchmark started: %v\n%s", err, combinedOutput(stdout, stderr)) + default: + } + + output := combinedOutput(stdout, stderr) + containers, err := dockerProjectContainers(composeProject) + require.NoError(t, err, output) + volumes, err := dockerProjectVolumes(composeProject) + require.NoError(t, err, output) + + if len(containers) > 0 && len(volumes) > 0 && strings.Contains(output, "Benchmarking 'before'") { + time.Sleep(500 * time.Millisecond) + return + } + + time.Sleep(200 * time.Millisecond) + } + + t.Fatalf("timed out waiting for benchmark start\n%s", combinedOutput(stdout, stderr)) +} + +func waitForCommandExit(t *testing.T, waitCh <-chan error, stdout, stderr *lockedBuffer) error { + t.Helper() + + select { + case err := <-waitCh: + return err + case <-time.After(45 * time.Second): + t.Fatalf("pgcompare did not exit after SIGINT\n%s", combinedOutput(stdout, stderr)) + return nil + } +} + +func dockerProjectContainers(composeProject string) ([]string, error) { + return dockerLines("docker", "ps", "-aq", "--filter", "label=com.docker.compose.project="+composeProject) +} + +func dockerProjectVolumes(composeProject string) ([]string, error) { + return dockerLines("docker", "volume", "ls", "-q", "--filter", "label=com.docker.compose.project="+composeProject) +} + +func dockerLines(args ...string) ([]string, error) { + cmd := exec.Command(args[0], args[1:]...) + output, err := cmd.CombinedOutput() + if err != nil { + return nil, fmt.Errorf("%s: %w: %s", strings.Join(args, " "), err, strings.TrimSpace(string(output))) + } + + text := strings.TrimSpace(string(output)) + if text == "" { + return nil, nil + } + + return strings.Fields(text), nil +} + +func combinedOutput(stdout, stderr *lockedBuffer) string { + out := strings.TrimSpace(stdout.String()) + err := strings.TrimSpace(stderr.String()) + + switch { + case out == "" && err == "": + return "" + case out == "": + return "stderr:\n" + err + case err == "": + return "stdout:\n" + out + default: + return "stdout:\n" + out + "\n\nstderr:\n" + err + } +}