diff --git a/builtin/logical/transit/backend.go b/builtin/logical/transit/backend.go index d91df7df1..86838f74a 100644 --- a/builtin/logical/transit/backend.go +++ b/builtin/logical/transit/backend.go @@ -4,10 +4,14 @@ import ( "context" "fmt" "io" + "strconv" "strings" "sync" + "time" + "github.com/hashicorp/go-multierror" "github.com/hashicorp/vault/sdk/framework" + "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/keysutil" "github.com/hashicorp/vault/sdk/logical" ) @@ -59,9 +63,10 @@ func Backend(ctx context.Context, conf *logical.BackendConfig) (*backend, error) b.pathCacheConfig(), }, - Secrets: []*framework.Secret{}, - Invalidate: b.invalidate, - BackendType: logical.TypeLogical, + Secrets: []*framework.Secret{}, + Invalidate: b.invalidate, + BackendType: logical.TypeLogical, + PeriodicFunc: b.periodicFunc, } // determine cacheSize to use. Defaults to 0 which means unlimited @@ -93,8 +98,10 @@ type backend struct { *framework.Backend lm *keysutil.LockManager // Lock to make changes to any of the backend's cache configuration. - configMutex sync.RWMutex - cacheSizeChanged bool + configMutex sync.RWMutex + cacheSizeChanged bool + checkAutoRotateAfter time.Time + autoRotateOnce sync.Once } func GetCacheSizeFromStorage(ctx context.Context, s logical.Storage) (int, error) { @@ -162,3 +169,91 @@ func (b *backend) invalidate(ctx context.Context, key string) { b.cacheSizeChanged = true } } + +// periodicFunc is a central collection of functions that run on an interval. +// Anything that should be called regularly can be placed within this method. +func (b *backend) periodicFunc(ctx context.Context, req *logical.Request) error { + // These operations ensure the auto-rotate only happens once simultaneously. It's an unlikely edge + // given the time scale, but a safeguard nonetheless. + var err error + didAutoRotate := false + autoRotateOnceFn := func() { + err = b.autoRotateKeys(ctx, req) + didAutoRotate = true + } + b.autoRotateOnce.Do(autoRotateOnceFn) + if didAutoRotate { + b.autoRotateOnce = sync.Once{} + } + + return err +} + +// autoRotateKeys retrieves all transit keys and rotates those which have an +// auto rotate interval defined which has passed. This operation only happens +// on primary nodes and performance secondary nodes which have a local mount. +func (b *backend) autoRotateKeys(ctx context.Context, req *logical.Request) error { + // Only check for autorotation once an hour to avoid unnecessarily iterating + // over all keys too frequently. + if time.Now().Before(b.checkAutoRotateAfter) { + return nil + } + b.checkAutoRotateAfter = time.Now().Add(1 * time.Hour) + + // Early exit if not a primary or performance secondary with a local mount. + if b.System().ReplicationState().HasState(consts.ReplicationDRSecondary|consts.ReplicationPerformanceStandby) || + (!b.System().LocalMount() && b.System().ReplicationState().HasState(consts.ReplicationPerformanceSecondary)) { + return nil + } + + // Retrieve all keys and loop over them to check if they need to be rotated. + keys, err := req.Storage.List(ctx, "policy/") + if err != nil { + return err + } + + // Collect errors in a multierror to ensure a single failure doesn't prevent + // all keys from being rotated. + var errs *multierror.Error + + for _, key := range keys { + p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{ + Storage: req.Storage, + Name: key, + }, b.GetRandomReader()) + if err != nil { + errs = multierror.Append(errs, err) + continue + } + + // If the policy is nil, move onto the next one. + if p == nil { + continue + } + + // If the policy's automatic rotation interval is 0, it should not + // automatically rotate. + if p.AutoRotateInterval == 0 { + continue + } + + // Retrieve the latest version of the policy and determine if it is time to rotate. + latestKey := p.Keys[strconv.Itoa(p.LatestVersion)] + if time.Now().After(latestKey.CreationTime.Add(p.AutoRotateInterval)) { + if b.Logger().IsDebug() { + b.Logger().Debug("automatically rotating key", "key", key) + } + if !b.System().CachingDisabled() { + p.Lock(true) + } + err = p.Rotate(ctx, req.Storage, b.GetRandomReader()) + p.Unlock() + if err != nil { + errs = multierror.Append(errs, err) + continue + } + } + } + + return errs.ErrorOrNil() +} diff --git a/builtin/logical/transit/backend_test.go b/builtin/logical/transit/backend_test.go index 355b0738c..6ad7b2d7d 100644 --- a/builtin/logical/transit/backend_test.go +++ b/builtin/logical/transit/backend_test.go @@ -18,6 +18,7 @@ import ( uuid "github.com/hashicorp/go-uuid" logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical" "github.com/hashicorp/vault/sdk/framework" + "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/keysutil" "github.com/hashicorp/vault/sdk/logical" "github.com/mitchellh/mapstructure" @@ -1519,3 +1520,190 @@ func TestBadInput(t *testing.T) { t.Fatal("expected error") } } + +func TestTransit_AutoRotateKeys(t *testing.T) { + tests := map[string]struct { + isDRSecondary bool + isPerfSecondary bool + isStandby bool + isLocal bool + shouldRotate bool + }{ + "primary, no local mount": { + shouldRotate: true, + }, + "DR secondary, no local mount": { + isDRSecondary: true, + shouldRotate: false, + }, + "perf standby, no local mount": { + isStandby: true, + shouldRotate: false, + }, + "perf secondary, no local mount": { + isPerfSecondary: true, + shouldRotate: false, + }, + "perf secondary, local mount": { + isPerfSecondary: true, + isLocal: true, + shouldRotate: true, + }, + } + + for name, test := range tests { + t.Run( + name, + func(t *testing.T) { + var repState consts.ReplicationState + if test.isDRSecondary { + repState.AddState(consts.ReplicationDRSecondary) + } + if test.isPerfSecondary { + repState.AddState(consts.ReplicationPerformanceSecondary) + } + if test.isStandby { + repState.AddState(consts.ReplicationPerformanceStandby) + } + + sysView := logical.TestSystemView() + sysView.ReplicationStateVal = repState + sysView.LocalMountVal = test.isLocal + + storage := &logical.InmemStorage{} + + conf := &logical.BackendConfig{ + StorageView: storage, + System: sysView, + } + + b, _ := Backend(context.Background(), conf) + if b == nil { + t.Fatal("failed to create backend") + } + + err := b.Backend.Setup(context.Background(), conf) + if err != nil { + t.Fatal(err) + } + + // Write a key with the default auto rotate value (0/disabled) + req := &logical.Request{ + Storage: storage, + Operation: logical.UpdateOperation, + Path: "keys/test1", + } + resp, err := b.HandleRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if resp != nil { + t.Fatal("expected nil response") + } + + // Write a key with an auto rotate value one day in the future + req = &logical.Request{ + Storage: storage, + Operation: logical.UpdateOperation, + Path: "keys/test2", + Data: map[string]interface{}{ + "auto_rotate_interval": 24 * time.Hour, + }, + } + resp, err = b.HandleRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if resp != nil { + t.Fatal("expected nil response") + } + + // Run the rotation check and ensure none of the keys have rotated + b.checkAutoRotateAfter = time.Now() + if err = b.autoRotateKeys(context.Background(), &logical.Request{Storage: storage}); err != nil { + t.Fatal(err) + } + req = &logical.Request{ + Storage: storage, + Operation: logical.ReadOperation, + Path: "keys/test1", + } + resp, err = b.HandleRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if resp.Data["latest_version"] != 1 { + t.Fatalf("incorrect latest_version found, got: %d, want: %d", resp.Data["latest_version"], 1) + } + + req.Path = "keys/test2" + resp, err = b.HandleRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if resp.Data["latest_version"] != 1 { + t.Fatalf("incorrect latest_version found, got: %d, want: %d", resp.Data["latest_version"], 1) + } + + // Update auto rotate interval on one key to be one nanosecond + p, _, err := b.GetPolicy(context.Background(), keysutil.PolicyRequest{ + Storage: storage, + Name: "test2", + }, b.GetRandomReader()) + if err != nil { + t.Fatal(err) + } + if p == nil { + t.Fatal("expected non-nil policy") + } + p.AutoRotateInterval = time.Nanosecond + err = p.Persist(context.Background(), storage) + if err != nil { + t.Fatal(err) + } + + // Run the rotation check and validate the state of key rotations + b.checkAutoRotateAfter = time.Now() + if err = b.autoRotateKeys(context.Background(), &logical.Request{Storage: storage}); err != nil { + t.Fatal(err) + } + req = &logical.Request{ + Storage: storage, + Operation: logical.ReadOperation, + Path: "keys/test1", + } + resp, err = b.HandleRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if resp.Data["latest_version"] != 1 { + t.Fatalf("incorrect latest_version found, got: %d, want: %d", resp.Data["latest_version"], 1) + } + req.Path = "keys/test2" + resp, err = b.HandleRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + expectedVersion := 1 + if test.shouldRotate { + expectedVersion = 2 + } + if resp.Data["latest_version"] != expectedVersion { + t.Fatalf("incorrect latest_version found, got: %d, want: %d", resp.Data["latest_version"], expectedVersion) + } + }, + ) + } +} diff --git a/builtin/logical/transit/path_config.go b/builtin/logical/transit/path_config.go index 1c41cd0d4..336643227 100644 --- a/builtin/logical/transit/path_config.go +++ b/builtin/logical/transit/path_config.go @@ -3,6 +3,7 @@ package transit import ( "context" "fmt" + "time" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/keysutil" @@ -47,6 +48,13 @@ the latest version of the key is allowed.`, Type: framework.TypeBool, Description: `Enables taking a backup of the named key in plaintext format. Once set, this cannot be disabled.`, }, + + "auto_rotate_interval": { + Type: framework.TypeDurationSecond, + Description: `Amount of time the key should live before +being automatically rotated. A value of 0 +disables automatic rotation for the key.`, + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ @@ -185,6 +193,23 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, d * } } + autoRotateIntervalRaw, ok, err := d.GetOkErr("auto_rotate_interval") + if err != nil { + return nil, err + } + if ok { + autoRotateInterval := time.Second * time.Duration(autoRotateIntervalRaw.(int)) + // Provided value must be 0 to disable or at least an hour + if autoRotateInterval != 0 && autoRotateInterval < time.Hour { + return logical.ErrorResponse("auto rotate interval must be 0 to disable or at least an hour"), nil + } + + if autoRotateInterval != p.AutoRotateInterval { + p.AutoRotateInterval = autoRotateInterval + persistNeeded = true + } + } + if !persistNeeded { return nil, nil } diff --git a/builtin/logical/transit/path_config_test.go b/builtin/logical/transit/path_config_test.go index c0f0eba0e..62864bb73 100644 --- a/builtin/logical/transit/path_config_test.go +++ b/builtin/logical/transit/path_config_test.go @@ -2,11 +2,18 @@ package transit import ( "context" + "encoding/hex" + "fmt" "strconv" "strings" "testing" + "time" + uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/api" + vaulthttp "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/vault" ) func TestTransit_ConfigSettings(t *testing.T) { @@ -283,3 +290,106 @@ func TestTransit_ConfigSettings(t *testing.T) { testHMAC(3, true) testHMAC(2, false) } + +func TestTransit_UpdateKeyConfigWithAutorotation(t *testing.T) { + tests := map[string]struct { + initialAutoRotateInterval interface{} + newAutoRotateInterval interface{} + shouldError bool + expectedValue time.Duration + }{ + "default (no value)": { + initialAutoRotateInterval: "5h", + shouldError: false, + expectedValue: 5 * time.Hour, + }, + "0 (int)": { + initialAutoRotateInterval: "5h", + newAutoRotateInterval: 0, + shouldError: false, + expectedValue: 0, + }, + "0 (string)": { + initialAutoRotateInterval: "5h", + newAutoRotateInterval: 0, + shouldError: false, + expectedValue: 0, + }, + "5 seconds": { + newAutoRotateInterval: "5s", + shouldError: true, + }, + "5 hours": { + newAutoRotateInterval: "5h", + shouldError: false, + expectedValue: 5 * time.Hour, + }, + "negative value": { + newAutoRotateInterval: "-1800s", + shouldError: true, + }, + "invalid string": { + newAutoRotateInterval: "this shouldn't work", + shouldError: true, + }, + } + + coreConfig := &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "transit": Factory, + }, + } + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + cluster.Start() + defer cluster.Cleanup() + cores := cluster.Cores + vault.TestWaitActive(t, cores[0].Core) + client := cores[0].Client + err := client.Sys().Mount("transit", &api.MountInput{ + Type: "transit", + }) + if err != nil { + t.Fatal(err) + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + keyNameBytes, err := uuid.GenerateRandomBytes(16) + if err != nil { + t.Fatal(err) + } + keyName := hex.EncodeToString(keyNameBytes) + + _, err = client.Logical().Write(fmt.Sprintf("transit/keys/%s", keyName), map[string]interface{}{ + "auto_rotate_interval": test.initialAutoRotateInterval, + }) + + resp, err := client.Logical().Write(fmt.Sprintf("transit/keys/%s/config", keyName), map[string]interface{}{ + "auto_rotate_interval": test.newAutoRotateInterval, + }) + switch { + case test.shouldError && err == nil: + t.Fatal("expected non-nil error") + case !test.shouldError && err != nil: + t.Fatal(err) + } + + if !test.shouldError { + resp, err = client.Logical().Read(fmt.Sprintf("transit/keys/%s", keyName)) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + got := resp.Data["auto_rotate_interval"] + want := test.expectedValue.String() + if got != want { + t.Fatalf("incorrect auto_rotate_interval returned, got: %s, want: %s", got, want) + } + } + }) + } +} diff --git a/builtin/logical/transit/path_keys.go b/builtin/logical/transit/path_keys.go index 8c43ab593..1d5e142b7 100644 --- a/builtin/logical/transit/path_keys.go +++ b/builtin/logical/transit/path_keys.go @@ -94,6 +94,15 @@ When reading a key with key derivation enabled, if the key type supports public keys, this will return the public key for the given context.`, }, + + "auto_rotate_interval": { + Type: framework.TypeDurationSecond, + Default: 0, + Description: `Amount of time the key should live before +being automatically rotated. A value of 0 +(default) disables automatic rotation for the +key.`, + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ @@ -123,6 +132,11 @@ func (b *backend) pathPolicyWrite(ctx context.Context, req *logical.Request, d * keyType := d.Get("type").(string) exportable := d.Get("exportable").(bool) allowPlaintextBackup := d.Get("allow_plaintext_backup").(bool) + autoRotateInterval := time.Second * time.Duration(d.Get("auto_rotate_interval").(int)) + + if autoRotateInterval != 0 && autoRotateInterval < time.Hour { + return logical.ErrorResponse("auto rotate interval must be 0 to disable or at least an hour"), nil + } if !derived && convergent { return logical.ErrorResponse("convergent encryption requires derivation to be enabled"), nil @@ -136,6 +150,7 @@ func (b *backend) pathPolicyWrite(ctx context.Context, req *logical.Request, d * Convergent: convergent, Exportable: exportable, AllowPlaintextBackup: allowPlaintextBackup, + AutoRotateInterval: autoRotateInterval, } switch keyType { case "aes128-gcm96": @@ -223,6 +238,7 @@ func (b *backend) pathPolicyRead(ctx context.Context, req *logical.Request, d *f "supports_decryption": p.Type.DecryptionSupported(), "supports_signing": p.Type.SigningSupported(), "supports_derivation": p.Type.DerivationSupported(), + "auto_rotate_interval": p.AutoRotateInterval.String(), }, } diff --git a/builtin/logical/transit/path_keys_test.go b/builtin/logical/transit/path_keys_test.go index a99f64d7a..f90fd4691 100644 --- a/builtin/logical/transit/path_keys_test.go +++ b/builtin/logical/transit/path_keys_test.go @@ -1,8 +1,12 @@ package transit_test import ( + "encoding/hex" + "fmt" "testing" + "time" + uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/builtin/audit/file" @@ -87,3 +91,97 @@ func TestTransit_Issue_2958(t *testing.T) { t.Fatal(err) } } + +func TestTransit_CreateKeyWithAutorotation(t *testing.T) { + tests := map[string]struct { + autoRotateInterval interface{} + shouldError bool + expectedValue time.Duration + }{ + "default (no value)": { + shouldError: false, + }, + "0 (int)": { + autoRotateInterval: 0, + shouldError: false, + expectedValue: 0, + }, + "0 (string)": { + autoRotateInterval: "0", + shouldError: false, + expectedValue: 0, + }, + "5 seconds": { + autoRotateInterval: "5s", + shouldError: true, + }, + "5 hours": { + autoRotateInterval: "5h", + shouldError: false, + expectedValue: 5 * time.Hour, + }, + "negative value": { + autoRotateInterval: "-1800s", + shouldError: true, + }, + "invalid string": { + autoRotateInterval: "this shouldn't work", + shouldError: true, + }, + } + + coreConfig := &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "transit": transit.Factory, + }, + } + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + cluster.Start() + defer cluster.Cleanup() + cores := cluster.Cores + vault.TestWaitActive(t, cores[0].Core) + client := cores[0].Client + err := client.Sys().Mount("transit", &api.MountInput{ + Type: "transit", + }) + if err != nil { + t.Fatal(err) + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + keyNameBytes, err := uuid.GenerateRandomBytes(16) + if err != nil { + t.Fatal(err) + } + keyName := hex.EncodeToString(keyNameBytes) + + _, err = client.Logical().Write(fmt.Sprintf("transit/keys/%s", keyName), map[string]interface{}{ + "auto_rotate_interval": test.autoRotateInterval, + }) + switch { + case test.shouldError && err == nil: + t.Fatal("expected non-nil error") + case !test.shouldError && err != nil: + t.Fatal(err) + } + + if !test.shouldError { + resp, err := client.Logical().Read(fmt.Sprintf("transit/keys/%s", keyName)) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + got := resp.Data["auto_rotate_interval"] + want := test.expectedValue.String() + if got != want { + t.Fatalf("incorrect auto_rotate_interval returned, got: %s, want: %s", got, want) + } + } + }) + } +} diff --git a/changelog/13691.txt b/changelog/13691.txt new file mode 100644 index 000000000..b3c0cb9c3 --- /dev/null +++ b/changelog/13691.txt @@ -0,0 +1,3 @@ +```release-note:feature +**Transit Time-Based Key Autorotation**: Add support for automatic, time-based key rotation to transit secrets engine. +``` diff --git a/sdk/helper/keysutil/lock_manager.go b/sdk/helper/keysutil/lock_manager.go index c6a0a23d6..71bfcac84 100644 --- a/sdk/helper/keysutil/lock_manager.go +++ b/sdk/helper/keysutil/lock_manager.go @@ -50,6 +50,9 @@ type PolicyRequest struct { // Whether to allow plaintext backup AllowPlaintextBackup bool + + // How frequently the key should automatically rotate + AutoRotateInterval time.Duration } type LockManager struct { @@ -380,6 +383,7 @@ func (lm *LockManager) GetPolicy(ctx context.Context, req PolicyRequest, rand io Derived: req.Derived, Exportable: req.Exportable, AllowPlaintextBackup: req.AllowPlaintextBackup, + AutoRotateInterval: req.AutoRotateInterval, } if req.Derived { diff --git a/sdk/helper/keysutil/policy.go b/sdk/helper/keysutil/policy.go index 1b198a35f..d4b82ab82 100644 --- a/sdk/helper/keysutil/policy.go +++ b/sdk/helper/keysutil/policy.go @@ -374,6 +374,10 @@ type Policy struct { // policy object. StoragePrefix string `json:"storage_prefix"` + // AutoRotateInterval defines how frequently the key should automatically + // rotate. Setting this to zero disables automatic rotation for the key. + AutoRotateInterval time.Duration `json:"auto_rotate_interval"` + // versionPrefixCache stores caches of version prefix strings and the split // version template. versionPrefixCache sync.Map diff --git a/website/content/api-docs/secret/transit.mdx b/website/content/api-docs/secret/transit.mdx index b5e68ef8e..4b0af3bd6 100644 --- a/website/content/api-docs/secret/transit.mdx +++ b/website/content/api-docs/secret/transit.mdx @@ -64,6 +64,11 @@ values set here cannot be changed after key creation. - `rsa-3072` - RSA with bit size of 3072 (asymmetric) - `rsa-4096` - RSA with bit size of 4096 (asymmetric) +- `auto_rotate_interval` `(duration: "0", optional)` – The interval at which + this key should be rotated automatically. Setting this to "0" (the default) + will disable automatic key rotation. This value cannot be shorter than one + hour. + ### Sample Payload ```json @@ -227,6 +232,11 @@ are returned during a read operation on the named key.) - `allow_plaintext_backup` `(bool: false)` - If set, enables taking backup of named key in the plaintext format. Once set, this cannot be disabled. +- `auto_rotate_interval` `(duration: "", optional)` – The interval at which this + key should be rotated automatically. Setting this to "0" will disable automatic + key rotation. This value cannot be shorter than one hour. When no value is + provided, the interval remains unchanged. + ### Sample Payload ```json