8b6e07d960
To reduce the chance of some tests not being run because it does not match the regex passed to '-run'. Also document why some tests are allowed to be skipped on CI.
520 lines
13 KiB
Go
520 lines
13 KiB
Go
package ca
|
|
|
|
import (
|
|
"crypto/x509"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"os"
|
|
"os/exec"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/hashicorp/consul/agent/connect"
|
|
"github.com/hashicorp/consul/agent/structs"
|
|
"github.com/hashicorp/consul/sdk/freeport"
|
|
"github.com/hashicorp/consul/sdk/testutil/retry"
|
|
vaultapi "github.com/hashicorp/vault/api"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestVaultCAProvider_VaultTLSConfig(t *testing.T) {
|
|
config := &structs.VaultCAProviderConfig{
|
|
CAFile: "/capath/ca.pem",
|
|
CAPath: "/capath/",
|
|
CertFile: "/certpath/cert.pem",
|
|
KeyFile: "/certpath/key.pem",
|
|
TLSServerName: "server.name",
|
|
TLSSkipVerify: true,
|
|
}
|
|
tlsConfig := vaultTLSConfig(config)
|
|
require := require.New(t)
|
|
require.Equal(config.CAFile, tlsConfig.CACert)
|
|
require.Equal(config.CAPath, tlsConfig.CAPath)
|
|
require.Equal(config.CertFile, tlsConfig.ClientCert)
|
|
require.Equal(config.KeyFile, tlsConfig.ClientKey)
|
|
require.Equal(config.TLSServerName, tlsConfig.TLSServerName)
|
|
require.Equal(config.TLSSkipVerify, tlsConfig.Insecure)
|
|
}
|
|
|
|
func TestVaultCAProvider_Bootstrap(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
skipIfVaultNotPresent(t)
|
|
|
|
provider, testVault := testVaultProvider(t)
|
|
defer testVault.Stop()
|
|
client := testVault.client
|
|
|
|
require := require.New(t)
|
|
|
|
cases := []struct {
|
|
certFunc func() (string, error)
|
|
backendPath string
|
|
}{
|
|
{
|
|
certFunc: provider.ActiveRoot,
|
|
backendPath: "pki-root/",
|
|
},
|
|
{
|
|
certFunc: provider.ActiveIntermediate,
|
|
backendPath: "pki-intermediate/",
|
|
},
|
|
}
|
|
|
|
// Verify the root and intermediate certs match the ones in the vault backends
|
|
for _, tc := range cases {
|
|
cert, err := tc.certFunc()
|
|
require.NoError(err)
|
|
req := client.NewRequest("GET", "/v1/"+tc.backendPath+"ca/pem")
|
|
resp, err := client.RawRequest(req)
|
|
require.NoError(err)
|
|
bytes, err := ioutil.ReadAll(resp.Body)
|
|
require.NoError(err)
|
|
require.Equal(cert, string(bytes))
|
|
|
|
// Should be a valid CA cert
|
|
parsed, err := connect.ParseCert(cert)
|
|
require.NoError(err)
|
|
require.True(parsed.IsCA)
|
|
require.Len(parsed.URIs, 1)
|
|
require.Equal(fmt.Sprintf("spiffe://%s.consul", provider.clusterID), parsed.URIs[0].String())
|
|
}
|
|
}
|
|
|
|
func assertCorrectKeyType(t *testing.T, want, certPEM string) {
|
|
t.Helper()
|
|
|
|
cert, err := connect.ParseCert(certPEM)
|
|
require.NoError(t, err)
|
|
|
|
switch want {
|
|
case "ec":
|
|
require.Equal(t, x509.ECDSA, cert.PublicKeyAlgorithm)
|
|
case "rsa":
|
|
require.Equal(t, x509.RSA, cert.PublicKeyAlgorithm)
|
|
default:
|
|
t.Fatal("test doesn't support key type")
|
|
}
|
|
}
|
|
|
|
func TestVaultCAProvider_SignLeaf(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
skipIfVaultNotPresent(t)
|
|
|
|
for _, tc := range KeyTestCases {
|
|
tc := tc
|
|
t.Run(tc.Desc, func(t *testing.T) {
|
|
require := require.New(t)
|
|
provider, testVault := testVaultProviderWithConfig(t, true, map[string]interface{}{
|
|
"LeafCertTTL": "1h",
|
|
"PrivateKeyType": tc.KeyType,
|
|
"PrivateKeyBits": tc.KeyBits,
|
|
})
|
|
defer testVault.Stop()
|
|
|
|
spiffeService := &connect.SpiffeIDService{
|
|
Host: "node1",
|
|
Namespace: "default",
|
|
Datacenter: "dc1",
|
|
Service: "foo",
|
|
}
|
|
|
|
rootPEM, err := provider.ActiveRoot()
|
|
require.NoError(err)
|
|
assertCorrectKeyType(t, tc.KeyType, rootPEM)
|
|
|
|
intPEM, err := provider.ActiveIntermediate()
|
|
require.NoError(err)
|
|
assertCorrectKeyType(t, tc.KeyType, intPEM)
|
|
|
|
// Generate a leaf cert for the service.
|
|
var firstSerial uint64
|
|
{
|
|
raw, _ := connect.TestCSR(t, spiffeService)
|
|
|
|
csr, err := connect.ParseCSR(raw)
|
|
require.NoError(err)
|
|
|
|
cert, err := provider.Sign(csr)
|
|
require.NoError(err)
|
|
|
|
parsed, err := connect.ParseCert(cert)
|
|
require.NoError(err)
|
|
require.Equal(parsed.URIs[0], spiffeService.URI())
|
|
firstSerial = parsed.SerialNumber.Uint64()
|
|
|
|
// Ensure the cert is valid now and expires within the correct limit.
|
|
now := time.Now()
|
|
require.True(parsed.NotAfter.Sub(now) < time.Hour)
|
|
require.True(parsed.NotBefore.Before(now))
|
|
|
|
// Make sure we can validate the cert as expected.
|
|
require.NoError(connect.ValidateLeaf(rootPEM, cert, []string{intPEM}))
|
|
}
|
|
|
|
// Generate a new cert for another service and make sure
|
|
// the serial number is unique.
|
|
spiffeService.Service = "bar"
|
|
{
|
|
raw, _ := connect.TestCSR(t, spiffeService)
|
|
|
|
csr, err := connect.ParseCSR(raw)
|
|
require.NoError(err)
|
|
|
|
cert, err := provider.Sign(csr)
|
|
require.NoError(err)
|
|
|
|
parsed, err := connect.ParseCert(cert)
|
|
require.NoError(err)
|
|
require.Equal(parsed.URIs[0], spiffeService.URI())
|
|
require.NotEqual(firstSerial, parsed.SerialNumber.Uint64())
|
|
|
|
// Ensure the cert is valid now and expires within the correct limit.
|
|
require.True(time.Until(parsed.NotAfter) < time.Hour)
|
|
require.True(parsed.NotBefore.Before(time.Now()))
|
|
|
|
// Make sure we can validate the cert as expected.
|
|
require.NoError(connect.ValidateLeaf(rootPEM, cert, []string{intPEM}))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestVaultCAProvider_CrossSignCA(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
skipIfVaultNotPresent(t)
|
|
|
|
tests := CASigningKeyTypeCases()
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.Desc, func(t *testing.T) {
|
|
require := require.New(t)
|
|
|
|
if tc.SigningKeyType != tc.CSRKeyType {
|
|
// See https://github.com/hashicorp/vault/issues/7709
|
|
t.Skip("Vault doesn't support cross-signing different key types yet.")
|
|
}
|
|
provider1, testVault1 := testVaultProviderWithConfig(t, true, map[string]interface{}{
|
|
"LeafCertTTL": "1h",
|
|
"PrivateKeyType": tc.SigningKeyType,
|
|
"PrivateKeyBits": tc.SigningKeyBits,
|
|
})
|
|
defer testVault1.Stop()
|
|
|
|
{
|
|
rootPEM, err := provider1.ActiveRoot()
|
|
require.NoError(err)
|
|
assertCorrectKeyType(t, tc.SigningKeyType, rootPEM)
|
|
|
|
intPEM, err := provider1.ActiveIntermediate()
|
|
require.NoError(err)
|
|
assertCorrectKeyType(t, tc.SigningKeyType, intPEM)
|
|
}
|
|
|
|
provider2, testVault2 := testVaultProviderWithConfig(t, true, map[string]interface{}{
|
|
"LeafCertTTL": "1h",
|
|
"PrivateKeyType": tc.CSRKeyType,
|
|
"PrivateKeyBits": tc.CSRKeyBits,
|
|
})
|
|
defer testVault2.Stop()
|
|
|
|
{
|
|
rootPEM, err := provider2.ActiveRoot()
|
|
require.NoError(err)
|
|
assertCorrectKeyType(t, tc.CSRKeyType, rootPEM)
|
|
|
|
intPEM, err := provider2.ActiveIntermediate()
|
|
require.NoError(err)
|
|
assertCorrectKeyType(t, tc.CSRKeyType, intPEM)
|
|
}
|
|
|
|
testCrossSignProviders(t, provider1, provider2)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestVaultProvider_SignIntermediate(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
skipIfVaultNotPresent(t)
|
|
|
|
tests := CASigningKeyTypeCases()
|
|
|
|
for _, tc := range tests {
|
|
tc := tc
|
|
t.Run(tc.Desc, func(t *testing.T) {
|
|
provider1, testVault1 := testVaultProviderWithConfig(t, true, map[string]interface{}{
|
|
"LeafCertTTL": "1h",
|
|
"PrivateKeyType": tc.SigningKeyType,
|
|
"PrivateKeyBits": tc.SigningKeyBits,
|
|
})
|
|
defer testVault1.Stop()
|
|
|
|
provider2, testVault2 := testVaultProviderWithConfig(t, false, map[string]interface{}{
|
|
"LeafCertTTL": "1h",
|
|
"PrivateKeyType": tc.CSRKeyType,
|
|
"PrivateKeyBits": tc.CSRKeyBits,
|
|
})
|
|
defer testVault2.Stop()
|
|
|
|
testSignIntermediateCrossDC(t, provider1, provider2)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestVaultProvider_SignIntermediateConsul(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
skipIfVaultNotPresent(t)
|
|
|
|
// primary = Vault, secondary = Consul
|
|
t.Run("pri=vault,sec=consul", func(t *testing.T) {
|
|
provider1, testVault1 := testVaultProviderWithConfig(t, true, nil)
|
|
defer testVault1.Stop()
|
|
|
|
conf := testConsulCAConfig()
|
|
delegate := newMockDelegate(t, conf)
|
|
provider2 := TestConsulProvider(t, delegate)
|
|
cfg := testProviderConfig(conf)
|
|
cfg.IsPrimary = false
|
|
cfg.Datacenter = "dc2"
|
|
require.NoError(t, provider2.Configure(cfg))
|
|
|
|
testSignIntermediateCrossDC(t, provider1, provider2)
|
|
})
|
|
|
|
// primary = Consul, secondary = Vault
|
|
t.Run("pri=consul,sec=vault", func(t *testing.T) {
|
|
conf := testConsulCAConfig()
|
|
delegate := newMockDelegate(t, conf)
|
|
provider1 := TestConsulProvider(t, delegate)
|
|
require.NoError(t, provider1.Configure(testProviderConfig(conf)))
|
|
require.NoError(t, provider1.GenerateRoot())
|
|
|
|
// Ensure that we don't configure vault to try and mint leafs that
|
|
// outlive their CA during the test (which hard fails in vault).
|
|
intermediateCertTTL := getIntermediateCertTTL(t, conf)
|
|
leafCertTTL := intermediateCertTTL - 4*time.Hour
|
|
|
|
overrideConf := map[string]interface{}{
|
|
"LeafCertTTL": []uint8(leafCertTTL.String()),
|
|
}
|
|
|
|
provider2, testVault2 := testVaultProviderWithConfig(t, false, overrideConf)
|
|
defer testVault2.Stop()
|
|
|
|
testSignIntermediateCrossDC(t, provider1, provider2)
|
|
})
|
|
}
|
|
|
|
func getIntermediateCertTTL(t *testing.T, caConf *structs.CAConfiguration) time.Duration {
|
|
t.Helper()
|
|
|
|
require.NotNil(t, caConf)
|
|
require.NotNil(t, caConf.Config)
|
|
|
|
iface, ok := caConf.Config["IntermediateCertTTL"]
|
|
require.True(t, ok)
|
|
|
|
ttlBytes, ok := iface.([]uint8)
|
|
require.True(t, ok)
|
|
|
|
ttlString := string(ttlBytes)
|
|
|
|
dur, err := time.ParseDuration(ttlString)
|
|
require.NoError(t, err)
|
|
return dur
|
|
}
|
|
|
|
func testVaultProvider(t *testing.T) (*VaultProvider, *testVaultServer) {
|
|
return testVaultProviderWithConfig(t, true, nil)
|
|
}
|
|
|
|
func testVaultProviderWithConfig(t *testing.T, isPrimary bool, rawConf map[string]interface{}) (*VaultProvider, *testVaultServer) {
|
|
testVault, err := runTestVault()
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
testVault.WaitUntilReady(t)
|
|
|
|
conf := map[string]interface{}{
|
|
"Address": testVault.addr,
|
|
"Token": testVault.rootToken,
|
|
"RootPKIPath": "pki-root/",
|
|
"IntermediatePKIPath": "pki-intermediate/",
|
|
// Tests duration parsing after msgpack type mangling during raft apply.
|
|
"LeafCertTTL": []uint8("72h"),
|
|
}
|
|
for k, v := range rawConf {
|
|
conf[k] = v
|
|
}
|
|
|
|
provider := &VaultProvider{}
|
|
|
|
cfg := ProviderConfig{
|
|
ClusterID: connect.TestClusterID,
|
|
Datacenter: "dc1",
|
|
IsPrimary: true,
|
|
RawConfig: conf,
|
|
}
|
|
|
|
if !isPrimary {
|
|
cfg.IsPrimary = false
|
|
cfg.Datacenter = "dc2"
|
|
}
|
|
|
|
if err := provider.Configure(cfg); err != nil {
|
|
testVault.Stop()
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if isPrimary {
|
|
if err = provider.GenerateRoot(); err != nil {
|
|
testVault.Stop()
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if _, err := provider.GenerateIntermediate(); err != nil {
|
|
testVault.Stop()
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
}
|
|
return provider, testVault
|
|
}
|
|
|
|
// skipIfVaultNotPresent skips the test if the vault binary is not in PATH.
|
|
//
|
|
// These tests may be skipped in CI. They are run as part of a separate
|
|
// integration test suite.
|
|
func skipIfVaultNotPresent(t *testing.T) {
|
|
vaultBinaryName := os.Getenv("VAULT_BINARY_NAME")
|
|
if vaultBinaryName == "" {
|
|
vaultBinaryName = "vault"
|
|
}
|
|
|
|
path, err := exec.LookPath(vaultBinaryName)
|
|
if err != nil || path == "" {
|
|
t.Skipf("%q not found on $PATH - download and install to run this test", vaultBinaryName)
|
|
}
|
|
}
|
|
|
|
func runTestVault() (*testVaultServer, error) {
|
|
vaultBinaryName := os.Getenv("VAULT_BINARY_NAME")
|
|
if vaultBinaryName == "" {
|
|
vaultBinaryName = "vault"
|
|
}
|
|
|
|
path, err := exec.LookPath(vaultBinaryName)
|
|
if err != nil || path == "" {
|
|
return nil, fmt.Errorf("%q not found on $PATH", vaultBinaryName)
|
|
}
|
|
|
|
ports := freeport.MustTake(2)
|
|
returnPortsFn := func() {
|
|
freeport.Return(ports)
|
|
}
|
|
|
|
var (
|
|
clientAddr = fmt.Sprintf("127.0.0.1:%d", ports[0])
|
|
clusterAddr = fmt.Sprintf("127.0.0.1:%d", ports[1])
|
|
)
|
|
|
|
const token = "root"
|
|
|
|
client, err := vaultapi.NewClient(&vaultapi.Config{
|
|
Address: "http://" + clientAddr,
|
|
})
|
|
if err != nil {
|
|
returnPortsFn()
|
|
return nil, err
|
|
}
|
|
client.SetToken(token)
|
|
|
|
args := []string{
|
|
"server",
|
|
"-dev",
|
|
"-dev-root-token-id",
|
|
token,
|
|
"-dev-listen-address",
|
|
clientAddr,
|
|
"-address",
|
|
clusterAddr,
|
|
}
|
|
|
|
cmd := exec.Command(vaultBinaryName, args...)
|
|
cmd.Stdout = ioutil.Discard
|
|
cmd.Stderr = ioutil.Discard
|
|
if err := cmd.Start(); err != nil {
|
|
returnPortsFn()
|
|
return nil, err
|
|
}
|
|
|
|
return &testVaultServer{
|
|
rootToken: token,
|
|
addr: "http://" + clientAddr,
|
|
cmd: cmd,
|
|
client: client,
|
|
returnPortsFn: returnPortsFn,
|
|
}, nil
|
|
}
|
|
|
|
type testVaultServer struct {
|
|
rootToken string
|
|
addr string
|
|
cmd *exec.Cmd
|
|
client *vaultapi.Client
|
|
|
|
// returnPortsFn will put the ports claimed for the test back into the
|
|
returnPortsFn func()
|
|
}
|
|
|
|
var printedVaultVersion sync.Once
|
|
|
|
func (v *testVaultServer) WaitUntilReady(t *testing.T) {
|
|
var version string
|
|
retry.Run(t, func(r *retry.R) {
|
|
resp, err := v.client.Sys().Health()
|
|
if err != nil {
|
|
r.Fatalf("err: %v", err)
|
|
}
|
|
if !resp.Initialized {
|
|
r.Fatalf("vault server is not initialized")
|
|
}
|
|
if resp.Sealed {
|
|
r.Fatalf("vault server is sealed")
|
|
}
|
|
version = resp.Version
|
|
})
|
|
printedVaultVersion.Do(func() {
|
|
fmt.Fprintf(os.Stderr, "[INFO] agent/connect/ca: testing with vault server version: %s\n", version)
|
|
})
|
|
}
|
|
|
|
func (v *testVaultServer) Stop() error {
|
|
// There was no process
|
|
if v.cmd == nil {
|
|
return nil
|
|
}
|
|
|
|
if v.cmd.Process != nil {
|
|
if err := v.cmd.Process.Signal(os.Interrupt); err != nil {
|
|
return fmt.Errorf("failed to kill vault server: %v", err)
|
|
}
|
|
}
|
|
|
|
// wait for the process to exit to be sure that the data dir can be
|
|
// deleted on all platforms.
|
|
if err := v.cmd.Wait(); err != nil {
|
|
return err
|
|
}
|
|
|
|
if v.returnPortsFn != nil {
|
|
v.returnPortsFn()
|
|
}
|
|
|
|
return nil
|
|
}
|