diff --git a/builtin/logical/pki/acme_billing.go b/builtin/logical/pki/acme_billing.go new file mode 100644 index 000000000..642e0f4fc --- /dev/null +++ b/builtin/logical/pki/acme_billing.go @@ -0,0 +1,25 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package pki + +import ( + "context" + "fmt" + + "github.com/hashicorp/vault/sdk/logical" +) + +func (b *backend) doTrackBilling(ctx context.Context, identifiers []*ACMEIdentifier) error { + billingView, ok := b.System().(logical.ACMEBillingSystemView) + if !ok { + return fmt.Errorf("failed to perform cast to ACME billing system view interface") + } + + var realized []string + for _, identifier := range identifiers { + realized = append(realized, fmt.Sprintf("%s/%s", identifier.Type, identifier.OriginalValue)) + } + + return billingView.CreateActivityCountEventForIdentifiers(ctx, realized) +} diff --git a/builtin/logical/pki/acme_billing_test.go b/builtin/logical/pki/acme_billing_test.go new file mode 100644 index 000000000..959e4a383 --- /dev/null +++ b/builtin/logical/pki/acme_billing_test.go @@ -0,0 +1,296 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package pki + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/json" + "strings" + "testing" + "time" + + "golang.org/x/crypto/acme" + + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/builtin/logical/pki/dnstest" + "github.com/hashicorp/vault/helper/constants" + "github.com/hashicorp/vault/helper/timeutil" + + "github.com/stretchr/testify/require" +) + +// TestACMEBilling is a basic test that will validate client counts created via ACME workflows. +func TestACMEBilling(t *testing.T) { + t.Parallel() + timeutil.SkipAtEndOfMonth(t) + + cluster, client, _ := setupAcmeBackend(t) + defer cluster.Cleanup() + + dns := dnstest.SetupResolver(t, "dadgarcorp.com") + defer dns.Cleanup() + + // Enable additional mounts. + setupAcmeBackendOnClusterAtPath(t, cluster, client, "pki2") + setupAcmeBackendOnClusterAtPath(t, cluster, client, "ns1/pki") + setupAcmeBackendOnClusterAtPath(t, cluster, client, "ns2/pki") + + // Enable custom DNS resolver for testing. + for _, mount := range []string{"pki", "pki2", "ns1/pki", "ns2/pki"} { + _, err := client.Logical().Write(mount+"/config/acme", map[string]interface{}{ + "dns_resolver": dns.GetLocalAddr(), + }) + require.NoError(t, err, "failed to set local dns resolver address for testing on mount: "+mount) + } + + // Enable client counting. + _, err := client.Logical().Write("/sys/internal/counters/config", map[string]interface{}{ + "enabled": "enable", + }) + require.NoError(t, err, "failed to enable client counting") + + // Setup ACME clients. We refresh account keys each time for consistency. + acmeClientPKI := getAcmeClientForCluster(t, cluster, "/v1/pki/acme/", nil) + acmeClientPKI2 := getAcmeClientForCluster(t, cluster, "/v1/pki2/acme/", nil) + acmeClientPKINS1 := getAcmeClientForCluster(t, cluster, "/v1/ns1/pki/acme/", nil) + acmeClientPKINS2 := getAcmeClientForCluster(t, cluster, "/v1/ns2/pki/acme/", nil) + + // Get our initial count. + expectedCount := validateClientCount(t, client, "", -1, "initial fetch") + + // Unique identifier: should increase by one. + doACMEForDomainWithDNS(t, dns, &acmeClientPKI, []string{"dadgarcorp.com"}) + expectedCount = validateClientCount(t, client, "pki", expectedCount+1, "new certificate") + + // Different identifier; should increase by one. + doACMEForDomainWithDNS(t, dns, &acmeClientPKI, []string{"example.dadgarcorp.com"}) + expectedCount = validateClientCount(t, client, "pki", expectedCount+1, "new certificate") + + // While same identifiers, used together and so thus are unique; increase by one. + doACMEForDomainWithDNS(t, dns, &acmeClientPKI, []string{"example.dadgarcorp.com", "dadgarcorp.com"}) + expectedCount = validateClientCount(t, client, "pki", expectedCount+1, "new certificate") + + // Same identifiers in different order are not unique; keep the same. + doACMEForDomainWithDNS(t, dns, &acmeClientPKI, []string{"dadgarcorp.com", "example.dadgarcorp.com"}) + expectedCount = validateClientCount(t, client, "pki", expectedCount, "different order; same identifiers") + + // Using a different mount shouldn't affect counts. + doACMEForDomainWithDNS(t, dns, &acmeClientPKI2, []string{"dadgarcorp.com"}) + expectedCount = validateClientCount(t, client, "", expectedCount, "different mount; same identifiers") + + // But using a different identifier should. + doACMEForDomainWithDNS(t, dns, &acmeClientPKI2, []string{"pki2.dadgarcorp.com"}) + expectedCount = validateClientCount(t, client, "pki2", expectedCount+1, "different mount with different identifiers") + + // A new identifier in a unique namespace will affect results. + doACMEForDomainWithDNS(t, dns, &acmeClientPKINS1, []string{"unique.dadgarcorp.com"}) + expectedCount = validateClientCount(t, client, "ns1/pki", expectedCount+1, "unique identifier in a namespace") + + // But in a different namespace with the existing identifier will not. + doACMEForDomainWithDNS(t, dns, &acmeClientPKINS2, []string{"unique.dadgarcorp.com"}) + expectedCount = validateClientCount(t, client, "", expectedCount, "existing identifier in a namespace") + doACMEForDomainWithDNS(t, dns, &acmeClientPKI2, []string{"unique.dadgarcorp.com"}) + expectedCount = validateClientCount(t, client, "", expectedCount, "existing identifier outside of a namespace") + + // Creating a unique identifier in a namespace with a mount with the + // same name as another namespace should increase counts as well. + doACMEForDomainWithDNS(t, dns, &acmeClientPKINS2, []string{"very-unique.dadgarcorp.com"}) + expectedCount = validateClientCount(t, client, "ns2/pki", expectedCount+1, "unique identifier in a different namespace") +} + +func validateClientCount(t *testing.T, client *api.Client, mount string, expected int64, message string) int64 { + resp, err := client.Logical().Read("/sys/internal/counters/activity/monthly") + require.NoError(t, err, "failed to fetch client count values") + t.Logf("got client count numbers: %v", resp) + + require.NotNil(t, resp) + require.NotNil(t, resp.Data) + require.Contains(t, resp.Data, "non_entity_clients") + require.Contains(t, resp.Data, "months") + + rawCount := resp.Data["non_entity_clients"].(json.Number) + count, err := rawCount.Int64() + require.NoError(t, err, "failed to parse number as int64: "+rawCount.String()) + + if expected != -1 { + require.Equal(t, expected, count, "value of client counts did not match expectations: "+message) + } + + if mount == "" { + return count + } + + months := resp.Data["months"].([]interface{}) + if len(months) > 1 { + t.Fatalf("running across a month boundary despite using SkipAtEndOfMonth(...); rerun test from start fully in the next month instead") + } + + require.Equal(t, 1, len(months), "expected only a single month when running this test") + + monthlyInfo := months[0].(map[string]interface{}) + + // Validate this month's aggregate counts match the overall value. + require.Contains(t, monthlyInfo, "counts", "expected monthly info to contain a count key") + monthlyCounts := monthlyInfo["counts"].(map[string]interface{}) + require.Contains(t, monthlyCounts, "non_entity_clients", "expected month[0].counts to contain a non_entity_clients key") + monthlyCountNonEntityRaw := monthlyCounts["non_entity_clients"].(json.Number) + monthlyCountNonEntity, err := monthlyCountNonEntityRaw.Int64() + require.NoError(t, err, "failed to parse number as int64: "+monthlyCountNonEntityRaw.String()) + require.Equal(t, count, monthlyCountNonEntity, "expected equal values for non entity client counts") + + // Validate this mount's namespace is included in the namespaces list, + // if this is enterprise. Otherwise, if its OSS or we don't have a + // namespace, we default to the value root. + mountNamespace := "root" + mountPath := mount + "/" + if constants.IsEnterprise && strings.Contains(mount, "/") { + pieces := strings.Split(mount, "/") + require.Equal(t, 2, len(pieces), "we do not support nested namespaces in this test") + mountNamespace = pieces[0] + mountPath = pieces[1] + "/" + } + + require.Contains(t, monthlyInfo, "namespaces", "expected monthly info to contain a namespaces key") + monthlyNamespaces := monthlyInfo["namespaces"].([]interface{}) + foundNamespace := false + for index, namespaceRaw := range monthlyNamespaces { + namespace := namespaceRaw.(map[string]interface{}) + require.Contains(t, namespace, "namespace_id", "expected monthly.namespaces[%v] to contain a namespace_id key", index) + namespaceId := namespace["namespace_id"].(string) + + if namespaceId != mountNamespace { + t.Logf("skipping non-matching namespace %v: %v != %v / %v", index, namespaceId, mountNamespace, namespace) + continue + } + + foundNamespace = true + + // This namespace must have a non-empty aggregate non-entity count. + require.Contains(t, namespace, "counts", "expected monthly.namespaces[%v] to contain a counts key", index) + namespaceCounts := namespace["counts"].(map[string]interface{}) + require.Contains(t, namespaceCounts, "non_entity_clients", "expected namespace counts to contain a non_entity_clients key") + namespaceCountNonEntityRaw := namespaceCounts["non_entity_clients"].(json.Number) + namespaceCountNonEntity, err := namespaceCountNonEntityRaw.Int64() + require.NoError(t, err, "failed to parse number as int64: "+namespaceCountNonEntityRaw.String()) + require.Greater(t, namespaceCountNonEntity, int64(0), "expected at least one non-entity client count value in the namespace") + + require.Contains(t, namespace, "mounts", "expected monthly.namespaces[%v] to contain a mounts key", index) + namespaceMounts := namespace["mounts"].([]interface{}) + foundMount := false + for mountIndex, mountRaw := range namespaceMounts { + mountInfo := mountRaw.(map[string]interface{}) + require.Contains(t, mountInfo, "mount_path", "expected monthly.namespaces[%v].mounts[%v] to contain a mount_path key", index, mountIndex) + mountInfoPath := mountInfo["mount_path"].(string) + if mountPath != mountInfoPath { + t.Logf("skipping non-matching mount path %v in namespace %v: %v != %v / %v of %v", mountIndex, index, mountPath, mountInfoPath, mountInfo, namespace) + continue + } + + foundMount = true + + // This mount must also have a non-empty non-entity client count. + require.Contains(t, mountInfo, "counts", "expected monthly.namespaces[%v].mounts[%v] to contain a counts key", index, mountIndex) + mountCounts := mountInfo["counts"].(map[string]interface{}) + require.Contains(t, mountCounts, "non_entity_clients", "expected mount counts to contain a non_entity_clients key") + mountCountNonEntityRaw := mountCounts["non_entity_clients"].(json.Number) + mountCountNonEntity, err := mountCountNonEntityRaw.Int64() + require.NoError(t, err, "failed to parse number as int64: "+mountCountNonEntityRaw.String()) + require.Greater(t, mountCountNonEntity, int64(0), "expected at least one non-entity client count value in the mount") + } + + require.True(t, foundMount, "expected to find the mount "+mountPath+" in the list of mounts for namespace, but did not") + } + + require.True(t, foundNamespace, "expected to find the namespace "+mountNamespace+" in the list of namespaces, but did not") + + return count +} + +func doACMEForDomainWithDNS(t *testing.T, dns *dnstest.TestServer, acmeClient *acme.Client, domains []string) *x509.Certificate { + cr := &x509.CertificateRequest{ + Subject: pkix.Name{CommonName: domains[0]}, + DNSNames: domains, + } + + accountKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err, "failed to generate account key") + acmeClient.Key = accountKey + + testCtx, cancelFunc := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancelFunc() + + // Register the client. + _, err = acmeClient.Register(testCtx, &acme.Account{Contact: []string{"mailto:ipsans@dadgarcorp.com"}}, func(tosURL string) bool { return true }) + require.NoError(t, err, "failed registering account") + + // Create the Order + var orderIdentifiers []acme.AuthzID + for _, domain := range domains { + orderIdentifiers = append(orderIdentifiers, acme.AuthzID{Type: "dns", Value: domain}) + } + order, err := acmeClient.AuthorizeOrder(testCtx, orderIdentifiers) + require.NoError(t, err, "failed creating ACME order") + + // Fetch its authorizations. + var auths []*acme.Authorization + for _, authUrl := range order.AuthzURLs { + authorization, err := acmeClient.GetAuthorization(testCtx, authUrl) + require.NoError(t, err, "failed to lookup authorization at url: %s", authUrl) + auths = append(auths, authorization) + } + + // For each dns-01 challenge, place the record in the associated DNS resolver. + var challengesToAccept []*acme.Challenge + for _, auth := range auths { + for _, challenge := range auth.Challenges { + if challenge.Status != acme.StatusPending { + t.Logf("ignoring challenge not in status pending: %v", challenge) + continue + } + + if challenge.Type == "dns-01" { + challengeBody, err := acmeClient.DNS01ChallengeRecord(challenge.Token) + require.NoError(t, err, "failed generating challenge response") + + dns.AddRecord("_acme-challenge."+auth.Identifier.Value, "TXT", challengeBody) + defer dns.RemoveRecord("_acme-challenge."+auth.Identifier.Value, "TXT", challengeBody) + + require.NoError(t, err, "failed setting DNS record") + + challengesToAccept = append(challengesToAccept, challenge) + } + } + } + + dns.PushConfig() + require.GreaterOrEqual(t, len(challengesToAccept), 1, "Need at least one challenge, got none") + + // Tell the ACME server, that they can now validate those challenges. + for _, challenge := range challengesToAccept { + _, err = acmeClient.Accept(testCtx, challenge) + require.NoError(t, err, "failed to accept challenge: %v", challenge) + } + + // Wait for the order/challenges to be validated. + _, err = acmeClient.WaitOrder(testCtx, order.URI) + require.NoError(t, err, "failed waiting for order to be ready") + + // Create/sign the CSR and ask ACME server to sign it returning us the final certificate + csrKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + csr, err := x509.CreateCertificateRequest(rand.Reader, cr, csrKey) + require.NoError(t, err, "failed generating csr") + + certs, _, err := acmeClient.CreateOrderCert(testCtx, order.FinalizeURL, csr, false) + require.NoError(t, err, "failed to get a certificate back from ACME") + + acmeCert, err := x509.ParseCertificate(certs[0]) + require.NoError(t, err, "failed parsing acme cert bytes") + + return acmeCert +} diff --git a/builtin/logical/pki/path_acme_order.go b/builtin/logical/pki/path_acme_order.go index d64979560..22ba2917a 100644 --- a/builtin/logical/pki/path_acme_order.go +++ b/builtin/logical/pki/path_acme_order.go @@ -285,6 +285,11 @@ func (b *backend) acmeFinalizeOrderHandler(ac *acmeContext, _ *logical.Request, return nil, fmt.Errorf("failed saving updated order: %w", err) } + if err := b.doTrackBilling(ac.sc.Context, order.Identifiers); err != nil { + b.Logger().Error("failed to track billing for order", "order", orderId, "error", err) + err = nil + } + return formatOrderResponse(ac, order), nil } diff --git a/builtin/logical/pki/path_acme_test.go b/builtin/logical/pki/path_acme_test.go index 5a4f26357..5d2796cc9 100644 --- a/builtin/logical/pki/path_acme_test.go +++ b/builtin/logical/pki/path_acme_test.go @@ -21,20 +21,18 @@ import ( "testing" "time" - "github.com/hashicorp/vault/sdk/helper/jsonutil" - - "github.com/go-test/deep" - - "github.com/hashicorp/go-cleanhttp" "golang.org/x/crypto/acme" "golang.org/x/net/http2" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/constants" vaulthttp "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/sdk/helper/jsonutil" + "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" - "github.com/hashicorp/vault/sdk/logical" - + "github.com/go-test/deep" + "github.com/hashicorp/go-cleanhttp" "github.com/stretchr/testify/require" "gopkg.in/square/go-jose.v2/json" ) @@ -572,28 +570,66 @@ func TestAcmeConfigChecksPublicAcmeEnv(t *testing.T) { func setupAcmeBackend(t *testing.T) (*vault.TestCluster, *api.Client, string) { cluster, client := setupTestPkiCluster(t) - // Setting templated AIAs should succeed. - pathConfig := client.Address() + "/v1/pki" + return setupAcmeBackendOnClusterAtPath(t, cluster, client, "pki") +} - _, err := client.Logical().WriteWithContext(context.Background(), "pki/config/cluster", map[string]interface{}{ +func setupAcmeBackendOnClusterAtPath(t *testing.T, cluster *vault.TestCluster, client *api.Client, mount string) (*vault.TestCluster, *api.Client, string) { + mount = strings.Trim(mount, "/") + + // Setting templated AIAs should succeed. + pathConfig := client.Address() + "/v1/" + mount + + namespace := "" + mountName := mount + if mount != "pki" { + if strings.Contains(mount, "/") && constants.IsEnterprise { + ns_pieces := strings.Split(mount, "/") + c := len(ns_pieces) + // mount is c-1 + ns_name := ns_pieces[c-2] + if len(ns_pieces) > 2 { + // Parent's namespaces + parent := strings.Join(ns_pieces[0:c-2], "/") + _, err := client.WithNamespace(parent).Logical().Write("/sys/namespaces/"+ns_name, nil) + require.NoError(t, err, "failed to create nested namespaces "+parent+" -> "+ns_name) + } else { + _, err := client.Logical().Write("/sys/namespaces/"+ns_name, nil) + require.NoError(t, err, "failed to create nested namespace "+ns_name) + } + namespace = strings.Join(ns_pieces[0:c-1], "/") + mountName = ns_pieces[c-1] + } + + err := client.WithNamespace(namespace).Sys().Mount(mountName, &api.MountInput{ + Type: "pki", + Config: api.MountConfigInput{ + DefaultLeaseTTL: "16h", + MaxLeaseTTL: "60h", + }, + }) + require.NoError(t, err, "failed to mount new PKI instance at "+mount) + } + + _, err := client.Logical().WriteWithContext(context.Background(), mount+"/config/cluster", map[string]interface{}{ "path": pathConfig, - "aia_path": "http://localhost:8200/cdn/pki", + "aia_path": "http://localhost:8200/cdn/" + mount, }) require.NoError(t, err) - _, err = client.Logical().WriteWithContext(context.Background(), "pki/config/acme", map[string]interface{}{ - "enabled": true, + _, err = client.Logical().WriteWithContext(context.Background(), mount+"/config/acme", map[string]interface{}{ + "enabled": true, + "eab_policy": "not-required", }) require.NoError(t, err) // Allow certain headers to pass through for ACME support - _, err = client.Logical().WriteWithContext(context.Background(), "sys/mounts/pki/tune", map[string]interface{}{ + _, err = client.WithNamespace(namespace).Logical().WriteWithContext(context.Background(), "sys/mounts/"+mountName+"/tune", map[string]interface{}{ "allowed_response_headers": []string{"Last-Modified", "Replay-Nonce", "Link", "Location"}, "max_lease_ttl": "920000h", }) require.NoError(t, err, "failed tuning mount response headers") - resp, err := client.Logical().WriteWithContext(context.Background(), "/pki/issuers/generate/root/internal", + resp, err := client.Logical().WriteWithContext(context.Background(), mount+"/issuers/generate/root/internal", map[string]interface{}{ "issuer_name": "root-ca", "key_name": "root-key", @@ -604,7 +640,7 @@ func setupAcmeBackend(t *testing.T) (*vault.TestCluster, *api.Client, string) { }) require.NoError(t, err, "failed creating root CA") - resp, err = client.Logical().WriteWithContext(context.Background(), "/pki/issuers/generate/intermediate/internal", + resp, err = client.Logical().WriteWithContext(context.Background(), mount+"/issuers/generate/intermediate/internal", map[string]interface{}{ "key_name": "int-key", "key_type": "ec", @@ -614,7 +650,7 @@ func setupAcmeBackend(t *testing.T) (*vault.TestCluster, *api.Client, string) { intermediateCSR := resp.Data["csr"].(string) // Sign the intermediate CSR using /pki - resp, err = client.Logical().Write("pki/issuer/root-ca/sign-intermediate", map[string]interface{}{ + resp, err = client.Logical().Write(mount+"/issuer/root-ca/sign-intermediate", map[string]interface{}{ "csr": intermediateCSR, "ttl": "720h", "max_ttl": "7200h", @@ -623,7 +659,7 @@ func setupAcmeBackend(t *testing.T) (*vault.TestCluster, *api.Client, string) { intermediateCertPEM := resp.Data["certificate"].(string) // Configure the intermediate cert as the CA in /pki2 - resp, err = client.Logical().Write("/pki/issuers/import/cert", map[string]interface{}{ + resp, err = client.Logical().Write(mount+"/issuers/import/cert", map[string]interface{}{ "pem_bundle": intermediateCertPEM, }) require.NoError(t, err, "failed importing intermediary cert") @@ -631,17 +667,17 @@ func setupAcmeBackend(t *testing.T) (*vault.TestCluster, *api.Client, string) { require.Len(t, importedIssuersRaw, 1) intCaUuid := importedIssuersRaw[0].(string) - _, err = client.Logical().Write("/pki/issuer/"+intCaUuid, map[string]interface{}{ + _, err = client.Logical().Write(mount+"/issuer/"+intCaUuid, map[string]interface{}{ "issuer_name": "int-ca", }) require.NoError(t, err, "failed updating issuer name") - _, err = client.Logical().Write("/pki/config/issuers", map[string]interface{}{ + _, err = client.Logical().Write(mount+"/config/issuers", map[string]interface{}{ "default": "int-ca", }) require.NoError(t, err, "failed updating default issuer") - _, err = client.Logical().Write("/pki/roles/test-role", map[string]interface{}{ + _, err = client.Logical().Write(mount+"/roles/test-role", map[string]interface{}{ "ttl_duration": "365h", "max_ttl_duration": "720h", "key_type": "any", @@ -651,7 +687,7 @@ func setupAcmeBackend(t *testing.T) (*vault.TestCluster, *api.Client, string) { }) require.NoError(t, err, "failed creating role test-role") - _, err = client.Logical().Write("/pki/roles/acme", map[string]interface{}{ + _, err = client.Logical().Write(mount+"/roles/acme", map[string]interface{}{ "ttl_duration": "365h", "max_ttl_duration": "720h", "key_type": "any", diff --git a/sdk/logical/acme_billing.go b/sdk/logical/acme_billing.go new file mode 100644 index 000000000..6e4f6ef39 --- /dev/null +++ b/sdk/logical/acme_billing.go @@ -0,0 +1,10 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package logical + +import "context" + +type ACMEBillingSystemView interface { + CreateActivityCountEventForIdentifiers(ctx context.Context, identifiers []string) error +} diff --git a/vault/acme_billing_system_view.go b/vault/acme_billing_system_view.go new file mode 100644 index 000000000..cf87833e4 --- /dev/null +++ b/vault/acme_billing_system_view.go @@ -0,0 +1,60 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package vault + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "fmt" + "sort" + "strings" + "time" + + "github.com/hashicorp/vault/sdk/logical" +) + +// See comment in command/format.go +const hopeDelim = "♨" + +type acmeBillingSystemViewImpl struct { + extendedSystemView + logical.ManagedKeySystemView + + core *Core + entry *MountEntry +} + +var _ logical.ACMEBillingSystemView = (*acmeBillingSystemViewImpl)(nil) + +func (c *Core) NewAcmeBillingSystemView(sysView interface{}, managed logical.ManagedKeySystemView) *acmeBillingSystemViewImpl { + es := sysView.(extendedSystemViewImpl) + des := es.dynamicSystemView + + return &acmeBillingSystemViewImpl{ + extendedSystemView: es, + ManagedKeySystemView: managed, + core: c, + entry: des.mountEntry, + } +} + +func (a *acmeBillingSystemViewImpl) CreateActivityCountEventForIdentifiers(ctx context.Context, identifiers []string) error { + // Fake our clientID from the identifiers, but ensure it is + // independent of ordering. + // + // TODO: Because of prefixing currently handled by AddActivityToFragment, + // we do not need to ensure it is globally unique. + sort.Strings(identifiers) + joinedIdentifiers := "[" + strings.Join(identifiers, "]"+hopeDelim+"[") + "]" + identifiersHash := sha256.Sum256([]byte(joinedIdentifiers)) + clientID := base64.RawURLEncoding.EncodeToString(identifiersHash[:]) + + // Log so users can correlate ACME requests to client count tokens. + activityType := "acme" + a.core.activityLog.logger.Debug(fmt.Sprintf("Handling ACME client count event for [%v] -> %v", identifiers, clientID)) + a.core.activityLog.AddActivityToFragment(clientID, a.entry.NamespaceID, time.Now().Unix(), activityType, a.entry.Accessor) + + return nil +} diff --git a/vault/activity_log.go b/vault/activity_log.go index cf87287fa..73ddb33b7 100644 --- a/vault/activity_log.go +++ b/vault/activity_log.go @@ -78,6 +78,12 @@ const ( // all fragments and segments no longer storing token counts in the directtokens // storage path. trackedTWESegmentPeriod = 35 * 24 + + // Known types of activity events; there's presently two internal event + // types (tokens/clients with and without entities), but we're beginning + // to support additional buckets for e.g., ACME requests. + nonEntityTokenActivityType = "non-entity-token" + entityActivityType = "entity" ) type segmentInfo struct { @@ -1401,12 +1407,36 @@ func (a *ActivityLog) AddEntityToFragment(entityID string, namespaceID string, t // AddClientToFragment checks a client ID for uniqueness and // if not already present, adds it to the current fragment. -// The timestamp is a Unix timestamp *without* nanoseconds, as that -// is what token.CreationTime uses. +// +// See note below about AddActivityToFragment. func (a *ActivityLog) AddClientToFragment(clientID string, namespaceID string, timestamp int64, isTWE bool, mountAccessor string) { + // TWE == token without entity + if isTWE { + a.AddActivityToFragment(clientID, namespaceID, timestamp, nonEntityTokenActivityType, mountAccessor) + return + } + + a.AddActivityToFragment(clientID, namespaceID, timestamp, entityActivityType, mountAccessor) +} + +// AddActivityToFragment adds a client count event of any type to +// add to the current fragment. ClientIDs must be unique across +// all types; if not already present, we will add it to the current +// fragment. The timestamp is a Unix timestamp *without* nanoseconds, +// as that is what token.CreationTime uses. +func (a *ActivityLog) AddActivityToFragment(clientID string, namespaceID string, timestamp int64, activityType string, mountAccessor string) { // Check whether entity ID already recorded var present bool + // TODO: This hack enables separate tracking of events without handling + // separate storage buckets for counting these event types. Consider + // removing if the event type is otherwise clear; notably though, this + // does help ensure clientID uniqueness across different types of tokens, + // assuming it does not break any other downstream systems. + if activityType != nonEntityTokenActivityType && activityType != entityActivityType { + clientID = activityType + "." + clientID + } + a.fragmentLock.RLock() if a.enabled { _, present = a.partialMonthClientTracker[clientID] @@ -1440,7 +1470,10 @@ func (a *ActivityLog) AddClientToFragment(clientID string, namespaceID string, t // Track whether the clientID corresponds to a token without an entity or not. // This field is backward compatible, as the default is 0, so records created // from pre-1.9 activityLog code will automatically be marked as having an entity. - if isTWE { + if activityType != entityActivityType { + // TODO: This part needs to be modified potentially for separate + // storage buckets of custom event types. Consider setting the above + // condition to activityType == nonEntityTokenEventType in the future. clientRecord.NonEntity = true } diff --git a/vault/mount_util.go b/vault/mount_util.go index c0ebbdbe4..d31c16e68 100644 --- a/vault/mount_util.go +++ b/vault/mount_util.go @@ -56,11 +56,13 @@ func verifyNamespace(*Core, *namespace.Namespace, *MountEntry) error { return ni // mount-specific entries; because this should be called when setting // up a mountEntry, it doesn't check to ensure that me is not nil func (c *Core) mountEntrySysView(entry *MountEntry) extendedSystemView { - return extendedSystemViewImpl{ + esi := extendedSystemViewImpl{ dynamicSystemView{ core: c, mountEntry: entry, perfStandby: c.perfStandby, }, } + + return c.NewAcmeBillingSystemView(esi, nil /* managed keys system view */) }