diff --git a/internal/pgcompare/config.go b/internal/pgcompare/config.go index dbf7a6b..b5b4202 100644 --- a/internal/pgcompare/config.go +++ b/internal/pgcompare/config.go @@ -2,6 +2,7 @@ package pgcompare import ( "fmt" + "net/url" "os" "path/filepath" "strings" @@ -130,11 +131,11 @@ func buildDSN() string { if port == "" { port = "5432" } - return fmt.Sprintf( - "postgres://%s:%s@localhost:%s/%s", - os.Getenv("POSTGRES_USER"), - os.Getenv("POSTGRES_PASSWORD"), - port, - os.Getenv("POSTGRES_DB"), - ) + u := url.URL{ + Scheme: "postgres", + User: url.UserPassword(os.Getenv("POSTGRES_USER"), os.Getenv("POSTGRES_PASSWORD")), + Host: "localhost:" + port, + Path: "/" + os.Getenv("POSTGRES_DB"), + } + return u.String() } diff --git a/internal/pgcompare/config_test.go b/internal/pgcompare/config_test.go index 3e5e9be..7cc08a0 100644 --- a/internal/pgcompare/config_test.go +++ b/internal/pgcompare/config_test.go @@ -104,27 +104,35 @@ func TestConfigValidate(t *testing.T) { func TestBuildDSN(t *testing.T) { tests := []struct { - name string - port string - want string + name, user, pass, db, port, want string }{ { name: "all vars present", - port: "9999", + user: "u", pass: "p", db: "d", port: "9999", want: "postgres://u:p@localhost:9999/d", }, { name: "default port", - port: "", + user: "u", pass: "p", db: "d", port: "", want: "postgres://u:p@localhost:5432/d", }, + { + name: "password with special chars is escaped", + user: "u", pass: "p@ss:w/ord#1", db: "d", port: "5432", + want: "postgres://u:p%40ss%3Aw%2Ford%231@localhost:5432/d", + }, + { + name: "user with special chars is escaped", + user: "u@admin", pass: "p", db: "d", port: "5432", + want: "postgres://u%40admin:p@localhost:5432/d", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - t.Setenv("POSTGRES_USER", "u") - t.Setenv("POSTGRES_PASSWORD", "p") - t.Setenv("POSTGRES_DB", "d") + t.Setenv("POSTGRES_USER", tt.user) + t.Setenv("POSTGRES_PASSWORD", tt.pass) + t.Setenv("POSTGRES_DB", tt.db) t.Setenv("POSTGRES_PORT", tt.port) assert.Equal(t, tt.want, buildDSN())