146 lines
3.8 KiB
Go
146 lines
3.8 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package ca
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/hashicorp/consul/agent/structs"
|
|
"github.com/hashicorp/go-cleanhttp"
|
|
"github.com/hashicorp/vault/sdk/helper/jsonutil"
|
|
)
|
|
|
|
func NewAzureAuthClient(authMethod *structs.VaultAuthMethod) (*VaultAuthClient, error) {
|
|
params := authMethod.Params
|
|
authClient := NewVaultAPIAuthClient(authMethod, "")
|
|
// check for login data already in params (for backwards compability)
|
|
legacyKeys := []string{
|
|
"vm_name", "vmss_name", "resource_group_name", "subscription_id", "jwt",
|
|
}
|
|
if legacyCheck(params, legacyKeys...) {
|
|
return authClient, nil
|
|
}
|
|
|
|
role, ok := params["role"].(string)
|
|
if !ok || strings.TrimSpace(role) == "" {
|
|
return nil, fmt.Errorf("missing 'role' value")
|
|
}
|
|
resource, ok := params["resource"].(string)
|
|
if !ok || strings.TrimSpace(resource) == "" {
|
|
return nil, fmt.Errorf("missing 'resource' value")
|
|
}
|
|
|
|
authClient.LoginDataGen = AzureLoginDataGen
|
|
return authClient, nil
|
|
}
|
|
|
|
var ( // use variables so we can change these in tests
|
|
instanceEndpoint = "http://169.254.169.254/metadata/instance"
|
|
identityEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
|
|
// minimum version 2018-02-01 needed for identity metadata
|
|
apiVersion = "2018-02-01"
|
|
)
|
|
|
|
type instanceData struct {
|
|
Compute Compute
|
|
}
|
|
type Compute struct {
|
|
Name string
|
|
ResourceGroupName string
|
|
SubscriptionID string
|
|
VMScaleSetName string
|
|
}
|
|
type identityData struct {
|
|
AccessToken string `json:"access_token"`
|
|
}
|
|
|
|
func AzureLoginDataGen(authMethod *structs.VaultAuthMethod) (map[string]any, error) {
|
|
params := authMethod.Params
|
|
role := params["role"].(string)
|
|
metaConf := map[string]string{
|
|
"role": role,
|
|
"resource": params["resource"].(string),
|
|
}
|
|
if objectID, ok := params["object_id"].(string); ok {
|
|
metaConf["object_id"] = objectID
|
|
}
|
|
if clientID, ok := params["client_id"].(string); ok {
|
|
metaConf["client_id"] = clientID
|
|
}
|
|
|
|
// Fetch instance data
|
|
var instance instanceData
|
|
body, err := getMetadataInfo(instanceEndpoint, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = jsonutil.DecodeJSON(body, &instance)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error parsing instance metadata response: %w", err)
|
|
}
|
|
|
|
// Fetch JWT
|
|
var identity identityData
|
|
body, err = getMetadataInfo(identityEndpoint, metaConf)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = jsonutil.DecodeJSON(body, &identity)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error parsing instance metadata response: %w", err)
|
|
}
|
|
|
|
data := map[string]interface{}{
|
|
"role": role,
|
|
"vm_name": instance.Compute.Name,
|
|
"vmss_name": instance.Compute.VMScaleSetName,
|
|
"resource_group_name": instance.Compute.ResourceGroupName,
|
|
"subscription_id": instance.Compute.SubscriptionID,
|
|
"jwt": identity.AccessToken,
|
|
}
|
|
|
|
return data, nil
|
|
}
|
|
|
|
func getMetadataInfo(endpoint string, query map[string]string) ([]byte, error) {
|
|
req, err := http.NewRequest("GET", endpoint, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
q := req.URL.Query()
|
|
q.Add("api-version", apiVersion)
|
|
for k, v := range query {
|
|
q.Add(k, v)
|
|
}
|
|
req.URL.RawQuery = q.Encode()
|
|
req.Header.Set("Metadata", "true")
|
|
req.Header.Set("User-Agent", "Consul")
|
|
|
|
client := cleanhttp.DefaultClient()
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error fetching metadata from %s: %w", endpoint, err)
|
|
}
|
|
|
|
if resp == nil {
|
|
return nil, fmt.Errorf("empty response fetching metadata from %s", endpoint)
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error reading metadata from %s: %w", endpoint, err)
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("error response in metadata from %s: %s", endpoint, body)
|
|
}
|
|
|
|
return body, nil
|
|
}
|