09939f0ba9
* Allow passing AssociatedData factories in keysutil This allows the high-level, algorithm-agnostic Encrypt/Decrypt with Factory to pass in AssociatedData, and potentially take multiple factories (to allow KMS keys to work). On AEAD ciphers with a relevant factory, an AssociatedData factory will be used to populate the AdditionalData field of the SymmetricOpts struct, using it in the AEAD Seal process. Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com> * Add associated_data to Transit Encrypt/Decrypt API This allows passing the associated_data (the last AD in AEAD) to Transit's encrypt/decrypt when using an AEAD cipher (currently aes128-gcm96, aes256-gcm96, and chacha20-poly1305). We err if this parameter is passed on non-AEAD ciphers presently. This associated data can be safely transited in plaintext, without risk of modifications. In the event of tampering with either the ciphertext or the associated data, decryption will fail. Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com> * Add changelog Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com> * Add to documentation Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com> Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com>
1976 lines
55 KiB
Go
1976 lines
55 KiB
Go
package keysutil
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto"
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/hmac"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/sha256"
|
|
"crypto/x509"
|
|
"encoding/asn1"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"encoding/pem"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"math/big"
|
|
"path"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/chacha20poly1305"
|
|
"golang.org/x/crypto/ed25519"
|
|
"golang.org/x/crypto/hkdf"
|
|
|
|
"github.com/hashicorp/errwrap"
|
|
uuid "github.com/hashicorp/go-uuid"
|
|
"github.com/hashicorp/vault/sdk/helper/errutil"
|
|
"github.com/hashicorp/vault/sdk/helper/jsonutil"
|
|
"github.com/hashicorp/vault/sdk/helper/kdf"
|
|
"github.com/hashicorp/vault/sdk/logical"
|
|
)
|
|
|
|
// Careful with iota; don't put anything before it in this const block because
|
|
// we need the default of zero to be the old-style KDF
|
|
const (
|
|
Kdf_hmac_sha256_counter = iota // built-in helper
|
|
Kdf_hkdf_sha256 // golang.org/x/crypto/hkdf
|
|
|
|
HmacMinKeySize = 256 / 8
|
|
HmacMaxKeySize = 4096 / 8
|
|
)
|
|
|
|
// Or this one...we need the default of zero to be the original AES256-GCM96
|
|
const (
|
|
KeyType_AES256_GCM96 = iota
|
|
KeyType_ECDSA_P256
|
|
KeyType_ED25519
|
|
KeyType_RSA2048
|
|
KeyType_RSA4096
|
|
KeyType_ChaCha20_Poly1305
|
|
KeyType_ECDSA_P384
|
|
KeyType_ECDSA_P521
|
|
KeyType_AES128_GCM96
|
|
KeyType_RSA3072
|
|
KeyType_MANAGED_KEY
|
|
KeyType_HMAC
|
|
)
|
|
|
|
const (
|
|
// ErrTooOld is returned whtn the ciphertext or signatures's key version is
|
|
// too old.
|
|
ErrTooOld = "ciphertext or signature version is disallowed by policy (too old)"
|
|
|
|
// DefaultVersionTemplate is used when no version template is provided.
|
|
DefaultVersionTemplate = "vault:v{{version}}:"
|
|
)
|
|
|
|
type AEADFactory interface {
|
|
GetAEAD(iv []byte) (cipher.AEAD, error)
|
|
}
|
|
|
|
type AssociatedDataFactory interface {
|
|
GetAssociatedData() ([]byte, error)
|
|
}
|
|
|
|
type RestoreInfo struct {
|
|
Time time.Time `json:"time"`
|
|
Version int `json:"version"`
|
|
}
|
|
|
|
type BackupInfo struct {
|
|
Time time.Time `json:"time"`
|
|
Version int `json:"version"`
|
|
}
|
|
|
|
type SigningOptions struct {
|
|
HashAlgorithm HashType
|
|
Marshaling MarshalingType
|
|
SaltLength int
|
|
SigAlgorithm string
|
|
}
|
|
|
|
type SigningResult struct {
|
|
Signature string
|
|
PublicKey []byte
|
|
}
|
|
|
|
type ecdsaSignature struct {
|
|
R, S *big.Int
|
|
}
|
|
|
|
type KeyType int
|
|
|
|
func (kt KeyType) EncryptionSupported() bool {
|
|
switch kt {
|
|
case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305, KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (kt KeyType) DecryptionSupported() bool {
|
|
switch kt {
|
|
case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305, KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (kt KeyType) SigningSupported() bool {
|
|
switch kt {
|
|
case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521, KeyType_ED25519, KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (kt KeyType) HashSignatureInput() bool {
|
|
switch kt {
|
|
case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521, KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (kt KeyType) DerivationSupported() bool {
|
|
switch kt {
|
|
case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305, KeyType_ED25519:
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (kt KeyType) AssociatedDataSupported() bool {
|
|
switch kt {
|
|
case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305:
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (kt KeyType) String() string {
|
|
switch kt {
|
|
case KeyType_AES128_GCM96:
|
|
return "aes128-gcm96"
|
|
case KeyType_AES256_GCM96:
|
|
return "aes256-gcm96"
|
|
case KeyType_ChaCha20_Poly1305:
|
|
return "chacha20-poly1305"
|
|
case KeyType_ECDSA_P256:
|
|
return "ecdsa-p256"
|
|
case KeyType_ECDSA_P384:
|
|
return "ecdsa-p384"
|
|
case KeyType_ECDSA_P521:
|
|
return "ecdsa-p521"
|
|
case KeyType_ED25519:
|
|
return "ed25519"
|
|
case KeyType_RSA2048:
|
|
return "rsa-2048"
|
|
case KeyType_RSA3072:
|
|
return "rsa-3072"
|
|
case KeyType_RSA4096:
|
|
return "rsa-4096"
|
|
case KeyType_HMAC:
|
|
return "hmac"
|
|
}
|
|
|
|
return "[unknown]"
|
|
}
|
|
|
|
type KeyData struct {
|
|
Policy *Policy `json:"policy"`
|
|
ArchivedKeys *archivedKeys `json:"archived_keys"`
|
|
}
|
|
|
|
// KeyEntry stores the key and metadata
|
|
type KeyEntry struct {
|
|
// AES or some other kind that is a pure byte slice like ED25519
|
|
Key []byte `json:"key"`
|
|
|
|
// Key used for HMAC functions
|
|
HMACKey []byte `json:"hmac_key"`
|
|
|
|
// Time of creation
|
|
CreationTime time.Time `json:"time"`
|
|
|
|
EC_X *big.Int `json:"ec_x"`
|
|
EC_Y *big.Int `json:"ec_y"`
|
|
EC_D *big.Int `json:"ec_d"`
|
|
|
|
RSAKey *rsa.PrivateKey `json:"rsa_key"`
|
|
|
|
// The public key in an appropriate format for the type of key
|
|
FormattedPublicKey string `json:"public_key"`
|
|
|
|
// If convergent is enabled, the version (falling back to what's in the
|
|
// policy)
|
|
ConvergentVersion int `json:"convergent_version"`
|
|
|
|
// This is deprecated (but still filled) in favor of the value above which
|
|
// is more precise
|
|
DeprecatedCreationTime int64 `json:"creation_time"`
|
|
}
|
|
|
|
// deprecatedKeyEntryMap is used to allow JSON marshal/unmarshal
|
|
type deprecatedKeyEntryMap map[int]KeyEntry
|
|
|
|
// MarshalJSON implements JSON marshaling
|
|
func (kem deprecatedKeyEntryMap) MarshalJSON() ([]byte, error) {
|
|
intermediate := map[string]KeyEntry{}
|
|
for k, v := range kem {
|
|
intermediate[strconv.Itoa(k)] = v
|
|
}
|
|
return json.Marshal(&intermediate)
|
|
}
|
|
|
|
// MarshalJSON implements JSON unmarshalling
|
|
func (kem deprecatedKeyEntryMap) UnmarshalJSON(data []byte) error {
|
|
intermediate := map[string]KeyEntry{}
|
|
if err := jsonutil.DecodeJSON(data, &intermediate); err != nil {
|
|
return err
|
|
}
|
|
for k, v := range intermediate {
|
|
keyval, err := strconv.Atoi(k)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
kem[keyval] = v
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// keyEntryMap is used to allow JSON marshal/unmarshal
|
|
type keyEntryMap map[string]KeyEntry
|
|
|
|
// PolicyConfig is used to create a new policy
|
|
type PolicyConfig struct {
|
|
// The name of the policy
|
|
Name string `json:"name"`
|
|
|
|
// The type of key
|
|
Type KeyType
|
|
|
|
// Derived keys MUST provide a context and the master underlying key is
|
|
// never used.
|
|
Derived bool
|
|
KDF int
|
|
ConvergentEncryption bool
|
|
|
|
// Whether the key is exportable
|
|
Exportable bool
|
|
|
|
// Whether the key is allowed to be deleted
|
|
DeletionAllowed bool
|
|
|
|
// AllowPlaintextBackup allows taking backup of the policy in plaintext
|
|
AllowPlaintextBackup bool
|
|
|
|
// VersionTemplate is used to prefix the ciphertext with information about
|
|
// the key version. It must inclide {{version}} and a delimiter between the
|
|
// version prefix and the ciphertext.
|
|
VersionTemplate string
|
|
|
|
// StoragePrefix is used to add a prefix when storing and retrieving the
|
|
// policy object.
|
|
StoragePrefix string
|
|
}
|
|
|
|
// NewPolicy takes a policy config and returns a Policy with those settings.
|
|
func NewPolicy(config PolicyConfig) *Policy {
|
|
return &Policy{
|
|
l: new(sync.RWMutex),
|
|
Name: config.Name,
|
|
Type: config.Type,
|
|
Derived: config.Derived,
|
|
KDF: config.KDF,
|
|
ConvergentEncryption: config.ConvergentEncryption,
|
|
ConvergentVersion: -1,
|
|
Exportable: config.Exportable,
|
|
DeletionAllowed: config.DeletionAllowed,
|
|
AllowPlaintextBackup: config.AllowPlaintextBackup,
|
|
VersionTemplate: config.VersionTemplate,
|
|
StoragePrefix: config.StoragePrefix,
|
|
}
|
|
}
|
|
|
|
// LoadPolicy will load a policy from the provided storage path and set the
|
|
// necessary un-exported variables. It is particularly useful when accessing a
|
|
// policy without the lock manager.
|
|
func LoadPolicy(ctx context.Context, s logical.Storage, path string) (*Policy, error) {
|
|
raw, err := s.Get(ctx, path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if raw == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
var policy Policy
|
|
err = jsonutil.DecodeJSON(raw.Value, &policy)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
policy.l = new(sync.RWMutex)
|
|
|
|
return &policy, nil
|
|
}
|
|
|
|
// Policy is the struct used to store metadata
|
|
type Policy struct {
|
|
// This is a pointer on purpose: if we are running with cache disabled we
|
|
// need to actually swap in the lock manager's lock for this policy with
|
|
// the local lock.
|
|
l *sync.RWMutex
|
|
// writeLocked allows us to implement Lock() and Unlock()
|
|
writeLocked bool
|
|
// Stores whether it's been deleted. This acts as a guard for operations
|
|
// that may write data, e.g. if one request rotates and that request is
|
|
// served after a delete.
|
|
deleted uint32
|
|
|
|
Name string `json:"name"`
|
|
Key []byte `json:"key,omitempty"` // DEPRECATED
|
|
KeySize int `json:"key_size,omitempty"` // For algorithms with variable key sizes
|
|
Keys keyEntryMap `json:"keys"`
|
|
|
|
// Derived keys MUST provide a context and the master underlying key is
|
|
// never used. If convergent encryption is true, the context will be used
|
|
// as the nonce as well.
|
|
Derived bool `json:"derived"`
|
|
KDF int `json:"kdf"`
|
|
ConvergentEncryption bool `json:"convergent_encryption"`
|
|
|
|
// Whether the key is exportable
|
|
Exportable bool `json:"exportable"`
|
|
|
|
// The minimum version of the key allowed to be used for decryption
|
|
MinDecryptionVersion int `json:"min_decryption_version"`
|
|
|
|
// The minimum version of the key allowed to be used for encryption
|
|
MinEncryptionVersion int `json:"min_encryption_version"`
|
|
|
|
// The latest key version in this policy
|
|
LatestVersion int `json:"latest_version"`
|
|
|
|
// The latest key version in the archive. We never delete these, so this is
|
|
// a max.
|
|
ArchiveVersion int `json:"archive_version"`
|
|
|
|
// ArchiveMinVersion is the minimum version of the key in the archive.
|
|
ArchiveMinVersion int `json:"archive_min_version"`
|
|
|
|
// MinAvailableVersion is the minimum version of the key present. All key
|
|
// versions before this would have been deleted.
|
|
MinAvailableVersion int `json:"min_available_version"`
|
|
|
|
// Whether the key is allowed to be deleted
|
|
DeletionAllowed bool `json:"deletion_allowed"`
|
|
|
|
// The version of the convergent nonce to use
|
|
ConvergentVersion int `json:"convergent_version"`
|
|
|
|
// The type of key
|
|
Type KeyType `json:"type"`
|
|
|
|
// BackupInfo indicates the information about the backup action taken on
|
|
// this policy
|
|
BackupInfo *BackupInfo `json:"backup_info"`
|
|
|
|
// RestoreInfo indicates the information about the restore action taken on
|
|
// this policy
|
|
RestoreInfo *RestoreInfo `json:"restore_info"`
|
|
|
|
// AllowPlaintextBackup allows taking backup of the policy in plaintext
|
|
AllowPlaintextBackup bool `json:"allow_plaintext_backup"`
|
|
|
|
// VersionTemplate is used to prefix the ciphertext with information about
|
|
// the key version. It must inclide {{version}} and a delimiter between the
|
|
// version prefix and the ciphertext.
|
|
VersionTemplate string `json:"version_template"`
|
|
|
|
// StoragePrefix is used to add a prefix when storing and retrieving the
|
|
// policy object.
|
|
StoragePrefix string `json:"storage_prefix"`
|
|
|
|
// AutoRotatePeriod defines how frequently the key should automatically
|
|
// rotate. Setting this to zero disables automatic rotation for the key.
|
|
AutoRotatePeriod time.Duration `json:"auto_rotate_period"`
|
|
|
|
// versionPrefixCache stores caches of version prefix strings and the split
|
|
// version template.
|
|
versionPrefixCache sync.Map
|
|
|
|
// Imported indicates whether the key was generated by Vault or imported
|
|
// from an external source
|
|
Imported bool
|
|
|
|
// AllowImportedKeyRotation indicates whether an imported key may be rotated by Vault
|
|
AllowImportedKeyRotation bool
|
|
|
|
ManagedKeyName string `json:"managed_key_name,omitempty"`
|
|
}
|
|
|
|
func (p *Policy) Lock(exclusive bool) {
|
|
if exclusive {
|
|
p.l.Lock()
|
|
p.writeLocked = true
|
|
} else {
|
|
p.l.RLock()
|
|
}
|
|
}
|
|
|
|
func (p *Policy) Unlock() {
|
|
if p.writeLocked {
|
|
p.writeLocked = false
|
|
p.l.Unlock()
|
|
} else {
|
|
p.l.RUnlock()
|
|
}
|
|
}
|
|
|
|
// ArchivedKeys stores old keys. This is used to keep the key loading time sane
|
|
// when there are huge numbers of rotations.
|
|
type archivedKeys struct {
|
|
Keys []KeyEntry `json:"keys"`
|
|
}
|
|
|
|
func (p *Policy) LoadArchive(ctx context.Context, storage logical.Storage) (*archivedKeys, error) {
|
|
archive := &archivedKeys{}
|
|
|
|
raw, err := storage.Get(ctx, path.Join(p.StoragePrefix, "archive", p.Name))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if raw == nil {
|
|
archive.Keys = make([]KeyEntry, 0)
|
|
return archive, nil
|
|
}
|
|
|
|
if err := jsonutil.DecodeJSON(raw.Value, archive); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return archive, nil
|
|
}
|
|
|
|
func (p *Policy) storeArchive(ctx context.Context, storage logical.Storage, archive *archivedKeys) error {
|
|
// Encode the policy
|
|
buf, err := json.Marshal(archive)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Write the policy into storage
|
|
err = storage.Put(ctx, &logical.StorageEntry{
|
|
Key: path.Join(p.StoragePrefix, "archive", p.Name),
|
|
Value: buf,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// handleArchiving manages the movement of keys to and from the policy archive.
|
|
// This should *ONLY* be called from Persist() since it assumes that the policy
|
|
// will be persisted afterwards.
|
|
func (p *Policy) handleArchiving(ctx context.Context, storage logical.Storage) error {
|
|
// We need to move keys that are no longer accessible to archivedKeys, and keys
|
|
// that now need to be accessible back here.
|
|
//
|
|
// For safety, because there isn't really a good reason to, we never delete
|
|
// keys from the archive even when we move them back.
|
|
|
|
// Check if we have the latest minimum version in the current set of keys
|
|
_, keysContainsMinimum := p.Keys[strconv.Itoa(p.MinDecryptionVersion)]
|
|
|
|
// Sanity checks
|
|
switch {
|
|
case p.MinDecryptionVersion < 1:
|
|
return fmt.Errorf("minimum decryption version of %d is less than 1", p.MinDecryptionVersion)
|
|
case p.LatestVersion < 1:
|
|
return fmt.Errorf("latest version of %d is less than 1", p.LatestVersion)
|
|
case !keysContainsMinimum && p.ArchiveVersion != p.LatestVersion:
|
|
return fmt.Errorf("need to move keys from archive but archive version not up-to-date")
|
|
case p.ArchiveVersion > p.LatestVersion:
|
|
return fmt.Errorf("archive version of %d is greater than the latest version %d",
|
|
p.ArchiveVersion, p.LatestVersion)
|
|
case p.MinEncryptionVersion > 0 && p.MinEncryptionVersion < p.MinDecryptionVersion:
|
|
return fmt.Errorf("minimum decryption version of %d is greater than minimum encryption version %d",
|
|
p.MinDecryptionVersion, p.MinEncryptionVersion)
|
|
case p.MinDecryptionVersion > p.LatestVersion:
|
|
return fmt.Errorf("minimum decryption version of %d is greater than the latest version %d",
|
|
p.MinDecryptionVersion, p.LatestVersion)
|
|
}
|
|
|
|
archive, err := p.LoadArchive(ctx, storage)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !keysContainsMinimum {
|
|
// Need to move keys *from* archive
|
|
for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ {
|
|
p.Keys[strconv.Itoa(i)] = archive.Keys[i-p.MinAvailableVersion]
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Need to move keys *to* archive
|
|
|
|
// We need a size that is equivalent to the latest version (number of keys)
|
|
// but adding one since slice numbering starts at 0 and we're indexing by
|
|
// key version
|
|
if len(archive.Keys)+p.MinAvailableVersion < p.LatestVersion+1 {
|
|
// Increase the size of the archive slice
|
|
newKeys := make([]KeyEntry, p.LatestVersion-p.MinAvailableVersion+1)
|
|
copy(newKeys, archive.Keys)
|
|
archive.Keys = newKeys
|
|
}
|
|
|
|
// We are storing all keys in the archive, so we ensure that it is up to
|
|
// date up to p.LatestVersion
|
|
for i := p.ArchiveVersion + 1; i <= p.LatestVersion; i++ {
|
|
archive.Keys[i-p.MinAvailableVersion] = p.Keys[strconv.Itoa(i)]
|
|
p.ArchiveVersion = i
|
|
}
|
|
|
|
// Trim the keys if required
|
|
if p.ArchiveMinVersion < p.MinAvailableVersion {
|
|
archive.Keys = archive.Keys[p.MinAvailableVersion-p.ArchiveMinVersion:]
|
|
p.ArchiveMinVersion = p.MinAvailableVersion
|
|
}
|
|
|
|
err = p.storeArchive(ctx, storage, archive)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Perform deletion afterwards so that if there is an error saving we
|
|
// haven't messed with the current policy
|
|
for i := p.LatestVersion - len(p.Keys) + 1; i < p.MinDecryptionVersion; i++ {
|
|
delete(p.Keys, strconv.Itoa(i))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *Policy) Persist(ctx context.Context, storage logical.Storage) (retErr error) {
|
|
if atomic.LoadUint32(&p.deleted) == 1 {
|
|
return errors.New("key has been deleted, not persisting")
|
|
}
|
|
|
|
// Other functions will take care of restoring other values; this is just
|
|
// responsible for archiving and keys since the archive function can modify
|
|
// keys. At the moment one of the other functions calling persist will also
|
|
// roll back keys, but better safe than sorry and this doesn't happen
|
|
// enough to worry about the speed tradeoff.
|
|
priorArchiveVersion := p.ArchiveVersion
|
|
var priorKeys keyEntryMap
|
|
|
|
if p.Keys != nil {
|
|
priorKeys = keyEntryMap{}
|
|
for k, v := range p.Keys {
|
|
priorKeys[k] = v
|
|
}
|
|
}
|
|
|
|
defer func() {
|
|
if retErr != nil {
|
|
p.ArchiveVersion = priorArchiveVersion
|
|
p.Keys = priorKeys
|
|
}
|
|
}()
|
|
|
|
err := p.handleArchiving(ctx, storage)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Encode the policy
|
|
buf, err := p.Serialize()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Write the policy into storage
|
|
err = storage.Put(ctx, &logical.StorageEntry{
|
|
Key: path.Join(p.StoragePrefix, "policy", p.Name),
|
|
Value: buf,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *Policy) Serialize() ([]byte, error) {
|
|
return json.Marshal(p)
|
|
}
|
|
|
|
func (p *Policy) NeedsUpgrade() bool {
|
|
// Ensure we've moved from Key -> Keys
|
|
if p.Key != nil && len(p.Key) > 0 {
|
|
return true
|
|
}
|
|
|
|
// With archiving, past assumptions about the length of the keys map are no
|
|
// longer valid
|
|
if p.LatestVersion == 0 && len(p.Keys) != 0 {
|
|
return true
|
|
}
|
|
|
|
// We disallow setting the version to 0, since they start at 1 since moving
|
|
// to rotate-able keys, so update if it's set to 0
|
|
if p.MinDecryptionVersion == 0 {
|
|
return true
|
|
}
|
|
|
|
// On first load after an upgrade, copy keys to the archive
|
|
if p.ArchiveVersion == 0 {
|
|
return true
|
|
}
|
|
|
|
// Need to write the version if zero; for version 3 on we set this to -1 to
|
|
// ignore it since we store this information in each key entry
|
|
if p.ConvergentEncryption && p.ConvergentVersion == 0 {
|
|
return true
|
|
}
|
|
|
|
if p.Keys[strconv.Itoa(p.LatestVersion)].HMACKey == nil || len(p.Keys[strconv.Itoa(p.LatestVersion)].HMACKey) == 0 {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func (p *Policy) Upgrade(ctx context.Context, storage logical.Storage, randReader io.Reader) (retErr error) {
|
|
priorKey := p.Key
|
|
priorLatestVersion := p.LatestVersion
|
|
priorMinDecryptionVersion := p.MinDecryptionVersion
|
|
priorConvergentVersion := p.ConvergentVersion
|
|
var priorKeys keyEntryMap
|
|
|
|
if p.Keys != nil {
|
|
priorKeys = keyEntryMap{}
|
|
for k, v := range p.Keys {
|
|
priorKeys[k] = v
|
|
}
|
|
}
|
|
|
|
defer func() {
|
|
if retErr != nil {
|
|
p.Key = priorKey
|
|
p.LatestVersion = priorLatestVersion
|
|
p.MinDecryptionVersion = priorMinDecryptionVersion
|
|
p.ConvergentVersion = priorConvergentVersion
|
|
p.Keys = priorKeys
|
|
}
|
|
}()
|
|
|
|
persistNeeded := false
|
|
// Ensure we've moved from Key -> Keys
|
|
if p.Key != nil && len(p.Key) > 0 {
|
|
p.MigrateKeyToKeysMap()
|
|
persistNeeded = true
|
|
}
|
|
|
|
// With archiving, past assumptions about the length of the keys map are no
|
|
// longer valid
|
|
if p.LatestVersion == 0 && len(p.Keys) != 0 {
|
|
p.LatestVersion = len(p.Keys)
|
|
persistNeeded = true
|
|
}
|
|
|
|
// We disallow setting the version to 0, since they start at 1 since moving
|
|
// to rotate-able keys, so update if it's set to 0
|
|
if p.MinDecryptionVersion == 0 {
|
|
p.MinDecryptionVersion = 1
|
|
persistNeeded = true
|
|
}
|
|
|
|
// On first load after an upgrade, copy keys to the archive
|
|
if p.ArchiveVersion == 0 {
|
|
persistNeeded = true
|
|
}
|
|
|
|
if p.ConvergentEncryption && p.ConvergentVersion == 0 {
|
|
p.ConvergentVersion = 1
|
|
persistNeeded = true
|
|
}
|
|
|
|
if p.Keys[strconv.Itoa(p.LatestVersion)].HMACKey == nil || len(p.Keys[strconv.Itoa(p.LatestVersion)].HMACKey) == 0 {
|
|
entry := p.Keys[strconv.Itoa(p.LatestVersion)]
|
|
hmacKey, err := uuid.GenerateRandomBytesWithReader(32, randReader)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
entry.HMACKey = hmacKey
|
|
p.Keys[strconv.Itoa(p.LatestVersion)] = entry
|
|
persistNeeded = true
|
|
}
|
|
|
|
if persistNeeded {
|
|
err := p.Persist(ctx, storage)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetKey is used to derive the encryption key that should be used depending
|
|
// on the policy. If derivation is disabled the raw key is used and no context
|
|
// is required, otherwise the KDF mode is used with the context to derive the
|
|
// proper key.
|
|
func (p *Policy) GetKey(context []byte, ver, numBytes int) ([]byte, error) {
|
|
// Fast-path non-derived keys
|
|
if !p.Derived {
|
|
keyEntry, err := p.safeGetKeyEntry(ver)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return keyEntry.Key, nil
|
|
}
|
|
|
|
return p.DeriveKey(context, nil, ver, numBytes)
|
|
}
|
|
|
|
// DeriveKey is used to derive a symmetric key given a context and salt. This does not
|
|
// check the policies Derived flag, but just implements the derivation logic. GetKey
|
|
// is responsible for switching on the policy config.
|
|
func (p *Policy) DeriveKey(context, salt []byte, ver int, numBytes int) ([]byte, error) {
|
|
if !p.Type.DerivationSupported() {
|
|
return nil, errutil.UserError{Err: fmt.Sprintf("derivation not supported for key type %v", p.Type)}
|
|
}
|
|
|
|
if p.Keys == nil || p.LatestVersion == 0 {
|
|
return nil, errutil.InternalError{Err: "unable to access the key; no key versions found"}
|
|
}
|
|
|
|
if ver <= 0 || ver > p.LatestVersion {
|
|
return nil, errutil.UserError{Err: "invalid key version"}
|
|
}
|
|
|
|
// Ensure a context is provided
|
|
if len(context) == 0 {
|
|
return nil, errutil.UserError{Err: "missing 'context' for key derivation; the key was created using a derived key, which means additional, per-request information must be included in order to perform operations with the key"}
|
|
}
|
|
|
|
keyEntry, err := p.safeGetKeyEntry(ver)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch p.KDF {
|
|
case Kdf_hmac_sha256_counter:
|
|
prf := kdf.HMACSHA256PRF
|
|
prfLen := kdf.HMACSHA256PRFLen
|
|
return kdf.CounterMode(prf, prfLen, keyEntry.Key, append(context, salt...), 256)
|
|
|
|
case Kdf_hkdf_sha256:
|
|
reader := hkdf.New(sha256.New, keyEntry.Key, salt, context)
|
|
derBytes := bytes.NewBuffer(nil)
|
|
derBytes.Grow(numBytes)
|
|
limReader := &io.LimitedReader{
|
|
R: reader,
|
|
N: int64(numBytes),
|
|
}
|
|
|
|
switch p.Type {
|
|
case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305:
|
|
n, err := derBytes.ReadFrom(limReader)
|
|
if err != nil {
|
|
return nil, errutil.InternalError{Err: fmt.Sprintf("error reading returned derived bytes: %v", err)}
|
|
}
|
|
if n != int64(numBytes) {
|
|
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to read enough derived bytes, needed %d, got %d", numBytes, n)}
|
|
}
|
|
return derBytes.Bytes(), nil
|
|
|
|
case KeyType_ED25519:
|
|
// We use the limited reader containing the derived bytes as the
|
|
// "random" input to the generation function
|
|
_, pri, err := ed25519.GenerateKey(limReader)
|
|
if err != nil {
|
|
return nil, errutil.InternalError{Err: fmt.Sprintf("error generating derived key: %v", err)}
|
|
}
|
|
return pri, nil
|
|
|
|
default:
|
|
return nil, errutil.InternalError{Err: "unsupported key type for derivation"}
|
|
}
|
|
|
|
default:
|
|
return nil, errutil.InternalError{Err: "unsupported key derivation mode"}
|
|
}
|
|
}
|
|
|
|
func (p *Policy) safeGetKeyEntry(ver int) (KeyEntry, error) {
|
|
keyVerStr := strconv.Itoa(ver)
|
|
keyEntry, ok := p.Keys[keyVerStr]
|
|
if !ok {
|
|
return keyEntry, errutil.UserError{Err: "no such key version"}
|
|
}
|
|
return keyEntry, nil
|
|
}
|
|
|
|
func (p *Policy) convergentVersion(ver int) int {
|
|
if !p.ConvergentEncryption {
|
|
return 0
|
|
}
|
|
|
|
convergentVersion := p.ConvergentVersion
|
|
if convergentVersion == 0 {
|
|
// For some reason, not upgraded yet
|
|
convergentVersion = 1
|
|
}
|
|
currKey := p.Keys[strconv.Itoa(ver)]
|
|
if currKey.ConvergentVersion != 0 {
|
|
convergentVersion = currKey.ConvergentVersion
|
|
}
|
|
|
|
return convergentVersion
|
|
}
|
|
|
|
func (p *Policy) Encrypt(ver int, context, nonce []byte, value string) (string, error) {
|
|
return p.EncryptWithFactory(ver, context, nonce, value, nil)
|
|
}
|
|
|
|
func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) {
|
|
return p.DecryptWithFactory(context, nonce, value, nil)
|
|
}
|
|
|
|
func (p *Policy) DecryptWithFactory(context, nonce []byte, value string, factories ...interface{}) (string, error) {
|
|
if !p.Type.DecryptionSupported() {
|
|
return "", errutil.UserError{Err: fmt.Sprintf("message decryption not supported for key type %v", p.Type)}
|
|
}
|
|
|
|
tplParts, err := p.getTemplateParts()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Verify the prefix
|
|
if !strings.HasPrefix(value, tplParts[0]) {
|
|
return "", errutil.UserError{Err: "invalid ciphertext: no prefix"}
|
|
}
|
|
|
|
splitVerCiphertext := strings.SplitN(strings.TrimPrefix(value, tplParts[0]), tplParts[1], 2)
|
|
if len(splitVerCiphertext) != 2 {
|
|
return "", errutil.UserError{Err: "invalid ciphertext: wrong number of fields"}
|
|
}
|
|
|
|
ver, err := strconv.Atoi(splitVerCiphertext[0])
|
|
if err != nil {
|
|
return "", errutil.UserError{Err: "invalid ciphertext: version number could not be decoded"}
|
|
}
|
|
|
|
if ver == 0 {
|
|
// Compatibility mode with initial implementation, where keys start at
|
|
// zero
|
|
ver = 1
|
|
}
|
|
|
|
if ver > p.LatestVersion {
|
|
return "", errutil.UserError{Err: "invalid ciphertext: version is too new"}
|
|
}
|
|
|
|
if p.MinDecryptionVersion > 0 && ver < p.MinDecryptionVersion {
|
|
return "", errutil.UserError{Err: ErrTooOld}
|
|
}
|
|
|
|
convergentVersion := p.convergentVersion(ver)
|
|
if convergentVersion == 1 && (nonce == nil || len(nonce) == 0) {
|
|
return "", errutil.UserError{Err: "invalid convergent nonce supplied"}
|
|
}
|
|
|
|
// Decode the base64
|
|
decoded, err := base64.StdEncoding.DecodeString(splitVerCiphertext[1])
|
|
if err != nil {
|
|
return "", errutil.UserError{Err: "invalid ciphertext: could not decode base64"}
|
|
}
|
|
|
|
var plain []byte
|
|
|
|
switch p.Type {
|
|
case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305:
|
|
numBytes := 32
|
|
if p.Type == KeyType_AES128_GCM96 {
|
|
numBytes = 16
|
|
}
|
|
|
|
encKey, err := p.GetKey(context, ver, numBytes)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if len(encKey) != numBytes {
|
|
return "", errutil.InternalError{Err: "could not derive enc key, length not correct"}
|
|
}
|
|
|
|
symopts := SymmetricOpts{
|
|
Convergent: p.ConvergentEncryption,
|
|
ConvergentVersion: p.ConvergentVersion,
|
|
}
|
|
for index, rawFactory := range factories {
|
|
if rawFactory == nil {
|
|
continue
|
|
}
|
|
switch factory := rawFactory.(type) {
|
|
case AEADFactory:
|
|
symopts.AEADFactory = factory
|
|
case AssociatedDataFactory:
|
|
symopts.AdditionalData, err = factory.GetAssociatedData()
|
|
if err != nil {
|
|
return "", errutil.InternalError{Err: fmt.Sprintf("unable to get associated_data/additional_data from factory[%d]: %v", index, err)}
|
|
}
|
|
default:
|
|
return "", errutil.InternalError{Err: fmt.Sprintf("unknown type of factory[%d]: %T", index, rawFactory)}
|
|
}
|
|
}
|
|
|
|
plain, err = p.SymmetricDecryptRaw(encKey, decoded, symopts)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
|
|
keyEntry, err := p.safeGetKeyEntry(ver)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
key := keyEntry.RSAKey
|
|
plain, err = rsa.DecryptOAEP(sha256.New(), rand.Reader, key, decoded, nil)
|
|
if err != nil {
|
|
return "", errutil.InternalError{Err: fmt.Sprintf("failed to RSA decrypt the ciphertext: %v", err)}
|
|
}
|
|
|
|
default:
|
|
return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
|
|
}
|
|
|
|
return base64.StdEncoding.EncodeToString(plain), nil
|
|
}
|
|
|
|
func (p *Policy) HMACKey(version int) ([]byte, error) {
|
|
switch {
|
|
case version < 0:
|
|
return nil, fmt.Errorf("key version does not exist (cannot be negative)")
|
|
case version > p.LatestVersion:
|
|
return nil, fmt.Errorf("key version does not exist; latest key version is %d", p.LatestVersion)
|
|
}
|
|
keyEntry, err := p.safeGetKeyEntry(version)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if p.Type == KeyType_HMAC {
|
|
return keyEntry.Key, nil
|
|
}
|
|
if keyEntry.HMACKey == nil {
|
|
return nil, fmt.Errorf("no HMAC key exists for that key version")
|
|
}
|
|
return keyEntry.HMACKey, nil
|
|
}
|
|
|
|
func (p *Policy) Sign(ver int, context, input []byte, hashAlgorithm HashType, sigAlgorithm string, marshaling MarshalingType) (*SigningResult, error) {
|
|
return p.SignWithOptions(ver, context, input, &SigningOptions{
|
|
HashAlgorithm: hashAlgorithm,
|
|
Marshaling: marshaling,
|
|
SaltLength: rsa.PSSSaltLengthAuto,
|
|
SigAlgorithm: sigAlgorithm,
|
|
})
|
|
}
|
|
|
|
func (p *Policy) minRSAPSSSaltLength() int {
|
|
// https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/crypto/rsa/pss.go;l=247
|
|
return rsa.PSSSaltLengthEqualsHash
|
|
}
|
|
|
|
func (p *Policy) maxRSAPSSSaltLength(priv *rsa.PrivateKey, hash crypto.Hash) int {
|
|
// https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/crypto/rsa/pss.go;l=288
|
|
return (priv.N.BitLen()-1+7)/8 - 2 - hash.Size()
|
|
}
|
|
|
|
func (p *Policy) validRSAPSSSaltLength(priv *rsa.PrivateKey, hash crypto.Hash, saltLength int) bool {
|
|
return p.minRSAPSSSaltLength() <= saltLength && saltLength <= p.maxRSAPSSSaltLength(priv, hash)
|
|
}
|
|
|
|
func (p *Policy) SignWithOptions(ver int, context, input []byte, options *SigningOptions) (*SigningResult, error) {
|
|
if !p.Type.SigningSupported() {
|
|
return nil, fmt.Errorf("message signing not supported for key type %v", p.Type)
|
|
}
|
|
|
|
switch {
|
|
case ver == 0:
|
|
ver = p.LatestVersion
|
|
case ver < 0:
|
|
return nil, errutil.UserError{Err: "requested version for signing is negative"}
|
|
case ver > p.LatestVersion:
|
|
return nil, errutil.UserError{Err: "requested version for signing is higher than the latest key version"}
|
|
case p.MinEncryptionVersion > 0 && ver < p.MinEncryptionVersion:
|
|
return nil, errutil.UserError{Err: "requested version for signing is less than the minimum encryption key version"}
|
|
}
|
|
|
|
var sig []byte
|
|
var pubKey []byte
|
|
var err error
|
|
keyParams, err := p.safeGetKeyEntry(ver)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
hashAlgorithm := options.HashAlgorithm
|
|
marshaling := options.Marshaling
|
|
saltLength := options.SaltLength
|
|
sigAlgorithm := options.SigAlgorithm
|
|
|
|
switch p.Type {
|
|
case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521:
|
|
var curveBits int
|
|
var curve elliptic.Curve
|
|
switch p.Type {
|
|
case KeyType_ECDSA_P384:
|
|
curveBits = 384
|
|
curve = elliptic.P384()
|
|
case KeyType_ECDSA_P521:
|
|
curveBits = 521
|
|
curve = elliptic.P521()
|
|
default:
|
|
curveBits = 256
|
|
curve = elliptic.P256()
|
|
}
|
|
|
|
key := &ecdsa.PrivateKey{
|
|
PublicKey: ecdsa.PublicKey{
|
|
Curve: curve,
|
|
X: keyParams.EC_X,
|
|
Y: keyParams.EC_Y,
|
|
},
|
|
D: keyParams.EC_D,
|
|
}
|
|
|
|
r, s, err := ecdsa.Sign(rand.Reader, key, input)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch marshaling {
|
|
case MarshalingTypeASN1:
|
|
// This is used by openssl and X.509
|
|
sig, err = asn1.Marshal(ecdsaSignature{
|
|
R: r,
|
|
S: s,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
case MarshalingTypeJWS:
|
|
// This is used by JWS
|
|
|
|
// First we have to get the length of the curve in bytes. Although
|
|
// we only support 256 now, we'll do this in an agnostic way so we
|
|
// can reuse this marshaling if we support e.g. 521. Getting the
|
|
// number of bytes without rounding up would be 65.125 so we need
|
|
// to add one in that case.
|
|
keyLen := curveBits / 8
|
|
if curveBits%8 > 0 {
|
|
keyLen++
|
|
}
|
|
|
|
// Now create the output array
|
|
sig = make([]byte, keyLen*2)
|
|
rb := r.Bytes()
|
|
sb := s.Bytes()
|
|
copy(sig[keyLen-len(rb):], rb)
|
|
copy(sig[2*keyLen-len(sb):], sb)
|
|
|
|
default:
|
|
return nil, errutil.UserError{Err: "requested marshaling type is invalid"}
|
|
}
|
|
|
|
case KeyType_ED25519:
|
|
var key ed25519.PrivateKey
|
|
|
|
if p.Derived {
|
|
// Derive the key that should be used
|
|
var err error
|
|
key, err = p.GetKey(context, ver, 32)
|
|
if err != nil {
|
|
return nil, errutil.InternalError{Err: fmt.Sprintf("error deriving key: %v", err)}
|
|
}
|
|
pubKey = key.Public().(ed25519.PublicKey)
|
|
} else {
|
|
key = ed25519.PrivateKey(keyParams.Key)
|
|
}
|
|
|
|
// Per docs, do not pre-hash ed25519; it does two passes and performs
|
|
// its own hashing
|
|
sig, err = key.Sign(rand.Reader, input, crypto.Hash(0))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
|
|
key := keyParams.RSAKey
|
|
|
|
algo, ok := CryptoHashMap[hashAlgorithm]
|
|
if !ok {
|
|
return nil, errutil.InternalError{Err: "unsupported hash algorithm"}
|
|
}
|
|
|
|
if sigAlgorithm == "" {
|
|
sigAlgorithm = "pss"
|
|
}
|
|
|
|
switch sigAlgorithm {
|
|
case "pss":
|
|
if !p.validRSAPSSSaltLength(key, algo, saltLength) {
|
|
return nil, errutil.UserError{Err: fmt.Sprintf("requested salt length %d is invalid", saltLength)}
|
|
}
|
|
sig, err = rsa.SignPSS(rand.Reader, key, algo, input, &rsa.PSSOptions{SaltLength: saltLength})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
case "pkcs1v15":
|
|
sig, err = rsa.SignPKCS1v15(rand.Reader, key, algo, input)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
default:
|
|
return nil, errutil.InternalError{Err: fmt.Sprintf("unsupported rsa signature algorithm %s", sigAlgorithm)}
|
|
}
|
|
|
|
default:
|
|
return nil, fmt.Errorf("unsupported key type %v", p.Type)
|
|
}
|
|
|
|
// Convert to base64
|
|
var encoded string
|
|
switch marshaling {
|
|
case MarshalingTypeASN1:
|
|
encoded = base64.StdEncoding.EncodeToString(sig)
|
|
case MarshalingTypeJWS:
|
|
encoded = base64.RawURLEncoding.EncodeToString(sig)
|
|
}
|
|
res := &SigningResult{
|
|
Signature: p.getVersionPrefix(ver) + encoded,
|
|
PublicKey: pubKey,
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func (p *Policy) VerifySignature(context, input []byte, hashAlgorithm HashType, sigAlgorithm string, marshaling MarshalingType, sig string) (bool, error) {
|
|
return p.VerifySignatureWithOptions(context, input, sig, &SigningOptions{
|
|
HashAlgorithm: hashAlgorithm,
|
|
Marshaling: marshaling,
|
|
SaltLength: rsa.PSSSaltLengthAuto,
|
|
SigAlgorithm: sigAlgorithm,
|
|
})
|
|
}
|
|
|
|
func (p *Policy) VerifySignatureWithOptions(context, input []byte, sig string, options *SigningOptions) (bool, error) {
|
|
if !p.Type.SigningSupported() {
|
|
return false, errutil.UserError{Err: fmt.Sprintf("message verification not supported for key type %v", p.Type)}
|
|
}
|
|
|
|
tplParts, err := p.getTemplateParts()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
// Verify the prefix
|
|
if !strings.HasPrefix(sig, tplParts[0]) {
|
|
return false, errutil.UserError{Err: "invalid signature: no prefix"}
|
|
}
|
|
|
|
splitVerSig := strings.SplitN(strings.TrimPrefix(sig, tplParts[0]), tplParts[1], 2)
|
|
if len(splitVerSig) != 2 {
|
|
return false, errutil.UserError{Err: "invalid signature: wrong number of fields"}
|
|
}
|
|
|
|
ver, err := strconv.Atoi(splitVerSig[0])
|
|
if err != nil {
|
|
return false, errutil.UserError{Err: "invalid signature: version number could not be decoded"}
|
|
}
|
|
|
|
if ver > p.LatestVersion {
|
|
return false, errutil.UserError{Err: "invalid signature: version is too new"}
|
|
}
|
|
|
|
if p.MinDecryptionVersion > 0 && ver < p.MinDecryptionVersion {
|
|
return false, errutil.UserError{Err: ErrTooOld}
|
|
}
|
|
|
|
hashAlgorithm := options.HashAlgorithm
|
|
marshaling := options.Marshaling
|
|
saltLength := options.SaltLength
|
|
sigAlgorithm := options.SigAlgorithm
|
|
|
|
var sigBytes []byte
|
|
switch marshaling {
|
|
case MarshalingTypeASN1:
|
|
sigBytes, err = base64.StdEncoding.DecodeString(splitVerSig[1])
|
|
case MarshalingTypeJWS:
|
|
sigBytes, err = base64.RawURLEncoding.DecodeString(splitVerSig[1])
|
|
default:
|
|
return false, errutil.UserError{Err: "requested marshaling type is invalid"}
|
|
}
|
|
if err != nil {
|
|
return false, errutil.UserError{Err: "invalid base64 signature value"}
|
|
}
|
|
|
|
switch p.Type {
|
|
case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521:
|
|
var curve elliptic.Curve
|
|
switch p.Type {
|
|
case KeyType_ECDSA_P384:
|
|
curve = elliptic.P384()
|
|
case KeyType_ECDSA_P521:
|
|
curve = elliptic.P521()
|
|
default:
|
|
curve = elliptic.P256()
|
|
}
|
|
|
|
var ecdsaSig ecdsaSignature
|
|
|
|
switch marshaling {
|
|
case MarshalingTypeASN1:
|
|
rest, err := asn1.Unmarshal(sigBytes, &ecdsaSig)
|
|
if err != nil {
|
|
return false, errutil.UserError{Err: "supplied signature is invalid"}
|
|
}
|
|
if rest != nil && len(rest) != 0 {
|
|
return false, errutil.UserError{Err: "supplied signature contains extra data"}
|
|
}
|
|
|
|
case MarshalingTypeJWS:
|
|
paramLen := len(sigBytes) / 2
|
|
rb := sigBytes[:paramLen]
|
|
sb := sigBytes[paramLen:]
|
|
ecdsaSig.R = new(big.Int)
|
|
ecdsaSig.R.SetBytes(rb)
|
|
ecdsaSig.S = new(big.Int)
|
|
ecdsaSig.S.SetBytes(sb)
|
|
}
|
|
|
|
keyParams, err := p.safeGetKeyEntry(ver)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
key := &ecdsa.PublicKey{
|
|
Curve: curve,
|
|
X: keyParams.EC_X,
|
|
Y: keyParams.EC_Y,
|
|
}
|
|
|
|
return ecdsa.Verify(key, input, ecdsaSig.R, ecdsaSig.S), nil
|
|
|
|
case KeyType_ED25519:
|
|
var key ed25519.PrivateKey
|
|
|
|
if p.Derived {
|
|
// Derive the key that should be used
|
|
var err error
|
|
key, err = p.GetKey(context, ver, 32)
|
|
if err != nil {
|
|
return false, errutil.InternalError{Err: fmt.Sprintf("error deriving key: %v", err)}
|
|
}
|
|
} else {
|
|
key = ed25519.PrivateKey(p.Keys[strconv.Itoa(ver)].Key)
|
|
}
|
|
|
|
return ed25519.Verify(key.Public().(ed25519.PublicKey), input, sigBytes), nil
|
|
|
|
case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
|
|
keyEntry, err := p.safeGetKeyEntry(ver)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
key := keyEntry.RSAKey
|
|
|
|
algo, ok := CryptoHashMap[hashAlgorithm]
|
|
if !ok {
|
|
return false, errutil.InternalError{Err: "unsupported hash algorithm"}
|
|
}
|
|
|
|
if sigAlgorithm == "" {
|
|
sigAlgorithm = "pss"
|
|
}
|
|
|
|
switch sigAlgorithm {
|
|
case "pss":
|
|
if !p.validRSAPSSSaltLength(key, algo, saltLength) {
|
|
return false, errutil.UserError{Err: fmt.Sprintf("requested salt length %d is invalid", saltLength)}
|
|
}
|
|
err = rsa.VerifyPSS(&key.PublicKey, algo, input, sigBytes, &rsa.PSSOptions{SaltLength: saltLength})
|
|
case "pkcs1v15":
|
|
err = rsa.VerifyPKCS1v15(&key.PublicKey, algo, input, sigBytes)
|
|
default:
|
|
return false, errutil.InternalError{Err: fmt.Sprintf("unsupported rsa signature algorithm %s", sigAlgorithm)}
|
|
}
|
|
|
|
return err == nil, nil
|
|
|
|
default:
|
|
return false, errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
|
|
}
|
|
}
|
|
|
|
func (p *Policy) Import(ctx context.Context, storage logical.Storage, key []byte, randReader io.Reader) error {
|
|
now := time.Now()
|
|
entry := KeyEntry{
|
|
CreationTime: now,
|
|
DeprecatedCreationTime: now.Unix(),
|
|
}
|
|
|
|
if p.Type != KeyType_HMAC {
|
|
hmacKey, err := uuid.GenerateRandomBytesWithReader(32, randReader)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
entry.HMACKey = hmacKey
|
|
}
|
|
|
|
if (p.Type == KeyType_AES128_GCM96 && len(key) != 16) ||
|
|
((p.Type == KeyType_AES256_GCM96 || p.Type == KeyType_ChaCha20_Poly1305) && len(key) != 32) ||
|
|
(p.Type == KeyType_HMAC && (len(key) < HmacMinKeySize || len(key) > HmacMaxKeySize)) {
|
|
return fmt.Errorf("invalid key size %d bytes for key type %s", len(key), p.Type)
|
|
}
|
|
|
|
if p.Type == KeyType_AES128_GCM96 || p.Type == KeyType_AES256_GCM96 || p.Type == KeyType_ChaCha20_Poly1305 || p.Type == KeyType_HMAC {
|
|
entry.Key = key
|
|
if p.Type == KeyType_HMAC {
|
|
p.KeySize = len(key)
|
|
}
|
|
} else {
|
|
parsedPrivateKey, err := x509.ParsePKCS8PrivateKey(key)
|
|
if err != nil {
|
|
if strings.Contains(err.Error(), "unknown elliptic curve") {
|
|
var edErr error
|
|
parsedPrivateKey, edErr = ParsePKCS8Ed25519PrivateKey(key)
|
|
if edErr != nil {
|
|
return fmt.Errorf("error parsing asymmetric key:\n - assuming contents are an ed25519 private key: %s\n - original error: %v", edErr, err)
|
|
}
|
|
|
|
// Parsing as Ed25519-in-PKCS8-ECPrivateKey succeeded!
|
|
} else {
|
|
return fmt.Errorf("error parsing asymmetric key: %s", err)
|
|
}
|
|
}
|
|
|
|
switch parsedPrivateKey.(type) {
|
|
case *ecdsa.PrivateKey:
|
|
if p.Type != KeyType_ECDSA_P256 && p.Type != KeyType_ECDSA_P384 && p.Type != KeyType_ECDSA_P521 {
|
|
return fmt.Errorf("invalid key type: expected %s, got %T", p.Type, parsedPrivateKey)
|
|
}
|
|
|
|
ecdsaKey := parsedPrivateKey.(*ecdsa.PrivateKey)
|
|
curve := elliptic.P256()
|
|
if p.Type == KeyType_ECDSA_P384 {
|
|
curve = elliptic.P384()
|
|
} else if p.Type == KeyType_ECDSA_P521 {
|
|
curve = elliptic.P521()
|
|
}
|
|
|
|
if ecdsaKey.Curve != curve {
|
|
return fmt.Errorf("invalid curve: expected %s, got %s", curve.Params().Name, ecdsaKey.Curve.Params().Name)
|
|
}
|
|
|
|
entry.EC_D = ecdsaKey.D
|
|
entry.EC_X = ecdsaKey.X
|
|
entry.EC_Y = ecdsaKey.Y
|
|
derBytes, err := x509.MarshalPKIXPublicKey(ecdsaKey.Public())
|
|
if err != nil {
|
|
return errwrap.Wrapf("error marshaling public key: {{err}}", err)
|
|
}
|
|
pemBlock := &pem.Block{
|
|
Type: "PUBLIC KEY",
|
|
Bytes: derBytes,
|
|
}
|
|
pemBytes := pem.EncodeToMemory(pemBlock)
|
|
if pemBytes == nil || len(pemBytes) == 0 {
|
|
return fmt.Errorf("error PEM-encoding public key")
|
|
}
|
|
entry.FormattedPublicKey = string(pemBytes)
|
|
case ed25519.PrivateKey:
|
|
if p.Type != KeyType_ED25519 {
|
|
return fmt.Errorf("invalid key type: expected %s, got %T", p.Type, parsedPrivateKey)
|
|
}
|
|
privateKey := parsedPrivateKey.(ed25519.PrivateKey)
|
|
|
|
entry.Key = privateKey
|
|
publicKey := privateKey.Public().(ed25519.PublicKey)
|
|
entry.FormattedPublicKey = base64.StdEncoding.EncodeToString(publicKey)
|
|
case *rsa.PrivateKey:
|
|
if p.Type != KeyType_RSA2048 && p.Type != KeyType_RSA3072 && p.Type != KeyType_RSA4096 {
|
|
return fmt.Errorf("invalid key type: expected %s, got %T", p.Type, parsedPrivateKey)
|
|
}
|
|
|
|
keyBytes := 256
|
|
if p.Type == KeyType_RSA3072 {
|
|
keyBytes = 384
|
|
} else if p.Type == KeyType_RSA4096 {
|
|
keyBytes = 512
|
|
}
|
|
rsaKey := parsedPrivateKey.(*rsa.PrivateKey)
|
|
if rsaKey.Size() != keyBytes {
|
|
return fmt.Errorf("invalid key size: expected %d bytes, got %d bytes", keyBytes, rsaKey.Size())
|
|
}
|
|
|
|
entry.RSAKey = rsaKey
|
|
default:
|
|
return fmt.Errorf("invalid key type: expected %s, got %T", p.Type, parsedPrivateKey)
|
|
}
|
|
}
|
|
|
|
p.LatestVersion += 1
|
|
|
|
if p.Keys == nil {
|
|
// This is an initial key rotation when generating a new policy. We
|
|
// don't need to call migrate here because if we've called getPolicy to
|
|
// get the policy in the first place it will have been run.
|
|
p.Keys = keyEntryMap{}
|
|
}
|
|
p.Keys[strconv.Itoa(p.LatestVersion)] = entry
|
|
|
|
// This ensures that with new key creations min decryption version is set
|
|
// to 1 rather than the int default of 0, since keys start at 1 (either
|
|
// fresh or after migration to the key map)
|
|
if p.MinDecryptionVersion == 0 {
|
|
p.MinDecryptionVersion = 1
|
|
}
|
|
|
|
return p.Persist(ctx, storage)
|
|
}
|
|
|
|
// Rotate rotates the policy and persists it to storage.
|
|
// If the rotation partially fails, the policy state will be restored.
|
|
func (p *Policy) Rotate(ctx context.Context, storage logical.Storage, randReader io.Reader) (retErr error) {
|
|
priorLatestVersion := p.LatestVersion
|
|
priorMinDecryptionVersion := p.MinDecryptionVersion
|
|
var priorKeys keyEntryMap
|
|
|
|
if p.Imported && !p.AllowImportedKeyRotation {
|
|
return fmt.Errorf("imported key %s does not allow rotation within Vault", p.Name)
|
|
}
|
|
|
|
if p.Keys != nil {
|
|
priorKeys = keyEntryMap{}
|
|
for k, v := range p.Keys {
|
|
priorKeys[k] = v
|
|
}
|
|
}
|
|
|
|
defer func() {
|
|
if retErr != nil {
|
|
p.LatestVersion = priorLatestVersion
|
|
p.MinDecryptionVersion = priorMinDecryptionVersion
|
|
p.Keys = priorKeys
|
|
}
|
|
}()
|
|
|
|
if err := p.RotateInMemory(randReader); err != nil {
|
|
return err
|
|
}
|
|
|
|
p.Imported = false
|
|
return p.Persist(ctx, storage)
|
|
}
|
|
|
|
// RotateInMemory rotates the policy but does not persist it to storage.
|
|
func (p *Policy) RotateInMemory(randReader io.Reader) (retErr error) {
|
|
now := time.Now()
|
|
entry := KeyEntry{
|
|
CreationTime: now,
|
|
DeprecatedCreationTime: now.Unix(),
|
|
}
|
|
|
|
hmacKey, err := uuid.GenerateRandomBytesWithReader(32, randReader)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
entry.HMACKey = hmacKey
|
|
|
|
switch p.Type {
|
|
case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305, KeyType_HMAC:
|
|
// Default to 256 bit key
|
|
numBytes := 32
|
|
if p.Type == KeyType_AES128_GCM96 {
|
|
numBytes = 16
|
|
} else if p.Type == KeyType_HMAC {
|
|
numBytes := p.KeySize
|
|
if numBytes < HmacMinKeySize || numBytes > HmacMaxKeySize {
|
|
return fmt.Errorf("invalid key size for HMAC key, must be between %d and %d bytes", HmacMinKeySize, HmacMaxKeySize)
|
|
}
|
|
}
|
|
newKey, err := uuid.GenerateRandomBytesWithReader(numBytes, randReader)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
entry.Key = newKey
|
|
|
|
case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521:
|
|
var curve elliptic.Curve
|
|
switch p.Type {
|
|
case KeyType_ECDSA_P384:
|
|
curve = elliptic.P384()
|
|
case KeyType_ECDSA_P521:
|
|
curve = elliptic.P521()
|
|
default:
|
|
curve = elliptic.P256()
|
|
}
|
|
|
|
privKey, err := ecdsa.GenerateKey(curve, rand.Reader)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
entry.EC_D = privKey.D
|
|
entry.EC_X = privKey.X
|
|
entry.EC_Y = privKey.Y
|
|
derBytes, err := x509.MarshalPKIXPublicKey(privKey.Public())
|
|
if err != nil {
|
|
return errwrap.Wrapf("error marshaling public key: {{err}}", err)
|
|
}
|
|
pemBlock := &pem.Block{
|
|
Type: "PUBLIC KEY",
|
|
Bytes: derBytes,
|
|
}
|
|
pemBytes := pem.EncodeToMemory(pemBlock)
|
|
if pemBytes == nil || len(pemBytes) == 0 {
|
|
return fmt.Errorf("error PEM-encoding public key")
|
|
}
|
|
entry.FormattedPublicKey = string(pemBytes)
|
|
|
|
case KeyType_ED25519:
|
|
pub, pri, err := ed25519.GenerateKey(randReader)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
entry.Key = pri
|
|
entry.FormattedPublicKey = base64.StdEncoding.EncodeToString(pub)
|
|
|
|
case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
|
|
bitSize := 2048
|
|
if p.Type == KeyType_RSA3072 {
|
|
bitSize = 3072
|
|
}
|
|
if p.Type == KeyType_RSA4096 {
|
|
bitSize = 4096
|
|
}
|
|
|
|
entry.RSAKey, err = rsa.GenerateKey(randReader, bitSize)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if p.ConvergentEncryption {
|
|
if p.ConvergentVersion == -1 || p.ConvergentVersion > 1 {
|
|
entry.ConvergentVersion = currentConvergentVersion
|
|
}
|
|
}
|
|
|
|
p.LatestVersion += 1
|
|
|
|
if p.Keys == nil {
|
|
// This is an initial key rotation when generating a new policy. We
|
|
// don't need to call migrate here because if we've called getPolicy to
|
|
// get the policy in the first place it will have been run.
|
|
p.Keys = keyEntryMap{}
|
|
}
|
|
p.Keys[strconv.Itoa(p.LatestVersion)] = entry
|
|
|
|
// This ensures that with new key creations min decryption version is set
|
|
// to 1 rather than the int default of 0, since keys start at 1 (either
|
|
// fresh or after migration to the key map)
|
|
if p.MinDecryptionVersion == 0 {
|
|
p.MinDecryptionVersion = 1
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *Policy) MigrateKeyToKeysMap() {
|
|
now := time.Now()
|
|
p.Keys = keyEntryMap{
|
|
"1": KeyEntry{
|
|
Key: p.Key,
|
|
CreationTime: now,
|
|
DeprecatedCreationTime: now.Unix(),
|
|
},
|
|
}
|
|
p.Key = nil
|
|
}
|
|
|
|
// Backup should be called with an exclusive lock held on the policy
|
|
func (p *Policy) Backup(ctx context.Context, storage logical.Storage) (out string, retErr error) {
|
|
if !p.Exportable {
|
|
return "", fmt.Errorf("exporting is disallowed on the policy")
|
|
}
|
|
|
|
if !p.AllowPlaintextBackup {
|
|
return "", fmt.Errorf("plaintext backup is disallowed on the policy")
|
|
}
|
|
|
|
priorBackupInfo := p.BackupInfo
|
|
|
|
defer func() {
|
|
if retErr != nil {
|
|
p.BackupInfo = priorBackupInfo
|
|
}
|
|
}()
|
|
|
|
// Create a record of this backup operation in the policy
|
|
p.BackupInfo = &BackupInfo{
|
|
Time: time.Now(),
|
|
Version: p.LatestVersion,
|
|
}
|
|
err := p.Persist(ctx, storage)
|
|
if err != nil {
|
|
return "", errwrap.Wrapf("failed to persist policy with backup info: {{err}}", err)
|
|
}
|
|
|
|
// Load the archive only after persisting the policy as the archive can get
|
|
// adjusted while persisting the policy
|
|
archivedKeys, err := p.LoadArchive(ctx, storage)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
keyData := &KeyData{
|
|
Policy: p,
|
|
ArchivedKeys: archivedKeys,
|
|
}
|
|
|
|
encodedBackup, err := jsonutil.EncodeJSON(keyData)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return base64.StdEncoding.EncodeToString(encodedBackup), nil
|
|
}
|
|
|
|
func (p *Policy) getTemplateParts() ([]string, error) {
|
|
partsRaw, ok := p.versionPrefixCache.Load("template-parts")
|
|
if ok {
|
|
return partsRaw.([]string), nil
|
|
}
|
|
|
|
template := p.VersionTemplate
|
|
if template == "" {
|
|
template = DefaultVersionTemplate
|
|
}
|
|
|
|
tplParts := strings.Split(template, "{{version}}")
|
|
if len(tplParts) != 2 {
|
|
return nil, errutil.InternalError{Err: "error parsing version template"}
|
|
}
|
|
|
|
p.versionPrefixCache.Store("template-parts", tplParts)
|
|
return tplParts, nil
|
|
}
|
|
|
|
func (p *Policy) getVersionPrefix(ver int) string {
|
|
prefixRaw, ok := p.versionPrefixCache.Load(ver)
|
|
if ok {
|
|
return prefixRaw.(string)
|
|
}
|
|
|
|
template := p.VersionTemplate
|
|
if template == "" {
|
|
template = DefaultVersionTemplate
|
|
}
|
|
|
|
prefix := strings.ReplaceAll(template, "{{version}}", strconv.Itoa(ver))
|
|
p.versionPrefixCache.Store(ver, prefix)
|
|
|
|
return prefix
|
|
}
|
|
|
|
// SymmetricOpts are the arguments to symmetric operations that are "optional", e.g.
|
|
// not always used. This improves the aesthetics of calls to those functions.
|
|
type SymmetricOpts struct {
|
|
// Whether to use convergent encryption
|
|
Convergent bool
|
|
// The version of the convergent encryption scheme
|
|
ConvergentVersion int
|
|
// The nonce, if not randomly generated
|
|
Nonce []byte
|
|
// Additional data to include in AEAD authentication
|
|
AdditionalData []byte
|
|
// The HMAC key, for generating IVs in convergent encryption
|
|
HMACKey []byte
|
|
// Allows an external provider of the AEAD, for e.g. managed keys
|
|
AEADFactory AEADFactory
|
|
}
|
|
|
|
// Symmetrically encrypt a plaintext given the convergence configuration and appropriate keys
|
|
func (p *Policy) SymmetricEncryptRaw(ver int, encKey, plaintext []byte, opts SymmetricOpts) ([]byte, error) {
|
|
var aead cipher.AEAD
|
|
var err error
|
|
nonce := opts.Nonce
|
|
|
|
switch p.Type {
|
|
case KeyType_AES128_GCM96, KeyType_AES256_GCM96:
|
|
// Setup the cipher
|
|
aesCipher, err := aes.NewCipher(encKey)
|
|
if err != nil {
|
|
return nil, errutil.InternalError{Err: err.Error()}
|
|
}
|
|
|
|
// Setup the GCM AEAD
|
|
gcm, err := cipher.NewGCM(aesCipher)
|
|
if err != nil {
|
|
return nil, errutil.InternalError{Err: err.Error()}
|
|
}
|
|
|
|
aead = gcm
|
|
|
|
case KeyType_ChaCha20_Poly1305:
|
|
cha, err := chacha20poly1305.New(encKey)
|
|
if err != nil {
|
|
return nil, errutil.InternalError{Err: err.Error()}
|
|
}
|
|
|
|
aead = cha
|
|
case KeyType_MANAGED_KEY:
|
|
if opts.Convergent || len(opts.Nonce) != 0 {
|
|
return nil, errutil.UserError{Err: "cannot use convergent encryption or provide a nonce to managed-key backed encryption"}
|
|
}
|
|
if opts.AEADFactory == nil {
|
|
return nil, errors.New("expected AEAD factory from managed key, none provided")
|
|
}
|
|
aead, err = opts.AEADFactory.GetAEAD(nonce)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if opts.Convergent {
|
|
convergentVersion := p.convergentVersion(ver)
|
|
switch convergentVersion {
|
|
case 1:
|
|
if len(opts.Nonce) != aead.NonceSize() {
|
|
return nil, errutil.UserError{Err: fmt.Sprintf("base64-decoded nonce must be %d bytes long when using convergent encryption with this key", aead.NonceSize())}
|
|
}
|
|
case 2, 3:
|
|
if len(opts.HMACKey) == 0 {
|
|
return nil, errutil.InternalError{Err: fmt.Sprintf("invalid hmac key length of zero")}
|
|
}
|
|
nonceHmac := hmac.New(sha256.New, opts.HMACKey)
|
|
nonceHmac.Write(plaintext)
|
|
nonceSum := nonceHmac.Sum(nil)
|
|
nonce = nonceSum[:aead.NonceSize()]
|
|
default:
|
|
return nil, errutil.InternalError{Err: fmt.Sprintf("unhandled convergent version %d", convergentVersion)}
|
|
}
|
|
} else if len(nonce) == 0 {
|
|
// Compute random nonce
|
|
nonce, err = uuid.GenerateRandomBytes(aead.NonceSize())
|
|
if err != nil {
|
|
return nil, errutil.InternalError{Err: err.Error()}
|
|
}
|
|
} else if len(nonce) != aead.NonceSize() {
|
|
return nil, errutil.UserError{Err: fmt.Sprintf("base64-decoded nonce must be %d bytes long but given %d bytes", aead.NonceSize(), len(nonce))}
|
|
}
|
|
|
|
// Encrypt and tag with AEAD
|
|
ciphertext := aead.Seal(nil, nonce, plaintext, opts.AdditionalData)
|
|
|
|
// Place the encrypted data after the nonce
|
|
if !opts.Convergent || p.convergentVersion(ver) > 1 {
|
|
ciphertext = append(nonce, ciphertext...)
|
|
}
|
|
return ciphertext, nil
|
|
}
|
|
|
|
// Symmetrically decrypt a ciphertext given the convergence configuration and appropriate keys
|
|
func (p *Policy) SymmetricDecryptRaw(encKey, ciphertext []byte, opts SymmetricOpts) ([]byte, error) {
|
|
var aead cipher.AEAD
|
|
var err error
|
|
var nonce []byte
|
|
|
|
switch p.Type {
|
|
case KeyType_AES128_GCM96, KeyType_AES256_GCM96:
|
|
// Setup the cipher
|
|
aesCipher, err := aes.NewCipher(encKey)
|
|
if err != nil {
|
|
return nil, errutil.InternalError{Err: err.Error()}
|
|
}
|
|
|
|
// Setup the GCM AEAD
|
|
gcm, err := cipher.NewGCM(aesCipher)
|
|
if err != nil {
|
|
return nil, errutil.InternalError{Err: err.Error()}
|
|
}
|
|
|
|
aead = gcm
|
|
|
|
case KeyType_ChaCha20_Poly1305:
|
|
cha, err := chacha20poly1305.New(encKey)
|
|
if err != nil {
|
|
return nil, errutil.InternalError{Err: err.Error()}
|
|
}
|
|
|
|
aead = cha
|
|
case KeyType_MANAGED_KEY:
|
|
aead, err = opts.AEADFactory.GetAEAD(nonce)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if len(ciphertext) < aead.NonceSize() {
|
|
return nil, errutil.UserError{Err: "invalid ciphertext length"}
|
|
}
|
|
|
|
// Extract the nonce and ciphertext
|
|
var trueCT []byte
|
|
if opts.Convergent && opts.ConvergentVersion == 1 {
|
|
trueCT = ciphertext
|
|
} else {
|
|
nonce = ciphertext[:aead.NonceSize()]
|
|
trueCT = ciphertext[aead.NonceSize():]
|
|
}
|
|
|
|
// Verify and Decrypt
|
|
plain, err := aead.Open(nil, nonce, trueCT, opts.AdditionalData)
|
|
if err != nil {
|
|
return nil, errutil.UserError{Err: err.Error()}
|
|
}
|
|
return plain, nil
|
|
}
|
|
|
|
func (p *Policy) EncryptWithFactory(ver int, context []byte, nonce []byte, value string, factories ...interface{}) (string, error) {
|
|
if !p.Type.EncryptionSupported() {
|
|
return "", errutil.UserError{Err: fmt.Sprintf("message encryption not supported for key type %v", p.Type)}
|
|
}
|
|
|
|
// Decode the plaintext value
|
|
plaintext, err := base64.StdEncoding.DecodeString(value)
|
|
if err != nil {
|
|
return "", errutil.UserError{Err: err.Error()}
|
|
}
|
|
|
|
switch {
|
|
case ver == 0:
|
|
ver = p.LatestVersion
|
|
case ver < 0:
|
|
return "", errutil.UserError{Err: "requested version for encryption is negative"}
|
|
case ver > p.LatestVersion:
|
|
return "", errutil.UserError{Err: "requested version for encryption is higher than the latest key version"}
|
|
case ver < p.MinEncryptionVersion:
|
|
return "", errutil.UserError{Err: "requested version for encryption is less than the minimum encryption key version"}
|
|
}
|
|
|
|
var ciphertext []byte
|
|
|
|
switch p.Type {
|
|
case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305:
|
|
hmacKey := context
|
|
|
|
var encKey []byte
|
|
var deriveHMAC bool
|
|
|
|
encBytes := 32
|
|
hmacBytes := 0
|
|
if p.convergentVersion(ver) > 2 {
|
|
deriveHMAC = true
|
|
hmacBytes = 32
|
|
}
|
|
if p.Type == KeyType_AES128_GCM96 {
|
|
encBytes = 16
|
|
}
|
|
|
|
key, err := p.GetKey(context, ver, encBytes+hmacBytes)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if len(key) < encBytes+hmacBytes {
|
|
return "", errutil.InternalError{Err: "could not derive key, length too small"}
|
|
}
|
|
|
|
encKey = key[:encBytes]
|
|
if len(encKey) != encBytes {
|
|
return "", errutil.InternalError{Err: "could not derive enc key, length not correct"}
|
|
}
|
|
if deriveHMAC {
|
|
hmacKey = key[encBytes:]
|
|
if len(hmacKey) != hmacBytes {
|
|
return "", errutil.InternalError{Err: "could not derive hmac key, length not correct"}
|
|
}
|
|
}
|
|
|
|
symopts := SymmetricOpts{
|
|
Convergent: p.ConvergentEncryption,
|
|
HMACKey: hmacKey,
|
|
Nonce: nonce,
|
|
}
|
|
for index, rawFactory := range factories {
|
|
if rawFactory == nil {
|
|
continue
|
|
}
|
|
switch factory := rawFactory.(type) {
|
|
case AEADFactory:
|
|
symopts.AEADFactory = factory
|
|
case AssociatedDataFactory:
|
|
symopts.AdditionalData, err = factory.GetAssociatedData()
|
|
if err != nil {
|
|
return "", errutil.InternalError{Err: fmt.Sprintf("unable to get associated_data/additional_data from factory[%d]: %v", index, err)}
|
|
}
|
|
default:
|
|
return "", errutil.InternalError{Err: fmt.Sprintf("unknown type of factory[%d]: %T", index, rawFactory)}
|
|
}
|
|
}
|
|
|
|
ciphertext, err = p.SymmetricEncryptRaw(ver, encKey, plaintext, symopts)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
|
|
keyEntry, err := p.safeGetKeyEntry(ver)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
key := keyEntry.RSAKey
|
|
ciphertext, err = rsa.EncryptOAEP(sha256.New(), rand.Reader, &key.PublicKey, plaintext, nil)
|
|
if err != nil {
|
|
return "", errutil.InternalError{Err: fmt.Sprintf("failed to RSA encrypt the plaintext: %v", err)}
|
|
}
|
|
|
|
default:
|
|
return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
|
|
}
|
|
|
|
// Convert to base64
|
|
encoded := base64.StdEncoding.EncodeToString(ciphertext)
|
|
|
|
// Prepend some information
|
|
encoded = p.getVersionPrefix(ver) + encoded
|
|
|
|
return encoded, nil
|
|
}
|