open-vault/helper/keysutil/policy.go

822 lines
21 KiB
Go

package keysutil
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"crypto/x509"
"encoding/asn1"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"math/big"
"strconv"
"strings"
"time"
"golang.org/x/crypto/hkdf"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/helper/kdf"
"github.com/hashicorp/vault/logical"
)
// Careful with iota; don't put anything before it in this const block because
// we need the default of zero to be the old-style KDF
const (
Kdf_hmac_sha256_counter = iota // built-in helper
Kdf_hkdf_sha256 // golang.org/x/crypto/hkdf
)
// Or this one...we need the default of zero to be the original AES256-GCM96
const (
KeyType_AES256_GCM96 = iota
KeyType_ECDSA_P256
)
const ErrTooOld = "ciphertext or signature version is disallowed by policy (too old)"
type ecdsaSignature struct {
R, S *big.Int
}
type KeyType int
func (kt KeyType) EncryptionSupported() bool {
switch kt {
case KeyType_AES256_GCM96:
return true
}
return false
}
func (kt KeyType) DecryptionSupported() bool {
switch kt {
case KeyType_AES256_GCM96:
return true
}
return false
}
func (kt KeyType) SigningSupported() bool {
switch kt {
case KeyType_ECDSA_P256:
return true
}
return false
}
func (kt KeyType) DerivationSupported() bool {
switch kt {
case KeyType_AES256_GCM96:
return true
}
return false
}
func (kt KeyType) String() string {
switch kt {
case KeyType_AES256_GCM96:
return "aes256-gcm96"
case KeyType_ECDSA_P256:
return "ecdsa-p256"
}
return "[unknown]"
}
// KeyEntry stores the key and metadata
type KeyEntry struct {
AESKey []byte `json:"key"`
HMACKey []byte `json:"hmac_key"`
CreationTime int64 `json:"creation_time"`
EC_X *big.Int `json:"ec_x"`
EC_Y *big.Int `json:"ec_y"`
EC_D *big.Int `json:"ec_d"`
FormattedPublicKey string `json:"public_key"`
}
// keyEntryMap is used to allow JSON marshal/unmarshal
type keyEntryMap map[int]KeyEntry
// MarshalJSON implements JSON marshaling
func (kem keyEntryMap) MarshalJSON() ([]byte, error) {
intermediate := map[string]KeyEntry{}
for k, v := range kem {
intermediate[strconv.Itoa(k)] = v
}
return json.Marshal(&intermediate)
}
// MarshalJSON implements JSON unmarshaling
func (kem keyEntryMap) UnmarshalJSON(data []byte) error {
intermediate := map[string]KeyEntry{}
if err := jsonutil.DecodeJSON(data, &intermediate); err != nil {
return err
}
for k, v := range intermediate {
keyval, err := strconv.Atoi(k)
if err != nil {
return err
}
kem[keyval] = v
}
return nil
}
// Policy is the struct used to store metadata
type Policy struct {
Name string `json:"name"`
Key []byte `json:"key,omitempty"` //DEPRECATED
Keys keyEntryMap `json:"keys"`
// Derived keys MUST provide a context and the master underlying key is
// never used. If convergent encryption is true, the context will be used
// as the nonce as well.
Derived bool `json:"derived"`
KDF int `json:"kdf"`
ConvergentEncryption bool `json:"convergent_encryption"`
// Whether the key is exportable
Exportable bool `json:"exportable"`
// The minimum version of the key allowed to be used
// for decryption
MinDecryptionVersion int `json:"min_decryption_version"`
// The latest key version in this policy
LatestVersion int `json:"latest_version"`
// The latest key version in the archive. We never delete these, so this is
// a max.
ArchiveVersion int `json:"archive_version"`
// Whether the key is allowed to be deleted
DeletionAllowed bool `json:"deletion_allowed"`
// The version of the convergent nonce to use
ConvergentVersion int `json:"convergent_version"`
// The type of key
Type KeyType `json:"type"`
}
// ArchivedKeys stores old keys. This is used to keep the key loading time sane
// when there are huge numbers of rotations.
type archivedKeys struct {
Keys []KeyEntry `json:"keys"`
}
func (p *Policy) LoadArchive(storage logical.Storage) (*archivedKeys, error) {
archive := &archivedKeys{}
raw, err := storage.Get("archive/" + p.Name)
if err != nil {
return nil, err
}
if raw == nil {
archive.Keys = make([]KeyEntry, 0)
return archive, nil
}
if err := jsonutil.DecodeJSON(raw.Value, archive); err != nil {
return nil, err
}
return archive, nil
}
func (p *Policy) storeArchive(archive *archivedKeys, storage logical.Storage) error {
// Encode the policy
buf, err := json.Marshal(archive)
if err != nil {
return err
}
// Write the policy into storage
err = storage.Put(&logical.StorageEntry{
Key: "archive/" + p.Name,
Value: buf,
})
if err != nil {
return err
}
return nil
}
// handleArchiving manages the movement of keys to and from the policy archive.
// This should *ONLY* be called from Persist() since it assumes that the policy
// will be persisted afterwards.
func (p *Policy) handleArchiving(storage logical.Storage) error {
// We need to move keys that are no longer accessible to archivedKeys, and keys
// that now need to be accessible back here.
//
// For safety, because there isn't really a good reason to, we never delete
// keys from the archive even when we move them back.
// Check if we have the latest minimum version in the current set of keys
_, keysContainsMinimum := p.Keys[p.MinDecryptionVersion]
// Sanity checks
switch {
case p.MinDecryptionVersion < 1:
return fmt.Errorf("minimum decryption version of %d is less than 1", p.MinDecryptionVersion)
case p.LatestVersion < 1:
return fmt.Errorf("latest version of %d is less than 1", p.LatestVersion)
case !keysContainsMinimum && p.ArchiveVersion != p.LatestVersion:
return fmt.Errorf("need to move keys from archive but archive version not up-to-date")
case p.ArchiveVersion > p.LatestVersion:
return fmt.Errorf("archive version of %d is greater than the latest version %d",
p.ArchiveVersion, p.LatestVersion)
case p.MinDecryptionVersion > p.LatestVersion:
return fmt.Errorf("minimum decryption version of %d is greater than the latest version %d",
p.MinDecryptionVersion, p.LatestVersion)
}
archive, err := p.LoadArchive(storage)
if err != nil {
return err
}
if !keysContainsMinimum {
// Need to move keys *from* archive
for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ {
p.Keys[i] = archive.Keys[i]
}
return nil
}
// Need to move keys *to* archive
// We need a size that is equivalent to the latest version (number of keys)
// but adding one since slice numbering starts at 0 and we're indexing by
// key version
if len(archive.Keys) < p.LatestVersion+1 {
// Increase the size of the archive slice
newKeys := make([]KeyEntry, p.LatestVersion+1)
copy(newKeys, archive.Keys)
archive.Keys = newKeys
}
// We are storing all keys in the archive, so we ensure that it is up to
// date up to p.LatestVersion
for i := p.ArchiveVersion + 1; i <= p.LatestVersion; i++ {
archive.Keys[i] = p.Keys[i]
p.ArchiveVersion = i
}
err = p.storeArchive(archive, storage)
if err != nil {
return err
}
// Perform deletion afterwards so that if there is an error saving we
// haven't messed with the current policy
for i := p.LatestVersion - len(p.Keys) + 1; i < p.MinDecryptionVersion; i++ {
delete(p.Keys, i)
}
return nil
}
func (p *Policy) Persist(storage logical.Storage) error {
err := p.handleArchiving(storage)
if err != nil {
return err
}
// Encode the policy
buf, err := p.Serialize()
if err != nil {
return err
}
// Write the policy into storage
err = storage.Put(&logical.StorageEntry{
Key: "policy/" + p.Name,
Value: buf,
})
if err != nil {
return err
}
return nil
}
func (p *Policy) Serialize() ([]byte, error) {
return json.Marshal(p)
}
func (p *Policy) NeedsUpgrade() bool {
// Ensure we've moved from Key -> Keys
if p.Key != nil && len(p.Key) > 0 {
return true
}
// With archiving, past assumptions about the length of the keys map are no
// longer valid
if p.LatestVersion == 0 && len(p.Keys) != 0 {
return true
}
// We disallow setting the version to 0, since they start at 1 since moving
// to rotate-able keys, so update if it's set to 0
if p.MinDecryptionVersion == 0 {
return true
}
// On first load after an upgrade, copy keys to the archive
if p.ArchiveVersion == 0 {
return true
}
// Need to write the version
if p.ConvergentEncryption && p.ConvergentVersion == 0 {
return true
}
if p.Keys[p.LatestVersion].HMACKey == nil || len(p.Keys[p.LatestVersion].HMACKey) == 0 {
return true
}
return false
}
func (p *Policy) Upgrade(storage logical.Storage) error {
persistNeeded := false
// Ensure we've moved from Key -> Keys
if p.Key != nil && len(p.Key) > 0 {
p.MigrateKeyToKeysMap()
persistNeeded = true
}
// With archiving, past assumptions about the length of the keys map are no
// longer valid
if p.LatestVersion == 0 && len(p.Keys) != 0 {
p.LatestVersion = len(p.Keys)
persistNeeded = true
}
// We disallow setting the version to 0, since they start at 1 since moving
// to rotate-able keys, so update if it's set to 0
if p.MinDecryptionVersion == 0 {
p.MinDecryptionVersion = 1
persistNeeded = true
}
// On first load after an upgrade, copy keys to the archive
if p.ArchiveVersion == 0 {
persistNeeded = true
}
if p.ConvergentEncryption && p.ConvergentVersion == 0 {
p.ConvergentVersion = 1
persistNeeded = true
}
if p.Keys[p.LatestVersion].HMACKey == nil || len(p.Keys[p.LatestVersion].HMACKey) == 0 {
entry := p.Keys[p.LatestVersion]
hmacKey, err := uuid.GenerateRandomBytes(32)
if err != nil {
return err
}
entry.HMACKey = hmacKey
p.Keys[p.LatestVersion] = entry
persistNeeded = true
}
if persistNeeded {
err := p.Persist(storage)
if err != nil {
return err
}
}
return nil
}
// DeriveKey is used to derive the encryption key that should be used depending
// on the policy. If derivation is disabled the raw key is used and no context
// is required, otherwise the KDF mode is used with the context to derive the
// proper key.
func (p *Policy) DeriveKey(context []byte, ver int) ([]byte, error) {
if !p.Type.DerivationSupported() {
return nil, errutil.UserError{Err: fmt.Sprintf("derivation not supported for key type %v", p.Type)}
}
if p.Keys == nil || p.LatestVersion == 0 {
return nil, errutil.InternalError{Err: "unable to access the key; no key versions found"}
}
if ver <= 0 || ver > p.LatestVersion {
return nil, errutil.UserError{Err: "invalid key version"}
}
// Fast-path non-derived keys
if !p.Derived {
return p.Keys[ver].AESKey, nil
}
// Ensure a context is provided
if len(context) == 0 {
return nil, errutil.UserError{Err: "missing 'context' for key deriviation. The key was created using a derived key, which means additional, per-request information must be included in order to encrypt or decrypt information"}
}
switch p.KDF {
case Kdf_hmac_sha256_counter:
prf := kdf.HMACSHA256PRF
prfLen := kdf.HMACSHA256PRFLen
return kdf.CounterMode(prf, prfLen, p.Keys[ver].AESKey, context, 256)
case Kdf_hkdf_sha256:
reader := hkdf.New(sha256.New, p.Keys[ver].AESKey, nil, context)
derBytes := bytes.NewBuffer(nil)
derBytes.Grow(32)
limReader := &io.LimitedReader{
R: reader,
N: 32,
}
n, err := derBytes.ReadFrom(limReader)
if err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("error reading returned derived bytes: %v", err)}
}
if n != 32 {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to read enough derived bytes, needed 32, got %d", n)}
}
return derBytes.Bytes(), nil
default:
return nil, errutil.InternalError{Err: "unsupported key derivation mode"}
}
}
func (p *Policy) Encrypt(context, nonce []byte, value string) (string, error) {
if !p.Type.EncryptionSupported() {
return "", errutil.UserError{Err: fmt.Sprintf("message encryption not supported for key type %v", p.Type)}
}
// Guard against a potentially invalid key type
switch p.Type {
case KeyType_AES256_GCM96:
default:
return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
}
// Decode the plaintext value
plaintext, err := base64.StdEncoding.DecodeString(value)
if err != nil {
return "", errutil.UserError{Err: "failed to base64-decode plaintext"}
}
// Derive the key that should be used
key, err := p.DeriveKey(context, p.LatestVersion)
if err != nil {
return "", err
}
// Guard against a potentially invalid key type
switch p.Type {
case KeyType_AES256_GCM96:
default:
return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
}
// Setup the cipher
aesCipher, err := aes.NewCipher(key)
if err != nil {
return "", errutil.InternalError{Err: err.Error()}
}
// Setup the GCM AEAD
gcm, err := cipher.NewGCM(aesCipher)
if err != nil {
return "", errutil.InternalError{Err: err.Error()}
}
if p.ConvergentEncryption {
switch p.ConvergentVersion {
case 1:
if len(nonce) != gcm.NonceSize() {
return "", errutil.UserError{Err: fmt.Sprintf("base64-decoded nonce must be %d bytes long when using convergent encryption with this key", gcm.NonceSize())}
}
default:
nonceHmac := hmac.New(sha256.New, context)
nonceHmac.Write(plaintext)
nonceSum := nonceHmac.Sum(nil)
nonce = nonceSum[:gcm.NonceSize()]
}
} else {
// Compute random nonce
nonce, err = uuid.GenerateRandomBytes(gcm.NonceSize())
if err != nil {
return "", errutil.InternalError{Err: err.Error()}
}
}
// Encrypt and tag with GCM
out := gcm.Seal(nil, nonce, plaintext, nil)
// Place the encrypted data after the nonce
full := out
if !p.ConvergentEncryption || p.ConvergentVersion > 1 {
full = append(nonce, out...)
}
// Convert to base64
encoded := base64.StdEncoding.EncodeToString(full)
// Prepend some information
encoded = "vault:v" + strconv.Itoa(p.LatestVersion) + ":" + encoded
return encoded, nil
}
func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) {
if !p.Type.DecryptionSupported() {
return "", errutil.UserError{Err: fmt.Sprintf("message decryption not supported for key type %v", p.Type)}
}
// Verify the prefix
if !strings.HasPrefix(value, "vault:v") {
return "", errutil.UserError{Err: "invalid ciphertext: no prefix"}
}
if p.ConvergentEncryption && p.ConvergentVersion == 1 && (nonce == nil || len(nonce) == 0) {
return "", errutil.UserError{Err: "invalid convergent nonce supplied"}
}
splitVerCiphertext := strings.SplitN(strings.TrimPrefix(value, "vault:v"), ":", 2)
if len(splitVerCiphertext) != 2 {
return "", errutil.UserError{Err: "invalid ciphertext: wrong number of fields"}
}
ver, err := strconv.Atoi(splitVerCiphertext[0])
if err != nil {
return "", errutil.UserError{Err: "invalid ciphertext: version number could not be decoded"}
}
if ver == 0 {
// Compatibility mode with initial implementation, where keys start at
// zero
ver = 1
}
if ver > p.LatestVersion {
return "", errutil.UserError{Err: "invalid ciphertext: version is too new"}
}
if p.MinDecryptionVersion > 0 && ver < p.MinDecryptionVersion {
return "", errutil.UserError{Err: ErrTooOld}
}
// Derive the key that should be used
key, err := p.DeriveKey(context, ver)
if err != nil {
return "", err
}
// Guard against a potentially invalid key type
switch p.Type {
case KeyType_AES256_GCM96:
default:
return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
}
// Decode the base64
decoded, err := base64.StdEncoding.DecodeString(splitVerCiphertext[1])
if err != nil {
return "", errutil.UserError{Err: "invalid ciphertext: could not decode base64"}
}
// Setup the cipher
aesCipher, err := aes.NewCipher(key)
if err != nil {
return "", errutil.InternalError{Err: err.Error()}
}
// Setup the GCM AEAD
gcm, err := cipher.NewGCM(aesCipher)
if err != nil {
return "", errutil.InternalError{Err: err.Error()}
}
// Extract the nonce and ciphertext
var ciphertext []byte
if p.ConvergentEncryption && p.ConvergentVersion < 2 {
ciphertext = decoded
} else {
nonce = decoded[:gcm.NonceSize()]
ciphertext = decoded[gcm.NonceSize():]
}
// Verify and Decrypt
plain, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return "", errutil.UserError{Err: "invalid ciphertext: unable to decrypt"}
}
return base64.StdEncoding.EncodeToString(plain), nil
}
func (p *Policy) HMACKey(version int) ([]byte, error) {
if version < p.MinDecryptionVersion {
return nil, fmt.Errorf("key version disallowed by policy (minimum is %d)", p.MinDecryptionVersion)
}
if version > p.LatestVersion {
return nil, fmt.Errorf("key version does not exist; latest key version is %d", p.LatestVersion)
}
if p.Keys[version].HMACKey == nil {
return nil, fmt.Errorf("no HMAC key exists for that key version")
}
return p.Keys[version].HMACKey, nil
}
func (p *Policy) Sign(hashedInput []byte) (string, error) {
if !p.Type.SigningSupported() {
return "", fmt.Errorf("message signing not supported for key type %v", p.Type)
}
var sig []byte
switch p.Type {
case KeyType_ECDSA_P256:
keyParams := p.Keys[p.LatestVersion]
key := &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: elliptic.P256(),
X: keyParams.EC_X,
Y: keyParams.EC_Y,
},
D: keyParams.EC_D,
}
r, s, err := ecdsa.Sign(rand.Reader, key, hashedInput)
if err != nil {
return "", err
}
marshaledSig, err := asn1.Marshal(ecdsaSignature{
R: r,
S: s,
})
if err != nil {
return "", err
}
sig = marshaledSig
default:
return "", fmt.Errorf("unsupported key type %v", p.Type)
}
// Convert to base64
encoded := base64.StdEncoding.EncodeToString(sig)
// Prepend some information
encoded = "vault:v" + strconv.Itoa(p.LatestVersion) + ":" + encoded
return encoded, nil
}
func (p *Policy) VerifySignature(hashedInput []byte, sig string) (bool, error) {
if !p.Type.SigningSupported() {
return false, errutil.UserError{Err: fmt.Sprintf("message verification not supported for key type %v", p.Type)}
}
// Verify the prefix
if !strings.HasPrefix(sig, "vault:v") {
return false, errutil.UserError{Err: "invalid signature: no prefix"}
}
splitVerSig := strings.SplitN(strings.TrimPrefix(sig, "vault:v"), ":", 2)
if len(splitVerSig) != 2 {
return false, errutil.UserError{Err: "invalid signature: wrong number of fields"}
}
ver, err := strconv.Atoi(splitVerSig[0])
if err != nil {
return false, errutil.UserError{Err: "invalid signature: version number could not be decoded"}
}
if ver > p.LatestVersion {
return false, errutil.UserError{Err: "invalid signature: version is too new"}
}
if p.MinDecryptionVersion > 0 && ver < p.MinDecryptionVersion {
return false, errutil.UserError{Err: ErrTooOld}
}
switch p.Type {
case KeyType_ECDSA_P256:
asn1Sig, err := base64.StdEncoding.DecodeString(splitVerSig[1])
if err != nil {
return false, errutil.UserError{Err: "invalid base64 signature value"}
}
var ecdsaSig ecdsaSignature
rest, err := asn1.Unmarshal(asn1Sig, &ecdsaSig)
if err != nil {
return false, errutil.UserError{Err: "supplied signature is invalid"}
}
if rest != nil && len(rest) != 0 {
return false, errutil.UserError{Err: "supplied signature contains extra data"}
}
keyParams := p.Keys[ver]
key := &ecdsa.PublicKey{
Curve: elliptic.P256(),
X: keyParams.EC_X,
Y: keyParams.EC_Y,
}
return ecdsa.Verify(key, hashedInput, ecdsaSig.R, ecdsaSig.S), nil
default:
return false, errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
}
return false, errutil.InternalError{Err: "no valid key type found"}
}
func (p *Policy) Rotate(storage logical.Storage) error {
if p.Keys == nil {
// This is an initial key rotation when generating a new policy. We
// don't need to call migrate here because if we've called getPolicy to
// get the policy in the first place it will have been run.
p.Keys = keyEntryMap{}
}
p.LatestVersion += 1
entry := KeyEntry{
CreationTime: time.Now().Unix(),
}
hmacKey, err := uuid.GenerateRandomBytes(32)
if err != nil {
return err
}
entry.HMACKey = hmacKey
switch p.Type {
case KeyType_AES256_GCM96:
// Generate a 256bit key
newKey, err := uuid.GenerateRandomBytes(32)
if err != nil {
return err
}
entry.AESKey = newKey
case KeyType_ECDSA_P256:
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return err
}
entry.EC_D = privKey.D
entry.EC_X = privKey.X
entry.EC_Y = privKey.Y
derBytes, err := x509.MarshalPKIXPublicKey(privKey.Public())
if err != nil {
return fmt.Errorf("error marshaling public key: %s", err)
}
pemBlock := &pem.Block{
Type: "PUBLIC KEY",
Bytes: derBytes,
}
pemBytes := pem.EncodeToMemory(pemBlock)
if pemBytes == nil || len(pemBytes) == 0 {
return fmt.Errorf("error PEM-encoding public key")
}
entry.FormattedPublicKey = string(pemBytes)
}
p.Keys[p.LatestVersion] = entry
// This ensures that with new key creations min decryption version is set
// to 1 rather than the int default of 0, since keys start at 1 (either
// fresh or after migration to the key map)
if p.MinDecryptionVersion == 0 {
p.MinDecryptionVersion = 1
}
return p.Persist(storage)
}
func (p *Policy) MigrateKeyToKeysMap() {
p.Keys = keyEntryMap{
1: KeyEntry{
AESKey: p.Key,
CreationTime: time.Now().Unix(),
},
}
p.Key = nil
}