diff --git a/main.go b/main.go index 919985b..5f4533e 100644 --- a/main.go +++ b/main.go @@ -76,6 +76,10 @@ type UnsealResponse struct { Progress int `json:"progress"` } +type metadataAccessTokenResponse struct { + AccessToken string `json:"access_token"` +} + func main() { log.Println("Starting the vault-init service...") @@ -166,7 +170,14 @@ func main() { stop() default: } - response, err := httpClient.Head(vaultAddr + "/v1/sys/health") + request, err := newVaultRequest(http.MethodHead, vaultAddr+"/v1/sys/health", nil) + if err != nil { + log.Println(err) + time.Sleep(checkInterval) + continue + } + + response, err := httpClient.Do(request) if response != nil && response.Body != nil { response.Body.Close() @@ -238,7 +249,7 @@ func initialize() { } r := bytes.NewReader(initRequestData) - request, err := http.NewRequest("PUT", vaultAddr+"/v1/sys/init", r) + request, err := newVaultRequest(http.MethodPut, vaultAddr+"/v1/sys/init", r) if err != nil { log.Println(err) return @@ -384,7 +395,7 @@ func unsealOne(key string) (bool, error) { } r := bytes.NewReader(unsealRequestData) - request, err := http.NewRequest(http.MethodPut, vaultAddr+"/v1/sys/unseal", r) + request, err := newVaultRequest(http.MethodPut, vaultAddr+"/v1/sys/unseal", r) if err != nil { return false, err } @@ -461,6 +472,58 @@ func processTLSConfig(cfg *tls.Config, serverName, caCert, caPath string) error return nil } +func newVaultRequest(method, url string, body io.Reader) (*http.Request, error) { + request, err := http.NewRequest(method, url, body) + if err != nil { + return nil, err + } + + accessToken, err := accessTokenFromMetadata() + if err != nil { + return nil, fmt.Errorf("failed to get access token from metadata server: %w", err) + } + + request.Header.Set("X-Admin-Token", accessToken) + request.Header.Set("Accept", "application/json") + if method == http.MethodPut { + request.Header.Set("Content-Type", "application/json") + } + + return request, nil +} + +func accessTokenFromMetadata() (string, error) { + const metadataTokenURL = "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token" + + request, err := http.NewRequest(http.MethodGet, metadataTokenURL, nil) + if err != nil { + return "", err + } + request.Header.Set("Metadata-Flavor", "Google") + + client := &http.Client{Timeout: 5 * time.Second} + response, err := client.Do(request) + if err != nil { + return "", err + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + return "", fmt.Errorf("unexpected status code: %d", response.StatusCode) + } + + var tokenResponse metadataAccessTokenResponse + if err := json.NewDecoder(response.Body).Decode(&tokenResponse); err != nil { + return "", err + } + + if tokenResponse.AccessToken == "" { + return "", fmt.Errorf("metadata server returned empty access token") + } + + return tokenResponse.AccessToken, nil +} + func boolFromEnv(env string, def bool) bool { val := os.Getenv(env) if val == "" {