diff --git a/builtin/credential/app-id/backend_test.go b/builtin/credential/app-id/backend_test.go index 4ae5d3e1c..e5d335b4f 100644 --- a/builtin/credential/app-id/backend_test.go +++ b/builtin/credential/app-id/backend_test.go @@ -141,7 +141,7 @@ func testAccStepMapUserIdCidr(t *testing.T, cidr string) logicaltest.TestStep { func testAccLogin(t *testing.T, display string) logicaltest.TestStep { checkTTL := func(resp *logical.Response) error { if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" { - return fmt.Errorf("invalid TTL") + return fmt.Errorf("invalid TTL: got %s", resp.Auth.LeaseOptions.TTL) } return nil } @@ -165,7 +165,7 @@ func testAccLogin(t *testing.T, display string) logicaltest.TestStep { func testAccLoginAppIDInPath(t *testing.T, display string) logicaltest.TestStep { checkTTL := func(resp *logical.Response) error { if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" { - return fmt.Errorf("invalid TTL") + return fmt.Errorf("invalid TTL: got %s", resp.Auth.LeaseOptions.TTL) } return nil } diff --git a/builtin/credential/approle/path_login.go b/builtin/credential/approle/path_login.go index 300ee9409..3dd829a84 100644 --- a/builtin/credential/approle/path_login.go +++ b/builtin/credential/approle/path_login.go @@ -3,7 +3,6 @@ package approle import ( "fmt" "strings" - "time" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" @@ -68,20 +67,13 @@ func (b *backend) pathLoginUpdate(req *logical.Request, data *framework.FieldDat Policies: role.Policies, LeaseOptions: logical.LeaseOptions{ Renewable: true, + TTL: role.TokenTTL, }, Alias: &logical.Alias{ Name: role.RoleID, }, } - // If 'Period' is set, use the value of 'Period' as the TTL. - // Otherwise, set the normal TokenTTL. - if role.Period > time.Duration(0) { - auth.TTL = role.Period - } else { - auth.TTL = role.TokenTTL - } - return &logical.Response{ Auth: auth, }, nil @@ -107,16 +99,12 @@ func (b *backend) pathLoginRenew(req *logical.Request, data *framework.FieldData return nil, fmt.Errorf("role %s does not exist during renewal", roleName) } - // If 'Period' is set on the Role, the token should never expire. - // Replenish the TTL with 'Period's value. - if role.Period > time.Duration(0) { - // If 'Period' was updated after the token was issued, - // token will bear the updated 'Period' value as its TTL. - req.Auth.TTL = role.Period - return &logical.Response{Auth: req.Auth}, nil - } else { - return framework.LeaseExtend(role.TokenTTL, role.TokenMaxTTL, b.System())(req, data) + resp, err := framework.LeaseExtend(role.TokenTTL, role.TokenMaxTTL, b.System())(req, data) + if err != nil { + return nil, err } + resp.Auth.Period = role.Period + return resp, nil } const pathLoginHelpSys = "Issue a token based on the credentials supplied" diff --git a/logical/framework/lease.go b/logical/framework/lease.go index 4fd2ac902..d2678f712 100644 --- a/logical/framework/lease.go +++ b/logical/framework/lease.go @@ -8,7 +8,8 @@ import ( ) // LeaseExtend returns an OperationFunc that can be used to simply extend the -// lease of the auth/secret for the duration that was requested. +// lease of the auth/secret for the duration that was requested. The parameters +// provided are used to determine the lease's new TTL value. // // backendIncrement is the backend's requested increment -- perhaps from a user // request, perhaps from a role/config value. If not set, uses the mount/system diff --git a/vault/expiration.go b/vault/expiration.go index 6ebbb99b8..b42f334d5 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -750,6 +750,31 @@ func (m *ExpirationManager) RenewToken(req *logical.Request, source string, toke }, nil } + sysView := m.router.MatchingSystemView(le.Path) + if sysView == nil { + return nil, fmt.Errorf("expiration: unable to retrieve system view from router") + } + + retResp := &logical.Response{} + switch { + case resp.Auth.Period > time.Duration(0): + // If it resp.Period is non-zero, use that as the TTL and override backend's + // call on TTL modification, such as a TTL value determined by + // framework.LeaseExtend call against the request. Also, cap period value to + // the sys/mount max value. + if resp.Auth.Period > sysView.MaxLeaseTTL() { + retResp.AddWarning(fmt.Sprintf("Period of %d seconds is greater than current mount/system default of %d seconds, value will be truncated.", resp.Auth.TTL, sysView.MaxLeaseTTL())) + resp.Auth.Period = sysView.MaxLeaseTTL() + } + resp.Auth.TTL = resp.Auth.Period + case resp.Auth.TTL > time.Duration(0): + // Cap TTL value to the sys/mount max value + if resp.Auth.TTL > sysView.MaxLeaseTTL() { + retResp.AddWarning(fmt.Sprintf("TTL of %d seconds is greater than current mount/system default of %d seconds, value will be truncated.", resp.Auth.TTL, sysView.MaxLeaseTTL())) + resp.Auth.TTL = sysView.MaxLeaseTTL() + } + } + // Attach the ClientToken resp.Auth.ClientToken = token resp.Auth.Increment = 0 @@ -764,9 +789,9 @@ func (m *ExpirationManager) RenewToken(req *logical.Request, source string, toke // Update the expiration time m.updatePending(le, resp.Auth.LeaseTotal()) - return &logical.Response{ - Auth: resp.Auth, - }, nil + + retResp.Auth = resp.Auth + return retResp, nil } // Register is used to take a request and response with an associated @@ -866,6 +891,12 @@ func (m *ExpirationManager) RegisterAuth(source string, auth *logical.Auth) erro return err } + // If it resp.Period is non-zero, override the TTL value determined + // by the backend. + if auth.Period > time.Duration(0) { + auth.TTL = auth.Period + } + // Create a lease entry le := leaseEntry{ LeaseID: path.Join(source, saltedID), diff --git a/vault/request_handling.go b/vault/request_handling.go index 7fb7ea1d3..3453ff1b0 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -477,14 +477,27 @@ func (c *Core) handleLoginRequest(req *logical.Request) (retResp *logical.Respon return nil, nil, ErrInternalError } - // Set the default lease if not provided - if auth.TTL == 0 { - auth.TTL = sysView.DefaultLeaseTTL() - } + // Start off with the sys default value, and update according to period/TTL + // from resp.Auth + tokenTTL := sysView.DefaultLeaseTTL() - // Limit the lease duration - if auth.TTL > sysView.MaxLeaseTTL() { - auth.TTL = sysView.MaxLeaseTTL() + switch { + case auth.Period > time.Duration(0): + // Cap the period value to the sys max_ttl value. The auth backend should + // have checked for it on its login path, but we check here again for + // sanity. + if auth.Period > sysView.MaxLeaseTTL() { + auth.Period = sysView.MaxLeaseTTL() + } + tokenTTL = auth.Period + case auth.TTL > time.Duration(0): + // Cap the TTL value. The auth backend should have checked for it on its + // login path (e.g. a call to b.SanitizeTTL), but we check here again for + // sanity. + if auth.TTL > sysView.MaxLeaseTTL() { + auth.TTL = sysView.MaxLeaseTTL() + } + tokenTTL = auth.TTL } // Generate a token @@ -494,7 +507,7 @@ func (c *Core) handleLoginRequest(req *logical.Request) (retResp *logical.Respon Meta: auth.Metadata, DisplayName: auth.DisplayName, CreationTime: time.Now().Unix(), - TTL: auth.TTL, + TTL: tokenTTL, NumUses: auth.NumUses, EntityID: auth.EntityID, } @@ -513,10 +526,11 @@ func (c *Core) handleLoginRequest(req *logical.Request) (retResp *logical.Respon return nil, auth, ErrInternalError } - // Populate the client token and accessor + // Populate the client token, accessor, and TTL auth.ClientToken = te.ID auth.Accessor = te.Accessor auth.Policies = te.Policies + auth.TTL = te.TTL // Register with the expiration manager if err := c.expiration.RegisterAuth(te.Path, auth); err != nil { diff --git a/vault/token_store_test.go b/vault/token_store_test.go index 1b3d7c346..db8cf4816 100644 --- a/vault/token_store_test.go +++ b/vault/token_store_test.go @@ -2312,7 +2312,7 @@ func TestTokenStore_RolePeriod(t *testing.T) { req := logical.TestRequest(t, logical.UpdateOperation, "auth/token/roles/test") req.ClientToken = root req.Data = map[string]interface{}{ - "period": 300, + "period": 5, } resp, err := core.HandleRequest(req) @@ -2425,8 +2425,8 @@ func TestTokenStore_RolePeriod(t *testing.T) { t.Fatalf("err: %v", err) } ttl := resp.Data["ttl"].(int64) - if ttl < 299 { - t.Fatalf("TTL too small (expected %d, got %d", 299, ttl) + if ttl > 5 { + t.Fatalf("TTL too large (expected %d, got %d", 5, ttl) } // Let the TTL go down a bit to 3 seconds @@ -2449,8 +2449,8 @@ func TestTokenStore_RolePeriod(t *testing.T) { t.Fatalf("err: %v", err) } ttl = resp.Data["ttl"].(int64) - if ttl < 299 { - t.Fatalf("TTL too small (expected %d, got %d", 299, ttl) + if ttl > 5 { + t.Fatalf("TTL too large (expected %d, got %d", 5, ttl) } } } @@ -2677,7 +2677,7 @@ func TestTokenStore_Periodic(t *testing.T) { req := logical.TestRequest(t, logical.UpdateOperation, "auth/token/roles/test") req.ClientToken = root req.Data = map[string]interface{}{ - "period": 300, + "period": 5, } resp, err := core.HandleRequest(req) @@ -2715,8 +2715,8 @@ func TestTokenStore_Periodic(t *testing.T) { t.Fatalf("err: %v", err) } ttl := resp.Data["ttl"].(int64) - if ttl < 299 { - t.Fatalf("TTL too small (expected %d, got %d)", 299, ttl) + if ttl > 5 { + t.Fatalf("TTL too large (expected %d, got %d)", 5, ttl) } // Let the TTL go down a bit @@ -2739,8 +2739,8 @@ func TestTokenStore_Periodic(t *testing.T) { t.Fatalf("err: %v", err) } ttl = resp.Data["ttl"].(int64) - if ttl < 299 { - t.Fatalf("TTL too small (expected %d, got %d)", 299, ttl) + if ttl > 5 { + t.Fatalf("TTL too large (expected %d, got %d)", 5, ttl) } } @@ -2750,8 +2750,8 @@ func TestTokenStore_Periodic(t *testing.T) { req.Operation = logical.UpdateOperation req.Path = "auth/token/create" req.Data = map[string]interface{}{ - "period": 300, - "explicit_max_ttl": 150, + "period": 5, + "explicit_max_ttl": 4, } resp, err = core.HandleRequest(req) if err != nil { @@ -2775,8 +2775,8 @@ func TestTokenStore_Periodic(t *testing.T) { t.Fatalf("err: %v", err) } ttl := resp.Data["ttl"].(int64) - if ttl < 149 || ttl > 150 { - t.Fatalf("TTL bad (expected %d, got %d)", 149, ttl) + if ttl < 3 || ttl > 4 { + t.Fatalf("TTL bad (expected %d, got %d)", 3, ttl) } // Let the TTL go down a bit @@ -2799,8 +2799,8 @@ func TestTokenStore_Periodic(t *testing.T) { t.Fatalf("err: %v", err) } ttl = resp.Data["ttl"].(int64) - if ttl < 140 || ttl > 150 { - t.Fatalf("TTL bad (expected around %d, got %d)", 145, ttl) + if ttl > 2 { + t.Fatalf("TTL bad (expected less than %d, got %d)", 2, ttl) } } @@ -2812,7 +2812,7 @@ func TestTokenStore_Periodic(t *testing.T) { req.Operation = logical.UpdateOperation req.Path = "auth/token/create/test" req.Data = map[string]interface{}{ - "period": 150, + "period": 5, } resp, err = core.HandleRequest(req) if err != nil { @@ -2836,8 +2836,8 @@ func TestTokenStore_Periodic(t *testing.T) { t.Fatalf("err: %v", err) } ttl := resp.Data["ttl"].(int64) - if ttl < 149 || ttl > 150 { - t.Fatalf("TTL bad (expected %d, got %d)", 149, ttl) + if ttl < 4 || ttl > 5 { + t.Fatalf("TTL bad (expected %d, got %d)", 4, ttl) } // Let the TTL go down a bit @@ -2860,8 +2860,8 @@ func TestTokenStore_Periodic(t *testing.T) { t.Fatalf("err: %v", err) } ttl = resp.Data["ttl"].(int64) - if ttl < 149 { - t.Fatalf("TTL bad (expected %d, got %d)", 149, ttl) + if ttl > 5 { + t.Fatalf("TTL bad (expected less than %d, got %d)", 5, ttl) } } @@ -2869,18 +2869,23 @@ func TestTokenStore_Periodic(t *testing.T) { { req.Path = "auth/token/roles/test" req.ClientToken = root + req.Operation = logical.UpdateOperation req.Data = map[string]interface{}{ - "period": 300, - "explicit_max_ttl": 150, + "period": 5, + "explicit_max_ttl": 4, + } + + resp, err := core.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v %v", err, resp) + } + if resp != nil { + t.Fatalf("expected a nil response") } req.ClientToken = root req.Operation = logical.UpdateOperation req.Path = "auth/token/create/test" - req.Data = map[string]interface{}{ - "period": 150, - "explicit_max_ttl": 130, - } resp, err = core.HandleRequest(req) if err != nil { t.Fatalf("err: %v %v", err, resp) @@ -2903,12 +2908,12 @@ func TestTokenStore_Periodic(t *testing.T) { t.Fatalf("err: %v", err) } ttl := resp.Data["ttl"].(int64) - if ttl < 129 || ttl > 130 { - t.Fatalf("TTL bad (expected %d, got %d)", 129, ttl) + if ttl < 3 || ttl > 4 { + t.Fatalf("TTL bad (expected %d, got %d)", 3, ttl) } // Let the TTL go down a bit - time.Sleep(4 * time.Second) + time.Sleep(2 * time.Second) req.Operation = logical.UpdateOperation req.Path = "auth/token/renew-self" @@ -2927,8 +2932,8 @@ func TestTokenStore_Periodic(t *testing.T) { t.Fatalf("err: %v", err) } ttl = resp.Data["ttl"].(int64) - if ttl > 127 { - t.Fatalf("TTL bad (expected < %d, got %d)", 128, ttl) + if ttl > 2 { + t.Fatalf("TTL bad (expected less than %d, got %d)", 2, ttl) } } }