Fix identity token panic during invalidation (#8015)
* Fix identity token crash during invalidation * Check for nil namespace * Fix test * Add nil check test * Check OIDC cache errors
This commit is contained in:
parent
02dfa885d4
commit
5821fe48c7
|
@ -315,7 +315,9 @@ func (i *IdentityStore) Invalidate(ctx context.Context, key string) {
|
|||
return
|
||||
|
||||
case strings.HasPrefix(key, oidcTokensPrefix):
|
||||
i.oidcCache.Flush(nil)
|
||||
if err := i.oidcCache.Flush(noNamespace); err != nil {
|
||||
i.logger.Error("error flushing oidc cache", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
@ -90,6 +91,8 @@ type oidcCache struct {
|
|||
c *cache.Cache
|
||||
}
|
||||
|
||||
var errNilNamespace = errors.New("nil namespace in oidc cache request")
|
||||
|
||||
const (
|
||||
issuerPath = "identity/oidc"
|
||||
oidcTokensPrefix = "oidc_tokens/"
|
||||
|
@ -111,7 +114,7 @@ var supportedAlgs = []string{
|
|||
}
|
||||
|
||||
// pseudo-namespace for cache items that don't belong to any real namespace.
|
||||
var nilNamespace = &namespace.Namespace{ID: "__NIL_NAMESPACE"}
|
||||
var noNamespace = &namespace.Namespace{ID: "__NO_NAMESPACE"}
|
||||
|
||||
func oidcPaths(i *IdentityStore) []*framework.Path {
|
||||
return []*framework.Path{
|
||||
|
@ -370,7 +373,9 @@ func (i *IdentityStore) pathOIDCUpdateConfig(ctx context.Context, req *logical.R
|
|||
return nil, err
|
||||
}
|
||||
|
||||
i.oidcCache.Flush(ns)
|
||||
if err := i.oidcCache.Flush(ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
@ -381,7 +386,12 @@ func (i *IdentityStore) getOIDCConfig(ctx context.Context, s logical.Storage) (*
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if v, ok := i.oidcCache.Get(ns, "config"); ok {
|
||||
v, ok, err := i.oidcCache.Get(ns, "config")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ok {
|
||||
return v.(*oidcConfig), nil
|
||||
}
|
||||
|
||||
|
@ -404,7 +414,9 @@ func (i *IdentityStore) getOIDCConfig(ctx context.Context, s logical.Storage) (*
|
|||
|
||||
c.effectiveIssuer += "/v1/" + ns.Path + issuerPath
|
||||
|
||||
i.oidcCache.SetDefault(ns, "config", &c)
|
||||
if err := i.oidcCache.SetDefault(ns, "config", &c); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &c, nil
|
||||
}
|
||||
|
@ -416,8 +428,6 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica
|
|||
return nil, err
|
||||
}
|
||||
|
||||
defer i.oidcCache.Flush(ns)
|
||||
|
||||
name := d.Get("name").(string)
|
||||
|
||||
i.oidcLock.Lock()
|
||||
|
@ -494,6 +504,10 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica
|
|||
}
|
||||
}
|
||||
|
||||
if err := i.oidcCache.Flush(ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// store named key
|
||||
entry, err := logical.StorageEntryJSON(namedKeyConfigPath+name, key)
|
||||
if err != nil {
|
||||
|
@ -590,7 +604,9 @@ func (i *IdentityStore) pathOIDCDeleteKey(ctx context.Context, req *logical.Requ
|
|||
return nil, err
|
||||
}
|
||||
|
||||
i.oidcCache.Flush(ns)
|
||||
if err := i.oidcCache.Flush(ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -645,7 +661,9 @@ func (i *IdentityStore) pathOIDCRotateKey(ctx context.Context, req *logical.Requ
|
|||
return nil, err
|
||||
}
|
||||
|
||||
i.oidcCache.Flush(ns)
|
||||
if err := i.oidcCache.Flush(ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -683,7 +701,12 @@ func (i *IdentityStore) pathOIDCGenerateToken(ctx context.Context, req *logical.
|
|||
|
||||
var key *namedKey
|
||||
|
||||
if keyRaw, found := i.oidcCache.Get(ns, "namedKeys/"+role.Key); found {
|
||||
keyRaw, found, err := i.oidcCache.Get(ns, "namedKeys/"+role.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if found {
|
||||
key = keyRaw.(*namedKey)
|
||||
} else {
|
||||
entry, _ := req.Storage.Get(ctx, namedKeyConfigPath+role.Key)
|
||||
|
@ -695,7 +718,9 @@ func (i *IdentityStore) pathOIDCGenerateToken(ctx context.Context, req *logical.
|
|||
return nil, err
|
||||
}
|
||||
|
||||
i.oidcCache.SetDefault(ns, "namedKeys/"+role.Key, key)
|
||||
if err := i.oidcCache.SetDefault(ns, "namedKeys/"+role.Key, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// Validate that the role is allowed to sign with its key (the key could have been updated)
|
||||
if !strutil.StrListContains(key.AllowedClientIDs, "*") && !strutil.StrListContains(key.AllowedClientIDs, role.ClientID) {
|
||||
|
@ -923,7 +948,10 @@ func (i *IdentityStore) pathOIDCCreateUpdateRole(ctx context.Context, req *logic
|
|||
return nil, err
|
||||
}
|
||||
|
||||
i.oidcCache.Flush(ns)
|
||||
if err := i.oidcCache.Flush(ns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
@ -994,7 +1022,12 @@ func (i *IdentityStore) pathOIDCDiscovery(ctx context.Context, req *logical.Requ
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if v, ok := i.oidcCache.Get(ns, "discoveryResponse"); ok {
|
||||
v, ok, err := i.oidcCache.Get(ns, "discoveryResponse")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ok {
|
||||
data = v.([]byte)
|
||||
} else {
|
||||
c, err := i.getOIDCConfig(ctx, req.Storage)
|
||||
|
@ -1015,7 +1048,9 @@ func (i *IdentityStore) pathOIDCDiscovery(ctx context.Context, req *logical.Requ
|
|||
return nil, err
|
||||
}
|
||||
|
||||
i.oidcCache.SetDefault(ns, "discoveryResponse", data)
|
||||
if err := i.oidcCache.SetDefault(ns, "discoveryResponse", data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
resp := &logical.Response{
|
||||
|
@ -1040,7 +1075,12 @@ func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if v, ok := i.oidcCache.Get(ns, "jwksResponse"); ok {
|
||||
v, ok, err := i.oidcCache.Get(ns, "jwksResponse")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ok {
|
||||
data = v.([]byte)
|
||||
} else {
|
||||
jwks, err := i.generatePublicJWKS(ctx, req.Storage)
|
||||
|
@ -1053,7 +1093,9 @@ func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical
|
|||
return nil, err
|
||||
}
|
||||
|
||||
i.oidcCache.SetDefault(ns, "jwksResponse", data)
|
||||
if err := i.oidcCache.SetDefault(ns, "jwksResponse", data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
resp := &logical.Response{
|
||||
|
@ -1072,7 +1114,12 @@ func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical
|
|||
return nil, err
|
||||
}
|
||||
if len(keys) > 0 {
|
||||
if v, ok := i.oidcCache.Get(nilNamespace, "nextRun"); ok {
|
||||
v, ok, err := i.oidcCache.Get(noNamespace, "nextRun")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ok {
|
||||
now := time.Now()
|
||||
expireAt := v.(time.Time)
|
||||
if expireAt.After(now) {
|
||||
|
@ -1311,7 +1358,12 @@ func (i *IdentityStore) generatePublicJWKS(ctx context.Context, s logical.Storag
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if jwksRaw, ok := i.oidcCache.Get(ns, "jwks"); ok {
|
||||
jwksRaw, ok, err := i.oidcCache.Get(ns, "jwks")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ok {
|
||||
return jwksRaw.(*jose.JSONWebKeySet), nil
|
||||
}
|
||||
|
||||
|
@ -1336,7 +1388,9 @@ func (i *IdentityStore) generatePublicJWKS(ctx context.Context, s logical.Storag
|
|||
jwks.Keys = append(jwks.Keys, *key)
|
||||
}
|
||||
|
||||
i.oidcCache.SetDefault(ns, "jwks", jwks)
|
||||
if err := i.oidcCache.SetDefault(ns, "jwks", jwks); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return jwks, nil
|
||||
}
|
||||
|
@ -1435,7 +1489,9 @@ func (i *IdentityStore) expireOIDCPublicKeys(ctx context.Context, s logical.Stor
|
|||
}
|
||||
|
||||
if didUpdate {
|
||||
i.oidcCache.Flush(ns)
|
||||
if err := i.oidcCache.Flush(ns); err != nil {
|
||||
i.Logger().Error("error flushing oidc cache", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nextExpiration, nil
|
||||
|
@ -1501,7 +1557,13 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
|
|||
|
||||
nsPaths := i.listNamespacePaths()
|
||||
|
||||
if v, ok := i.oidcCache.Get(nilNamespace, "nextRun"); ok {
|
||||
v, ok, err := i.oidcCache.Get(noNamespace, "nextRun")
|
||||
if err != nil {
|
||||
i.Logger().Error("error reading oidc cache", "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
if ok {
|
||||
nextRun = v.(time.Time)
|
||||
}
|
||||
|
||||
|
@ -1531,7 +1593,9 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
|
|||
i.Logger().Warn("error expiring OIDC public keys", "err", err)
|
||||
}
|
||||
|
||||
i.oidcCache.Flush(nilNamespace)
|
||||
if err := i.oidcCache.Flush(noNamespace); err != nil {
|
||||
i.Logger().Error("error flushing oidc cache", "err", err)
|
||||
}
|
||||
|
||||
// re-run at the soonest expiration or rotation time
|
||||
if nextRotation.Before(nextRun) {
|
||||
|
@ -1542,7 +1606,9 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
|
|||
nextRun = nextExpiration
|
||||
}
|
||||
}
|
||||
i.oidcCache.SetDefault(nilNamespace, "nextRun", nextRun)
|
||||
if err := i.oidcCache.SetDefault(noNamespace, "nextRun", nextRun); err != nil {
|
||||
i.Logger().Error("error setting oidc cache", "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1556,20 +1622,35 @@ func (c *oidcCache) nskey(ns *namespace.Namespace, key string) string {
|
|||
return fmt.Sprintf("v0:%s:%s", ns.ID, key)
|
||||
}
|
||||
|
||||
func (c *oidcCache) Get(ns *namespace.Namespace, key string) (interface{}, bool) {
|
||||
return c.c.Get(c.nskey(ns, key))
|
||||
func (c *oidcCache) Get(ns *namespace.Namespace, key string) (interface{}, bool, error) {
|
||||
if ns == nil {
|
||||
return nil, false, errNilNamespace
|
||||
}
|
||||
v, found := c.c.Get(c.nskey(ns, key))
|
||||
return v, found, nil
|
||||
}
|
||||
|
||||
func (c *oidcCache) SetDefault(ns *namespace.Namespace, key string, obj interface{}) {
|
||||
func (c *oidcCache) SetDefault(ns *namespace.Namespace, key string, obj interface{}) error {
|
||||
if ns == nil {
|
||||
return errNilNamespace
|
||||
}
|
||||
c.c.SetDefault(c.nskey(ns, key), obj)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *oidcCache) Flush(ns *namespace.Namespace) {
|
||||
func (c *oidcCache) Flush(ns *namespace.Namespace) error {
|
||||
if ns == nil {
|
||||
return errNilNamespace
|
||||
}
|
||||
|
||||
for itemKey := range c.c.Items() {
|
||||
if isTargetNamespacedKey(itemKey, []string{nilNamespace.ID, ns.ID}) {
|
||||
if isTargetNamespacedKey(itemKey, []string{noNamespace.ID, ns.ID}) {
|
||||
c.c.Delete(itemKey)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isTargetNamespacedKey returns true for a properly constructed namespaced key (<version>:<nsID>:<key>)
|
||||
|
|
|
@ -619,7 +619,7 @@ func TestOIDC_PeriodicFunc(t *testing.T) {
|
|||
currentCycle = currentCycle + 1
|
||||
|
||||
// sleep until we are in the next cycle - where a next run will happen
|
||||
v, _ := c.identityStore.oidcCache.Get(nilNamespace, "nextRun")
|
||||
v, _, _ := c.identityStore.oidcCache.Get(noNamespace, "nextRun")
|
||||
nextRun := v.(time.Time)
|
||||
now := time.Now()
|
||||
diff := nextRun.Sub(now)
|
||||
|
@ -1012,7 +1012,7 @@ func TestOIDC_isTargetNamespacedKey(t *testing.T) {
|
|||
func TestOIDC_Flush(t *testing.T) {
|
||||
c := newOIDCCache()
|
||||
ns := []*namespace.Namespace{
|
||||
nilNamespace, //ns[0] is nilNamespace
|
||||
noNamespace, //ns[0] is nilNamespace
|
||||
&namespace.Namespace{ID: "ns1"},
|
||||
&namespace.Namespace{ID: "ns2"},
|
||||
}
|
||||
|
@ -1021,7 +1021,9 @@ func TestOIDC_Flush(t *testing.T) {
|
|||
populateNs := func() {
|
||||
for i := range ns {
|
||||
for _, val := range []string{"keyA", "keyB", "keyC"} {
|
||||
c.SetDefault(ns[i], val, struct{}{})
|
||||
if err := c.SetDefault(ns[i], val, struct{}{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1052,17 +1054,37 @@ func TestOIDC_Flush(t *testing.T) {
|
|||
|
||||
// flushing ns1 should flush ns1 and nilNamespace but not ns2
|
||||
populateNs()
|
||||
c.Flush(ns[1])
|
||||
if err := c.Flush(ns[1]); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
items := c.c.Items()
|
||||
verify(items, []*namespace.Namespace{ns[2]}, []*namespace.Namespace{ns[0], ns[1]})
|
||||
|
||||
// flushing nilNamespace should flush nilNamespace but not ns1 or ns2
|
||||
populateNs()
|
||||
c.Flush(ns[0])
|
||||
if err := c.Flush(ns[0]); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
items = c.c.Items()
|
||||
verify(items, []*namespace.Namespace{ns[1], ns[2]}, []*namespace.Namespace{ns[0]})
|
||||
}
|
||||
|
||||
func TestOIDC_CacheNamespaceNilCheck(t *testing.T) {
|
||||
cache := newOIDCCache()
|
||||
|
||||
if _, _, err := cache.Get(nil, "foo"); err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
|
||||
if err := cache.SetDefault(nil, "foo", 42); err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
|
||||
if err := cache.Flush(nil); err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// some helpers
|
||||
func expectSuccess(t *testing.T, resp *logical.Response, err error) {
|
||||
t.Helper()
|
||||
|
|
Loading…
Reference in New Issue