331 lines
8.9 KiB
Go
331 lines
8.9 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package azure
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net/url"
|
|
"os"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/Azure/azure-storage-blob-go/azblob"
|
|
"github.com/Azure/go-autorest/autorest/adal"
|
|
"github.com/Azure/go-autorest/autorest/azure"
|
|
"github.com/armon/go-metrics"
|
|
log "github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/go-secure-stdlib/strutil"
|
|
"github.com/hashicorp/vault/sdk/physical"
|
|
)
|
|
|
|
const (
|
|
// MaxBlobSize at this time
|
|
MaxBlobSize = 1024 * 1024 * 4
|
|
// MaxListResults is the current default value, setting explicitly
|
|
MaxListResults = 5000
|
|
)
|
|
|
|
// AzureBackend is a physical backend that stores data
|
|
// within an Azure blob container.
|
|
type AzureBackend struct {
|
|
container *azblob.ContainerURL
|
|
logger log.Logger
|
|
permitPool *physical.PermitPool
|
|
}
|
|
|
|
// Verify AzureBackend satisfies the correct interfaces
|
|
var _ physical.Backend = (*AzureBackend)(nil)
|
|
|
|
// NewAzureBackend constructs an Azure backend using a pre-existing
|
|
// bucket. Credentials can be provided to the backend, sourced
|
|
// from the environment, via HCL or by using managed identities.
|
|
func NewAzureBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
|
|
name := os.Getenv("AZURE_BLOB_CONTAINER")
|
|
useMSI := false
|
|
|
|
if name == "" {
|
|
name = conf["container"]
|
|
if name == "" {
|
|
return nil, fmt.Errorf("'container' must be set")
|
|
}
|
|
}
|
|
|
|
accountName := os.Getenv("AZURE_ACCOUNT_NAME")
|
|
if accountName == "" {
|
|
accountName = conf["accountName"]
|
|
if accountName == "" {
|
|
return nil, fmt.Errorf("'accountName' must be set")
|
|
}
|
|
}
|
|
|
|
accountKey := os.Getenv("AZURE_ACCOUNT_KEY")
|
|
if accountKey == "" {
|
|
accountKey = conf["accountKey"]
|
|
if accountKey == "" {
|
|
logger.Info("accountKey not set, using managed identity auth")
|
|
useMSI = true
|
|
}
|
|
}
|
|
|
|
environmentName := os.Getenv("AZURE_ENVIRONMENT")
|
|
if environmentName == "" {
|
|
environmentName = conf["environment"]
|
|
if environmentName == "" {
|
|
environmentName = "AzurePublicCloud"
|
|
}
|
|
}
|
|
|
|
environmentURL := os.Getenv("AZURE_ARM_ENDPOINT")
|
|
if environmentURL == "" {
|
|
environmentURL = conf["arm_endpoint"]
|
|
}
|
|
|
|
var environment azure.Environment
|
|
var URL *url.URL
|
|
var err error
|
|
|
|
testHost := conf["testHost"]
|
|
switch {
|
|
case testHost != "":
|
|
URL = &url.URL{Scheme: "http", Host: testHost, Path: fmt.Sprintf("/%s/%s", accountName, name)}
|
|
default:
|
|
if environmentURL != "" {
|
|
environment, err = azure.EnvironmentFromURL(environmentURL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to look up Azure environment descriptor for URL %q: %w", environmentURL, err)
|
|
}
|
|
} else {
|
|
environment, err = azure.EnvironmentFromName(environmentName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to look up Azure environment descriptor for name %q: %w", environmentName, err)
|
|
}
|
|
}
|
|
URL, err = url.Parse(
|
|
fmt.Sprintf("https://%s.blob.%s/%s", accountName, environment.StorageEndpointSuffix, name))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create Azure client: %w", err)
|
|
}
|
|
}
|
|
|
|
var credential azblob.Credential
|
|
if useMSI {
|
|
authToken, err := getAuthTokenFromIMDS(environment.ResourceIdentifiers.Storage)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to obtain auth token from IMDS %q: %w", environmentName, err)
|
|
}
|
|
|
|
credential = azblob.NewTokenCredential(authToken.OAuthToken(), func(c azblob.TokenCredential) time.Duration {
|
|
err = authToken.Refresh()
|
|
if err != nil {
|
|
logger.Error("couldn't refresh token credential", "error", err)
|
|
return 0
|
|
}
|
|
|
|
expIn, err := authToken.Token().ExpiresIn.Int64()
|
|
if err != nil {
|
|
logger.Error("couldn't retrieve jwt claim for 'expiresIn' from refreshed token", "error", err)
|
|
return 0
|
|
}
|
|
|
|
logger.Debug("token refreshed, new token expires in", "access_token_expiry", expIn)
|
|
c.SetToken(authToken.OAuthToken())
|
|
|
|
// tokens are valid for 23h59m (86399s) by default, refresh after ~21h
|
|
return time.Duration(int(float64(expIn)*0.9)) * time.Second
|
|
})
|
|
} else {
|
|
credential, err = azblob.NewSharedKeyCredential(accountName, accountKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create Azure client: %w", err)
|
|
}
|
|
}
|
|
|
|
p := azblob.NewPipeline(credential, azblob.PipelineOptions{})
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
containerURL := azblob.NewContainerURL(*URL, p)
|
|
_, err = containerURL.GetProperties(ctx, azblob.LeaseAccessConditions{})
|
|
if err != nil {
|
|
var e azblob.StorageError
|
|
if errors.As(err, &e) {
|
|
switch e.ServiceCode() {
|
|
case azblob.ServiceCodeContainerNotFound:
|
|
_, err := containerURL.Create(ctx, azblob.Metadata{}, azblob.PublicAccessNone)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create %q container: %w", name, err)
|
|
}
|
|
default:
|
|
return nil, fmt.Errorf("failed to get properties for container %q: %w", name, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
maxParStr, ok := conf["max_parallel"]
|
|
var maxParInt int
|
|
if ok {
|
|
maxParInt, err = strconv.Atoi(maxParStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err)
|
|
}
|
|
if logger.IsDebug() {
|
|
logger.Debug("max_parallel set", "max_parallel", maxParInt)
|
|
}
|
|
}
|
|
|
|
a := &AzureBackend{
|
|
container: &containerURL,
|
|
logger: logger,
|
|
permitPool: physical.NewPermitPool(maxParInt),
|
|
}
|
|
return a, nil
|
|
}
|
|
|
|
// Put is used to insert or update an entry
|
|
func (a *AzureBackend) Put(ctx context.Context, entry *physical.Entry) error {
|
|
defer metrics.MeasureSince([]string{"azure", "put"}, time.Now())
|
|
|
|
if len(entry.Value) >= MaxBlobSize {
|
|
return fmt.Errorf("value is bigger than the current supported limit of 4MBytes")
|
|
}
|
|
|
|
a.permitPool.Acquire()
|
|
defer a.permitPool.Release()
|
|
|
|
blobURL := a.container.NewBlockBlobURL(entry.Key)
|
|
_, err := azblob.UploadBufferToBlockBlob(ctx, entry.Value, blobURL, azblob.UploadToBlockBlobOptions{
|
|
BlockSize: MaxBlobSize,
|
|
})
|
|
|
|
return err
|
|
}
|
|
|
|
// Get is used to fetch an entry
|
|
func (a *AzureBackend) Get(ctx context.Context, key string) (*physical.Entry, error) {
|
|
defer metrics.MeasureSince([]string{"azure", "get"}, time.Now())
|
|
|
|
a.permitPool.Acquire()
|
|
defer a.permitPool.Release()
|
|
|
|
blobURL := a.container.NewBlockBlobURL(key)
|
|
clientOptions := azblob.ClientProvidedKeyOptions{}
|
|
|
|
res, err := blobURL.Download(ctx, 0, azblob.CountToEnd, azblob.BlobAccessConditions{}, false, clientOptions)
|
|
if err != nil {
|
|
var e azblob.StorageError
|
|
if errors.As(err, &e) {
|
|
switch e.ServiceCode() {
|
|
case azblob.ServiceCodeBlobNotFound:
|
|
return nil, nil
|
|
default:
|
|
return nil, fmt.Errorf("failed to download blob %q: %w", key, err)
|
|
}
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
reader := res.Body(azblob.RetryReaderOptions{})
|
|
defer reader.Close()
|
|
|
|
data, err := ioutil.ReadAll(reader)
|
|
|
|
ent := &physical.Entry{
|
|
Key: key,
|
|
Value: data,
|
|
}
|
|
|
|
return ent, err
|
|
}
|
|
|
|
// Delete is used to permanently delete an entry
|
|
func (a *AzureBackend) Delete(ctx context.Context, key string) error {
|
|
defer metrics.MeasureSince([]string{"azure", "delete"}, time.Now())
|
|
|
|
a.permitPool.Acquire()
|
|
defer a.permitPool.Release()
|
|
|
|
blobURL := a.container.NewBlockBlobURL(key)
|
|
_, err := blobURL.Delete(ctx, azblob.DeleteSnapshotsOptionInclude, azblob.BlobAccessConditions{})
|
|
if err != nil {
|
|
var e azblob.StorageError
|
|
if errors.As(err, &e) {
|
|
switch e.ServiceCode() {
|
|
case azblob.ServiceCodeBlobNotFound:
|
|
return nil
|
|
default:
|
|
return fmt.Errorf("failed to delete blob %q: %w", key, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
// List is used to list all the keys under a given
|
|
// prefix, up to the next prefix.
|
|
func (a *AzureBackend) List(ctx context.Context, prefix string) ([]string, error) {
|
|
defer metrics.MeasureSince([]string{"azure", "list"}, time.Now())
|
|
|
|
a.permitPool.Acquire()
|
|
defer a.permitPool.Release()
|
|
|
|
var keys []string
|
|
for marker := (azblob.Marker{}); marker.NotDone(); {
|
|
listBlob, err := a.container.ListBlobsFlatSegment(ctx, marker, azblob.ListBlobsSegmentOptions{
|
|
Prefix: prefix,
|
|
MaxResults: MaxListResults,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, blobInfo := range listBlob.Segment.BlobItems {
|
|
key := strings.TrimPrefix(blobInfo.Name, prefix)
|
|
if i := strings.Index(key, "/"); i == -1 {
|
|
// file
|
|
keys = append(keys, key)
|
|
} else {
|
|
// subdirectory
|
|
keys = strutil.AppendIfMissing(keys, key[:i+1])
|
|
}
|
|
}
|
|
|
|
marker = listBlob.NextMarker
|
|
}
|
|
|
|
sort.Strings(keys)
|
|
return keys, nil
|
|
}
|
|
|
|
// getAuthTokenFromIMDS uses the Azure Instance Metadata Service to retrieve a short-lived credential using OAuth
|
|
// more info on this https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/overview
|
|
func getAuthTokenFromIMDS(resource string) (*adal.ServicePrincipalToken, error) {
|
|
msiEndpoint, err := adal.GetMSIVMEndpoint()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
spt, err := adal.NewServicePrincipalTokenFromMSI(msiEndpoint, resource)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := spt.Refresh(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
token := spt.Token()
|
|
if token.IsZero() {
|
|
return nil, err
|
|
}
|
|
|
|
return spt, nil
|
|
}
|