Merge pull request #399 from hashicorp/f-kdf

Support for key derivation in secret/transit
This commit is contained in:
Armon Dadgar 2015-07-05 20:06:28 -06:00
commit fb4a6ff28b
8 changed files with 363 additions and 20 deletions

View file

@ -19,14 +19,14 @@ func TestBackend_basic(t *testing.T) {
logicaltest.Test(t, logicaltest.TestCase{
Backend: Backend(),
Steps: []logicaltest.TestStep{
testAccStepWritePolicy(t, "test"),
testAccStepReadPolicy(t, "test", false),
testAccStepReadRaw(t, "test", false),
testAccStepWritePolicy(t, "test", false),
testAccStepReadPolicy(t, "test", false, false),
testAccStepReadRaw(t, "test", false, false),
testAccStepEncrypt(t, "test", testPlaintext, decryptData),
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
testAccStepDeletePolicy(t, "test"),
testAccStepReadPolicy(t, "test", true),
testAccStepReadRaw(t, "test", true),
testAccStepReadPolicy(t, "test", true, false),
testAccStepReadRaw(t, "test", true, false),
},
})
}
@ -36,20 +36,40 @@ func TestBackend_upsert(t *testing.T) {
logicaltest.Test(t, logicaltest.TestCase{
Backend: Backend(),
Steps: []logicaltest.TestStep{
testAccStepReadPolicy(t, "test", true),
testAccStepReadPolicy(t, "test", true, false),
testAccStepEncrypt(t, "test", testPlaintext, decryptData),
testAccStepReadPolicy(t, "test", false),
testAccStepReadPolicy(t, "test", false, false),
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
testAccStepDeletePolicy(t, "test"),
testAccStepReadPolicy(t, "test", true),
testAccStepReadPolicy(t, "test", true, false),
},
})
}
func testAccStepWritePolicy(t *testing.T, name string) logicaltest.TestStep {
func TestBackend_basic_derived(t *testing.T) {
decryptData := make(map[string]interface{})
logicaltest.Test(t, logicaltest.TestCase{
Backend: Backend(),
Steps: []logicaltest.TestStep{
testAccStepWritePolicy(t, "test", true),
testAccStepReadPolicy(t, "test", false, true),
testAccStepReadRaw(t, "test", false, true),
testAccStepEncryptContext(t, "test", testPlaintext, "my-cool-context", decryptData),
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
testAccStepDeletePolicy(t, "test"),
testAccStepReadPolicy(t, "test", true, true),
testAccStepReadRaw(t, "test", true, true),
},
})
}
func testAccStepWritePolicy(t *testing.T, name string, derived bool) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.WriteOperation,
Path: "keys/" + name,
Data: map[string]interface{}{
"derived": derived,
},
}
}
@ -60,7 +80,7 @@ func testAccStepDeletePolicy(t *testing.T, name string) logicaltest.TestStep {
}
}
func testAccStepReadPolicy(t *testing.T, name string, expectNone bool) logicaltest.TestStep {
func testAccStepReadPolicy(t *testing.T, name string, expectNone, derived bool) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "keys/" + name,
@ -77,6 +97,8 @@ func testAccStepReadPolicy(t *testing.T, name string, expectNone bool) logicalte
Name string `mapstructure:"name"`
Key []byte `mapstructure:"key"`
CipherMode string `mapstructure:"cipher_mode"`
Derived bool `mapstructure:"derived"`
KDFMode string `mapstructure:"kdf_mode"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
@ -92,12 +114,18 @@ func testAccStepReadPolicy(t *testing.T, name string, expectNone bool) logicalte
if d.Key != nil {
return fmt.Errorf("bad: %#v", d)
}
if d.Derived != derived {
return fmt.Errorf("bad: %#v", d)
}
if derived && d.KDFMode != kdfMode {
return fmt.Errorf("bad: %#v", d)
}
return nil
},
}
}
func testAccStepReadRaw(t *testing.T, name string, expectNone bool) logicaltest.TestStep {
func testAccStepReadRaw(t *testing.T, name string, expectNone, derived bool) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "raw/" + name,
@ -114,6 +142,8 @@ func testAccStepReadRaw(t *testing.T, name string, expectNone bool) logicaltest.
Name string `mapstructure:"name"`
Key []byte `mapstructure:"key"`
CipherMode string `mapstructure:"cipher_mode"`
Derived bool `mapstructure:"derived"`
KDFMode string `mapstructure:"kdf_mode"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
@ -128,6 +158,12 @@ func testAccStepReadRaw(t *testing.T, name string, expectNone bool) logicaltest.
if len(d.Key) != 32 {
return fmt.Errorf("bad: %#v", d)
}
if d.Derived != derived {
return fmt.Errorf("bad: %#v", d)
}
if derived && d.KDFMode != kdfMode {
return fmt.Errorf("bad: %#v", d)
}
return nil
},
}
@ -157,6 +193,32 @@ func testAccStepEncrypt(
}
}
func testAccStepEncryptContext(
t *testing.T, name, plaintext, context string, decryptData map[string]interface{}) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.WriteOperation,
Path: "encrypt/" + name,
Data: map[string]interface{}{
"plaintext": base64.StdEncoding.EncodeToString([]byte(plaintext)),
"context": base64.StdEncoding.EncodeToString([]byte(context)),
},
Check: func(resp *logical.Response) error {
var d struct {
Ciphertext string `mapstructure:"ciphertext"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
if d.Ciphertext == "" {
return fmt.Errorf("missing ciphertext")
}
decryptData["ciphertext"] = d.Ciphertext
decryptData["context"] = base64.StdEncoding.EncodeToString([]byte(context))
return nil
},
}
}
func testAccStepDecrypt(
t *testing.T, name, plaintext string, decryptData map[string]interface{}) logicaltest.TestStep {
return logicaltest.TestStep{

View file

@ -23,6 +23,11 @@ func pathDecrypt() *framework.Path {
Type: framework.TypeString,
Description: "Ciphertext value to decrypt",
},
"context": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Context for key derivation. Required for derived keys.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
@ -42,6 +47,17 @@ func pathDecryptWrite(
return logical.ErrorResponse("missing ciphertext to decrypt"), logical.ErrInvalidRequest
}
// Decode the context if any
contextRaw := d.Get("context").(string)
var context []byte
if len(contextRaw) != 0 {
var err error
context, err = base64.StdEncoding.DecodeString(contextRaw)
if err != nil {
return logical.ErrorResponse("failed to decode context as base64"), logical.ErrInvalidRequest
}
}
// Get the policy
p, err := getPolicy(req, name)
if err != nil {
@ -53,6 +69,12 @@ func pathDecryptWrite(
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
}
// Derive the key that should be used
key, err := p.DeriveKey(context)
if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
}
// Guard against a potentially invalid cipher-mode
switch p.CipherMode {
case "aes-gcm":
@ -72,7 +94,7 @@ func pathDecryptWrite(
}
// Setup the cipher
aesCipher, err := aes.NewCipher(p.Key)
aesCipher, err := aes.NewCipher(key)
if err != nil {
return nil, err
}

View file

@ -24,6 +24,11 @@ func pathEncrypt() *framework.Path {
Type: framework.TypeString,
Description: "Plaintext value to encrypt",
},
"context": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Context for key derivation. Required for derived keys.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
@ -49,6 +54,17 @@ func pathEncryptWrite(
return logical.ErrorResponse("failed to decode plaintext as base64"), logical.ErrInvalidRequest
}
// Decode the context if any
contextRaw := d.Get("context").(string)
var context []byte
if len(contextRaw) != 0 {
var err error
context, err = base64.StdEncoding.DecodeString(contextRaw)
if err != nil {
return logical.ErrorResponse("failed to decode context as base64"), logical.ErrInvalidRequest
}
}
// Get the policy
p, err := getPolicy(req, name)
if err != nil {
@ -57,12 +73,19 @@ func pathEncryptWrite(
// Error if invalid policy
if p == nil {
p, err = generatePolicy(req.Storage, name)
isDerived := len(context) != 0
p, err = generatePolicy(req.Storage, name, isDerived)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("failed to upsert policy: %v", err)), logical.ErrInvalidRequest
}
}
// Derive the key that should be used
key, err := p.DeriveKey(context)
if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
}
// Guard against a potentially invalid cipher-mode
switch p.CipherMode {
case "aes-gcm":
@ -71,7 +94,7 @@ func pathEncryptWrite(
}
// Setup the cipher
aesCipher, err := aes.NewCipher(p.Key)
aesCipher, err := aes.NewCipher(key)
if err != nil {
return nil, err
}

View file

@ -3,22 +3,59 @@ package transit
import (
"crypto/rand"
"encoding/json"
"fmt"
"github.com/hashicorp/vault/helper/kdf"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
const (
// kdfMode is the only KDF mode currently supported
kdfMode = "hmac-sha256-counter"
)
// Policy is the struct used to store metadata
type Policy struct {
Name string `json:"name"`
Key []byte `json:"key"`
CipherMode string `json:"cipher"`
// Derived keys MUST provide a context and the
// master underlying key is never used.
Derived bool `json:"derived"`
KDFMode string `json:"kdf_mode"`
}
func (p *Policy) Serialize() ([]byte, error) {
return json.Marshal(p)
}
// DeriveKey is used to derive the encryption key that should
// be used depending on the policy. If derivation is disabled the
// raw key is used and no context is required, otherwise the KDF
// mode is used with the context to derive the proper key.
func (p *Policy) DeriveKey(context []byte) ([]byte, error) {
// Fast-path non-derived keys
if !p.Derived {
return p.Key, nil
}
// Ensure a context is provided
if len(context) == 0 {
return nil, fmt.Errorf("missing 'context' for key deriviation. The key was created using a derived key, which means additional, per-request information must be included in order to encrypt or decrypt information.")
}
switch p.KDFMode {
case kdfMode:
prf := kdf.HMACSHA256PRF
prfLen := kdf.HMACSHA256PRFLen
return kdf.CounterMode(prf, prfLen, p.Key, context, 256)
default:
return nil, fmt.Errorf("unsupported key derivation mode")
}
}
func DeserializePolicy(buf []byte) (*Policy, error) {
p := new(Policy)
if err := json.Unmarshal(buf, p); err != nil {
@ -47,11 +84,15 @@ func getPolicy(req *logical.Request, name string) (*Policy, error) {
// generatePolicy is used to create a new named policy with
// a randomly generated key
func generatePolicy(storage logical.Storage, name string) (*Policy, error) {
func generatePolicy(storage logical.Storage, name string, derived bool) (*Policy, error) {
// Create the policy object
p := &Policy{
Name: name,
CipherMode: "aes-gcm",
Derived: derived,
}
if derived {
p.KDFMode = kdfMode
}
// Generate a 256bit key
@ -88,6 +129,11 @@ func pathKeys() *framework.Path {
Type: framework.TypeString,
Description: "Name of the key",
},
"derived": &framework.FieldSchema{
Type: framework.TypeBool,
Description: "Enables key derivation mode. This allows for per-transaction unique keys",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
@ -104,6 +150,7 @@ func pathKeys() *framework.Path {
func pathPolicyWrite(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
derived := d.Get("derived").(bool)
// Check if the policy already exists
existing, err := getPolicy(req, name)
@ -115,7 +162,7 @@ func pathPolicyWrite(
}
// Generate the policy
_, err = generatePolicy(req.Storage, name)
_, err = generatePolicy(req.Storage, name, derived)
return nil, err
}
@ -135,8 +182,12 @@ func pathPolicyRead(
Data: map[string]interface{}{
"name": p.Name,
"cipher_mode": p.CipherMode,
"derived": p.Derived,
},
}
if p.Derived {
resp.Data["kdf_mode"] = p.KDFMode
}
return resp, nil
}

View file

@ -41,8 +41,12 @@ func pathRawRead(
"name": p.Name,
"key": p.Key,
"cipher_mode": p.CipherMode,
"derived": p.Derived,
},
}
if p.Derived {
resp.Data["kdf_mode"] = p.KDFMode
}
return resp, nil
}

77
helper/kdf/kdf.go Normal file
View file

@ -0,0 +1,77 @@
// This package is used to implement Key Derivation Functions (KDF)
// based on the recommendations of NIST SP 800-108. These are useful
// for generating unique-per-transaction keys, or situations in which
// a key hierarchy may be useful.
package kdf
import (
"crypto/hmac"
"crypto/sha256"
"encoding/binary"
"fmt"
)
// PRF is a psuedo-random function that takes a key or seed,
// as well as additional binary data and generates output that is
// indistinguishable from random. Examples are cryptographic hash
// functions or block ciphers.
type PRF func([]byte, []byte) ([]byte, error)
// CounterMode implements the counter mode KDF that uses a psuedo-random-function (PRF)
// along with a counter to generate derived keys. The KDF takes a base key
// a derivation context, and the requried number of output bits.
func CounterMode(prf PRF, prfLen uint32, key []byte, context []byte, bits uint32) ([]byte, error) {
// Ensure the PRF is byte aligned
if prfLen%8 != 0 {
return nil, fmt.Errorf("PRF must be byte aligned")
}
// Ensure the bits required are byte aligned
if bits%8 != 0 {
return nil, fmt.Errorf("bits required must be byte aligned")
}
// Determine the number of rounds required
rounds := bits / prfLen
if bits%prfLen != 0 {
rounds++
}
// Allocate and setup the input
input := make([]byte, 4+len(context)+4)
copy(input[4:], context)
binary.BigEndian.PutUint32(input[4+len(context):], bits)
// Iteratively generate more key material
var out []byte
var i uint32
for i = 0; i < rounds; i++ {
// Update the counter in the input string
binary.BigEndian.PutUint32(input[:4], i)
// Compute a more key material
part, err := prf(key, input)
if err != nil {
return nil, err
}
if uint32(len(part)*8) != prfLen {
return nil, fmt.Errorf("PRF length mis-match (%d vs %d)", len(part)*8, prfLen)
}
out = append(out, part...)
}
// Return the desired number of output bytes
return out[:bits/8], nil
}
const (
// HMACSHA256PRFLen is the length of output from HMACSHA256PRF
HMACSHA256PRFLen uint32 = 256
)
// HMACSHA256PRF is a pseudo-random-function (PRF) that uses an HMAC-SHA256
func HMACSHA256PRF(key []byte, data []byte) ([]byte, error) {
hash := hmac.New(sha256.New, key)
hash.Write(data)
return hash.Sum(nil), nil
}

72
helper/kdf/kdf_test.go Normal file
View file

@ -0,0 +1,72 @@
package kdf
import (
"bytes"
"testing"
)
func TestCounterMode(t *testing.T) {
key := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
context := []byte("the quick brown fox")
prf := HMACSHA256PRF
prfLen := HMACSHA256PRFLen
// Expect256 was generated in python with
// import hashlib, hmac
// hash = hashlib.sha256
// context = "the quick brown fox"
// key = "".join([chr(x) for x in range(1, 17)])
// inp = "\x00\x00\x00\x00"+context+"\x00\x00\x01\x00"
// digest = hmac.HMAC(key, inp, hash).digest()
// print [ord(x) for x in digest]
expect256 := []byte{219, 25, 238, 6, 185, 236, 180, 64, 248, 152, 251,
153, 79, 5, 141, 222, 66, 200, 66, 143, 40, 3, 101, 221, 206, 163, 102,
80, 88, 234, 87, 157}
for _, l := range []uint32{128, 256, 384, 1024} {
out, err := CounterMode(prf, prfLen, key, context, l)
if err != nil {
t.Fatalf("err: %v", err)
}
if uint32(len(out)*8) != l {
t.Fatalf("bad length: %#v", out)
}
if bytes.Contains(out, key) {
t.Fatalf("output contains key")
}
if l == 256 && !bytes.Equal(out, expect256) {
t.Fatalf("mis-match")
}
}
}
func TestHMACSHA256PRF(t *testing.T) {
key := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
data := []byte("foobarbaz")
out, err := HMACSHA256PRF(key, data)
if err != nil {
t.Fatalf("err: %v", err)
}
if uint32(len(out)*8) != HMACSHA256PRFLen {
t.Fatalf("Bad len")
}
// Expect was generated in python with:
// import hashlib, hmac
// hash = hashlib.sha256
// msg = "foobarbaz"
// key = "".join([chr(x) for x in range(1, 17)])
// hm = hmac.HMAC(key, msg, hash)
// print [ord(x) for x in hm.digest()]
expect := []byte{9, 50, 146, 8, 188, 130, 150, 107, 205, 147, 82, 170,
253, 183, 26, 38, 167, 194, 220, 111, 56, 118, 219, 209, 31, 52, 137,
90, 246, 133, 191, 124}
if !bytes.Equal(expect, out) {
t.Fatalf("mis-matched output")
}
}

View file

@ -21,6 +21,11 @@ application developers and pushes the burden onto the operators of Vault.
Operators of Vault generally include the security team at an organization,
which means they can ensure that data is encrypted/decrypted properly.
As of Vault 0.2, the transit backend also supports doing key derivation. This
allows data to be encrypted within a context such that the same context must be
used for decryption. This can be used to enable per transaction unique keys which
further increase the security of data at rest.
Additionally, since encrypt/decrypt operations must enter the audit log,
any decryption event is recorded.
@ -42,7 +47,7 @@ many different applications can use the transit backend with independent keys.
This is done by doing a write against the backend:
```
$ vault write transit/keys/foo test=1
$ vault write -f transit/keys/foo
Success! Data written to: transit/keys/foo
```
@ -54,6 +59,7 @@ $ vault read transit/keys/foo
Key Value
name foo
cipher_mode aes-gcm
derived false
````
We can read from the `raw/` endpoint to see the encryption key itself:
@ -64,6 +70,7 @@ Key Value
name foo
cipher_mode aes-gcm
key PhKFTALCmhAhVQfMBAH4+UwJ6J2gybapUH9BsrtIgR8=
derived false
````
Here we can see that the randomly generated encryption key being used, as
@ -118,7 +125,16 @@ only encrypt or decrypt using the named keys they need access to.
<dt>Parameters</dt>
<dd>
None
<ul>
<li>
<span class="param">derived</span>
<span class="param-flags">optional</span>
Boolean flag indicating if key derivation MUST be used.
If enabled, all encrypt/decrypt requests to this named key
must provide a context which is used for key derivation.
Defaults to false.
</li>
</ul>
</dd>
<dt>Returns</dt>
@ -155,6 +171,8 @@ only encrypt or decrypt using the named keys they need access to.
"data": {
"name": "foo",
"cipher_mode": "aes-gcm",
"derived": "true",
"kdf_mode": "hmac-sha256-counter",
}
}
```
@ -213,6 +231,12 @@ only encrypt or decrypt using the named keys they need access to.
<span class="param-flags">required</span>
The plaintext to encrypt, provided as base64 encoded.
</li>
<li>
<span class="param">context</span>
<span class="param-flags">optional</span>
The key derivation context, provided as base64 encoded.
Must be provided if the derivation enabled.
</li>
</ul>
</dd>
@ -253,6 +277,12 @@ only encrypt or decrypt using the named keys they need access to.
<span class="param-flags">required</span>
The ciphertext to decrypt, provided as returned by encrypt.
</li>
<li>
<span class="param">context</span>
<span class="param-flags">optional</span>
The key derivation context, provided as base64 encoded.
Must be provided if the derivation enabled.
</li>
</ul>
</dd>
@ -300,6 +330,8 @@ only encrypt or decrypt using the named keys they need access to.
"name": "foo",
"cipher_mode": "aes-gcm",
"key": "PhKFTALCmhAhVQfMBAH4+UwJ6J2gybapUH9BsrtIgR8="
"derived": "true",
"kdf_mode": "hmac-sha256-counter",
}
}
```