open-vault/builtin/logical/pki/cert_util_test.go
2020-02-04 23:46:38 +01:00

193 lines
4.4 KiB
Go

package pki
import (
"context"
"fmt"
"reflect"
"testing"
"strings"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
func TestPki_FetchCertBySerial(t *testing.T) {
storage := &logical.InmemStorage{}
cases := map[string]struct {
Req *logical.Request
Prefix string
Serial string
}{
"valid cert": {
&logical.Request{
Storage: storage,
},
"certs/",
"00:00:00:00:00:00:00:00",
},
"revoked cert": {
&logical.Request{
Storage: storage,
},
"revoked/",
"11:11:11:11:11:11:11:11",
},
}
// Test for colon-based paths in storage
for name, tc := range cases {
storageKey := fmt.Sprintf("%s%s", tc.Prefix, tc.Serial)
err := storage.Put(context.Background(), &logical.StorageEntry{
Key: storageKey,
Value: []byte("some data"),
})
if err != nil {
t.Fatalf("error writing to storage on %s colon-based storage path: %s", name, err)
}
certEntry, err := fetchCertBySerial(context.Background(), tc.Req, tc.Prefix, tc.Serial)
if err != nil {
t.Fatalf("error on %s for colon-based storage path: %s", name, err)
}
// Check for non-nil on valid/revoked certs
if certEntry == nil {
t.Fatalf("nil on %s for colon-based storage path", name)
}
// Ensure that cert serials are converted/updated after fetch
expectedKey := tc.Prefix + normalizeSerial(tc.Serial)
se, err := storage.Get(context.Background(), expectedKey)
if err != nil {
t.Fatalf("error on %s for colon-based storage path:%s", name, err)
}
if strings.Compare(expectedKey, se.Key) != 0 {
t.Fatalf("expected: %s, got: %s", expectedKey, certEntry.Key)
}
}
// Reset storage
storage = &logical.InmemStorage{}
// Test for hyphen-base paths in storage
for name, tc := range cases {
storageKey := tc.Prefix + normalizeSerial(tc.Serial)
err := storage.Put(context.Background(), &logical.StorageEntry{
Key: storageKey,
Value: []byte("some data"),
})
if err != nil {
t.Fatalf("error writing to storage on %s hyphen-based storage path: %s", name, err)
}
certEntry, err := fetchCertBySerial(context.Background(), tc.Req, tc.Prefix, tc.Serial)
if err != nil || certEntry == nil {
t.Fatalf("error on %s for hyphen-based storage path: err: %v, entry: %v", name, err, certEntry)
}
}
noConvCases := map[string]struct {
Req *logical.Request
Prefix string
Serial string
}{
"ca": {
&logical.Request{
Storage: storage,
},
"",
"ca",
},
"crl": {
&logical.Request{
Storage: storage,
},
"",
"crl",
},
}
// Test for ca and crl case
for name, tc := range noConvCases {
err := storage.Put(context.Background(), &logical.StorageEntry{
Key: tc.Serial,
Value: []byte("some data"),
})
if err != nil {
t.Fatalf("error writing to storage on %s: %s", name, err)
}
certEntry, err := fetchCertBySerial(context.Background(), tc.Req, tc.Prefix, tc.Serial)
if err != nil || certEntry == nil {
t.Fatalf("error on %s: err: %v, entry: %v", name, err, certEntry)
}
}
}
// Demonstrate that multiple OUs in the name are handled in an
// order-preserving way.
func TestPki_MultipleOUs(t *testing.T) {
var b backend
fields := addCACommonFields(map[string]*framework.FieldSchema{})
apiData := &framework.FieldData{
Schema: fields,
Raw: map[string]interface{}{
"cn": "example.com",
"ttl": 3600,
},
}
input := &inputBundle{
apiData: apiData,
role: &roleEntry{
MaxTTL: 3600,
OU: []string{"Z", "E", "V"},
},
}
cb, err := generateCreationBundle(&b, input, nil, nil)
if err != nil {
t.Fatalf("Error: %v", err)
}
expected := []string{"Z", "E", "V"}
actual := cb.Params.Subject.OrganizationalUnit
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("Expected %v, got %v", expected, actual)
}
}
func TestPki_PermitFQDNs(t *testing.T) {
var b backend
fields := addCACommonFields(map[string]*framework.FieldSchema{})
apiData := &framework.FieldData{
Schema: fields,
Raw: map[string]interface{}{
"common_name": "example.com.",
"ttl": 3600,
},
}
input := &inputBundle{
apiData: apiData,
role: &roleEntry{
AllowAnyName: true,
MaxTTL: 3600,
EnforceHostnames: true,
},
}
cb, err := generateCreationBundle(&b, input, nil, nil)
if err != nil {
t.Fatalf("Error: %v", err)
}
expected := []string{"example.com."}
actual := cb.Params.DNSNames
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("Expected %v, got %v", expected, actual)
}
}