diff --git a/builtin/logical/pki/backend.go b/builtin/logical/pki/backend.go index 0bdb0a100..f2b700ed2 100644 --- a/builtin/logical/pki/backend.go +++ b/builtin/logical/pki/backend.go @@ -115,6 +115,10 @@ func Backend(conf *logical.BackendConfig) *backend { legacyCertBundleBackupPath, keyPrefix, }, + + WriteForwardedStorage: []string{ + crossRevocationPath, + }, }, Paths: []*framework.Path{ @@ -138,6 +142,7 @@ func Backend(conf *logical.BackendConfig) *backend { pathRevoke(&b), pathRevokeWithKey(&b), pathListCertsRevoked(&b), + pathListCertsRevocationQueue(&b), pathTidy(&b), pathTidyCancel(&b), pathTidyStatus(&b), @@ -430,6 +435,9 @@ func (b *backend) updatePkiStorageVersion(ctx context.Context, grabIssuersLock b } func (b *backend) invalidate(ctx context.Context, key string) { + isNotPerfPrimary := b.System().ReplicationState().HasState(consts.ReplicationDRSecondary|consts.ReplicationPerformanceStandby) || + (!b.System().LocalMount() && b.System().ReplicationState().HasState(consts.ReplicationPerformanceSecondary)) + switch { case strings.HasPrefix(key, legacyMigrationBundleLogKey): // This is for a secondary cluster to pick up that the migration has completed @@ -460,6 +468,25 @@ func (b *backend) invalidate(ctx context.Context, key string) { b.crlBuilder.markConfigDirty() case key == storageIssuerConfig: b.crlBuilder.invalidateCRLBuildTime() + case strings.HasPrefix(key, crossRevocationPrefix): + split := strings.Split(key, "/") + + if !strings.HasSuffix(key, "/confirmed") { + cluster := split[len(split)-2] + serial := split[len(split)-1] + // Only process confirmations on the perf primary. + b.crlBuilder.addCertForRevocationCheck(cluster, serial) + } else { + if len(split) >= 3 { + cluster := split[len(split)-3] + serial := split[len(split)-2] + if !isNotPerfPrimary { + b.crlBuilder.addCertForRevocationRemoval(cluster, serial) + } + } + } + + b.Logger().Debug("got replicated cross-cluster revocation: " + key) } } @@ -479,6 +506,11 @@ func (b *backend) periodicFunc(ctx context.Context, request *logical.Request) er return nil } + // First handle any global revocation queue entries. + if err := b.crlBuilder.processRevocationQueue(sc); err != nil { + return err + } + // Check if we're set to auto rebuild and a CRL is set to expire. if err := b.crlBuilder.checkForAutoRebuild(sc); err != nil { return err diff --git a/builtin/logical/pki/crl_util.go b/builtin/logical/pki/crl_util.go index 7a293a4a2..a0157f585 100644 --- a/builtin/logical/pki/crl_util.go +++ b/builtin/logical/pki/crl_util.go @@ -13,12 +13,15 @@ import ( atomic2 "go.uber.org/atomic" "github.com/hashicorp/vault/sdk/helper/certutil" + "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/errutil" "github.com/hashicorp/vault/sdk/logical" ) const ( revokedPath = "revoked/" + crossRevocationPrefix = "cross-revocation-queue/" + crossRevocationPath = crossRevocationPrefix + "{{clusterId}}/" deltaWALLastBuildSerialName = "last-build-serial" deltaWALLastRevokedSerialName = "last-revoked-serial" localDeltaWALPath = "delta-wal/" @@ -33,6 +36,20 @@ type revocationInfo struct { CertificateIssuer issuerID `json:"issuer_id"` } +type revocationRequest struct { + RequestedAt time.Time `json:"requested_at"` +} + +type revocationConfirmed struct { + RevokedAt string `json:"revoked_at"` + Source string `json:"source"` +} + +type revocationQueueEntry struct { + Cluster string + Serial string +} + type ( // Placeholder in case of migrations needing more data. Currently // we use the path name to store the serial number that was revoked. @@ -76,6 +93,12 @@ type crlBuilder struct { // Whether to invalidate our LastModifiedTime due to write on the // global issuance config. invalidate *atomic2.Bool + + // Global revocation queue entries get accepted by the invalidate func + // and passed to the crlBuilder for processing. + haveInitializedQueue bool + revQueue *revocationQueue + removalQueue *revocationQueue } const ( @@ -94,6 +117,8 @@ func newCRLBuilder(canRebuild bool) *crlBuilder { dirty: atomic2.NewBool(true), config: defaultCrlConfig, invalidate: atomic2.NewBool(false), + revQueue: newRevocationQueue(), + removalQueue: newRevocationQueue(), } } @@ -422,6 +447,187 @@ func (cb *crlBuilder) rebuildDeltaCRLsHoldingLock(sc *storageContext, forceNew b return buildAnyCRLs(sc, forceNew, true /* building delta */) } +func (cb *crlBuilder) addCertForRevocationCheck(cluster, serial string) { + entry := &revocationQueueEntry{ + Cluster: cluster, + Serial: serial, + } + cb.revQueue.Add(entry) +} + +func (cb *crlBuilder) addCertForRevocationRemoval(cluster, serial string) { + entry := &revocationQueueEntry{ + Cluster: cluster, + Serial: serial, + } + cb.removalQueue.Add(entry) +} + +func (cb *crlBuilder) maybeGatherQueueForFirstProcess(sc *storageContext, isNotPerfPrimary bool) error { + // Assume holding lock. + if cb.haveInitializedQueue { + return nil + } + + sc.Backend.Logger().Debug(fmt.Sprintf("gathering first time existing revocations")) + + clusters, err := sc.Storage.List(sc.Context, crossRevocationPrefix) + if err != nil { + return fmt.Errorf("failed to list cross-cluster revocation queue participating clusters: %w", err) + } + + sc.Backend.Logger().Debug(fmt.Sprintf("found %v clusters: %v", len(clusters), clusters)) + + for cIndex, cluster := range clusters { + cluster = cluster[0 : len(cluster)-1] + cPath := crossRevocationPrefix + cluster + "/" + serials, err := sc.Storage.List(sc.Context, cPath) + if err != nil { + return fmt.Errorf("failed to list cross-cluster revocation queue entries for cluster %v (%v): %w", cluster, cIndex, err) + } + + sc.Backend.Logger().Debug(fmt.Sprintf("found %v serials for cluster %v: %v", len(serials), cluster, serials)) + + for _, serial := range serials { + if serial[len(serial)-1] == '/' { + serial = serial[0 : len(serial)-1] + } + + ePath := cPath + serial + eConfirmPath := ePath + "/confirmed" + removalEntry, err := sc.Storage.Get(sc.Context, eConfirmPath) + + entry := &revocationQueueEntry{ + Cluster: cluster, + Serial: serial, + } + + // No removal entry yet; add to regular queue. Otherwise, slate it + // for removal if we're a perfPrimary. + if err != nil || removalEntry == nil { + cb.revQueue.Add(entry) + } else if !isNotPerfPrimary { + cb.removalQueue.Add(entry) + } else { + sc.Backend.Logger().Debug(fmt.Sprintf("ignoring confirmed revoked serial %v: %v vs %v ", serial, err, removalEntry)) + } + + // Overwrite the error; we don't really care about its contents + // at this step. + err = nil + } + } + + return nil +} + +func (cb *crlBuilder) processRevocationQueue(sc *storageContext) error { + sc.Backend.Logger().Debug(fmt.Sprintf("starting to process revocation requests")) + + isNotPerfPrimary := sc.Backend.System().ReplicationState().HasState(consts.ReplicationDRSecondary|consts.ReplicationPerformanceStandby) || + (!sc.Backend.System().LocalMount() && sc.Backend.System().ReplicationState().HasState(consts.ReplicationPerformanceSecondary)) + + // Before revoking certificates, we need to hold the lock for certificate + // storage. This prevents any parallel revocations and prevents us from + // multiple places. We do this before grabbing the contents of the + // revocation queues themselves, to ensure we interleave well with other + // invocations of this function and avoid duplicate work. + sc.Backend.revokeStorageLock.Lock() + defer sc.Backend.revokeStorageLock.Unlock() + + if err := cb.maybeGatherQueueForFirstProcess(sc, isNotPerfPrimary); err != nil { + return fmt.Errorf("failed to gather first queue: %v", err) + } + + revQueue := cb.revQueue.Iterate() + removalQueue := cb.removalQueue.Iterate() + + sc.Backend.Logger().Debug(fmt.Sprintf("gathered %v revocations and %v confirmation entries", len(revQueue), len(removalQueue))) + + for _, req := range revQueue { + sc.Backend.Logger().Debug(fmt.Sprintf("handling revocation request: %v", req)) + rPath := crossRevocationPrefix + req.Cluster + "/" + req.Serial + entry, err := sc.Storage.Get(sc.Context, rPath) + if err != nil { + return fmt.Errorf("failed to read cross-cluster revocation queue entry: %w", err) + } + if entry == nil { + // Skipping this entry; it was likely an incorrect invalidation + // caused by the primary cluster removing the confirmation. + cb.revQueue.Remove(req) + continue + } + + resp, err := tryRevokeCertBySerial(sc, req.Serial) + sc.Backend.Logger().Debug(fmt.Sprintf("checked local revocation entry: %v / %v", resp, err)) + if err == nil && resp != nil && !resp.IsError() && resp.Data != nil && resp.Data["state"].(string) == "revoked" { + if isNotPerfPrimary { + // Write a revocation queue removal entry. + confirmed := revocationConfirmed{ + RevokedAt: resp.Data["revocation_time_rfc3339"].(string), + Source: req.Cluster, + } + path := crossRevocationPath + req.Serial + "/confirmed" + confirmedEntry, err := logical.StorageEntryJSON(path, confirmed) + if err != nil { + return fmt.Errorf("failed to create storage entry for cross-cluster revocation confirmed response: %w", err) + } + + if err := sc.Storage.Put(sc.Context, confirmedEntry); err != nil { + return fmt.Errorf("error persisting cross-cluster revocation confirmation: %w\nThis may occur when the active node of the primary performance replication cluster is unavailable.", err) + } + } else { + // Since we're the active node of the primary cluster, go ahead + // and just remove it. + path := crossRevocationPrefix + req.Cluster + "/" + req.Serial + if err := sc.Storage.Delete(sc.Context, path); err != nil { + return fmt.Errorf("failed to delete processed revocation request: %w", err) + } + } + } else if err != nil { + // Because we fake being from a lease, we get the guarantee that + // err == nil == resp if the cert was already revoked; this means + // this err should actually be fatal. + return err + } + cb.revQueue.Remove(req) + } + + if isNotPerfPrimary { + sc.Backend.Logger().Debug(fmt.Sprintf("not on perf primary so done; ignoring any revocation confirmations")) + cb.removalQueue.RemoveAll() + cb.haveInitializedQueue = true + return nil + } + + clusters, err := sc.Storage.List(sc.Context, crossRevocationPrefix) + if err != nil { + return err + } + + for _, entry := range removalQueue { + sc.Backend.Logger().Debug(fmt.Sprintf("handling revocation confirmation: %v", entry)) + // First remove the revocation request. + for cIndex, cluster := range clusters { + eEntry := crossRevocationPrefix + cluster + entry.Serial + if err := sc.Storage.Delete(sc.Context, eEntry); err != nil { + return fmt.Errorf("failed to delete potential cross-cluster revocation entry for cluster %v (%v) and serial %v: %w", cluster, cIndex, entry.Serial, err) + } + } + + // Then remove the confirmation. + if err := sc.Storage.Delete(sc.Context, crossRevocationPrefix+entry.Cluster+"/"+entry.Serial+"/confirmed"); err != nil { + return fmt.Errorf("failed to delete cross-cluster revocation confirmation entry for cluster %v and serial %v: %w", entry.Cluster, entry.Serial, err) + } + + cb.removalQueue.Remove(entry) + } + + cb.haveInitializedQueue = true + + return nil +} + // Helper function to fetch a map of issuerID->parsed cert for revocation // usage. Unlike other paths, this needs to handle the legacy bundle // more gracefully than rejecting it outright. @@ -466,6 +672,31 @@ func fetchIssuerMapForRevocationChecking(sc *storageContext) (map[issuerID]*x509 return issuerIDCertMap, nil } +// Revoke a certificate from a given serial number if it is present in local +// storage. +func tryRevokeCertBySerial(sc *storageContext, serial string) (*logical.Response, error) { + certEntry, err := fetchCertBySerial(sc, "certs/", serial) + if err != nil { + switch err.(type) { + case errutil.UserError: + return logical.ErrorResponse(err.Error()), nil + default: + return nil, err + } + } + + if certEntry == nil { + return nil, nil + } + + cert, err := x509.ParseCertificate(certEntry.Value) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %w", err) + } + + return revokeCert(sc, cert) +} + // Revokes a cert, and tries to be smart about error recovery func revokeCert(sc *storageContext, cert *x509.Certificate) (*logical.Response, error) { // As this backend is self-contained and this function does not hook into @@ -517,6 +748,7 @@ func revokeCert(sc *storageContext, cert *x509.Certificate) (*logical.Response, resp := &logical.Response{ Data: map[string]interface{}{ "revocation_time": revInfo.RevocationTime, + "state": "revoked", }, } if !revInfo.RevocationTimeUTC.IsZero() { @@ -618,6 +850,7 @@ func revokeCert(sc *storageContext, cert *x509.Certificate) (*logical.Response, Data: map[string]interface{}{ "revocation_time": revInfo.RevocationTime, "revocation_time_rfc3339": revInfo.RevocationTimeUTC.Format(time.RFC3339Nano), + "state": "revoked", }, }, nil } diff --git a/builtin/logical/pki/fields.go b/builtin/logical/pki/fields.go index d42240e57..1f3022ab4 100644 --- a/builtin/logical/pki/fields.go +++ b/builtin/logical/pki/fields.go @@ -508,5 +508,23 @@ greater period of time. By default this is zero seconds.`, Default: "0s", } + fields["tidy_revocation_queue"] = &framework.FieldSchema{ + Type: framework.TypeBool, + Description: `Set to true to remove stale revocation queue entries +that haven't been confirmed by any active cluster. Only runs on the +active primary node`, + Default: defaultTidyConfig.RevocationQueue, + } + + fields["revocation_queue_safety_buffer"] = &framework.FieldSchema{ + Type: framework.TypeDurationSecond, + Description: `The amount of time that must pass from the +cross-cluster revocation request being initiated to when it will be +slated for removal. Setting this too low may remove valid revocation +requests before the owning cluster has a chance to process them, +especially if the cluster is offline.`, + Default: int(defaultTidyConfig.QueueSafetyBuffer / time.Second), // TypeDurationSecond currently requires defaults to be int + } + return fields } diff --git a/builtin/logical/pki/path_config_crl.go b/builtin/logical/pki/path_config_crl.go index ed73ce5c4..b9cfb44cc 100644 --- a/builtin/logical/pki/path_config_crl.go +++ b/builtin/logical/pki/path_config_crl.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + "github.com/hashicorp/vault/helper/constants" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/errutil" "github.com/hashicorp/vault/sdk/logical" @@ -23,6 +24,7 @@ type crlConfig struct { OcspExpiry string `json:"ocsp_expiry"` EnableDelta bool `json:"enable_delta"` DeltaRebuildInterval string `json:"delta_rebuild_interval"` + UseGlobalQueue bool `json:"cross_cluster_revocation"` } // Implicit default values for the config if it does not exist. @@ -36,6 +38,7 @@ var defaultCrlConfig = crlConfig{ AutoRebuildGracePeriod: "12h", EnableDelta: false, DeltaRebuildInterval: "15m", + UseGlobalQueue: false, } func pathConfigCRL(b *backend) *framework.Path { @@ -80,6 +83,11 @@ the NextUpdate field); defaults to 12 hours`, Description: `The time between delta CRL rebuilds if a new revocation has occurred. Must be shorter than the CRL expiry. Defaults to 15m.`, Default: "15m", }, + "cross_cluster_revocation": { + Type: framework.TypeBool, + Description: `Whether to enable a global, cross-cluster revocation queue. +Must be used with auto_rebuild=true.`, + }, }, Operations: map[logical.Operation]framework.OperationHandler{ @@ -116,6 +124,7 @@ func (b *backend) pathCRLRead(ctx context.Context, req *logical.Request, _ *fram "auto_rebuild_grace_period": config.AutoRebuildGracePeriod, "enable_delta": config.EnableDelta, "delta_rebuild_interval": config.DeltaRebuildInterval, + "cross_cluster_revocation": config.UseGlobalQueue, }, }, nil } @@ -182,6 +191,10 @@ func (b *backend) pathCRLWrite(ctx context.Context, req *logical.Request, d *fra config.DeltaRebuildInterval = deltaRebuildInterval } + if useGlobalQueue, ok := d.GetOk("cross_cluster_revocation"); ok { + config.UseGlobalQueue = useGlobalQueue.(bool) + } + expiry, _ := time.ParseDuration(config.Expiry) if config.AutoRebuild { gracePeriod, _ := time.ParseDuration(config.AutoRebuildGracePeriod) @@ -197,8 +210,18 @@ func (b *backend) pathCRLWrite(ctx context.Context, req *logical.Request, d *fra } } - if config.EnableDelta && !config.AutoRebuild { - return logical.ErrorResponse("Delta CRLs cannot be enabled when auto rebuilding is disabled as the complete CRL is always regenerated!"), nil + if !config.AutoRebuild { + if config.EnableDelta { + return logical.ErrorResponse("Delta CRLs cannot be enabled when auto rebuilding is disabled as the complete CRL is always regenerated!"), nil + } + + if config.UseGlobalQueue { + return logical.ErrorResponse("Global, cross-cluster revocation queue cannot be enabled when auto rebuilding is disabled as the local cluster may not have the certificate entry!"), nil + } + } + + if !constants.IsEnterprise && config.UseGlobalQueue { + return logical.ErrorResponse("Global, cross-cluster revocation queue can only be enabled on Vault Enterprise."), nil } entry, err := logical.StorageEntryJSON("config/crl", config) diff --git a/builtin/logical/pki/path_revoke.go b/builtin/logical/pki/path_revoke.go index a4c971521..bedf83665 100644 --- a/builtin/logical/pki/path_revoke.go +++ b/builtin/logical/pki/path_revoke.go @@ -9,6 +9,7 @@ import ( "crypto/x509" "encoding/pem" "fmt" + "time" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/certutil" @@ -31,6 +32,21 @@ func pathListCertsRevoked(b *backend) *framework.Path { } } +func pathListCertsRevocationQueue(b *backend) *framework.Path { + return &framework.Path{ + Pattern: "certs/revocation-queue/?$", + + Operations: map[logical.Operation]framework.OperationHandler{ + logical.ListOperation: &framework.PathOperation{ + Callback: b.pathListRevocationQueueHandler, + }, + }, + + HelpSynopsis: pathListRevocationQueueHelpSyn, + HelpDescription: pathListRevocationQueueHelpDesc, + } +} + func pathRevoke(b *backend) *framework.Path { return &framework.Path{ Pattern: `revoke`, @@ -315,9 +331,47 @@ func (b *backend) pathRevokeWriteHandleKey(req *logical.Request, certReference * return nil } +func (b *backend) maybeRevokeCrossCluster(ctx context.Context, sc *storageContext, serial string) (*logical.Response, error) { + config, err := b.crlBuilder.getConfigWithUpdate(sc) + if err != nil { + return nil, err + } + + if !config.UseGlobalQueue { + return logical.ErrorResponse(fmt.Sprintf("certificate with serial %s not found or was already revoked", serial)), nil + } + + // Here, we have to use the global revocation queue as the cert + // was not found on this current cluster. + currTime := time.Now() + nSerial := normalizeSerial(serial) + queueReq := revocationRequest{ + RequestedAt: currTime, + } + path := crossRevocationPath + nSerial + + reqEntry, err := logical.StorageEntryJSON(path, queueReq) + if err != nil { + return nil, fmt.Errorf("failed to create storage entry for cross-cluster revocation request: %w", err) + } + + if err := sc.Storage.Put(ctx, reqEntry); err != nil { + return nil, fmt.Errorf("error persisting cross-cluster revocation request: %w\nThis may occur when the active node of the primary performance replication cluster is unavailable.", err) + } + + resp := &logical.Response{ + Data: map[string]interface{}{ + "state": "pending", + }, + } + resp.AddWarning("Revocation request was not found on this present node. This request will be in a pending state until the PR cluster which issued this certificate sees the request and revokes the certificate. If no online cluster has this certificate, the request will eventually be removed without revoking any certificates.") + return resp, nil +} + func (b *backend) pathRevokeWrite(ctx context.Context, req *logical.Request, data *framework.FieldData, _ *roleEntry) (*logical.Response, error) { rawSerial, haveSerial := data.GetOk("serial_number") rawCertificate, haveCert := data.GetOk("certificate") + sc := b.makeStorageContext(ctx, req.Storage) if !haveSerial && !haveCert { return logical.ErrorResponse("The serial number or certificate to revoke must be provided."), nil @@ -343,8 +397,6 @@ func (b *backend) pathRevokeWrite(ctx context.Context, req *logical.Request, dat var cert *x509.Certificate var serial string - sc := b.makeStorageContext(ctx, req.Storage) - if haveCert { var err error serial, writeCert, cert, err = b.pathRevokeWriteHandleCertificate(ctx, req, rawCertificate.(string)) @@ -373,6 +425,9 @@ func (b *backend) pathRevokeWrite(ctx context.Context, req *logical.Request, dat if err != nil { return nil, fmt.Errorf("error parsing certificate: %w", err) } + } else if keyPem == "" { + // Cross-cluster revocation can only happen without PoP currently. + return b.maybeRevokeCrossCluster(ctx, sc, serial) } } @@ -477,6 +532,57 @@ func (b *backend) pathListRevokedCertsHandler(ctx context.Context, request *logi return logical.ListResponse(revokedCerts), nil } +func (b *backend) pathListRevocationQueueHandler(ctx context.Context, request *logical.Request, _ *framework.FieldData) (*logical.Response, error) { + var responseKeys []string + responseInfo := make(map[string]interface{}) + + sc := b.makeStorageContext(ctx, request.Storage) + + clusters, err := sc.Storage.List(sc.Context, crossRevocationPrefix) + if err != nil { + return nil, fmt.Errorf("failed to list cross-cluster revocation queue participating clusters: %w", err) + } + + for cIndex, cluster := range clusters { + cluster = cluster[0 : len(cluster)-1] + cPath := crossRevocationPrefix + cluster + "/" + serials, err := sc.Storage.List(sc.Context, cPath) + if err != nil { + return nil, fmt.Errorf("failed to list cross-cluster revocation queue entries for cluster %v (%v): %w", cluster, cIndex, err) + } + + for _, serial := range serials { + // Always strip the slash out; it indicates the presence of + // a confirmed revocation, which we add to the main serial's + // entry. + hasSlash := serial[len(serial)-1] == '/' + if hasSlash { + serial = serial[0 : len(serial)-1] + } + + var data map[string]interface{} + rawData, isPresent := responseInfo[serial] + if !isPresent { + data = map[string]interface{}{} + responseKeys = append(responseKeys, serial) + } else { + data = rawData.(map[string]interface{}) + } + + if hasSlash { + data["confirmed"] = true + data["confirmation_cluster"] = cluster + } else { + data["requesting_cluster"] = cluster + } + + responseInfo[serial] = data + } + } + + return logical.ListResponseWithInfo(responseKeys, responseInfo), nil +} + const pathRevokeHelpSyn = ` Revoke a certificate by serial number or with explicit certificate. @@ -512,3 +618,12 @@ List all revoked serial numbers within the local cluster const pathListRevokedHelpDesc = ` Returns a list of serial numbers for revoked certificates in the local cluster. ` + +const pathListRevocationQueueHelpSyn = ` +List all pending, cross-cluster revocations known to the local cluster. +` + +const pathListRevocationQueueHelpDesc = ` +Returns a detailed list containing serial number, requesting cluster, and +optionally a confirming cluster. +` diff --git a/builtin/logical/pki/path_tidy.go b/builtin/logical/pki/path_tidy.go index 977934788..c2cb47a7a 100644 --- a/builtin/logical/pki/path_tidy.go +++ b/builtin/logical/pki/path_tidy.go @@ -30,6 +30,8 @@ type tidyConfig struct { SafetyBuffer time.Duration `json:"safety_buffer"` IssuerSafetyBuffer time.Duration `json:"issuer_safety_buffer"` PauseDuration time.Duration `json:"pause_duration"` + RevocationQueue bool `json:"tidy_revocation_queue"` + QueueSafetyBuffer time.Duration `json:"revocation_queue_safety_buffer"` } var defaultTidyConfig = tidyConfig{ @@ -43,6 +45,8 @@ var defaultTidyConfig = tidyConfig{ SafetyBuffer: 72 * time.Hour, IssuerSafetyBuffer: 365 * 24 * time.Hour, PauseDuration: 0 * time.Second, + RevocationQueue: false, + QueueSafetyBuffer: 48 * time.Hour, } func pathTidy(b *backend) *framework.Path { @@ -128,6 +132,8 @@ func (b *backend) pathTidyWrite(ctx context.Context, req *logical.Request, d *fr issuerSafetyBuffer := d.Get("issuer_safety_buffer").(int) pauseDurationStr := d.Get("pause_duration").(string) pauseDuration := 0 * time.Second + tidyRevocationQueue := d.Get("tidy_revocation_queue").(bool) + queueSafetyBuffer := d.Get("revocation_queue_safety_buffer").(int) if safetyBuffer < 1 { return logical.ErrorResponse("safety_buffer must be greater than zero"), nil @@ -137,6 +143,10 @@ func (b *backend) pathTidyWrite(ctx context.Context, req *logical.Request, d *fr return logical.ErrorResponse("issuer_safety_buffer must be greater than zero"), nil } + if queueSafetyBuffer < 1 { + return logical.ErrorResponse("revocation_queue_safety_buffer must be greater than zero"), nil + } + if pauseDurationStr != "" { var err error pauseDuration, err = time.ParseDuration(pauseDurationStr) @@ -151,6 +161,7 @@ func (b *backend) pathTidyWrite(ctx context.Context, req *logical.Request, d *fr bufferDuration := time.Duration(safetyBuffer) * time.Second issuerBufferDuration := time.Duration(issuerSafetyBuffer) * time.Second + queueSafetyBufferDuration := time.Duration(queueSafetyBuffer) * time.Second // Manual run with constructed configuration. config := &tidyConfig{ @@ -164,6 +175,8 @@ func (b *backend) pathTidyWrite(ctx context.Context, req *logical.Request, d *fr SafetyBuffer: bufferDuration, IssuerSafetyBuffer: issuerBufferDuration, PauseDuration: pauseDuration, + RevocationQueue: tidyRevocationQueue, + QueueSafetyBuffer: queueSafetyBufferDuration, } if !atomic.CompareAndSwapUint32(b.tidyCASGuard, 0, 1) { @@ -188,12 +201,20 @@ func (b *backend) pathTidyWrite(ctx context.Context, req *logical.Request, d *fr b.startTidyOperation(req, config) resp := &logical.Response{} - if !tidyCertStore && !tidyRevokedCerts && !tidyRevokedAssocs && !tidyExpiredIssuers && !tidyBackupBundle { - resp.AddWarning("No targets to tidy; specify tidy_cert_store=true or tidy_revoked_certs=true or tidy_revoked_cert_issuer_associations=true or tidy_expired_issuers=true or tidy_move_legacy_ca_bundle=true to start a tidy operation.") + if !tidyCertStore && !tidyRevokedCerts && !tidyRevokedAssocs && !tidyExpiredIssuers && !tidyBackupBundle && !tidyRevocationQueue { + resp.AddWarning("No targets to tidy; specify tidy_cert_store=true or tidy_revoked_certs=true or tidy_revoked_cert_issuer_associations=true or tidy_expired_issuers=true or tidy_move_legacy_ca_bundle=true or tidy_revocation_queue=true to start a tidy operation.") } else { resp.AddWarning("Tidy operation successfully started. Any information from the operation will be printed to Vault's server logs.") } + if tidyRevocationQueue { + isNotPerfPrimary := b.System().ReplicationState().HasState(consts.ReplicationDRSecondary|consts.ReplicationPerformanceStandby) || + (!b.System().LocalMount() && b.System().ReplicationState().HasState(consts.ReplicationPerformanceSecondary)) + if isNotPerfPrimary { + resp.AddWarning("tidy_revocation_queue=true can only be set on the active node of the primary cluster unless a local mount is used; this option has been ignored.") + } + } + return logical.RespondWithStatusCode(resp, req, http.StatusAccepted) } @@ -227,18 +248,39 @@ func (b *backend) startTidyOperation(req *logical.Request, config *tidyConfig) { } } + // Check for cancel before continuing. + if atomic.CompareAndSwapUint32(b.tidyCancelCAS, 1, 0) { + return tidyCancelledError + } + if config.ExpiredIssuers { if err := b.doTidyExpiredIssuers(ctx, req, logger, config); err != nil { return err } } + // Check for cancel before continuing. + if atomic.CompareAndSwapUint32(b.tidyCancelCAS, 1, 0) { + return tidyCancelledError + } + if config.BackupBundle { if err := b.doTidyMoveCABundle(ctx, req, logger, config); err != nil { return err } } + // Check for cancel before continuing. + if atomic.CompareAndSwapUint32(b.tidyCancelCAS, 1, 0) { + return tidyCancelledError + } + + if config.RevocationQueue { + if err := b.doTidyRevocationQueue(ctx, req, logger, config); err != nil { + return err + } + } + return nil } @@ -625,6 +667,70 @@ func (b *backend) doTidyMoveCABundle(ctx context.Context, req *logical.Request, } b.Logger().Info("legacy CA bundle successfully moved to backup location") + return nil +} + +func (b *backend) doTidyRevocationQueue(ctx context.Context, req *logical.Request, logger hclog.Logger, config *tidyConfig) error { + if b.System().ReplicationState().HasState(consts.ReplicationDRSecondary|consts.ReplicationPerformanceStandby) || + (!b.System().LocalMount() && b.System().ReplicationState().HasState(consts.ReplicationPerformanceSecondary)) { + b.Logger().Debug("skipping cross-cluster revocation queue tidy as we're not on the primary or secondary with a local mount") + return nil + } + + sc := b.makeStorageContext(ctx, req.Storage) + clusters, err := sc.Storage.List(sc.Context, crossRevocationPrefix) + if err != nil { + return fmt.Errorf("failed to list cross-cluster revocation queue participating clusters: %w", err) + } + + for cIndex, cluster := range clusters { + cluster = cluster[0 : len(cluster)-1] + cPath := crossRevocationPrefix + cluster + "/" + serials, err := sc.Storage.List(sc.Context, cPath) + if err != nil { + return fmt.Errorf("failed to list cross-cluster revocation queue entries for cluster %v (%v): %w", cluster, cIndex, err) + } + + for _, serial := range serials { + // Confirmation entries _should_ be handled by this cluster's + // processRevocationQueue(...) invocation; if not, when the plugin + // reloads, maybeGatherQueueForFirstProcess(...) will remove all + // stale confirmation requests. + if serial[len(serial)-1] == '/' { + continue + } + + // Check for pause duration to reduce resource consumption. + if config.PauseDuration > (0 * time.Second) { + time.Sleep(config.PauseDuration) + } + + ePath := cPath + serial + entry, err := sc.Storage.Get(sc.Context, ePath) + if err != nil { + return fmt.Errorf("error reading revocation request (%v) to tidy: %w", ePath, err) + } + if entry == nil || entry.Value == nil { + continue + } + + var revRequest revocationRequest + if err := entry.DecodeJSON(&revRequest); err != nil { + return fmt.Errorf("error reading revocation request (%v) to tidy: %w", ePath, err) + } + + now := time.Now() + afterBuffer := now.Add(-1 * config.QueueSafetyBuffer) + if revRequest.RequestedAt.After(afterBuffer) { + continue + } + + // Safe to remove this entry. + if err := sc.Storage.Delete(sc.Context, ePath); err != nil { + return fmt.Errorf("error deleting revocation request (%v): %w", ePath, err) + } + } + } return nil } @@ -747,6 +853,8 @@ func (b *backend) pathConfigAutoTidyRead(ctx context.Context, req *logical.Reque "safety_buffer": int(config.SafetyBuffer / time.Second), "issuer_safety_buffer": int(config.IssuerSafetyBuffer / time.Second), "pause_duration": config.PauseDuration.String(), + "tidy_revocation_queue": config.RevocationQueue, + "revocation_queue_safety_buffer": int(config.QueueSafetyBuffer / time.Second), }, }, nil } @@ -814,8 +922,19 @@ func (b *backend) pathConfigAutoTidyWrite(ctx context.Context, req *logical.Requ config.BackupBundle = backupBundle.(bool) } - if config.Enabled && !(config.CertStore || config.RevokedCerts || config.IssuerAssocs || config.ExpiredIssuers || config.BackupBundle) { - return logical.ErrorResponse("Auto-tidy enabled but no tidy operations were requested. Enable at least one tidy operation to be run (tidy_cert_store / tidy_revoked_certs / tidy_revoked_cert_issuer_associations / tidy_move_legacy_ca_bundle)."), nil + if revocationQueueRaw, ok := d.GetOk("tidy_revocation_queue"); ok { + config.RevocationQueue = revocationQueueRaw.(bool) + } + + if queueSafetyBufferRaw, ok := d.GetOk("revocation_queue_safety_buffer"); ok { + config.QueueSafetyBuffer = time.Duration(queueSafetyBufferRaw.(int)) * time.Second + if config.QueueSafetyBuffer < 1*time.Second { + return logical.ErrorResponse(fmt.Sprintf("given revocation_queue_safety_buffer must be at least one second; got: %v", queueSafetyBufferRaw)), nil + } + } + + if config.Enabled && !(config.CertStore || config.RevokedCerts || config.IssuerAssocs || config.ExpiredIssuers || config.BackupBundle || config.RevocationQueue) { + return logical.ErrorResponse("Auto-tidy enabled but no tidy operations were requested. Enable at least one tidy operation to be run (tidy_cert_store / tidy_revoked_certs / tidy_revoked_cert_issuer_associations)."), nil } if err := sc.writeAutoTidyConfig(config); err != nil { @@ -834,6 +953,8 @@ func (b *backend) pathConfigAutoTidyWrite(ctx context.Context, req *logical.Requ "safety_buffer": int(config.SafetyBuffer / time.Second), "issuer_safety_buffer": int(config.IssuerSafetyBuffer / time.Second), "pause_duration": config.PauseDuration.String(), + "tidy_revocation_queue": config.RevocationQueue, + "revocation_queue_safety_buffer": int(config.QueueSafetyBuffer / time.Second), }, }, nil } diff --git a/builtin/logical/pki/path_tidy_test.go b/builtin/logical/pki/path_tidy_test.go index 4cd137a21..45d3d3a6a 100644 --- a/builtin/logical/pki/path_tidy_test.go +++ b/builtin/logical/pki/path_tidy_test.go @@ -392,6 +392,7 @@ func TestTidyIssuerConfig(t *testing.T) { defaultConfigMap["issuer_safety_buffer"] = int(time.Duration(defaultConfigMap["issuer_safety_buffer"].(float64)) / time.Second) defaultConfigMap["safety_buffer"] = int(time.Duration(defaultConfigMap["safety_buffer"].(float64)) / time.Second) defaultConfigMap["pause_duration"] = time.Duration(defaultConfigMap["pause_duration"].(float64)).String() + defaultConfigMap["revocation_queue_safety_buffer"] = int(time.Duration(defaultConfigMap["revocation_queue_safety_buffer"].(float64)) / time.Second) require.Equal(t, defaultConfigMap, resp.Data) diff --git a/builtin/logical/pki/storage.go b/builtin/logical/pki/storage.go index 92bb309c1..c5aa1e0f9 100644 --- a/builtin/logical/pki/storage.go +++ b/builtin/logical/pki/storage.go @@ -182,6 +182,7 @@ type internalCRLConfigEntry struct { CRLExpirationMap map[crlID]time.Time `json:"crl_expiration_map"` LastModified time.Time `json:"last_modified"` DeltaLastModified time.Time `json:"delta_last_modified"` + UseGlobalQueue bool `json:"cross_cluster_revocation"` } type keyConfigEntry struct { diff --git a/builtin/logical/pki/util.go b/builtin/logical/pki/util.go index a71a4d017..78886c54b 100644 --- a/builtin/logical/pki/util.go +++ b/builtin/logical/pki/util.go @@ -8,6 +8,7 @@ import ( "net/http" "regexp" "strings" + "sync" "time" "github.com/hashicorp/vault/sdk/framework" @@ -359,3 +360,99 @@ func addWarnings(resp *logical.Response, warnings []string) *logical.Response { } return resp } + +// revocationQueue is a type for allowing invalidateFunc to continue operating +// quickly, while letting periodicFunc slowly sort through all open +// revocations to process. In particular, we do not wish to be holding this +// lock while periodicFunc is running, so iteration returns a full copy of +// the data in this queue. We use a map from serial->[]clusterId, allowing us +// to quickly insert and remove items, without using a slice of tuples. One +// serial might be present on two clusters, if two clusters both have the cert +// stored locally (e.g., via BYOC), which would result in two confirmation +// entries and thus dictating the need for []clusterId. This also lets us +// avoid having duplicate entries. +type revocationQueue struct { + _l sync.Mutex + queue map[string][]string +} + +func newRevocationQueue() *revocationQueue { + return &revocationQueue{ + queue: make(map[string][]string), + } +} + +func (q *revocationQueue) Add(items ...*revocationQueueEntry) { + q._l.Lock() + defer q._l.Unlock() + + for _, item := range items { + var found bool + for _, cluster := range q.queue[item.Serial] { + if cluster == item.Cluster { + found = true + break + } + } + + if !found { + q.queue[item.Serial] = append(q.queue[item.Serial], item.Cluster) + } + } +} + +func (q *revocationQueue) Remove(item *revocationQueueEntry) { + q._l.Lock() + defer q._l.Unlock() + + clusters, present := q.queue[item.Serial] + if !present { + return + } + + if len(clusters) == 0 || (len(clusters) == 1 && clusters[0] == item.Cluster) { + delete(q.queue, item.Serial) + return + } + + result := clusters + for index, cluster := range clusters { + if cluster == item.Cluster { + result = append(clusters[0:index], clusters[index+1:]...) + break + } + } + + q.queue[item.Serial] = result +} + +// As this doesn't depend on any internal state, it should not be called +// unless it is OK to remove any items added since the last Iterate() +// function call. +func (q *revocationQueue) RemoveAll() { + q._l.Lock() + defer q._l.Unlock() + + q.queue = make(map[string][]string) +} + +func (q *revocationQueue) Iterate() []*revocationQueueEntry { + q._l.Lock() + defer q._l.Unlock() + + // Heuristic: by storing by serial, occasionally we'll get double entires + // if it was already revoked, but otherwise we'll be off by fewer when + // building this list. + ret := make([]*revocationQueueEntry, 0, len(q.queue)) + + for serial, clusters := range q.queue { + for _, cluster := range clusters { + ret = append(ret, &revocationQueueEntry{ + Serial: serial, + Cluster: cluster, + }) + } + } + + return ret +} diff --git a/vault/testing.go b/vault/testing.go index 0b3d63dd7..afc46a399 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -1115,6 +1115,20 @@ func (c *TestClusterCore) stop() error { return nil } +func (c *TestClusterCore) GrabRollbackLock() { + // Ensure we don't hold this lock while there are in flight rollbacks. + c.rollback.inflightAll.Wait() + c.rollback.inflightLock.Lock() +} + +func (c *TestClusterCore) ReleaseRollbackLock() { + c.rollback.inflightLock.Unlock() +} + +func (c *TestClusterCore) TriggerRollbacks() { + c.rollback.triggerRollbacks() +} + func (c *TestCluster) Cleanup() { c.Logger.Info("cleaning up vault cluster") if tl, ok := c.Logger.(*TestLogger); ok {