Identity: prepublish jwt signing keys (#12414)

* pre-publish new signing keys for `rotation_period` of time before using

* Work In Progress: Prepublish JWKS and even cache control

* remove comments

* use math/rand instead of math/big

* update tests

* remove debug comment

* refactor cache control logic into func

* don't set expiry when create/update key

* update cachecontrol name in oidccache for test

* fix bug in periodicfunc test case

* add changelog

* remove confusing comment

* add logging and comments

* update change log from bug to improvement

Co-authored-by: Ian Ferguson <ian.ferguson@datadoghq.com>
This commit is contained in:
John-Michael Faircloth 2021-09-09 13:47:42 -05:00 committed by GitHub
parent d4656971b1
commit c42bbb369c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 204 additions and 47 deletions

3
changelog/12414.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:improvement
identity: fix issue where Cache-Control header causes stampede of requests for JWKS keys
```

View File

@ -10,6 +10,8 @@ import (
"encoding/json"
"errors"
"fmt"
"math"
mathrand "math/rand"
"net/url"
"sort"
"strings"
@ -50,6 +52,7 @@ type namedKey struct {
RotationPeriod time.Duration `json:"rotation_period"`
KeyRing []*expireableKey `json:"key_ring"`
SigningKey *jose.JSONWebKey `json:"signing_key"`
NextSigningKey *jose.JSONWebKey `json:"next_signing_key"`
NextRotation time.Time `json:"next_rotation"`
AllowedClientIDs []string `json:"allowed_client_ids"`
}
@ -510,13 +513,15 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica
return logical.ErrorResponse("unknown signing algorithm %q", key.Algorithm), nil
}
now := time.Now()
// Update next rotation time if it is unset or now earlier than previously set.
nextRotation := time.Now().Add(key.RotationPeriod)
nextRotation := now.Add(key.RotationPeriod)
if key.NextRotation.IsZero() || nextRotation.Before(key.NextRotation) {
key.NextRotation = nextRotation
}
// generate keys if creating a new key or changing algorithms
// generate current and next keys if creating a new key or changing algorithms
if key.Algorithm != prevAlgorithm {
signingKey, err := generateKeys(key.Algorithm)
if err != nil {
@ -529,6 +534,20 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica
if err := saveOIDCPublicKey(ctx, req.Storage, signingKey.Public()); err != nil {
return nil, err
}
i.Logger().Debug("generated OIDC public key to sign JWTs", "key_id", signingKey.Public().KeyID)
nextSigningKey, err := generateKeys(key.Algorithm)
if err != nil {
return nil, err
}
key.NextSigningKey = nextSigningKey
key.KeyRing = append(key.KeyRing, &expireableKey{KeyID: nextSigningKey.Public().KeyID})
if err := saveOIDCPublicKey(ctx, req.Storage, nextSigningKey.Public()); err != nil {
return nil, err
}
i.Logger().Debug("generated OIDC public key for future use", "key_id", nextSigningKey.Public().KeyID)
}
if err := i.oidcCache.Flush(ns); err != nil {
@ -727,7 +746,7 @@ func (i *IdentityStore) pathOIDCRotateKey(ctx context.Context, req *logical.Requ
verificationTTLOverride = time.Duration(ttlRaw.(int)) * time.Second
}
if err := storedNamedKey.rotate(ctx, req.Storage, verificationTTLOverride); err != nil {
if err := storedNamedKey.rotate(ctx, i.Logger(), req.Storage, verificationTTLOverride); err != nil {
return nil, err
}
@ -1168,6 +1187,40 @@ func (i *IdentityStore) pathOIDCDiscovery(ctx context.Context, req *logical.Requ
return resp, nil
}
// getKeysCacheControlHeader returns the cache control header for all public
// keys at the .well-known/keys endpoint
func (i *IdentityStore) getKeysCacheControlHeader() (string, error) {
// if jwksCacheControlMaxAge is set use that, otherwise fall back on the
// more conservative nextRun values
jwksCacheControlMaxAge, ok, err := i.oidcCache.Get(noNamespace, "jwksCacheControlMaxAge")
if err != nil {
return "", err
}
if ok {
maxDuration := int64(jwksCacheControlMaxAge.(time.Duration))
randDuration := mathrand.Int63n(maxDuration)
durationInSeconds := time.Duration(randDuration).Seconds()
return fmt.Sprintf("max-age=%.0f", durationInSeconds), nil
}
nextRun, ok, err := i.oidcCache.Get(noNamespace, "nextRun")
if err != nil {
return "", err
}
if ok {
now := time.Now()
expireAt := nextRun.(time.Time)
if expireAt.After(now) {
i.Logger().Debug("use nextRun value for Cache Control header", "nextRun", nextRun)
expireInSeconds := expireAt.Sub(time.Now()).Seconds()
return fmt.Sprintf("max-age=%.0f", expireInSeconds), nil
}
}
return "", nil
}
// pathOIDCReadPublicKeys is used to retrieve all public keys so that clients can
// verify the validity of a signed OIDC token.
func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
@ -1209,27 +1262,19 @@ func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical
},
}
// set a Cache-Control header only if there are keys, if there aren't keys
// then nextRun should not be used to set Cache-Control header because it chooses
// a time in the future that isn't based on key rotation/expiration values
// set a Cache-Control header only if there are keys
keys, err := listOIDCPublicKeys(ctx, req.Storage)
if err != nil {
return nil, err
}
if len(keys) > 0 {
v, ok, err := i.oidcCache.Get(noNamespace, "nextRun")
header, err := i.getKeysCacheControlHeader()
if err != nil {
return nil, err
}
if ok {
now := time.Now()
expireAt := v.(time.Time)
if expireAt.After(now) {
expireInSeconds := expireAt.Sub(time.Now()).Seconds()
expireInString := fmt.Sprintf("max-age=%.0f", expireInSeconds)
resp.Data[logical.HTTPRawCacheControl] = expireInString
}
if header != "" {
resp.Data[logical.HTTPRawCacheControl] = header
}
}
@ -1326,10 +1371,9 @@ func (i *IdentityStore) pathOIDCIntrospect(ctx context.Context, req *logical.Req
return introspectionResp("")
}
// namedKey.rotate(overrides) performs a key rotation on a namedKey and returns the
// verification_ttl that was applied. verification_ttl can be overridden with an
// overrideVerificationTTL value >= 0
func (k *namedKey) rotate(ctx context.Context, s logical.Storage, overrideVerificationTTL time.Duration) error {
// namedKey.rotate(overrides) performs a key rotation on a namedKey.
// verification_ttl can be overridden with an overrideVerificationTTL value >= 0
func (k *namedKey) rotate(ctx context.Context, logger hclog.Logger, s logical.Storage, overrideVerificationTTL time.Duration) error {
verificationTTL := k.VerificationTTL
if overrideVerificationTTL >= 0 {
@ -1337,16 +1381,16 @@ func (k *namedKey) rotate(ctx context.Context, s logical.Storage, overrideVerifi
}
// generate new key
signingKey, err := generateKeys(k.Algorithm)
nextSigningKey, err := generateKeys(k.Algorithm)
if err != nil {
return err
}
if err := saveOIDCPublicKey(ctx, s, signingKey.Public()); err != nil {
if err := saveOIDCPublicKey(ctx, s, nextSigningKey.Public()); err != nil {
return err
}
logger.Debug("generated OIDC public key for future use", "key_id", nextSigningKey.Public().KeyID)
now := time.Now()
// set the previous public key's expiry time
for _, key := range k.KeyRing {
if key.KeyID == k.SigningKey.KeyID {
@ -1354,8 +1398,10 @@ func (k *namedKey) rotate(ctx context.Context, s logical.Storage, overrideVerifi
break
}
}
k.SigningKey = signingKey
k.KeyRing = append(k.KeyRing, &expireableKey{KeyID: signingKey.KeyID})
k.KeyRing = append(k.KeyRing, &expireableKey{KeyID: nextSigningKey.KeyID})
k.SigningKey = k.NextSigningKey
k.NextSigningKey = nextSigningKey
k.NextRotation = now.Add(k.RotationPeriod)
// store named key (it was modified when rotate was called on it)
@ -1367,6 +1413,7 @@ func (k *namedKey) rotate(ctx context.Context, s logical.Storage, overrideVerifi
return err
}
logger.Debug("rotated OIDC public key, now using", "key_id", k.SigningKey.Public().KeyID)
return nil
}
@ -1599,24 +1646,30 @@ func (i *IdentityStore) expireOIDCPublicKeys(ctx context.Context, s logical.Stor
return nextExpiration, nil
}
func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage) (time.Time, error) {
// oidcKeyRotation will rotate any keys that are due to be rotated.
//
// It will return the time of the soonest rotation and the minimum
// verificationTTL or minimum rotationPeriod out of all the current keys.
func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage) (time.Time, time.Duration, error) {
// soonestRotation will be the soonest rotation time of all keys. Initialize
// here to a relatively distant time.
now := time.Now()
soonestRotation := now.Add(24 * time.Hour)
jwksClientCacheDuration := time.Duration(math.MaxInt64)
i.oidcLock.Lock()
defer i.oidcLock.Unlock()
keys, err := s.List(ctx, namedKeyConfigPath)
if err != nil {
return now, err
return now, jwksClientCacheDuration, err
}
for _, k := range keys {
entry, err := s.Get(ctx, namedKeyConfigPath+k)
if err != nil {
return now, err
return now, jwksClientCacheDuration, err
}
if entry == nil {
@ -1625,10 +1678,18 @@ func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage)
var key namedKey
if err := entry.DecodeJSON(&key); err != nil {
return now, err
return now, jwksClientCacheDuration, err
}
key.name = k
if key.VerificationTTL < jwksClientCacheDuration {
jwksClientCacheDuration = key.VerificationTTL
}
if key.RotationPeriod < jwksClientCacheDuration {
jwksClientCacheDuration = key.RotationPeriod
}
// Future key rotation that is the earliest we've seen.
if now.Before(key.NextRotation) && key.NextRotation.Before(soonestRotation) {
soonestRotation = key.NextRotation
@ -1637,8 +1698,8 @@ func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage)
// Key that is due to be rotated.
if now.After(key.NextRotation) {
i.Logger().Debug("rotating OIDC key", "key", key.name)
if err := key.rotate(ctx, s, -1); err != nil {
return now, err
if err := key.rotate(ctx, i.Logger(), s, -1); err != nil {
return now, jwksClientCacheDuration, err
}
// Possibly save the new rotation time
@ -1648,12 +1709,13 @@ func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage)
}
}
return soonestRotation, nil
return soonestRotation, jwksClientCacheDuration, nil
}
// oidcPeriodFunc is invoked by the backend's periodFunc and runs regular key
// rotations and expiration actions.
func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
i.Logger().Debug("begin oidcPeriodicFunc")
var nextRun time.Time
now := time.Now()
@ -1675,6 +1737,7 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
// Initialize to a fairly distant next run time. This will be brought in
// based on key rotation times.
nextRun = now.Add(24 * time.Hour)
minJwksClientCacheDuration := time.Duration(math.MaxInt64)
for _, ns := range i.namespacer.ListNamespaces() {
nsPath := ns.Path
@ -1685,7 +1748,7 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
continue
}
nextRotation, err := i.oidcKeyRotation(ctx, s)
nextRotation, jwksClientCacheDuration, err := i.oidcKeyRotation(ctx, s)
if err != nil {
i.Logger().Warn("error rotating OIDC keys", "err", err)
}
@ -1707,10 +1770,31 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
if nextExpiration.Before(nextRun) {
nextRun = nextExpiration
}
if jwksClientCacheDuration < minJwksClientCacheDuration {
minJwksClientCacheDuration = jwksClientCacheDuration
}
}
if err := i.oidcCache.SetDefault(noNamespace, "nextRun", nextRun); err != nil {
i.Logger().Error("error setting oidc cache", "err", err)
}
if minJwksClientCacheDuration < math.MaxInt64 {
// the OIDC JWKS endpoint returns a Cache-Control HTTP header time between
// 0 and the minimum verificationTTL or minimum rotationPeriod out of all
// keys, whichever value is lower.
//
// This smooths calls from services validating JWTs to Vault, while
// ensuring that operators can assert that servers honoring the
// Cache-Control header will always have a superset of all valid keys, and
// not trust any keys longer than a jwksCacheControlMaxAge duration after a
// key is rotated out of signing use
if err := i.oidcCache.SetDefault(noNamespace, "jwksCacheControlMaxAge", minJwksClientCacheDuration); err != nil {
i.Logger().Error("error setting jwksCacheControlMaxAge in oidc cache", "err", err)
}
}
}
}

View File

@ -4,6 +4,8 @@ import (
"crypto/rand"
"crypto/rsa"
"encoding/json"
"strconv"
"strings"
"testing"
"time"
@ -626,7 +628,7 @@ func TestOIDC_PublicKeys(t *testing.T) {
Storage: storage,
})
// .well-known/keys should contain 1 public key
// .well-known/keys should contain 2 public keys
resp, err := c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/.well-known/keys",
Operation: logical.ReadOperation,
@ -636,8 +638,8 @@ func TestOIDC_PublicKeys(t *testing.T) {
// parse response
responseJWKS := &jose.JSONWebKeySet{}
json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS)
if len(responseJWKS.Keys) != 1 {
t.Fatalf("expected 1 public key but instead got %d", len(responseJWKS.Keys))
if len(responseJWKS.Keys) != 2 {
t.Fatalf("expected 2 public keys but instead got %d", len(responseJWKS.Keys))
}
// rotate test-key a few times, each rotate should increase the length of public keys returned
@ -655,7 +657,7 @@ func TestOIDC_PublicKeys(t *testing.T) {
})
expectSuccess(t, resp, err)
// .well-known/keys should contain 3 public keys
// .well-known/keys should contain 4 public keys
resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/.well-known/keys",
Operation: logical.ReadOperation,
@ -664,8 +666,8 @@ func TestOIDC_PublicKeys(t *testing.T) {
expectSuccess(t, resp, err)
// parse response
json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS)
if len(responseJWKS.Keys) != 3 {
t.Fatalf("expected 3 public keys but instead got %d", len(responseJWKS.Keys))
if len(responseJWKS.Keys) != 4 {
t.Fatalf("expected 4 public keys but instead got %d", len(responseJWKS.Keys))
}
// create another named key
@ -682,7 +684,7 @@ func TestOIDC_PublicKeys(t *testing.T) {
Storage: storage,
})
// .well-known/keys should contain 1 public key, all of the public keys
// .well-known/keys should contain 2 public key, all of the public keys
// from named key "test-key" should have been deleted
resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/.well-known/keys",
@ -692,8 +694,8 @@ func TestOIDC_PublicKeys(t *testing.T) {
expectSuccess(t, resp, err)
// parse response
json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS)
if len(responseJWKS.Keys) != 1 {
t.Fatalf("expected 1 public keys but instead got %d", len(responseJWKS.Keys))
if len(responseJWKS.Keys) != 2 {
t.Fatalf("expected 2 public keys but instead got %d", len(responseJWKS.Keys))
}
}
@ -814,10 +816,18 @@ func TestOIDC_SignIDToken(t *testing.T) {
responseJWKS := &jose.JSONWebKeySet{}
json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS)
// Validate the signature
claims := &jwt.Claims{}
if err := parsedToken.Claims(responseJWKS.Keys[0], claims); err != nil {
t.Fatalf("unable to validate signed token, err:\n%#v", err)
keyCount := len(responseJWKS.Keys)
errorCount := 0
for _, key := range responseJWKS.Keys {
// Validate the signature
claims := &jwt.Claims{}
if err := parsedToken.Claims(key, claims); err != nil {
t.Logf("unable to validate signed token, err:\n%#v", err)
errorCount += 1
}
}
if errorCount == keyCount {
t.Fatalf("unable to validate signed token with any of the .well-known keys")
}
}
@ -856,6 +866,7 @@ func TestOIDC_PeriodicFunc(t *testing.T) {
RotationPeriod: 1 * cyclePeriod,
KeyRing: nil,
SigningKey: jwk,
NextSigningKey: jwk,
NextRotation: time.Now(),
},
[]struct {
@ -865,8 +876,11 @@ func TestOIDC_PeriodicFunc(t *testing.T) {
}{
{1, 1, 1},
{2, 2, 2},
{3, 2, 2},
{4, 2, 2},
{3, 3, 3},
{4, 3, 3},
{5, 3, 3},
{6, 3, 3},
{7, 3, 3},
},
},
}
@ -1368,6 +1382,62 @@ func TestOIDC_CacheNamespaceNilCheck(t *testing.T) {
}
}
func TestOIDC_GetKeysCacheControlHeader(t *testing.T) {
c, _, _ := TestCoreUnsealed(t)
// get default value
header, err := c.identityStore.getKeysCacheControlHeader()
if err != nil {
t.Fatalf("expected success, got error:\n%v", err)
}
expectedHeader := ""
if header != expectedHeader {
t.Fatalf("expected %s, got %s", expectedHeader, header)
}
// set nextRun
nextRun := time.Now().Add(24 * time.Hour)
if err = c.identityStore.oidcCache.SetDefault(noNamespace, "nextRun", nextRun); err != nil {
t.Fatal(err)
}
header, err = c.identityStore.getKeysCacheControlHeader()
if err != nil {
t.Fatalf("expected success, got error:\n%v", err)
}
expectedNextRun := "max-age=86400"
if header != expectedNextRun {
t.Fatalf("expected %s, got %s", expectedNextRun, header)
}
// set jwksCacheControlMaxAge
jwksCacheControlMaxAge := time.Duration(60) * time.Second
if err = c.identityStore.oidcCache.SetDefault(noNamespace, "jwksCacheControlMaxAge", jwksCacheControlMaxAge); err != nil {
t.Fatal(err)
}
header, err = c.identityStore.getKeysCacheControlHeader()
if err != nil {
t.Fatalf("expected success, got error:\n%v", err)
}
if header == "" {
t.Fatalf("expected header to be set, got %s", header)
}
maxAgeValue := strings.Split(header, "=")[1]
headerVal, err := strconv.Atoi(maxAgeValue)
if err != nil {
t.Fatal(err)
}
// headerVal will be a random value between 0 and jwksCacheControlMaxAge
if headerVal > int(jwksCacheControlMaxAge) {
t.Fatalf("unexpected header value, got %d", headerVal)
}
}
// some helpers
func expectSuccess(t *testing.T, resp *logical.Response, err error) {
t.Helper()