Add logic for using Auth.Period when handling auth login/renew requests (#3677)

* Add logic for using Auth.Period when handling auth login/renew requests

* Set auth.TTL if not set in handleLoginRequest

* Always set auth.TTL = te.TTL on handleLoginRequest, check TTL and period against sys values on RenewToken

* Get sysView from le.Path, revert tests

* Add back auth.Policies

* Fix TokenStore tests, add resp warning when capping values

* Use switch for ttl/period check on RenewToken

* Move comments around
This commit is contained in:
Calvin Leung Huang 2017-12-15 13:30:05 -05:00 committed by GitHub
parent 9358540d50
commit 79cb82e133
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 104 additions and 65 deletions

View file

@ -141,7 +141,7 @@ func testAccStepMapUserIdCidr(t *testing.T, cidr string) logicaltest.TestStep {
func testAccLogin(t *testing.T, display string) logicaltest.TestStep { func testAccLogin(t *testing.T, display string) logicaltest.TestStep {
checkTTL := func(resp *logical.Response) error { checkTTL := func(resp *logical.Response) error {
if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" { if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" {
return fmt.Errorf("invalid TTL") return fmt.Errorf("invalid TTL: got %s", resp.Auth.LeaseOptions.TTL)
} }
return nil return nil
} }
@ -165,7 +165,7 @@ func testAccLogin(t *testing.T, display string) logicaltest.TestStep {
func testAccLoginAppIDInPath(t *testing.T, display string) logicaltest.TestStep { func testAccLoginAppIDInPath(t *testing.T, display string) logicaltest.TestStep {
checkTTL := func(resp *logical.Response) error { checkTTL := func(resp *logical.Response) error {
if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" { if resp.Auth.LeaseOptions.TTL.String() != "768h0m0s" {
return fmt.Errorf("invalid TTL") return fmt.Errorf("invalid TTL: got %s", resp.Auth.LeaseOptions.TTL)
} }
return nil return nil
} }

View file

@ -3,7 +3,6 @@ package approle
import ( import (
"fmt" "fmt"
"strings" "strings"
"time"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
@ -68,20 +67,13 @@ func (b *backend) pathLoginUpdate(req *logical.Request, data *framework.FieldDat
Policies: role.Policies, Policies: role.Policies,
LeaseOptions: logical.LeaseOptions{ LeaseOptions: logical.LeaseOptions{
Renewable: true, Renewable: true,
TTL: role.TokenTTL,
}, },
Alias: &logical.Alias{ Alias: &logical.Alias{
Name: role.RoleID, Name: role.RoleID,
}, },
} }
// If 'Period' is set, use the value of 'Period' as the TTL.
// Otherwise, set the normal TokenTTL.
if role.Period > time.Duration(0) {
auth.TTL = role.Period
} else {
auth.TTL = role.TokenTTL
}
return &logical.Response{ return &logical.Response{
Auth: auth, Auth: auth,
}, nil }, nil
@ -107,16 +99,12 @@ func (b *backend) pathLoginRenew(req *logical.Request, data *framework.FieldData
return nil, fmt.Errorf("role %s does not exist during renewal", roleName) return nil, fmt.Errorf("role %s does not exist during renewal", roleName)
} }
// If 'Period' is set on the Role, the token should never expire. resp, err := framework.LeaseExtend(role.TokenTTL, role.TokenMaxTTL, b.System())(req, data)
// Replenish the TTL with 'Period's value. if err != nil {
if role.Period > time.Duration(0) { return nil, err
// If 'Period' was updated after the token was issued,
// token will bear the updated 'Period' value as its TTL.
req.Auth.TTL = role.Period
return &logical.Response{Auth: req.Auth}, nil
} else {
return framework.LeaseExtend(role.TokenTTL, role.TokenMaxTTL, b.System())(req, data)
} }
resp.Auth.Period = role.Period
return resp, nil
} }
const pathLoginHelpSys = "Issue a token based on the credentials supplied" const pathLoginHelpSys = "Issue a token based on the credentials supplied"

View file

@ -8,7 +8,8 @@ import (
) )
// LeaseExtend returns an OperationFunc that can be used to simply extend the // LeaseExtend returns an OperationFunc that can be used to simply extend the
// lease of the auth/secret for the duration that was requested. // lease of the auth/secret for the duration that was requested. The parameters
// provided are used to determine the lease's new TTL value.
// //
// backendIncrement is the backend's requested increment -- perhaps from a user // backendIncrement is the backend's requested increment -- perhaps from a user
// request, perhaps from a role/config value. If not set, uses the mount/system // request, perhaps from a role/config value. If not set, uses the mount/system

