identity/oidc: Adds proof key for code exchange (PKCE) support (#13917)

This commit is contained in:
Austin Gebauer 2022-02-15 12:02:22 -08:00 committed by GitHub
parent 42bdcf0657
commit 34d295e28f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 868 additions and 38 deletions

3
changelog/13917.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:improvement
identity/oidc: Adds proof key for code exchange (PKCE) support to OIDC providers.
```

View File

@ -39,11 +39,12 @@ const (
`
)
// TestOIDC_Auth_Code_Flow_CAP_Client tests the authorization code flow
// using a Vault OIDC provider. The test uses the CAP OIDC client to verify
// that the Vault OIDC provider's responses pass the various client-side
// validation requirements of the OIDC spec.
func TestOIDC_Auth_Code_Flow_CAP_Client(t *testing.T) {
// TestOIDC_Auth_Code_Flow_Confidential_CAP_Client tests the authorization code
// flow using a Vault OIDC provider. The test uses the CAP OIDC client to verify
// that the Vault OIDC provider's responses pass the various client-side validation
// requirements of the OIDC spec. This test uses a confidential client which has
// a client secret and authenticates to the token endpoint.
func TestOIDC_Auth_Code_Flow_Confidential_CAP_Client(t *testing.T) {
cluster := setupOIDCTestCluster(t, 2)
defer cluster.Cleanup()
active := cluster.Cores[0].Client
@ -131,8 +132,8 @@ func TestOIDC_Auth_Code_Flow_CAP_Client(t *testing.T) {
})
require.NoError(t, err)
// Create a client
_, err = active.Logical().Write("identity/oidc/client/test-client", map[string]interface{}{
// Create a confidential client
_, err = active.Logical().Write("identity/oidc/client/confidential", map[string]interface{}{
"key": "test-key",
"redirect_uris": []string{testRedirectURI},
"assignments": []string{"test-assignment"},
@ -142,7 +143,7 @@ func TestOIDC_Auth_Code_Flow_CAP_Client(t *testing.T) {
require.NoError(t, err)
// Read the client ID and secret in order to configure the OIDC client
resp, err = active.Logical().Read("identity/oidc/client/test-client")
resp, err = active.Logical().Read("identity/oidc/client/confidential")
require.NoError(t, err)
clientID := resp.Data["client_id"].(string)
clientSecret := resp.Data["client_secret"].(string)
@ -191,6 +192,10 @@ func TestOIDC_Auth_Code_Flow_CAP_Client(t *testing.T) {
require.NoError(t, err)
defer p.Done()
// Create the client-side PKCE code verifier
v, err := oidc.NewCodeVerifier()
require.NoError(t, err)
type args struct {
useStandby bool
options []oidc.Option
@ -255,6 +260,21 @@ func TestOIDC_Auth_Code_Flow_CAP_Client(t *testing.T) {
"auth_time": %d
}`, discovery.Issuer, clientID, entityID, expectedAuthTime),
},
{
name: "active: authorization code flow with Proof Key for Code Exchange (PKCE)",
args: args{
options: []oidc.Option{
oidc.WithScopes("openid"),
oidc.WithPKCE(v),
},
},
expected: fmt.Sprintf(`{
"iss": "%s",
"aud": "%s",
"sub": "%s",
"namespace": "root"
}`, discovery.Issuer, clientID, entityID),
},
{
name: "standby: authorization code flow with additional scopes",
args: args{
@ -369,6 +389,342 @@ func TestOIDC_Auth_Code_Flow_CAP_Client(t *testing.T) {
}
}
// TestOIDC_Auth_Code_Flow_Public_CAP_Client tests the authorization code flow using
// a Vault OIDC provider. The test uses the CAP OIDC client to verify that the Vault
// OIDC provider's responses pass the various client-side validation requirements of
// the OIDC spec. This test uses a public client which does not have a client secret
// and always uses proof key for code exchange (PKCE).
func TestOIDC_Auth_Code_Flow_Public_CAP_Client(t *testing.T) {
cluster := setupOIDCTestCluster(t, 2)
defer cluster.Cleanup()
active := cluster.Cores[0].Client
standby := cluster.Cores[1].Client
// Create an entity with some metadata
resp, err := active.Logical().Write("identity/entity", map[string]interface{}{
"name": "test-entity",
"metadata": map[string]string{
"email": "test@hashicorp.com",
"phone_number": "123-456-7890",
},
})
require.NoError(t, err)
entityID := resp.Data["id"].(string)
// Create a group
resp, err = active.Logical().Write("identity/group", map[string]interface{}{
"name": "engineering",
"member_entity_ids": []string{entityID},
})
require.NoError(t, err)
groupID := resp.Data["id"].(string)
// Create a policy that allows updating the provider
err = active.Sys().PutPolicy("test-policy", `
path "identity/oidc/provider/test-provider" {
capabilities = ["update"]
}
`)
require.NoError(t, err)
// Enable userpass auth and create a user
err = active.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{
Type: "userpass",
})
require.NoError(t, err)
_, err = active.Logical().Write("auth/userpass/users/end-user", map[string]interface{}{
"password": testPassword,
"token_policies": "test-policy",
})
require.NoError(t, err)
// Get the userpass mount accessor
mounts, err := active.Sys().ListAuth()
require.NoError(t, err)
var mountAccessor string
for k, v := range mounts {
if k == "userpass/" {
mountAccessor = v.Accessor
break
}
}
require.NotEmpty(t, mountAccessor)
// Create an entity alias
_, err = active.Logical().Write("identity/entity-alias", map[string]interface{}{
"name": "end-user",
"canonical_id": entityID,
"mount_accessor": mountAccessor,
})
require.NoError(t, err)
// Create some custom scopes
_, err = active.Logical().Write("identity/oidc/scope/groups", map[string]interface{}{
"template": testGroupScopeTemplate,
})
require.NoError(t, err)
_, err = active.Logical().Write("identity/oidc/scope/user", map[string]interface{}{
"template": fmt.Sprintf(testUserScopeTemplate, mountAccessor),
})
require.NoError(t, err)
// Create a key
_, err = active.Logical().Write("identity/oidc/key/test-key", map[string]interface{}{
"allowed_client_ids": []string{"*"},
"algorithm": "RS256",
})
require.NoError(t, err)
// Create an assignment
_, err = active.Logical().Write("identity/oidc/assignment/test-assignment", map[string]interface{}{
"entity_ids": []string{entityID},
"group_ids": []string{groupID},
})
require.NoError(t, err)
// Create a public client
_, err = active.Logical().Write("identity/oidc/client/public", map[string]interface{}{
"key": "test-key",
"redirect_uris": []string{testRedirectURI},
"assignments": []string{"test-assignment"},
"id_token_ttl": "1h",
"access_token_ttl": "30m",
"client_type": "public",
})
require.NoError(t, err)
// Read the client ID in order to configure the OIDC client
resp, err = active.Logical().Read("identity/oidc/client/public")
require.NoError(t, err)
clientID := resp.Data["client_id"].(string)
// Create the OIDC provider
_, err = active.Logical().Write("identity/oidc/provider/test-provider", map[string]interface{}{
"allowed_client_ids": []string{clientID},
"scopes_supported": []string{"user", "groups"},
})
require.NoError(t, err)
// We aren't going to open up a browser to facilitate the login and redirect
// from this test, so we'll log in via userpass and set the client's token as
// the token that results from the authentication.
resp, err = active.Logical().Write("auth/userpass/login/end-user", map[string]interface{}{
"password": testPassword,
})
require.NoError(t, err)
clientToken := resp.Auth.ClientToken
// Look up the token to get its creation time. This will be used for test
// cases that make assertions on the max_age parameter and auth_time claim.
resp, err = active.Logical().Write("auth/token/lookup", map[string]interface{}{
"token": clientToken,
})
require.NoError(t, err)
expectedAuthTime, err := strconv.Atoi(string(resp.Data["creation_time"].(json.Number)))
require.NoError(t, err)
// Read the issuer from the OIDC provider's discovery document
var discovery struct {
Issuer string `json:"issuer"`
}
decodeRawRequest(t, active, http.MethodGet,
"/v1/identity/oidc/provider/test-provider/.well-known/openid-configuration",
nil, &discovery)
// Create the client-side OIDC provider config with client secret intentionally empty
clientSecret := oidc.ClientSecret("")
pc, err := oidc.NewConfig(discovery.Issuer, clientID, clientSecret, []oidc.Alg{oidc.RS256},
[]string{testRedirectURI}, oidc.WithProviderCA(string(cluster.CACertPEM)))
require.NoError(t, err)
// Create the client-side OIDC provider
p, err := oidc.NewProvider(pc)
require.NoError(t, err)
defer p.Done()
type args struct {
useStandby bool
options []oidc.Option
}
tests := []struct {
name string
args args
expected string
}{
{
name: "active: authorization code flow",
args: args{
options: []oidc.Option{
oidc.WithScopes("openid user"),
},
},
expected: fmt.Sprintf(`{
"iss": "%s",
"aud": "%s",
"sub": "%s",
"namespace": "root",
"username": "end-user",
"contact": {
"email": "test@hashicorp.com",
"phone_number": "123-456-7890"
}
}`, discovery.Issuer, clientID, entityID),
},
{
name: "active: authorization code flow with additional scopes",
args: args{
options: []oidc.Option{
oidc.WithScopes("openid user groups"),
},
},
expected: fmt.Sprintf(`{
"iss": "%s",
"aud": "%s",
"sub": "%s",
"namespace": "root",
"username": "end-user",
"contact": {
"email": "test@hashicorp.com",
"phone_number": "123-456-7890"
},
"groups": ["engineering"]
}`, discovery.Issuer, clientID, entityID),
},
{
name: "active: authorization code flow with max_age parameter",
args: args{
options: []oidc.Option{
oidc.WithScopes("openid"),
oidc.WithMaxAge(60),
},
},
expected: fmt.Sprintf(`{
"iss": "%s",
"aud": "%s",
"sub": "%s",
"namespace": "root",
"auth_time": %d
}`, discovery.Issuer, clientID, entityID, expectedAuthTime),
},
{
name: "standby: authorization code flow with additional scopes",
args: args{
useStandby: true,
options: []oidc.Option{
oidc.WithScopes("openid user groups"),
},
},
expected: fmt.Sprintf(`{
"iss": "%s",
"aud": "%s",
"sub": "%s",
"namespace": "root",
"username": "end-user",
"contact": {
"email": "test@hashicorp.com",
"phone_number": "123-456-7890"
},
"groups": ["engineering"]
}`, discovery.Issuer, clientID, entityID),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := active
if tt.args.useStandby {
client = standby
}
client.SetToken(clientToken)
// Update allowed client IDs before the authentication flow
_, err = client.Logical().Write("identity/oidc/provider/test-provider", map[string]interface{}{
"allowed_client_ids": []string{clientID},
})
require.NoError(t, err)
// Create the required client-side PKCE code verifier.
v, err := oidc.NewCodeVerifier()
require.NoError(t, err)
options := append([]oidc.Option{oidc.WithPKCE(v)}, tt.args.options...)
// Create the client-side OIDC request state
oidcRequest, err := oidc.NewRequest(10*time.Minute, testRedirectURI, options...)
require.NoError(t, err)
// Get the URL for the authorization endpoint from the OIDC client
authURL, err := p.AuthURL(context.Background(), oidcRequest)
require.NoError(t, err)
parsedAuthURL, err := url.Parse(authURL)
require.NoError(t, err)
// This replace only occurs because we're not using the browser in this test
authURLPath := strings.Replace(parsedAuthURL.Path, "/ui/vault/", "/v1/", 1)
// Kick off the authorization code flow
var authResp struct {
Code string `json:"code"`
State string `json:"state"`
}
decodeRawRequest(t, client, http.MethodGet, authURLPath, parsedAuthURL.Query(), &authResp)
// The returned state must match the OIDC client state
require.Equal(t, oidcRequest.State(), authResp.State)
// Exchange the authorization code for an ID token and access token.
// The ID token signature is verified using the provider's public keys after
// the exchange takes place. The ID token is also validated according to the
// client-side requirements of the OIDC spec. See the validation code at:
// - https://github.com/hashicorp/cap/blob/main/oidc/provider.go#L240
// - https://github.com/hashicorp/cap/blob/main/oidc/provider.go#L441
token, err := p.Exchange(context.Background(), oidcRequest, authResp.State, authResp.Code)
require.NoError(t, err)
require.NotNil(t, token)
idToken := token.IDToken()
accessToken := token.StaticTokenSource()
// Get the ID token claims
allClaims := make(map[string]interface{})
require.NoError(t, idToken.Claims(&allClaims))
// Get the sub claim for userinfo validation
require.NotEmpty(t, allClaims["sub"])
subject := allClaims["sub"].(string)
// Request userinfo using the access token
err = p.UserInfo(context.Background(), accessToken, subject, &allClaims)
require.NoError(t, err)
// Assert that claims computed during the flow (i.e., not known
// ahead of time in this test) are present as top-level keys
for _, claim := range []string{"iat", "exp", "nonce", "at_hash", "c_hash"} {
_, ok := allClaims[claim]
require.True(t, ok)
}
// Assert that all other expected claims are populated
expectedClaims := make(map[string]interface{})
require.NoError(t, json.Unmarshal([]byte(tt.expected), &expectedClaims))
for k, expectedVal := range expectedClaims {
actualVal, ok := allClaims[k]
require.True(t, ok)
require.EqualValues(t, expectedVal, actualVal)
}
// Assert that the access token is no longer able to obtain user info
// after removing the client from the provider's allowed client ids
_, err = client.Logical().Write("identity/oidc/provider/test-provider", map[string]interface{}{
"allowed_client_ids": []string{},
})
require.NoError(t, err)
err = p.UserInfo(context.Background(), accessToken, subject, &allClaims)
require.Error(t, err)
require.Equal(t, `Provider.UserInfo: provider UserInfo request failed: 403 Forbidden: {"error":"access_denied","error_description":"client is not authorized to use the provider"}`,
err.Error())
})
}
}
func setupOIDCTestCluster(t *testing.T, numCores int) *vault.TestCluster {
t.Helper()

View File

@ -26,13 +26,15 @@ import (
const (
// OIDC-related constants
openIDScope = "openid"
scopesDelimiter = " "
accessTokenScopesMeta = "scopes"
accessTokenClientIDMeta = "client_id"
clientIDLength = 32
clientSecretLength = 64
clientSecretPrefix = "hvo_secret_"
openIDScope = "openid"
scopesDelimiter = " "
accessTokenScopesMeta = "scopes"
accessTokenClientIDMeta = "client_id"
clientIDLength = 32
clientSecretLength = 64
clientSecretPrefix = "hvo_secret_"
codeChallengeMethodPlain = "plain"
codeChallengeMethodS256 = "S256"
// Storage path constants
oidcProviderPrefix = "oidc_provider/"
@ -95,12 +97,31 @@ type client struct {
Key string `json:"key"`
IDTokenTTL time.Duration `json:"id_token_ttl"`
AccessTokenTTL time.Duration `json:"access_token_ttl"`
Type clientType `json:"type"`
// Generated values that are used in OIDC endpoints
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
}
type clientType int
const (
confidential clientType = iota
public
)
func (k clientType) String() string {
switch k {
case confidential:
return "confidential"
case public:
return "public"
default:
return "unknown"
}
}
type provider struct {
Issuer string `json:"issuer"`
AllowedClientIDs []string `json:"allowed_client_ids"`
@ -127,13 +148,15 @@ type providerDiscovery struct {
}
type authCodeCacheEntry struct {
provider string
clientID string
entityID string
redirectURI string
nonce string
scopes []string
authTime time.Time
provider string
clientID string
entityID string
redirectURI string
nonce string
scopes []string
authTime time.Time
codeChallenge string
codeChallengeMethod string
}
func oidcProviderPaths(i *IdentityStore) []*framework.Path {
@ -256,6 +279,11 @@ func oidcProviderPaths(i *IdentityStore) []*framework.Path {
Description: "The time-to-live for access tokens obtained by the client.",
Default: "24h",
},
"client_type": {
Type: framework.TypeString,
Description: "The client type based on its ability to maintain confidentiality of credentials. The following client types are supported: 'confidential', 'public'. Defaults to 'confidential'.",
Default: "confidential",
},
},
Operations: map[logical.Operation]framework.OperationHandler{
logical.UpdateOperation: &framework.PathOperation{
@ -405,6 +433,15 @@ func oidcProviderPaths(i *IdentityStore) []*framework.Path {
Type: framework.TypeInt,
Description: "The allowable elapsed time in seconds since the last time the end-user was actively authenticated.",
},
"code_challenge": {
Type: framework.TypeString,
Description: "The code challenge derived from the code verifier.",
},
"code_challenge_method": {
Type: framework.TypeString,
Description: "The method that was used to derive the code challenge. The following methods are supported: 'S256', 'plain'. Defaults to 'plain'.",
Default: codeChallengeMethodPlain,
},
},
Operations: map[logical.Operation]framework.OperationHandler{
logical.ReadOperation: &framework.PathOperation{
@ -443,10 +480,23 @@ func oidcProviderPaths(i *IdentityStore) []*framework.Path {
Description: "The callback location where the authentication response was sent.",
Required: true,
},
// The client_id and client_secret are provided to the token endpoint via
// the client_secret_basic authentication method, which uses the HTTP Basic
// authentication scheme. See the OIDC spec for details at:
"code_verifier": {
Type: framework.TypeString,
Description: "The code verifier associated with the authorization code.",
},
// For confidential clients, the client_id and client_secret are provided to
// the token endpoint via the 'client_secret_basic' authentication method, which
// uses the HTTP Basic authentication scheme. See the OIDC spec for details at:
// https://openid.net/specs/openid-connect-core-1_0.html#ClientAuthentication
// For public clients, the client_id is required and a client_secret does
// not exist. This means that public clients use the 'none' authentication
// method. However, public clients are required to use Proof Key for Code
// Exchange (PKCE) when using the authorization code flow.
"client_id": {
Type: framework.TypeString,
Description: "The ID of the requesting client.",
},
},
Operations: map[logical.Operation]framework.OperationHandler{
logical.UpdateOperation: &framework.PathOperation{
@ -984,6 +1034,22 @@ func (i *IdentityStore) pathOIDCCreateUpdateClient(ctx context.Context, req *log
client.AccessTokenTTL = time.Duration(d.Get("access_token_ttl").(int)) * time.Second
}
if clientTypeRaw, ok := d.GetOk("client_type"); ok {
clientType := clientTypeRaw.(string)
if req.Operation == logical.UpdateOperation && client.Type.String() != clientType {
return logical.ErrorResponse("client_type modification is not allowed"), nil
}
switch clientType {
case confidential.String():
client.Type = confidential
case public.String():
client.Type = public
default:
return logical.ErrorResponse("invalid client_type %q", clientType), nil
}
}
if client.ClientID == "" {
// generate client_id
clientID, err := base62.Random(clientIDLength)
@ -993,7 +1059,8 @@ func (i *IdentityStore) pathOIDCCreateUpdateClient(ctx context.Context, req *log
client.ClientID = clientID
}
if client.ClientSecret == "" {
// client secrets are only generated for confidential clients
if client.Type == confidential && client.ClientSecret == "" {
// generate client_secret
clientSecret, err := base62.Random(clientSecretLength)
if err != nil {
@ -1040,7 +1107,7 @@ func (i *IdentityStore) pathOIDCReadClient(ctx context.Context, req *logical.Req
return nil, nil
}
return &logical.Response{
resp := &logical.Response{
Data: map[string]interface{}{
"redirect_uris": client.RedirectURIs,
"assignments": client.Assignments,
@ -1048,9 +1115,15 @@ func (i *IdentityStore) pathOIDCReadClient(ctx context.Context, req *logical.Req
"id_token_ttl": int64(client.IDTokenTTL.Seconds()),
"access_token_ttl": int64(client.AccessTokenTTL.Seconds()),
"client_id": client.ClientID,
"client_secret": client.ClientSecret,
"client_type": client.Type.String(),
},
}, nil
}
if client.Type == confidential {
resp.Data["client_secret"] = client.ClientSecret
}
return resp, nil
}
// pathOIDCDeleteClient is used to delete a client
@ -1561,6 +1634,37 @@ func (i *IdentityStore) pathOIDCAuthorize(ctx context.Context, req *logical.Requ
scopes: scopes,
}
// Validate the Proof Key for Code Exchange (PKCE) code challenge and code challenge
// method. PKCE is required for public clients and optional for confidential clients.
// See details at https://datatracker.ietf.org/doc/html/rfc7636.
codeChallengeRaw, okCodeChallenge := d.GetOk("code_challenge")
if !okCodeChallenge && client.Type == public {
return authResponse("", state, ErrAuthInvalidRequest, "PKCE is required for public clients")
}
if okCodeChallenge {
codeChallenge := codeChallengeRaw.(string)
// Validate the code challenge method
codeChallengeMethod := d.Get("code_challenge_method").(string)
switch codeChallengeMethod {
case codeChallengeMethodPlain, codeChallengeMethodS256:
case "":
codeChallengeMethod = codeChallengeMethodPlain
default:
return authResponse("", state, ErrAuthInvalidRequest, "invalid code_challenge_method")
}
// Validate the code challenge
if len(codeChallenge) < 43 || len(codeChallenge) > 128 {
return authResponse("", state, ErrAuthInvalidRequest, "invalid code_challenge")
}
// Associate the code challenge and method with the authorization code.
// This will be used to verify the code verifier in the token exchange.
authCodeEntry.codeChallenge = codeChallenge
authCodeEntry.codeChallengeMethod = codeChallengeMethod
}
// Validate the optional max_age parameter to check if an active re-authentication
// of the user should occur. Re-authentication will be requested if the last time
// the token actively authenticated exceeds the given max_age requirement. Returning
@ -1662,13 +1766,13 @@ func (i *IdentityStore) pathOIDCToken(ctx context.Context, req *logical.Request,
return tokenResponse(nil, ErrTokenInvalidRequest, "provider not found")
}
// Authenticate the client using the client_secret_basic authentication method.
// The authentication method uses the HTTP Basic authentication scheme. Details at
// https://openid.net/specs/openid-connect-core-1_0.html#ClientAuthentication
headerReq := &http.Request{Header: req.Headers}
clientID, clientSecret, ok := headerReq.BasicAuth()
if !ok {
return tokenResponse(nil, ErrTokenInvalidRequest, "client failed to authenticate")
// Get the client ID
clientID, clientSecret, okBasicAuth := basicAuth(req)
if !okBasicAuth {
clientID = d.Get("client_id").(string)
if clientID == "" {
return tokenResponse(nil, ErrTokenInvalidRequest, "client_id parameter is required")
}
}
client, err := i.clientByID(ctx, req.Storage, clientID)
if err != nil {
@ -1678,7 +1782,12 @@ func (i *IdentityStore) pathOIDCToken(ctx context.Context, req *logical.Request,
i.Logger().Debug("client failed to authenticate with client not found", "client_id", clientID)
return tokenResponse(nil, ErrTokenInvalidClient, "client failed to authenticate")
}
if subtle.ConstantTimeCompare([]byte(client.ClientSecret), []byte(clientSecret)) == 0 {
// Authenticate the client using the client_secret_basic authentication method if it's a
// confidential client. The authentication method uses the HTTP Basic authentication scheme.
// Details at https://openid.net/specs/openid-connect-core-1_0.html#ClientAuthentication
if client.Type == confidential &&
subtle.ConstantTimeCompare([]byte(client.ClientSecret), []byte(clientSecret)) == 0 {
i.Logger().Debug("client failed to authenticate with invalid client secret", "client_id", clientID)
return tokenResponse(nil, ErrTokenInvalidClient, "client failed to authenticate")
}
@ -1771,6 +1880,28 @@ func (i *IdentityStore) pathOIDCToken(ctx context.Context, req *logical.Request,
return tokenResponse(nil, ErrTokenInvalidRequest, "identity entity not authorized by client assignment")
}
// Validate the PKCE code verifier. See details at
// https://datatracker.ietf.org/doc/html/rfc7636#section-4.6.
usedPKCE := authCodeUsedPKCE(authCodeEntry)
codeVerifier := d.Get("code_verifier").(string)
switch {
case !usedPKCE && client.Type == public:
return tokenResponse(nil, ErrTokenInvalidRequest, "PKCE is required for public clients")
case !usedPKCE && codeVerifier != "":
return tokenResponse(nil, ErrTokenInvalidRequest, "unexpected code_verifier for token exchange")
case usedPKCE && codeVerifier == "":
return tokenResponse(nil, ErrTokenInvalidRequest, "expected code_verifier for token exchange")
case usedPKCE:
codeChallenge, err := computeCodeChallenge(codeVerifier, authCodeEntry.codeChallengeMethod)
if err != nil {
return tokenResponse(nil, ErrTokenServerError, err.Error())
}
if subtle.ConstantTimeCompare([]byte(codeChallenge), []byte(authCodeEntry.codeChallenge)) == 0 {
return tokenResponse(nil, ErrTokenInvalidGrant, "invalid code_verifier for token exchange")
}
}
// The access token is a Vault batch token with a policy that only
// provides access to the issuing provider's userinfo endpoint.
accessTokenIssuedAt := time.Now()

View File

@ -270,6 +270,138 @@ func TestOIDC_Path_OIDC_Token(t *testing.T) {
},
wantErr: ErrTokenInvalidRequest,
},
{
name: "invalid token request with empty code_verifier",
args: args{
clientReq: testClientReq(s),
providerReq: testProviderReq(s, clientID),
assignmentReq: testAssignmentReq(s, entityID, groupID),
authorizeReq: func() *logical.Request {
req := testAuthorizeReq(s, clientID)
req.Data["code_challenge_method"] = "plain"
req.Data["code_challenge"] = "43_char_min_abcdefghijklmnopqrstuvwxyzabcde"
return req
}(),
tokenReq: func() *logical.Request {
req := testTokenReq(s, "", clientID, clientSecret)
req.Data["code_verifier"] = ""
return req
}(),
},
wantErr: ErrTokenInvalidRequest,
},
{
name: "invalid token request with code_verifier provided for non-PKCE flow",
args: args{
clientReq: testClientReq(s),
providerReq: testProviderReq(s, clientID),
assignmentReq: testAssignmentReq(s, entityID, groupID),
authorizeReq: testAuthorizeReq(s, clientID),
tokenReq: func() *logical.Request {
req := testTokenReq(s, "", clientID, clientSecret)
req.Data["code_verifier"] = "pkce_not_used_in_authorize_request"
return req
}(),
},
wantErr: ErrTokenInvalidRequest,
},
{
name: "invalid token request with incorrect plain code_verifier",
args: args{
clientReq: testClientReq(s),
providerReq: testProviderReq(s, clientID),
assignmentReq: testAssignmentReq(s, entityID, groupID),
authorizeReq: func() *logical.Request {
req := testAuthorizeReq(s, clientID)
req.Data["code_challenge_method"] = "plain"
req.Data["code_challenge"] = "43_char_min_abcdefghijklmnopqrstuvwxyzabcde"
return req
}(),
tokenReq: func() *logical.Request {
req := testTokenReq(s, "", clientID, clientSecret)
req.Data["code_verifier"] = "wont_match_challenge"
return req
}(),
},
wantErr: ErrTokenInvalidGrant,
},
{
name: "invalid token request with incorrect S256 code_verifier",
args: args{
clientReq: testClientReq(s),
providerReq: testProviderReq(s, clientID),
assignmentReq: testAssignmentReq(s, entityID, groupID),
authorizeReq: func() *logical.Request {
req := testAuthorizeReq(s, clientID)
req.Data["code_challenge_method"] = "S256"
req.Data["code_challenge"] = "43_char_min_abcdefghijklmnopqrstuvwxyzabcde"
return req
}(),
tokenReq: func() *logical.Request {
req := testTokenReq(s, "", clientID, clientSecret)
req.Data["code_verifier"] = "wont_hash_to_challenge"
return req
}(),
},
wantErr: ErrTokenInvalidGrant,
},
{
name: "valid token request with plain code_challenge_method",
args: args{
clientReq: testClientReq(s),
providerReq: testProviderReq(s, clientID),
assignmentReq: testAssignmentReq(s, entityID, groupID),
authorizeReq: func() *logical.Request {
req := testAuthorizeReq(s, clientID)
req.Data["code_challenge_method"] = "plain"
req.Data["code_challenge"] = "43_char_min_abcdefghijklmnopqrstuvwxyzabcde"
return req
}(),
tokenReq: func() *logical.Request {
req := testTokenReq(s, "", clientID, clientSecret)
req.Data["code_verifier"] = "43_char_min_abcdefghijklmnopqrstuvwxyzabcde"
return req
}(),
},
},
{
name: "valid token request with default plain code_challenge_method",
args: args{
clientReq: testClientReq(s),
providerReq: testProviderReq(s, clientID),
assignmentReq: testAssignmentReq(s, entityID, groupID),
authorizeReq: func() *logical.Request {
// code_challenge_method intentionally not provided
req := testAuthorizeReq(s, clientID)
req.Data["code_challenge"] = "43_char_min_abcdefghijklmnopqrstuvwxyzabcde"
return req
}(),
tokenReq: func() *logical.Request {
req := testTokenReq(s, "", clientID, clientSecret)
req.Data["code_verifier"] = "43_char_min_abcdefghijklmnopqrstuvwxyzabcde"
return req
}(),
},
},
{
name: "valid token request with S256 code_challenge_method",
args: args{
clientReq: testClientReq(s),
providerReq: testProviderReq(s, clientID),
assignmentReq: testAssignmentReq(s, entityID, groupID),
authorizeReq: func() *logical.Request {
req := testAuthorizeReq(s, clientID)
req.Data["code_challenge_method"] = "S256"
req.Data["code_challenge"] = "hMn-5TBH-t3uN00FEaGsQtYPhyC4Otbx-9vDcPTYHmc"
return req
}(),
tokenReq: func() *logical.Request {
req := testTokenReq(s, "", clientID, clientSecret)
req.Data["code_verifier"] = "43_char_min_abcdefghijklmnopqrstuvwxyzabcde"
return req
}(),
},
},
{
name: "valid token request with max_age and auth_time claim",
args: args{
@ -712,6 +844,58 @@ func TestOIDC_Path_OIDC_Authorize(t *testing.T) {
},
wantErr: ErrAuthInvalidRequest,
},
{
name: "invalid authorize request with invalid code_challenge_method",
args: args{
entityID: entityID,
clientReq: testClientReq(s),
providerReq: testProviderReq(s, clientID),
assignmentReq: testAssignmentReq(s, entityID, groupID),
authorizeReq: func() *logical.Request {
req := testAuthorizeReq(s, clientID)
req.Data["code_challenge_method"] = "S512"
req.Data["code_challenge"] = "43_char_min_abcdefghijklmnopqrstuvwxyzabcde"
return req
}(),
},
wantErr: ErrAuthInvalidRequest,
},
{
name: "invalid authorize request with code_challenge length < 43 characters",
args: args{
entityID: entityID,
clientReq: testClientReq(s),
providerReq: testProviderReq(s, clientID),
assignmentReq: testAssignmentReq(s, entityID, groupID),
authorizeReq: func() *logical.Request {
req := testAuthorizeReq(s, clientID)
req.Data["code_challenge_method"] = "S256"
req.Data["code_challenge"] = ""
return req
}(),
},
wantErr: ErrAuthInvalidRequest,
},
{
name: "invalid authorize request with code_challenge length > 128 characters",
args: args{
entityID: entityID,
clientReq: testClientReq(s),
providerReq: testProviderReq(s, clientID),
assignmentReq: testAssignmentReq(s, entityID, groupID),
authorizeReq: func() *logical.Request {
req := testAuthorizeReq(s, clientID)
req.Data["code_challenge_method"] = "S256"
req.Data["code_challenge"] = `
129_char_abcdefghijklmnopqrstuvwxyzabcd
129_char_abcdefghijklmnopqrstuvwxyzabcd
129_char_abcdefghijklmnopqrstuvwxyzabcd
`
return req
}(),
},
wantErr: ErrAuthInvalidRequest,
},
{
name: "valid authorize request with empty nonce",
args: args{
@ -1342,6 +1526,127 @@ func TestOIDC_Path_OIDC_ProviderReadPublicKey(t *testing.T) {
}
}
func TestOIDC_Path_OIDC_Client_Type(t *testing.T) {
c, _, _ := TestCoreUnsealed(t)
ctx := namespace.RootContext(nil)
storage := &logical.InmemStorage{}
resp, err := c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/key/test-key",
Operation: logical.CreateOperation,
Storage: storage,
})
expectSuccess(t, resp, err)
tests := []struct {
name string
createClientType clientType
updateClientType clientType
wantCreateErr bool
wantUpdateErr bool
}{
{
name: "create confidential client and update to public client",
createClientType: confidential,
updateClientType: public,
wantUpdateErr: true,
},
{
name: "create confidential client and update to confidential client",
createClientType: confidential,
updateClientType: confidential,
},
{
name: "create public client and update to confidential client",
createClientType: public,
updateClientType: confidential,
wantUpdateErr: true,
},
{
name: "create public client and update to public client",
createClientType: public,
updateClientType: public,
},
{
name: "create an invalid client type",
createClientType: clientType(300),
wantCreateErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a client with the given client type
resp, err := c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/client/test-client",
Operation: logical.CreateOperation,
Storage: storage,
Data: map[string]interface{}{
"key": "test-key",
"client_type": tt.createClientType.String(),
},
})
if tt.wantCreateErr {
expectError(t, resp, err)
return
}
expectSuccess(t, resp, err)
// Read the client
resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/client/test-client",
Operation: logical.ReadOperation,
Storage: storage,
})
expectSuccess(t, resp, err)
// Assert that the client type is properly set
clientType := resp.Data["client_type"].(string)
require.Equal(t, tt.createClientType.String(), clientType)
// Assert that all client types have a client ID
clientID := resp.Data["client_id"].(string)
require.Len(t, clientID, clientIDLength)
// Assert that confidential clients have a client secret
if tt.createClientType == confidential {
clientSecret := resp.Data["client_secret"].(string)
require.Contains(t, clientSecret, clientSecretPrefix)
}
// Assert that public clients do not have a client secret
if tt.createClientType == public {
_, ok := resp.Data["client_secret"]
require.False(t, ok)
}
// Update the client and expect error if the type is different
resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/client/test-client",
Operation: logical.UpdateOperation,
Storage: storage,
Data: map[string]interface{}{
"key": "test-key",
"client_type": tt.updateClientType.String(),
},
})
if tt.wantUpdateErr {
expectError(t, resp, err)
} else {
expectSuccess(t, resp, err)
}
// Delete the client
resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/client/test-client",
Operation: logical.DeleteOperation,
Storage: storage,
})
expectSuccess(t, resp, err)
})
}
}
// TestOIDC_Path_OIDC_ProviderClient_NoKeyParameter tests that a client cannot
// be created without a key parameter
func TestOIDC_Path_OIDC_ProviderClient_NoKeyParameter(t *testing.T) {
@ -1573,6 +1878,7 @@ func TestOIDC_Path_OIDC_ProviderClient(t *testing.T) {
"access_token_ttl": int64(86400),
"client_id": resp.Data["client_id"],
"client_secret": resp.Data["client_secret"],
"client_type": confidential.String(),
}
if diff := deep.Equal(expected, resp.Data); diff != nil {
t.Fatal(diff)
@ -1627,6 +1933,7 @@ func TestOIDC_Path_OIDC_ProviderClient(t *testing.T) {
"access_token_ttl": int64(60),
"client_id": resp.Data["client_id"],
"client_secret": resp.Data["client_secret"],
"client_type": confidential.String(),
}
if diff := deep.Equal(expected, resp.Data); diff != nil {
t.Fatal(diff)
@ -1686,6 +1993,7 @@ func TestOIDC_Path_OIDC_ProviderClient_Deduplication(t *testing.T) {
"id_token_ttl": "1m",
"assignments": []string{"test-assignment1", "test-assignment1"},
"redirect_uris": []string{"http://example.com", "http://notduplicate.com", "http://example.com"},
"client_type": public.String(),
},
})
expectSuccess(t, resp, err)
@ -1704,7 +2012,7 @@ func TestOIDC_Path_OIDC_ProviderClient_Deduplication(t *testing.T) {
"id_token_ttl": int64(60),
"access_token_ttl": int64(86400),
"client_id": resp.Data["client_id"],
"client_secret": resp.Data["client_secret"],
"client_type": public.String(),
}
if diff := deep.Equal(expected, resp.Data); diff != nil {
t.Fatal(diff)
@ -1766,6 +2074,7 @@ func TestOIDC_Path_OIDC_ProviderClient_Update(t *testing.T) {
"access_token_ttl": int64(3600),
"client_id": resp.Data["client_id"],
"client_secret": resp.Data["client_secret"],
"client_type": confidential.String(),
}
if diff := deep.Equal(expected, resp.Data); diff != nil {
t.Fatal(diff)
@ -1799,6 +2108,7 @@ func TestOIDC_Path_OIDC_ProviderClient_Update(t *testing.T) {
"access_token_ttl": int64(60),
"client_id": resp.Data["client_id"],
"client_secret": resp.Data["client_secret"],
"client_type": confidential.String(),
}
if diff := deep.Equal(expected, resp.Data); diff != nil {
t.Fatal(diff)

View File

@ -6,9 +6,11 @@ import (
"encoding/base64"
"fmt"
"hash"
"net/http"
"net/url"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/sdk/logical"
"gopkg.in/square/go-jose.v2"
)
@ -75,3 +77,31 @@ func computeHashClaim(alg string, input string) (string, error) {
sum := h.Sum(nil)
return base64.RawURLEncoding.EncodeToString(sum[:len(sum)/2]), nil
}
// computeCodeChallenge computes a Proof Key for Code Exchange (PKCE)
// code challenge given a code verifier and code challenge method.
func computeCodeChallenge(verifier string, method string) (string, error) {
switch method {
case codeChallengeMethodPlain:
return verifier, nil
case codeChallengeMethodS256:
hf := sha256.New()
hf.Write([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(hf.Sum(nil)), nil
default:
return "", fmt.Errorf("invalid code challenge method %q", method)
}
}
// authCodeUsedPKCE returns true if the given entry was granted using PKCE.
func authCodeUsedPKCE(entry *authCodeCacheEntry) bool {
return entry.codeChallenge != "" && entry.codeChallengeMethod != ""
}
// basicAuth returns the username/password provided in the logical.Request's
// authorization header and a bool indicating if the request used basic
// authentication.
func basicAuth(req *logical.Request) (string, string, bool) {
headerReq := &http.Request{Header: req.Headers}
return headerReq.BasicAuth()
}