open-vault/vault/seal/azurekeyvault/azurekeyvault.go
2019-04-12 17:54:35 -04:00

307 lines
8.5 KiB
Go

package azurekeyvault
import (
"context"
"encoding/base64"
"errors"
"fmt"
"os"
"strings"
"sync/atomic"
"time"
"github.com/armon/go-metrics"
"github.com/Azure/azure-sdk-for-go/services/keyvault/2016-10-01/keyvault"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/Azure/go-autorest/autorest/azure/auth"
"github.com/Azure/go-autorest/autorest/to"
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/physical"
"github.com/hashicorp/vault/vault/seal"
)
// AzureKeyVaultSeal is an auto-seal that uses Azure Key Vault
// for crypto operations. Azure Key Vault currently does not support
// keys that can encrypt long data (RSA keys). Due to this fact, we generate
// and AES key and wrap the key using Key Vault and store it with the
// data
type AzureKeyVaultSeal struct {
tenantID string
clientID string
clientSecret string
vaultName string
keyName string
currentKeyID *atomic.Value
environment azure.Environment
client *keyvault.BaseClient
logger log.Logger
}
// Ensure that we are implementing AutoSealAccess
var _ seal.Access = (*AzureKeyVaultSeal)(nil)
func NewSeal(logger log.Logger) *AzureKeyVaultSeal {
v := &AzureKeyVaultSeal{
logger: logger,
currentKeyID: new(atomic.Value),
}
v.currentKeyID.Store("")
return v
}
// SetConfig sets the fields on the AzureKeyVaultSeal object based on
// values from the config parameter.
//
// Order of precedence:
// * Environment variable
// * Value from Vault configuration file
// * Managed Service Identity for instance
func (v *AzureKeyVaultSeal) SetConfig(config map[string]string) (map[string]string, error) {
if config == nil {
config = map[string]string{}
}
switch {
case os.Getenv("AZURE_TENANT_ID") != "":
v.tenantID = os.Getenv("AZURE_TENANT_ID")
case config["tenant_id"] != "":
v.tenantID = config["tenant_id"]
}
switch {
case os.Getenv("AZURE_CLIENT_ID") != "":
v.clientID = os.Getenv("AZURE_CLIENT_ID")
case config["client_id"] != "":
v.clientID = config["client_id"]
}
switch {
case os.Getenv("AZURE_CLIENT_SECRET") != "":
v.clientSecret = os.Getenv("AZURE_CLIENT_SECRET")
case config["client_secret"] != "":
v.clientSecret = config["client_secret"]
}
envName := os.Getenv("AZURE_ENVIRONMENT")
if envName == "" {
envName = config["environment"]
}
if envName == "" {
v.environment = azure.PublicCloud
} else {
var err error
v.environment, err = azure.EnvironmentFromName(envName)
if err != nil {
return nil, err
}
}
switch {
case os.Getenv("VAULT_AZUREKEYVAULT_VAULT_NAME") != "":
v.vaultName = os.Getenv("VAULT_AZUREKEYVAULT_VAULT_NAME")
case config["vault_name"] != "":
v.vaultName = config["vault_name"]
default:
return nil, errors.New("vault name is required")
}
switch {
case os.Getenv("VAULT_AZUREKEYVAULT_KEY_NAME") != "":
v.keyName = os.Getenv("VAULT_AZUREKEYVAULT_KEY_NAME")
case config["key_name"] != "":
v.keyName = config["key_name"]
default:
return nil, errors.New("key name is required")
}
if v.client == nil {
client, err := v.getKeyVaultClient()
if err != nil {
return nil, errwrap.Wrapf("error initializing Azure Key Vault seal client: {{err}}", err)
}
// Test the client connection using provided key ID
keyInfo, err := client.GetKey(context.Background(), v.buildBaseURL(), v.keyName, "")
if err != nil {
return nil, errwrap.Wrapf("error fetching Azure Key Vault seal key information: {{err}}", err)
}
if keyInfo.Key == nil {
return nil, errors.New("no key information returned")
}
v.currentKeyID.Store(parseKeyVersion(to.String(keyInfo.Key.Kid)))
v.client = client
}
// Map that holds non-sensitive configuration info
sealInfo := make(map[string]string)
sealInfo["environment"] = v.environment.Name
sealInfo["vault_name"] = v.vaultName
sealInfo["key_name"] = v.keyName
return sealInfo, nil
}
// Init is called during core.Initialize. This is a no-op.
func (v *AzureKeyVaultSeal) Init(context.Context) error {
return nil
}
// Finalize is called during shutdown. This is a no-op.
func (v *AzureKeyVaultSeal) Finalize(context.Context) error {
return nil
}
// SealType returns the seal type for this particular seal implementation.
func (v *AzureKeyVaultSeal) SealType() string {
return seal.AzureKeyVault
}
// KeyID returns the last known key id.
func (v *AzureKeyVaultSeal) KeyID() string {
return v.currentKeyID.Load().(string)
}
// Encrypt is used to encrypt using Azure Key Vault.
// This returns the ciphertext, and/or any errors from this
// call.
func (v *AzureKeyVaultSeal) Encrypt(ctx context.Context, plaintext []byte) (blob *physical.EncryptedBlobInfo, err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"seal", "encrypt", "time"}, now)
metrics.MeasureSince([]string{"seal", "azurekeyvault", "encrypt", "time"}, now)
if err != nil {
metrics.IncrCounter([]string{"seal", "encrypt", "error"}, 1)
metrics.IncrCounter([]string{"seal", "azurekeyvault", "encrypt", "error"}, 1)
}
}(time.Now())
metrics.IncrCounter([]string{"seal", "encrypt"}, 1)
metrics.IncrCounter([]string{"seal", "azurekeyvault", "encrypt"}, 1)
if plaintext == nil {
return nil, errors.New("given plaintext for encryption is nil")
}
env, err := seal.NewEnvelope().Encrypt(plaintext)
if err != nil {
return nil, errwrap.Wrapf("error wrapping dat: {{err}}", err)
}
// Encrypt the DEK using Key Vault
params := keyvault.KeyOperationsParameters{
Algorithm: keyvault.RSAOAEP256,
Value: to.StringPtr(base64.URLEncoding.EncodeToString(env.Key)),
}
// Wrap key with the latest version for the key name
resp, err := v.client.WrapKey(ctx, v.buildBaseURL(), v.keyName, "", params)
if err != nil {
return nil, err
}
// Store the current key version
keyVersion := parseKeyVersion(to.String(resp.Kid))
v.currentKeyID.Store(keyVersion)
ret := &physical.EncryptedBlobInfo{
Ciphertext: env.Ciphertext,
IV: env.IV,
KeyInfo: &physical.SealKeyInfo{
KeyID: keyVersion,
WrappedKey: []byte(to.String(resp.Result)),
},
}
return ret, nil
}
// Decrypt is used to decrypt the ciphertext.
func (v *AzureKeyVaultSeal) Decrypt(ctx context.Context, in *physical.EncryptedBlobInfo) (pt []byte, err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"seal", "decrypt", "time"}, now)
metrics.MeasureSince([]string{"seal", "azurekeyvault", "decrypt", "time"}, now)
if err != nil {
metrics.IncrCounter([]string{"seal", "decrypt", "error"}, 1)
metrics.IncrCounter([]string{"seal", "azurekeyvault", "decrypt", "error"}, 1)
}
}(time.Now())
metrics.IncrCounter([]string{"seal", "decrypt"}, 1)
metrics.IncrCounter([]string{"seal", "azurekeyvault", "decrypt"}, 1)
if in == nil {
return nil, errors.New("given input for decryption is nil")
}
if in.KeyInfo == nil {
return nil, errors.New("key info is nil")
}
// Unwrap the key
params := keyvault.KeyOperationsParameters{
Algorithm: keyvault.RSAOAEP256,
Value: to.StringPtr(string(in.KeyInfo.WrappedKey)),
}
resp, err := v.client.UnwrapKey(ctx, v.buildBaseURL(), v.keyName, in.KeyInfo.KeyID, params)
if err != nil {
return nil, err
}
keyBytes, err := base64.URLEncoding.WithPadding(base64.NoPadding).DecodeString(to.String(resp.Result))
if err != nil {
return nil, err
}
envInfo := &seal.EnvelopeInfo{
Key: keyBytes,
IV: in.IV,
Ciphertext: in.Ciphertext,
}
return seal.NewEnvelope().Decrypt(envInfo)
}
func (v *AzureKeyVaultSeal) buildBaseURL() string {
return fmt.Sprintf("https://%s.%s/", v.vaultName, v.environment.KeyVaultDNSSuffix)
}
func (v *AzureKeyVaultSeal) getKeyVaultClient() (*keyvault.BaseClient, error) {
var authorizer autorest.Authorizer
var err error
switch {
case v.clientID != "" && v.clientSecret != "":
config := auth.NewClientCredentialsConfig(v.clientID, v.clientSecret, v.tenantID)
config.AADEndpoint = v.environment.ActiveDirectoryEndpoint
config.Resource = strings.TrimSuffix(v.environment.KeyVaultEndpoint, "/")
authorizer, err = config.Authorizer()
if err != nil {
return nil, err
}
// By default use MSI
default:
config := auth.NewMSIConfig()
config.Resource = strings.TrimSuffix(v.environment.KeyVaultEndpoint, "/")
authorizer, err = config.Authorizer()
if err != nil {
return nil, err
}
}
client := keyvault.New()
client.Authorizer = authorizer
return &client, nil
}
// Kid gets returned as a full URL, get the last bit which is just
// the version
func parseKeyVersion(kid string) string {
keyVersionParts := strings.Split(kid, "/")
return keyVersionParts[len(keyVersionParts)-1]
}