Fix identity token panic during invalidation (#8015)

* Fix identity token crash during invalidation

* Check for nil namespace

* Fix test

* Add nil check test

* Check OIDC cache errors
This commit is contained in:
Jim Kalafut 2019-12-17 10:43:38 -08:00 committed by Brian Kassouf
parent 02dfa885d4
commit 5821fe48c7
3 changed files with 138 additions and 33 deletions

View File

@ -315,7 +315,9 @@ func (i *IdentityStore) Invalidate(ctx context.Context, key string) {
return
case strings.HasPrefix(key, oidcTokensPrefix):
i.oidcCache.Flush(nil)
if err := i.oidcCache.Flush(noNamespace); err != nil {
i.logger.Error("error flushing oidc cache", "error", err)
}
}
}

View File

@ -8,6 +8,7 @@ import (
"crypto/rsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/url"
"strings"
@ -90,6 +91,8 @@ type oidcCache struct {
c *cache.Cache
}
var errNilNamespace = errors.New("nil namespace in oidc cache request")
const (
issuerPath = "identity/oidc"
oidcTokensPrefix = "oidc_tokens/"
@ -111,7 +114,7 @@ var supportedAlgs = []string{
}
// pseudo-namespace for cache items that don't belong to any real namespace.
var nilNamespace = &namespace.Namespace{ID: "__NIL_NAMESPACE"}
var noNamespace = &namespace.Namespace{ID: "__NO_NAMESPACE"}
func oidcPaths(i *IdentityStore) []*framework.Path {
return []*framework.Path{
@ -370,7 +373,9 @@ func (i *IdentityStore) pathOIDCUpdateConfig(ctx context.Context, req *logical.R
return nil, err
}
i.oidcCache.Flush(ns)
if err := i.oidcCache.Flush(ns); err != nil {
return nil, err
}
return resp, nil
}
@ -381,7 +386,12 @@ func (i *IdentityStore) getOIDCConfig(ctx context.Context, s logical.Storage) (*
return nil, err
}
if v, ok := i.oidcCache.Get(ns, "config"); ok {
v, ok, err := i.oidcCache.Get(ns, "config")
if err != nil {
return nil, err
}
if ok {
return v.(*oidcConfig), nil
}
@ -404,7 +414,9 @@ func (i *IdentityStore) getOIDCConfig(ctx context.Context, s logical.Storage) (*
c.effectiveIssuer += "/v1/" + ns.Path + issuerPath
i.oidcCache.SetDefault(ns, "config", &c)
if err := i.oidcCache.SetDefault(ns, "config", &c); err != nil {
return nil, err
}
return &c, nil
}
@ -416,8 +428,6 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica
return nil, err
}
defer i.oidcCache.Flush(ns)
name := d.Get("name").(string)
i.oidcLock.Lock()
@ -494,6 +504,10 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica
}
}
if err := i.oidcCache.Flush(ns); err != nil {
return nil, err
}
// store named key
entry, err := logical.StorageEntryJSON(namedKeyConfigPath+name, key)
if err != nil {
@ -590,7 +604,9 @@ func (i *IdentityStore) pathOIDCDeleteKey(ctx context.Context, req *logical.Requ
return nil, err
}
i.oidcCache.Flush(ns)
if err := i.oidcCache.Flush(ns); err != nil {
return nil, err
}
return nil, nil
}
@ -645,7 +661,9 @@ func (i *IdentityStore) pathOIDCRotateKey(ctx context.Context, req *logical.Requ
return nil, err
}
i.oidcCache.Flush(ns)
if err := i.oidcCache.Flush(ns); err != nil {
return nil, err
}
return nil, nil
}
@ -683,7 +701,12 @@ func (i *IdentityStore) pathOIDCGenerateToken(ctx context.Context, req *logical.
var key *namedKey
if keyRaw, found := i.oidcCache.Get(ns, "namedKeys/"+role.Key); found {
keyRaw, found, err := i.oidcCache.Get(ns, "namedKeys/"+role.Key)
if err != nil {
return nil, err
}
if found {
key = keyRaw.(*namedKey)
} else {
entry, _ := req.Storage.Get(ctx, namedKeyConfigPath+role.Key)
@ -695,7 +718,9 @@ func (i *IdentityStore) pathOIDCGenerateToken(ctx context.Context, req *logical.
return nil, err
}
i.oidcCache.SetDefault(ns, "namedKeys/"+role.Key, key)
if err := i.oidcCache.SetDefault(ns, "namedKeys/"+role.Key, key); err != nil {
return nil, err
}
}
// Validate that the role is allowed to sign with its key (the key could have been updated)
if !strutil.StrListContains(key.AllowedClientIDs, "*") && !strutil.StrListContains(key.AllowedClientIDs, role.ClientID) {
@ -923,7 +948,10 @@ func (i *IdentityStore) pathOIDCCreateUpdateRole(ctx context.Context, req *logic
return nil, err
}
i.oidcCache.Flush(ns)
if err := i.oidcCache.Flush(ns); err != nil {
return nil, err
}
return nil, nil
}
@ -994,7 +1022,12 @@ func (i *IdentityStore) pathOIDCDiscovery(ctx context.Context, req *logical.Requ
return nil, err
}
if v, ok := i.oidcCache.Get(ns, "discoveryResponse"); ok {
v, ok, err := i.oidcCache.Get(ns, "discoveryResponse")
if err != nil {
return nil, err
}
if ok {
data = v.([]byte)
} else {
c, err := i.getOIDCConfig(ctx, req.Storage)
@ -1015,7 +1048,9 @@ func (i *IdentityStore) pathOIDCDiscovery(ctx context.Context, req *logical.Requ
return nil, err
}
i.oidcCache.SetDefault(ns, "discoveryResponse", data)
if err := i.oidcCache.SetDefault(ns, "discoveryResponse", data); err != nil {
return nil, err
}
}
resp := &logical.Response{
@ -1040,7 +1075,12 @@ func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical
return nil, err
}
if v, ok := i.oidcCache.Get(ns, "jwksResponse"); ok {
v, ok, err := i.oidcCache.Get(ns, "jwksResponse")
if err != nil {
return nil, err
}
if ok {
data = v.([]byte)
} else {
jwks, err := i.generatePublicJWKS(ctx, req.Storage)
@ -1053,7 +1093,9 @@ func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical
return nil, err
}
i.oidcCache.SetDefault(ns, "jwksResponse", data)
if err := i.oidcCache.SetDefault(ns, "jwksResponse", data); err != nil {
return nil, err
}
}
resp := &logical.Response{
@ -1072,7 +1114,12 @@ func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical
return nil, err
}
if len(keys) > 0 {
if v, ok := i.oidcCache.Get(nilNamespace, "nextRun"); ok {
v, ok, err := i.oidcCache.Get(noNamespace, "nextRun")
if err != nil {
return nil, err
}
if ok {
now := time.Now()
expireAt := v.(time.Time)
if expireAt.After(now) {
@ -1311,7 +1358,12 @@ func (i *IdentityStore) generatePublicJWKS(ctx context.Context, s logical.Storag
return nil, err
}
if jwksRaw, ok := i.oidcCache.Get(ns, "jwks"); ok {
jwksRaw, ok, err := i.oidcCache.Get(ns, "jwks")
if err != nil {
return nil, err
}
if ok {
return jwksRaw.(*jose.JSONWebKeySet), nil
}
@ -1336,7 +1388,9 @@ func (i *IdentityStore) generatePublicJWKS(ctx context.Context, s logical.Storag
jwks.Keys = append(jwks.Keys, *key)
}
i.oidcCache.SetDefault(ns, "jwks", jwks)
if err := i.oidcCache.SetDefault(ns, "jwks", jwks); err != nil {
return nil, err
}
return jwks, nil
}
@ -1435,7 +1489,9 @@ func (i *IdentityStore) expireOIDCPublicKeys(ctx context.Context, s logical.Stor
}
if didUpdate {
i.oidcCache.Flush(ns)
if err := i.oidcCache.Flush(ns); err != nil {
i.Logger().Error("error flushing oidc cache", "error", err)
}
}
return nextExpiration, nil
@ -1501,7 +1557,13 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
nsPaths := i.listNamespacePaths()
if v, ok := i.oidcCache.Get(nilNamespace, "nextRun"); ok {
v, ok, err := i.oidcCache.Get(noNamespace, "nextRun")
if err != nil {
i.Logger().Error("error reading oidc cache", "err", err)
return
}
if ok {
nextRun = v.(time.Time)
}
@ -1531,7 +1593,9 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
i.Logger().Warn("error expiring OIDC public keys", "err", err)
}
i.oidcCache.Flush(nilNamespace)
if err := i.oidcCache.Flush(noNamespace); err != nil {
i.Logger().Error("error flushing oidc cache", "err", err)
}
// re-run at the soonest expiration or rotation time
if nextRotation.Before(nextRun) {
@ -1542,7 +1606,9 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
nextRun = nextExpiration
}
}
i.oidcCache.SetDefault(nilNamespace, "nextRun", nextRun)
if err := i.oidcCache.SetDefault(noNamespace, "nextRun", nextRun); err != nil {
i.Logger().Error("error setting oidc cache", "err", err)
}
}
}
@ -1556,20 +1622,35 @@ func (c *oidcCache) nskey(ns *namespace.Namespace, key string) string {
return fmt.Sprintf("v0:%s:%s", ns.ID, key)
}
func (c *oidcCache) Get(ns *namespace.Namespace, key string) (interface{}, bool) {
return c.c.Get(c.nskey(ns, key))
func (c *oidcCache) Get(ns *namespace.Namespace, key string) (interface{}, bool, error) {
if ns == nil {
return nil, false, errNilNamespace
}
v, found := c.c.Get(c.nskey(ns, key))
return v, found, nil
}
func (c *oidcCache) SetDefault(ns *namespace.Namespace, key string, obj interface{}) {
func (c *oidcCache) SetDefault(ns *namespace.Namespace, key string, obj interface{}) error {
if ns == nil {
return errNilNamespace
}
c.c.SetDefault(c.nskey(ns, key), obj)
return nil
}
func (c *oidcCache) Flush(ns *namespace.Namespace) {
func (c *oidcCache) Flush(ns *namespace.Namespace) error {
if ns == nil {
return errNilNamespace
}
for itemKey := range c.c.Items() {
if isTargetNamespacedKey(itemKey, []string{nilNamespace.ID, ns.ID}) {
if isTargetNamespacedKey(itemKey, []string{noNamespace.ID, ns.ID}) {
c.c.Delete(itemKey)
}
}
return nil
}
// isTargetNamespacedKey returns true for a properly constructed namespaced key (<version>:<nsID>:<key>)

View File

@ -619,7 +619,7 @@ func TestOIDC_PeriodicFunc(t *testing.T) {
currentCycle = currentCycle + 1
// sleep until we are in the next cycle - where a next run will happen
v, _ := c.identityStore.oidcCache.Get(nilNamespace, "nextRun")
v, _, _ := c.identityStore.oidcCache.Get(noNamespace, "nextRun")
nextRun := v.(time.Time)
now := time.Now()
diff := nextRun.Sub(now)
@ -1012,7 +1012,7 @@ func TestOIDC_isTargetNamespacedKey(t *testing.T) {
func TestOIDC_Flush(t *testing.T) {
c := newOIDCCache()
ns := []*namespace.Namespace{
nilNamespace, //ns[0] is nilNamespace
noNamespace, //ns[0] is nilNamespace
&namespace.Namespace{ID: "ns1"},
&namespace.Namespace{ID: "ns2"},
}
@ -1021,7 +1021,9 @@ func TestOIDC_Flush(t *testing.T) {
populateNs := func() {
for i := range ns {
for _, val := range []string{"keyA", "keyB", "keyC"} {
c.SetDefault(ns[i], val, struct{}{})
if err := c.SetDefault(ns[i], val, struct{}{}); err != nil {
t.Fatal(err)
}
}
}
}
@ -1052,17 +1054,37 @@ func TestOIDC_Flush(t *testing.T) {
// flushing ns1 should flush ns1 and nilNamespace but not ns2
populateNs()
c.Flush(ns[1])
if err := c.Flush(ns[1]); err != nil {
t.Fatal(err)
}
items := c.c.Items()
verify(items, []*namespace.Namespace{ns[2]}, []*namespace.Namespace{ns[0], ns[1]})
// flushing nilNamespace should flush nilNamespace but not ns1 or ns2
populateNs()
c.Flush(ns[0])
if err := c.Flush(ns[0]); err != nil {
t.Fatal(err)
}
items = c.c.Items()
verify(items, []*namespace.Namespace{ns[1], ns[2]}, []*namespace.Namespace{ns[0]})
}
func TestOIDC_CacheNamespaceNilCheck(t *testing.T) {
cache := newOIDCCache()
if _, _, err := cache.Get(nil, "foo"); err == nil {
t.Fatal("expected error, got nil")
}
if err := cache.SetDefault(nil, "foo", 42); err == nil {
t.Fatal("expected error, got nil")
}
if err := cache.Flush(nil); err == nil {
t.Fatal("expected error, got nil")
}
}
// some helpers
func expectSuccess(t *testing.T, resp *logical.Response, err error) {
t.Helper()