// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package azure import ( "context" "encoding/json" "fmt" "io/ioutil" "net/http" "net/url" "time" "github.com/hashicorp/vault/api" ) type AzureAuth struct { roleName string mountPath string resource string } var _ api.AuthMethod = (*AzureAuth)(nil) type LoginOption func(a *AzureAuth) error type responseJSON struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` ExpiresIn string `json:"expires_in"` ExpiresOn string `json:"expires_on"` NotBefore string `json:"not_before"` Resource string `json:"resource"` TokenType string `json:"token_type"` } type errorJSON struct { Error string `json:"error"` ErrorDescription string `json:"error_description"` } type metadataJSON struct { Compute computeJSON `json:"compute"` } type computeJSON struct { VMName string `json:"name"` VMScaleSetName string `json:"vmScaleSetName"` SubscriptionID string `json:"subscriptionId"` ResourceGroupName string `json:"resourceGroupName"` } const ( defaultMountPath = "azure" defaultResourceURL = "https://management.azure.com/" metadataEndpoint = "http://169.254.169.254" metadataAPIVersion = "2021-05-01" apiVersionQueryParam = "api-version" resourceQueryParam = "resource" clientTimeout = 10 * time.Second ) // NewAzureAuth initializes a new Azure auth method interface to be // passed as a parameter to the client.Auth().Login method. // // Supported options: WithMountPath, WithResource func NewAzureAuth(roleName string, opts ...LoginOption) (*AzureAuth, error) { if roleName == "" { return nil, fmt.Errorf("no role name provided for login") } a := &AzureAuth{ roleName: roleName, mountPath: defaultMountPath, resource: defaultResourceURL, } // Loop through each option for _, opt := range opts { // Call the option giving the instantiated // *AzureAuth as the argument err := opt(a) if err != nil { return nil, fmt.Errorf("error with login option: %w", err) } } // return the modified auth struct instance return a, nil } // Login sets up the required request body for the Azure auth method's /login // endpoint, and performs a write to it. func (a *AzureAuth) Login(ctx context.Context, client *api.Client) (*api.Secret, error) { if ctx == nil { ctx = context.Background() } jwtResp, err := a.getJWT() if err != nil { return nil, fmt.Errorf("unable to get access token: %w", err) } metadataRespJSON, err := getMetadata() if err != nil { return nil, fmt.Errorf("unable to get instance metadata: %w", err) } loginData := map[string]interface{}{ "role": a.roleName, "jwt": jwtResp, "vm_name": metadataRespJSON.Compute.VMName, "vmss_name": metadataRespJSON.Compute.VMScaleSetName, "subscription_id": metadataRespJSON.Compute.SubscriptionID, "resource_group_name": metadataRespJSON.Compute.ResourceGroupName, } path := fmt.Sprintf("auth/%s/login", a.mountPath) resp, err := client.Logical().WriteWithContext(ctx, path, loginData) if err != nil { return nil, fmt.Errorf("unable to log in with Azure auth: %w", err) } return resp, nil } func WithMountPath(mountPath string) LoginOption { return func(a *AzureAuth) error { a.mountPath = mountPath return nil } } // WithResource allows you to specify a different resource URL to use as the aud value // on the JWT token than the default of Azure Public Cloud's ARM URL. // This should match the resource URI that an administrator configured your // Vault server to use. // // See https://github.com/Azure/go-autorest/blob/master/autorest/azure/environments.go // for a list of valid environments. func WithResource(url string) LoginOption { return func(a *AzureAuth) error { a.resource = url return nil } } // Retrieves an access token from Managed Identities for Azure Resources // // Learn more here: https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token func (a *AzureAuth) getJWT() (string, error) { identityEndpoint, err := url.Parse(fmt.Sprintf("%s/metadata/identity/oauth2/token", metadataEndpoint)) if err != nil { return "", fmt.Errorf("error creating metadata URL: %w", err) } identityParameters := identityEndpoint.Query() identityParameters.Add(apiVersionQueryParam, metadataAPIVersion) identityParameters.Add(resourceQueryParam, a.resource) identityEndpoint.RawQuery = identityParameters.Encode() req, err := http.NewRequest(http.MethodGet, identityEndpoint.String(), nil) if err != nil { return "", fmt.Errorf("error creating HTTP request: %w", err) } req.Header.Add("Metadata", "true") client := &http.Client{ Timeout: clientTimeout, } resp, err := client.Do(req) if err != nil { return "", fmt.Errorf("error calling Azure token endpoint: %w", err) } defer resp.Body.Close() responseBytes, err := ioutil.ReadAll(resp.Body) if err != nil { return "", fmt.Errorf("error reading response body from Azure token endpoint: %w", err) } if resp.StatusCode != http.StatusOK { var errResp errorJSON err = json.Unmarshal(responseBytes, &errResp) if err != nil { return "", fmt.Errorf("received error message but was unable to unmarshal its contents") } return "", fmt.Errorf("%s error from Azure token endpoint: %s", errResp.Error, errResp.ErrorDescription) } var r responseJSON err = json.Unmarshal(responseBytes, &r) if err != nil { return "", fmt.Errorf("error unmarshaling response from Azure token endpoint: %w", err) } return r.AccessToken, nil } func getMetadata() (metadataJSON, error) { metadataEndpoint, err := url.Parse(fmt.Sprintf("%s/metadata/instance", metadataEndpoint)) if err != nil { return metadataJSON{}, err } metadataParameters := metadataEndpoint.Query() metadataParameters.Add(apiVersionQueryParam, metadataAPIVersion) metadataEndpoint.RawQuery = metadataParameters.Encode() req, err := http.NewRequest(http.MethodGet, metadataEndpoint.String(), nil) if err != nil { return metadataJSON{}, fmt.Errorf("error creating HTTP Request for metadata endpoint: %w", err) } req.Header.Add("Metadata", "true") client := &http.Client{ Timeout: clientTimeout, } resp, err := client.Do(req) if err != nil { return metadataJSON{}, fmt.Errorf("error calling metadata endpoint: %w", err) } defer resp.Body.Close() responseBytes, err := ioutil.ReadAll(resp.Body) if err != nil { return metadataJSON{}, fmt.Errorf("error reading response body from metadata endpoint: %w", err) } if resp.StatusCode != http.StatusOK { var errResp errorJSON _ = json.Unmarshal(responseBytes, &errResp) if err != nil { return metadataJSON{}, fmt.Errorf("received error message but was unable to unmarshal its contents") } return metadataJSON{}, fmt.Errorf("%s error from metadata endpoint: %s", errResp.Error, errResp.ErrorDescription) } var r metadataJSON err = json.Unmarshal(responseBytes, &r) if err != nil { return metadataJSON{}, fmt.Errorf("error unmarshaling the response from metadata endpoint: %w", err) } return r, nil }