diff --git a/cmd/sqlc-test-setup/main.go b/cmd/sqlc-test-setup/main.go index 3a816f4502..2a0d04dc5b 100644 --- a/cmd/sqlc-test-setup/main.go +++ b/cmd/sqlc-test-setup/main.go @@ -1,14 +1,43 @@ package main import ( + "crypto/sha256" + "encoding/hex" "fmt" + "io" "log" + "net/http" "os" "os/exec" + "path/filepath" + "runtime" "strings" "time" ) +const ( + // pgVersion is the PostgreSQL version to install. + pgVersion = "18.2.0" +) + +// pgBinary contains the download information for a PostgreSQL binary release. +type pgBinary struct { + URL string + SHA256 string +} + +// pgBinaries maps "/" to the corresponding binary download info. +var pgBinaries = map[string]pgBinary{ + "linux/amd64": { + URL: "https://github.com/theseus-rs/postgresql-binaries/releases/download/" + pgVersion + "/postgresql-" + pgVersion + "-x86_64-unknown-linux-gnu.tar.gz", + SHA256: "cc2674e1641aa2a62b478971a22c131a768eb783f313e6a3385888f58a604074", + }, + "linux/arm64": { + URL: "https://github.com/theseus-rs/postgresql-binaries/releases/download/" + pgVersion + "/postgresql-" + pgVersion + "-aarch64-unknown-linux-gnu.tar.gz", + SHA256: "8b415a11c7a5484e5fbf7a57fca71554d2d1d7acd34faf066606d2fee1261854", + }, +} + func main() { log.SetFlags(log.Ltime) log.SetPrefix("[sqlc-test-setup] ") @@ -77,6 +106,31 @@ func isMySQLVersionOK(versionOutput string) bool { return false } +// pgBaseDir returns the sqlc-specific directory where PostgreSQL is installed, +// using the user's cache directory (~/.cache/sqlc/postgresql on Linux). +func pgBaseDir() string { + cacheDir, err := os.UserCacheDir() + if err != nil { + cacheDir = filepath.Join(os.Getenv("HOME"), ".cache") + } + return filepath.Join(cacheDir, "sqlc", "postgresql") +} + +// pgBinDir returns the path to the PostgreSQL bin directory. +func pgBinDir() string { + return filepath.Join(pgBaseDir(), "bin") +} + +// pgDataDir returns the path to the PostgreSQL data directory. +func pgDataDir() string { + return filepath.Join(pgBaseDir(), "data") +} + +// pgBin returns the full path to a PostgreSQL binary. +func pgBin(name string) string { + return filepath.Join(pgBinDir(), name) +} + // ---- install ---- func runInstall() error { @@ -120,8 +174,15 @@ func installAptProxy() error { func installPostgreSQL() error { log.Println("--- Installing PostgreSQL ---") - if commandExists("psql") { - out, err := runOutput("psql", "--version") + // Install runtime dependencies needed by PostgreSQL extensions (e.g. + // uuid-ossp requires libossp-uuid16). + if err := installPgDeps(); err != nil { + return fmt.Errorf("installing postgresql dependencies: %w", err) + } + + // Check if already installed in our directory + if _, err := os.Stat(pgBin("postgres")); err == nil { + out, err := runOutput(pgBin("postgres"), "--version") if err == nil { log.Printf("postgresql is already installed: %s", strings.TrimSpace(out)) log.Println("skipping postgresql installation") @@ -129,20 +190,117 @@ func installPostgreSQL() error { } } - log.Println("updating apt package lists") - if err := run("sudo", "apt-get", "update", "-qq"); err != nil { - return fmt.Errorf("apt-get update: %w", err) + platform := runtime.GOOS + "/" + runtime.GOARCH + bin, ok := pgBinaries[platform] + if !ok { + return fmt.Errorf("unsupported platform: %s (supported: %s)", platform, supportedPlatforms()) } - log.Println("installing postgresql package") - if err := run("sudo", "apt-get", "install", "-y", "-qq", "postgresql"); err != nil { - return fmt.Errorf("apt-get install postgresql: %w", err) + // Download to a temp file + tarball := filepath.Join(os.TempDir(), fmt.Sprintf("postgresql-%s.tar.gz", pgVersion)) + + if _, err := os.Stat(tarball); err != nil { + log.Printf("downloading PostgreSQL %s from %s", pgVersion, bin.URL) + if err := downloadFile(tarball, bin.URL); err != nil { + os.Remove(tarball) + return fmt.Errorf("downloading postgresql: %w", err) + } + } else { + log.Printf("postgresql tarball already downloaded at %s", tarball) } - log.Println("postgresql installed successfully") + // Verify SHA256 checksum + log.Printf("verifying SHA256 checksum") + actualHash, err := sha256File(tarball) + if err != nil { + return fmt.Errorf("computing sha256: %w", err) + } + if actualHash != bin.SHA256 { + os.Remove(tarball) + return fmt.Errorf("SHA256 mismatch: expected %s, got %s", bin.SHA256, actualHash) + } + log.Printf("SHA256 checksum verified: %s", actualHash) + + baseDir := pgBaseDir() + + // Create the base directory in the user cache + if err := os.MkdirAll(baseDir, 0o755); err != nil { + return fmt.Errorf("creating %s: %w", baseDir, err) + } + + // Extract the tarball - it contains a top-level directory like + // postgresql-18.2.0-x86_64-unknown-linux-gnu/ with bin/, lib/, share/ inside. + // We strip that top-level directory and extract directly into the base dir. + log.Printf("extracting postgresql to %s", baseDir) + if err := run("tar", "-xzf", tarball, "-C", baseDir, "--strip-components=1"); err != nil { + return fmt.Errorf("extracting postgresql: %w", err) + } + + // Verify the binary works + out, err := runOutput(pgBin("postgres"), "--version") + if err != nil { + return fmt.Errorf("postgres --version failed after install: %w", err) + } + log.Printf("postgresql installed successfully: %s", strings.TrimSpace(out)) + return nil +} + +// installPgDeps installs shared libraries required by PostgreSQL extensions at +// runtime (e.g. libossp-uuid16 for uuid-ossp). +func installPgDeps() error { + log.Println("installing postgresql runtime dependencies") + if err := run("sudo", "apt-get", "install", "-y", "--no-install-recommends", "libossp-uuid16"); err != nil { + return fmt.Errorf("apt-get install libossp-uuid16: %w", err) + } return nil } +// supportedPlatforms returns a comma-separated list of supported platforms. +func supportedPlatforms() string { + platforms := make([]string, 0, len(pgBinaries)) + for p := range pgBinaries { + platforms = append(platforms, p) + } + return strings.Join(platforms, ", ") +} + +// downloadFile downloads a URL to a local file path. +func downloadFile(filepath string, url string) error { + resp, err := http.Get(url) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("HTTP %d: %s", resp.StatusCode, resp.Status) + } + + out, err := os.Create(filepath) + if err != nil { + return err + } + defer out.Close() + + _, err = io.Copy(out, resp.Body) + return err +} + +// sha256File computes the SHA256 hash of a file and returns the hex string. +func sha256File(path string) (string, error) { + f, err := os.Open(path) + if err != nil { + return "", err + } + defer f.Close() + + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return "", err + } + return hex.EncodeToString(h.Sum(nil)), nil +} + func installMySQL() error { log.Println("--- Installing MySQL 9 ---") @@ -246,33 +404,98 @@ func runStart() error { func startPostgreSQL() error { log.Println("--- Starting PostgreSQL ---") - log.Println("starting postgresql service") - if err := run("sudo", "service", "postgresql", "start"); err != nil { - return fmt.Errorf("service postgresql start: %w", err) + dataDir := pgDataDir() + logFile := filepath.Join(pgBaseDir(), "postgresql.log") + + // Check if already running + if pgIsReady() { + log.Println("postgresql is already running and accepting connections") + return nil } - log.Println("setting password for postgres user") - if err := run("sudo", "-u", "postgres", "psql", "-c", "ALTER USER postgres PASSWORD 'postgres';"); err != nil { - return fmt.Errorf("setting postgres password: %w", err) + // Initialize data directory if needed + if _, err := os.Stat(filepath.Join(dataDir, "PG_VERSION")); os.IsNotExist(err) { + log.Println("initializing postgresql data directory") + if err := os.MkdirAll(dataDir, 0o700); err != nil { + return fmt.Errorf("creating data directory: %w", err) + } + if err := run(pgBin("initdb"), + "-D", dataDir, + "--username=postgres", + "--auth=trust", + ); err != nil { + return fmt.Errorf("initdb: %w", err) + } + + // Configure pg_hba.conf for md5 password authentication on TCP + hbaPath := filepath.Join(dataDir, "pg_hba.conf") + if err := configurePgHBA(hbaPath); err != nil { + return fmt.Errorf("configuring pg_hba.conf: %w", err) + } + + // Configure postgresql.conf to listen on localhost + confPath := filepath.Join(dataDir, "postgresql.conf") + if err := appendToFile(confPath, + "\n# sqlc-test-setup configuration\n"+ + "listen_addresses = '127.0.0.1'\n"+ + "port = 5432\n", + ); err != nil { + return fmt.Errorf("configuring postgresql.conf: %w", err) + } + } else { + log.Println("postgresql data directory already initialized") } - log.Println("detecting postgresql config directory") - hbaPath, err := detectPgHBAPath() - if err != nil { - return fmt.Errorf("detecting pg_hba.conf path: %w", err) + // Start PostgreSQL using pg_ctl + log.Println("starting postgresql") + if err := run(pgBin("pg_ctl"), + "-D", dataDir, + "-l", logFile, + "-o", fmt.Sprintf("-k %s", dataDir), + "start", + ); err != nil { + return fmt.Errorf("pg_ctl start: %w", err) + } + + // Wait for PostgreSQL to be ready + log.Println("waiting for postgresql to accept connections") + if err := waitForPostgreSQL(30 * time.Second); err != nil { + return fmt.Errorf("postgresql did not start in time: %w", err) + } + + // Set the postgres user password + log.Println("setting password for postgres user") + if err := run(pgBin("psql"), + "-h", "127.0.0.1", + "-U", "postgres", + "-c", "ALTER USER postgres PASSWORD 'postgres';", + ); err != nil { + return fmt.Errorf("setting postgres password: %w", err) } - if err := ensurePgHBAEntry(hbaPath); err != nil { - return fmt.Errorf("configuring pg_hba.conf: %w", err) + // Update pg_hba.conf to require md5 auth now that password is set + hbaPath := filepath.Join(dataDir, "pg_hba.conf") + if err := configurePgHBAWithMD5(hbaPath); err != nil { + return fmt.Errorf("updating pg_hba.conf for md5: %w", err) } + // Reload configuration log.Println("reloading postgresql configuration") - if err := run("sudo", "service", "postgresql", "reload"); err != nil { - return fmt.Errorf("reloading postgresql: %w", err) + if err := run(pgBin("pg_ctl"), "-D", dataDir, "reload"); err != nil { + return fmt.Errorf("pg_ctl reload: %w", err) } + // Verify connection with password log.Println("verifying postgresql connection") - if err := run("bash", "-c", "PGPASSWORD=postgres psql -h 127.0.0.1 -U postgres -c 'SELECT 1;'"); err != nil { + cmd := exec.Command(pgBin("psql"), + "-h", "127.0.0.1", + "-U", "postgres", + "-c", "SELECT 1;", + ) + cmd.Env = append(os.Environ(), "PGPASSWORD=postgres") + cmd.Stdout = os.Stderr + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { return fmt.Errorf("postgresql connection test failed: %w", err) } @@ -280,37 +503,56 @@ func startPostgreSQL() error { return nil } -// detectPgHBAPath finds the pg_hba.conf file across different PostgreSQL versions. -func detectPgHBAPath() (string, error) { - out, err := runOutput("bash", "-c", "sudo -u postgres psql -t -c 'SHOW hba_file;'") - if err != nil { - return "", fmt.Errorf("querying hba_file: %w (output: %s)", err, out) - } - path := strings.TrimSpace(out) - if path == "" { - return "", fmt.Errorf("pg_hba.conf path is empty") - } - log.Printf("found pg_hba.conf at %s", path) - return path, nil +// configurePgHBA writes a pg_hba.conf that allows trust auth initially (for +// setting the password), then we switch to md5. +func configurePgHBA(hbaPath string) error { + content := `# pg_hba.conf - generated by sqlc-test-setup +# TYPE DATABASE USER ADDRESS METHOD +local all all trust +host all all 127.0.0.1/32 trust +host all all ::1/128 trust +` + return os.WriteFile(hbaPath, []byte(content), 0o600) } -// ensurePgHBAEntry adds the md5 auth line to pg_hba.conf if it's not already present. -func ensurePgHBAEntry(hbaPath string) error { - hbaLine := "host all all 127.0.0.1/32 md5" +// configurePgHBAWithMD5 rewrites pg_hba.conf to use md5 for TCP connections. +func configurePgHBAWithMD5(hbaPath string) error { + content := `# pg_hba.conf - generated by sqlc-test-setup +# TYPE DATABASE USER ADDRESS METHOD +local all all trust +host all all 127.0.0.1/32 md5 +host all all ::1/128 md5 +` + return os.WriteFile(hbaPath, []byte(content), 0o600) +} - out, err := runOutput("sudo", "cat", hbaPath) +// appendToFile appends text to a file. +func appendToFile(path, text string) error { + f, err := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0o644) if err != nil { - return fmt.Errorf("reading pg_hba.conf: %w", err) + return err } + defer f.Close() + _, err = f.WriteString(text) + return err +} - if strings.Contains(out, "127.0.0.1/32 md5") { - log.Println("md5 authentication for 127.0.0.1/32 already configured in pg_hba.conf, skipping") - return nil - } +// pgIsReady checks if PostgreSQL is running and accepting connections. +func pgIsReady() bool { + cmd := exec.Command(pgBin("pg_isready"), "-h", "127.0.0.1", "-p", "5432") + return cmd.Run() == nil +} - log.Printf("enabling md5 authentication in %s", hbaPath) - cmd := fmt.Sprintf("echo '%s' | sudo tee -a %s", hbaLine, hbaPath) - return run("bash", "-c", cmd) +// waitForPostgreSQL polls until PostgreSQL accepts connections or times out. +func waitForPostgreSQL(timeout time.Duration) error { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if pgIsReady() { + return nil + } + time.Sleep(500 * time.Millisecond) + } + return fmt.Errorf("timed out after %s waiting for postgresql", timeout) } func startMySQL() error {