Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 66 additions & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ type UnsealResponse struct {
Progress int `json:"progress"`
}

type metadataAccessTokenResponse struct {
AccessToken string `json:"access_token"`
}

func main() {
log.Println("Starting the vault-init service...")

Expand Down Expand Up @@ -166,7 +170,14 @@ func main() {
stop()
default:
}
response, err := httpClient.Head(vaultAddr + "/v1/sys/health")
request, err := newVaultRequest(http.MethodHead, vaultAddr+"/v1/sys/health", nil)
if err != nil {
log.Println(err)
time.Sleep(checkInterval)
continue
}

response, err := httpClient.Do(request)

if response != nil && response.Body != nil {
response.Body.Close()
Expand Down Expand Up @@ -238,7 +249,7 @@ func initialize() {
}

r := bytes.NewReader(initRequestData)
request, err := http.NewRequest("PUT", vaultAddr+"/v1/sys/init", r)
request, err := newVaultRequest(http.MethodPut, vaultAddr+"/v1/sys/init", r)
if err != nil {
log.Println(err)
return
Expand Down Expand Up @@ -384,7 +395,7 @@ func unsealOne(key string) (bool, error) {
}

r := bytes.NewReader(unsealRequestData)
request, err := http.NewRequest(http.MethodPut, vaultAddr+"/v1/sys/unseal", r)
request, err := newVaultRequest(http.MethodPut, vaultAddr+"/v1/sys/unseal", r)
if err != nil {
return false, err
}
Expand Down Expand Up @@ -461,6 +472,58 @@ func processTLSConfig(cfg *tls.Config, serverName, caCert, caPath string) error
return nil
}

func newVaultRequest(method, url string, body io.Reader) (*http.Request, error) {
request, err := http.NewRequest(method, url, body)
if err != nil {
return nil, err
}

accessToken, err := accessTokenFromMetadata()
if err != nil {
return nil, fmt.Errorf("failed to get access token from metadata server: %w", err)
}

request.Header.Set("X-Admin-Token", accessToken)
request.Header.Set("Accept", "application/json")
if method == http.MethodPut {
request.Header.Set("Content-Type", "application/json")
}

return request, nil
}

func accessTokenFromMetadata() (string, error) {
const metadataTokenURL = "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token"

request, err := http.NewRequest(http.MethodGet, metadataTokenURL, nil)
if err != nil {
return "", err
}
request.Header.Set("Metadata-Flavor", "Google")

client := &http.Client{Timeout: 5 * time.Second}
response, err := client.Do(request)
if err != nil {
return "", err
}
defer response.Body.Close()

if response.StatusCode != http.StatusOK {
return "", fmt.Errorf("unexpected status code: %d", response.StatusCode)
}

var tokenResponse metadataAccessTokenResponse
if err := json.NewDecoder(response.Body).Decode(&tokenResponse); err != nil {
return "", err
}

if tokenResponse.AccessToken == "" {
return "", fmt.Errorf("metadata server returned empty access token")
}

return tokenResponse.AccessToken, nil
}

func boolFromEnv(env string, def bool) bool {
val := os.Getenv(env)
if val == "" {
Expand Down