open-vault/api/auth/azure/azure.go
Hamid Ghaf 27bb03bbc0
adding copyright header (#19555)
* adding copyright header

* fix fmt and a test
2023-03-15 09:00:52 -07:00

247 lines
7 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package azure
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"time"
"github.com/hashicorp/vault/api"
)
type AzureAuth struct {
roleName string
mountPath string
resource string
}
var _ api.AuthMethod = (*AzureAuth)(nil)
type LoginOption func(a *AzureAuth) error
type responseJSON struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn string `json:"expires_in"`
ExpiresOn string `json:"expires_on"`
NotBefore string `json:"not_before"`
Resource string `json:"resource"`
TokenType string `json:"token_type"`
}
type errorJSON struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
}
type metadataJSON struct {
Compute computeJSON `json:"compute"`
}
type computeJSON struct {
VMName string `json:"name"`
VMScaleSetName string `json:"vmScaleSetName"`
SubscriptionID string `json:"subscriptionId"`
ResourceGroupName string `json:"resourceGroupName"`
}
const (
defaultMountPath = "azure"
defaultResourceURL = "https://management.azure.com/"
metadataEndpoint = "http://169.254.169.254"
metadataAPIVersion = "2021-05-01"
apiVersionQueryParam = "api-version"
resourceQueryParam = "resource"
clientTimeout = 10 * time.Second
)
// NewAzureAuth initializes a new Azure auth method interface to be
// passed as a parameter to the client.Auth().Login method.
//
// Supported options: WithMountPath, WithResource
func NewAzureAuth(roleName string, opts ...LoginOption) (*AzureAuth, error) {
if roleName == "" {
return nil, fmt.Errorf("no role name provided for login")
}
a := &AzureAuth{
roleName: roleName,
mountPath: defaultMountPath,
resource: defaultResourceURL,
}
// Loop through each option
for _, opt := range opts {
// Call the option giving the instantiated
// *AzureAuth as the argument
err := opt(a)
if err != nil {
return nil, fmt.Errorf("error with login option: %w", err)
}
}
// return the modified auth struct instance
return a, nil
}
// Login sets up the required request body for the Azure auth method's /login
// endpoint, and performs a write to it.
func (a *AzureAuth) Login(ctx context.Context, client *api.Client) (*api.Secret, error) {
if ctx == nil {
ctx = context.Background()
}
jwtResp, err := a.getJWT()
if err != nil {
return nil, fmt.Errorf("unable to get access token: %w", err)
}
metadataRespJSON, err := getMetadata()
if err != nil {
return nil, fmt.Errorf("unable to get instance metadata: %w", err)
}
loginData := map[string]interface{}{
"role": a.roleName,
"jwt": jwtResp,
"vm_name": metadataRespJSON.Compute.VMName,
"vmss_name": metadataRespJSON.Compute.VMScaleSetName,
"subscription_id": metadataRespJSON.Compute.SubscriptionID,
"resource_group_name": metadataRespJSON.Compute.ResourceGroupName,
}
path := fmt.Sprintf("auth/%s/login", a.mountPath)
resp, err := client.Logical().WriteWithContext(ctx, path, loginData)
if err != nil {
return nil, fmt.Errorf("unable to log in with Azure auth: %w", err)
}
return resp, nil
}
func WithMountPath(mountPath string) LoginOption {
return func(a *AzureAuth) error {
a.mountPath = mountPath
return nil
}
}
// WithResource allows you to specify a different resource URL to use as the aud value
// on the JWT token than the default of Azure Public Cloud's ARM URL.
// This should match the resource URI that an administrator configured your
// Vault server to use.
//
// See https://github.com/Azure/go-autorest/blob/master/autorest/azure/environments.go
// for a list of valid environments.
func WithResource(url string) LoginOption {
return func(a *AzureAuth) error {
a.resource = url
return nil
}
}
// Retrieves an access token from Managed Identities for Azure Resources
//
// Learn more here: https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token
func (a *AzureAuth) getJWT() (string, error) {
identityEndpoint, err := url.Parse(fmt.Sprintf("%s/metadata/identity/oauth2/token", metadataEndpoint))
if err != nil {
return "", fmt.Errorf("error creating metadata URL: %w", err)
}
identityParameters := identityEndpoint.Query()
identityParameters.Add(apiVersionQueryParam, metadataAPIVersion)
identityParameters.Add(resourceQueryParam, a.resource)
identityEndpoint.RawQuery = identityParameters.Encode()
req, err := http.NewRequest(http.MethodGet, identityEndpoint.String(), nil)
if err != nil {
return "", fmt.Errorf("error creating HTTP request: %w", err)
}
req.Header.Add("Metadata", "true")
client := &http.Client{
Timeout: clientTimeout,
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("error calling Azure token endpoint: %w", err)
}
defer resp.Body.Close()
responseBytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("error reading response body from Azure token endpoint: %w", err)
}
if resp.StatusCode != http.StatusOK {
var errResp errorJSON
err = json.Unmarshal(responseBytes, &errResp)
if err != nil {
return "", fmt.Errorf("received error message but was unable to unmarshal its contents")
}
return "", fmt.Errorf("%s error from Azure token endpoint: %s", errResp.Error, errResp.ErrorDescription)
}
var r responseJSON
err = json.Unmarshal(responseBytes, &r)
if err != nil {
return "", fmt.Errorf("error unmarshaling response from Azure token endpoint: %w", err)
}
return r.AccessToken, nil
}
func getMetadata() (metadataJSON, error) {
metadataEndpoint, err := url.Parse(fmt.Sprintf("%s/metadata/instance", metadataEndpoint))
if err != nil {
return metadataJSON{}, err
}
metadataParameters := metadataEndpoint.Query()
metadataParameters.Add(apiVersionQueryParam, metadataAPIVersion)
metadataEndpoint.RawQuery = metadataParameters.Encode()
req, err := http.NewRequest(http.MethodGet, metadataEndpoint.String(), nil)
if err != nil {
return metadataJSON{}, fmt.Errorf("error creating HTTP Request for metadata endpoint: %w", err)
}
req.Header.Add("Metadata", "true")
client := &http.Client{
Timeout: clientTimeout,
}
resp, err := client.Do(req)
if err != nil {
return metadataJSON{}, fmt.Errorf("error calling metadata endpoint: %w", err)
}
defer resp.Body.Close()
responseBytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
return metadataJSON{}, fmt.Errorf("error reading response body from metadata endpoint: %w", err)
}
if resp.StatusCode != http.StatusOK {
var errResp errorJSON
_ = json.Unmarshal(responseBytes, &errResp)
if err != nil {
return metadataJSON{}, fmt.Errorf("received error message but was unable to unmarshal its contents")
}
return metadataJSON{}, fmt.Errorf("%s error from metadata endpoint: %s", errResp.Error, errResp.ErrorDescription)
}
var r metadataJSON
err = json.Unmarshal(responseBytes, &r)
if err != nil {
return metadataJSON{}, fmt.Errorf("error unmarshaling the response from metadata endpoint: %w", err)
}
return r, nil
}