acl: add auth method for JWTs (#7846)
This commit is contained in:
parent
24175e2925
commit
940e5ad160
|
@ -8,13 +8,18 @@ import (
|
|||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/acl"
|
||||
"github.com/hashicorp/consul/agent/consul/authmethod/testauth"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest"
|
||||
"github.com/hashicorp/consul/sdk/freeport"
|
||||
"github.com/hashicorp/consul/sdk/testutil"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
// NOTE: The tests contained herein are designed to test the HTTP API
|
||||
|
@ -1591,6 +1596,168 @@ func TestACL_LoginProcedure_HTTP(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestACLEndpoint_LoginLogout_jwt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
a := NewTestAgent(t, TestACLConfigWithParams(nil))
|
||||
defer a.Shutdown()
|
||||
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
|
||||
// spin up a fake oidc server
|
||||
oidcServer := startSSOTestServer(t)
|
||||
pubKey, privKey := oidcServer.SigningKeys()
|
||||
|
||||
type mConfig = map[string]interface{}
|
||||
cases := map[string]struct {
|
||||
f func(config mConfig)
|
||||
issuer string
|
||||
expectErr string
|
||||
}{
|
||||
"success - jwt static keys": {func(config mConfig) {
|
||||
config["BoundIssuer"] = "https://legit.issuer.internal/"
|
||||
config["JWTValidationPubKeys"] = []string{pubKey}
|
||||
},
|
||||
"https://legit.issuer.internal/",
|
||||
""},
|
||||
"success - jwt jwks": {func(config mConfig) {
|
||||
config["JWKSURL"] = oidcServer.Addr() + "/certs"
|
||||
config["JWKSCACert"] = oidcServer.CACert()
|
||||
},
|
||||
"https://legit.issuer.internal/",
|
||||
""},
|
||||
"success - jwt oidc discovery": {func(config mConfig) {
|
||||
config["OIDCDiscoveryURL"] = oidcServer.Addr()
|
||||
config["OIDCDiscoveryCACert"] = oidcServer.CACert()
|
||||
},
|
||||
oidcServer.Addr(),
|
||||
""},
|
||||
}
|
||||
|
||||
for name, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
method, err := upsertTestCustomizedAuthMethod(a.RPC, TestDefaultMasterToken, "dc1", func(method *structs.ACLAuthMethod) {
|
||||
method.Type = "jwt"
|
||||
method.Config = map[string]interface{}{
|
||||
"JWTSupportedAlgs": []string{"ES256"},
|
||||
"ClaimMappings": map[string]string{
|
||||
"first_name": "name",
|
||||
"/org/primary": "primary_org",
|
||||
},
|
||||
"ListClaimMappings": map[string]string{
|
||||
"https://consul.test/groups": "groups",
|
||||
},
|
||||
"BoundAudiences": []string{"https://consul.test"},
|
||||
}
|
||||
if tc.f != nil {
|
||||
tc.f(method.Config)
|
||||
}
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("invalid bearer token", func(t *testing.T) {
|
||||
loginInput := &structs.ACLLoginParams{
|
||||
AuthMethod: method.Name,
|
||||
BearerToken: "invalid",
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("POST", "/v1/acl/login", jsonBody(loginInput))
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := a.srv.ACLLogin(resp, req)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
cl := jwt.Claims{
|
||||
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
|
||||
Audience: jwt.Audience{"https://consul.test"},
|
||||
Issuer: tc.issuer,
|
||||
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Second)),
|
||||
}
|
||||
|
||||
type orgs struct {
|
||||
Primary string `json:"primary"`
|
||||
}
|
||||
|
||||
privateCl := struct {
|
||||
FirstName string `json:"first_name"`
|
||||
Org orgs `json:"org"`
|
||||
Groups []string `json:"https://consul.test/groups"`
|
||||
}{
|
||||
FirstName: "jeff2",
|
||||
Org: orgs{"engineering"},
|
||||
Groups: []string{"foo", "bar"},
|
||||
}
|
||||
|
||||
jwtData, err := oidcauthtest.SignJWT(privKey, cl, privateCl)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("valid bearer token no bindings", func(t *testing.T) {
|
||||
loginInput := &structs.ACLLoginParams{
|
||||
AuthMethod: method.Name,
|
||||
BearerToken: jwtData,
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("POST", "/v1/acl/login", jsonBody(loginInput))
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := a.srv.ACLLogin(resp, req)
|
||||
|
||||
testutil.RequireErrorContains(t, err, "Permission denied")
|
||||
})
|
||||
|
||||
_, err = upsertTestCustomizedBindingRule(a.RPC, TestDefaultMasterToken, "dc1", func(rule *structs.ACLBindingRule) {
|
||||
rule.AuthMethod = method.Name
|
||||
rule.BindType = structs.BindingRuleBindTypeService
|
||||
rule.BindName = "test--${value.name}--${value.primary_org}"
|
||||
rule.Selector = "value.name == jeff2 and value.primary_org == engineering and foo in list.groups"
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("valid bearer token 1 service binding", func(t *testing.T) {
|
||||
loginInput := &structs.ACLLoginParams{
|
||||
AuthMethod: method.Name,
|
||||
BearerToken: jwtData,
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("POST", "/v1/acl/login", jsonBody(loginInput))
|
||||
resp := httptest.NewRecorder()
|
||||
obj, err := a.srv.ACLLogin(resp, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
token, ok := obj.(*structs.ACLToken)
|
||||
require.True(t, ok)
|
||||
|
||||
require.Equal(t, method.Name, token.AuthMethod)
|
||||
require.Equal(t, `token created via login`, token.Description)
|
||||
require.True(t, token.Local)
|
||||
require.Len(t, token.Roles, 0)
|
||||
require.Len(t, token.ServiceIdentities, 1)
|
||||
svcid := token.ServiceIdentities[0]
|
||||
require.Len(t, svcid.Datacenters, 0)
|
||||
require.Equal(t, "test--jeff2--engineering", svcid.ServiceName)
|
||||
|
||||
// and delete it
|
||||
req, _ = http.NewRequest("GET", "/v1/acl/logout", nil)
|
||||
req.Header.Add("X-Consul-Token", token.SecretID)
|
||||
resp = httptest.NewRecorder()
|
||||
_, err = a.srv.ACLLogout(resp, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// verify the token was deleted
|
||||
req, _ = http.NewRequest("GET", "/v1/acl/token/"+token.AccessorID, nil)
|
||||
req.Header.Add("X-Consul-Token", TestDefaultMasterToken)
|
||||
resp = httptest.NewRecorder()
|
||||
|
||||
// make the request
|
||||
_, err = a.srv.ACLTokenCRUD(resp, req)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, acl.ErrNotFound, err)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestACL_Authorize(t *testing.T) {
|
||||
t.Parallel()
|
||||
a1 := NewTestAgent(t, TestACLConfigWithParams(nil))
|
||||
|
@ -2087,3 +2254,11 @@ func upsertTestCustomizedBindingRule(rpc rpcFn, masterToken string, datacenter s
|
|||
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func startSSOTestServer(t *testing.T) *oidcauthtest.Server {
|
||||
ports := freeport.MustTake(1)
|
||||
return oidcauthtest.Start(t, oidcauthtest.WithPort(
|
||||
ports[0],
|
||||
func() { freeport.Return(ports) },
|
||||
))
|
||||
}
|
||||
|
|
|
@ -7,8 +7,9 @@ import (
|
|||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/go-bexpr"
|
||||
|
||||
// register this as a builtin auth method
|
||||
// register these as a builtin auth method
|
||||
_ "github.com/hashicorp/consul/agent/consul/authmethod/kubeauth"
|
||||
_ "github.com/hashicorp/consul/agent/consul/authmethod/ssoauth"
|
||||
)
|
||||
|
||||
type authMethodValidatorEntry struct {
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
//+build !consulent
|
||||
|
||||
package consul
|
||||
|
||||
func (s *Server) enterpriseEvaluateRoleBindings() error {
|
||||
return nil
|
||||
}
|
|
@ -1909,7 +1909,6 @@ func (a *ACL) BindingRuleSet(args *structs.ACLBindingRuleSetRequest, reply *stru
|
|||
if err != nil {
|
||||
return fmt.Errorf("Failed to apply binding rule upsert request: %v", err)
|
||||
}
|
||||
|
||||
if respErr, ok := resp.(error); ok {
|
||||
return fmt.Errorf("Failed to apply binding rule upsert request: %v", respErr)
|
||||
}
|
||||
|
@ -2049,6 +2048,10 @@ func (a *ACL) AuthMethodRead(args *structs.ACLAuthMethodGetRequest, reply *struc
|
|||
return err
|
||||
}
|
||||
|
||||
if method != nil {
|
||||
_ = a.enterpriseAuthMethodTypeValidation(method.Type)
|
||||
}
|
||||
|
||||
reply.Index, reply.AuthMethod = index, method
|
||||
return nil
|
||||
})
|
||||
|
@ -2093,6 +2096,10 @@ func (a *ACL) AuthMethodSet(args *structs.ACLAuthMethodSetRequest, reply *struct
|
|||
return fmt.Errorf("Invalid Auth Method: invalid Name. Only alphanumeric characters, '-' and '_' are allowed")
|
||||
}
|
||||
|
||||
if err := a.enterpriseAuthMethodTypeValidation(method.Type); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check to see if the method exists first.
|
||||
_, existing, err := state.ACLAuthMethodGetByName(nil, method.Name, &method.EnterpriseMeta)
|
||||
if err != nil {
|
||||
|
@ -2193,6 +2200,10 @@ func (a *ACL) AuthMethodDelete(args *structs.ACLAuthMethodDeleteRequest, reply *
|
|||
return nil
|
||||
}
|
||||
|
||||
if err := a.enterpriseAuthMethodTypeValidation(method.Type); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req := structs.ACLAuthMethodBatchDeleteRequest{
|
||||
AuthMethodNames: []string{args.AuthMethodName},
|
||||
EnterpriseMeta: args.EnterpriseMeta,
|
||||
|
@ -2249,6 +2260,7 @@ func (a *ACL) AuthMethodList(args *structs.ACLAuthMethodListRequest, reply *stru
|
|||
|
||||
var stubs structs.ACLAuthMethodListStubs
|
||||
for _, method := range methods {
|
||||
_ = a.enterpriseAuthMethodTypeValidation(method.Type)
|
||||
stubs = append(stubs, method.Stub())
|
||||
}
|
||||
|
||||
|
@ -2294,6 +2306,10 @@ func (a *ACL) Login(args *structs.ACLLoginRequest, reply *structs.ACLToken) erro
|
|||
return acl.ErrNotFound
|
||||
}
|
||||
|
||||
if err := a.enterpriseAuthMethodTypeValidation(method.Type); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
validator, err := a.srv.loadAuthMethodValidator(idx, method)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -2305,6 +2321,31 @@ func (a *ACL) Login(args *structs.ACLLoginRequest, reply *structs.ACLToken) erro
|
|||
return err
|
||||
}
|
||||
|
||||
return a.tokenSetFromAuthMethod(
|
||||
method,
|
||||
&auth.EnterpriseMeta,
|
||||
"token created via login",
|
||||
auth.Meta,
|
||||
validator,
|
||||
verifiedIdentity,
|
||||
&structs.ACLTokenSetRequest{
|
||||
Datacenter: args.Datacenter,
|
||||
WriteRequest: args.WriteRequest,
|
||||
},
|
||||
reply,
|
||||
)
|
||||
}
|
||||
|
||||
func (a *ACL) tokenSetFromAuthMethod(
|
||||
method *structs.ACLAuthMethod,
|
||||
entMeta *structs.EnterpriseMeta,
|
||||
tokenDescriptionPrefix string,
|
||||
tokenMetadata map[string]string,
|
||||
validator authmethod.Validator,
|
||||
verifiedIdentity *authmethod.Identity,
|
||||
createReq *structs.ACLTokenSetRequest, // this should be prepopulated with datacenter+writerequest
|
||||
reply *structs.ACLToken,
|
||||
) error {
|
||||
// This always will return a valid pointer
|
||||
targetMeta, err := computeTargetEnterpriseMeta(method, verifiedIdentity)
|
||||
if err != nil {
|
||||
|
@ -2312,7 +2353,7 @@ func (a *ACL) Login(args *structs.ACLLoginRequest, reply *structs.ACLToken) erro
|
|||
}
|
||||
|
||||
// 3. send map through role bindings
|
||||
serviceIdentities, roleLinks, err := a.srv.evaluateRoleBindings(validator, verifiedIdentity, &auth.EnterpriseMeta, targetMeta)
|
||||
serviceIdentities, roleLinks, err := a.srv.evaluateRoleBindings(validator, verifiedIdentity, entMeta, targetMeta)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -2323,8 +2364,10 @@ func (a *ACL) Login(args *structs.ACLLoginRequest, reply *structs.ACLToken) erro
|
|||
return acl.ErrPermissionDenied
|
||||
}
|
||||
|
||||
description := "token created via login"
|
||||
loginMeta, err := encodeLoginMeta(auth.Meta)
|
||||
// TODO(sso): add a CapturedField to ACLAuthMethod that would pluck fields from the returned identity and stuff into `auth.Meta`.
|
||||
|
||||
description := tokenDescriptionPrefix
|
||||
loginMeta, err := encodeLoginMeta(tokenMetadata)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -2333,24 +2376,20 @@ func (a *ACL) Login(args *structs.ACLLoginRequest, reply *structs.ACLToken) erro
|
|||
}
|
||||
|
||||
// 4. create token
|
||||
createReq := structs.ACLTokenSetRequest{
|
||||
Datacenter: args.Datacenter,
|
||||
ACLToken: structs.ACLToken{
|
||||
Description: description,
|
||||
Local: true,
|
||||
AuthMethod: auth.AuthMethod,
|
||||
ServiceIdentities: serviceIdentities,
|
||||
Roles: roleLinks,
|
||||
ExpirationTTL: method.MaxTokenTTL,
|
||||
EnterpriseMeta: *targetMeta,
|
||||
},
|
||||
WriteRequest: args.WriteRequest,
|
||||
createReq.ACLToken = structs.ACLToken{
|
||||
Description: description,
|
||||
Local: true,
|
||||
AuthMethod: method.Name,
|
||||
ServiceIdentities: serviceIdentities,
|
||||
Roles: roleLinks,
|
||||
ExpirationTTL: method.MaxTokenTTL,
|
||||
EnterpriseMeta: *targetMeta,
|
||||
}
|
||||
|
||||
createReq.ACLToken.ACLAuthMethodEnterpriseMeta.FillWithEnterpriseMeta(&auth.EnterpriseMeta)
|
||||
createReq.ACLToken.ACLAuthMethodEnterpriseMeta.FillWithEnterpriseMeta(entMeta)
|
||||
|
||||
// 5. return token information like a TokenCreate would
|
||||
err = a.tokenSetInternal(&createReq, reply, true)
|
||||
err = a.tokenSetInternal(createReq, reply, true)
|
||||
|
||||
// If we were in a slight race with a role delete operation then we may
|
||||
// still end up failing to insert an unprivileged token in the state
|
||||
|
|
|
@ -7,6 +7,10 @@ import (
|
|||
"github.com/hashicorp/consul/agent/structs"
|
||||
)
|
||||
|
||||
func (a *ACL) enterpriseAuthMethodTypeValidation(authMethodType string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func enterpriseAuthMethodValidation(method *structs.ACLAuthMethod, validator authmethod.Validator) error {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -16,13 +16,16 @@ import (
|
|||
"github.com/hashicorp/consul/agent/consul/authmethod/testauth"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
tokenStore "github.com/hashicorp/consul/agent/token"
|
||||
"github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest"
|
||||
"github.com/hashicorp/consul/lib"
|
||||
"github.com/hashicorp/consul/sdk/freeport"
|
||||
"github.com/hashicorp/consul/sdk/testutil"
|
||||
"github.com/hashicorp/consul/sdk/testutil/retry"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
func TestACLEndpoint_Bootstrap(t *testing.T) {
|
||||
|
@ -5233,6 +5236,167 @@ func TestACLEndpoint_Login_k8s(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestACLEndpoint_Login_jwt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir1, s1 := testServerWithConfig(t, func(c *Config) {
|
||||
c.ACLDatacenter = "dc1"
|
||||
c.ACLsEnabled = true
|
||||
c.ACLMasterToken = "root"
|
||||
})
|
||||
defer os.RemoveAll(dir1)
|
||||
defer s1.Shutdown()
|
||||
codec := rpcClient(t, s1)
|
||||
defer codec.Close()
|
||||
|
||||
testrpc.WaitForLeader(t, s1.RPC, "dc1")
|
||||
|
||||
acl := ACL{srv: s1}
|
||||
|
||||
// spin up a fake oidc server
|
||||
oidcServer := startSSOTestServer(t)
|
||||
pubKey, privKey := oidcServer.SigningKeys()
|
||||
|
||||
type mConfig = map[string]interface{}
|
||||
cases := map[string]struct {
|
||||
f func(config mConfig)
|
||||
issuer string
|
||||
expectErr string
|
||||
}{
|
||||
"success - jwt static keys": {func(config mConfig) {
|
||||
config["BoundIssuer"] = "https://legit.issuer.internal/"
|
||||
config["JWTValidationPubKeys"] = []string{pubKey}
|
||||
},
|
||||
"https://legit.issuer.internal/",
|
||||
""},
|
||||
"success - jwt jwks": {func(config mConfig) {
|
||||
config["JWKSURL"] = oidcServer.Addr() + "/certs"
|
||||
config["JWKSCACert"] = oidcServer.CACert()
|
||||
},
|
||||
"https://legit.issuer.internal/",
|
||||
""},
|
||||
"success - jwt oidc discovery": {func(config mConfig) {
|
||||
config["OIDCDiscoveryURL"] = oidcServer.Addr()
|
||||
config["OIDCDiscoveryCACert"] = oidcServer.CACert()
|
||||
},
|
||||
oidcServer.Addr(),
|
||||
""},
|
||||
}
|
||||
|
||||
for name, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
method, err := upsertTestCustomizedAuthMethod(codec, "root", "dc1", func(method *structs.ACLAuthMethod) {
|
||||
method.Type = "jwt"
|
||||
method.Config = map[string]interface{}{
|
||||
"JWTSupportedAlgs": []string{"ES256"},
|
||||
"ClaimMappings": map[string]string{
|
||||
"first_name": "name",
|
||||
"/org/primary": "primary_org",
|
||||
},
|
||||
"ListClaimMappings": map[string]string{
|
||||
"https://consul.test/groups": "groups",
|
||||
},
|
||||
"BoundAudiences": []string{"https://consul.test"},
|
||||
}
|
||||
if tc.f != nil {
|
||||
tc.f(method.Config)
|
||||
}
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("invalid bearer token", func(t *testing.T) {
|
||||
req := structs.ACLLoginRequest{
|
||||
Auth: &structs.ACLLoginParams{
|
||||
AuthMethod: method.Name,
|
||||
BearerToken: "invalid",
|
||||
},
|
||||
Datacenter: "dc1",
|
||||
}
|
||||
resp := structs.ACLToken{}
|
||||
|
||||
require.Error(t, acl.Login(&req, &resp))
|
||||
})
|
||||
|
||||
cl := jwt.Claims{
|
||||
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
|
||||
Audience: jwt.Audience{"https://consul.test"},
|
||||
Issuer: tc.issuer,
|
||||
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Second)),
|
||||
}
|
||||
|
||||
type orgs struct {
|
||||
Primary string `json:"primary"`
|
||||
}
|
||||
|
||||
privateCl := struct {
|
||||
FirstName string `json:"first_name"`
|
||||
Org orgs `json:"org"`
|
||||
Groups []string `json:"https://consul.test/groups"`
|
||||
}{
|
||||
FirstName: "jeff2",
|
||||
Org: orgs{"engineering"},
|
||||
Groups: []string{"foo", "bar"},
|
||||
}
|
||||
|
||||
jwtData, err := oidcauthtest.SignJWT(privKey, cl, privateCl)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("valid bearer token no bindings", func(t *testing.T) {
|
||||
req := structs.ACLLoginRequest{
|
||||
Auth: &structs.ACLLoginParams{
|
||||
AuthMethod: method.Name,
|
||||
BearerToken: jwtData,
|
||||
},
|
||||
Datacenter: "dc1",
|
||||
}
|
||||
resp := structs.ACLToken{}
|
||||
|
||||
testutil.RequireErrorContains(t, acl.Login(&req, &resp), "Permission denied")
|
||||
})
|
||||
|
||||
_, err = upsertTestBindingRule(
|
||||
codec, "root", "dc1", method.Name,
|
||||
"value.name == jeff2 and value.primary_org == engineering and foo in list.groups",
|
||||
structs.BindingRuleBindTypeService,
|
||||
"test--${value.name}--${value.primary_org}",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("valid bearer token 1 service binding", func(t *testing.T) {
|
||||
req := structs.ACLLoginRequest{
|
||||
Auth: &structs.ACLLoginParams{
|
||||
AuthMethod: method.Name,
|
||||
BearerToken: jwtData,
|
||||
},
|
||||
Datacenter: "dc1",
|
||||
}
|
||||
resp := structs.ACLToken{}
|
||||
|
||||
require.NoError(t, acl.Login(&req, &resp))
|
||||
|
||||
require.Equal(t, method.Name, resp.AuthMethod)
|
||||
require.Equal(t, `token created via login`, resp.Description)
|
||||
require.True(t, resp.Local)
|
||||
require.Len(t, resp.Roles, 0)
|
||||
require.Len(t, resp.ServiceIdentities, 1)
|
||||
svcid := resp.ServiceIdentities[0]
|
||||
require.Len(t, svcid.Datacenters, 0)
|
||||
require.Equal(t, "test--jeff2--engineering", svcid.ServiceName)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func startSSOTestServer(t *testing.T) *oidcauthtest.Server {
|
||||
ports := freeport.MustTake(1)
|
||||
return oidcauthtest.Start(t, oidcauthtest.WithPort(
|
||||
ports[0],
|
||||
func() { freeport.Return(ports) },
|
||||
))
|
||||
}
|
||||
|
||||
func TestACLEndpoint_Logout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
|
@ -94,6 +94,10 @@ func NewValidator(method *structs.ACLAuthMethod) (*Validator, error) {
|
|||
return nil, fmt.Errorf("Config.ServiceAccountJWT is not a valid JWT: %v", err)
|
||||
}
|
||||
|
||||
if err := enterpriseValidation(method, &config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
transport := cleanhttp.DefaultTransport()
|
||||
client, err := k8s.NewForConfig(&client_rest.Config{
|
||||
Host: config.Host,
|
||||
|
|
|
@ -0,0 +1,165 @@
|
|||
package ssoauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/consul/authmethod"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/internal/go-sso/oidcauth"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
)
|
||||
|
||||
func init() {
|
||||
authmethod.Register("jwt", func(logger hclog.Logger, method *structs.ACLAuthMethod) (authmethod.Validator, error) {
|
||||
v, err := NewValidator(logger, method)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return v, nil
|
||||
})
|
||||
}
|
||||
|
||||
// Validator is the wrapper around the go-sso library that also conforms to the
|
||||
// authmethod.Validator interface.
|
||||
type Validator struct {
|
||||
name string
|
||||
methodType string
|
||||
config *oidcauth.Config
|
||||
logger hclog.Logger
|
||||
oa *oidcauth.Authenticator
|
||||
}
|
||||
|
||||
var _ authmethod.Validator = (*Validator)(nil)
|
||||
|
||||
func NewValidator(logger hclog.Logger, method *structs.ACLAuthMethod) (*Validator, error) {
|
||||
if err := validateType(method.Type); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var config Config
|
||||
if err := authmethod.ParseConfig(method.Config, &config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ssoConfig := config.convertForLibrary(method.Type)
|
||||
|
||||
oa, err := oidcauth.New(ssoConfig, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
v := &Validator{
|
||||
name: method.Name,
|
||||
methodType: method.Type,
|
||||
config: ssoConfig,
|
||||
logger: logger,
|
||||
oa: oa,
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// Name implements authmethod.Validator.
|
||||
func (v *Validator) Name() string { return v.name }
|
||||
|
||||
// Stop implements authmethod.Validator.
|
||||
func (v *Validator) Stop() { v.oa.Stop() }
|
||||
|
||||
// ValidateLogin implements authmethod.Validator.
|
||||
func (v *Validator) ValidateLogin(ctx context.Context, loginToken string) (*authmethod.Identity, error) {
|
||||
c, err := v.oa.ClaimsFromJWT(ctx, loginToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return v.identityFromClaims(c), nil
|
||||
}
|
||||
|
||||
func (v *Validator) identityFromClaims(c *oidcauth.Claims) *authmethod.Identity {
|
||||
id := v.NewIdentity()
|
||||
id.SelectableFields = &fieldDetails{
|
||||
Values: c.Values,
|
||||
Lists: c.Lists,
|
||||
}
|
||||
for k, val := range c.Values {
|
||||
id.ProjectedVars["value."+k] = val
|
||||
}
|
||||
id.EnterpriseMeta = v.ssoEntMetaFromClaims(c)
|
||||
return id
|
||||
}
|
||||
|
||||
// NewIdentity implements authmethod.Validator.
|
||||
func (v *Validator) NewIdentity() *authmethod.Identity {
|
||||
// Populate selectable fields with empty values so emptystring filters
|
||||
// works. Populate projectable vars with empty values so HIL works.
|
||||
fd := &fieldDetails{
|
||||
Values: make(map[string]string),
|
||||
Lists: make(map[string][]string),
|
||||
}
|
||||
projectedVars := make(map[string]string)
|
||||
for _, k := range v.config.ClaimMappings {
|
||||
fd.Values[k] = ""
|
||||
projectedVars["value."+k] = ""
|
||||
}
|
||||
for _, k := range v.config.ListClaimMappings {
|
||||
fd.Lists[k] = nil
|
||||
}
|
||||
|
||||
return &authmethod.Identity{
|
||||
SelectableFields: fd,
|
||||
ProjectedVars: projectedVars,
|
||||
}
|
||||
}
|
||||
|
||||
type fieldDetails struct {
|
||||
Values map[string]string `bexpr:"value"`
|
||||
Lists map[string][]string `bexpr:"list"`
|
||||
}
|
||||
|
||||
// Config is the collection of all settings that pertain to doing OIDC-based
|
||||
// authentication and direct JWT-based authentication processes.
|
||||
type Config struct {
|
||||
// common for type=oidc and type=jwt
|
||||
JWTSupportedAlgs []string `json:",omitempty"`
|
||||
BoundAudiences []string `json:",omitempty"`
|
||||
ClaimMappings map[string]string `json:",omitempty"`
|
||||
ListClaimMappings map[string]string `json:",omitempty"`
|
||||
OIDCDiscoveryURL string `json:",omitempty"`
|
||||
OIDCDiscoveryCACert string `json:",omitempty"`
|
||||
|
||||
// just for type=jwt
|
||||
JWKSURL string `json:",omitempty"`
|
||||
JWKSCACert string `json:",omitempty"`
|
||||
JWTValidationPubKeys []string `json:",omitempty"`
|
||||
BoundIssuer string `json:",omitempty"`
|
||||
ExpirationLeeway time.Duration `json:",omitempty"`
|
||||
NotBeforeLeeway time.Duration `json:",omitempty"`
|
||||
ClockSkewLeeway time.Duration `json:",omitempty"`
|
||||
|
||||
enterpriseConfig `mapstructure:",squash"`
|
||||
}
|
||||
|
||||
func (c *Config) convertForLibrary(methodType string) *oidcauth.Config {
|
||||
ssoConfig := &oidcauth.Config{
|
||||
Type: methodType,
|
||||
|
||||
// common for type=oidc and type=jwt
|
||||
JWTSupportedAlgs: c.JWTSupportedAlgs,
|
||||
BoundAudiences: c.BoundAudiences,
|
||||
ClaimMappings: c.ClaimMappings,
|
||||
ListClaimMappings: c.ListClaimMappings,
|
||||
OIDCDiscoveryURL: c.OIDCDiscoveryURL,
|
||||
OIDCDiscoveryCACert: c.OIDCDiscoveryCACert,
|
||||
|
||||
// just for type=jwt
|
||||
JWKSURL: c.JWKSURL,
|
||||
JWKSCACert: c.JWKSCACert,
|
||||
JWTValidationPubKeys: c.JWTValidationPubKeys,
|
||||
BoundIssuer: c.BoundIssuer,
|
||||
ExpirationLeeway: c.ExpirationLeeway,
|
||||
NotBeforeLeeway: c.NotBeforeLeeway,
|
||||
ClockSkewLeeway: c.ClockSkewLeeway,
|
||||
}
|
||||
c.enterpriseConvertForLibrary(ssoConfig)
|
||||
return ssoConfig
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
//+build !consulent
|
||||
|
||||
package ssoauth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/internal/go-sso/oidcauth"
|
||||
)
|
||||
|
||||
func validateType(typ string) error {
|
||||
if typ != "jwt" {
|
||||
return fmt.Errorf("type should be %q", "jwt")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *Validator) ssoEntMetaFromClaims(_ *oidcauth.Claims) *structs.EnterpriseMeta {
|
||||
return nil
|
||||
}
|
||||
|
||||
type enterpriseConfig struct{}
|
||||
|
||||
func (c *Config) enterpriseConvertForLibrary(_ *oidcauth.Config) {}
|
|
@ -0,0 +1,270 @@
|
|||
package ssoauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/consul/authmethod"
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
"github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest"
|
||||
"github.com/hashicorp/consul/sdk/freeport"
|
||||
"github.com/hashicorp/consul/sdk/testutil"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
func TestJWT_NewValidator(t *testing.T) {
|
||||
nullLogger := hclog.NewNullLogger()
|
||||
type AM = *structs.ACLAuthMethod
|
||||
|
||||
makeAuthMethod := func(typ string, f func(method AM)) *structs.ACLAuthMethod {
|
||||
method := &structs.ACLAuthMethod{
|
||||
Name: "test-" + typ,
|
||||
Description: typ + " test",
|
||||
Type: typ,
|
||||
Config: map[string]interface{}{},
|
||||
}
|
||||
if f != nil {
|
||||
f(method)
|
||||
}
|
||||
return method
|
||||
}
|
||||
|
||||
oidcServer := startTestServer(t)
|
||||
|
||||
// Note that we won't test ALL of the available config variations here.
|
||||
// The go-sso library has exhaustive tests.
|
||||
for name, tc := range map[string]struct {
|
||||
method *structs.ACLAuthMethod
|
||||
expectErr string
|
||||
}{
|
||||
"wrong type": {makeAuthMethod("invalid", nil), `type should be`},
|
||||
"extra config": {makeAuthMethod("jwt", func(method AM) {
|
||||
method.Config["extra"] = "config"
|
||||
}), "has invalid keys"},
|
||||
"wrong type of key in config blob": {makeAuthMethod("jwt", func(method AM) {
|
||||
method.Config["JWKSURL"] = []int{12345}
|
||||
}), `'JWKSURL' expected type 'string', got unconvertible type '[]int'`},
|
||||
|
||||
"normal jwt - static keys": {makeAuthMethod("jwt", func(method AM) {
|
||||
method.Config["BoundIssuer"] = "https://legit.issuer.internal/"
|
||||
pubKey, _ := oidcServer.SigningKeys()
|
||||
method.Config["JWTValidationPubKeys"] = []string{pubKey}
|
||||
}), ""},
|
||||
"normal jwt - jwks": {makeAuthMethod("jwt", func(method AM) {
|
||||
method.Config["JWKSURL"] = oidcServer.Addr() + "/certs"
|
||||
method.Config["JWKSCACert"] = oidcServer.CACert()
|
||||
}), ""},
|
||||
"normal jwt - oidc discovery": {makeAuthMethod("jwt", func(method AM) {
|
||||
method.Config["OIDCDiscoveryURL"] = oidcServer.Addr()
|
||||
method.Config["OIDCDiscoveryCACert"] = oidcServer.CACert()
|
||||
}), ""},
|
||||
} {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
v, err := NewValidator(nullLogger, tc.method)
|
||||
if tc.expectErr != "" {
|
||||
testutil.RequireErrorContains(t, err, tc.expectErr)
|
||||
require.Nil(t, v)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, v)
|
||||
v.Stop()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWT_ValidateLogin(t *testing.T) {
|
||||
type mConfig = map[string]interface{}
|
||||
|
||||
setup := func(t *testing.T, f func(config mConfig)) *Validator {
|
||||
t.Helper()
|
||||
|
||||
config := map[string]interface{}{
|
||||
"JWTSupportedAlgs": []string{"ES256"},
|
||||
"ClaimMappings": map[string]string{
|
||||
"first_name": "name",
|
||||
"/org/primary": "primary_org",
|
||||
},
|
||||
"ListClaimMappings": map[string]string{
|
||||
"https://consul.test/groups": "groups",
|
||||
},
|
||||
"BoundAudiences": []string{"https://consul.test"},
|
||||
}
|
||||
if f != nil {
|
||||
f(config)
|
||||
}
|
||||
|
||||
method := &structs.ACLAuthMethod{
|
||||
Name: "test-method",
|
||||
Type: "jwt",
|
||||
Config: config,
|
||||
}
|
||||
|
||||
nullLogger := hclog.NewNullLogger()
|
||||
v, err := NewValidator(nullLogger, method)
|
||||
require.NoError(t, err)
|
||||
return v
|
||||
}
|
||||
|
||||
oidcServer := startTestServer(t)
|
||||
pubKey, privKey := oidcServer.SigningKeys()
|
||||
|
||||
cases := map[string]struct {
|
||||
f func(config mConfig)
|
||||
issuer string
|
||||
expectErr string
|
||||
}{
|
||||
"success - jwt static keys": {func(config mConfig) {
|
||||
config["BoundIssuer"] = "https://legit.issuer.internal/"
|
||||
config["JWTValidationPubKeys"] = []string{pubKey}
|
||||
},
|
||||
"https://legit.issuer.internal/",
|
||||
""},
|
||||
"success - jwt jwks": {func(config mConfig) {
|
||||
config["JWKSURL"] = oidcServer.Addr() + "/certs"
|
||||
config["JWKSCACert"] = oidcServer.CACert()
|
||||
},
|
||||
"https://legit.issuer.internal/",
|
||||
""},
|
||||
"success - jwt oidc discovery": {func(config mConfig) {
|
||||
config["OIDCDiscoveryURL"] = oidcServer.Addr()
|
||||
config["OIDCDiscoveryCACert"] = oidcServer.CACert()
|
||||
},
|
||||
oidcServer.Addr(),
|
||||
""},
|
||||
}
|
||||
|
||||
for name, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
v := setup(t, tc.f)
|
||||
|
||||
cl := jwt.Claims{
|
||||
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
|
||||
Audience: jwt.Audience{"https://consul.test"},
|
||||
Issuer: tc.issuer,
|
||||
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Second)),
|
||||
}
|
||||
|
||||
type orgs struct {
|
||||
Primary string `json:"primary"`
|
||||
}
|
||||
|
||||
privateCl := struct {
|
||||
FirstName string `json:"first_name"`
|
||||
Org orgs `json:"org"`
|
||||
Groups []string `json:"https://consul.test/groups"`
|
||||
}{
|
||||
FirstName: "jeff2",
|
||||
Org: orgs{"engineering"},
|
||||
Groups: []string{"foo", "bar"},
|
||||
}
|
||||
|
||||
jwtData, err := oidcauthtest.SignJWT(privKey, cl, privateCl)
|
||||
require.NoError(t, err)
|
||||
|
||||
id, err := v.ValidateLogin(context.Background(), jwtData)
|
||||
if tc.expectErr != "" {
|
||||
testutil.RequireErrorContains(t, err, tc.expectErr)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
|
||||
authmethod.RequireIdentityMatch(t, id, map[string]string{
|
||||
"value.name": "jeff2",
|
||||
"value.primary_org": "engineering",
|
||||
},
|
||||
"value.name == jeff2",
|
||||
"value.name != jeff",
|
||||
"value.primary_org == engineering",
|
||||
"foo in list.groups",
|
||||
"bar in list.groups",
|
||||
"salt not in list.groups",
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewIdentity(t *testing.T) {
|
||||
// This is only based on claim mappings, so we'll just use the JWT type
|
||||
// since that's cheaper to setup.
|
||||
cases := map[string]struct {
|
||||
claimMappings map[string]string
|
||||
listClaimMappings map[string]string
|
||||
expectVars map[string]string
|
||||
expectFilters []string
|
||||
}{
|
||||
"nil": {nil, nil, kv(), nil},
|
||||
"empty": {kv(), kv(), kv(), nil},
|
||||
"one value mapping": {
|
||||
kv("foo1", "val1"),
|
||||
kv(),
|
||||
kv("value.val1", ""),
|
||||
[]string{`value.val1 == ""`},
|
||||
},
|
||||
"one list mapping": {kv(),
|
||||
kv("foo2", "val2"),
|
||||
kv(),
|
||||
nil,
|
||||
},
|
||||
"one of each": {
|
||||
kv("foo1", "val1"),
|
||||
kv("foo2", "val2"),
|
||||
kv("value.val1", ""),
|
||||
[]string{`value.val1 == ""`},
|
||||
},
|
||||
"two value mappings": {
|
||||
kv("foo1", "val1", "foo2", "val2"),
|
||||
kv(),
|
||||
kv("value.val1", "", "value.val2", ""),
|
||||
[]string{`value.val1 == ""`, `value.val2 == ""`},
|
||||
},
|
||||
}
|
||||
pubKey, _ := oidcauthtest.SigningKeys()
|
||||
|
||||
for name, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
method := &structs.ACLAuthMethod{
|
||||
Name: "test-method",
|
||||
Type: "jwt",
|
||||
Config: map[string]interface{}{
|
||||
"BoundIssuer": "https://legit.issuer.internal/",
|
||||
"JWTValidationPubKeys": []string{pubKey},
|
||||
"ClaimMappings": tc.claimMappings,
|
||||
"ListClaimMappings": tc.listClaimMappings,
|
||||
},
|
||||
}
|
||||
nullLogger := hclog.NewNullLogger()
|
||||
v, err := NewValidator(nullLogger, method)
|
||||
require.NoError(t, err)
|
||||
|
||||
id := v.NewIdentity()
|
||||
authmethod.RequireIdentityMatch(t, id, tc.expectVars, tc.expectFilters...)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func kv(a ...string) map[string]string {
|
||||
if len(a)%2 != 0 {
|
||||
panic("kv() requires even numbers of arguments")
|
||||
}
|
||||
m := make(map[string]string)
|
||||
for i := 0; i < len(a); i += 2 {
|
||||
m[a[i]] = a[i+1]
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func startTestServer(t *testing.T) *oidcauthtest.Server {
|
||||
ports := freeport.MustTake(1)
|
||||
return oidcauthtest.Start(t, oidcauthtest.WithPort(
|
||||
ports[0],
|
||||
func() { freeport.Return(ports) },
|
||||
))
|
||||
}
|
120
api/acl.go
120
api/acl.go
|
@ -305,12 +305,73 @@ func (c *KubernetesAuthMethodConfig) RenderToConfig() map[string]interface{} {
|
|||
}
|
||||
}
|
||||
|
||||
// OIDCAuthMethodConfig is the config for the built-in Consul auth method for
|
||||
// OIDC and JWT.
|
||||
type OIDCAuthMethodConfig struct {
|
||||
// common for type=oidc and type=jwt
|
||||
JWTSupportedAlgs []string `json:",omitempty"`
|
||||
BoundAudiences []string `json:",omitempty"`
|
||||
ClaimMappings map[string]string `json:",omitempty"`
|
||||
ListClaimMappings map[string]string `json:",omitempty"`
|
||||
OIDCDiscoveryURL string `json:",omitempty"`
|
||||
OIDCDiscoveryCACert string `json:",omitempty"`
|
||||
// just for type=oidc
|
||||
OIDCClientID string `json:",omitempty"`
|
||||
OIDCClientSecret string `json:",omitempty"`
|
||||
OIDCScopes []string `json:",omitempty"`
|
||||
AllowedRedirectURIs []string `json:",omitempty"`
|
||||
VerboseOIDCLogging bool `json:",omitempty"`
|
||||
// just for type=jwt
|
||||
JWKSURL string `json:",omitempty"`
|
||||
JWKSCACert string `json:",omitempty"`
|
||||
JWTValidationPubKeys []string `json:",omitempty"`
|
||||
BoundIssuer string `json:",omitempty"`
|
||||
ExpirationLeeway time.Duration `json:",omitempty"`
|
||||
NotBeforeLeeway time.Duration `json:",omitempty"`
|
||||
ClockSkewLeeway time.Duration `json:",omitempty"`
|
||||
}
|
||||
|
||||
// RenderToConfig converts this into a map[string]interface{} suitable for use
|
||||
// in the ACLAuthMethod.Config field.
|
||||
func (c *OIDCAuthMethodConfig) RenderToConfig() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
// common for type=oidc and type=jwt
|
||||
"JWTSupportedAlgs": c.JWTSupportedAlgs,
|
||||
"BoundAudiences": c.BoundAudiences,
|
||||
"ClaimMappings": c.ClaimMappings,
|
||||
"ListClaimMappings": c.ListClaimMappings,
|
||||
"OIDCDiscoveryURL": c.OIDCDiscoveryURL,
|
||||
"OIDCDiscoveryCACert": c.OIDCDiscoveryCACert,
|
||||
// just for type=oidc
|
||||
"OIDCClientID": c.OIDCClientID,
|
||||
"OIDCClientSecret": c.OIDCClientSecret,
|
||||
"OIDCScopes": c.OIDCScopes,
|
||||
"AllowedRedirectURIs": c.AllowedRedirectURIs,
|
||||
"VerboseOIDCLogging": c.VerboseOIDCLogging,
|
||||
// just for type=jwt
|
||||
"JWKSURL": c.JWKSURL,
|
||||
"JWKSCACert": c.JWKSCACert,
|
||||
"JWTValidationPubKeys": c.JWTValidationPubKeys,
|
||||
"BoundIssuer": c.BoundIssuer,
|
||||
"ExpirationLeeway": c.ExpirationLeeway,
|
||||
"NotBeforeLeeway": c.NotBeforeLeeway,
|
||||
"ClockSkewLeeway": c.ClockSkewLeeway,
|
||||
}
|
||||
}
|
||||
|
||||
type ACLLoginParams struct {
|
||||
AuthMethod string
|
||||
BearerToken string
|
||||
Meta map[string]string `json:",omitempty"`
|
||||
}
|
||||
|
||||
type ACLOIDCAuthURLParams struct {
|
||||
AuthMethod string
|
||||
RedirectURI string
|
||||
ClientNonce string
|
||||
Meta map[string]string `json:",omitempty"`
|
||||
}
|
||||
|
||||
// ACL can be used to query the ACL endpoints
|
||||
type ACL struct {
|
||||
c *Client
|
||||
|
@ -1227,3 +1288,62 @@ func (a *ACL) Logout(q *WriteOptions) (*WriteMeta, error) {
|
|||
wm := &WriteMeta{RequestTime: rtt}
|
||||
return wm, nil
|
||||
}
|
||||
|
||||
// OIDCAuthURL requests an authorization URL to start an OIDC login flow.
|
||||
func (a *ACL) OIDCAuthURL(auth *ACLOIDCAuthURLParams, q *WriteOptions) (string, *WriteMeta, error) {
|
||||
if auth.AuthMethod == "" {
|
||||
return "", nil, fmt.Errorf("Must specify an auth method name")
|
||||
}
|
||||
|
||||
r := a.c.newRequest("POST", "/v1/acl/oidc/auth-url")
|
||||
r.setWriteOptions(q)
|
||||
r.obj = auth
|
||||
|
||||
rtt, resp, err := requireOK(a.c.doRequest(r))
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
wm := &WriteMeta{RequestTime: rtt}
|
||||
var out aclOIDCAuthURLResponse
|
||||
if err := decodeBody(resp, &out); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return out.AuthURL, wm, nil
|
||||
}
|
||||
|
||||
type aclOIDCAuthURLResponse struct {
|
||||
AuthURL string
|
||||
}
|
||||
|
||||
type ACLOIDCCallbackParams struct {
|
||||
AuthMethod string
|
||||
State string
|
||||
Code string
|
||||
ClientNonce string
|
||||
}
|
||||
|
||||
// OIDCCallback is the callback endpoint to complete an OIDC login.
|
||||
func (a *ACL) OIDCCallback(auth *ACLOIDCCallbackParams, q *WriteOptions) (*ACLToken, *WriteMeta, error) {
|
||||
if auth.AuthMethod == "" {
|
||||
return nil, nil, fmt.Errorf("Must specify an auth method name")
|
||||
}
|
||||
|
||||
r := a.c.newRequest("POST", "/v1/acl/oidc/callback")
|
||||
r.setWriteOptions(q)
|
||||
r.obj = auth
|
||||
|
||||
rtt, resp, err := requireOK(a.c.doRequest(r))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
wm := &WriteMeta{RequestTime: rtt}
|
||||
var out ACLToken
|
||||
if err := decodeBody(resp, &out); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return &out, wm, nil
|
||||
}
|
||||
|
|
|
@ -29,10 +29,14 @@ type cmd struct {
|
|||
bearerToken string
|
||||
|
||||
// flags
|
||||
authMethodName string
|
||||
authMethodName string
|
||||
authMethodType string
|
||||
|
||||
bearerTokenFile string
|
||||
tokenSinkFile string
|
||||
meta map[string]string
|
||||
|
||||
enterpriseCmd
|
||||
}
|
||||
|
||||
func (c *cmd) init() {
|
||||
|
@ -41,6 +45,9 @@ func (c *cmd) init() {
|
|||
c.flags.StringVar(&c.authMethodName, "method", "",
|
||||
"Name of the auth method to login to.")
|
||||
|
||||
c.flags.StringVar(&c.authMethodType, "type", "",
|
||||
"Type of the auth method to login to. This field is optional and defaults to no type.")
|
||||
|
||||
c.flags.StringVar(&c.bearerTokenFile, "bearer-token-file", "",
|
||||
"Path to a file containing a secret bearer token to use with this auth method.")
|
||||
|
||||
|
@ -51,6 +58,8 @@ func (c *cmd) init() {
|
|||
"Metadata to set on the token, formatted as key=value. This flag "+
|
||||
"may be specified multiple times to set multiple meta fields.")
|
||||
|
||||
c.initEnterpriseFlags()
|
||||
|
||||
c.http = &flags.HTTPFlags{}
|
||||
flags.Merge(c.flags, c.http.ClientFlags())
|
||||
flags.Merge(c.flags, c.http.ServerFlags())
|
||||
|
@ -76,6 +85,10 @@ func (c *cmd) Run(args []string) int {
|
|||
return 1
|
||||
}
|
||||
|
||||
return c.login()
|
||||
}
|
||||
|
||||
func (c *cmd) bearerTokenLogin() int {
|
||||
if c.bearerTokenFile == "" {
|
||||
c.UI.Error(fmt.Sprintf("Missing required '-bearer-token-file' flag"))
|
||||
return 1
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
//+build !consulent
|
||||
|
||||
package login
|
||||
|
||||
type enterpriseCmd struct {
|
||||
}
|
||||
|
||||
func (c *cmd) initEnterpriseFlags() {
|
||||
}
|
||||
|
||||
func (c *cmd) login() int {
|
||||
return c.bearerTokenLogin()
|
||||
}
|
|
@ -6,16 +6,20 @@ import (
|
|||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent"
|
||||
"github.com/hashicorp/consul/agent/consul/authmethod/kubeauth"
|
||||
"github.com/hashicorp/consul/agent/consul/authmethod/testauth"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/consul/command/acl"
|
||||
"github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest"
|
||||
"github.com/hashicorp/consul/sdk/freeport"
|
||||
"github.com/hashicorp/consul/sdk/testutil"
|
||||
"github.com/hashicorp/consul/testrpc"
|
||||
"github.com/mitchellh/cli"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
func TestLoginCommand_noTabs(t *testing.T) {
|
||||
|
@ -314,3 +318,154 @@ func TestLoginCommand_k8s(t *testing.T) {
|
|||
require.Len(t, token, 36, "must be a valid uid: %s", token)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoginCommand_jwt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDir := testutil.TempDir(t, "acl")
|
||||
defer os.RemoveAll(testDir)
|
||||
|
||||
a := agent.NewTestAgent(t, `
|
||||
primary_datacenter = "dc1"
|
||||
acl {
|
||||
enabled = true
|
||||
tokens {
|
||||
master = "root"
|
||||
}
|
||||
}`)
|
||||
|
||||
defer a.Shutdown()
|
||||
testrpc.WaitForLeader(t, a.RPC, "dc1")
|
||||
|
||||
client := a.Client()
|
||||
|
||||
tokenSinkFile := filepath.Join(testDir, "test.token")
|
||||
bearerTokenFile := filepath.Join(testDir, "bearer.token")
|
||||
|
||||
// spin up a fake oidc server
|
||||
oidcServer := startSSOTestServer(t)
|
||||
pubKey, privKey := oidcServer.SigningKeys()
|
||||
|
||||
type mConfig = map[string]interface{}
|
||||
cases := map[string]struct {
|
||||
f func(config mConfig)
|
||||
issuer string
|
||||
expectErr string
|
||||
}{
|
||||
"success - jwt static keys": {func(config mConfig) {
|
||||
config["BoundIssuer"] = "https://legit.issuer.internal/"
|
||||
config["JWTValidationPubKeys"] = []string{pubKey}
|
||||
},
|
||||
"https://legit.issuer.internal/",
|
||||
""},
|
||||
"success - jwt jwks": {func(config mConfig) {
|
||||
config["JWKSURL"] = oidcServer.Addr() + "/certs"
|
||||
config["JWKSCACert"] = oidcServer.CACert()
|
||||
},
|
||||
"https://legit.issuer.internal/",
|
||||
""},
|
||||
"success - jwt oidc discovery": {func(config mConfig) {
|
||||
config["OIDCDiscoveryURL"] = oidcServer.Addr()
|
||||
config["OIDCDiscoveryCACert"] = oidcServer.CACert()
|
||||
},
|
||||
oidcServer.Addr(),
|
||||
""},
|
||||
}
|
||||
|
||||
for name, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
method := &api.ACLAuthMethod{
|
||||
Name: "jwt",
|
||||
Type: "jwt",
|
||||
Config: map[string]interface{}{
|
||||
"JWTSupportedAlgs": []string{"ES256"},
|
||||
"ClaimMappings": map[string]string{
|
||||
"first_name": "name",
|
||||
"/org/primary": "primary_org",
|
||||
},
|
||||
"ListClaimMappings": map[string]string{
|
||||
"https://consul.test/groups": "groups",
|
||||
},
|
||||
"BoundAudiences": []string{"https://consul.test"},
|
||||
},
|
||||
}
|
||||
if tc.f != nil {
|
||||
tc.f(method.Config)
|
||||
}
|
||||
_, _, err := client.ACL().AuthMethodCreate(
|
||||
method,
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = client.ACL().BindingRuleCreate(&api.ACLBindingRule{
|
||||
AuthMethod: "jwt",
|
||||
BindType: api.BindingRuleBindTypeService,
|
||||
BindName: "test--${value.name}--${value.primary_org}",
|
||||
Selector: "value.name == jeff2 and value.primary_org == engineering and foo in list.groups",
|
||||
},
|
||||
&api.WriteOptions{Token: "root"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
cl := jwt.Claims{
|
||||
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
|
||||
Audience: jwt.Audience{"https://consul.test"},
|
||||
Issuer: tc.issuer,
|
||||
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Second)),
|
||||
}
|
||||
|
||||
type orgs struct {
|
||||
Primary string `json:"primary"`
|
||||
}
|
||||
|
||||
privateCl := struct {
|
||||
FirstName string `json:"first_name"`
|
||||
Org orgs `json:"org"`
|
||||
Groups []string `json:"https://consul.test/groups"`
|
||||
}{
|
||||
FirstName: "jeff2",
|
||||
Org: orgs{"engineering"},
|
||||
Groups: []string{"foo", "bar"},
|
||||
}
|
||||
|
||||
// Drop a JWT on disk.
|
||||
jwtData, err := oidcauthtest.SignJWT(privKey, cl, privateCl)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, ioutil.WriteFile(bearerTokenFile, []byte(jwtData), 0600))
|
||||
|
||||
defer os.Remove(tokenSinkFile)
|
||||
ui := cli.NewMockUi()
|
||||
cmd := New(ui)
|
||||
|
||||
args := []string{
|
||||
"-http-addr=" + a.HTTPAddr(),
|
||||
"-token=root",
|
||||
"-method=jwt",
|
||||
"-token-sink-file", tokenSinkFile,
|
||||
"-bearer-token-file", bearerTokenFile,
|
||||
}
|
||||
|
||||
code := cmd.Run(args)
|
||||
require.Equal(t, 0, code, "err: %s", ui.ErrorWriter.String())
|
||||
require.Empty(t, ui.ErrorWriter.String())
|
||||
require.Empty(t, ui.OutputWriter.String())
|
||||
|
||||
raw, err := ioutil.ReadFile(tokenSinkFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
token := strings.TrimSpace(string(raw))
|
||||
require.Len(t, token, 36, "must be a valid uid: %s", token)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func startSSOTestServer(t *testing.T) *oidcauthtest.Server {
|
||||
ports := freeport.MustTake(1)
|
||||
return oidcauthtest.Start(t, oidcauthtest.WithPort(
|
||||
ports[0],
|
||||
func() { freeport.Return(ports) },
|
||||
))
|
||||
}
|
||||
|
|
9
go.mod
9
go.mod
|
@ -17,6 +17,7 @@ require (
|
|||
github.com/armon/go-radix v1.0.0
|
||||
github.com/aws/aws-sdk-go v1.25.41
|
||||
github.com/coredns/coredns v1.1.2
|
||||
github.com/coreos/go-oidc v2.1.0+incompatible
|
||||
github.com/digitalocean/godo v1.10.0 // indirect
|
||||
github.com/docker/go-connections v0.3.0
|
||||
github.com/elazarl/go-bindata-assetfs v0.0.0-20160803192304-e1a2a7ec64b0
|
||||
|
@ -43,7 +44,7 @@ require (
|
|||
github.com/hashicorp/go-raftchunking v0.6.1
|
||||
github.com/hashicorp/go-sockaddr v1.0.2
|
||||
github.com/hashicorp/go-syslog v1.0.0
|
||||
github.com/hashicorp/go-uuid v1.0.1
|
||||
github.com/hashicorp/go-uuid v1.0.2
|
||||
github.com/hashicorp/go-version v1.2.0
|
||||
github.com/hashicorp/golang-lru v0.5.1
|
||||
github.com/hashicorp/hcl v1.0.0
|
||||
|
@ -65,9 +66,12 @@ require (
|
|||
github.com/mitchellh/go-testing-interface v1.14.0
|
||||
github.com/mitchellh/hashstructure v0.0.0-20170609045927-2bca23e0e452
|
||||
github.com/mitchellh/mapstructure v1.2.3
|
||||
github.com/mitchellh/pointerstructure v1.0.0
|
||||
github.com/mitchellh/reflectwalk v1.0.1
|
||||
github.com/pascaldekloe/goe v0.1.0
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/pkg/errors v0.8.1
|
||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
|
||||
github.com/prometheus/client_golang v1.0.0
|
||||
github.com/rboyer/safeio v0.2.1
|
||||
github.com/ryanuber/columnize v2.1.0+incompatible
|
||||
|
@ -78,6 +82,7 @@ require (
|
|||
go.opencensus.io v0.22.0 // indirect
|
||||
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975
|
||||
golang.org/x/net v0.0.0-20191004110552-13f9640d40b9
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58
|
||||
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4
|
||||
|
@ -85,7 +90,7 @@ require (
|
|||
google.golang.org/appengine v1.6.0 // indirect
|
||||
google.golang.org/genproto v0.0.0-20190530194941-fb225487d101 // indirect
|
||||
google.golang.org/grpc v1.23.0
|
||||
gopkg.in/square/go-jose.v2 v2.3.1
|
||||
gopkg.in/square/go-jose.v2 v2.4.1
|
||||
k8s.io/api v0.16.9
|
||||
k8s.io/apimachinery v0.16.9
|
||||
k8s.io/client-go v0.16.9
|
||||
|
|
12
go.sum
12
go.sum
|
@ -79,6 +79,8 @@ github.com/coredns/coredns v1.1.2/go.mod h1:zASH/MVDgR6XZTbxvOnsZfffS+31vg6Ackf/
|
|||
github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk=
|
||||
github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
|
||||
github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk=
|
||||
github.com/coreos/go-oidc v2.1.0+incompatible h1:sdJrfw8akMnCuUlaZU3tE/uYXFgfqom8DBE9so9EBsM=
|
||||
github.com/coreos/go-oidc v2.1.0+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc=
|
||||
github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
|
||||
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
|
||||
github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA=
|
||||
|
@ -230,6 +232,8 @@ github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdv
|
|||
github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||
github.com/hashicorp/go-uuid v1.0.1 h1:fv1ep09latC32wFoVwnqcnKJGnMSdBanPczbHAYm1BE=
|
||||
github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||
github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE=
|
||||
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||
github.com/hashicorp/go-version v1.1.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
|
||||
github.com/hashicorp/go-version v1.2.0 h1:3vNe/fWF5CBgRIguda1meWhsZHy3m8gCJ5wx+dIzX/E=
|
||||
github.com/hashicorp/go-version v1.2.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
|
||||
|
@ -343,6 +347,8 @@ github.com/mitchellh/mapstructure v1.1.2 h1:fmNYVwqnSfB9mZU6OS2O6GsXM+wcskZDuKQz
|
|||
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
|
||||
github.com/mitchellh/mapstructure v1.2.3 h1:f/MjBEBDLttYCGfRaKBbKSRVF5aV2O6fnBpzknuE3jU=
|
||||
github.com/mitchellh/mapstructure v1.2.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/mitchellh/pointerstructure v1.0.0 h1:ATSdz4NWrmWPOF1CeCBU4sMCno2hgqdbSrRPFWQSVZI=
|
||||
github.com/mitchellh/pointerstructure v1.0.0/go.mod h1:k4XwG94++jLVsSiTxo7qdIfXA9pj9EAeo0QsNNJOLZ8=
|
||||
github.com/mitchellh/reflectwalk v1.0.0/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
|
||||
github.com/mitchellh/reflectwalk v1.0.1 h1:FVzMWA5RllMAKIdUSC8mdWo3XtwoecrH79BY70sEEpE=
|
||||
github.com/mitchellh/reflectwalk v1.0.1/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
|
||||
|
@ -371,6 +377,8 @@ github.com/packethost/packngo v0.1.1-0.20180711074735-b9cb5096f54c/go.mod h1:otz
|
|||
github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
|
||||
github.com/pascaldekloe/goe v0.1.0 h1:cBOtyMzM9HTpWjXfbbunk26uA6nG3a8n06Wieeh0MwY=
|
||||
github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
||||
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
|
||||
github.com/peterbourgon/diskv v2.0.1+incompatible h1:UBdAOUP5p4RWqPBg048CAvpKN+vxiaj6gdUUzhl4XmI=
|
||||
github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU=
|
||||
|
@ -386,6 +394,8 @@ github.com/posener/complete v1.1.1 h1:ccV59UEOTzVDnDUEFdT95ZzHVZ+5+158q8+SJb2QV5
|
|||
github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI=
|
||||
github.com/posener/complete v1.2.3 h1:NP0eAhjcjImqslEwo/1hq7gpajME0fTLTezBKDqfXqo=
|
||||
github.com/posener/complete v1.2.3/go.mod h1:WZIdtGGp+qx0sLrYKtIRAruyNpv6hFCicSgv7Sy7s/s=
|
||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=
|
||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA=
|
||||
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
|
||||
github.com/prometheus/client_golang v0.9.2/go.mod h1:OsXs2jCmiKlQ1lTBmv21f2mNfw4xf/QclQDMrYNZzcM=
|
||||
github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso=
|
||||
|
@ -610,6 +620,8 @@ gopkg.in/resty.v1 v1.12.0 h1:CuXP0Pjfw9rOuY6EP+UvtNvt5DSqHpIxILZKT/quCZI=
|
|||
gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo=
|
||||
gopkg.in/square/go-jose.v2 v2.3.1 h1:SK5KegNXmKmqE342YYN2qPHEnUYeoMiXXl1poUlI+o4=
|
||||
gopkg.in/square/go-jose.v2 v2.3.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
|
||||
gopkg.in/square/go-jose.v2 v2.4.1 h1:H0TmLt7/KmzlrDOpa1F+zr0Tk90PbJYBfsVUmRLrf9Y=
|
||||
gopkg.in/square/go-jose.v2 v2.4.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||
gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74=
|
||||
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
# go-sso
|
||||
|
||||
This is a Go library that is being incubated in Consul to assist in doing
|
||||
opinionated OIDC-based single sign on.
|
||||
|
||||
The `go.mod.sample` and `go.sum.sample` files are what the overall real
|
||||
`go.mod` and `go.sum` files should end up being when extracted from the Consul
|
||||
codebase.
|
|
@ -0,0 +1,18 @@
|
|||
module github.com/hashicorp/consul/go-sso
|
||||
|
||||
go 1.13
|
||||
|
||||
require (
|
||||
github.com/coreos/go-oidc v2.1.0+incompatible
|
||||
github.com/hashicorp/go-cleanhttp v0.5.1
|
||||
github.com/hashicorp/go-hclog v0.12.0
|
||||
github.com/hashicorp/go-uuid v1.0.2
|
||||
github.com/mitchellh/go-testing-interface v1.14.0
|
||||
github.com/mitchellh/pointerstructure v1.0.0
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
|
||||
github.com/stretchr/testify v1.2.2
|
||||
golang.org/x/crypto v0.0.0-20191106202628-ed6320f186d4 // indirect
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45
|
||||
gopkg.in/square/go-jose.v2 v2.4.1
|
||||
)
|
|
@ -0,0 +1,56 @@
|
|||
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||
github.com/coreos/go-oidc v2.1.0+incompatible h1:sdJrfw8akMnCuUlaZU3tE/uYXFgfqom8DBE9so9EBsM=
|
||||
github.com/coreos/go-oidc v2.1.0+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys=
|
||||
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
|
||||
github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/hashicorp/go-cleanhttp v0.5.1 h1:dH3aiDG9Jvb5r5+bYHsikaOUIpcM0xvgMXVoDkXMzJM=
|
||||
github.com/hashicorp/go-cleanhttp v0.5.1/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80=
|
||||
github.com/hashicorp/go-hclog v0.12.0 h1:d4QkX8FRTYaKaCZBoXYY8zJX2BXjWxurN/GA2tkrmZM=
|
||||
github.com/hashicorp/go-hclog v0.12.0/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ=
|
||||
github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE=
|
||||
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||
github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA=
|
||||
github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
|
||||
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
|
||||
github.com/mattn/go-isatty v0.0.10 h1:qxFzApOv4WsAL965uUPIsXzAKCZxN2p9UqdhFS4ZW10=
|
||||
github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcMEpPG5Rm84=
|
||||
github.com/mitchellh/go-testing-interface v1.14.0 h1:/x0XQ6h+3U3nAyk1yx+bHPURrKa9sVVvYbuqZ7pIAtI=
|
||||
github.com/mitchellh/go-testing-interface v1.14.0/go.mod h1:gfgS7OtZj6MA4U1UrDRp04twqAjfvlZyCfX3sDjEym8=
|
||||
github.com/mitchellh/mapstructure v1.1.2 h1:fmNYVwqnSfB9mZU6OS2O6GsXM+wcskZDuKQzvN1EDeE=
|
||||
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
|
||||
github.com/mitchellh/pointerstructure v1.0.0 h1:ATSdz4NWrmWPOF1CeCBU4sMCno2hgqdbSrRPFWQSVZI=
|
||||
github.com/mitchellh/pointerstructure v1.0.0/go.mod h1:k4XwG94++jLVsSiTxo7qdIfXA9pj9EAeo0QsNNJOLZ8=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=
|
||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA=
|
||||
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191106202628-ed6320f186d4 h1:PDpCLFAH/YIX0QpHPf2eO7L4rC2OOirBrKtXTLLiNTY=
|
||||
golang.org/x/crypto v0.0.0-20191106202628-ed6320f186d4/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg=
|
||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45 h1:SVwTIAaPC2U/AvvLNZ2a7OVsmBpC8L5BlwK1whH3hm0=
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191008105621-543471e840be h1:QAcqgptGM8IQBC9K/RC4o+O9YmqEm0diQn9QmZw/0mU=
|
||||
golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508=
|
||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
gopkg.in/square/go-jose.v2 v2.4.1 h1:H0TmLt7/KmzlrDOpa1F+zr0Tk90PbJYBfsVUmRLrf9Y=
|
||||
gopkg.in/square/go-jose.v2 v2.4.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
|
|
@ -0,0 +1,129 @@
|
|||
// package oidcauth bundles up an opinionated approach to authentication using
|
||||
// both the OIDC authorization code workflow and simple JWT decoding (via
|
||||
// static keys, JWKS, and OIDC discovery).
|
||||
//
|
||||
// NOTE: This was roughly forked from hashicorp/vault-plugin-auth-jwt
|
||||
// originally at commit 825c85535e3832d254a74253a8e9ae105357778b with later
|
||||
// backports of behavior in 0e93b06cecb0477d6ee004e44b04832d110096cf
|
||||
package oidcauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
// Claims represents a set of claims or assertions computed about a given
|
||||
// authentication exchange.
|
||||
type Claims struct {
|
||||
// Values is a set of key/value string claims about the authentication
|
||||
// exchange.
|
||||
Values map[string]string
|
||||
|
||||
// Lists is a set of key/value string list claims about the authentication
|
||||
// exchange.
|
||||
Lists map[string][]string
|
||||
}
|
||||
|
||||
// Authenticator allows for extracting a set of claims from either an OIDC
|
||||
// authorization code exchange or a bare JWT.
|
||||
type Authenticator struct {
|
||||
config *Config
|
||||
logger hclog.Logger
|
||||
|
||||
// parsedJWTPubKeys is the parsed form of config.JWTValidationPubKeys
|
||||
parsedJWTPubKeys []interface{}
|
||||
provider *oidc.Provider
|
||||
keySet oidc.KeySet
|
||||
|
||||
// httpClient should be configured with all relevant root CA certs and be
|
||||
// reused for all OIDC or JWKS operations. This will be nil for the static
|
||||
// keys JWT configuration.
|
||||
httpClient *http.Client
|
||||
|
||||
l sync.Mutex
|
||||
oidcStates *cache.Cache
|
||||
|
||||
// backgroundCtx is a cancellable context primarily meant to be used for
|
||||
// things that may spawn background goroutines and are not tied to a
|
||||
// request/response lifecycle. Use backgroundCtxCancel to cancel this.
|
||||
backgroundCtx context.Context
|
||||
backgroundCtxCancel context.CancelFunc
|
||||
}
|
||||
|
||||
// New creates an authenticator suitable for use with either an OIDC
|
||||
// authorization code workflow or a bare JWT workflow depending upon the value
|
||||
// of the config Type.
|
||||
func New(c *Config, logger hclog.Logger) (*Authenticator, error) {
|
||||
if err := c.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var parsedJWTPubKeys []interface{}
|
||||
if c.Type == TypeJWT {
|
||||
for _, v := range c.JWTValidationPubKeys {
|
||||
key, err := parsePublicKeyPEM([]byte(v))
|
||||
if err != nil {
|
||||
// This shouldn't happen as the keys are already validated in Validate().
|
||||
return nil, fmt.Errorf("error parsing public key: %v", err)
|
||||
}
|
||||
parsedJWTPubKeys = append(parsedJWTPubKeys, key)
|
||||
}
|
||||
}
|
||||
|
||||
a := &Authenticator{
|
||||
config: c,
|
||||
logger: logger,
|
||||
parsedJWTPubKeys: parsedJWTPubKeys,
|
||||
}
|
||||
a.backgroundCtx, a.backgroundCtxCancel = context.WithCancel(context.Background())
|
||||
|
||||
if c.Type == TypeOIDC {
|
||||
a.oidcStates = cache.New(oidcStateTimeout, oidcStateCleanupInterval)
|
||||
}
|
||||
|
||||
var err error
|
||||
switch c.authType() {
|
||||
case authOIDCDiscovery, authOIDCFlow:
|
||||
a.httpClient, err = createHTTPClient(a.config.OIDCDiscoveryCACert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing OIDCDiscoveryCACert: %v", err)
|
||||
}
|
||||
|
||||
provider, err := oidc.NewProvider(
|
||||
contextWithHttpClient(a.backgroundCtx, a.httpClient),
|
||||
a.config.OIDCDiscoveryURL,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating provider: %v", err)
|
||||
}
|
||||
a.provider = provider
|
||||
case authJWKS:
|
||||
a.httpClient, err = createHTTPClient(a.config.JWKSCACert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing JWKSCACert: %v", err)
|
||||
}
|
||||
|
||||
a.keySet = oidc.NewRemoteKeySet(
|
||||
contextWithHttpClient(a.backgroundCtx, a.httpClient),
|
||||
a.config.JWKSURL,
|
||||
)
|
||||
}
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// Stop stops any background goroutines and does cleanup.
|
||||
func (a *Authenticator) Stop() {
|
||||
a.l.Lock()
|
||||
defer a.l.Unlock()
|
||||
if a.backgroundCtxCancel != nil {
|
||||
a.backgroundCtxCancel()
|
||||
a.backgroundCtxCancel = nil
|
||||
}
|
||||
}
|
|
@ -0,0 +1,348 @@
|
|||
package oidcauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
)
|
||||
|
||||
const (
|
||||
// TypeOIDC is the config type to specify if the OIDC authorization code
|
||||
// workflow is desired. The Authenticator methods GetAuthCodeURL and
|
||||
// ClaimsFromAuthCode are activated with the type.
|
||||
TypeOIDC = "oidc"
|
||||
|
||||
// TypeJWT is the config type to specify if simple JWT decoding (via static
|
||||
// keys, JWKS, and OIDC discovery) is desired. The Authenticator method
|
||||
// ClaimsFromJWT is activated with this type.
|
||||
TypeJWT = "jwt"
|
||||
)
|
||||
|
||||
// Config is the collection of all settings that pertain to doing OIDC-based
|
||||
// authentication and direct JWT-based authentication processes.
|
||||
type Config struct {
|
||||
// Type defines which kind of authentication will be happening, OIDC-based
|
||||
// or JWT-based. Allowed values are either 'oidc' or 'jwt'.
|
||||
//
|
||||
// Defaults to 'oidc' if unset.
|
||||
Type string
|
||||
|
||||
// -------
|
||||
// common for type=oidc and type=jwt
|
||||
// -------
|
||||
|
||||
// JWTSupportedAlgs is a list of supported signing algorithms. Defaults to
|
||||
// RS256.
|
||||
JWTSupportedAlgs []string
|
||||
|
||||
// Comma-separated list of 'aud' claims that are valid for login; any match
|
||||
// is sufficient
|
||||
// TODO(sso): actually just send these down as string claims?
|
||||
BoundAudiences []string
|
||||
|
||||
// Mappings of claims (key) that will be copied to a metadata field
|
||||
// (value). Use this if the claim you are capturing is singular (such as an
|
||||
// attribute).
|
||||
//
|
||||
// When mapped, the values can be any of a number, string, or boolean and
|
||||
// will all be stringified when returned.
|
||||
ClaimMappings map[string]string
|
||||
|
||||
// Mappings of claims (key) that will be copied to a metadata field
|
||||
// (value). Use this if the claim you are capturing is list-like (such as
|
||||
// groups).
|
||||
//
|
||||
// When mapped, the values in each list can be any of a number, string, or
|
||||
// boolean and will all be stringified when returned.
|
||||
ListClaimMappings map[string]string
|
||||
|
||||
// OIDCDiscoveryURL is the OIDC Discovery URL, without any .well-known
|
||||
// component (base path). Cannot be used with "JWKSURL" or
|
||||
// "JWTValidationPubKeys".
|
||||
OIDCDiscoveryURL string
|
||||
|
||||
// OIDCDiscoveryCACert is the CA certificate or chain of certificates, in
|
||||
// PEM format, to use to validate connections to the OIDC Discovery URL. If
|
||||
// not set, system certificates are used.
|
||||
OIDCDiscoveryCACert string
|
||||
|
||||
// -------
|
||||
// just for type=oidc
|
||||
// -------
|
||||
|
||||
// OIDCClientID is the OAuth Client ID configured with your OIDC provider.
|
||||
//
|
||||
// Valid only if Type=oidc
|
||||
OIDCClientID string
|
||||
|
||||
// The OAuth Client Secret configured with your OIDC provider.
|
||||
//
|
||||
// Valid only if Type=oidc
|
||||
OIDCClientSecret string
|
||||
|
||||
// Comma-separated list of OIDC scopes
|
||||
//
|
||||
// Valid only if Type=oidc
|
||||
OIDCScopes []string
|
||||
|
||||
// Comma-separated list of allowed values for redirect_uri
|
||||
//
|
||||
// Valid only if Type=oidc
|
||||
AllowedRedirectURIs []string
|
||||
|
||||
// Log received OIDC tokens and claims when debug-level logging is active.
|
||||
// Not recommended in production since sensitive information may be present
|
||||
// in OIDC responses.
|
||||
//
|
||||
// Valid only if Type=oidc
|
||||
VerboseOIDCLogging bool
|
||||
|
||||
// -------
|
||||
// just for type=jwt
|
||||
// -------
|
||||
|
||||
// JWKSURL is the JWKS URL to use to authenticate signatures. Cannot be
|
||||
// used with "OIDCDiscoveryURL" or "JWTValidationPubKeys".
|
||||
//
|
||||
// Valid only if Type=jwt
|
||||
JWKSURL string
|
||||
|
||||
// JWKSCACert is the CA certificate or chain of certificates, in PEM
|
||||
// format, to use to validate connections to the JWKS URL. If not set,
|
||||
// system certificates are used.
|
||||
//
|
||||
// Valid only if Type=jwt
|
||||
JWKSCACert string
|
||||
|
||||
// JWTValidationPubKeys is a list of PEM-encoded public keys to use to
|
||||
// authenticate signatures locally. Cannot be used with "JWKSURL" or
|
||||
// "OIDCDiscoveryURL".
|
||||
//
|
||||
// Valid only if Type=jwt
|
||||
JWTValidationPubKeys []string
|
||||
|
||||
// BoundIssuer is the value against which to match the 'iss' claim in a
|
||||
// JWT. Optional.
|
||||
//
|
||||
// Valid only if Type=jwt
|
||||
BoundIssuer string
|
||||
|
||||
// Duration in seconds of leeway when validating expiration of
|
||||
// a token to account for clock skew.
|
||||
//
|
||||
// Defaults to 150 (2.5 minutes) if set to 0 and can be disabled if set to -1.`,
|
||||
//
|
||||
// Valid only if Type=jwt
|
||||
ExpirationLeeway time.Duration
|
||||
|
||||
// Duration in seconds of leeway when validating not before values of a
|
||||
// token to account for clock skew.
|
||||
//
|
||||
// Defaults to 150 (2.5 minutes) if set to 0 and can be disabled if set to
|
||||
// -1.`,
|
||||
//
|
||||
// Valid only if Type=jwt
|
||||
NotBeforeLeeway time.Duration
|
||||
|
||||
// Duration in seconds of leeway when validating all claims to account for
|
||||
// clock skew.
|
||||
//
|
||||
// Defaults to 60 (1 minute) if set to 0 and can be disabled if set to
|
||||
// -1.`,
|
||||
//
|
||||
// Valid only if Type=jwt
|
||||
ClockSkewLeeway time.Duration
|
||||
}
|
||||
|
||||
// Validate returns an error if the config is not valid.
|
||||
func (c *Config) Validate() error {
|
||||
validateCtx, validateCtxCancel := context.WithCancel(context.Background())
|
||||
defer validateCtxCancel()
|
||||
|
||||
switch c.Type {
|
||||
case TypeOIDC, "":
|
||||
// required
|
||||
switch {
|
||||
case c.OIDCDiscoveryURL == "":
|
||||
return fmt.Errorf("'OIDCDiscoveryURL' must be set for type %q", c.Type)
|
||||
case c.OIDCClientID == "":
|
||||
return fmt.Errorf("'OIDCClientID' must be set for type %q", c.Type)
|
||||
case c.OIDCClientSecret == "":
|
||||
return fmt.Errorf("'OIDCClientSecret' must be set for type %q", c.Type)
|
||||
case len(c.AllowedRedirectURIs) == 0:
|
||||
return fmt.Errorf("'AllowedRedirectURIs' must be set for type %q", c.Type)
|
||||
}
|
||||
|
||||
// not allowed
|
||||
switch {
|
||||
case c.JWKSURL != "":
|
||||
return fmt.Errorf("'JWKSURL' must not be set for type %q", c.Type)
|
||||
case c.JWKSCACert != "":
|
||||
return fmt.Errorf("'JWKSCACert' must not be set for type %q", c.Type)
|
||||
case len(c.JWTValidationPubKeys) != 0:
|
||||
return fmt.Errorf("'JWTValidationPubKeys' must not be set for type %q", c.Type)
|
||||
case c.BoundIssuer != "":
|
||||
return fmt.Errorf("'BoundIssuer' must not be set for type %q", c.Type)
|
||||
case c.ExpirationLeeway != 0:
|
||||
return fmt.Errorf("'ExpirationLeeway' must not be set for type %q", c.Type)
|
||||
case c.NotBeforeLeeway != 0:
|
||||
return fmt.Errorf("'NotBeforeLeeway' must not be set for type %q", c.Type)
|
||||
case c.ClockSkewLeeway != 0:
|
||||
return fmt.Errorf("'ClockSkewLeeway' must not be set for type %q", c.Type)
|
||||
}
|
||||
|
||||
var bad []string
|
||||
for _, allowed := range c.AllowedRedirectURIs {
|
||||
if _, err := url.Parse(allowed); err != nil {
|
||||
bad = append(bad, allowed)
|
||||
}
|
||||
}
|
||||
if len(bad) > 0 {
|
||||
return fmt.Errorf("Invalid AllowedRedirectURIs provided: %v", bad)
|
||||
}
|
||||
|
||||
case TypeJWT:
|
||||
// not allowed
|
||||
switch {
|
||||
case c.OIDCClientID != "":
|
||||
return fmt.Errorf("'OIDCClientID' must not be set for type %q", c.Type)
|
||||
case c.OIDCClientSecret != "":
|
||||
return fmt.Errorf("'OIDCClientSecret' must not be set for type %q", c.Type)
|
||||
case len(c.OIDCScopes) != 0:
|
||||
return fmt.Errorf("'OIDCScopes' must not be set for type %q", c.Type)
|
||||
case len(c.AllowedRedirectURIs) != 0:
|
||||
return fmt.Errorf("'AllowedRedirectURIs' must not be set for type %q", c.Type)
|
||||
case c.VerboseOIDCLogging:
|
||||
return fmt.Errorf("'VerboseOIDCLogging' must not be set for type %q", c.Type)
|
||||
}
|
||||
|
||||
methodCount := 0
|
||||
if c.OIDCDiscoveryURL != "" {
|
||||
methodCount++
|
||||
}
|
||||
if len(c.JWTValidationPubKeys) != 0 {
|
||||
methodCount++
|
||||
}
|
||||
if c.JWKSURL != "" {
|
||||
methodCount++
|
||||
}
|
||||
|
||||
if methodCount != 1 {
|
||||
return fmt.Errorf("exactly one of 'JWTValidationPubKeys', 'JWKSURL', or 'OIDCDiscoveryURL' must be set for type %q", c.Type)
|
||||
}
|
||||
|
||||
if c.JWKSURL != "" {
|
||||
httpClient, err := createHTTPClient(c.JWKSCACert)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error checking JWKSCACert: %v", err)
|
||||
}
|
||||
|
||||
ctx := contextWithHttpClient(validateCtx, httpClient)
|
||||
keyset := oidc.NewRemoteKeySet(ctx, c.JWKSURL)
|
||||
|
||||
// Try to verify a correctly formatted JWT. The signature will fail
|
||||
// to match, but other errors with fetching the remote keyset
|
||||
// should be reported.
|
||||
_, err = keyset.VerifySignature(ctx, testJWT)
|
||||
if err == nil {
|
||||
err = errors.New("unexpected verification of JWT")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "failed to verify id token signature") {
|
||||
return fmt.Errorf("error checking JWKSURL: %v", err)
|
||||
}
|
||||
} else if c.JWKSCACert != "" {
|
||||
return fmt.Errorf("'JWKSCACert' should not be set unless 'JWKSURL' is set")
|
||||
}
|
||||
|
||||
if len(c.JWTValidationPubKeys) != 0 {
|
||||
for i, v := range c.JWTValidationPubKeys {
|
||||
if _, err := parsePublicKeyPEM([]byte(v)); err != nil {
|
||||
return fmt.Errorf("error parsing public key JWTValidationPubKeys[%d]: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("authenticator type should be %q or %q", TypeOIDC, TypeJWT)
|
||||
}
|
||||
|
||||
if c.OIDCDiscoveryURL != "" {
|
||||
httpClient, err := createHTTPClient(c.OIDCDiscoveryCACert)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error checking OIDCDiscoveryCACert: %v", err)
|
||||
}
|
||||
|
||||
ctx := contextWithHttpClient(validateCtx, httpClient)
|
||||
if _, err := oidc.NewProvider(ctx, c.OIDCDiscoveryURL); err != nil {
|
||||
return fmt.Errorf("error checking OIDCDiscoveryURL: %v", err)
|
||||
}
|
||||
} else if c.OIDCDiscoveryCACert != "" {
|
||||
return fmt.Errorf("'OIDCDiscoveryCACert' should not be set unless 'OIDCDiscoveryURL' is set")
|
||||
}
|
||||
|
||||
for _, a := range c.JWTSupportedAlgs {
|
||||
switch a {
|
||||
case oidc.RS256, oidc.RS384, oidc.RS512,
|
||||
oidc.ES256, oidc.ES384, oidc.ES512,
|
||||
oidc.PS256, oidc.PS384, oidc.PS512:
|
||||
default:
|
||||
return fmt.Errorf("Invalid supported algorithm: %s", a)
|
||||
}
|
||||
}
|
||||
|
||||
if len(c.ClaimMappings) > 0 {
|
||||
targets := make(map[string]bool)
|
||||
for _, mappedKey := range c.ClaimMappings {
|
||||
if targets[mappedKey] {
|
||||
return fmt.Errorf("ClaimMappings contains multiple mappings for key %q", mappedKey)
|
||||
}
|
||||
targets[mappedKey] = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(c.ListClaimMappings) > 0 {
|
||||
targets := make(map[string]bool)
|
||||
for _, mappedKey := range c.ListClaimMappings {
|
||||
if targets[mappedKey] {
|
||||
return fmt.Errorf("ListClaimMappings contains multiple mappings for key %q", mappedKey)
|
||||
}
|
||||
targets[mappedKey] = true
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
authUnconfigured = iota
|
||||
authStaticKeys
|
||||
authJWKS
|
||||
authOIDCDiscovery
|
||||
authOIDCFlow
|
||||
)
|
||||
|
||||
// authType classifies the authorization type/flow based on config parameters.
|
||||
// It is only valid to invoke if Validate() returns a nil error.
|
||||
func (c *Config) authType() int {
|
||||
switch {
|
||||
case len(c.JWTValidationPubKeys) > 0:
|
||||
return authStaticKeys
|
||||
case c.JWKSURL != "":
|
||||
return authJWKS
|
||||
case c.OIDCDiscoveryURL != "":
|
||||
if c.OIDCClientID != "" && c.OIDCClientSecret != "" {
|
||||
return authOIDCFlow
|
||||
}
|
||||
return authOIDCDiscovery
|
||||
default:
|
||||
return authUnconfigured
|
||||
}
|
||||
}
|
||||
|
||||
const testJWT = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Hf3E3iCHzqC5QIQ0nCqS1kw78IiQTRVzsLTuKoDIpdk"
|
|
@ -0,0 +1,655 @@
|
|||
package oidcauth
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConfigValidate(t *testing.T) {
|
||||
type testcase struct {
|
||||
config Config
|
||||
expectAuthType int
|
||||
expectErr string
|
||||
}
|
||||
|
||||
srv := oidcauthtest.Start(t)
|
||||
|
||||
oidcCases := map[string]testcase{
|
||||
"all required": {
|
||||
config: Config{
|
||||
Type: TypeOIDC,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
OIDCClientID: "abc",
|
||||
OIDCClientSecret: "def",
|
||||
AllowedRedirectURIs: []string{"http://foo.test"},
|
||||
},
|
||||
expectAuthType: authOIDCFlow,
|
||||
},
|
||||
"missing required OIDCDiscoveryURL": {
|
||||
config: Config{
|
||||
Type: TypeOIDC,
|
||||
// OIDCDiscoveryURL: srv.Addr(),
|
||||
// OIDCDiscoveryCACert: srv.CACert(),
|
||||
OIDCClientID: "abc",
|
||||
OIDCClientSecret: "def",
|
||||
AllowedRedirectURIs: []string{"http://foo.test"},
|
||||
},
|
||||
expectErr: "must be set for type",
|
||||
},
|
||||
"missing required OIDCClientID": {
|
||||
config: Config{
|
||||
Type: TypeOIDC,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
// OIDCClientID: "abc",
|
||||
OIDCClientSecret: "def",
|
||||
AllowedRedirectURIs: []string{"http://foo.test"},
|
||||
},
|
||||
expectErr: "must be set for type",
|
||||
},
|
||||
"missing required OIDCClientSecret": {
|
||||
config: Config{
|
||||
Type: TypeOIDC,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
OIDCClientID: "abc",
|
||||
// OIDCClientSecret: "def",
|
||||
AllowedRedirectURIs: []string{"http://foo.test"},
|
||||
},
|
||||
expectErr: "must be set for type",
|
||||
},
|
||||
"missing required AllowedRedirectURIs": {
|
||||
config: Config{
|
||||
Type: TypeOIDC,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
OIDCClientID: "abc",
|
||||
OIDCClientSecret: "def",
|
||||
AllowedRedirectURIs: []string{},
|
||||
},
|
||||
expectErr: "must be set for type",
|
||||
},
|
||||
"incompatible with JWKSURL": {
|
||||
config: Config{
|
||||
Type: TypeOIDC,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
OIDCClientID: "abc",
|
||||
OIDCClientSecret: "def",
|
||||
AllowedRedirectURIs: []string{"http://foo.test"},
|
||||
JWKSURL: srv.Addr() + "/certs",
|
||||
},
|
||||
expectErr: "must not be set for type",
|
||||
},
|
||||
"incompatible with JWKSCACert": {
|
||||
config: Config{
|
||||
Type: TypeOIDC,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
OIDCClientID: "abc",
|
||||
OIDCClientSecret: "def",
|
||||
AllowedRedirectURIs: []string{"http://foo.test"},
|
||||
JWKSCACert: srv.CACert(),
|
||||
},
|
||||
expectErr: "must not be set for type",
|
||||
},
|
||||
"incompatible with JWTValidationPubKeys": {
|
||||
config: Config{
|
||||
Type: TypeOIDC,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
OIDCClientID: "abc",
|
||||
OIDCClientSecret: "def",
|
||||
AllowedRedirectURIs: []string{"http://foo.test"},
|
||||
JWTValidationPubKeys: []string{testJWTPubKey},
|
||||
},
|
||||
expectErr: "must not be set for type",
|
||||
},
|
||||
"incompatible with BoundIssuer": {
|
||||
config: Config{
|
||||
Type: TypeOIDC,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
OIDCClientID: "abc",
|
||||
OIDCClientSecret: "def",
|
||||
AllowedRedirectURIs: []string{"http://foo.test"},
|
||||
BoundIssuer: "foo",
|
||||
},
|
||||
expectErr: "must not be set for type",
|
||||
},
|
||||
"incompatible with ExpirationLeeway": {
|
||||
config: Config{
|
||||
Type: TypeOIDC,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
OIDCClientID: "abc",
|
||||
OIDCClientSecret: "def",
|
||||
AllowedRedirectURIs: []string{"http://foo.test"},
|
||||
ExpirationLeeway: 1 * time.Second,
|
||||
},
|
||||
expectErr: "must not be set for type",
|
||||
},
|
||||
"incompatible with NotBeforeLeeway": {
|
||||
config: Config{
|
||||
Type: TypeOIDC,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
OIDCClientID: "abc",
|
||||
OIDCClientSecret: "def",
|
||||
AllowedRedirectURIs: []string{"http://foo.test"},
|
||||
NotBeforeLeeway: 1 * time.Second,
|
||||
},
|
||||
expectErr: "must not be set for type",
|
||||
},
|
||||
"incompatible with ClockSkewLeeway": {
|
||||
config: Config{
|
||||
Type: TypeOIDC,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
OIDCClientID: "abc",
|
||||
OIDCClientSecret: "def",
|
||||
AllowedRedirectURIs: []string{"http://foo.test"},
|
||||
ClockSkewLeeway: 1 * time.Second,
|
||||
},
|
||||
expectErr: "must not be set for type",
|
||||
},
|
||||
"bad discovery cert": {
|
||||
config: Config{
|
||||
Type: TypeOIDC,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: oidcBadCACerts,
|
||||
OIDCClientID: "abc",
|
||||
OIDCClientSecret: "def",
|
||||
AllowedRedirectURIs: []string{"http://foo.test"},
|
||||
},
|
||||
expectErr: "certificate signed by unknown authority",
|
||||
},
|
||||
"garbage discovery cert": {
|
||||
config: Config{
|
||||
Type: TypeOIDC,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: garbageCACert,
|
||||
OIDCClientID: "abc",
|
||||
OIDCClientSecret: "def",
|
||||
AllowedRedirectURIs: []string{"http://foo.test"},
|
||||
},
|
||||
expectErr: "could not parse CA PEM value successfully",
|
||||
},
|
||||
"good discovery cert": {
|
||||
config: Config{
|
||||
Type: TypeOIDC,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
OIDCClientID: "abc",
|
||||
OIDCClientSecret: "def",
|
||||
AllowedRedirectURIs: []string{"http://foo.test"},
|
||||
},
|
||||
expectAuthType: authOIDCFlow,
|
||||
},
|
||||
"valid redirect uris": {
|
||||
config: Config{
|
||||
Type: TypeOIDC,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
OIDCClientID: "abc",
|
||||
OIDCClientSecret: "def",
|
||||
AllowedRedirectURIs: []string{
|
||||
"http://foo.test",
|
||||
"https://example.com",
|
||||
"https://evilcorp.com:8443",
|
||||
},
|
||||
},
|
||||
expectAuthType: authOIDCFlow,
|
||||
},
|
||||
"invalid redirect uris": {
|
||||
config: Config{
|
||||
Type: TypeOIDC,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
OIDCClientID: "abc",
|
||||
OIDCClientSecret: "def",
|
||||
AllowedRedirectURIs: []string{
|
||||
"%%%%",
|
||||
"http://foo.test",
|
||||
"https://example.com",
|
||||
"https://evilcorp.com:8443",
|
||||
},
|
||||
},
|
||||
expectErr: "Invalid AllowedRedirectURIs provided: [%%%%]",
|
||||
},
|
||||
"valid algorithm": {
|
||||
config: Config{
|
||||
Type: TypeOIDC,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
OIDCClientID: "abc",
|
||||
OIDCClientSecret: "def",
|
||||
AllowedRedirectURIs: []string{"http://foo.test"},
|
||||
JWTSupportedAlgs: []string{
|
||||
oidc.RS256, oidc.RS384, oidc.RS512,
|
||||
oidc.ES256, oidc.ES384, oidc.ES512,
|
||||
oidc.PS256, oidc.PS384, oidc.PS512,
|
||||
},
|
||||
},
|
||||
expectAuthType: authOIDCFlow,
|
||||
},
|
||||
"invalid algorithm": {
|
||||
config: Config{
|
||||
Type: TypeOIDC,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
OIDCClientID: "abc",
|
||||
OIDCClientSecret: "def",
|
||||
AllowedRedirectURIs: []string{"http://foo.test"},
|
||||
JWTSupportedAlgs: []string{
|
||||
oidc.RS256, oidc.RS384, oidc.RS512,
|
||||
oidc.ES256, oidc.ES384, oidc.ES512,
|
||||
oidc.PS256, oidc.PS384, oidc.PS512,
|
||||
"foo",
|
||||
},
|
||||
},
|
||||
expectErr: "Invalid supported algorithm",
|
||||
},
|
||||
"valid claim mappings": {
|
||||
config: Config{
|
||||
Type: TypeOIDC,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
OIDCClientID: "abc",
|
||||
OIDCClientSecret: "def",
|
||||
AllowedRedirectURIs: []string{"http://foo.test"},
|
||||
ClaimMappings: map[string]string{
|
||||
"foo": "bar",
|
||||
"peanutbutter": "jelly",
|
||||
"wd40": "ducttape",
|
||||
},
|
||||
ListClaimMappings: map[string]string{
|
||||
"foo": "bar",
|
||||
"peanutbutter": "jelly",
|
||||
"wd40": "ducttape",
|
||||
},
|
||||
},
|
||||
expectAuthType: authOIDCFlow,
|
||||
},
|
||||
"invalid repeated value claim mappings": {
|
||||
config: Config{
|
||||
Type: TypeOIDC,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
OIDCClientID: "abc",
|
||||
OIDCClientSecret: "def",
|
||||
AllowedRedirectURIs: []string{"http://foo.test"},
|
||||
ClaimMappings: map[string]string{
|
||||
"foo": "bar",
|
||||
"bling": "bar",
|
||||
"peanutbutter": "jelly",
|
||||
"wd40": "ducttape",
|
||||
},
|
||||
ListClaimMappings: map[string]string{
|
||||
"foo": "bar",
|
||||
"peanutbutter": "jelly",
|
||||
"wd40": "ducttape",
|
||||
},
|
||||
},
|
||||
expectErr: "ClaimMappings contains multiple mappings for key",
|
||||
},
|
||||
"invalid repeated list claim mappings": {
|
||||
config: Config{
|
||||
Type: TypeOIDC,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
OIDCClientID: "abc",
|
||||
OIDCClientSecret: "def",
|
||||
AllowedRedirectURIs: []string{"http://foo.test"},
|
||||
ClaimMappings: map[string]string{
|
||||
"foo": "bar",
|
||||
"peanutbutter": "jelly",
|
||||
"wd40": "ducttape",
|
||||
},
|
||||
ListClaimMappings: map[string]string{
|
||||
"foo": "bar",
|
||||
"bling": "bar",
|
||||
"peanutbutter": "jelly",
|
||||
"wd40": "ducttape",
|
||||
},
|
||||
},
|
||||
expectErr: "ListClaimMappings contains multiple mappings for key",
|
||||
},
|
||||
}
|
||||
|
||||
jwtCases := map[string]testcase{
|
||||
"all required for oidc discovery": {
|
||||
config: Config{
|
||||
Type: TypeJWT,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
},
|
||||
expectAuthType: authOIDCDiscovery,
|
||||
},
|
||||
"all required for jwks": {
|
||||
config: Config{
|
||||
Type: TypeJWT,
|
||||
JWKSURL: srv.Addr() + "/certs",
|
||||
JWKSCACert: srv.CACert(), // needed to avoid self signed cert issue
|
||||
},
|
||||
expectAuthType: authJWKS,
|
||||
},
|
||||
"all required for public keys": {
|
||||
config: Config{
|
||||
Type: TypeJWT,
|
||||
JWTValidationPubKeys: []string{testJWTPubKey},
|
||||
},
|
||||
expectAuthType: authStaticKeys,
|
||||
},
|
||||
"incompatible with OIDCClientID": {
|
||||
config: Config{
|
||||
Type: TypeJWT,
|
||||
JWTValidationPubKeys: []string{testJWTPubKey},
|
||||
OIDCClientID: "abc",
|
||||
},
|
||||
expectErr: "must not be set for type",
|
||||
},
|
||||
"incompatible with OIDCClientSecret": {
|
||||
config: Config{
|
||||
Type: TypeJWT,
|
||||
JWTValidationPubKeys: []string{testJWTPubKey},
|
||||
OIDCClientSecret: "abc",
|
||||
},
|
||||
expectErr: "must not be set for type",
|
||||
},
|
||||
"incompatible with OIDCScopes": {
|
||||
config: Config{
|
||||
Type: TypeJWT,
|
||||
JWTValidationPubKeys: []string{testJWTPubKey},
|
||||
OIDCScopes: []string{"blah"},
|
||||
},
|
||||
expectErr: "must not be set for type",
|
||||
},
|
||||
"incompatible with AllowedRedirectURIs": {
|
||||
config: Config{
|
||||
Type: TypeJWT,
|
||||
JWTValidationPubKeys: []string{testJWTPubKey},
|
||||
AllowedRedirectURIs: []string{"http://foo.test"},
|
||||
},
|
||||
expectErr: "must not be set for type",
|
||||
},
|
||||
"incompatible with VerboseOIDCLogging": {
|
||||
config: Config{
|
||||
Type: TypeJWT,
|
||||
JWTValidationPubKeys: []string{testJWTPubKey},
|
||||
VerboseOIDCLogging: true,
|
||||
},
|
||||
expectErr: "must not be set for type",
|
||||
},
|
||||
"too many methods (discovery + jwks)": {
|
||||
config: Config{
|
||||
Type: TypeJWT,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
JWKSURL: srv.Addr() + "/certs",
|
||||
JWKSCACert: srv.CACert(),
|
||||
// JWTValidationPubKeys: []string{testJWTPubKey},
|
||||
},
|
||||
expectErr: "exactly one of",
|
||||
},
|
||||
"too many methods (discovery + pubkeys)": {
|
||||
config: Config{
|
||||
Type: TypeJWT,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
// JWKSURL: srv.Addr() + "/certs",
|
||||
// JWKSCACert: srv.CACert(),
|
||||
JWTValidationPubKeys: []string{testJWTPubKey},
|
||||
},
|
||||
expectErr: "exactly one of",
|
||||
},
|
||||
"too many methods (jwks + pubkeys)": {
|
||||
config: Config{
|
||||
Type: TypeJWT,
|
||||
// OIDCDiscoveryURL: srv.Addr(),
|
||||
// OIDCDiscoveryCACert: srv.CACert(),
|
||||
JWKSURL: srv.Addr() + "/certs",
|
||||
JWKSCACert: srv.CACert(),
|
||||
JWTValidationPubKeys: []string{testJWTPubKey},
|
||||
},
|
||||
expectErr: "exactly one of",
|
||||
},
|
||||
"too many methods (discovery + jwks + pubkeys)": {
|
||||
config: Config{
|
||||
Type: TypeJWT,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
JWKSURL: srv.Addr() + "/certs",
|
||||
JWKSCACert: srv.CACert(),
|
||||
JWTValidationPubKeys: []string{testJWTPubKey},
|
||||
},
|
||||
expectErr: "exactly one of",
|
||||
},
|
||||
"incompatible with JWKSCACert": {
|
||||
config: Config{
|
||||
Type: TypeJWT,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
JWKSCACert: srv.CACert(),
|
||||
},
|
||||
expectErr: "should not be set unless",
|
||||
},
|
||||
"invalid pubkey": {
|
||||
config: Config{
|
||||
Type: TypeJWT,
|
||||
JWTValidationPubKeys: []string{testJWTPubKeyBad},
|
||||
},
|
||||
expectErr: "error parsing public key",
|
||||
},
|
||||
"incompatible with OIDCDiscoveryCACert": {
|
||||
config: Config{
|
||||
Type: TypeJWT,
|
||||
JWTValidationPubKeys: []string{testJWTPubKey},
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
},
|
||||
expectErr: "should not be set unless",
|
||||
},
|
||||
"bad discovery cert": {
|
||||
config: Config{
|
||||
Type: TypeJWT,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: oidcBadCACerts,
|
||||
},
|
||||
expectErr: "certificate signed by unknown authority",
|
||||
},
|
||||
"good discovery cert": {
|
||||
config: Config{
|
||||
Type: TypeJWT,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
},
|
||||
expectAuthType: authOIDCDiscovery,
|
||||
},
|
||||
"jwks invalid 404": {
|
||||
config: Config{
|
||||
Type: TypeJWT,
|
||||
JWKSURL: srv.Addr() + "/certs_missing",
|
||||
JWKSCACert: srv.CACert(),
|
||||
},
|
||||
expectErr: "get keys failed",
|
||||
},
|
||||
"jwks mismatched certs": {
|
||||
config: Config{
|
||||
Type: TypeJWT,
|
||||
JWKSURL: srv.Addr() + "/certs_invalid",
|
||||
JWKSCACert: srv.CACert(),
|
||||
},
|
||||
expectErr: "failed to decode keys",
|
||||
},
|
||||
"jwks bad certs": {
|
||||
config: Config{
|
||||
Type: TypeJWT,
|
||||
JWKSURL: srv.Addr() + "/certs_invalid",
|
||||
JWKSCACert: garbageCACert,
|
||||
},
|
||||
expectErr: "could not parse CA PEM value successfully",
|
||||
},
|
||||
"valid algorithm": {
|
||||
config: Config{
|
||||
Type: TypeJWT,
|
||||
JWTValidationPubKeys: []string{testJWTPubKey},
|
||||
JWTSupportedAlgs: []string{
|
||||
oidc.RS256, oidc.RS384, oidc.RS512,
|
||||
oidc.ES256, oidc.ES384, oidc.ES512,
|
||||
oidc.PS256, oidc.PS384, oidc.PS512,
|
||||
},
|
||||
},
|
||||
expectAuthType: authStaticKeys,
|
||||
},
|
||||
"invalid algorithm": {
|
||||
config: Config{
|
||||
Type: TypeJWT,
|
||||
JWTValidationPubKeys: []string{testJWTPubKey},
|
||||
JWTSupportedAlgs: []string{
|
||||
oidc.RS256, oidc.RS384, oidc.RS512,
|
||||
oidc.ES256, oidc.ES384, oidc.ES512,
|
||||
oidc.PS256, oidc.PS384, oidc.PS512,
|
||||
"foo",
|
||||
},
|
||||
},
|
||||
expectErr: "Invalid supported algorithm",
|
||||
},
|
||||
"valid claim mappings": {
|
||||
config: Config{
|
||||
Type: TypeJWT,
|
||||
JWTValidationPubKeys: []string{testJWTPubKey},
|
||||
ClaimMappings: map[string]string{
|
||||
"foo": "bar",
|
||||
"peanutbutter": "jelly",
|
||||
"wd40": "ducttape",
|
||||
},
|
||||
ListClaimMappings: map[string]string{
|
||||
"foo": "bar",
|
||||
"peanutbutter": "jelly",
|
||||
"wd40": "ducttape",
|
||||
},
|
||||
},
|
||||
expectAuthType: authStaticKeys,
|
||||
},
|
||||
"invalid repeated value claim mappings": {
|
||||
config: Config{
|
||||
Type: TypeJWT,
|
||||
JWTValidationPubKeys: []string{testJWTPubKey},
|
||||
ClaimMappings: map[string]string{
|
||||
"foo": "bar",
|
||||
"bling": "bar",
|
||||
"peanutbutter": "jelly",
|
||||
"wd40": "ducttape",
|
||||
},
|
||||
ListClaimMappings: map[string]string{
|
||||
"foo": "bar",
|
||||
"peanutbutter": "jelly",
|
||||
"wd40": "ducttape",
|
||||
},
|
||||
},
|
||||
expectErr: "ClaimMappings contains multiple mappings for key",
|
||||
},
|
||||
"invalid repeated list claim mappings": {
|
||||
config: Config{
|
||||
Type: TypeJWT,
|
||||
JWTValidationPubKeys: []string{testJWTPubKey},
|
||||
ClaimMappings: map[string]string{
|
||||
"foo": "bar",
|
||||
"peanutbutter": "jelly",
|
||||
"wd40": "ducttape",
|
||||
},
|
||||
ListClaimMappings: map[string]string{
|
||||
"foo": "bar",
|
||||
"bling": "bar",
|
||||
"peanutbutter": "jelly",
|
||||
"wd40": "ducttape",
|
||||
},
|
||||
},
|
||||
expectErr: "ListClaimMappings contains multiple mappings for key",
|
||||
},
|
||||
}
|
||||
|
||||
cases := map[string]testcase{
|
||||
"bad type": {
|
||||
config: Config{Type: "invalid"},
|
||||
expectErr: "authenticator type should be",
|
||||
},
|
||||
}
|
||||
|
||||
for k, v := range oidcCases {
|
||||
cases["type=oidc/"+k] = v
|
||||
|
||||
v2 := v
|
||||
v2.config.Type = ""
|
||||
cases["type=inferred_oidc/"+k] = v2
|
||||
}
|
||||
for k, v := range jwtCases {
|
||||
cases["type=jwt/"+k] = v
|
||||
}
|
||||
|
||||
for name, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := tc.config.Validate()
|
||||
if tc.expectErr != "" {
|
||||
require.Error(t, err)
|
||||
requireErrorContains(t, err, tc.expectErr)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.expectAuthType, tc.config.authType())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func requireErrorContains(t *testing.T, err error, expectedErrorMessage string) {
|
||||
t.Helper()
|
||||
if err == nil {
|
||||
t.Fatal("An error is expected but got nil.")
|
||||
}
|
||||
if !strings.Contains(err.Error(), expectedErrorMessage) {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
testJWTPubKey = `-----BEGIN PUBLIC KEY-----
|
||||
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEEVs/o5+uQbTjL3chynL4wXgUg2R9
|
||||
q9UU8I5mEovUf86QZ7kOBIjJwqnzD1omageEHWwHdBO6B+dFabmdT9POxg==
|
||||
-----END PUBLIC KEY-----`
|
||||
|
||||
testJWTPubKeyBad = `-----BEGIN PUBLIC KEY-----
|
||||
MFkwEwYHKoZIzj0CAQYIrollingyourricksEVs/o5+uQbTjL3chynL4wXgUg2R9
|
||||
q9UU8I5mEovUf86QZ7kOBIjJwqnzD1omageEHWwHdBO6B+dFabmdT9POxg==
|
||||
-----END PUBLIC KEY-----`
|
||||
|
||||
garbageCACert = `this is not a key`
|
||||
|
||||
oidcBadCACerts = `-----BEGIN CERTIFICATE-----
|
||||
MIIDYDCCAkigAwIBAgIJAK8uAVsPxWKGMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV
|
||||
BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX
|
||||
aWRnaXRzIFB0eSBMdGQwHhcNMTgwNzA5MTgwODI5WhcNMjgwNzA2MTgwODI5WjBF
|
||||
MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50
|
||||
ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB
|
||||
CgKCAQEA1eaEmIHKQqDlSadCtg6YY332qIMoeSb2iZTRhBRYBXRhMIKF3HoLXlI8
|
||||
/3veheMnBQM7zxIeLwtJ4VuZVZcpJlqHdsXQVj6A8+8MlAzNh3+Xnv0tjZ83QLwZ
|
||||
D6FWvMEzihxATD9uTCu2qRgeKnMYQFq4EG72AGb5094zfsXTAiwCfiRPVumiNbs4
|
||||
Mr75vf+2DEhqZuyP7GR2n3BKzrWo62yAmgLQQ07zfd1u1buv8R72HCYXYpFul5qx
|
||||
slZHU3yR+tLiBKOYB+C/VuB7hJZfVx25InIL1HTpIwWvmdk3QzpSpAGIAxWMXSzS
|
||||
oRmBYGnsgR6WTymfXuokD4ZhHOpFZQIDAQABo1MwUTAdBgNVHQ4EFgQURh/QFJBn
|
||||
hMXcgB1bWbGiU9B2VBQwHwYDVR0jBBgwFoAURh/QFJBnhMXcgB1bWbGiU9B2VBQw
|
||||
DwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAr8CZLA3MQjMDWweS
|
||||
ax9S1fRb8ifxZ4RqDcLj3dw5KZqnjEo8ggczR66T7vVXet/2TFBKYJAM0np26Z4A
|
||||
WjZfrDT7/bHXseWQAUhw/k2d39o+Um4aXkGpg1Paky9D+ddMdbx1hFkYxDq6kYGd
|
||||
PlBYSEiYQvVxDx7s7H0Yj9FWKO8WIO6BRUEvLlG7k/Xpp1OI6dV3nqwJ9CbcbqKt
|
||||
ff4hAtoAmN0/x6yFclFFWX8s7bRGqmnoj39/r98kzeGFb/lPKgQjSVcBJuE7UO4k
|
||||
8HP6vsnr/ruSlzUMv6XvHtT68kGC1qO3MfqiPhdSa4nxf9g/1xyBmAw/Uf90BJrm
|
||||
sj9DpQ==
|
||||
-----END CERTIFICATE-----`
|
||||
)
|
|
@ -0,0 +1,11 @@
|
|||
package strutil
|
||||
|
||||
// StrListContains looks for a string in a list of strings.
|
||||
func StrListContains(haystack []string, needle string) bool {
|
||||
for _, item := range haystack {
|
||||
if item == needle {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
package strutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestStrListContains(t *testing.T) {
|
||||
tests := []struct {
|
||||
haystack []string
|
||||
needle string
|
||||
expected bool
|
||||
}{
|
||||
// found
|
||||
{[]string{"a"}, "a", true},
|
||||
{[]string{"a", "b", "c"}, "a", true},
|
||||
{[]string{"a", "b", "c"}, "b", true},
|
||||
{[]string{"a", "b", "c"}, "c", true},
|
||||
|
||||
// not found
|
||||
{nil, "", false},
|
||||
{[]string{}, "", false},
|
||||
{[]string{"a"}, "", false},
|
||||
{[]string{"a"}, "b", false},
|
||||
{[]string{"a", "b", "c"}, "x", false},
|
||||
}
|
||||
for _, test := range tests {
|
||||
ok := StrListContains(test.haystack, test.needle)
|
||||
assert.Equal(t, test.expected, ok, "failed on %s/%v", test.needle, test.haystack)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,207 @@
|
|||
package oidcauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
const claimDefaultLeeway = 150
|
||||
|
||||
// ClaimsFromJWT is unrelated to the OIDC authorization code workflow. This
|
||||
// allows for a JWT to be directly validated and decoded into a set of claims.
|
||||
//
|
||||
// Requires the authenticator's config type be set to 'jwt'.
|
||||
func (a *Authenticator) ClaimsFromJWT(ctx context.Context, jwt string) (*Claims, error) {
|
||||
if a.config.authType() == authOIDCFlow {
|
||||
return nil, fmt.Errorf("ClaimsFromJWT is incompatible with type %q", TypeOIDC)
|
||||
}
|
||||
if jwt == "" {
|
||||
return nil, errors.New("missing jwt")
|
||||
}
|
||||
|
||||
// Here is where things diverge. If it is using OIDC Discovery, validate that way;
|
||||
// otherwise validate against the locally configured or JWKS keys. Once things are
|
||||
// validated, we re-unify the request path when evaluating the claims.
|
||||
var (
|
||||
allClaims map[string]interface{}
|
||||
err error
|
||||
)
|
||||
switch a.config.authType() {
|
||||
case authStaticKeys, authJWKS:
|
||||
allClaims, err = a.verifyVanillaJWT(ctx, jwt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
case authOIDCDiscovery:
|
||||
allClaims, err = a.verifyOIDCToken(ctx, jwt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, errors.New("unhandled case during login")
|
||||
}
|
||||
|
||||
c, err := a.extractClaims(allClaims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if a.config.VerboseOIDCLogging && a.logger != nil {
|
||||
a.logger.Debug("OIDC provider response", "extracted_claims", c)
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (a *Authenticator) verifyVanillaJWT(ctx context.Context, loginToken string) (map[string]interface{}, error) {
|
||||
var (
|
||||
allClaims = map[string]interface{}{}
|
||||
claims = jwt.Claims{}
|
||||
)
|
||||
// TODO(sso): handle JWTSupportedAlgs
|
||||
switch a.config.authType() {
|
||||
case authJWKS:
|
||||
// Verify signature (and only signature... other elements are checked later)
|
||||
payload, err := a.keySet.VerifySignature(ctx, loginToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error verifying token: %v", err)
|
||||
}
|
||||
|
||||
// Unmarshal payload into two copies: public claims for library verification, and a set
|
||||
// of all received claims.
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal claims: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal(payload, &allClaims); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal claims: %v", err)
|
||||
}
|
||||
case authStaticKeys:
|
||||
parsedJWT, err := jwt.ParseSigned(loginToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing token: %v", err)
|
||||
}
|
||||
|
||||
var valid bool
|
||||
for _, key := range a.parsedJWTPubKeys {
|
||||
if err := parsedJWT.Claims(key, &claims, &allClaims); err == nil {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
return nil, errors.New("no known key successfully validated the token signature")
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported auth type for this verifyVanillaJWT: %d", a.config.authType())
|
||||
}
|
||||
|
||||
// We require notbefore or expiry; if only one is provided, we allow 5 minutes of leeway by default.
|
||||
// Configurable by ExpirationLeeway and NotBeforeLeeway
|
||||
if claims.IssuedAt == nil {
|
||||
claims.IssuedAt = new(jwt.NumericDate)
|
||||
}
|
||||
if claims.Expiry == nil {
|
||||
claims.Expiry = new(jwt.NumericDate)
|
||||
}
|
||||
if claims.NotBefore == nil {
|
||||
claims.NotBefore = new(jwt.NumericDate)
|
||||
}
|
||||
if *claims.IssuedAt == 0 && *claims.Expiry == 0 && *claims.NotBefore == 0 {
|
||||
return nil, errors.New("no issue time, notbefore, or expiration time encoded in token")
|
||||
}
|
||||
|
||||
if *claims.Expiry == 0 {
|
||||
latestStart := *claims.IssuedAt
|
||||
if *claims.NotBefore > *claims.IssuedAt {
|
||||
latestStart = *claims.NotBefore
|
||||
}
|
||||
leeway := a.config.ExpirationLeeway.Seconds()
|
||||
if a.config.ExpirationLeeway.Seconds() < 0 {
|
||||
leeway = 0
|
||||
} else if a.config.ExpirationLeeway.Seconds() == 0 {
|
||||
leeway = claimDefaultLeeway
|
||||
}
|
||||
*claims.Expiry = jwt.NumericDate(int64(latestStart) + int64(leeway))
|
||||
}
|
||||
|
||||
if *claims.NotBefore == 0 {
|
||||
if *claims.IssuedAt != 0 {
|
||||
*claims.NotBefore = *claims.IssuedAt
|
||||
} else {
|
||||
leeway := a.config.NotBeforeLeeway.Seconds()
|
||||
if a.config.NotBeforeLeeway.Seconds() < 0 {
|
||||
leeway = 0
|
||||
} else if a.config.NotBeforeLeeway.Seconds() == 0 {
|
||||
leeway = claimDefaultLeeway
|
||||
}
|
||||
*claims.NotBefore = jwt.NumericDate(int64(*claims.Expiry) - int64(leeway))
|
||||
}
|
||||
}
|
||||
|
||||
expected := jwt.Expected{
|
||||
Issuer: a.config.BoundIssuer,
|
||||
// Subject: a.config.BoundSubject,
|
||||
Time: time.Now(),
|
||||
}
|
||||
|
||||
cksLeeway := a.config.ClockSkewLeeway
|
||||
if a.config.ClockSkewLeeway.Seconds() < 0 {
|
||||
cksLeeway = 0
|
||||
} else if a.config.ClockSkewLeeway.Seconds() == 0 {
|
||||
cksLeeway = jwt.DefaultLeeway
|
||||
}
|
||||
|
||||
if err := claims.ValidateWithLeeway(expected, cksLeeway); err != nil {
|
||||
return nil, fmt.Errorf("error validating claims: %v", err)
|
||||
}
|
||||
|
||||
if err := validateAudience(a.config.BoundAudiences, claims.Audience, true); err != nil {
|
||||
return nil, fmt.Errorf("error validating claims: %v", err)
|
||||
}
|
||||
|
||||
return allClaims, nil
|
||||
}
|
||||
|
||||
// parsePublicKeyPEM is used to parse RSA, ECDSA, and Ed25519 public keys from PEMs
|
||||
//
|
||||
// Extracted from "github.com/hashicorp/vault/sdk/helper/certutil"
|
||||
//
|
||||
// go-sso added support for ed25519 (EdDSA)
|
||||
func parsePublicKeyPEM(data []byte) (interface{}, error) {
|
||||
block, data := pem.Decode(data)
|
||||
if block != nil {
|
||||
var rawKey interface{}
|
||||
var err error
|
||||
if rawKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil {
|
||||
if cert, err := x509.ParseCertificate(block.Bytes); err == nil {
|
||||
rawKey = cert.PublicKey
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if rsaPublicKey, ok := rawKey.(*rsa.PublicKey); ok {
|
||||
return rsaPublicKey, nil
|
||||
}
|
||||
if ecPublicKey, ok := rawKey.(*ecdsa.PublicKey); ok {
|
||||
return ecPublicKey, nil
|
||||
}
|
||||
if edPublicKey, ok := rawKey.(ed25519.PublicKey); ok {
|
||||
return edPublicKey, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("data does not contain any valid RSA, ECDSA, or ED25519 public keys")
|
||||
}
|
|
@ -0,0 +1,695 @@
|
|||
package oidcauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
func setupForJWT(t *testing.T, authType int, f func(c *Config)) (*Authenticator, string) {
|
||||
t.Helper()
|
||||
|
||||
config := &Config{
|
||||
Type: TypeJWT,
|
||||
JWTSupportedAlgs: []string{oidc.ES256},
|
||||
ClaimMappings: map[string]string{
|
||||
"first_name": "name",
|
||||
"/org/primary": "primary_org",
|
||||
"/nested/Size": "size",
|
||||
"Age": "age",
|
||||
"Admin": "is_admin",
|
||||
"/nested/division": "division",
|
||||
"/nested/remote": "is_remote",
|
||||
},
|
||||
ListClaimMappings: map[string]string{
|
||||
"https://go-sso/groups": "groups",
|
||||
},
|
||||
}
|
||||
|
||||
var issuer string
|
||||
switch authType {
|
||||
case authOIDCDiscovery:
|
||||
srv := oidcauthtest.Start(t)
|
||||
config.OIDCDiscoveryURL = srv.Addr()
|
||||
config.OIDCDiscoveryCACert = srv.CACert()
|
||||
|
||||
issuer = config.OIDCDiscoveryURL
|
||||
|
||||
// TODO(sso): is this a bug in vault?
|
||||
// config.BoundIssuer = issuer
|
||||
case authStaticKeys:
|
||||
pubKey, _ := oidcauthtest.SigningKeys()
|
||||
config.BoundIssuer = "https://legit.issuer.internal/"
|
||||
config.JWTValidationPubKeys = []string{pubKey}
|
||||
issuer = config.BoundIssuer
|
||||
case authJWKS:
|
||||
srv := oidcauthtest.Start(t)
|
||||
config.JWKSURL = srv.Addr() + "/certs"
|
||||
config.JWKSCACert = srv.CACert()
|
||||
|
||||
issuer = "https://legit.issuer.internal/"
|
||||
|
||||
// TODO(sso): is this a bug in vault?
|
||||
// config.BoundIssuer = issuer
|
||||
default:
|
||||
require.Fail(t, "inappropriate authType: %d", authType)
|
||||
}
|
||||
|
||||
if f != nil {
|
||||
f(config)
|
||||
}
|
||||
|
||||
require.NoError(t, config.Validate())
|
||||
|
||||
oa, err := New(config, hclog.NewNullLogger())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(oa.Stop)
|
||||
|
||||
return oa, issuer
|
||||
}
|
||||
|
||||
func TestJWT_OIDC_Functions_Fail(t *testing.T) {
|
||||
t.Run("static", func(t *testing.T) {
|
||||
testJWT_OIDC_Functions_Fail(t, authStaticKeys)
|
||||
})
|
||||
t.Run("JWKS", func(t *testing.T) {
|
||||
testJWT_OIDC_Functions_Fail(t, authJWKS)
|
||||
})
|
||||
t.Run("oidc discovery", func(t *testing.T) {
|
||||
testJWT_OIDC_Functions_Fail(t, authOIDCDiscovery)
|
||||
})
|
||||
}
|
||||
|
||||
func testJWT_OIDC_Functions_Fail(t *testing.T, authType int) {
|
||||
t.Helper()
|
||||
|
||||
t.Run("GetAuthCodeURL", func(t *testing.T) {
|
||||
oa, _ := setupForJWT(t, authType, nil)
|
||||
|
||||
_, err := oa.GetAuthCodeURL(
|
||||
context.Background(),
|
||||
"https://example.com",
|
||||
map[string]string{"foo": "bar"},
|
||||
)
|
||||
requireErrorContains(t, err, `GetAuthCodeURL is incompatible with type "jwt"`)
|
||||
})
|
||||
|
||||
t.Run("ClaimsFromAuthCode", func(t *testing.T) {
|
||||
oa, _ := setupForJWT(t, authType, nil)
|
||||
|
||||
_, _, err := oa.ClaimsFromAuthCode(
|
||||
context.Background(),
|
||||
"abc", "def",
|
||||
)
|
||||
requireErrorContains(t, err, `ClaimsFromAuthCode is incompatible with type "jwt"`)
|
||||
})
|
||||
}
|
||||
|
||||
func TestJWT_ClaimsFromJWT(t *testing.T) {
|
||||
t.Run("static", func(t *testing.T) {
|
||||
testJWT_ClaimsFromJWT(t, authStaticKeys)
|
||||
})
|
||||
t.Run("JWKS", func(t *testing.T) {
|
||||
testJWT_ClaimsFromJWT(t, authJWKS)
|
||||
})
|
||||
t.Run("oidc discovery", func(t *testing.T) {
|
||||
// TODO(sso): the vault versions of these tests did not run oidc-discovery
|
||||
testJWT_ClaimsFromJWT(t, authOIDCDiscovery)
|
||||
})
|
||||
}
|
||||
|
||||
func testJWT_ClaimsFromJWT(t *testing.T, authType int) {
|
||||
t.Helper()
|
||||
|
||||
t.Run("missing audience", func(t *testing.T) {
|
||||
if authType == authOIDCDiscovery {
|
||||
// TODO(sso): why isn't this strict?
|
||||
t.Skip("why?")
|
||||
return
|
||||
}
|
||||
oa, issuer := setupForJWT(t, authType, nil)
|
||||
|
||||
cl := jwt.Claims{
|
||||
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
|
||||
Issuer: issuer,
|
||||
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
|
||||
Audience: jwt.Audience{"https://go-sso.test"},
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Second)),
|
||||
}
|
||||
|
||||
privateCl := struct {
|
||||
User string `json:"https://go-sso/user"`
|
||||
Groups []string `json:"https://go-sso/groups"`
|
||||
}{
|
||||
"jeff",
|
||||
[]string{"foo", "bar"},
|
||||
}
|
||||
|
||||
jwtData, err := oidcauthtest.SignJWT("", cl, privateCl)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = oa.ClaimsFromJWT(context.Background(), jwtData)
|
||||
requireErrorContains(t, err, "audience claim found in JWT but no audiences are bound")
|
||||
})
|
||||
|
||||
t.Run("valid inputs", func(t *testing.T) {
|
||||
oa, issuer := setupForJWT(t, authType, func(c *Config) {
|
||||
c.BoundAudiences = []string{
|
||||
"https://go-sso.test",
|
||||
"another_audience",
|
||||
}
|
||||
})
|
||||
|
||||
cl := jwt.Claims{
|
||||
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
|
||||
Issuer: issuer,
|
||||
Audience: jwt.Audience{"https://go-sso.test"},
|
||||
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Second)),
|
||||
}
|
||||
|
||||
type orgs struct {
|
||||
Primary string `json:"primary"`
|
||||
}
|
||||
|
||||
type nested struct {
|
||||
Division int64 `json:"division"`
|
||||
Remote bool `json:"remote"`
|
||||
Size string `json:"Size"`
|
||||
}
|
||||
|
||||
privateCl := struct {
|
||||
User string `json:"https://go-sso/user"`
|
||||
Groups []string `json:"https://go-sso/groups"`
|
||||
FirstName string `json:"first_name"`
|
||||
Org orgs `json:"org"`
|
||||
Color string `json:"color"`
|
||||
Age int64 `json:"Age"`
|
||||
Admin bool `json:"Admin"`
|
||||
Nested nested `json:"nested"`
|
||||
}{
|
||||
User: "jeff",
|
||||
Groups: []string{"foo", "bar"},
|
||||
FirstName: "jeff2",
|
||||
Org: orgs{"engineering"},
|
||||
Color: "green",
|
||||
Age: 85,
|
||||
Admin: true,
|
||||
Nested: nested{
|
||||
Division: 3,
|
||||
Remote: true,
|
||||
Size: "medium",
|
||||
},
|
||||
}
|
||||
|
||||
jwtData, err := oidcauthtest.SignJWT("", cl, privateCl)
|
||||
require.NoError(t, err)
|
||||
|
||||
claims, err := oa.ClaimsFromJWT(context.Background(), jwtData)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedClaims := &Claims{
|
||||
Values: map[string]string{
|
||||
"name": "jeff2",
|
||||
"primary_org": "engineering",
|
||||
"size": "medium",
|
||||
"age": "85",
|
||||
"is_admin": "true",
|
||||
"division": "3",
|
||||
"is_remote": "true",
|
||||
},
|
||||
Lists: map[string][]string{
|
||||
"groups": []string{"foo", "bar"},
|
||||
},
|
||||
}
|
||||
|
||||
require.Equal(t, expectedClaims, claims)
|
||||
})
|
||||
|
||||
t.Run("unusable claims", func(t *testing.T) {
|
||||
oa, issuer := setupForJWT(t, authType, func(c *Config) {
|
||||
c.BoundAudiences = []string{
|
||||
"https://go-sso.test",
|
||||
"another_audience",
|
||||
}
|
||||
})
|
||||
|
||||
cl := jwt.Claims{
|
||||
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
|
||||
Issuer: issuer,
|
||||
Audience: jwt.Audience{"https://go-sso.test"},
|
||||
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Second)),
|
||||
}
|
||||
|
||||
type orgs struct {
|
||||
Primary string `json:"primary"`
|
||||
}
|
||||
|
||||
type nested struct {
|
||||
Division int64 `json:"division"`
|
||||
Remote bool `json:"remote"`
|
||||
Size []string `json:"Size"`
|
||||
}
|
||||
|
||||
privateCl := struct {
|
||||
User string `json:"https://go-sso/user"`
|
||||
Groups []string `json:"https://go-sso/groups"`
|
||||
FirstName string `json:"first_name"`
|
||||
Org orgs `json:"org"`
|
||||
Color string `json:"color"`
|
||||
Age int64 `json:"Age"`
|
||||
Admin bool `json:"Admin"`
|
||||
Nested nested `json:"nested"`
|
||||
}{
|
||||
User: "jeff",
|
||||
Groups: []string{"foo", "bar"},
|
||||
FirstName: "jeff2",
|
||||
Org: orgs{"engineering"},
|
||||
Color: "green",
|
||||
Age: 85,
|
||||
Admin: true,
|
||||
Nested: nested{
|
||||
Division: 3,
|
||||
Remote: true,
|
||||
Size: []string{"medium"},
|
||||
},
|
||||
}
|
||||
|
||||
jwtData, err := oidcauthtest.SignJWT("", cl, privateCl)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = oa.ClaimsFromJWT(context.Background(), jwtData)
|
||||
requireErrorContains(t, err, "error converting claim '/nested/Size' to string from unknown type []interface {}")
|
||||
})
|
||||
|
||||
t.Run("bad signature", func(t *testing.T) {
|
||||
oa, issuer := setupForJWT(t, authType, func(c *Config) {
|
||||
c.BoundAudiences = []string{
|
||||
"https://go-sso.test",
|
||||
"another_audience",
|
||||
}
|
||||
})
|
||||
|
||||
cl := jwt.Claims{
|
||||
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
|
||||
Issuer: issuer,
|
||||
Audience: jwt.Audience{"https://go-sso.test"},
|
||||
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Second)),
|
||||
}
|
||||
|
||||
privateCl := struct {
|
||||
User string `json:"https://go-sso/user"`
|
||||
Groups []string `json:"https://go-sso/groups"`
|
||||
}{
|
||||
"jeff",
|
||||
[]string{"foo", "bar"},
|
||||
}
|
||||
|
||||
jwtData, err := oidcauthtest.SignJWT(badPrivKey, cl, privateCl)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = oa.ClaimsFromJWT(context.Background(), jwtData)
|
||||
|
||||
switch authType {
|
||||
case authOIDCDiscovery, authJWKS:
|
||||
requireErrorContains(t, err, "failed to verify id token signature")
|
||||
case authStaticKeys:
|
||||
requireErrorContains(t, err, "no known key successfully validated the token signature")
|
||||
default:
|
||||
require.Fail(t, "unexpected type: %d", authType)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("bad issuer", func(t *testing.T) {
|
||||
oa, _ := setupForJWT(t, authType, func(c *Config) {
|
||||
c.BoundAudiences = []string{
|
||||
"https://go-sso.test",
|
||||
"another_audience",
|
||||
}
|
||||
})
|
||||
|
||||
cl := jwt.Claims{
|
||||
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
|
||||
Issuer: "https://not.real.issuer.internal/",
|
||||
Audience: jwt.Audience{"https://go-sso.test"},
|
||||
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Second)),
|
||||
}
|
||||
|
||||
privateCl := struct {
|
||||
User string `json:"https://go-sso/user"`
|
||||
Groups []string `json:"https://go-sso/groups"`
|
||||
}{
|
||||
"jeff",
|
||||
[]string{"foo", "bar"},
|
||||
}
|
||||
|
||||
jwtData, err := oidcauthtest.SignJWT("", cl, privateCl)
|
||||
require.NoError(t, err)
|
||||
|
||||
claims, err := oa.ClaimsFromJWT(context.Background(), jwtData)
|
||||
switch authType {
|
||||
case authOIDCDiscovery:
|
||||
requireErrorContains(t, err, "error validating signature: oidc: id token issued by a different provider")
|
||||
case authStaticKeys:
|
||||
requireErrorContains(t, err, "validation failed, invalid issuer claim (iss)")
|
||||
case authJWKS:
|
||||
// requireErrorContains(t, err, "validation failed, invalid issuer claim (iss)")
|
||||
// TODO(sso) The original vault test doesn't care about bound issuer.
|
||||
require.NoError(t, err)
|
||||
expectedClaims := &Claims{
|
||||
Values: map[string]string{},
|
||||
Lists: map[string][]string{
|
||||
"groups": []string{"foo", "bar"},
|
||||
},
|
||||
}
|
||||
require.Equal(t, expectedClaims, claims)
|
||||
default:
|
||||
require.Fail(t, "unexpected type: %d", authType)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("bad audience", func(t *testing.T) {
|
||||
oa, issuer := setupForJWT(t, authType, func(c *Config) {
|
||||
c.BoundAudiences = []string{
|
||||
"https://go-sso.test",
|
||||
"another_audience",
|
||||
}
|
||||
})
|
||||
|
||||
cl := jwt.Claims{
|
||||
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
|
||||
Issuer: issuer,
|
||||
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
|
||||
Audience: jwt.Audience{"https://fault.plugin.auth.jwt.test"},
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Second)),
|
||||
}
|
||||
|
||||
privateCl := struct {
|
||||
User string `json:"https://go-sso/user"`
|
||||
Groups []string `json:"https://go-sso/groups"`
|
||||
}{
|
||||
"jeff",
|
||||
[]string{"foo", "bar"},
|
||||
}
|
||||
|
||||
jwtData, err := oidcauthtest.SignJWT("", cl, privateCl)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = oa.ClaimsFromJWT(context.Background(), jwtData)
|
||||
requireErrorContains(t, err, "error validating claims: aud claim does not match any bound audience")
|
||||
})
|
||||
}
|
||||
|
||||
func TestJWT_ClaimsFromJWT_ExpiryClaims(t *testing.T) {
|
||||
t.Run("static", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testJWT_ClaimsFromJWT_ExpiryClaims(t, authStaticKeys)
|
||||
})
|
||||
t.Run("JWKS", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testJWT_ClaimsFromJWT_ExpiryClaims(t, authJWKS)
|
||||
})
|
||||
// TODO(sso): the vault versions of these tests did not run oidc-discovery
|
||||
// t.Run("oidc discovery", func(t *testing.T) {
|
||||
// t.Parallel()
|
||||
// testJWT_ClaimsFromJWT_ExpiryClaims(t, authOIDCDiscovery)
|
||||
// })
|
||||
}
|
||||
|
||||
func testJWT_ClaimsFromJWT_ExpiryClaims(t *testing.T, authType int) {
|
||||
t.Helper()
|
||||
|
||||
tests := map[string]struct {
|
||||
Valid bool
|
||||
IssuedAt time.Time
|
||||
NotBefore time.Time
|
||||
Expiration time.Time
|
||||
DefaultLeeway int
|
||||
ExpLeeway int
|
||||
}{
|
||||
// iat, auto clock_skew_leeway (60s), auto expiration leeway (150s)
|
||||
"auto expire leeway using iat with auto clock_skew_leeway": {true, time.Now().Add(-205 * time.Second), time.Time{}, time.Time{}, 0, 0},
|
||||
"expired auto expire leeway using iat with auto clock_skew_leeway": {false, time.Now().Add(-215 * time.Second), time.Time{}, time.Time{}, 0, 0},
|
||||
|
||||
// iat, clock_skew_leeway (10s), auto expiration leeway (150s)
|
||||
"auto expire leeway using iat with custom clock_skew_leeway": {true, time.Now().Add(-150 * time.Second), time.Time{}, time.Time{}, 10, 0},
|
||||
"expired auto expire leeway using iat with custom clock_skew_leeway": {false, time.Now().Add(-165 * time.Second), time.Time{}, time.Time{}, 10, 0},
|
||||
|
||||
// iat, no clock_skew_leeway (0s), auto expiration leeway (150s)
|
||||
"auto expire leeway using iat with no clock_skew_leeway": {true, time.Now().Add(-145 * time.Second), time.Time{}, time.Time{}, -1, 0},
|
||||
"expired auto expire leeway using iat with no clock_skew_leeway": {false, time.Now().Add(-155 * time.Second), time.Time{}, time.Time{}, -1, 0},
|
||||
|
||||
// nbf, auto clock_skew_leeway (60s), auto expiration leeway (150s)
|
||||
"auto expire leeway using nbf with auto clock_skew_leeway": {true, time.Time{}, time.Now().Add(-205 * time.Second), time.Time{}, 0, 0},
|
||||
"expired auto expire leeway using nbf with auto clock_skew_leeway": {false, time.Time{}, time.Now().Add(-215 * time.Second), time.Time{}, 0, 0},
|
||||
|
||||
// nbf, clock_skew_leeway (10s), auto expiration leeway (150s)
|
||||
"auto expire leeway using nbf with custom clock_skew_leeway": {true, time.Time{}, time.Now().Add(-145 * time.Second), time.Time{}, 10, 0},
|
||||
"expired auto expire leeway using nbf with custom clock_skew_leeway": {false, time.Time{}, time.Now().Add(-165 * time.Second), time.Time{}, 10, 0},
|
||||
|
||||
// nbf, no clock_skew_leeway (0s), auto expiration leeway (150s)
|
||||
"auto expire leeway using nbf with no clock_skew_leeway": {true, time.Time{}, time.Now().Add(-145 * time.Second), time.Time{}, -1, 0},
|
||||
"expired auto expire leeway using nbf with no clock_skew_leeway": {false, time.Time{}, time.Now().Add(-155 * time.Second), time.Time{}, -1, 0},
|
||||
|
||||
// iat, auto clock_skew_leeway (60s), custom expiration leeway (10s)
|
||||
"custom expire leeway using iat with clock_skew_leeway": {true, time.Now().Add(-65 * time.Second), time.Time{}, time.Time{}, 0, 10},
|
||||
"expired custom expire leeway using iat with clock_skew_leeway": {false, time.Now().Add(-75 * time.Second), time.Time{}, time.Time{}, 0, 10},
|
||||
|
||||
// iat, clock_skew_leeway (10s), custom expiration leeway (10s)
|
||||
"custom expire leeway using iat with clock_skew_leeway with default leeway": {true, time.Now().Add(-5 * time.Second), time.Time{}, time.Time{}, 10, 10},
|
||||
"expired custom expire leeway using iat with clock_skew_leeway with default leeway": {false, time.Now().Add(-25 * time.Second), time.Time{}, time.Time{}, 10, 10},
|
||||
|
||||
// iat, clock_skew_leeway (10s), no expiration leeway (10s)
|
||||
"no expire leeway using iat with clock_skew_leeway": {true, time.Now().Add(-5 * time.Second), time.Time{}, time.Time{}, 10, -1},
|
||||
"expired no expire leeway using iat with clock_skew_leeway": {false, time.Now().Add(-15 * time.Second), time.Time{}, time.Time{}, 10, -1},
|
||||
|
||||
// nbf, default clock_skew_leeway (60s), custom expiration leeway (10s)
|
||||
"custom expire leeway using nbf with clock_skew_leeway": {true, time.Time{}, time.Now().Add(-65 * time.Second), time.Time{}, 0, 10},
|
||||
"expired custom expire leeway using nbf with clock_skew_leeway": {false, time.Time{}, time.Now().Add(-75 * time.Second), time.Time{}, 0, 10},
|
||||
|
||||
// nbf, clock_skew_leeway (10s), custom expiration leeway (0s)
|
||||
"custom expire leeway using nbf with clock_skew_leeway with default leeway": {true, time.Time{}, time.Now().Add(-5 * time.Second), time.Time{}, 10, 10},
|
||||
"expired custom expire leeway using nbf with clock_skew_leeway with default leeway": {false, time.Time{}, time.Now().Add(-25 * time.Second), time.Time{}, 10, 10},
|
||||
|
||||
// nbf, clock_skew_leeway (10s), no expiration leeway (0s)
|
||||
"no expire leeway using nbf with clock_skew_leeway with default leeway": {true, time.Time{}, time.Now().Add(-5 * time.Second), time.Time{}, 10, -1},
|
||||
"no expire leeway using nbf with clock_skew_leeway with default leeway and nbf": {true, time.Time{}, time.Now().Add(-5 * time.Second), time.Time{}, 10, -100},
|
||||
"expired no expire leeway using nbf with clock_skew_leeway": {false, time.Time{}, time.Now().Add(-15 * time.Second), time.Time{}, 10, -1},
|
||||
"expired no expire leeway using nbf with clock_skew_leeway with default leeway and nbf": {false, time.Time{}, time.Now().Add(-15 * time.Second), time.Time{}, 10, -100},
|
||||
}
|
||||
|
||||
for name, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
oa, issuer := setupForJWT(t, authType, func(c *Config) {
|
||||
c.BoundAudiences = []string{
|
||||
"https://go-sso.test",
|
||||
"another_audience",
|
||||
}
|
||||
c.ClockSkewLeeway = time.Duration(tt.DefaultLeeway) * time.Second
|
||||
c.ExpirationLeeway = time.Duration(tt.ExpLeeway) * time.Second
|
||||
c.NotBeforeLeeway = 0
|
||||
})
|
||||
|
||||
jwtData := setupLogin(t, tt.IssuedAt, tt.Expiration, tt.NotBefore, issuer)
|
||||
|
||||
_, err := oa.ClaimsFromJWT(context.Background(), jwtData)
|
||||
if tt.Valid {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWT_ClaimsFromJWT_NotBeforeClaims(t *testing.T) {
|
||||
t.Run("static", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testJWT_ClaimsFromJWT_NotBeforeClaims(t, authStaticKeys)
|
||||
})
|
||||
t.Run("JWKS", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testJWT_ClaimsFromJWT_NotBeforeClaims(t, authJWKS)
|
||||
})
|
||||
// TODO(sso): the vault versions of these tests did not run oidc-discovery
|
||||
// t.Run("oidc discovery", func(t *testing.T) {
|
||||
// t.Parallel()
|
||||
// testJWT_ClaimsFromJWT_NotBeforeClaims(t, authOIDCDiscovery)
|
||||
// })
|
||||
}
|
||||
|
||||
func testJWT_ClaimsFromJWT_NotBeforeClaims(t *testing.T, authType int) {
|
||||
t.Helper()
|
||||
|
||||
tests := map[string]struct {
|
||||
Valid bool
|
||||
IssuedAt time.Time
|
||||
NotBefore time.Time
|
||||
Expiration time.Time
|
||||
DefaultLeeway int
|
||||
NBFLeeway int
|
||||
}{
|
||||
// iat, auto clock_skew_leeway (60s), no nbf leeway (0)
|
||||
"no nbf leeway using iat with auto clock_skew_leeway": {true, time.Now().Add(55 * time.Second), time.Time{}, time.Now(), 0, -1},
|
||||
"not yet valid no nbf leeway using iat with auto clock_skew_leeway": {false, time.Now().Add(65 * time.Second), time.Time{}, time.Now(), 0, -1},
|
||||
|
||||
// iat, clock_skew_leeway (10s), no nbf leeway (0s)
|
||||
"no nbf leeway using iat with custom clock_skew_leeway": {true, time.Now().Add(5 * time.Second), time.Time{}, time.Time{}, 10, -1},
|
||||
"not yet valid no nbf leeway using iat with custom clock_skew_leeway": {false, time.Now().Add(15 * time.Second), time.Time{}, time.Time{}, 10, -1},
|
||||
|
||||
// iat, no clock_skew_leeway (0s), nbf leeway (5s)
|
||||
"nbf leeway using iat with no clock_skew_leeway": {true, time.Now(), time.Time{}, time.Time{}, -1, 5},
|
||||
"not yet valid nbf leeway using iat with no clock_skew_leeway": {false, time.Now().Add(6 * time.Second), time.Time{}, time.Time{}, -1, 5},
|
||||
|
||||
// exp, auto clock_skew_leeway (60s), auto nbf leeway (150s)
|
||||
"auto nbf leeway using exp with auto clock_skew_leeway": {true, time.Time{}, time.Time{}, time.Now().Add(205 * time.Second), 0, 0},
|
||||
"not yet valid auto nbf leeway using exp with auto clock_skew_leeway": {false, time.Time{}, time.Time{}, time.Now().Add(215 * time.Second), 0, 0},
|
||||
|
||||
// exp, clock_skew_leeway (10s), auto nbf leeway (150s)
|
||||
"auto nbf leeway using exp with custom clock_skew_leeway": {true, time.Time{}, time.Time{}, time.Now().Add(150 * time.Second), 10, 0},
|
||||
"not yet valid auto nbf leeway using exp with custom clock_skew_leeway": {false, time.Time{}, time.Time{}, time.Now().Add(165 * time.Second), 10, 0},
|
||||
|
||||
// exp, no clock_skew_leeway (0s), auto nbf leeway (150s)
|
||||
"auto nbf leeway using exp with no clock_skew_leeway": {true, time.Time{}, time.Time{}, time.Now().Add(145 * time.Second), -1, 0},
|
||||
"not yet valid auto nbf leeway using exp with no clock_skew_leeway": {false, time.Time{}, time.Time{}, time.Now().Add(152 * time.Second), -1, 0},
|
||||
|
||||
// exp, auto clock_skew_leeway (60s), custom nbf leeway (10s)
|
||||
"custom nbf leeway using exp with auto clock_skew_leeway": {true, time.Time{}, time.Time{}, time.Now().Add(65 * time.Second), 0, 10},
|
||||
"not yet valid custom nbf leeway using exp with auto clock_skew_leeway": {false, time.Time{}, time.Time{}, time.Now().Add(75 * time.Second), 0, 10},
|
||||
|
||||
// exp, clock_skew_leeway (10s), custom nbf leeway (10s)
|
||||
"custom nbf leeway using exp with custom clock_skew_leeway": {true, time.Time{}, time.Time{}, time.Now().Add(15 * time.Second), 10, 10},
|
||||
"not yet valid custom nbf leeway using exp with custom clock_skew_leeway": {false, time.Time{}, time.Time{}, time.Now().Add(25 * time.Second), 10, 10},
|
||||
|
||||
// exp, no clock_skew_leeway (0s), custom nbf leeway (5s)
|
||||
"custom nbf leeway using exp with no clock_skew_leeway": {true, time.Time{}, time.Time{}, time.Now().Add(3 * time.Second), -1, 5},
|
||||
"custom nbf leeway using exp with no clock_skew_leeway with default leeway": {true, time.Time{}, time.Time{}, time.Now().Add(3 * time.Second), -100, 5},
|
||||
"not yet valid custom nbf leeway using exp with no clock_skew_leeway": {false, time.Time{}, time.Time{}, time.Now().Add(7 * time.Second), -1, 5},
|
||||
"not yet valid custom nbf leeway using exp with no clock_skew_leeway with default leeway": {false, time.Time{}, time.Time{}, time.Now().Add(7 * time.Second), -100, 5},
|
||||
}
|
||||
|
||||
for name, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oa, issuer := setupForJWT(t, authType, func(c *Config) {
|
||||
c.BoundAudiences = []string{
|
||||
"https://go-sso.test",
|
||||
"another_audience",
|
||||
}
|
||||
c.ClockSkewLeeway = time.Duration(tt.DefaultLeeway) * time.Second
|
||||
c.ExpirationLeeway = 0
|
||||
c.NotBeforeLeeway = time.Duration(tt.NBFLeeway) * time.Second
|
||||
})
|
||||
|
||||
jwtData := setupLogin(t, tt.IssuedAt, tt.Expiration, tt.NotBefore, issuer)
|
||||
|
||||
_, err := oa.ClaimsFromJWT(context.Background(), jwtData)
|
||||
if tt.Valid {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func setupLogin(t *testing.T, iat, exp, nbf time.Time, issuer string) string {
|
||||
cl := jwt.Claims{
|
||||
Audience: jwt.Audience{"https://go-sso.test"},
|
||||
Issuer: issuer,
|
||||
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
|
||||
IssuedAt: jwt.NewNumericDate(iat),
|
||||
Expiry: jwt.NewNumericDate(exp),
|
||||
NotBefore: jwt.NewNumericDate(nbf),
|
||||
}
|
||||
|
||||
privateCl := struct {
|
||||
User string `json:"https://go-sso/user"`
|
||||
Groups []string `json:"https://go-sso/groups"`
|
||||
Color string `json:"color"`
|
||||
}{
|
||||
"foobar",
|
||||
[]string{"foo", "bar"},
|
||||
"green",
|
||||
}
|
||||
|
||||
jwtData, err := oidcauthtest.SignJWT("", cl, privateCl)
|
||||
require.NoError(t, err)
|
||||
|
||||
return jwtData
|
||||
}
|
||||
|
||||
func TestParsePublicKeyPEM(t *testing.T) {
|
||||
getPublicPEM := func(t *testing.T, pub interface{}) string {
|
||||
derBytes, err := x509.MarshalPKIXPublicKey(pub)
|
||||
require.NoError(t, err)
|
||||
pemBlock := &pem.Block{
|
||||
Type: "PUBLIC KEY",
|
||||
Bytes: derBytes,
|
||||
}
|
||||
return string(pem.EncodeToMemory(pemBlock))
|
||||
}
|
||||
|
||||
t.Run("rsa", func(t *testing.T) {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
pub := privateKey.Public()
|
||||
pubPEM := getPublicPEM(t, pub)
|
||||
|
||||
got, err := parsePublicKeyPEM([]byte(pubPEM))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pub, got)
|
||||
})
|
||||
|
||||
t.Run("ecdsa", func(t *testing.T) {
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
pub := privateKey.Public()
|
||||
pubPEM := getPublicPEM(t, pub)
|
||||
|
||||
got, err := parsePublicKeyPEM([]byte(pubPEM))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pub, got)
|
||||
})
|
||||
|
||||
t.Run("ed25519", func(t *testing.T) {
|
||||
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
pubPEM := getPublicPEM(t, pub)
|
||||
|
||||
got, err := parsePublicKeyPEM([]byte(pubPEM))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pub, got)
|
||||
})
|
||||
}
|
||||
|
||||
const (
|
||||
badPrivKey string = `-----BEGIN EC PRIVATE KEY-----
|
||||
MHcCAQEEILTAHJm+clBKYCrRDc74Pt7uF7kH+2x2TdL5cH23FEcsoAoGCCqGSM49
|
||||
AwEHoUQDQgAE+C3CyjVWdeYtIqgluFJlwZmoonphsQbj9Nfo5wrEutv+3RTFnDQh
|
||||
vttUajcFAcl4beR+jHFYC00vSO4i5jZ64g==
|
||||
-----END EC PRIVATE KEY-----`
|
||||
)
|
|
@ -0,0 +1,282 @@
|
|||
package oidcauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
var (
|
||||
oidcStateTimeout = 10 * time.Minute
|
||||
oidcStateCleanupInterval = 1 * time.Minute
|
||||
)
|
||||
|
||||
// GetAuthCodeURL is the first part of the OIDC authorization code workflow.
|
||||
// The statePayload field is stored in the Authenticator instance keyed by the
|
||||
// "state" key so it can be returned during a future call to
|
||||
// ClaimsFromAuthCode.
|
||||
//
|
||||
// Requires the authenticator's config type be set to 'oidc'.
|
||||
func (a *Authenticator) GetAuthCodeURL(ctx context.Context, redirectURI string, statePayload interface{}) (string, error) {
|
||||
if a.config.authType() != authOIDCFlow {
|
||||
return "", fmt.Errorf("GetAuthCodeURL is incompatible with type %q", TypeJWT)
|
||||
}
|
||||
if redirectURI == "" {
|
||||
return "", errors.New("missing redirect_uri")
|
||||
}
|
||||
|
||||
if !validRedirect(redirectURI, a.config.AllowedRedirectURIs) {
|
||||
return "", fmt.Errorf("unauthorized redirect_uri: %s", redirectURI)
|
||||
}
|
||||
|
||||
// "openid" is a required scope for OpenID Connect flows
|
||||
scopes := append([]string{oidc.ScopeOpenID}, a.config.OIDCScopes...)
|
||||
|
||||
// Configure an OpenID Connect aware OAuth2 client
|
||||
oauth2Config := oauth2.Config{
|
||||
ClientID: a.config.OIDCClientID,
|
||||
ClientSecret: a.config.OIDCClientSecret,
|
||||
RedirectURL: redirectURI,
|
||||
Endpoint: a.provider.Endpoint(),
|
||||
Scopes: scopes,
|
||||
}
|
||||
|
||||
stateID, nonce, err := a.createOIDCState(redirectURI, statePayload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error generating OAuth state: %v", err)
|
||||
}
|
||||
|
||||
authCodeOpts := []oauth2.AuthCodeOption{
|
||||
oidc.Nonce(nonce),
|
||||
}
|
||||
|
||||
return oauth2Config.AuthCodeURL(stateID, authCodeOpts...), nil
|
||||
}
|
||||
|
||||
// ClaimsFromAuthCode is the second part of the OIDC authorization code
|
||||
// workflow. The interface{} return value is the statePayload previously passed
|
||||
// via GetAuthCodeURL.
|
||||
//
|
||||
// The error may be of type *ProviderLoginFailedError or
|
||||
// *TokenVerificationFailedError which can be detected via errors.As().
|
||||
//
|
||||
// Requires the authenticator's config type be set to 'oidc'.
|
||||
func (a *Authenticator) ClaimsFromAuthCode(ctx context.Context, stateParam, code string) (*Claims, interface{}, error) {
|
||||
if a.config.authType() != authOIDCFlow {
|
||||
return nil, nil, fmt.Errorf("ClaimsFromAuthCode is incompatible with type %q", TypeJWT)
|
||||
}
|
||||
|
||||
// TODO(sso): this could be because we ACTUALLY are getting OIDC error responses and
|
||||
// should handle them elsewhere!
|
||||
if code == "" {
|
||||
return nil, nil, &ProviderLoginFailedError{
|
||||
Err: fmt.Errorf("OAuth code parameter not provided"),
|
||||
}
|
||||
}
|
||||
|
||||
state := a.verifyOIDCState(stateParam)
|
||||
if state == nil {
|
||||
return nil, nil, &ProviderLoginFailedError{
|
||||
Err: fmt.Errorf("Expired or missing OAuth state."),
|
||||
}
|
||||
}
|
||||
|
||||
oidcCtx := contextWithHttpClient(ctx, a.httpClient)
|
||||
|
||||
var oauth2Config = oauth2.Config{
|
||||
ClientID: a.config.OIDCClientID,
|
||||
ClientSecret: a.config.OIDCClientSecret,
|
||||
RedirectURL: state.redirectURI,
|
||||
Endpoint: a.provider.Endpoint(),
|
||||
Scopes: []string{oidc.ScopeOpenID},
|
||||
}
|
||||
|
||||
oauth2Token, err := oauth2Config.Exchange(oidcCtx, code)
|
||||
if err != nil {
|
||||
return nil, nil, &ProviderLoginFailedError{
|
||||
Err: fmt.Errorf("Error exchanging oidc code: %w", err),
|
||||
}
|
||||
}
|
||||
|
||||
// Extract the ID Token from OAuth2 token.
|
||||
rawToken, ok := oauth2Token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return nil, nil, &TokenVerificationFailedError{
|
||||
Err: errors.New("No id_token found in response."),
|
||||
}
|
||||
}
|
||||
|
||||
if a.config.VerboseOIDCLogging && a.logger != nil {
|
||||
a.logger.Debug("OIDC provider response", "ID token", rawToken)
|
||||
}
|
||||
|
||||
// Parse and verify ID Token payload.
|
||||
allClaims, err := a.verifyOIDCToken(ctx, rawToken) // TODO(sso): should this use oidcCtx?
|
||||
if err != nil {
|
||||
return nil, nil, &TokenVerificationFailedError{
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
if allClaims["nonce"] != state.nonce { // TODO(sso): does this need a cast?
|
||||
return nil, nil, &TokenVerificationFailedError{
|
||||
Err: errors.New("Invalid ID token nonce."),
|
||||
}
|
||||
}
|
||||
delete(allClaims, "nonce")
|
||||
|
||||
// Attempt to fetch information from the /userinfo endpoint and merge it with
|
||||
// the existing claims data. A failure to fetch additional information from this
|
||||
// endpoint will not invalidate the authorization flow.
|
||||
if userinfo, err := a.provider.UserInfo(oidcCtx, oauth2.StaticTokenSource(oauth2Token)); err == nil {
|
||||
_ = userinfo.Claims(&allClaims)
|
||||
} else {
|
||||
if a.logger != nil {
|
||||
logFunc := a.logger.Warn
|
||||
if strings.Contains(err.Error(), "user info endpoint is not supported") {
|
||||
logFunc = a.logger.Info
|
||||
}
|
||||
logFunc("error reading /userinfo endpoint", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if a.config.VerboseOIDCLogging && a.logger != nil {
|
||||
if c, err := json.Marshal(allClaims); err == nil {
|
||||
a.logger.Debug("OIDC provider response", "claims", string(c))
|
||||
} else {
|
||||
a.logger.Debug("OIDC provider response", "marshalling error", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
c, err := a.extractClaims(allClaims)
|
||||
if err != nil {
|
||||
return nil, nil, &TokenVerificationFailedError{
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
if a.config.VerboseOIDCLogging && a.logger != nil {
|
||||
a.logger.Debug("OIDC provider response", "extracted_claims", c)
|
||||
}
|
||||
|
||||
return c, state.payload, nil
|
||||
}
|
||||
|
||||
// ProviderLoginFailedError is an error type sometimes returned from
|
||||
// ClaimsFromAuthCode().
|
||||
//
|
||||
// It represents a failure to complete the authorization code workflow with the
|
||||
// provider such as losing important OIDC parameters or a failure to fetch an
|
||||
// id_token.
|
||||
//
|
||||
// You can check for it with errors.As().
|
||||
type ProviderLoginFailedError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *ProviderLoginFailedError) Error() string {
|
||||
return fmt.Sprintf("Provider login failed: %v", e.Err)
|
||||
}
|
||||
|
||||
func (e *ProviderLoginFailedError) Unwrap() error { return e.Err }
|
||||
|
||||
// TokenVerificationFailedError is an error type sometimes returned from
|
||||
// ClaimsFromAuthCode().
|
||||
//
|
||||
// It represents a failure to vet the returned OIDC credentials for validity
|
||||
// such as the id_token not passing verification or using an mismatched nonce.
|
||||
//
|
||||
// You can check for it with errors.As().
|
||||
type TokenVerificationFailedError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *TokenVerificationFailedError) Error() string {
|
||||
return fmt.Sprintf("Token verification failed: %v", e.Err)
|
||||
}
|
||||
|
||||
func (e *TokenVerificationFailedError) Unwrap() error { return e.Err }
|
||||
|
||||
func (a *Authenticator) verifyOIDCToken(ctx context.Context, rawToken string) (map[string]interface{}, error) {
|
||||
allClaims := make(map[string]interface{})
|
||||
|
||||
oidcConfig := &oidc.Config{
|
||||
SupportedSigningAlgs: a.config.JWTSupportedAlgs,
|
||||
}
|
||||
switch a.config.authType() {
|
||||
case authOIDCFlow:
|
||||
oidcConfig.ClientID = a.config.OIDCClientID
|
||||
case authOIDCDiscovery:
|
||||
oidcConfig.SkipClientIDCheck = true
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported auth type for this verifyOIDCToken: %d", a.config.authType())
|
||||
}
|
||||
|
||||
verifier := a.provider.Verifier(oidcConfig)
|
||||
|
||||
idToken, err := verifier.Verify(ctx, rawToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error validating signature: %v", err)
|
||||
}
|
||||
|
||||
if err := idToken.Claims(&allClaims); err != nil {
|
||||
return nil, fmt.Errorf("unable to successfully parse all claims from token: %v", err)
|
||||
}
|
||||
// TODO(sso): why isn't this strict for OIDC?
|
||||
if err := validateAudience(a.config.BoundAudiences, idToken.Audience, false); err != nil {
|
||||
return nil, fmt.Errorf("error validating claims: %v", err)
|
||||
}
|
||||
|
||||
return allClaims, nil
|
||||
}
|
||||
|
||||
// verifyOIDCState tests whether the provided state ID is valid and returns the
|
||||
// associated state object if so. A nil state is returned if the ID is not found
|
||||
// or expired. The state should only ever be retrieved once and is deleted as
|
||||
// part of this request.
|
||||
func (a *Authenticator) verifyOIDCState(stateID string) *oidcState {
|
||||
defer a.oidcStates.Delete(stateID)
|
||||
|
||||
if stateRaw, ok := a.oidcStates.Get(stateID); ok {
|
||||
return stateRaw.(*oidcState)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createOIDCState make an expiring state object, associated with a random state ID
|
||||
// that is passed throughout the OAuth process. A nonce is also included in the
|
||||
// auth process, and for simplicity will be identical in length/format as the state ID.
|
||||
func (a *Authenticator) createOIDCState(redirectURI string, payload interface{}) (string, string, error) {
|
||||
// Get enough bytes for 2 160-bit IDs (per rfc6749#section-10.10)
|
||||
bytes, err := uuid.GenerateRandomBytes(2 * 20)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
stateID := fmt.Sprintf("%x", bytes[:20])
|
||||
nonce := fmt.Sprintf("%x", bytes[20:])
|
||||
|
||||
a.oidcStates.SetDefault(stateID, &oidcState{
|
||||
nonce: nonce,
|
||||
redirectURI: redirectURI,
|
||||
payload: payload,
|
||||
})
|
||||
|
||||
return stateID, nonce, nil
|
||||
}
|
||||
|
||||
// oidcState is created when an authURL is requested. The state
|
||||
// identifier is passed throughout the OAuth process.
|
||||
type oidcState struct {
|
||||
nonce string
|
||||
redirectURI string
|
||||
payload interface{}
|
||||
}
|
|
@ -0,0 +1,507 @@
|
|||
package oidcauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
func setupForOIDC(t *testing.T) (*Authenticator, *oidcauthtest.Server) {
|
||||
t.Helper()
|
||||
|
||||
srv := oidcauthtest.Start(t)
|
||||
srv.SetClientCreds("abc", "def")
|
||||
|
||||
config := &Config{
|
||||
Type: TypeOIDC,
|
||||
OIDCDiscoveryURL: srv.Addr(),
|
||||
OIDCDiscoveryCACert: srv.CACert(),
|
||||
OIDCClientID: "abc",
|
||||
OIDCClientSecret: "def",
|
||||
JWTSupportedAlgs: []string{"ES256"},
|
||||
BoundAudiences: []string{"abc"},
|
||||
AllowedRedirectURIs: []string{"https://example.com"},
|
||||
ClaimMappings: map[string]string{
|
||||
"COLOR": "color",
|
||||
"/nested/Size": "size",
|
||||
"Age": "age",
|
||||
"Admin": "is_admin",
|
||||
"/nested/division": "division",
|
||||
"/nested/remote": "is_remote",
|
||||
"flavor": "flavor", // userinfo
|
||||
},
|
||||
ListClaimMappings: map[string]string{
|
||||
"/nested/Groups": "groups",
|
||||
},
|
||||
}
|
||||
require.NoError(t, config.Validate())
|
||||
|
||||
oa, err := New(config, hclog.NewNullLogger())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(oa.Stop)
|
||||
|
||||
return oa, srv
|
||||
}
|
||||
|
||||
func TestOIDC_AuthURL(t *testing.T) {
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oa, _ := setupForOIDC(t)
|
||||
|
||||
authURL, err := oa.GetAuthCodeURL(
|
||||
context.Background(),
|
||||
"https://example.com",
|
||||
map[string]string{"foo": "bar"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, strings.HasPrefix(authURL, oa.config.OIDCDiscoveryURL+"/auth?"))
|
||||
|
||||
expected := map[string]string{
|
||||
"client_id": "abc",
|
||||
"redirect_uri": "https://example.com",
|
||||
"response_type": "code",
|
||||
"scope": "openid",
|
||||
}
|
||||
|
||||
au, err := url.Parse(authURL)
|
||||
require.NoError(t, err)
|
||||
|
||||
for k, v := range expected {
|
||||
assert.Equal(t, v, au.Query().Get(k), "key %q is incorrect", k)
|
||||
}
|
||||
|
||||
assert.Regexp(t, `^[a-z0-9]{40}$`, au.Query().Get("nonce"))
|
||||
assert.Regexp(t, `^[a-z0-9]{40}$`, au.Query().Get("state"))
|
||||
|
||||
})
|
||||
|
||||
t.Run("invalid RedirectURI", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oa, _ := setupForOIDC(t)
|
||||
|
||||
_, err := oa.GetAuthCodeURL(
|
||||
context.Background(),
|
||||
"http://bitc0in-4-less.cx",
|
||||
map[string]string{"foo": "bar"},
|
||||
)
|
||||
requireErrorContains(t, err, "unauthorized redirect_uri: http://bitc0in-4-less.cx")
|
||||
})
|
||||
|
||||
t.Run("missing RedirectURI", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oa, _ := setupForOIDC(t)
|
||||
|
||||
_, err := oa.GetAuthCodeURL(
|
||||
context.Background(),
|
||||
"",
|
||||
map[string]string{"foo": "bar"},
|
||||
)
|
||||
requireErrorContains(t, err, "missing redirect_uri")
|
||||
})
|
||||
}
|
||||
|
||||
func TestOIDC_JWT_Functions_Fail(t *testing.T) {
|
||||
oa, srv := setupForOIDC(t)
|
||||
|
||||
cl := jwt.Claims{
|
||||
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
|
||||
Issuer: srv.Addr(),
|
||||
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
|
||||
Audience: jwt.Audience{"https://go-sso.test"},
|
||||
}
|
||||
|
||||
privateCl := struct {
|
||||
User string `json:"https://go-sso/user"`
|
||||
Groups []string `json:"https://go-sso/groups"`
|
||||
}{
|
||||
"jeff",
|
||||
[]string{"foo", "bar"},
|
||||
}
|
||||
|
||||
jwtData, err := oidcauthtest.SignJWT("", cl, privateCl)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = oa.ClaimsFromJWT(context.Background(), jwtData)
|
||||
requireErrorContains(t, err, `ClaimsFromJWT is incompatible with type "oidc"`)
|
||||
}
|
||||
|
||||
func TestOIDC_ClaimsFromAuthCode(t *testing.T) {
|
||||
requireProviderError := func(t *testing.T, err error) {
|
||||
var provErr *ProviderLoginFailedError
|
||||
if !errors.As(err, &provErr) {
|
||||
t.Fatalf("error was not a *ProviderLoginFailedError")
|
||||
}
|
||||
}
|
||||
requireTokenVerificationError := func(t *testing.T, err error) {
|
||||
var tokErr *TokenVerificationFailedError
|
||||
if !errors.As(err, &tokErr) {
|
||||
t.Fatalf("error was not a *TokenVerificationFailedError")
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("successful login", func(t *testing.T) {
|
||||
oa, srv := setupForOIDC(t)
|
||||
|
||||
origPayload := map[string]string{"foo": "bar"}
|
||||
authURL, err := oa.GetAuthCodeURL(
|
||||
context.Background(),
|
||||
"https://example.com",
|
||||
origPayload,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
state := getQueryParam(t, authURL, "state")
|
||||
nonce := getQueryParam(t, authURL, "nonce")
|
||||
|
||||
// set provider claims that will be returned by the mock server
|
||||
srv.SetCustomClaims(sampleClaims(nonce))
|
||||
|
||||
// set mock provider's expected code
|
||||
srv.SetExpectedAuthCode("abc")
|
||||
|
||||
claims, payload, err := oa.ClaimsFromAuthCode(
|
||||
context.Background(),
|
||||
state, "abc",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, origPayload, payload)
|
||||
|
||||
expectedClaims := &Claims{
|
||||
Values: map[string]string{
|
||||
"color": "green",
|
||||
"size": "medium",
|
||||
"age": "85",
|
||||
"is_admin": "true",
|
||||
"division": "3",
|
||||
"is_remote": "true",
|
||||
"flavor": "umami", // from userinfo
|
||||
},
|
||||
Lists: map[string][]string{
|
||||
"groups": []string{"a", "b"},
|
||||
},
|
||||
}
|
||||
|
||||
require.Equal(t, expectedClaims, claims)
|
||||
})
|
||||
|
||||
t.Run("failed login unusable claims", func(t *testing.T) {
|
||||
oa, srv := setupForOIDC(t)
|
||||
|
||||
origPayload := map[string]string{"foo": "bar"}
|
||||
authURL, err := oa.GetAuthCodeURL(
|
||||
context.Background(),
|
||||
"https://example.com",
|
||||
origPayload,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
state := getQueryParam(t, authURL, "state")
|
||||
nonce := getQueryParam(t, authURL, "nonce")
|
||||
|
||||
// set provider claims that will be returned by the mock server
|
||||
customClaims := sampleClaims(nonce)
|
||||
customClaims["COLOR"] = []interface{}{"yellow"}
|
||||
srv.SetCustomClaims(customClaims)
|
||||
|
||||
// set mock provider's expected code
|
||||
srv.SetExpectedAuthCode("abc")
|
||||
|
||||
_, _, err = oa.ClaimsFromAuthCode(
|
||||
context.Background(),
|
||||
state, "abc",
|
||||
)
|
||||
requireErrorContains(t, err, "error converting claim 'COLOR' to string from unknown type []interface {}")
|
||||
requireTokenVerificationError(t, err)
|
||||
})
|
||||
|
||||
t.Run("successful login - no userinfo", func(t *testing.T) {
|
||||
oa, srv := setupForOIDC(t)
|
||||
|
||||
srv.DisableUserInfo()
|
||||
|
||||
origPayload := map[string]string{"foo": "bar"}
|
||||
authURL, err := oa.GetAuthCodeURL(
|
||||
context.Background(),
|
||||
"https://example.com",
|
||||
origPayload,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
state := getQueryParam(t, authURL, "state")
|
||||
nonce := getQueryParam(t, authURL, "nonce")
|
||||
|
||||
// set provider claims that will be returned by the mock server
|
||||
srv.SetCustomClaims(sampleClaims(nonce))
|
||||
|
||||
// set mock provider's expected code
|
||||
srv.SetExpectedAuthCode("abc")
|
||||
|
||||
claims, payload, err := oa.ClaimsFromAuthCode(
|
||||
context.Background(),
|
||||
state, "abc",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, origPayload, payload)
|
||||
|
||||
expectedClaims := &Claims{
|
||||
Values: map[string]string{
|
||||
"color": "green",
|
||||
"size": "medium",
|
||||
"age": "85",
|
||||
"is_admin": "true",
|
||||
"division": "3",
|
||||
"is_remote": "true",
|
||||
// "flavor": "umami", // from userinfo
|
||||
},
|
||||
Lists: map[string][]string{
|
||||
"groups": []string{"a", "b"},
|
||||
},
|
||||
}
|
||||
|
||||
require.Equal(t, expectedClaims, claims)
|
||||
})
|
||||
|
||||
t.Run("failed login - bad nonce", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oa, srv := setupForOIDC(t)
|
||||
|
||||
origPayload := map[string]string{"foo": "bar"}
|
||||
authURL, err := oa.GetAuthCodeURL(
|
||||
context.Background(),
|
||||
"https://example.com",
|
||||
origPayload,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
state := getQueryParam(t, authURL, "state")
|
||||
|
||||
srv.SetCustomClaims(sampleClaims("bad nonce"))
|
||||
|
||||
// set mock provider's expected code
|
||||
srv.SetExpectedAuthCode("abc")
|
||||
|
||||
_, _, err = oa.ClaimsFromAuthCode(
|
||||
context.Background(),
|
||||
state, "abc",
|
||||
)
|
||||
requireErrorContains(t, err, "Invalid ID token nonce")
|
||||
requireTokenVerificationError(t, err)
|
||||
})
|
||||
|
||||
t.Run("missing state", func(t *testing.T) {
|
||||
oa, _ := setupForOIDC(t)
|
||||
|
||||
origPayload := map[string]string{"foo": "bar"}
|
||||
_, err := oa.GetAuthCodeURL(
|
||||
context.Background(),
|
||||
"https://example.com",
|
||||
origPayload,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = oa.ClaimsFromAuthCode(
|
||||
context.Background(),
|
||||
"", "abc",
|
||||
)
|
||||
requireErrorContains(t, err, "Expired or missing OAuth state")
|
||||
requireProviderError(t, err)
|
||||
})
|
||||
|
||||
t.Run("unknown state", func(t *testing.T) {
|
||||
oa, _ := setupForOIDC(t)
|
||||
|
||||
origPayload := map[string]string{"foo": "bar"}
|
||||
_, err := oa.GetAuthCodeURL(
|
||||
context.Background(),
|
||||
"https://example.com",
|
||||
origPayload,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = oa.ClaimsFromAuthCode(
|
||||
context.Background(),
|
||||
"not_a_state", "abc",
|
||||
)
|
||||
requireErrorContains(t, err, "Expired or missing OAuth state")
|
||||
requireProviderError(t, err)
|
||||
})
|
||||
|
||||
t.Run("valid state, missing code", func(t *testing.T) {
|
||||
oa, _ := setupForOIDC(t)
|
||||
|
||||
origPayload := map[string]string{"foo": "bar"}
|
||||
authURL, err := oa.GetAuthCodeURL(
|
||||
context.Background(),
|
||||
"https://example.com",
|
||||
origPayload,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
state := getQueryParam(t, authURL, "state")
|
||||
|
||||
_, _, err = oa.ClaimsFromAuthCode(
|
||||
context.Background(),
|
||||
state, "",
|
||||
)
|
||||
requireErrorContains(t, err, "OAuth code parameter not provided")
|
||||
requireProviderError(t, err)
|
||||
})
|
||||
|
||||
t.Run("failed code exchange", func(t *testing.T) {
|
||||
oa, srv := setupForOIDC(t)
|
||||
|
||||
origPayload := map[string]string{"foo": "bar"}
|
||||
authURL, err := oa.GetAuthCodeURL(
|
||||
context.Background(),
|
||||
"https://example.com",
|
||||
origPayload,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
state := getQueryParam(t, authURL, "state")
|
||||
|
||||
// set mock provider's expected code
|
||||
srv.SetExpectedAuthCode("abc")
|
||||
|
||||
_, _, err = oa.ClaimsFromAuthCode(
|
||||
context.Background(),
|
||||
state, "wrong_code",
|
||||
)
|
||||
requireErrorContains(t, err, "cannot fetch token")
|
||||
requireProviderError(t, err)
|
||||
})
|
||||
|
||||
t.Run("no id_token returned", func(t *testing.T) {
|
||||
oa, srv := setupForOIDC(t)
|
||||
|
||||
origPayload := map[string]string{"foo": "bar"}
|
||||
authURL, err := oa.GetAuthCodeURL(
|
||||
context.Background(),
|
||||
"https://example.com",
|
||||
origPayload,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
state := getQueryParam(t, authURL, "state")
|
||||
nonce := getQueryParam(t, authURL, "nonce")
|
||||
|
||||
// set provider claims that will be returned by the mock server
|
||||
srv.SetCustomClaims(sampleClaims(nonce))
|
||||
|
||||
// set mock provider's expected code
|
||||
srv.SetExpectedAuthCode("abc")
|
||||
|
||||
srv.OmitIDTokens()
|
||||
|
||||
_, _, err = oa.ClaimsFromAuthCode(
|
||||
context.Background(),
|
||||
state, "abc",
|
||||
)
|
||||
requireErrorContains(t, err, "No id_token found in response")
|
||||
requireTokenVerificationError(t, err)
|
||||
})
|
||||
|
||||
t.Run("no response from provider", func(t *testing.T) {
|
||||
oa, srv := setupForOIDC(t)
|
||||
|
||||
origPayload := map[string]string{"foo": "bar"}
|
||||
authURL, err := oa.GetAuthCodeURL(
|
||||
context.Background(),
|
||||
"https://example.com",
|
||||
origPayload,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
state := getQueryParam(t, authURL, "state")
|
||||
|
||||
// close the server prematurely
|
||||
srv.Stop()
|
||||
srv.SetExpectedAuthCode("abc")
|
||||
|
||||
_, _, err = oa.ClaimsFromAuthCode(
|
||||
context.Background(),
|
||||
state, "abc",
|
||||
)
|
||||
requireErrorContains(t, err, "connection refused")
|
||||
requireProviderError(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid bound audience", func(t *testing.T) {
|
||||
oa, srv := setupForOIDC(t)
|
||||
|
||||
srv.SetClientCreds("not_gonna_match", "def")
|
||||
|
||||
origPayload := map[string]string{"foo": "bar"}
|
||||
authURL, err := oa.GetAuthCodeURL(
|
||||
context.Background(),
|
||||
"https://example.com",
|
||||
origPayload,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
state := getQueryParam(t, authURL, "state")
|
||||
nonce := getQueryParam(t, authURL, "nonce")
|
||||
|
||||
// set provider claims that will be returned by the mock server
|
||||
srv.SetCustomClaims(sampleClaims(nonce))
|
||||
|
||||
// set mock provider's expected code
|
||||
srv.SetExpectedAuthCode("abc")
|
||||
|
||||
_, _, err = oa.ClaimsFromAuthCode(
|
||||
context.Background(),
|
||||
state, "abc",
|
||||
)
|
||||
requireErrorContains(t, err, `error validating signature: oidc: expected audience "abc" got ["not_gonna_match"]`)
|
||||
requireTokenVerificationError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func sampleClaims(nonce string) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"nonce": nonce,
|
||||
"email": "bob@example.com",
|
||||
"COLOR": "green",
|
||||
"sk": "42",
|
||||
"Age": 85,
|
||||
"Admin": true,
|
||||
"nested": map[string]interface{}{
|
||||
"Size": "medium",
|
||||
"division": 3,
|
||||
"remote": true,
|
||||
"Groups": []string{"a", "b"},
|
||||
"secret_code": "bar",
|
||||
},
|
||||
"password": "foo",
|
||||
}
|
||||
}
|
||||
|
||||
func getQueryParam(t *testing.T, inputURL, param string) string {
|
||||
t.Helper()
|
||||
|
||||
m, err := url.ParseQuery(inputURL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
v, ok := m[param]
|
||||
if !ok {
|
||||
t.Fatalf("query param %q not found", param)
|
||||
}
|
||||
return v[0]
|
||||
}
|
|
@ -0,0 +1,529 @@
|
|||
// package oidcauthtest exposes tools to assist in writing unit tests of OIDC
|
||||
// and JWT authentication workflows.
|
||||
//
|
||||
// When the package is loaded it will randomly generate an ECDSA signing
|
||||
// keypair used to sign JWTs both via the Server and the SignJWT method.
|
||||
package oidcauthtest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/internal/go-sso/oidcauth/internal/strutil"
|
||||
"github.com/mitchellh/go-testing-interface"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
// Server is local server the mocks the endpoints used by the OIDC and
|
||||
// JWKS process.
|
||||
type Server struct {
|
||||
httpServer *httptest.Server
|
||||
caCert string
|
||||
returnFunc func()
|
||||
|
||||
jwks *jose.JSONWebKeySet
|
||||
allowedRedirectURIs []string
|
||||
replySubject string
|
||||
replyUserinfo map[string]interface{}
|
||||
|
||||
mu sync.Mutex
|
||||
clientID string
|
||||
clientSecret string
|
||||
expectedAuthCode string
|
||||
expectedAuthNonce string
|
||||
customClaims map[string]interface{}
|
||||
customAudience string
|
||||
omitIDToken bool
|
||||
disableUserInfo bool
|
||||
}
|
||||
|
||||
type startOption struct {
|
||||
port int
|
||||
returnFunc func()
|
||||
}
|
||||
|
||||
// WithPort is a option for Start that lets the caller control the port
|
||||
// allocation. The returnFunc parameter is used when the provider is stopped to
|
||||
// return the port in whatever bookkeeping system the caller wants to use.
|
||||
func WithPort(port int, returnFunc func()) startOption {
|
||||
return startOption{
|
||||
port: port,
|
||||
returnFunc: returnFunc,
|
||||
}
|
||||
}
|
||||
|
||||
// Start creates a disposable Server. If the port provided is
|
||||
// zero it will bind to a random free port, otherwise the provided port is
|
||||
// used.
|
||||
func Start(t testing.T, options ...startOption) *Server {
|
||||
s := &Server{
|
||||
allowedRedirectURIs: []string{
|
||||
"https://example.com",
|
||||
},
|
||||
replySubject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
|
||||
replyUserinfo: map[string]interface{}{
|
||||
"color": "red",
|
||||
"temperature": "76",
|
||||
"flavor": "umami",
|
||||
},
|
||||
}
|
||||
|
||||
jwks, err := newJWKS(ecdsaPublicKey)
|
||||
require.NoError(t, err)
|
||||
s.jwks = jwks
|
||||
|
||||
var (
|
||||
port int
|
||||
returnFunc func()
|
||||
)
|
||||
for _, option := range options {
|
||||
if option.port > 0 {
|
||||
port = option.port
|
||||
returnFunc = option.returnFunc
|
||||
}
|
||||
}
|
||||
|
||||
s.httpServer = httptestNewUnstartedServerWithPort(s, port)
|
||||
s.httpServer.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
|
||||
s.httpServer.StartTLS()
|
||||
if returnFunc != nil {
|
||||
t.Cleanup(returnFunc)
|
||||
}
|
||||
t.Cleanup(s.httpServer.Close)
|
||||
|
||||
cert := s.httpServer.Certificate()
|
||||
|
||||
var buf bytes.Buffer
|
||||
require.NoError(t, pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}))
|
||||
s.caCert = buf.String()
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// SetClientCreds is for configuring the client information required for the
|
||||
// OIDC workflows.
|
||||
func (s *Server) SetClientCreds(clientID, clientSecret string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.clientID = clientID
|
||||
s.clientSecret = clientSecret
|
||||
}
|
||||
|
||||
// SetExpectedAuthCode configures the auth code to return from /auth and the
|
||||
// allowed auth code for /token.
|
||||
func (s *Server) SetExpectedAuthCode(code string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.expectedAuthCode = code
|
||||
}
|
||||
|
||||
// SetExpectedAuthNonce configures the nonce value required for /auth.
|
||||
func (s *Server) SetExpectedAuthNonce(nonce string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.expectedAuthNonce = nonce
|
||||
}
|
||||
|
||||
// SetAllowedRedirectURIs allows you to configure the allowed redirect URIs for
|
||||
// the OIDC workflow. If not configured a sample of "https://example.com" is
|
||||
// used.
|
||||
func (s *Server) SetAllowedRedirectURIs(uris []string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.allowedRedirectURIs = uris
|
||||
}
|
||||
|
||||
// SetCustomClaims lets you set claims to return in the JWT issued by the OIDC
|
||||
// workflow.
|
||||
func (s *Server) SetCustomClaims(customClaims map[string]interface{}) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.customClaims = customClaims
|
||||
}
|
||||
|
||||
// SetCustomAudience configures what audience value to embed in the JWT issued
|
||||
// by the OIDC workflow.
|
||||
func (s *Server) SetCustomAudience(customAudience string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.customAudience = customAudience
|
||||
}
|
||||
|
||||
// OmitIDTokens forces an error state where the /token endpoint does not return
|
||||
// id_token.
|
||||
func (s *Server) OmitIDTokens() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.omitIDToken = true
|
||||
}
|
||||
|
||||
// DisableUserInfo makes the userinfo endpoint return 404 and omits it from the
|
||||
// discovery config.
|
||||
func (s *Server) DisableUserInfo() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.disableUserInfo = true
|
||||
}
|
||||
|
||||
// Stop stops the running Server.
|
||||
func (s *Server) Stop() {
|
||||
s.httpServer.Close()
|
||||
}
|
||||
|
||||
// Addr returns the current base URL for the running webserver.
|
||||
func (s *Server) Addr() string { return s.httpServer.URL }
|
||||
|
||||
// CACert returns the pem-encoded CA certificate used by the HTTPS server.
|
||||
func (s *Server) CACert() string { return s.caCert }
|
||||
|
||||
// SigningKeys returns the pem-encoded keys used to sign JWTs.
|
||||
func (s *Server) SigningKeys() (pub, priv string) {
|
||||
return SigningKeys()
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler.
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
switch req.URL.Path {
|
||||
case "/.well-known/openid-configuration":
|
||||
if req.Method != "GET" {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
reply := struct {
|
||||
Issuer string `json:"issuer"`
|
||||
AuthEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
JWKSURI string `json:"jwks_uri"`
|
||||
UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"`
|
||||
}{
|
||||
Issuer: s.Addr(),
|
||||
AuthEndpoint: s.Addr() + "/auth",
|
||||
TokenEndpoint: s.Addr() + "/token",
|
||||
JWKSURI: s.Addr() + "/certs",
|
||||
UserinfoEndpoint: s.Addr() + "/userinfo",
|
||||
}
|
||||
if s.disableUserInfo {
|
||||
reply.UserinfoEndpoint = ""
|
||||
}
|
||||
|
||||
if err := writeJSON(w, &reply); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
case "/auth":
|
||||
if req.Method != "GET" {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
qv := req.URL.Query()
|
||||
|
||||
if qv.Get("response_type") != "code" {
|
||||
writeAuthErrorResponse(w, req, "unsupported_response_type", "")
|
||||
return
|
||||
}
|
||||
if qv.Get("scope") != "openid" {
|
||||
writeAuthErrorResponse(w, req, "invalid_scope", "")
|
||||
return
|
||||
}
|
||||
|
||||
if s.expectedAuthCode == "" {
|
||||
writeAuthErrorResponse(w, req, "access_denied", "")
|
||||
return
|
||||
}
|
||||
|
||||
nonce := qv.Get("nonce")
|
||||
if s.expectedAuthNonce != "" && s.expectedAuthNonce != nonce {
|
||||
writeAuthErrorResponse(w, req, "access_denied", "")
|
||||
return
|
||||
}
|
||||
|
||||
state := qv.Get("state")
|
||||
if state == "" {
|
||||
writeAuthErrorResponse(w, req, "invalid_request", "missing state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
redirectURI := qv.Get("redirect_uri")
|
||||
if redirectURI == "" {
|
||||
writeAuthErrorResponse(w, req, "invalid_request", "missing redirect_uri parameter")
|
||||
return
|
||||
}
|
||||
|
||||
redirectURI += "?state=" + url.QueryEscape(state) +
|
||||
"&code=" + url.QueryEscape(s.expectedAuthCode)
|
||||
|
||||
http.Redirect(w, req, redirectURI, http.StatusFound)
|
||||
|
||||
return
|
||||
|
||||
case "/certs":
|
||||
if req.Method != "GET" {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
if err := writeJSON(w, s.jwks); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
case "/certs_missing":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
|
||||
case "/certs_invalid":
|
||||
w.Write([]byte("It's not a keyset!"))
|
||||
|
||||
case "/token":
|
||||
if req.Method != "POST" {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case req.FormValue("grant_type") != "authorization_code":
|
||||
_ = writeTokenErrorResponse(w, req, http.StatusBadRequest, "invalid_request", "bad grant_type")
|
||||
return
|
||||
case !strutil.StrListContains(s.allowedRedirectURIs, req.FormValue("redirect_uri")):
|
||||
_ = writeTokenErrorResponse(w, req, http.StatusBadRequest, "invalid_request", "redirect_uri is not allowed")
|
||||
return
|
||||
case req.FormValue("code") != s.expectedAuthCode:
|
||||
_ = writeTokenErrorResponse(w, req, http.StatusUnauthorized, "invalid_grant", "unexpected auth code")
|
||||
return
|
||||
}
|
||||
|
||||
stdClaims := jwt.Claims{
|
||||
Subject: s.replySubject,
|
||||
Issuer: s.Addr(),
|
||||
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Second)),
|
||||
Audience: jwt.Audience{s.clientID},
|
||||
}
|
||||
if s.customAudience != "" {
|
||||
stdClaims.Audience = jwt.Audience{s.customAudience}
|
||||
}
|
||||
|
||||
jwtData, err := SignJWT("", stdClaims, s.customClaims)
|
||||
if err != nil {
|
||||
_ = writeTokenErrorResponse(w, req, http.StatusInternalServerError, "server_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
reply := struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
}{
|
||||
AccessToken: jwtData,
|
||||
IDToken: jwtData,
|
||||
}
|
||||
if s.omitIDToken {
|
||||
reply.IDToken = ""
|
||||
}
|
||||
if err := writeJSON(w, &reply); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
case "/userinfo":
|
||||
if s.disableUserInfo {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if req.Method != "GET" {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
if err := writeJSON(w, s.replyUserinfo); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func writeAuthErrorResponse(w http.ResponseWriter, req *http.Request, errorCode, errorMessage string) {
|
||||
qv := req.URL.Query()
|
||||
|
||||
redirectURI := qv.Get("redirect_uri") +
|
||||
"?state=" + url.QueryEscape(qv.Get("state")) +
|
||||
"&error=" + url.QueryEscape(errorCode)
|
||||
|
||||
if errorMessage != "" {
|
||||
redirectURI += "&error_description=" + url.QueryEscape(errorMessage)
|
||||
}
|
||||
|
||||
http.Redirect(w, req, redirectURI, http.StatusFound)
|
||||
}
|
||||
|
||||
func writeTokenErrorResponse(w http.ResponseWriter, req *http.Request, statusCode int, errorCode, errorMessage string) error {
|
||||
body := struct {
|
||||
Code string `json:"error"`
|
||||
Desc string `json:"error_description,omitempty"`
|
||||
}{
|
||||
Code: errorCode,
|
||||
Desc: errorMessage,
|
||||
}
|
||||
|
||||
w.WriteHeader(statusCode)
|
||||
return writeJSON(w, &body)
|
||||
}
|
||||
|
||||
// newJWKS converts a pem-encoded public key into JWKS data suitable for a
|
||||
// verification endpoint response
|
||||
func newJWKS(pubKey string) (*jose.JSONWebKeySet, error) {
|
||||
block, _ := pem.Decode([]byte(pubKey))
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("unable to decode public key")
|
||||
}
|
||||
input := block.Bytes
|
||||
|
||||
pub, err := x509.ParsePKIXPublicKey(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &jose.JSONWebKeySet{
|
||||
Keys: []jose.JSONWebKey{
|
||||
jose.JSONWebKey{
|
||||
Key: pub,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, out interface{}) error {
|
||||
enc := json.NewEncoder(w)
|
||||
return enc.Encode(out)
|
||||
}
|
||||
|
||||
// SignJWT will bundle the provided claims into a signed JWT. The provided key
|
||||
// is assumed to be ECDSA.
|
||||
//
|
||||
// If no private key is provided, the default package keys are used. These can
|
||||
// be retrieved via the SigningKeys() method.
|
||||
func SignJWT(privKey string, claims jwt.Claims, privateClaims interface{}) (string, error) {
|
||||
if privKey == "" {
|
||||
privKey = ecdsaPrivateKey
|
||||
}
|
||||
var key *ecdsa.PrivateKey
|
||||
block, _ := pem.Decode([]byte(privKey))
|
||||
if block != nil {
|
||||
var err error
|
||||
key, err = x509.ParseECPrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
sig, err := jose.NewSigner(
|
||||
jose.SigningKey{Algorithm: jose.ES256, Key: key},
|
||||
(&jose.SignerOptions{}).WithType("JWT"),
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
raw, err := jwt.Signed(sig).
|
||||
Claims(claims).
|
||||
Claims(privateClaims).
|
||||
CompactSerialize()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
// httptestNewUnstartedServerWithPort is roughly the same as
|
||||
// httptest.NewUnstartedServer() but allows the caller to explicitly choose the
|
||||
// port if desired.
|
||||
func httptestNewUnstartedServerWithPort(handler http.Handler, port int) *httptest.Server {
|
||||
if port == 0 {
|
||||
return httptest.NewUnstartedServer(handler)
|
||||
}
|
||||
addr := net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
|
||||
l, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
|
||||
}
|
||||
|
||||
return &httptest.Server{
|
||||
Listener: l,
|
||||
Config: &http.Server{Handler: handler},
|
||||
}
|
||||
}
|
||||
|
||||
// SigningKeys returns the pem-encoded keys used to sign JWTs by default.
|
||||
func SigningKeys() (pub, priv string) {
|
||||
return ecdsaPublicKey, ecdsaPrivateKey
|
||||
}
|
||||
|
||||
var (
|
||||
ecdsaPublicKey string
|
||||
ecdsaPrivateKey string
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Each time we run tests we generate a unique set of keys for use in the
|
||||
// test. These are cached between runs but do not persist between restarts
|
||||
// of the test binary.
|
||||
var err error
|
||||
ecdsaPublicKey, ecdsaPrivateKey, err = generateKey()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func generateKey() (pub, priv string, err error) {
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("error generating private key: %v", err)
|
||||
}
|
||||
|
||||
{
|
||||
derBytes, err := x509.MarshalECPrivateKey(privateKey)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("error marshaling private key: %v", err)
|
||||
}
|
||||
pemBlock := &pem.Block{
|
||||
Type: "EC PRIVATE KEY",
|
||||
Bytes: derBytes,
|
||||
}
|
||||
priv = string(pem.EncodeToMemory(pemBlock))
|
||||
}
|
||||
{
|
||||
derBytes, err := x509.MarshalPKIXPublicKey(privateKey.Public())
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("error marshaling public key: %v", err)
|
||||
}
|
||||
pemBlock := &pem.Block{
|
||||
Type: "PUBLIC KEY",
|
||||
Bytes: derBytes,
|
||||
}
|
||||
pub = string(pem.EncodeToMemory(pemBlock))
|
||||
}
|
||||
|
||||
return pub, priv, nil
|
||||
}
|
|
@ -0,0 +1,252 @@
|
|||
package oidcauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/consul/internal/go-sso/oidcauth/internal/strutil"
|
||||
"github.com/hashicorp/go-cleanhttp"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/mitchellh/pointerstructure"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func contextWithHttpClient(ctx context.Context, client *http.Client) context.Context {
|
||||
return context.WithValue(ctx, oauth2.HTTPClient, client)
|
||||
}
|
||||
|
||||
func createHTTPClient(caCert string) (*http.Client, error) {
|
||||
tr := cleanhttp.DefaultPooledTransport()
|
||||
|
||||
if caCert != "" {
|
||||
certPool := x509.NewCertPool()
|
||||
if ok := certPool.AppendCertsFromPEM([]byte(caCert)); !ok {
|
||||
return nil, errors.New("could not parse CA PEM value successfully")
|
||||
}
|
||||
|
||||
tr.TLSClientConfig = &tls.Config{
|
||||
RootCAs: certPool,
|
||||
}
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: tr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extractClaims extracts all configured claims from the received claims.
|
||||
func (a *Authenticator) extractClaims(allClaims map[string]interface{}) (*Claims, error) {
|
||||
metadata, err := extractStringMetadata(a.logger, allClaims, a.config.ClaimMappings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
listMetadata, err := extractListMetadata(a.logger, allClaims, a.config.ListClaimMappings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Claims{
|
||||
Values: metadata,
|
||||
Lists: listMetadata,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extractStringMetadata builds a metadata map of string values from a set of
|
||||
// claims and claims mappings. The referenced claims must be strings and the
|
||||
// claims mappings must be of the structure:
|
||||
//
|
||||
// {
|
||||
// "/some/claim/pointer": "metadata_key1",
|
||||
// "another_claim": "metadata_key2",
|
||||
// ...
|
||||
// }
|
||||
func extractStringMetadata(logger hclog.Logger, allClaims map[string]interface{}, claimMappings map[string]string) (map[string]string, error) {
|
||||
metadata := make(map[string]string)
|
||||
for source, target := range claimMappings {
|
||||
rawValue := getClaim(logger, allClaims, source)
|
||||
if rawValue == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
strValue, ok := stringifyMetadataValue(rawValue)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("error converting claim '%s' to string from unknown type %T", source, rawValue)
|
||||
}
|
||||
|
||||
metadata[target] = strValue
|
||||
}
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
// extractListMetadata builds a metadata map of string list values from a set
|
||||
// of claims and claims mappings. The referenced claims must be strings and
|
||||
// the claims mappings must be of the structure:
|
||||
//
|
||||
// {
|
||||
// "/some/claim/pointer": "metadata_key1",
|
||||
// "another_claim": "metadata_key2",
|
||||
// ...
|
||||
// }
|
||||
func extractListMetadata(logger hclog.Logger, allClaims map[string]interface{}, listClaimMappings map[string]string) (map[string][]string, error) {
|
||||
out := make(map[string][]string)
|
||||
for source, target := range listClaimMappings {
|
||||
if rawValue := getClaim(logger, allClaims, source); rawValue != nil {
|
||||
rawList, ok := normalizeList(rawValue)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%q list claim could not be converted to string list", source)
|
||||
}
|
||||
|
||||
list := make([]string, 0, len(rawList))
|
||||
for _, raw := range rawList {
|
||||
value, ok := stringifyMetadataValue(raw)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value %v in %q list claim could not be parsed as string", raw, source)
|
||||
}
|
||||
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
list = append(list, value)
|
||||
}
|
||||
|
||||
out[target] = list
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// getClaim returns a claim value from allClaims given a provided claim string.
|
||||
// If this string is a valid JSONPointer, it will be interpreted as such to
|
||||
// locate the claim. Otherwise, the claim string will be used directly.
|
||||
//
|
||||
// There is no fixup done to the returned data type here. That happens a layer
|
||||
// up in the caller.
|
||||
func getClaim(logger hclog.Logger, allClaims map[string]interface{}, claim string) interface{} {
|
||||
if !strings.HasPrefix(claim, "/") {
|
||||
return allClaims[claim]
|
||||
}
|
||||
|
||||
val, err := pointerstructure.Get(allClaims, claim)
|
||||
if err != nil {
|
||||
if logger != nil {
|
||||
logger.Warn("unable to locate claim", "claim", claim, "error", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return val
|
||||
}
|
||||
|
||||
// normalizeList takes an item or a slice and returns a slice. This is useful
|
||||
// when providers are expected to return a list (typically of strings) but
|
||||
// reduce it to a non-slice type when the list count is 1.
|
||||
//
|
||||
// There is no fixup done to elements of the returned slice here. That happens
|
||||
// a layer up in the caller.
|
||||
func normalizeList(raw interface{}) ([]interface{}, bool) {
|
||||
switch v := raw.(type) {
|
||||
case []interface{}:
|
||||
return v, true
|
||||
case string, // note: this list should be the same as stringifyMetadataValue
|
||||
bool,
|
||||
json.Number,
|
||||
float64,
|
||||
float32,
|
||||
int8,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
int,
|
||||
uint8,
|
||||
uint16,
|
||||
uint32,
|
||||
uint64,
|
||||
uint:
|
||||
return []interface{}{v}, true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// stringifyMetadataValue will try to convert the provided raw value into a
|
||||
// faithful string representation of that value per these rules:
|
||||
//
|
||||
// - strings => unchanged
|
||||
// - bool => "true" / "false"
|
||||
// - json.Number => String()
|
||||
// - float32/64 => truncated to int64 and then formatted as an ascii string
|
||||
// - intXX/uintXX => casted to int64 and then formatted as an ascii string
|
||||
//
|
||||
// If successful the string value and true are returned. otherwise an empty
|
||||
// string and false are returned.
|
||||
func stringifyMetadataValue(rawValue interface{}) (string, bool) {
|
||||
switch v := rawValue.(type) {
|
||||
case string:
|
||||
return v, true
|
||||
case bool:
|
||||
return strconv.FormatBool(v), true
|
||||
case json.Number:
|
||||
return v.String(), true
|
||||
case float64:
|
||||
// The claims unmarshalled by go-oidc don't use UseNumber, so
|
||||
// they'll come in as float64 instead of an integer or json.Number.
|
||||
return strconv.FormatInt(int64(v), 10), true
|
||||
|
||||
// The numerical type cases following here are only here for the sake
|
||||
// of numerical type completion. Everything is truncated to an integer
|
||||
// before being stringified.
|
||||
case float32:
|
||||
return strconv.FormatInt(int64(v), 10), true
|
||||
case int8:
|
||||
return strconv.FormatInt(int64(v), 10), true
|
||||
case int16:
|
||||
return strconv.FormatInt(int64(v), 10), true
|
||||
case int32:
|
||||
return strconv.FormatInt(int64(v), 10), true
|
||||
case int64:
|
||||
return strconv.FormatInt(int64(v), 10), true
|
||||
case int:
|
||||
return strconv.FormatInt(int64(v), 10), true
|
||||
case uint8:
|
||||
return strconv.FormatInt(int64(v), 10), true
|
||||
case uint16:
|
||||
return strconv.FormatInt(int64(v), 10), true
|
||||
case uint32:
|
||||
return strconv.FormatInt(int64(v), 10), true
|
||||
case uint64:
|
||||
return strconv.FormatInt(int64(v), 10), true
|
||||
case uint:
|
||||
return strconv.FormatInt(int64(v), 10), true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
// validateAudience checks whether any of the audiences in audClaim match those
|
||||
// in boundAudiences. If strict is true and there are no bound audiences, then
|
||||
// the presence of any audience in the received claim is considered an error.
|
||||
func validateAudience(boundAudiences, audClaim []string, strict bool) error {
|
||||
if strict && len(boundAudiences) == 0 && len(audClaim) > 0 {
|
||||
return errors.New("audience claim found in JWT but no audiences are bound")
|
||||
}
|
||||
|
||||
if len(boundAudiences) > 0 {
|
||||
for _, v := range boundAudiences {
|
||||
if strutil.StrListContains(audClaim, v) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return errors.New("aud claim does not match any bound audience")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,614 @@
|
|||
package oidcauth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExtractStringMetadata(t *testing.T) {
|
||||
emptyMap := make(map[string]string)
|
||||
|
||||
tests := map[string]struct {
|
||||
allClaims map[string]interface{}
|
||||
claimMappings map[string]string
|
||||
expected map[string]string
|
||||
errExpected bool
|
||||
}{
|
||||
"empty": {nil, nil, emptyMap, false},
|
||||
"all": {
|
||||
map[string]interface{}{
|
||||
"data1": "foo",
|
||||
"data2": "bar",
|
||||
},
|
||||
map[string]string{
|
||||
"data1": "val1",
|
||||
"data2": "val2",
|
||||
},
|
||||
map[string]string{
|
||||
"val1": "foo",
|
||||
"val2": "bar",
|
||||
},
|
||||
false,
|
||||
},
|
||||
"some": {
|
||||
map[string]interface{}{
|
||||
"data1": "foo",
|
||||
"data2": "bar",
|
||||
},
|
||||
map[string]string{
|
||||
"data1": "val1",
|
||||
"data3": "val2",
|
||||
},
|
||||
map[string]string{
|
||||
"val1": "foo",
|
||||
},
|
||||
false,
|
||||
},
|
||||
"none": {
|
||||
map[string]interface{}{
|
||||
"data1": "foo",
|
||||
"data2": "bar",
|
||||
},
|
||||
map[string]string{
|
||||
"data8": "val1",
|
||||
"data9": "val2",
|
||||
},
|
||||
emptyMap,
|
||||
false,
|
||||
},
|
||||
|
||||
"nested data": {
|
||||
map[string]interface{}{
|
||||
"data1": "foo",
|
||||
"data2": map[string]interface{}{
|
||||
"child": "bar",
|
||||
},
|
||||
"data3": true,
|
||||
"data4": false,
|
||||
"data5": float64(7.9),
|
||||
"data6": json.Number("-12345"),
|
||||
"data7": int(42),
|
||||
},
|
||||
map[string]string{
|
||||
"data1": "val1",
|
||||
"/data2/child": "val2",
|
||||
"data3": "val3",
|
||||
"data4": "val4",
|
||||
"data5": "val5",
|
||||
"data6": "val6",
|
||||
"data7": "val7",
|
||||
},
|
||||
map[string]string{
|
||||
"val1": "foo",
|
||||
"val2": "bar",
|
||||
"val3": "true",
|
||||
"val4": "false",
|
||||
"val5": "7",
|
||||
"val6": "-12345",
|
||||
"val7": "42",
|
||||
},
|
||||
false,
|
||||
},
|
||||
|
||||
"error: a struct isn't stringifiable": {
|
||||
map[string]interface{}{
|
||||
"data1": map[string]interface{}{
|
||||
"child": "bar",
|
||||
},
|
||||
},
|
||||
map[string]string{
|
||||
"data1": "val1",
|
||||
},
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
"error: a slice isn't stringifiable": {
|
||||
map[string]interface{}{
|
||||
"data1": []interface{}{
|
||||
"child", "bar",
|
||||
},
|
||||
},
|
||||
map[string]string{
|
||||
"data1": "val1",
|
||||
},
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
test := test
|
||||
t.Run(name, func(t *testing.T) {
|
||||
actual, err := extractStringMetadata(nil, test.allClaims, test.claimMappings)
|
||||
if test.errExpected {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.expected, actual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractListMetadata(t *testing.T) {
|
||||
emptyMap := make(map[string][]string)
|
||||
|
||||
tests := map[string]struct {
|
||||
allClaims map[string]interface{}
|
||||
claimMappings map[string]string
|
||||
expected map[string][]string
|
||||
errExpected bool
|
||||
}{
|
||||
"empty": {nil, nil, emptyMap, false},
|
||||
"all - singular": {
|
||||
map[string]interface{}{
|
||||
"data1": "foo",
|
||||
"data2": "bar",
|
||||
},
|
||||
map[string]string{
|
||||
"data1": "val1",
|
||||
"data2": "val2",
|
||||
},
|
||||
map[string][]string{
|
||||
"val1": []string{"foo"},
|
||||
"val2": []string{"bar"},
|
||||
},
|
||||
false,
|
||||
},
|
||||
"some - singular": {
|
||||
map[string]interface{}{
|
||||
"data1": "foo",
|
||||
"data2": "bar",
|
||||
},
|
||||
map[string]string{
|
||||
"data1": "val1",
|
||||
"data3": "val2",
|
||||
},
|
||||
map[string][]string{
|
||||
"val1": []string{"foo"},
|
||||
},
|
||||
false,
|
||||
},
|
||||
"none - singular": {
|
||||
map[string]interface{}{
|
||||
"data1": "foo",
|
||||
"data2": "bar",
|
||||
},
|
||||
map[string]string{
|
||||
"data8": "val1",
|
||||
"data9": "val2",
|
||||
},
|
||||
emptyMap,
|
||||
false,
|
||||
},
|
||||
|
||||
"nested data - singular": {
|
||||
map[string]interface{}{
|
||||
"data1": "foo",
|
||||
"data2": map[string]interface{}{
|
||||
"child": "bar",
|
||||
},
|
||||
"data3": true,
|
||||
"data4": false,
|
||||
"data5": float64(7.9),
|
||||
"data6": json.Number("-12345"),
|
||||
"data7": int(42),
|
||||
"data8": []interface{}{ // mixed
|
||||
"foo", true, float64(7.9), json.Number("-12345"), int(42),
|
||||
},
|
||||
},
|
||||
map[string]string{
|
||||
"data1": "val1",
|
||||
"/data2/child": "val2",
|
||||
"data3": "val3",
|
||||
"data4": "val4",
|
||||
"data5": "val5",
|
||||
"data6": "val6",
|
||||
"data7": "val7",
|
||||
"data8": "val8",
|
||||
},
|
||||
map[string][]string{
|
||||
"val1": []string{"foo"},
|
||||
"val2": []string{"bar"},
|
||||
"val3": []string{"true"},
|
||||
"val4": []string{"false"},
|
||||
"val5": []string{"7"},
|
||||
"val6": []string{"-12345"},
|
||||
"val7": []string{"42"},
|
||||
"val8": []string{
|
||||
"foo", "true", "7", "-12345", "42",
|
||||
},
|
||||
},
|
||||
false,
|
||||
},
|
||||
|
||||
"error: a struct isn't stringifiable (singular)": {
|
||||
map[string]interface{}{
|
||||
"data1": map[string]interface{}{
|
||||
"child": map[string]interface{}{
|
||||
"inner": "bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
map[string]string{
|
||||
"data1": "val1",
|
||||
},
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
"error: a slice isn't stringifiable (singular)": {
|
||||
map[string]interface{}{
|
||||
"data1": []interface{}{
|
||||
"child", []interface{}{"bar"},
|
||||
},
|
||||
},
|
||||
map[string]string{
|
||||
"data1": "val1",
|
||||
},
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
|
||||
"non-string-slice data (string)": {
|
||||
map[string]interface{}{
|
||||
"data1": "foo",
|
||||
},
|
||||
map[string]string{
|
||||
"data1": "val1",
|
||||
},
|
||||
map[string][]string{
|
||||
"val1": []string{"foo"}, // singular values become lists
|
||||
},
|
||||
false,
|
||||
},
|
||||
|
||||
"all - list": {
|
||||
map[string]interface{}{
|
||||
"data1": []interface{}{"foo", "otherFoo"},
|
||||
"data2": []interface{}{"bar", "otherBar"},
|
||||
},
|
||||
map[string]string{
|
||||
"data1": "val1",
|
||||
"data2": "val2",
|
||||
},
|
||||
map[string][]string{
|
||||
"val1": []string{"foo", "otherFoo"},
|
||||
"val2": []string{"bar", "otherBar"},
|
||||
},
|
||||
false,
|
||||
},
|
||||
"some - list": {
|
||||
map[string]interface{}{
|
||||
"data1": []interface{}{"foo", "otherFoo"},
|
||||
"data2": map[string]interface{}{
|
||||
"child": []interface{}{"bar", "otherBar"},
|
||||
},
|
||||
},
|
||||
map[string]string{
|
||||
"data1": "val1",
|
||||
"/data2/child": "val2",
|
||||
},
|
||||
map[string][]string{
|
||||
"val1": []string{"foo", "otherFoo"},
|
||||
"val2": []string{"bar", "otherBar"},
|
||||
},
|
||||
false,
|
||||
},
|
||||
"none - list": {
|
||||
map[string]interface{}{
|
||||
"data1": []interface{}{"foo"},
|
||||
"data2": []interface{}{"bar"},
|
||||
},
|
||||
map[string]string{
|
||||
"data8": "val1",
|
||||
"data9": "val2",
|
||||
},
|
||||
emptyMap,
|
||||
false,
|
||||
},
|
||||
"list omits empty strings": {
|
||||
map[string]interface{}{
|
||||
"data1": []interface{}{"foo", "", "otherFoo", ""},
|
||||
"data2": "",
|
||||
},
|
||||
map[string]string{
|
||||
"data1": "val1",
|
||||
"data2": "val2",
|
||||
},
|
||||
map[string][]string{
|
||||
"val1": []string{"foo", "otherFoo"},
|
||||
"val2": []string{},
|
||||
},
|
||||
false,
|
||||
},
|
||||
|
||||
"nested data - list": {
|
||||
map[string]interface{}{
|
||||
"data1": []interface{}{"foo"},
|
||||
"data2": map[string]interface{}{
|
||||
"child": []interface{}{"bar"},
|
||||
},
|
||||
"data3": []interface{}{true},
|
||||
"data4": []interface{}{false},
|
||||
"data5": []interface{}{float64(7.9)},
|
||||
"data6": []interface{}{json.Number("-12345")},
|
||||
"data7": []interface{}{int(42)},
|
||||
"data8": []interface{}{ // mixed
|
||||
"foo", true, float64(7.9), json.Number("-12345"), int(42),
|
||||
},
|
||||
},
|
||||
map[string]string{
|
||||
"data1": "val1",
|
||||
"/data2/child": "val2",
|
||||
"data3": "val3",
|
||||
"data4": "val4",
|
||||
"data5": "val5",
|
||||
"data6": "val6",
|
||||
"data7": "val7",
|
||||
"data8": "val8",
|
||||
},
|
||||
map[string][]string{
|
||||
"val1": []string{"foo"},
|
||||
"val2": []string{"bar"},
|
||||
"val3": []string{"true"},
|
||||
"val4": []string{"false"},
|
||||
"val5": []string{"7"},
|
||||
"val6": []string{"-12345"},
|
||||
"val7": []string{"42"},
|
||||
"val8": []string{
|
||||
"foo", "true", "7", "-12345", "42",
|
||||
},
|
||||
},
|
||||
false,
|
||||
},
|
||||
|
||||
"JSONPointer": {
|
||||
map[string]interface{}{
|
||||
"foo": "a",
|
||||
"bar": map[string]interface{}{
|
||||
"baz": []string{"x", "y", "z"},
|
||||
},
|
||||
},
|
||||
map[string]string{
|
||||
"foo": "val1",
|
||||
"/bar/baz/1": "val2",
|
||||
},
|
||||
map[string][]string{
|
||||
"val1": []string{"a"},
|
||||
"val2": []string{"y"},
|
||||
},
|
||||
false,
|
||||
},
|
||||
"JSONPointer not found": {
|
||||
map[string]interface{}{
|
||||
"foo": "a",
|
||||
"bar": map[string]interface{}{
|
||||
"baz": []string{"x", "y", "z"},
|
||||
},
|
||||
},
|
||||
map[string]string{
|
||||
"foo": "val1",
|
||||
"/bar/XXX/1243": "val2",
|
||||
},
|
||||
map[string][]string{
|
||||
"val1": []string{"a"},
|
||||
},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
test := test
|
||||
t.Run(name, func(t *testing.T) {
|
||||
actual, err := extractListMetadata(nil, test.allClaims, test.claimMappings)
|
||||
if test.errExpected {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.expected, actual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClaim(t *testing.T) {
|
||||
data := `{
|
||||
"a": 42,
|
||||
"b": "bar",
|
||||
"c": {
|
||||
"d": 95,
|
||||
"e": [
|
||||
"dog",
|
||||
"cat",
|
||||
"bird"
|
||||
],
|
||||
"f": {
|
||||
"g": "zebra"
|
||||
}
|
||||
},
|
||||
"h": true,
|
||||
"i": false
|
||||
}`
|
||||
var claims map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal([]byte(data), &claims))
|
||||
|
||||
tests := []struct {
|
||||
claim string
|
||||
value interface{}
|
||||
}{
|
||||
{"a", float64(42)},
|
||||
{"/a", float64(42)},
|
||||
{"b", "bar"},
|
||||
{"/c/d", float64(95)},
|
||||
{"/c/e/1", "cat"},
|
||||
{"/c/f/g", "zebra"},
|
||||
{"nope", nil},
|
||||
{"/c/f/h", nil},
|
||||
{"", nil},
|
||||
{"\\", nil},
|
||||
{"h", true},
|
||||
{"i", false},
|
||||
{"/c/e", []interface{}{"dog", "cat", "bird"}},
|
||||
{"/c/f", map[string]interface{}{"g": "zebra"}},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.claim, func(t *testing.T) {
|
||||
v := getClaim(nil, claims, test.claim)
|
||||
require.Equal(t, test.value, v)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeList(t *testing.T) {
|
||||
tests := []struct {
|
||||
raw interface{}
|
||||
normalized []interface{}
|
||||
ok bool
|
||||
}{
|
||||
{
|
||||
raw: []interface{}{"green", 42},
|
||||
normalized: []interface{}{"green", 42},
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
raw: []interface{}{"green"},
|
||||
normalized: []interface{}{"green"},
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
raw: []interface{}{},
|
||||
normalized: []interface{}{},
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
raw: "green",
|
||||
normalized: []interface{}{"green"},
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
raw: "",
|
||||
normalized: []interface{}{""},
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
raw: 42,
|
||||
normalized: []interface{}{42},
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
raw: struct{ A int }{A: 5},
|
||||
normalized: nil,
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
raw: nil,
|
||||
normalized: nil,
|
||||
ok: false,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(fmt.Sprintf("%#v", tc.raw), func(t *testing.T) {
|
||||
normalized, ok := normalizeList(tc.raw)
|
||||
assert.Equal(t, tc.normalized, normalized)
|
||||
assert.Equal(t, tc.ok, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringifyMetadataValue(t *testing.T) {
|
||||
cases := map[string]struct {
|
||||
value interface{}
|
||||
expect string
|
||||
expectFailure bool
|
||||
}{
|
||||
"empty string": {"", "", false},
|
||||
"string": {"foo", "foo", false},
|
||||
"true": {true, "true", false},
|
||||
"false": {false, "false", false},
|
||||
"json number": {json.Number("-12345"), "-12345", false},
|
||||
"float64": {float64(7.9), "7", false},
|
||||
//
|
||||
"float32": {float32(7.9), "7", false},
|
||||
"int8": {int8(42), "42", false},
|
||||
"int16": {int16(42), "42", false},
|
||||
"int32": {int32(42), "42", false},
|
||||
"int64": {int64(42), "42", false},
|
||||
"int": {int(42), "42", false},
|
||||
"uint8": {uint8(42), "42", false},
|
||||
"uint16": {uint16(42), "42", false},
|
||||
"uint32": {uint32(42), "42", false},
|
||||
"uint64": {uint64(42), "42", false},
|
||||
"uint": {uint(42), "42", false},
|
||||
// fail
|
||||
"string slice": {[]string{"a"}, "", true},
|
||||
"int slice": {[]int64{99}, "", true},
|
||||
"map": {map[string]int{"a": 99}, "", true},
|
||||
"nil": {nil, "", true},
|
||||
"struct": {struct{ A int }{A: 5}, "", true},
|
||||
}
|
||||
|
||||
for name, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
got, ok := stringifyMetadataValue(tc.value)
|
||||
if tc.expectFailure {
|
||||
require.False(t, ok)
|
||||
} else {
|
||||
require.True(t, ok)
|
||||
require.Equal(t, tc.expect, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAudience(t *testing.T) {
|
||||
tests := []struct {
|
||||
boundAudiences []string
|
||||
audience []string
|
||||
errExpectedLax bool
|
||||
errExpectedStrict bool
|
||||
}{
|
||||
{[]string{"a"}, []string{"a"}, false, false},
|
||||
{[]string{"a"}, []string{"b"}, true, true},
|
||||
{[]string{"a"}, []string{""}, true, true},
|
||||
{[]string{}, []string{"a"}, false, true},
|
||||
{[]string{"a", "b"}, []string{"a"}, false, false},
|
||||
{[]string{"a", "b"}, []string{"b"}, false, false},
|
||||
{[]string{"a", "b"}, []string{"a", "b", "c"}, false, false},
|
||||
{[]string{"a", "b"}, []string{"c", "d"}, true, true},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
|
||||
t.Run(fmt.Sprintf(
|
||||
"boundAudiences=%#v audience=%#v strict=false",
|
||||
tc.boundAudiences, tc.audience,
|
||||
), func(t *testing.T) {
|
||||
err := validateAudience(tc.boundAudiences, tc.audience, false)
|
||||
if tc.errExpectedLax {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run(fmt.Sprintf(
|
||||
"boundAudiences=%#v audience=%#v strict=true",
|
||||
tc.boundAudiences, tc.audience,
|
||||
), func(t *testing.T) {
|
||||
err := validateAudience(tc.boundAudiences, tc.audience, true)
|
||||
if tc.errExpectedStrict {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
package oidcauth
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
|
||||
"github.com/hashicorp/consul/internal/go-sso/oidcauth/internal/strutil"
|
||||
)
|
||||
|
||||
// validRedirect checks whether uri is in allowed using special handling for loopback uris.
|
||||
// Ref: https://tools.ietf.org/html/rfc8252#section-7.3
|
||||
func validRedirect(uri string, allowed []string) bool {
|
||||
inputURI, err := url.Parse(uri)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// if uri isn't a loopback, just string search the allowed list
|
||||
if !strutil.StrListContains([]string{"localhost", "127.0.0.1", "::1"}, inputURI.Hostname()) {
|
||||
return strutil.StrListContains(allowed, uri)
|
||||
}
|
||||
|
||||
// otherwise, search for a match in a port-agnostic manner, per the OAuth RFC.
|
||||
inputURI.Host = inputURI.Hostname()
|
||||
|
||||
for _, a := range allowed {
|
||||
allowedURI, err := url.Parse(a)
|
||||
if err != nil {
|
||||
return false // shouldn't happen due to (*Config).Validate checks
|
||||
}
|
||||
allowedURI.Host = allowedURI.Hostname()
|
||||
|
||||
if inputURI.String() == allowedURI.String() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
package oidcauth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestValidRedirect(t *testing.T) {
|
||||
tests := []struct {
|
||||
uri string
|
||||
allowed []string
|
||||
expected bool
|
||||
}{
|
||||
// valid
|
||||
{"https://example.com", []string{"https://example.com"}, true},
|
||||
{"https://example.com:5000", []string{"a", "b", "https://example.com:5000"}, true},
|
||||
{"https://example.com/a/b/c", []string{"a", "b", "https://example.com/a/b/c"}, true},
|
||||
{"https://localhost:9000", []string{"a", "b", "https://localhost:5000"}, true},
|
||||
{"https://127.0.0.1:9000", []string{"a", "b", "https://127.0.0.1:5000"}, true},
|
||||
{"https://[::1]:9000", []string{"a", "b", "https://[::1]:5000"}, true},
|
||||
{"https://[::1]:9000/x/y?r=42", []string{"a", "b", "https://[::1]:5000/x/y?r=42"}, true},
|
||||
|
||||
// invalid
|
||||
{"https://example.com", []string{}, false},
|
||||
{"http://example.com", []string{"a", "b", "https://example.com"}, false},
|
||||
{"https://example.com:9000", []string{"a", "b", "https://example.com:5000"}, false},
|
||||
{"https://[::2]:9000", []string{"a", "b", "https://[::2]:5000"}, false},
|
||||
{"https://localhost:5000", []string{"a", "b", "https://127.0.0.1:5000"}, false},
|
||||
{"https://localhost:5000", []string{"a", "b", "https://127.0.0.1:5000"}, false},
|
||||
{"https://localhost:5000", []string{"a", "b", "http://localhost:5000"}, false},
|
||||
{"https://[::1]:5000/x/y?r=42", []string{"a", "b", "https://[::1]:5000/x/y?r=43"}, false},
|
||||
|
||||
// extra invalid
|
||||
{"%%%%%%%%%%%", []string{}, false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(fmt.Sprintf("uri=%q allowed=%#v", tc.uri, tc.allowed), func(t *testing.T) {
|
||||
require.Equal(t, tc.expected, validRedirect(tc.uri, tc.allowed))
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,2 @@
|
|||
/bin
|
||||
/gopath
|
|
@ -0,0 +1,16 @@
|
|||
language: go
|
||||
|
||||
go:
|
||||
- "1.9"
|
||||
- "1.10"
|
||||
|
||||
install:
|
||||
- go get -v -t github.com/coreos/go-oidc/...
|
||||
- go get golang.org/x/tools/cmd/cover
|
||||
- go get github.com/golang/lint/golint
|
||||
|
||||
script:
|
||||
- ./test
|
||||
|
||||
notifications:
|
||||
email: false
|
|
@ -0,0 +1,71 @@
|
|||
# How to Contribute
|
||||
|
||||
CoreOS projects are [Apache 2.0 licensed](LICENSE) and accept contributions via
|
||||
GitHub pull requests. This document outlines some of the conventions on
|
||||
development workflow, commit message formatting, contact points and other
|
||||
resources to make it easier to get your contribution accepted.
|
||||
|
||||
# Certificate of Origin
|
||||
|
||||
By contributing to this project you agree to the Developer Certificate of
|
||||
Origin (DCO). This document was created by the Linux Kernel community and is a
|
||||
simple statement that you, as a contributor, have the legal right to make the
|
||||
contribution. See the [DCO](DCO) file for details.
|
||||
|
||||
# Email and Chat
|
||||
|
||||
The project currently uses the general CoreOS email list and IRC channel:
|
||||
- Email: [coreos-dev](https://groups.google.com/forum/#!forum/coreos-dev)
|
||||
- IRC: #[coreos](irc://irc.freenode.org:6667/#coreos) IRC channel on freenode.org
|
||||
|
||||
Please avoid emailing maintainers found in the MAINTAINERS file directly. They
|
||||
are very busy and read the mailing lists.
|
||||
|
||||
## Getting Started
|
||||
|
||||
- Fork the repository on GitHub
|
||||
- Read the [README](README.md) for build and test instructions
|
||||
- Play with the project, submit bugs, submit patches!
|
||||
|
||||
## Contribution Flow
|
||||
|
||||
This is a rough outline of what a contributor's workflow looks like:
|
||||
|
||||
- Create a topic branch from where you want to base your work (usually master).
|
||||
- Make commits of logical units.
|
||||
- Make sure your commit messages are in the proper format (see below).
|
||||
- Push your changes to a topic branch in your fork of the repository.
|
||||
- Make sure the tests pass, and add any new tests as appropriate.
|
||||
- Submit a pull request to the original repository.
|
||||
|
||||
Thanks for your contributions!
|
||||
|
||||
### Format of the Commit Message
|
||||
|
||||
We follow a rough convention for commit messages that is designed to answer two
|
||||
questions: what changed and why. The subject line should feature the what and
|
||||
the body of the commit should describe the why.
|
||||
|
||||
```
|
||||
scripts: add the test-cluster command
|
||||
|
||||
this uses tmux to setup a test cluster that you can easily kill and
|
||||
start for debugging.
|
||||
|
||||
Fixes #38
|
||||
```
|
||||
|
||||
The format can be described more formally as follows:
|
||||
|
||||
```
|
||||
<subsystem>: <what changed>
|
||||
<BLANK LINE>
|
||||
<why this change was made>
|
||||
<BLANK LINE>
|
||||
<footer>
|
||||
```
|
||||
|
||||
The first line is the subject and should be no longer than 70 characters, the
|
||||
second line is always blank, and other lines should be wrapped at 80 characters.
|
||||
This allows the message to be easier to read on GitHub as well as in various
|
||||
git tools.
|
|
@ -0,0 +1,36 @@
|
|||
Developer Certificate of Origin
|
||||
Version 1.1
|
||||
|
||||
Copyright (C) 2004, 2006 The Linux Foundation and its contributors.
|
||||
660 York Street, Suite 102,
|
||||
San Francisco, CA 94110 USA
|
||||
|
||||
Everyone is permitted to copy and distribute verbatim copies of this
|
||||
license document, but changing it is not allowed.
|
||||
|
||||
|
||||
Developer's Certificate of Origin 1.1
|
||||
|
||||
By making a contribution to this project, I certify that:
|
||||
|
||||
(a) The contribution was created in whole or in part by me and I
|
||||
have the right to submit it under the open source license
|
||||
indicated in the file; or
|
||||
|
||||
(b) The contribution is based upon previous work that, to the best
|
||||
of my knowledge, is covered under an appropriate open source
|
||||
license and I have the right under that license to submit that
|
||||
work with modifications, whether created in whole or in part
|
||||
by me, under the same open source license (unless I am
|
||||
permitted to submit under a different license), as indicated
|
||||
in the file; or
|
||||
|
||||
(c) The contribution was provided directly to me by some other
|
||||
person who certified (a), (b) or (c) and I have not modified
|
||||
it.
|
||||
|
||||
(d) I understand and agree that this project and the contribution
|
||||
are public and that a record of the contribution (including all
|
||||
personal information I submit with it, including my sign-off) is
|
||||
maintained indefinitely and may be redistributed consistent with
|
||||
this project or the open source license(s) involved.
|
|
@ -0,0 +1,202 @@
|
|||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright {yyyy} {name of copyright owner}
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
Eric Chiang <ericchiang@google.com> (@ericchiang)
|
||||
Mike Danese <mikedanese@google.com> (@mikedanese)
|
||||
Rithu Leena John <rjohn@redhat.com> (@rithujohn191)
|
|
@ -0,0 +1,5 @@
|
|||
CoreOS Project
|
||||
Copyright 2014 CoreOS, Inc
|
||||
|
||||
This product includes software developed at CoreOS, Inc.
|
||||
(http://www.coreos.com/).
|
|
@ -0,0 +1,72 @@
|
|||
# go-oidc
|
||||
|
||||
[![GoDoc](https://godoc.org/github.com/coreos/go-oidc?status.svg)](https://godoc.org/github.com/coreos/go-oidc)
|
||||
[![Build Status](https://travis-ci.org/coreos/go-oidc.png?branch=master)](https://travis-ci.org/coreos/go-oidc)
|
||||
|
||||
## OpenID Connect support for Go
|
||||
|
||||
This package enables OpenID Connect support for the [golang.org/x/oauth2](https://godoc.org/golang.org/x/oauth2) package.
|
||||
|
||||
```go
|
||||
provider, err := oidc.NewProvider(ctx, "https://accounts.google.com")
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
|
||||
// Configure an OpenID Connect aware OAuth2 client.
|
||||
oauth2Config := oauth2.Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURL: redirectURL,
|
||||
|
||||
// Discovery returns the OAuth2 endpoints.
|
||||
Endpoint: provider.Endpoint(),
|
||||
|
||||
// "openid" is a required scope for OpenID Connect flows.
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
}
|
||||
```
|
||||
|
||||
OAuth2 redirects are unchanged.
|
||||
|
||||
```go
|
||||
func handleRedirect(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, oauth2Config.AuthCodeURL(state), http.StatusFound)
|
||||
}
|
||||
```
|
||||
|
||||
The on responses, the provider can be used to verify ID Tokens.
|
||||
|
||||
```go
|
||||
var verifier = provider.Verifier(&oidc.Config{ClientID: clientID})
|
||||
|
||||
func handleOAuth2Callback(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify state and errors.
|
||||
|
||||
oauth2Token, err := oauth2Config.Exchange(ctx, r.URL.Query().Get("code"))
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
|
||||
// Extract the ID Token from OAuth2 token.
|
||||
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
// handle missing token
|
||||
}
|
||||
|
||||
// Parse and verify ID Token payload.
|
||||
idToken, err := verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
|
||||
// Extract custom claims
|
||||
var claims struct {
|
||||
Email string `json:"email"`
|
||||
Verified bool `json:"email_verified"`
|
||||
}
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
// handle error
|
||||
}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,61 @@
|
|||
## CoreOS Community Code of Conduct
|
||||
|
||||
### Contributor Code of Conduct
|
||||
|
||||
As contributors and maintainers of this project, and in the interest of
|
||||
fostering an open and welcoming community, we pledge to respect all people who
|
||||
contribute through reporting issues, posting feature requests, updating
|
||||
documentation, submitting pull requests or patches, and other activities.
|
||||
|
||||
We are committed to making participation in this project a harassment-free
|
||||
experience for everyone, regardless of level of experience, gender, gender
|
||||
identity and expression, sexual orientation, disability, personal appearance,
|
||||
body size, race, ethnicity, age, religion, or nationality.
|
||||
|
||||
Examples of unacceptable behavior by participants include:
|
||||
|
||||
* The use of sexualized language or imagery
|
||||
* Personal attacks
|
||||
* Trolling or insulting/derogatory comments
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as physical or electronic addresses, without explicit permission
|
||||
* Other unethical or unprofessional conduct.
|
||||
|
||||
Project maintainers have the right and responsibility to remove, edit, or
|
||||
reject comments, commits, code, wiki edits, issues, and other contributions
|
||||
that are not aligned to this Code of Conduct. By adopting this Code of Conduct,
|
||||
project maintainers commit themselves to fairly and consistently applying these
|
||||
principles to every aspect of managing this project. Project maintainers who do
|
||||
not follow or enforce the Code of Conduct may be permanently removed from the
|
||||
project team.
|
||||
|
||||
This code of conduct applies both within project spaces and in public spaces
|
||||
when an individual is representing the project or its community.
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported by contacting a project maintainer, Brandon Philips
|
||||
<brandon.philips@coreos.com>, and/or Rithu John <rithu.john@coreos.com>.
|
||||
|
||||
This Code of Conduct is adapted from the Contributor Covenant
|
||||
(http://contributor-covenant.org), version 1.2.0, available at
|
||||
http://contributor-covenant.org/version/1/2/0/
|
||||
|
||||
### CoreOS Events Code of Conduct
|
||||
|
||||
CoreOS events are working conferences intended for professional networking and
|
||||
collaboration in the CoreOS community. Attendees are expected to behave
|
||||
according to professional standards and in accordance with their employer’s
|
||||
policies on appropriate workplace behavior.
|
||||
|
||||
While at CoreOS events or related social networking opportunities, attendees
|
||||
should not engage in discriminatory or offensive speech or actions including
|
||||
but not limited to gender, sexuality, race, age, disability, or religion.
|
||||
Speakers should be especially aware of these concerns.
|
||||
|
||||
CoreOS does not condone any statements by speakers contrary to these standards.
|
||||
CoreOS reserves the right to deny entrance and/or eject from an event (without
|
||||
refund) any individual found to be engaging in discriminatory or offensive
|
||||
speech or actions.
|
||||
|
||||
Please bring any concerns to the immediate attention of designated on-site
|
||||
staff, Brandon Philips <brandon.philips@coreos.com>, and/or Rithu John <rithu.john@coreos.com>.
|
|
@ -0,0 +1,20 @@
|
|||
// +build !golint
|
||||
|
||||
// Don't lint this file. We don't want to have to add a comment to each constant.
|
||||
|
||||
package oidc
|
||||
|
||||
const (
|
||||
// JOSE asymmetric signing algorithm values as defined by RFC 7518
|
||||
//
|
||||
// see: https://tools.ietf.org/html/rfc7518#section-3.1
|
||||
RS256 = "RS256" // RSASSA-PKCS-v1.5 using SHA-256
|
||||
RS384 = "RS384" // RSASSA-PKCS-v1.5 using SHA-384
|
||||
RS512 = "RS512" // RSASSA-PKCS-v1.5 using SHA-512
|
||||
ES256 = "ES256" // ECDSA using P-256 and SHA-256
|
||||
ES384 = "ES384" // ECDSA using P-384 and SHA-384
|
||||
ES512 = "ES512" // ECDSA using P-521 and SHA-512
|
||||
PS256 = "PS256" // RSASSA-PSS using SHA256 and MGF1-SHA256
|
||||
PS384 = "PS384" // RSASSA-PSS using SHA384 and MGF1-SHA384
|
||||
PS512 = "PS512" // RSASSA-PSS using SHA512 and MGF1-SHA512
|
||||
)
|
|
@ -0,0 +1,228 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pquerna/cachecontrol"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
// keysExpiryDelta is the allowed clock skew between a client and the OpenID Connect
|
||||
// server.
|
||||
//
|
||||
// When keys expire, they are valid for this amount of time after.
|
||||
//
|
||||
// If the keys have not expired, and an ID Token claims it was signed by a key not in
|
||||
// the cache, if and only if the keys expire in this amount of time, the keys will be
|
||||
// updated.
|
||||
const keysExpiryDelta = 30 * time.Second
|
||||
|
||||
// NewRemoteKeySet returns a KeySet that can validate JSON web tokens by using HTTP
|
||||
// GETs to fetch JSON web token sets hosted at a remote URL. This is automatically
|
||||
// used by NewProvider using the URLs returned by OpenID Connect discovery, but is
|
||||
// exposed for providers that don't support discovery or to prevent round trips to the
|
||||
// discovery URL.
|
||||
//
|
||||
// The returned KeySet is a long lived verifier that caches keys based on cache-control
|
||||
// headers. Reuse a common remote key set instead of creating new ones as needed.
|
||||
//
|
||||
// The behavior of the returned KeySet is undefined once the context is canceled.
|
||||
func NewRemoteKeySet(ctx context.Context, jwksURL string) KeySet {
|
||||
return newRemoteKeySet(ctx, jwksURL, time.Now)
|
||||
}
|
||||
|
||||
func newRemoteKeySet(ctx context.Context, jwksURL string, now func() time.Time) *remoteKeySet {
|
||||
if now == nil {
|
||||
now = time.Now
|
||||
}
|
||||
return &remoteKeySet{jwksURL: jwksURL, ctx: ctx, now: now}
|
||||
}
|
||||
|
||||
type remoteKeySet struct {
|
||||
jwksURL string
|
||||
ctx context.Context
|
||||
now func() time.Time
|
||||
|
||||
// guard all other fields
|
||||
mu sync.Mutex
|
||||
|
||||
// inflight suppresses parallel execution of updateKeys and allows
|
||||
// multiple goroutines to wait for its result.
|
||||
inflight *inflight
|
||||
|
||||
// A set of cached keys and their expiry.
|
||||
cachedKeys []jose.JSONWebKey
|
||||
expiry time.Time
|
||||
}
|
||||
|
||||
// inflight is used to wait on some in-flight request from multiple goroutines.
|
||||
type inflight struct {
|
||||
doneCh chan struct{}
|
||||
|
||||
keys []jose.JSONWebKey
|
||||
err error
|
||||
}
|
||||
|
||||
func newInflight() *inflight {
|
||||
return &inflight{doneCh: make(chan struct{})}
|
||||
}
|
||||
|
||||
// wait returns a channel that multiple goroutines can receive on. Once it returns
|
||||
// a value, the inflight request is done and result() can be inspected.
|
||||
func (i *inflight) wait() <-chan struct{} {
|
||||
return i.doneCh
|
||||
}
|
||||
|
||||
// done can only be called by a single goroutine. It records the result of the
|
||||
// inflight request and signals other goroutines that the result is safe to
|
||||
// inspect.
|
||||
func (i *inflight) done(keys []jose.JSONWebKey, err error) {
|
||||
i.keys = keys
|
||||
i.err = err
|
||||
close(i.doneCh)
|
||||
}
|
||||
|
||||
// result cannot be called until the wait() channel has returned a value.
|
||||
func (i *inflight) result() ([]jose.JSONWebKey, error) {
|
||||
return i.keys, i.err
|
||||
}
|
||||
|
||||
func (r *remoteKeySet) VerifySignature(ctx context.Context, jwt string) ([]byte, error) {
|
||||
jws, err := jose.ParseSigned(jwt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oidc: malformed jwt: %v", err)
|
||||
}
|
||||
return r.verify(ctx, jws)
|
||||
}
|
||||
|
||||
func (r *remoteKeySet) verify(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
|
||||
// We don't support JWTs signed with multiple signatures.
|
||||
keyID := ""
|
||||
for _, sig := range jws.Signatures {
|
||||
keyID = sig.Header.KeyID
|
||||
break
|
||||
}
|
||||
|
||||
keys, expiry := r.keysFromCache()
|
||||
|
||||
// Don't check expiry yet. This optimizes for when the provider is unavailable.
|
||||
for _, key := range keys {
|
||||
if keyID == "" || key.KeyID == keyID {
|
||||
if payload, err := jws.Verify(&key); err == nil {
|
||||
return payload, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !r.now().Add(keysExpiryDelta).After(expiry) {
|
||||
// Keys haven't expired, don't refresh.
|
||||
return nil, errors.New("failed to verify id token signature")
|
||||
}
|
||||
|
||||
keys, err := r.keysFromRemote(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetching keys %v", err)
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
if keyID == "" || key.KeyID == keyID {
|
||||
if payload, err := jws.Verify(&key); err == nil {
|
||||
return payload, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, errors.New("failed to verify id token signature")
|
||||
}
|
||||
|
||||
func (r *remoteKeySet) keysFromCache() (keys []jose.JSONWebKey, expiry time.Time) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.cachedKeys, r.expiry
|
||||
}
|
||||
|
||||
// keysFromRemote syncs the key set from the remote set, records the values in the
|
||||
// cache, and returns the key set.
|
||||
func (r *remoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, error) {
|
||||
// Need to lock to inspect the inflight request field.
|
||||
r.mu.Lock()
|
||||
// If there's not a current inflight request, create one.
|
||||
if r.inflight == nil {
|
||||
r.inflight = newInflight()
|
||||
|
||||
// This goroutine has exclusive ownership over the current inflight
|
||||
// request. It releases the resource by nil'ing the inflight field
|
||||
// once the goroutine is done.
|
||||
go func() {
|
||||
// Sync keys and finish inflight when that's done.
|
||||
keys, expiry, err := r.updateKeys()
|
||||
|
||||
r.inflight.done(keys, err)
|
||||
|
||||
// Lock to update the keys and indicate that there is no longer an
|
||||
// inflight request.
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if err == nil {
|
||||
r.cachedKeys = keys
|
||||
r.expiry = expiry
|
||||
}
|
||||
|
||||
// Free inflight so a different request can run.
|
||||
r.inflight = nil
|
||||
}()
|
||||
}
|
||||
inflight := r.inflight
|
||||
r.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-inflight.wait():
|
||||
return inflight.result()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *remoteKeySet) updateKeys() ([]jose.JSONWebKey, time.Time, error) {
|
||||
req, err := http.NewRequest("GET", r.jwksURL, nil)
|
||||
if err != nil {
|
||||
return nil, time.Time{}, fmt.Errorf("oidc: can't create request: %v", err)
|
||||
}
|
||||
|
||||
resp, err := doRequest(r.ctx, req)
|
||||
if err != nil {
|
||||
return nil, time.Time{}, fmt.Errorf("oidc: get keys failed %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, time.Time{}, fmt.Errorf("unable to read response body: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, time.Time{}, fmt.Errorf("oidc: get keys failed: %s %s", resp.Status, body)
|
||||
}
|
||||
|
||||
var keySet jose.JSONWebKeySet
|
||||
err = unmarshalResp(resp, body, &keySet)
|
||||
if err != nil {
|
||||
return nil, time.Time{}, fmt.Errorf("oidc: failed to decode keys: %v %s", err, body)
|
||||
}
|
||||
|
||||
// If the server doesn't provide cache control headers, assume the
|
||||
// keys expire immediately.
|
||||
expiry := r.now()
|
||||
|
||||
_, e, err := cachecontrol.CachableResponse(req, resp, cachecontrol.Options{})
|
||||
if err == nil && e.After(expiry) {
|
||||
expiry = e
|
||||
}
|
||||
return keySet.Keys, expiry, nil
|
||||
}
|
|
@ -0,0 +1,385 @@
|
|||
// Package oidc implements OpenID Connect client logic for the golang.org/x/oauth2 package.
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io/ioutil"
|
||||
"mime"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
const (
|
||||
// ScopeOpenID is the mandatory scope for all OpenID Connect OAuth2 requests.
|
||||
ScopeOpenID = "openid"
|
||||
|
||||
// ScopeOfflineAccess is an optional scope defined by OpenID Connect for requesting
|
||||
// OAuth2 refresh tokens.
|
||||
//
|
||||
// Support for this scope differs between OpenID Connect providers. For instance
|
||||
// Google rejects it, favoring appending "access_type=offline" as part of the
|
||||
// authorization request instead.
|
||||
//
|
||||
// See: https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess
|
||||
ScopeOfflineAccess = "offline_access"
|
||||
)
|
||||
|
||||
var (
|
||||
errNoAtHash = errors.New("id token did not have an access token hash")
|
||||
errInvalidAtHash = errors.New("access token hash does not match value in ID token")
|
||||
)
|
||||
|
||||
// ClientContext returns a new Context that carries the provided HTTP client.
|
||||
//
|
||||
// This method sets the same context key used by the golang.org/x/oauth2 package,
|
||||
// so the returned context works for that package too.
|
||||
//
|
||||
// myClient := &http.Client{}
|
||||
// ctx := oidc.ClientContext(parentContext, myClient)
|
||||
//
|
||||
// // This will use the custom client
|
||||
// provider, err := oidc.NewProvider(ctx, "https://accounts.example.com")
|
||||
//
|
||||
func ClientContext(ctx context.Context, client *http.Client) context.Context {
|
||||
return context.WithValue(ctx, oauth2.HTTPClient, client)
|
||||
}
|
||||
|
||||
func doRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
||||
client := http.DefaultClient
|
||||
if c, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
|
||||
client = c
|
||||
}
|
||||
return client.Do(req.WithContext(ctx))
|
||||
}
|
||||
|
||||
// Provider represents an OpenID Connect server's configuration.
|
||||
type Provider struct {
|
||||
issuer string
|
||||
authURL string
|
||||
tokenURL string
|
||||
userInfoURL string
|
||||
|
||||
// Raw claims returned by the server.
|
||||
rawClaims []byte
|
||||
|
||||
remoteKeySet KeySet
|
||||
}
|
||||
|
||||
type cachedKeys struct {
|
||||
keys []jose.JSONWebKey
|
||||
expiry time.Time
|
||||
}
|
||||
|
||||
type providerJSON struct {
|
||||
Issuer string `json:"issuer"`
|
||||
AuthURL string `json:"authorization_endpoint"`
|
||||
TokenURL string `json:"token_endpoint"`
|
||||
JWKSURL string `json:"jwks_uri"`
|
||||
UserInfoURL string `json:"userinfo_endpoint"`
|
||||
}
|
||||
|
||||
// NewProvider uses the OpenID Connect discovery mechanism to construct a Provider.
|
||||
//
|
||||
// The issuer is the URL identifier for the service. For example: "https://accounts.google.com"
|
||||
// or "https://login.salesforce.com".
|
||||
func NewProvider(ctx context.Context, issuer string) (*Provider, error) {
|
||||
wellKnown := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
|
||||
req, err := http.NewRequest("GET", wellKnown, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := doRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read response body: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("%s: %s", resp.Status, body)
|
||||
}
|
||||
|
||||
var p providerJSON
|
||||
err = unmarshalResp(resp, body, &p)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oidc: failed to decode provider discovery object: %v", err)
|
||||
}
|
||||
|
||||
if p.Issuer != issuer {
|
||||
return nil, fmt.Errorf("oidc: issuer did not match the issuer returned by provider, expected %q got %q", issuer, p.Issuer)
|
||||
}
|
||||
return &Provider{
|
||||
issuer: p.Issuer,
|
||||
authURL: p.AuthURL,
|
||||
tokenURL: p.TokenURL,
|
||||
userInfoURL: p.UserInfoURL,
|
||||
rawClaims: body,
|
||||
remoteKeySet: NewRemoteKeySet(ctx, p.JWKSURL),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Claims unmarshals raw fields returned by the server during discovery.
|
||||
//
|
||||
// var claims struct {
|
||||
// ScopesSupported []string `json:"scopes_supported"`
|
||||
// ClaimsSupported []string `json:"claims_supported"`
|
||||
// }
|
||||
//
|
||||
// if err := provider.Claims(&claims); err != nil {
|
||||
// // handle unmarshaling error
|
||||
// }
|
||||
//
|
||||
// For a list of fields defined by the OpenID Connect spec see:
|
||||
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
|
||||
func (p *Provider) Claims(v interface{}) error {
|
||||
if p.rawClaims == nil {
|
||||
return errors.New("oidc: claims not set")
|
||||
}
|
||||
return json.Unmarshal(p.rawClaims, v)
|
||||
}
|
||||
|
||||
// Endpoint returns the OAuth2 auth and token endpoints for the given provider.
|
||||
func (p *Provider) Endpoint() oauth2.Endpoint {
|
||||
return oauth2.Endpoint{AuthURL: p.authURL, TokenURL: p.tokenURL}
|
||||
}
|
||||
|
||||
// UserInfo represents the OpenID Connect userinfo claims.
|
||||
type UserInfo struct {
|
||||
Subject string `json:"sub"`
|
||||
Profile string `json:"profile"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
|
||||
claims []byte
|
||||
}
|
||||
|
||||
// Claims unmarshals the raw JSON object claims into the provided object.
|
||||
func (u *UserInfo) Claims(v interface{}) error {
|
||||
if u.claims == nil {
|
||||
return errors.New("oidc: claims not set")
|
||||
}
|
||||
return json.Unmarshal(u.claims, v)
|
||||
}
|
||||
|
||||
// UserInfo uses the token source to query the provider's user info endpoint.
|
||||
func (p *Provider) UserInfo(ctx context.Context, tokenSource oauth2.TokenSource) (*UserInfo, error) {
|
||||
if p.userInfoURL == "" {
|
||||
return nil, errors.New("oidc: user info endpoint is not supported by this provider")
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", p.userInfoURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oidc: create GET request: %v", err)
|
||||
}
|
||||
|
||||
token, err := tokenSource.Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oidc: get access token: %v", err)
|
||||
}
|
||||
token.SetAuthHeader(req)
|
||||
|
||||
resp, err := doRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("%s: %s", resp.Status, body)
|
||||
}
|
||||
|
||||
var userInfo UserInfo
|
||||
if err := json.Unmarshal(body, &userInfo); err != nil {
|
||||
return nil, fmt.Errorf("oidc: failed to decode userinfo: %v", err)
|
||||
}
|
||||
userInfo.claims = body
|
||||
return &userInfo, nil
|
||||
}
|
||||
|
||||
// IDToken is an OpenID Connect extension that provides a predictable representation
|
||||
// of an authorization event.
|
||||
//
|
||||
// The ID Token only holds fields OpenID Connect requires. To access additional
|
||||
// claims returned by the server, use the Claims method.
|
||||
type IDToken struct {
|
||||
// The URL of the server which issued this token. OpenID Connect
|
||||
// requires this value always be identical to the URL used for
|
||||
// initial discovery.
|
||||
//
|
||||
// Note: Because of a known issue with Google Accounts' implementation
|
||||
// this value may differ when using Google.
|
||||
//
|
||||
// See: https://developers.google.com/identity/protocols/OpenIDConnect#obtainuserinfo
|
||||
Issuer string
|
||||
|
||||
// The client ID, or set of client IDs, that this token is issued for. For
|
||||
// common uses, this is the client that initialized the auth flow.
|
||||
//
|
||||
// This package ensures the audience contains an expected value.
|
||||
Audience []string
|
||||
|
||||
// A unique string which identifies the end user.
|
||||
Subject string
|
||||
|
||||
// Expiry of the token. Ths package will not process tokens that have
|
||||
// expired unless that validation is explicitly turned off.
|
||||
Expiry time.Time
|
||||
// When the token was issued by the provider.
|
||||
IssuedAt time.Time
|
||||
|
||||
// Initial nonce provided during the authentication redirect.
|
||||
//
|
||||
// This package does NOT provided verification on the value of this field
|
||||
// and it's the user's responsibility to ensure it contains a valid value.
|
||||
Nonce string
|
||||
|
||||
// at_hash claim, if set in the ID token. Callers can verify an access token
|
||||
// that corresponds to the ID token using the VerifyAccessToken method.
|
||||
AccessTokenHash string
|
||||
|
||||
// signature algorithm used for ID token, needed to compute a verification hash of an
|
||||
// access token
|
||||
sigAlgorithm string
|
||||
|
||||
// Raw payload of the id_token.
|
||||
claims []byte
|
||||
|
||||
// Map of distributed claim names to claim sources
|
||||
distributedClaims map[string]claimSource
|
||||
}
|
||||
|
||||
// Claims unmarshals the raw JSON payload of the ID Token into a provided struct.
|
||||
//
|
||||
// idToken, err := idTokenVerifier.Verify(rawIDToken)
|
||||
// if err != nil {
|
||||
// // handle error
|
||||
// }
|
||||
// var claims struct {
|
||||
// Email string `json:"email"`
|
||||
// EmailVerified bool `json:"email_verified"`
|
||||
// }
|
||||
// if err := idToken.Claims(&claims); err != nil {
|
||||
// // handle error
|
||||
// }
|
||||
//
|
||||
func (i *IDToken) Claims(v interface{}) error {
|
||||
if i.claims == nil {
|
||||
return errors.New("oidc: claims not set")
|
||||
}
|
||||
return json.Unmarshal(i.claims, v)
|
||||
}
|
||||
|
||||
// VerifyAccessToken verifies that the hash of the access token that corresponds to the iD token
|
||||
// matches the hash in the id token. It returns an error if the hashes don't match.
|
||||
// It is the caller's responsibility to ensure that the optional access token hash is present for the ID token
|
||||
// before calling this method. See https://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken
|
||||
func (i *IDToken) VerifyAccessToken(accessToken string) error {
|
||||
if i.AccessTokenHash == "" {
|
||||
return errNoAtHash
|
||||
}
|
||||
var h hash.Hash
|
||||
switch i.sigAlgorithm {
|
||||
case RS256, ES256, PS256:
|
||||
h = sha256.New()
|
||||
case RS384, ES384, PS384:
|
||||
h = sha512.New384()
|
||||
case RS512, ES512, PS512:
|
||||
h = sha512.New()
|
||||
default:
|
||||
return fmt.Errorf("oidc: unsupported signing algorithm %q", i.sigAlgorithm)
|
||||
}
|
||||
h.Write([]byte(accessToken)) // hash documents that Write will never return an error
|
||||
sum := h.Sum(nil)[:h.Size()/2]
|
||||
actual := base64.RawURLEncoding.EncodeToString(sum)
|
||||
if actual != i.AccessTokenHash {
|
||||
return errInvalidAtHash
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type idToken struct {
|
||||
Issuer string `json:"iss"`
|
||||
Subject string `json:"sub"`
|
||||
Audience audience `json:"aud"`
|
||||
Expiry jsonTime `json:"exp"`
|
||||
IssuedAt jsonTime `json:"iat"`
|
||||
NotBefore *jsonTime `json:"nbf"`
|
||||
Nonce string `json:"nonce"`
|
||||
AtHash string `json:"at_hash"`
|
||||
ClaimNames map[string]string `json:"_claim_names"`
|
||||
ClaimSources map[string]claimSource `json:"_claim_sources"`
|
||||
}
|
||||
|
||||
type claimSource struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
|
||||
type audience []string
|
||||
|
||||
func (a *audience) UnmarshalJSON(b []byte) error {
|
||||
var s string
|
||||
if json.Unmarshal(b, &s) == nil {
|
||||
*a = audience{s}
|
||||
return nil
|
||||
}
|
||||
var auds []string
|
||||
if err := json.Unmarshal(b, &auds); err != nil {
|
||||
return err
|
||||
}
|
||||
*a = audience(auds)
|
||||
return nil
|
||||
}
|
||||
|
||||
type jsonTime time.Time
|
||||
|
||||
func (j *jsonTime) UnmarshalJSON(b []byte) error {
|
||||
var n json.Number
|
||||
if err := json.Unmarshal(b, &n); err != nil {
|
||||
return err
|
||||
}
|
||||
var unix int64
|
||||
|
||||
if t, err := n.Int64(); err == nil {
|
||||
unix = t
|
||||
} else {
|
||||
f, err := n.Float64()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
unix = int64(f)
|
||||
}
|
||||
*j = jsonTime(time.Unix(unix, 0))
|
||||
return nil
|
||||
}
|
||||
|
||||
func unmarshalResp(r *http.Response, body []byte, v interface{}) error {
|
||||
err := json.Unmarshal(body, &v)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
ct := r.Header.Get("Content-Type")
|
||||
mediaType, _, parseErr := mime.ParseMediaType(ct)
|
||||
if parseErr == nil && mediaType == "application/json" {
|
||||
return fmt.Errorf("got Content-Type = application/json, but could not unmarshal as JSON: %v", err)
|
||||
}
|
||||
return fmt.Errorf("expected Content-Type = application/json, got %q: %v", ct, err)
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
# Filter out any files with a !golint build tag.
|
||||
LINTABLE=$( go list -tags=golint -f '
|
||||
{{- range $i, $file := .GoFiles -}}
|
||||
{{ $file }} {{ end }}
|
||||
{{ range $i, $file := .TestGoFiles -}}
|
||||
{{ $file }} {{ end }}' github.com/coreos/go-oidc )
|
||||
|
||||
go test -v -i -race github.com/coreos/go-oidc/...
|
||||
go test -v -race github.com/coreos/go-oidc/...
|
||||
golint -set_exit_status $LINTABLE
|
||||
go vet github.com/coreos/go-oidc/...
|
||||
go build -v ./example/...
|
|
@ -0,0 +1,327 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
const (
|
||||
issuerGoogleAccounts = "https://accounts.google.com"
|
||||
issuerGoogleAccountsNoScheme = "accounts.google.com"
|
||||
)
|
||||
|
||||
// KeySet is a set of publc JSON Web Keys that can be used to validate the signature
|
||||
// of JSON web tokens. This is expected to be backed by a remote key set through
|
||||
// provider metadata discovery or an in-memory set of keys delivered out-of-band.
|
||||
type KeySet interface {
|
||||
// VerifySignature parses the JSON web token, verifies the signature, and returns
|
||||
// the raw payload. Header and claim fields are validated by other parts of the
|
||||
// package. For example, the KeySet does not need to check values such as signature
|
||||
// algorithm, issuer, and audience since the IDTokenVerifier validates these values
|
||||
// independently.
|
||||
//
|
||||
// If VerifySignature makes HTTP requests to verify the token, it's expected to
|
||||
// use any HTTP client associated with the context through ClientContext.
|
||||
VerifySignature(ctx context.Context, jwt string) (payload []byte, err error)
|
||||
}
|
||||
|
||||
// IDTokenVerifier provides verification for ID Tokens.
|
||||
type IDTokenVerifier struct {
|
||||
keySet KeySet
|
||||
config *Config
|
||||
issuer string
|
||||
}
|
||||
|
||||
// NewVerifier returns a verifier manually constructed from a key set and issuer URL.
|
||||
//
|
||||
// It's easier to use provider discovery to construct an IDTokenVerifier than creating
|
||||
// one directly. This method is intended to be used with provider that don't support
|
||||
// metadata discovery, or avoiding round trips when the key set URL is already known.
|
||||
//
|
||||
// This constructor can be used to create a verifier directly using the issuer URL and
|
||||
// JSON Web Key Set URL without using discovery:
|
||||
//
|
||||
// keySet := oidc.NewRemoteKeySet(ctx, "https://www.googleapis.com/oauth2/v3/certs")
|
||||
// verifier := oidc.NewVerifier("https://accounts.google.com", keySet, config)
|
||||
//
|
||||
// Since KeySet is an interface, this constructor can also be used to supply custom
|
||||
// public key sources. For example, if a user wanted to supply public keys out-of-band
|
||||
// and hold them statically in-memory:
|
||||
//
|
||||
// // Custom KeySet implementation.
|
||||
// keySet := newStatisKeySet(publicKeys...)
|
||||
//
|
||||
// // Verifier uses the custom KeySet implementation.
|
||||
// verifier := oidc.NewVerifier("https://auth.example.com", keySet, config)
|
||||
//
|
||||
func NewVerifier(issuerURL string, keySet KeySet, config *Config) *IDTokenVerifier {
|
||||
return &IDTokenVerifier{keySet: keySet, config: config, issuer: issuerURL}
|
||||
}
|
||||
|
||||
// Config is the configuration for an IDTokenVerifier.
|
||||
type Config struct {
|
||||
// Expected audience of the token. For a majority of the cases this is expected to be
|
||||
// the ID of the client that initialized the login flow. It may occasionally differ if
|
||||
// the provider supports the authorizing party (azp) claim.
|
||||
//
|
||||
// If not provided, users must explicitly set SkipClientIDCheck.
|
||||
ClientID string
|
||||
// If specified, only this set of algorithms may be used to sign the JWT.
|
||||
//
|
||||
// Since many providers only support RS256, SupportedSigningAlgs defaults to this value.
|
||||
SupportedSigningAlgs []string
|
||||
|
||||
// If true, no ClientID check performed. Must be true if ClientID field is empty.
|
||||
SkipClientIDCheck bool
|
||||
// If true, token expiry is not checked.
|
||||
SkipExpiryCheck bool
|
||||
|
||||
// SkipIssuerCheck is intended for specialized cases where the the caller wishes to
|
||||
// defer issuer validation. When enabled, callers MUST independently verify the Token's
|
||||
// Issuer is a known good value.
|
||||
//
|
||||
// Mismatched issuers often indicate client mis-configuration. If mismatches are
|
||||
// unexpected, evaluate if the provided issuer URL is incorrect instead of enabling
|
||||
// this option.
|
||||
SkipIssuerCheck bool
|
||||
|
||||
// Time function to check Token expiry. Defaults to time.Now
|
||||
Now func() time.Time
|
||||
}
|
||||
|
||||
// Verifier returns an IDTokenVerifier that uses the provider's key set to verify JWTs.
|
||||
//
|
||||
// The returned IDTokenVerifier is tied to the Provider's context and its behavior is
|
||||
// undefined once the Provider's context is canceled.
|
||||
func (p *Provider) Verifier(config *Config) *IDTokenVerifier {
|
||||
return NewVerifier(p.issuer, p.remoteKeySet, config)
|
||||
}
|
||||
|
||||
func parseJWT(p string) ([]byte, error) {
|
||||
parts := strings.Split(p, ".")
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("oidc: malformed jwt, expected 3 parts got %d", len(parts))
|
||||
}
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oidc: malformed jwt payload: %v", err)
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func contains(sli []string, ele string) bool {
|
||||
for _, s := range sli {
|
||||
if s == ele {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Returns the Claims from the distributed JWT token
|
||||
func resolveDistributedClaim(ctx context.Context, verifier *IDTokenVerifier, src claimSource) ([]byte, error) {
|
||||
req, err := http.NewRequest("GET", src.Endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("malformed request: %v", err)
|
||||
}
|
||||
if src.AccessToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+src.AccessToken)
|
||||
}
|
||||
|
||||
resp, err := doRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oidc: Request to endpoint failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read response body: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("oidc: request failed: %v", resp.StatusCode)
|
||||
}
|
||||
|
||||
token, err := verifier.Verify(ctx, string(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("malformed response body: %v", err)
|
||||
}
|
||||
|
||||
return token.claims, nil
|
||||
}
|
||||
|
||||
func parseClaim(raw []byte, name string, v interface{}) error {
|
||||
var parsed map[string]json.RawMessage
|
||||
if err := json.Unmarshal(raw, &parsed); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val, ok := parsed[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("claim doesn't exist: %s", name)
|
||||
}
|
||||
|
||||
return json.Unmarshal([]byte(val), v)
|
||||
}
|
||||
|
||||
// Verify parses a raw ID Token, verifies it's been signed by the provider, preforms
|
||||
// any additional checks depending on the Config, and returns the payload.
|
||||
//
|
||||
// Verify does NOT do nonce validation, which is the callers responsibility.
|
||||
//
|
||||
// See: https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
//
|
||||
// oauth2Token, err := oauth2Config.Exchange(ctx, r.URL.Query().Get("code"))
|
||||
// if err != nil {
|
||||
// // handle error
|
||||
// }
|
||||
//
|
||||
// // Extract the ID Token from oauth2 token.
|
||||
// rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
||||
// if !ok {
|
||||
// // handle error
|
||||
// }
|
||||
//
|
||||
// token, err := verifier.Verify(ctx, rawIDToken)
|
||||
//
|
||||
func (v *IDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (*IDToken, error) {
|
||||
jws, err := jose.ParseSigned(rawIDToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oidc: malformed jwt: %v", err)
|
||||
}
|
||||
|
||||
// Throw out tokens with invalid claims before trying to verify the token. This lets
|
||||
// us do cheap checks before possibly re-syncing keys.
|
||||
payload, err := parseJWT(rawIDToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oidc: malformed jwt: %v", err)
|
||||
}
|
||||
var token idToken
|
||||
if err := json.Unmarshal(payload, &token); err != nil {
|
||||
return nil, fmt.Errorf("oidc: failed to unmarshal claims: %v", err)
|
||||
}
|
||||
|
||||
distributedClaims := make(map[string]claimSource)
|
||||
|
||||
//step through the token to map claim names to claim sources"
|
||||
for cn, src := range token.ClaimNames {
|
||||
if src == "" {
|
||||
return nil, fmt.Errorf("oidc: failed to obtain source from claim name")
|
||||
}
|
||||
s, ok := token.ClaimSources[src]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("oidc: source does not exist")
|
||||
}
|
||||
distributedClaims[cn] = s
|
||||
}
|
||||
|
||||
t := &IDToken{
|
||||
Issuer: token.Issuer,
|
||||
Subject: token.Subject,
|
||||
Audience: []string(token.Audience),
|
||||
Expiry: time.Time(token.Expiry),
|
||||
IssuedAt: time.Time(token.IssuedAt),
|
||||
Nonce: token.Nonce,
|
||||
AccessTokenHash: token.AtHash,
|
||||
claims: payload,
|
||||
distributedClaims: distributedClaims,
|
||||
}
|
||||
|
||||
// Check issuer.
|
||||
if !v.config.SkipIssuerCheck && t.Issuer != v.issuer {
|
||||
// Google sometimes returns "accounts.google.com" as the issuer claim instead of
|
||||
// the required "https://accounts.google.com". Detect this case and allow it only
|
||||
// for Google.
|
||||
//
|
||||
// We will not add hooks to let other providers go off spec like this.
|
||||
if !(v.issuer == issuerGoogleAccounts && t.Issuer == issuerGoogleAccountsNoScheme) {
|
||||
return nil, fmt.Errorf("oidc: id token issued by a different provider, expected %q got %q", v.issuer, t.Issuer)
|
||||
}
|
||||
}
|
||||
|
||||
// If a client ID has been provided, make sure it's part of the audience. SkipClientIDCheck must be true if ClientID is empty.
|
||||
//
|
||||
// This check DOES NOT ensure that the ClientID is the party to which the ID Token was issued (i.e. Authorized party).
|
||||
if !v.config.SkipClientIDCheck {
|
||||
if v.config.ClientID != "" {
|
||||
if !contains(t.Audience, v.config.ClientID) {
|
||||
return nil, fmt.Errorf("oidc: expected audience %q got %q", v.config.ClientID, t.Audience)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("oidc: invalid configuration, clientID must be provided or SkipClientIDCheck must be set")
|
||||
}
|
||||
}
|
||||
|
||||
// If a SkipExpiryCheck is false, make sure token is not expired.
|
||||
if !v.config.SkipExpiryCheck {
|
||||
now := time.Now
|
||||
if v.config.Now != nil {
|
||||
now = v.config.Now
|
||||
}
|
||||
nowTime := now()
|
||||
|
||||
if t.Expiry.Before(nowTime) {
|
||||
return nil, fmt.Errorf("oidc: token is expired (Token Expiry: %v)", t.Expiry)
|
||||
}
|
||||
|
||||
// If nbf claim is provided in token, ensure that it is indeed in the past.
|
||||
if token.NotBefore != nil {
|
||||
nbfTime := time.Time(*token.NotBefore)
|
||||
leeway := 1 * time.Minute
|
||||
|
||||
if nowTime.Add(leeway).Before(nbfTime) {
|
||||
return nil, fmt.Errorf("oidc: current time %v before the nbf (not before) time: %v", nowTime, nbfTime)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch len(jws.Signatures) {
|
||||
case 0:
|
||||
return nil, fmt.Errorf("oidc: id token not signed")
|
||||
case 1:
|
||||
default:
|
||||
return nil, fmt.Errorf("oidc: multiple signatures on id token not supported")
|
||||
}
|
||||
|
||||
sig := jws.Signatures[0]
|
||||
supportedSigAlgs := v.config.SupportedSigningAlgs
|
||||
if len(supportedSigAlgs) == 0 {
|
||||
supportedSigAlgs = []string{RS256}
|
||||
}
|
||||
|
||||
if !contains(supportedSigAlgs, sig.Header.Algorithm) {
|
||||
return nil, fmt.Errorf("oidc: id token signed with unsupported algorithm, expected %q got %q", supportedSigAlgs, sig.Header.Algorithm)
|
||||
}
|
||||
|
||||
t.sigAlgorithm = sig.Header.Algorithm
|
||||
|
||||
gotPayload, err := v.keySet.VerifySignature(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to verify signature: %v", err)
|
||||
}
|
||||
|
||||
// Ensure that the payload returned by the square actually matches the payload parsed earlier.
|
||||
if !bytes.Equal(gotPayload, payload) {
|
||||
return nil, errors.New("oidc: internal error, payload parsed did not match previous payload")
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// Nonce returns an auth code option which requires the ID Token created by the
|
||||
// OpenID Connect provider to contain the specified nonce.
|
||||
func Nonce(nonce string) oauth2.AuthCodeOption {
|
||||
return oauth2.SetAuthURLParam("nonce", nonce)
|
||||
}
|
|
@ -4,22 +4,40 @@ import (
|
|||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// GenerateRandomBytes is used to generate random bytes of given size.
|
||||
func GenerateRandomBytes(size int) ([]byte, error) {
|
||||
return GenerateRandomBytesWithReader(size, rand.Reader)
|
||||
}
|
||||
|
||||
// GenerateRandomBytesWithReader is used to generate random bytes of given size read from a given reader.
|
||||
func GenerateRandomBytesWithReader(size int, reader io.Reader) ([]byte, error) {
|
||||
if reader == nil {
|
||||
return nil, fmt.Errorf("provided reader is nil")
|
||||
}
|
||||
buf := make([]byte, size)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
if _, err := io.ReadFull(reader, buf); err != nil {
|
||||
return nil, fmt.Errorf("failed to read random bytes: %v", err)
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
|
||||
const uuidLen = 16
|
||||
|
||||
// GenerateUUID is used to generate a random UUID
|
||||
func GenerateUUID() (string, error) {
|
||||
buf, err := GenerateRandomBytes(uuidLen)
|
||||
return GenerateUUIDWithReader(rand.Reader)
|
||||
}
|
||||
|
||||
// GenerateUUIDWithReader is used to generate a random UUID with a given Reader
|
||||
func GenerateUUIDWithReader(reader io.Reader) (string, error) {
|
||||
if reader == nil {
|
||||
return "", fmt.Errorf("provided reader is nil")
|
||||
}
|
||||
buf, err := GenerateRandomBytesWithReader(uuidLen, reader)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
language: go
|
||||
|
||||
go:
|
||||
- 1.7
|
||||
- tip
|
||||
|
||||
script:
|
||||
- go test
|
||||
|
||||
matrix:
|
||||
allow_failures:
|
||||
- go: tip
|
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2019 Mitchell Hashimoto
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
|
@ -0,0 +1,74 @@
|
|||
# pointerstructure [![GoDoc](https://godoc.org/github.com/mitchellh/pointerstructure?status.svg)](https://godoc.org/github.com/mitchellh/pointerstructure)
|
||||
|
||||
pointerstructure is a Go library for identifying a specific value within
|
||||
any Go structure using a string syntax.
|
||||
|
||||
pointerstructure is based on
|
||||
[JSON Pointer (RFC 6901)](https://tools.ietf.org/html/rfc6901), but
|
||||
reimplemented for Go.
|
||||
|
||||
The goal of pointerstructure is to provide a single, well-known format
|
||||
for addressing a specific value. This can be useful for user provided
|
||||
input on structures, diffs of structures, etc.
|
||||
|
||||
## Features
|
||||
|
||||
* Get the value for an address
|
||||
|
||||
* Set the value for an address within an existing structure
|
||||
|
||||
* Delete the value at an address
|
||||
|
||||
* Sorting a list of addresses
|
||||
|
||||
## Installation
|
||||
|
||||
Standard `go get`:
|
||||
|
||||
```
|
||||
$ go get github.com/mitchellh/pointerstructure
|
||||
```
|
||||
|
||||
## Usage & Example
|
||||
|
||||
For usage and examples see the [Godoc](http://godoc.org/github.com/mitchellh/pointerstructure).
|
||||
|
||||
A quick code example is shown below:
|
||||
|
||||
```go
|
||||
complex := map[string]interface{}{
|
||||
"alice": 42,
|
||||
"bob": []interface{}{
|
||||
map[string]interface{}{
|
||||
"name": "Bob",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
value, err := pointerstructure.Get(complex, "/bob/0/name")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("%s", value)
|
||||
// Output:
|
||||
// Bob
|
||||
```
|
||||
|
||||
Continuing the example above, you can also set values:
|
||||
|
||||
```go
|
||||
value, err = pointerstructure.Set(complex, "/bob/0/name", "Alice")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
value, err = pointerstructure.Get(complex, "/bob/0/name")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("%s", value)
|
||||
// Output:
|
||||
// Alice
|
||||
```
|
|
@ -0,0 +1,112 @@
|
|||
package pointerstructure
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// Delete deletes the value specified by the pointer p in structure s.
|
||||
//
|
||||
// When deleting a slice index, all other elements will be shifted to
|
||||
// the left. This is specified in RFC6902 (JSON Patch) and not RFC6901 since
|
||||
// RFC6901 doesn't specify operations on pointers. If you don't want to
|
||||
// shift elements, you should use Set to set the slice index to the zero value.
|
||||
//
|
||||
// The structures s must have non-zero values set up to this pointer.
|
||||
// For example, if deleting "/bob/0/name", then "/bob/0" must be set already.
|
||||
//
|
||||
// The returned value is potentially a new value if this pointer represents
|
||||
// the root document. Otherwise, the returned value will always be s.
|
||||
func (p *Pointer) Delete(s interface{}) (interface{}, error) {
|
||||
// if we represent the root doc, we've deleted everything
|
||||
if len(p.Parts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Save the original since this is going to be our return value
|
||||
originalS := s
|
||||
|
||||
// Get the parent value
|
||||
var err error
|
||||
s, err = p.Parent().Get(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Map for lookup of getter to call for type
|
||||
funcMap := map[reflect.Kind]deleteFunc{
|
||||
reflect.Array: p.deleteSlice,
|
||||
reflect.Map: p.deleteMap,
|
||||
reflect.Slice: p.deleteSlice,
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(s)
|
||||
for val.Kind() == reflect.Interface {
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
for val.Kind() == reflect.Ptr {
|
||||
val = reflect.Indirect(val)
|
||||
}
|
||||
|
||||
f, ok := funcMap[val.Kind()]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("delete %s: invalid value kind: %s", p, val.Kind())
|
||||
}
|
||||
|
||||
result, err := f(originalS, val)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("delete %s: %s", p, err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type deleteFunc func(interface{}, reflect.Value) (interface{}, error)
|
||||
|
||||
func (p *Pointer) deleteMap(root interface{}, m reflect.Value) (interface{}, error) {
|
||||
part := p.Parts[len(p.Parts)-1]
|
||||
key, err := coerce(reflect.ValueOf(part), m.Type().Key())
|
||||
if err != nil {
|
||||
return root, err
|
||||
}
|
||||
|
||||
// Delete the key
|
||||
var elem reflect.Value
|
||||
m.SetMapIndex(key, elem)
|
||||
return root, nil
|
||||
}
|
||||
|
||||
func (p *Pointer) deleteSlice(root interface{}, s reflect.Value) (interface{}, error) {
|
||||
// Coerce the key to an int
|
||||
part := p.Parts[len(p.Parts)-1]
|
||||
idxVal, err := coerce(reflect.ValueOf(part), reflect.TypeOf(42))
|
||||
if err != nil {
|
||||
return root, err
|
||||
}
|
||||
idx := int(idxVal.Int())
|
||||
|
||||
// Verify we're within bounds
|
||||
if idx < 0 || idx >= s.Len() {
|
||||
return root, fmt.Errorf(
|
||||
"index %d is out of range (length = %d)", idx, s.Len())
|
||||
}
|
||||
|
||||
// Mimicing the following with reflection to do this:
|
||||
//
|
||||
// copy(a[i:], a[i+1:])
|
||||
// a[len(a)-1] = nil // or the zero value of T
|
||||
// a = a[:len(a)-1]
|
||||
|
||||
// copy(a[i:], a[i+1:])
|
||||
reflect.Copy(s.Slice(idx, s.Len()), s.Slice(idx+1, s.Len()))
|
||||
|
||||
// a[len(a)-1] = nil // or the zero value of T
|
||||
s.Index(s.Len() - 1).Set(reflect.Zero(s.Type().Elem()))
|
||||
|
||||
// a = a[:len(a)-1]
|
||||
s = s.Slice(0, s.Len()-1)
|
||||
|
||||
// set the slice back on the parent
|
||||
return p.Parent().Set(root, s.Interface())
|
||||
}
|
|
@ -0,0 +1,96 @@
|
|||
package pointerstructure
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// Get reads the value out of the total value v.
|
||||
func (p *Pointer) Get(v interface{}) (interface{}, error) {
|
||||
// fast-path the empty address case to avoid reflect.ValueOf below
|
||||
if len(p.Parts) == 0 {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// Map for lookup of getter to call for type
|
||||
funcMap := map[reflect.Kind]func(string, reflect.Value) (reflect.Value, error){
|
||||
reflect.Array: p.getSlice,
|
||||
reflect.Map: p.getMap,
|
||||
reflect.Slice: p.getSlice,
|
||||
reflect.Struct: p.getStruct,
|
||||
}
|
||||
|
||||
currentVal := reflect.ValueOf(v)
|
||||
for i, part := range p.Parts {
|
||||
for currentVal.Kind() == reflect.Interface {
|
||||
currentVal = currentVal.Elem()
|
||||
}
|
||||
|
||||
for currentVal.Kind() == reflect.Ptr {
|
||||
currentVal = reflect.Indirect(currentVal)
|
||||
}
|
||||
|
||||
f, ok := funcMap[currentVal.Kind()]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(
|
||||
"%s: at part %d, invalid value kind: %s", p, i, currentVal.Kind())
|
||||
}
|
||||
|
||||
var err error
|
||||
currentVal, err = f(part, currentVal)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s at part %d: %s", p, i, err)
|
||||
}
|
||||
}
|
||||
|
||||
return currentVal.Interface(), nil
|
||||
}
|
||||
|
||||
func (p *Pointer) getMap(part string, m reflect.Value) (reflect.Value, error) {
|
||||
var zeroValue reflect.Value
|
||||
|
||||
// Coerce the string part to the correct key type
|
||||
key, err := coerce(reflect.ValueOf(part), m.Type().Key())
|
||||
if err != nil {
|
||||
return zeroValue, err
|
||||
}
|
||||
|
||||
// Verify that the key exists
|
||||
found := false
|
||||
for _, k := range m.MapKeys() {
|
||||
if k.Interface() == key.Interface() {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return zeroValue, fmt.Errorf("couldn't find key %#v", key.Interface())
|
||||
}
|
||||
|
||||
// Get the key
|
||||
return m.MapIndex(key), nil
|
||||
}
|
||||
|
||||
func (p *Pointer) getSlice(part string, v reflect.Value) (reflect.Value, error) {
|
||||
var zeroValue reflect.Value
|
||||
|
||||
// Coerce the key to an int
|
||||
idxVal, err := coerce(reflect.ValueOf(part), reflect.TypeOf(42))
|
||||
if err != nil {
|
||||
return zeroValue, err
|
||||
}
|
||||
idx := int(idxVal.Int())
|
||||
|
||||
// Verify we're within bounds
|
||||
if idx < 0 || idx >= v.Len() {
|
||||
return zeroValue, fmt.Errorf(
|
||||
"index %d is out of range (length = %d)", idx, v.Len())
|
||||
}
|
||||
|
||||
// Get the key
|
||||
return v.Index(idx), nil
|
||||
}
|
||||
|
||||
func (p *Pointer) getStruct(part string, m reflect.Value) (reflect.Value, error) {
|
||||
return m.FieldByName(part), nil
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
module github.com/mitchellh/pointerstructure
|
||||
|
||||
go 1.12
|
||||
|
||||
require github.com/mitchellh/mapstructure v1.1.2
|
|
@ -0,0 +1,2 @@
|
|||
github.com/mitchellh/mapstructure v1.1.2 h1:fmNYVwqnSfB9mZU6OS2O6GsXM+wcskZDuKQzvN1EDeE=
|
||||
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
|
|
@ -0,0 +1,57 @@
|
|||
package pointerstructure
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Parse parses a pointer from the input string. The input string
|
||||
// is expected to follow the format specified by RFC 6901: '/'-separated
|
||||
// parts. Each part can contain escape codes to contain '/' or '~'.
|
||||
func Parse(input string) (*Pointer, error) {
|
||||
// Special case the empty case
|
||||
if input == "" {
|
||||
return &Pointer{}, nil
|
||||
}
|
||||
|
||||
// We expect the first character to be "/"
|
||||
if input[0] != '/' {
|
||||
return nil, fmt.Errorf(
|
||||
"parse Go pointer %q: first char must be '/'", input)
|
||||
}
|
||||
|
||||
// Trim out the first slash so we don't have to +1 every index
|
||||
input = input[1:]
|
||||
|
||||
// Parse out all the parts
|
||||
var parts []string
|
||||
lastSlash := -1
|
||||
for i, r := range input {
|
||||
if r == '/' {
|
||||
parts = append(parts, input[lastSlash+1:i])
|
||||
lastSlash = i
|
||||
}
|
||||
}
|
||||
|
||||
// Add last part
|
||||
parts = append(parts, input[lastSlash+1:])
|
||||
|
||||
// Process each part for string replacement
|
||||
for i, p := range parts {
|
||||
// Replace ~1 followed by ~0 as specified by the RFC
|
||||
parts[i] = strings.Replace(
|
||||
strings.Replace(p, "~1", "/", -1), "~0", "~", -1)
|
||||
}
|
||||
|
||||
return &Pointer{Parts: parts}, nil
|
||||
}
|
||||
|
||||
// MustParse is like Parse but panics if the input cannot be parsed.
|
||||
func MustParse(input string) *Pointer {
|
||||
p, err := Parse(input)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
|
@ -0,0 +1,123 @@
|
|||
// Package pointerstructure provides functions for identifying a specific
|
||||
// value within any Go structure using a string syntax.
|
||||
//
|
||||
// The syntax used is based on JSON Pointer (RFC 6901).
|
||||
package pointerstructure
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
// Pointer represents a pointer to a specific value. You can construct
|
||||
// a pointer manually or use Parse.
|
||||
type Pointer struct {
|
||||
// Parts are the pointer parts. No escape codes are processed here.
|
||||
// The values are expected to be exact. If you have escape codes, use
|
||||
// the Parse functions.
|
||||
Parts []string
|
||||
}
|
||||
|
||||
// Get reads the value at the given pointer.
|
||||
//
|
||||
// This is a shorthand for calling Parse on the pointer and then calling Get
|
||||
// on that result. An error will be returned if the value cannot be found or
|
||||
// there is an error with the format of pointer.
|
||||
func Get(value interface{}, pointer string) (interface{}, error) {
|
||||
p, err := Parse(pointer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.Get(value)
|
||||
}
|
||||
|
||||
// Set sets the value at the given pointer.
|
||||
//
|
||||
// This is a shorthand for calling Parse on the pointer and then calling Set
|
||||
// on that result. An error will be returned if the value cannot be found or
|
||||
// there is an error with the format of pointer.
|
||||
//
|
||||
// Set returns the complete document, which might change if the pointer value
|
||||
// points to the root ("").
|
||||
func Set(doc interface{}, pointer string, value interface{}) (interface{}, error) {
|
||||
p, err := Parse(pointer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.Set(doc, value)
|
||||
}
|
||||
|
||||
// String returns the string value that can be sent back to Parse to get
|
||||
// the same Pointer result.
|
||||
func (p *Pointer) String() string {
|
||||
if len(p.Parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Copy the parts so we can convert back the escapes
|
||||
result := make([]string, len(p.Parts))
|
||||
copy(result, p.Parts)
|
||||
for i, p := range p.Parts {
|
||||
result[i] = strings.Replace(
|
||||
strings.Replace(p, "~", "~0", -1), "/", "~1", -1)
|
||||
|
||||
}
|
||||
|
||||
return "/" + strings.Join(result, "/")
|
||||
}
|
||||
|
||||
// Parent returns a pointer to the parent element of this pointer.
|
||||
//
|
||||
// If Pointer represents the root (empty parts), a pointer representing
|
||||
// the root is returned. Therefore, to check for the root, IsRoot() should be
|
||||
// called.
|
||||
func (p *Pointer) Parent() *Pointer {
|
||||
// If this is root, then we just return a new root pointer. We allocate
|
||||
// a new one though so this can still be modified.
|
||||
if p.IsRoot() {
|
||||
return &Pointer{}
|
||||
}
|
||||
|
||||
parts := make([]string, len(p.Parts)-1)
|
||||
copy(parts, p.Parts[:len(p.Parts)-1])
|
||||
return &Pointer{
|
||||
Parts: parts,
|
||||
}
|
||||
}
|
||||
|
||||
// IsRoot returns true if this pointer represents the root document.
|
||||
func (p *Pointer) IsRoot() bool {
|
||||
return len(p.Parts) == 0
|
||||
}
|
||||
|
||||
// coerce is a helper to coerce a value to a specific type if it must
|
||||
// and if its possible. If it isn't possible, an error is returned.
|
||||
func coerce(value reflect.Value, to reflect.Type) (reflect.Value, error) {
|
||||
// If the value is already assignable to the type, then let it go
|
||||
if value.Type().AssignableTo(to) {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// If a direct conversion is possible, do that
|
||||
if value.Type().ConvertibleTo(to) {
|
||||
return value.Convert(to), nil
|
||||
}
|
||||
|
||||
// Create a new value to hold our result
|
||||
result := reflect.New(to)
|
||||
|
||||
// Decode
|
||||
if err := mapstructure.WeakDecode(value.Interface(), result.Interface()); err != nil {
|
||||
return result, fmt.Errorf(
|
||||
"couldn't convert value %#v to type %s",
|
||||
value.Interface(), to.String())
|
||||
}
|
||||
|
||||
// We need to indirect the value since reflect.New always creates a pointer
|
||||
return reflect.Indirect(result), nil
|
||||
}
|
|
@ -0,0 +1,122 @@
|
|||
package pointerstructure
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// Set writes a value v to the pointer p in structure s.
|
||||
//
|
||||
// The structures s must have non-zero values set up to this pointer.
|
||||
// For example, if setting "/bob/0/name", then "/bob/0" must be set already.
|
||||
//
|
||||
// The returned value is potentially a new value if this pointer represents
|
||||
// the root document. Otherwise, the returned value will always be s.
|
||||
func (p *Pointer) Set(s, v interface{}) (interface{}, error) {
|
||||
// if we represent the root doc, return that
|
||||
if len(p.Parts) == 0 {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// Save the original since this is going to be our return value
|
||||
originalS := s
|
||||
|
||||
// Get the parent value
|
||||
var err error
|
||||
s, err = p.Parent().Get(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Map for lookup of getter to call for type
|
||||
funcMap := map[reflect.Kind]setFunc{
|
||||
reflect.Array: p.setSlice,
|
||||
reflect.Map: p.setMap,
|
||||
reflect.Slice: p.setSlice,
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(s)
|
||||
for val.Kind() == reflect.Interface {
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
for val.Kind() == reflect.Ptr {
|
||||
val = reflect.Indirect(val)
|
||||
}
|
||||
|
||||
f, ok := funcMap[val.Kind()]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("set %s: invalid value kind: %s", p, val.Kind())
|
||||
}
|
||||
|
||||
result, err := f(originalS, val, reflect.ValueOf(v))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("set %s: %s", p, err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type setFunc func(interface{}, reflect.Value, reflect.Value) (interface{}, error)
|
||||
|
||||
func (p *Pointer) setMap(root interface{}, m, value reflect.Value) (interface{}, error) {
|
||||
part := p.Parts[len(p.Parts)-1]
|
||||
key, err := coerce(reflect.ValueOf(part), m.Type().Key())
|
||||
if err != nil {
|
||||
return root, err
|
||||
}
|
||||
|
||||
elem, err := coerce(value, m.Type().Elem())
|
||||
if err != nil {
|
||||
return root, err
|
||||
}
|
||||
|
||||
// Set the key
|
||||
m.SetMapIndex(key, elem)
|
||||
return root, nil
|
||||
}
|
||||
|
||||
func (p *Pointer) setSlice(root interface{}, s, value reflect.Value) (interface{}, error) {
|
||||
// Coerce the value, we'll need that no matter what
|
||||
value, err := coerce(value, s.Type().Elem())
|
||||
if err != nil {
|
||||
return root, err
|
||||
}
|
||||
|
||||
// If the part is the special "-", that means to append it (RFC6901 4.)
|
||||
part := p.Parts[len(p.Parts)-1]
|
||||
if part == "-" {
|
||||
return p.setSliceAppend(root, s, value)
|
||||
}
|
||||
|
||||
// Coerce the key to an int
|
||||
idxVal, err := coerce(reflect.ValueOf(part), reflect.TypeOf(42))
|
||||
if err != nil {
|
||||
return root, err
|
||||
}
|
||||
idx := int(idxVal.Int())
|
||||
|
||||
// Verify we're within bounds
|
||||
if idx < 0 || idx >= s.Len() {
|
||||
return root, fmt.Errorf(
|
||||
"index %d is out of range (length = %d)", idx, s.Len())
|
||||
}
|
||||
|
||||
// Set the key
|
||||
s.Index(idx).Set(value)
|
||||
return root, nil
|
||||
}
|
||||
|
||||
func (p *Pointer) setSliceAppend(root interface{}, s, value reflect.Value) (interface{}, error) {
|
||||
// Coerce the value, we'll need that no matter what. This should
|
||||
// be a no-op since we expect it to be done already, but there is
|
||||
// a fast-path check for that in coerce so do it anyways.
|
||||
value, err := coerce(value, s.Type().Elem())
|
||||
if err != nil {
|
||||
return root, err
|
||||
}
|
||||
|
||||
// We can assume "s" is the parent of pointer value. We need to actually
|
||||
// write s back because Append can return a new slice.
|
||||
return p.Parent().Set(root, reflect.Append(s, value).Interface())
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
package pointerstructure
|
||||
|
||||
import (
|
||||
"sort"
|
||||
)
|
||||
|
||||
// Sort does an in-place sort of the pointers so that they are in order
|
||||
// of least specific to most specific alphabetized. For example:
|
||||
// "/foo", "/foo/0", "/qux"
|
||||
//
|
||||
// This ordering is ideal for applying the changes in a way that ensures
|
||||
// that parents are set first.
|
||||
func Sort(p []*Pointer) { sort.Sort(PointerSlice(p)) }
|
||||
|
||||
// PointerSlice is a slice of pointers that adheres to sort.Interface
|
||||
type PointerSlice []*Pointer
|
||||
|
||||
func (p PointerSlice) Len() int { return len(p) }
|
||||
func (p PointerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
||||
func (p PointerSlice) Less(i, j int) bool {
|
||||
// Equal number of parts, do a string compare per part
|
||||
for idx, ival := range p[i].Parts {
|
||||
// If we're passed the length of p[j] parts, then we're done
|
||||
if idx >= len(p[j].Parts) {
|
||||
break
|
||||
}
|
||||
|
||||
// Compare the values if they're not equal
|
||||
jval := p[j].Parts[idx]
|
||||
if ival != jval {
|
||||
return ival < jval
|
||||
}
|
||||
}
|
||||
|
||||
// Equal prefix, take the shorter
|
||||
if len(p[i].Parts) != len(p[j].Parts) {
|
||||
return len(p[i].Parts) < len(p[j].Parts)
|
||||
}
|
||||
|
||||
// Equal, it doesn't matter
|
||||
return false
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
This is a list of people who have contributed code to go-cache. They, or their
|
||||
employers, are the copyright holders of the contributed code. Contributed code
|
||||
is subject to the license restrictions listed in LICENSE (as they were when the
|
||||
code was contributed.)
|
||||
|
||||
Dustin Sallings <dustin@spy.net>
|
||||
Jason Mooberry <jasonmoo@me.com>
|
||||
Sergey Shepelev <temotor@gmail.com>
|
||||
Alex Edwards <ajmedwards@gmail.com>
|
|
@ -0,0 +1,19 @@
|
|||
Copyright (c) 2012-2017 Patrick Mylund Nielsen and the go-cache contributors
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
|
@ -0,0 +1,83 @@
|
|||
# go-cache
|
||||
|
||||
go-cache is an in-memory key:value store/cache similar to memcached that is
|
||||
suitable for applications running on a single machine. Its major advantage is
|
||||
that, being essentially a thread-safe `map[string]interface{}` with expiration
|
||||
times, it doesn't need to serialize or transmit its contents over the network.
|
||||
|
||||
Any object can be stored, for a given duration or forever, and the cache can be
|
||||
safely used by multiple goroutines.
|
||||
|
||||
Although go-cache isn't meant to be used as a persistent datastore, the entire
|
||||
cache can be saved to and loaded from a file (using `c.Items()` to retrieve the
|
||||
items map to serialize, and `NewFrom()` to create a cache from a deserialized
|
||||
one) to recover from downtime quickly. (See the docs for `NewFrom()` for caveats.)
|
||||
|
||||
### Installation
|
||||
|
||||
`go get github.com/patrickmn/go-cache`
|
||||
|
||||
### Usage
|
||||
|
||||
```go
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Create a cache with a default expiration time of 5 minutes, and which
|
||||
// purges expired items every 10 minutes
|
||||
c := cache.New(5*time.Minute, 10*time.Minute)
|
||||
|
||||
// Set the value of the key "foo" to "bar", with the default expiration time
|
||||
c.Set("foo", "bar", cache.DefaultExpiration)
|
||||
|
||||
// Set the value of the key "baz" to 42, with no expiration time
|
||||
// (the item won't be removed until it is re-set, or removed using
|
||||
// c.Delete("baz")
|
||||
c.Set("baz", 42, cache.NoExpiration)
|
||||
|
||||
// Get the string associated with the key "foo" from the cache
|
||||
foo, found := c.Get("foo")
|
||||
if found {
|
||||
fmt.Println(foo)
|
||||
}
|
||||
|
||||
// Since Go is statically typed, and cache values can be anything, type
|
||||
// assertion is needed when values are being passed to functions that don't
|
||||
// take arbitrary types, (i.e. interface{}). The simplest way to do this for
|
||||
// values which will only be used once--e.g. for passing to another
|
||||
// function--is:
|
||||
foo, found := c.Get("foo")
|
||||
if found {
|
||||
MyFunction(foo.(string))
|
||||
}
|
||||
|
||||
// This gets tedious if the value is used several times in the same function.
|
||||
// You might do either of the following instead:
|
||||
if x, found := c.Get("foo"); found {
|
||||
foo := x.(string)
|
||||
// ...
|
||||
}
|
||||
// or
|
||||
var foo string
|
||||
if x, found := c.Get("foo"); found {
|
||||
foo = x.(string)
|
||||
}
|
||||
// ...
|
||||
// foo can then be passed around freely as a string
|
||||
|
||||
// Want performance? Store pointers!
|
||||
c.Set("foo", &MyStruct, cache.DefaultExpiration)
|
||||
if x, found := c.Get("foo"); found {
|
||||
foo := x.(*MyStruct)
|
||||
// ...
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Reference
|
||||
|
||||
`godoc` or [http://godoc.org/github.com/patrickmn/go-cache](http://godoc.org/github.com/patrickmn/go-cache)
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,192 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"math"
|
||||
"math/big"
|
||||
insecurerand "math/rand"
|
||||
"os"
|
||||
"runtime"
|
||||
"time"
|
||||
)
|
||||
|
||||
// This is an experimental and unexported (for now) attempt at making a cache
|
||||
// with better algorithmic complexity than the standard one, namely by
|
||||
// preventing write locks of the entire cache when an item is added. As of the
|
||||
// time of writing, the overhead of selecting buckets results in cache
|
||||
// operations being about twice as slow as for the standard cache with small
|
||||
// total cache sizes, and faster for larger ones.
|
||||
//
|
||||
// See cache_test.go for a few benchmarks.
|
||||
|
||||
type unexportedShardedCache struct {
|
||||
*shardedCache
|
||||
}
|
||||
|
||||
type shardedCache struct {
|
||||
seed uint32
|
||||
m uint32
|
||||
cs []*cache
|
||||
janitor *shardedJanitor
|
||||
}
|
||||
|
||||
// djb2 with better shuffling. 5x faster than FNV with the hash.Hash overhead.
|
||||
func djb33(seed uint32, k string) uint32 {
|
||||
var (
|
||||
l = uint32(len(k))
|
||||
d = 5381 + seed + l
|
||||
i = uint32(0)
|
||||
)
|
||||
// Why is all this 5x faster than a for loop?
|
||||
if l >= 4 {
|
||||
for i < l-4 {
|
||||
d = (d * 33) ^ uint32(k[i])
|
||||
d = (d * 33) ^ uint32(k[i+1])
|
||||
d = (d * 33) ^ uint32(k[i+2])
|
||||
d = (d * 33) ^ uint32(k[i+3])
|
||||
i += 4
|
||||
}
|
||||
}
|
||||
switch l - i {
|
||||
case 1:
|
||||
case 2:
|
||||
d = (d * 33) ^ uint32(k[i])
|
||||
case 3:
|
||||
d = (d * 33) ^ uint32(k[i])
|
||||
d = (d * 33) ^ uint32(k[i+1])
|
||||
case 4:
|
||||
d = (d * 33) ^ uint32(k[i])
|
||||
d = (d * 33) ^ uint32(k[i+1])
|
||||
d = (d * 33) ^ uint32(k[i+2])
|
||||
}
|
||||
return d ^ (d >> 16)
|
||||
}
|
||||
|
||||
func (sc *shardedCache) bucket(k string) *cache {
|
||||
return sc.cs[djb33(sc.seed, k)%sc.m]
|
||||
}
|
||||
|
||||
func (sc *shardedCache) Set(k string, x interface{}, d time.Duration) {
|
||||
sc.bucket(k).Set(k, x, d)
|
||||
}
|
||||
|
||||
func (sc *shardedCache) Add(k string, x interface{}, d time.Duration) error {
|
||||
return sc.bucket(k).Add(k, x, d)
|
||||
}
|
||||
|
||||
func (sc *shardedCache) Replace(k string, x interface{}, d time.Duration) error {
|
||||
return sc.bucket(k).Replace(k, x, d)
|
||||
}
|
||||
|
||||
func (sc *shardedCache) Get(k string) (interface{}, bool) {
|
||||
return sc.bucket(k).Get(k)
|
||||
}
|
||||
|
||||
func (sc *shardedCache) Increment(k string, n int64) error {
|
||||
return sc.bucket(k).Increment(k, n)
|
||||
}
|
||||
|
||||
func (sc *shardedCache) IncrementFloat(k string, n float64) error {
|
||||
return sc.bucket(k).IncrementFloat(k, n)
|
||||
}
|
||||
|
||||
func (sc *shardedCache) Decrement(k string, n int64) error {
|
||||
return sc.bucket(k).Decrement(k, n)
|
||||
}
|
||||
|
||||
func (sc *shardedCache) Delete(k string) {
|
||||
sc.bucket(k).Delete(k)
|
||||
}
|
||||
|
||||
func (sc *shardedCache) DeleteExpired() {
|
||||
for _, v := range sc.cs {
|
||||
v.DeleteExpired()
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the items in the cache. This may include items that have expired,
|
||||
// but have not yet been cleaned up. If this is significant, the Expiration
|
||||
// fields of the items should be checked. Note that explicit synchronization
|
||||
// is needed to use a cache and its corresponding Items() return values at
|
||||
// the same time, as the maps are shared.
|
||||
func (sc *shardedCache) Items() []map[string]Item {
|
||||
res := make([]map[string]Item, len(sc.cs))
|
||||
for i, v := range sc.cs {
|
||||
res[i] = v.Items()
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (sc *shardedCache) Flush() {
|
||||
for _, v := range sc.cs {
|
||||
v.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
type shardedJanitor struct {
|
||||
Interval time.Duration
|
||||
stop chan bool
|
||||
}
|
||||
|
||||
func (j *shardedJanitor) Run(sc *shardedCache) {
|
||||
j.stop = make(chan bool)
|
||||
tick := time.Tick(j.Interval)
|
||||
for {
|
||||
select {
|
||||
case <-tick:
|
||||
sc.DeleteExpired()
|
||||
case <-j.stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stopShardedJanitor(sc *unexportedShardedCache) {
|
||||
sc.janitor.stop <- true
|
||||
}
|
||||
|
||||
func runShardedJanitor(sc *shardedCache, ci time.Duration) {
|
||||
j := &shardedJanitor{
|
||||
Interval: ci,
|
||||
}
|
||||
sc.janitor = j
|
||||
go j.Run(sc)
|
||||
}
|
||||
|
||||
func newShardedCache(n int, de time.Duration) *shardedCache {
|
||||
max := big.NewInt(0).SetUint64(uint64(math.MaxUint32))
|
||||
rnd, err := rand.Int(rand.Reader, max)
|
||||
var seed uint32
|
||||
if err != nil {
|
||||
os.Stderr.Write([]byte("WARNING: go-cache's newShardedCache failed to read from the system CSPRNG (/dev/urandom or equivalent.) Your system's security may be compromised. Continuing with an insecure seed.\n"))
|
||||
seed = insecurerand.Uint32()
|
||||
} else {
|
||||
seed = uint32(rnd.Uint64())
|
||||
}
|
||||
sc := &shardedCache{
|
||||
seed: seed,
|
||||
m: uint32(n),
|
||||
cs: make([]*cache, n),
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
c := &cache{
|
||||
defaultExpiration: de,
|
||||
items: map[string]Item{},
|
||||
}
|
||||
sc.cs[i] = c
|
||||
}
|
||||
return sc
|
||||
}
|
||||
|
||||
func unexportedNewSharded(defaultExpiration, cleanupInterval time.Duration, shards int) *unexportedShardedCache {
|
||||
if defaultExpiration == 0 {
|
||||
defaultExpiration = -1
|
||||
}
|
||||
sc := newShardedCache(shards, defaultExpiration)
|
||||
SC := &unexportedShardedCache{sc}
|
||||
if cleanupInterval > 0 {
|
||||
runShardedJanitor(sc, cleanupInterval)
|
||||
runtime.SetFinalizer(SC, stopShardedJanitor)
|
||||
}
|
||||
return SC
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
language: go
|
||||
|
||||
install:
|
||||
- go get -d -v ./...
|
||||
- go get -u github.com/stretchr/testify/require
|
||||
|
||||
go:
|
||||
- 1.7
|
||||
- 1.8
|
||||
- tip
|
|
@ -0,0 +1,202 @@
|
|||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
|
@ -0,0 +1,107 @@
|
|||
# cachecontrol: HTTP Caching Parser and Interpretation
|
||||
|
||||
[![GoDoc](https://godoc.org/github.com/pquerna/cachecontrol?status.svg)](https://godoc.org/github.com/pquerna/cachecontrol)[![Build Status](https://travis-ci.org/pquerna/cachecontrol.svg?branch=master)](https://travis-ci.org/pquerna/cachecontrol)
|
||||
|
||||
|
||||
|
||||
`cachecontrol` implements [RFC 7234](http://tools.ietf.org/html/rfc7234) __Hypertext Transfer Protocol (HTTP/1.1): Caching__. It does this by parsing the `Cache-Control` and other headers, providing information about requests and responses -- but `cachecontrol` does not implement an actual cache backend, just the control plane to make decisions about if a particular response is cachable.
|
||||
|
||||
# Usage
|
||||
|
||||
`cachecontrol.CachableResponse` returns an array of [reasons](https://godoc.org/github.com/pquerna/cachecontrol/cacheobject#Reason) why a response should not be cached and when it expires. In the case that `len(reasons) == 0`, the response is cachable according to the RFC. However, some people want non-compliant caches for various business use cases, so each reason is specifically named, so if your cache wants to cache `POST` requests, it can easily do that, but still be RFC compliant in other situations.
|
||||
|
||||
# Examples
|
||||
|
||||
## Can you cache Example.com?
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/pquerna/cachecontrol"
|
||||
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func main() {
|
||||
req, _ := http.NewRequest("GET", "http://www.example.com/", nil)
|
||||
|
||||
res, _ := http.DefaultClient.Do(req)
|
||||
_, _ = ioutil.ReadAll(res.Body)
|
||||
|
||||
reasons, expires, _ := cachecontrol.CachableResponse(req, res, cachecontrol.Options{})
|
||||
|
||||
fmt.Println("Reasons to not cache: ", reasons)
|
||||
fmt.Println("Expiration: ", expires.String())
|
||||
}
|
||||
```
|
||||
|
||||
## Can I use this in a high performance caching server?
|
||||
|
||||
`cachecontrol` is divided into two packages: `cachecontrol` with a high level API, and a lower level `cacheobject` package. Use [Object](https://godoc.org/github.com/pquerna/cachecontrol/cacheobject#Object) in a high performance use case where you have previously parsed headers containing dates or would like to avoid memory allocations.
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/pquerna/cachecontrol/cacheobject"
|
||||
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func main() {
|
||||
req, _ := http.NewRequest("GET", "http://www.example.com/", nil)
|
||||
|
||||
res, _ := http.DefaultClient.Do(req)
|
||||
_, _ = ioutil.ReadAll(res.Body)
|
||||
|
||||
reqDir, _ := cacheobject.ParseRequestCacheControl(req.Header.Get("Cache-Control"))
|
||||
|
||||
resDir, _ := cacheobject.ParseResponseCacheControl(res.Header.Get("Cache-Control"))
|
||||
expiresHeader, _ := http.ParseTime(res.Header.Get("Expires"))
|
||||
dateHeader, _ := http.ParseTime(res.Header.Get("Date"))
|
||||
lastModifiedHeader, _ := http.ParseTime(res.Header.Get("Last-Modified"))
|
||||
|
||||
obj := cacheobject.Object{
|
||||
RespDirectives: resDir,
|
||||
RespHeaders: res.Header,
|
||||
RespStatusCode: res.StatusCode,
|
||||
RespExpiresHeader: expiresHeader,
|
||||
RespDateHeader: dateHeader,
|
||||
RespLastModifiedHeader: lastModifiedHeader,
|
||||
|
||||
ReqDirectives: reqDir,
|
||||
ReqHeaders: req.Header,
|
||||
ReqMethod: req.Method,
|
||||
|
||||
NowUTC: time.Now().UTC(),
|
||||
}
|
||||
rv := cacheobject.ObjectResults{}
|
||||
|
||||
cacheobject.CachableObject(&obj, &rv)
|
||||
cacheobject.ExpirationObject(&obj, &rv)
|
||||
|
||||
fmt.Println("Errors: ", rv.OutErr)
|
||||
fmt.Println("Reasons to not cache: ", rv.OutReasons)
|
||||
fmt.Println("Warning headers to add: ", rv.OutWarnings)
|
||||
fmt.Println("Expiration: ", rv.OutExpirationTime.String())
|
||||
}
|
||||
```
|
||||
|
||||
## Improvements, bugs, adding features, and taking cachecontrol new directions!
|
||||
|
||||
Please [open issues in Github](https://github.com/pquerna/cachecontrol/issues) for ideas, bugs, and general thoughts. Pull requests are of course preferred :)
|
||||
|
||||
# Credits
|
||||
|
||||
`cachecontrol` has recieved significant contributions from:
|
||||
|
||||
* [Paul Querna](https://github.com/pquerna)
|
||||
|
||||
## License
|
||||
|
||||
`cachecontrol` is licensed under the [Apache License, Version 2.0](./LICENSE)
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2015 Paul Querna
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package cachecontrol
|
||||
|
||||
import (
|
||||
"github.com/pquerna/cachecontrol/cacheobject"
|
||||
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
// Set to True for a prviate cache, which is not shared amoung users (eg, in a browser)
|
||||
// Set to False for a "shared" cache, which is more common in a server context.
|
||||
PrivateCache bool
|
||||
}
|
||||
|
||||
// Given an HTTP Request, the future Status Code, and an ResponseWriter,
|
||||
// determine the possible reasons a response SHOULD NOT be cached.
|
||||
func CachableResponseWriter(req *http.Request,
|
||||
statusCode int,
|
||||
resp http.ResponseWriter,
|
||||
opts Options) ([]cacheobject.Reason, time.Time, error) {
|
||||
return cacheobject.UsingRequestResponse(req, statusCode, resp.Header(), opts.PrivateCache)
|
||||
}
|
||||
|
||||
// Given an HTTP Request and Response, determine the possible reasons a response SHOULD NOT
|
||||
// be cached.
|
||||
func CachableResponse(req *http.Request,
|
||||
resp *http.Response,
|
||||
opts Options) ([]cacheobject.Reason, time.Time, error) {
|
||||
return cacheobject.UsingRequestResponse(req, resp.StatusCode, resp.Header, opts.PrivateCache)
|
||||
}
|
|
@ -0,0 +1,546 @@
|
|||
/**
|
||||
* Copyright 2015 Paul Querna
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package cacheobject
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// TODO(pquerna): add extensions from here: http://www.iana.org/assignments/http-cache-directives/http-cache-directives.xhtml
|
||||
|
||||
var (
|
||||
ErrQuoteMismatch = errors.New("Missing closing quote")
|
||||
ErrMaxAgeDeltaSeconds = errors.New("Failed to parse delta-seconds in `max-age`")
|
||||
ErrSMaxAgeDeltaSeconds = errors.New("Failed to parse delta-seconds in `s-maxage`")
|
||||
ErrMaxStaleDeltaSeconds = errors.New("Failed to parse delta-seconds in `min-fresh`")
|
||||
ErrMinFreshDeltaSeconds = errors.New("Failed to parse delta-seconds in `min-fresh`")
|
||||
ErrNoCacheNoArgs = errors.New("Unexpected argument to `no-cache`")
|
||||
ErrNoStoreNoArgs = errors.New("Unexpected argument to `no-store`")
|
||||
ErrNoTransformNoArgs = errors.New("Unexpected argument to `no-transform`")
|
||||
ErrOnlyIfCachedNoArgs = errors.New("Unexpected argument to `only-if-cached`")
|
||||
ErrMustRevalidateNoArgs = errors.New("Unexpected argument to `must-revalidate`")
|
||||
ErrPublicNoArgs = errors.New("Unexpected argument to `public`")
|
||||
ErrProxyRevalidateNoArgs = errors.New("Unexpected argument to `proxy-revalidate`")
|
||||
// Experimental
|
||||
ErrImmutableNoArgs = errors.New("Unexpected argument to `immutable`")
|
||||
ErrStaleIfErrorDeltaSeconds = errors.New("Failed to parse delta-seconds in `stale-if-error`")
|
||||
ErrStaleWhileRevalidateDeltaSeconds = errors.New("Failed to parse delta-seconds in `stale-while-revalidate`")
|
||||
)
|
||||
|
||||
func whitespace(b byte) bool {
|
||||
if b == '\t' || b == ' ' {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func parse(value string, cd cacheDirective) error {
|
||||
var err error = nil
|
||||
i := 0
|
||||
|
||||
for i < len(value) && err == nil {
|
||||
// eat leading whitespace or commas
|
||||
if whitespace(value[i]) || value[i] == ',' {
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
j := i + 1
|
||||
|
||||
for j < len(value) {
|
||||
if !isToken(value[j]) {
|
||||
break
|
||||
}
|
||||
j++
|
||||
}
|
||||
|
||||
token := strings.ToLower(value[i:j])
|
||||
tokenHasFields := hasFieldNames(token)
|
||||
/*
|
||||
println("GOT TOKEN:")
|
||||
println(" i -> ", i)
|
||||
println(" j -> ", j)
|
||||
println(" token -> ", token)
|
||||
*/
|
||||
|
||||
if j+1 < len(value) && value[j] == '=' {
|
||||
k := j + 1
|
||||
// minimum size two bytes of "", but we let httpUnquote handle it.
|
||||
if k < len(value) && value[k] == '"' {
|
||||
eaten, result := httpUnquote(value[k:])
|
||||
if eaten == -1 {
|
||||
return ErrQuoteMismatch
|
||||
}
|
||||
i = k + eaten
|
||||
|
||||
err = cd.addPair(token, result)
|
||||
} else {
|
||||
z := k
|
||||
for z < len(value) {
|
||||
if tokenHasFields {
|
||||
if whitespace(value[z]) {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
if whitespace(value[z]) || value[z] == ',' {
|
||||
break
|
||||
}
|
||||
}
|
||||
z++
|
||||
}
|
||||
i = z
|
||||
|
||||
result := value[k:z]
|
||||
if result != "" && result[len(result)-1] == ',' {
|
||||
result = result[:len(result)-1]
|
||||
}
|
||||
|
||||
err = cd.addPair(token, result)
|
||||
}
|
||||
} else {
|
||||
if token != "," {
|
||||
err = cd.addToken(token)
|
||||
}
|
||||
i = j
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// DeltaSeconds specifies a non-negative integer, representing
|
||||
// time in seconds: http://tools.ietf.org/html/rfc7234#section-1.2.1
|
||||
//
|
||||
// When set to -1, this means unset.
|
||||
//
|
||||
type DeltaSeconds int32
|
||||
|
||||
// Parser for delta-seconds, a uint31, more or less:
|
||||
// http://tools.ietf.org/html/rfc7234#section-1.2.1
|
||||
func parseDeltaSeconds(v string) (DeltaSeconds, error) {
|
||||
n, err := strconv.ParseUint(v, 10, 32)
|
||||
if err != nil {
|
||||
if numError, ok := err.(*strconv.NumError); ok {
|
||||
if numError.Err == strconv.ErrRange {
|
||||
return DeltaSeconds(math.MaxInt32), nil
|
||||
}
|
||||
}
|
||||
return DeltaSeconds(-1), err
|
||||
} else {
|
||||
if n > math.MaxInt32 {
|
||||
return DeltaSeconds(math.MaxInt32), nil
|
||||
} else {
|
||||
return DeltaSeconds(n), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fields present in a header.
|
||||
type FieldNames map[string]bool
|
||||
|
||||
// internal interface for shared methods of RequestCacheDirectives and ResponseCacheDirectives
|
||||
type cacheDirective interface {
|
||||
addToken(s string) error
|
||||
addPair(s string, v string) error
|
||||
}
|
||||
|
||||
// LOW LEVEL API: Repersentation of possible request directives in a `Cache-Control` header: http://tools.ietf.org/html/rfc7234#section-5.2.1
|
||||
//
|
||||
// Note: Many fields will be `nil` in practice.
|
||||
//
|
||||
type RequestCacheDirectives struct {
|
||||
|
||||
// max-age(delta seconds): http://tools.ietf.org/html/rfc7234#section-5.2.1.1
|
||||
//
|
||||
// The "max-age" request directive indicates that the client is
|
||||
// unwilling to accept a response whose age is greater than the
|
||||
// specified number of seconds. Unless the max-stale request directive
|
||||
// is also present, the client is not willing to accept a stale
|
||||
// response.
|
||||
MaxAge DeltaSeconds
|
||||
|
||||
// max-stale(delta seconds): http://tools.ietf.org/html/rfc7234#section-5.2.1.2
|
||||
//
|
||||
// The "max-stale" request directive indicates that the client is
|
||||
// willing to accept a response that has exceeded its freshness
|
||||
// lifetime. If max-stale is assigned a value, then the client is
|
||||
// willing to accept a response that has exceeded its freshness lifetime
|
||||
// by no more than the specified number of seconds. If no value is
|
||||
// assigned to max-stale, then the client is willing to accept a stale
|
||||
// response of any age.
|
||||
MaxStale DeltaSeconds
|
||||
|
||||
// min-fresh(delta seconds): http://tools.ietf.org/html/rfc7234#section-5.2.1.3
|
||||
//
|
||||
// The "min-fresh" request directive indicates that the client is
|
||||
// willing to accept a response whose freshness lifetime is no less than
|
||||
// its current age plus the specified time in seconds. That is, the
|
||||
// client wants a response that will still be fresh for at least the
|
||||
// specified number of seconds.
|
||||
MinFresh DeltaSeconds
|
||||
|
||||
// no-cache(bool): http://tools.ietf.org/html/rfc7234#section-5.2.1.4
|
||||
//
|
||||
// The "no-cache" request directive indicates that a cache MUST NOT use
|
||||
// a stored response to satisfy the request without successful
|
||||
// validation on the origin server.
|
||||
NoCache bool
|
||||
|
||||
// no-store(bool): http://tools.ietf.org/html/rfc7234#section-5.2.1.5
|
||||
//
|
||||
// The "no-store" request directive indicates that a cache MUST NOT
|
||||
// store any part of either this request or any response to it. This
|
||||
// directive applies to both private and shared caches.
|
||||
NoStore bool
|
||||
|
||||
// no-transform(bool): http://tools.ietf.org/html/rfc7234#section-5.2.1.6
|
||||
//
|
||||
// The "no-transform" request directive indicates that an intermediary
|
||||
// (whether or not it implements a cache) MUST NOT transform the
|
||||
// payload, as defined in Section 5.7.2 of RFC7230.
|
||||
NoTransform bool
|
||||
|
||||
// only-if-cached(bool): http://tools.ietf.org/html/rfc7234#section-5.2.1.7
|
||||
//
|
||||
// The "only-if-cached" request directive indicates that the client only
|
||||
// wishes to obtain a stored response.
|
||||
OnlyIfCached bool
|
||||
|
||||
// Extensions: http://tools.ietf.org/html/rfc7234#section-5.2.3
|
||||
//
|
||||
// The Cache-Control header field can be extended through the use of one
|
||||
// or more cache-extension tokens, each with an optional value. A cache
|
||||
// MUST ignore unrecognized cache directives.
|
||||
Extensions []string
|
||||
}
|
||||
|
||||
func (cd *RequestCacheDirectives) addToken(token string) error {
|
||||
var err error = nil
|
||||
|
||||
switch token {
|
||||
case "max-age":
|
||||
err = ErrMaxAgeDeltaSeconds
|
||||
case "max-stale":
|
||||
err = ErrMaxStaleDeltaSeconds
|
||||
case "min-fresh":
|
||||
err = ErrMinFreshDeltaSeconds
|
||||
case "no-cache":
|
||||
cd.NoCache = true
|
||||
case "no-store":
|
||||
cd.NoStore = true
|
||||
case "no-transform":
|
||||
cd.NoTransform = true
|
||||
case "only-if-cached":
|
||||
cd.OnlyIfCached = true
|
||||
default:
|
||||
cd.Extensions = append(cd.Extensions, token)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (cd *RequestCacheDirectives) addPair(token string, v string) error {
|
||||
var err error = nil
|
||||
|
||||
switch token {
|
||||
case "max-age":
|
||||
cd.MaxAge, err = parseDeltaSeconds(v)
|
||||
if err != nil {
|
||||
err = ErrMaxAgeDeltaSeconds
|
||||
}
|
||||
case "max-stale":
|
||||
cd.MaxStale, err = parseDeltaSeconds(v)
|
||||
if err != nil {
|
||||
err = ErrMaxStaleDeltaSeconds
|
||||
}
|
||||
case "min-fresh":
|
||||
cd.MinFresh, err = parseDeltaSeconds(v)
|
||||
if err != nil {
|
||||
err = ErrMinFreshDeltaSeconds
|
||||
}
|
||||
case "no-cache":
|
||||
err = ErrNoCacheNoArgs
|
||||
case "no-store":
|
||||
err = ErrNoStoreNoArgs
|
||||
case "no-transform":
|
||||
err = ErrNoTransformNoArgs
|
||||
case "only-if-cached":
|
||||
err = ErrOnlyIfCachedNoArgs
|
||||
default:
|
||||
// TODO(pquerna): this sucks, making user re-parse
|
||||
cd.Extensions = append(cd.Extensions, token+"="+v)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// LOW LEVEL API: Parses a Cache Control Header from a Request into a set of directives.
|
||||
func ParseRequestCacheControl(value string) (*RequestCacheDirectives, error) {
|
||||
cd := &RequestCacheDirectives{
|
||||
MaxAge: -1,
|
||||
MaxStale: -1,
|
||||
MinFresh: -1,
|
||||
}
|
||||
|
||||
err := parse(value, cd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cd, nil
|
||||
}
|
||||
|
||||
// LOW LEVEL API: Repersentation of possible response directives in a `Cache-Control` header: http://tools.ietf.org/html/rfc7234#section-5.2.2
|
||||
//
|
||||
// Note: Many fields will be `nil` in practice.
|
||||
//
|
||||
type ResponseCacheDirectives struct {
|
||||
|
||||
// must-revalidate(bool): http://tools.ietf.org/html/rfc7234#section-5.2.2.1
|
||||
//
|
||||
// The "must-revalidate" response directive indicates that once it has
|
||||
// become stale, a cache MUST NOT use the response to satisfy subsequent
|
||||
// requests without successful validation on the origin server.
|
||||
MustRevalidate bool
|
||||
|
||||
// no-cache(FieldName): http://tools.ietf.org/html/rfc7234#section-5.2.2.2
|
||||
//
|
||||
// The "no-cache" response directive indicates that the response MUST
|
||||
// NOT be used to satisfy a subsequent request without successful
|
||||
// validation on the origin server.
|
||||
//
|
||||
// If the no-cache response directive specifies one or more field-names,
|
||||
// then a cache MAY use the response to satisfy a subsequent request,
|
||||
// subject to any other restrictions on caching. However, any header
|
||||
// fields in the response that have the field-name(s) listed MUST NOT be
|
||||
// sent in the response to a subsequent request without successful
|
||||
// revalidation with the origin server.
|
||||
NoCache FieldNames
|
||||
|
||||
// no-cache(cast-to-bool): http://tools.ietf.org/html/rfc7234#section-5.2.2.2
|
||||
//
|
||||
// While the RFC defines optional field-names on a no-cache directive,
|
||||
// many applications only want to know if any no-cache directives were
|
||||
// present at all.
|
||||
NoCachePresent bool
|
||||
|
||||
// no-store(bool): http://tools.ietf.org/html/rfc7234#section-5.2.2.3
|
||||
//
|
||||
// The "no-store" request directive indicates that a cache MUST NOT
|
||||
// store any part of either this request or any response to it. This
|
||||
// directive applies to both private and shared caches.
|
||||
NoStore bool
|
||||
|
||||
// no-transform(bool): http://tools.ietf.org/html/rfc7234#section-5.2.2.4
|
||||
//
|
||||
// The "no-transform" response directive indicates that an intermediary
|
||||
// (regardless of whether it implements a cache) MUST NOT transform the
|
||||
// payload, as defined in Section 5.7.2 of RFC7230.
|
||||
NoTransform bool
|
||||
|
||||
// public(bool): http://tools.ietf.org/html/rfc7234#section-5.2.2.5
|
||||
//
|
||||
// The "public" response directive indicates that any cache MAY store
|
||||
// the response, even if the response would normally be non-cacheable or
|
||||
// cacheable only within a private cache.
|
||||
Public bool
|
||||
|
||||
// private(FieldName): http://tools.ietf.org/html/rfc7234#section-5.2.2.6
|
||||
//
|
||||
// The "private" response directive indicates that the response message
|
||||
// is intended for a single user and MUST NOT be stored by a shared
|
||||
// cache. A private cache MAY store the response and reuse it for later
|
||||
// requests, even if the response would normally be non-cacheable.
|
||||
//
|
||||
// If the private response directive specifies one or more field-names,
|
||||
// this requirement is limited to the field-values associated with the
|
||||
// listed response header fields. That is, a shared cache MUST NOT
|
||||
// store the specified field-names(s), whereas it MAY store the
|
||||
// remainder of the response message.
|
||||
Private FieldNames
|
||||
|
||||
// private(cast-to-bool): http://tools.ietf.org/html/rfc7234#section-5.2.2.6
|
||||
//
|
||||
// While the RFC defines optional field-names on a private directive,
|
||||
// many applications only want to know if any private directives were
|
||||
// present at all.
|
||||
PrivatePresent bool
|
||||
|
||||
// proxy-revalidate(bool): http://tools.ietf.org/html/rfc7234#section-5.2.2.7
|
||||
//
|
||||
// The "proxy-revalidate" response directive has the same meaning as the
|
||||
// must-revalidate response directive, except that it does not apply to
|
||||
// private caches.
|
||||
ProxyRevalidate bool
|
||||
|
||||
// max-age(delta seconds): http://tools.ietf.org/html/rfc7234#section-5.2.2.8
|
||||
//
|
||||
// The "max-age" response directive indicates that the response is to be
|
||||
// considered stale after its age is greater than the specified number
|
||||
// of seconds.
|
||||
MaxAge DeltaSeconds
|
||||
|
||||
// s-maxage(delta seconds): http://tools.ietf.org/html/rfc7234#section-5.2.2.9
|
||||
//
|
||||
// The "s-maxage" response directive indicates that, in shared caches,
|
||||
// the maximum age specified by this directive overrides the maximum age
|
||||
// specified by either the max-age directive or the Expires header
|
||||
// field. The s-maxage directive also implies the semantics of the
|
||||
// proxy-revalidate response directive.
|
||||
SMaxAge DeltaSeconds
|
||||
|
||||
////
|
||||
// Experimental features
|
||||
// - https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control#Extension_Cache-Control_directives
|
||||
// - https://www.fastly.com/blog/stale-while-revalidate-stale-if-error-available-today
|
||||
////
|
||||
|
||||
// immutable(cast-to-bool): experimental feature
|
||||
Immutable bool
|
||||
|
||||
// stale-if-error(delta seconds): experimental feature
|
||||
StaleIfError DeltaSeconds
|
||||
|
||||
// stale-while-revalidate(delta seconds): experimental feature
|
||||
StaleWhileRevalidate DeltaSeconds
|
||||
|
||||
// Extensions: http://tools.ietf.org/html/rfc7234#section-5.2.3
|
||||
//
|
||||
// The Cache-Control header field can be extended through the use of one
|
||||
// or more cache-extension tokens, each with an optional value. A cache
|
||||
// MUST ignore unrecognized cache directives.
|
||||
Extensions []string
|
||||
}
|
||||
|
||||
// LOW LEVEL API: Parses a Cache Control Header from a Response into a set of directives.
|
||||
func ParseResponseCacheControl(value string) (*ResponseCacheDirectives, error) {
|
||||
cd := &ResponseCacheDirectives{
|
||||
MaxAge: -1,
|
||||
SMaxAge: -1,
|
||||
// Exerimantal stale timeouts
|
||||
StaleIfError: -1,
|
||||
StaleWhileRevalidate: -1,
|
||||
}
|
||||
|
||||
err := parse(value, cd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cd, nil
|
||||
}
|
||||
|
||||
func (cd *ResponseCacheDirectives) addToken(token string) error {
|
||||
var err error = nil
|
||||
switch token {
|
||||
case "must-revalidate":
|
||||
cd.MustRevalidate = true
|
||||
case "no-cache":
|
||||
cd.NoCachePresent = true
|
||||
case "no-store":
|
||||
cd.NoStore = true
|
||||
case "no-transform":
|
||||
cd.NoTransform = true
|
||||
case "public":
|
||||
cd.Public = true
|
||||
case "private":
|
||||
cd.PrivatePresent = true
|
||||
case "proxy-revalidate":
|
||||
cd.ProxyRevalidate = true
|
||||
case "max-age":
|
||||
err = ErrMaxAgeDeltaSeconds
|
||||
case "s-maxage":
|
||||
err = ErrSMaxAgeDeltaSeconds
|
||||
// Experimental
|
||||
case "immutable":
|
||||
cd.Immutable = true
|
||||
case "stale-if-error":
|
||||
err = ErrMaxAgeDeltaSeconds
|
||||
case "stale-while-revalidate":
|
||||
err = ErrMaxAgeDeltaSeconds
|
||||
default:
|
||||
cd.Extensions = append(cd.Extensions, token)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func hasFieldNames(token string) bool {
|
||||
switch token {
|
||||
case "no-cache":
|
||||
return true
|
||||
case "private":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (cd *ResponseCacheDirectives) addPair(token string, v string) error {
|
||||
var err error = nil
|
||||
|
||||
switch token {
|
||||
case "must-revalidate":
|
||||
err = ErrMustRevalidateNoArgs
|
||||
case "no-cache":
|
||||
cd.NoCachePresent = true
|
||||
tokens := strings.Split(v, ",")
|
||||
if cd.NoCache == nil {
|
||||
cd.NoCache = make(FieldNames)
|
||||
}
|
||||
for _, t := range tokens {
|
||||
k := http.CanonicalHeaderKey(textproto.TrimString(t))
|
||||
cd.NoCache[k] = true
|
||||
}
|
||||
case "no-store":
|
||||
err = ErrNoStoreNoArgs
|
||||
case "no-transform":
|
||||
err = ErrNoTransformNoArgs
|
||||
case "public":
|
||||
err = ErrPublicNoArgs
|
||||
case "private":
|
||||
cd.PrivatePresent = true
|
||||
tokens := strings.Split(v, ",")
|
||||
if cd.Private == nil {
|
||||
cd.Private = make(FieldNames)
|
||||
}
|
||||
for _, t := range tokens {
|
||||
k := http.CanonicalHeaderKey(textproto.TrimString(t))
|
||||
cd.Private[k] = true
|
||||
}
|
||||
case "proxy-revalidate":
|
||||
err = ErrProxyRevalidateNoArgs
|
||||
case "max-age":
|
||||
cd.MaxAge, err = parseDeltaSeconds(v)
|
||||
case "s-maxage":
|
||||
cd.SMaxAge, err = parseDeltaSeconds(v)
|
||||
// Experimental
|
||||
case "immutable":
|
||||
err = ErrImmutableNoArgs
|
||||
case "stale-if-error":
|
||||
cd.StaleIfError, err = parseDeltaSeconds(v)
|
||||
case "stale-while-revalidate":
|
||||
cd.StaleWhileRevalidate, err = parseDeltaSeconds(v)
|
||||
default:
|
||||
// TODO(pquerna): this sucks, making user re-parse, and its technically not 'quoted' like the original,
|
||||
// but this is still easier, just a SplitN on "="
|
||||
cd.Extensions = append(cd.Extensions, token+"="+v)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,93 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package cacheobject
|
||||
|
||||
// This file deals with lexical matters of HTTP
|
||||
|
||||
func isSeparator(c byte) bool {
|
||||
switch c {
|
||||
case '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=', '{', '}', ' ', '\t':
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isCtl(c byte) bool { return (0 <= c && c <= 31) || c == 127 }
|
||||
|
||||
func isChar(c byte) bool { return 0 <= c && c <= 127 }
|
||||
|
||||
func isAnyText(c byte) bool { return !isCtl(c) }
|
||||
|
||||
func isQdText(c byte) bool { return isAnyText(c) && c != '"' }
|
||||
|
||||
func isToken(c byte) bool { return isChar(c) && !isCtl(c) && !isSeparator(c) }
|
||||
|
||||
// Valid escaped sequences are not specified in RFC 2616, so for now, we assume
|
||||
// that they coincide with the common sense ones used by GO. Malformed
|
||||
// characters should probably not be treated as errors by a robust (forgiving)
|
||||
// parser, so we replace them with the '?' character.
|
||||
func httpUnquotePair(b byte) byte {
|
||||
// skip the first byte, which should always be '\'
|
||||
switch b {
|
||||
case 'a':
|
||||
return '\a'
|
||||
case 'b':
|
||||
return '\b'
|
||||
case 'f':
|
||||
return '\f'
|
||||
case 'n':
|
||||
return '\n'
|
||||
case 'r':
|
||||
return '\r'
|
||||
case 't':
|
||||
return '\t'
|
||||
case 'v':
|
||||
return '\v'
|
||||
case '\\':
|
||||
return '\\'
|
||||
case '\'':
|
||||
return '\''
|
||||
case '"':
|
||||
return '"'
|
||||
}
|
||||
return '?'
|
||||
}
|
||||
|
||||
// raw must begin with a valid quoted string. Only the first quoted string is
|
||||
// parsed and is unquoted in result. eaten is the number of bytes parsed, or -1
|
||||
// upon failure.
|
||||
func httpUnquote(raw string) (eaten int, result string) {
|
||||
buf := make([]byte, len(raw))
|
||||
if raw[0] != '"' {
|
||||
return -1, ""
|
||||
}
|
||||
eaten = 1
|
||||
j := 0 // # of bytes written in buf
|
||||
for i := 1; i < len(raw); i++ {
|
||||
switch b := raw[i]; b {
|
||||
case '"':
|
||||
eaten++
|
||||
buf = buf[0:j]
|
||||
return i + 1, string(buf)
|
||||
case '\\':
|
||||
if len(raw) < i+2 {
|
||||
return -1, ""
|
||||
}
|
||||
buf[j] = httpUnquotePair(raw[i+1])
|
||||
eaten += 2
|
||||
j++
|
||||
i++
|
||||
default:
|
||||
if isQdText(b) {
|
||||
buf[j] = b
|
||||
} else {
|
||||
buf[j] = '?'
|
||||
}
|
||||
eaten++
|
||||
j++
|
||||
}
|
||||
}
|
||||
return -1, ""
|
||||
}
|
|
@ -0,0 +1,387 @@
|
|||
/**
|
||||
* Copyright 2015 Paul Querna
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package cacheobject
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// LOW LEVEL API: Repersents a potentially cachable HTTP object.
|
||||
//
|
||||
// This struct is designed to be serialized efficiently, so in a high
|
||||
// performance caching server, things like Date-Strings don't need to be
|
||||
// parsed for every use of a cached object.
|
||||
type Object struct {
|
||||
CacheIsPrivate bool
|
||||
|
||||
RespDirectives *ResponseCacheDirectives
|
||||
RespHeaders http.Header
|
||||
RespStatusCode int
|
||||
RespExpiresHeader time.Time
|
||||
RespDateHeader time.Time
|
||||
RespLastModifiedHeader time.Time
|
||||
|
||||
ReqDirectives *RequestCacheDirectives
|
||||
ReqHeaders http.Header
|
||||
ReqMethod string
|
||||
|
||||
NowUTC time.Time
|
||||
}
|
||||
|
||||
// LOW LEVEL API: Repersents the results of examinig an Object with
|
||||
// CachableObject and ExpirationObject.
|
||||
//
|
||||
// TODO(pquerna): decide if this is a good idea or bad
|
||||
type ObjectResults struct {
|
||||
OutReasons []Reason
|
||||
OutWarnings []Warning
|
||||
OutExpirationTime time.Time
|
||||
OutErr error
|
||||
}
|
||||
|
||||
// LOW LEVEL API: Check if a object is cachable.
|
||||
func CachableObject(obj *Object, rv *ObjectResults) {
|
||||
rv.OutReasons = nil
|
||||
rv.OutWarnings = nil
|
||||
rv.OutErr = nil
|
||||
|
||||
switch obj.ReqMethod {
|
||||
case "GET":
|
||||
break
|
||||
case "HEAD":
|
||||
break
|
||||
case "POST":
|
||||
/**
|
||||
POST: http://tools.ietf.org/html/rfc7231#section-4.3.3
|
||||
|
||||
Responses to POST requests are only cacheable when they include
|
||||
explicit freshness information (see Section 4.2.1 of [RFC7234]).
|
||||
However, POST caching is not widely implemented. For cases where an
|
||||
origin server wishes the client to be able to cache the result of a
|
||||
POST in a way that can be reused by a later GET, the origin server
|
||||
MAY send a 200 (OK) response containing the result and a
|
||||
Content-Location header field that has the same value as the POST's
|
||||
effective request URI (Section 3.1.4.2).
|
||||
*/
|
||||
if !hasFreshness(obj.ReqDirectives, obj.RespDirectives, obj.RespHeaders, obj.RespExpiresHeader, obj.CacheIsPrivate) {
|
||||
rv.OutReasons = append(rv.OutReasons, ReasonRequestMethodPOST)
|
||||
}
|
||||
|
||||
case "PUT":
|
||||
rv.OutReasons = append(rv.OutReasons, ReasonRequestMethodPUT)
|
||||
|
||||
case "DELETE":
|
||||
rv.OutReasons = append(rv.OutReasons, ReasonRequestMethodDELETE)
|
||||
|
||||
case "CONNECT":
|
||||
rv.OutReasons = append(rv.OutReasons, ReasonRequestMethodCONNECT)
|
||||
|
||||
case "OPTIONS":
|
||||
rv.OutReasons = append(rv.OutReasons, ReasonRequestMethodOPTIONS)
|
||||
|
||||
case "TRACE":
|
||||
rv.OutReasons = append(rv.OutReasons, ReasonRequestMethodTRACE)
|
||||
|
||||
// HTTP Extension Methods: http://www.iana.org/assignments/http-methods/http-methods.xhtml
|
||||
//
|
||||
// To my knowledge, none of them are cachable. Please open a ticket if this is not the case!
|
||||
//
|
||||
default:
|
||||
rv.OutReasons = append(rv.OutReasons, ReasonRequestMethodUnkown)
|
||||
}
|
||||
|
||||
if obj.ReqDirectives.NoStore {
|
||||
rv.OutReasons = append(rv.OutReasons, ReasonRequestNoStore)
|
||||
}
|
||||
|
||||
// Storing Responses to Authenticated Requests: http://tools.ietf.org/html/rfc7234#section-3.2
|
||||
authz := obj.ReqHeaders.Get("Authorization")
|
||||
if authz != "" {
|
||||
if obj.RespDirectives.MustRevalidate ||
|
||||
obj.RespDirectives.Public ||
|
||||
obj.RespDirectives.SMaxAge != -1 {
|
||||
// Expires of some kind present, this is potentially OK.
|
||||
} else {
|
||||
rv.OutReasons = append(rv.OutReasons, ReasonRequestAuthorizationHeader)
|
||||
}
|
||||
}
|
||||
|
||||
if obj.RespDirectives.PrivatePresent && !obj.CacheIsPrivate {
|
||||
rv.OutReasons = append(rv.OutReasons, ReasonResponsePrivate)
|
||||
}
|
||||
|
||||
if obj.RespDirectives.NoStore {
|
||||
rv.OutReasons = append(rv.OutReasons, ReasonResponseNoStore)
|
||||
}
|
||||
|
||||
/*
|
||||
the response either:
|
||||
|
||||
* contains an Expires header field (see Section 5.3), or
|
||||
|
||||
* contains a max-age response directive (see Section 5.2.2.8), or
|
||||
|
||||
* contains a s-maxage response directive (see Section 5.2.2.9)
|
||||
and the cache is shared, or
|
||||
|
||||
* contains a Cache Control Extension (see Section 5.2.3) that
|
||||
allows it to be cached, or
|
||||
|
||||
* has a status code that is defined as cacheable by default (see
|
||||
Section 4.2.2), or
|
||||
|
||||
* contains a public response directive (see Section 5.2.2.5).
|
||||
*/
|
||||
|
||||
expires := obj.RespHeaders.Get("Expires") != ""
|
||||
statusCachable := cachableStatusCode(obj.RespStatusCode)
|
||||
|
||||
if expires ||
|
||||
obj.RespDirectives.MaxAge != -1 ||
|
||||
(obj.RespDirectives.SMaxAge != -1 && !obj.CacheIsPrivate) ||
|
||||
statusCachable ||
|
||||
obj.RespDirectives.Public {
|
||||
/* cachable by default, at least one of the above conditions was true */
|
||||
} else {
|
||||
rv.OutReasons = append(rv.OutReasons, ReasonResponseUncachableByDefault)
|
||||
}
|
||||
}
|
||||
|
||||
var twentyFourHours = time.Duration(24 * time.Hour)
|
||||
|
||||
const debug = false
|
||||
|
||||
// LOW LEVEL API: Update an objects expiration time.
|
||||
func ExpirationObject(obj *Object, rv *ObjectResults) {
|
||||
/**
|
||||
* Okay, lets calculate Freshness/Expiration now. woo:
|
||||
* http://tools.ietf.org/html/rfc7234#section-4.2
|
||||
*/
|
||||
|
||||
/*
|
||||
o If the cache is shared and the s-maxage response directive
|
||||
(Section 5.2.2.9) is present, use its value, or
|
||||
|
||||
o If the max-age response directive (Section 5.2.2.8) is present,
|
||||
use its value, or
|
||||
|
||||
o If the Expires response header field (Section 5.3) is present, use
|
||||
its value minus the value of the Date response header field, or
|
||||
|
||||
o Otherwise, no explicit expiration time is present in the response.
|
||||
A heuristic freshness lifetime might be applicable; see
|
||||
Section 4.2.2.
|
||||
*/
|
||||
|
||||
var expiresTime time.Time
|
||||
|
||||
if obj.RespDirectives.SMaxAge != -1 && !obj.CacheIsPrivate {
|
||||
expiresTime = obj.NowUTC.Add(time.Second * time.Duration(obj.RespDirectives.SMaxAge))
|
||||
} else if obj.RespDirectives.MaxAge != -1 {
|
||||
expiresTime = obj.NowUTC.UTC().Add(time.Second * time.Duration(obj.RespDirectives.MaxAge))
|
||||
} else if !obj.RespExpiresHeader.IsZero() {
|
||||
serverDate := obj.RespDateHeader
|
||||
if serverDate.IsZero() {
|
||||
// common enough case when a Date: header has not yet been added to an
|
||||
// active response.
|
||||
serverDate = obj.NowUTC
|
||||
}
|
||||
expiresTime = obj.NowUTC.Add(obj.RespExpiresHeader.Sub(serverDate))
|
||||
} else if !obj.RespLastModifiedHeader.IsZero() {
|
||||
// heuristic freshness lifetime
|
||||
rv.OutWarnings = append(rv.OutWarnings, WarningHeuristicExpiration)
|
||||
|
||||
// http://httpd.apache.org/docs/2.4/mod/mod_cache.html#cachelastmodifiedfactor
|
||||
// CacheMaxExpire defaults to 24 hours
|
||||
// CacheLastModifiedFactor: is 0.1
|
||||
//
|
||||
// expiry-period = MIN(time-since-last-modified-date * factor, 24 hours)
|
||||
//
|
||||
// obj.NowUTC
|
||||
|
||||
since := obj.RespLastModifiedHeader.Sub(obj.NowUTC)
|
||||
since = time.Duration(float64(since) * -0.1)
|
||||
|
||||
if since > twentyFourHours {
|
||||
expiresTime = obj.NowUTC.Add(twentyFourHours)
|
||||
} else {
|
||||
expiresTime = obj.NowUTC.Add(since)
|
||||
}
|
||||
|
||||
if debug {
|
||||
println("Now UTC: ", obj.NowUTC.String())
|
||||
println("Last-Modified: ", obj.RespLastModifiedHeader.String())
|
||||
println("Since: ", since.String())
|
||||
println("TwentyFourHours: ", twentyFourHours.String())
|
||||
println("Expiration: ", expiresTime.String())
|
||||
}
|
||||
} else {
|
||||
// TODO(pquerna): what should the default behavoir be for expiration time?
|
||||
}
|
||||
|
||||
rv.OutExpirationTime = expiresTime
|
||||
}
|
||||
|
||||
// Evaluate cachability based on an HTTP request, and parts of the response.
|
||||
func UsingRequestResponse(req *http.Request,
|
||||
statusCode int,
|
||||
respHeaders http.Header,
|
||||
privateCache bool) ([]Reason, time.Time, error) {
|
||||
reasons, time, _, _, err := UsingRequestResponseWithObject(req, statusCode, respHeaders, privateCache)
|
||||
return reasons, time, err
|
||||
}
|
||||
|
||||
// Evaluate cachability based on an HTTP request, and parts of the response.
|
||||
// Returns the parsed Object as well.
|
||||
func UsingRequestResponseWithObject(req *http.Request,
|
||||
statusCode int,
|
||||
respHeaders http.Header,
|
||||
privateCache bool) ([]Reason, time.Time, []Warning, *Object, error) {
|
||||
var reqHeaders http.Header
|
||||
var reqMethod string
|
||||
|
||||
var reqDir *RequestCacheDirectives = nil
|
||||
respDir, err := ParseResponseCacheControl(respHeaders.Get("Cache-Control"))
|
||||
if err != nil {
|
||||
return nil, time.Time{}, nil, nil, err
|
||||
}
|
||||
|
||||
if req != nil {
|
||||
reqDir, err = ParseRequestCacheControl(req.Header.Get("Cache-Control"))
|
||||
if err != nil {
|
||||
return nil, time.Time{}, nil, nil, err
|
||||
}
|
||||
reqHeaders = req.Header
|
||||
reqMethod = req.Method
|
||||
}
|
||||
|
||||
var expiresHeader time.Time
|
||||
var dateHeader time.Time
|
||||
var lastModifiedHeader time.Time
|
||||
|
||||
if respHeaders.Get("Expires") != "" {
|
||||
expiresHeader, err = http.ParseTime(respHeaders.Get("Expires"))
|
||||
if err != nil {
|
||||
// sometimes servers will return `Expires: 0` or `Expires: -1` to
|
||||
// indicate expired content
|
||||
expiresHeader = time.Time{}
|
||||
}
|
||||
expiresHeader = expiresHeader.UTC()
|
||||
}
|
||||
|
||||
if respHeaders.Get("Date") != "" {
|
||||
dateHeader, err = http.ParseTime(respHeaders.Get("Date"))
|
||||
if err != nil {
|
||||
return nil, time.Time{}, nil, nil, err
|
||||
}
|
||||
dateHeader = dateHeader.UTC()
|
||||
}
|
||||
|
||||
if respHeaders.Get("Last-Modified") != "" {
|
||||
lastModifiedHeader, err = http.ParseTime(respHeaders.Get("Last-Modified"))
|
||||
if err != nil {
|
||||
return nil, time.Time{}, nil, nil, err
|
||||
}
|
||||
lastModifiedHeader = lastModifiedHeader.UTC()
|
||||
}
|
||||
|
||||
obj := Object{
|
||||
CacheIsPrivate: privateCache,
|
||||
|
||||
RespDirectives: respDir,
|
||||
RespHeaders: respHeaders,
|
||||
RespStatusCode: statusCode,
|
||||
RespExpiresHeader: expiresHeader,
|
||||
RespDateHeader: dateHeader,
|
||||
RespLastModifiedHeader: lastModifiedHeader,
|
||||
|
||||
ReqDirectives: reqDir,
|
||||
ReqHeaders: reqHeaders,
|
||||
ReqMethod: reqMethod,
|
||||
|
||||
NowUTC: time.Now().UTC(),
|
||||
}
|
||||
rv := ObjectResults{}
|
||||
|
||||
CachableObject(&obj, &rv)
|
||||
if rv.OutErr != nil {
|
||||
return nil, time.Time{}, nil, nil, rv.OutErr
|
||||
}
|
||||
|
||||
ExpirationObject(&obj, &rv)
|
||||
if rv.OutErr != nil {
|
||||
return nil, time.Time{}, nil, nil, rv.OutErr
|
||||
}
|
||||
|
||||
return rv.OutReasons, rv.OutExpirationTime, rv.OutWarnings, &obj, nil
|
||||
}
|
||||
|
||||
// calculate if a freshness directive is present: http://tools.ietf.org/html/rfc7234#section-4.2.1
|
||||
func hasFreshness(reqDir *RequestCacheDirectives, respDir *ResponseCacheDirectives, respHeaders http.Header, respExpires time.Time, privateCache bool) bool {
|
||||
if !privateCache && respDir.SMaxAge != -1 {
|
||||
return true
|
||||
}
|
||||
|
||||
if respDir.MaxAge != -1 {
|
||||
return true
|
||||
}
|
||||
|
||||
if !respExpires.IsZero() || respHeaders.Get("Expires") != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func cachableStatusCode(statusCode int) bool {
|
||||
/*
|
||||
Responses with status codes that are defined as cacheable by default
|
||||
(e.g., 200, 203, 204, 206, 300, 301, 404, 405, 410, 414, and 501 in
|
||||
this specification) can be reused by a cache with heuristic
|
||||
expiration unless otherwise indicated by the method definition or
|
||||
explicit cache controls [RFC7234]; all other status codes are not
|
||||
cacheable by default.
|
||||
*/
|
||||
switch statusCode {
|
||||
case 200:
|
||||
return true
|
||||
case 203:
|
||||
return true
|
||||
case 204:
|
||||
return true
|
||||
case 206:
|
||||
return true
|
||||
case 300:
|
||||
return true
|
||||
case 301:
|
||||
return true
|
||||
case 404:
|
||||
return true
|
||||
case 405:
|
||||
return true
|
||||
case 410:
|
||||
return true
|
||||
case 414:
|
||||
return true
|
||||
case 501:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
|
@ -0,0 +1,95 @@
|
|||
/**
|
||||
* Copyright 2015 Paul Querna
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package cacheobject
|
||||
|
||||
// Repersents a potential Reason to not cache an object.
|
||||
//
|
||||
// Applications may wish to ignore specific reasons, which will make them non-RFC
|
||||
// compliant, but this type gives them specific cases they can choose to ignore,
|
||||
// making them compliant in as many cases as they can.
|
||||
type Reason int
|
||||
|
||||
const (
|
||||
|
||||
// The request method was POST and an Expiration header was not supplied.
|
||||
ReasonRequestMethodPOST Reason = iota
|
||||
|
||||
// The request method was PUT and PUTs are not cachable.
|
||||
ReasonRequestMethodPUT
|
||||
|
||||
// The request method was DELETE and DELETEs are not cachable.
|
||||
ReasonRequestMethodDELETE
|
||||
|
||||
// The request method was CONNECT and CONNECTs are not cachable.
|
||||
ReasonRequestMethodCONNECT
|
||||
|
||||
// The request method was OPTIONS and OPTIONS are not cachable.
|
||||
ReasonRequestMethodOPTIONS
|
||||
|
||||
// The request method was TRACE and TRACEs are not cachable.
|
||||
ReasonRequestMethodTRACE
|
||||
|
||||
// The request method was not recognized by cachecontrol, and should not be cached.
|
||||
ReasonRequestMethodUnkown
|
||||
|
||||
// The request included an Cache-Control: no-store header
|
||||
ReasonRequestNoStore
|
||||
|
||||
// The request included an Authorization header without an explicit Public or Expiration time: http://tools.ietf.org/html/rfc7234#section-3.2
|
||||
ReasonRequestAuthorizationHeader
|
||||
|
||||
// The response included an Cache-Control: no-store header
|
||||
ReasonResponseNoStore
|
||||
|
||||
// The response included an Cache-Control: private header and this is not a Private cache
|
||||
ReasonResponsePrivate
|
||||
|
||||
// The response failed to meet at least one of the conditions specified in RFC 7234 section 3: http://tools.ietf.org/html/rfc7234#section-3
|
||||
ReasonResponseUncachableByDefault
|
||||
)
|
||||
|
||||
func (r Reason) String() string {
|
||||
switch r {
|
||||
case ReasonRequestMethodPOST:
|
||||
return "ReasonRequestMethodPOST"
|
||||
case ReasonRequestMethodPUT:
|
||||
return "ReasonRequestMethodPUT"
|
||||
case ReasonRequestMethodDELETE:
|
||||
return "ReasonRequestMethodDELETE"
|
||||
case ReasonRequestMethodCONNECT:
|
||||
return "ReasonRequestMethodCONNECT"
|
||||
case ReasonRequestMethodOPTIONS:
|
||||
return "ReasonRequestMethodOPTIONS"
|
||||
case ReasonRequestMethodTRACE:
|
||||
return "ReasonRequestMethodTRACE"
|
||||
case ReasonRequestMethodUnkown:
|
||||
return "ReasonRequestMethodUnkown"
|
||||
case ReasonRequestNoStore:
|
||||
return "ReasonRequestNoStore"
|
||||
case ReasonRequestAuthorizationHeader:
|
||||
return "ReasonRequestAuthorizationHeader"
|
||||
case ReasonResponseNoStore:
|
||||
return "ReasonResponseNoStore"
|
||||
case ReasonResponsePrivate:
|
||||
return "ReasonResponsePrivate"
|
||||
case ReasonResponseUncachableByDefault:
|
||||
return "ReasonResponseUncachableByDefault"
|
||||
}
|
||||
|
||||
panic(r)
|
||||
}
|
|
@ -0,0 +1,107 @@
|
|||
/**
|
||||
* Copyright 2015 Paul Querna
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package cacheobject
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Repersents an HTTP Warning: http://tools.ietf.org/html/rfc7234#section-5.5
|
||||
type Warning int
|
||||
|
||||
const (
|
||||
// Response is Stale
|
||||
// A cache SHOULD generate this whenever the sent response is stale.
|
||||
WarningResponseIsStale Warning = 110
|
||||
|
||||
// Revalidation Failed
|
||||
// A cache SHOULD generate this when sending a stale
|
||||
// response because an attempt to validate the response failed, due to an
|
||||
// inability to reach the server.
|
||||
WarningRevalidationFailed Warning = 111
|
||||
|
||||
// Disconnected Operation
|
||||
// A cache SHOULD generate this if it is intentionally disconnected from
|
||||
// the rest of the network for a period of time.
|
||||
WarningDisconnectedOperation Warning = 112
|
||||
|
||||
// Heuristic Expiration
|
||||
//
|
||||
// A cache SHOULD generate this if it heuristically chose a freshness
|
||||
// lifetime greater than 24 hours and the response's age is greater than
|
||||
// 24 hours.
|
||||
WarningHeuristicExpiration Warning = 113
|
||||
|
||||
// Miscellaneous Warning
|
||||
//
|
||||
// The warning text can include arbitrary information to be presented to
|
||||
// a human user or logged. A system receiving this warning MUST NOT
|
||||
// take any automated action, besides presenting the warning to the
|
||||
// user.
|
||||
WarningMiscellaneousWarning Warning = 199
|
||||
|
||||
// Transformation Applied
|
||||
//
|
||||
// This Warning code MUST be added by a proxy if it applies any
|
||||
// transformation to the representation, such as changing the
|
||||
// content-coding, media-type, or modifying the representation data,
|
||||
// unless this Warning code already appears in the response.
|
||||
WarningTransformationApplied Warning = 214
|
||||
|
||||
// Miscellaneous Persistent Warning
|
||||
//
|
||||
// The warning text can include arbitrary information to be presented to
|
||||
// a human user or logged. A system receiving this warning MUST NOT
|
||||
// take any automated action.
|
||||
WarningMiscellaneousPersistentWarning Warning = 299
|
||||
)
|
||||
|
||||
func (w Warning) HeaderString(agent string, date time.Time) string {
|
||||
if agent == "" {
|
||||
agent = "-"
|
||||
} else {
|
||||
// TODO(pquerna): this doesn't escape agent if it contains bad things.
|
||||
agent = `"` + agent + `"`
|
||||
}
|
||||
return fmt.Sprintf(`%d %s "%s" %s`, w, agent, w.String(), date.Format(http.TimeFormat))
|
||||
}
|
||||
|
||||
func (w Warning) String() string {
|
||||
switch w {
|
||||
case WarningResponseIsStale:
|
||||
return "Response is Stale"
|
||||
case WarningRevalidationFailed:
|
||||
return "Revalidation Failed"
|
||||
case WarningDisconnectedOperation:
|
||||
return "Disconnected Operation"
|
||||
case WarningHeuristicExpiration:
|
||||
return "Heuristic Expiration"
|
||||
case WarningMiscellaneousWarning:
|
||||
// TODO(pquerna): ideally had a better way to override this one code.
|
||||
return "Miscellaneous Warning"
|
||||
case WarningTransformationApplied:
|
||||
return "Transformation Applied"
|
||||
case WarningMiscellaneousPersistentWarning:
|
||||
// TODO(pquerna): same as WarningMiscellaneousWarning
|
||||
return "Miscellaneous Persistent Warning"
|
||||
}
|
||||
|
||||
panic(w)
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
/**
|
||||
* Copyright 2015 Paul Querna
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
// Package cachecontrol implements the logic for HTTP Caching
|
||||
//
|
||||
// Deciding if an HTTP Response can be cached is often harder
|
||||
// and more bug prone than an actual cache storage backend.
|
||||
// cachecontrol provides a simple interface to determine if
|
||||
// request and response pairs are cachable as defined under
|
||||
// RFC 7234 http://tools.ietf.org/html/rfc7234
|
||||
package cachecontrol
|
|
@ -8,11 +8,9 @@ matrix:
|
|||
- go: tip
|
||||
|
||||
go:
|
||||
- '1.7.x'
|
||||
- '1.8.x'
|
||||
- '1.9.x'
|
||||
- '1.10.x'
|
||||
- '1.11.x'
|
||||
- '1.12.x'
|
||||
- tip
|
||||
|
||||
go_import_path: gopkg.in/square/go-jose.v2
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ import (
|
|||
"math/big"
|
||||
|
||||
"golang.org/x/crypto/ed25519"
|
||||
"gopkg.in/square/go-jose.v2/cipher"
|
||||
josecipher "gopkg.in/square/go-jose.v2/cipher"
|
||||
"gopkg.in/square/go-jose.v2/json"
|
||||
)
|
||||
|
||||
|
@ -288,7 +288,7 @@ func (ctx rsaDecrypterSigner) signPayload(payload []byte, alg SignatureAlgorithm
|
|||
out, err = rsa.SignPKCS1v15(RandReader, ctx.privateKey, hash, hashed)
|
||||
case PS256, PS384, PS512:
|
||||
out, err = rsa.SignPSS(RandReader, ctx.privateKey, hash, hashed, &rsa.PSSOptions{
|
||||
SaltLength: rsa.PSSSaltLengthAuto,
|
||||
SaltLength: rsa.PSSSaltLengthEqualsHash,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -17,8 +17,10 @@
|
|||
package josecipher
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
|
@ -44,16 +46,38 @@ func DeriveECDHES(alg string, apuData, apvData []byte, priv *ecdsa.PrivateKey, p
|
|||
panic("public key not on same curve as private key")
|
||||
}
|
||||
|
||||
z, _ := priv.PublicKey.Curve.ScalarMult(pub.X, pub.Y, priv.D.Bytes())
|
||||
reader := NewConcatKDF(crypto.SHA256, z.Bytes(), algID, ptyUInfo, ptyVInfo, supPubInfo, []byte{})
|
||||
z, _ := priv.Curve.ScalarMult(pub.X, pub.Y, priv.D.Bytes())
|
||||
zBytes := z.Bytes()
|
||||
|
||||
// Note that calling z.Bytes() on a big.Int may strip leading zero bytes from
|
||||
// the returned byte array. This can lead to a problem where zBytes will be
|
||||
// shorter than expected which breaks the key derivation. Therefore we must pad
|
||||
// to the full length of the expected coordinate here before calling the KDF.
|
||||
octSize := dSize(priv.Curve)
|
||||
if len(zBytes) != octSize {
|
||||
zBytes = append(bytes.Repeat([]byte{0}, octSize-len(zBytes)), zBytes...)
|
||||
}
|
||||
|
||||
reader := NewConcatKDF(crypto.SHA256, zBytes, algID, ptyUInfo, ptyVInfo, supPubInfo, []byte{})
|
||||
key := make([]byte, size)
|
||||
|
||||
// Read on the KDF will never fail
|
||||
_, _ = reader.Read(key)
|
||||
|
||||
return key
|
||||
}
|
||||
|
||||
// dSize returns the size in octets for a coordinate on a elliptic curve.
|
||||
func dSize(curve elliptic.Curve) int {
|
||||
order := curve.Params().P
|
||||
bitLen := order.BitLen()
|
||||
size := bitLen / 8
|
||||
if bitLen%8 != 0 {
|
||||
size++
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
func lengthPrefixed(data []byte) []byte {
|
||||
out := make([]byte, len(data)+4)
|
||||
binary.BigEndian.PutUint32(out, uint32(len(data)))
|
||||
|
|
|
@ -141,6 +141,8 @@ func NewEncrypter(enc ContentEncryption, rcpt Recipient, opts *EncrypterOptions)
|
|||
keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key
|
||||
case *JSONWebKey:
|
||||
keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key
|
||||
case OpaqueKeyEncrypter:
|
||||
keyID, rawKey = encryptionKey.KeyID(), encryptionKey
|
||||
default:
|
||||
rawKey = encryptionKey
|
||||
}
|
||||
|
@ -267,9 +269,11 @@ func makeJWERecipient(alg KeyAlgorithm, encryptionKey interface{}) (recipientKey
|
|||
recipient, err := makeJWERecipient(alg, encryptionKey.Key)
|
||||
recipient.keyID = encryptionKey.KeyID
|
||||
return recipient, err
|
||||
default:
|
||||
return recipientKeyInfo{}, ErrUnsupportedKeyType
|
||||
}
|
||||
if encrypter, ok := encryptionKey.(OpaqueKeyEncrypter); ok {
|
||||
return newOpaqueKeyEncrypter(alg, encrypter)
|
||||
}
|
||||
return recipientKeyInfo{}, ErrUnsupportedKeyType
|
||||
}
|
||||
|
||||
// newDecrypter creates an appropriate decrypter based on the key type
|
||||
|
@ -295,9 +299,11 @@ func newDecrypter(decryptionKey interface{}) (keyDecrypter, error) {
|
|||
return newDecrypter(decryptionKey.Key)
|
||||
case *JSONWebKey:
|
||||
return newDecrypter(decryptionKey.Key)
|
||||
default:
|
||||
return nil, ErrUnsupportedKeyType
|
||||
}
|
||||
if okd, ok := decryptionKey.(OpaqueKeyDecrypter); ok {
|
||||
return &opaqueKeyDecrypter{decrypter: okd}, nil
|
||||
}
|
||||
return nil, ErrUnsupportedKeyType
|
||||
}
|
||||
|
||||
// Implementation of encrypt method producing a JWE object.
|
||||
|
|
|
@ -23,13 +23,12 @@ import (
|
|||
"encoding/binary"
|
||||
"io"
|
||||
"math/big"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"gopkg.in/square/go-jose.v2/json"
|
||||
)
|
||||
|
||||
var stripWhitespaceRegex = regexp.MustCompile("\\s")
|
||||
|
||||
// Helper function to serialize known-good objects.
|
||||
// Precondition: value is not a nil pointer.
|
||||
func mustSerializeJSON(value interface{}) []byte {
|
||||
|
@ -56,7 +55,14 @@ func mustSerializeJSON(value interface{}) []byte {
|
|||
|
||||
// Strip all newlines and whitespace
|
||||
func stripWhitespace(data string) string {
|
||||
return stripWhitespaceRegex.ReplaceAllString(data, "")
|
||||
buf := strings.Builder{}
|
||||
buf.Grow(len(data))
|
||||
for _, r := range data {
|
||||
if !unicode.IsSpace(r) {
|
||||
buf.WriteRune(r)
|
||||
}
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// Perform compression based on algorithm
|
||||
|
|
|
@ -357,11 +357,11 @@ func (key rawJSONWebKey) ecPublicKey() (*ecdsa.PublicKey, error) {
|
|||
// the curve specified in the "crv" parameter.
|
||||
// https://tools.ietf.org/html/rfc7518#section-6.2.1.2
|
||||
if curveSize(curve) != len(key.X.data) {
|
||||
return nil, fmt.Errorf("square/go-jose: invalid EC private key, wrong length for x")
|
||||
return nil, fmt.Errorf("square/go-jose: invalid EC public key, wrong length for x")
|
||||
}
|
||||
|
||||
if curveSize(curve) != len(key.Y.data) {
|
||||
return nil, fmt.Errorf("square/go-jose: invalid EC private key, wrong length for y")
|
||||
return nil, fmt.Errorf("square/go-jose: invalid EC public key, wrong length for y")
|
||||
}
|
||||
|
||||
x := key.X.bigInt()
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package jose
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -75,13 +76,21 @@ type Signature struct {
|
|||
}
|
||||
|
||||
// ParseSigned parses a signed message in compact or full serialization format.
|
||||
func ParseSigned(input string) (*JSONWebSignature, error) {
|
||||
input = stripWhitespace(input)
|
||||
if strings.HasPrefix(input, "{") {
|
||||
return parseSignedFull(input)
|
||||
func ParseSigned(signature string) (*JSONWebSignature, error) {
|
||||
signature = stripWhitespace(signature)
|
||||
if strings.HasPrefix(signature, "{") {
|
||||
return parseSignedFull(signature)
|
||||
}
|
||||
|
||||
return parseSignedCompact(input)
|
||||
return parseSignedCompact(signature, nil)
|
||||
}
|
||||
|
||||
// ParseDetached parses a signed message in compact serialization format with detached payload.
|
||||
func ParseDetached(signature string, payload []byte) (*JSONWebSignature, error) {
|
||||
if payload == nil {
|
||||
return nil, errors.New("square/go-jose: nil payload")
|
||||
}
|
||||
return parseSignedCompact(stripWhitespace(signature), payload)
|
||||
}
|
||||
|
||||
// Get a header value
|
||||
|
@ -93,20 +102,39 @@ func (sig Signature) mergedHeaders() rawHeader {
|
|||
}
|
||||
|
||||
// Compute data to be signed
|
||||
func (obj JSONWebSignature) computeAuthData(payload []byte, signature *Signature) []byte {
|
||||
var serializedProtected string
|
||||
func (obj JSONWebSignature) computeAuthData(payload []byte, signature *Signature) ([]byte, error) {
|
||||
var authData bytes.Buffer
|
||||
|
||||
protectedHeader := new(rawHeader)
|
||||
|
||||
if signature.original != nil && signature.original.Protected != nil {
|
||||
serializedProtected = signature.original.Protected.base64()
|
||||
if err := json.Unmarshal(signature.original.Protected.bytes(), protectedHeader); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authData.WriteString(signature.original.Protected.base64())
|
||||
} else if signature.protected != nil {
|
||||
serializedProtected = base64.RawURLEncoding.EncodeToString(mustSerializeJSON(signature.protected))
|
||||
} else {
|
||||
serializedProtected = ""
|
||||
protectedHeader = signature.protected
|
||||
authData.WriteString(base64.RawURLEncoding.EncodeToString(mustSerializeJSON(protectedHeader)))
|
||||
}
|
||||
|
||||
return []byte(fmt.Sprintf("%s.%s",
|
||||
serializedProtected,
|
||||
base64.RawURLEncoding.EncodeToString(payload)))
|
||||
needsBase64 := true
|
||||
|
||||
if protectedHeader != nil {
|
||||
var err error
|
||||
if needsBase64, err = protectedHeader.getB64(); err != nil {
|
||||
needsBase64 = true
|
||||
}
|
||||
}
|
||||
|
||||
authData.WriteByte('.')
|
||||
|
||||
if needsBase64 {
|
||||
authData.WriteString(base64.RawURLEncoding.EncodeToString(payload))
|
||||
} else {
|
||||
authData.Write(payload)
|
||||
}
|
||||
|
||||
return authData.Bytes(), nil
|
||||
}
|
||||
|
||||
// parseSignedFull parses a message in full format.
|
||||
|
@ -246,20 +274,26 @@ func (parsed *rawJSONWebSignature) sanitized() (*JSONWebSignature, error) {
|
|||
}
|
||||
|
||||
// parseSignedCompact parses a message in compact format.
|
||||
func parseSignedCompact(input string) (*JSONWebSignature, error) {
|
||||
func parseSignedCompact(input string, payload []byte) (*JSONWebSignature, error) {
|
||||
parts := strings.Split(input, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("square/go-jose: compact JWS format must have three parts")
|
||||
}
|
||||
|
||||
if parts[1] != "" && payload != nil {
|
||||
return nil, fmt.Errorf("square/go-jose: payload is not detached")
|
||||
}
|
||||
|
||||
rawProtected, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if payload == nil {
|
||||
payload, err = base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
|
@ -275,19 +309,30 @@ func parseSignedCompact(input string) (*JSONWebSignature, error) {
|
|||
return raw.sanitized()
|
||||
}
|
||||
|
||||
// CompactSerialize serializes an object using the compact serialization format.
|
||||
func (obj JSONWebSignature) CompactSerialize() (string, error) {
|
||||
func (obj JSONWebSignature) compactSerialize(detached bool) (string, error) {
|
||||
if len(obj.Signatures) != 1 || obj.Signatures[0].header != nil || obj.Signatures[0].protected == nil {
|
||||
return "", ErrNotSupported
|
||||
}
|
||||
|
||||
serializedProtected := mustSerializeJSON(obj.Signatures[0].protected)
|
||||
serializedProtected := base64.RawURLEncoding.EncodeToString(mustSerializeJSON(obj.Signatures[0].protected))
|
||||
payload := ""
|
||||
signature := base64.RawURLEncoding.EncodeToString(obj.Signatures[0].Signature)
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%s.%s.%s",
|
||||
base64.RawURLEncoding.EncodeToString(serializedProtected),
|
||||
base64.RawURLEncoding.EncodeToString(obj.payload),
|
||||
base64.RawURLEncoding.EncodeToString(obj.Signatures[0].Signature)), nil
|
||||
if !detached {
|
||||
payload = base64.RawURLEncoding.EncodeToString(obj.payload)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s.%s.%s", serializedProtected, payload, signature), nil
|
||||
}
|
||||
|
||||
// CompactSerialize serializes an object using the compact serialization format.
|
||||
func (obj JSONWebSignature) CompactSerialize() (string, error) {
|
||||
return obj.compactSerialize(false)
|
||||
}
|
||||
|
||||
// DetachedCompactSerialize serializes an object using the compact serialization format with detached payload.
|
||||
func (obj JSONWebSignature) DetachedCompactSerialize() (string, error) {
|
||||
return obj.compactSerialize(true)
|
||||
}
|
||||
|
||||
// FullSerialize serializes an object using the full JSON serialization format.
|
||||
|
|
|
@ -81,3 +81,64 @@ type opaqueVerifier struct {
|
|||
func (o *opaqueVerifier) verifyPayload(payload []byte, signature []byte, alg SignatureAlgorithm) error {
|
||||
return o.verifier.VerifyPayload(payload, signature, alg)
|
||||
}
|
||||
|
||||
// OpaqueKeyEncrypter is an interface that supports encrypting keys with an opaque key.
|
||||
type OpaqueKeyEncrypter interface {
|
||||
// KeyID returns the kid
|
||||
KeyID() string
|
||||
// Algs returns a list of supported key encryption algorithms.
|
||||
Algs() []KeyAlgorithm
|
||||
// encryptKey encrypts the CEK using the given algorithm.
|
||||
encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error)
|
||||
}
|
||||
|
||||
type opaqueKeyEncrypter struct {
|
||||
encrypter OpaqueKeyEncrypter
|
||||
}
|
||||
|
||||
func newOpaqueKeyEncrypter(alg KeyAlgorithm, encrypter OpaqueKeyEncrypter) (recipientKeyInfo, error) {
|
||||
var algSupported bool
|
||||
for _, salg := range encrypter.Algs() {
|
||||
if alg == salg {
|
||||
algSupported = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !algSupported {
|
||||
return recipientKeyInfo{}, ErrUnsupportedAlgorithm
|
||||
}
|
||||
|
||||
return recipientKeyInfo{
|
||||
keyID: encrypter.KeyID(),
|
||||
keyAlg: alg,
|
||||
keyEncrypter: &opaqueKeyEncrypter{
|
||||
encrypter: encrypter,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (oke *opaqueKeyEncrypter) encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) {
|
||||
return oke.encrypter.encryptKey(cek, alg)
|
||||
}
|
||||
|
||||
//OpaqueKeyDecrypter is an interface that supports decrypting keys with an opaque key.
|
||||
type OpaqueKeyDecrypter interface {
|
||||
DecryptKey(encryptedKey []byte, header Header) ([]byte, error)
|
||||
}
|
||||
|
||||
type opaqueKeyDecrypter struct {
|
||||
decrypter OpaqueKeyDecrypter
|
||||
}
|
||||
|
||||
func (okd *opaqueKeyDecrypter) decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) {
|
||||
mergedHeaders := rawHeader{}
|
||||
mergedHeaders.merge(&headers)
|
||||
mergedHeaders.merge(recipient.header)
|
||||
|
||||
header, err := mergedHeaders.sanitized()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return okd.decrypter.DecryptKey(recipient.encryptedKey, header)
|
||||
}
|
||||
|
|
|
@ -153,12 +153,18 @@ const (
|
|||
headerJWK = "jwk" // *JSONWebKey
|
||||
headerKeyID = "kid" // string
|
||||
headerNonce = "nonce" // string
|
||||
headerB64 = "b64" // bool
|
||||
|
||||
headerP2C = "p2c" // *byteBuffer (int)
|
||||
headerP2S = "p2s" // *byteBuffer ([]byte)
|
||||
|
||||
)
|
||||
|
||||
// supportedCritical is the set of supported extensions that are understood and processed.
|
||||
var supportedCritical = map[string]bool{
|
||||
headerB64: true,
|
||||
}
|
||||
|
||||
// rawHeader represents the JOSE header for JWE/JWS objects (used for parsing).
|
||||
//
|
||||
// The decoding of the constituent items is deferred because we want to marshal
|
||||
|
@ -349,6 +355,21 @@ func (parsed rawHeader) getP2S() (*byteBuffer, error) {
|
|||
return parsed.getByteBuffer(headerP2S)
|
||||
}
|
||||
|
||||
// getB64 extracts parsed "b64" from the raw JSON, defaulting to true.
|
||||
func (parsed rawHeader) getB64() (bool, error) {
|
||||
v := parsed[headerB64]
|
||||
if v == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
var b64 bool
|
||||
err := json.Unmarshal(*v, &b64)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
return b64, nil
|
||||
}
|
||||
|
||||
// sanitized produces a cleaned-up header object from the raw JSON.
|
||||
func (parsed rawHeader) sanitized() (h Header, err error) {
|
||||
for k, v := range parsed {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package jose
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
|
@ -77,6 +78,27 @@ func (so *SignerOptions) WithType(typ ContentType) *SignerOptions {
|
|||
return so.WithHeader(HeaderType, typ)
|
||||
}
|
||||
|
||||
// WithCritical adds the given names to the critical ("crit") header and returns
|
||||
// the updated SignerOptions.
|
||||
func (so *SignerOptions) WithCritical(names ...string) *SignerOptions {
|
||||
if so.ExtraHeaders[headerCritical] == nil {
|
||||
so.WithHeader(headerCritical, make([]string, 0, len(names)))
|
||||
}
|
||||
crit := so.ExtraHeaders[headerCritical].([]string)
|
||||
so.ExtraHeaders[headerCritical] = append(crit, names...)
|
||||
return so
|
||||
}
|
||||
|
||||
// WithBase64 adds a base64url-encode payload ("b64") header and returns the updated
|
||||
// SignerOptions. When the "b64" value is "false", the payload is not base64 encoded.
|
||||
func (so *SignerOptions) WithBase64(b64 bool) *SignerOptions {
|
||||
if !b64 {
|
||||
so.WithHeader(headerB64, b64)
|
||||
so.WithCritical(headerB64)
|
||||
}
|
||||
return so
|
||||
}
|
||||
|
||||
type payloadSigner interface {
|
||||
signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error)
|
||||
}
|
||||
|
@ -233,7 +255,10 @@ func (ctx *genericSigner) Sign(payload []byte) (*JSONWebSignature, error) {
|
|||
if ctx.embedJWK {
|
||||
protected[headerJWK] = recipient.publicKey()
|
||||
} else {
|
||||
protected[headerKeyID] = recipient.publicKey().KeyID
|
||||
keyID := recipient.publicKey().KeyID
|
||||
if keyID != "" {
|
||||
protected[headerKeyID] = keyID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -250,12 +275,26 @@ func (ctx *genericSigner) Sign(payload []byte) (*JSONWebSignature, error) {
|
|||
}
|
||||
|
||||
serializedProtected := mustSerializeJSON(protected)
|
||||
needsBase64 := true
|
||||
|
||||
input := []byte(fmt.Sprintf("%s.%s",
|
||||
base64.RawURLEncoding.EncodeToString(serializedProtected),
|
||||
base64.RawURLEncoding.EncodeToString(payload)))
|
||||
if b64, ok := protected[headerB64]; ok {
|
||||
if needsBase64, ok = b64.(bool); !ok {
|
||||
return nil, errors.New("square/go-jose: Invalid b64 header parameter")
|
||||
}
|
||||
}
|
||||
|
||||
signatureInfo, err := recipient.signer.signPayload(input, recipient.sigAlg)
|
||||
var input bytes.Buffer
|
||||
|
||||
input.WriteString(base64.RawURLEncoding.EncodeToString(serializedProtected))
|
||||
input.WriteByte('.')
|
||||
|
||||
if needsBase64 {
|
||||
input.WriteString(base64.RawURLEncoding.EncodeToString(payload))
|
||||
} else {
|
||||
input.Write(payload)
|
||||
}
|
||||
|
||||
signatureInfo, err := recipient.signer.signPayload(input.Bytes(), recipient.sigAlg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -324,12 +363,18 @@ func (obj JSONWebSignature) DetachedVerify(payload []byte, verificationKey inter
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(critical) > 0 {
|
||||
// Unsupported crit header
|
||||
|
||||
for _, name := range critical {
|
||||
if !supportedCritical[name] {
|
||||
return ErrCryptoFailure
|
||||
}
|
||||
}
|
||||
|
||||
input, err := obj.computeAuthData(payload, &signature)
|
||||
if err != nil {
|
||||
return ErrCryptoFailure
|
||||
}
|
||||
|
||||
input := obj.computeAuthData(payload, &signature)
|
||||
alg := headers.getSignatureAlgorithm()
|
||||
err = verifier.verifyPayload(input, signature.Signature, alg)
|
||||
if err == nil {
|
||||
|
@ -366,18 +411,25 @@ func (obj JSONWebSignature) DetachedVerifyMulti(payload []byte, verificationKey
|
|||
return -1, Signature{}, err
|
||||
}
|
||||
|
||||
outer:
|
||||
for i, signature := range obj.Signatures {
|
||||
headers := signature.mergedHeaders()
|
||||
critical, err := headers.getCritical()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if len(critical) > 0 {
|
||||
// Unsupported crit header
|
||||
|
||||
for _, name := range critical {
|
||||
if !supportedCritical[name] {
|
||||
continue outer
|
||||
}
|
||||
}
|
||||
|
||||
input, err := obj.computeAuthData(payload, &signature)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
input := obj.computeAuthData(payload, &signature)
|
||||
alg := headers.getSignatureAlgorithm()
|
||||
err = verifier.verifyPayload(input, signature.Signature, alg)
|
||||
if err == nil {
|
||||
|
|
|
@ -91,6 +91,8 @@ github.com/circonus-labs/circonus-gometrics/checkmgr
|
|||
github.com/circonus-labs/circonusllhist
|
||||
# github.com/coredns/coredns v1.1.2
|
||||
github.com/coredns/coredns/plugin/pkg/dnsutil
|
||||
# github.com/coreos/go-oidc v2.1.0+incompatible
|
||||
github.com/coreos/go-oidc
|
||||
# github.com/davecgh/go-spew v1.1.1
|
||||
github.com/davecgh/go-spew/spew
|
||||
# github.com/denverdino/aliyungo v0.0.0-20170926055100-d3308649c661
|
||||
|
@ -222,7 +224,7 @@ github.com/hashicorp/go-sockaddr
|
|||
github.com/hashicorp/go-sockaddr/template
|
||||
# github.com/hashicorp/go-syslog v1.0.0
|
||||
github.com/hashicorp/go-syslog
|
||||
# github.com/hashicorp/go-uuid v1.0.1
|
||||
# github.com/hashicorp/go-uuid v1.0.2
|
||||
github.com/hashicorp/go-uuid
|
||||
# github.com/hashicorp/go-version v1.2.0
|
||||
github.com/hashicorp/go-version
|
||||
|
@ -307,6 +309,8 @@ github.com/mitchellh/go-testing-interface
|
|||
github.com/mitchellh/hashstructure
|
||||
# github.com/mitchellh/mapstructure v1.2.3
|
||||
github.com/mitchellh/mapstructure
|
||||
# github.com/mitchellh/pointerstructure v1.0.0
|
||||
github.com/mitchellh/pointerstructure
|
||||
# github.com/mitchellh/reflectwalk v1.0.1
|
||||
github.com/mitchellh/reflectwalk
|
||||
# github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd
|
||||
|
@ -319,6 +323,8 @@ github.com/nicolai86/scaleway-sdk
|
|||
github.com/packethost/packngo
|
||||
# github.com/pascaldekloe/goe v0.1.0
|
||||
github.com/pascaldekloe/goe/verify
|
||||
# github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/patrickmn/go-cache
|
||||
# github.com/pierrec/lz4 v2.0.5+incompatible
|
||||
github.com/pierrec/lz4
|
||||
github.com/pierrec/lz4/internal/xxh32
|
||||
|
@ -330,6 +336,9 @@ github.com/pmezard/go-difflib/difflib
|
|||
github.com/posener/complete
|
||||
github.com/posener/complete/cmd
|
||||
github.com/posener/complete/cmd/install
|
||||
# github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35
|
||||
github.com/pquerna/cachecontrol
|
||||
github.com/pquerna/cachecontrol/cacheobject
|
||||
# github.com/prometheus/client_golang v1.0.0
|
||||
github.com/prometheus/client_golang/prometheus
|
||||
github.com/prometheus/client_golang/prometheus/internal
|
||||
|
@ -543,7 +552,7 @@ google.golang.org/grpc/tap
|
|||
gopkg.in/inf.v0
|
||||
# gopkg.in/resty.v1 v1.12.0
|
||||
gopkg.in/resty.v1
|
||||
# gopkg.in/square/go-jose.v2 v2.3.1
|
||||
# gopkg.in/square/go-jose.v2 v2.4.1
|
||||
gopkg.in/square/go-jose.v2
|
||||
gopkg.in/square/go-jose.v2/cipher
|
||||
gopkg.in/square/go-jose.v2/json
|
||||
|
|
Loading…
Reference in New Issue