diff --git a/vault/identity_store.go b/vault/identity_store.go index fd7395649..98a5350d9 100644 --- a/vault/identity_store.go +++ b/vault/identity_store.go @@ -18,6 +18,7 @@ import ( "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/logical" + "github.com/patrickmn/go-cache" ) const ( @@ -56,6 +57,7 @@ func NewIdentityStore(ctx context.Context, core *Core, config *logical.BackendCo metrics: core.MetricSink(), totpPersister: core, groupUpdater: core, + tokenStorer: core, } // Create a memdb instance, which by default, operates on lower cased @@ -96,7 +98,8 @@ func NewIdentityStore(ctx context.Context, core *Core, config *logical.BackendCo }, } - iStore.oidcCache = newOIDCCache() + iStore.oidcCache = newOIDCCache(cache.NoExpiration, cache.NoExpiration) + iStore.oidcAuthCodeCache = newOIDCCache(5*time.Minute, 5*time.Minute) err = iStore.Setup(ctx, config) if err != nil { @@ -181,6 +184,10 @@ func (i *IdentityStore) Invalidate(ctx context.Context, key string) { i.logger.Error("failed to load groups during invalidation", "error", err) return } + if err := i.loadOIDCClients(ctx); err != nil { + i.logger.Error("failed to load OIDC clients during invalidation", "error", err) + return + } // Check if the key is a storage entry key for an entity bucket case strings.HasPrefix(key, storagepacker.StoragePackerBucketsPrefix): // Create a MemDB transaction @@ -334,6 +341,14 @@ func (i *IdentityStore) Invalidate(ctx context.Context, key string) { if err := i.oidcCache.Flush(ns); err != nil { i.logger.Error("error flushing oidc cache", "error", err) } + case strings.HasPrefix(key, clientPath): + name := strings.TrimPrefix(key, clientPath) + + // Invalidate the cached client in memdb + if err := i.memDBDeleteClientByName(ctx, name); err != nil { + i.logger.Error("error invalidating client", "error", err, "key", key) + return + } } } diff --git a/vault/identity_store_oidc.go b/vault/identity_store_oidc.go index 2bf1bfd21..405685a38 100644 --- a/vault/identity_store_oidc.go +++ b/vault/identity_store_oidc.go @@ -1797,9 +1797,9 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) { } } -func newOIDCCache() *oidcCache { +func newOIDCCache(defaultExpiration, cleanupInterval time.Duration) *oidcCache { return &oidcCache{ - c: cache.New(cache.NoExpiration, cache.NoExpiration), + c: cache.New(defaultExpiration, cleanupInterval), } } diff --git a/vault/identity_store_oidc_provider.go b/vault/identity_store_oidc_provider.go index 39ec71cea..3b5ff1917 100644 --- a/vault/identity_store_oidc_provider.go +++ b/vault/identity_store_oidc_provider.go @@ -4,12 +4,15 @@ import ( "context" "encoding/base64" "encoding/json" + "errors" "fmt" + "net/http" "net/url" "sort" "strings" "time" + "github.com/hashicorp/go-memdb" "github.com/hashicorp/go-secure-stdlib/base62" "github.com/hashicorp/go-secure-stdlib/strutil" "github.com/hashicorp/vault/helper/namespace" @@ -19,9 +22,38 @@ import ( "gopkg.in/square/go-jose.v2" ) +const ( + // OIDC-related constants + openIDScope = "openid" + + // Storage path constants + oidcProviderPrefix = "oidc_provider/" + assignmentPath = oidcProviderPrefix + "assignment/" + scopePath = oidcProviderPrefix + "scope/" + clientPath = oidcProviderPrefix + "client/" + providerPath = oidcProviderPrefix + "provider/" + + // Error constants used in the Authorization Endpoint. See details at + // https://openid.net/specs/openid-connect-core-1_0.html#AuthError. + ErrAuthUnsupportedResponseType = "unsupported_response_type" + ErrAuthInvalidRequest = "invalid_request" + ErrAuthAccessDenied = "access_denied" + ErrAuthUnauthorizedClient = "unauthorized_client" + ErrAuthServerError = "server_error" + ErrAuthRequestNotSupported = "request_not_supported" + ErrAuthRequestURINotSupported = "request_uri_not_supported" + + // The following errors are used by the UI for specific behavior of + // the OIDC specification. Any changes to their values must come with + // a corresponding change in the UI code. + ErrAuthInvalidClientID = "invalid_client_id" + ErrAuthInvalidRedirectURI = "invalid_redirect_uri" + ErrAuthMaxAgeReAuthenticate = "max_age_violation" +) + type assignment struct { - Groups []string `json:"groups"` - Entities []string `json:"entities"` + GroupIDs []string `json:"group_ids"` + EntityIDs []string `json:"entity_ids"` } type scope struct { @@ -30,13 +62,18 @@ type scope struct { } type client struct { + // Used for indexing in memdb + Name string `json:"name"` + NamespaceID string `json:"namespace_id"` + + // User-supplied parameters RedirectURIs []string `json:"redirect_uris"` Assignments []string `json:"assignments"` Key string `json:"key"` IDTokenTTL time.Duration `json:"id_token_ttl"` AccessTokenTTL time.Duration `json:"access_token_ttl"` - // used for OIDC endpoints + // Generated values that are used in OIDC endpoints ClientID string `json:"client_id"` ClientSecret string `json:"client_secret"` } @@ -45,6 +82,7 @@ type provider struct { Issuer string `json:"issuer"` AllowedClientIDs []string `json:"allowed_client_ids"` Scopes []string `json:"scopes"` + // effectiveIssuer is a calculated field and will be either Issuer (if // that's set) or the Vault instance's api_addr. effectiveIssuer string @@ -65,13 +103,14 @@ type providerDiscovery struct { AuthMethods []string `json:"token_endpoint_auth_methods_supported"` } -const ( - oidcProviderPrefix = "oidc_provider/" - assignmentPath = oidcProviderPrefix + "assignment/" - scopePath = oidcProviderPrefix + "scope/" - clientPath = oidcProviderPrefix + "client/" - providerPath = oidcProviderPrefix + "provider/" -) +type authCodeCacheEntry struct { + clientID string + entityID string + redirectURI string + nonce string + scopes []string + authTime time.Time +} func oidcProviderPaths(i *IdentityStore) []*framework.Path { return []*framework.Path{ @@ -82,13 +121,13 @@ func oidcProviderPaths(i *IdentityStore) []*framework.Path { Type: framework.TypeString, Description: "Name of the assignment", }, - "entities": { + "entity_ids": { Type: framework.TypeCommaStringSlice, - Description: "Comma separated string or array of identity entity names", + Description: "Comma separated string or array of identity entity IDs", }, - "groups": { + "group_ids": { Type: framework.TypeCommaStringSlice, - Description: "Comma separated string or array of identity group names", + Description: "Comma separated string or array of identity group IDs", }, }, Operations: map[logical.Operation]framework.OperationHandler{ @@ -298,197 +337,64 @@ func oidcProviderPaths(i *IdentityStore) []*framework.Path { HelpSynopsis: "Retrieve public keys", HelpDescription: "Returns the public portion of keys for a named OIDC provider. Clients can use them to validate the authenticity of an ID token.", }, - } -} - -func (i *IdentityStore) listClients(ctx context.Context, s logical.Storage) ([]*client, error) { - clientNames, err := s.List(ctx, clientPath) - if err != nil { - return nil, err - } - - var clients []*client - for _, name := range clientNames { - entry, err := s.Get(ctx, clientPath+name) - if err != nil { - return nil, err - } - if entry == nil { - continue - } - - var client client - if err := entry.DecodeJSON(&client); err != nil { - return nil, err - } - clients = append(clients, &client) - } - - return clients, nil -} - -// TODO: load clients into memory (go-memdb) to look this up -func (i *IdentityStore) clientByID(ctx context.Context, s logical.Storage, id string) (*client, error) { - clients, err := i.listClients(ctx, s) - if err != nil { - return nil, err - } - - for _, client := range clients { - if client.ClientID == id { - return client, nil - } - } - - return nil, nil -} - -// keyIDsReferencedByTargetClientIDs returns a slice of key IDs that are -// referenced by the clients' targetIDs. -// If targetIDs contains "*" then the IDs for all public keys are returned. -func (i *IdentityStore) keyIDsReferencedByTargetClientIDs(ctx context.Context, s logical.Storage, targetIDs []string) ([]string, error) { - keyNames := make(map[string]bool) - - // Get all key names referenced by clients if wildcard "*" in target client IDs - if strutil.StrListContains(targetIDs, "*") { - clients, err := i.listClients(ctx, s) - if err != nil { - return nil, err - } - - for _, client := range clients { - keyNames[client.Key] = true - } - } - - // Otherwise, get the key names referenced by each target client ID - if len(keyNames) == 0 { - for _, clientID := range targetIDs { - client, err := i.clientByID(ctx, s, clientID) - if err != nil { - return nil, err - } - - if client != nil { - keyNames[client.Key] = true - } - } - } - - // Collect the key IDs - var keyIDs []string - for name, _ := range keyNames { - entry, err := s.Get(ctx, namedKeyConfigPath+name) - if err != nil { - return nil, err - } - - var key namedKey - if err := entry.DecodeJSON(&key); err != nil { - return nil, err - } - for _, expirableKey := range key.KeyRing { - keyIDs = append(keyIDs, expirableKey.KeyID) - } - } - return keyIDs, nil -} - -// pathOIDCReadProviderPublicKeys is used to retrieve all public keys for a -// named provider so that clients can verify the validity of a signed OIDC token. -func (i *IdentityStore) pathOIDCReadProviderPublicKeys(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - providerName := d.Get("name").(string) - - var provider provider - - providerEntry, err := req.Storage.Get(ctx, providerPath+providerName) - if err != nil { - return nil, err - } - if providerEntry == nil { - return nil, nil - } - if err := providerEntry.DecodeJSON(&provider); err != nil { - return nil, err - } - - keyIDs, err := i.keyIDsReferencedByTargetClientIDs(ctx, req.Storage, provider.AllowedClientIDs) - if err != nil { - return nil, err - } - - jwks := &jose.JSONWebKeySet{ - Keys: make([]jose.JSONWebKey, 0, len(keyIDs)), - } - - for _, keyID := range keyIDs { - key, err := loadOIDCPublicKey(ctx, req.Storage, keyID) - if err != nil { - return nil, err - } - jwks.Keys = append(jwks.Keys, *key) - } - - data, err := json.Marshal(jwks) - if err != nil { - return nil, err - } - - resp := &logical.Response{ - Data: map[string]interface{}{ - logical.HTTPStatusCode: 200, - logical.HTTPRawBody: data, - logical.HTTPContentType: "application/json", + { + Pattern: "oidc/provider/" + framework.GenericNameRegex("name") + "/authorize", + Fields: map[string]*framework.FieldSchema{ + "name": { + Type: framework.TypeString, + Description: "Name of the provider", + }, + "client_id": { + Type: framework.TypeString, + Description: "The ID of the requesting client.", + Required: true, + }, + "scope": { + Type: framework.TypeString, + Description: "A space-delimited, case-sensitive list of scopes to be requested. The 'openid' scope is required.", + Required: true, + }, + "redirect_uri": { + Type: framework.TypeString, + Description: "The redirection URI to which the response will be sent.", + Required: true, + }, + "response_type": { + Type: framework.TypeString, + Description: "The OIDC authentication flow to be used. The following response types are supported: 'code'", + Required: true, + }, + "state": { + Type: framework.TypeString, + Description: "The value used to maintain state between the authentication request and client.", + Required: true, + }, + "nonce": { + Type: framework.TypeString, + Description: "The value that will be returned in the ID token nonce claim after a token exchange.", + Required: true, + }, + "max_age": { + Type: framework.TypeInt, + Description: "The allowable elapsed time in seconds since the last time the end-user was actively authenticated.", + }, + }, + Operations: map[logical.Operation]framework.OperationHandler{ + logical.ReadOperation: &framework.PathOperation{ + Callback: i.pathOIDCAuthorize, + ForwardPerformanceStandby: true, + ForwardPerformanceSecondary: false, + }, + logical.UpdateOperation: &framework.PathOperation{ + Callback: i.pathOIDCAuthorize, + ForwardPerformanceStandby: true, + ForwardPerformanceSecondary: false, + }, + }, + HelpSynopsis: "Provides the OIDC Authorization Endpoint.", + HelpDescription: "The OIDC Authorization Endpoint performs authentication and authorization by using request parameters defined by OpenID Connect (OIDC).", }, } - - return resp, nil -} - -func (i *IdentityStore) pathOIDCProviderDiscovery(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - name := d.Get("name").(string) - - p, err := i.getOIDCProvider(ctx, req.Storage, name) - if err != nil { - return nil, err - } - if p == nil { - return nil, nil - } - - // the "openid" scope is reserved and is included for every provider - scopes := append(p.Scopes, "openid") - - disc := providerDiscovery{ - Issuer: p.effectiveIssuer, - Keys: p.effectiveIssuer + "/.well-known/keys", - AuthorizationEndpoint: strings.Replace(p.effectiveIssuer, "/v1/", "/ui/vault/", 1) + "/authorize", - TokenEndpoint: p.effectiveIssuer + "/token", - UserinfoEndpoint: p.effectiveIssuer + "/userinfo", - IDTokenAlgs: supportedAlgs, - Scopes: scopes, - RequestURIParameter: false, - ResponseTypes: []string{"code"}, - Subjects: []string{"public"}, - GrantTypes: []string{"authorization_code"}, - AuthMethods: []string{"client_secret_basic"}, - } - - data, err := json.Marshal(disc) - if err != nil { - return nil, err - } - - resp := &logical.Response{ - Data: map[string]interface{}{ - logical.HTTPStatusCode: 200, - logical.HTTPRawBody: data, - logical.HTTPContentType: "application/json", - logical.HTTPRawCacheControl: "max-age=3600", - }, - } - - return resp, nil } // clientsReferencingTargetAssignmentName returns a map of client names to @@ -628,18 +534,22 @@ func (i *IdentityStore) pathOIDCCreateUpdateAssignment(ctx context.Context, req } } - if entitiesRaw, ok := d.GetOk("entities"); ok { - assignment.Entities = entitiesRaw.([]string) + if entitiesRaw, ok := d.GetOk("entity_ids"); ok { + assignment.EntityIDs = entitiesRaw.([]string) } else if req.Operation == logical.CreateOperation { - assignment.Entities = d.GetDefaultOrZero("entities").([]string) + assignment.EntityIDs = d.GetDefaultOrZero("entity_ids").([]string) } - if groupsRaw, ok := d.GetOk("groups"); ok { - assignment.Groups = groupsRaw.([]string) + if groupsRaw, ok := d.GetOk("group_ids"); ok { + assignment.GroupIDs = groupsRaw.([]string) } else if req.Operation == logical.CreateOperation { - assignment.Groups = d.GetDefaultOrZero("groups").([]string) + assignment.GroupIDs = d.GetDefaultOrZero("group_ids").([]string) } + // remove duplicates and lowercase entities and groups + assignment.EntityIDs = strutil.RemoveDuplicates(assignment.EntityIDs, true) + assignment.GroupIDs = strutil.RemoveDuplicates(assignment.GroupIDs, true) + // store assignment entry, err := logical.StorageEntryJSON(assignmentPath+name, assignment) if err != nil { @@ -666,7 +576,24 @@ func (i *IdentityStore) pathOIDCListAssignment(ctx context.Context, req *logical func (i *IdentityStore) pathOIDCReadAssignment(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - entry, err := req.Storage.Get(ctx, assignmentPath+name) + assignment, err := i.getOIDCAssignment(ctx, req.Storage, name) + if err != nil { + return nil, err + } + if assignment == nil { + return nil, nil + } + + return &logical.Response{ + Data: map[string]interface{}{ + "group_ids": assignment.GroupIDs, + "entity_ids": assignment.EntityIDs, + }, + }, nil +} + +func (i *IdentityStore) getOIDCAssignment(ctx context.Context, s logical.Storage, name string) (*assignment, error) { + entry, err := s.Get(ctx, assignmentPath+name) if err != nil { return nil, err } @@ -678,12 +605,8 @@ func (i *IdentityStore) pathOIDCReadAssignment(ctx context.Context, req *logical if err := entry.DecodeJSON(&assignment); err != nil { return nil, err } - return &logical.Response{ - Data: map[string]interface{}{ - "groups": assignment.Groups, - "entities": assignment.Entities, - }, - }, nil + + return &assignment, nil } // pathOIDCDeleteAssignment is used to delete an assignment @@ -722,8 +645,8 @@ func (i *IdentityStore) pathOIDCAssignmentExistenceCheck(ctx context.Context, re // pathOIDCCreateUpdateScope is used to create a new scope or update an existing one func (i *IdentityStore) pathOIDCCreateUpdateScope(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - if name == "openid" { - return logical.ErrorResponse("the \"openid\" scope name is reserved"), nil + if name == openIDScope { + return logical.ErrorResponse("the %q scope name is reserved", openIDScope), nil } var scope scope @@ -865,7 +788,15 @@ func (i *IdentityStore) pathOIDCScopeExistenceCheck(ctx context.Context, req *lo func (i *IdentityStore) pathOIDCCreateUpdateClient(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - var client client + ns, err := namespace.FromContext(ctx) + if err != nil { + return nil, err + } + + client := client{ + Name: name, + NamespaceID: ns.ID, + } if req.Operation == logical.UpdateOperation { entry, err := req.Storage.Get(ctx, clientPath+name) if err != nil { @@ -958,12 +889,16 @@ func (i *IdentityStore) pathOIDCCreateUpdateClient(ctx context.Context, req *log client.ClientSecret = clientSecret } + // invalidate the cached client in memdb + if err := i.memDBDeleteClientByName(ctx, name); err != nil { + return nil, err + } + // store client entry, err = logical.StorageEntryJSON(clientPath+name, client) if err != nil { return nil, err } - if err := req.Storage.Put(ctx, entry); err != nil { return nil, err } @@ -984,18 +919,14 @@ func (i *IdentityStore) pathOIDCListClient(ctx context.Context, req *logical.Req func (i *IdentityStore) pathOIDCReadClient(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - entry, err := req.Storage.Get(ctx, clientPath+name) + client, err := i.clientByName(ctx, req.Storage, name) if err != nil { return nil, err } - if entry == nil { + if client == nil { return nil, nil } - var client client - if err := entry.DecodeJSON(&client); err != nil { - return nil, err - } return &logical.Response{ Data: map[string]interface{}{ "redirect_uris": client.RedirectURIs, @@ -1012,10 +943,17 @@ func (i *IdentityStore) pathOIDCReadClient(ctx context.Context, req *logical.Req // pathOIDCDeleteClient is used to delete an client func (i *IdentityStore) pathOIDCDeleteClient(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - err := req.Storage.Delete(ctx, clientPath+name) - if err != nil { + + // Delete the client from memdb + if err := i.memDBDeleteClientByName(ctx, name); err != nil { return nil, err } + + // Delete the client from storage + if err := req.Storage.Delete(ctx, clientPath+name); err != nil { + return nil, err + } + return nil, nil } @@ -1236,3 +1174,667 @@ func (i *IdentityStore) pathOIDCProviderExistenceCheck(ctx context.Context, req return entry != nil, nil } + +func (i *IdentityStore) pathOIDCProviderDiscovery(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + name := d.Get("name").(string) + + p, err := i.getOIDCProvider(ctx, req.Storage, name) + if err != nil { + return nil, err + } + if p == nil { + return nil, nil + } + + // the "openid" scope is reserved and is included for every provider + scopes := append(p.Scopes, openIDScope) + + disc := providerDiscovery{ + Issuer: p.effectiveIssuer, + Keys: p.effectiveIssuer + "/.well-known/keys", + AuthorizationEndpoint: strings.Replace(p.effectiveIssuer, "/v1/", "/ui/vault/", 1) + "/authorize", + TokenEndpoint: p.effectiveIssuer + "/token", + UserinfoEndpoint: p.effectiveIssuer + "/userinfo", + IDTokenAlgs: supportedAlgs, + Scopes: scopes, + RequestURIParameter: false, + ResponseTypes: []string{"code"}, + Subjects: []string{"public"}, + GrantTypes: []string{"authorization_code"}, + AuthMethods: []string{"client_secret_basic"}, + } + + data, err := json.Marshal(disc) + if err != nil { + return nil, err + } + + resp := &logical.Response{ + Data: map[string]interface{}{ + logical.HTTPStatusCode: 200, + logical.HTTPRawBody: data, + logical.HTTPContentType: "application/json", + logical.HTTPRawCacheControl: "max-age=3600", + }, + } + + return resp, nil +} + +// pathOIDCReadProviderPublicKeys is used to retrieve all public keys for a +// named provider so that clients can verify the validity of a signed OIDC token. +func (i *IdentityStore) pathOIDCReadProviderPublicKeys(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + providerName := d.Get("name").(string) + + var provider provider + + providerEntry, err := req.Storage.Get(ctx, providerPath+providerName) + if err != nil { + return nil, err + } + if providerEntry == nil { + return nil, nil + } + if err := providerEntry.DecodeJSON(&provider); err != nil { + return nil, err + } + + keyIDs, err := i.keyIDsReferencedByTargetClientIDs(ctx, req.Storage, provider.AllowedClientIDs) + if err != nil { + return nil, err + } + + jwks := &jose.JSONWebKeySet{ + Keys: make([]jose.JSONWebKey, 0, len(keyIDs)), + } + + for _, keyID := range keyIDs { + key, err := loadOIDCPublicKey(ctx, req.Storage, keyID) + if err != nil { + return nil, err + } + jwks.Keys = append(jwks.Keys, *key) + } + + data, err := json.Marshal(jwks) + if err != nil { + return nil, err + } + + resp := &logical.Response{ + Data: map[string]interface{}{ + logical.HTTPStatusCode: 200, + logical.HTTPRawBody: data, + logical.HTTPContentType: "application/json", + }, + } + + return resp, nil +} + +// keyIDsReferencedByTargetClientIDs returns a slice of key IDs that are +// referenced by the clients' targetIDs. +// If targetIDs contains "*" then the IDs for all public keys are returned. +func (i *IdentityStore) keyIDsReferencedByTargetClientIDs(ctx context.Context, s logical.Storage, targetIDs []string) ([]string, error) { + keyNames := make(map[string]bool) + + // Get all key names referenced by clients if wildcard "*" in target client IDs + if strutil.StrListContains(targetIDs, "*") { + clients, err := i.listClients(ctx, s) + if err != nil { + return nil, err + } + + for _, client := range clients { + keyNames[client.Key] = true + } + } + + // Otherwise, get the key names referenced by each target client ID + if len(keyNames) == 0 { + for _, clientID := range targetIDs { + client, err := i.clientByID(ctx, s, clientID) + if err != nil { + return nil, err + } + + if client != nil { + keyNames[client.Key] = true + } + } + } + + // Collect the key IDs + var keyIDs []string + for name, _ := range keyNames { + entry, err := s.Get(ctx, namedKeyConfigPath+name) + if err != nil { + return nil, err + } + + var key namedKey + if err := entry.DecodeJSON(&key); err != nil { + return nil, err + } + for _, expirableKey := range key.KeyRing { + keyIDs = append(keyIDs, expirableKey.KeyID) + } + } + return keyIDs, nil +} + +func (i *IdentityStore) pathOIDCAuthorize(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + // Validate the state + state := d.Get("state").(string) + if state == "" { + return authResponse("", "", ErrAuthInvalidRequest, "state parameter is required") + } + + // Get the namespace + ns, err := namespace.FromContext(ctx) + if err != nil { + return authResponse("", state, ErrAuthServerError, err.Error()) + } + + // Get the OIDC provider + name := d.Get("name").(string) + provider, err := i.getOIDCProvider(ctx, req.Storage, name) + if err != nil { + return authResponse("", state, ErrAuthServerError, err.Error()) + } + if provider == nil { + return authResponse("", state, ErrAuthInvalidRequest, "provider not found") + } + + // Validate that a scope parameter is present and contains the openid scope value + scopes := strutil.ParseStringSlice(d.Get("scope").(string), " ") + if len(scopes) == 0 || !strutil.StrListContains(scopes, openIDScope) { + return authResponse("", state, ErrAuthInvalidRequest, + fmt.Sprintf("scope parameter must contain the %q value", openIDScope)) + } + + // Validate the response type + responseType := d.Get("response_type").(string) + if responseType == "" { + return authResponse("", state, ErrAuthInvalidRequest, "response_type parameter is required") + } + if responseType != "code" { + return authResponse("", state, ErrAuthUnsupportedResponseType, "unsupported response_type value") + } + + // Validate the client ID + clientID := d.Get("client_id").(string) + if clientID == "" { + return authResponse("", state, ErrAuthInvalidClientID, "client_id parameter is required") + } + client, err := i.clientByID(ctx, req.Storage, clientID) + if err != nil { + return authResponse("", state, ErrAuthServerError, err.Error()) + } + if client == nil { + return authResponse("", state, ErrAuthInvalidClientID, "client with client_id not found") + } + if !strutil.StrListContains(provider.AllowedClientIDs, "*") && + !strutil.StrListContains(provider.AllowedClientIDs, clientID) { + return authResponse("", state, ErrAuthUnauthorizedClient, "client is not authorized to use the provider") + } + + // Validate the redirect URI + redirectURI := d.Get("redirect_uri").(string) + if redirectURI == "" { + return authResponse("", state, ErrAuthInvalidRequest, "redirect_uri parameter is required") + } + if !strutil.StrListContains(client.RedirectURIs, redirectURI) { + return authResponse("", state, ErrAuthInvalidRedirectURI, "redirect_uri is not allowed for the client") + } + + // Validate the nonce + nonce := d.Get("nonce").(string) + if nonce == "" { + return authResponse("", state, ErrAuthInvalidRequest, "nonce parameter is required") + } + + // We don't support the request or request_uri parameters. If they're provided, + // the appropriate errors must be returned. For details, see the spec at: + // https://openid.net/specs/openid-connect-core-1_0.html#RequestObject + // https://openid.net/specs/openid-connect-core-1_0.html#RequestUriParameter + if _, ok := d.Raw["request"]; ok { + return authResponse("", state, ErrAuthRequestNotSupported, "request parameter is not supported") + } + if _, ok := d.Raw["request_uri"]; ok { + return authResponse("", state, ErrAuthRequestURINotSupported, "request_uri parameter is not supported") + } + + // Validate that there is an identity entity associated with the request + entity, err := i.MemDBEntityByID(req.EntityID, false) + if err != nil { + return authResponse("", state, ErrAuthServerError, err.Error()) + } + if entity == nil { + return authResponse("", state, ErrAuthAccessDenied, "identity entity must be associated with the request") + } + + // Validate that the identity entity associated with the request + // is a member of the client assignments' groups or entities + isMember, err := i.entityHasAssignment(ctx, req.Storage, entity.GetID(), client.Assignments) + if err != nil { + return authResponse("", state, ErrAuthServerError, err.Error()) + } + if !isMember { + return authResponse("", state, ErrAuthAccessDenied, "identity entity not authorized by client assignment") + } + + // Create the auth code cache entry + authCodeEntry := &authCodeCacheEntry{ + clientID: clientID, + entityID: entity.GetID(), + redirectURI: redirectURI, + nonce: nonce, + scopes: scopes, + } + + // Validate the optional max_age parameter to check if an active re-authentication + // of the user should occur. Re-authentication will be requested if max_age=0 or the + // last time the token actively authenticated exceeds the given max_age requirement. + // Returning ErrAuthMaxAgeReAuthenticate will enforce the user to re-authenticate via + // the user agent. + if maxAgeRaw, ok := d.GetOk("max_age"); ok { + maxAge := maxAgeRaw.(int) + if maxAge < 1 { + return authResponse("", state, ErrAuthInvalidRequest, "max_age must be greater than zero") + } + + // Look up the token associated with the request + te, err := i.tokenStorer.LookupToken(ctx, req.ClientToken) + if err != nil { + return authResponse("", state, ErrAuthServerError, err.Error()) + } + if te == nil { + return authResponse("", state, ErrAuthAccessDenied, "token associated with request not found") + } + + // Check if the token creation time violates the max age requirement + now := time.Now().UTC() + lastAuthTime := time.Unix(te.CreationTime, 0).UTC() + secondsSince := int(now.Sub(lastAuthTime).Seconds()) + if secondsSince > maxAge { + return authResponse("", state, ErrAuthMaxAgeReAuthenticate, "active re-authentication is required by max_age") + } + + // Set the auth time to use for the auth_time claim in the token exchange + authCodeEntry.authTime = lastAuthTime + } + + // Generate the authorization code + code, err := base62.Random(32) + if err != nil { + return authResponse("", state, ErrAuthServerError, err.Error()) + } + + // Cache the authorization code for a subsequent token exchange + if err := i.oidcAuthCodeCache.SetDefault(ns, code, authCodeEntry); err != nil { + return authResponse("", state, ErrAuthServerError, err.Error()) + } + + return authResponse(code, state, "", "") +} + +// authResponse returns the OIDC Authentication Response. An error response is +// returned if the given error code is non-empty. For details, see spec at +// - https://openid.net/specs/openid-connect-core-1_0.html#AuthResponse +// - https://openid.net/specs/openid-connect-core-1_0.html#AuthError +func authResponse(code, state, errorCode, errorDescription string) (*logical.Response, error) { + statusCode := http.StatusOK + response := map[string]interface{}{ + "code": code, + "state": state, + } + + // Set the error response and status code if error code isn't empty + if errorCode != "" { + statusCode = http.StatusBadRequest + if errorCode == ErrAuthServerError { + statusCode = http.StatusInternalServerError + } + + response = map[string]interface{}{ + "error": errorCode, + "error_description": errorDescription, + "state": state, + } + } + + data, err := json.Marshal(response) + if err != nil { + return nil, err + } + + return &logical.Response{ + Data: map[string]interface{}{ + logical.HTTPStatusCode: statusCode, + logical.HTTPRawBody: data, + logical.HTTPContentType: "application/json", + }, + }, nil +} + +// entityHasAssignment returns true if the entity is a member of any of the +// assignments' groups or entities. Otherwise, returns false or an error. +func (i *IdentityStore) entityHasAssignment(ctx context.Context, s logical.Storage, entityID string, assignments []string) (bool, error) { + // Get the group IDs that the entity is a member of + entityGroups, err := i.MemDBGroupsByMemberEntityID(entityID, true, false) + if err != nil { + return false, err + } + entityGroupIDs := make(map[string]bool) + for _, group := range entityGroups { + entityGroupIDs[group.GetID()] = true + } + + for _, a := range assignments { + assignment, err := i.getOIDCAssignment(ctx, s, a) + if err != nil { + return false, err + } + if assignment == nil { + return false, fmt.Errorf("client assignment %q not found", a) + } + + // Check if the entity is a member of any groups in the assignment + for _, id := range assignment.GroupIDs { + if entityGroupIDs[id] { + return true, nil + } + } + + // Check if the entity is a member of the assignment's entities + if strutil.StrListContains(assignment.EntityIDs, entityID) { + return true, nil + } + } + + return false, nil +} + +// clientByID returns the client with the given ID. +func (i *IdentityStore) clientByID(ctx context.Context, s logical.Storage, id string) (*client, error) { + // Read the client from memdb + client, err := i.memDBClientByID(id) + if err != nil { + return nil, err + } + if client != nil { + return client, nil + } + + // Fall back to reading the client from storage + client, err = i.storageClientByID(ctx, s, id) + if err != nil { + return nil, err + } + if client == nil { + return nil, nil + } + + // Upsert the client in memdb + txn := i.db.Txn(true) + defer txn.Abort() + if err := i.memDBUpsertClientInTxn(txn, client); err != nil { + i.logger.Debug("failed to upsert client in memdb", "error", err) + return client, nil + } + txn.Commit() + + return client, nil +} + +// clientByName returns the client with the given name. +func (i *IdentityStore) clientByName(ctx context.Context, s logical.Storage, name string) (*client, error) { + // Read the client from memdb + client, err := i.memDBClientByName(ctx, name) + if err != nil { + return nil, err + } + if client != nil { + return client, nil + } + + // Fall back to reading the client from storage + client, err = i.storageClientByName(ctx, s, name) + if err != nil { + return nil, err + } + if client == nil { + return nil, nil + } + + // Upsert the client in memdb + txn := i.db.Txn(true) + defer txn.Abort() + if err := i.memDBUpsertClientInTxn(txn, client); err != nil { + i.logger.Debug("failed to upsert client in memdb", "error", err) + return client, nil + } + txn.Commit() + + return client, nil +} + +// memDBClientByID returns the client with the given ID from memdb. +func (i *IdentityStore) memDBClientByID(id string) (*client, error) { + if id == "" { + return nil, errors.New("missing client ID") + } + + txn := i.db.Txn(false) + + return i.memDBClientByIDInTxn(txn, id) +} + +// memDBClientByIDInTxn returns the client with the given ID from memdb using the given txn. +func (i *IdentityStore) memDBClientByIDInTxn(txn *memdb.Txn, id string) (*client, error) { + if id == "" { + return nil, errors.New("missing client ID") + } + + if txn == nil { + return nil, errors.New("txn is nil") + } + + clientRaw, err := txn.First(oidcClientsTable, "id", id) + if err != nil { + return nil, fmt.Errorf("failed to fetch client from memdb using ID: %w", err) + } + if clientRaw == nil { + return nil, nil + } + + client, ok := clientRaw.(*client) + if !ok { + return nil, errors.New("unexpected client type") + } + + return client, nil +} + +// memDBClientByName returns the client with the given name from memdb. +func (i *IdentityStore) memDBClientByName(ctx context.Context, name string) (*client, error) { + if name == "" { + return nil, errors.New("missing client name") + } + + txn := i.db.Txn(false) + + return i.memDBClientByNameInTxn(ctx, txn, name) +} + +// memDBClientByNameInTxn returns the client with the given ID from memdb using the given txn. +func (i *IdentityStore) memDBClientByNameInTxn(ctx context.Context, txn *memdb.Txn, name string) (*client, error) { + if name == "" { + return nil, errors.New("missing client name") + } + + if txn == nil { + return nil, errors.New("txn is nil") + } + + ns, err := namespace.FromContext(ctx) + if err != nil { + return nil, err + } + + clientRaw, err := txn.First(oidcClientsTable, "name", ns.ID, name) + if err != nil { + return nil, fmt.Errorf("failed to fetch client from memdb using name: %w", err) + } + if clientRaw == nil { + return nil, nil + } + + client, ok := clientRaw.(*client) + if !ok { + return nil, errors.New("unexpected client type") + } + + return client, nil +} + +// memDBDeleteClientByName deletes the client with the given name from memdb. +func (i *IdentityStore) memDBDeleteClientByName(ctx context.Context, name string) error { + if name == "" { + return errors.New("missing client name") + } + + txn := i.db.Txn(true) + defer txn.Abort() + + if err := i.memDBDeleteClientByNameInTxn(ctx, txn, name); err != nil { + return err + } + + txn.Commit() + + return nil +} + +// memDBDeleteClientByNameInTxn deletes the client with name from memdb using the given txn. +func (i *IdentityStore) memDBDeleteClientByNameInTxn(ctx context.Context, txn *memdb.Txn, name string) error { + if name == "" { + return errors.New("missing client name") + } + + if txn == nil { + return errors.New("txn is nil") + } + + client, err := i.memDBClientByNameInTxn(ctx, txn, name) + if err != nil { + return err + } + if client == nil { + return nil + } + + if err := txn.Delete(oidcClientsTable, client); err != nil { + return fmt.Errorf("failed to delete client from memdb: %w", err) + } + + return nil +} + +// memDBUpsertClientInTxn creates or updates the given client in memdb using the given txn. +func (i *IdentityStore) memDBUpsertClientInTxn(txn *memdb.Txn, client *client) error { + if client == nil { + return errors.New("client is nil") + } + + if txn == nil { + return errors.New("nil txn") + } + + clientRaw, err := txn.First(oidcClientsTable, "id", client.ClientID) + if err != nil { + return fmt.Errorf("failed to lookup client from memdb using ID: %w", err) + } + + if clientRaw != nil { + err = txn.Delete(oidcClientsTable, clientRaw) + if err != nil { + return fmt.Errorf("failed to delete client from memdb: %w", err) + } + } + + if err := txn.Insert(oidcClientsTable, client); err != nil { + return fmt.Errorf("failed to update client in memdb: %w", err) + } + + return nil +} + +// storageClientByName returns the client with name from the given logical storage. +func (i *IdentityStore) storageClientByName(ctx context.Context, s logical.Storage, name string) (*client, error) { + entry, err := s.Get(ctx, clientPath+name) + if err != nil { + return nil, err + } + if entry == nil { + return nil, nil + } + + var client client + if err := entry.DecodeJSON(&client); err != nil { + return nil, err + } + + return &client, nil +} + +// storageClientByID returns the client with ID from the given logical storage. +func (i *IdentityStore) storageClientByID(ctx context.Context, s logical.Storage, id string) (*client, error) { + clients, err := s.List(ctx, clientPath) + if err != nil { + return nil, err + } + + for _, name := range clients { + client, err := i.storageClientByName(ctx, s, name) + if err != nil { + return nil, err + } + if client == nil { + continue + } + + if client.ClientID == id { + return client, nil + } + } + + return nil, nil +} + +func (i *IdentityStore) listClients(ctx context.Context, s logical.Storage) ([]*client, error) { + clientNames, err := s.List(ctx, clientPath) + if err != nil { + return nil, err + } + + var clients []*client + for _, name := range clientNames { + entry, err := s.Get(ctx, clientPath+name) + if err != nil { + return nil, err + } + if entry == nil { + continue + } + + var client client + if err := entry.DecodeJSON(&client); err != nil { + return nil, err + } + clients = append(clients, &client) + } + + return clients, nil +} diff --git a/vault/identity_store_oidc_provider_test.go b/vault/identity_store_oidc_provider_test.go index 6fe0ee948..4c5baff4c 100644 --- a/vault/identity_store_oidc_provider_test.go +++ b/vault/identity_store_oidc_provider_test.go @@ -5,14 +5,786 @@ import ( "encoding/json" "fmt" "testing" + "time" "github.com/go-test/deep" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/logical" + "github.com/stretchr/testify/assert" "gopkg.in/square/go-jose.v2" ) +func TestOIDC_Path_OIDC_Authorize(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + ctx := namespace.RootContext(nil) + storage := new(logical.InmemStorage) + + // Create a key + resp, err := c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/key/test-key", + Operation: logical.CreateOperation, + Data: map[string]interface{}{}, + Storage: storage, + }) + expectSuccess(t, resp, err) + + // Create an entity + resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "entity", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "name": "test-entity", + }, + }) + expectSuccess(t, resp, err) + assert.NotNil(t, resp.Data["id"]) + entityID := resp.Data["id"].(string) + + // Create a group + resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "group", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "name": "test-group", + "member_entity_ids": []string{entityID}, + }, + }) + expectSuccess(t, resp, err) + assert.NotNil(t, resp.Data["id"]) + groupID := resp.Data["id"].(string) + + type args struct { + entityID string + client client + provider provider + assignment assignment + authorizeRequest *logical.Request + tokenCreationTime func() time.Time + } + tests := []struct { + name string + args args + wantErr string + }{ + { + name: "invalid authorize request with provider not found", + args: args{ + entityID: entityID, + assignment: assignment{ + EntityIDs: []string{entityID}, + }, + client: client{ + RedirectURIs: []string{"https://localhost:8251/callback"}, + Assignments: []string{"test-assignment"}, + Key: "test-key", + }, + authorizeRequest: &logical.Request{ + Path: "oidc/provider/non-existent-provider/authorize", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "client_id": "", + "scope": "openid", + "redirect_uri": "https://localhost:8251/callback", + "response_type": "code", + "state": "abcdefg", + "nonce": "hijklmn", + }, + }, + }, + wantErr: ErrAuthInvalidRequest, + }, + { + name: "invalid authorize request with empty scopes", + args: args{ + entityID: entityID, + assignment: assignment{ + EntityIDs: []string{entityID}, + }, + client: client{ + RedirectURIs: []string{"https://localhost:8251/callback"}, + Assignments: []string{"test-assignment"}, + Key: "test-key", + }, + authorizeRequest: &logical.Request{ + Path: "oidc/provider/test-provider/authorize", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "client_id": "", + "scope": "", + "redirect_uri": "https://localhost:8251/callback", + "response_type": "code", + "state": "abcdefg", + "nonce": "hijklmn", + }, + }, + }, + wantErr: ErrAuthInvalidRequest, + }, + { + name: "invalid authorize request with missing openid scope", + args: args{ + entityID: entityID, + assignment: assignment{ + EntityIDs: []string{entityID}, + }, + client: client{ + RedirectURIs: []string{"https://localhost:8251/callback"}, + Assignments: []string{"test-assignment"}, + Key: "test-key", + }, + authorizeRequest: &logical.Request{ + Path: "oidc/provider/test-provider/authorize", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "client_id": "", + "scope": "groups email profile", + "redirect_uri": "https://localhost:8251/callback", + "response_type": "code", + "state": "abcdefg", + "nonce": "hijklmn", + }, + }, + }, + wantErr: ErrAuthInvalidRequest, + }, + { + name: "invalid authorize request with missing response_type", + args: args{ + entityID: entityID, + assignment: assignment{ + EntityIDs: []string{entityID}, + }, + client: client{ + RedirectURIs: []string{"https://localhost:8251/callback"}, + Assignments: []string{"test-assignment"}, + Key: "test-key", + }, + authorizeRequest: &logical.Request{ + Path: "oidc/provider/test-provider/authorize", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "client_id": "", + "scope": "openid", + "redirect_uri": "https://localhost:8251/callback", + "response_type": "", + "state": "abcdefg", + "nonce": "hijklmn", + }, + }, + }, + wantErr: ErrAuthInvalidRequest, + }, + { + name: "invalid authorize request with unsupported response_type", + args: args{ + entityID: entityID, + assignment: assignment{ + EntityIDs: []string{entityID}, + }, + client: client{ + RedirectURIs: []string{"https://localhost:8251/callback"}, + Assignments: []string{"test-assignment"}, + Key: "test-key", + }, + authorizeRequest: &logical.Request{ + Path: "oidc/provider/test-provider/authorize", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "client_id": "", + "scope": "openid", + "redirect_uri": "https://localhost:8251/callback", + "response_type": "id_token", + "state": "abcdefg", + "nonce": "hijklmn", + }, + }, + }, + wantErr: ErrAuthUnsupportedResponseType, + }, + { + name: "invalid authorize request with client_id not found", + args: args{ + entityID: entityID, + assignment: assignment{ + EntityIDs: []string{entityID}, + }, + client: client{ + RedirectURIs: []string{"https://localhost:8251/callback"}, + Assignments: []string{"test-assignment"}, + Key: "test-key", + }, + authorizeRequest: &logical.Request{ + Path: "oidc/provider/test-provider/authorize", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "client_id": "non-existent-client-id", + "scope": "openid", + "redirect_uri": "https://localhost:8251/callback", + "response_type": "code", + "state": "abcdefg", + "nonce": "hijklmn", + }, + }, + }, + wantErr: ErrAuthInvalidClientID, + }, + { + name: "invalid authorize request with client_id not allowed by provider", + args: args{ + entityID: entityID, + assignment: assignment{ + EntityIDs: []string{entityID}, + }, + provider: provider{ + AllowedClientIDs: []string{"not-client-id"}, + }, + client: client{ + RedirectURIs: []string{"https://localhost:8251/callback"}, + Assignments: []string{"test-assignment"}, + Key: "test-key", + }, + authorizeRequest: &logical.Request{ + Path: "oidc/provider/test-provider/authorize", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "client_id": "", + "scope": "openid", + "redirect_uri": "https://localhost:8251/callback", + "response_type": "code", + "state": "abcdefg", + "nonce": "hijklmn", + }, + }, + }, + wantErr: ErrAuthUnauthorizedClient, + }, + { + name: "invalid authorize request with missing redirect_uri", + args: args{ + entityID: entityID, + assignment: assignment{ + EntityIDs: []string{entityID}, + }, + client: client{ + RedirectURIs: []string{"https://localhost:8251/callback"}, + Assignments: []string{"test-assignment"}, + Key: "test-key", + }, + authorizeRequest: &logical.Request{ + Path: "oidc/provider/test-provider/authorize", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "client_id": "", + "scope": "openid", + "redirect_uri": "", + "response_type": "code", + "state": "abcdefg", + "nonce": "hijklmn", + }, + }, + }, + wantErr: ErrAuthInvalidRequest, + }, + { + name: "invalid authorize request with redirect_uri not allowed by client", + args: args{ + entityID: entityID, + assignment: assignment{ + EntityIDs: []string{entityID}, + }, + client: client{ + RedirectURIs: []string{"https://not.redirect.uri:8251/callback"}, + Assignments: []string{"test-assignment"}, + Key: "test-key", + }, + authorizeRequest: &logical.Request{ + Path: "oidc/provider/test-provider/authorize", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "client_id": "", + "scope": "openid", + "redirect_uri": "https://localhost:8251/callback", + "response_type": "code", + "state": "abcdefg", + "nonce": "hijklmn", + }, + }, + }, + wantErr: ErrAuthInvalidRedirectURI, + }, + { + name: "invalid authorize request with missing state", + args: args{ + entityID: entityID, + assignment: assignment{ + EntityIDs: []string{entityID}, + }, + client: client{ + RedirectURIs: []string{"https://localhost:8251/callback"}, + Assignments: []string{"test-assignment"}, + Key: "test-key", + }, + authorizeRequest: &logical.Request{ + Path: "oidc/provider/test-provider/authorize", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "client_id": "", + "scope": "openid", + "redirect_uri": "https://localhost:8251/callback", + "response_type": "code", + "state": "", + "nonce": "hijklmn", + }, + }, + }, + wantErr: ErrAuthInvalidRequest, + }, + { + name: "invalid authorize request with missing nonce", + args: args{ + entityID: entityID, + assignment: assignment{ + EntityIDs: []string{entityID}, + }, + client: client{ + RedirectURIs: []string{"https://localhost:8251/callback"}, + Assignments: []string{"test-assignment"}, + Key: "test-key", + }, + authorizeRequest: &logical.Request{ + Path: "oidc/provider/test-provider/authorize", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "client_id": "", + "scope": "openid", + "redirect_uri": "https://localhost:8251/callback", + "response_type": "code", + "state": "abcdefg", + "nonce": "", + }, + }, + }, + wantErr: ErrAuthInvalidRequest, + }, + { + name: "invalid authorize request with request parameter provided", + args: args{ + entityID: entityID, + assignment: assignment{ + EntityIDs: []string{entityID}, + }, + client: client{ + RedirectURIs: []string{"https://localhost:8251/callback"}, + Assignments: []string{"test-assignment"}, + Key: "test-key", + }, + authorizeRequest: &logical.Request{ + Path: "oidc/provider/test-provider/authorize", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "client_id": "", + "scope": "openid", + "redirect_uri": "https://localhost:8251/callback", + "response_type": "code", + "state": "abcdefg", + "nonce": "hijklmn", + "request": "header.payload.signature", + }, + }, + }, + wantErr: ErrAuthRequestNotSupported, + }, + { + name: "invalid authorize request with request_uri parameter provided", + args: args{ + entityID: entityID, + assignment: assignment{ + EntityIDs: []string{entityID}, + }, + client: client{ + RedirectURIs: []string{"https://localhost:8251/callback"}, + Assignments: []string{"test-assignment"}, + Key: "test-key", + }, + authorizeRequest: &logical.Request{ + Path: "oidc/provider/test-provider/authorize", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "client_id": "", + "scope": "openid", + "redirect_uri": "https://localhost:8251/callback", + "response_type": "code", + "state": "abcdefg", + "nonce": "hijklmn", + "request_uri": "https://client.example.org/request.jwt", + }, + }, + }, + wantErr: ErrAuthRequestURINotSupported, + }, + { + name: "invalid authorize request with identity entity ID not found", + args: args{ + entityID: "non-existent-entity", + assignment: assignment{ + EntityIDs: []string{entityID}, + }, + client: client{ + RedirectURIs: []string{"https://localhost:8251/callback"}, + Assignments: []string{"test-assignment"}, + Key: "test-key", + }, + authorizeRequest: &logical.Request{ + Path: "oidc/provider/test-provider/authorize", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "client_id": "", + "scope": "openid", + "redirect_uri": "https://localhost:8251/callback", + "response_type": "code", + "state": "abcdefg", + "nonce": "hijklmn", + }, + }, + }, + wantErr: ErrAuthAccessDenied, + }, + { + name: "invalid authorize request with entity not found in client assignment", + args: args{ + entityID: entityID, + assignment: assignment{ + EntityIDs: []string{"not-entity-id"}, + }, + client: client{ + RedirectURIs: []string{"https://localhost:8251/callback"}, + Assignments: []string{"test-assignment"}, + Key: "test-key", + }, + authorizeRequest: &logical.Request{ + Path: "oidc/provider/test-provider/authorize", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "client_id": "", + "scope": "openid", + "redirect_uri": "https://localhost:8251/callback", + "response_type": "code", + "state": "abcdefg", + "nonce": "hijklmn", + }, + }, + }, + wantErr: ErrAuthAccessDenied, + }, + { + name: "invalid authorize request with group not found in client assignment", + args: args{ + entityID: entityID, + assignment: assignment{ + GroupIDs: []string{"not-group-id"}, + }, + client: client{ + RedirectURIs: []string{"https://localhost:8251/callback"}, + Assignments: []string{"test-assignment"}, + Key: "test-key", + }, + authorizeRequest: &logical.Request{ + Path: "oidc/provider/test-provider/authorize", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "client_id": "", + "scope": "openid", + "redirect_uri": "https://localhost:8251/callback", + "response_type": "code", + "state": "abcdefg", + "nonce": "hijklmn", + }, + }, + }, + wantErr: ErrAuthAccessDenied, + }, + { + name: "invalid authorize request with negative max_age", + args: args{ + entityID: entityID, + assignment: assignment{ + EntityIDs: []string{entityID}, + }, + client: client{ + RedirectURIs: []string{"https://localhost:8251/callback"}, + Assignments: []string{"test-assignment"}, + Key: "test-key", + }, + authorizeRequest: &logical.Request{ + Path: "oidc/provider/test-provider/authorize", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "client_id": "", + "scope": "openid", + "redirect_uri": "https://localhost:8251/callback", + "response_type": "code", + "state": "abcdefg", + "nonce": "hijklmn", + "max_age": "-1", + }, + }, + }, + wantErr: ErrAuthInvalidRequest, + }, + { + name: "active re-authentication required with token creation time exceeding max_age requirement", + args: args{ + entityID: entityID, + assignment: assignment{ + EntityIDs: []string{entityID}, + }, + client: client{ + RedirectURIs: []string{"https://localhost:8251/callback"}, + Assignments: []string{"test-assignment"}, + Key: "test-key", + }, + authorizeRequest: &logical.Request{ + Path: "oidc/provider/test-provider/authorize", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "client_id": "", + "scope": "openid", + "redirect_uri": "https://localhost:8251/callback", + "response_type": "code", + "state": "abcdefg", + "nonce": "hijklmn", + "max_age": "30", + }, + }, + tokenCreationTime: func() time.Time { + return time.Now().Add(-time.Minute) + }, + }, + wantErr: ErrAuthMaxAgeReAuthenticate, + }, + { + name: "valid authorize request with token creation time within max_age requirement", + args: args{ + entityID: entityID, + assignment: assignment{ + EntityIDs: []string{entityID}, + }, + client: client{ + RedirectURIs: []string{"https://localhost:8251/callback"}, + Assignments: []string{"test-assignment"}, + Key: "test-key", + }, + authorizeRequest: &logical.Request{ + Path: "oidc/provider/test-provider/authorize", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "client_id": "", + "scope": "openid", + "redirect_uri": "https://localhost:8251/callback", + "response_type": "code", + "state": "abcdefg", + "nonce": "hijklmn", + "max_age": "30", + }, + }, + tokenCreationTime: func() time.Time { + return time.Now() + }, + }, + }, + { + name: "valid authorize request using update operation (HTTP PUT/POST)", + args: args{ + entityID: entityID, + assignment: assignment{ + EntityIDs: []string{entityID}, + }, + client: client{ + RedirectURIs: []string{"https://localhost:8251/callback"}, + Assignments: []string{"test-assignment"}, + Key: "test-key", + }, + authorizeRequest: &logical.Request{ + Path: "oidc/provider/test-provider/authorize", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "client_id": "", + "scope": "openid", + "redirect_uri": "https://localhost:8251/callback", + "response_type": "code", + "state": "abcdefg", + "nonce": "hijklmn", + }, + }, + }, + }, + { + name: "valid authorize request using read operation (HTTP GET)", + args: args{ + entityID: entityID, + assignment: assignment{ + EntityIDs: []string{entityID}, + }, + client: client{ + RedirectURIs: []string{"https://localhost:8251/callback"}, + Assignments: []string{"test-assignment"}, + Key: "test-key", + }, + authorizeRequest: &logical.Request{ + Path: "oidc/provider/test-provider/authorize", + Operation: logical.ReadOperation, + Data: map[string]interface{}{ + "client_id": "", + "scope": "openid", + "redirect_uri": "https://localhost:8251/callback", + "response_type": "code", + "state": "abcdefg", + "nonce": "hijklmn", + }, + }, + }, + }, + { + name: "valid authorize request using client assignment with group membership", + args: args{ + entityID: entityID, + assignment: assignment{ + GroupIDs: []string{groupID}, + }, + client: client{ + RedirectURIs: []string{"https://localhost:8251/callback"}, + Assignments: []string{"test-assignment"}, + Key: "test-key", + }, + authorizeRequest: &logical.Request{ + Path: "oidc/provider/test-provider/authorize", + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "client_id": "", + "scope": "openid", + "redirect_uri": "https://localhost:8251/callback", + "response_type": "code", + "state": "abcdefg", + "nonce": "hijklmn", + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a token entry and associate with the authorize request + creationTime := time.Now() + if tt.args.tokenCreationTime != nil { + creationTime = tt.args.tokenCreationTime() + } + te := &logical.TokenEntry{ + Path: "test", + Policies: []string{"default"}, + TTL: time.Hour * 24, + CreationTime: creationTime.Unix(), + } + testMakeTokenDirectly(t, c.tokenStore, te) + assert.NotEmpty(t, te.ID) + tt.args.authorizeRequest.ClientToken = te.ID + + // Create an assignment + resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/assignment/test-assignment", + Operation: logical.CreateOperation, + Data: map[string]interface{}{ + "group_ids": tt.args.assignment.GroupIDs, + "entity_ids": tt.args.assignment.EntityIDs, + }, + Storage: storage, + }) + expectSuccess(t, resp, err) + + // Create a client + resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/client/test-client", + Operation: logical.CreateOperation, + Storage: storage, + Data: map[string]interface{}{ + "key": "test-key", + "redirect_uris": tt.args.client.RedirectURIs, + "assignments": tt.args.client.Assignments, + "id_token_ttl": tt.args.client.IDTokenTTL, + "access_token_ttl": tt.args.client.AccessTokenTTL, + }, + }) + expectSuccess(t, resp, err) + + // Read the client ID + resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/client/test-client", + Operation: logical.ReadOperation, + Storage: storage, + }) + expectSuccess(t, resp, err) + assert.NotNil(t, resp.Data["client_id"]) + clientID := resp.Data["client_id"].(string) + + // Use allowed client IDs if set by test args + if len(tt.args.provider.AllowedClientIDs) == 0 { + tt.args.provider.AllowedClientIDs = []string{clientID} + } + + // Create a provider + resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/provider/test-provider", + Operation: logical.CreateOperation, + Data: map[string]interface{}{ + "issuer": tt.args.provider.Issuer, + "allowed_client_ids": tt.args.provider.AllowedClientIDs, + "scopes": tt.args.provider.Scopes, + }, + Storage: storage, + }) + expectSuccess(t, resp, err) + + // Use the client ID if set by test args + if len(tt.args.authorizeRequest.Data["client_id"].(string)) == 0 { + tt.args.authorizeRequest.Data["client_id"] = clientID + } + + // Send the request to the OIDC authorize endpoint + tt.args.authorizeRequest.Storage = storage + tt.args.authorizeRequest.EntityID = tt.args.entityID + resp, err = c.identityStore.HandleRequest(ctx, tt.args.authorizeRequest) + + // Parse the response + var res struct { + Code string `json:"code"` + State string `json:"state"` + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + } + assert.NotNil(t, resp) + assert.NotNil(t, resp.Data[logical.HTTPRawBody]) + assert.NotNil(t, resp.Data[logical.HTTPContentType]) + assert.Equal(t, "application/json", resp.Data[logical.HTTPContentType].(string)) + assert.NoError(t, json.Unmarshal(resp.Data["http_raw_body"].([]byte), &res)) + + if tt.wantErr != "" { + // Assert that we receive the expected error code + assert.Equal(t, tt.wantErr, res.Error) + assert.NotEmpty(t, res.ErrorDescription) + return + } + + // Assert that we receive an authorization code (base62) and state + expectSuccess(t, resp, err) + assert.Regexp(t, "[a-zA-Z0-9]{32}", res.Code) + assert.NotEmpty(t, res.State) + assert.Empty(t, res.Error) + assert.Empty(t, res.ErrorDescription) + }) + } +} + // TestOIDC_Path_OIDC_ProviderReadPublicKey_ProviderDoesNotExist tests that the // path can handle the read operation when the provider does not exist func TestOIDC_Path_OIDC_ProviderReadPublicKey_ProviderDoesNotExist(t *testing.T) { @@ -1091,8 +1863,8 @@ func TestOIDC_Path_OIDC_ProviderAssignment(t *testing.T) { }) expectSuccess(t, resp, err) expected := map[string]interface{}{ - "groups": []string{}, - "entities": []string{}, + "group_ids": []string{}, + "entity_ids": []string{}, } if diff := deep.Equal(expected, resp.Data); diff != nil { t.Fatal(diff) @@ -1103,8 +1875,8 @@ func TestOIDC_Path_OIDC_ProviderAssignment(t *testing.T) { Path: "oidc/assignment/test-assignment", Operation: logical.UpdateOperation, Data: map[string]interface{}{ - "groups": "my-group", - "entities": "my-entity", + "group_ids": "my-group", + "entity_ids": "my-entity", }, Storage: storage, }) @@ -1118,8 +1890,8 @@ func TestOIDC_Path_OIDC_ProviderAssignment(t *testing.T) { }) expectSuccess(t, resp, err) expected = map[string]interface{}{ - "groups": []string{"my-group"}, - "entities": []string{"my-entity"}, + "group_ids": []string{"my-group"}, + "entity_ids": []string{"my-entity"}, } if diff := deep.Equal(expected, resp.Data); diff != nil { t.Fatal(diff) @@ -1203,8 +1975,8 @@ func TestOIDC_Path_OIDC_ProviderAssignment_DeleteWithExistingClient(t *testing.T }) expectSuccess(t, resp, err) expected := map[string]interface{}{ - "groups": []string{}, - "entities": []string{}, + "group_ids": []string{}, + "entity_ids": []string{}, } if diff := deep.Equal(expected, resp.Data); diff != nil { t.Fatal(diff) @@ -1223,8 +1995,8 @@ func TestOIDC_Path_OIDC_ProviderAssignment_Update(t *testing.T) { Operation: logical.CreateOperation, Storage: storage, Data: map[string]interface{}{ - "groups": "my-group", - "entities": "my-entity", + "group_ids": "my-group", + "entity_ids": "my-entity", }, }) expectSuccess(t, resp, err) @@ -1237,8 +2009,8 @@ func TestOIDC_Path_OIDC_ProviderAssignment_Update(t *testing.T) { }) expectSuccess(t, resp, err) expected := map[string]interface{}{ - "groups": []string{"my-group"}, - "entities": []string{"my-entity"}, + "group_ids": []string{"my-group"}, + "entity_ids": []string{"my-entity"}, } if diff := deep.Equal(expected, resp.Data); diff != nil { t.Fatal(diff) @@ -1249,7 +2021,7 @@ func TestOIDC_Path_OIDC_ProviderAssignment_Update(t *testing.T) { Path: "oidc/assignment/test-assignment", Operation: logical.UpdateOperation, Data: map[string]interface{}{ - "groups": "my-group2", + "group_ids": "my-group2", }, Storage: storage, }) @@ -1263,8 +2035,8 @@ func TestOIDC_Path_OIDC_ProviderAssignment_Update(t *testing.T) { }) expectSuccess(t, resp, err) expected = map[string]interface{}{ - "groups": []string{"my-group2"}, - "entities": []string{"my-entity"}, + "group_ids": []string{"my-group2"}, + "entity_ids": []string{"my-entity"}, } if diff := deep.Equal(expected, resp.Data); diff != nil { t.Fatal(diff) diff --git a/vault/identity_store_oidc_test.go b/vault/identity_store_oidc_test.go index ec1131074..5274ada71 100644 --- a/vault/identity_store_oidc_test.go +++ b/vault/identity_store_oidc_test.go @@ -1307,7 +1307,7 @@ func TestOIDC_isTargetNamespacedKey(t *testing.T) { } func TestOIDC_Flush(t *testing.T) { - c := newOIDCCache() + c := newOIDCCache(gocache.NoExpiration, gocache.NoExpiration) ns := []*namespace.Namespace{ noNamespace, // ns[0] is nilNamespace {ID: "ns1"}, @@ -1367,7 +1367,7 @@ func TestOIDC_Flush(t *testing.T) { } func TestOIDC_CacheNamespaceNilCheck(t *testing.T) { - cache := newOIDCCache() + cache := newOIDCCache(gocache.NoExpiration, gocache.NoExpiration) if _, _, err := cache.Get(nil, "foo"); err == nil { t.Fatal("expected error, got nil") diff --git a/vault/identity_store_schema.go b/vault/identity_store_schema.go index bd15bc3b6..76eab6664 100644 --- a/vault/identity_store_schema.go +++ b/vault/identity_store_schema.go @@ -11,6 +11,7 @@ const ( entityAliasesTable = "entity_aliases" groupsTable = "groups" groupAliasesTable = "group_aliases" + oidcClientsTable = "oidc_clients" ) func identityStoreSchema(lowerCaseName bool) *memdb.DBSchema { @@ -23,6 +24,7 @@ func identityStoreSchema(lowerCaseName bool) *memdb.DBSchema { aliasesTableSchema, groupsTableSchema, groupAliasesTableSchema, + oidcClientsTableSchema, } for _, schemaFunc := range schemas { @@ -213,3 +215,38 @@ func groupAliasesTableSchema(lowerCaseName bool) *memdb.TableSchema { }, } } + +func oidcClientsTableSchema(_ bool) *memdb.TableSchema { + return &memdb.TableSchema{ + Name: oidcClientsTable, + Indexes: map[string]*memdb.IndexSchema{ + "id": { + Name: "id", + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "ClientID", + }, + }, + "name": { + Name: "name", + Unique: true, + Indexer: &memdb.CompoundIndex{ + Indexes: []memdb.Indexer{ + &memdb.StringFieldIndex{ + Field: "NamespaceID", + }, + &memdb.StringFieldIndex{ + Field: "Name", + }, + }, + }, + }, + "namespace_id": { + Name: "namespace_id", + Indexer: &memdb.StringFieldIndex{ + Field: "NamespaceID", + }, + }, + }, + } +} diff --git a/vault/identity_store_structs.go b/vault/identity_store_structs.go index fe69cb1a0..144453c2d 100644 --- a/vault/identity_store_structs.go +++ b/vault/identity_store_structs.go @@ -65,6 +65,10 @@ type IdentityStore struct { // will invalidate the cache. oidcCache *oidcCache + // oidcAuthCodeCache stores OIDC authorization codes to be exchanged + // for an ID token during an authorization code flow. + oidcAuthCodeCache *oidcCache + // logger is the server logger copied over from core logger log.Logger @@ -87,6 +91,7 @@ type IdentityStore struct { metrics metricsutil.Metrics totpPersister TOTPPersister groupUpdater GroupUpdater + tokenStorer TokenStorer } type groupDiff struct { @@ -124,3 +129,9 @@ type GroupUpdater interface { } var _ GroupUpdater = &Core{} + +type TokenStorer interface { + LookupToken(ctx context.Context, token string) (*logical.TokenEntry, error) +} + +var _ TokenStorer = &Core{} diff --git a/vault/identity_store_util.go b/vault/identity_store_util.go index 53b347f4e..0d1061726 100644 --- a/vault/identity_store_util.go +++ b/vault/identity_store_util.go @@ -39,7 +39,11 @@ func (c *Core) loadIdentityStoreArtifacts(ctx context.Context) error { if err != nil { return err } - return c.identityStore.loadGroups(ctx) + err = c.identityStore.loadGroups(ctx) + if err != nil { + return err + } + return c.identityStore.loadOIDCClients(ctx) } if !c.loadCaseSensitiveIdentityStore { @@ -78,6 +82,39 @@ func (i *IdentityStore) sanitizeName(name string) string { return strings.ToLower(name) } +func (i *IdentityStore) loadOIDCClients(ctx context.Context) error { + i.logger.Debug("identity loading OIDC clients") + + clients, err := i.view.List(ctx, clientPath) + if err != nil { + return err + } + + txn := i.db.Txn(true) + defer txn.Abort() + for _, name := range clients { + entry, err := i.view.Get(ctx, clientPath+name) + if err != nil { + return err + } + if entry == nil { + continue + } + + var client client + if err := entry.DecodeJSON(&client); err != nil { + return err + } + + if err := i.memDBUpsertClientInTxn(txn, &client); err != nil { + return err + } + } + txn.Commit() + + return nil +} + func (i *IdentityStore) loadGroups(ctx context.Context) error { i.logger.Debug("identity loading groups") existing, err := i.groupPacker.View().List(ctx, groupBucketsPrefix) diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index b3548e2f6..8830b4eac 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -3306,15 +3306,15 @@ func TestHandlePoliciesPasswordGenerate(t *testing.T) { t.Fatalf("no error expected, got: %s", err) } - assert(t, actualResp != nil, "response is nil") - assert(t, actualResp.Data != nil, "expected data, got nil") + assertTrue(t, actualResp != nil, "response is nil") + assertTrue(t, actualResp.Data != nil, "expected data, got nil") assertHasKey(t, actualResp.Data, "password", "password key not found in data") assertIsString(t, actualResp.Data["password"], "password key should have a string value") password := actualResp.Data["password"].(string) // Delete the password so the rest of the response can be compared delete(actualResp.Data, "password") - assert(t, reflect.DeepEqual(actualResp, expectedResp), "Actual response: %#v\nExpected response: %#v", actualResp, expectedResp) + assertTrue(t, reflect.DeepEqual(actualResp, expectedResp), "Actual response: %#v\nExpected response: %#v", actualResp, expectedResp) // Check to make sure the password is correctly formatted passwordLength := len([]rune(password)) @@ -3331,7 +3331,7 @@ func TestHandlePoliciesPasswordGenerate(t *testing.T) { }) } -func assert(t *testing.T, pass bool, f string, vals ...interface{}) { +func assertTrue(t *testing.T, pass bool, f string, vals ...interface{}) { t.Helper() if !pass { t.Fatalf(f, vals...) diff --git a/vault/policy_store.go b/vault/policy_store.go index 20a17f1e7..451339814 100644 --- a/vault/policy_store.go +++ b/vault/policy_store.go @@ -150,6 +150,11 @@ path "sys/tools/hash/*" { path "sys/control-group/request" { capabilities = ["update"] } + +# Allow a token to make requests to the Authorization Endpoint for OIDC providers. +path "identity/oidc/provider/+/authorize" { + capabilities = ["read", "update"] +} ` ) diff --git a/vault/router.go b/vault/router.go index 26e63ed14..6d61d7f0e 100644 --- a/vault/router.go +++ b/vault/router.go @@ -563,6 +563,7 @@ func (r *Router) routeCommon(ctx context.Context, req *logical.Request, existenc switch { case strings.HasPrefix(originalPath, "auth/token/"): case strings.HasPrefix(originalPath, "sys/"): + case strings.HasPrefix(originalPath, "identity/"): case strings.HasPrefix(originalPath, cubbyholeMountPath): if req.Operation == logical.RollbackOperation { // Backend doesn't support this and it can't properly look up a