Merge pull request #1028 from hashicorp/transit-fuzz-test
Add transit fuzz test
This commit is contained in:
commit
1adc2fbf77
2
Makefile
2
Makefile
|
@ -19,7 +19,7 @@ dev: generate
|
|||
|
||||
# test runs the unit tests and vets the code
|
||||
test: generate
|
||||
VAULT_TOKEN= TF_ACC= godep go test $(TEST) $(TESTARGS) -timeout=60s -parallel=4
|
||||
VAULT_TOKEN= TF_ACC= godep go test $(TEST) $(TESTARGS) -timeout=120s -parallel=4
|
||||
|
||||
# testacc runs acceptance tests
|
||||
testacc: generate
|
||||
|
|
|
@ -3,11 +3,15 @@ package transit
|
|||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
logicaltest "github.com/hashicorp/vault/logical/testing"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
@ -541,3 +545,135 @@ func TestKeyUpgrade(t *testing.T) {
|
|||
t.Errorf("bad key migration, result is %#v", p.Keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyFuzzing(t *testing.T) {
|
||||
be := Backend()
|
||||
|
||||
storage := &logical.InmemStorage{}
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
funcs := []string{"encrypt", "decrypt", "rotate", "change_min_version"}
|
||||
keys := []string{"test1", "test2", "test3"}
|
||||
|
||||
// This is the goroutine loop
|
||||
doFuzzy := func() {
|
||||
// Check for panics, otherwise notify we're done
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
t.Fatalf("got a panic: %v", err)
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
// Holds the latest encrypted value for each key
|
||||
latestEncryptedText := map[string]string{}
|
||||
|
||||
startTime := time.Now()
|
||||
req := &logical.Request{
|
||||
Storage: storage,
|
||||
Data: map[string]interface{}{},
|
||||
}
|
||||
fd := &framework.FieldData{}
|
||||
|
||||
for {
|
||||
// Stop after 10 seconds
|
||||
if time.Now().Sub(startTime) > 10*time.Second {
|
||||
return
|
||||
}
|
||||
|
||||
// Pick a function and a key
|
||||
chosenFunc := funcs[rand.Int()%len(funcs)]
|
||||
chosenKey := keys[rand.Int()%len(keys)]
|
||||
|
||||
fd.Raw = map[string]interface{}{
|
||||
"name": chosenKey,
|
||||
}
|
||||
fd.Schema = be.pathKeys().Fields
|
||||
|
||||
// Try to write the key to make sure it exists
|
||||
_, err := be.pathPolicyWrite(req, fd)
|
||||
if err != nil {
|
||||
t.Errorf("got an error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
switch chosenFunc {
|
||||
// Encrypt our plaintext and store the result
|
||||
case "encrypt":
|
||||
fd.Raw["plaintext"] = base64.StdEncoding.EncodeToString([]byte(testPlaintext))
|
||||
fd.Schema = be.pathEncrypt().Fields
|
||||
resp, err := be.pathEncryptWrite(req, fd)
|
||||
if err != nil {
|
||||
t.Errorf("got an error: %v, resp is %#v", err, *resp)
|
||||
return
|
||||
}
|
||||
latestEncryptedText[chosenKey] = resp.Data["ciphertext"].(string)
|
||||
|
||||
// Rotate to a new key version
|
||||
case "rotate":
|
||||
fd.Schema = be.pathRotate().Fields
|
||||
resp, err := be.pathRotateWrite(req, fd)
|
||||
if err != nil {
|
||||
t.Errorf("got an error: %v, resp is %#v", err, *resp)
|
||||
return
|
||||
}
|
||||
|
||||
// Decrypt the ciphertext and compare the result
|
||||
case "decrypt":
|
||||
ct := latestEncryptedText[chosenKey]
|
||||
if ct == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
fd.Raw["ciphertext"] = ct
|
||||
fd.Schema = be.pathDecrypt().Fields
|
||||
resp, err := be.pathDecryptWrite(req, fd)
|
||||
if err != nil {
|
||||
// This could well happen since the min version is jumping around
|
||||
if resp.Data["error"].(string) == ErrTooOld {
|
||||
continue
|
||||
}
|
||||
t.Errorf("got an error: %v, resp is %#v, ciphertext was %s", err, *resp, latestEncryptedText[chosenKey])
|
||||
return
|
||||
}
|
||||
ptb64 := resp.Data["plaintext"].(string)
|
||||
pt, err := base64.StdEncoding.DecodeString(ptb64)
|
||||
if err != nil {
|
||||
t.Errorf("got an error decoding base64 plaintext: %v", err)
|
||||
return
|
||||
}
|
||||
if string(pt) != testPlaintext {
|
||||
t.Fatalf("got bad plaintext back: %s", pt)
|
||||
}
|
||||
|
||||
// Change the min version, which also tests the archive functionality
|
||||
case "change_min_version":
|
||||
resp, err := be.pathPolicyRead(req, fd)
|
||||
if err != nil {
|
||||
t.Errorf("got an error reading policy %s: %v", chosenKey, err)
|
||||
return
|
||||
}
|
||||
latestVersion := resp.Data["latest_version"].(int)
|
||||
|
||||
// keys start at version 1 so we want [1, latestVersion] not [0, latestVersion)
|
||||
setVersion := (rand.Int() % latestVersion) + 1
|
||||
fd.Raw["min_decryption_version"] = setVersion
|
||||
fd.Schema = be.pathConfig().Fields
|
||||
resp, err = be.pathConfigWrite(req, fd)
|
||||
if err != nil {
|
||||
t.Errorf("got an error setting min decryption version: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Spawn 1000 of these workers for 10 seconds
|
||||
for i := 0; i < 1000; i++ {
|
||||
wg.Add(1)
|
||||
go doFuzzy()
|
||||
}
|
||||
|
||||
// Wait for them all to finish
|
||||
wg.Wait()
|
||||
}
|
||||
|
|
|
@ -20,6 +20,8 @@ import (
|
|||
const (
|
||||
// kdfMode is the only KDF mode currently supported
|
||||
kdfMode = "hmac-sha256-counter"
|
||||
|
||||
ErrTooOld = "ciphertext version is disallowed by policy (too old)"
|
||||
)
|
||||
|
||||
// policyCache implements a simple locking cache of policies
|
||||
|
@ -539,7 +541,7 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) {
|
|||
}
|
||||
|
||||
if p.MinDecryptionVersion > 0 && ver < p.MinDecryptionVersion {
|
||||
return "", certutil.UserError{Err: "ciphertext version is disallowed by policy (too old)"}
|
||||
return "", certutil.UserError{Err: ErrTooOld}
|
||||
}
|
||||
|
||||
// Derive the key that should be used
|
||||
|
|
Loading…
Reference in New Issue