open-vault/builtin/logical/pki/acme_billing_test.go
hc-github-team-secure-vault-core 7a0badc115
backport of commit 072f0dd7c85be8d4e4390cf417900efce5e38d56 (#21656)
Co-authored-by: Alexander Scheel <alex.scheel@hashicorp.com>
2023-07-07 15:45:01 +00:00

323 lines
14 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package pki
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"strings"
"testing"
"time"
"golang.org/x/crypto/acme"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/logical/pki/dnstest"
"github.com/hashicorp/vault/helper/constants"
"github.com/hashicorp/vault/helper/timeutil"
"github.com/hashicorp/vault/vault"
"github.com/hashicorp/vault/vault/activity"
"github.com/stretchr/testify/require"
)
// TestACMEBilling is a basic test that will validate client counts created via ACME workflows.
func TestACMEBilling(t *testing.T) {
t.Parallel()
timeutil.SkipAtEndOfMonth(t)
cluster, client, _ := setupAcmeBackend(t)
defer cluster.Cleanup()
dns := dnstest.SetupResolver(t, "dadgarcorp.com")
defer dns.Cleanup()
// Enable additional mounts.
setupAcmeBackendOnClusterAtPath(t, cluster, client, "pki2")
setupAcmeBackendOnClusterAtPath(t, cluster, client, "ns1/pki")
setupAcmeBackendOnClusterAtPath(t, cluster, client, "ns2/pki")
// Enable custom DNS resolver for testing.
for _, mount := range []string{"pki", "pki2", "ns1/pki", "ns2/pki"} {
_, err := client.Logical().Write(mount+"/config/acme", map[string]interface{}{
"dns_resolver": dns.GetLocalAddr(),
})
require.NoError(t, err, "failed to set local dns resolver address for testing on mount: "+mount)
}
// Enable client counting.
_, err := client.Logical().Write("/sys/internal/counters/config", map[string]interface{}{
"enabled": "enable",
})
require.NoError(t, err, "failed to enable client counting")
// Setup ACME clients. We refresh account keys each time for consistency.
acmeClientPKI := getAcmeClientForCluster(t, cluster, "/v1/pki/acme/", nil)
acmeClientPKI2 := getAcmeClientForCluster(t, cluster, "/v1/pki2/acme/", nil)
acmeClientPKINS1 := getAcmeClientForCluster(t, cluster, "/v1/ns1/pki/acme/", nil)
acmeClientPKINS2 := getAcmeClientForCluster(t, cluster, "/v1/ns2/pki/acme/", nil)
// Get our initial count.
expectedCount := validateClientCount(t, client, "", -1, "initial fetch")
// Unique identifier: should increase by one.
doACMEForDomainWithDNS(t, dns, acmeClientPKI, []string{"dadgarcorp.com"})
expectedCount = validateClientCount(t, client, "pki", expectedCount+1, "new certificate")
// Different identifier; should increase by one.
doACMEForDomainWithDNS(t, dns, acmeClientPKI, []string{"example.dadgarcorp.com"})
expectedCount = validateClientCount(t, client, "pki", expectedCount+1, "new certificate")
// While same identifiers, used together and so thus are unique; increase by one.
doACMEForDomainWithDNS(t, dns, acmeClientPKI, []string{"example.dadgarcorp.com", "dadgarcorp.com"})
expectedCount = validateClientCount(t, client, "pki", expectedCount+1, "new certificate")
// Same identifiers in different order are not unique; keep the same.
doACMEForDomainWithDNS(t, dns, acmeClientPKI, []string{"dadgarcorp.com", "example.dadgarcorp.com"})
expectedCount = validateClientCount(t, client, "pki", expectedCount, "different order; same identifiers")
// Using a different mount shouldn't affect counts.
doACMEForDomainWithDNS(t, dns, acmeClientPKI2, []string{"dadgarcorp.com"})
expectedCount = validateClientCount(t, client, "", expectedCount, "different mount; same identifiers")
// But using a different identifier should.
doACMEForDomainWithDNS(t, dns, acmeClientPKI2, []string{"pki2.dadgarcorp.com"})
expectedCount = validateClientCount(t, client, "pki2", expectedCount+1, "different mount with different identifiers")
// A new identifier in a unique namespace will affect results.
doACMEForDomainWithDNS(t, dns, acmeClientPKINS1, []string{"unique.dadgarcorp.com"})
expectedCount = validateClientCount(t, client, "ns1/pki", expectedCount+1, "unique identifier in a namespace")
// But in a different namespace with the existing identifier will not.
doACMEForDomainWithDNS(t, dns, acmeClientPKINS2, []string{"unique.dadgarcorp.com"})
expectedCount = validateClientCount(t, client, "", expectedCount, "existing identifier in a namespace")
doACMEForDomainWithDNS(t, dns, acmeClientPKI2, []string{"unique.dadgarcorp.com"})
expectedCount = validateClientCount(t, client, "", expectedCount, "existing identifier outside of a namespace")
// Creating a unique identifier in a namespace with a mount with the
// same name as another namespace should increase counts as well.
doACMEForDomainWithDNS(t, dns, acmeClientPKINS2, []string{"very-unique.dadgarcorp.com"})
expectedCount = validateClientCount(t, client, "ns2/pki", expectedCount+1, "unique identifier in a different namespace")
// Check the current fragment
fragment := cluster.Cores[0].Core.ResetActivityLog()[0]
if fragment == nil {
t.Fatal("no fragment created")
}
validateAcmeClientTypes(t, fragment, expectedCount)
}
func validateAcmeClientTypes(t *testing.T, fragment *activity.LogFragment, expectedCount int64) {
t.Helper()
if int64(len(fragment.Clients)) != expectedCount {
t.Fatalf("bad number of entities, expected %v: got %v, entities are: %v", expectedCount, len(fragment.Clients), fragment.Clients)
}
for _, ac := range fragment.Clients {
if ac.ClientType != vault.ACMEActivityType {
t.Fatalf("Couldn't find expected '%v' client_type in %v", vault.ACMEActivityType, fragment.Clients)
}
}
}
func validateClientCount(t *testing.T, client *api.Client, mount string, expected int64, message string) int64 {
resp, err := client.Logical().Read("/sys/internal/counters/activity/monthly")
require.NoError(t, err, "failed to fetch client count values")
t.Logf("got client count numbers: %v", resp)
require.NotNil(t, resp)
require.NotNil(t, resp.Data)
require.Contains(t, resp.Data, "non_entity_clients")
require.Contains(t, resp.Data, "months")
rawCount := resp.Data["non_entity_clients"].(json.Number)
count, err := rawCount.Int64()
require.NoError(t, err, "failed to parse number as int64: "+rawCount.String())
if expected != -1 {
require.Equal(t, expected, count, "value of client counts did not match expectations: "+message)
}
if mount == "" {
return count
}
months := resp.Data["months"].([]interface{})
if len(months) > 1 {
t.Fatalf("running across a month boundary despite using SkipAtEndOfMonth(...); rerun test from start fully in the next month instead")
}
require.Equal(t, 1, len(months), "expected only a single month when running this test")
monthlyInfo := months[0].(map[string]interface{})
// Validate this month's aggregate counts match the overall value.
require.Contains(t, monthlyInfo, "counts", "expected monthly info to contain a count key")
monthlyCounts := monthlyInfo["counts"].(map[string]interface{})
require.Contains(t, monthlyCounts, "non_entity_clients", "expected month[0].counts to contain a non_entity_clients key")
monthlyCountNonEntityRaw := monthlyCounts["non_entity_clients"].(json.Number)
monthlyCountNonEntity, err := monthlyCountNonEntityRaw.Int64()
require.NoError(t, err, "failed to parse number as int64: "+monthlyCountNonEntityRaw.String())
require.Equal(t, count, monthlyCountNonEntity, "expected equal values for non entity client counts")
// Validate this mount's namespace is included in the namespaces list,
// if this is enterprise. Otherwise, if its OSS or we don't have a
// namespace, we default to the value root.
mountNamespace := ""
mountPath := mount + "/"
if constants.IsEnterprise && strings.Contains(mount, "/") {
pieces := strings.Split(mount, "/")
require.Equal(t, 2, len(pieces), "we do not support nested namespaces in this test")
mountNamespace = pieces[0] + "/"
mountPath = pieces[1] + "/"
}
require.Contains(t, monthlyInfo, "namespaces", "expected monthly info to contain a namespaces key")
monthlyNamespaces := monthlyInfo["namespaces"].([]interface{})
foundNamespace := false
for index, namespaceRaw := range monthlyNamespaces {
namespace := namespaceRaw.(map[string]interface{})
require.Contains(t, namespace, "namespace_path", "expected monthly.namespaces[%v] to contain a namespace_path key", index)
namespacePath := namespace["namespace_path"].(string)
if namespacePath != mountNamespace {
t.Logf("skipping non-matching namespace %v: %v != %v / %v", index, namespacePath, mountNamespace, namespace)
continue
}
foundNamespace = true
// This namespace must have a non-empty aggregate non-entity count.
require.Contains(t, namespace, "counts", "expected monthly.namespaces[%v] to contain a counts key", index)
namespaceCounts := namespace["counts"].(map[string]interface{})
require.Contains(t, namespaceCounts, "non_entity_clients", "expected namespace counts to contain a non_entity_clients key")
namespaceCountNonEntityRaw := namespaceCounts["non_entity_clients"].(json.Number)
namespaceCountNonEntity, err := namespaceCountNonEntityRaw.Int64()
require.NoError(t, err, "failed to parse number as int64: "+namespaceCountNonEntityRaw.String())
require.Greater(t, namespaceCountNonEntity, int64(0), "expected at least one non-entity client count value in the namespace")
require.Contains(t, namespace, "mounts", "expected monthly.namespaces[%v] to contain a mounts key", index)
namespaceMounts := namespace["mounts"].([]interface{})
foundMount := false
for mountIndex, mountRaw := range namespaceMounts {
mountInfo := mountRaw.(map[string]interface{})
require.Contains(t, mountInfo, "mount_path", "expected monthly.namespaces[%v].mounts[%v] to contain a mount_path key", index, mountIndex)
mountInfoPath := mountInfo["mount_path"].(string)
if mountPath != mountInfoPath {
t.Logf("skipping non-matching mount path %v in namespace %v: %v != %v / %v of %v", mountIndex, index, mountPath, mountInfoPath, mountInfo, namespace)
continue
}
foundMount = true
// This mount must also have a non-empty non-entity client count.
require.Contains(t, mountInfo, "counts", "expected monthly.namespaces[%v].mounts[%v] to contain a counts key", index, mountIndex)
mountCounts := mountInfo["counts"].(map[string]interface{})
require.Contains(t, mountCounts, "non_entity_clients", "expected mount counts to contain a non_entity_clients key")
mountCountNonEntityRaw := mountCounts["non_entity_clients"].(json.Number)
mountCountNonEntity, err := mountCountNonEntityRaw.Int64()
require.NoError(t, err, "failed to parse number as int64: "+mountCountNonEntityRaw.String())
require.Greater(t, mountCountNonEntity, int64(0), "expected at least one non-entity client count value in the mount")
}
require.True(t, foundMount, "expected to find the mount "+mountPath+" in the list of mounts for namespace, but did not")
}
require.True(t, foundNamespace, "expected to find the namespace "+mountNamespace+" in the list of namespaces, but did not")
return count
}
func doACMEForDomainWithDNS(t *testing.T, dns *dnstest.TestServer, acmeClient *acme.Client, domains []string) *x509.Certificate {
cr := &x509.CertificateRequest{
Subject: pkix.Name{CommonName: domains[0]},
DNSNames: domains,
}
return doACMEForCSRWithDNS(t, dns, acmeClient, domains, cr)
}
func doACMEForCSRWithDNS(t *testing.T, dns *dnstest.TestServer, acmeClient *acme.Client, domains []string, cr *x509.CertificateRequest) *x509.Certificate {
accountKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err, "failed to generate account key")
acmeClient.Key = accountKey
testCtx, cancelFunc := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancelFunc()
// Register the client.
_, err = acmeClient.Register(testCtx, &acme.Account{Contact: []string{"mailto:ipsans@dadgarcorp.com"}}, func(tosURL string) bool { return true })
require.NoError(t, err, "failed registering account")
// Create the Order
var orderIdentifiers []acme.AuthzID
for _, domain := range domains {
orderIdentifiers = append(orderIdentifiers, acme.AuthzID{Type: "dns", Value: domain})
}
order, err := acmeClient.AuthorizeOrder(testCtx, orderIdentifiers)
require.NoError(t, err, "failed creating ACME order")
// Fetch its authorizations.
var auths []*acme.Authorization
for _, authUrl := range order.AuthzURLs {
authorization, err := acmeClient.GetAuthorization(testCtx, authUrl)
require.NoError(t, err, "failed to lookup authorization at url: %s", authUrl)
auths = append(auths, authorization)
}
// For each dns-01 challenge, place the record in the associated DNS resolver.
var challengesToAccept []*acme.Challenge
for _, auth := range auths {
for _, challenge := range auth.Challenges {
if challenge.Status != acme.StatusPending {
t.Logf("ignoring challenge not in status pending: %v", challenge)
continue
}
if challenge.Type == "dns-01" {
challengeBody, err := acmeClient.DNS01ChallengeRecord(challenge.Token)
require.NoError(t, err, "failed generating challenge response")
dns.AddRecord("_acme-challenge."+auth.Identifier.Value, "TXT", challengeBody)
defer dns.RemoveRecord("_acme-challenge."+auth.Identifier.Value, "TXT", challengeBody)
require.NoError(t, err, "failed setting DNS record")
challengesToAccept = append(challengesToAccept, challenge)
}
}
}
dns.PushConfig()
require.GreaterOrEqual(t, len(challengesToAccept), 1, "Need at least one challenge, got none")
// Tell the ACME server, that they can now validate those challenges.
for _, challenge := range challengesToAccept {
_, err = acmeClient.Accept(testCtx, challenge)
require.NoError(t, err, "failed to accept challenge: %v", challenge)
}
// Wait for the order/challenges to be validated.
_, err = acmeClient.WaitOrder(testCtx, order.URI)
require.NoError(t, err, "failed waiting for order to be ready")
// Create/sign the CSR and ask ACME server to sign it returning us the final certificate
csrKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
csr, err := x509.CreateCertificateRequest(rand.Reader, cr, csrKey)
require.NoError(t, err, "failed generating csr")
certs, _, err := acmeClient.CreateOrderCert(testCtx, order.FinalizeURL, csr, false)
require.NoError(t, err, "failed to get a certificate back from ACME")
acmeCert, err := x509.ParseCertificate(certs[0])
require.NoError(t, err, "failed parsing acme cert bytes")
return acmeCert
}