// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package pki import ( "context" "fmt" "reflect" "strings" "testing" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/logical" ) func TestPki_FetchCertBySerial(t *testing.T) { t.Parallel() b, storage := CreateBackendWithStorage(t) sc := b.makeStorageContext(ctx, storage) cases := map[string]struct { Req *logical.Request Prefix string Serial string }{ "valid cert": { &logical.Request{ Storage: storage, }, "certs/", "00:00:00:00:00:00:00:00", }, "revoked cert": { &logical.Request{ Storage: storage, }, "revoked/", "11:11:11:11:11:11:11:11", }, } // Test for colon-based paths in storage for name, tc := range cases { storageKey := fmt.Sprintf("%s%s", tc.Prefix, tc.Serial) err := storage.Put(context.Background(), &logical.StorageEntry{ Key: storageKey, Value: []byte("some data"), }) if err != nil { t.Fatalf("error writing to storage on %s colon-based storage path: %s", name, err) } certEntry, err := fetchCertBySerial(sc, tc.Prefix, tc.Serial) if err != nil { t.Fatalf("error on %s for colon-based storage path: %s", name, err) } // Check for non-nil on valid/revoked certs if certEntry == nil { t.Fatalf("nil on %s for colon-based storage path", name) } // Ensure that cert serials are converted/updated after fetch expectedKey := tc.Prefix + normalizeSerial(tc.Serial) se, err := storage.Get(context.Background(), expectedKey) if err != nil { t.Fatalf("error on %s for colon-based storage path:%s", name, err) } if strings.Compare(expectedKey, se.Key) != 0 { t.Fatalf("expected: %s, got: %s", expectedKey, certEntry.Key) } } // Reset storage storage = &logical.InmemStorage{} // Test for hyphen-base paths in storage for name, tc := range cases { storageKey := tc.Prefix + normalizeSerial(tc.Serial) err := storage.Put(context.Background(), &logical.StorageEntry{ Key: storageKey, Value: []byte("some data"), }) if err != nil { t.Fatalf("error writing to storage on %s hyphen-based storage path: %s", name, err) } certEntry, err := fetchCertBySerial(sc, tc.Prefix, tc.Serial) if err != nil || certEntry == nil { t.Fatalf("error on %s for hyphen-based storage path: err: %v, entry: %v", name, err, certEntry) } } } // Demonstrate that multiple OUs in the name are handled in an // order-preserving way. func TestPki_MultipleOUs(t *testing.T) { t.Parallel() var b backend fields := addCACommonFields(map[string]*framework.FieldSchema{}) apiData := &framework.FieldData{ Schema: fields, Raw: map[string]interface{}{ "cn": "example.com", "ttl": 3600, }, } input := &inputBundle{ apiData: apiData, role: &roleEntry{ MaxTTL: 3600, OU: []string{"Z", "E", "V"}, }, } cb, _, err := generateCreationBundle(&b, input, nil, nil) if err != nil { t.Fatalf("Error: %v", err) } expected := []string{"Z", "E", "V"} actual := cb.Params.Subject.OrganizationalUnit if !reflect.DeepEqual(expected, actual) { t.Fatalf("Expected %v, got %v", expected, actual) } } func TestPki_PermitFQDNs(t *testing.T) { t.Parallel() var b backend fields := addCACommonFields(map[string]*framework.FieldSchema{}) cases := map[string]struct { input *inputBundle expectedDnsNames []string expectedEmails []string }{ "base valid case": { input: &inputBundle{ apiData: &framework.FieldData{ Schema: fields, Raw: map[string]interface{}{ "common_name": "example.com.", "ttl": 3600, }, }, role: &roleEntry{ AllowAnyName: true, MaxTTL: 3600, EnforceHostnames: true, }, }, expectedDnsNames: []string{"example.com."}, expectedEmails: []string{}, }, "case insensitivity validation": { input: &inputBundle{ apiData: &framework.FieldData{ Schema: fields, Raw: map[string]interface{}{ "common_name": "Example.Net", "alt_names": "eXaMPLe.COM", "ttl": 3600, }, }, role: &roleEntry{ AllowedDomains: []string{"example.net", "EXAMPLE.COM"}, AllowBareDomains: true, MaxTTL: 3600, }, }, expectedDnsNames: []string{"Example.Net", "eXaMPLe.COM"}, expectedEmails: []string{}, }, "case email as AllowedDomain with bare domains": { input: &inputBundle{ apiData: &framework.FieldData{ Schema: fields, Raw: map[string]interface{}{ "common_name": "test@testemail.com", "ttl": 3600, }, }, role: &roleEntry{ AllowedDomains: []string{"test@testemail.com"}, AllowBareDomains: true, MaxTTL: 3600, }, }, expectedDnsNames: []string{}, expectedEmails: []string{"test@testemail.com"}, }, "case email common name with bare domains": { input: &inputBundle{ apiData: &framework.FieldData{ Schema: fields, Raw: map[string]interface{}{ "common_name": "test@testemail.com", "ttl": 3600, }, }, role: &roleEntry{ AllowedDomains: []string{"testemail.com"}, AllowBareDomains: true, MaxTTL: 3600, }, }, expectedDnsNames: []string{}, expectedEmails: []string{"test@testemail.com"}, }, } for name, testCase := range cases { name := name testCase := testCase t.Run(name, func(t *testing.T) { cb, _, err := generateCreationBundle(&b, testCase.input, nil, nil) if err != nil { t.Fatalf("Error: %v", err) } actualDnsNames := cb.Params.DNSNames if !reflect.DeepEqual(testCase.expectedDnsNames, actualDnsNames) { t.Fatalf("Expected dns names %v, got %v", testCase.expectedDnsNames, actualDnsNames) } actualEmails := cb.Params.EmailAddresses if !reflect.DeepEqual(testCase.expectedEmails, actualEmails) { t.Fatalf("Expected email addresses %v, got %v", testCase.expectedEmails, actualEmails) } }) } }