open-vault/builtin/logical/pki/backend_test.go

3285 lines
92 KiB
Go
Raw Normal View History

package pki
import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/pem"
"fmt"
"io/ioutil"
"math"
"math/big"
mathrand "math/rand"
2015-10-14 15:46:01 +00:00
"net"
2018-06-15 19:32:25 +00:00
"net/url"
"os"
2015-10-14 15:46:01 +00:00
"reflect"
"sort"
"strconv"
"strings"
2018-02-20 05:03:45 +00:00
"sync"
"testing"
"time"
"github.com/fatih/structs"
"github.com/go-test/deep"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/credential/userpass"
logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault"
"github.com/mitchellh/mapstructure"
2018-02-10 15:07:10 +00:00
"golang.org/x/net/idna"
)
var (
stepCount = 0
serialUnderTest string
parsedKeyUsageUnderTest int
)
func TestPKI_RequireCN(t *testing.T) {
coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"pki": Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
client := cluster.Cores[0].Client
var err error
err = client.Sys().Mount("pki", &api.MountInput{
Type: "pki",
Config: api.MountConfigInput{
DefaultLeaseTTL: "16h",
MaxLeaseTTL: "32h",
},
})
if err != nil {
t.Fatal(err)
}
resp, err := client.Logical().Write("pki/root/generate/internal", map[string]interface{}{
"common_name": "myvault.com",
})
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("expected ca info")
}
// Create a role which does require CN (default)
_, err = client.Logical().Write("pki/roles/example", map[string]interface{}{
"allowed_domains": "foobar.com,zipzap.com,abc.com,xyz.com",
"allow_bare_domains": true,
"allow_subdomains": true,
"max_ttl": "2h",
})
if err != nil {
t.Fatal(err)
}
// Issue a cert with require_cn set to true and with common name supplied.
// It should succeed.
resp, err = client.Logical().Write("pki/issue/example", map[string]interface{}{
"common_name": "foobar.com",
})
if err != nil {
t.Fatal(err)
}
// Issue a cert with require_cn set to true and with out supplying the
// common name. It should error out.
resp, err = client.Logical().Write("pki/issue/example", map[string]interface{}{})
if err == nil {
t.Fatalf("expected an error due to missing common_name")
}
// Modify the role to make the common name optional
_, err = client.Logical().Write("pki/roles/example", map[string]interface{}{
"allowed_domains": "foobar.com,zipzap.com,abc.com,xyz.com",
"allow_bare_domains": true,
"allow_subdomains": true,
"max_ttl": "2h",
"require_cn": false,
})
if err != nil {
t.Fatal(err)
}
// Issue a cert with require_cn set to false and without supplying the
// common name. It should succeed.
resp, err = client.Logical().Write("pki/issue/example", map[string]interface{}{})
if err != nil {
t.Fatal(err)
}
if resp.Data["certificate"] == "" {
t.Fatalf("expected a cert to be generated")
}
// Issue a cert with require_cn set to false and with a common name. It
// should succeed.
resp, err = client.Logical().Write("pki/issue/example", map[string]interface{}{})
if err != nil {
t.Fatal(err)
}
if resp.Data["certificate"] == "" {
t.Fatalf("expected a cert to be generated")
}
}
2015-10-14 15:46:01 +00:00
func TestBackend_CSRValues(t *testing.T) {
2018-02-20 05:03:45 +00:00
initTest.Do(setCerts)
2015-10-14 15:46:01 +00:00
defaultLeaseTTLVal := time.Hour * 24
maxLeaseTTLVal := time.Hour * 24 * 32
b, err := Factory(context.Background(), &logical.BackendConfig{
2015-10-14 15:46:01 +00:00
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: defaultLeaseTTLVal,
MaxLeaseTTLVal: maxLeaseTTLVal,
},
})
if err != nil {
t.Fatalf("Unable to create backend: %s", err)
}
testCase := logicaltest.TestCase{
2018-11-07 01:21:24 +00:00
LogicalBackend: b,
Steps: []logicaltest.TestStep{},
2015-10-14 15:46:01 +00:00
}
intdata := map[string]interface{}{}
reqdata := map[string]interface{}{}
testCase.Steps = append(testCase.Steps, generateCSRSteps(t, ecCACert, ecCAKey, intdata, reqdata)...)
logicaltest.Test(t, testCase)
}
func TestBackend_URLsCRUD(t *testing.T) {
2018-02-20 05:03:45 +00:00
initTest.Do(setCerts)
defaultLeaseTTLVal := time.Hour * 24
maxLeaseTTLVal := time.Hour * 24 * 32
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: defaultLeaseTTLVal,
MaxLeaseTTLVal: maxLeaseTTLVal,
},
})
if err != nil {
t.Fatalf("Unable to create backend: %s", err)
}
testCase := logicaltest.TestCase{
2018-11-07 01:21:24 +00:00
LogicalBackend: b,
Steps: []logicaltest.TestStep{},
}
intdata := map[string]interface{}{}
reqdata := map[string]interface{}{}
testCase.Steps = append(testCase.Steps, generateURLSteps(t, ecCACert, ecCAKey, intdata, reqdata)...)
logicaltest.Test(t, testCase)
}
// Generates and tests steps that walk through the various possibilities
// of role flags to ensure that they are properly restricted
func TestBackend_Roles(t *testing.T) {
cases := []struct {
name string
key, cert *string
useCSR bool
}{
{"RSA", &rsaCAKey, &rsaCACert, false},
{"RSACSR", &rsaCAKey, &rsaCACert, true},
{"EC", &ecCAKey, &ecCACert, false},
{"ECCSR", &ecCAKey, &ecCACert, true},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
initTest.Do(setCerts)
defaultLeaseTTLVal := time.Hour * 24
maxLeaseTTLVal := time.Hour * 24 * 32
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: defaultLeaseTTLVal,
MaxLeaseTTLVal: maxLeaseTTLVal,
},
})
if err != nil {
t.Fatalf("Unable to create backend: %s", err)
}
testCase := logicaltest.TestCase{
LogicalBackend: b,
Steps: []logicaltest.TestStep{
{
Operation: logical.UpdateOperation,
Path: "config/ca",
Data: map[string]interface{}{
"pem_bundle": *tc.key + "\n" + *tc.cert,
},
},
},
}
testCase.Steps = append(testCase.Steps, generateRoleSteps(t, tc.useCSR)...)
if len(os.Getenv("VAULT_VERBOSE_PKITESTS")) > 0 {
for i, v := range testCase.Steps {
data := map[string]interface{}{}
var keys []string
for k := range v.Data {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
interf := v.Data[k]
switch v := interf.(type) {
case bool:
if !v {
continue
}
case int:
if v == 0 {
continue
}
case []string:
if len(v) == 0 {
continue
}
case string:
if v == "" {
continue
}
lines := strings.Split(v, "\n")
if len(lines) > 1 {
data[k] = lines[0] + " ... (truncated)"
continue
}
}
data[k] = interf
}
t.Logf("Step %d:\n%s %s err=%v %+v\n\n", i+1, v.Operation, v.Path, v.ErrorOk, data)
}
}
logicaltest.Test(t, testCase)
})
}
}
// Performs some validity checking on the returned bundles
func checkCertsAndPrivateKey(keyType string, key crypto.Signer, usage x509.KeyUsage, extUsage x509.ExtKeyUsage, validity time.Duration, certBundle *certutil.CertBundle) (*certutil.ParsedCertBundle, error) {
parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil {
return nil, fmt.Errorf("error parsing cert bundle: %s", err)
}
if key != nil {
switch keyType {
case "rsa":
parsedCertBundle.PrivateKeyType = certutil.RSAPrivateKey
parsedCertBundle.PrivateKey = key
parsedCertBundle.PrivateKeyBytes = x509.MarshalPKCS1PrivateKey(key.(*rsa.PrivateKey))
case "ec":
parsedCertBundle.PrivateKeyType = certutil.ECPrivateKey
parsedCertBundle.PrivateKey = key
parsedCertBundle.PrivateKeyBytes, err = x509.MarshalECPrivateKey(key.(*ecdsa.PrivateKey))
if err != nil {
return nil, fmt.Errorf("error parsing EC key: %s", err)
}
}
}
switch {
case parsedCertBundle.Certificate == nil:
return nil, fmt.Errorf("did not find a certificate in the cert bundle")
case len(parsedCertBundle.CAChain) == 0 || parsedCertBundle.CAChain[0].Certificate == nil:
return nil, fmt.Errorf("did not find a CA in the cert bundle")
case parsedCertBundle.PrivateKey == nil:
return nil, fmt.Errorf("did not find a private key in the cert bundle")
case parsedCertBundle.PrivateKeyType == certutil.UnknownPrivateKey:
return nil, fmt.Errorf("could not figure out type of private key")
}
switch {
case parsedCertBundle.PrivateKeyType == certutil.RSAPrivateKey && keyType != "rsa":
fallthrough
case parsedCertBundle.PrivateKeyType == certutil.ECPrivateKey && keyType != "ec":
return nil, fmt.Errorf("given key type does not match type found in bundle")
}
cert := parsedCertBundle.Certificate
if usage != cert.KeyUsage {
return nil, fmt.Errorf("expected usage of %#v, got %#v; ext usage is %#v", usage, cert.KeyUsage, cert.ExtKeyUsage)
}
// There should only be one ext usage type, because only one is requested
// in the tests
if len(cert.ExtKeyUsage) != 1 {
return nil, fmt.Errorf("got wrong size key usage in generated cert; expected 1, values are %#v", cert.ExtKeyUsage)
}
switch extUsage {
case x509.ExtKeyUsageEmailProtection:
if cert.ExtKeyUsage[0] != x509.ExtKeyUsageEmailProtection {
return nil, fmt.Errorf("bad extended key usage")
}
case x509.ExtKeyUsageServerAuth:
if cert.ExtKeyUsage[0] != x509.ExtKeyUsageServerAuth {
return nil, fmt.Errorf("bad extended key usage")
}
case x509.ExtKeyUsageClientAuth:
if cert.ExtKeyUsage[0] != x509.ExtKeyUsageClientAuth {
return nil, fmt.Errorf("bad extended key usage")
}
case x509.ExtKeyUsageCodeSigning:
if cert.ExtKeyUsage[0] != x509.ExtKeyUsageCodeSigning {
return nil, fmt.Errorf("bad extended key usage")
}
}
if math.Abs(float64(time.Now().Add(validity).Unix()-cert.NotAfter.Unix())) > 20 {
return nil, fmt.Errorf("certificate validity end: %s; expected within 20 seconds of %s", cert.NotAfter.Format(time.RFC3339), time.Now().Add(validity).Format(time.RFC3339))
}
return parsedCertBundle, nil
}
func generateURLSteps(t *testing.T, caCert, caKey string, intdata, reqdata map[string]interface{}) []logicaltest.TestStep {
expected := certutil.URLEntries{
IssuingCertificates: []string{
"http://example.com/ca1",
"http://example.com/ca2",
},
CRLDistributionPoints: []string{
"http://example.com/crl1",
"http://example.com/crl2",
},
OCSPServers: []string{
"http://example.com/ocsp1",
"http://example.com/ocsp2",
},
}
csrTemplate := x509.CertificateRequest{
Subject: pkix.Name{
CommonName: "my@example.com",
},
}
priv1024, _ := rsa.GenerateKey(rand.Reader, 1024)
csr1024, _ := x509.CreateCertificateRequest(rand.Reader, &csrTemplate, priv1024)
csrPem1024 := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE REQUEST",
Bytes: csr1024,
})))
priv2048, _ := rsa.GenerateKey(rand.Reader, 2048)
csr2048, _ := x509.CreateCertificateRequest(rand.Reader, &csrTemplate, priv2048)
csrPem2048 := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE REQUEST",
Bytes: csr2048,
})))
ret := []logicaltest.TestStep{
{
2016-01-07 15:30:47 +00:00
Operation: logical.UpdateOperation,
Path: "root/generate/exported",
Data: map[string]interface{}{
"common_name": "Root Cert",
"ttl": "180h",
},
Check: func(resp *logical.Response) error {
if resp.Secret != nil && resp.Secret.LeaseID != "" {
return fmt.Errorf("root returned with a lease")
}
return nil
},
},
{
2016-01-07 15:30:47 +00:00
Operation: logical.UpdateOperation,
Path: "config/urls",
Data: map[string]interface{}{
"issuing_certificates": strings.Join(expected.IssuingCertificates, ","),
"crl_distribution_points": strings.Join(expected.CRLDistributionPoints, ","),
"ocsp_servers": strings.Join(expected.OCSPServers, ","),
},
},
{
Operation: logical.ReadOperation,
Path: "config/urls",
Check: func(resp *logical.Response) error {
if resp.Data == nil {
return fmt.Errorf("no data returned")
}
var entries certutil.URLEntries
err := mapstructure.Decode(resp.Data, &entries)
if err != nil {
return err
}
if !reflect.DeepEqual(entries, expected) {
return fmt.Errorf("expected urls\n%#v\ndoes not match provided\n%#v\n", expected, entries)
}
return nil
},
},
{
2016-01-07 15:30:47 +00:00
Operation: logical.UpdateOperation,
Path: "root/sign-intermediate",
Data: map[string]interface{}{
2018-02-10 15:07:10 +00:00
"common_name": "intermediate.cert.com",
"csr": csrPem1024,
"format": "der",
},
ErrorOk: true,
Check: func(resp *logical.Response) error {
if !resp.IsError() {
return fmt.Errorf("expected an error response but did not get one")
}
if !strings.Contains(resp.Data["error"].(string), "2048") {
2018-03-20 18:54:10 +00:00
return fmt.Errorf("received an error but not about a 1024-bit key, error was: %s", resp.Data["error"].(string))
}
return nil
},
},
{
Operation: logical.UpdateOperation,
Path: "root/sign-intermediate",
Data: map[string]interface{}{
2018-02-10 15:07:10 +00:00
"common_name": "intermediate.cert.com",
"csr": csrPem2048,
"format": "der",
},
Check: func(resp *logical.Response) error {
certString := resp.Data["certificate"].(string)
if certString == "" {
return fmt.Errorf("no certificate returned")
}
if resp.Secret != nil && resp.Secret.LeaseID != "" {
return fmt.Errorf("signed intermediate returned with a lease")
}
certBytes, _ := base64.StdEncoding.DecodeString(certString)
certs, err := x509.ParseCertificates(certBytes)
if err != nil {
return fmt.Errorf("returned cert cannot be parsed: %v", err)
}
if len(certs) != 1 {
return fmt.Errorf("unexpected returned length of certificates: %d", len(certs))
}
cert := certs[0]
switch {
case !reflect.DeepEqual(expected.IssuingCertificates, cert.IssuingCertificateURL):
return fmt.Errorf("expected\n%#v\ngot\n%#v\n", expected.IssuingCertificates, cert.IssuingCertificateURL)
case !reflect.DeepEqual(expected.CRLDistributionPoints, cert.CRLDistributionPoints):
return fmt.Errorf("expected\n%#v\ngot\n%#v\n", expected.CRLDistributionPoints, cert.CRLDistributionPoints)
case !reflect.DeepEqual(expected.OCSPServers, cert.OCSPServer):
return fmt.Errorf("expected\n%#v\ngot\n%#v\n", expected.OCSPServers, cert.OCSPServer)
2018-02-10 15:07:10 +00:00
case !reflect.DeepEqual([]string{"intermediate.cert.com"}, cert.DNSNames):
return fmt.Errorf("expected\n%#v\ngot\n%#v\n", []string{"intermediate.cert.com"}, cert.DNSNames)
}
return nil
},
},
// Same as above but exclude adding to sans
{
Operation: logical.UpdateOperation,
Path: "root/sign-intermediate",
Data: map[string]interface{}{
2018-02-10 15:07:10 +00:00
"common_name": "intermediate.cert.com",
"csr": csrPem2048,
"format": "der",
"exclude_cn_from_sans": true,
},
Check: func(resp *logical.Response) error {
certString := resp.Data["certificate"].(string)
if certString == "" {
return fmt.Errorf("no certificate returned")
}
if resp.Secret != nil && resp.Secret.LeaseID != "" {
return fmt.Errorf("signed intermediate returned with a lease")
}
certBytes, _ := base64.StdEncoding.DecodeString(certString)
certs, err := x509.ParseCertificates(certBytes)
if err != nil {
return fmt.Errorf("returned cert cannot be parsed: %v", err)
}
if len(certs) != 1 {
return fmt.Errorf("unexpected returned length of certificates: %d", len(certs))
}
cert := certs[0]
switch {
case !reflect.DeepEqual(expected.IssuingCertificates, cert.IssuingCertificateURL):
return fmt.Errorf("expected\n%#v\ngot\n%#v\n", expected.IssuingCertificates, cert.IssuingCertificateURL)
case !reflect.DeepEqual(expected.CRLDistributionPoints, cert.CRLDistributionPoints):
return fmt.Errorf("expected\n%#v\ngot\n%#v\n", expected.CRLDistributionPoints, cert.CRLDistributionPoints)
case !reflect.DeepEqual(expected.OCSPServers, cert.OCSPServer):
return fmt.Errorf("expected\n%#v\ngot\n%#v\n", expected.OCSPServers, cert.OCSPServer)
case !reflect.DeepEqual([]string(nil), cert.DNSNames):
return fmt.Errorf("expected\n%#v\ngot\n%#v\n", []string(nil), cert.DNSNames)
}
return nil
},
},
}
return ret
}
2015-10-14 15:46:01 +00:00
func generateCSRSteps(t *testing.T, caCert, caKey string, intdata, reqdata map[string]interface{}) []logicaltest.TestStep {
csrTemplate := x509.CertificateRequest{
Subject: pkix.Name{
Country: []string{"MyCountry"},
PostalCode: []string{"MyPostalCode"},
SerialNumber: "MySerialNumber",
CommonName: "my@example.com",
},
DNSNames: []string{
"name1.example.com",
"name2.example.com",
"name3.example.com",
},
EmailAddresses: []string{
"name1@example.com",
"name2@example.com",
"name3@example.com",
},
IPAddresses: []net.IP{
net.ParseIP("::ff:1:2:3:4"),
net.ParseIP("::ff:5:6:7:8"),
},
}
priv, _ := rsa.GenerateKey(rand.Reader, 2048)
csr, _ := x509.CreateCertificateRequest(rand.Reader, &csrTemplate, priv)
csrPem := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{
2015-10-14 15:46:01 +00:00
Type: "CERTIFICATE REQUEST",
Bytes: csr,
})))
2015-10-14 15:46:01 +00:00
ret := []logicaltest.TestStep{
{
2016-01-07 15:30:47 +00:00
Operation: logical.UpdateOperation,
Path: "root/generate/exported",
2015-10-14 15:46:01 +00:00
Data: map[string]interface{}{
"common_name": "Root Cert",
"ttl": "180h",
"max_path_length": 0,
},
},
{
2016-01-07 15:30:47 +00:00
Operation: logical.UpdateOperation,
Path: "root/sign-intermediate",
Data: map[string]interface{}{
"use_csr_values": true,
"csr": csrPem,
"format": "der",
},
ErrorOk: true,
},
{
Operation: logical.DeleteOperation,
Path: "root",
},
{
2016-01-07 15:30:47 +00:00
Operation: logical.UpdateOperation,
Path: "root/generate/exported",
Data: map[string]interface{}{
"common_name": "Root Cert",
"ttl": "180h",
"max_path_length": 1,
2015-10-14 15:46:01 +00:00
},
},
{
2016-01-07 15:30:47 +00:00
Operation: logical.UpdateOperation,
Path: "root/sign-intermediate",
2015-10-14 15:46:01 +00:00
Data: map[string]interface{}{
"use_csr_values": true,
"csr": csrPem,
2015-10-14 15:46:01 +00:00
"format": "der",
},
Check: func(resp *logical.Response) error {
certString := resp.Data["certificate"].(string)
if certString == "" {
return fmt.Errorf("no certificate returned")
}
certBytes, _ := base64.StdEncoding.DecodeString(certString)
certs, err := x509.ParseCertificates(certBytes)
if err != nil {
return fmt.Errorf("returned cert cannot be parsed: %v", err)
}
if len(certs) != 1 {
return fmt.Errorf("unexpected returned length of certificates: %d", len(certs))
}
cert := certs[0]
if cert.MaxPathLen != 0 {
return fmt.Errorf("max path length of %d does not match the requested of 3", cert.MaxPathLen)
}
if !cert.MaxPathLenZero {
return fmt.Errorf("max path length zero is not set")
}
2015-10-14 15:46:01 +00:00
// We need to set these as they are filled in with unparsed values in the final cert
csrTemplate.Subject.Names = cert.Subject.Names
csrTemplate.Subject.ExtraNames = cert.Subject.ExtraNames
2015-10-14 15:46:01 +00:00
switch {
case !reflect.DeepEqual(cert.Subject, csrTemplate.Subject):
return fmt.Errorf("cert subject\n%#v\ndoes not match csr subject\n%#v\n", cert.Subject, csrTemplate.Subject)
case !reflect.DeepEqual(cert.DNSNames, csrTemplate.DNSNames):
return fmt.Errorf("cert dns names\n%#v\ndoes not match csr dns names\n%#v\n", cert.DNSNames, csrTemplate.DNSNames)
case !reflect.DeepEqual(cert.EmailAddresses, csrTemplate.EmailAddresses):
return fmt.Errorf("cert email addresses\n%#v\ndoes not match csr email addresses\n%#v\n", cert.EmailAddresses, csrTemplate.EmailAddresses)
case !reflect.DeepEqual(cert.IPAddresses, csrTemplate.IPAddresses):
return fmt.Errorf("cert ip addresses\n%#v\ndoes not match csr ip addresses\n%#v\n", cert.IPAddresses, csrTemplate.IPAddresses)
}
return nil
},
},
}
return ret
}
// Generates steps to test out various role permutations
func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
roleVals := roleEntry{
MaxTTL: 12 * time.Hour,
2018-04-09 19:19:05 +00:00
KeyType: "rsa",
KeyBits: 2048,
RequireCN: true,
SignatureBits: 256,
}
issueVals := certutil.IssueData{}
ret := []logicaltest.TestStep{}
roleTestStep := logicaltest.TestStep{
2016-01-07 15:30:47 +00:00
Operation: logical.UpdateOperation,
Path: "roles/test",
}
var issueTestStep logicaltest.TestStep
if useCSRs {
issueTestStep = logicaltest.TestStep{
2016-01-07 15:30:47 +00:00
Operation: logical.UpdateOperation,
Path: "sign/test",
}
} else {
issueTestStep = logicaltest.TestStep{
2016-01-07 15:30:47 +00:00
Operation: logical.UpdateOperation,
Path: "issue/test",
}
}
generatedRSAKeys := map[int]crypto.Signer{}
generatedECKeys := map[int]crypto.Signer{}
/*
// For the number of tests being run, a seed of 1 has been tested
// to hit all of the various values below. However, for normal
// testing we use a randomized time for maximum fuzziness.
*/
var seed int64 = 1
fixedSeed := os.Getenv("VAULT_PKITESTS_FIXED_SEED")
if len(fixedSeed) == 0 {
seed = time.Now().UnixNano()
} else {
var err error
seed, err = strconv.ParseInt(fixedSeed, 10, 64)
if err != nil {
t.Fatalf("error parsing fixed seed of %s: %v", fixedSeed, err)
}
}
mathRand := mathrand.New(mathrand.NewSource(seed))
// t.Logf("seed under test: %v", seed)
// Used by tests not toggling common names to turn off the behavior of random key bit fuzziness
keybitSizeRandOff := false
genericErrorOkCheck := func(resp *logical.Response) error {
if resp.IsError() {
return nil
}
return fmt.Errorf("expected an error, but did not seem to get one")
}
// Adds tests with the currently configured issue/role information
addTests := func(testCheck logicaltest.TestCheckFunc) {
2016-07-19 17:54:18 +00:00
stepCount++
// t.Logf("test step %d\nrole vals: %#v\n", stepCount, roleVals)
2016-07-19 17:54:18 +00:00
stepCount++
// t.Logf("test step %d\nissue vals: %#v\n", stepCount, issueTestStep)
roleTestStep.Data = roleVals.ToResponseData()
roleTestStep.Data["generate_lease"] = false
ret = append(ret, roleTestStep)
issueTestStep.Data = structs.New(issueVals).Map()
switch {
case issueTestStep.ErrorOk:
issueTestStep.Check = genericErrorOkCheck
case testCheck != nil:
issueTestStep.Check = testCheck
default:
issueTestStep.Check = nil
}
ret = append(ret, issueTestStep)
}
getCountryCheck := func(role roleEntry) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle
return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle)
if err != nil {
return err
}
parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil {
return fmt.Errorf("error checking generated certificate: %s", err)
}
cert := parsedCertBundle.Certificate
expected := strutil.RemoveDuplicates(role.Country, true)
if !reflect.DeepEqual(cert.Subject.Country, expected) {
return fmt.Errorf("error: returned certificate has Country of %s but %s was specified in the role", cert.Subject.Country, expected)
}
return nil
}
}
getOuCheck := func(role roleEntry) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle
return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle)
if err != nil {
return err
}
parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil {
return fmt.Errorf("error checking generated certificate: %s", err)
}
cert := parsedCertBundle.Certificate
expected := strutil.RemoveDuplicatesStable(role.OU, true)
if !reflect.DeepEqual(cert.Subject.OrganizationalUnit, expected) {
return fmt.Errorf("error: returned certificate has OU of %s but %s was specified in the role", cert.Subject.OrganizationalUnit, expected)
}
return nil
}
}
getOrganizationCheck := func(role roleEntry) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle
return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle)
if err != nil {
return err
}
parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil {
return fmt.Errorf("error checking generated certificate: %s", err)
}
cert := parsedCertBundle.Certificate
expected := strutil.RemoveDuplicates(role.Organization, true)
if !reflect.DeepEqual(cert.Subject.Organization, expected) {
return fmt.Errorf("error: returned certificate has Organization of %s but %s was specified in the role", cert.Subject.Organization, expected)
}
return nil
}
}
getLocalityCheck := func(role roleEntry) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle
return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle)
if err != nil {
return err
}
parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil {
return fmt.Errorf("error checking generated certificate: %s", err)
}
cert := parsedCertBundle.Certificate
expected := strutil.RemoveDuplicates(role.Locality, true)
if !reflect.DeepEqual(cert.Subject.Locality, expected) {
return fmt.Errorf("error: returned certificate has Locality of %s but %s was specified in the role", cert.Subject.Locality, expected)
}
return nil
}
}
getProvinceCheck := func(role roleEntry) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle
return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle)
if err != nil {
return err
}
parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil {
return fmt.Errorf("error checking generated certificate: %s", err)
}
cert := parsedCertBundle.Certificate
expected := strutil.RemoveDuplicates(role.Province, true)
if !reflect.DeepEqual(cert.Subject.Province, expected) {
return fmt.Errorf("error: returned certificate has Province of %s but %s was specified in the role", cert.Subject.Province, expected)
}
return nil
}
}
getStreetAddressCheck := func(role roleEntry) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle
return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle)
if err != nil {
return err
}
parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil {
return fmt.Errorf("error checking generated certificate: %s", err)
}
cert := parsedCertBundle.Certificate
expected := strutil.RemoveDuplicates(role.StreetAddress, true)
if !reflect.DeepEqual(cert.Subject.StreetAddress, expected) {
return fmt.Errorf("error: returned certificate has StreetAddress of %s but %s was specified in the role", cert.Subject.StreetAddress, expected)
}
return nil
}
}
getPostalCodeCheck := func(role roleEntry) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle
return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle)
if err != nil {
return err
}
parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil {
return fmt.Errorf("error checking generated certificate: %s", err)
}
cert := parsedCertBundle.Certificate
expected := strutil.RemoveDuplicates(role.PostalCode, true)
if !reflect.DeepEqual(cert.Subject.PostalCode, expected) {
return fmt.Errorf("error: returned certificate has PostalCode of %s but %s was specified in the role", cert.Subject.PostalCode, expected)
}
return nil
}
}
getNotBeforeCheck := func(role roleEntry) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle
return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle)
if err != nil {
return err
}
parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil {
return fmt.Errorf("error checking generated certificate: %s", err)
}
cert := parsedCertBundle.Certificate
actualDiff := time.Now().Sub(cert.NotBefore)
certRoleDiff := (role.NotBeforeDuration - actualDiff).Truncate(time.Second)
// These times get truncated, so give a 1 second buffer on each side
if certRoleDiff >= -1*time.Second && certRoleDiff <= 1*time.Second {
return nil
}
return fmt.Errorf("validity period out of range diff: %v", certRoleDiff)
}
}
// Returns a TestCheckFunc that performs various validity checks on the
// returned certificate information, mostly within checkCertsAndPrivateKey
getCnCheck := func(name string, role roleEntry, key crypto.Signer, usage x509.KeyUsage, extUsage x509.ExtKeyUsage, validity time.Duration) logicaltest.TestCheckFunc {
var certBundle certutil.CertBundle
return func(resp *logical.Response) error {
err := mapstructure.Decode(resp.Data, &certBundle)
if err != nil {
return err
}
parsedCertBundle, err := checkCertsAndPrivateKey(role.KeyType, key, usage, extUsage, validity, &certBundle)
if err != nil {
return fmt.Errorf("error checking generated certificate: %s", err)
}
cert := parsedCertBundle.Certificate
if cert.Subject.CommonName != name {
return fmt.Errorf("error: returned certificate has CN of %s but %s was requested", cert.Subject.CommonName, name)
}
if strings.Contains(cert.Subject.CommonName, "@") {
if len(cert.DNSNames) != 0 || len(cert.EmailAddresses) != 1 {
return fmt.Errorf("error: found more than one DNS SAN or not one Email SAN but only one was requested, cert.DNSNames = %#v, cert.EmailAddresses = %#v", cert.DNSNames, cert.EmailAddresses)
}
} else {
if len(cert.DNSNames) != 1 || len(cert.EmailAddresses) != 0 {
return fmt.Errorf("error: found more than one Email SAN or not one DNS SAN but only one was requested, cert.DNSNames = %#v, cert.EmailAddresses = %#v", cert.DNSNames, cert.EmailAddresses)
}
}
var retName string
if len(cert.DNSNames) > 0 {
retName = cert.DNSNames[0]
}
if len(cert.EmailAddresses) > 0 {
retName = cert.EmailAddresses[0]
}
if retName != name {
2018-02-10 15:07:10 +00:00
// Check IDNA
p := idna.New(
idna.StrictDomainName(true),
idna.VerifyDNSLength(true),
)
converted, err := p.ToUnicode(retName)
if err != nil {
t.Fatal(err)
}
if converted != name {
return fmt.Errorf("error: returned certificate has a DNS SAN of %s (from idna: %s) but %s was requested", retName, converted, name)
2018-02-10 15:07:10 +00:00
}
}
return nil
}
}
type csrPlan struct {
errorOk bool
roleKeyBits int
cert string
privKey crypto.Signer
}
getCsr := func(keyType string, keyBits int, csrTemplate *x509.CertificateRequest) (*pem.Block, crypto.Signer) {
var privKey crypto.Signer
var ok bool
switch keyType {
case "rsa":
privKey, ok = generatedRSAKeys[keyBits]
if !ok {
privKey, _ = rsa.GenerateKey(rand.Reader, keyBits)
generatedRSAKeys[keyBits] = privKey
}
case "ec":
var curve elliptic.Curve
switch keyBits {
case 224:
curve = elliptic.P224()
case 256:
curve = elliptic.P256()
case 384:
curve = elliptic.P384()
case 521:
curve = elliptic.P521()
}
privKey, ok = generatedECKeys[keyBits]
if !ok {
privKey, _ = ecdsa.GenerateKey(curve, rand.Reader)
generatedECKeys[keyBits] = privKey
}
default:
panic("invalid key type: " + keyType)
}
csr, err := x509.CreateCertificateRequest(rand.Reader, csrTemplate, privKey)
if err != nil {
t.Fatalf("Error creating certificate request: %s", err)
}
block := pem.Block{
Type: "CERTIFICATE REQUEST",
Bytes: csr,
}
return &block, privKey
}
getRandCsr := func(keyType string, errorOk bool, csrTemplate *x509.CertificateRequest) csrPlan {
rsaKeyBits := []int{2048, 4096}
ecKeyBits := []int{224, 256, 384, 521}
plan := csrPlan{errorOk: errorOk}
var testBitSize int
switch keyType {
case "rsa":
plan.roleKeyBits = rsaKeyBits[mathRand.Int()%2]
testBitSize = plan.roleKeyBits
// If we don't expect an error already, randomly choose a
// key size and expect an error if it's less than the role
// setting
if !keybitSizeRandOff && !errorOk {
testBitSize = rsaKeyBits[mathRand.Int()%2]
}
if testBitSize < plan.roleKeyBits {
plan.errorOk = true
}
case "ec":
plan.roleKeyBits = ecKeyBits[mathRand.Int()%4]
testBitSize = plan.roleKeyBits
// If we don't expect an error already, randomly choose a
// key size and expect an error if it's less than the role
// setting
if !keybitSizeRandOff && !errorOk {
testBitSize = ecKeyBits[mathRand.Int()%4]
}
if testBitSize < plan.roleKeyBits {
plan.errorOk = true
}
default:
panic("invalid key type: " + keyType)
}
if len(os.Getenv("VAULT_VERBOSE_PKITESTS")) > 0 {
t.Logf("roleKeyBits=%d testBitSize=%d errorOk=%v", plan.roleKeyBits, testBitSize, plan.errorOk)
}
block, privKey := getCsr(keyType, testBitSize, csrTemplate)
plan.cert = strings.TrimSpace(string(pem.EncodeToMemory(block)))
plan.privKey = privKey
return plan
}
// Common names to test with the various role flags toggled
var commonNames struct {
Localhost bool `structs:"localhost"`
BareDomain bool `structs:"example.com"`
SecondDomain bool `structs:"foobar.com"`
SubDomain bool `structs:"foo.example.com"`
Wildcard bool `structs:"*.example.com"`
SubSubdomain bool `structs:"foo.bar.example.com"`
SubSubdomainWildcard bool `structs:"*.bar.example.com"`
GlobDomain bool `structs:"fooexample.com"`
2018-02-10 15:07:10 +00:00
IDN bool `structs:"daɪˈɛrɨsɨs"`
AnyHost bool `structs:"porkslap.beer"`
}
// Adds a series of tests based on the current selection of
// allowed common names; contains some (seeded) randomness
//
// This allows for a variety of common names to be tested in various
// combinations with allowed toggles of the role
addCnTests := func() {
cnMap := structs.New(commonNames).Map()
for name, allowedInt := range cnMap {
roleVals.KeyType = "rsa"
roleVals.KeyBits = 2048
if mathRand.Int()%2 == 1 {
roleVals.KeyType = "ec"
roleVals.KeyBits = 224
}
roleVals.ServerFlag = false
roleVals.ClientFlag = false
roleVals.CodeSigningFlag = false
roleVals.EmailProtectionFlag = false
var usage []string
if mathRand.Int()%2 == 1 {
usage = append(usage, "DigitalSignature")
}
if mathRand.Int()%2 == 1 {
usage = append(usage, "ContentCoMmitment")
}
if mathRand.Int()%2 == 1 {
usage = append(usage, "KeyEncipherment")
}
if mathRand.Int()%2 == 1 {
usage = append(usage, "DataEncipherment")
}
if mathRand.Int()%2 == 1 {
usage = append(usage, "KeyAgreemEnt")
}
if mathRand.Int()%2 == 1 {
usage = append(usage, "CertSign")
}
if mathRand.Int()%2 == 1 {
usage = append(usage, "CRLSign")
}
if mathRand.Int()%2 == 1 {
usage = append(usage, "EncipherOnly")
}
if mathRand.Int()%2 == 1 {
usage = append(usage, "DecipherOnly")
}
roleVals.KeyUsage = usage
parsedKeyUsage := parseKeyUsages(roleVals.KeyUsage)
if parsedKeyUsage == 0 && len(usage) != 0 {
2016-06-23 13:49:03 +00:00
panic("parsed key usages was zero")
}
parsedKeyUsageUnderTest = parsedKeyUsage
var extUsage x509.ExtKeyUsage
i := mathRand.Int() % 4
switch {
case i == 0:
2018-02-20 05:03:45 +00:00
// Punt on this for now since I'm not clear the actual proper
// way to format these
if name != "daɪˈɛrɨsɨs" {
extUsage = x509.ExtKeyUsageEmailProtection
roleVals.EmailProtectionFlag = true
break
}
fallthrough
case i == 1:
extUsage = x509.ExtKeyUsageServerAuth
roleVals.ServerFlag = true
case i == 2:
extUsage = x509.ExtKeyUsageClientAuth
roleVals.ClientFlag = true
default:
extUsage = x509.ExtKeyUsageCodeSigning
roleVals.CodeSigningFlag = true
}
allowed := allowedInt.(bool)
issueVals.CommonName = name
if roleVals.EmailProtectionFlag {
if !strings.HasPrefix(name, "*") {
issueVals.CommonName = "user@" + issueVals.CommonName
}
}
issueTestStep.ErrorOk = !allowed
validity := roleVals.MaxTTL
if useCSRs {
templ := &x509.CertificateRequest{
Subject: pkix.Name{
CommonName: issueVals.CommonName,
},
}
plan := getRandCsr(roleVals.KeyType, issueTestStep.ErrorOk, templ)
issueVals.CSR = plan.cert
roleVals.KeyBits = plan.roleKeyBits
issueTestStep.ErrorOk = plan.errorOk
addTests(getCnCheck(issueVals.CommonName, roleVals, plan.privKey, x509.KeyUsage(parsedKeyUsage), extUsage, validity))
} else {
addTests(getCnCheck(issueVals.CommonName, roleVals, nil, x509.KeyUsage(parsedKeyUsage), extUsage, validity))
}
}
}
funcs := []interface{}{
addCnTests, getCnCheck, getCountryCheck, getLocalityCheck, getNotBeforeCheck,
getOrganizationCheck, getOuCheck, getPostalCodeCheck, getRandCsr, getStreetAddressCheck,
getProvinceCheck,
}
if len(os.Getenv("VAULT_VERBOSE_PKITESTS")) > 0 {
t.Logf("funcs=%d", len(funcs))
}
// Common Name tests
{
// common_name not provided
issueVals.CommonName = ""
issueTestStep.ErrorOk = true
addTests(nil)
// Nothing is allowed
addCnTests()
roleVals.AllowLocalhost = true
commonNames.Localhost = true
addCnTests()
roleVals.AllowedDomains = []string{"foobar.com"}
addCnTests()
roleVals.AllowedDomains = []string{"example.com"}
roleVals.AllowSubdomains = true
commonNames.SubDomain = true
commonNames.Wildcard = true
commonNames.SubSubdomain = true
commonNames.SubSubdomainWildcard = true
addCnTests()
roleVals.AllowedDomains = []string{"foobar.com", "example.com"}
commonNames.SecondDomain = true
roleVals.AllowBareDomains = true
commonNames.BareDomain = true
addCnTests()
roleVals.AllowedDomains = []string{"foobar.com", "*example.com"}
roleVals.AllowGlobDomains = true
commonNames.GlobDomain = true
addCnTests()
roleVals.AllowAnyName = true
roleVals.EnforceHostnames = true
commonNames.AnyHost = true
2018-02-10 15:07:10 +00:00
commonNames.IDN = true
addCnTests()
roleVals.EnforceHostnames = false
addCnTests()
// Ensure that we end up with acceptable key sizes since they won't be
// toggled any longer
keybitSizeRandOff = true
addCnTests()
}
// Country tests
{
roleVals.Country = []string{"foo"}
addTests(getCountryCheck(roleVals))
roleVals.Country = []string{"foo", "bar"}
addTests(getCountryCheck(roleVals))
}
// OU tests
{
roleVals.OU = []string{"foo"}
addTests(getOuCheck(roleVals))
roleVals.OU = []string{"bar", "foo"}
addTests(getOuCheck(roleVals))
}
// Organization tests
{
roleVals.Organization = []string{"system:masters"}
addTests(getOrganizationCheck(roleVals))
roleVals.Organization = []string{"foo", "bar"}
addTests(getOrganizationCheck(roleVals))
}
// Locality tests
{
roleVals.Locality = []string{"foo"}
addTests(getLocalityCheck(roleVals))
roleVals.Locality = []string{"foo", "bar"}
addTests(getLocalityCheck(roleVals))
}
// Province tests
{
roleVals.Province = []string{"foo"}
addTests(getProvinceCheck(roleVals))
roleVals.Province = []string{"foo", "bar"}
addTests(getProvinceCheck(roleVals))
}
// StreetAddress tests
{
roleVals.StreetAddress = []string{"123 foo street"}
addTests(getStreetAddressCheck(roleVals))
roleVals.StreetAddress = []string{"123 foo street", "456 bar avenue"}
addTests(getStreetAddressCheck(roleVals))
}
// PostalCode tests
{
roleVals.PostalCode = []string{"f00"}
addTests(getPostalCodeCheck(roleVals))
roleVals.PostalCode = []string{"f00", "b4r"}
addTests(getPostalCodeCheck(roleVals))
}
// NotBefore tests
{
roleVals.NotBeforeDuration = 10 * time.Second
addTests(getNotBeforeCheck(roleVals))
roleVals.NotBeforeDuration = 30 * time.Second
addTests(getNotBeforeCheck(roleVals))
roleVals.NotBeforeDuration = 0
}
// IP SAN tests
{
getIpCheck := func(expectedIp ...net.IP) logicaltest.TestCheckFunc {
return func(resp *logical.Response) error {
var certBundle certutil.CertBundle
err := mapstructure.Decode(resp.Data, &certBundle)
if err != nil {
return err
}
parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil {
return fmt.Errorf("error parsing cert bundle: %s", err)
}
cert := parsedCertBundle.Certificate
var emptyIPs []net.IP
var expected []net.IP = append(emptyIPs, expectedIp...)
if diff := deep.Equal(cert.IPAddresses, expected); len(diff) > 0 {
return fmt.Errorf("wrong SAN IPs, diff: %v", diff)
}
return nil
}
}
addIPSANTests := func(useCSRs, useCSRSANs, allowIPSANs, errorOk bool, ipSANs string, csrIPSANs []net.IP, check logicaltest.TestCheckFunc) {
if useCSRs {
csrTemplate := &x509.CertificateRequest{
Subject: pkix.Name{
CommonName: issueVals.CommonName,
},
IPAddresses: csrIPSANs,
}
block, _ := getCsr(roleVals.KeyType, roleVals.KeyBits, csrTemplate)
issueVals.CSR = strings.TrimSpace(string(pem.EncodeToMemory(block)))
}
oldRoleVals, oldIssueVals, oldIssueTestStep := roleVals, issueVals, issueTestStep
roleVals.UseCSRSANs = useCSRSANs
roleVals.AllowIPSANs = allowIPSANs
issueVals.CommonName = "someone@example.com"
issueVals.IPSANs = ipSANs
issueTestStep.ErrorOk = errorOk
addTests(check)
roleVals, issueVals, issueTestStep = oldRoleVals, oldIssueVals, oldIssueTestStep
}
roleVals.AllowAnyName = true
roleVals.EnforceHostnames = true
roleVals.AllowLocalhost = true
roleVals.UseCSRCommonName = true
commonNames.Localhost = true
netip1, netip2 := net.IP{127, 0, 0, 1}, net.IP{170, 171, 172, 173}
textip1, textip3 := "127.0.0.1", "::1"
// IPSANs not allowed and not provided, should not be an error.
addIPSANTests(useCSRs, false, false, false, "", nil, getIpCheck())
// IPSANs not allowed, valid IPSANs provided, should be an error.
addIPSANTests(useCSRs, false, false, true, textip1+","+textip3, nil, nil)
// IPSANs allowed, bogus IPSANs provided, should be an error.
addIPSANTests(useCSRs, false, true, true, "foobar", nil, nil)
// Given IPSANs as API argument and useCSRSANs false, CSR arg ignored.
addIPSANTests(useCSRs, false, true, false, textip1,
[]net.IP{netip2}, getIpCheck(netip1))
if useCSRs {
// IPSANs not allowed, valid IPSANs provided via CSR, should be an error.
addIPSANTests(useCSRs, true, false, true, "", []net.IP{netip1}, nil)
// Given IPSANs as both API and CSR arguments and useCSRSANs=true, API arg ignored.
addIPSANTests(useCSRs, true, true, false, textip3,
[]net.IP{netip1, netip2}, getIpCheck(netip1, netip2))
}
}
{
getOtherCheck := func(expectedOthers ...otherNameUtf8) logicaltest.TestCheckFunc {
return func(resp *logical.Response) error {
var certBundle certutil.CertBundle
err := mapstructure.Decode(resp.Data, &certBundle)
if err != nil {
return err
}
parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil {
return fmt.Errorf("error parsing cert bundle: %s", err)
}
cert := parsedCertBundle.Certificate
foundOthers, err := getOtherSANsFromX509Extensions(cert.Extensions)
if err != nil {
return err
}
var emptyOthers []otherNameUtf8
var expected []otherNameUtf8 = append(emptyOthers, expectedOthers...)
if diff := deep.Equal(foundOthers, expected); len(diff) > 0 {
return fmt.Errorf("wrong SAN IPs, diff: %v", diff)
}
return nil
}
}
addOtherSANTests := func(useCSRs, useCSRSANs bool, allowedOtherSANs []string, errorOk bool, otherSANs []string, csrOtherSANs []otherNameUtf8, check logicaltest.TestCheckFunc) {
otherSansMap := func(os []otherNameUtf8) map[string][]string {
ret := make(map[string][]string)
for _, o := range os {
ret[o.oid] = append(ret[o.oid], o.value)
}
return ret
}
if useCSRs {
csrTemplate := &x509.CertificateRequest{
Subject: pkix.Name{
CommonName: issueVals.CommonName,
},
}
if err := handleOtherCSRSANs(csrTemplate, otherSansMap(csrOtherSANs)); err != nil {
t.Fatal(err)
}
block, _ := getCsr(roleVals.KeyType, roleVals.KeyBits, csrTemplate)
issueVals.CSR = strings.TrimSpace(string(pem.EncodeToMemory(block)))
}
oldRoleVals, oldIssueVals, oldIssueTestStep := roleVals, issueVals, issueTestStep
roleVals.UseCSRSANs = useCSRSANs
roleVals.AllowedOtherSANs = allowedOtherSANs
issueVals.CommonName = "someone@example.com"
issueVals.OtherSANs = strings.Join(otherSANs, ",")
issueTestStep.ErrorOk = errorOk
addTests(check)
roleVals, issueVals, issueTestStep = oldRoleVals, oldIssueVals, oldIssueTestStep
}
roleVals.AllowAnyName = true
roleVals.EnforceHostnames = true
roleVals.AllowLocalhost = true
roleVals.UseCSRCommonName = true
commonNames.Localhost = true
newOtherNameUtf8 := func(s string) (ret otherNameUtf8) {
pieces := strings.Split(s, ";")
if len(pieces) == 2 {
piecesRest := strings.Split(pieces[1], ":")
if len(piecesRest) == 2 {
switch strings.ToUpper(piecesRest[0]) {
case "UTF-8", "UTF8":
return otherNameUtf8{oid: pieces[0], value: piecesRest[1]}
}
}
}
t.Fatalf("error parsing otherName: %q", s)
return
}
oid1 := "1.3.6.1.4.1.311.20.2.3"
oth1str := oid1 + ";utf8:devops@nope.com"
oth1 := newOtherNameUtf8(oth1str)
oth2 := otherNameUtf8{oid1, "me@example.com"}
// allowNone, allowAll := []string{}, []string{oid1 + ";UTF-8:*"}
allowNone, allowAll := []string{}, []string{"*"}
// OtherSANs not allowed and not provided, should not be an error.
addOtherSANTests(useCSRs, false, allowNone, false, nil, nil, getOtherCheck())
// OtherSANs not allowed, valid OtherSANs provided, should be an error.
addOtherSANTests(useCSRs, false, allowNone, true, []string{oth1str}, nil, nil)
// OtherSANs allowed, bogus OtherSANs provided, should be an error.
addOtherSANTests(useCSRs, false, allowAll, true, []string{"foobar"}, nil, nil)
// Given OtherSANs as API argument and useCSRSANs false, CSR arg ignored.
addOtherSANTests(useCSRs, false, allowAll, false, []string{oth1str},
[]otherNameUtf8{oth2}, getOtherCheck(oth1))
if useCSRs {
// OtherSANs not allowed, valid OtherSANs provided via CSR, should be an error.
addOtherSANTests(useCSRs, true, allowNone, true, nil, []otherNameUtf8{oth1}, nil)
// Given OtherSANs as both API and CSR arguments and useCSRSANs=true, API arg ignored.
addOtherSANTests(useCSRs, false, allowAll, false, []string{oth2.String()},
[]otherNameUtf8{oth1}, getOtherCheck(oth2))
}
}
// Lease tests
{
roleTestStep.ErrorOk = true
roleVals.Lease = ""
roleVals.MaxTTL = 0
addTests(nil)
roleVals.Lease = "12h"
roleVals.MaxTTL = 6 * time.Hour
addTests(nil)
roleTestStep.ErrorOk = false
2018-05-09 16:47:00 +00:00
roleVals.TTL = 0
roleVals.MaxTTL = 12 * time.Hour
}
2016-01-28 20:18:07 +00:00
// Listing test
ret = append(ret, logicaltest.TestStep{
Operation: logical.ListOperation,
Path: "roles/",
Check: func(resp *logical.Response) error {
if resp.Data == nil {
return fmt.Errorf("nil data")
}
keysRaw, ok := resp.Data["keys"]
if !ok {
return fmt.Errorf("no keys found")
}
keys, ok := keysRaw.([]string)
if !ok {
return fmt.Errorf("could not convert keys to a string list")
}
if len(keys) != 1 {
return fmt.Errorf("unexpected keys length of %d", len(keys))
}
if keys[0] != "test" {
2016-04-13 19:38:29 +00:00
return fmt.Errorf("unexpected key value of %s", keys[0])
2016-01-28 20:18:07 +00:00
}
return nil
},
})
return ret
}
2016-06-08 15:46:58 +00:00
func TestBackend_PathFetchCertList(t *testing.T) {
// create the backend
config := logical.TestBackendConfig()
storage := &logical.InmemStorage{}
config.StorageView = storage
b := Backend(config)
err := b.Setup(context.Background(), config)
2016-06-08 15:46:58 +00:00
if err != nil {
t.Fatal(err)
}
// generate root
rootData := map[string]interface{}{
"common_name": "test.com",
"ttl": "6h",
}
resp, err := b.HandleRequest(context.Background(), &logical.Request{
2016-06-08 15:46:58 +00:00
Operation: logical.UpdateOperation,
Path: "root/generate/internal",
Storage: storage,
Data: rootData,
})
if resp != nil && resp.IsError() {
t.Fatalf("failed to generate root, %#v", resp)
}
if err != nil {
t.Fatal(err)
}
// config urls
urlsData := map[string]interface{}{
"issuing_certificates": "http://127.0.0.1:8200/v1/pki/ca",
"crl_distribution_points": "http://127.0.0.1:8200/v1/pki/crl",
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
2016-06-08 15:46:58 +00:00
Operation: logical.UpdateOperation,
Path: "config/urls",
Storage: storage,
Data: urlsData,
})
if resp != nil && resp.IsError() {
t.Fatalf("failed to config urls, %#v", resp)
}
if err != nil {
t.Fatal(err)
}
// create a role entry
roleData := map[string]interface{}{
"allowed_domains": "test.com",
"allow_subdomains": "true",
"max_ttl": "4h",
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
2016-06-08 15:46:58 +00:00
Operation: logical.UpdateOperation,
Path: "roles/test-example",
Storage: storage,
Data: roleData,
})
if resp != nil && resp.IsError() {
t.Fatalf("failed to create a role, %#v", resp)
}
if err != nil {
t.Fatal(err)
}
2016-06-08 16:49:10 +00:00
// issue some certs
i := 1
for i < 10 {
certData := map[string]interface{}{
"common_name": "example.test.com",
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
2016-06-08 16:49:10 +00:00
Operation: logical.UpdateOperation,
Path: "issue/test-example",
Storage: storage,
Data: certData,
})
if resp != nil && resp.IsError() {
t.Fatalf("failed to issue a cert, %#v", resp)
}
if err != nil {
t.Fatal(err)
}
i = i + 1
2016-06-08 15:46:58 +00:00
}
2016-06-08 16:49:10 +00:00
// list certs
resp, err = b.HandleRequest(context.Background(), &logical.Request{
2016-06-08 16:49:10 +00:00
Operation: logical.ListOperation,
Path: "certs",
2016-06-08 15:46:58 +00:00
Storage: storage,
})
if resp != nil && resp.IsError() {
2016-06-08 16:49:10 +00:00
t.Fatalf("failed to list certs, %#v", resp)
2016-06-08 15:46:58 +00:00
}
if err != nil {
t.Fatal(err)
}
2016-06-08 16:49:10 +00:00
// check that the root and 9 additional certs are all listed
if len(resp.Data["keys"].([]string)) != 10 {
t.Fatalf("failed to list all 10 certs")
}
2016-06-08 15:46:58 +00:00
2016-06-08 16:49:10 +00:00
// list certs/
resp, err = b.HandleRequest(context.Background(), &logical.Request{
2016-06-08 15:46:58 +00:00
Operation: logical.ListOperation,
2016-06-08 16:49:10 +00:00
Path: "certs/",
2016-06-08 15:46:58 +00:00
Storage: storage,
})
if resp != nil && resp.IsError() {
t.Fatalf("failed to list certs, %#v", resp)
}
if err != nil {
t.Fatal(err)
}
2016-06-08 16:49:10 +00:00
// check that the root and 9 additional certs are all listed
if len(resp.Data["keys"].([]string)) != 10 {
t.Fatalf("failed to list all 10 certs")
}
2016-06-08 15:46:58 +00:00
}
func TestBackend_SignVerbatim(t *testing.T) {
// create the backend
config := logical.TestBackendConfig()
storage := &logical.InmemStorage{}
config.StorageView = storage
b := Backend(config)
err := b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
// generate root
rootData := map[string]interface{}{
"common_name": "test.com",
"ttl": "172800",
}
resp, err := b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.UpdateOperation,
Path: "root/generate/internal",
Storage: storage,
Data: rootData,
})
if resp != nil && resp.IsError() {
t.Fatalf("failed to generate root, %#v", *resp)
}
if err != nil {
t.Fatal(err)
}
// create a CSR and key
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
csrReq := &x509.CertificateRequest{
Subject: pkix.Name{
CommonName: "foo.bar.com",
},
}
csr, err := x509.CreateCertificateRequest(rand.Reader, csrReq, key)
if err != nil {
t.Fatal(err)
}
if len(csr) == 0 {
t.Fatal("generated csr is empty")
}
pemCSR := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE REQUEST",
Bytes: csr,
})))
if len(pemCSR) == 0 {
t.Fatal("pem csr is empty")
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.UpdateOperation,
Path: "sign-verbatim",
Storage: storage,
Data: map[string]interface{}{
"csr": pemCSR,
},
})
if resp != nil && resp.IsError() {
t.Fatalf("failed to sign-verbatim basic CSR: %#v", *resp)
}
if err != nil {
t.Fatal(err)
}
if resp.Secret != nil {
t.Fatal("secret is not nil")
}
// create a role entry; we use this to check that sign-verbatim when used with a role is still honoring TTLs
roleData := map[string]interface{}{
"ttl": "4h",
"max_ttl": "8h",
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.UpdateOperation,
Path: "roles/test",
Storage: storage,
Data: roleData,
})
if resp != nil && resp.IsError() {
t.Fatalf("failed to create a role, %#v", *resp)
}
if err != nil {
t.Fatal(err)
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.UpdateOperation,
Path: "sign-verbatim/test",
Storage: storage,
Data: map[string]interface{}{
"csr": pemCSR,
"ttl": "5h",
},
})
if resp != nil && resp.IsError() {
t.Fatalf("failed to sign-verbatim ttl'd CSR: %#v", *resp)
}
if err != nil {
t.Fatal(err)
}
if resp.Secret != nil {
t.Fatal("got a lease when we should not have")
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.UpdateOperation,
Path: "sign-verbatim/test",
Storage: storage,
Data: map[string]interface{}{
"csr": pemCSR,
"ttl": "12h",
},
})
if err != nil {
t.Fatal(err)
}
if resp != nil && resp.IsError() {
t.Fatalf(resp.Error().Error())
}
if resp.Data == nil || resp.Data["certificate"] == nil {
t.Fatal("did not get expected data")
}
certString := resp.Data["certificate"].(string)
block, _ := pem.Decode([]byte(certString))
if block == nil {
t.Fatal("nil pem block")
}
certs, err := x509.ParseCertificates(block.Bytes)
if err != nil {
t.Fatal(err)
}
if len(certs) != 1 {
t.Fatalf("expected a single cert, got %d", len(certs))
}
cert := certs[0]
if math.Abs(float64(time.Now().Add(12*time.Hour).Unix()-cert.NotAfter.Unix())) < 10 {
t.Fatalf("sign-verbatim did not properly cap validity period on signed CSR")
}
// now check that if we set generate-lease it takes it from the role and the TTLs match
roleData = map[string]interface{}{
"ttl": "4h",
"max_ttl": "8h",
"generate_lease": true,
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.UpdateOperation,
Path: "roles/test",
Storage: storage,
Data: roleData,
})
if resp != nil && resp.IsError() {
t.Fatalf("failed to create a role, %#v", *resp)
}
if err != nil {
t.Fatal(err)
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.UpdateOperation,
Path: "sign-verbatim/test",
Storage: storage,
Data: map[string]interface{}{
"csr": pemCSR,
"ttl": "5h",
},
})
if resp != nil && resp.IsError() {
t.Fatalf("failed to sign-verbatim role-leased CSR: %#v", *resp)
}
if err != nil {
t.Fatal(err)
}
if resp.Secret == nil {
t.Fatalf("secret is nil, response is %#v", *resp)
}
if math.Abs(float64(resp.Secret.TTL-(5*time.Hour))) > float64(5*time.Hour) {
t.Fatalf("ttl not default; wanted %v, got %v", b.System().DefaultLeaseTTL(), resp.Secret.TTL)
}
}
2018-03-20 18:54:10 +00:00
func TestBackend_Root_Idempotency(t *testing.T) {
coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"pki": Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
client := cluster.Cores[0].Client
var err error
err = client.Sys().Mount("pki", &api.MountInput{
Type: "pki",
Config: api.MountConfigInput{
DefaultLeaseTTL: "16h",
MaxLeaseTTL: "32h",
},
})
if err != nil {
t.Fatal(err)
}
resp, err := client.Logical().Write("pki/root/generate/internal", map[string]interface{}{
"common_name": "myvault.com",
})
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("expected ca info")
}
resp, err = client.Logical().Read("pki/cert/ca_chain")
if err != nil {
t.Fatalf("error reading ca_chain: %v", err)
}
r1Data := resp.Data
// Try again, make sure it's a 204 and same CA
resp, err = client.Logical().Write("pki/root/generate/internal", map[string]interface{}{
"common_name": "myvault.com",
})
if err != nil {
t.Fatal(err)
}
2018-05-09 16:51:34 +00:00
if resp == nil {
t.Fatal("expected a warning")
}
if resp.Data != nil || len(resp.Warnings) == 0 {
t.Fatalf("bad response: %#v", *resp)
}
resp, err = client.Logical().Read("pki/cert/ca_chain")
if err != nil {
t.Fatalf("error reading ca_chain: %v", err)
}
r2Data := resp.Data
if !reflect.DeepEqual(r1Data, r2Data) {
t.Fatal("got different ca certs")
}
resp, err = client.Logical().Delete("pki/root")
if err != nil {
t.Fatal(err)
}
if resp != nil {
t.Fatal("expected nil response")
}
// Make sure it behaves the same
resp, err = client.Logical().Delete("pki/root")
if err != nil {
t.Fatal(err)
}
if resp != nil {
t.Fatal("expected nil response")
}
_, err = client.Logical().Read("pki/cert/ca_chain")
if err == nil {
t.Fatal("expected error")
}
resp, err = client.Logical().Write("pki/root/generate/internal", map[string]interface{}{
"common_name": "myvault.com",
})
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("expected ca info")
}
_, err = client.Logical().Read("pki/cert/ca_chain")
if err != nil {
t.Fatal(err)
}
}
func TestBackend_SignIntermediate_AllowedPastCA(t *testing.T) {
coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"pki": Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
client := cluster.Cores[0].Client
var err error
err = client.Sys().Mount("root", &api.MountInput{
Type: "pki",
Config: api.MountConfigInput{
DefaultLeaseTTL: "16h",
MaxLeaseTTL: "60h",
},
})
if err != nil {
t.Fatal(err)
}
err = client.Sys().Mount("int", &api.MountInput{
Type: "pki",
Config: api.MountConfigInput{
DefaultLeaseTTL: "4h",
MaxLeaseTTL: "20h",
},
})
if err != nil {
t.Fatal(err)
}
// Direct issuing from root
_, err = client.Logical().Write("root/root/generate/internal", map[string]interface{}{
"ttl": "40h",
"common_name": "myvault.com",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("root/roles/test", map[string]interface{}{
"allow_bare_domains": true,
"allow_subdomains": true,
})
if err != nil {
t.Fatal(err)
}
resp, err := client.Logical().Write("int/intermediate/generate/internal", map[string]interface{}{
"common_name": "myint.com",
})
if err != nil {
t.Fatal(err)
}
csr := resp.Data["csr"]
_, err = client.Logical().Write("root/sign/test", map[string]interface{}{
"common_name": "myint.com",
"csr": csr,
"ttl": "60h",
})
if err == nil {
t.Fatal("expected error")
}
_, err = client.Logical().Write("root/sign-verbatim/test", map[string]interface{}{
"common_name": "myint.com",
"csr": csr,
"ttl": "60h",
})
if err == nil {
t.Fatal("expected error")
}
resp, err = client.Logical().Write("root/root/sign-intermediate", map[string]interface{}{
"common_name": "myint.com",
"csr": csr,
"ttl": "60h",
})
if err != nil {
t.Fatalf("got error: %v", err)
}
if resp == nil {
t.Fatal("got nil response")
}
if len(resp.Warnings) == 0 {
t.Fatalf("expected warnings, got %#v", *resp)
}
}
func TestBackend_SignSelfIssued(t *testing.T) {
// create the backend
config := logical.TestBackendConfig()
storage := &logical.InmemStorage{}
config.StorageView = storage
b := Backend(config)
err := b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
// generate root
rootData := map[string]interface{}{
"common_name": "test.com",
"ttl": "172800",
}
resp, err := b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.UpdateOperation,
Path: "root/generate/internal",
Storage: storage,
Data: rootData,
})
if resp != nil && resp.IsError() {
t.Fatalf("failed to generate root, %#v", *resp)
}
if err != nil {
t.Fatal(err)
}
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
template := &x509.Certificate{
Subject: pkix.Name{
CommonName: "foo.bar.com",
},
SerialNumber: big.NewInt(1234),
IsCA: false,
BasicConstraintsValid: true,
}
ss, _ := getSelfSigned(t, template, template, key)
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.UpdateOperation,
Path: "root/sign-self-issued",
Storage: storage,
Data: map[string]interface{}{
"certificate": ss,
},
})
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("got nil response")
}
if !resp.IsError() {
t.Fatalf("expected error due to non-CA; got: %#v", *resp)
}
// Set CA to true, but leave issuer alone
template.IsCA = true
issuer := &x509.Certificate{
Subject: pkix.Name{
CommonName: "bar.foo.com",
},
SerialNumber: big.NewInt(2345),
IsCA: true,
BasicConstraintsValid: true,
}
ss, ssCert := getSelfSigned(t, template, issuer, key)
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.UpdateOperation,
Path: "root/sign-self-issued",
Storage: storage,
Data: map[string]interface{}{
"certificate": ss,
},
})
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("got nil response")
}
if !resp.IsError() {
t.Fatalf("expected error due to different issuer; cert info is\nIssuer\n%#v\nSubject\n%#v\n", ssCert.Issuer, ssCert.Subject)
}
ss, ssCert = getSelfSigned(t, template, template, key)
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.UpdateOperation,
Path: "root/sign-self-issued",
Storage: storage,
Data: map[string]interface{}{
"certificate": ss,
},
})
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("got nil response")
}
if resp.IsError() {
t.Fatalf("error in response: %s", resp.Error().Error())
}
newCertString := resp.Data["certificate"].(string)
block, _ := pem.Decode([]byte(newCertString))
newCert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
t.Fatal(err)
}
signingBundle, err := fetchCAInfo(context.Background(), &logical.Request{Storage: storage})
if err != nil {
t.Fatal(err)
}
if reflect.DeepEqual(newCert.Subject, newCert.Issuer) {
t.Fatal("expected different subject/issuer")
}
if !reflect.DeepEqual(newCert.Issuer, signingBundle.Certificate.Subject) {
t.Fatalf("expected matching issuer/CA subject\n\nIssuer:\n%#v\nSubject:\n%#v\n", newCert.Issuer, signingBundle.Certificate.Subject)
}
if bytes.Equal(newCert.AuthorityKeyId, newCert.SubjectKeyId) {
t.Fatal("expected different authority/subject")
}
if !bytes.Equal(newCert.AuthorityKeyId, signingBundle.Certificate.SubjectKeyId) {
t.Fatal("expected authority on new cert to be same as signing subject")
}
if newCert.Subject.CommonName != "foo.bar.com" {
t.Fatalf("unexpected common name on new cert: %s", newCert.Subject.CommonName)
}
}
// TestBackend_SignSelfIssued_DifferentTypes is a copy of
// TestBackend_SignSelfIssued, but uses a different key type for the internal
// root (EC instead of RSA). This verifies that we can cross-sign CAs that are
// different key types, at the cost of verifying the algorithm used
func TestBackend_SignSelfIssued_DifferentTypes(t *testing.T) {
// create the backend
config := logical.TestBackendConfig()
storage := &logical.InmemStorage{}
config.StorageView = storage
b := Backend(config)
err := b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
// generate root
rootData := map[string]interface{}{
"common_name": "test.com",
"ttl": "172800",
"key_type": "ec",
"key_bits": "521",
}
resp, err := b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.UpdateOperation,
Path: "root/generate/internal",
Storage: storage,
Data: rootData,
})
if resp != nil && resp.IsError() {
t.Fatalf("failed to generate root, %#v", *resp)
}
if err != nil {
t.Fatal(err)
}
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
template := &x509.Certificate{
Subject: pkix.Name{
CommonName: "foo.bar.com",
},
2018-11-07 21:52:01 +00:00
SerialNumber: big.NewInt(1234),
IsCA: false,
BasicConstraintsValid: true,
}
ss, _ := getSelfSigned(t, template, template, key)
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.UpdateOperation,
Path: "root/sign-self-issued",
Storage: storage,
Data: map[string]interface{}{
"certificate": ss,
},
})
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("got nil response")
}
if !resp.IsError() {
t.Fatalf("expected error due to non-CA; got: %#v", *resp)
}
// Set CA to true, but leave issuer alone
template.IsCA = true
issuer := &x509.Certificate{
Subject: pkix.Name{
CommonName: "bar.foo.com",
},
2018-11-07 21:52:01 +00:00
SerialNumber: big.NewInt(2345),
IsCA: true,
BasicConstraintsValid: true,
}
ss, ssCert := getSelfSigned(t, template, issuer, key)
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.UpdateOperation,
Path: "root/sign-self-issued",
Storage: storage,
Data: map[string]interface{}{
"certificate": ss,
},
})
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("got nil response")
}
if !resp.IsError() {
t.Fatalf("expected error due to different issuer; cert info is\nIssuer\n%#v\nSubject\n%#v\n", ssCert.Issuer, ssCert.Subject)
}
ss, ssCert = getSelfSigned(t, template, template, key)
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.UpdateOperation,
Path: "root/sign-self-issued",
Storage: storage,
Data: map[string]interface{}{
"certificate": ss,
},
})
if err == nil {
t.Fatal("expected error due to different signature algo but not opted-in")
}
ss, ssCert = getSelfSigned(t, template, template, key)
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.UpdateOperation,
Path: "root/sign-self-issued",
Storage: storage,
Data: map[string]interface{}{
"certificate": ss,
"allow_different_signature_algorithm": "true",
},
})
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("got nil response")
}
if resp.IsError() {
t.Fatalf("error in response: %s", resp.Error().Error())
}
newCertString := resp.Data["certificate"].(string)
block, _ := pem.Decode([]byte(newCertString))
newCert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
t.Fatal(err)
}
signingBundle, err := fetchCAInfo(context.Background(), &logical.Request{Storage: storage})
if err != nil {
t.Fatal(err)
}
if reflect.DeepEqual(newCert.Subject, newCert.Issuer) {
t.Fatal("expected different subject/issuer")
}
if !reflect.DeepEqual(newCert.Issuer, signingBundle.Certificate.Subject) {
t.Fatalf("expected matching issuer/CA subject\n\nIssuer:\n%#v\nSubject:\n%#v\n", newCert.Issuer, signingBundle.Certificate.Subject)
}
if bytes.Equal(newCert.AuthorityKeyId, newCert.SubjectKeyId) {
t.Fatal("expected different authority/subject")
}
if !bytes.Equal(newCert.AuthorityKeyId, signingBundle.Certificate.SubjectKeyId) {
t.Fatal("expected authority on new cert to be same as signing subject")
}
if newCert.Subject.CommonName != "foo.bar.com" {
t.Fatalf("unexpected common name on new cert: %s", newCert.Subject.CommonName)
}
}
func getSelfSigned(t *testing.T, subject, issuer *x509.Certificate, key *rsa.PrivateKey) (string, *x509.Certificate) {
t.Helper()
selfSigned, err := x509.CreateCertificate(rand.Reader, subject, issuer, key.Public(), key)
if err != nil {
t.Fatal(err)
}
cert, err := x509.ParseCertificate(selfSigned)
if err != nil {
t.Fatal(err)
}
pemSS := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: selfSigned,
})))
return pemSS, cert
}
2018-02-16 22:19:34 +00:00
// This is a really tricky test because the Go stdlib asn1 package is incapable
// of doing the right thing with custom OID SANs (see comments in the package,
// it's readily admitted that it's too magic) but that means that any
// validation logic written for this test isn't being independently verified,
// as in, if cryptobytes is used to decode it to make the test work, that
// doesn't mean we're encoding and decoding correctly, only that we made the
// test pass. Instead, when run verbosely it will first perform a bunch of
// checks to verify that the OID SAN logic doesn't screw up other SANs, then
// will spit out the PEM. This can be validated independently.
//
// You want the hex dump of the octet string corresponding to the X509v3
// Subject Alternative Name. There's a nice online utility at
// https://lapo.it/asn1js that can be used to view the structure of an
// openssl-generated other SAN at
// https://lapo.it/asn1js/#3022A020060A2B060104018237140203A0120C106465766F7073406C6F63616C686F7374
// (openssl asn1parse can also be used with -strparse using an offset of the
// hex blob for the subject alternative names extension).
//
// The structure output from here should match that precisely (even if the OID
// itself doesn't) in the second test.
//
// The test that encodes two should have them be in separate elements in the
// top-level sequence; see
// https://lapo.it/asn1js/#3046A020060A2B060104018237140203A0120C106465766F7073406C6F63616C686F7374A022060A2B060104018237140204A0140C12322D6465766F7073406C6F63616C686F7374 for an openssl-generated example.
//
// The good news is that it's valid to simply copy and paste the PEM output from
2018-02-16 22:19:34 +00:00
// here into the form at that site as it will do the right thing so it's pretty
// easy to validate.
func TestBackend_OID_SANs(t *testing.T) {
coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"pki": Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
client := cluster.Cores[0].Client
var err error
err = client.Sys().Mount("root", &api.MountInput{
Type: "pki",
Config: api.MountConfigInput{
DefaultLeaseTTL: "16h",
MaxLeaseTTL: "60h",
},
})
if err != nil {
t.Fatal(err)
}
var resp *api.Secret
var certStr string
var block *pem.Block
var cert *x509.Certificate
_, err = client.Logical().Write("root/root/generate/internal", map[string]interface{}{
"ttl": "40h",
"common_name": "myvault.com",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("root/roles/test", map[string]interface{}{
"allowed_domains": []string{"foobar.com", "zipzap.com"},
"allow_bare_domains": true,
"allow_subdomains": true,
"allow_ip_sans": true,
"allowed_other_sans": "1.3.6.1.4.1.311.20.2.3;UTF8:devops@*,1.3.6.1.4.1.311.20.2.4;utf8:d*e@foobar.com",
})
if err != nil {
t.Fatal(err)
}
// Get a baseline before adding OID SANs. In the next sections we'll verify
// that the SANs are all added even as the OID SAN inclusion forces other
// adding logic (custom rather than built-in Golang logic)
resp, err = client.Logical().Write("root/issue/test", map[string]interface{}{
"common_name": "foobar.com",
"ip_sans": "1.2.3.4",
"alt_names": "foobar.com,foo.foobar.com,bar.foobar.com",
2018-02-16 22:19:34 +00:00
"ttl": "1h",
})
if err != nil {
t.Fatal(err)
}
certStr = resp.Data["certificate"].(string)
block, _ = pem.Decode([]byte(certStr))
cert, err = x509.ParseCertificate(block.Bytes)
if err != nil {
t.Fatal(err)
}
if cert.IPAddresses[0].String() != "1.2.3.4" {
t.Fatalf("unexpected IP SAN %q", cert.IPAddresses[0].String())
}
if len(cert.DNSNames) != 3 ||
cert.DNSNames[0] != "bar.foobar.com" ||
cert.DNSNames[1] != "foo.foobar.com" ||
cert.DNSNames[2] != "foobar.com" {
2018-02-16 22:19:34 +00:00
t.Fatalf("unexpected DNS SANs %v", cert.DNSNames)
}
// First test some bad stuff that shouldn't work
resp, err = client.Logical().Write("root/issue/test", map[string]interface{}{
"common_name": "foobar.com",
"ip_sans": "1.2.3.4",
"alt_names": "foo.foobar.com,bar.foobar.com",
"ttl": "1h",
// Not a valid value for the first possibility
"other_sans": "1.3.6.1.4.1.311.20.2.3;UTF8:devop@nope.com",
})
if err == nil {
t.Fatal("expected error")
}
resp, err = client.Logical().Write("root/issue/test", map[string]interface{}{
"common_name": "foobar.com",
"ip_sans": "1.2.3.4",
"alt_names": "foo.foobar.com,bar.foobar.com",
"ttl": "1h",
// Not a valid OID for the first possibility
"other_sans": "1.3.6.1.4.1.311.20.2.5;UTF8:devops@nope.com",
})
if err == nil {
t.Fatal("expected error")
}
resp, err = client.Logical().Write("root/issue/test", map[string]interface{}{
"common_name": "foobar.com",
"ip_sans": "1.2.3.4",
"alt_names": "foo.foobar.com,bar.foobar.com",
"ttl": "1h",
// Not a valid name for the second possibility
"other_sans": "1.3.6.1.4.1.311.20.2.4;UTF8:d34g@foobar.com",
})
if err == nil {
t.Fatal("expected error")
}
resp, err = client.Logical().Write("root/issue/test", map[string]interface{}{
"common_name": "foobar.com",
"ip_sans": "1.2.3.4",
"alt_names": "foo.foobar.com,bar.foobar.com",
"ttl": "1h",
// Not a valid OID for the second possibility
"other_sans": "1.3.6.1.4.1.311.20.2.5;UTF8:d34e@foobar.com",
})
if err == nil {
t.Fatal("expected error")
}
resp, err = client.Logical().Write("root/issue/test", map[string]interface{}{
"common_name": "foobar.com",
"ip_sans": "1.2.3.4",
"alt_names": "foo.foobar.com,bar.foobar.com",
"ttl": "1h",
// Not a valid type
"other_sans": "1.3.6.1.4.1.311.20.2.5;UTF2:d34e@foobar.com",
})
if err == nil {
t.Fatal("expected error")
}
// Valid for first possibility
resp, err = client.Logical().Write("root/issue/test", map[string]interface{}{
"common_name": "foobar.com",
"ip_sans": "1.2.3.4",
"alt_names": "foo.foobar.com,bar.foobar.com",
"ttl": "1h",
"other_sans": "1.3.6.1.4.1.311.20.2.3;utf8:devops@nope.com",
})
if err != nil {
t.Fatal(err)
}
certStr = resp.Data["certificate"].(string)
block, _ = pem.Decode([]byte(certStr))
cert, err = x509.ParseCertificate(block.Bytes)
if err != nil {
t.Fatal(err)
}
if cert.IPAddresses[0].String() != "1.2.3.4" {
t.Fatalf("unexpected IP SAN %q", cert.IPAddresses[0].String())
}
if len(cert.DNSNames) != 3 ||
cert.DNSNames[0] != "bar.foobar.com" ||
cert.DNSNames[1] != "foo.foobar.com" ||
cert.DNSNames[2] != "foobar.com" {
2018-02-16 22:19:34 +00:00
t.Fatalf("unexpected DNS SANs %v", cert.DNSNames)
}
if len(os.Getenv("VAULT_VERBOSE_PKITESTS")) > 0 {
t.Logf("certificate 1 to check:\n%s", certStr)
}
2018-02-16 22:19:34 +00:00
// Valid for second possibility
resp, err = client.Logical().Write("root/issue/test", map[string]interface{}{
"common_name": "foobar.com",
"ip_sans": "1.2.3.4",
"alt_names": "foo.foobar.com,bar.foobar.com",
"ttl": "1h",
"other_sans": "1.3.6.1.4.1.311.20.2.4;UTF8:d234e@foobar.com",
})
if err != nil {
t.Fatal(err)
}
certStr = resp.Data["certificate"].(string)
block, _ = pem.Decode([]byte(certStr))
cert, err = x509.ParseCertificate(block.Bytes)
if err != nil {
t.Fatal(err)
}
if cert.IPAddresses[0].String() != "1.2.3.4" {
t.Fatalf("unexpected IP SAN %q", cert.IPAddresses[0].String())
}
if len(cert.DNSNames) != 3 ||
cert.DNSNames[0] != "bar.foobar.com" ||
cert.DNSNames[1] != "foo.foobar.com" ||
cert.DNSNames[2] != "foobar.com" {
2018-02-16 22:19:34 +00:00
t.Fatalf("unexpected DNS SANs %v", cert.DNSNames)
}
if len(os.Getenv("VAULT_VERBOSE_PKITESTS")) > 0 {
t.Logf("certificate 2 to check:\n%s", certStr)
}
2018-02-16 22:19:34 +00:00
// Valid for both
oid1, type1, val1 := "1.3.6.1.4.1.311.20.2.3", "utf8", "devops@nope.com"
oid2, type2, val2 := "1.3.6.1.4.1.311.20.2.4", "utf-8", "d234e@foobar.com"
otherNames := []string{
fmt.Sprintf("%s;%s:%s", oid1, type1, val1),
fmt.Sprintf("%s;%s:%s", oid2, type2, val2),
}
2018-02-16 22:19:34 +00:00
resp, err = client.Logical().Write("root/issue/test", map[string]interface{}{
"common_name": "foobar.com",
"ip_sans": "1.2.3.4",
"alt_names": "foo.foobar.com,bar.foobar.com",
"ttl": "1h",
"other_sans": strings.Join(otherNames, ","),
2018-02-16 22:19:34 +00:00
})
if err != nil {
t.Fatal(err)
}
certStr = resp.Data["certificate"].(string)
block, _ = pem.Decode([]byte(certStr))
cert, err = x509.ParseCertificate(block.Bytes)
if err != nil {
t.Fatal(err)
}
if cert.IPAddresses[0].String() != "1.2.3.4" {
t.Fatalf("unexpected IP SAN %q", cert.IPAddresses[0].String())
}
if len(cert.DNSNames) != 3 ||
cert.DNSNames[0] != "bar.foobar.com" ||
cert.DNSNames[1] != "foo.foobar.com" ||
cert.DNSNames[2] != "foobar.com" {
2018-02-16 22:19:34 +00:00
t.Fatalf("unexpected DNS SANs %v", cert.DNSNames)
}
expectedOtherNames := []otherNameUtf8{{oid1, val1}, {oid2, val2}}
foundOtherNames, err := getOtherSANsFromX509Extensions(cert.Extensions)
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(expectedOtherNames, foundOtherNames); len(diff) != 0 {
t.Errorf("unexpected otherNames: %v", diff)
}
if len(os.Getenv("VAULT_VERBOSE_PKITESTS")) > 0 {
t.Logf("certificate 3 to check:\n%s", certStr)
}
}
func TestBackend_AllowedSerialNumbers(t *testing.T) {
coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"pki": Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
client := cluster.Cores[0].Client
var err error
err = client.Sys().Mount("root", &api.MountInput{
Type: "pki",
Config: api.MountConfigInput{
DefaultLeaseTTL: "16h",
MaxLeaseTTL: "60h",
},
})
if err != nil {
t.Fatal(err)
}
var resp *api.Secret
var certStr string
var block *pem.Block
var cert *x509.Certificate
_, err = client.Logical().Write("root/root/generate/internal", map[string]interface{}{
"ttl": "40h",
"common_name": "myvault.com",
})
if err != nil {
t.Fatal(err)
}
// First test that Serial Numbers are not allowed
_, err = client.Logical().Write("root/roles/test", map[string]interface{}{
"allow_any_name": true,
"enforce_hostnames": false,
})
if err != nil {
t.Fatal(err)
}
resp, err = client.Logical().Write("root/issue/test", map[string]interface{}{
"common_name": "foobar",
"ttl": "1h",
})
if err != nil {
t.Fatal(err)
}
resp, err = client.Logical().Write("root/issue/test", map[string]interface{}{
"common_name": "foobar",
"ttl": "1h",
"serial_number": "foobar",
})
if err == nil {
t.Fatal("expected error")
}
// Update the role to allow serial numbers
_, err = client.Logical().Write("root/roles/test", map[string]interface{}{
"allow_any_name": true,
"enforce_hostnames": false,
"allowed_serial_numbers": "f00*,b4r*",
})
if err != nil {
t.Fatal(err)
}
resp, err = client.Logical().Write("root/issue/test", map[string]interface{}{
"common_name": "foobar",
"ttl": "1h",
// Not a valid serial number
"serial_number": "foobar",
})
if err == nil {
t.Fatal("expected error")
}
// Valid for first possibility
resp, err = client.Logical().Write("root/issue/test", map[string]interface{}{
"common_name": "foobar",
"serial_number": "f00bar",
})
if err != nil {
t.Fatal(err)
}
certStr = resp.Data["certificate"].(string)
block, _ = pem.Decode([]byte(certStr))
cert, err = x509.ParseCertificate(block.Bytes)
if err != nil {
t.Fatal(err)
}
if cert.Subject.SerialNumber != "f00bar" {
t.Fatalf("unexpected Subject SerialNumber %s", cert.Subject.SerialNumber)
}
if len(os.Getenv("VAULT_VERBOSE_PKITESTS")) > 0 {
t.Logf("certificate 1 to check:\n%s", certStr)
}
// Valid for second possibility
resp, err = client.Logical().Write("root/issue/test", map[string]interface{}{
"common_name": "foobar",
"serial_number": "b4rf00",
})
if err != nil {
t.Fatal(err)
}
certStr = resp.Data["certificate"].(string)
block, _ = pem.Decode([]byte(certStr))
cert, err = x509.ParseCertificate(block.Bytes)
if err != nil {
t.Fatal(err)
}
if cert.Subject.SerialNumber != "b4rf00" {
t.Fatalf("unexpected Subject SerialNumber %s", cert.Subject.SerialNumber)
}
if len(os.Getenv("VAULT_VERBOSE_PKITESTS")) > 0 {
t.Logf("certificate 2 to check:\n%s", certStr)
}
2018-02-16 22:19:34 +00:00
}
2018-06-15 19:32:25 +00:00
func TestBackend_URI_SANs(t *testing.T) {
coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"pki": Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
client := cluster.Cores[0].Client
var err error
err = client.Sys().Mount("root", &api.MountInput{
Type: "pki",
Config: api.MountConfigInput{
DefaultLeaseTTL: "16h",
MaxLeaseTTL: "60h",
},
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("root/root/generate/internal", map[string]interface{}{
"ttl": "40h",
"common_name": "myvault.com",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("root/roles/test", map[string]interface{}{
"allowed_domains": []string{"foobar.com", "zipzap.com"},
"allow_bare_domains": true,
"allow_subdomains": true,
"allow_ip_sans": true,
"allowed_uri_sans": []string{"http://someuri/abc", "spiffe://host.com/*"},
})
if err != nil {
t.Fatal(err)
}
// First test some bad stuff that shouldn't work
_, err = client.Logical().Write("root/issue/test", map[string]interface{}{
"common_name": "foobar.com",
"ip_sans": "1.2.3.4",
"alt_names": "foo.foobar.com,bar.foobar.com",
"ttl": "1h",
"uri_sans": "http://www.mydomain.com/zxf",
})
if err == nil {
t.Fatal("expected error")
}
// Test valid single entry
_, err = client.Logical().Write("root/issue/test", map[string]interface{}{
"common_name": "foobar.com",
"ip_sans": "1.2.3.4",
"alt_names": "foo.foobar.com,bar.foobar.com",
"ttl": "1h",
"uri_sans": "http://someuri/abc",
})
if err != nil {
t.Fatal(err)
}
// Test globed entry
_, err = client.Logical().Write("root/issue/test", map[string]interface{}{
"common_name": "foobar.com",
"ip_sans": "1.2.3.4",
"alt_names": "foo.foobar.com,bar.foobar.com",
"ttl": "1h",
"uri_sans": "spiffe://host.com/something",
})
if err != nil {
t.Fatal(err)
}
// Test multiple entries
resp, err := client.Logical().Write("root/issue/test", map[string]interface{}{
"common_name": "foobar.com",
"ip_sans": "1.2.3.4",
"alt_names": "foo.foobar.com,bar.foobar.com",
"ttl": "1h",
"uri_sans": "spiffe://host.com/something,http://someuri/abc",
})
if err != nil {
t.Fatal(err)
}
certStr := resp.Data["certificate"].(string)
block, _ := pem.Decode([]byte(certStr))
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
t.Fatal(err)
}
URI0, _ := url.Parse("spiffe://host.com/something")
URI1, _ := url.Parse("http://someuri/abc")
if len(cert.URIs) != 2 {
t.Fatalf("expected 2 valid URIs SANs %v", cert.URIs)
}
if cert.URIs[0].String() != URI0.String() || cert.URIs[1].String() != URI1.String() {
t.Fatalf(
"expected URIs SANs %v to equal provided values spiffe://host.com/something, http://someuri/abc",
cert.URIs)
}
}
func TestBackend_AllowedDomainsTemplate(t *testing.T) {
coreConfig := &vault.CoreConfig{
CredentialBackends: map[string]logical.Factory{
"userpass": userpass.Factory,
},
LogicalBackends: map[string]logical.Factory{
"pki": Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
client := cluster.Cores[0].Client
// Write test policy for userpass auth method.
err := client.Sys().PutPolicy("test", `
path "pki/*" {
capabilities = ["update"]
}`)
if err != nil {
t.Fatal(err)
}
// Enable userpass auth method.
if err := client.Sys().EnableAuth("userpass", "userpass", ""); err != nil {
t.Fatal(err)
}
// Configure test role for userpass.
if _, err := client.Logical().Write("auth/userpass/users/userpassname", map[string]interface{}{
"password": "test",
"policies": "test",
}); err != nil {
t.Fatal(err)
}
// Login userpass for test role and keep client token.
secret, err := client.Logical().Write("auth/userpass/login/userpassname", map[string]interface{}{
"password": "test",
})
if err != nil || secret == nil {
t.Fatal(err)
}
userpassToken := secret.Auth.ClientToken
// Get auth accessor for identity template.
auths, err := client.Sys().ListAuth()
if err != nil {
t.Fatal(err)
}
userpassAccessor := auths["userpass/"].Accessor
// Mount PKI.
err = client.Sys().Mount("pki", &api.MountInput{
Type: "pki",
Config: api.MountConfigInput{
DefaultLeaseTTL: "16h",
MaxLeaseTTL: "60h",
},
})
if err != nil {
t.Fatal(err)
}
// Generate internal CA.
_, err = client.Logical().Write("pki/root/generate/internal", map[string]interface{}{
"ttl": "40h",
"common_name": "myvault.com",
})
if err != nil {
t.Fatal(err)
}
// Write role PKI.
_, err = client.Logical().Write("pki/roles/test", map[string]interface{}{
"allowed_domains": []string{
"foobar.com", "zipzap.com", "{{identity.entity.aliases." + userpassAccessor + ".name}}",
"foo.{{identity.entity.aliases." + userpassAccessor + ".name}}.example.com",
},
"allowed_domains_template": true,
"allow_bare_domains": true,
})
if err != nil {
t.Fatal(err)
}
// Issue certificate with userpassToken.
client.SetToken(userpassToken)
_, err = client.Logical().Write("pki/issue/test", map[string]interface{}{"common_name": "userpassname"})
if err != nil {
t.Fatal(err)
}
// Issue certificate for foobar.com to verify allowed_domain_templae doesnt break plain domains.
_, err = client.Logical().Write("pki/issue/test", map[string]interface{}{"common_name": "foobar.com"})
if err != nil {
t.Fatal(err)
}
// Issue certificate for unknown userpassname.
_, err = client.Logical().Write("pki/issue/test", map[string]interface{}{"common_name": "unknownuserpassname"})
if err == nil {
t.Fatal("expected error")
}
// Issue certificate for foo.userpassname.domain.
_, err = client.Logical().Write("pki/issue/test", map[string]interface{}{"common_name": "foo.userpassname.example.com"})
if err != nil {
t.Fatal("expected error")
}
// Set allowed_domains_template to false.
_, err = client.Logical().Write("pki/roles/test", map[string]interface{}{
"allowed_domains_template": false,
})
if err != nil {
t.Fatal(err)
}
// Issue certificate with userpassToken.
_, err = client.Logical().Write("pki/issue/test", map[string]interface{}{"common_name": "userpassname"})
if err == nil {
t.Fatal("expected error")
}
}
2018-02-20 05:03:45 +00:00
func setCerts() {
cak, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
panic(err)
}
marshaledKey, err := x509.MarshalECPrivateKey(cak)
if err != nil {
panic(err)
}
keyPEMBlock := &pem.Block{
Type: "EC PRIVATE KEY",
Bytes: marshaledKey,
}
ecCAKey = strings.TrimSpace(string(pem.EncodeToMemory(keyPEMBlock)))
2018-02-20 05:03:45 +00:00
if err != nil {
panic(err)
}
subjKeyID, err := certutil.GetSubjKeyID(cak)
if err != nil {
panic(err)
}
caCertTemplate := &x509.Certificate{
Subject: pkix.Name{
CommonName: "root.localhost",
},
SubjectKeyId: subjKeyID,
DNSNames: []string{"root.localhost"},
KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign),
SerialNumber: big.NewInt(mathrand.Int63()),
NotAfter: time.Now().Add(262980 * time.Hour),
BasicConstraintsValid: true,
2018-11-07 21:52:01 +00:00
IsCA: true,
2018-02-20 05:03:45 +00:00
}
caBytes, err := x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, cak.Public(), cak)
if err != nil {
panic(err)
}
caCertPEMBlock := &pem.Block{
Type: "CERTIFICATE",
Bytes: caBytes,
}
ecCACert = strings.TrimSpace(string(pem.EncodeToMemory(caCertPEMBlock)))
2018-02-20 05:03:45 +00:00
rak, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
panic(err)
}
marshaledKey = x509.MarshalPKCS1PrivateKey(rak)
keyPEMBlock = &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: marshaledKey,
}
rsaCAKey = strings.TrimSpace(string(pem.EncodeToMemory(keyPEMBlock)))
2018-02-20 05:03:45 +00:00
if err != nil {
panic(err)
}
subjKeyID, err = certutil.GetSubjKeyID(rak)
if err != nil {
panic(err)
}
caBytes, err = x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, rak.Public(), rak)
if err != nil {
panic(err)
}
caCertPEMBlock = &pem.Block{
Type: "CERTIFICATE",
Bytes: caBytes,
}
rsaCACert = strings.TrimSpace(string(pem.EncodeToMemory(caCertPEMBlock)))
2018-02-20 05:03:45 +00:00
}
func TestBackend_RevokePlusTidy_Intermediate(t *testing.T) {
// Enable PKI secret engine
coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"pki": Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
cores := cluster.Cores
vault.TestWaitActive(t, cores[0].Core)
client := cores[0].Client
var err error
// Mount /pki as a root CA
err = client.Sys().Mount("pki", &api.MountInput{
Type: "pki",
Config: api.MountConfigInput{
DefaultLeaseTTL: "16h",
MaxLeaseTTL: "32h",
},
})
if err != nil {
t.Fatal(err)
}
// Set the cluster's certificate as the root CA in /pki
pemBundleRootCA := string(cluster.CACertPEM) + string(cluster.CAKeyPEM)
_, err = client.Logical().Write("pki/config/ca", map[string]interface{}{
"pem_bundle": pemBundleRootCA,
})
if err != nil {
t.Fatal(err)
}
// Mount /pki2 to operate as an intermediate CA
err = client.Sys().Mount("pki2", &api.MountInput{
Type: "pki",
Config: api.MountConfigInput{
DefaultLeaseTTL: "16h",
MaxLeaseTTL: "32h",
},
})
if err != nil {
t.Fatal(err)
}
// Create a CSR for the intermediate CA
secret, err := client.Logical().Write("pki2/intermediate/generate/internal", nil)
if err != nil {
t.Fatal(err)
}
intermediateCSR := secret.Data["csr"].(string)
// Sign the intermediate CSR using /pki
secret, err = client.Logical().Write("pki/root/sign-intermediate", map[string]interface{}{
"permitted_dns_domains": ".myvault.com",
"csr": intermediateCSR,
"ttl": "10s",
})
if err != nil {
t.Fatal(err)
}
intermediateCertSerial := secret.Data["serial_number"].(string)
intermediateCASerialColon := strings.ReplaceAll(strings.ToLower(intermediateCertSerial), ":", "-")
// Get the intermediate cert after signing
secret, err = client.Logical().Read("pki/cert/" + intermediateCASerialColon)
if err != nil {
t.Fatal(err)
}
if secret == nil || len(secret.Data) == 0 || len(secret.Data["certificate"].(string)) == 0 {
t.Fatal("expected certificate information from read operation")
}
// Issue a revoke on on /pki
_, err = client.Logical().Write("pki/revoke", map[string]interface{}{
"serial_number": intermediateCertSerial,
})
if err != nil {
t.Fatal(err)
}
// Revoke adds a fixed 2s buffer, so we sleep for a bit longer to ensure
// the revocation time is past the current time.
time.Sleep(3 * time.Second)
// Issue a tidy on /pki
_, err = client.Logical().Write("pki/tidy", map[string]interface{}{
"tidy_cert_store": true,
"tidy_revoked_certs": true,
"safety_buffer": "1s",
})
if err != nil {
t.Fatal(err)
}
// Sleep a bit to make sure we're past the safety buffer
time.Sleep(2 * time.Second)
// Get CRL and ensure the tidied cert is still in the list after the tidy
// operation since it's not past the NotAfter (ttl) value yet.
req := client.NewRequest("GET", "/v1/pki/crl")
resp, err := client.RawRequest(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
crlBytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("err: %s", err)
}
if len(crlBytes) == 0 {
t.Fatalf("expected CRL in response body")
}
crl, err := x509.ParseDERCRL(crlBytes)
if err != nil {
t.Fatal(err)
}
revokedCerts := crl.TBSCertList.RevokedCertificates
if len(revokedCerts) == 0 {
t.Fatal("expected CRL to be non-empty")
}
sn := certutil.GetHexFormatted(revokedCerts[0].SerialNumber.Bytes(), ":")
if sn != intermediateCertSerial {
t.Fatalf("expected: %v, got: %v", intermediateCertSerial, sn)
}
// Wait for cert to expire
time.Sleep(10 * time.Second)
// Issue a tidy on /pki
_, err = client.Logical().Write("pki/tidy", map[string]interface{}{
"tidy_cert_store": true,
"tidy_revoked_certs": true,
"safety_buffer": "1s",
})
if err != nil {
t.Fatal(err)
}
// Sleep a bit to make sure we're past the safety buffer
time.Sleep(2 * time.Second)
req = client.NewRequest("GET", "/v1/pki/crl")
resp, err = client.RawRequest(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
crlBytes, err = ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("err: %s", err)
}
if len(crlBytes) == 0 {
t.Fatalf("expected CRL in response body")
}
crl, err = x509.ParseDERCRL(crlBytes)
if err != nil {
t.Fatal(err)
}
revokedCerts = crl.TBSCertList.RevokedCertificates
if len(revokedCerts) != 0 {
t.Fatal("expected CRL to be empty")
}
}
2018-02-20 05:03:45 +00:00
var (
initTest sync.Once
rsaCAKey string
rsaCACert string
ecCAKey string
ecCACert string
)