Identity: prepublish jwt signing keys (#12414)
* pre-publish new signing keys for `rotation_period` of time before using * Work In Progress: Prepublish JWKS and even cache control * remove comments * use math/rand instead of math/big * update tests * remove debug comment * refactor cache control logic into func * don't set expiry when create/update key * update cachecontrol name in oidccache for test * fix bug in periodicfunc test case * add changelog * remove confusing comment * add logging and comments * update change log from bug to improvement Co-authored-by: Ian Ferguson <ian.ferguson@datadoghq.com>
This commit is contained in:
parent
d4656971b1
commit
c42bbb369c
3
changelog/12414.txt
Normal file
3
changelog/12414.txt
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
```release-note:improvement
|
||||||
|
identity: fix issue where Cache-Control header causes stampede of requests for JWKS keys
|
||||||
|
```
|
|
@ -10,6 +10,8 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
|
mathrand "math/rand"
|
||||||
"net/url"
|
"net/url"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -50,6 +52,7 @@ type namedKey struct {
|
||||||
RotationPeriod time.Duration `json:"rotation_period"`
|
RotationPeriod time.Duration `json:"rotation_period"`
|
||||||
KeyRing []*expireableKey `json:"key_ring"`
|
KeyRing []*expireableKey `json:"key_ring"`
|
||||||
SigningKey *jose.JSONWebKey `json:"signing_key"`
|
SigningKey *jose.JSONWebKey `json:"signing_key"`
|
||||||
|
NextSigningKey *jose.JSONWebKey `json:"next_signing_key"`
|
||||||
NextRotation time.Time `json:"next_rotation"`
|
NextRotation time.Time `json:"next_rotation"`
|
||||||
AllowedClientIDs []string `json:"allowed_client_ids"`
|
AllowedClientIDs []string `json:"allowed_client_ids"`
|
||||||
}
|
}
|
||||||
|
@ -510,13 +513,15 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica
|
||||||
return logical.ErrorResponse("unknown signing algorithm %q", key.Algorithm), nil
|
return logical.ErrorResponse("unknown signing algorithm %q", key.Algorithm), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
// Update next rotation time if it is unset or now earlier than previously set.
|
// Update next rotation time if it is unset or now earlier than previously set.
|
||||||
nextRotation := time.Now().Add(key.RotationPeriod)
|
nextRotation := now.Add(key.RotationPeriod)
|
||||||
if key.NextRotation.IsZero() || nextRotation.Before(key.NextRotation) {
|
if key.NextRotation.IsZero() || nextRotation.Before(key.NextRotation) {
|
||||||
key.NextRotation = nextRotation
|
key.NextRotation = nextRotation
|
||||||
}
|
}
|
||||||
|
|
||||||
// generate keys if creating a new key or changing algorithms
|
// generate current and next keys if creating a new key or changing algorithms
|
||||||
if key.Algorithm != prevAlgorithm {
|
if key.Algorithm != prevAlgorithm {
|
||||||
signingKey, err := generateKeys(key.Algorithm)
|
signingKey, err := generateKeys(key.Algorithm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -529,6 +534,20 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica
|
||||||
if err := saveOIDCPublicKey(ctx, req.Storage, signingKey.Public()); err != nil {
|
if err := saveOIDCPublicKey(ctx, req.Storage, signingKey.Public()); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
i.Logger().Debug("generated OIDC public key to sign JWTs", "key_id", signingKey.Public().KeyID)
|
||||||
|
|
||||||
|
nextSigningKey, err := generateKeys(key.Algorithm)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
key.NextSigningKey = nextSigningKey
|
||||||
|
key.KeyRing = append(key.KeyRing, &expireableKey{KeyID: nextSigningKey.Public().KeyID})
|
||||||
|
|
||||||
|
if err := saveOIDCPublicKey(ctx, req.Storage, nextSigningKey.Public()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
i.Logger().Debug("generated OIDC public key for future use", "key_id", nextSigningKey.Public().KeyID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := i.oidcCache.Flush(ns); err != nil {
|
if err := i.oidcCache.Flush(ns); err != nil {
|
||||||
|
@ -727,7 +746,7 @@ func (i *IdentityStore) pathOIDCRotateKey(ctx context.Context, req *logical.Requ
|
||||||
verificationTTLOverride = time.Duration(ttlRaw.(int)) * time.Second
|
verificationTTLOverride = time.Duration(ttlRaw.(int)) * time.Second
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := storedNamedKey.rotate(ctx, req.Storage, verificationTTLOverride); err != nil {
|
if err := storedNamedKey.rotate(ctx, i.Logger(), req.Storage, verificationTTLOverride); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1168,6 +1187,40 @@ func (i *IdentityStore) pathOIDCDiscovery(ctx context.Context, req *logical.Requ
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getKeysCacheControlHeader returns the cache control header for all public
|
||||||
|
// keys at the .well-known/keys endpoint
|
||||||
|
func (i *IdentityStore) getKeysCacheControlHeader() (string, error) {
|
||||||
|
// if jwksCacheControlMaxAge is set use that, otherwise fall back on the
|
||||||
|
// more conservative nextRun values
|
||||||
|
jwksCacheControlMaxAge, ok, err := i.oidcCache.Get(noNamespace, "jwksCacheControlMaxAge")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
maxDuration := int64(jwksCacheControlMaxAge.(time.Duration))
|
||||||
|
randDuration := mathrand.Int63n(maxDuration)
|
||||||
|
durationInSeconds := time.Duration(randDuration).Seconds()
|
||||||
|
return fmt.Sprintf("max-age=%.0f", durationInSeconds), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
nextRun, ok, err := i.oidcCache.Get(noNamespace, "nextRun")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
now := time.Now()
|
||||||
|
expireAt := nextRun.(time.Time)
|
||||||
|
if expireAt.After(now) {
|
||||||
|
i.Logger().Debug("use nextRun value for Cache Control header", "nextRun", nextRun)
|
||||||
|
expireInSeconds := expireAt.Sub(time.Now()).Seconds()
|
||||||
|
return fmt.Sprintf("max-age=%.0f", expireInSeconds), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
// pathOIDCReadPublicKeys is used to retrieve all public keys so that clients can
|
// pathOIDCReadPublicKeys is used to retrieve all public keys so that clients can
|
||||||
// verify the validity of a signed OIDC token.
|
// verify the validity of a signed OIDC token.
|
||||||
func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||||
|
@ -1209,27 +1262,19 @@ func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// set a Cache-Control header only if there are keys, if there aren't keys
|
// set a Cache-Control header only if there are keys
|
||||||
// then nextRun should not be used to set Cache-Control header because it chooses
|
|
||||||
// a time in the future that isn't based on key rotation/expiration values
|
|
||||||
keys, err := listOIDCPublicKeys(ctx, req.Storage)
|
keys, err := listOIDCPublicKeys(ctx, req.Storage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if len(keys) > 0 {
|
if len(keys) > 0 {
|
||||||
v, ok, err := i.oidcCache.Get(noNamespace, "nextRun")
|
header, err := i.getKeysCacheControlHeader()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if ok {
|
if header != "" {
|
||||||
now := time.Now()
|
resp.Data[logical.HTTPRawCacheControl] = header
|
||||||
expireAt := v.(time.Time)
|
|
||||||
if expireAt.After(now) {
|
|
||||||
expireInSeconds := expireAt.Sub(time.Now()).Seconds()
|
|
||||||
expireInString := fmt.Sprintf("max-age=%.0f", expireInSeconds)
|
|
||||||
resp.Data[logical.HTTPRawCacheControl] = expireInString
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1326,10 +1371,9 @@ func (i *IdentityStore) pathOIDCIntrospect(ctx context.Context, req *logical.Req
|
||||||
return introspectionResp("")
|
return introspectionResp("")
|
||||||
}
|
}
|
||||||
|
|
||||||
// namedKey.rotate(overrides) performs a key rotation on a namedKey and returns the
|
// namedKey.rotate(overrides) performs a key rotation on a namedKey.
|
||||||
// verification_ttl that was applied. verification_ttl can be overridden with an
|
// verification_ttl can be overridden with an overrideVerificationTTL value >= 0
|
||||||
// overrideVerificationTTL value >= 0
|
func (k *namedKey) rotate(ctx context.Context, logger hclog.Logger, s logical.Storage, overrideVerificationTTL time.Duration) error {
|
||||||
func (k *namedKey) rotate(ctx context.Context, s logical.Storage, overrideVerificationTTL time.Duration) error {
|
|
||||||
verificationTTL := k.VerificationTTL
|
verificationTTL := k.VerificationTTL
|
||||||
|
|
||||||
if overrideVerificationTTL >= 0 {
|
if overrideVerificationTTL >= 0 {
|
||||||
|
@ -1337,16 +1381,16 @@ func (k *namedKey) rotate(ctx context.Context, s logical.Storage, overrideVerifi
|
||||||
}
|
}
|
||||||
|
|
||||||
// generate new key
|
// generate new key
|
||||||
signingKey, err := generateKeys(k.Algorithm)
|
nextSigningKey, err := generateKeys(k.Algorithm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := saveOIDCPublicKey(ctx, s, signingKey.Public()); err != nil {
|
if err := saveOIDCPublicKey(ctx, s, nextSigningKey.Public()); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
logger.Debug("generated OIDC public key for future use", "key_id", nextSigningKey.Public().KeyID)
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
// set the previous public key's expiry time
|
// set the previous public key's expiry time
|
||||||
for _, key := range k.KeyRing {
|
for _, key := range k.KeyRing {
|
||||||
if key.KeyID == k.SigningKey.KeyID {
|
if key.KeyID == k.SigningKey.KeyID {
|
||||||
|
@ -1354,8 +1398,10 @@ func (k *namedKey) rotate(ctx context.Context, s logical.Storage, overrideVerifi
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
k.SigningKey = signingKey
|
|
||||||
k.KeyRing = append(k.KeyRing, &expireableKey{KeyID: signingKey.KeyID})
|
k.KeyRing = append(k.KeyRing, &expireableKey{KeyID: nextSigningKey.KeyID})
|
||||||
|
k.SigningKey = k.NextSigningKey
|
||||||
|
k.NextSigningKey = nextSigningKey
|
||||||
k.NextRotation = now.Add(k.RotationPeriod)
|
k.NextRotation = now.Add(k.RotationPeriod)
|
||||||
|
|
||||||
// store named key (it was modified when rotate was called on it)
|
// store named key (it was modified when rotate was called on it)
|
||||||
|
@ -1367,6 +1413,7 @@ func (k *namedKey) rotate(ctx context.Context, s logical.Storage, overrideVerifi
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.Debug("rotated OIDC public key, now using", "key_id", k.SigningKey.Public().KeyID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1599,24 +1646,30 @@ func (i *IdentityStore) expireOIDCPublicKeys(ctx context.Context, s logical.Stor
|
||||||
return nextExpiration, nil
|
return nextExpiration, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage) (time.Time, error) {
|
// oidcKeyRotation will rotate any keys that are due to be rotated.
|
||||||
|
//
|
||||||
|
// It will return the time of the soonest rotation and the minimum
|
||||||
|
// verificationTTL or minimum rotationPeriod out of all the current keys.
|
||||||
|
func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage) (time.Time, time.Duration, error) {
|
||||||
// soonestRotation will be the soonest rotation time of all keys. Initialize
|
// soonestRotation will be the soonest rotation time of all keys. Initialize
|
||||||
// here to a relatively distant time.
|
// here to a relatively distant time.
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
soonestRotation := now.Add(24 * time.Hour)
|
soonestRotation := now.Add(24 * time.Hour)
|
||||||
|
|
||||||
|
jwksClientCacheDuration := time.Duration(math.MaxInt64)
|
||||||
|
|
||||||
i.oidcLock.Lock()
|
i.oidcLock.Lock()
|
||||||
defer i.oidcLock.Unlock()
|
defer i.oidcLock.Unlock()
|
||||||
|
|
||||||
keys, err := s.List(ctx, namedKeyConfigPath)
|
keys, err := s.List(ctx, namedKeyConfigPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return now, err
|
return now, jwksClientCacheDuration, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, k := range keys {
|
for _, k := range keys {
|
||||||
entry, err := s.Get(ctx, namedKeyConfigPath+k)
|
entry, err := s.Get(ctx, namedKeyConfigPath+k)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return now, err
|
return now, jwksClientCacheDuration, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if entry == nil {
|
if entry == nil {
|
||||||
|
@ -1625,10 +1678,18 @@ func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage)
|
||||||
|
|
||||||
var key namedKey
|
var key namedKey
|
||||||
if err := entry.DecodeJSON(&key); err != nil {
|
if err := entry.DecodeJSON(&key); err != nil {
|
||||||
return now, err
|
return now, jwksClientCacheDuration, err
|
||||||
}
|
}
|
||||||
key.name = k
|
key.name = k
|
||||||
|
|
||||||
|
if key.VerificationTTL < jwksClientCacheDuration {
|
||||||
|
jwksClientCacheDuration = key.VerificationTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
if key.RotationPeriod < jwksClientCacheDuration {
|
||||||
|
jwksClientCacheDuration = key.RotationPeriod
|
||||||
|
}
|
||||||
|
|
||||||
// Future key rotation that is the earliest we've seen.
|
// Future key rotation that is the earliest we've seen.
|
||||||
if now.Before(key.NextRotation) && key.NextRotation.Before(soonestRotation) {
|
if now.Before(key.NextRotation) && key.NextRotation.Before(soonestRotation) {
|
||||||
soonestRotation = key.NextRotation
|
soonestRotation = key.NextRotation
|
||||||
|
@ -1637,8 +1698,8 @@ func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage)
|
||||||
// Key that is due to be rotated.
|
// Key that is due to be rotated.
|
||||||
if now.After(key.NextRotation) {
|
if now.After(key.NextRotation) {
|
||||||
i.Logger().Debug("rotating OIDC key", "key", key.name)
|
i.Logger().Debug("rotating OIDC key", "key", key.name)
|
||||||
if err := key.rotate(ctx, s, -1); err != nil {
|
if err := key.rotate(ctx, i.Logger(), s, -1); err != nil {
|
||||||
return now, err
|
return now, jwksClientCacheDuration, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Possibly save the new rotation time
|
// Possibly save the new rotation time
|
||||||
|
@ -1648,12 +1709,13 @@ func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return soonestRotation, nil
|
return soonestRotation, jwksClientCacheDuration, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// oidcPeriodFunc is invoked by the backend's periodFunc and runs regular key
|
// oidcPeriodFunc is invoked by the backend's periodFunc and runs regular key
|
||||||
// rotations and expiration actions.
|
// rotations and expiration actions.
|
||||||
func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
|
func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
|
||||||
|
i.Logger().Debug("begin oidcPeriodicFunc")
|
||||||
var nextRun time.Time
|
var nextRun time.Time
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
|
@ -1675,6 +1737,7 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
|
||||||
// Initialize to a fairly distant next run time. This will be brought in
|
// Initialize to a fairly distant next run time. This will be brought in
|
||||||
// based on key rotation times.
|
// based on key rotation times.
|
||||||
nextRun = now.Add(24 * time.Hour)
|
nextRun = now.Add(24 * time.Hour)
|
||||||
|
minJwksClientCacheDuration := time.Duration(math.MaxInt64)
|
||||||
|
|
||||||
for _, ns := range i.namespacer.ListNamespaces() {
|
for _, ns := range i.namespacer.ListNamespaces() {
|
||||||
nsPath := ns.Path
|
nsPath := ns.Path
|
||||||
|
@ -1685,7 +1748,7 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
nextRotation, err := i.oidcKeyRotation(ctx, s)
|
nextRotation, jwksClientCacheDuration, err := i.oidcKeyRotation(ctx, s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
i.Logger().Warn("error rotating OIDC keys", "err", err)
|
i.Logger().Warn("error rotating OIDC keys", "err", err)
|
||||||
}
|
}
|
||||||
|
@ -1707,10 +1770,31 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
|
||||||
if nextExpiration.Before(nextRun) {
|
if nextExpiration.Before(nextRun) {
|
||||||
nextRun = nextExpiration
|
nextRun = nextExpiration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if jwksClientCacheDuration < minJwksClientCacheDuration {
|
||||||
|
minJwksClientCacheDuration = jwksClientCacheDuration
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := i.oidcCache.SetDefault(noNamespace, "nextRun", nextRun); err != nil {
|
if err := i.oidcCache.SetDefault(noNamespace, "nextRun", nextRun); err != nil {
|
||||||
i.Logger().Error("error setting oidc cache", "err", err)
|
i.Logger().Error("error setting oidc cache", "err", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if minJwksClientCacheDuration < math.MaxInt64 {
|
||||||
|
// the OIDC JWKS endpoint returns a Cache-Control HTTP header time between
|
||||||
|
// 0 and the minimum verificationTTL or minimum rotationPeriod out of all
|
||||||
|
// keys, whichever value is lower.
|
||||||
|
//
|
||||||
|
// This smooths calls from services validating JWTs to Vault, while
|
||||||
|
// ensuring that operators can assert that servers honoring the
|
||||||
|
// Cache-Control header will always have a superset of all valid keys, and
|
||||||
|
// not trust any keys longer than a jwksCacheControlMaxAge duration after a
|
||||||
|
// key is rotated out of signing use
|
||||||
|
if err := i.oidcCache.SetDefault(noNamespace, "jwksCacheControlMaxAge", minJwksClientCacheDuration); err != nil {
|
||||||
|
i.Logger().Error("error setting jwksCacheControlMaxAge in oidc cache", "err", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,8 @@ import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -626,7 +628,7 @@ func TestOIDC_PublicKeys(t *testing.T) {
|
||||||
Storage: storage,
|
Storage: storage,
|
||||||
})
|
})
|
||||||
|
|
||||||
// .well-known/keys should contain 1 public key
|
// .well-known/keys should contain 2 public keys
|
||||||
resp, err := c.identityStore.HandleRequest(ctx, &logical.Request{
|
resp, err := c.identityStore.HandleRequest(ctx, &logical.Request{
|
||||||
Path: "oidc/.well-known/keys",
|
Path: "oidc/.well-known/keys",
|
||||||
Operation: logical.ReadOperation,
|
Operation: logical.ReadOperation,
|
||||||
|
@ -636,8 +638,8 @@ func TestOIDC_PublicKeys(t *testing.T) {
|
||||||
// parse response
|
// parse response
|
||||||
responseJWKS := &jose.JSONWebKeySet{}
|
responseJWKS := &jose.JSONWebKeySet{}
|
||||||
json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS)
|
json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS)
|
||||||
if len(responseJWKS.Keys) != 1 {
|
if len(responseJWKS.Keys) != 2 {
|
||||||
t.Fatalf("expected 1 public key but instead got %d", len(responseJWKS.Keys))
|
t.Fatalf("expected 2 public keys but instead got %d", len(responseJWKS.Keys))
|
||||||
}
|
}
|
||||||
|
|
||||||
// rotate test-key a few times, each rotate should increase the length of public keys returned
|
// rotate test-key a few times, each rotate should increase the length of public keys returned
|
||||||
|
@ -655,7 +657,7 @@ func TestOIDC_PublicKeys(t *testing.T) {
|
||||||
})
|
})
|
||||||
expectSuccess(t, resp, err)
|
expectSuccess(t, resp, err)
|
||||||
|
|
||||||
// .well-known/keys should contain 3 public keys
|
// .well-known/keys should contain 4 public keys
|
||||||
resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{
|
resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{
|
||||||
Path: "oidc/.well-known/keys",
|
Path: "oidc/.well-known/keys",
|
||||||
Operation: logical.ReadOperation,
|
Operation: logical.ReadOperation,
|
||||||
|
@ -664,8 +666,8 @@ func TestOIDC_PublicKeys(t *testing.T) {
|
||||||
expectSuccess(t, resp, err)
|
expectSuccess(t, resp, err)
|
||||||
// parse response
|
// parse response
|
||||||
json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS)
|
json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS)
|
||||||
if len(responseJWKS.Keys) != 3 {
|
if len(responseJWKS.Keys) != 4 {
|
||||||
t.Fatalf("expected 3 public keys but instead got %d", len(responseJWKS.Keys))
|
t.Fatalf("expected 4 public keys but instead got %d", len(responseJWKS.Keys))
|
||||||
}
|
}
|
||||||
|
|
||||||
// create another named key
|
// create another named key
|
||||||
|
@ -682,7 +684,7 @@ func TestOIDC_PublicKeys(t *testing.T) {
|
||||||
Storage: storage,
|
Storage: storage,
|
||||||
})
|
})
|
||||||
|
|
||||||
// .well-known/keys should contain 1 public key, all of the public keys
|
// .well-known/keys should contain 2 public key, all of the public keys
|
||||||
// from named key "test-key" should have been deleted
|
// from named key "test-key" should have been deleted
|
||||||
resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{
|
resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{
|
||||||
Path: "oidc/.well-known/keys",
|
Path: "oidc/.well-known/keys",
|
||||||
|
@ -692,8 +694,8 @@ func TestOIDC_PublicKeys(t *testing.T) {
|
||||||
expectSuccess(t, resp, err)
|
expectSuccess(t, resp, err)
|
||||||
// parse response
|
// parse response
|
||||||
json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS)
|
json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS)
|
||||||
if len(responseJWKS.Keys) != 1 {
|
if len(responseJWKS.Keys) != 2 {
|
||||||
t.Fatalf("expected 1 public keys but instead got %d", len(responseJWKS.Keys))
|
t.Fatalf("expected 2 public keys but instead got %d", len(responseJWKS.Keys))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -814,10 +816,18 @@ func TestOIDC_SignIDToken(t *testing.T) {
|
||||||
responseJWKS := &jose.JSONWebKeySet{}
|
responseJWKS := &jose.JSONWebKeySet{}
|
||||||
json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS)
|
json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS)
|
||||||
|
|
||||||
|
keyCount := len(responseJWKS.Keys)
|
||||||
|
errorCount := 0
|
||||||
|
for _, key := range responseJWKS.Keys {
|
||||||
// Validate the signature
|
// Validate the signature
|
||||||
claims := &jwt.Claims{}
|
claims := &jwt.Claims{}
|
||||||
if err := parsedToken.Claims(responseJWKS.Keys[0], claims); err != nil {
|
if err := parsedToken.Claims(key, claims); err != nil {
|
||||||
t.Fatalf("unable to validate signed token, err:\n%#v", err)
|
t.Logf("unable to validate signed token, err:\n%#v", err)
|
||||||
|
errorCount += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if errorCount == keyCount {
|
||||||
|
t.Fatalf("unable to validate signed token with any of the .well-known keys")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -856,6 +866,7 @@ func TestOIDC_PeriodicFunc(t *testing.T) {
|
||||||
RotationPeriod: 1 * cyclePeriod,
|
RotationPeriod: 1 * cyclePeriod,
|
||||||
KeyRing: nil,
|
KeyRing: nil,
|
||||||
SigningKey: jwk,
|
SigningKey: jwk,
|
||||||
|
NextSigningKey: jwk,
|
||||||
NextRotation: time.Now(),
|
NextRotation: time.Now(),
|
||||||
},
|
},
|
||||||
[]struct {
|
[]struct {
|
||||||
|
@ -865,8 +876,11 @@ func TestOIDC_PeriodicFunc(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
{1, 1, 1},
|
{1, 1, 1},
|
||||||
{2, 2, 2},
|
{2, 2, 2},
|
||||||
{3, 2, 2},
|
{3, 3, 3},
|
||||||
{4, 2, 2},
|
{4, 3, 3},
|
||||||
|
{5, 3, 3},
|
||||||
|
{6, 3, 3},
|
||||||
|
{7, 3, 3},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -1368,6 +1382,62 @@ func TestOIDC_CacheNamespaceNilCheck(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOIDC_GetKeysCacheControlHeader(t *testing.T) {
|
||||||
|
c, _, _ := TestCoreUnsealed(t)
|
||||||
|
|
||||||
|
// get default value
|
||||||
|
header, err := c.identityStore.getKeysCacheControlHeader()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected success, got error:\n%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedHeader := ""
|
||||||
|
if header != expectedHeader {
|
||||||
|
t.Fatalf("expected %s, got %s", expectedHeader, header)
|
||||||
|
}
|
||||||
|
|
||||||
|
// set nextRun
|
||||||
|
nextRun := time.Now().Add(24 * time.Hour)
|
||||||
|
if err = c.identityStore.oidcCache.SetDefault(noNamespace, "nextRun", nextRun); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
header, err = c.identityStore.getKeysCacheControlHeader()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected success, got error:\n%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedNextRun := "max-age=86400"
|
||||||
|
if header != expectedNextRun {
|
||||||
|
t.Fatalf("expected %s, got %s", expectedNextRun, header)
|
||||||
|
}
|
||||||
|
|
||||||
|
// set jwksCacheControlMaxAge
|
||||||
|
jwksCacheControlMaxAge := time.Duration(60) * time.Second
|
||||||
|
if err = c.identityStore.oidcCache.SetDefault(noNamespace, "jwksCacheControlMaxAge", jwksCacheControlMaxAge); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
header, err = c.identityStore.getKeysCacheControlHeader()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected success, got error:\n%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if header == "" {
|
||||||
|
t.Fatalf("expected header to be set, got %s", header)
|
||||||
|
}
|
||||||
|
|
||||||
|
maxAgeValue := strings.Split(header, "=")[1]
|
||||||
|
headerVal, err := strconv.Atoi(maxAgeValue)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// headerVal will be a random value between 0 and jwksCacheControlMaxAge
|
||||||
|
if headerVal > int(jwksCacheControlMaxAge) {
|
||||||
|
t.Fatalf("unexpected header value, got %d", headerVal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// some helpers
|
// some helpers
|
||||||
func expectSuccess(t *testing.T, resp *logical.Response, err error) {
|
func expectSuccess(t *testing.T, resp *logical.Response, err error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
Loading…
Reference in a new issue