Merge pull request #626 from hashicorp/f-transit-enhancements
Enhancements to the transit backend
This commit is contained in:
commit
61e331200c
|
@ -15,15 +15,19 @@ func Backend() *framework.Backend {
|
|||
PathsSpecial: &logical.Paths{
|
||||
Root: []string{
|
||||
"keys/*",
|
||||
"raw/*",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
// Rotate/Config needs to come before Keys
|
||||
// as the handler is greedy
|
||||
pathConfig(),
|
||||
pathRotate(),
|
||||
pathRewrap(),
|
||||
pathKeys(),
|
||||
pathRaw(),
|
||||
pathEncrypt(),
|
||||
pathDecrypt(),
|
||||
pathDatakey(),
|
||||
},
|
||||
|
||||
Secrets: []*framework.Secret{},
|
||||
|
|
|
@ -3,6 +3,8 @@ package transit
|
|||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
|
@ -21,12 +23,90 @@ func TestBackend_basic(t *testing.T) {
|
|||
Steps: []logicaltest.TestStep{
|
||||
testAccStepWritePolicy(t, "test", false),
|
||||
testAccStepReadPolicy(t, "test", false, false),
|
||||
testAccStepReadRaw(t, "test", false, false),
|
||||
testAccStepEncrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepDeleteNotDisabledPolicy(t, "test"),
|
||||
testAccStepEnableDeletion(t, "test"),
|
||||
testAccStepDeletePolicy(t, "test"),
|
||||
testAccStepWritePolicy(t, "test", false),
|
||||
testAccStepEnableDeletion(t, "test"),
|
||||
testAccStepDisableDeletion(t, "test"),
|
||||
testAccStepDeleteNotDisabledPolicy(t, "test"),
|
||||
testAccStepEnableDeletion(t, "test"),
|
||||
testAccStepDeletePolicy(t, "test"),
|
||||
testAccStepReadPolicy(t, "test", true, false),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackend_datakey(t *testing.T) {
|
||||
dataKeyInfo := make(map[string]interface{})
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
Backend: Backend(),
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepWritePolicy(t, "test", false),
|
||||
testAccStepReadPolicy(t, "test", false, false),
|
||||
testAccStepWriteDatakey(t, "test", false, 256, dataKeyInfo),
|
||||
testAccStepDecryptDatakey(t, "test", dataKeyInfo),
|
||||
testAccStepWriteDatakey(t, "test", true, 128, dataKeyInfo),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackend_rotation(t *testing.T) {
|
||||
decryptData := make(map[string]interface{})
|
||||
encryptHistory := make(map[int]map[string]interface{})
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
Backend: Backend(),
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepWritePolicy(t, "test", false),
|
||||
testAccStepEncryptVX(t, "test", testPlaintext, decryptData, 0, encryptHistory),
|
||||
testAccStepEncryptVX(t, "test", testPlaintext, decryptData, 1, encryptHistory),
|
||||
testAccStepRotate(t, "test"), // now v2
|
||||
testAccStepEncryptVX(t, "test", testPlaintext, decryptData, 2, encryptHistory),
|
||||
testAccStepRotate(t, "test"), // now v3
|
||||
testAccStepEncryptVX(t, "test", testPlaintext, decryptData, 3, encryptHistory),
|
||||
testAccStepRotate(t, "test"), // now v4
|
||||
testAccStepEncryptVX(t, "test", testPlaintext, decryptData, 4, encryptHistory),
|
||||
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepEncryptVX(t, "test", testPlaintext, decryptData, 99, encryptHistory),
|
||||
testAccStepDecryptExpectFailure(t, "test", testPlaintext, decryptData),
|
||||
testAccStepLoadVX(t, "test", decryptData, 0, encryptHistory),
|
||||
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepLoadVX(t, "test", decryptData, 1, encryptHistory),
|
||||
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepLoadVX(t, "test", decryptData, 2, encryptHistory),
|
||||
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepLoadVX(t, "test", decryptData, 3, encryptHistory),
|
||||
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepLoadVX(t, "test", decryptData, 99, encryptHistory),
|
||||
testAccStepDecryptExpectFailure(t, "test", testPlaintext, decryptData),
|
||||
testAccStepLoadVX(t, "test", decryptData, 4, encryptHistory),
|
||||
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepDeleteNotDisabledPolicy(t, "test"),
|
||||
testAccStepAdjustPolicy(t, "test", 3),
|
||||
testAccStepLoadVX(t, "test", decryptData, 0, encryptHistory),
|
||||
testAccStepDecryptExpectFailure(t, "test", testPlaintext, decryptData),
|
||||
testAccStepLoadVX(t, "test", decryptData, 1, encryptHistory),
|
||||
testAccStepDecryptExpectFailure(t, "test", testPlaintext, decryptData),
|
||||
testAccStepLoadVX(t, "test", decryptData, 2, encryptHistory),
|
||||
testAccStepDecryptExpectFailure(t, "test", testPlaintext, decryptData),
|
||||
testAccStepLoadVX(t, "test", decryptData, 3, encryptHistory),
|
||||
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepLoadVX(t, "test", decryptData, 4, encryptHistory),
|
||||
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepAdjustPolicy(t, "test", 1),
|
||||
testAccStepLoadVX(t, "test", decryptData, 0, encryptHistory),
|
||||
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepLoadVX(t, "test", decryptData, 1, encryptHistory),
|
||||
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepLoadVX(t, "test", decryptData, 2, encryptHistory),
|
||||
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepRewrap(t, "test", decryptData, 4),
|
||||
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepEnableDeletion(t, "test"),
|
||||
testAccStepDeletePolicy(t, "test"),
|
||||
testAccStepReadPolicy(t, "test", true, false),
|
||||
testAccStepReadRaw(t, "test", true, false),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
@ -40,6 +120,7 @@ func TestBackend_upsert(t *testing.T) {
|
|||
testAccStepEncrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepReadPolicy(t, "test", false, false),
|
||||
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepEnableDeletion(t, "test"),
|
||||
testAccStepDeletePolicy(t, "test"),
|
||||
testAccStepReadPolicy(t, "test", true, false),
|
||||
},
|
||||
|
@ -53,12 +134,11 @@ func TestBackend_basic_derived(t *testing.T) {
|
|||
Steps: []logicaltest.TestStep{
|
||||
testAccStepWritePolicy(t, "test", true),
|
||||
testAccStepReadPolicy(t, "test", false, true),
|
||||
testAccStepReadRaw(t, "test", false, true),
|
||||
testAccStepEncryptContext(t, "test", testPlaintext, "my-cool-context", decryptData),
|
||||
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepEnableDeletion(t, "test"),
|
||||
testAccStepDeletePolicy(t, "test"),
|
||||
testAccStepReadPolicy(t, "test", true, true),
|
||||
testAccStepReadRaw(t, "test", true, true),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
@ -73,6 +153,36 @@ func testAccStepWritePolicy(t *testing.T, name string, derived bool) logicaltest
|
|||
}
|
||||
}
|
||||
|
||||
func testAccStepAdjustPolicy(t *testing.T, name string, minVer int) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.WriteOperation,
|
||||
Path: "keys/" + name + "/config",
|
||||
Data: map[string]interface{}{
|
||||
"min_decryption_version": minVer,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepDisableDeletion(t *testing.T, name string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.WriteOperation,
|
||||
Path: "keys/" + name + "/config",
|
||||
Data: map[string]interface{}{
|
||||
"deletion_allowed": false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepEnableDeletion(t *testing.T, name string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.WriteOperation,
|
||||
Path: "keys/" + name + "/config",
|
||||
Data: map[string]interface{}{
|
||||
"deletion_allowed": true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepDeletePolicy(t *testing.T, name string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.DeleteOperation,
|
||||
|
@ -80,6 +190,23 @@ func testAccStepDeletePolicy(t *testing.T, name string) logicaltest.TestStep {
|
|||
}
|
||||
}
|
||||
|
||||
func testAccStepDeleteNotDisabledPolicy(t *testing.T, name string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.DeleteOperation,
|
||||
Path: "keys/" + name,
|
||||
ErrorOk: true,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp == nil {
|
||||
return fmt.Errorf("Got nil response instead of error")
|
||||
}
|
||||
if resp.IsError() {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("expected error but did not get one")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepReadPolicy(t *testing.T, name string, expectNone, derived bool) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
|
@ -94,11 +221,13 @@ func testAccStepReadPolicy(t *testing.T, name string, expectNone, derived bool)
|
|||
return nil
|
||||
}
|
||||
var d struct {
|
||||
Name string `mapstructure:"name"`
|
||||
Key []byte `mapstructure:"key"`
|
||||
CipherMode string `mapstructure:"cipher_mode"`
|
||||
Derived bool `mapstructure:"derived"`
|
||||
KDFMode string `mapstructure:"kdf_mode"`
|
||||
Name string `mapstructure:"name"`
|
||||
Key []byte `mapstructure:"key"`
|
||||
Keys map[string]int64 `mapstructure:"keys"`
|
||||
CipherMode string `mapstructure:"cipher_mode"`
|
||||
Derived bool `mapstructure:"derived"`
|
||||
KDFMode string `mapstructure:"kdf_mode"`
|
||||
DeletionAllowed bool `mapstructure:"deletion_allowed"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
|
@ -114,48 +243,10 @@ func testAccStepReadPolicy(t *testing.T, name string, expectNone, derived bool)
|
|||
if d.Key != nil {
|
||||
return fmt.Errorf("bad: %#v", d)
|
||||
}
|
||||
if d.Derived != derived {
|
||||
if d.Keys == nil {
|
||||
return fmt.Errorf("bad: %#v", d)
|
||||
}
|
||||
if derived && d.KDFMode != kdfMode {
|
||||
return fmt.Errorf("bad: %#v", d)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepReadRaw(t *testing.T, name string, expectNone, derived bool) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "raw/" + name,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp == nil && !expectNone {
|
||||
return fmt.Errorf("missing response")
|
||||
} else if expectNone {
|
||||
if resp != nil {
|
||||
return fmt.Errorf("response when expecting none")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
var d struct {
|
||||
Name string `mapstructure:"name"`
|
||||
Key []byte `mapstructure:"key"`
|
||||
CipherMode string `mapstructure:"cipher_mode"`
|
||||
Derived bool `mapstructure:"derived"`
|
||||
KDFMode string `mapstructure:"kdf_mode"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if d.Name != name {
|
||||
return fmt.Errorf("bad: %#v", d)
|
||||
}
|
||||
if d.CipherMode != "aes-gcm" {
|
||||
return fmt.Errorf("bad: %#v", d)
|
||||
}
|
||||
if len(d.Key) != 32 {
|
||||
if d.DeletionAllowed == true {
|
||||
return fmt.Errorf("bad: %#v", d)
|
||||
}
|
||||
if d.Derived != derived {
|
||||
|
@ -240,9 +331,192 @@ func testAccStepDecrypt(
|
|||
}
|
||||
|
||||
if string(plainRaw) != plaintext {
|
||||
return fmt.Errorf("plaintext mismatch: %s expect: %s", plainRaw, plaintext)
|
||||
return fmt.Errorf("plaintext mismatch: %s expect: %s, decryptData was %#v", plainRaw, plaintext, decryptData)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepRewrap(
|
||||
t *testing.T, name string, decryptData map[string]interface{}, expectedVer int) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.WriteOperation,
|
||||
Path: "rewrap/" + name,
|
||||
Data: decryptData,
|
||||
Check: func(resp *logical.Response) error {
|
||||
var d struct {
|
||||
Ciphertext string `mapstructure:"ciphertext"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
if d.Ciphertext == "" {
|
||||
return fmt.Errorf("missing ciphertext")
|
||||
}
|
||||
splitStrings := strings.Split(d.Ciphertext, ":")
|
||||
verString := splitStrings[1][1:]
|
||||
ver, err := strconv.Atoi(verString)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error pulling out version from verString '%s', ciphertext was %s", verString, d.Ciphertext)
|
||||
}
|
||||
if ver != expectedVer {
|
||||
return fmt.Errorf("Did not get expected version")
|
||||
}
|
||||
decryptData["ciphertext"] = d.Ciphertext
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepEncryptVX(
|
||||
t *testing.T, name, plaintext string, decryptData map[string]interface{},
|
||||
ver int, encryptHistory map[int]map[string]interface{}) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.WriteOperation,
|
||||
Path: "encrypt/" + name,
|
||||
Data: map[string]interface{}{
|
||||
"plaintext": base64.StdEncoding.EncodeToString([]byte(plaintext)),
|
||||
},
|
||||
Check: func(resp *logical.Response) error {
|
||||
var d struct {
|
||||
Ciphertext string `mapstructure:"ciphertext"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
if d.Ciphertext == "" {
|
||||
return fmt.Errorf("missing ciphertext")
|
||||
}
|
||||
splitStrings := strings.Split(d.Ciphertext, ":")
|
||||
splitStrings[1] = "v" + strconv.Itoa(ver)
|
||||
ciphertext := strings.Join(splitStrings, ":")
|
||||
decryptData["ciphertext"] = ciphertext
|
||||
encryptHistory[ver] = map[string]interface{}{
|
||||
"ciphertext": ciphertext,
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepLoadVX(
|
||||
t *testing.T, name string, decryptData map[string]interface{},
|
||||
ver int, encryptHistory map[int]map[string]interface{}) logicaltest.TestStep {
|
||||
// This is really a no-op to allow us to do data manip in the check function
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "keys/" + name,
|
||||
Check: func(resp *logical.Response) error {
|
||||
decryptData["ciphertext"] = encryptHistory[ver]["ciphertext"].(string)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepDecryptExpectFailure(
|
||||
t *testing.T, name, plaintext string, decryptData map[string]interface{}) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.WriteOperation,
|
||||
Path: "decrypt/" + name,
|
||||
Data: decryptData,
|
||||
ErrorOk: true,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if !resp.IsError() {
|
||||
return fmt.Errorf("expected error")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepRotate(t *testing.T, name string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.WriteOperation,
|
||||
Path: "keys/" + name + "/rotate",
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepWriteDatakey(t *testing.T, name string,
|
||||
noPlaintext bool, bits int,
|
||||
dataKeyInfo map[string]interface{}) logicaltest.TestStep {
|
||||
data := map[string]interface{}{}
|
||||
subPath := "plaintext"
|
||||
if noPlaintext {
|
||||
subPath = "wrapped"
|
||||
}
|
||||
if bits != 256 {
|
||||
data["bits"] = bits
|
||||
}
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.WriteOperation,
|
||||
Path: "datakey/" + subPath + "/" + name,
|
||||
Data: data,
|
||||
Check: func(resp *logical.Response) error {
|
||||
var d struct {
|
||||
Plaintext string `mapstructure:"plaintext"`
|
||||
Ciphertext string `mapstructure:"ciphertext"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
if noPlaintext && len(d.Plaintext) != 0 {
|
||||
return fmt.Errorf("received plaintxt when we disabled it")
|
||||
}
|
||||
if !noPlaintext {
|
||||
if len(d.Plaintext) == 0 {
|
||||
return fmt.Errorf("did not get plaintext when we expected it")
|
||||
}
|
||||
dataKeyInfo["plaintext"] = d.Plaintext
|
||||
plainBytes, err := base64.StdEncoding.DecodeString(d.Plaintext)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not base64 decode plaintext string '%s'", d.Plaintext)
|
||||
}
|
||||
if len(plainBytes)*8 != bits {
|
||||
return fmt.Errorf("returned key does not have correct bit length")
|
||||
}
|
||||
}
|
||||
dataKeyInfo["ciphertext"] = d.Ciphertext
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepDecryptDatakey(t *testing.T, name string,
|
||||
dataKeyInfo map[string]interface{}) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.WriteOperation,
|
||||
Path: "decrypt/" + name,
|
||||
Data: dataKeyInfo,
|
||||
Check: func(resp *logical.Response) error {
|
||||
var d struct {
|
||||
Plaintext string `mapstructure:"plaintext"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if d.Plaintext != dataKeyInfo["plaintext"].(string) {
|
||||
return fmt.Errorf("plaintext mismatch: got '%s', expected '%s', decryptData was %#v", d.Plaintext, dataKeyInfo["plaintext"].(string))
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyUpgrade(t *testing.T) {
|
||||
p := &Policy{
|
||||
Name: "test",
|
||||
Key: []byte(testPlaintext),
|
||||
CipherMode: "aes-gcm",
|
||||
}
|
||||
|
||||
p.migrateKeyToKeysMap()
|
||||
|
||||
if p.Key != nil ||
|
||||
p.Keys == nil ||
|
||||
len(p.Keys) != 1 ||
|
||||
string(p.Keys[1].Key) != testPlaintext {
|
||||
t.Errorf("bad key migration, result is %#v", p.Keys)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
package transit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func pathConfig() *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "keys/" + framework.GenericNameRegex("name") + "/config",
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the key",
|
||||
},
|
||||
|
||||
"min_decryption_version": &framework.FieldSchema{
|
||||
Type: framework.TypeInt,
|
||||
Description: `If set, the minimum version of the key allowed
|
||||
to be decrypted.`,
|
||||
},
|
||||
|
||||
"deletion_allowed": &framework.FieldSchema{
|
||||
Type: framework.TypeBool,
|
||||
Description: "Whether to allow deletion of the key",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.WriteOperation: pathConfigWrite,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathConfigHelpSyn,
|
||||
HelpDescription: pathConfigHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func pathConfigWrite(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
|
||||
// Check if the policy already exists
|
||||
policy, err := getPolicy(req, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if policy == nil {
|
||||
return logical.ErrorResponse(
|
||||
fmt.Sprintf("no existing role named %s could be found", name)),
|
||||
logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
persistNeeded := false
|
||||
|
||||
minDecryptionVersion := d.Get("min_decryption_version").(int)
|
||||
if minDecryptionVersion != 0 &&
|
||||
minDecryptionVersion != policy.MinDecryptionVersion {
|
||||
policy.MinDecryptionVersion = minDecryptionVersion
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
allowDeletionInt, ok := d.GetOk("deletion_allowed")
|
||||
if ok {
|
||||
allowDeletion := allowDeletionInt.(bool)
|
||||
if allowDeletion != policy.DeletionAllowed {
|
||||
policy.DeletionAllowed = allowDeletion
|
||||
persistNeeded = true
|
||||
}
|
||||
}
|
||||
|
||||
if !persistNeeded {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, policy.Persist(req.Storage, name)
|
||||
}
|
||||
|
||||
const pathConfigHelpSyn = `Configure a named encryption key`
|
||||
|
||||
const pathConfigHelpDesc = `
|
||||
This path is used to configure the named key. Currently, this
|
||||
supports adjusting the minimum version of the key allowed to
|
||||
be used for decryption via the min_decryption_version paramter.
|
||||
`
|
|
@ -0,0 +1,142 @@
|
|||
package transit
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func pathDatakey() *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "datakey/" + framework.GenericNameRegex("plaintext") + "/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "The backend key used for encrypting the data key",
|
||||
},
|
||||
|
||||
"plaintext": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: `"plaintext" will return the key in both plaintext and
|
||||
ciphertext; "wrapped" will return the ciphertext only.`,
|
||||
},
|
||||
|
||||
"context": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Context for key derivation. Required for derived keys.",
|
||||
},
|
||||
|
||||
"bits": &framework.FieldSchema{
|
||||
Type: framework.TypeInt,
|
||||
Description: `Number of bits for the key; currently 128 and
|
||||
256 are supported. Defaults to 256.`,
|
||||
Default: 256,
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.WriteOperation: pathDatakeyWrite,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathDatakeyHelpSyn,
|
||||
HelpDescription: pathDatakeyHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func pathDatakeyWrite(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
|
||||
plaintext := d.Get("plaintext").(string)
|
||||
plaintextAllowed := false
|
||||
switch plaintext {
|
||||
case "plaintext":
|
||||
plaintextAllowed = true
|
||||
case "wrapped":
|
||||
default:
|
||||
return logical.ErrorResponse("Invalid path, must be 'plaintext' or 'wrapped'"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Decode the context if any
|
||||
contextRaw := d.Get("context").(string)
|
||||
var context []byte
|
||||
if len(contextRaw) != 0 {
|
||||
var err error
|
||||
context, err = base64.StdEncoding.DecodeString(contextRaw)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse("failed to decode context as base64"), logical.ErrInvalidRequest
|
||||
}
|
||||
}
|
||||
|
||||
// Get the policy
|
||||
p, err := getPolicy(req, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Error if invalid policy
|
||||
if p == nil {
|
||||
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
newKey := make([]byte, 32)
|
||||
bits := d.Get("bits").(int)
|
||||
switch bits {
|
||||
case 512:
|
||||
newKey = make([]byte, 64)
|
||||
case 256:
|
||||
case 128:
|
||||
newKey = make([]byte, 16)
|
||||
default:
|
||||
return logical.ErrorResponse("invalid bit length"), logical.ErrInvalidRequest
|
||||
}
|
||||
_, err = rand.Read(newKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ciphertext, err := p.Encrypt(context, base64.StdEncoding.EncodeToString(newKey))
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case certutil.UserError:
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
case certutil.InternalError:
|
||||
return nil, err
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if ciphertext == "" {
|
||||
return nil, fmt.Errorf("empty ciphertext returned")
|
||||
}
|
||||
|
||||
// Generate the response
|
||||
resp := &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"ciphertext": ciphertext,
|
||||
},
|
||||
}
|
||||
|
||||
if plaintextAllowed {
|
||||
resp.Data["plaintext"] = base64.StdEncoding.EncodeToString(newKey)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
const pathDatakeyHelpSyn = `Generate a data key`
|
||||
|
||||
const pathDatakeyHelpDesc = `
|
||||
This path can be used to generate a data key: a random
|
||||
key of a certain length that can be used for encryption
|
||||
and decryption, protected by the named backend key. 128, 256,
|
||||
or 512 bits can be specified; if not specified, the default
|
||||
is 256 bits. Call with the the "wrapped" path to prevent the
|
||||
(base64-encoded) plaintext key from being returned along with
|
||||
the encrypted key, the "plaintext" path returns both.
|
||||
`
|
|
@ -1,11 +1,10 @@
|
|||
package transit
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
@ -42,8 +41,8 @@ func pathDecrypt() *framework.Path {
|
|||
func pathDecryptWrite(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
value := d.Get("ciphertext").(string)
|
||||
if len(value) == 0 {
|
||||
ciphertext := d.Get("ciphertext").(string)
|
||||
if len(ciphertext) == 0 {
|
||||
return logical.ErrorResponse("missing ciphertext to decrypt"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
|
@ -69,56 +68,26 @@ func pathDecryptWrite(
|
|||
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Derive the key that should be used
|
||||
key, err := p.DeriveKey(context)
|
||||
plaintext, err := p.Decrypt(context, ciphertext)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
switch err.(type) {
|
||||
case certutil.UserError:
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
case certutil.InternalError:
|
||||
return nil, err
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Guard against a potentially invalid cipher-mode
|
||||
switch p.CipherMode {
|
||||
case "aes-gcm":
|
||||
default:
|
||||
return logical.ErrorResponse("unsupported cipher mode"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Verify the prefix
|
||||
if !strings.HasPrefix(value, "vault:v0:") {
|
||||
return logical.ErrorResponse("invalid ciphertext"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Decode the base64
|
||||
decoded, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(value, "vault:v0:"))
|
||||
if err != nil {
|
||||
return logical.ErrorResponse("invalid ciphertext"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Setup the cipher
|
||||
aesCipher, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Setup the GCM AEAD
|
||||
gcm, err := cipher.NewGCM(aesCipher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Extract the nonce and ciphertext
|
||||
nonce := decoded[:gcm.NonceSize()]
|
||||
ciphertext := decoded[gcm.NonceSize():]
|
||||
|
||||
// Verify and Decrypt
|
||||
plain, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse("invalid ciphertext"), logical.ErrInvalidRequest
|
||||
if plaintext == "" {
|
||||
return nil, fmt.Errorf("empty plaintext returned")
|
||||
}
|
||||
|
||||
// Generate the response
|
||||
resp := &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"plaintext": base64.StdEncoding.EncodeToString(plain),
|
||||
"plaintext": plaintext,
|
||||
},
|
||||
}
|
||||
return resp, nil
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
package transit
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
@ -48,10 +46,10 @@ func pathEncryptWrite(
|
|||
return logical.ErrorResponse("missing plaintext to encrypt"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Decode the plaintext value
|
||||
plaintext, err := base64.StdEncoding.DecodeString(value)
|
||||
// Get the policy
|
||||
p, err := getPolicy(req, name)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse("failed to decode plaintext as base64"), logical.ErrInvalidRequest
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Decode the context if any
|
||||
|
@ -65,12 +63,6 @@ func pathEncryptWrite(
|
|||
}
|
||||
}
|
||||
|
||||
// Get the policy
|
||||
p, err := getPolicy(req, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Error if invalid policy
|
||||
if p == nil {
|
||||
isDerived := len(context) != 0
|
||||
|
@ -80,54 +72,26 @@ func pathEncryptWrite(
|
|||
}
|
||||
}
|
||||
|
||||
// Derive the key that should be used
|
||||
key, err := p.DeriveKey(context)
|
||||
ciphertext, err := p.Encrypt(context, value)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
switch err.(type) {
|
||||
case certutil.UserError:
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
case certutil.InternalError:
|
||||
return nil, err
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Guard against a potentially invalid cipher-mode
|
||||
switch p.CipherMode {
|
||||
case "aes-gcm":
|
||||
default:
|
||||
return logical.ErrorResponse("unsupported cipher mode"), logical.ErrInvalidRequest
|
||||
if ciphertext == "" {
|
||||
return nil, fmt.Errorf("empty ciphertext returned")
|
||||
}
|
||||
|
||||
// Setup the cipher
|
||||
aesCipher, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Setup the GCM AEAD
|
||||
gcm, err := cipher.NewGCM(aesCipher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Compute random nonce
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
_, err = rand.Read(nonce)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Encrypt and tag with GCM
|
||||
out := gcm.Seal(nil, nonce, plaintext, nil)
|
||||
|
||||
// Place the encrypted data after the nonce
|
||||
full := append(nonce, out...)
|
||||
|
||||
// Convert to base64
|
||||
encoded := base64.StdEncoding.EncodeToString(full)
|
||||
|
||||
// Prepend some information
|
||||
encoded = "vault:v0:" + encoded
|
||||
|
||||
// Generate the response
|
||||
resp := &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"ciphertext": encoded,
|
||||
"ciphertext": ciphertext,
|
||||
},
|
||||
}
|
||||
return resp, nil
|
||||
|
|
|
@ -1,126 +1,13 @@
|
|||
package transit
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/hashicorp/vault/helper/kdf"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
const (
|
||||
// kdfMode is the only KDF mode currently supported
|
||||
kdfMode = "hmac-sha256-counter"
|
||||
)
|
||||
|
||||
// Policy is the struct used to store metadata
|
||||
type Policy struct {
|
||||
Name string `json:"name"`
|
||||
Key []byte `json:"key"`
|
||||
CipherMode string `json:"cipher"`
|
||||
|
||||
// Derived keys MUST provide a context and the
|
||||
// master underlying key is never used.
|
||||
Derived bool `json:"derived"`
|
||||
KDFMode string `json:"kdf_mode"`
|
||||
}
|
||||
|
||||
func (p *Policy) Serialize() ([]byte, error) {
|
||||
return json.Marshal(p)
|
||||
}
|
||||
|
||||
// DeriveKey is used to derive the encryption key that should
|
||||
// be used depending on the policy. If derivation is disabled the
|
||||
// raw key is used and no context is required, otherwise the KDF
|
||||
// mode is used with the context to derive the proper key.
|
||||
func (p *Policy) DeriveKey(context []byte) ([]byte, error) {
|
||||
// Fast-path non-derived keys
|
||||
if !p.Derived {
|
||||
return p.Key, nil
|
||||
}
|
||||
|
||||
// Ensure a context is provided
|
||||
if len(context) == 0 {
|
||||
return nil, fmt.Errorf("missing 'context' for key deriviation. The key was created using a derived key, which means additional, per-request information must be included in order to encrypt or decrypt information.")
|
||||
}
|
||||
|
||||
switch p.KDFMode {
|
||||
case kdfMode:
|
||||
prf := kdf.HMACSHA256PRF
|
||||
prfLen := kdf.HMACSHA256PRFLen
|
||||
return kdf.CounterMode(prf, prfLen, p.Key, context, 256)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported key derivation mode")
|
||||
}
|
||||
}
|
||||
|
||||
func DeserializePolicy(buf []byte) (*Policy, error) {
|
||||
p := new(Policy)
|
||||
if err := json.Unmarshal(buf, p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func getPolicy(req *logical.Request, name string) (*Policy, error) {
|
||||
// Check if the policy already exists
|
||||
raw, err := req.Storage.Get("policy/" + name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if raw == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Decode the policy
|
||||
p, err := DeserializePolicy(raw.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// generatePolicy is used to create a new named policy with
|
||||
// a randomly generated key
|
||||
func generatePolicy(storage logical.Storage, name string, derived bool) (*Policy, error) {
|
||||
// Create the policy object
|
||||
p := &Policy{
|
||||
Name: name,
|
||||
CipherMode: "aes-gcm",
|
||||
Derived: derived,
|
||||
}
|
||||
if derived {
|
||||
p.KDFMode = kdfMode
|
||||
}
|
||||
|
||||
// Generate a 256bit key
|
||||
p.Key = make([]byte, 32)
|
||||
_, err := rand.Read(p.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Encode the policy
|
||||
buf, err := p.Serialize()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Write the policy into storage
|
||||
err = storage.Put(&logical.StorageEntry{
|
||||
Key: "policy/" + name,
|
||||
Value: buf,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return the policy
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func pathKeys() *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "keys/" + framework.GenericNameRegex("name"),
|
||||
|
@ -169,6 +56,7 @@ func pathPolicyWrite(
|
|||
func pathPolicyRead(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
|
||||
p, err := getPolicy(req, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -180,14 +68,22 @@ func pathPolicyRead(
|
|||
// Return the response
|
||||
resp := &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"name": p.Name,
|
||||
"cipher_mode": p.CipherMode,
|
||||
"derived": p.Derived,
|
||||
"name": p.Name,
|
||||
"cipher_mode": p.CipherMode,
|
||||
"derived": p.Derived,
|
||||
"deletion_allowed": p.DeletionAllowed,
|
||||
},
|
||||
}
|
||||
if p.Derived {
|
||||
resp.Data["kdf_mode"] = p.KDFMode
|
||||
}
|
||||
|
||||
retKeys := map[string]int64{}
|
||||
for k, v := range p.Keys {
|
||||
retKeys[strconv.Itoa(k)] = v.CreationTime
|
||||
}
|
||||
resp.Data["keys"] = retKeys
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
|
@ -195,14 +91,26 @@ func pathPolicyDelete(
|
|||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
|
||||
err := req.Storage.Delete("policy/" + name)
|
||||
p, err := getPolicy(req, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return logical.ErrorResponse(fmt.Sprintf("error looking up policy %s, error is %s", name, err)), err
|
||||
}
|
||||
if p == nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("no such key %s", name)), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
if !p.DeletionAllowed {
|
||||
return logical.ErrorResponse(fmt.Sprintf("'allow_deletion' config value is not set")), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
err = req.Storage.Delete("policy/" + name)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("error deleting policy %s: %s", name, err)), err
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
const pathPolicyHelpSyn = `Managed named encrption keys`
|
||||
const pathPolicyHelpSyn = `Managed named encryption keys`
|
||||
|
||||
const pathPolicyHelpDesc = `
|
||||
This path is used to manage the named keys that are available.
|
||||
|
|
|
@ -1,58 +0,0 @@
|
|||
package transit
|
||||
|
||||
import (
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func pathRaw() *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "raw/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the key",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: pathRawRead,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathPolicyHelpSyn,
|
||||
HelpDescription: pathPolicyHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func pathRawRead(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
p, err := getPolicy(req, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if p == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Return the response
|
||||
resp := &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"name": p.Name,
|
||||
"key": p.Key,
|
||||
"cipher_mode": p.CipherMode,
|
||||
"derived": p.Derived,
|
||||
},
|
||||
}
|
||||
if p.Derived {
|
||||
resp.Data["kdf_mode"] = p.KDFMode
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
const pathRawHelpSyn = `Fetch raw keys for named encrption keys`
|
||||
|
||||
const pathRawHelpDesc = `
|
||||
This path is used to get the underlying encryption keys used for the
|
||||
named keys that are available.
|
||||
`
|
|
@ -0,0 +1,120 @@
|
|||
package transit
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func pathRewrap() *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "rewrap/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the key",
|
||||
},
|
||||
|
||||
"ciphertext": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Ciphertext value to rewrap",
|
||||
},
|
||||
|
||||
"context": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Context for key derivation. Required for derived keys.",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.WriteOperation: pathRewrapWrite,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathRewrapHelpSyn,
|
||||
HelpDescription: pathRewrapHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func pathRewrapWrite(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
|
||||
value := d.Get("ciphertext").(string)
|
||||
if len(value) == 0 {
|
||||
return logical.ErrorResponse("missing ciphertext to decrypt"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Decode the context if any
|
||||
contextRaw := d.Get("context").(string)
|
||||
var context []byte
|
||||
if len(contextRaw) != 0 {
|
||||
var err error
|
||||
context, err = base64.StdEncoding.DecodeString(contextRaw)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse("failed to decode context as base64"), logical.ErrInvalidRequest
|
||||
}
|
||||
}
|
||||
|
||||
// Get the policy
|
||||
p, err := getPolicy(req, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Error if invalid policy
|
||||
if p == nil {
|
||||
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
plaintext, err := p.Decrypt(context, value)
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case certutil.UserError:
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
case certutil.InternalError:
|
||||
return nil, err
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if plaintext == "" {
|
||||
return nil, fmt.Errorf("empty plaintext returned during rewrap")
|
||||
}
|
||||
|
||||
ciphertext, err := p.Encrypt(context, plaintext)
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case certutil.UserError:
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
case certutil.InternalError:
|
||||
return nil, err
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if ciphertext == "" {
|
||||
return nil, fmt.Errorf("empty ciphertext returned")
|
||||
}
|
||||
|
||||
// Generate the response
|
||||
resp := &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"ciphertext": ciphertext,
|
||||
},
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
const pathRewrapHelpSyn = `Rewrap ciphertext`
|
||||
|
||||
const pathRewrapHelpDesc = `
|
||||
After key rotation, this function can be used to rewrap the
|
||||
given ciphertext with the latest version of the named key.
|
||||
If the given ciphertext is already using the latest version
|
||||
of the key, this function is a no-op.
|
||||
`
|
|
@ -0,0 +1,56 @@
|
|||
package transit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func pathRotate() *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "keys/" + framework.GenericNameRegex("name") + "/rotate",
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the key",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.WriteOperation: pathRotateWrite,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathRotateHelpSyn,
|
||||
HelpDescription: pathRotateHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func pathRotateWrite(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
|
||||
// Check if the policy already exists
|
||||
policy, err := getPolicy(req, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if policy == nil {
|
||||
return logical.ErrorResponse(
|
||||
fmt.Sprintf("no existing role named %s could be found", name)),
|
||||
logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Generate the policy
|
||||
err = policy.rotate(req.Storage)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
const pathRotateHelpSyn = `Rotate named encryption key`
|
||||
|
||||
const pathRotateHelpDesc = `
|
||||
This path is used to rotate the named key. After rotation,
|
||||
new encryption requests using this name will use the new key,
|
||||
but decryption will still be supported for older versions.
|
||||
`
|
|
@ -0,0 +1,360 @@
|
|||
package transit
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
"github.com/hashicorp/vault/helper/kdf"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
const (
|
||||
// kdfMode is the only KDF mode currently supported
|
||||
kdfMode = "hmac-sha256-counter"
|
||||
)
|
||||
|
||||
// KeyEntry stores the key and metadata
|
||||
type KeyEntry struct {
|
||||
Key []byte `json:"key"`
|
||||
CreationTime int64 `json:"creation_time"`
|
||||
}
|
||||
|
||||
// KeyEntryMap is used to allow JSON marshal/unmarshal
|
||||
type KeyEntryMap map[int]KeyEntry
|
||||
|
||||
// MarshalJSON implements JSON marshaling
|
||||
func (kem KeyEntryMap) MarshalJSON() ([]byte, error) {
|
||||
intermediate := map[string]KeyEntry{}
|
||||
for k, v := range kem {
|
||||
intermediate[strconv.Itoa(k)] = v
|
||||
}
|
||||
return json.Marshal(&intermediate)
|
||||
}
|
||||
|
||||
// MarshalJSON implements JSON unmarshaling
|
||||
func (kem KeyEntryMap) UnmarshalJSON(data []byte) error {
|
||||
intermediate := map[string]KeyEntry{}
|
||||
err := json.Unmarshal(data, &intermediate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for k, v := range intermediate {
|
||||
keyval, err := strconv.Atoi(k)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
kem[keyval] = v
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Policy is the struct used to store metadata
|
||||
type Policy struct {
|
||||
Name string `json:"name"`
|
||||
Key []byte `json:"key,omitempty"` //DEPRECATED
|
||||
Keys KeyEntryMap `json:"keys"`
|
||||
CipherMode string `json:"cipher"`
|
||||
|
||||
// Derived keys MUST provide a context and the
|
||||
// master underlying key is never used.
|
||||
Derived bool `json:"derived"`
|
||||
KDFMode string `json:"kdf_mode"`
|
||||
|
||||
// The minimum version of the key allowed to be used
|
||||
// for decryption
|
||||
MinDecryptionVersion int `json:"min_decryption_version"`
|
||||
|
||||
// Whether the key is allowed to be deleted
|
||||
DeletionAllowed bool `json:"deletion_allowed"`
|
||||
}
|
||||
|
||||
func (p *Policy) Persist(storage logical.Storage, name string) error {
|
||||
// Encode the policy
|
||||
buf, err := p.Serialize()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write the policy into storage
|
||||
err = storage.Put(&logical.StorageEntry{
|
||||
Key: "policy/" + name,
|
||||
Value: buf,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Policy) Serialize() ([]byte, error) {
|
||||
return json.Marshal(p)
|
||||
}
|
||||
|
||||
// DeriveKey is used to derive the encryption key that should
|
||||
// be used depending on the policy. If derivation is disabled the
|
||||
// raw key is used and no context is required, otherwise the KDF
|
||||
// mode is used with the context to derive the proper key.
|
||||
func (p *Policy) DeriveKey(context []byte, ver int) ([]byte, error) {
|
||||
if p.Keys == nil || len(p.Keys) == 0 {
|
||||
if p.Key == nil || len(p.Key) == 0 {
|
||||
return nil, certutil.InternalError{Err: "unable to access the key; no key versions found"}
|
||||
}
|
||||
p.migrateKeyToKeysMap()
|
||||
}
|
||||
|
||||
if len(p.Keys) == 0 {
|
||||
return nil, certutil.InternalError{Err: "unable to access the key; no key versions found"}
|
||||
}
|
||||
|
||||
if ver <= 0 || ver > len(p.Keys) {
|
||||
return nil, certutil.UserError{Err: "invalid key version"}
|
||||
}
|
||||
|
||||
// Fast-path non-derived keys
|
||||
if !p.Derived {
|
||||
return p.Keys[ver].Key, nil
|
||||
}
|
||||
|
||||
// Ensure a context is provided
|
||||
if len(context) == 0 {
|
||||
return nil, certutil.UserError{Err: "missing 'context' for key deriviation. The key was created using a derived key, which means additional, per-request information must be included in order to encrypt or decrypt information"}
|
||||
}
|
||||
|
||||
switch p.KDFMode {
|
||||
case kdfMode:
|
||||
prf := kdf.HMACSHA256PRF
|
||||
prfLen := kdf.HMACSHA256PRFLen
|
||||
return kdf.CounterMode(prf, prfLen, p.Keys[ver].Key, context, 256)
|
||||
default:
|
||||
return nil, certutil.InternalError{Err: "unsupported key derivation mode"}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Policy) Encrypt(context []byte, value string) (string, error) {
|
||||
// Decode the plaintext value
|
||||
plaintext, err := base64.StdEncoding.DecodeString(value)
|
||||
if err != nil {
|
||||
return "", certutil.UserError{Err: "failed to decode plaintext as base64"}
|
||||
}
|
||||
|
||||
// Derive the key that should be used
|
||||
key, err := p.DeriveKey(context, len(p.Keys))
|
||||
if err != nil {
|
||||
return "", certutil.InternalError{Err: err.Error()}
|
||||
}
|
||||
|
||||
// Guard against a potentially invalid cipher-mode
|
||||
switch p.CipherMode {
|
||||
case "aes-gcm":
|
||||
default:
|
||||
return "", certutil.InternalError{Err: "unsupported cipher mode"}
|
||||
}
|
||||
|
||||
// Setup the cipher
|
||||
aesCipher, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", certutil.InternalError{Err: err.Error()}
|
||||
}
|
||||
|
||||
// Setup the GCM AEAD
|
||||
gcm, err := cipher.NewGCM(aesCipher)
|
||||
if err != nil {
|
||||
return "", certutil.InternalError{Err: err.Error()}
|
||||
}
|
||||
|
||||
// Compute random nonce
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
_, err = rand.Read(nonce)
|
||||
if err != nil {
|
||||
return "", certutil.InternalError{Err: err.Error()}
|
||||
}
|
||||
|
||||
// Encrypt and tag with GCM
|
||||
out := gcm.Seal(nil, nonce, plaintext, nil)
|
||||
|
||||
// Place the encrypted data after the nonce
|
||||
full := append(nonce, out...)
|
||||
|
||||
// Convert to base64
|
||||
encoded := base64.StdEncoding.EncodeToString(full)
|
||||
|
||||
// Prepend some information
|
||||
encoded = "vault:v" + strconv.Itoa(len(p.Keys)) + ":" + encoded
|
||||
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
func (p *Policy) Decrypt(context []byte, value string) (string, error) {
|
||||
// Verify the prefix
|
||||
if !strings.HasPrefix(value, "vault:v") {
|
||||
return "", certutil.UserError{Err: "invalid ciphertext"}
|
||||
}
|
||||
|
||||
splitVerCiphertext := strings.SplitN(strings.TrimPrefix(value, "vault:v"), ":", 2)
|
||||
if len(splitVerCiphertext) != 2 {
|
||||
return "", certutil.UserError{Err: "invalid ciphertext"}
|
||||
}
|
||||
|
||||
ver, err := strconv.Atoi(splitVerCiphertext[0])
|
||||
if err != nil {
|
||||
return "", certutil.UserError{Err: "invalid ciphertext"}
|
||||
}
|
||||
|
||||
if ver == 0 {
|
||||
// Compatibility mode with initial implementation, where keys start at zero
|
||||
ver = 1
|
||||
}
|
||||
|
||||
if p.MinDecryptionVersion > 0 && ver < p.MinDecryptionVersion {
|
||||
return "", certutil.UserError{Err: "ciphertext version is disallowed by policy (too old)"}
|
||||
}
|
||||
|
||||
// Derive the key that should be used
|
||||
key, err := p.DeriveKey(context, ver)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Guard against a potentially invalid cipher-mode
|
||||
switch p.CipherMode {
|
||||
case "aes-gcm":
|
||||
default:
|
||||
return "", certutil.InternalError{Err: "unsupported cipher mode"}
|
||||
}
|
||||
|
||||
// Decode the base64
|
||||
decoded, err := base64.StdEncoding.DecodeString(splitVerCiphertext[1])
|
||||
if err != nil {
|
||||
return "", certutil.UserError{Err: "invalid ciphertext"}
|
||||
}
|
||||
|
||||
// Setup the cipher
|
||||
aesCipher, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", certutil.InternalError{Err: err.Error()}
|
||||
}
|
||||
|
||||
// Setup the GCM AEAD
|
||||
gcm, err := cipher.NewGCM(aesCipher)
|
||||
if err != nil {
|
||||
return "", certutil.InternalError{Err: err.Error()}
|
||||
}
|
||||
|
||||
// Extract the nonce and ciphertext
|
||||
nonce := decoded[:gcm.NonceSize()]
|
||||
ciphertext := decoded[gcm.NonceSize():]
|
||||
|
||||
// Verify and Decrypt
|
||||
plain, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return "", certutil.UserError{Err: "invalid ciphertext"}
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(plain), nil
|
||||
}
|
||||
|
||||
func (p *Policy) rotate(storage logical.Storage) error {
|
||||
if p.Keys == nil {
|
||||
p.migrateKeyToKeysMap()
|
||||
}
|
||||
|
||||
// Generate a 256bit key
|
||||
newKey := make([]byte, 32)
|
||||
_, err := rand.Read(newKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.Keys[len(p.Keys)+1] = KeyEntry{
|
||||
Key: newKey,
|
||||
CreationTime: time.Now().Unix(),
|
||||
}
|
||||
|
||||
return p.Persist(storage, p.Name)
|
||||
}
|
||||
|
||||
func (p *Policy) migrateKeyToKeysMap() {
|
||||
if p.Key == nil || len(p.Key) == 0 {
|
||||
p.Key = nil
|
||||
p.Keys = KeyEntryMap{}
|
||||
return
|
||||
}
|
||||
|
||||
p.Keys = KeyEntryMap{
|
||||
1: KeyEntry{
|
||||
Key: p.Key,
|
||||
CreationTime: time.Now().Unix(),
|
||||
},
|
||||
}
|
||||
p.Key = nil
|
||||
}
|
||||
|
||||
func deserializePolicy(buf []byte) (*Policy, error) {
|
||||
p := &Policy{
|
||||
Keys: KeyEntryMap{},
|
||||
}
|
||||
if err := json.Unmarshal(buf, p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func getPolicy(req *logical.Request, name string) (*Policy, error) {
|
||||
// Check if the policy already exists
|
||||
raw, err := req.Storage.Get("policy/" + name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if raw == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Decode the policy
|
||||
p, err := deserializePolicy(raw.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Ensure we've moved from Key -> Keys
|
||||
if p.Key != nil && len(p.Key) > 0 {
|
||||
p.migrateKeyToKeysMap()
|
||||
|
||||
err = p.Persist(req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// generatePolicy is used to create a new named policy with
|
||||
// a randomly generated key
|
||||
func generatePolicy(storage logical.Storage, name string, derived bool) (*Policy, error) {
|
||||
// Create the policy object
|
||||
p := &Policy{
|
||||
Name: name,
|
||||
CipherMode: "aes-gcm",
|
||||
Derived: derived,
|
||||
}
|
||||
if derived {
|
||||
p.KDFMode = kdfMode
|
||||
}
|
||||
|
||||
err := p.rotate(storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return the policy
|
||||
return p, nil
|
||||
}
|
|
@ -209,7 +209,18 @@ func Test(t TestT, c TestCase) {
|
|||
Path: "sys/revoke/" + resp.Secret.LeaseID,
|
||||
})
|
||||
}
|
||||
if err == nil && resp.IsError() && !s.ErrorOk {
|
||||
// If it's an error, but an error is expected, and one is also
|
||||
// returned as a logical.ErrorResponse, let it go to the check
|
||||
if err != nil {
|
||||
if !resp.IsError() || (resp.IsError() && !s.ErrorOk) {
|
||||
t.Error(fmt.Sprintf("Failed step %d: %s", i+1, err))
|
||||
break
|
||||
}
|
||||
// Set it to nil here as we're catching on the
|
||||
// logical.ErrorResponse instead
|
||||
err = nil
|
||||
}
|
||||
if resp.IsError() && !s.ErrorOk {
|
||||
err = fmt.Errorf("Erroneous response:\n\n%#v", resp)
|
||||
}
|
||||
if err == nil && s.Check != nil {
|
||||
|
|
|
@ -62,21 +62,6 @@ cipher_mode aes-gcm
|
|||
derived false
|
||||
````
|
||||
|
||||
We can read from the `raw/` endpoint to see the encryption key itself:
|
||||
|
||||
```
|
||||
$ vault read transit/raw/foo
|
||||
Key Value
|
||||
name foo
|
||||
cipher_mode aes-gcm
|
||||
key PhKFTALCmhAhVQfMBAH4+UwJ6J2gybapUH9BsrtIgR8=
|
||||
derived false
|
||||
````
|
||||
|
||||
Here we can see that the randomly generated encryption key being used, as
|
||||
well as the AES-GCM cipher mode. We don't need to know any of this to use
|
||||
the key however.
|
||||
|
||||
Now, if we wanted to encrypt a piece of plain text, we use the encrypt
|
||||
endpoint using our named key:
|
||||
|
||||
|
@ -299,44 +284,3 @@ only encrypt or decrypt using the named keys they need access to.
|
|||
|
||||
</dd>
|
||||
</dl>
|
||||
|
||||
### /transit/raw/
|
||||
#### GET
|
||||
|
||||
<dl class="api">
|
||||
<dt>Description</dt>
|
||||
<dd>
|
||||
Returns raw information about a named encryption key,
|
||||
Including the underlying encryption key. This is a root protected endpoint.
|
||||
</dd>
|
||||
|
||||
<dt>Method</dt>
|
||||
<dd>GET</dd>
|
||||
|
||||
<dt>URL</dt>
|
||||
<dd>`/transit/raw/<name>`</dd>
|
||||
|
||||
<dt>Parameters</dt>
|
||||
<dd>
|
||||
None
|
||||
</dd>
|
||||
|
||||
<dt>Returns</dt>
|
||||
<dd>
|
||||
|
||||
```javascript
|
||||
{
|
||||
"data": {
|
||||
"name": "foo",
|
||||
"cipher_mode": "aes-gcm",
|
||||
"key": "PhKFTALCmhAhVQfMBAH4+UwJ6J2gybapUH9BsrtIgR8="
|
||||
"derived": "true",
|
||||
"kdf_mode": "hmac-sha256-counter",
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
</dd>
|
||||
</dl>
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue