From cdea4b3445cbebde833fc2d4b01cf6337ccb7888 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Sat, 13 Aug 2016 14:03:22 -0400 Subject: [PATCH] Add some tests and fix some bugs --- vault/token_store.go | 28 +++++- vault/token_store_test.go | 201 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 224 insertions(+), 5 deletions(-) diff --git a/vault/token_store.go b/vault/token_store.go index d08498c64..f81a5d774 100644 --- a/vault/token_store.go +++ b/vault/token_store.go @@ -1325,8 +1325,8 @@ func (ts *TokenStore) handleCreateCommon( // Run some bounding checks if the explicit max TTL is set; we do not check // period as it's defined to escape the max TTL if te.ExplicitMaxTTL > 0 { - // Limit the lease duration - if sysView.MaxLeaseTTL() != 0 && te.ExplicitMaxTTL > sysView.MaxLeaseTTL() { + // Limit the lease duration, except for periodic tokens -- in that case the explicit max limits the period, which itself can escape normal max + if sysView.MaxLeaseTTL() != 0 && te.ExplicitMaxTTL > sysView.MaxLeaseTTL() && te.Period == 0 { resp.AddWarning(fmt.Sprintf( "Explicit max TTL of %d seconds is greater than system/mount allowed value; value is being capped to %d seconds", int64(te.ExplicitMaxTTL.Seconds()), int64(sysView.MaxLeaseTTL().Seconds()))) @@ -1334,8 +1334,10 @@ func (ts *TokenStore) handleCreateCommon( } if te.TTL == 0 { + // This won't be the case if it's periodic -- it will be set above te.TTL = te.ExplicitMaxTTL } else { + // Limit even in the periodic case if te.TTL > te.ExplicitMaxTTL { resp.AddWarning(fmt.Sprintf( "Requested TTL of %d seconds higher than explicit max TTL; value being capped to %d seconds", @@ -1617,7 +1619,13 @@ func (ts *TokenStore) authRenew( req.Auth.TTL = te.Period return &logical.Response{Auth: req.Auth}, nil } else { - f = framework.LeaseExtend(te.Period, te.ExplicitMaxTTL, ts.System()) + maxTime := time.Unix(te.CreationTime, 0).Add(te.ExplicitMaxTTL) + if maxTime.Add(-1 * te.Period).Before(time.Now()) { + req.Auth.TTL = maxTime.Sub(time.Now()) + } else { + req.Auth.TTL = te.Period + } + return &logical.Response{Auth: req.Auth}, nil } } return f(req, d) @@ -1634,11 +1642,21 @@ func (ts *TokenStore) authRenew( // Same deal here, but using the role period if role.Period != 0 { + periodToUse := role.Period + if te.Period < role.Period { + periodToUse = te.Period + } if te.ExplicitMaxTTL == 0 { - req.Auth.TTL = role.Period + req.Auth.TTL = periodToUse return &logical.Response{Auth: req.Auth}, nil } else { - f = framework.LeaseExtend(role.Period, te.ExplicitMaxTTL, ts.System()) + maxTime := time.Unix(te.CreationTime, 0).Add(te.ExplicitMaxTTL) + if maxTime.Add(-1 * periodToUse).Before(time.Now()) { + req.Auth.TTL = maxTime.Sub(time.Now()) + } else { + req.Auth.TTL = periodToUse + } + return &logical.Response{Auth: req.Auth}, nil } } diff --git a/vault/token_store_test.go b/vault/token_store_test.go index 54435414d..5923bc978 100644 --- a/vault/token_store_test.go +++ b/vault/token_store_test.go @@ -2196,3 +2196,204 @@ func TestTokenStore_RoleExplicitMaxTTL(t *testing.T) { } } } + +func TestTokenStore_Periodic(t *testing.T) { + core, _, _, root := TestCoreWithTokenStore(t) + + core.defaultLeaseTTL = 10 * time.Second + core.maxLeaseTTL = 10 * time.Second + + // Note: these requests are sent to Core since Core handles registration + // with the expiration manager and we need the storage to be consistent + + req := logical.TestRequest(t, logical.UpdateOperation, "auth/token/roles/test") + req.ClientToken = root + req.Data = map[string]interface{}{ + "period": 300, + } + + resp, err := core.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v %v", err, resp) + } + if resp != nil { + t.Fatalf("expected a nil response") + } + + // First make one directly and verify on renew it uses the period. + { + req.ClientToken = root + req.Operation = logical.UpdateOperation + req.Path = "auth/token/create" + resp, err = core.HandleRequest(req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("response was nil") + } + if resp.Auth == nil { + t.Fatal(fmt.Sprintf("response auth was nil, resp is %#v", *resp)) + } + if resp.Auth.ClientToken == "" { + t.Fatalf("bad: %#v", resp) + } + + req.ClientToken = resp.Auth.ClientToken + req.Operation = logical.ReadOperation + req.Path = "auth/token/lookup-self" + resp, err = core.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + ttl := resp.Data["ttl"].(int64) + if ttl < 299 { + t.Fatalf("TTL too small (expected %d, got %d", 299, ttl) + } + + // Let the TTL go down a bit to 3 seconds + time.Sleep(2 * time.Second) + + req.Operation = logical.UpdateOperation + req.Path = "auth/token/renew-self" + req.Data = map[string]interface{}{ + "increment": 1, + } + resp, err = core.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v %v", err, resp) + } + + req.Operation = logical.ReadOperation + req.Path = "auth/token/lookup-self" + resp, err = core.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + ttl = resp.Data["ttl"].(int64) + if ttl < 299 { + t.Fatalf("TTL too small (expected %d, got %d", 299, ttl) + } + } + + // Do the same with an explicit max TTL + { + req.ClientToken = root + req.Operation = logical.UpdateOperation + req.Path = "auth/token/create" + req.Data = map[string]interface{}{ + "period": 300, + "explicit_max_ttl": 150, + } + resp, err = core.HandleRequest(req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("response was nil") + } + if resp.Auth == nil { + t.Fatal(fmt.Sprintf("response auth was nil, resp is %#v", *resp)) + } + if resp.Auth.ClientToken == "" { + t.Fatalf("bad: %#v", resp) + } + + req.ClientToken = resp.Auth.ClientToken + req.Operation = logical.ReadOperation + req.Path = "auth/token/lookup-self" + resp, err = core.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + ttl := resp.Data["ttl"].(int64) + if ttl < 149 || ttl > 150 { + t.Fatalf("TTL bad (expected %d, got %d", 149, ttl) + } + + // Let the TTL go down a bit to 3 seconds + time.Sleep(2 * time.Second) + + req.Operation = logical.UpdateOperation + req.Path = "auth/token/renew-self" + req.Data = map[string]interface{}{ + "increment": 76, + } + resp, err = core.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v %v", err, resp) + } + + req.Operation = logical.ReadOperation + req.Path = "auth/token/lookup-self" + resp, err = core.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + ttl = resp.Data["ttl"].(int64) + if ttl < 140 || ttl > 150 { + t.Fatalf("TTL bad (expected around %d, got %d", 145, ttl) + } + } + + // Now we create a token against the role and also set the te value + // directly. We should be able to renew; increment should be ignored as + // well. + { + req.ClientToken = root + req.Operation = logical.UpdateOperation + req.Path = "auth/token/create/test" + req.Data = map[string]interface{}{ + "period": 150, + } + resp, err = core.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v %v", err, resp) + } + if resp == nil { + t.Fatal("response was nil") + } + if resp.Auth == nil { + t.Fatal(fmt.Sprintf("response auth was nil, resp is %#v", *resp)) + } + if resp.Auth.ClientToken == "" { + t.Fatalf("bad: %#v", resp) + } + + req.ClientToken = resp.Auth.ClientToken + req.Operation = logical.ReadOperation + req.Path = "auth/token/lookup-self" + resp, err = core.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + ttl := resp.Data["ttl"].(int64) + if ttl < 149 || ttl > 150 { + t.Fatalf("TTL bad (expected %d, got %d", 149, ttl) + } + + // Let the TTL go down a bit to 3 seconds + time.Sleep(2 * time.Second) + + req.Operation = logical.UpdateOperation + req.Path = "auth/token/renew-self" + req.Data = map[string]interface{}{ + "increment": 1, + } + resp, err = core.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v %v", err, resp) + } + + req.Operation = logical.ReadOperation + req.Path = "auth/token/lookup-self" + resp, err = core.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + ttl = resp.Data["ttl"].(int64) + if ttl < 149 { + t.Fatalf("TTL bad (expected %d, got %d", 149, ttl) + } + } +}