128 lines
2.9 KiB
Go
128 lines
2.9 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package transit
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"reflect"
|
|
"testing"
|
|
|
|
"github.com/hashicorp/vault/helper/random"
|
|
"github.com/hashicorp/vault/sdk/logical"
|
|
)
|
|
|
|
func TestTransit_Random(t *testing.T) {
|
|
var b *backend
|
|
sysView := logical.TestSystemView()
|
|
storage := &logical.InmemStorage{}
|
|
sysView.CachingDisabledVal = true
|
|
|
|
b, _ = Backend(context.Background(), &logical.BackendConfig{
|
|
StorageView: storage,
|
|
System: sysView,
|
|
})
|
|
|
|
req := &logical.Request{
|
|
Storage: storage,
|
|
Operation: logical.UpdateOperation,
|
|
Path: "random",
|
|
Data: map[string]interface{}{},
|
|
}
|
|
|
|
doRequest := func(req *logical.Request, errExpected bool, format string, numBytes int) {
|
|
getResponse := func() []byte {
|
|
resp, err := b.HandleRequest(context.Background(), req)
|
|
if err != nil && !errExpected {
|
|
t.Fatal(err)
|
|
}
|
|
if resp == nil {
|
|
t.Fatal("expected non-nil response")
|
|
}
|
|
if errExpected {
|
|
if !resp.IsError() {
|
|
t.Fatalf("bad: got error response: %#v", *resp)
|
|
}
|
|
return nil
|
|
}
|
|
if resp.IsError() {
|
|
t.Fatalf("bad: got error response: %#v", *resp)
|
|
}
|
|
if _, ok := resp.Data["random_bytes"]; !ok {
|
|
t.Fatal("no random_bytes found in response")
|
|
}
|
|
|
|
outputStr := resp.Data["random_bytes"].(string)
|
|
var outputBytes []byte
|
|
switch format {
|
|
case "base64":
|
|
outputBytes, err = base64.StdEncoding.DecodeString(outputStr)
|
|
case "hex":
|
|
outputBytes, err = hex.DecodeString(outputStr)
|
|
default:
|
|
t.Fatal("unknown format")
|
|
}
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
return outputBytes
|
|
}
|
|
|
|
rand1 := getResponse()
|
|
// Expected error
|
|
if rand1 == nil {
|
|
return
|
|
}
|
|
rand2 := getResponse()
|
|
if len(rand1) != numBytes || len(rand2) != numBytes {
|
|
t.Fatal("length of output random bytes not what is expected")
|
|
}
|
|
if reflect.DeepEqual(rand1, rand2) {
|
|
t.Fatal("found identical ouputs")
|
|
}
|
|
}
|
|
|
|
for _, source := range []string{"", "platform", "seal", "all"} {
|
|
req.Data["source"] = source
|
|
req.Data["bytes"] = 32
|
|
req.Data["format"] = "base64"
|
|
req.Path = "random"
|
|
// Test defaults
|
|
doRequest(req, false, "base64", 32)
|
|
|
|
// Test size selection in the path
|
|
req.Path = "random/24"
|
|
req.Data["format"] = "hex"
|
|
doRequest(req, false, "hex", 24)
|
|
|
|
if source != "" {
|
|
// Test source selection in the path
|
|
req.Path = fmt.Sprintf("random/%s", source)
|
|
req.Data["format"] = "hex"
|
|
doRequest(req, false, "hex", 32)
|
|
|
|
req.Path = fmt.Sprintf("random/%s/24", source)
|
|
req.Data["format"] = "hex"
|
|
doRequest(req, false, "hex", 24)
|
|
}
|
|
|
|
// Test bad input/format
|
|
req.Path = "random"
|
|
req.Data["format"] = "base92"
|
|
doRequest(req, true, "", 0)
|
|
|
|
req.Data["format"] = "hex"
|
|
req.Data["bytes"] = -1
|
|
doRequest(req, true, "", 0)
|
|
|
|
req.Data["format"] = "hex"
|
|
req.Data["bytes"] = random.APIMaxBytes + 1
|
|
|
|
doRequest(req, true, "", 0)
|
|
}
|
|
}
|