View file

@ -750,6 +750,31 @@ func (m *ExpirationManager) RenewToken(req *logical.Request, source string, toke
}, nil }, nil
} }
sysView := m.router.MatchingSystemView(le.Path)
if sysView == nil {
return nil, fmt.Errorf("expiration: unable to retrieve system view from router")
}
retResp := &logical.Response{}
switch {
case resp.Auth.Period > time.Duration(0):
// If it resp.Period is non-zero, use that as the TTL and override backend's
// call on TTL modification, such as a TTL value determined by
// framework.LeaseExtend call against the request. Also, cap period value to
// the sys/mount max value.
if resp.Auth.Period > sysView.MaxLeaseTTL() {
retResp.AddWarning(fmt.Sprintf("Period of %d seconds is greater than current mount/system default of %d seconds, value will be truncated.", resp.Auth.TTL, sysView.MaxLeaseTTL()))
resp.Auth.Period = sysView.MaxLeaseTTL()
}
resp.Auth.TTL = resp.Auth.Period
case resp.Auth.TTL > time.Duration(0):
// Cap TTL value to the sys/mount max value
if resp.Auth.TTL > sysView.MaxLeaseTTL() {
retResp.AddWarning(fmt.Sprintf("TTL of %d seconds is greater than current mount/system default of %d seconds, value will be truncated.", resp.Auth.TTL, sysView.MaxLeaseTTL()))
resp.Auth.TTL = sysView.MaxLeaseTTL()
}
}
// Attach the ClientToken // Attach the ClientToken
resp.Auth.ClientToken = token resp.Auth.ClientToken = token
resp.Auth.Increment = 0 resp.Auth.Increment = 0
@ -764,9 +789,9 @@ func (m *ExpirationManager) RenewToken(req *logical.Request, source string, toke
// Update the expiration time // Update the expiration time
m.updatePending(le, resp.Auth.LeaseTotal()) m.updatePending(le, resp.Auth.LeaseTotal())
return &logical.Response{
Auth: resp.Auth, retResp.Auth = resp.Auth
}, nil return retResp, nil
} }
// Register is used to take a request and response with an associated // Register is used to take a request and response with an associated
@ -866,6 +891,12 @@ func (m *ExpirationManager) RegisterAuth(source string, auth *logical.Auth) erro
return err return err
} }
// If it resp.Period is non-zero, override the TTL value determined
// by the backend.
if auth.Period > time.Duration(0) {
auth.TTL = auth.Period
}
// Create a lease entry // Create a lease entry
le := leaseEntry{ le := leaseEntry{
LeaseID: path.Join(source, saltedID), LeaseID: path.Join(source, saltedID),

View file

@ -477,15 +477,28 @@ func (c *Core) handleLoginRequest(req *logical.Request) (retResp *logical.Respon
return nil, nil, ErrInternalError return nil, nil, ErrInternalError
} }
// Set the default lease if not provided // Start off with the sys default value, and update according to period/TTL
if auth.TTL == 0 { // from resp.Auth
auth.TTL = sysView.DefaultLeaseTTL() tokenTTL := sysView.DefaultLeaseTTL()
}
// Limit the lease duration switch {
case auth.Period > time.Duration(0):
// Cap the period value to the sys max_ttl value. The auth backend should
// have checked for it on its login path, but we check here again for
// sanity.
if auth.Period > sysView.MaxLeaseTTL() {
auth.Period = sysView.MaxLeaseTTL()
}
tokenTTL = auth.Period
case auth.TTL > time.Duration(0):
// Cap the TTL value. The auth backend should have checked for it on its
// login path (e.g. a call to b.SanitizeTTL), but we check here again for
// sanity.
if auth.TTL > sysView.MaxLeaseTTL() { if auth.TTL > sysView.MaxLeaseTTL() {
auth.TTL = sysView.MaxLeaseTTL() auth.TTL = sysView.MaxLeaseTTL()
} }
tokenTTL = auth.TTL
}
// Generate a token // Generate a token
te := TokenEntry{ te := TokenEntry{
@ -494,7 +507,7 @@ func (c *Core) handleLoginRequest(req *logical.Request) (retResp *logical.Respon
Meta: auth.Metadata, Meta: auth.Metadata,
DisplayName: auth.DisplayName, DisplayName: auth.DisplayName,
CreationTime: time.Now().Unix(), CreationTime: time.Now().Unix(),
TTL: auth.TTL, TTL: tokenTTL,
NumUses: auth.NumUses, NumUses: auth.NumUses,
EntityID: auth.EntityID, EntityID: auth.EntityID,
} }
@ -513,10 +526,11 @@ func (c *Core) handleLoginRequest(req *logical.Request) (retResp *logical.Respon
return nil, auth, ErrInternalError return nil, auth, ErrInternalError
} }
// Populate the client token and accessor // Populate the client token, accessor, and TTL
auth.ClientToken = te.ID auth.ClientToken = te.ID
auth.Accessor = te.Accessor auth.Accessor = te.Accessor
auth.Policies = te.Policies auth.Policies = te.Policies
auth.TTL = te.TTL
// Register with the expiration manager // Register with the expiration manager
if err := c.expiration.RegisterAuth(te.Path, auth); err != nil { if err := c.expiration.RegisterAuth(te.Path, auth); err != nil {

View file

@ -2312,7 +2312,7 @@ func TestTokenStore_RolePeriod(t *testing.T) {
req := logical.TestRequest(t, logical.UpdateOperation, "auth/token/roles/test") req := logical.TestRequest(t, logical.UpdateOperation, "auth/token/roles/test")
req.ClientToken = root req.ClientToken = root
req.Data = map[string]interface{}{ req.Data = map[string]interface{}{
"period": 300, "period": 5,
} }
resp, err := core.HandleRequest(req) resp, err := core.HandleRequest(req)
@ -2425,8 +2425,8 @@ func TestTokenStore_RolePeriod(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
ttl := resp.Data["ttl"].(int64) ttl := resp.Data["ttl"].(int64)
if ttl < 299 { if ttl > 5 {
t.Fatalf("TTL too small (expected %d, got %d", 299, ttl) t.Fatalf("TTL too large (expected %d, got %d", 5, ttl)
} }
// Let the TTL go down a bit to 3 seconds // Let the TTL go down a bit to 3 seconds
@ -2449,8 +2449,8 @@ func TestTokenStore_RolePeriod(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
ttl = resp.Data["ttl"].(int64) ttl = resp.Data["ttl"].(int64)
if ttl < 299 { if ttl > 5 {
t.Fatalf("TTL too small (expected %d, got %d", 299, ttl) t.Fatalf("TTL too large (expected %d, got %d", 5, ttl)
} }
} }
} }
@ -2677,7 +2677,7 @@ func TestTokenStore_Periodic(t *testing.T) {
req := logical.TestRequest(t, logical.UpdateOperation, "auth/token/roles/test") req := logical.TestRequest(t, logical.UpdateOperation, "auth/token/roles/test")
req.ClientToken = root req.ClientToken = root
req.Data = map[string]interface{}{ req.Data = map[string]interface{}{
"period": 300, "period": 5,
} }
resp, err := core.HandleRequest(req) resp, err := core.HandleRequest(req)
@ -2715,8 +2715,8 @@ func TestTokenStore_Periodic(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
ttl := resp.Data["ttl"].(int64) ttl := resp.Data["ttl"].(int64)
if ttl < 299 { if ttl > 5 {
t.Fatalf("TTL too small (expected %d, got %d)", 299, ttl) t.Fatalf("TTL too large (expected %d, got %d)", 5, ttl)
} }
// Let the TTL go down a bit // Let the TTL go down a bit
@ -2739,8 +2739,8 @@ func TestTokenStore_Periodic(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
ttl = resp.Data["ttl"].(int64) ttl = resp.Data["ttl"].(int64)
if ttl < 299 { if ttl > 5 {
t.Fatalf("TTL too small (expected %d, got %d)", 299, ttl) t.Fatalf("TTL too large (expected %d, got %d)", 5, ttl)
} }
} }
@ -2750,8 +2750,8 @@ func TestTokenStore_Periodic(t *testing.T) {
req.Operation = logical.UpdateOperation req.Operation = logical.UpdateOperation
req.Path = "auth/token/create" req.Path = "auth/token/create"
req.Data = map[string]interface{}{ req.Data = map[string]interface{}{
"period": 300, "period": 5,
"explicit_max_ttl": 150, "explicit_max_ttl": 4,
} }
resp, err = core.HandleRequest(req) resp, err = core.HandleRequest(req)
if err != nil { if err != nil {
@ -2775,8 +2775,8 @@ func TestTokenStore_Periodic(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
ttl := resp.Data["ttl"].(int64) ttl := resp.Data["ttl"].(int64)
if ttl < 149 || ttl > 150 { if ttl < 3 || ttl > 4 {
t.Fatalf("TTL bad (expected %d, got %d)", 149, ttl) t.Fatalf("TTL bad (expected %d, got %d)", 3, ttl)
} }
// Let the TTL go down a bit // Let the TTL go down a bit
@ -2799,8 +2799,8 @@ func TestTokenStore_Periodic(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
ttl = resp.Data["ttl"].(int64) ttl = resp.Data["ttl"].(int64)
if ttl < 140 || ttl > 150 { if ttl > 2 {
t.Fatalf("TTL bad (expected around %d, got %d)", 145, ttl) t.Fatalf("TTL bad (expected less than %d, got %d)", 2, ttl)
} }
} }
@ -2812,7 +2812,7 @@ func TestTokenStore_Periodic(t *testing.T) {
req.Operation = logical.UpdateOperation req.Operation = logical.UpdateOperation
req.Path = "auth/token/create/test" req.Path = "auth/token/create/test"
req.Data = map[string]interface{}{ req.Data = map[string]interface{}{
"period": 150, "period": 5,
} }
resp, err = core.HandleRequest(req) resp, err = core.HandleRequest(req)
if err != nil { if err != nil {
@ -2836,8 +2836,8 @@ func TestTokenStore_Periodic(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
ttl := resp.Data["ttl"].(int64) ttl := resp.Data["ttl"].(int64)
if ttl < 149 || ttl > 150 { if ttl < 4 || ttl > 5 {
t.Fatalf("TTL bad (expected %d, got %d)", 149, ttl) t.Fatalf("TTL bad (expected %d, got %d)", 4, ttl)
} }
// Let the TTL go down a bit // Let the TTL go down a bit
@ -2860,8 +2860,8 @@ func TestTokenStore_Periodic(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
ttl = resp.Data["ttl"].(int64) ttl = resp.Data["ttl"].(int64)
if ttl < 149 { if ttl > 5 {
t.Fatalf("TTL bad (expected %d, got %d)", 149, ttl) t.Fatalf("TTL bad (expected less than %d, got %d)", 5, ttl)
} }
} }
@ -2869,18 +2869,23 @@ func TestTokenStore_Periodic(t *testing.T) {
{ {
req.Path = "auth/token/roles/test" req.Path = "auth/token/roles/test"
req.ClientToken = root req.ClientToken = root
req.Operation = logical.UpdateOperation
req.Data = map[string]interface{}{ req.Data = map[string]interface{}{
"period": 300, "period": 5,
"explicit_max_ttl": 150, "explicit_max_ttl": 4,
}
resp, err := core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v %v", err, resp)
}
if resp != nil {
t.Fatalf("expected a nil response")
} }
req.ClientToken = root req.ClientToken = root
req.Operation = logical.UpdateOperation req.Operation = logical.UpdateOperation
req.Path = "auth/token/create/test" req.Path = "auth/token/create/test"
req.Data = map[string]interface{}{
"period": 150,
"explicit_max_ttl": 130,
}
resp, err = core.HandleRequest(req) resp, err = core.HandleRequest(req)
if err != nil { if err != nil {
t.Fatalf("err: %v %v", err, resp) t.Fatalf("err: %v %v", err, resp)
@ -2903,12 +2908,12 @@ func TestTokenStore_Periodic(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
ttl := resp.Data["ttl"].(int64) ttl := resp.Data["ttl"].(int64)
if ttl < 129 || ttl > 130 { if ttl < 3 || ttl > 4 {
t.Fatalf("TTL bad (expected %d, got %d)", 129, ttl) t.Fatalf("TTL bad (expected %d, got %d)", 3, ttl)
} }
// Let the TTL go down a bit // Let the TTL go down a bit
time.Sleep(4 * time.Second) time.Sleep(2 * time.Second)
req.Operation = logical.UpdateOperation req.Operation = logical.UpdateOperation
req.Path = "auth/token/renew-self" req.Path = "auth/token/renew-self"
@ -2927,8 +2932,8 @@ func TestTokenStore_Periodic(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
ttl = resp.Data["ttl"].(int64) ttl = resp.Data["ttl"].(int64)
if ttl > 127 { if ttl > 2 {
t.Fatalf("TTL bad (expected < %d, got %d)", 128, ttl) t.Fatalf("TTL bad (expected less than %d, got %d)", 2, ttl)
} }
} }
} }