diff --git a/config/config.go b/config/config.go index 768576123..c4db67158 100644 --- a/config/config.go +++ b/config/config.go @@ -8,6 +8,7 @@ import ( "net/http" "net/url" "reflect" + "strconv" "strings" "sync" "time" @@ -625,6 +626,20 @@ func (c *Config) fixHostIfNeeded() error { if parsedHost.Hostname() == "" { return ErrNoHostConfigured } + // SPOG URLs pasted from the Databricks UI carry the workspace ID as + // ?o= (or ?workspace_id=) and the account ID as ?a= (or ?account_id=). + // Promote those into WorkspaceID/AccountID before we strip the query, + // so requests get the X-Databricks-Org-Id header instead of hitting the + // SPOG without routing and getting back the login HTML page. + if parsedHost.RawQuery != "" { + q := parsedHost.Query() + if c.WorkspaceID == "" { + c.WorkspaceID = workspaceIDFromQuery(q) + } + if c.AccountID == "" { + c.AccountID = accountIDFromQuery(q) + } + } // Create new instance to ensure other fields are initialized as empty. parsedHost = &url.URL{ Scheme: parsedHost.Scheme, @@ -635,6 +650,28 @@ func (c *Config) fixHostIfNeeded() error { return nil } +func workspaceIDFromQuery(q url.Values) string { + for _, key := range []string{"o", "workspace_id"} { + v := q.Get(key) + if v == "" { + continue + } + if _, err := strconv.ParseInt(v, 10, 64); err == nil { + return v + } + } + return "" +} + +func accountIDFromQuery(q url.Values) string { + for _, key := range []string{"a", "account_id"} { + if v := q.Get(key); v != "" { + return v + } + } + return "" +} + // ErrNoHostConfigured is the error returned when a user tries to authenticate // without a host configured. Applications can check for this error to provide // more user-friendly error messages. diff --git a/config/config_test.go b/config/config_test.go index 05f57f280..eaf8661a6 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -1307,3 +1307,79 @@ func TestDefaultHostMetadataResolverFactory_NilResolverFromFactoryFallsThroughTo assert.Equal(t, testHMAccountID, cfg.AccountID) } + +func TestConfig_fixHostIfNeeded_extractsWorkspaceIDFromQuery(t *testing.T) { + tests := []struct { + name string + host string + workspaceID string + accountID string + wantHost string + wantWorkspaceID string + wantAccountID string + }{ + { + name: "?o= promoted to WorkspaceID", + host: "https://acme.databricks.net/?o=12345", + wantHost: "https://acme.databricks.net", + wantWorkspaceID: "12345", + }, + { + name: "?workspace_id= promoted to WorkspaceID", + host: "https://acme.databricks.net/?workspace_id=12345", + wantHost: "https://acme.databricks.net", + wantWorkspaceID: "12345", + }, + { + name: "?a= promoted to AccountID", + host: "https://acme.databricks.net/?a=abc", + wantHost: "https://acme.databricks.net", + wantAccountID: "abc", + }, + { + name: "?o= and ?a= both promoted", + host: "https://acme.databricks.net/?o=12345&a=abc", + wantHost: "https://acme.databricks.net", + wantWorkspaceID: "12345", + wantAccountID: "abc", + }, + { + name: "existing WorkspaceID is preserved", + host: "https://acme.databricks.net/?o=12345", + workspaceID: "99999", + wantHost: "https://acme.databricks.net", + wantWorkspaceID: "99999", + }, + { + name: "existing AccountID is preserved", + host: "https://acme.databricks.net/?a=other", + accountID: "kept", + wantHost: "https://acme.databricks.net", + wantAccountID: "kept", + }, + { + name: "non-numeric ?o= is dropped", + host: "https://acme.databricks.net/?o=notanumber", + wantHost: "https://acme.databricks.net", + }, + { + name: "host without query is unchanged", + host: "https://acme.databricks.net", + wantHost: "https://acme.databricks.net", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &Config{ + Host: tt.host, + WorkspaceID: tt.workspaceID, + AccountID: tt.accountID, + } + require.NoError(t, cfg.fixHostIfNeeded()) + assert.Equal(t, tt.wantHost, cfg.Host) + assert.Equal(t, tt.wantWorkspaceID, cfg.WorkspaceID) + assert.Equal(t, tt.wantAccountID, cfg.AccountID) + }) + } +}