Pulled out transit's lock manager and policy structs into a helper
This commit is contained in:
parent
c74303dd59
commit
6d1e1a3ba5
|
@ -1,6 +1,7 @@
|
|||
package transit
|
||||
|
||||
import (
|
||||
"github.com/hashicorp/vault/helper/keysutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
@ -39,12 +40,12 @@ func Backend(conf *logical.BackendConfig) *backend {
|
|||
Secrets: []*framework.Secret{},
|
||||
}
|
||||
|
||||
b.lm = newLockManager(conf.System.CachingDisabled())
|
||||
b.lm = keysutil.NewLockManager(conf.System.CachingDisabled())
|
||||
|
||||
return &b
|
||||
}
|
||||
|
||||
type backend struct {
|
||||
*framework.Backend
|
||||
lm *lockManager
|
||||
lm *keysutil.LockManager
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"time"
|
||||
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/helper/keysutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
logicaltest "github.com/hashicorp/vault/logical/testing"
|
||||
|
@ -289,7 +290,7 @@ func testAccStepReadPolicy(t *testing.T, name string, expectNone, derived bool)
|
|||
if d.Name != name {
|
||||
return fmt.Errorf("bad name: %#v", d)
|
||||
}
|
||||
if d.Type != KeyType(keyType_AES256_GCM96).String() {
|
||||
if d.Type != keysutil.KeyType(keysutil.KeyType_AES256_GCM96).String() {
|
||||
return fmt.Errorf("bad key type: %#v", d)
|
||||
}
|
||||
// Should NOT get a key back
|
||||
|
@ -583,13 +584,13 @@ func testAccStepDecryptDatakey(t *testing.T, name string,
|
|||
|
||||
func TestKeyUpgrade(t *testing.T) {
|
||||
key, _ := uuid.GenerateRandomBytes(32)
|
||||
p := &policy{
|
||||
p := &keysutil.Policy{
|
||||
Name: "test",
|
||||
Key: key,
|
||||
Type: keyType_AES256_GCM96,
|
||||
Type: keysutil.KeyType_AES256_GCM96,
|
||||
}
|
||||
|
||||
p.migrateKeyToKeysMap()
|
||||
p.MigrateKeyToKeysMap()
|
||||
|
||||
if p.Key != nil ||
|
||||
p.Keys == nil ||
|
||||
|
@ -604,18 +605,18 @@ func TestDerivedKeyUpgrade(t *testing.T) {
|
|||
key, _ := uuid.GenerateRandomBytes(32)
|
||||
context, _ := uuid.GenerateRandomBytes(32)
|
||||
|
||||
p := &policy{
|
||||
p := &keysutil.Policy{
|
||||
Name: "test",
|
||||
Key: key,
|
||||
Type: keyType_AES256_GCM96,
|
||||
Type: keysutil.KeyType_AES256_GCM96,
|
||||
Derived: true,
|
||||
}
|
||||
|
||||
p.migrateKeyToKeysMap()
|
||||
p.upgrade(storage) // Need to run the upgrade code to make the migration stick
|
||||
p.MigrateKeyToKeysMap()
|
||||
p.Upgrade(storage) // Need to run the upgrade code to make the migration stick
|
||||
|
||||
if p.KDF != kdf_hmac_sha256_counter {
|
||||
t.Fatalf("bad KDF value by default; counter val is %d, KDF val is %d, policy is %#v", kdf_hmac_sha256_counter, p.KDF, *p)
|
||||
if p.KDF != keysutil.Kdf_hmac_sha256_counter {
|
||||
t.Fatalf("bad KDF value by default; counter val is %d, KDF val is %d, policy is %#v", keysutil.Kdf_hmac_sha256_counter, p.KDF, *p)
|
||||
}
|
||||
|
||||
derBytesOld, err := p.DeriveKey(context, 1)
|
||||
|
@ -632,8 +633,8 @@ func TestDerivedKeyUpgrade(t *testing.T) {
|
|||
t.Fatal("mismatch of same context alg")
|
||||
}
|
||||
|
||||
p.KDF = kdf_hkdf_sha256
|
||||
if p.needsUpgrade() {
|
||||
p.KDF = keysutil.Kdf_hkdf_sha256
|
||||
if p.NeedsUpgrade() {
|
||||
t.Fatal("expected no upgrade needed")
|
||||
}
|
||||
|
||||
|
@ -692,15 +693,15 @@ func testConvergentEncryptionCommon(t *testing.T, ver int) {
|
|||
t.Fatalf("bad: expected error response, got %#v", *resp)
|
||||
}
|
||||
|
||||
p := &policy{
|
||||
p := &keysutil.Policy{
|
||||
Name: "testkey",
|
||||
Type: keyType_AES256_GCM96,
|
||||
Type: keysutil.KeyType_AES256_GCM96,
|
||||
Derived: true,
|
||||
ConvergentEncryption: true,
|
||||
ConvergentVersion: ver,
|
||||
}
|
||||
|
||||
err = p.rotate(storage)
|
||||
err = p.Rotate(storage)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -976,7 +977,7 @@ func testPolicyFuzzingCommon(t *testing.T, be *backend) {
|
|||
resp, err := be.pathDecryptWrite(req, fd)
|
||||
if err != nil {
|
||||
// This could well happen since the min version is jumping around
|
||||
if resp.Data["error"].(string) == ErrTooOld {
|
||||
if resp.Data["error"].(string) == keysutil.ErrTooOld {
|
||||
continue
|
||||
}
|
||||
t.Fatalf("got an error: %v, resp is %#v, ciphertext was %s, chosenKey is %s, id is %d", err, *resp, ct, chosenKey, id)
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/hashicorp/vault/helper/errutil"
|
||||
"github.com/hashicorp/vault/helper/keysutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
@ -116,7 +117,7 @@ func (b *backend) pathEncryptWrite(
|
|||
}
|
||||
|
||||
// Get the policy
|
||||
var p *policy
|
||||
var p *keysutil.Policy
|
||||
var lock *sync.RWMutex
|
||||
var upserted bool
|
||||
if req.Operation == logical.CreateOperation {
|
||||
|
@ -125,17 +126,17 @@ func (b *backend) pathEncryptWrite(
|
|||
return logical.ErrorResponse("convergent encryption requires derivation to be enabled, so context is required"), nil
|
||||
}
|
||||
|
||||
polReq := policyRequest{
|
||||
storage: req.Storage,
|
||||
name: name,
|
||||
derived: len(context) != 0,
|
||||
convergent: convergent,
|
||||
polReq := keysutil.PolicyRequest{
|
||||
Storage: req.Storage,
|
||||
Name: name,
|
||||
Derived: len(context) != 0,
|
||||
Convergent: convergent,
|
||||
}
|
||||
|
||||
keyType := d.Get("type").(string)
|
||||
switch keyType {
|
||||
case "aes256-gcm96":
|
||||
polReq.keyType = keyType_AES256_GCM96
|
||||
polReq.KeyType = keysutil.KeyType_AES256_GCM96
|
||||
case "ecdsa-p256":
|
||||
return logical.ErrorResponse(fmt.Sprintf("key type %v not supported for this operation", keyType)), logical.ErrInvalidRequest
|
||||
default:
|
||||
|
|
|
@ -124,7 +124,7 @@ func TestTransit_HMAC(t *testing.T) {
|
|||
req.Data["input"] = "dGhlIHF1aWNrIGJyb3duIGZveA=="
|
||||
|
||||
// Rotate
|
||||
err = p.rotate(storage)
|
||||
err = p.Rotate(storage)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/hashicorp/vault/helper/keysutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
@ -95,17 +96,17 @@ func (b *backend) pathPolicyWrite(
|
|||
return logical.ErrorResponse("convergent encryption requires derivation to be enabled"), nil
|
||||
}
|
||||
|
||||
polReq := policyRequest{
|
||||
storage: req.Storage,
|
||||
name: name,
|
||||
derived: derived,
|
||||
convergent: convergent,
|
||||
polReq := keysutil.PolicyRequest{
|
||||
Storage: req.Storage,
|
||||
Name: name,
|
||||
Derived: derived,
|
||||
Convergent: convergent,
|
||||
}
|
||||
switch keyType {
|
||||
case "aes256-gcm96":
|
||||
polReq.keyType = keyType_AES256_GCM96
|
||||
polReq.KeyType = keysutil.KeyType_AES256_GCM96
|
||||
case "ecdsa-p256":
|
||||
polReq.keyType = keyType_ECDSA_P256
|
||||
polReq.KeyType = keysutil.KeyType_ECDSA_P256
|
||||
default:
|
||||
return logical.ErrorResponse(fmt.Sprintf("unknown key type %v", keyType)), logical.ErrInvalidRequest
|
||||
}
|
||||
|
@ -158,10 +159,10 @@ func (b *backend) pathPolicyRead(
|
|||
|
||||
if p.Derived {
|
||||
switch p.KDF {
|
||||
case kdf_hmac_sha256_counter:
|
||||
case keysutil.Kdf_hmac_sha256_counter:
|
||||
resp.Data["kdf"] = "hmac-sha256-counter"
|
||||
resp.Data["kdf_mode"] = "hmac-sha256-counter"
|
||||
case kdf_hkdf_sha256:
|
||||
case keysutil.Kdf_hkdf_sha256:
|
||||
resp.Data["kdf"] = "hkdf_sha256"
|
||||
}
|
||||
resp.Data["convergent_encryption"] = p.ConvergentEncryption
|
||||
|
@ -171,14 +172,14 @@ func (b *backend) pathPolicyRead(
|
|||
}
|
||||
|
||||
switch p.Type {
|
||||
case keyType_AES256_GCM96:
|
||||
case keysutil.KeyType_AES256_GCM96:
|
||||
retKeys := map[string]int64{}
|
||||
for k, v := range p.Keys {
|
||||
retKeys[strconv.Itoa(k)] = v.CreationTime
|
||||
}
|
||||
resp.Data["keys"] = retKeys
|
||||
|
||||
case keyType_ECDSA_P256:
|
||||
case keysutil.KeyType_ECDSA_P256:
|
||||
type ecdsaKey struct {
|
||||
Name string `json:"name"`
|
||||
PublicKey string `json:"public_key"`
|
||||
|
|
|
@ -41,7 +41,7 @@ func (b *backend) pathRotateWrite(
|
|||
}
|
||||
|
||||
// Rotate the policy
|
||||
err = p.rotate(req.Storage)
|
||||
err = p.Rotate(req.Storage)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -177,11 +177,11 @@ func TestTransit_SignVerify(t *testing.T) {
|
|||
signRequest(req, true, "")
|
||||
|
||||
// Rotate and set min decryption version
|
||||
err = p.rotate(storage)
|
||||
err = p.Rotate(storage)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = p.rotate(storage)
|
||||
err = p.Rotate(storage)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -4,28 +4,29 @@ import (
|
|||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/helper/keysutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
var (
|
||||
keysArchive []keyEntry
|
||||
keysArchive []keysutil.KeyEntry
|
||||
)
|
||||
|
||||
func resetKeysArchive() {
|
||||
keysArchive = []keyEntry{keyEntry{}}
|
||||
keysArchive = []keysutil.KeyEntry{keysutil.KeyEntry{}}
|
||||
}
|
||||
|
||||
func Test_KeyUpgrade(t *testing.T) {
|
||||
testKeyUpgradeCommon(t, newLockManager(false))
|
||||
testKeyUpgradeCommon(t, newLockManager(true))
|
||||
testKeyUpgradeCommon(t, keysutil.NewLockManager(false))
|
||||
testKeyUpgradeCommon(t, keysutil.NewLockManager(true))
|
||||
}
|
||||
|
||||
func testKeyUpgradeCommon(t *testing.T, lm *lockManager) {
|
||||
func testKeyUpgradeCommon(t *testing.T, lm *keysutil.LockManager) {
|
||||
storage := &logical.InmemStorage{}
|
||||
p, lock, upserted, err := lm.GetPolicyUpsert(policyRequest{
|
||||
storage: storage,
|
||||
keyType: keyType_AES256_GCM96,
|
||||
name: "test",
|
||||
p, lock, upserted, err := lm.GetPolicyUpsert(keysutil.PolicyRequest{
|
||||
Storage: storage,
|
||||
KeyType: keysutil.KeyType_AES256_GCM96,
|
||||
Name: "test",
|
||||
})
|
||||
if lock != nil {
|
||||
defer lock.RUnlock()
|
||||
|
@ -45,7 +46,7 @@ func testKeyUpgradeCommon(t *testing.T, lm *lockManager) {
|
|||
|
||||
p.Key = p.Keys[1].AESKey
|
||||
p.Keys = nil
|
||||
p.migrateKeyToKeysMap()
|
||||
p.MigrateKeyToKeysMap()
|
||||
if p.Key != nil {
|
||||
t.Fatal("policy.Key is not nil")
|
||||
}
|
||||
|
@ -58,11 +59,11 @@ func testKeyUpgradeCommon(t *testing.T, lm *lockManager) {
|
|||
}
|
||||
|
||||
func Test_ArchivingUpgrade(t *testing.T) {
|
||||
testArchivingUpgradeCommon(t, newLockManager(false))
|
||||
testArchivingUpgradeCommon(t, newLockManager(true))
|
||||
testArchivingUpgradeCommon(t, keysutil.NewLockManager(false))
|
||||
testArchivingUpgradeCommon(t, keysutil.NewLockManager(true))
|
||||
}
|
||||
|
||||
func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
|
||||
func testArchivingUpgradeCommon(t *testing.T, lm *keysutil.LockManager) {
|
||||
resetKeysArchive()
|
||||
|
||||
// First, we generate a policy and rotate it a number of times. Each time
|
||||
|
@ -71,10 +72,10 @@ func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
|
|||
// zero and latest, respectively
|
||||
|
||||
storage := &logical.InmemStorage{}
|
||||
p, lock, _, err := lm.GetPolicyUpsert(policyRequest{
|
||||
storage: storage,
|
||||
keyType: keyType_AES256_GCM96,
|
||||
name: "test",
|
||||
p, lock, _, err := lm.GetPolicyUpsert(keysutil.PolicyRequest{
|
||||
Storage: storage,
|
||||
KeyType: keysutil.KeyType_AES256_GCM96,
|
||||
Name: "test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -89,7 +90,7 @@ func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
|
|||
checkKeys(t, p, storage, "initial", 1, 1, 1)
|
||||
|
||||
for i := 2; i <= 10; i++ {
|
||||
err = p.rotate(storage)
|
||||
err = p.Rotate(storage)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -123,7 +124,7 @@ func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
|
|||
// If we're caching, expire from the cache since we modified it
|
||||
// under-the-hood
|
||||
if lm.CacheActive() {
|
||||
delete(lm.cache, "test")
|
||||
lm.CacheDelete("test")
|
||||
}
|
||||
|
||||
// Now get the policy again; the upgrade should happen automatically
|
||||
|
@ -141,7 +142,7 @@ func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
|
|||
// Let's check some deletion logic while we're at it
|
||||
|
||||
// The policy should be in there
|
||||
if lm.CacheActive() && lm.cache["test"] == nil {
|
||||
if lm.CacheActive() && lm.Cache("test") == nil {
|
||||
t.Fatal("nil policy in cache")
|
||||
}
|
||||
|
||||
|
@ -152,7 +153,7 @@ func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
|
|||
}
|
||||
|
||||
// The policy should still be in there
|
||||
if lm.CacheActive() && lm.cache["test"] == nil {
|
||||
if lm.CacheActive() && lm.Cache("test") == nil {
|
||||
t.Fatal("nil policy in cache")
|
||||
}
|
||||
|
||||
|
@ -177,7 +178,7 @@ func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
|
|||
}
|
||||
|
||||
// The policy should *not* be in there
|
||||
if lm.CacheActive() && lm.cache["test"] != nil {
|
||||
if lm.CacheActive() && lm.Cache("test") != nil {
|
||||
t.Fatal("non-nil policy in cache")
|
||||
}
|
||||
|
||||
|
@ -191,11 +192,11 @@ func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
|
|||
}
|
||||
|
||||
func Test_Archiving(t *testing.T) {
|
||||
testArchivingCommon(t, newLockManager(false))
|
||||
testArchivingCommon(t, newLockManager(true))
|
||||
testArchivingCommon(t, keysutil.NewLockManager(false))
|
||||
testArchivingCommon(t, keysutil.NewLockManager(true))
|
||||
}
|
||||
|
||||
func testArchivingCommon(t *testing.T, lm *lockManager) {
|
||||
func testArchivingCommon(t *testing.T, lm *keysutil.LockManager) {
|
||||
resetKeysArchive()
|
||||
|
||||
// First, we generate a policy and rotate it a number of times. Each time // we'll ensure that we have the expected number of keys in the archive and
|
||||
|
@ -203,10 +204,10 @@ func testArchivingCommon(t *testing.T, lm *lockManager) {
|
|||
// zero and latest, respectively
|
||||
|
||||
storage := &logical.InmemStorage{}
|
||||
p, lock, _, err := lm.GetPolicyUpsert(policyRequest{
|
||||
storage: storage,
|
||||
keyType: keyType_AES256_GCM96,
|
||||
name: "test",
|
||||
p, lock, _, err := lm.GetPolicyUpsert(keysutil.PolicyRequest{
|
||||
Storage: storage,
|
||||
KeyType: keysutil.KeyType_AES256_GCM96,
|
||||
Name: "test",
|
||||
})
|
||||
if lock != nil {
|
||||
defer lock.RUnlock()
|
||||
|
@ -223,7 +224,7 @@ func testArchivingCommon(t *testing.T, lm *lockManager) {
|
|||
checkKeys(t, p, storage, "initial", 1, 1, 1)
|
||||
|
||||
for i := 2; i <= 10; i++ {
|
||||
err = p.rotate(storage)
|
||||
err = p.Rotate(storage)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -271,7 +272,7 @@ func testArchivingCommon(t *testing.T, lm *lockManager) {
|
|||
}
|
||||
|
||||
func checkKeys(t *testing.T,
|
||||
p *policy,
|
||||
p *keysutil.Policy,
|
||||
storage logical.Storage,
|
||||
action string,
|
||||
archiveVer, latestVer, keysSize int) {
|
||||
|
@ -282,7 +283,7 @@ func checkKeys(t *testing.T,
|
|||
"but keys archive is of size %d", latestVer, latestVer+1, len(keysArchive))
|
||||
}
|
||||
|
||||
archive, err := p.loadArchive(storage)
|
||||
archive, err := p.LoadArchive(storage)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package transit
|
||||
package keysutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
@ -18,29 +18,29 @@ var (
|
|||
errNeedExclusiveLock = errors.New("an exclusive lock is needed for this operation")
|
||||
)
|
||||
|
||||
// policyRequest holds values used when requesting a policy. Most values are
|
||||
// PolicyRequest holds values used when requesting a policy. Most values are
|
||||
// only used during an upsert.
|
||||
type policyRequest struct {
|
||||
type PolicyRequest struct {
|
||||
// The storage to use
|
||||
storage logical.Storage
|
||||
Storage logical.Storage
|
||||
|
||||
// The name of the policy
|
||||
name string
|
||||
Name string
|
||||
|
||||
// The key type
|
||||
keyType KeyType
|
||||
KeyType KeyType
|
||||
|
||||
// Whether it should be derived
|
||||
derived bool
|
||||
Derived bool
|
||||
|
||||
// Whether to enable convergent encryption
|
||||
convergent bool
|
||||
Convergent bool
|
||||
|
||||
// Whether to upsert
|
||||
upsert bool
|
||||
Upsert bool
|
||||
}
|
||||
|
||||
type lockManager struct {
|
||||
type LockManager struct {
|
||||
// A lock for each named key
|
||||
locks map[string]*sync.RWMutex
|
||||
|
||||
|
@ -48,27 +48,35 @@ type lockManager struct {
|
|||
locksMutex sync.RWMutex
|
||||
|
||||
// If caching is enabled, the map of name to in-memory policy cache
|
||||
cache map[string]*policy
|
||||
cache map[string]*Policy
|
||||
|
||||
// Used for global locking, and as the cache map mutex
|
||||
cacheMutex sync.RWMutex
|
||||
}
|
||||
|
||||
func newLockManager(cacheDisabled bool) *lockManager {
|
||||
lm := &lockManager{
|
||||
func NewLockManager(cacheDisabled bool) *LockManager {
|
||||
lm := &LockManager{
|
||||
locks: map[string]*sync.RWMutex{},
|
||||
}
|
||||
if !cacheDisabled {
|
||||
lm.cache = map[string]*policy{}
|
||||
lm.cache = map[string]*Policy{}
|
||||
}
|
||||
return lm
|
||||
}
|
||||
|
||||
func (lm *lockManager) CacheActive() bool {
|
||||
func (lm *LockManager) CacheActive() bool {
|
||||
return lm.cache != nil
|
||||
}
|
||||
|
||||
func (lm *lockManager) policyLock(name string, lockType bool) *sync.RWMutex {
|
||||
func (lm *LockManager) CacheDelete(name string) {
|
||||
delete(lm.cache, name)
|
||||
}
|
||||
|
||||
func (lm *LockManager) Cache(name string) *Policy {
|
||||
return lm.cache[name]
|
||||
}
|
||||
|
||||
func (lm *LockManager) policyLock(name string, lockType bool) *sync.RWMutex {
|
||||
lm.locksMutex.RLock()
|
||||
lock := lm.locks[name]
|
||||
if lock != nil {
|
||||
|
@ -115,7 +123,7 @@ func (lm *lockManager) policyLock(name string, lockType bool) *sync.RWMutex {
|
|||
return lock
|
||||
}
|
||||
|
||||
func (lm *lockManager) UnlockPolicy(lock *sync.RWMutex, lockType bool) {
|
||||
func (lm *LockManager) UnlockPolicy(lock *sync.RWMutex, lockType bool) {
|
||||
if lockType == exclusive {
|
||||
lock.Unlock()
|
||||
} else {
|
||||
|
@ -126,10 +134,10 @@ func (lm *lockManager) UnlockPolicy(lock *sync.RWMutex, lockType bool) {
|
|||
// Get the policy with a read lock. If we get an error saying an exclusive lock
|
||||
// is needed (for instance, for an upgrade/migration), give up the read lock,
|
||||
// call again with an exclusive lock, then swap back out for a read lock.
|
||||
func (lm *lockManager) GetPolicyShared(storage logical.Storage, name string) (*policy, *sync.RWMutex, error) {
|
||||
p, lock, _, err := lm.getPolicyCommon(policyRequest{
|
||||
storage: storage,
|
||||
name: name,
|
||||
func (lm *LockManager) GetPolicyShared(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) {
|
||||
p, lock, _, err := lm.getPolicyCommon(PolicyRequest{
|
||||
Storage: storage,
|
||||
Name: name,
|
||||
}, shared)
|
||||
if err == nil ||
|
||||
(err != nil && err != errNeedExclusiveLock) {
|
||||
|
@ -137,9 +145,9 @@ func (lm *lockManager) GetPolicyShared(storage logical.Storage, name string) (*p
|
|||
}
|
||||
|
||||
// Try again while asking for an exlusive lock
|
||||
p, lock, _, err = lm.getPolicyCommon(policyRequest{
|
||||
storage: storage,
|
||||
name: name,
|
||||
p, lock, _, err = lm.getPolicyCommon(PolicyRequest{
|
||||
Storage: storage,
|
||||
Name: name,
|
||||
}, exclusive)
|
||||
if err != nil || p == nil || lock == nil {
|
||||
return p, lock, err
|
||||
|
@ -147,18 +155,18 @@ func (lm *lockManager) GetPolicyShared(storage logical.Storage, name string) (*p
|
|||
|
||||
lock.Unlock()
|
||||
|
||||
p, lock, _, err = lm.getPolicyCommon(policyRequest{
|
||||
storage: storage,
|
||||
name: name,
|
||||
p, lock, _, err = lm.getPolicyCommon(PolicyRequest{
|
||||
Storage: storage,
|
||||
Name: name,
|
||||
}, shared)
|
||||
return p, lock, err
|
||||
}
|
||||
|
||||
// Get the policy with an exclusive lock
|
||||
func (lm *lockManager) GetPolicyExclusive(storage logical.Storage, name string) (*policy, *sync.RWMutex, error) {
|
||||
p, lock, _, err := lm.getPolicyCommon(policyRequest{
|
||||
storage: storage,
|
||||
name: name,
|
||||
func (lm *LockManager) GetPolicyExclusive(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) {
|
||||
p, lock, _, err := lm.getPolicyCommon(PolicyRequest{
|
||||
Storage: storage,
|
||||
Name: name,
|
||||
}, exclusive)
|
||||
return p, lock, err
|
||||
}
|
||||
|
@ -166,8 +174,8 @@ func (lm *lockManager) GetPolicyExclusive(storage logical.Storage, name string)
|
|||
// Get the policy with a read lock; if it returns that an exclusive lock is
|
||||
// needed, retry. If successful, call one more time to get a read lock and
|
||||
// return the value.
|
||||
func (lm *lockManager) GetPolicyUpsert(req policyRequest) (*policy, *sync.RWMutex, bool, error) {
|
||||
req.upsert = true
|
||||
func (lm *LockManager) GetPolicyUpsert(req PolicyRequest) (*Policy, *sync.RWMutex, bool, error) {
|
||||
req.Upsert = true
|
||||
|
||||
p, lock, _, err := lm.getPolicyCommon(req, shared)
|
||||
if err == nil ||
|
||||
|
@ -182,7 +190,7 @@ func (lm *lockManager) GetPolicyUpsert(req policyRequest) (*policy, *sync.RWMute
|
|||
}
|
||||
lock.Unlock()
|
||||
|
||||
req.upsert = false
|
||||
req.Upsert = false
|
||||
// Now get a shared lock for the return, but preserve the value of upserted
|
||||
p, lock, _, err = lm.getPolicyCommon(req, shared)
|
||||
|
||||
|
@ -191,16 +199,16 @@ func (lm *lockManager) GetPolicyUpsert(req policyRequest) (*policy, *sync.RWMute
|
|||
|
||||
// When the function returns, a lock will be held on the policy if err == nil.
|
||||
// It is the caller's responsibility to unlock.
|
||||
func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*policy, *sync.RWMutex, bool, error) {
|
||||
lock := lm.policyLock(req.name, lockType)
|
||||
func (lm *LockManager) getPolicyCommon(req PolicyRequest, lockType bool) (*Policy, *sync.RWMutex, bool, error) {
|
||||
lock := lm.policyLock(req.Name, lockType)
|
||||
|
||||
var p *policy
|
||||
var p *Policy
|
||||
var err error
|
||||
|
||||
// Check if it's in our cache. If so, return right away.
|
||||
if lm.CacheActive() {
|
||||
lm.cacheMutex.RLock()
|
||||
p = lm.cache[req.name]
|
||||
p = lm.cache[req.Name]
|
||||
if p != nil {
|
||||
lm.cacheMutex.RUnlock()
|
||||
return p, lock, false, nil
|
||||
|
@ -209,7 +217,7 @@ func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*polic
|
|||
}
|
||||
|
||||
// Load it from storage
|
||||
p, err = lm.getStoredPolicy(req.storage, req.name)
|
||||
p, err = lm.getStoredPolicy(req.Storage, req.Name)
|
||||
if err != nil {
|
||||
lm.UnlockPolicy(lock, lockType)
|
||||
return nil, nil, false, err
|
||||
|
@ -218,7 +226,7 @@ func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*polic
|
|||
if p == nil {
|
||||
// This is the only place we upsert a new policy, so if upsert is not
|
||||
// specified, or the lock type is wrong, unlock before returning
|
||||
if !req.upsert {
|
||||
if !req.Upsert {
|
||||
lm.UnlockPolicy(lock, lockType)
|
||||
return nil, nil, false, nil
|
||||
}
|
||||
|
@ -228,33 +236,33 @@ func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*polic
|
|||
return nil, nil, false, errNeedExclusiveLock
|
||||
}
|
||||
|
||||
switch req.keyType {
|
||||
case keyType_AES256_GCM96:
|
||||
if req.convergent && !req.derived {
|
||||
switch req.KeyType {
|
||||
case KeyType_AES256_GCM96:
|
||||
if req.Convergent && !req.Derived {
|
||||
return nil, nil, false, fmt.Errorf("convergent encryption requires derivation to be enabled")
|
||||
}
|
||||
|
||||
case keyType_ECDSA_P256:
|
||||
if req.derived || req.convergent {
|
||||
return nil, nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %s", keyType_ECDSA_P256)
|
||||
case KeyType_ECDSA_P256:
|
||||
if req.Derived || req.Convergent {
|
||||
return nil, nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %s", KeyType_ECDSA_P256)
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, nil, false, fmt.Errorf("unsupported key type %v", req.keyType)
|
||||
return nil, nil, false, fmt.Errorf("unsupported key type %v", req.KeyType)
|
||||
}
|
||||
|
||||
p = &policy{
|
||||
Name: req.name,
|
||||
Type: req.keyType,
|
||||
Derived: req.derived,
|
||||
p = &Policy{
|
||||
Name: req.Name,
|
||||
Type: req.KeyType,
|
||||
Derived: req.Derived,
|
||||
}
|
||||
if req.derived {
|
||||
p.KDF = kdf_hkdf_sha256
|
||||
p.ConvergentEncryption = req.convergent
|
||||
if req.Derived {
|
||||
p.KDF = Kdf_hkdf_sha256
|
||||
p.ConvergentEncryption = req.Convergent
|
||||
p.ConvergentVersion = 2
|
||||
}
|
||||
|
||||
err = p.rotate(req.storage)
|
||||
err = p.Rotate(req.Storage)
|
||||
if err != nil {
|
||||
lm.UnlockPolicy(lock, lockType)
|
||||
return nil, nil, false, err
|
||||
|
@ -267,12 +275,12 @@ func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*polic
|
|||
defer lm.cacheMutex.Unlock()
|
||||
// Make sure a policy didn't appear. If so, it will only be set if
|
||||
// there was no error, so assume it's good and return that
|
||||
exp := lm.cache[req.name]
|
||||
exp := lm.cache[req.Name]
|
||||
if exp != nil {
|
||||
return exp, lock, false, nil
|
||||
}
|
||||
if err == nil {
|
||||
lm.cache[req.name] = p
|
||||
lm.cache[req.Name] = p
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -280,13 +288,13 @@ func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*polic
|
|||
return p, lock, true, nil
|
||||
}
|
||||
|
||||
if p.needsUpgrade() {
|
||||
if p.NeedsUpgrade() {
|
||||
if lockType == shared {
|
||||
lm.UnlockPolicy(lock, lockType)
|
||||
return nil, nil, false, errNeedExclusiveLock
|
||||
}
|
||||
|
||||
err = p.upgrade(req.storage)
|
||||
err = p.Upgrade(req.Storage)
|
||||
if err != nil {
|
||||
lm.UnlockPolicy(lock, lockType)
|
||||
return nil, nil, false, err
|
||||
|
@ -300,25 +308,25 @@ func (lm *lockManager) getPolicyCommon(req policyRequest, lockType bool) (*polic
|
|||
defer lm.cacheMutex.Unlock()
|
||||
// Make sure a policy didn't appear. If so, it will only be set if
|
||||
// there was no error, so assume it's good and return that
|
||||
exp := lm.cache[req.name]
|
||||
exp := lm.cache[req.Name]
|
||||
if exp != nil {
|
||||
return exp, lock, false, nil
|
||||
}
|
||||
if err == nil {
|
||||
lm.cache[req.name] = p
|
||||
lm.cache[req.Name] = p
|
||||
}
|
||||
}
|
||||
|
||||
return p, lock, false, nil
|
||||
}
|
||||
|
||||
func (lm *lockManager) DeletePolicy(storage logical.Storage, name string) error {
|
||||
func (lm *LockManager) DeletePolicy(storage logical.Storage, name string) error {
|
||||
lm.cacheMutex.Lock()
|
||||
lock := lm.policyLock(name, exclusive)
|
||||
defer lock.Unlock()
|
||||
defer lm.cacheMutex.Unlock()
|
||||
|
||||
var p *policy
|
||||
var p *Policy
|
||||
var err error
|
||||
|
||||
if lm.CacheActive() {
|
||||
|
@ -355,7 +363,7 @@ func (lm *lockManager) DeletePolicy(storage logical.Storage, name string) error
|
|||
return nil
|
||||
}
|
||||
|
||||
func (lm *lockManager) getStoredPolicy(storage logical.Storage, name string) (*policy, error) {
|
||||
func (lm *LockManager) getStoredPolicy(storage logical.Storage, name string) (*Policy, error) {
|
||||
// Check if the policy already exists
|
||||
raw, err := storage.Get("policy/" + name)
|
||||
if err != nil {
|
||||
|
@ -366,7 +374,7 @@ func (lm *lockManager) getStoredPolicy(storage logical.Storage, name string) (*p
|
|||
}
|
||||
|
||||
// Decode the policy
|
||||
policy := &policy{
|
||||
policy := &Policy{
|
||||
Keys: keyEntryMap{},
|
||||
}
|
||||
err = jsonutil.DecodeJSON(raw.Value, policy)
|
|
@ -1,4 +1,4 @@
|
|||
package transit
|
||||
package keysutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
@ -33,14 +33,14 @@ import (
|
|||
// Careful with iota; don't put anything before it in this const block because
|
||||
// we need the default of zero to be the old-style KDF
|
||||
const (
|
||||
kdf_hmac_sha256_counter = iota // built-in helper
|
||||
kdf_hkdf_sha256 // golang.org/x/crypto/hkdf
|
||||
Kdf_hmac_sha256_counter = iota // built-in helper
|
||||
Kdf_hkdf_sha256 // golang.org/x/crypto/hkdf
|
||||
)
|
||||
|
||||
// Or this one...we need the default of zero to be the original AES256-GCM96
|
||||
const (
|
||||
keyType_AES256_GCM96 = iota
|
||||
keyType_ECDSA_P256
|
||||
KeyType_AES256_GCM96 = iota
|
||||
KeyType_ECDSA_P256
|
||||
)
|
||||
|
||||
const ErrTooOld = "ciphertext or signature version is disallowed by policy (too old)"
|
||||
|
@ -53,7 +53,7 @@ type KeyType int
|
|||
|
||||
func (kt KeyType) EncryptionSupported() bool {
|
||||
switch kt {
|
||||
case keyType_AES256_GCM96:
|
||||
case KeyType_AES256_GCM96:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
@ -61,7 +61,7 @@ func (kt KeyType) EncryptionSupported() bool {
|
|||
|
||||
func (kt KeyType) DecryptionSupported() bool {
|
||||
switch kt {
|
||||
case keyType_AES256_GCM96:
|
||||
case KeyType_AES256_GCM96:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
@ -69,7 +69,7 @@ func (kt KeyType) DecryptionSupported() bool {
|
|||
|
||||
func (kt KeyType) SigningSupported() bool {
|
||||
switch kt {
|
||||
case keyType_ECDSA_P256:
|
||||
case KeyType_ECDSA_P256:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
@ -77,7 +77,7 @@ func (kt KeyType) SigningSupported() bool {
|
|||
|
||||
func (kt KeyType) DerivationSupported() bool {
|
||||
switch kt {
|
||||
case keyType_AES256_GCM96:
|
||||
case KeyType_AES256_GCM96:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
@ -85,17 +85,17 @@ func (kt KeyType) DerivationSupported() bool {
|
|||
|
||||
func (kt KeyType) String() string {
|
||||
switch kt {
|
||||
case keyType_AES256_GCM96:
|
||||
case KeyType_AES256_GCM96:
|
||||
return "aes256-gcm96"
|
||||
case keyType_ECDSA_P256:
|
||||
case KeyType_ECDSA_P256:
|
||||
return "ecdsa-p256"
|
||||
}
|
||||
|
||||
return "[unknown]"
|
||||
}
|
||||
|
||||
// keyEntry stores the key and metadata
|
||||
type keyEntry struct {
|
||||
// KeyEntry stores the key and metadata
|
||||
type KeyEntry struct {
|
||||
AESKey []byte `json:"key"`
|
||||
HMACKey []byte `json:"hmac_key"`
|
||||
CreationTime int64 `json:"creation_time"`
|
||||
|
@ -106,11 +106,11 @@ type keyEntry struct {
|
|||
}
|
||||
|
||||
// keyEntryMap is used to allow JSON marshal/unmarshal
|
||||
type keyEntryMap map[int]keyEntry
|
||||
type keyEntryMap map[int]KeyEntry
|
||||
|
||||
// MarshalJSON implements JSON marshaling
|
||||
func (kem keyEntryMap) MarshalJSON() ([]byte, error) {
|
||||
intermediate := map[string]keyEntry{}
|
||||
intermediate := map[string]KeyEntry{}
|
||||
for k, v := range kem {
|
||||
intermediate[strconv.Itoa(k)] = v
|
||||
}
|
||||
|
@ -119,7 +119,7 @@ func (kem keyEntryMap) MarshalJSON() ([]byte, error) {
|
|||
|
||||
// MarshalJSON implements JSON unmarshaling
|
||||
func (kem keyEntryMap) UnmarshalJSON(data []byte) error {
|
||||
intermediate := map[string]keyEntry{}
|
||||
intermediate := map[string]KeyEntry{}
|
||||
if err := jsonutil.DecodeJSON(data, &intermediate); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -135,7 +135,7 @@ func (kem keyEntryMap) UnmarshalJSON(data []byte) error {
|
|||
}
|
||||
|
||||
// Policy is the struct used to store metadata
|
||||
type policy struct {
|
||||
type Policy struct {
|
||||
Name string `json:"name"`
|
||||
Key []byte `json:"key,omitempty"` //DEPRECATED
|
||||
Keys keyEntryMap `json:"keys"`
|
||||
|
@ -171,10 +171,10 @@ type policy struct {
|
|||
// ArchivedKeys stores old keys. This is used to keep the key loading time sane
|
||||
// when there are huge numbers of rotations.
|
||||
type archivedKeys struct {
|
||||
Keys []keyEntry `json:"keys"`
|
||||
Keys []KeyEntry `json:"keys"`
|
||||
}
|
||||
|
||||
func (p *policy) loadArchive(storage logical.Storage) (*archivedKeys, error) {
|
||||
func (p *Policy) LoadArchive(storage logical.Storage) (*archivedKeys, error) {
|
||||
archive := &archivedKeys{}
|
||||
|
||||
raw, err := storage.Get("archive/" + p.Name)
|
||||
|
@ -182,7 +182,7 @@ func (p *policy) loadArchive(storage logical.Storage) (*archivedKeys, error) {
|
|||
return nil, err
|
||||
}
|
||||
if raw == nil {
|
||||
archive.Keys = make([]keyEntry, 0)
|
||||
archive.Keys = make([]KeyEntry, 0)
|
||||
return archive, nil
|
||||
}
|
||||
|
||||
|
@ -193,7 +193,7 @@ func (p *policy) loadArchive(storage logical.Storage) (*archivedKeys, error) {
|
|||
return archive, nil
|
||||
}
|
||||
|
||||
func (p *policy) storeArchive(archive *archivedKeys, storage logical.Storage) error {
|
||||
func (p *Policy) storeArchive(archive *archivedKeys, storage logical.Storage) error {
|
||||
// Encode the policy
|
||||
buf, err := json.Marshal(archive)
|
||||
if err != nil {
|
||||
|
@ -215,7 +215,7 @@ func (p *policy) storeArchive(archive *archivedKeys, storage logical.Storage) er
|
|||
// handleArchiving manages the movement of keys to and from the policy archive.
|
||||
// This should *ONLY* be called from Persist() since it assumes that the policy
|
||||
// will be persisted afterwards.
|
||||
func (p *policy) handleArchiving(storage logical.Storage) error {
|
||||
func (p *Policy) handleArchiving(storage logical.Storage) error {
|
||||
// We need to move keys that are no longer accessible to archivedKeys, and keys
|
||||
// that now need to be accessible back here.
|
||||
//
|
||||
|
@ -241,7 +241,7 @@ func (p *policy) handleArchiving(storage logical.Storage) error {
|
|||
p.MinDecryptionVersion, p.LatestVersion)
|
||||
}
|
||||
|
||||
archive, err := p.loadArchive(storage)
|
||||
archive, err := p.LoadArchive(storage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -263,7 +263,7 @@ func (p *policy) handleArchiving(storage logical.Storage) error {
|
|||
// key version
|
||||
if len(archive.Keys) < p.LatestVersion+1 {
|
||||
// Increase the size of the archive slice
|
||||
newKeys := make([]keyEntry, p.LatestVersion+1)
|
||||
newKeys := make([]KeyEntry, p.LatestVersion+1)
|
||||
copy(newKeys, archive.Keys)
|
||||
archive.Keys = newKeys
|
||||
}
|
||||
|
@ -289,7 +289,7 @@ func (p *policy) handleArchiving(storage logical.Storage) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *policy) Persist(storage logical.Storage) error {
|
||||
func (p *Policy) Persist(storage logical.Storage) error {
|
||||
err := p.handleArchiving(storage)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -313,11 +313,11 @@ func (p *policy) Persist(storage logical.Storage) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *policy) Serialize() ([]byte, error) {
|
||||
func (p *Policy) Serialize() ([]byte, error) {
|
||||
return json.Marshal(p)
|
||||
}
|
||||
|
||||
func (p *policy) needsUpgrade() bool {
|
||||
func (p *Policy) NeedsUpgrade() bool {
|
||||
// Ensure we've moved from Key -> Keys
|
||||
if p.Key != nil && len(p.Key) > 0 {
|
||||
return true
|
||||
|
@ -352,11 +352,11 @@ func (p *policy) needsUpgrade() bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (p *policy) upgrade(storage logical.Storage) error {
|
||||
func (p *Policy) Upgrade(storage logical.Storage) error {
|
||||
persistNeeded := false
|
||||
// Ensure we've moved from Key -> Keys
|
||||
if p.Key != nil && len(p.Key) > 0 {
|
||||
p.migrateKeyToKeysMap()
|
||||
p.MigrateKeyToKeysMap()
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
|
@ -409,7 +409,7 @@ func (p *policy) upgrade(storage logical.Storage) error {
|
|||
// on the policy. If derivation is disabled the raw key is used and no context
|
||||
// is required, otherwise the KDF mode is used with the context to derive the
|
||||
// proper key.
|
||||
func (p *policy) DeriveKey(context []byte, ver int) ([]byte, error) {
|
||||
func (p *Policy) DeriveKey(context []byte, ver int) ([]byte, error) {
|
||||
if !p.Type.DerivationSupported() {
|
||||
return nil, errutil.UserError{Err: fmt.Sprintf("derivation not supported for key type %v", p.Type)}
|
||||
}
|
||||
|
@ -433,11 +433,11 @@ func (p *policy) DeriveKey(context []byte, ver int) ([]byte, error) {
|
|||
}
|
||||
|
||||
switch p.KDF {
|
||||
case kdf_hmac_sha256_counter:
|
||||
case Kdf_hmac_sha256_counter:
|
||||
prf := kdf.HMACSHA256PRF
|
||||
prfLen := kdf.HMACSHA256PRFLen
|
||||
return kdf.CounterMode(prf, prfLen, p.Keys[ver].AESKey, context, 256)
|
||||
case kdf_hkdf_sha256:
|
||||
case Kdf_hkdf_sha256:
|
||||
reader := hkdf.New(sha256.New, p.Keys[ver].AESKey, nil, context)
|
||||
derBytes := bytes.NewBuffer(nil)
|
||||
derBytes.Grow(32)
|
||||
|
@ -458,14 +458,14 @@ func (p *policy) DeriveKey(context []byte, ver int) ([]byte, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (p *policy) Encrypt(context, nonce []byte, value string) (string, error) {
|
||||
func (p *Policy) Encrypt(context, nonce []byte, value string) (string, error) {
|
||||
if !p.Type.EncryptionSupported() {
|
||||
return "", errutil.UserError{Err: fmt.Sprintf("message encryption not supported for key type %v", p.Type)}
|
||||
}
|
||||
|
||||
// Guard against a potentially invalid key type
|
||||
switch p.Type {
|
||||
case keyType_AES256_GCM96:
|
||||
case KeyType_AES256_GCM96:
|
||||
default:
|
||||
return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
|
||||
}
|
||||
|
@ -484,7 +484,7 @@ func (p *policy) Encrypt(context, nonce []byte, value string) (string, error) {
|
|||
|
||||
// Guard against a potentially invalid key type
|
||||
switch p.Type {
|
||||
case keyType_AES256_GCM96:
|
||||
case KeyType_AES256_GCM96:
|
||||
default:
|
||||
return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
|
||||
}
|
||||
|
@ -539,7 +539,7 @@ func (p *policy) Encrypt(context, nonce []byte, value string) (string, error) {
|
|||
return encoded, nil
|
||||
}
|
||||
|
||||
func (p *policy) Decrypt(context, nonce []byte, value string) (string, error) {
|
||||
func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) {
|
||||
if !p.Type.DecryptionSupported() {
|
||||
return "", errutil.UserError{Err: fmt.Sprintf("message decryption not supported for key type %v", p.Type)}
|
||||
}
|
||||
|
@ -585,7 +585,7 @@ func (p *policy) Decrypt(context, nonce []byte, value string) (string, error) {
|
|||
|
||||
// Guard against a potentially invalid key type
|
||||
switch p.Type {
|
||||
case keyType_AES256_GCM96:
|
||||
case KeyType_AES256_GCM96:
|
||||
default:
|
||||
return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
|
||||
}
|
||||
|
@ -626,7 +626,7 @@ func (p *policy) Decrypt(context, nonce []byte, value string) (string, error) {
|
|||
return base64.StdEncoding.EncodeToString(plain), nil
|
||||
}
|
||||
|
||||
func (p *policy) HMACKey(version int) ([]byte, error) {
|
||||
func (p *Policy) HMACKey(version int) ([]byte, error) {
|
||||
if version < p.MinDecryptionVersion {
|
||||
return nil, fmt.Errorf("key version disallowed by policy (minimum is %d)", p.MinDecryptionVersion)
|
||||
}
|
||||
|
@ -642,14 +642,14 @@ func (p *policy) HMACKey(version int) ([]byte, error) {
|
|||
return p.Keys[version].HMACKey, nil
|
||||
}
|
||||
|
||||
func (p *policy) Sign(hashedInput []byte) (string, error) {
|
||||
func (p *Policy) Sign(hashedInput []byte) (string, error) {
|
||||
if !p.Type.SigningSupported() {
|
||||
return "", fmt.Errorf("message signing not supported for key type %v", p.Type)
|
||||
}
|
||||
|
||||
var sig []byte
|
||||
switch p.Type {
|
||||
case keyType_ECDSA_P256:
|
||||
case KeyType_ECDSA_P256:
|
||||
keyParams := p.Keys[p.LatestVersion]
|
||||
key := &ecdsa.PrivateKey{
|
||||
PublicKey: ecdsa.PublicKey{
|
||||
|
@ -685,7 +685,7 @@ func (p *policy) Sign(hashedInput []byte) (string, error) {
|
|||
return encoded, nil
|
||||
}
|
||||
|
||||
func (p *policy) VerifySignature(hashedInput []byte, sig string) (bool, error) {
|
||||
func (p *Policy) VerifySignature(hashedInput []byte, sig string) (bool, error) {
|
||||
if !p.Type.SigningSupported() {
|
||||
return false, errutil.UserError{Err: fmt.Sprintf("message verification not supported for key type %v", p.Type)}
|
||||
}
|
||||
|
@ -714,7 +714,7 @@ func (p *policy) VerifySignature(hashedInput []byte, sig string) (bool, error) {
|
|||
}
|
||||
|
||||
switch p.Type {
|
||||
case keyType_ECDSA_P256:
|
||||
case KeyType_ECDSA_P256:
|
||||
asn1Sig, err := base64.StdEncoding.DecodeString(splitVerSig[1])
|
||||
if err != nil {
|
||||
return false, errutil.UserError{Err: "invalid base64 signature value"}
|
||||
|
@ -744,7 +744,7 @@ func (p *policy) VerifySignature(hashedInput []byte, sig string) (bool, error) {
|
|||
return false, errutil.InternalError{Err: "no valid key type found"}
|
||||
}
|
||||
|
||||
func (p *policy) rotate(storage logical.Storage) error {
|
||||
func (p *Policy) Rotate(storage logical.Storage) error {
|
||||
if p.Keys == nil {
|
||||
// This is an initial key rotation when generating a new policy. We
|
||||
// don't need to call migrate here because if we've called getPolicy to
|
||||
|
@ -753,7 +753,7 @@ func (p *policy) rotate(storage logical.Storage) error {
|
|||
}
|
||||
|
||||
p.LatestVersion += 1
|
||||
entry := keyEntry{
|
||||
entry := KeyEntry{
|
||||
CreationTime: time.Now().Unix(),
|
||||
}
|
||||
|
||||
|
@ -764,7 +764,7 @@ func (p *policy) rotate(storage logical.Storage) error {
|
|||
entry.HMACKey = hmacKey
|
||||
|
||||
switch p.Type {
|
||||
case keyType_AES256_GCM96:
|
||||
case KeyType_AES256_GCM96:
|
||||
// Generate a 256bit key
|
||||
newKey, err := uuid.GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
|
@ -772,7 +772,7 @@ func (p *policy) rotate(storage logical.Storage) error {
|
|||
}
|
||||
entry.AESKey = newKey
|
||||
|
||||
case keyType_ECDSA_P256:
|
||||
case KeyType_ECDSA_P256:
|
||||
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -807,9 +807,9 @@ func (p *policy) rotate(storage logical.Storage) error {
|
|||
return p.Persist(storage)
|
||||
}
|
||||
|
||||
func (p *policy) migrateKeyToKeysMap() {
|
||||
func (p *Policy) MigrateKeyToKeysMap() {
|
||||
p.Keys = keyEntryMap{
|
||||
1: keyEntry{
|
||||
1: KeyEntry{
|
||||
AESKey: p.Key,
|
||||
CreationTime: time.Now().Unix(),
|
||||
},
|
|
@ -21,6 +21,7 @@ import (
|
|||
"github.com/hashicorp/vault/api"
|
||||
credCert "github.com/hashicorp/vault/builtin/credential/cert"
|
||||
"github.com/hashicorp/vault/builtin/logical/transit"
|
||||
"github.com/hashicorp/vault/helper/keysutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
@ -381,7 +382,7 @@ func testHTTP_Forwarding_Stress_Common(t *testing.T, rpc, parallel bool, num uin
|
|||
secret, err := doResp(resp)
|
||||
if err != nil {
|
||||
// This could well happen since the min version is jumping around
|
||||
if strings.Contains(err.Error(), transit.ErrTooOld) {
|
||||
if strings.Contains(err.Error(), keysutil.ErrTooOld) {
|
||||
mySuccessfulOps++
|
||||
continue
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue