Add logic for using Auth.Period when handling auth login/renew requests (#3677)
* Add logic for using Auth.Period when handling auth login/renew requests * Set auth.TTL if not set in handleLoginRequest * Always set auth.TTL = te.TTL on handleLoginRequest, check TTL and period against sys values on RenewToken * Get sysView from le.Path, revert tests * Add back auth.Policies * Fix TokenStore tests, add resp warning when capping values * Use switch for ttl/period check on RenewToken * Move comments around
This commit is contained in:
parent
9358540d50
commit
79cb82e133
|
@ -141,7 +141,7 @@ func testAccStepMapUserIdCidr(t *testing.T, cidr string) logicaltest.TestStep {
|
|||
func testAccLogin(t *testing.T, display string) logicaltest.TestStep {
|
||||
checkTTL := func(resp *logical.Response) error {
|
||||
if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" {
|
||||
return fmt.Errorf("invalid TTL")
|
||||
return fmt.Errorf("invalid TTL: got %s", resp.Auth.LeaseOptions.TTL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -165,7 +165,7 @@ func testAccLogin(t *testing.T, display string) logicaltest.TestStep {
|
|||
func testAccLoginAppIDInPath(t *testing.T, display string) logicaltest.TestStep {
|
||||
checkTTL := func(resp *logical.Response) error {
|
||||
if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" {
|
||||
return fmt.Errorf("invalid TTL")
|
||||
return fmt.Errorf("invalid TTL: got %s", resp.Auth.LeaseOptions.TTL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package approle
|
|||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
|
@ -68,20 +67,13 @@ func (b *backend) pathLoginUpdate(req *logical.Request, data *framework.FieldDat
|
|||
Policies: role.Policies,
|
||||
LeaseOptions: logical.LeaseOptions{
|
||||
Renewable: true,
|
||||
TTL: role.TokenTTL,
|
||||
},
|
||||
Alias: &logical.Alias{
|
||||
Name: role.RoleID,
|
||||
},
|
||||
}
|
||||
|
||||
// If 'Period' is set, use the value of 'Period' as the TTL.
|
||||
// Otherwise, set the normal TokenTTL.
|
||||
if role.Period > time.Duration(0) {
|
||||
auth.TTL = role.Period
|
||||
} else {
|
||||
auth.TTL = role.TokenTTL
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Auth: auth,
|
||||
}, nil
|
||||
|
@ -107,16 +99,12 @@ func (b *backend) pathLoginRenew(req *logical.Request, data *framework.FieldData
|
|||
return nil, fmt.Errorf("role %s does not exist during renewal", roleName)
|
||||
}
|
||||
|
||||
// If 'Period' is set on the Role, the token should never expire.
|
||||
// Replenish the TTL with 'Period's value.
|
||||
if role.Period > time.Duration(0) {
|
||||
// If 'Period' was updated after the token was issued,
|
||||
// token will bear the updated 'Period' value as its TTL.
|
||||
req.Auth.TTL = role.Period
|
||||
return &logical.Response{Auth: req.Auth}, nil
|
||||
} else {
|
||||
return framework.LeaseExtend(role.TokenTTL, role.TokenMaxTTL, b.System())(req, data)
|
||||
resp, err := framework.LeaseExtend(role.TokenTTL, role.TokenMaxTTL, b.System())(req, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Auth.Period = role.Period
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
const pathLoginHelpSys = "Issue a token based on the credentials supplied"
|
||||
|
|
|
@ -8,7 +8,8 @@ import (
|
|||
)
|
||||
|
||||
// LeaseExtend returns an OperationFunc that can be used to simply extend the
|
||||
// lease of the auth/secret for the duration that was requested.
|
||||
// lease of the auth/secret for the duration that was requested. The parameters
|
||||
// provided are used to determine the lease's new TTL value.
|
||||
//
|
||||
// backendIncrement is the backend's requested increment -- perhaps from a user
|
||||
// request, perhaps from a role/config value. If not set, uses the mount/system
|
||||
|
|
|
@ -750,6 +750,31 @@ func (m *ExpirationManager) RenewToken(req *logical.Request, source string, toke
|
|||
}, nil
|
||||
}
|
||||
|
||||
sysView := m.router.MatchingSystemView(le.Path)
|
||||
if sysView == nil {
|
||||
return nil, fmt.Errorf("expiration: unable to retrieve system view from router")
|
||||
}
|
||||
|
||||
retResp := &logical.Response{}
|
||||
switch {
|
||||
case resp.Auth.Period > time.Duration(0):
|
||||
// If it resp.Period is non-zero, use that as the TTL and override backend's
|
||||
// call on TTL modification, such as a TTL value determined by
|
||||
// framework.LeaseExtend call against the request. Also, cap period value to
|
||||
// the sys/mount max value.
|
||||
if resp.Auth.Period > sysView.MaxLeaseTTL() {
|
||||
retResp.AddWarning(fmt.Sprintf("Period of %d seconds is greater than current mount/system default of %d seconds, value will be truncated.", resp.Auth.TTL, sysView.MaxLeaseTTL()))
|
||||
resp.Auth.Period = sysView.MaxLeaseTTL()
|
||||
}
|
||||
resp.Auth.TTL = resp.Auth.Period
|
||||
case resp.Auth.TTL > time.Duration(0):
|
||||
// Cap TTL value to the sys/mount max value
|
||||
if resp.Auth.TTL > sysView.MaxLeaseTTL() {
|
||||
retResp.AddWarning(fmt.Sprintf("TTL of %d seconds is greater than current mount/system default of %d seconds, value will be truncated.", resp.Auth.TTL, sysView.MaxLeaseTTL()))
|
||||
resp.Auth.TTL = sysView.MaxLeaseTTL()
|
||||
}
|
||||
}
|
||||
|
||||
// Attach the ClientToken
|
||||
resp.Auth.ClientToken = token
|
||||
resp.Auth.Increment = 0
|
||||
|
@ -764,9 +789,9 @@ func (m *ExpirationManager) RenewToken(req *logical.Request, source string, toke
|
|||
|
||||
// Update the expiration time
|
||||
m.updatePending(le, resp.Auth.LeaseTotal())
|
||||
return &logical.Response{
|
||||
Auth: resp.Auth,
|
||||
}, nil
|
||||
|
||||
retResp.Auth = resp.Auth
|
||||
return retResp, nil
|
||||
}
|
||||
|
||||
// Register is used to take a request and response with an associated
|
||||
|
@ -866,6 +891,12 @@ func (m *ExpirationManager) RegisterAuth(source string, auth *logical.Auth) erro
|
|||
return err
|
||||
}
|
||||
|
||||
// If it resp.Period is non-zero, override the TTL value determined
|
||||
// by the backend.
|
||||
if auth.Period > time.Duration(0) {
|
||||
auth.TTL = auth.Period
|
||||
}
|
||||
|
||||
// Create a lease entry
|
||||
le := leaseEntry{
|
||||
LeaseID: path.Join(source, saltedID),
|
||||
|
|
|
@ -477,14 +477,27 @@ func (c *Core) handleLoginRequest(req *logical.Request) (retResp *logical.Respon
|
|||
return nil, nil, ErrInternalError
|
||||
}
|
||||
|
||||
// Set the default lease if not provided
|
||||
if auth.TTL == 0 {
|
||||
auth.TTL = sysView.DefaultLeaseTTL()
|
||||
}
|
||||
// Start off with the sys default value, and update according to period/TTL
|
||||
// from resp.Auth
|
||||
tokenTTL := sysView.DefaultLeaseTTL()
|
||||
|
||||
// Limit the lease duration
|
||||
if auth.TTL > sysView.MaxLeaseTTL() {
|
||||
auth.TTL = sysView.MaxLeaseTTL()
|
||||
switch {
|
||||
case auth.Period > time.Duration(0):
|
||||
// Cap the period value to the sys max_ttl value. The auth backend should
|
||||
// have checked for it on its login path, but we check here again for
|
||||
// sanity.
|
||||
if auth.Period > sysView.MaxLeaseTTL() {
|
||||
auth.Period = sysView.MaxLeaseTTL()
|
||||
}
|
||||
tokenTTL = auth.Period
|
||||
case auth.TTL > time.Duration(0):
|
||||
// Cap the TTL value. The auth backend should have checked for it on its
|
||||
// login path (e.g. a call to b.SanitizeTTL), but we check here again for
|
||||
// sanity.
|
||||
if auth.TTL > sysView.MaxLeaseTTL() {
|
||||
auth.TTL = sysView.MaxLeaseTTL()
|
||||
}
|
||||
tokenTTL = auth.TTL
|
||||
}
|
||||
|
||||
// Generate a token
|
||||
|
@ -494,7 +507,7 @@ func (c *Core) handleLoginRequest(req *logical.Request) (retResp *logical.Respon
|
|||
Meta: auth.Metadata,
|
||||
DisplayName: auth.DisplayName,
|
||||
CreationTime: time.Now().Unix(),
|
||||
TTL: auth.TTL,
|
||||
TTL: tokenTTL,
|
||||
NumUses: auth.NumUses,
|
||||
EntityID: auth.EntityID,
|
||||
}
|
||||
|
@ -513,10 +526,11 @@ func (c *Core) handleLoginRequest(req *logical.Request) (retResp *logical.Respon
|
|||
return nil, auth, ErrInternalError
|
||||
}
|
||||
|
||||
// Populate the client token and accessor
|
||||
// Populate the client token, accessor, and TTL
|
||||
auth.ClientToken = te.ID
|
||||
auth.Accessor = te.Accessor
|
||||
auth.Policies = te.Policies
|
||||
auth.TTL = te.TTL
|
||||
|
||||
// Register with the expiration manager
|
||||
if err := c.expiration.RegisterAuth(te.Path, auth); err != nil {
|
||||
|
|
|
@ -2312,7 +2312,7 @@ func TestTokenStore_RolePeriod(t *testing.T) {
|
|||
req := logical.TestRequest(t, logical.UpdateOperation, "auth/token/roles/test")
|
||||
req.ClientToken = root
|
||||
req.Data = map[string]interface{}{
|
||||
"period": 300,
|
||||
"period": 5,
|
||||
}
|
||||
|
||||
resp, err := core.HandleRequest(req)
|
||||
|
@ -2425,8 +2425,8 @@ func TestTokenStore_RolePeriod(t *testing.T) {
|
|||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
ttl := resp.Data["ttl"].(int64)
|
||||
if ttl < 299 {
|
||||
t.Fatalf("TTL too small (expected %d, got %d", 299, ttl)
|
||||
if ttl > 5 {
|
||||
t.Fatalf("TTL too large (expected %d, got %d", 5, ttl)
|
||||
}
|
||||
|
||||
// Let the TTL go down a bit to 3 seconds
|
||||
|
@ -2449,8 +2449,8 @@ func TestTokenStore_RolePeriod(t *testing.T) {
|
|||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
ttl = resp.Data["ttl"].(int64)
|
||||
if ttl < 299 {
|
||||
t.Fatalf("TTL too small (expected %d, got %d", 299, ttl)
|
||||
if ttl > 5 {
|
||||
t.Fatalf("TTL too large (expected %d, got %d", 5, ttl)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2677,7 +2677,7 @@ func TestTokenStore_Periodic(t *testing.T) {
|
|||
req := logical.TestRequest(t, logical.UpdateOperation, "auth/token/roles/test")
|
||||
req.ClientToken = root
|
||||
req.Data = map[string]interface{}{
|
||||
"period": 300,
|
||||
"period": 5,
|
||||
}
|
||||
|
||||
resp, err := core.HandleRequest(req)
|
||||
|
@ -2715,8 +2715,8 @@ func TestTokenStore_Periodic(t *testing.T) {
|
|||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
ttl := resp.Data["ttl"].(int64)
|
||||
if ttl < 299 {
|
||||
t.Fatalf("TTL too small (expected %d, got %d)", 299, ttl)
|
||||
if ttl > 5 {
|
||||
t.Fatalf("TTL too large (expected %d, got %d)", 5, ttl)
|
||||
}
|
||||
|
||||
// Let the TTL go down a bit
|
||||
|
@ -2739,8 +2739,8 @@ func TestTokenStore_Periodic(t *testing.T) {
|
|||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
ttl = resp.Data["ttl"].(int64)
|
||||
if ttl < 299 {
|
||||
t.Fatalf("TTL too small (expected %d, got %d)", 299, ttl)
|
||||
if ttl > 5 {
|
||||
t.Fatalf("TTL too large (expected %d, got %d)", 5, ttl)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2750,8 +2750,8 @@ func TestTokenStore_Periodic(t *testing.T) {
|
|||
req.Operation = logical.UpdateOperation
|
||||
req.Path = "auth/token/create"
|
||||
req.Data = map[string]interface{}{
|
||||
"period": 300,
|
||||
"explicit_max_ttl": 150,
|
||||
"period": 5,
|
||||
"explicit_max_ttl": 4,
|
||||
}
|
||||
resp, err = core.HandleRequest(req)
|
||||
if err != nil {
|
||||
|
@ -2775,8 +2775,8 @@ func TestTokenStore_Periodic(t *testing.T) {
|
|||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
ttl := resp.Data["ttl"].(int64)
|
||||
if ttl < 149 || ttl > 150 {
|
||||
t.Fatalf("TTL bad (expected %d, got %d)", 149, ttl)
|
||||
if ttl < 3 || ttl > 4 {
|
||||
t.Fatalf("TTL bad (expected %d, got %d)", 3, ttl)
|
||||
}
|
||||
|
||||
// Let the TTL go down a bit
|
||||
|
@ -2799,8 +2799,8 @@ func TestTokenStore_Periodic(t *testing.T) {
|
|||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
ttl = resp.Data["ttl"].(int64)
|
||||
if ttl < 140 || ttl > 150 {
|
||||
t.Fatalf("TTL bad (expected around %d, got %d)", 145, ttl)
|
||||
if ttl > 2 {
|
||||
t.Fatalf("TTL bad (expected less than %d, got %d)", 2, ttl)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2812,7 +2812,7 @@ func TestTokenStore_Periodic(t *testing.T) {
|
|||
req.Operation = logical.UpdateOperation
|
||||
req.Path = "auth/token/create/test"
|
||||
req.Data = map[string]interface{}{
|
||||
"period": 150,
|
||||
"period": 5,
|
||||
}
|
||||
resp, err = core.HandleRequest(req)
|
||||
if err != nil {
|
||||
|
@ -2836,8 +2836,8 @@ func TestTokenStore_Periodic(t *testing.T) {
|
|||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
ttl := resp.Data["ttl"].(int64)
|
||||
if ttl < 149 || ttl > 150 {
|
||||
t.Fatalf("TTL bad (expected %d, got %d)", 149, ttl)
|
||||
if ttl < 4 || ttl > 5 {
|
||||
t.Fatalf("TTL bad (expected %d, got %d)", 4, ttl)
|
||||
}
|
||||
|
||||
// Let the TTL go down a bit
|
||||
|
@ -2860,8 +2860,8 @@ func TestTokenStore_Periodic(t *testing.T) {
|
|||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
ttl = resp.Data["ttl"].(int64)
|
||||
if ttl < 149 {
|
||||
t.Fatalf("TTL bad (expected %d, got %d)", 149, ttl)
|
||||
if ttl > 5 {
|
||||
t.Fatalf("TTL bad (expected less than %d, got %d)", 5, ttl)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2869,18 +2869,23 @@ func TestTokenStore_Periodic(t *testing.T) {
|
|||
{
|
||||
req.Path = "auth/token/roles/test"
|
||||
req.ClientToken = root
|
||||
req.Operation = logical.UpdateOperation
|
||||
req.Data = map[string]interface{}{
|
||||
"period": 300,
|
||||
"explicit_max_ttl": 150,
|
||||
"period": 5,
|
||||
"explicit_max_ttl": 4,
|
||||
}
|
||||
|
||||
resp, err := core.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v %v", err, resp)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatalf("expected a nil response")
|
||||
}
|
||||
|
||||
req.ClientToken = root
|
||||
req.Operation = logical.UpdateOperation
|
||||
req.Path = "auth/token/create/test"
|
||||
req.Data = map[string]interface{}{
|
||||
"period": 150,
|
||||
"explicit_max_ttl": 130,
|
||||
}
|
||||
resp, err = core.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v %v", err, resp)
|
||||
|
@ -2903,12 +2908,12 @@ func TestTokenStore_Periodic(t *testing.T) {
|
|||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
ttl := resp.Data["ttl"].(int64)
|
||||
if ttl < 129 || ttl > 130 {
|
||||
t.Fatalf("TTL bad (expected %d, got %d)", 129, ttl)
|
||||
if ttl < 3 || ttl > 4 {
|
||||
t.Fatalf("TTL bad (expected %d, got %d)", 3, ttl)
|
||||
}
|
||||
|
||||
// Let the TTL go down a bit
|
||||
time.Sleep(4 * time.Second)
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
req.Operation = logical.UpdateOperation
|
||||
req.Path = "auth/token/renew-self"
|
||||
|
@ -2927,8 +2932,8 @@ func TestTokenStore_Periodic(t *testing.T) {
|
|||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
ttl = resp.Data["ttl"].(int64)
|
||||
if ttl > 127 {
|
||||
t.Fatalf("TTL bad (expected < %d, got %d)", 128, ttl)
|
||||
if ttl > 2 {
|
||||
t.Fatalf("TTL bad (expected less than %d, got %d)", 2, ttl)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue