From f30c3ac62136fb3fddb57b772b870b7acec329dc Mon Sep 17 00:00:00 2001 From: John-Michael Faircloth Date: Tue, 12 Oct 2021 11:14:03 -0500 Subject: [PATCH] Filter identity token keys (#12780) * filter identity token keys * Update test cases to associate keys with roles * use getOIDCRole helper * add func comment and test assertion * add changelog * remove unnecessary code * build list of keys to return by starting with a list of roles * move comment * update changelog --- changelog/12780.txt | 3 + vault/identity_store_oidc.go | 53 +++++++++++-- vault/identity_store_oidc_test.go | 119 ++++++++++++++++++++---------- 3 files changed, 127 insertions(+), 48 deletions(-) create mode 100644 changelog/12780.txt diff --git a/changelog/12780.txt b/changelog/12780.txt new file mode 100644 index 000000000..61a2c5d4f --- /dev/null +++ b/changelog/12780.txt @@ -0,0 +1,3 @@ +```release-note:improvement +identity/token: Only return keys from the `.well-known/keys` endpoint that are being used by roles to sign/verify tokens. +``` diff --git a/vault/identity_store_oidc.go b/vault/identity_store_oidc.go index 8dd8be260..ce4d628b8 100644 --- a/vault/identity_store_oidc.go +++ b/vault/identity_store_oidc.go @@ -613,6 +613,27 @@ func (i *IdentityStore) pathOIDCReadKey(ctx context.Context, req *logical.Reques }, nil } +// keyIDsByName will return a slice of key IDs for the given key name +func (i *IdentityStore) keyIDsByName(ctx context.Context, s logical.Storage, name string) ([]string, error) { + var keyIDs []string + entry, err := s.Get(ctx, namedKeyConfigPath+name) + if err != nil { + return keyIDs, err + } + if entry == nil { + return keyIDs, nil + } + + var key namedKey + if err := entry.DecodeJSON(&key); err != nil { + return keyIDs, err + } + for _, k := range key.KeyRing { + keyIDs = append(keyIDs, k.KeyID) + } + return keyIDs, nil +} + // rolesReferencingTargetKeyName returns a map of role names to roles // referencing targetKeyName. // @@ -1538,21 +1559,37 @@ func (i *IdentityStore) generatePublicJWKS(ctx context.Context, s logical.Storag return nil, err } - keyIDs, err := listOIDCPublicKeys(ctx, s) + jwks := &jose.JSONWebKeySet{ + Keys: make([]jose.JSONWebKey, 0), + } + + // only return keys that are associated with a role + roleNames, err := s.List(ctx, roleConfigPath) if err != nil { return nil, err } - jwks := &jose.JSONWebKeySet{ - Keys: make([]jose.JSONWebKey, 0, len(keyIDs)), - } - - for _, keyID := range keyIDs { - key, err := loadOIDCPublicKey(ctx, s, keyID) + for _, roleName := range roleNames { + role, err := i.getOIDCRole(ctx, s, roleName) if err != nil { return nil, err } - jwks.Keys = append(jwks.Keys, *key) + if role == nil { + continue + } + + keyIDs, err := i.keyIDsByName(ctx, s, role.Key) + if err != nil { + return nil, err + } + + for _, keyID := range keyIDs { + key, err := loadOIDCPublicKey(ctx, s, keyID) + if err != nil { + return nil, err + } + jwks.Keys = append(jwks.Keys, *key) + } } if err := i.oidcCache.SetDefault(ns, "jwks", jwks); err != nil { diff --git a/vault/identity_store_oidc_test.go b/vault/identity_store_oidc_test.go index 45a5da3ee..ba63a940f 100644 --- a/vault/identity_store_oidc_test.go +++ b/vault/identity_store_oidc_test.go @@ -1,6 +1,7 @@ package vault import ( + "context" "crypto/rand" "crypto/rsa" "encoding/json" @@ -637,6 +638,43 @@ func TestOIDC_Path_OIDCKey_DeleteWithExistingClient(t *testing.T) { expectError(t, resp, err) } +// TestOIDC_PublicKeys_NoRole tests that public keys are not returned by the +// oidc/.well-known/keys endpoint when they are not associated with a role +func TestOIDC_PublicKeys_NoRole(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + ctx := namespace.RootContext(nil) + s := &logical.InmemStorage{} + + // Create a test key "test-key" + resp, err := c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/key/test-key", + Operation: logical.CreateOperation, + Storage: s, + }) + expectSuccess(t, resp, err) + + // .well-known/keys should contain 0 public keys + assertPublicKeyCount(t, ctx, s, c, 0) +} + +func assertPublicKeyCount(t *testing.T, ctx context.Context, s logical.Storage, c *Core, keyCount int) { + t.Helper() + + // .well-known/keys should contain keyCount public keys + resp, err := c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/.well-known/keys", + Operation: logical.ReadOperation, + Storage: s, + }) + expectSuccess(t, resp, err) + // parse response + responseJWKS := &jose.JSONWebKeySet{} + json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS) + if len(responseJWKS.Keys) != keyCount { + t.Fatalf("expected %d public keys but instead got %d", keyCount, len(responseJWKS.Keys)) + } +} + // TestOIDC_PublicKeys tests that public keys are updated by // key creation, rotation, and deletion func TestOIDC_PublicKeys(t *testing.T) { @@ -651,23 +689,22 @@ func TestOIDC_PublicKeys(t *testing.T) { Storage: storage, }) - // .well-known/keys should contain 2 public keys - resp, err := c.identityStore.HandleRequest(ctx, &logical.Request{ - Path: "oidc/.well-known/keys", - Operation: logical.ReadOperation, - Storage: storage, + // Create a test role "test-role" + c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/role/test-role", + Operation: logical.CreateOperation, + Data: map[string]interface{}{ + "key": "test-key", + }, + Storage: storage, }) - expectSuccess(t, resp, err) - // parse response - responseJWKS := &jose.JSONWebKeySet{} - json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS) - if len(responseJWKS.Keys) != 2 { - t.Fatalf("expected 2 public keys but instead got %d", len(responseJWKS.Keys)) - } + + // .well-known/keys should contain 2 public keys + assertPublicKeyCount(t, ctx, storage, c, 2) // rotate test-key a few times, each rotate should increase the length of public keys returned // by the .well-known endpoint - resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ + resp, err := c.identityStore.HandleRequest(ctx, &logical.Request{ Path: "oidc/key/test-key/rotate", Operation: logical.UpdateOperation, Storage: storage, @@ -681,45 +718,47 @@ func TestOIDC_PublicKeys(t *testing.T) { expectSuccess(t, resp, err) // .well-known/keys should contain 4 public keys - resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ - Path: "oidc/.well-known/keys", - Operation: logical.ReadOperation, - Storage: storage, - }) - expectSuccess(t, resp, err) - // parse response - json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS) - if len(responseJWKS.Keys) != 4 { - t.Fatalf("expected 4 public keys but instead got %d", len(responseJWKS.Keys)) - } + assertPublicKeyCount(t, ctx, storage, c, 4) - // create another named key - c.identityStore.HandleRequest(ctx, &logical.Request{ + // create another named key "test-key2" + resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ Path: "oidc/key/test-key2", Operation: logical.CreateOperation, Storage: storage, }) + expectSuccess(t, resp, err) + // Create a test role "test-role2" + resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/role/test-role2", + Operation: logical.CreateOperation, + Data: map[string]interface{}{ + "key": "test-key2", + }, + Storage: storage, + }) + expectSuccess(t, resp, err) + // .well-known/keys should contain 6 public keys + assertPublicKeyCount(t, ctx, storage, c, 6) + + // delete test role that references "test-key" + resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/role/test-role", + Operation: logical.DeleteOperation, + Storage: storage, + }) + expectSuccess(t, resp, err) // delete test key - c.identityStore.HandleRequest(ctx, &logical.Request{ + resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ Path: "oidc/key/test-key", Operation: logical.DeleteOperation, Storage: storage, }) - - // .well-known/keys should contain 2 public key, all of the public keys - // from named key "test-key" should have been deleted - resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ - Path: "oidc/.well-known/keys", - Operation: logical.ReadOperation, - Storage: storage, - }) expectSuccess(t, resp, err) - // parse response - json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS) - if len(responseJWKS.Keys) != 2 { - t.Fatalf("expected 2 public keys but instead got %d", len(responseJWKS.Keys)) - } + + // .well-known/keys should contain 2 public keys, all of the public keys + // from named key "test-key" should have been deleted + assertPublicKeyCount(t, ctx, storage, c, 2) } // TestOIDC_SignIDToken tests acquiring a signed token and verifying the public portion