acl: add auth method for JWTs (#7846)

This commit is contained in:
R.B. Boyer 2020-05-11 20:59:29 -05:00 committed by GitHub
parent 24175e2925
commit 940e5ad160
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
85 changed files with 11112 additions and 81 deletions

View file

@ -8,13 +8,18 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/consul/authmethod/testauth"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest"
"github.com/hashicorp/consul/sdk/freeport"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/consul/testrpc"
"github.com/hashicorp/go-uuid"
"github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2/jwt"
)
// NOTE: The tests contained herein are designed to test the HTTP API
@ -1591,6 +1596,168 @@ func TestACL_LoginProcedure_HTTP(t *testing.T) {
})
}
func TestACLEndpoint_LoginLogout_jwt(t *testing.T) {
t.Parallel()
a := NewTestAgent(t, TestACLConfigWithParams(nil))
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
// spin up a fake oidc server
oidcServer := startSSOTestServer(t)
pubKey, privKey := oidcServer.SigningKeys()
type mConfig = map[string]interface{}
cases := map[string]struct {
f func(config mConfig)
issuer string
expectErr string
}{
"success - jwt static keys": {func(config mConfig) {
config["BoundIssuer"] = "https://legit.issuer.internal/"
config["JWTValidationPubKeys"] = []string{pubKey}
},
"https://legit.issuer.internal/",
""},
"success - jwt jwks": {func(config mConfig) {
config["JWKSURL"] = oidcServer.Addr() + "/certs"
config["JWKSCACert"] = oidcServer.CACert()
},
"https://legit.issuer.internal/",
""},
"success - jwt oidc discovery": {func(config mConfig) {
config["OIDCDiscoveryURL"] = oidcServer.Addr()
config["OIDCDiscoveryCACert"] = oidcServer.CACert()
},
oidcServer.Addr(),
""},
}
for name, tc := range cases {
tc := tc
t.Run(name, func(t *testing.T) {
method, err := upsertTestCustomizedAuthMethod(a.RPC, TestDefaultMasterToken, "dc1", func(method *structs.ACLAuthMethod) {
method.Type = "jwt"
method.Config = map[string]interface{}{
"JWTSupportedAlgs": []string{"ES256"},
"ClaimMappings": map[string]string{
"first_name": "name",
"/org/primary": "primary_org",
},
"ListClaimMappings": map[string]string{
"https://consul.test/groups": "groups",
},
"BoundAudiences": []string{"https://consul.test"},
}
if tc.f != nil {
tc.f(method.Config)
}
})
require.NoError(t, err)
t.Run("invalid bearer token", func(t *testing.T) {
loginInput := &structs.ACLLoginParams{
AuthMethod: method.Name,
BearerToken: "invalid",
}
req, _ := http.NewRequest("POST", "/v1/acl/login", jsonBody(loginInput))
resp := httptest.NewRecorder()
_, err := a.srv.ACLLogin(resp, req)
require.Error(t, err)
})
cl := jwt.Claims{
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
Audience: jwt.Audience{"https://consul.test"},
Issuer: tc.issuer,
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Second)),
}
type orgs struct {
Primary string `json:"primary"`
}
privateCl := struct {
FirstName string `json:"first_name"`
Org orgs `json:"org"`
Groups []string `json:"https://consul.test/groups"`
}{
FirstName: "jeff2",
Org: orgs{"engineering"},
Groups: []string{"foo", "bar"},
}
jwtData, err := oidcauthtest.SignJWT(privKey, cl, privateCl)
require.NoError(t, err)
t.Run("valid bearer token no bindings", func(t *testing.T) {
loginInput := &structs.ACLLoginParams{
AuthMethod: method.Name,
BearerToken: jwtData,
}
req, _ := http.NewRequest("POST", "/v1/acl/login", jsonBody(loginInput))
resp := httptest.NewRecorder()
_, err := a.srv.ACLLogin(resp, req)
testutil.RequireErrorContains(t, err, "Permission denied")
})
_, err = upsertTestCustomizedBindingRule(a.RPC, TestDefaultMasterToken, "dc1", func(rule *structs.ACLBindingRule) {
rule.AuthMethod = method.Name
rule.BindType = structs.BindingRuleBindTypeService
rule.BindName = "test--${value.name}--${value.primary_org}"
rule.Selector = "value.name == jeff2 and value.primary_org == engineering and foo in list.groups"
})
require.NoError(t, err)
t.Run("valid bearer token 1 service binding", func(t *testing.T) {
loginInput := &structs.ACLLoginParams{
AuthMethod: method.Name,
BearerToken: jwtData,
}
req, _ := http.NewRequest("POST", "/v1/acl/login", jsonBody(loginInput))
resp := httptest.NewRecorder()
obj, err := a.srv.ACLLogin(resp, req)
require.NoError(t, err)
token, ok := obj.(*structs.ACLToken)
require.True(t, ok)
require.Equal(t, method.Name, token.AuthMethod)
require.Equal(t, `token created via login`, token.Description)
require.True(t, token.Local)
require.Len(t, token.Roles, 0)
require.Len(t, token.ServiceIdentities, 1)
svcid := token.ServiceIdentities[0]
require.Len(t, svcid.Datacenters, 0)
require.Equal(t, "test--jeff2--engineering", svcid.ServiceName)
// and delete it
req, _ = http.NewRequest("GET", "/v1/acl/logout", nil)
req.Header.Add("X-Consul-Token", token.SecretID)
resp = httptest.NewRecorder()
_, err = a.srv.ACLLogout(resp, req)
require.NoError(t, err)
// verify the token was deleted
req, _ = http.NewRequest("GET", "/v1/acl/token/"+token.AccessorID, nil)
req.Header.Add("X-Consul-Token", TestDefaultMasterToken)
resp = httptest.NewRecorder()
// make the request
_, err = a.srv.ACLTokenCRUD(resp, req)
require.Error(t, err)
require.Equal(t, acl.ErrNotFound, err)
})
})
}
}
func TestACL_Authorize(t *testing.T) {
t.Parallel()
a1 := NewTestAgent(t, TestACLConfigWithParams(nil))
@ -2087,3 +2254,11 @@ func upsertTestCustomizedBindingRule(rpc rpcFn, masterToken string, datacenter s
return &out, nil
}
func startSSOTestServer(t *testing.T) *oidcauthtest.Server {
ports := freeport.MustTake(1)
return oidcauthtest.Start(t, oidcauthtest.WithPort(
ports[0],
func() { freeport.Return(ports) },
))
}

View file

@ -7,8 +7,9 @@ import (
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/go-bexpr"
// register this as a builtin auth method
// register these as a builtin auth method
_ "github.com/hashicorp/consul/agent/consul/authmethod/kubeauth"
_ "github.com/hashicorp/consul/agent/consul/authmethod/ssoauth"
)
type authMethodValidatorEntry struct {

View file

@ -0,0 +1,7 @@
//+build !consulent
package consul
func (s *Server) enterpriseEvaluateRoleBindings() error {
return nil
}

View file

@ -1909,7 +1909,6 @@ func (a *ACL) BindingRuleSet(args *structs.ACLBindingRuleSetRequest, reply *stru
if err != nil {
return fmt.Errorf("Failed to apply binding rule upsert request: %v", err)
}
if respErr, ok := resp.(error); ok {
return fmt.Errorf("Failed to apply binding rule upsert request: %v", respErr)
}
@ -2049,6 +2048,10 @@ func (a *ACL) AuthMethodRead(args *structs.ACLAuthMethodGetRequest, reply *struc
return err
}
if method != nil {
_ = a.enterpriseAuthMethodTypeValidation(method.Type)
}
reply.Index, reply.AuthMethod = index, method
return nil
})
@ -2093,6 +2096,10 @@ func (a *ACL) AuthMethodSet(args *structs.ACLAuthMethodSetRequest, reply *struct
return fmt.Errorf("Invalid Auth Method: invalid Name. Only alphanumeric characters, '-' and '_' are allowed")
}
if err := a.enterpriseAuthMethodTypeValidation(method.Type); err != nil {
return err
}
// Check to see if the method exists first.
_, existing, err := state.ACLAuthMethodGetByName(nil, method.Name, &method.EnterpriseMeta)
if err != nil {
@ -2193,6 +2200,10 @@ func (a *ACL) AuthMethodDelete(args *structs.ACLAuthMethodDeleteRequest, reply *
return nil
}
if err := a.enterpriseAuthMethodTypeValidation(method.Type); err != nil {
return err
}
req := structs.ACLAuthMethodBatchDeleteRequest{
AuthMethodNames: []string{args.AuthMethodName},
EnterpriseMeta: args.EnterpriseMeta,
@ -2249,6 +2260,7 @@ func (a *ACL) AuthMethodList(args *structs.ACLAuthMethodListRequest, reply *stru
var stubs structs.ACLAuthMethodListStubs
for _, method := range methods {
_ = a.enterpriseAuthMethodTypeValidation(method.Type)
stubs = append(stubs, method.Stub())
}
@ -2294,6 +2306,10 @@ func (a *ACL) Login(args *structs.ACLLoginRequest, reply *structs.ACLToken) erro
return acl.ErrNotFound
}
if err := a.enterpriseAuthMethodTypeValidation(method.Type); err != nil {
return err
}
validator, err := a.srv.loadAuthMethodValidator(idx, method)
if err != nil {
return err
@ -2305,6 +2321,31 @@ func (a *ACL) Login(args *structs.ACLLoginRequest, reply *structs.ACLToken) erro
return err
}
return a.tokenSetFromAuthMethod(
method,
&auth.EnterpriseMeta,
"token created via login",
auth.Meta,
validator,
verifiedIdentity,
&structs.ACLTokenSetRequest{
Datacenter: args.Datacenter,
WriteRequest: args.WriteRequest,
},
reply,
)
}
func (a *ACL) tokenSetFromAuthMethod(
method *structs.ACLAuthMethod,
entMeta *structs.EnterpriseMeta,
tokenDescriptionPrefix string,
tokenMetadata map[string]string,
validator authmethod.Validator,
verifiedIdentity *authmethod.Identity,
createReq *structs.ACLTokenSetRequest, // this should be prepopulated with datacenter+writerequest
reply *structs.ACLToken,
) error {
// This always will return a valid pointer
targetMeta, err := computeTargetEnterpriseMeta(method, verifiedIdentity)
if err != nil {
@ -2312,7 +2353,7 @@ func (a *ACL) Login(args *structs.ACLLoginRequest, reply *structs.ACLToken) erro
}
// 3. send map through role bindings
serviceIdentities, roleLinks, err := a.srv.evaluateRoleBindings(validator, verifiedIdentity, &auth.EnterpriseMeta, targetMeta)
serviceIdentities, roleLinks, err := a.srv.evaluateRoleBindings(validator, verifiedIdentity, entMeta, targetMeta)
if err != nil {
return err
}
@ -2323,8 +2364,10 @@ func (a *ACL) Login(args *structs.ACLLoginRequest, reply *structs.ACLToken) erro
return acl.ErrPermissionDenied
}
description := "token created via login"
loginMeta, err := encodeLoginMeta(auth.Meta)
// TODO(sso): add a CapturedField to ACLAuthMethod that would pluck fields from the returned identity and stuff into `auth.Meta`.
description := tokenDescriptionPrefix
loginMeta, err := encodeLoginMeta(tokenMetadata)
if err != nil {
return err
}
@ -2333,24 +2376,20 @@ func (a *ACL) Login(args *structs.ACLLoginRequest, reply *structs.ACLToken) erro
}
// 4. create token
createReq := structs.ACLTokenSetRequest{
Datacenter: args.Datacenter,
ACLToken: structs.ACLToken{
createReq.ACLToken = structs.ACLToken{
Description: description,
Local: true,
AuthMethod: auth.AuthMethod,
AuthMethod: method.Name,
ServiceIdentities: serviceIdentities,
Roles: roleLinks,
ExpirationTTL: method.MaxTokenTTL,
EnterpriseMeta: *targetMeta,
},
WriteRequest: args.WriteRequest,
}
createReq.ACLToken.ACLAuthMethodEnterpriseMeta.FillWithEnterpriseMeta(&auth.EnterpriseMeta)
createReq.ACLToken.ACLAuthMethodEnterpriseMeta.FillWithEnterpriseMeta(entMeta)
// 5. return token information like a TokenCreate would
err = a.tokenSetInternal(&createReq, reply, true)
err = a.tokenSetInternal(createReq, reply, true)
// If we were in a slight race with a role delete operation then we may
// still end up failing to insert an unprivileged token in the state

View file

@ -7,6 +7,10 @@ import (
"github.com/hashicorp/consul/agent/structs"
)
func (a *ACL) enterpriseAuthMethodTypeValidation(authMethodType string) error {
return nil
}
func enterpriseAuthMethodValidation(method *structs.ACLAuthMethod, validator authmethod.Validator) error {
return nil
}

View file

@ -16,13 +16,16 @@ import (
"github.com/hashicorp/consul/agent/consul/authmethod/testauth"
"github.com/hashicorp/consul/agent/structs"
tokenStore "github.com/hashicorp/consul/agent/token"
"github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/consul/sdk/freeport"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/consul/sdk/testutil/retry"
"github.com/hashicorp/consul/testrpc"
uuid "github.com/hashicorp/go-uuid"
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2/jwt"
)
func TestACLEndpoint_Bootstrap(t *testing.T) {
@ -5233,6 +5236,167 @@ func TestACLEndpoint_Login_k8s(t *testing.T) {
})
}
func TestACLEndpoint_Login_jwt(t *testing.T) {
t.Parallel()
dir1, s1 := testServerWithConfig(t, func(c *Config) {
c.ACLDatacenter = "dc1"
c.ACLsEnabled = true
c.ACLMasterToken = "root"
})
defer os.RemoveAll(dir1)
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
acl := ACL{srv: s1}
// spin up a fake oidc server
oidcServer := startSSOTestServer(t)
pubKey, privKey := oidcServer.SigningKeys()
type mConfig = map[string]interface{}
cases := map[string]struct {
f func(config mConfig)
issuer string
expectErr string
}{
"success - jwt static keys": {func(config mConfig) {
config["BoundIssuer"] = "https://legit.issuer.internal/"
config["JWTValidationPubKeys"] = []string{pubKey}
},
"https://legit.issuer.internal/",
""},
"success - jwt jwks": {func(config mConfig) {
config["JWKSURL"] = oidcServer.Addr() + "/certs"
config["JWKSCACert"] = oidcServer.CACert()
},
"https://legit.issuer.internal/",
""},
"success - jwt oidc discovery": {func(config mConfig) {
config["OIDCDiscoveryURL"] = oidcServer.Addr()
config["OIDCDiscoveryCACert"] = oidcServer.CACert()
},
oidcServer.Addr(),
""},
}
for name, tc := range cases {
tc := tc
t.Run(name, func(t *testing.T) {
method, err := upsertTestCustomizedAuthMethod(codec, "root", "dc1", func(method *structs.ACLAuthMethod) {
method.Type = "jwt"
method.Config = map[string]interface{}{
"JWTSupportedAlgs": []string{"ES256"},
"ClaimMappings": map[string]string{
"first_name": "name",
"/org/primary": "primary_org",
},
"ListClaimMappings": map[string]string{
"https://consul.test/groups": "groups",
},
"BoundAudiences": []string{"https://consul.test"},
}
if tc.f != nil {
tc.f(method.Config)
}
})
require.NoError(t, err)
t.Run("invalid bearer token", func(t *testing.T) {
req := structs.ACLLoginRequest{
Auth: &structs.ACLLoginParams{
AuthMethod: method.Name,
BearerToken: "invalid",
},
Datacenter: "dc1",
}
resp := structs.ACLToken{}
require.Error(t, acl.Login(&req, &resp))
})
cl := jwt.Claims{
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
Audience: jwt.Audience{"https://consul.test"},
Issuer: tc.issuer,
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Second)),
}
type orgs struct {
Primary string `json:"primary"`
}
privateCl := struct {
FirstName string `json:"first_name"`
Org orgs `json:"org"`
Groups []string `json:"https://consul.test/groups"`
}{
FirstName: "jeff2",
Org: orgs{"engineering"},
Groups: []string{"foo", "bar"},
}
jwtData, err := oidcauthtest.SignJWT(privKey, cl, privateCl)
require.NoError(t, err)
t.Run("valid bearer token no bindings", func(t *testing.T) {
req := structs.ACLLoginRequest{
Auth: &structs.ACLLoginParams{
AuthMethod: method.Name,
BearerToken: jwtData,
},
Datacenter: "dc1",
}
resp := structs.ACLToken{}
testutil.RequireErrorContains(t, acl.Login(&req, &resp), "Permission denied")
})
_, err = upsertTestBindingRule(
codec, "root", "dc1", method.Name,
"value.name == jeff2 and value.primary_org == engineering and foo in list.groups",
structs.BindingRuleBindTypeService,
"test--${value.name}--${value.primary_org}",
)
require.NoError(t, err)
t.Run("valid bearer token 1 service binding", func(t *testing.T) {
req := structs.ACLLoginRequest{
Auth: &structs.ACLLoginParams{
AuthMethod: method.Name,
BearerToken: jwtData,
},
Datacenter: "dc1",
}
resp := structs.ACLToken{}
require.NoError(t, acl.Login(&req, &resp))
require.Equal(t, method.Name, resp.AuthMethod)
require.Equal(t, `token created via login`, resp.Description)
require.True(t, resp.Local)
require.Len(t, resp.Roles, 0)
require.Len(t, resp.ServiceIdentities, 1)
svcid := resp.ServiceIdentities[0]
require.Len(t, svcid.Datacenters, 0)
require.Equal(t, "test--jeff2--engineering", svcid.ServiceName)
})
})
}
}
func startSSOTestServer(t *testing.T) *oidcauthtest.Server {
ports := freeport.MustTake(1)
return oidcauthtest.Start(t, oidcauthtest.WithPort(
ports[0],
func() { freeport.Return(ports) },
))
}
func TestACLEndpoint_Logout(t *testing.T) {
t.Parallel()

View file

@ -94,6 +94,10 @@ func NewValidator(method *structs.ACLAuthMethod) (*Validator, error) {
return nil, fmt.Errorf("Config.ServiceAccountJWT is not a valid JWT: %v", err)
}
if err := enterpriseValidation(method, &config); err != nil {
return nil, err
}
transport := cleanhttp.DefaultTransport()
client, err := k8s.NewForConfig(&client_rest.Config{
Host: config.Host,

View file

@ -0,0 +1,165 @@
package ssoauth
import (
"context"
"time"
"github.com/hashicorp/consul/agent/consul/authmethod"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/internal/go-sso/oidcauth"
"github.com/hashicorp/go-hclog"
)
func init() {
authmethod.Register("jwt", func(logger hclog.Logger, method *structs.ACLAuthMethod) (authmethod.Validator, error) {
v, err := NewValidator(logger, method)
if err != nil {
return nil, err
}
return v, nil
})
}
// Validator is the wrapper around the go-sso library that also conforms to the
// authmethod.Validator interface.
type Validator struct {
name string
methodType string
config *oidcauth.Config
logger hclog.Logger
oa *oidcauth.Authenticator
}
var _ authmethod.Validator = (*Validator)(nil)
func NewValidator(logger hclog.Logger, method *structs.ACLAuthMethod) (*Validator, error) {
if err := validateType(method.Type); err != nil {
return nil, err
}
var config Config
if err := authmethod.ParseConfig(method.Config, &config); err != nil {
return nil, err
}
ssoConfig := config.convertForLibrary(method.Type)
oa, err := oidcauth.New(ssoConfig, logger)
if err != nil {
return nil, err
}
v := &Validator{
name: method.Name,
methodType: method.Type,
config: ssoConfig,
logger: logger,
oa: oa,
}
return v, nil
}
// Name implements authmethod.Validator.
func (v *Validator) Name() string { return v.name }
// Stop implements authmethod.Validator.
func (v *Validator) Stop() { v.oa.Stop() }
// ValidateLogin implements authmethod.Validator.
func (v *Validator) ValidateLogin(ctx context.Context, loginToken string) (*authmethod.Identity, error) {
c, err := v.oa.ClaimsFromJWT(ctx, loginToken)
if err != nil {
return nil, err
}
return v.identityFromClaims(c), nil
}
func (v *Validator) identityFromClaims(c *oidcauth.Claims) *authmethod.Identity {
id := v.NewIdentity()
id.SelectableFields = &fieldDetails{
Values: c.Values,
Lists: c.Lists,
}
for k, val := range c.Values {
id.ProjectedVars["value."+k] = val
}
id.EnterpriseMeta = v.ssoEntMetaFromClaims(c)
return id
}
// NewIdentity implements authmethod.Validator.
func (v *Validator) NewIdentity() *authmethod.Identity {
// Populate selectable fields with empty values so emptystring filters
// works. Populate projectable vars with empty values so HIL works.
fd := &fieldDetails{
Values: make(map[string]string),
Lists: make(map[string][]string),
}
projectedVars := make(map[string]string)
for _, k := range v.config.ClaimMappings {
fd.Values[k] = ""
projectedVars["value."+k] = ""
}
for _, k := range v.config.ListClaimMappings {
fd.Lists[k] = nil
}
return &authmethod.Identity{
SelectableFields: fd,
ProjectedVars: projectedVars,
}
}
type fieldDetails struct {
Values map[string]string `bexpr:"value"`
Lists map[string][]string `bexpr:"list"`
}
// Config is the collection of all settings that pertain to doing OIDC-based
// authentication and direct JWT-based authentication processes.
type Config struct {
// common for type=oidc and type=jwt
JWTSupportedAlgs []string `json:",omitempty"`
BoundAudiences []string `json:",omitempty"`
ClaimMappings map[string]string `json:",omitempty"`
ListClaimMappings map[string]string `json:",omitempty"`
OIDCDiscoveryURL string `json:",omitempty"`
OIDCDiscoveryCACert string `json:",omitempty"`
// just for type=jwt
JWKSURL string `json:",omitempty"`
JWKSCACert string `json:",omitempty"`
JWTValidationPubKeys []string `json:",omitempty"`
BoundIssuer string `json:",omitempty"`
ExpirationLeeway time.Duration `json:",omitempty"`
NotBeforeLeeway time.Duration `json:",omitempty"`
ClockSkewLeeway time.Duration `json:",omitempty"`
enterpriseConfig `mapstructure:",squash"`
}
func (c *Config) convertForLibrary(methodType string) *oidcauth.Config {
ssoConfig := &oidcauth.Config{
Type: methodType,
// common for type=oidc and type=jwt
JWTSupportedAlgs: c.JWTSupportedAlgs,
BoundAudiences: c.BoundAudiences,
ClaimMappings: c.ClaimMappings,
ListClaimMappings: c.ListClaimMappings,
OIDCDiscoveryURL: c.OIDCDiscoveryURL,
OIDCDiscoveryCACert: c.OIDCDiscoveryCACert,
// just for type=jwt
JWKSURL: c.JWKSURL,
JWKSCACert: c.JWKSCACert,
JWTValidationPubKeys: c.JWTValidationPubKeys,
BoundIssuer: c.BoundIssuer,
ExpirationLeeway: c.ExpirationLeeway,
NotBeforeLeeway: c.NotBeforeLeeway,
ClockSkewLeeway: c.ClockSkewLeeway,
}
c.enterpriseConvertForLibrary(ssoConfig)
return ssoConfig
}

View file

@ -0,0 +1,25 @@
//+build !consulent
package ssoauth
import (
"fmt"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/internal/go-sso/oidcauth"
)
func validateType(typ string) error {
if typ != "jwt" {
return fmt.Errorf("type should be %q", "jwt")
}
return nil
}
func (v *Validator) ssoEntMetaFromClaims(_ *oidcauth.Claims) *structs.EnterpriseMeta {
return nil
}
type enterpriseConfig struct{}
func (c *Config) enterpriseConvertForLibrary(_ *oidcauth.Config) {}

View file

@ -0,0 +1,270 @@
package ssoauth
import (
"context"
"testing"
"time"
"github.com/hashicorp/consul/agent/consul/authmethod"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest"
"github.com/hashicorp/consul/sdk/freeport"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2/jwt"
)
func TestJWT_NewValidator(t *testing.T) {
nullLogger := hclog.NewNullLogger()
type AM = *structs.ACLAuthMethod
makeAuthMethod := func(typ string, f func(method AM)) *structs.ACLAuthMethod {
method := &structs.ACLAuthMethod{
Name: "test-" + typ,
Description: typ + " test",
Type: typ,
Config: map[string]interface{}{},
}
if f != nil {
f(method)
}
return method
}
oidcServer := startTestServer(t)
// Note that we won't test ALL of the available config variations here.
// The go-sso library has exhaustive tests.
for name, tc := range map[string]struct {
method *structs.ACLAuthMethod
expectErr string
}{
"wrong type": {makeAuthMethod("invalid", nil), `type should be`},
"extra config": {makeAuthMethod("jwt", func(method AM) {
method.Config["extra"] = "config"
}), "has invalid keys"},
"wrong type of key in config blob": {makeAuthMethod("jwt", func(method AM) {
method.Config["JWKSURL"] = []int{12345}
}), `'JWKSURL' expected type 'string', got unconvertible type '[]int'`},
"normal jwt - static keys": {makeAuthMethod("jwt", func(method AM) {
method.Config["BoundIssuer"] = "https://legit.issuer.internal/"
pubKey, _ := oidcServer.SigningKeys()
method.Config["JWTValidationPubKeys"] = []string{pubKey}
}), ""},
"normal jwt - jwks": {makeAuthMethod("jwt", func(method AM) {
method.Config["JWKSURL"] = oidcServer.Addr() + "/certs"
method.Config["JWKSCACert"] = oidcServer.CACert()
}), ""},
"normal jwt - oidc discovery": {makeAuthMethod("jwt", func(method AM) {
method.Config["OIDCDiscoveryURL"] = oidcServer.Addr()
method.Config["OIDCDiscoveryCACert"] = oidcServer.CACert()
}), ""},
} {
tc := tc
t.Run(name, func(t *testing.T) {
v, err := NewValidator(nullLogger, tc.method)
if tc.expectErr != "" {
testutil.RequireErrorContains(t, err, tc.expectErr)
require.Nil(t, v)
} else {
require.NoError(t, err)
require.NotNil(t, v)
v.Stop()
}
})
}
}
func TestJWT_ValidateLogin(t *testing.T) {
type mConfig = map[string]interface{}
setup := func(t *testing.T, f func(config mConfig)) *Validator {
t.Helper()
config := map[string]interface{}{
"JWTSupportedAlgs": []string{"ES256"},
"ClaimMappings": map[string]string{
"first_name": "name",
"/org/primary": "primary_org",
},
"ListClaimMappings": map[string]string{
"https://consul.test/groups": "groups",
},
"BoundAudiences": []string{"https://consul.test"},
}
if f != nil {
f(config)
}
method := &structs.ACLAuthMethod{
Name: "test-method",
Type: "jwt",
Config: config,
}
nullLogger := hclog.NewNullLogger()
v, err := NewValidator(nullLogger, method)
require.NoError(t, err)
return v
}
oidcServer := startTestServer(t)
pubKey, privKey := oidcServer.SigningKeys()
cases := map[string]struct {
f func(config mConfig)
issuer string
expectErr string
}{
"success - jwt static keys": {func(config mConfig) {
config["BoundIssuer"] = "https://legit.issuer.internal/"
config["JWTValidationPubKeys"] = []string{pubKey}
},
"https://legit.issuer.internal/",
""},
"success - jwt jwks": {func(config mConfig) {
config["JWKSURL"] = oidcServer.Addr() + "/certs"
config["JWKSCACert"] = oidcServer.CACert()
},
"https://legit.issuer.internal/",
""},
"success - jwt oidc discovery": {func(config mConfig) {
config["OIDCDiscoveryURL"] = oidcServer.Addr()
config["OIDCDiscoveryCACert"] = oidcServer.CACert()
},
oidcServer.Addr(),
""},
}
for name, tc := range cases {
tc := tc
t.Run(name, func(t *testing.T) {
v := setup(t, tc.f)
cl := jwt.Claims{
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
Audience: jwt.Audience{"https://consul.test"},
Issuer: tc.issuer,
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Second)),
}
type orgs struct {
Primary string `json:"primary"`
}
privateCl := struct {
FirstName string `json:"first_name"`
Org orgs `json:"org"`
Groups []string `json:"https://consul.test/groups"`
}{
FirstName: "jeff2",
Org: orgs{"engineering"},
Groups: []string{"foo", "bar"},
}
jwtData, err := oidcauthtest.SignJWT(privKey, cl, privateCl)
require.NoError(t, err)
id, err := v.ValidateLogin(context.Background(), jwtData)
if tc.expectErr != "" {
testutil.RequireErrorContains(t, err, tc.expectErr)
} else {
require.NoError(t, err)
authmethod.RequireIdentityMatch(t, id, map[string]string{
"value.name": "jeff2",
"value.primary_org": "engineering",
},
"value.name == jeff2",
"value.name != jeff",
"value.primary_org == engineering",
"foo in list.groups",
"bar in list.groups",
"salt not in list.groups",
)
}
})
}
}
func TestNewIdentity(t *testing.T) {
// This is only based on claim mappings, so we'll just use the JWT type
// since that's cheaper to setup.
cases := map[string]struct {
claimMappings map[string]string
listClaimMappings map[string]string
expectVars map[string]string
expectFilters []string
}{
"nil": {nil, nil, kv(), nil},
"empty": {kv(), kv(), kv(), nil},
"one value mapping": {
kv("foo1", "val1"),
kv(),
kv("value.val1", ""),
[]string{`value.val1 == ""`},
},
"one list mapping": {kv(),
kv("foo2", "val2"),
kv(),
nil,
},
"one of each": {
kv("foo1", "val1"),
kv("foo2", "val2"),
kv("value.val1", ""),
[]string{`value.val1 == ""`},
},
"two value mappings": {
kv("foo1", "val1", "foo2", "val2"),
kv(),
kv("value.val1", "", "value.val2", ""),
[]string{`value.val1 == ""`, `value.val2 == ""`},
},
}
pubKey, _ := oidcauthtest.SigningKeys()
for name, tc := range cases {
tc := tc
t.Run(name, func(t *testing.T) {
method := &structs.ACLAuthMethod{
Name: "test-method",
Type: "jwt",
Config: map[string]interface{}{
"BoundIssuer": "https://legit.issuer.internal/",
"JWTValidationPubKeys": []string{pubKey},
"ClaimMappings": tc.claimMappings,
"ListClaimMappings": tc.listClaimMappings,
},
}
nullLogger := hclog.NewNullLogger()
v, err := NewValidator(nullLogger, method)
require.NoError(t, err)
id := v.NewIdentity()
authmethod.RequireIdentityMatch(t, id, tc.expectVars, tc.expectFilters...)
})
}
}
func kv(a ...string) map[string]string {
if len(a)%2 != 0 {
panic("kv() requires even numbers of arguments")
}
m := make(map[string]string)
for i := 0; i < len(a); i += 2 {
m[a[i]] = a[i+1]
}
return m
}
func startTestServer(t *testing.T) *oidcauthtest.Server {
ports := freeport.MustTake(1)
return oidcauthtest.Start(t, oidcauthtest.WithPort(
ports[0],
func() { freeport.Return(ports) },
))
}

View file

@ -305,12 +305,73 @@ func (c *KubernetesAuthMethodConfig) RenderToConfig() map[string]interface{} {
}
}
// OIDCAuthMethodConfig is the config for the built-in Consul auth method for
// OIDC and JWT.
type OIDCAuthMethodConfig struct {
// common for type=oidc and type=jwt
JWTSupportedAlgs []string `json:",omitempty"`
BoundAudiences []string `json:",omitempty"`
ClaimMappings map[string]string `json:",omitempty"`
ListClaimMappings map[string]string `json:",omitempty"`
OIDCDiscoveryURL string `json:",omitempty"`
OIDCDiscoveryCACert string `json:",omitempty"`
// just for type=oidc
OIDCClientID string `json:",omitempty"`
OIDCClientSecret string `json:",omitempty"`
OIDCScopes []string `json:",omitempty"`
AllowedRedirectURIs []string `json:",omitempty"`
VerboseOIDCLogging bool `json:",omitempty"`
// just for type=jwt
JWKSURL string `json:",omitempty"`
JWKSCACert string `json:",omitempty"`
JWTValidationPubKeys []string `json:",omitempty"`
BoundIssuer string `json:",omitempty"`
ExpirationLeeway time.Duration `json:",omitempty"`
NotBeforeLeeway time.Duration `json:",omitempty"`
ClockSkewLeeway time.Duration `json:",omitempty"`
}
// RenderToConfig converts this into a map[string]interface{} suitable for use
// in the ACLAuthMethod.Config field.
func (c *OIDCAuthMethodConfig) RenderToConfig() map[string]interface{} {
return map[string]interface{}{
// common for type=oidc and type=jwt
"JWTSupportedAlgs": c.JWTSupportedAlgs,
"BoundAudiences": c.BoundAudiences,
"ClaimMappings": c.ClaimMappings,
"ListClaimMappings": c.ListClaimMappings,
"OIDCDiscoveryURL": c.OIDCDiscoveryURL,
"OIDCDiscoveryCACert": c.OIDCDiscoveryCACert,
// just for type=oidc
"OIDCClientID": c.OIDCClientID,
"OIDCClientSecret": c.OIDCClientSecret,
"OIDCScopes": c.OIDCScopes,
"AllowedRedirectURIs": c.AllowedRedirectURIs,
"VerboseOIDCLogging": c.VerboseOIDCLogging,
// just for type=jwt
"JWKSURL": c.JWKSURL,
"JWKSCACert": c.JWKSCACert,
"JWTValidationPubKeys": c.JWTValidationPubKeys,
"BoundIssuer": c.BoundIssuer,
"ExpirationLeeway": c.ExpirationLeeway,
"NotBeforeLeeway": c.NotBeforeLeeway,
"ClockSkewLeeway": c.ClockSkewLeeway,
}
}
type ACLLoginParams struct {
AuthMethod string
BearerToken string
Meta map[string]string `json:",omitempty"`
}
type ACLOIDCAuthURLParams struct {
AuthMethod string
RedirectURI string
ClientNonce string
Meta map[string]string `json:",omitempty"`
}
// ACL can be used to query the ACL endpoints
type ACL struct {
c *Client
@ -1227,3 +1288,62 @@ func (a *ACL) Logout(q *WriteOptions) (*WriteMeta, error) {
wm := &WriteMeta{RequestTime: rtt}
return wm, nil
}
// OIDCAuthURL requests an authorization URL to start an OIDC login flow.
func (a *ACL) OIDCAuthURL(auth *ACLOIDCAuthURLParams, q *WriteOptions) (string, *WriteMeta, error) {
if auth.AuthMethod == "" {
return "", nil, fmt.Errorf("Must specify an auth method name")
}
r := a.c.newRequest("POST", "/v1/acl/oidc/auth-url")
r.setWriteOptions(q)
r.obj = auth
rtt, resp, err := requireOK(a.c.doRequest(r))
if err != nil {
return "", nil, err
}
defer resp.Body.Close()
wm := &WriteMeta{RequestTime: rtt}
var out aclOIDCAuthURLResponse
if err := decodeBody(resp, &out); err != nil {
return "", nil, err
}
return out.AuthURL, wm, nil
}
type aclOIDCAuthURLResponse struct {
AuthURL string
}
type ACLOIDCCallbackParams struct {
AuthMethod string
State string
Code string
ClientNonce string
}
// OIDCCallback is the callback endpoint to complete an OIDC login.
func (a *ACL) OIDCCallback(auth *ACLOIDCCallbackParams, q *WriteOptions) (*ACLToken, *WriteMeta, error) {
if auth.AuthMethod == "" {
return nil, nil, fmt.Errorf("Must specify an auth method name")
}
r := a.c.newRequest("POST", "/v1/acl/oidc/callback")
r.setWriteOptions(q)
r.obj = auth
rtt, resp, err := requireOK(a.c.doRequest(r))
if err != nil {
return nil, nil, err
}
defer resp.Body.Close()
wm := &WriteMeta{RequestTime: rtt}
var out ACLToken
if err := decodeBody(resp, &out); err != nil {
return nil, nil, err
}
return &out, wm, nil
}

View file

@ -30,9 +30,13 @@ type cmd struct {
// flags
authMethodName string
authMethodType string
bearerTokenFile string
tokenSinkFile string
meta map[string]string
enterpriseCmd
}
func (c *cmd) init() {
@ -41,6 +45,9 @@ func (c *cmd) init() {
c.flags.StringVar(&c.authMethodName, "method", "",
"Name of the auth method to login to.")
c.flags.StringVar(&c.authMethodType, "type", "",
"Type of the auth method to login to. This field is optional and defaults to no type.")
c.flags.StringVar(&c.bearerTokenFile, "bearer-token-file", "",
"Path to a file containing a secret bearer token to use with this auth method.")
@ -51,6 +58,8 @@ func (c *cmd) init() {
"Metadata to set on the token, formatted as key=value. This flag "+
"may be specified multiple times to set multiple meta fields.")
c.initEnterpriseFlags()
c.http = &flags.HTTPFlags{}
flags.Merge(c.flags, c.http.ClientFlags())
flags.Merge(c.flags, c.http.ServerFlags())
@ -76,6 +85,10 @@ func (c *cmd) Run(args []string) int {
return 1
}
return c.login()
}
func (c *cmd) bearerTokenLogin() int {
if c.bearerTokenFile == "" {
c.UI.Error(fmt.Sprintf("Missing required '-bearer-token-file' flag"))
return 1

View file

@ -0,0 +1,13 @@
//+build !consulent
package login
type enterpriseCmd struct {
}
func (c *cmd) initEnterpriseFlags() {
}
func (c *cmd) login() int {
return c.bearerTokenLogin()
}

View file

@ -6,16 +6,20 @@ import (
"path/filepath"
"strings"
"testing"
"time"
"github.com/hashicorp/consul/agent"
"github.com/hashicorp/consul/agent/consul/authmethod/kubeauth"
"github.com/hashicorp/consul/agent/consul/authmethod/testauth"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/command/acl"
"github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest"
"github.com/hashicorp/consul/sdk/freeport"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/consul/testrpc"
"github.com/mitchellh/cli"
"github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2/jwt"
)
func TestLoginCommand_noTabs(t *testing.T) {
@ -314,3 +318,154 @@ func TestLoginCommand_k8s(t *testing.T) {
require.Len(t, token, 36, "must be a valid uid: %s", token)
})
}
func TestLoginCommand_jwt(t *testing.T) {
t.Parallel()
testDir := testutil.TempDir(t, "acl")
defer os.RemoveAll(testDir)
a := agent.NewTestAgent(t, `
primary_datacenter = "dc1"
acl {
enabled = true
tokens {
master = "root"
}
}`)
defer a.Shutdown()
testrpc.WaitForLeader(t, a.RPC, "dc1")
client := a.Client()
tokenSinkFile := filepath.Join(testDir, "test.token")
bearerTokenFile := filepath.Join(testDir, "bearer.token")
// spin up a fake oidc server
oidcServer := startSSOTestServer(t)
pubKey, privKey := oidcServer.SigningKeys()
type mConfig = map[string]interface{}
cases := map[string]struct {
f func(config mConfig)
issuer string
expectErr string
}{
"success - jwt static keys": {func(config mConfig) {
config["BoundIssuer"] = "https://legit.issuer.internal/"
config["JWTValidationPubKeys"] = []string{pubKey}
},
"https://legit.issuer.internal/",
""},
"success - jwt jwks": {func(config mConfig) {
config["JWKSURL"] = oidcServer.Addr() + "/certs"
config["JWKSCACert"] = oidcServer.CACert()
},
"https://legit.issuer.internal/",
""},
"success - jwt oidc discovery": {func(config mConfig) {
config["OIDCDiscoveryURL"] = oidcServer.Addr()
config["OIDCDiscoveryCACert"] = oidcServer.CACert()
},
oidcServer.Addr(),
""},
}
for name, tc := range cases {
tc := tc
t.Run(name, func(t *testing.T) {
method := &api.ACLAuthMethod{
Name: "jwt",
Type: "jwt",
Config: map[string]interface{}{
"JWTSupportedAlgs": []string{"ES256"},
"ClaimMappings": map[string]string{
"first_name": "name",
"/org/primary": "primary_org",
},
"ListClaimMappings": map[string]string{
"https://consul.test/groups": "groups",
},
"BoundAudiences": []string{"https://consul.test"},
},
}
if tc.f != nil {
tc.f(method.Config)
}
_, _, err := client.ACL().AuthMethodCreate(
method,
&api.WriteOptions{Token: "root"},
)
require.NoError(t, err)
_, _, err = client.ACL().BindingRuleCreate(&api.ACLBindingRule{
AuthMethod: "jwt",
BindType: api.BindingRuleBindTypeService,
BindName: "test--${value.name}--${value.primary_org}",
Selector: "value.name == jeff2 and value.primary_org == engineering and foo in list.groups",
},
&api.WriteOptions{Token: "root"},
)
require.NoError(t, err)
cl := jwt.Claims{
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
Audience: jwt.Audience{"https://consul.test"},
Issuer: tc.issuer,
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Second)),
}
type orgs struct {
Primary string `json:"primary"`
}
privateCl := struct {
FirstName string `json:"first_name"`
Org orgs `json:"org"`
Groups []string `json:"https://consul.test/groups"`
}{
FirstName: "jeff2",
Org: orgs{"engineering"},
Groups: []string{"foo", "bar"},
}
// Drop a JWT on disk.
jwtData, err := oidcauthtest.SignJWT(privKey, cl, privateCl)
require.NoError(t, err)
require.NoError(t, ioutil.WriteFile(bearerTokenFile, []byte(jwtData), 0600))
defer os.Remove(tokenSinkFile)
ui := cli.NewMockUi()
cmd := New(ui)
args := []string{
"-http-addr=" + a.HTTPAddr(),
"-token=root",
"-method=jwt",
"-token-sink-file", tokenSinkFile,
"-bearer-token-file", bearerTokenFile,
}
code := cmd.Run(args)
require.Equal(t, 0, code, "err: %s", ui.ErrorWriter.String())
require.Empty(t, ui.ErrorWriter.String())
require.Empty(t, ui.OutputWriter.String())
raw, err := ioutil.ReadFile(tokenSinkFile)
require.NoError(t, err)
token := strings.TrimSpace(string(raw))
require.Len(t, token, 36, "must be a valid uid: %s", token)
})
}
}
func startSSOTestServer(t *testing.T) *oidcauthtest.Server {
ports := freeport.MustTake(1)
return oidcauthtest.Start(t, oidcauthtest.WithPort(
ports[0],
func() { freeport.Return(ports) },
))
}

9
go.mod
View file

@ -17,6 +17,7 @@ require (
github.com/armon/go-radix v1.0.0
github.com/aws/aws-sdk-go v1.25.41
github.com/coredns/coredns v1.1.2
github.com/coreos/go-oidc v2.1.0+incompatible
github.com/digitalocean/godo v1.10.0 // indirect
github.com/docker/go-connections v0.3.0
github.com/elazarl/go-bindata-assetfs v0.0.0-20160803192304-e1a2a7ec64b0
@ -43,7 +44,7 @@ require (
github.com/hashicorp/go-raftchunking v0.6.1
github.com/hashicorp/go-sockaddr v1.0.2
github.com/hashicorp/go-syslog v1.0.0
github.com/hashicorp/go-uuid v1.0.1
github.com/hashicorp/go-uuid v1.0.2
github.com/hashicorp/go-version v1.2.0
github.com/hashicorp/golang-lru v0.5.1
github.com/hashicorp/hcl v1.0.0
@ -65,9 +66,12 @@ require (
github.com/mitchellh/go-testing-interface v1.14.0
github.com/mitchellh/hashstructure v0.0.0-20170609045927-2bca23e0e452
github.com/mitchellh/mapstructure v1.2.3
github.com/mitchellh/pointerstructure v1.0.0
github.com/mitchellh/reflectwalk v1.0.1
github.com/pascaldekloe/goe v0.1.0
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pkg/errors v0.8.1
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
github.com/prometheus/client_golang v1.0.0
github.com/rboyer/safeio v0.2.1
github.com/ryanuber/columnize v2.1.0+incompatible
@ -78,6 +82,7 @@ require (
go.opencensus.io v0.22.0 // indirect
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975
golang.org/x/net v0.0.0-20191004110552-13f9640d40b9
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45
golang.org/x/sync v0.0.0-20190423024810-112230192c58
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4
@ -85,7 +90,7 @@ require (
google.golang.org/appengine v1.6.0 // indirect
google.golang.org/genproto v0.0.0-20190530194941-fb225487d101 // indirect
google.golang.org/grpc v1.23.0
gopkg.in/square/go-jose.v2 v2.3.1
gopkg.in/square/go-jose.v2 v2.4.1
k8s.io/api v0.16.9
k8s.io/apimachinery v0.16.9
k8s.io/client-go v0.16.9

12
go.sum
View file

@ -79,6 +79,8 @@ github.com/coredns/coredns v1.1.2/go.mod h1:zASH/MVDgR6XZTbxvOnsZfffS+31vg6Ackf/
github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk=
github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk=
github.com/coreos/go-oidc v2.1.0+incompatible h1:sdJrfw8akMnCuUlaZU3tE/uYXFgfqom8DBE9so9EBsM=
github.com/coreos/go-oidc v2.1.0+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc=
github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA=
@ -230,6 +232,8 @@ github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdv
github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/go-uuid v1.0.1 h1:fv1ep09latC32wFoVwnqcnKJGnMSdBanPczbHAYm1BE=
github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE=
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/go-version v1.1.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
github.com/hashicorp/go-version v1.2.0 h1:3vNe/fWF5CBgRIguda1meWhsZHy3m8gCJ5wx+dIzX/E=
github.com/hashicorp/go-version v1.2.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
@ -343,6 +347,8 @@ github.com/mitchellh/mapstructure v1.1.2 h1:fmNYVwqnSfB9mZU6OS2O6GsXM+wcskZDuKQz
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
github.com/mitchellh/mapstructure v1.2.3 h1:f/MjBEBDLttYCGfRaKBbKSRVF5aV2O6fnBpzknuE3jU=
github.com/mitchellh/mapstructure v1.2.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/mitchellh/pointerstructure v1.0.0 h1:ATSdz4NWrmWPOF1CeCBU4sMCno2hgqdbSrRPFWQSVZI=
github.com/mitchellh/pointerstructure v1.0.0/go.mod h1:k4XwG94++jLVsSiTxo7qdIfXA9pj9EAeo0QsNNJOLZ8=
github.com/mitchellh/reflectwalk v1.0.0/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
github.com/mitchellh/reflectwalk v1.0.1 h1:FVzMWA5RllMAKIdUSC8mdWo3XtwoecrH79BY70sEEpE=
github.com/mitchellh/reflectwalk v1.0.1/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
@ -371,6 +377,8 @@ github.com/packethost/packngo v0.1.1-0.20180711074735-b9cb5096f54c/go.mod h1:otz
github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
github.com/pascaldekloe/goe v0.1.0 h1:cBOtyMzM9HTpWjXfbbunk26uA6nG3a8n06Wieeh0MwY=
github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
github.com/peterbourgon/diskv v2.0.1+incompatible h1:UBdAOUP5p4RWqPBg048CAvpKN+vxiaj6gdUUzhl4XmI=
github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU=
@ -386,6 +394,8 @@ github.com/posener/complete v1.1.1 h1:ccV59UEOTzVDnDUEFdT95ZzHVZ+5+158q8+SJb2QV5
github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI=
github.com/posener/complete v1.2.3 h1:NP0eAhjcjImqslEwo/1hq7gpajME0fTLTezBKDqfXqo=
github.com/posener/complete v1.2.3/go.mod h1:WZIdtGGp+qx0sLrYKtIRAruyNpv6hFCicSgv7Sy7s/s=
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA=
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
github.com/prometheus/client_golang v0.9.2/go.mod h1:OsXs2jCmiKlQ1lTBmv21f2mNfw4xf/QclQDMrYNZzcM=
github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso=
@ -610,6 +620,8 @@ gopkg.in/resty.v1 v1.12.0 h1:CuXP0Pjfw9rOuY6EP+UvtNvt5DSqHpIxILZKT/quCZI=
gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo=
gopkg.in/square/go-jose.v2 v2.3.1 h1:SK5KegNXmKmqE342YYN2qPHEnUYeoMiXXl1poUlI+o4=
gopkg.in/square/go-jose.v2 v2.3.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/square/go-jose.v2 v2.4.1 h1:H0TmLt7/KmzlrDOpa1F+zr0Tk90PbJYBfsVUmRLrf9Y=
gopkg.in/square/go-jose.v2 v2.4.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

View file

@ -0,0 +1,8 @@
# go-sso
This is a Go library that is being incubated in Consul to assist in doing
opinionated OIDC-based single sign on.
The `go.mod.sample` and `go.sum.sample` files are what the overall real
`go.mod` and `go.sum` files should end up being when extracted from the Consul
codebase.

View file

@ -0,0 +1,18 @@
module github.com/hashicorp/consul/go-sso
go 1.13
require (
github.com/coreos/go-oidc v2.1.0+incompatible
github.com/hashicorp/go-cleanhttp v0.5.1
github.com/hashicorp/go-hclog v0.12.0
github.com/hashicorp/go-uuid v1.0.2
github.com/mitchellh/go-testing-interface v1.14.0
github.com/mitchellh/pointerstructure v1.0.0
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
github.com/stretchr/testify v1.2.2
golang.org/x/crypto v0.0.0-20191106202628-ed6320f186d4 // indirect
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45
gopkg.in/square/go-jose.v2 v2.4.1
)

View file

@ -0,0 +1,56 @@
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
github.com/coreos/go-oidc v2.1.0+incompatible h1:sdJrfw8akMnCuUlaZU3tE/uYXFgfqom8DBE9so9EBsM=
github.com/coreos/go-oidc v2.1.0+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys=
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/hashicorp/go-cleanhttp v0.5.1 h1:dH3aiDG9Jvb5r5+bYHsikaOUIpcM0xvgMXVoDkXMzJM=
github.com/hashicorp/go-cleanhttp v0.5.1/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80=
github.com/hashicorp/go-hclog v0.12.0 h1:d4QkX8FRTYaKaCZBoXYY8zJX2BXjWxurN/GA2tkrmZM=
github.com/hashicorp/go-hclog v0.12.0/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ=
github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE=
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA=
github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/mattn/go-isatty v0.0.10 h1:qxFzApOv4WsAL965uUPIsXzAKCZxN2p9UqdhFS4ZW10=
github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcMEpPG5Rm84=
github.com/mitchellh/go-testing-interface v1.14.0 h1:/x0XQ6h+3U3nAyk1yx+bHPURrKa9sVVvYbuqZ7pIAtI=
github.com/mitchellh/go-testing-interface v1.14.0/go.mod h1:gfgS7OtZj6MA4U1UrDRp04twqAjfvlZyCfX3sDjEym8=
github.com/mitchellh/mapstructure v1.1.2 h1:fmNYVwqnSfB9mZU6OS2O6GsXM+wcskZDuKQzvN1EDeE=
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
github.com/mitchellh/pointerstructure v1.0.0 h1:ATSdz4NWrmWPOF1CeCBU4sMCno2hgqdbSrRPFWQSVZI=
github.com/mitchellh/pointerstructure v1.0.0/go.mod h1:k4XwG94++jLVsSiTxo7qdIfXA9pj9EAeo0QsNNJOLZ8=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA=
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191106202628-ed6320f186d4 h1:PDpCLFAH/YIX0QpHPf2eO7L4rC2OOirBrKtXTLLiNTY=
golang.org/x/crypto v0.0.0-20191106202628-ed6320f186d4/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45 h1:SVwTIAaPC2U/AvvLNZ2a7OVsmBpC8L5BlwK1whH3hm0=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191008105621-543471e840be h1:QAcqgptGM8IQBC9K/RC4o+O9YmqEm0diQn9QmZw/0mU=
golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
gopkg.in/square/go-jose.v2 v2.4.1 h1:H0TmLt7/KmzlrDOpa1F+zr0Tk90PbJYBfsVUmRLrf9Y=
gopkg.in/square/go-jose.v2 v2.4.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=

View file

@ -0,0 +1,129 @@
// package oidcauth bundles up an opinionated approach to authentication using
// both the OIDC authorization code workflow and simple JWT decoding (via
// static keys, JWKS, and OIDC discovery).
//
// NOTE: This was roughly forked from hashicorp/vault-plugin-auth-jwt
// originally at commit 825c85535e3832d254a74253a8e9ae105357778b with later
// backports of behavior in 0e93b06cecb0477d6ee004e44b04832d110096cf
package oidcauth
import (
"context"
"fmt"
"net/http"
"sync"
"github.com/coreos/go-oidc"
"github.com/hashicorp/go-hclog"
"github.com/patrickmn/go-cache"
)
// Claims represents a set of claims or assertions computed about a given
// authentication exchange.
type Claims struct {
// Values is a set of key/value string claims about the authentication
// exchange.
Values map[string]string
// Lists is a set of key/value string list claims about the authentication
// exchange.
Lists map[string][]string
}
// Authenticator allows for extracting a set of claims from either an OIDC
// authorization code exchange or a bare JWT.
type Authenticator struct {
config *Config
logger hclog.Logger
// parsedJWTPubKeys is the parsed form of config.JWTValidationPubKeys
parsedJWTPubKeys []interface{}
provider *oidc.Provider
keySet oidc.KeySet
// httpClient should be configured with all relevant root CA certs and be
// reused for all OIDC or JWKS operations. This will be nil for the static
// keys JWT configuration.
httpClient *http.Client
l sync.Mutex
oidcStates *cache.Cache
// backgroundCtx is a cancellable context primarily meant to be used for
// things that may spawn background goroutines and are not tied to a
// request/response lifecycle. Use backgroundCtxCancel to cancel this.
backgroundCtx context.Context
backgroundCtxCancel context.CancelFunc
}
// New creates an authenticator suitable for use with either an OIDC
// authorization code workflow or a bare JWT workflow depending upon the value
// of the config Type.
func New(c *Config, logger hclog.Logger) (*Authenticator, error) {
if err := c.Validate(); err != nil {
return nil, err
}
var parsedJWTPubKeys []interface{}
if c.Type == TypeJWT {
for _, v := range c.JWTValidationPubKeys {
key, err := parsePublicKeyPEM([]byte(v))
if err != nil {
// This shouldn't happen as the keys are already validated in Validate().
return nil, fmt.Errorf("error parsing public key: %v", err)
}
parsedJWTPubKeys = append(parsedJWTPubKeys, key)
}
}
a := &Authenticator{
config: c,
logger: logger,
parsedJWTPubKeys: parsedJWTPubKeys,
}
a.backgroundCtx, a.backgroundCtxCancel = context.WithCancel(context.Background())
if c.Type == TypeOIDC {
a.oidcStates = cache.New(oidcStateTimeout, oidcStateCleanupInterval)
}
var err error
switch c.authType() {
case authOIDCDiscovery, authOIDCFlow:
a.httpClient, err = createHTTPClient(a.config.OIDCDiscoveryCACert)
if err != nil {
return nil, fmt.Errorf("error parsing OIDCDiscoveryCACert: %v", err)
}
provider, err := oidc.NewProvider(
contextWithHttpClient(a.backgroundCtx, a.httpClient),
a.config.OIDCDiscoveryURL,
)
if err != nil {
return nil, fmt.Errorf("error creating provider: %v", err)
}
a.provider = provider
case authJWKS:
a.httpClient, err = createHTTPClient(a.config.JWKSCACert)
if err != nil {
return nil, fmt.Errorf("error parsing JWKSCACert: %v", err)
}
a.keySet = oidc.NewRemoteKeySet(
contextWithHttpClient(a.backgroundCtx, a.httpClient),
a.config.JWKSURL,
)
}
return a, nil
}
// Stop stops any background goroutines and does cleanup.
func (a *Authenticator) Stop() {
a.l.Lock()
defer a.l.Unlock()
if a.backgroundCtxCancel != nil {
a.backgroundCtxCancel()
a.backgroundCtxCancel = nil
}
}

View file

@ -0,0 +1,348 @@
package oidcauth
import (
"context"
"errors"
"fmt"
"net/url"
"strings"
"time"
"github.com/coreos/go-oidc"
)
const (
// TypeOIDC is the config type to specify if the OIDC authorization code
// workflow is desired. The Authenticator methods GetAuthCodeURL and
// ClaimsFromAuthCode are activated with the type.
TypeOIDC = "oidc"
// TypeJWT is the config type to specify if simple JWT decoding (via static
// keys, JWKS, and OIDC discovery) is desired. The Authenticator method
// ClaimsFromJWT is activated with this type.
TypeJWT = "jwt"
)
// Config is the collection of all settings that pertain to doing OIDC-based
// authentication and direct JWT-based authentication processes.
type Config struct {
// Type defines which kind of authentication will be happening, OIDC-based
// or JWT-based. Allowed values are either 'oidc' or 'jwt'.
//
// Defaults to 'oidc' if unset.
Type string
// -------
// common for type=oidc and type=jwt
// -------
// JWTSupportedAlgs is a list of supported signing algorithms. Defaults to
// RS256.
JWTSupportedAlgs []string
// Comma-separated list of 'aud' claims that are valid for login; any match
// is sufficient
// TODO(sso): actually just send these down as string claims?
BoundAudiences []string
// Mappings of claims (key) that will be copied to a metadata field
// (value). Use this if the claim you are capturing is singular (such as an
// attribute).
//
// When mapped, the values can be any of a number, string, or boolean and
// will all be stringified when returned.
ClaimMappings map[string]string
// Mappings of claims (key) that will be copied to a metadata field
// (value). Use this if the claim you are capturing is list-like (such as
// groups).
//
// When mapped, the values in each list can be any of a number, string, or
// boolean and will all be stringified when returned.
ListClaimMappings map[string]string
// OIDCDiscoveryURL is the OIDC Discovery URL, without any .well-known
// component (base path). Cannot be used with "JWKSURL" or
// "JWTValidationPubKeys".
OIDCDiscoveryURL string
// OIDCDiscoveryCACert is the CA certificate or chain of certificates, in
// PEM format, to use to validate connections to the OIDC Discovery URL. If
// not set, system certificates are used.
OIDCDiscoveryCACert string
// -------
// just for type=oidc
// -------
// OIDCClientID is the OAuth Client ID configured with your OIDC provider.
//
// Valid only if Type=oidc
OIDCClientID string
// The OAuth Client Secret configured with your OIDC provider.
//
// Valid only if Type=oidc
OIDCClientSecret string
// Comma-separated list of OIDC scopes
//
// Valid only if Type=oidc
OIDCScopes []string
// Comma-separated list of allowed values for redirect_uri
//
// Valid only if Type=oidc
AllowedRedirectURIs []string
// Log received OIDC tokens and claims when debug-level logging is active.
// Not recommended in production since sensitive information may be present
// in OIDC responses.
//
// Valid only if Type=oidc
VerboseOIDCLogging bool
// -------
// just for type=jwt
// -------
// JWKSURL is the JWKS URL to use to authenticate signatures. Cannot be
// used with "OIDCDiscoveryURL" or "JWTValidationPubKeys".
//
// Valid only if Type=jwt
JWKSURL string
// JWKSCACert is the CA certificate or chain of certificates, in PEM
// format, to use to validate connections to the JWKS URL. If not set,
// system certificates are used.
//
// Valid only if Type=jwt
JWKSCACert string
// JWTValidationPubKeys is a list of PEM-encoded public keys to use to
// authenticate signatures locally. Cannot be used with "JWKSURL" or
// "OIDCDiscoveryURL".
//
// Valid only if Type=jwt
JWTValidationPubKeys []string
// BoundIssuer is the value against which to match the 'iss' claim in a
// JWT. Optional.
//
// Valid only if Type=jwt
BoundIssuer string
// Duration in seconds of leeway when validating expiration of
// a token to account for clock skew.
//
// Defaults to 150 (2.5 minutes) if set to 0 and can be disabled if set to -1.`,
//
// Valid only if Type=jwt
ExpirationLeeway time.Duration
// Duration in seconds of leeway when validating not before values of a
// token to account for clock skew.
//
// Defaults to 150 (2.5 minutes) if set to 0 and can be disabled if set to
// -1.`,
//
// Valid only if Type=jwt
NotBeforeLeeway time.Duration
// Duration in seconds of leeway when validating all claims to account for
// clock skew.
//
// Defaults to 60 (1 minute) if set to 0 and can be disabled if set to
// -1.`,
//
// Valid only if Type=jwt
ClockSkewLeeway time.Duration
}
// Validate returns an error if the config is not valid.
func (c *Config) Validate() error {
validateCtx, validateCtxCancel := context.WithCancel(context.Background())
defer validateCtxCancel()
switch c.Type {
case TypeOIDC, "":
// required
switch {
case c.OIDCDiscoveryURL == "":
return fmt.Errorf("'OIDCDiscoveryURL' must be set for type %q", c.Type)
case c.OIDCClientID == "":
return fmt.Errorf("'OIDCClientID' must be set for type %q", c.Type)
case c.OIDCClientSecret == "":
return fmt.Errorf("'OIDCClientSecret' must be set for type %q", c.Type)
case len(c.AllowedRedirectURIs) == 0:
return fmt.Errorf("'AllowedRedirectURIs' must be set for type %q", c.Type)
}
// not allowed
switch {
case c.JWKSURL != "":
return fmt.Errorf("'JWKSURL' must not be set for type %q", c.Type)
case c.JWKSCACert != "":
return fmt.Errorf("'JWKSCACert' must not be set for type %q", c.Type)
case len(c.JWTValidationPubKeys) != 0:
return fmt.Errorf("'JWTValidationPubKeys' must not be set for type %q", c.Type)
case c.BoundIssuer != "":
return fmt.Errorf("'BoundIssuer' must not be set for type %q", c.Type)
case c.ExpirationLeeway != 0:
return fmt.Errorf("'ExpirationLeeway' must not be set for type %q", c.Type)
case c.NotBeforeLeeway != 0:
return fmt.Errorf("'NotBeforeLeeway' must not be set for type %q", c.Type)
case c.ClockSkewLeeway != 0:
return fmt.Errorf("'ClockSkewLeeway' must not be set for type %q", c.Type)
}
var bad []string
for _, allowed := range c.AllowedRedirectURIs {
if _, err := url.Parse(allowed); err != nil {
bad = append(bad, allowed)
}
}
if len(bad) > 0 {
return fmt.Errorf("Invalid AllowedRedirectURIs provided: %v", bad)
}
case TypeJWT:
// not allowed
switch {
case c.OIDCClientID != "":
return fmt.Errorf("'OIDCClientID' must not be set for type %q", c.Type)
case c.OIDCClientSecret != "":
return fmt.Errorf("'OIDCClientSecret' must not be set for type %q", c.Type)
case len(c.OIDCScopes) != 0:
return fmt.Errorf("'OIDCScopes' must not be set for type %q", c.Type)
case len(c.AllowedRedirectURIs) != 0:
return fmt.Errorf("'AllowedRedirectURIs' must not be set for type %q", c.Type)
case c.VerboseOIDCLogging:
return fmt.Errorf("'VerboseOIDCLogging' must not be set for type %q", c.Type)
}
methodCount := 0
if c.OIDCDiscoveryURL != "" {
methodCount++
}
if len(c.JWTValidationPubKeys) != 0 {
methodCount++
}
if c.JWKSURL != "" {
methodCount++
}
if methodCount != 1 {
return fmt.Errorf("exactly one of 'JWTValidationPubKeys', 'JWKSURL', or 'OIDCDiscoveryURL' must be set for type %q", c.Type)
}
if c.JWKSURL != "" {
httpClient, err := createHTTPClient(c.JWKSCACert)
if err != nil {
return fmt.Errorf("error checking JWKSCACert: %v", err)
}
ctx := contextWithHttpClient(validateCtx, httpClient)
keyset := oidc.NewRemoteKeySet(ctx, c.JWKSURL)
// Try to verify a correctly formatted JWT. The signature will fail
// to match, but other errors with fetching the remote keyset
// should be reported.
_, err = keyset.VerifySignature(ctx, testJWT)
if err == nil {
err = errors.New("unexpected verification of JWT")
}
if !strings.Contains(err.Error(), "failed to verify id token signature") {
return fmt.Errorf("error checking JWKSURL: %v", err)
}
} else if c.JWKSCACert != "" {
return fmt.Errorf("'JWKSCACert' should not be set unless 'JWKSURL' is set")
}
if len(c.JWTValidationPubKeys) != 0 {
for i, v := range c.JWTValidationPubKeys {
if _, err := parsePublicKeyPEM([]byte(v)); err != nil {
return fmt.Errorf("error parsing public key JWTValidationPubKeys[%d]: %v", i, err)
}
}
}
default:
return fmt.Errorf("authenticator type should be %q or %q", TypeOIDC, TypeJWT)
}
if c.OIDCDiscoveryURL != "" {
httpClient, err := createHTTPClient(c.OIDCDiscoveryCACert)
if err != nil {
return fmt.Errorf("error checking OIDCDiscoveryCACert: %v", err)
}
ctx := contextWithHttpClient(validateCtx, httpClient)
if _, err := oidc.NewProvider(ctx, c.OIDCDiscoveryURL); err != nil {
return fmt.Errorf("error checking OIDCDiscoveryURL: %v", err)
}
} else if c.OIDCDiscoveryCACert != "" {
return fmt.Errorf("'OIDCDiscoveryCACert' should not be set unless 'OIDCDiscoveryURL' is set")
}
for _, a := range c.JWTSupportedAlgs {
switch a {
case oidc.RS256, oidc.RS384, oidc.RS512,
oidc.ES256, oidc.ES384, oidc.ES512,
oidc.PS256, oidc.PS384, oidc.PS512:
default:
return fmt.Errorf("Invalid supported algorithm: %s", a)
}
}
if len(c.ClaimMappings) > 0 {
targets := make(map[string]bool)
for _, mappedKey := range c.ClaimMappings {
if targets[mappedKey] {
return fmt.Errorf("ClaimMappings contains multiple mappings for key %q", mappedKey)
}
targets[mappedKey] = true
}
}
if len(c.ListClaimMappings) > 0 {
targets := make(map[string]bool)
for _, mappedKey := range c.ListClaimMappings {
if targets[mappedKey] {
return fmt.Errorf("ListClaimMappings contains multiple mappings for key %q", mappedKey)
}
targets[mappedKey] = true
}
}
return nil
}
const (
authUnconfigured = iota
authStaticKeys
authJWKS
authOIDCDiscovery
authOIDCFlow
)
// authType classifies the authorization type/flow based on config parameters.
// It is only valid to invoke if Validate() returns a nil error.
func (c *Config) authType() int {
switch {
case len(c.JWTValidationPubKeys) > 0:
return authStaticKeys
case c.JWKSURL != "":
return authJWKS
case c.OIDCDiscoveryURL != "":
if c.OIDCClientID != "" && c.OIDCClientSecret != "" {
return authOIDCFlow
}
return authOIDCDiscovery
default:
return authUnconfigured
}
}
const testJWT = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Hf3E3iCHzqC5QIQ0nCqS1kw78IiQTRVzsLTuKoDIpdk"

View file

@ -0,0 +1,655 @@
package oidcauth
import (
"strings"
"testing"
"time"
"github.com/coreos/go-oidc"
"github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest"
"github.com/stretchr/testify/require"
)
func TestConfigValidate(t *testing.T) {
type testcase struct {
config Config
expectAuthType int
expectErr string
}
srv := oidcauthtest.Start(t)
oidcCases := map[string]testcase{
"all required": {
config: Config{
Type: TypeOIDC,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
OIDCClientID: "abc",
OIDCClientSecret: "def",
AllowedRedirectURIs: []string{"http://foo.test"},
},
expectAuthType: authOIDCFlow,
},
"missing required OIDCDiscoveryURL": {
config: Config{
Type: TypeOIDC,
// OIDCDiscoveryURL: srv.Addr(),
// OIDCDiscoveryCACert: srv.CACert(),
OIDCClientID: "abc",
OIDCClientSecret: "def",
AllowedRedirectURIs: []string{"http://foo.test"},
},
expectErr: "must be set for type",
},
"missing required OIDCClientID": {
config: Config{
Type: TypeOIDC,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
// OIDCClientID: "abc",
OIDCClientSecret: "def",
AllowedRedirectURIs: []string{"http://foo.test"},
},
expectErr: "must be set for type",
},
"missing required OIDCClientSecret": {
config: Config{
Type: TypeOIDC,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
OIDCClientID: "abc",
// OIDCClientSecret: "def",
AllowedRedirectURIs: []string{"http://foo.test"},
},
expectErr: "must be set for type",
},
"missing required AllowedRedirectURIs": {
config: Config{
Type: TypeOIDC,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
OIDCClientID: "abc",
OIDCClientSecret: "def",
AllowedRedirectURIs: []string{},
},
expectErr: "must be set for type",
},
"incompatible with JWKSURL": {
config: Config{
Type: TypeOIDC,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
OIDCClientID: "abc",
OIDCClientSecret: "def",
AllowedRedirectURIs: []string{"http://foo.test"},
JWKSURL: srv.Addr() + "/certs",
},
expectErr: "must not be set for type",
},
"incompatible with JWKSCACert": {
config: Config{
Type: TypeOIDC,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
OIDCClientID: "abc",
OIDCClientSecret: "def",
AllowedRedirectURIs: []string{"http://foo.test"},
JWKSCACert: srv.CACert(),
},
expectErr: "must not be set for type",
},
"incompatible with JWTValidationPubKeys": {
config: Config{
Type: TypeOIDC,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
OIDCClientID: "abc",
OIDCClientSecret: "def",
AllowedRedirectURIs: []string{"http://foo.test"},
JWTValidationPubKeys: []string{testJWTPubKey},
},
expectErr: "must not be set for type",
},
"incompatible with BoundIssuer": {
config: Config{
Type: TypeOIDC,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
OIDCClientID: "abc",
OIDCClientSecret: "def",
AllowedRedirectURIs: []string{"http://foo.test"},
BoundIssuer: "foo",
},
expectErr: "must not be set for type",
},
"incompatible with ExpirationLeeway": {
config: Config{
Type: TypeOIDC,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
OIDCClientID: "abc",
OIDCClientSecret: "def",
AllowedRedirectURIs: []string{"http://foo.test"},
ExpirationLeeway: 1 * time.Second,
},
expectErr: "must not be set for type",
},
"incompatible with NotBeforeLeeway": {
config: Config{
Type: TypeOIDC,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
OIDCClientID: "abc",
OIDCClientSecret: "def",
AllowedRedirectURIs: []string{"http://foo.test"},
NotBeforeLeeway: 1 * time.Second,
},
expectErr: "must not be set for type",
},
"incompatible with ClockSkewLeeway": {
config: Config{
Type: TypeOIDC,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
OIDCClientID: "abc",
OIDCClientSecret: "def",
AllowedRedirectURIs: []string{"http://foo.test"},
ClockSkewLeeway: 1 * time.Second,
},
expectErr: "must not be set for type",
},
"bad discovery cert": {
config: Config{
Type: TypeOIDC,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: oidcBadCACerts,
OIDCClientID: "abc",
OIDCClientSecret: "def",
AllowedRedirectURIs: []string{"http://foo.test"},
},
expectErr: "certificate signed by unknown authority",
},
"garbage discovery cert": {
config: Config{
Type: TypeOIDC,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: garbageCACert,
OIDCClientID: "abc",
OIDCClientSecret: "def",
AllowedRedirectURIs: []string{"http://foo.test"},
},
expectErr: "could not parse CA PEM value successfully",
},
"good discovery cert": {
config: Config{
Type: TypeOIDC,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
OIDCClientID: "abc",
OIDCClientSecret: "def",
AllowedRedirectURIs: []string{"http://foo.test"},
},
expectAuthType: authOIDCFlow,
},
"valid redirect uris": {
config: Config{
Type: TypeOIDC,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
OIDCClientID: "abc",
OIDCClientSecret: "def",
AllowedRedirectURIs: []string{
"http://foo.test",
"https://example.com",
"https://evilcorp.com:8443",
},
},
expectAuthType: authOIDCFlow,
},
"invalid redirect uris": {
config: Config{
Type: TypeOIDC,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
OIDCClientID: "abc",
OIDCClientSecret: "def",
AllowedRedirectURIs: []string{
"%%%%",
"http://foo.test",
"https://example.com",
"https://evilcorp.com:8443",
},
},
expectErr: "Invalid AllowedRedirectURIs provided: [%%%%]",
},
"valid algorithm": {
config: Config{
Type: TypeOIDC,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
OIDCClientID: "abc",
OIDCClientSecret: "def",
AllowedRedirectURIs: []string{"http://foo.test"},
JWTSupportedAlgs: []string{
oidc.RS256, oidc.RS384, oidc.RS512,
oidc.ES256, oidc.ES384, oidc.ES512,
oidc.PS256, oidc.PS384, oidc.PS512,
},
},
expectAuthType: authOIDCFlow,
},
"invalid algorithm": {
config: Config{
Type: TypeOIDC,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
OIDCClientID: "abc",
OIDCClientSecret: "def",
AllowedRedirectURIs: []string{"http://foo.test"},
JWTSupportedAlgs: []string{
oidc.RS256, oidc.RS384, oidc.RS512,
oidc.ES256, oidc.ES384, oidc.ES512,
oidc.PS256, oidc.PS384, oidc.PS512,
"foo",
},
},
expectErr: "Invalid supported algorithm",
},
"valid claim mappings": {
config: Config{
Type: TypeOIDC,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
OIDCClientID: "abc",
OIDCClientSecret: "def",
AllowedRedirectURIs: []string{"http://foo.test"},
ClaimMappings: map[string]string{
"foo": "bar",
"peanutbutter": "jelly",
"wd40": "ducttape",
},
ListClaimMappings: map[string]string{
"foo": "bar",
"peanutbutter": "jelly",
"wd40": "ducttape",
},
},
expectAuthType: authOIDCFlow,
},
"invalid repeated value claim mappings": {
config: Config{
Type: TypeOIDC,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
OIDCClientID: "abc",
OIDCClientSecret: "def",
AllowedRedirectURIs: []string{"http://foo.test"},
ClaimMappings: map[string]string{
"foo": "bar",
"bling": "bar",
"peanutbutter": "jelly",
"wd40": "ducttape",
},
ListClaimMappings: map[string]string{
"foo": "bar",
"peanutbutter": "jelly",
"wd40": "ducttape",
},
},
expectErr: "ClaimMappings contains multiple mappings for key",
},
"invalid repeated list claim mappings": {
config: Config{
Type: TypeOIDC,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
OIDCClientID: "abc",
OIDCClientSecret: "def",
AllowedRedirectURIs: []string{"http://foo.test"},
ClaimMappings: map[string]string{
"foo": "bar",
"peanutbutter": "jelly",
"wd40": "ducttape",
},
ListClaimMappings: map[string]string{
"foo": "bar",
"bling": "bar",
"peanutbutter": "jelly",
"wd40": "ducttape",
},
},
expectErr: "ListClaimMappings contains multiple mappings for key",
},
}
jwtCases := map[string]testcase{
"all required for oidc discovery": {
config: Config{
Type: TypeJWT,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
},
expectAuthType: authOIDCDiscovery,
},
"all required for jwks": {
config: Config{
Type: TypeJWT,
JWKSURL: srv.Addr() + "/certs",
JWKSCACert: srv.CACert(), // needed to avoid self signed cert issue
},
expectAuthType: authJWKS,
},
"all required for public keys": {
config: Config{
Type: TypeJWT,
JWTValidationPubKeys: []string{testJWTPubKey},
},
expectAuthType: authStaticKeys,
},
"incompatible with OIDCClientID": {
config: Config{
Type: TypeJWT,
JWTValidationPubKeys: []string{testJWTPubKey},
OIDCClientID: "abc",
},
expectErr: "must not be set for type",
},
"incompatible with OIDCClientSecret": {
config: Config{
Type: TypeJWT,
JWTValidationPubKeys: []string{testJWTPubKey},
OIDCClientSecret: "abc",
},
expectErr: "must not be set for type",
},
"incompatible with OIDCScopes": {
config: Config{
Type: TypeJWT,
JWTValidationPubKeys: []string{testJWTPubKey},
OIDCScopes: []string{"blah"},
},
expectErr: "must not be set for type",
},
"incompatible with AllowedRedirectURIs": {
config: Config{
Type: TypeJWT,
JWTValidationPubKeys: []string{testJWTPubKey},
AllowedRedirectURIs: []string{"http://foo.test"},
},
expectErr: "must not be set for type",
},
"incompatible with VerboseOIDCLogging": {
config: Config{
Type: TypeJWT,
JWTValidationPubKeys: []string{testJWTPubKey},
VerboseOIDCLogging: true,
},
expectErr: "must not be set for type",
},
"too many methods (discovery + jwks)": {
config: Config{
Type: TypeJWT,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
JWKSURL: srv.Addr() + "/certs",
JWKSCACert: srv.CACert(),
// JWTValidationPubKeys: []string{testJWTPubKey},
},
expectErr: "exactly one of",
},
"too many methods (discovery + pubkeys)": {
config: Config{
Type: TypeJWT,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
// JWKSURL: srv.Addr() + "/certs",
// JWKSCACert: srv.CACert(),
JWTValidationPubKeys: []string{testJWTPubKey},
},
expectErr: "exactly one of",
},
"too many methods (jwks + pubkeys)": {
config: Config{
Type: TypeJWT,
// OIDCDiscoveryURL: srv.Addr(),
// OIDCDiscoveryCACert: srv.CACert(),
JWKSURL: srv.Addr() + "/certs",
JWKSCACert: srv.CACert(),
JWTValidationPubKeys: []string{testJWTPubKey},
},
expectErr: "exactly one of",
},
"too many methods (discovery + jwks + pubkeys)": {
config: Config{
Type: TypeJWT,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
JWKSURL: srv.Addr() + "/certs",
JWKSCACert: srv.CACert(),
JWTValidationPubKeys: []string{testJWTPubKey},
},
expectErr: "exactly one of",
},
"incompatible with JWKSCACert": {
config: Config{
Type: TypeJWT,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
JWKSCACert: srv.CACert(),
},
expectErr: "should not be set unless",
},
"invalid pubkey": {
config: Config{
Type: TypeJWT,
JWTValidationPubKeys: []string{testJWTPubKeyBad},
},
expectErr: "error parsing public key",
},
"incompatible with OIDCDiscoveryCACert": {
config: Config{
Type: TypeJWT,
JWTValidationPubKeys: []string{testJWTPubKey},
OIDCDiscoveryCACert: srv.CACert(),
},
expectErr: "should not be set unless",
},
"bad discovery cert": {
config: Config{
Type: TypeJWT,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: oidcBadCACerts,
},
expectErr: "certificate signed by unknown authority",
},
"good discovery cert": {
config: Config{
Type: TypeJWT,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
},
expectAuthType: authOIDCDiscovery,
},
"jwks invalid 404": {
config: Config{
Type: TypeJWT,
JWKSURL: srv.Addr() + "/certs_missing",
JWKSCACert: srv.CACert(),
},
expectErr: "get keys failed",
},
"jwks mismatched certs": {
config: Config{
Type: TypeJWT,
JWKSURL: srv.Addr() + "/certs_invalid",
JWKSCACert: srv.CACert(),
},
expectErr: "failed to decode keys",
},
"jwks bad certs": {
config: Config{
Type: TypeJWT,
JWKSURL: srv.Addr() + "/certs_invalid",
JWKSCACert: garbageCACert,
},
expectErr: "could not parse CA PEM value successfully",
},
"valid algorithm": {
config: Config{
Type: TypeJWT,
JWTValidationPubKeys: []string{testJWTPubKey},
JWTSupportedAlgs: []string{
oidc.RS256, oidc.RS384, oidc.RS512,
oidc.ES256, oidc.ES384, oidc.ES512,
oidc.PS256, oidc.PS384, oidc.PS512,
},
},
expectAuthType: authStaticKeys,
},
"invalid algorithm": {
config: Config{
Type: TypeJWT,
JWTValidationPubKeys: []string{testJWTPubKey},
JWTSupportedAlgs: []string{
oidc.RS256, oidc.RS384, oidc.RS512,
oidc.ES256, oidc.ES384, oidc.ES512,
oidc.PS256, oidc.PS384, oidc.PS512,
"foo",
},
},
expectErr: "Invalid supported algorithm",
},
"valid claim mappings": {
config: Config{
Type: TypeJWT,
JWTValidationPubKeys: []string{testJWTPubKey},
ClaimMappings: map[string]string{
"foo": "bar",
"peanutbutter": "jelly",
"wd40": "ducttape",
},
ListClaimMappings: map[string]string{
"foo": "bar",
"peanutbutter": "jelly",
"wd40": "ducttape",
},
},
expectAuthType: authStaticKeys,
},
"invalid repeated value claim mappings": {
config: Config{
Type: TypeJWT,
JWTValidationPubKeys: []string{testJWTPubKey},
ClaimMappings: map[string]string{
"foo": "bar",
"bling": "bar",
"peanutbutter": "jelly",
"wd40": "ducttape",
},
ListClaimMappings: map[string]string{
"foo": "bar",
"peanutbutter": "jelly",
"wd40": "ducttape",
},
},
expectErr: "ClaimMappings contains multiple mappings for key",
},
"invalid repeated list claim mappings": {
config: Config{
Type: TypeJWT,
JWTValidationPubKeys: []string{testJWTPubKey},
ClaimMappings: map[string]string{
"foo": "bar",
"peanutbutter": "jelly",
"wd40": "ducttape",
},
ListClaimMappings: map[string]string{
"foo": "bar",
"bling": "bar",
"peanutbutter": "jelly",
"wd40": "ducttape",
},
},
expectErr: "ListClaimMappings contains multiple mappings for key",
},
}
cases := map[string]testcase{
"bad type": {
config: Config{Type: "invalid"},
expectErr: "authenticator type should be",
},
}
for k, v := range oidcCases {
cases["type=oidc/"+k] = v
v2 := v
v2.config.Type = ""
cases["type=inferred_oidc/"+k] = v2
}
for k, v := range jwtCases {
cases["type=jwt/"+k] = v
}
for name, tc := range cases {
tc := tc
t.Run(name, func(t *testing.T) {
err := tc.config.Validate()
if tc.expectErr != "" {
require.Error(t, err)
requireErrorContains(t, err, tc.expectErr)
} else {
require.NoError(t, err)
require.Equal(t, tc.expectAuthType, tc.config.authType())
}
})
}
}
func requireErrorContains(t *testing.T, err error, expectedErrorMessage string) {
t.Helper()
if err == nil {
t.Fatal("An error is expected but got nil.")
}
if !strings.Contains(err.Error(), expectedErrorMessage) {
t.Fatalf("unexpected error: %v", err)
}
}
const (
testJWTPubKey = `-----BEGIN PUBLIC KEY-----
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEEVs/o5+uQbTjL3chynL4wXgUg2R9
q9UU8I5mEovUf86QZ7kOBIjJwqnzD1omageEHWwHdBO6B+dFabmdT9POxg==
-----END PUBLIC KEY-----`
testJWTPubKeyBad = `-----BEGIN PUBLIC KEY-----
MFkwEwYHKoZIzj0CAQYIrollingyourricksEVs/o5+uQbTjL3chynL4wXgUg2R9
q9UU8I5mEovUf86QZ7kOBIjJwqnzD1omageEHWwHdBO6B+dFabmdT9POxg==
-----END PUBLIC KEY-----`
garbageCACert = `this is not a key`
oidcBadCACerts = `-----BEGIN CERTIFICATE-----
MIIDYDCCAkigAwIBAgIJAK8uAVsPxWKGMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV
BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX
aWRnaXRzIFB0eSBMdGQwHhcNMTgwNzA5MTgwODI5WhcNMjgwNzA2MTgwODI5WjBF
MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50
ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB
CgKCAQEA1eaEmIHKQqDlSadCtg6YY332qIMoeSb2iZTRhBRYBXRhMIKF3HoLXlI8
/3veheMnBQM7zxIeLwtJ4VuZVZcpJlqHdsXQVj6A8+8MlAzNh3+Xnv0tjZ83QLwZ
D6FWvMEzihxATD9uTCu2qRgeKnMYQFq4EG72AGb5094zfsXTAiwCfiRPVumiNbs4
Mr75vf+2DEhqZuyP7GR2n3BKzrWo62yAmgLQQ07zfd1u1buv8R72HCYXYpFul5qx
slZHU3yR+tLiBKOYB+C/VuB7hJZfVx25InIL1HTpIwWvmdk3QzpSpAGIAxWMXSzS
oRmBYGnsgR6WTymfXuokD4ZhHOpFZQIDAQABo1MwUTAdBgNVHQ4EFgQURh/QFJBn
hMXcgB1bWbGiU9B2VBQwHwYDVR0jBBgwFoAURh/QFJBnhMXcgB1bWbGiU9B2VBQw
DwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAr8CZLA3MQjMDWweS
ax9S1fRb8ifxZ4RqDcLj3dw5KZqnjEo8ggczR66T7vVXet/2TFBKYJAM0np26Z4A
WjZfrDT7/bHXseWQAUhw/k2d39o+Um4aXkGpg1Paky9D+ddMdbx1hFkYxDq6kYGd
PlBYSEiYQvVxDx7s7H0Yj9FWKO8WIO6BRUEvLlG7k/Xpp1OI6dV3nqwJ9CbcbqKt
ff4hAtoAmN0/x6yFclFFWX8s7bRGqmnoj39/r98kzeGFb/lPKgQjSVcBJuE7UO4k
8HP6vsnr/ruSlzUMv6XvHtT68kGC1qO3MfqiPhdSa4nxf9g/1xyBmAw/Uf90BJrm
sj9DpQ==
-----END CERTIFICATE-----`
)

View file

@ -0,0 +1,11 @@
package strutil
// StrListContains looks for a string in a list of strings.
func StrListContains(haystack []string, needle string) bool {
for _, item := range haystack {
if item == needle {
return true
}
}
return false
}

View file

@ -0,0 +1,32 @@
package strutil
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestStrListContains(t *testing.T) {
tests := []struct {
haystack []string
needle string
expected bool
}{
// found
{[]string{"a"}, "a", true},
{[]string{"a", "b", "c"}, "a", true},
{[]string{"a", "b", "c"}, "b", true},
{[]string{"a", "b", "c"}, "c", true},
// not found
{nil, "", false},
{[]string{}, "", false},
{[]string{"a"}, "", false},
{[]string{"a"}, "b", false},
{[]string{"a", "b", "c"}, "x", false},
}
for _, test := range tests {
ok := StrListContains(test.haystack, test.needle)
assert.Equal(t, test.expected, ok, "failed on %s/%v", test.needle, test.haystack)
}
}

View file

@ -0,0 +1,207 @@
package oidcauth
import (
"context"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"time"
"gopkg.in/square/go-jose.v2/jwt"
)
const claimDefaultLeeway = 150
// ClaimsFromJWT is unrelated to the OIDC authorization code workflow. This
// allows for a JWT to be directly validated and decoded into a set of claims.
//
// Requires the authenticator's config type be set to 'jwt'.
func (a *Authenticator) ClaimsFromJWT(ctx context.Context, jwt string) (*Claims, error) {
if a.config.authType() == authOIDCFlow {
return nil, fmt.Errorf("ClaimsFromJWT is incompatible with type %q", TypeOIDC)
}
if jwt == "" {
return nil, errors.New("missing jwt")
}
// Here is where things diverge. If it is using OIDC Discovery, validate that way;
// otherwise validate against the locally configured or JWKS keys. Once things are
// validated, we re-unify the request path when evaluating the claims.
var (
allClaims map[string]interface{}
err error
)
switch a.config.authType() {
case authStaticKeys, authJWKS:
allClaims, err = a.verifyVanillaJWT(ctx, jwt)
if err != nil {
return nil, err
}
case authOIDCDiscovery:
allClaims, err = a.verifyOIDCToken(ctx, jwt)
if err != nil {
return nil, err
}
default:
return nil, errors.New("unhandled case during login")
}
c, err := a.extractClaims(allClaims)
if err != nil {
return nil, err
}
if a.config.VerboseOIDCLogging && a.logger != nil {
a.logger.Debug("OIDC provider response", "extracted_claims", c)
}
return c, nil
}
func (a *Authenticator) verifyVanillaJWT(ctx context.Context, loginToken string) (map[string]interface{}, error) {
var (
allClaims = map[string]interface{}{}
claims = jwt.Claims{}
)
// TODO(sso): handle JWTSupportedAlgs
switch a.config.authType() {
case authJWKS:
// Verify signature (and only signature... other elements are checked later)
payload, err := a.keySet.VerifySignature(ctx, loginToken)
if err != nil {
return nil, fmt.Errorf("error verifying token: %v", err)
}
// Unmarshal payload into two copies: public claims for library verification, and a set
// of all received claims.
if err := json.Unmarshal(payload, &claims); err != nil {
return nil, fmt.Errorf("failed to unmarshal claims: %v", err)
}
if err := json.Unmarshal(payload, &allClaims); err != nil {
return nil, fmt.Errorf("failed to unmarshal claims: %v", err)
}
case authStaticKeys:
parsedJWT, err := jwt.ParseSigned(loginToken)
if err != nil {
return nil, fmt.Errorf("error parsing token: %v", err)
}
var valid bool
for _, key := range a.parsedJWTPubKeys {
if err := parsedJWT.Claims(key, &claims, &allClaims); err == nil {
valid = true
break
}
}
if !valid {
return nil, errors.New("no known key successfully validated the token signature")
}
default:
return nil, fmt.Errorf("unsupported auth type for this verifyVanillaJWT: %d", a.config.authType())
}
// We require notbefore or expiry; if only one is provided, we allow 5 minutes of leeway by default.
// Configurable by ExpirationLeeway and NotBeforeLeeway
if claims.IssuedAt == nil {
claims.IssuedAt = new(jwt.NumericDate)
}
if claims.Expiry == nil {
claims.Expiry = new(jwt.NumericDate)
}
if claims.NotBefore == nil {
claims.NotBefore = new(jwt.NumericDate)
}
if *claims.IssuedAt == 0 && *claims.Expiry == 0 && *claims.NotBefore == 0 {
return nil, errors.New("no issue time, notbefore, or expiration time encoded in token")
}
if *claims.Expiry == 0 {
latestStart := *claims.IssuedAt
if *claims.NotBefore > *claims.IssuedAt {
latestStart = *claims.NotBefore
}
leeway := a.config.ExpirationLeeway.Seconds()
if a.config.ExpirationLeeway.Seconds() < 0 {
leeway = 0
} else if a.config.ExpirationLeeway.Seconds() == 0 {
leeway = claimDefaultLeeway
}
*claims.Expiry = jwt.NumericDate(int64(latestStart) + int64(leeway))
}
if *claims.NotBefore == 0 {
if *claims.IssuedAt != 0 {
*claims.NotBefore = *claims.IssuedAt
} else {
leeway := a.config.NotBeforeLeeway.Seconds()
if a.config.NotBeforeLeeway.Seconds() < 0 {
leeway = 0
} else if a.config.NotBeforeLeeway.Seconds() == 0 {
leeway = claimDefaultLeeway
}
*claims.NotBefore = jwt.NumericDate(int64(*claims.Expiry) - int64(leeway))
}
}
expected := jwt.Expected{
Issuer: a.config.BoundIssuer,
// Subject: a.config.BoundSubject,
Time: time.Now(),
}
cksLeeway := a.config.ClockSkewLeeway
if a.config.ClockSkewLeeway.Seconds() < 0 {
cksLeeway = 0
} else if a.config.ClockSkewLeeway.Seconds() == 0 {
cksLeeway = jwt.DefaultLeeway
}
if err := claims.ValidateWithLeeway(expected, cksLeeway); err != nil {
return nil, fmt.Errorf("error validating claims: %v", err)
}
if err := validateAudience(a.config.BoundAudiences, claims.Audience, true); err != nil {
return nil, fmt.Errorf("error validating claims: %v", err)
}
return allClaims, nil
}
// parsePublicKeyPEM is used to parse RSA, ECDSA, and Ed25519 public keys from PEMs
//
// Extracted from "github.com/hashicorp/vault/sdk/helper/certutil"
//
// go-sso added support for ed25519 (EdDSA)
func parsePublicKeyPEM(data []byte) (interface{}, error) {
block, data := pem.Decode(data)
if block != nil {
var rawKey interface{}
var err error
if rawKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil {
if cert, err := x509.ParseCertificate(block.Bytes); err == nil {
rawKey = cert.PublicKey
} else {
return nil, err
}
}
if rsaPublicKey, ok := rawKey.(*rsa.PublicKey); ok {
return rsaPublicKey, nil
}
if ecPublicKey, ok := rawKey.(*ecdsa.PublicKey); ok {
return ecPublicKey, nil
}
if edPublicKey, ok := rawKey.(ed25519.PublicKey); ok {
return edPublicKey, nil
}
}
return nil, errors.New("data does not contain any valid RSA, ECDSA, or ED25519 public keys")
}

View file

@ -0,0 +1,695 @@
package oidcauth
import (
"context"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"testing"
"time"
"github.com/coreos/go-oidc"
"github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2/jwt"
)
func setupForJWT(t *testing.T, authType int, f func(c *Config)) (*Authenticator, string) {
t.Helper()
config := &Config{
Type: TypeJWT,
JWTSupportedAlgs: []string{oidc.ES256},
ClaimMappings: map[string]string{
"first_name": "name",
"/org/primary": "primary_org",
"/nested/Size": "size",
"Age": "age",
"Admin": "is_admin",
"/nested/division": "division",
"/nested/remote": "is_remote",
},
ListClaimMappings: map[string]string{
"https://go-sso/groups": "groups",
},
}
var issuer string
switch authType {
case authOIDCDiscovery:
srv := oidcauthtest.Start(t)
config.OIDCDiscoveryURL = srv.Addr()
config.OIDCDiscoveryCACert = srv.CACert()
issuer = config.OIDCDiscoveryURL
// TODO(sso): is this a bug in vault?
// config.BoundIssuer = issuer
case authStaticKeys:
pubKey, _ := oidcauthtest.SigningKeys()
config.BoundIssuer = "https://legit.issuer.internal/"
config.JWTValidationPubKeys = []string{pubKey}
issuer = config.BoundIssuer
case authJWKS:
srv := oidcauthtest.Start(t)
config.JWKSURL = srv.Addr() + "/certs"
config.JWKSCACert = srv.CACert()
issuer = "https://legit.issuer.internal/"
// TODO(sso): is this a bug in vault?
// config.BoundIssuer = issuer
default:
require.Fail(t, "inappropriate authType: %d", authType)
}
if f != nil {
f(config)
}
require.NoError(t, config.Validate())
oa, err := New(config, hclog.NewNullLogger())
require.NoError(t, err)
t.Cleanup(oa.Stop)
return oa, issuer
}
func TestJWT_OIDC_Functions_Fail(t *testing.T) {
t.Run("static", func(t *testing.T) {
testJWT_OIDC_Functions_Fail(t, authStaticKeys)
})
t.Run("JWKS", func(t *testing.T) {
testJWT_OIDC_Functions_Fail(t, authJWKS)
})
t.Run("oidc discovery", func(t *testing.T) {
testJWT_OIDC_Functions_Fail(t, authOIDCDiscovery)
})
}
func testJWT_OIDC_Functions_Fail(t *testing.T, authType int) {
t.Helper()
t.Run("GetAuthCodeURL", func(t *testing.T) {
oa, _ := setupForJWT(t, authType, nil)
_, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
map[string]string{"foo": "bar"},
)
requireErrorContains(t, err, `GetAuthCodeURL is incompatible with type "jwt"`)
})
t.Run("ClaimsFromAuthCode", func(t *testing.T) {
oa, _ := setupForJWT(t, authType, nil)
_, _, err := oa.ClaimsFromAuthCode(
context.Background(),
"abc", "def",
)
requireErrorContains(t, err, `ClaimsFromAuthCode is incompatible with type "jwt"`)
})
}
func TestJWT_ClaimsFromJWT(t *testing.T) {
t.Run("static", func(t *testing.T) {
testJWT_ClaimsFromJWT(t, authStaticKeys)
})
t.Run("JWKS", func(t *testing.T) {
testJWT_ClaimsFromJWT(t, authJWKS)
})
t.Run("oidc discovery", func(t *testing.T) {
// TODO(sso): the vault versions of these tests did not run oidc-discovery
testJWT_ClaimsFromJWT(t, authOIDCDiscovery)
})
}
func testJWT_ClaimsFromJWT(t *testing.T, authType int) {
t.Helper()
t.Run("missing audience", func(t *testing.T) {
if authType == authOIDCDiscovery {
// TODO(sso): why isn't this strict?
t.Skip("why?")
return
}
oa, issuer := setupForJWT(t, authType, nil)
cl := jwt.Claims{
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
Issuer: issuer,
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
Audience: jwt.Audience{"https://go-sso.test"},
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Second)),
}
privateCl := struct {
User string `json:"https://go-sso/user"`
Groups []string `json:"https://go-sso/groups"`
}{
"jeff",
[]string{"foo", "bar"},
}
jwtData, err := oidcauthtest.SignJWT("", cl, privateCl)
require.NoError(t, err)
_, err = oa.ClaimsFromJWT(context.Background(), jwtData)
requireErrorContains(t, err, "audience claim found in JWT but no audiences are bound")
})
t.Run("valid inputs", func(t *testing.T) {
oa, issuer := setupForJWT(t, authType, func(c *Config) {
c.BoundAudiences = []string{
"https://go-sso.test",
"another_audience",
}
})
cl := jwt.Claims{
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
Issuer: issuer,
Audience: jwt.Audience{"https://go-sso.test"},
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Second)),
}
type orgs struct {
Primary string `json:"primary"`
}
type nested struct {
Division int64 `json:"division"`
Remote bool `json:"remote"`
Size string `json:"Size"`
}
privateCl := struct {
User string `json:"https://go-sso/user"`
Groups []string `json:"https://go-sso/groups"`
FirstName string `json:"first_name"`
Org orgs `json:"org"`
Color string `json:"color"`
Age int64 `json:"Age"`
Admin bool `json:"Admin"`
Nested nested `json:"nested"`
}{
User: "jeff",
Groups: []string{"foo", "bar"},
FirstName: "jeff2",
Org: orgs{"engineering"},
Color: "green",
Age: 85,
Admin: true,
Nested: nested{
Division: 3,
Remote: true,
Size: "medium",
},
}
jwtData, err := oidcauthtest.SignJWT("", cl, privateCl)
require.NoError(t, err)
claims, err := oa.ClaimsFromJWT(context.Background(), jwtData)
require.NoError(t, err)
expectedClaims := &Claims{
Values: map[string]string{
"name": "jeff2",
"primary_org": "engineering",
"size": "medium",
"age": "85",
"is_admin": "true",
"division": "3",
"is_remote": "true",
},
Lists: map[string][]string{
"groups": []string{"foo", "bar"},
},
}
require.Equal(t, expectedClaims, claims)
})
t.Run("unusable claims", func(t *testing.T) {
oa, issuer := setupForJWT(t, authType, func(c *Config) {
c.BoundAudiences = []string{
"https://go-sso.test",
"another_audience",
}
})
cl := jwt.Claims{
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
Issuer: issuer,
Audience: jwt.Audience{"https://go-sso.test"},
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Second)),
}
type orgs struct {
Primary string `json:"primary"`
}
type nested struct {
Division int64 `json:"division"`
Remote bool `json:"remote"`
Size []string `json:"Size"`
}
privateCl := struct {
User string `json:"https://go-sso/user"`
Groups []string `json:"https://go-sso/groups"`
FirstName string `json:"first_name"`
Org orgs `json:"org"`
Color string `json:"color"`
Age int64 `json:"Age"`
Admin bool `json:"Admin"`
Nested nested `json:"nested"`
}{
User: "jeff",
Groups: []string{"foo", "bar"},
FirstName: "jeff2",
Org: orgs{"engineering"},
Color: "green",
Age: 85,
Admin: true,
Nested: nested{
Division: 3,
Remote: true,
Size: []string{"medium"},
},
}
jwtData, err := oidcauthtest.SignJWT("", cl, privateCl)
require.NoError(t, err)
_, err = oa.ClaimsFromJWT(context.Background(), jwtData)
requireErrorContains(t, err, "error converting claim '/nested/Size' to string from unknown type []interface {}")
})
t.Run("bad signature", func(t *testing.T) {
oa, issuer := setupForJWT(t, authType, func(c *Config) {
c.BoundAudiences = []string{
"https://go-sso.test",
"another_audience",
}
})
cl := jwt.Claims{
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
Issuer: issuer,
Audience: jwt.Audience{"https://go-sso.test"},
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Second)),
}
privateCl := struct {
User string `json:"https://go-sso/user"`
Groups []string `json:"https://go-sso/groups"`
}{
"jeff",
[]string{"foo", "bar"},
}
jwtData, err := oidcauthtest.SignJWT(badPrivKey, cl, privateCl)
require.NoError(t, err)
_, err = oa.ClaimsFromJWT(context.Background(), jwtData)
switch authType {
case authOIDCDiscovery, authJWKS:
requireErrorContains(t, err, "failed to verify id token signature")
case authStaticKeys:
requireErrorContains(t, err, "no known key successfully validated the token signature")
default:
require.Fail(t, "unexpected type: %d", authType)
}
})
t.Run("bad issuer", func(t *testing.T) {
oa, _ := setupForJWT(t, authType, func(c *Config) {
c.BoundAudiences = []string{
"https://go-sso.test",
"another_audience",
}
})
cl := jwt.Claims{
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
Issuer: "https://not.real.issuer.internal/",
Audience: jwt.Audience{"https://go-sso.test"},
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Second)),
}
privateCl := struct {
User string `json:"https://go-sso/user"`
Groups []string `json:"https://go-sso/groups"`
}{
"jeff",
[]string{"foo", "bar"},
}
jwtData, err := oidcauthtest.SignJWT("", cl, privateCl)
require.NoError(t, err)
claims, err := oa.ClaimsFromJWT(context.Background(), jwtData)
switch authType {
case authOIDCDiscovery:
requireErrorContains(t, err, "error validating signature: oidc: id token issued by a different provider")
case authStaticKeys:
requireErrorContains(t, err, "validation failed, invalid issuer claim (iss)")
case authJWKS:
// requireErrorContains(t, err, "validation failed, invalid issuer claim (iss)")
// TODO(sso) The original vault test doesn't care about bound issuer.
require.NoError(t, err)
expectedClaims := &Claims{
Values: map[string]string{},
Lists: map[string][]string{
"groups": []string{"foo", "bar"},
},
}
require.Equal(t, expectedClaims, claims)
default:
require.Fail(t, "unexpected type: %d", authType)
}
})
t.Run("bad audience", func(t *testing.T) {
oa, issuer := setupForJWT(t, authType, func(c *Config) {
c.BoundAudiences = []string{
"https://go-sso.test",
"another_audience",
}
})
cl := jwt.Claims{
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
Issuer: issuer,
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
Audience: jwt.Audience{"https://fault.plugin.auth.jwt.test"},
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Second)),
}
privateCl := struct {
User string `json:"https://go-sso/user"`
Groups []string `json:"https://go-sso/groups"`
}{
"jeff",
[]string{"foo", "bar"},
}
jwtData, err := oidcauthtest.SignJWT("", cl, privateCl)
require.NoError(t, err)
_, err = oa.ClaimsFromJWT(context.Background(), jwtData)
requireErrorContains(t, err, "error validating claims: aud claim does not match any bound audience")
})
}
func TestJWT_ClaimsFromJWT_ExpiryClaims(t *testing.T) {
t.Run("static", func(t *testing.T) {
t.Parallel()
testJWT_ClaimsFromJWT_ExpiryClaims(t, authStaticKeys)
})
t.Run("JWKS", func(t *testing.T) {
t.Parallel()
testJWT_ClaimsFromJWT_ExpiryClaims(t, authJWKS)
})
// TODO(sso): the vault versions of these tests did not run oidc-discovery
// t.Run("oidc discovery", func(t *testing.T) {
// t.Parallel()
// testJWT_ClaimsFromJWT_ExpiryClaims(t, authOIDCDiscovery)
// })
}
func testJWT_ClaimsFromJWT_ExpiryClaims(t *testing.T, authType int) {
t.Helper()
tests := map[string]struct {
Valid bool
IssuedAt time.Time
NotBefore time.Time
Expiration time.Time
DefaultLeeway int
ExpLeeway int
}{
// iat, auto clock_skew_leeway (60s), auto expiration leeway (150s)
"auto expire leeway using iat with auto clock_skew_leeway": {true, time.Now().Add(-205 * time.Second), time.Time{}, time.Time{}, 0, 0},
"expired auto expire leeway using iat with auto clock_skew_leeway": {false, time.Now().Add(-215 * time.Second), time.Time{}, time.Time{}, 0, 0},
// iat, clock_skew_leeway (10s), auto expiration leeway (150s)
"auto expire leeway using iat with custom clock_skew_leeway": {true, time.Now().Add(-150 * time.Second), time.Time{}, time.Time{}, 10, 0},
"expired auto expire leeway using iat with custom clock_skew_leeway": {false, time.Now().Add(-165 * time.Second), time.Time{}, time.Time{}, 10, 0},
// iat, no clock_skew_leeway (0s), auto expiration leeway (150s)
"auto expire leeway using iat with no clock_skew_leeway": {true, time.Now().Add(-145 * time.Second), time.Time{}, time.Time{}, -1, 0},
"expired auto expire leeway using iat with no clock_skew_leeway": {false, time.Now().Add(-155 * time.Second), time.Time{}, time.Time{}, -1, 0},
// nbf, auto clock_skew_leeway (60s), auto expiration leeway (150s)
"auto expire leeway using nbf with auto clock_skew_leeway": {true, time.Time{}, time.Now().Add(-205 * time.Second), time.Time{}, 0, 0},
"expired auto expire leeway using nbf with auto clock_skew_leeway": {false, time.Time{}, time.Now().Add(-215 * time.Second), time.Time{}, 0, 0},
// nbf, clock_skew_leeway (10s), auto expiration leeway (150s)
"auto expire leeway using nbf with custom clock_skew_leeway": {true, time.Time{}, time.Now().Add(-145 * time.Second), time.Time{}, 10, 0},
"expired auto expire leeway using nbf with custom clock_skew_leeway": {false, time.Time{}, time.Now().Add(-165 * time.Second), time.Time{}, 10, 0},
// nbf, no clock_skew_leeway (0s), auto expiration leeway (150s)
"auto expire leeway using nbf with no clock_skew_leeway": {true, time.Time{}, time.Now().Add(-145 * time.Second), time.Time{}, -1, 0},
"expired auto expire leeway using nbf with no clock_skew_leeway": {false, time.Time{}, time.Now().Add(-155 * time.Second), time.Time{}, -1, 0},
// iat, auto clock_skew_leeway (60s), custom expiration leeway (10s)
"custom expire leeway using iat with clock_skew_leeway": {true, time.Now().Add(-65 * time.Second), time.Time{}, time.Time{}, 0, 10},
"expired custom expire leeway using iat with clock_skew_leeway": {false, time.Now().Add(-75 * time.Second), time.Time{}, time.Time{}, 0, 10},
// iat, clock_skew_leeway (10s), custom expiration leeway (10s)
"custom expire leeway using iat with clock_skew_leeway with default leeway": {true, time.Now().Add(-5 * time.Second), time.Time{}, time.Time{}, 10, 10},
"expired custom expire leeway using iat with clock_skew_leeway with default leeway": {false, time.Now().Add(-25 * time.Second), time.Time{}, time.Time{}, 10, 10},
// iat, clock_skew_leeway (10s), no expiration leeway (10s)
"no expire leeway using iat with clock_skew_leeway": {true, time.Now().Add(-5 * time.Second), time.Time{}, time.Time{}, 10, -1},
"expired no expire leeway using iat with clock_skew_leeway": {false, time.Now().Add(-15 * time.Second), time.Time{}, time.Time{}, 10, -1},
// nbf, default clock_skew_leeway (60s), custom expiration leeway (10s)
"custom expire leeway using nbf with clock_skew_leeway": {true, time.Time{}, time.Now().Add(-65 * time.Second), time.Time{}, 0, 10},
"expired custom expire leeway using nbf with clock_skew_leeway": {false, time.Time{}, time.Now().Add(-75 * time.Second), time.Time{}, 0, 10},
// nbf, clock_skew_leeway (10s), custom expiration leeway (0s)
"custom expire leeway using nbf with clock_skew_leeway with default leeway": {true, time.Time{}, time.Now().Add(-5 * time.Second), time.Time{}, 10, 10},
"expired custom expire leeway using nbf with clock_skew_leeway with default leeway": {false, time.Time{}, time.Now().Add(-25 * time.Second), time.Time{}, 10, 10},
// nbf, clock_skew_leeway (10s), no expiration leeway (0s)
"no expire leeway using nbf with clock_skew_leeway with default leeway": {true, time.Time{}, time.Now().Add(-5 * time.Second), time.Time{}, 10, -1},
"no expire leeway using nbf with clock_skew_leeway with default leeway and nbf": {true, time.Time{}, time.Now().Add(-5 * time.Second), time.Time{}, 10, -100},
"expired no expire leeway using nbf with clock_skew_leeway": {false, time.Time{}, time.Now().Add(-15 * time.Second), time.Time{}, 10, -1},
"expired no expire leeway using nbf with clock_skew_leeway with default leeway and nbf": {false, time.Time{}, time.Now().Add(-15 * time.Second), time.Time{}, 10, -100},
}
for name, tt := range tests {
tt := tt
t.Run(name, func(t *testing.T) {
t.Parallel()
oa, issuer := setupForJWT(t, authType, func(c *Config) {
c.BoundAudiences = []string{
"https://go-sso.test",
"another_audience",
}
c.ClockSkewLeeway = time.Duration(tt.DefaultLeeway) * time.Second
c.ExpirationLeeway = time.Duration(tt.ExpLeeway) * time.Second
c.NotBeforeLeeway = 0
})
jwtData := setupLogin(t, tt.IssuedAt, tt.Expiration, tt.NotBefore, issuer)
_, err := oa.ClaimsFromJWT(context.Background(), jwtData)
if tt.Valid {
require.NoError(t, err)
} else {
require.Error(t, err)
}
})
}
}
func TestJWT_ClaimsFromJWT_NotBeforeClaims(t *testing.T) {
t.Run("static", func(t *testing.T) {
t.Parallel()
testJWT_ClaimsFromJWT_NotBeforeClaims(t, authStaticKeys)
})
t.Run("JWKS", func(t *testing.T) {
t.Parallel()
testJWT_ClaimsFromJWT_NotBeforeClaims(t, authJWKS)
})
// TODO(sso): the vault versions of these tests did not run oidc-discovery
// t.Run("oidc discovery", func(t *testing.T) {
// t.Parallel()
// testJWT_ClaimsFromJWT_NotBeforeClaims(t, authOIDCDiscovery)
// })
}
func testJWT_ClaimsFromJWT_NotBeforeClaims(t *testing.T, authType int) {
t.Helper()
tests := map[string]struct {
Valid bool
IssuedAt time.Time
NotBefore time.Time
Expiration time.Time
DefaultLeeway int
NBFLeeway int
}{
// iat, auto clock_skew_leeway (60s), no nbf leeway (0)
"no nbf leeway using iat with auto clock_skew_leeway": {true, time.Now().Add(55 * time.Second), time.Time{}, time.Now(), 0, -1},
"not yet valid no nbf leeway using iat with auto clock_skew_leeway": {false, time.Now().Add(65 * time.Second), time.Time{}, time.Now(), 0, -1},
// iat, clock_skew_leeway (10s), no nbf leeway (0s)
"no nbf leeway using iat with custom clock_skew_leeway": {true, time.Now().Add(5 * time.Second), time.Time{}, time.Time{}, 10, -1},
"not yet valid no nbf leeway using iat with custom clock_skew_leeway": {false, time.Now().Add(15 * time.Second), time.Time{}, time.Time{}, 10, -1},
// iat, no clock_skew_leeway (0s), nbf leeway (5s)
"nbf leeway using iat with no clock_skew_leeway": {true, time.Now(), time.Time{}, time.Time{}, -1, 5},
"not yet valid nbf leeway using iat with no clock_skew_leeway": {false, time.Now().Add(6 * time.Second), time.Time{}, time.Time{}, -1, 5},
// exp, auto clock_skew_leeway (60s), auto nbf leeway (150s)
"auto nbf leeway using exp with auto clock_skew_leeway": {true, time.Time{}, time.Time{}, time.Now().Add(205 * time.Second), 0, 0},
"not yet valid auto nbf leeway using exp with auto clock_skew_leeway": {false, time.Time{}, time.Time{}, time.Now().Add(215 * time.Second), 0, 0},
// exp, clock_skew_leeway (10s), auto nbf leeway (150s)
"auto nbf leeway using exp with custom clock_skew_leeway": {true, time.Time{}, time.Time{}, time.Now().Add(150 * time.Second), 10, 0},
"not yet valid auto nbf leeway using exp with custom clock_skew_leeway": {false, time.Time{}, time.Time{}, time.Now().Add(165 * time.Second), 10, 0},
// exp, no clock_skew_leeway (0s), auto nbf leeway (150s)
"auto nbf leeway using exp with no clock_skew_leeway": {true, time.Time{}, time.Time{}, time.Now().Add(145 * time.Second), -1, 0},
"not yet valid auto nbf leeway using exp with no clock_skew_leeway": {false, time.Time{}, time.Time{}, time.Now().Add(152 * time.Second), -1, 0},
// exp, auto clock_skew_leeway (60s), custom nbf leeway (10s)
"custom nbf leeway using exp with auto clock_skew_leeway": {true, time.Time{}, time.Time{}, time.Now().Add(65 * time.Second), 0, 10},
"not yet valid custom nbf leeway using exp with auto clock_skew_leeway": {false, time.Time{}, time.Time{}, time.Now().Add(75 * time.Second), 0, 10},
// exp, clock_skew_leeway (10s), custom nbf leeway (10s)
"custom nbf leeway using exp with custom clock_skew_leeway": {true, time.Time{}, time.Time{}, time.Now().Add(15 * time.Second), 10, 10},
"not yet valid custom nbf leeway using exp with custom clock_skew_leeway": {false, time.Time{}, time.Time{}, time.Now().Add(25 * time.Second), 10, 10},
// exp, no clock_skew_leeway (0s), custom nbf leeway (5s)
"custom nbf leeway using exp with no clock_skew_leeway": {true, time.Time{}, time.Time{}, time.Now().Add(3 * time.Second), -1, 5},
"custom nbf leeway using exp with no clock_skew_leeway with default leeway": {true, time.Time{}, time.Time{}, time.Now().Add(3 * time.Second), -100, 5},
"not yet valid custom nbf leeway using exp with no clock_skew_leeway": {false, time.Time{}, time.Time{}, time.Now().Add(7 * time.Second), -1, 5},
"not yet valid custom nbf leeway using exp with no clock_skew_leeway with default leeway": {false, time.Time{}, time.Time{}, time.Now().Add(7 * time.Second), -100, 5},
}
for name, tt := range tests {
tt := tt
t.Run(name, func(t *testing.T) {
t.Parallel()
oa, issuer := setupForJWT(t, authType, func(c *Config) {
c.BoundAudiences = []string{
"https://go-sso.test",
"another_audience",
}
c.ClockSkewLeeway = time.Duration(tt.DefaultLeeway) * time.Second
c.ExpirationLeeway = 0
c.NotBeforeLeeway = time.Duration(tt.NBFLeeway) * time.Second
})
jwtData := setupLogin(t, tt.IssuedAt, tt.Expiration, tt.NotBefore, issuer)
_, err := oa.ClaimsFromJWT(context.Background(), jwtData)
if tt.Valid {
require.NoError(t, err)
} else {
require.Error(t, err)
}
})
}
}
func setupLogin(t *testing.T, iat, exp, nbf time.Time, issuer string) string {
cl := jwt.Claims{
Audience: jwt.Audience{"https://go-sso.test"},
Issuer: issuer,
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
IssuedAt: jwt.NewNumericDate(iat),
Expiry: jwt.NewNumericDate(exp),
NotBefore: jwt.NewNumericDate(nbf),
}
privateCl := struct {
User string `json:"https://go-sso/user"`
Groups []string `json:"https://go-sso/groups"`
Color string `json:"color"`
}{
"foobar",
[]string{"foo", "bar"},
"green",
}
jwtData, err := oidcauthtest.SignJWT("", cl, privateCl)
require.NoError(t, err)
return jwtData
}
func TestParsePublicKeyPEM(t *testing.T) {
getPublicPEM := func(t *testing.T, pub interface{}) string {
derBytes, err := x509.MarshalPKIXPublicKey(pub)
require.NoError(t, err)
pemBlock := &pem.Block{
Type: "PUBLIC KEY",
Bytes: derBytes,
}
return string(pem.EncodeToMemory(pemBlock))
}
t.Run("rsa", func(t *testing.T) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
pub := privateKey.Public()
pubPEM := getPublicPEM(t, pub)
got, err := parsePublicKeyPEM([]byte(pubPEM))
require.NoError(t, err)
require.Equal(t, pub, got)
})
t.Run("ecdsa", func(t *testing.T) {
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
pub := privateKey.Public()
pubPEM := getPublicPEM(t, pub)
got, err := parsePublicKeyPEM([]byte(pubPEM))
require.NoError(t, err)
require.Equal(t, pub, got)
})
t.Run("ed25519", func(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pubPEM := getPublicPEM(t, pub)
got, err := parsePublicKeyPEM([]byte(pubPEM))
require.NoError(t, err)
require.Equal(t, pub, got)
})
}
const (
badPrivKey string = `-----BEGIN EC PRIVATE KEY-----
MHcCAQEEILTAHJm+clBKYCrRDc74Pt7uF7kH+2x2TdL5cH23FEcsoAoGCCqGSM49
AwEHoUQDQgAE+C3CyjVWdeYtIqgluFJlwZmoonphsQbj9Nfo5wrEutv+3RTFnDQh
vttUajcFAcl4beR+jHFYC00vSO4i5jZ64g==
-----END EC PRIVATE KEY-----`
)

View file

@ -0,0 +1,282 @@
package oidcauth
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/coreos/go-oidc"
"github.com/hashicorp/go-uuid"
"golang.org/x/oauth2"
)
var (
oidcStateTimeout = 10 * time.Minute
oidcStateCleanupInterval = 1 * time.Minute
)
// GetAuthCodeURL is the first part of the OIDC authorization code workflow.
// The statePayload field is stored in the Authenticator instance keyed by the
// "state" key so it can be returned during a future call to
// ClaimsFromAuthCode.
//
// Requires the authenticator's config type be set to 'oidc'.
func (a *Authenticator) GetAuthCodeURL(ctx context.Context, redirectURI string, statePayload interface{}) (string, error) {
if a.config.authType() != authOIDCFlow {
return "", fmt.Errorf("GetAuthCodeURL is incompatible with type %q", TypeJWT)
}
if redirectURI == "" {
return "", errors.New("missing redirect_uri")
}
if !validRedirect(redirectURI, a.config.AllowedRedirectURIs) {
return "", fmt.Errorf("unauthorized redirect_uri: %s", redirectURI)
}
// "openid" is a required scope for OpenID Connect flows
scopes := append([]string{oidc.ScopeOpenID}, a.config.OIDCScopes...)
// Configure an OpenID Connect aware OAuth2 client
oauth2Config := oauth2.Config{
ClientID: a.config.OIDCClientID,
ClientSecret: a.config.OIDCClientSecret,
RedirectURL: redirectURI,
Endpoint: a.provider.Endpoint(),
Scopes: scopes,
}
stateID, nonce, err := a.createOIDCState(redirectURI, statePayload)
if err != nil {
return "", fmt.Errorf("error generating OAuth state: %v", err)
}
authCodeOpts := []oauth2.AuthCodeOption{
oidc.Nonce(nonce),
}
return oauth2Config.AuthCodeURL(stateID, authCodeOpts...), nil
}
// ClaimsFromAuthCode is the second part of the OIDC authorization code
// workflow. The interface{} return value is the statePayload previously passed
// via GetAuthCodeURL.
//
// The error may be of type *ProviderLoginFailedError or
// *TokenVerificationFailedError which can be detected via errors.As().
//
// Requires the authenticator's config type be set to 'oidc'.
func (a *Authenticator) ClaimsFromAuthCode(ctx context.Context, stateParam, code string) (*Claims, interface{}, error) {
if a.config.authType() != authOIDCFlow {
return nil, nil, fmt.Errorf("ClaimsFromAuthCode is incompatible with type %q", TypeJWT)
}
// TODO(sso): this could be because we ACTUALLY are getting OIDC error responses and
// should handle them elsewhere!
if code == "" {
return nil, nil, &ProviderLoginFailedError{
Err: fmt.Errorf("OAuth code parameter not provided"),
}
}
state := a.verifyOIDCState(stateParam)
if state == nil {
return nil, nil, &ProviderLoginFailedError{
Err: fmt.Errorf("Expired or missing OAuth state."),
}
}
oidcCtx := contextWithHttpClient(ctx, a.httpClient)
var oauth2Config = oauth2.Config{
ClientID: a.config.OIDCClientID,
ClientSecret: a.config.OIDCClientSecret,
RedirectURL: state.redirectURI,
Endpoint: a.provider.Endpoint(),
Scopes: []string{oidc.ScopeOpenID},
}
oauth2Token, err := oauth2Config.Exchange(oidcCtx, code)
if err != nil {
return nil, nil, &ProviderLoginFailedError{
Err: fmt.Errorf("Error exchanging oidc code: %w", err),
}
}
// Extract the ID Token from OAuth2 token.
rawToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
return nil, nil, &TokenVerificationFailedError{
Err: errors.New("No id_token found in response."),
}
}
if a.config.VerboseOIDCLogging && a.logger != nil {
a.logger.Debug("OIDC provider response", "ID token", rawToken)
}
// Parse and verify ID Token payload.
allClaims, err := a.verifyOIDCToken(ctx, rawToken) // TODO(sso): should this use oidcCtx?
if err != nil {
return nil, nil, &TokenVerificationFailedError{
Err: err,
}
}
if allClaims["nonce"] != state.nonce { // TODO(sso): does this need a cast?
return nil, nil, &TokenVerificationFailedError{
Err: errors.New("Invalid ID token nonce."),
}
}
delete(allClaims, "nonce")
// Attempt to fetch information from the /userinfo endpoint and merge it with
// the existing claims data. A failure to fetch additional information from this
// endpoint will not invalidate the authorization flow.
if userinfo, err := a.provider.UserInfo(oidcCtx, oauth2.StaticTokenSource(oauth2Token)); err == nil {
_ = userinfo.Claims(&allClaims)
} else {
if a.logger != nil {
logFunc := a.logger.Warn
if strings.Contains(err.Error(), "user info endpoint is not supported") {
logFunc = a.logger.Info
}
logFunc("error reading /userinfo endpoint", "error", err)
}
}
if a.config.VerboseOIDCLogging && a.logger != nil {
if c, err := json.Marshal(allClaims); err == nil {
a.logger.Debug("OIDC provider response", "claims", string(c))
} else {
a.logger.Debug("OIDC provider response", "marshalling error", err.Error())
}
}
c, err := a.extractClaims(allClaims)
if err != nil {
return nil, nil, &TokenVerificationFailedError{
Err: err,
}
}
if a.config.VerboseOIDCLogging && a.logger != nil {
a.logger.Debug("OIDC provider response", "extracted_claims", c)
}
return c, state.payload, nil
}
// ProviderLoginFailedError is an error type sometimes returned from
// ClaimsFromAuthCode().
//
// It represents a failure to complete the authorization code workflow with the
// provider such as losing important OIDC parameters or a failure to fetch an
// id_token.
//
// You can check for it with errors.As().
type ProviderLoginFailedError struct {
Err error
}
func (e *ProviderLoginFailedError) Error() string {
return fmt.Sprintf("Provider login failed: %v", e.Err)
}
func (e *ProviderLoginFailedError) Unwrap() error { return e.Err }
// TokenVerificationFailedError is an error type sometimes returned from
// ClaimsFromAuthCode().
//
// It represents a failure to vet the returned OIDC credentials for validity
// such as the id_token not passing verification or using an mismatched nonce.
//
// You can check for it with errors.As().
type TokenVerificationFailedError struct {
Err error
}
func (e *TokenVerificationFailedError) Error() string {
return fmt.Sprintf("Token verification failed: %v", e.Err)
}
func (e *TokenVerificationFailedError) Unwrap() error { return e.Err }
func (a *Authenticator) verifyOIDCToken(ctx context.Context, rawToken string) (map[string]interface{}, error) {
allClaims := make(map[string]interface{})
oidcConfig := &oidc.Config{
SupportedSigningAlgs: a.config.JWTSupportedAlgs,
}
switch a.config.authType() {
case authOIDCFlow:
oidcConfig.ClientID = a.config.OIDCClientID
case authOIDCDiscovery:
oidcConfig.SkipClientIDCheck = true
default:
return nil, fmt.Errorf("unsupported auth type for this verifyOIDCToken: %d", a.config.authType())
}
verifier := a.provider.Verifier(oidcConfig)
idToken, err := verifier.Verify(ctx, rawToken)
if err != nil {
return nil, fmt.Errorf("error validating signature: %v", err)
}
if err := idToken.Claims(&allClaims); err != nil {
return nil, fmt.Errorf("unable to successfully parse all claims from token: %v", err)
}
// TODO(sso): why isn't this strict for OIDC?
if err := validateAudience(a.config.BoundAudiences, idToken.Audience, false); err != nil {
return nil, fmt.Errorf("error validating claims: %v", err)
}
return allClaims, nil
}
// verifyOIDCState tests whether the provided state ID is valid and returns the
// associated state object if so. A nil state is returned if the ID is not found
// or expired. The state should only ever be retrieved once and is deleted as
// part of this request.
func (a *Authenticator) verifyOIDCState(stateID string) *oidcState {
defer a.oidcStates.Delete(stateID)
if stateRaw, ok := a.oidcStates.Get(stateID); ok {
return stateRaw.(*oidcState)
}
return nil
}
// createOIDCState make an expiring state object, associated with a random state ID
// that is passed throughout the OAuth process. A nonce is also included in the
// auth process, and for simplicity will be identical in length/format as the state ID.
func (a *Authenticator) createOIDCState(redirectURI string, payload interface{}) (string, string, error) {
// Get enough bytes for 2 160-bit IDs (per rfc6749#section-10.10)
bytes, err := uuid.GenerateRandomBytes(2 * 20)
if err != nil {
return "", "", err
}
stateID := fmt.Sprintf("%x", bytes[:20])
nonce := fmt.Sprintf("%x", bytes[20:])
a.oidcStates.SetDefault(stateID, &oidcState{
nonce: nonce,
redirectURI: redirectURI,
payload: payload,
})
return stateID, nonce, nil
}
// oidcState is created when an authURL is requested. The state
// identifier is passed throughout the OAuth process.
type oidcState struct {
nonce string
redirectURI string
payload interface{}
}

View file

@ -0,0 +1,507 @@
package oidcauth
import (
"context"
"errors"
"net/url"
"strings"
"testing"
"time"
"github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2/jwt"
)
func setupForOIDC(t *testing.T) (*Authenticator, *oidcauthtest.Server) {
t.Helper()
srv := oidcauthtest.Start(t)
srv.SetClientCreds("abc", "def")
config := &Config{
Type: TypeOIDC,
OIDCDiscoveryURL: srv.Addr(),
OIDCDiscoveryCACert: srv.CACert(),
OIDCClientID: "abc",
OIDCClientSecret: "def",
JWTSupportedAlgs: []string{"ES256"},
BoundAudiences: []string{"abc"},
AllowedRedirectURIs: []string{"https://example.com"},
ClaimMappings: map[string]string{
"COLOR": "color",
"/nested/Size": "size",
"Age": "age",
"Admin": "is_admin",
"/nested/division": "division",
"/nested/remote": "is_remote",
"flavor": "flavor", // userinfo
},
ListClaimMappings: map[string]string{
"/nested/Groups": "groups",
},
}
require.NoError(t, config.Validate())
oa, err := New(config, hclog.NewNullLogger())
require.NoError(t, err)
t.Cleanup(oa.Stop)
return oa, srv
}
func TestOIDC_AuthURL(t *testing.T) {
t.Run("normal case", func(t *testing.T) {
t.Parallel()
oa, _ := setupForOIDC(t)
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
map[string]string{"foo": "bar"},
)
require.NoError(t, err)
require.True(t, strings.HasPrefix(authURL, oa.config.OIDCDiscoveryURL+"/auth?"))
expected := map[string]string{
"client_id": "abc",
"redirect_uri": "https://example.com",
"response_type": "code",
"scope": "openid",
}
au, err := url.Parse(authURL)
require.NoError(t, err)
for k, v := range expected {
assert.Equal(t, v, au.Query().Get(k), "key %q is incorrect", k)
}
assert.Regexp(t, `^[a-z0-9]{40}$`, au.Query().Get("nonce"))
assert.Regexp(t, `^[a-z0-9]{40}$`, au.Query().Get("state"))
})
t.Run("invalid RedirectURI", func(t *testing.T) {
t.Parallel()
oa, _ := setupForOIDC(t)
_, err := oa.GetAuthCodeURL(
context.Background(),
"http://bitc0in-4-less.cx",
map[string]string{"foo": "bar"},
)
requireErrorContains(t, err, "unauthorized redirect_uri: http://bitc0in-4-less.cx")
})
t.Run("missing RedirectURI", func(t *testing.T) {
t.Parallel()
oa, _ := setupForOIDC(t)
_, err := oa.GetAuthCodeURL(
context.Background(),
"",
map[string]string{"foo": "bar"},
)
requireErrorContains(t, err, "missing redirect_uri")
})
}
func TestOIDC_JWT_Functions_Fail(t *testing.T) {
oa, srv := setupForOIDC(t)
cl := jwt.Claims{
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
Issuer: srv.Addr(),
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
Audience: jwt.Audience{"https://go-sso.test"},
}
privateCl := struct {
User string `json:"https://go-sso/user"`
Groups []string `json:"https://go-sso/groups"`
}{
"jeff",
[]string{"foo", "bar"},
}
jwtData, err := oidcauthtest.SignJWT("", cl, privateCl)
require.NoError(t, err)
_, err = oa.ClaimsFromJWT(context.Background(), jwtData)
requireErrorContains(t, err, `ClaimsFromJWT is incompatible with type "oidc"`)
}
func TestOIDC_ClaimsFromAuthCode(t *testing.T) {
requireProviderError := func(t *testing.T, err error) {
var provErr *ProviderLoginFailedError
if !errors.As(err, &provErr) {
t.Fatalf("error was not a *ProviderLoginFailedError")
}
}
requireTokenVerificationError := func(t *testing.T, err error) {
var tokErr *TokenVerificationFailedError
if !errors.As(err, &tokErr) {
t.Fatalf("error was not a *TokenVerificationFailedError")
}
}
t.Run("successful login", func(t *testing.T) {
oa, srv := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
nonce := getQueryParam(t, authURL, "nonce")
// set provider claims that will be returned by the mock server
srv.SetCustomClaims(sampleClaims(nonce))
// set mock provider's expected code
srv.SetExpectedAuthCode("abc")
claims, payload, err := oa.ClaimsFromAuthCode(
context.Background(),
state, "abc",
)
require.NoError(t, err)
require.Equal(t, origPayload, payload)
expectedClaims := &Claims{
Values: map[string]string{
"color": "green",
"size": "medium",
"age": "85",
"is_admin": "true",
"division": "3",
"is_remote": "true",
"flavor": "umami", // from userinfo
},
Lists: map[string][]string{
"groups": []string{"a", "b"},
},
}
require.Equal(t, expectedClaims, claims)
})
t.Run("failed login unusable claims", func(t *testing.T) {
oa, srv := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
nonce := getQueryParam(t, authURL, "nonce")
// set provider claims that will be returned by the mock server
customClaims := sampleClaims(nonce)
customClaims["COLOR"] = []interface{}{"yellow"}
srv.SetCustomClaims(customClaims)
// set mock provider's expected code
srv.SetExpectedAuthCode("abc")
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
state, "abc",
)
requireErrorContains(t, err, "error converting claim 'COLOR' to string from unknown type []interface {}")
requireTokenVerificationError(t, err)
})
t.Run("successful login - no userinfo", func(t *testing.T) {
oa, srv := setupForOIDC(t)
srv.DisableUserInfo()
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
nonce := getQueryParam(t, authURL, "nonce")
// set provider claims that will be returned by the mock server
srv.SetCustomClaims(sampleClaims(nonce))
// set mock provider's expected code
srv.SetExpectedAuthCode("abc")
claims, payload, err := oa.ClaimsFromAuthCode(
context.Background(),
state, "abc",
)
require.NoError(t, err)
require.Equal(t, origPayload, payload)
expectedClaims := &Claims{
Values: map[string]string{
"color": "green",
"size": "medium",
"age": "85",
"is_admin": "true",
"division": "3",
"is_remote": "true",
// "flavor": "umami", // from userinfo
},
Lists: map[string][]string{
"groups": []string{"a", "b"},
},
}
require.Equal(t, expectedClaims, claims)
})
t.Run("failed login - bad nonce", func(t *testing.T) {
t.Parallel()
oa, srv := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
srv.SetCustomClaims(sampleClaims("bad nonce"))
// set mock provider's expected code
srv.SetExpectedAuthCode("abc")
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
state, "abc",
)
requireErrorContains(t, err, "Invalid ID token nonce")
requireTokenVerificationError(t, err)
})
t.Run("missing state", func(t *testing.T) {
oa, _ := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
_, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
"", "abc",
)
requireErrorContains(t, err, "Expired or missing OAuth state")
requireProviderError(t, err)
})
t.Run("unknown state", func(t *testing.T) {
oa, _ := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
_, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
"not_a_state", "abc",
)
requireErrorContains(t, err, "Expired or missing OAuth state")
requireProviderError(t, err)
})
t.Run("valid state, missing code", func(t *testing.T) {
oa, _ := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
state, "",
)
requireErrorContains(t, err, "OAuth code parameter not provided")
requireProviderError(t, err)
})
t.Run("failed code exchange", func(t *testing.T) {
oa, srv := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
// set mock provider's expected code
srv.SetExpectedAuthCode("abc")
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
state, "wrong_code",
)
requireErrorContains(t, err, "cannot fetch token")
requireProviderError(t, err)
})
t.Run("no id_token returned", func(t *testing.T) {
oa, srv := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
nonce := getQueryParam(t, authURL, "nonce")
// set provider claims that will be returned by the mock server
srv.SetCustomClaims(sampleClaims(nonce))
// set mock provider's expected code
srv.SetExpectedAuthCode("abc")
srv.OmitIDTokens()
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
state, "abc",
)
requireErrorContains(t, err, "No id_token found in response")
requireTokenVerificationError(t, err)
})
t.Run("no response from provider", func(t *testing.T) {
oa, srv := setupForOIDC(t)
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
// close the server prematurely
srv.Stop()
srv.SetExpectedAuthCode("abc")
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
state, "abc",
)
requireErrorContains(t, err, "connection refused")
requireProviderError(t, err)
})
t.Run("invalid bound audience", func(t *testing.T) {
oa, srv := setupForOIDC(t)
srv.SetClientCreds("not_gonna_match", "def")
origPayload := map[string]string{"foo": "bar"}
authURL, err := oa.GetAuthCodeURL(
context.Background(),
"https://example.com",
origPayload,
)
require.NoError(t, err)
state := getQueryParam(t, authURL, "state")
nonce := getQueryParam(t, authURL, "nonce")
// set provider claims that will be returned by the mock server
srv.SetCustomClaims(sampleClaims(nonce))
// set mock provider's expected code
srv.SetExpectedAuthCode("abc")
_, _, err = oa.ClaimsFromAuthCode(
context.Background(),
state, "abc",
)
requireErrorContains(t, err, `error validating signature: oidc: expected audience "abc" got ["not_gonna_match"]`)
requireTokenVerificationError(t, err)
})
}
func sampleClaims(nonce string) map[string]interface{} {
return map[string]interface{}{
"nonce": nonce,
"email": "bob@example.com",
"COLOR": "green",
"sk": "42",
"Age": 85,
"Admin": true,
"nested": map[string]interface{}{
"Size": "medium",
"division": 3,
"remote": true,
"Groups": []string{"a", "b"},
"secret_code": "bar",
},
"password": "foo",
}
}
func getQueryParam(t *testing.T, inputURL, param string) string {
t.Helper()
m, err := url.ParseQuery(inputURL)
if err != nil {
t.Fatal(err)
}
v, ok := m[param]
if !ok {
t.Fatalf("query param %q not found", param)
}
return v[0]
}

View file

@ -0,0 +1,529 @@
// package oidcauthtest exposes tools to assist in writing unit tests of OIDC
// and JWT authentication workflows.
//
// When the package is loaded it will randomly generate an ECDSA signing
// keypair used to sign JWTs both via the Server and the SignJWT method.
package oidcauthtest
import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"sync"
"time"
"github.com/hashicorp/consul/internal/go-sso/oidcauth/internal/strutil"
"github.com/mitchellh/go-testing-interface"
"github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/jwt"
)
// Server is local server the mocks the endpoints used by the OIDC and
// JWKS process.
type Server struct {
httpServer *httptest.Server
caCert string
returnFunc func()
jwks *jose.JSONWebKeySet
allowedRedirectURIs []string
replySubject string
replyUserinfo map[string]interface{}
mu sync.Mutex
clientID string
clientSecret string
expectedAuthCode string
expectedAuthNonce string
customClaims map[string]interface{}
customAudience string
omitIDToken bool
disableUserInfo bool
}
type startOption struct {
port int
returnFunc func()
}
// WithPort is a option for Start that lets the caller control the port
// allocation. The returnFunc parameter is used when the provider is stopped to
// return the port in whatever bookkeeping system the caller wants to use.
func WithPort(port int, returnFunc func()) startOption {
return startOption{
port: port,
returnFunc: returnFunc,
}
}
// Start creates a disposable Server. If the port provided is
// zero it will bind to a random free port, otherwise the provided port is
// used.
func Start(t testing.T, options ...startOption) *Server {
s := &Server{
allowedRedirectURIs: []string{
"https://example.com",
},
replySubject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
replyUserinfo: map[string]interface{}{
"color": "red",
"temperature": "76",
"flavor": "umami",
},
}
jwks, err := newJWKS(ecdsaPublicKey)
require.NoError(t, err)
s.jwks = jwks
var (
port int
returnFunc func()
)
for _, option := range options {
if option.port > 0 {
port = option.port
returnFunc = option.returnFunc
}
}
s.httpServer = httptestNewUnstartedServerWithPort(s, port)
s.httpServer.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
s.httpServer.StartTLS()
if returnFunc != nil {
t.Cleanup(returnFunc)
}
t.Cleanup(s.httpServer.Close)
cert := s.httpServer.Certificate()
var buf bytes.Buffer
require.NoError(t, pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}))
s.caCert = buf.String()
return s
}
// SetClientCreds is for configuring the client information required for the
// OIDC workflows.
func (s *Server) SetClientCreds(clientID, clientSecret string) {
s.mu.Lock()
defer s.mu.Unlock()
s.clientID = clientID
s.clientSecret = clientSecret
}
// SetExpectedAuthCode configures the auth code to return from /auth and the
// allowed auth code for /token.
func (s *Server) SetExpectedAuthCode(code string) {
s.mu.Lock()
defer s.mu.Unlock()
s.expectedAuthCode = code
}
// SetExpectedAuthNonce configures the nonce value required for /auth.
func (s *Server) SetExpectedAuthNonce(nonce string) {
s.mu.Lock()
defer s.mu.Unlock()
s.expectedAuthNonce = nonce
}
// SetAllowedRedirectURIs allows you to configure the allowed redirect URIs for
// the OIDC workflow. If not configured a sample of "https://example.com" is
// used.
func (s *Server) SetAllowedRedirectURIs(uris []string) {
s.mu.Lock()
defer s.mu.Unlock()
s.allowedRedirectURIs = uris
}
// SetCustomClaims lets you set claims to return in the JWT issued by the OIDC
// workflow.
func (s *Server) SetCustomClaims(customClaims map[string]interface{}) {
s.mu.Lock()
defer s.mu.Unlock()
s.customClaims = customClaims
}
// SetCustomAudience configures what audience value to embed in the JWT issued
// by the OIDC workflow.
func (s *Server) SetCustomAudience(customAudience string) {
s.mu.Lock()
defer s.mu.Unlock()
s.customAudience = customAudience
}
// OmitIDTokens forces an error state where the /token endpoint does not return
// id_token.
func (s *Server) OmitIDTokens() {
s.mu.Lock()
defer s.mu.Unlock()
s.omitIDToken = true
}
// DisableUserInfo makes the userinfo endpoint return 404 and omits it from the
// discovery config.
func (s *Server) DisableUserInfo() {
s.mu.Lock()
defer s.mu.Unlock()
s.disableUserInfo = true
}
// Stop stops the running Server.
func (s *Server) Stop() {
s.httpServer.Close()
}
// Addr returns the current base URL for the running webserver.
func (s *Server) Addr() string { return s.httpServer.URL }
// CACert returns the pem-encoded CA certificate used by the HTTPS server.
func (s *Server) CACert() string { return s.caCert }
// SigningKeys returns the pem-encoded keys used to sign JWTs.
func (s *Server) SigningKeys() (pub, priv string) {
return SigningKeys()
}
// ServeHTTP implements http.Handler.
func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
s.mu.Lock()
defer s.mu.Unlock()
w.Header().Set("Content-Type", "application/json")
switch req.URL.Path {
case "/.well-known/openid-configuration":
if req.Method != "GET" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
reply := struct {
Issuer string `json:"issuer"`
AuthEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
JWKSURI string `json:"jwks_uri"`
UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"`
}{
Issuer: s.Addr(),
AuthEndpoint: s.Addr() + "/auth",
TokenEndpoint: s.Addr() + "/token",
JWKSURI: s.Addr() + "/certs",
UserinfoEndpoint: s.Addr() + "/userinfo",
}
if s.disableUserInfo {
reply.UserinfoEndpoint = ""
}
if err := writeJSON(w, &reply); err != nil {
return
}
case "/auth":
if req.Method != "GET" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
qv := req.URL.Query()
if qv.Get("response_type") != "code" {
writeAuthErrorResponse(w, req, "unsupported_response_type", "")
return
}
if qv.Get("scope") != "openid" {
writeAuthErrorResponse(w, req, "invalid_scope", "")
return
}
if s.expectedAuthCode == "" {
writeAuthErrorResponse(w, req, "access_denied", "")
return
}
nonce := qv.Get("nonce")
if s.expectedAuthNonce != "" && s.expectedAuthNonce != nonce {
writeAuthErrorResponse(w, req, "access_denied", "")
return
}
state := qv.Get("state")
if state == "" {
writeAuthErrorResponse(w, req, "invalid_request", "missing state parameter")
return
}
redirectURI := qv.Get("redirect_uri")
if redirectURI == "" {
writeAuthErrorResponse(w, req, "invalid_request", "missing redirect_uri parameter")
return
}
redirectURI += "?state=" + url.QueryEscape(state) +
"&code=" + url.QueryEscape(s.expectedAuthCode)
http.Redirect(w, req, redirectURI, http.StatusFound)
return
case "/certs":
if req.Method != "GET" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
if err := writeJSON(w, s.jwks); err != nil {
return
}
case "/certs_missing":
w.WriteHeader(http.StatusNotFound)
case "/certs_invalid":
w.Write([]byte("It's not a keyset!"))
case "/token":
if req.Method != "POST" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
switch {
case req.FormValue("grant_type") != "authorization_code":
_ = writeTokenErrorResponse(w, req, http.StatusBadRequest, "invalid_request", "bad grant_type")
return
case !strutil.StrListContains(s.allowedRedirectURIs, req.FormValue("redirect_uri")):
_ = writeTokenErrorResponse(w, req, http.StatusBadRequest, "invalid_request", "redirect_uri is not allowed")
return
case req.FormValue("code") != s.expectedAuthCode:
_ = writeTokenErrorResponse(w, req, http.StatusUnauthorized, "invalid_grant", "unexpected auth code")
return
}
stdClaims := jwt.Claims{
Subject: s.replySubject,
Issuer: s.Addr(),
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Second)),
Audience: jwt.Audience{s.clientID},
}
if s.customAudience != "" {
stdClaims.Audience = jwt.Audience{s.customAudience}
}
jwtData, err := SignJWT("", stdClaims, s.customClaims)
if err != nil {
_ = writeTokenErrorResponse(w, req, http.StatusInternalServerError, "server_error", err.Error())
return
}
reply := struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token,omitempty"`
}{
AccessToken: jwtData,
IDToken: jwtData,
}
if s.omitIDToken {
reply.IDToken = ""
}
if err := writeJSON(w, &reply); err != nil {
return
}
case "/userinfo":
if s.disableUserInfo {
w.WriteHeader(http.StatusNotFound)
return
}
if req.Method != "GET" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
if err := writeJSON(w, s.replyUserinfo); err != nil {
return
}
default:
w.WriteHeader(http.StatusNotFound)
}
}
func writeAuthErrorResponse(w http.ResponseWriter, req *http.Request, errorCode, errorMessage string) {
qv := req.URL.Query()
redirectURI := qv.Get("redirect_uri") +
"?state=" + url.QueryEscape(qv.Get("state")) +
"&error=" + url.QueryEscape(errorCode)
if errorMessage != "" {
redirectURI += "&error_description=" + url.QueryEscape(errorMessage)
}
http.Redirect(w, req, redirectURI, http.StatusFound)
}
func writeTokenErrorResponse(w http.ResponseWriter, req *http.Request, statusCode int, errorCode, errorMessage string) error {
body := struct {
Code string `json:"error"`
Desc string `json:"error_description,omitempty"`
}{
Code: errorCode,
Desc: errorMessage,
}
w.WriteHeader(statusCode)
return writeJSON(w, &body)
}
// newJWKS converts a pem-encoded public key into JWKS data suitable for a
// verification endpoint response
func newJWKS(pubKey string) (*jose.JSONWebKeySet, error) {
block, _ := pem.Decode([]byte(pubKey))
if block == nil {
return nil, fmt.Errorf("unable to decode public key")
}
input := block.Bytes
pub, err := x509.ParsePKIXPublicKey(input)
if err != nil {
return nil, err
}
return &jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{
jose.JSONWebKey{
Key: pub,
},
},
}, nil
}
func writeJSON(w http.ResponseWriter, out interface{}) error {
enc := json.NewEncoder(w)
return enc.Encode(out)
}
// SignJWT will bundle the provided claims into a signed JWT. The provided key
// is assumed to be ECDSA.
//
// If no private key is provided, the default package keys are used. These can
// be retrieved via the SigningKeys() method.
func SignJWT(privKey string, claims jwt.Claims, privateClaims interface{}) (string, error) {
if privKey == "" {
privKey = ecdsaPrivateKey
}
var key *ecdsa.PrivateKey
block, _ := pem.Decode([]byte(privKey))
if block != nil {
var err error
key, err = x509.ParseECPrivateKey(block.Bytes)
if err != nil {
return "", err
}
}
sig, err := jose.NewSigner(
jose.SigningKey{Algorithm: jose.ES256, Key: key},
(&jose.SignerOptions{}).WithType("JWT"),
)
if err != nil {
return "", err
}
raw, err := jwt.Signed(sig).
Claims(claims).
Claims(privateClaims).
CompactSerialize()
if err != nil {
return "", err
}
return raw, nil
}
// httptestNewUnstartedServerWithPort is roughly the same as
// httptest.NewUnstartedServer() but allows the caller to explicitly choose the
// port if desired.
func httptestNewUnstartedServerWithPort(handler http.Handler, port int) *httptest.Server {
if port == 0 {
return httptest.NewUnstartedServer(handler)
}
addr := net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
l, err := net.Listen("tcp", addr)
if err != nil {
panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
}
return &httptest.Server{
Listener: l,
Config: &http.Server{Handler: handler},
}
}
// SigningKeys returns the pem-encoded keys used to sign JWTs by default.
func SigningKeys() (pub, priv string) {
return ecdsaPublicKey, ecdsaPrivateKey
}
var (
ecdsaPublicKey string
ecdsaPrivateKey string
)
func init() {
// Each time we run tests we generate a unique set of keys for use in the
// test. These are cached between runs but do not persist between restarts
// of the test binary.
var err error
ecdsaPublicKey, ecdsaPrivateKey, err = generateKey()
if err != nil {
panic(err)
}
}
func generateKey() (pub, priv string, err error) {
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return "", "", fmt.Errorf("error generating private key: %v", err)
}
{
derBytes, err := x509.MarshalECPrivateKey(privateKey)
if err != nil {
return "", "", fmt.Errorf("error marshaling private key: %v", err)
}
pemBlock := &pem.Block{
Type: "EC PRIVATE KEY",
Bytes: derBytes,
}
priv = string(pem.EncodeToMemory(pemBlock))
}
{
derBytes, err := x509.MarshalPKIXPublicKey(privateKey.Public())
if err != nil {
return "", "", fmt.Errorf("error marshaling public key: %v", err)
}
pemBlock := &pem.Block{
Type: "PUBLIC KEY",
Bytes: derBytes,
}
pub = string(pem.EncodeToMemory(pemBlock))
}
return pub, priv, nil
}

View file

@ -0,0 +1,252 @@
package oidcauth
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"github.com/hashicorp/consul/internal/go-sso/oidcauth/internal/strutil"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-hclog"
"github.com/mitchellh/pointerstructure"
"golang.org/x/oauth2"
)
func contextWithHttpClient(ctx context.Context, client *http.Client) context.Context {
return context.WithValue(ctx, oauth2.HTTPClient, client)
}
func createHTTPClient(caCert string) (*http.Client, error) {
tr := cleanhttp.DefaultPooledTransport()
if caCert != "" {
certPool := x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM([]byte(caCert)); !ok {
return nil, errors.New("could not parse CA PEM value successfully")
}
tr.TLSClientConfig = &tls.Config{
RootCAs: certPool,
}
}
return &http.Client{
Transport: tr,
}, nil
}
// extractClaims extracts all configured claims from the received claims.
func (a *Authenticator) extractClaims(allClaims map[string]interface{}) (*Claims, error) {
metadata, err := extractStringMetadata(a.logger, allClaims, a.config.ClaimMappings)
if err != nil {
return nil, err
}
listMetadata, err := extractListMetadata(a.logger, allClaims, a.config.ListClaimMappings)
if err != nil {
return nil, err
}
return &Claims{
Values: metadata,
Lists: listMetadata,
}, nil
}
// extractStringMetadata builds a metadata map of string values from a set of
// claims and claims mappings. The referenced claims must be strings and the
// claims mappings must be of the structure:
//
// {
// "/some/claim/pointer": "metadata_key1",
// "another_claim": "metadata_key2",
// ...
// }
func extractStringMetadata(logger hclog.Logger, allClaims map[string]interface{}, claimMappings map[string]string) (map[string]string, error) {
metadata := make(map[string]string)
for source, target := range claimMappings {
rawValue := getClaim(logger, allClaims, source)
if rawValue == nil {
continue
}
strValue, ok := stringifyMetadataValue(rawValue)
if !ok {
return nil, fmt.Errorf("error converting claim '%s' to string from unknown type %T", source, rawValue)
}
metadata[target] = strValue
}
return metadata, nil
}
// extractListMetadata builds a metadata map of string list values from a set
// of claims and claims mappings. The referenced claims must be strings and
// the claims mappings must be of the structure:
//
// {
// "/some/claim/pointer": "metadata_key1",
// "another_claim": "metadata_key2",
// ...
// }
func extractListMetadata(logger hclog.Logger, allClaims map[string]interface{}, listClaimMappings map[string]string) (map[string][]string, error) {
out := make(map[string][]string)
for source, target := range listClaimMappings {
if rawValue := getClaim(logger, allClaims, source); rawValue != nil {
rawList, ok := normalizeList(rawValue)
if !ok {
return nil, fmt.Errorf("%q list claim could not be converted to string list", source)
}
list := make([]string, 0, len(rawList))
for _, raw := range rawList {
value, ok := stringifyMetadataValue(raw)
if !ok {
return nil, fmt.Errorf("value %v in %q list claim could not be parsed as string", raw, source)
}
if value == "" {
continue
}
list = append(list, value)
}
out[target] = list
}
}
return out, nil
}
// getClaim returns a claim value from allClaims given a provided claim string.
// If this string is a valid JSONPointer, it will be interpreted as such to
// locate the claim. Otherwise, the claim string will be used directly.
//
// There is no fixup done to the returned data type here. That happens a layer
// up in the caller.
func getClaim(logger hclog.Logger, allClaims map[string]interface{}, claim string) interface{} {
if !strings.HasPrefix(claim, "/") {
return allClaims[claim]
}
val, err := pointerstructure.Get(allClaims, claim)
if err != nil {
if logger != nil {
logger.Warn("unable to locate claim", "claim", claim, "error", err)
}
return nil
}
return val
}
// normalizeList takes an item or a slice and returns a slice. This is useful
// when providers are expected to return a list (typically of strings) but
// reduce it to a non-slice type when the list count is 1.
//
// There is no fixup done to elements of the returned slice here. That happens
// a layer up in the caller.
func normalizeList(raw interface{}) ([]interface{}, bool) {
switch v := raw.(type) {
case []interface{}:
return v, true
case string, // note: this list should be the same as stringifyMetadataValue
bool,
json.Number,
float64,
float32,
int8,
int16,
int32,
int64,
int,
uint8,
uint16,
uint32,
uint64,
uint:
return []interface{}{v}, true
default:
return nil, false
}
}
// stringifyMetadataValue will try to convert the provided raw value into a
// faithful string representation of that value per these rules:
//
// - strings => unchanged
// - bool => "true" / "false"
// - json.Number => String()
// - float32/64 => truncated to int64 and then formatted as an ascii string
// - intXX/uintXX => casted to int64 and then formatted as an ascii string
//
// If successful the string value and true are returned. otherwise an empty
// string and false are returned.
func stringifyMetadataValue(rawValue interface{}) (string, bool) {
switch v := rawValue.(type) {
case string:
return v, true
case bool:
return strconv.FormatBool(v), true
case json.Number:
return v.String(), true
case float64:
// The claims unmarshalled by go-oidc don't use UseNumber, so
// they'll come in as float64 instead of an integer or json.Number.
return strconv.FormatInt(int64(v), 10), true
// The numerical type cases following here are only here for the sake
// of numerical type completion. Everything is truncated to an integer
// before being stringified.
case float32:
return strconv.FormatInt(int64(v), 10), true
case int8:
return strconv.FormatInt(int64(v), 10), true
case int16:
return strconv.FormatInt(int64(v), 10), true
case int32:
return strconv.FormatInt(int64(v), 10), true
case int64:
return strconv.FormatInt(int64(v), 10), true
case int:
return strconv.FormatInt(int64(v), 10), true
case uint8:
return strconv.FormatInt(int64(v), 10), true
case uint16:
return strconv.FormatInt(int64(v), 10), true
case uint32:
return strconv.FormatInt(int64(v), 10), true
case uint64:
return strconv.FormatInt(int64(v), 10), true
case uint:
return strconv.FormatInt(int64(v), 10), true
default:
return "", false
}
}
// validateAudience checks whether any of the audiences in audClaim match those
// in boundAudiences. If strict is true and there are no bound audiences, then
// the presence of any audience in the received claim is considered an error.
func validateAudience(boundAudiences, audClaim []string, strict bool) error {
if strict && len(boundAudiences) == 0 && len(audClaim) > 0 {
return errors.New("audience claim found in JWT but no audiences are bound")
}
if len(boundAudiences) > 0 {
for _, v := range boundAudiences {
if strutil.StrListContains(audClaim, v) {
return nil
}
}
return errors.New("aud claim does not match any bound audience")
}
return nil
}

View file

@ -0,0 +1,614 @@
package oidcauth
import (
"encoding/json"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExtractStringMetadata(t *testing.T) {
emptyMap := make(map[string]string)
tests := map[string]struct {
allClaims map[string]interface{}
claimMappings map[string]string
expected map[string]string
errExpected bool
}{
"empty": {nil, nil, emptyMap, false},
"all": {
map[string]interface{}{
"data1": "foo",
"data2": "bar",
},
map[string]string{
"data1": "val1",
"data2": "val2",
},
map[string]string{
"val1": "foo",
"val2": "bar",
},
false,
},
"some": {
map[string]interface{}{
"data1": "foo",
"data2": "bar",
},
map[string]string{
"data1": "val1",
"data3": "val2",
},
map[string]string{
"val1": "foo",
},
false,
},
"none": {
map[string]interface{}{
"data1": "foo",
"data2": "bar",
},
map[string]string{
"data8": "val1",
"data9": "val2",
},
emptyMap,
false,
},
"nested data": {
map[string]interface{}{
"data1": "foo",
"data2": map[string]interface{}{
"child": "bar",
},
"data3": true,
"data4": false,
"data5": float64(7.9),
"data6": json.Number("-12345"),
"data7": int(42),
},
map[string]string{
"data1": "val1",
"/data2/child": "val2",
"data3": "val3",
"data4": "val4",
"data5": "val5",
"data6": "val6",
"data7": "val7",
},
map[string]string{
"val1": "foo",
"val2": "bar",
"val3": "true",
"val4": "false",
"val5": "7",
"val6": "-12345",
"val7": "42",
},
false,
},
"error: a struct isn't stringifiable": {
map[string]interface{}{
"data1": map[string]interface{}{
"child": "bar",
},
},
map[string]string{
"data1": "val1",
},
nil,
true,
},
"error: a slice isn't stringifiable": {
map[string]interface{}{
"data1": []interface{}{
"child", "bar",
},
},
map[string]string{
"data1": "val1",
},
nil,
true,
},
}
for name, test := range tests {
test := test
t.Run(name, func(t *testing.T) {
actual, err := extractStringMetadata(nil, test.allClaims, test.claimMappings)
if test.errExpected {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.expected, actual)
}
})
}
}
func TestExtractListMetadata(t *testing.T) {
emptyMap := make(map[string][]string)
tests := map[string]struct {
allClaims map[string]interface{}
claimMappings map[string]string
expected map[string][]string
errExpected bool
}{
"empty": {nil, nil, emptyMap, false},
"all - singular": {
map[string]interface{}{
"data1": "foo",
"data2": "bar",
},
map[string]string{
"data1": "val1",
"data2": "val2",
},
map[string][]string{
"val1": []string{"foo"},
"val2": []string{"bar"},
},
false,
},
"some - singular": {
map[string]interface{}{
"data1": "foo",
"data2": "bar",
},
map[string]string{
"data1": "val1",
"data3": "val2",
},
map[string][]string{
"val1": []string{"foo"},
},
false,
},
"none - singular": {
map[string]interface{}{
"data1": "foo",
"data2": "bar",
},
map[string]string{
"data8": "val1",
"data9": "val2",
},
emptyMap,
false,
},
"nested data - singular": {
map[string]interface{}{
"data1": "foo",
"data2": map[string]interface{}{
"child": "bar",
},
"data3": true,
"data4": false,
"data5": float64(7.9),
"data6": json.Number("-12345"),
"data7": int(42),
"data8": []interface{}{ // mixed
"foo", true, float64(7.9), json.Number("-12345"), int(42),
},
},
map[string]string{
"data1": "val1",
"/data2/child": "val2",
"data3": "val3",
"data4": "val4",
"data5": "val5",
"data6": "val6",
"data7": "val7",
"data8": "val8",
},
map[string][]string{
"val1": []string{"foo"},
"val2": []string{"bar"},
"val3": []string{"true"},
"val4": []string{"false"},
"val5": []string{"7"},
"val6": []string{"-12345"},
"val7": []string{"42"},
"val8": []string{
"foo", "true", "7", "-12345", "42",
},
},
false,
},
"error: a struct isn't stringifiable (singular)": {
map[string]interface{}{
"data1": map[string]interface{}{
"child": map[string]interface{}{
"inner": "bar",
},
},
},
map[string]string{
"data1": "val1",
},
nil,
true,
},
"error: a slice isn't stringifiable (singular)": {
map[string]interface{}{
"data1": []interface{}{
"child", []interface{}{"bar"},
},
},
map[string]string{
"data1": "val1",
},
nil,
true,
},
"non-string-slice data (string)": {
map[string]interface{}{
"data1": "foo",
},
map[string]string{
"data1": "val1",
},
map[string][]string{
"val1": []string{"foo"}, // singular values become lists
},
false,
},
"all - list": {
map[string]interface{}{
"data1": []interface{}{"foo", "otherFoo"},
"data2": []interface{}{"bar", "otherBar"},
},
map[string]string{
"data1": "val1",
"data2": "val2",
},
map[string][]string{
"val1": []string{"foo", "otherFoo"},
"val2": []string{"bar", "otherBar"},
},
false,
},
"some - list": {
map[string]interface{}{
"data1": []interface{}{"foo", "otherFoo"},
"data2": map[string]interface{}{
"child": []interface{}{"bar", "otherBar"},
},
},
map[string]string{
"data1": "val1",
"/data2/child": "val2",
},
map[string][]string{
"val1": []string{"foo", "otherFoo"},
"val2": []string{"bar", "otherBar"},
},
false,
},
"none - list": {
map[string]interface{}{
"data1": []interface{}{"foo"},
"data2": []interface{}{"bar"},
},
map[string]string{
"data8": "val1",
"data9": "val2",
},
emptyMap,
false,
},
"list omits empty strings": {
map[string]interface{}{
"data1": []interface{}{"foo", "", "otherFoo", ""},
"data2": "",
},
map[string]string{
"data1": "val1",
"data2": "val2",
},
map[string][]string{
"val1": []string{"foo", "otherFoo"},
"val2": []string{},
},
false,
},
"nested data - list": {
map[string]interface{}{
"data1": []interface{}{"foo"},
"data2": map[string]interface{}{
"child": []interface{}{"bar"},
},
"data3": []interface{}{true},
"data4": []interface{}{false},
"data5": []interface{}{float64(7.9)},
"data6": []interface{}{json.Number("-12345")},
"data7": []interface{}{int(42)},
"data8": []interface{}{ // mixed
"foo", true, float64(7.9), json.Number("-12345"), int(42),
},
},
map[string]string{
"data1": "val1",
"/data2/child": "val2",
"data3": "val3",
"data4": "val4",
"data5": "val5",
"data6": "val6",
"data7": "val7",
"data8": "val8",
},
map[string][]string{
"val1": []string{"foo"},
"val2": []string{"bar"},
"val3": []string{"true"},
"val4": []string{"false"},
"val5": []string{"7"},
"val6": []string{"-12345"},
"val7": []string{"42"},
"val8": []string{
"foo", "true", "7", "-12345", "42",
},
},
false,
},
"JSONPointer": {
map[string]interface{}{
"foo": "a",
"bar": map[string]interface{}{
"baz": []string{"x", "y", "z"},
},
},
map[string]string{
"foo": "val1",
"/bar/baz/1": "val2",
},
map[string][]string{
"val1": []string{"a"},
"val2": []string{"y"},
},
false,
},
"JSONPointer not found": {
map[string]interface{}{
"foo": "a",
"bar": map[string]interface{}{
"baz": []string{"x", "y", "z"},
},
},
map[string]string{
"foo": "val1",
"/bar/XXX/1243": "val2",
},
map[string][]string{
"val1": []string{"a"},
},
false,
},
}
for name, test := range tests {
test := test
t.Run(name, func(t *testing.T) {
actual, err := extractListMetadata(nil, test.allClaims, test.claimMappings)
if test.errExpected {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.expected, actual)
}
})
}
}
func TestGetClaim(t *testing.T) {
data := `{
"a": 42,
"b": "bar",
"c": {
"d": 95,
"e": [
"dog",
"cat",
"bird"
],
"f": {
"g": "zebra"
}
},
"h": true,
"i": false
}`
var claims map[string]interface{}
require.NoError(t, json.Unmarshal([]byte(data), &claims))
tests := []struct {
claim string
value interface{}
}{
{"a", float64(42)},
{"/a", float64(42)},
{"b", "bar"},
{"/c/d", float64(95)},
{"/c/e/1", "cat"},
{"/c/f/g", "zebra"},
{"nope", nil},
{"/c/f/h", nil},
{"", nil},
{"\\", nil},
{"h", true},
{"i", false},
{"/c/e", []interface{}{"dog", "cat", "bird"}},
{"/c/f", map[string]interface{}{"g": "zebra"}},
}
for _, test := range tests {
t.Run(test.claim, func(t *testing.T) {
v := getClaim(nil, claims, test.claim)
require.Equal(t, test.value, v)
})
}
}
func TestNormalizeList(t *testing.T) {
tests := []struct {
raw interface{}
normalized []interface{}
ok bool
}{
{
raw: []interface{}{"green", 42},
normalized: []interface{}{"green", 42},
ok: true,
},
{
raw: []interface{}{"green"},
normalized: []interface{}{"green"},
ok: true,
},
{
raw: []interface{}{},
normalized: []interface{}{},
ok: true,
},
{
raw: "green",
normalized: []interface{}{"green"},
ok: true,
},
{
raw: "",
normalized: []interface{}{""},
ok: true,
},
{
raw: 42,
normalized: []interface{}{42},
ok: true,
},
{
raw: struct{ A int }{A: 5},
normalized: nil,
ok: false,
},
{
raw: nil,
normalized: nil,
ok: false,
},
}
for _, tc := range tests {
tc := tc
t.Run(fmt.Sprintf("%#v", tc.raw), func(t *testing.T) {
normalized, ok := normalizeList(tc.raw)
assert.Equal(t, tc.normalized, normalized)
assert.Equal(t, tc.ok, ok)
})
}
}
func TestStringifyMetadataValue(t *testing.T) {
cases := map[string]struct {
value interface{}
expect string
expectFailure bool
}{
"empty string": {"", "", false},
"string": {"foo", "foo", false},
"true": {true, "true", false},
"false": {false, "false", false},
"json number": {json.Number("-12345"), "-12345", false},
"float64": {float64(7.9), "7", false},
//
"float32": {float32(7.9), "7", false},
"int8": {int8(42), "42", false},
"int16": {int16(42), "42", false},
"int32": {int32(42), "42", false},
"int64": {int64(42), "42", false},
"int": {int(42), "42", false},
"uint8": {uint8(42), "42", false},
"uint16": {uint16(42), "42", false},
"uint32": {uint32(42), "42", false},
"uint64": {uint64(42), "42", false},
"uint": {uint(42), "42", false},
// fail
"string slice": {[]string{"a"}, "", true},
"int slice": {[]int64{99}, "", true},
"map": {map[string]int{"a": 99}, "", true},
"nil": {nil, "", true},
"struct": {struct{ A int }{A: 5}, "", true},
}
for name, tc := range cases {
tc := tc
t.Run(name, func(t *testing.T) {
got, ok := stringifyMetadataValue(tc.value)
if tc.expectFailure {
require.False(t, ok)
} else {
require.True(t, ok)
require.Equal(t, tc.expect, got)
}
})
}
}
func TestValidateAudience(t *testing.T) {
tests := []struct {
boundAudiences []string
audience []string
errExpectedLax bool
errExpectedStrict bool
}{
{[]string{"a"}, []string{"a"}, false, false},
{[]string{"a"}, []string{"b"}, true, true},
{[]string{"a"}, []string{""}, true, true},
{[]string{}, []string{"a"}, false, true},
{[]string{"a", "b"}, []string{"a"}, false, false},
{[]string{"a", "b"}, []string{"b"}, false, false},
{[]string{"a", "b"}, []string{"a", "b", "c"}, false, false},
{[]string{"a", "b"}, []string{"c", "d"}, true, true},
}
for _, tc := range tests {
tc := tc
t.Run(fmt.Sprintf(
"boundAudiences=%#v audience=%#v strict=false",
tc.boundAudiences, tc.audience,
), func(t *testing.T) {
err := validateAudience(tc.boundAudiences, tc.audience, false)
if tc.errExpectedLax {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
t.Run(fmt.Sprintf(
"boundAudiences=%#v audience=%#v strict=true",
tc.boundAudiences, tc.audience,
), func(t *testing.T) {
err := validateAudience(tc.boundAudiences, tc.audience, true)
if tc.errExpectedStrict {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}

View file

@ -0,0 +1,38 @@
package oidcauth
import (
"net/url"
"github.com/hashicorp/consul/internal/go-sso/oidcauth/internal/strutil"
)
// validRedirect checks whether uri is in allowed using special handling for loopback uris.
// Ref: https://tools.ietf.org/html/rfc8252#section-7.3
func validRedirect(uri string, allowed []string) bool {
inputURI, err := url.Parse(uri)
if err != nil {
return false
}
// if uri isn't a loopback, just string search the allowed list
if !strutil.StrListContains([]string{"localhost", "127.0.0.1", "::1"}, inputURI.Hostname()) {
return strutil.StrListContains(allowed, uri)
}
// otherwise, search for a match in a port-agnostic manner, per the OAuth RFC.
inputURI.Host = inputURI.Hostname()
for _, a := range allowed {
allowedURI, err := url.Parse(a)
if err != nil {
return false // shouldn't happen due to (*Config).Validate checks
}
allowedURI.Host = allowedURI.Hostname()
if inputURI.String() == allowedURI.String() {
return true
}
}
return false
}

View file

@ -0,0 +1,44 @@
package oidcauth
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
func TestValidRedirect(t *testing.T) {
tests := []struct {
uri string
allowed []string
expected bool
}{
// valid
{"https://example.com", []string{"https://example.com"}, true},
{"https://example.com:5000", []string{"a", "b", "https://example.com:5000"}, true},
{"https://example.com/a/b/c", []string{"a", "b", "https://example.com/a/b/c"}, true},
{"https://localhost:9000", []string{"a", "b", "https://localhost:5000"}, true},
{"https://127.0.0.1:9000", []string{"a", "b", "https://127.0.0.1:5000"}, true},
{"https://[::1]:9000", []string{"a", "b", "https://[::1]:5000"}, true},
{"https://[::1]:9000/x/y?r=42", []string{"a", "b", "https://[::1]:5000/x/y?r=42"}, true},
// invalid
{"https://example.com", []string{}, false},
{"http://example.com", []string{"a", "b", "https://example.com"}, false},
{"https://example.com:9000", []string{"a", "b", "https://example.com:5000"}, false},
{"https://[::2]:9000", []string{"a", "b", "https://[::2]:5000"}, false},
{"https://localhost:5000", []string{"a", "b", "https://127.0.0.1:5000"}, false},
{"https://localhost:5000", []string{"a", "b", "https://127.0.0.1:5000"}, false},
{"https://localhost:5000", []string{"a", "b", "http://localhost:5000"}, false},
{"https://[::1]:5000/x/y?r=42", []string{"a", "b", "https://[::1]:5000/x/y?r=43"}, false},
// extra invalid
{"%%%%%%%%%%%", []string{}, false},
}
for _, tc := range tests {
tc := tc
t.Run(fmt.Sprintf("uri=%q allowed=%#v", tc.uri, tc.allowed), func(t *testing.T) {
require.Equal(t, tc.expected, validRedirect(tc.uri, tc.allowed))
})
}
}

2
vendor/github.com/coreos/go-oidc/.gitignore generated vendored Normal file
View file

@ -0,0 +1,2 @@
/bin
/gopath

16
vendor/github.com/coreos/go-oidc/.travis.yml generated vendored Normal file
View file

@ -0,0 +1,16 @@
language: go
go:
- "1.9"
- "1.10"
install:
- go get -v -t github.com/coreos/go-oidc/...
- go get golang.org/x/tools/cmd/cover
- go get github.com/golang/lint/golint
script:
- ./test
notifications:
email: false

71
vendor/github.com/coreos/go-oidc/CONTRIBUTING.md generated vendored Normal file
View file

@ -0,0 +1,71 @@
# How to Contribute
CoreOS projects are [Apache 2.0 licensed](LICENSE) and accept contributions via
GitHub pull requests. This document outlines some of the conventions on
development workflow, commit message formatting, contact points and other
resources to make it easier to get your contribution accepted.
# Certificate of Origin
By contributing to this project you agree to the Developer Certificate of
Origin (DCO). This document was created by the Linux Kernel community and is a
simple statement that you, as a contributor, have the legal right to make the
contribution. See the [DCO](DCO) file for details.
# Email and Chat
The project currently uses the general CoreOS email list and IRC channel:
- Email: [coreos-dev](https://groups.google.com/forum/#!forum/coreos-dev)
- IRC: #[coreos](irc://irc.freenode.org:6667/#coreos) IRC channel on freenode.org
Please avoid emailing maintainers found in the MAINTAINERS file directly. They
are very busy and read the mailing lists.
## Getting Started
- Fork the repository on GitHub
- Read the [README](README.md) for build and test instructions
- Play with the project, submit bugs, submit patches!
## Contribution Flow
This is a rough outline of what a contributor's workflow looks like:
- Create a topic branch from where you want to base your work (usually master).
- Make commits of logical units.
- Make sure your commit messages are in the proper format (see below).
- Push your changes to a topic branch in your fork of the repository.
- Make sure the tests pass, and add any new tests as appropriate.
- Submit a pull request to the original repository.
Thanks for your contributions!
### Format of the Commit Message
We follow a rough convention for commit messages that is designed to answer two
questions: what changed and why. The subject line should feature the what and
the body of the commit should describe the why.
```
scripts: add the test-cluster command
this uses tmux to setup a test cluster that you can easily kill and
start for debugging.
Fixes #38
```
The format can be described more formally as follows:
```
<subsystem>: <what changed>
<BLANK LINE>
<why this change was made>
<BLANK LINE>
<footer>
```
The first line is the subject and should be no longer than 70 characters, the
second line is always blank, and other lines should be wrapped at 80 characters.
This allows the message to be easier to read on GitHub as well as in various
git tools.

36
vendor/github.com/coreos/go-oidc/DCO generated vendored Normal file
View file

@ -0,0 +1,36 @@
Developer Certificate of Origin
Version 1.1
Copyright (C) 2004, 2006 The Linux Foundation and its contributors.
660 York Street, Suite 102,
San Francisco, CA 94110 USA
Everyone is permitted to copy and distribute verbatim copies of this
license document, but changing it is not allowed.
Developer's Certificate of Origin 1.1
By making a contribution to this project, I certify that:
(a) The contribution was created in whole or in part by me and I
have the right to submit it under the open source license
indicated in the file; or
(b) The contribution is based upon previous work that, to the best
of my knowledge, is covered under an appropriate open source
license and I have the right under that license to submit that
work with modifications, whether created in whole or in part
by me, under the same open source license (unless I am
permitted to submit under a different license), as indicated
in the file; or
(c) The contribution was provided directly to me by some other
person who certified (a), (b) or (c) and I have not modified
it.
(d) I understand and agree that this project and the contribution
are public and that a record of the contribution (including all
personal information I submit with it, including my sign-off) is
maintained indefinitely and may be redistributed consistent with
this project or the open source license(s) involved.

202
vendor/github.com/coreos/go-oidc/LICENSE generated vendored Normal file
View file

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright {yyyy} {name of copyright owner}
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

3
vendor/github.com/coreos/go-oidc/MAINTAINERS generated vendored Normal file
View file

@ -0,0 +1,3 @@
Eric Chiang <ericchiang@google.com> (@ericchiang)
Mike Danese <mikedanese@google.com> (@mikedanese)
Rithu Leena John <rjohn@redhat.com> (@rithujohn191)

5
vendor/github.com/coreos/go-oidc/NOTICE generated vendored Normal file
View file

@ -0,0 +1,5 @@
CoreOS Project
Copyright 2014 CoreOS, Inc
This product includes software developed at CoreOS, Inc.
(http://www.coreos.com/).

72
vendor/github.com/coreos/go-oidc/README.md generated vendored Normal file
View file

@ -0,0 +1,72 @@
# go-oidc
[![GoDoc](https://godoc.org/github.com/coreos/go-oidc?status.svg)](https://godoc.org/github.com/coreos/go-oidc)
[![Build Status](https://travis-ci.org/coreos/go-oidc.png?branch=master)](https://travis-ci.org/coreos/go-oidc)
## OpenID Connect support for Go
This package enables OpenID Connect support for the [golang.org/x/oauth2](https://godoc.org/golang.org/x/oauth2) package.
```go
provider, err := oidc.NewProvider(ctx, "https://accounts.google.com")
if err != nil {
// handle error
}
// Configure an OpenID Connect aware OAuth2 client.
oauth2Config := oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
RedirectURL: redirectURL,
// Discovery returns the OAuth2 endpoints.
Endpoint: provider.Endpoint(),
// "openid" is a required scope for OpenID Connect flows.
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
```
OAuth2 redirects are unchanged.
```go
func handleRedirect(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, oauth2Config.AuthCodeURL(state), http.StatusFound)
}
```
The on responses, the provider can be used to verify ID Tokens.
```go
var verifier = provider.Verifier(&oidc.Config{ClientID: clientID})
func handleOAuth2Callback(w http.ResponseWriter, r *http.Request) {
// Verify state and errors.
oauth2Token, err := oauth2Config.Exchange(ctx, r.URL.Query().Get("code"))
if err != nil {
// handle error
}
// Extract the ID Token from OAuth2 token.
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
// handle missing token
}
// Parse and verify ID Token payload.
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
// handle error
}
// Extract custom claims
var claims struct {
Email string `json:"email"`
Verified bool `json:"email_verified"`
}
if err := idToken.Claims(&claims); err != nil {
// handle error
}
}
```

61
vendor/github.com/coreos/go-oidc/code-of-conduct.md generated vendored Normal file
View file

@ -0,0 +1,61 @@
## CoreOS Community Code of Conduct
### Contributor Code of Conduct
As contributors and maintainers of this project, and in the interest of
fostering an open and welcoming community, we pledge to respect all people who
contribute through reporting issues, posting feature requests, updating
documentation, submitting pull requests or patches, and other activities.
We are committed to making participation in this project a harassment-free
experience for everyone, regardless of level of experience, gender, gender
identity and expression, sexual orientation, disability, personal appearance,
body size, race, ethnicity, age, religion, or nationality.
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery
* Personal attacks
* Trolling or insulting/derogatory comments
* Public or private harassment
* Publishing others' private information, such as physical or electronic addresses, without explicit permission
* Other unethical or unprofessional conduct.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct. By adopting this Code of Conduct,
project maintainers commit themselves to fairly and consistently applying these
principles to every aspect of managing this project. Project maintainers who do
not follow or enforce the Code of Conduct may be permanently removed from the
project team.
This code of conduct applies both within project spaces and in public spaces
when an individual is representing the project or its community.
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting a project maintainer, Brandon Philips
<brandon.philips@coreos.com>, and/or Rithu John <rithu.john@coreos.com>.
This Code of Conduct is adapted from the Contributor Covenant
(http://contributor-covenant.org), version 1.2.0, available at
http://contributor-covenant.org/version/1/2/0/
### CoreOS Events Code of Conduct
CoreOS events are working conferences intended for professional networking and
collaboration in the CoreOS community. Attendees are expected to behave
according to professional standards and in accordance with their employers
policies on appropriate workplace behavior.
While at CoreOS events or related social networking opportunities, attendees
should not engage in discriminatory or offensive speech or actions including
but not limited to gender, sexuality, race, age, disability, or religion.
Speakers should be especially aware of these concerns.
CoreOS does not condone any statements by speakers contrary to these standards.
CoreOS reserves the right to deny entrance and/or eject from an event (without
refund) any individual found to be engaging in discriminatory or offensive
speech or actions.
Please bring any concerns to the immediate attention of designated on-site
staff, Brandon Philips <brandon.philips@coreos.com>, and/or Rithu John <rithu.john@coreos.com>.

20
vendor/github.com/coreos/go-oidc/jose.go generated vendored Normal file
View file

@ -0,0 +1,20 @@
// +build !golint
// Don't lint this file. We don't want to have to add a comment to each constant.
package oidc
const (
// JOSE asymmetric signing algorithm values as defined by RFC 7518
//
// see: https://tools.ietf.org/html/rfc7518#section-3.1
RS256 = "RS256" // RSASSA-PKCS-v1.5 using SHA-256
RS384 = "RS384" // RSASSA-PKCS-v1.5 using SHA-384
RS512 = "RS512" // RSASSA-PKCS-v1.5 using SHA-512
ES256 = "ES256" // ECDSA using P-256 and SHA-256
ES384 = "ES384" // ECDSA using P-384 and SHA-384
ES512 = "ES512" // ECDSA using P-521 and SHA-512
PS256 = "PS256" // RSASSA-PSS using SHA256 and MGF1-SHA256
PS384 = "PS384" // RSASSA-PSS using SHA384 and MGF1-SHA384
PS512 = "PS512" // RSASSA-PSS using SHA512 and MGF1-SHA512
)

228
vendor/github.com/coreos/go-oidc/jwks.go generated vendored Normal file
View file

@ -0,0 +1,228 @@
package oidc
import (
"context"
"errors"
"fmt"
"io/ioutil"
"net/http"
"sync"
"time"
"github.com/pquerna/cachecontrol"
jose "gopkg.in/square/go-jose.v2"
)
// keysExpiryDelta is the allowed clock skew between a client and the OpenID Connect
// server.
//
// When keys expire, they are valid for this amount of time after.
//
// If the keys have not expired, and an ID Token claims it was signed by a key not in
// the cache, if and only if the keys expire in this amount of time, the keys will be
// updated.
const keysExpiryDelta = 30 * time.Second
// NewRemoteKeySet returns a KeySet that can validate JSON web tokens by using HTTP
// GETs to fetch JSON web token sets hosted at a remote URL. This is automatically
// used by NewProvider using the URLs returned by OpenID Connect discovery, but is
// exposed for providers that don't support discovery or to prevent round trips to the
// discovery URL.
//
// The returned KeySet is a long lived verifier that caches keys based on cache-control
// headers. Reuse a common remote key set instead of creating new ones as needed.
//
// The behavior of the returned KeySet is undefined once the context is canceled.
func NewRemoteKeySet(ctx context.Context, jwksURL string) KeySet {
return newRemoteKeySet(ctx, jwksURL, time.Now)
}
func newRemoteKeySet(ctx context.Context, jwksURL string, now func() time.Time) *remoteKeySet {
if now == nil {
now = time.Now
}
return &remoteKeySet{jwksURL: jwksURL, ctx: ctx, now: now}
}
type remoteKeySet struct {
jwksURL string
ctx context.Context
now func() time.Time
// guard all other fields
mu sync.Mutex
// inflight suppresses parallel execution of updateKeys and allows
// multiple goroutines to wait for its result.
inflight *inflight
// A set of cached keys and their expiry.
cachedKeys []jose.JSONWebKey
expiry time.Time
}
// inflight is used to wait on some in-flight request from multiple goroutines.
type inflight struct {
doneCh chan struct{}
keys []jose.JSONWebKey
err error
}
func newInflight() *inflight {
return &inflight{doneCh: make(chan struct{})}
}
// wait returns a channel that multiple goroutines can receive on. Once it returns
// a value, the inflight request is done and result() can be inspected.
func (i *inflight) wait() <-chan struct{} {
return i.doneCh
}
// done can only be called by a single goroutine. It records the result of the
// inflight request and signals other goroutines that the result is safe to
// inspect.
func (i *inflight) done(keys []jose.JSONWebKey, err error) {
i.keys = keys
i.err = err
close(i.doneCh)
}
// result cannot be called until the wait() channel has returned a value.
func (i *inflight) result() ([]jose.JSONWebKey, error) {
return i.keys, i.err
}
func (r *remoteKeySet) VerifySignature(ctx context.Context, jwt string) ([]byte, error) {
jws, err := jose.ParseSigned(jwt)
if err != nil {
return nil, fmt.Errorf("oidc: malformed jwt: %v", err)
}
return r.verify(ctx, jws)
}
func (r *remoteKeySet) verify(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
// We don't support JWTs signed with multiple signatures.
keyID := ""
for _, sig := range jws.Signatures {
keyID = sig.Header.KeyID
break
}
keys, expiry := r.keysFromCache()
// Don't check expiry yet. This optimizes for when the provider is unavailable.
for _, key := range keys {
if keyID == "" || key.KeyID == keyID {
if payload, err := jws.Verify(&key); err == nil {
return payload, nil
}
}
}
if !r.now().Add(keysExpiryDelta).After(expiry) {
// Keys haven't expired, don't refresh.
return nil, errors.New("failed to verify id token signature")
}
keys, err := r.keysFromRemote(ctx)
if err != nil {
return nil, fmt.Errorf("fetching keys %v", err)
}
for _, key := range keys {
if keyID == "" || key.KeyID == keyID {
if payload, err := jws.Verify(&key); err == nil {
return payload, nil
}
}
}
return nil, errors.New("failed to verify id token signature")
}
func (r *remoteKeySet) keysFromCache() (keys []jose.JSONWebKey, expiry time.Time) {
r.mu.Lock()
defer r.mu.Unlock()
return r.cachedKeys, r.expiry
}
// keysFromRemote syncs the key set from the remote set, records the values in the
// cache, and returns the key set.
func (r *remoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, error) {
// Need to lock to inspect the inflight request field.
r.mu.Lock()
// If there's not a current inflight request, create one.
if r.inflight == nil {
r.inflight = newInflight()
// This goroutine has exclusive ownership over the current inflight
// request. It releases the resource by nil'ing the inflight field
// once the goroutine is done.
go func() {
// Sync keys and finish inflight when that's done.
keys, expiry, err := r.updateKeys()
r.inflight.done(keys, err)
// Lock to update the keys and indicate that there is no longer an
// inflight request.
r.mu.Lock()
defer r.mu.Unlock()
if err == nil {
r.cachedKeys = keys
r.expiry = expiry
}
// Free inflight so a different request can run.
r.inflight = nil
}()
}
inflight := r.inflight
r.mu.Unlock()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-inflight.wait():
return inflight.result()
}
}
func (r *remoteKeySet) updateKeys() ([]jose.JSONWebKey, time.Time, error) {
req, err := http.NewRequest("GET", r.jwksURL, nil)
if err != nil {
return nil, time.Time{}, fmt.Errorf("oidc: can't create request: %v", err)
}
resp, err := doRequest(r.ctx, req)
if err != nil {
return nil, time.Time{}, fmt.Errorf("oidc: get keys failed %v", err)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, time.Time{}, fmt.Errorf("unable to read response body: %v", err)
}
if resp.StatusCode != http.StatusOK {
return nil, time.Time{}, fmt.Errorf("oidc: get keys failed: %s %s", resp.Status, body)
}
var keySet jose.JSONWebKeySet
err = unmarshalResp(resp, body, &keySet)
if err != nil {
return nil, time.Time{}, fmt.Errorf("oidc: failed to decode keys: %v %s", err, body)
}
// If the server doesn't provide cache control headers, assume the
// keys expire immediately.
expiry := r.now()
_, e, err := cachecontrol.CachableResponse(req, resp, cachecontrol.Options{})
if err == nil && e.After(expiry) {
expiry = e
}
return keySet.Keys, expiry, nil
}

385
vendor/github.com/coreos/go-oidc/oidc.go generated vendored Normal file
View file

@ -0,0 +1,385 @@
// Package oidc implements OpenID Connect client logic for the golang.org/x/oauth2 package.
package oidc
import (
"context"
"crypto/sha256"
"crypto/sha512"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"hash"
"io/ioutil"
"mime"
"net/http"
"strings"
"time"
"golang.org/x/oauth2"
jose "gopkg.in/square/go-jose.v2"
)
const (
// ScopeOpenID is the mandatory scope for all OpenID Connect OAuth2 requests.
ScopeOpenID = "openid"
// ScopeOfflineAccess is an optional scope defined by OpenID Connect for requesting
// OAuth2 refresh tokens.
//
// Support for this scope differs between OpenID Connect providers. For instance
// Google rejects it, favoring appending "access_type=offline" as part of the
// authorization request instead.
//
// See: https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess
ScopeOfflineAccess = "offline_access"
)
var (
errNoAtHash = errors.New("id token did not have an access token hash")
errInvalidAtHash = errors.New("access token hash does not match value in ID token")
)
// ClientContext returns a new Context that carries the provided HTTP client.
//
// This method sets the same context key used by the golang.org/x/oauth2 package,
// so the returned context works for that package too.
//
// myClient := &http.Client{}
// ctx := oidc.ClientContext(parentContext, myClient)
//
// // This will use the custom client
// provider, err := oidc.NewProvider(ctx, "https://accounts.example.com")
//
func ClientContext(ctx context.Context, client *http.Client) context.Context {
return context.WithValue(ctx, oauth2.HTTPClient, client)
}
func doRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
client := http.DefaultClient
if c, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
client = c
}
return client.Do(req.WithContext(ctx))
}
// Provider represents an OpenID Connect server's configuration.
type Provider struct {
issuer string
authURL string
tokenURL string
userInfoURL string
// Raw claims returned by the server.
rawClaims []byte
remoteKeySet KeySet
}
type cachedKeys struct {
keys []jose.JSONWebKey
expiry time.Time
}
type providerJSON struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
UserInfoURL string `json:"userinfo_endpoint"`
}
// NewProvider uses the OpenID Connect discovery mechanism to construct a Provider.
//
// The issuer is the URL identifier for the service. For example: "https://accounts.google.com"
// or "https://login.salesforce.com".
func NewProvider(ctx context.Context, issuer string) (*Provider, error) {
wellKnown := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
req, err := http.NewRequest("GET", wellKnown, nil)
if err != nil {
return nil, err
}
resp, err := doRequest(ctx, req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("unable to read response body: %v", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%s: %s", resp.Status, body)
}
var p providerJSON
err = unmarshalResp(resp, body, &p)
if err != nil {
return nil, fmt.Errorf("oidc: failed to decode provider discovery object: %v", err)
}
if p.Issuer != issuer {
return nil, fmt.Errorf("oidc: issuer did not match the issuer returned by provider, expected %q got %q", issuer, p.Issuer)
}
return &Provider{
issuer: p.Issuer,
authURL: p.AuthURL,
tokenURL: p.TokenURL,
userInfoURL: p.UserInfoURL,
rawClaims: body,
remoteKeySet: NewRemoteKeySet(ctx, p.JWKSURL),
}, nil
}
// Claims unmarshals raw fields returned by the server during discovery.
//
// var claims struct {
// ScopesSupported []string `json:"scopes_supported"`
// ClaimsSupported []string `json:"claims_supported"`
// }
//
// if err := provider.Claims(&claims); err != nil {
// // handle unmarshaling error
// }
//
// For a list of fields defined by the OpenID Connect spec see:
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
func (p *Provider) Claims(v interface{}) error {
if p.rawClaims == nil {
return errors.New("oidc: claims not set")
}
return json.Unmarshal(p.rawClaims, v)
}
// Endpoint returns the OAuth2 auth and token endpoints for the given provider.
func (p *Provider) Endpoint() oauth2.Endpoint {
return oauth2.Endpoint{AuthURL: p.authURL, TokenURL: p.tokenURL}
}
// UserInfo represents the OpenID Connect userinfo claims.
type UserInfo struct {
Subject string `json:"sub"`
Profile string `json:"profile"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
claims []byte
}
// Claims unmarshals the raw JSON object claims into the provided object.
func (u *UserInfo) Claims(v interface{}) error {
if u.claims == nil {
return errors.New("oidc: claims not set")
}
return json.Unmarshal(u.claims, v)
}
// UserInfo uses the token source to query the provider's user info endpoint.
func (p *Provider) UserInfo(ctx context.Context, tokenSource oauth2.TokenSource) (*UserInfo, error) {
if p.userInfoURL == "" {
return nil, errors.New("oidc: user info endpoint is not supported by this provider")
}
req, err := http.NewRequest("GET", p.userInfoURL, nil)
if err != nil {
return nil, fmt.Errorf("oidc: create GET request: %v", err)
}
token, err := tokenSource.Token()
if err != nil {
return nil, fmt.Errorf("oidc: get access token: %v", err)
}
token.SetAuthHeader(req)
resp, err := doRequest(ctx, req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%s: %s", resp.Status, body)
}
var userInfo UserInfo
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, fmt.Errorf("oidc: failed to decode userinfo: %v", err)
}
userInfo.claims = body
return &userInfo, nil
}
// IDToken is an OpenID Connect extension that provides a predictable representation
// of an authorization event.
//
// The ID Token only holds fields OpenID Connect requires. To access additional
// claims returned by the server, use the Claims method.
type IDToken struct {
// The URL of the server which issued this token. OpenID Connect
// requires this value always be identical to the URL used for
// initial discovery.
//
// Note: Because of a known issue with Google Accounts' implementation
// this value may differ when using Google.
//
// See: https://developers.google.com/identity/protocols/OpenIDConnect#obtainuserinfo
Issuer string
// The client ID, or set of client IDs, that this token is issued for. For
// common uses, this is the client that initialized the auth flow.
//
// This package ensures the audience contains an expected value.
Audience []string
// A unique string which identifies the end user.
Subject string
// Expiry of the token. Ths package will not process tokens that have
// expired unless that validation is explicitly turned off.
Expiry time.Time
// When the token was issued by the provider.
IssuedAt time.Time
// Initial nonce provided during the authentication redirect.
//
// This package does NOT provided verification on the value of this field
// and it's the user's responsibility to ensure it contains a valid value.
Nonce string
// at_hash claim, if set in the ID token. Callers can verify an access token
// that corresponds to the ID token using the VerifyAccessToken method.
AccessTokenHash string
// signature algorithm used for ID token, needed to compute a verification hash of an
// access token
sigAlgorithm string
// Raw payload of the id_token.
claims []byte
// Map of distributed claim names to claim sources
distributedClaims map[string]claimSource
}
// Claims unmarshals the raw JSON payload of the ID Token into a provided struct.
//
// idToken, err := idTokenVerifier.Verify(rawIDToken)
// if err != nil {
// // handle error
// }
// var claims struct {
// Email string `json:"email"`
// EmailVerified bool `json:"email_verified"`
// }
// if err := idToken.Claims(&claims); err != nil {
// // handle error
// }
//
func (i *IDToken) Claims(v interface{}) error {
if i.claims == nil {
return errors.New("oidc: claims not set")
}
return json.Unmarshal(i.claims, v)
}
// VerifyAccessToken verifies that the hash of the access token that corresponds to the iD token
// matches the hash in the id token. It returns an error if the hashes don't match.
// It is the caller's responsibility to ensure that the optional access token hash is present for the ID token
// before calling this method. See https://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken
func (i *IDToken) VerifyAccessToken(accessToken string) error {
if i.AccessTokenHash == "" {
return errNoAtHash
}
var h hash.Hash
switch i.sigAlgorithm {
case RS256, ES256, PS256:
h = sha256.New()
case RS384, ES384, PS384:
h = sha512.New384()
case RS512, ES512, PS512:
h = sha512.New()
default:
return fmt.Errorf("oidc: unsupported signing algorithm %q", i.sigAlgorithm)
}
h.Write([]byte(accessToken)) // hash documents that Write will never return an error
sum := h.Sum(nil)[:h.Size()/2]
actual := base64.RawURLEncoding.EncodeToString(sum)
if actual != i.AccessTokenHash {
return errInvalidAtHash
}
return nil
}
type idToken struct {
Issuer string `json:"iss"`
Subject string `json:"sub"`
Audience audience `json:"aud"`
Expiry jsonTime `json:"exp"`
IssuedAt jsonTime `json:"iat"`
NotBefore *jsonTime `json:"nbf"`
Nonce string `json:"nonce"`
AtHash string `json:"at_hash"`
ClaimNames map[string]string `json:"_claim_names"`
ClaimSources map[string]claimSource `json:"_claim_sources"`
}
type claimSource struct {
Endpoint string `json:"endpoint"`
AccessToken string `json:"access_token"`
}
type audience []string
func (a *audience) UnmarshalJSON(b []byte) error {
var s string
if json.Unmarshal(b, &s) == nil {
*a = audience{s}
return nil
}
var auds []string
if err := json.Unmarshal(b, &auds); err != nil {
return err
}
*a = audience(auds)
return nil
}
type jsonTime time.Time
func (j *jsonTime) UnmarshalJSON(b []byte) error {
var n json.Number
if err := json.Unmarshal(b, &n); err != nil {
return err
}
var unix int64
if t, err := n.Int64(); err == nil {
unix = t
} else {
f, err := n.Float64()
if err != nil {
return err
}
unix = int64(f)
}
*j = jsonTime(time.Unix(unix, 0))
return nil
}
func unmarshalResp(r *http.Response, body []byte, v interface{}) error {
err := json.Unmarshal(body, &v)
if err == nil {
return nil
}
ct := r.Header.Get("Content-Type")
mediaType, _, parseErr := mime.ParseMediaType(ct)
if parseErr == nil && mediaType == "application/json" {
return fmt.Errorf("got Content-Type = application/json, but could not unmarshal as JSON: %v", err)
}
return fmt.Errorf("expected Content-Type = application/json, got %q: %v", ct, err)
}

16
vendor/github.com/coreos/go-oidc/test generated vendored Normal file
View file

@ -0,0 +1,16 @@
#!/bin/bash
set -e
# Filter out any files with a !golint build tag.
LINTABLE=$( go list -tags=golint -f '
{{- range $i, $file := .GoFiles -}}
{{ $file }} {{ end }}
{{ range $i, $file := .TestGoFiles -}}
{{ $file }} {{ end }}' github.com/coreos/go-oidc )
go test -v -i -race github.com/coreos/go-oidc/...
go test -v -race github.com/coreos/go-oidc/...
golint -set_exit_status $LINTABLE
go vet github.com/coreos/go-oidc/...
go build -v ./example/...

327
vendor/github.com/coreos/go-oidc/verify.go generated vendored Normal file
View file

@ -0,0 +1,327 @@
package oidc
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"strings"
"time"
"golang.org/x/oauth2"
jose "gopkg.in/square/go-jose.v2"
)
const (
issuerGoogleAccounts = "https://accounts.google.com"
issuerGoogleAccountsNoScheme = "accounts.google.com"
)
// KeySet is a set of publc JSON Web Keys that can be used to validate the signature
// of JSON web tokens. This is expected to be backed by a remote key set through
// provider metadata discovery or an in-memory set of keys delivered out-of-band.
type KeySet interface {
// VerifySignature parses the JSON web token, verifies the signature, and returns
// the raw payload. Header and claim fields are validated by other parts of the
// package. For example, the KeySet does not need to check values such as signature
// algorithm, issuer, and audience since the IDTokenVerifier validates these values
// independently.
//
// If VerifySignature makes HTTP requests to verify the token, it's expected to
// use any HTTP client associated with the context through ClientContext.
VerifySignature(ctx context.Context, jwt string) (payload []byte, err error)
}
// IDTokenVerifier provides verification for ID Tokens.
type IDTokenVerifier struct {
keySet KeySet
config *Config
issuer string
}
// NewVerifier returns a verifier manually constructed from a key set and issuer URL.
//
// It's easier to use provider discovery to construct an IDTokenVerifier than creating
// one directly. This method is intended to be used with provider that don't support
// metadata discovery, or avoiding round trips when the key set URL is already known.
//
// This constructor can be used to create a verifier directly using the issuer URL and
// JSON Web Key Set URL without using discovery:
//
// keySet := oidc.NewRemoteKeySet(ctx, "https://www.googleapis.com/oauth2/v3/certs")
// verifier := oidc.NewVerifier("https://accounts.google.com", keySet, config)
//
// Since KeySet is an interface, this constructor can also be used to supply custom
// public key sources. For example, if a user wanted to supply public keys out-of-band
// and hold them statically in-memory:
//
// // Custom KeySet implementation.
// keySet := newStatisKeySet(publicKeys...)
//
// // Verifier uses the custom KeySet implementation.
// verifier := oidc.NewVerifier("https://auth.example.com", keySet, config)
//
func NewVerifier(issuerURL string, keySet KeySet, config *Config) *IDTokenVerifier {
return &IDTokenVerifier{keySet: keySet, config: config, issuer: issuerURL}
}
// Config is the configuration for an IDTokenVerifier.
type Config struct {
// Expected audience of the token. For a majority of the cases this is expected to be
// the ID of the client that initialized the login flow. It may occasionally differ if
// the provider supports the authorizing party (azp) claim.
//
// If not provided, users must explicitly set SkipClientIDCheck.
ClientID string
// If specified, only this set of algorithms may be used to sign the JWT.
//
// Since many providers only support RS256, SupportedSigningAlgs defaults to this value.
SupportedSigningAlgs []string
// If true, no ClientID check performed. Must be true if ClientID field is empty.
SkipClientIDCheck bool
// If true, token expiry is not checked.
SkipExpiryCheck bool
// SkipIssuerCheck is intended for specialized cases where the the caller wishes to
// defer issuer validation. When enabled, callers MUST independently verify the Token's
// Issuer is a known good value.
//
// Mismatched issuers often indicate client mis-configuration. If mismatches are
// unexpected, evaluate if the provided issuer URL is incorrect instead of enabling
// this option.
SkipIssuerCheck bool
// Time function to check Token expiry. Defaults to time.Now
Now func() time.Time
}
// Verifier returns an IDTokenVerifier that uses the provider's key set to verify JWTs.
//
// The returned IDTokenVerifier is tied to the Provider's context and its behavior is
// undefined once the Provider's context is canceled.
func (p *Provider) Verifier(config *Config) *IDTokenVerifier {
return NewVerifier(p.issuer, p.remoteKeySet, config)
}
func parseJWT(p string) ([]byte, error) {
parts := strings.Split(p, ".")
if len(parts) < 2 {
return nil, fmt.Errorf("oidc: malformed jwt, expected 3 parts got %d", len(parts))
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("oidc: malformed jwt payload: %v", err)
}
return payload, nil
}
func contains(sli []string, ele string) bool {
for _, s := range sli {
if s == ele {
return true
}
}
return false
}
// Returns the Claims from the distributed JWT token
func resolveDistributedClaim(ctx context.Context, verifier *IDTokenVerifier, src claimSource) ([]byte, error) {
req, err := http.NewRequest("GET", src.Endpoint, nil)
if err != nil {
return nil, fmt.Errorf("malformed request: %v", err)
}
if src.AccessToken != "" {
req.Header.Set("Authorization", "Bearer "+src.AccessToken)
}
resp, err := doRequest(ctx, req)
if err != nil {
return nil, fmt.Errorf("oidc: Request to endpoint failed: %v", err)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("unable to read response body: %v", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("oidc: request failed: %v", resp.StatusCode)
}
token, err := verifier.Verify(ctx, string(body))
if err != nil {
return nil, fmt.Errorf("malformed response body: %v", err)
}
return token.claims, nil
}
func parseClaim(raw []byte, name string, v interface{}) error {
var parsed map[string]json.RawMessage
if err := json.Unmarshal(raw, &parsed); err != nil {
return err
}
val, ok := parsed[name]
if !ok {
return fmt.Errorf("claim doesn't exist: %s", name)
}
return json.Unmarshal([]byte(val), v)
}
// Verify parses a raw ID Token, verifies it's been signed by the provider, preforms
// any additional checks depending on the Config, and returns the payload.
//
// Verify does NOT do nonce validation, which is the callers responsibility.
//
// See: https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
//
// oauth2Token, err := oauth2Config.Exchange(ctx, r.URL.Query().Get("code"))
// if err != nil {
// // handle error
// }
//
// // Extract the ID Token from oauth2 token.
// rawIDToken, ok := oauth2Token.Extra("id_token").(string)
// if !ok {
// // handle error
// }
//
// token, err := verifier.Verify(ctx, rawIDToken)
//
func (v *IDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (*IDToken, error) {
jws, err := jose.ParseSigned(rawIDToken)
if err != nil {
return nil, fmt.Errorf("oidc: malformed jwt: %v", err)
}
// Throw out tokens with invalid claims before trying to verify the token. This lets
// us do cheap checks before possibly re-syncing keys.
payload, err := parseJWT(rawIDToken)
if err != nil {
return nil, fmt.Errorf("oidc: malformed jwt: %v", err)
}
var token idToken
if err := json.Unmarshal(payload, &token); err != nil {
return nil, fmt.Errorf("oidc: failed to unmarshal claims: %v", err)
}
distributedClaims := make(map[string]claimSource)
//step through the token to map claim names to claim sources"
for cn, src := range token.ClaimNames {
if src == "" {
return nil, fmt.Errorf("oidc: failed to obtain source from claim name")
}
s, ok := token.ClaimSources[src]
if !ok {
return nil, fmt.Errorf("oidc: source does not exist")
}
distributedClaims[cn] = s
}
t := &IDToken{
Issuer: token.Issuer,
Subject: token.Subject,
Audience: []string(token.Audience),
Expiry: time.Time(token.Expiry),
IssuedAt: time.Time(token.IssuedAt),
Nonce: token.Nonce,
AccessTokenHash: token.AtHash,
claims: payload,
distributedClaims: distributedClaims,
}
// Check issuer.
if !v.config.SkipIssuerCheck && t.Issuer != v.issuer {
// Google sometimes returns "accounts.google.com" as the issuer claim instead of
// the required "https://accounts.google.com". Detect this case and allow it only
// for Google.
//
// We will not add hooks to let other providers go off spec like this.
if !(v.issuer == issuerGoogleAccounts && t.Issuer == issuerGoogleAccountsNoScheme) {
return nil, fmt.Errorf("oidc: id token issued by a different provider, expected %q got %q", v.issuer, t.Issuer)
}
}
// If a client ID has been provided, make sure it's part of the audience. SkipClientIDCheck must be true if ClientID is empty.
//
// This check DOES NOT ensure that the ClientID is the party to which the ID Token was issued (i.e. Authorized party).
if !v.config.SkipClientIDCheck {
if v.config.ClientID != "" {
if !contains(t.Audience, v.config.ClientID) {
return nil, fmt.Errorf("oidc: expected audience %q got %q", v.config.ClientID, t.Audience)
}
} else {
return nil, fmt.Errorf("oidc: invalid configuration, clientID must be provided or SkipClientIDCheck must be set")
}
}
// If a SkipExpiryCheck is false, make sure token is not expired.
if !v.config.SkipExpiryCheck {
now := time.Now
if v.config.Now != nil {
now = v.config.Now
}
nowTime := now()
if t.Expiry.Before(nowTime) {
return nil, fmt.Errorf("oidc: token is expired (Token Expiry: %v)", t.Expiry)
}
// If nbf claim is provided in token, ensure that it is indeed in the past.
if token.NotBefore != nil {
nbfTime := time.Time(*token.NotBefore)
leeway := 1 * time.Minute
if nowTime.Add(leeway).Before(nbfTime) {
return nil, fmt.Errorf("oidc: current time %v before the nbf (not before) time: %v", nowTime, nbfTime)
}
}
}
switch len(jws.Signatures) {
case 0:
return nil, fmt.Errorf("oidc: id token not signed")
case 1:
default:
return nil, fmt.Errorf("oidc: multiple signatures on id token not supported")
}
sig := jws.Signatures[0]
supportedSigAlgs := v.config.SupportedSigningAlgs
if len(supportedSigAlgs) == 0 {
supportedSigAlgs = []string{RS256}
}
if !contains(supportedSigAlgs, sig.Header.Algorithm) {
return nil, fmt.Errorf("oidc: id token signed with unsupported algorithm, expected %q got %q", supportedSigAlgs, sig.Header.Algorithm)
}
t.sigAlgorithm = sig.Header.Algorithm
gotPayload, err := v.keySet.VerifySignature(ctx, rawIDToken)
if err != nil {
return nil, fmt.Errorf("failed to verify signature: %v", err)
}
// Ensure that the payload returned by the square actually matches the payload parsed earlier.
if !bytes.Equal(gotPayload, payload) {
return nil, errors.New("oidc: internal error, payload parsed did not match previous payload")
}
return t, nil
}
// Nonce returns an auth code option which requires the ID Token created by the
// OpenID Connect provider to contain the specified nonce.
func Nonce(nonce string) oauth2.AuthCodeOption {
return oauth2.SetAuthURLParam("nonce", nonce)
}

View file

@ -4,22 +4,40 @@ import (
"crypto/rand"
"encoding/hex"
"fmt"
"io"
)
// GenerateRandomBytes is used to generate random bytes of given size.
func GenerateRandomBytes(size int) ([]byte, error) {
return GenerateRandomBytesWithReader(size, rand.Reader)
}
// GenerateRandomBytesWithReader is used to generate random bytes of given size read from a given reader.
func GenerateRandomBytesWithReader(size int, reader io.Reader) ([]byte, error) {
if reader == nil {
return nil, fmt.Errorf("provided reader is nil")
}
buf := make([]byte, size)
if _, err := rand.Read(buf); err != nil {
if _, err := io.ReadFull(reader, buf); err != nil {
return nil, fmt.Errorf("failed to read random bytes: %v", err)
}
return buf, nil
}
const uuidLen = 16
// GenerateUUID is used to generate a random UUID
func GenerateUUID() (string, error) {
buf, err := GenerateRandomBytes(uuidLen)
return GenerateUUIDWithReader(rand.Reader)
}
// GenerateUUIDWithReader is used to generate a random UUID with a given Reader
func GenerateUUIDWithReader(reader io.Reader) (string, error) {
if reader == nil {
return "", fmt.Errorf("provided reader is nil")
}
buf, err := GenerateRandomBytesWithReader(uuidLen, reader)
if err != nil {
return "", err
}

View file

@ -0,0 +1,12 @@
language: go
go:
- 1.7
- tip
script:
- go test
matrix:
allow_failures:
- go: tip

21
vendor/github.com/mitchellh/pointerstructure/LICENSE generated vendored Normal file
View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2019 Mitchell Hashimoto
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

74
vendor/github.com/mitchellh/pointerstructure/README.md generated vendored Normal file
View file

@ -0,0 +1,74 @@
# pointerstructure [![GoDoc](https://godoc.org/github.com/mitchellh/pointerstructure?status.svg)](https://godoc.org/github.com/mitchellh/pointerstructure)
pointerstructure is a Go library for identifying a specific value within
any Go structure using a string syntax.
pointerstructure is based on
[JSON Pointer (RFC 6901)](https://tools.ietf.org/html/rfc6901), but
reimplemented for Go.
The goal of pointerstructure is to provide a single, well-known format
for addressing a specific value. This can be useful for user provided
input on structures, diffs of structures, etc.
## Features
* Get the value for an address
* Set the value for an address within an existing structure
* Delete the value at an address
* Sorting a list of addresses
## Installation
Standard `go get`:
```
$ go get github.com/mitchellh/pointerstructure
```
## Usage & Example
For usage and examples see the [Godoc](http://godoc.org/github.com/mitchellh/pointerstructure).
A quick code example is shown below:
```go
complex := map[string]interface{}{
"alice": 42,
"bob": []interface{}{
map[string]interface{}{
"name": "Bob",
},
},
}
value, err := pointerstructure.Get(complex, "/bob/0/name")
if err != nil {
panic(err)
}
fmt.Printf("%s", value)
// Output:
// Bob
```
Continuing the example above, you can also set values:
```go
value, err = pointerstructure.Set(complex, "/bob/0/name", "Alice")
if err != nil {
panic(err)
}
value, err = pointerstructure.Get(complex, "/bob/0/name")
if err != nil {
panic(err)
}
fmt.Printf("%s", value)
// Output:
// Alice
```

112
vendor/github.com/mitchellh/pointerstructure/delete.go generated vendored Normal file
View file

@ -0,0 +1,112 @@
package pointerstructure
import (
"fmt"
"reflect"
)
// Delete deletes the value specified by the pointer p in structure s.
//
// When deleting a slice index, all other elements will be shifted to
// the left. This is specified in RFC6902 (JSON Patch) and not RFC6901 since
// RFC6901 doesn't specify operations on pointers. If you don't want to
// shift elements, you should use Set to set the slice index to the zero value.
//
// The structures s must have non-zero values set up to this pointer.
// For example, if deleting "/bob/0/name", then "/bob/0" must be set already.
//
// The returned value is potentially a new value if this pointer represents
// the root document. Otherwise, the returned value will always be s.
func (p *Pointer) Delete(s interface{}) (interface{}, error) {
// if we represent the root doc, we've deleted everything
if len(p.Parts) == 0 {
return nil, nil
}
// Save the original since this is going to be our return value
originalS := s
// Get the parent value
var err error
s, err = p.Parent().Get(s)
if err != nil {
return nil, err
}
// Map for lookup of getter to call for type
funcMap := map[reflect.Kind]deleteFunc{
reflect.Array: p.deleteSlice,
reflect.Map: p.deleteMap,
reflect.Slice: p.deleteSlice,
}
val := reflect.ValueOf(s)
for val.Kind() == reflect.Interface {
val = val.Elem()
}
for val.Kind() == reflect.Ptr {
val = reflect.Indirect(val)
}
f, ok := funcMap[val.Kind()]
if !ok {
return nil, fmt.Errorf("delete %s: invalid value kind: %s", p, val.Kind())
}
result, err := f(originalS, val)
if err != nil {
return nil, fmt.Errorf("delete %s: %s", p, err)
}
return result, nil
}
type deleteFunc func(interface{}, reflect.Value) (interface{}, error)
func (p *Pointer) deleteMap(root interface{}, m reflect.Value) (interface{}, error) {
part := p.Parts[len(p.Parts)-1]
key, err := coerce(reflect.ValueOf(part), m.Type().Key())
if err != nil {
return root, err
}
// Delete the key
var elem reflect.Value
m.SetMapIndex(key, elem)
return root, nil
}
func (p *Pointer) deleteSlice(root interface{}, s reflect.Value) (interface{}, error) {
// Coerce the key to an int
part := p.Parts[len(p.Parts)-1]
idxVal, err := coerce(reflect.ValueOf(part), reflect.TypeOf(42))
if err != nil {
return root, err
}
idx := int(idxVal.Int())
// Verify we're within bounds
if idx < 0 || idx >= s.Len() {
return root, fmt.Errorf(
"index %d is out of range (length = %d)", idx, s.Len())
}
// Mimicing the following with reflection to do this:
//
// copy(a[i:], a[i+1:])
// a[len(a)-1] = nil // or the zero value of T
// a = a[:len(a)-1]
// copy(a[i:], a[i+1:])
reflect.Copy(s.Slice(idx, s.Len()), s.Slice(idx+1, s.Len()))
// a[len(a)-1] = nil // or the zero value of T
s.Index(s.Len() - 1).Set(reflect.Zero(s.Type().Elem()))
// a = a[:len(a)-1]
s = s.Slice(0, s.Len()-1)
// set the slice back on the parent
return p.Parent().Set(root, s.Interface())
}

96
vendor/github.com/mitchellh/pointerstructure/get.go generated vendored Normal file
View file

@ -0,0 +1,96 @@
package pointerstructure
import (
"fmt"
"reflect"
)
// Get reads the value out of the total value v.
func (p *Pointer) Get(v interface{}) (interface{}, error) {
// fast-path the empty address case to avoid reflect.ValueOf below
if len(p.Parts) == 0 {
return v, nil
}
// Map for lookup of getter to call for type
funcMap := map[reflect.Kind]func(string, reflect.Value) (reflect.Value, error){
reflect.Array: p.getSlice,
reflect.Map: p.getMap,
reflect.Slice: p.getSlice,
reflect.Struct: p.getStruct,
}
currentVal := reflect.ValueOf(v)
for i, part := range p.Parts {
for currentVal.Kind() == reflect.Interface {
currentVal = currentVal.Elem()
}
for currentVal.Kind() == reflect.Ptr {
currentVal = reflect.Indirect(currentVal)
}
f, ok := funcMap[currentVal.Kind()]
if !ok {
return nil, fmt.Errorf(
"%s: at part %d, invalid value kind: %s", p, i, currentVal.Kind())
}
var err error
currentVal, err = f(part, currentVal)
if err != nil {
return nil, fmt.Errorf("%s at part %d: %s", p, i, err)
}
}
return currentVal.Interface(), nil
}
func (p *Pointer) getMap(part string, m reflect.Value) (reflect.Value, error) {
var zeroValue reflect.Value
// Coerce the string part to the correct key type
key, err := coerce(reflect.ValueOf(part), m.Type().Key())
if err != nil {
return zeroValue, err
}
// Verify that the key exists
found := false
for _, k := range m.MapKeys() {
if k.Interface() == key.Interface() {
found = true
break
}
}
if !found {
return zeroValue, fmt.Errorf("couldn't find key %#v", key.Interface())
}
// Get the key
return m.MapIndex(key), nil
}
func (p *Pointer) getSlice(part string, v reflect.Value) (reflect.Value, error) {
var zeroValue reflect.Value
// Coerce the key to an int
idxVal, err := coerce(reflect.ValueOf(part), reflect.TypeOf(42))
if err != nil {
return zeroValue, err
}
idx := int(idxVal.Int())
// Verify we're within bounds
if idx < 0 || idx >= v.Len() {
return zeroValue, fmt.Errorf(
"index %d is out of range (length = %d)", idx, v.Len())
}
// Get the key
return v.Index(idx), nil
}
func (p *Pointer) getStruct(part string, m reflect.Value) (reflect.Value, error) {
return m.FieldByName(part), nil
}

5
vendor/github.com/mitchellh/pointerstructure/go.mod generated vendored Normal file
View file

@ -0,0 +1,5 @@
module github.com/mitchellh/pointerstructure
go 1.12
require github.com/mitchellh/mapstructure v1.1.2

2
vendor/github.com/mitchellh/pointerstructure/go.sum generated vendored Normal file
View file

@ -0,0 +1,2 @@
github.com/mitchellh/mapstructure v1.1.2 h1:fmNYVwqnSfB9mZU6OS2O6GsXM+wcskZDuKQzvN1EDeE=
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=

57
vendor/github.com/mitchellh/pointerstructure/parse.go generated vendored Normal file
View file

@ -0,0 +1,57 @@
package pointerstructure
import (
"fmt"
"strings"
)
// Parse parses a pointer from the input string. The input string
// is expected to follow the format specified by RFC 6901: '/'-separated
// parts. Each part can contain escape codes to contain '/' or '~'.
func Parse(input string) (*Pointer, error) {
// Special case the empty case
if input == "" {
return &Pointer{}, nil
}
// We expect the first character to be "/"
if input[0] != '/' {
return nil, fmt.Errorf(
"parse Go pointer %q: first char must be '/'", input)
}
// Trim out the first slash so we don't have to +1 every index
input = input[1:]
// Parse out all the parts
var parts []string
lastSlash := -1
for i, r := range input {
if r == '/' {
parts = append(parts, input[lastSlash+1:i])
lastSlash = i
}
}
// Add last part
parts = append(parts, input[lastSlash+1:])
// Process each part for string replacement
for i, p := range parts {
// Replace ~1 followed by ~0 as specified by the RFC
parts[i] = strings.Replace(
strings.Replace(p, "~1", "/", -1), "~0", "~", -1)
}
return &Pointer{Parts: parts}, nil
}
// MustParse is like Parse but panics if the input cannot be parsed.
func MustParse(input string) *Pointer {
p, err := Parse(input)
if err != nil {
panic(err)
}
return p
}

123
vendor/github.com/mitchellh/pointerstructure/pointer.go generated vendored Normal file
View file

@ -0,0 +1,123 @@
// Package pointerstructure provides functions for identifying a specific
// value within any Go structure using a string syntax.
//
// The syntax used is based on JSON Pointer (RFC 6901).
package pointerstructure
import (
"fmt"
"reflect"
"strings"
"github.com/mitchellh/mapstructure"
)
// Pointer represents a pointer to a specific value. You can construct
// a pointer manually or use Parse.
type Pointer struct {
// Parts are the pointer parts. No escape codes are processed here.
// The values are expected to be exact. If you have escape codes, use
// the Parse functions.
Parts []string
}
// Get reads the value at the given pointer.
//
// This is a shorthand for calling Parse on the pointer and then calling Get
// on that result. An error will be returned if the value cannot be found or
// there is an error with the format of pointer.
func Get(value interface{}, pointer string) (interface{}, error) {
p, err := Parse(pointer)
if err != nil {
return nil, err
}
return p.Get(value)
}
// Set sets the value at the given pointer.
//
// This is a shorthand for calling Parse on the pointer and then calling Set
// on that result. An error will be returned if the value cannot be found or
// there is an error with the format of pointer.
//
// Set returns the complete document, which might change if the pointer value
// points to the root ("").
func Set(doc interface{}, pointer string, value interface{}) (interface{}, error) {
p, err := Parse(pointer)
if err != nil {
return nil, err
}
return p.Set(doc, value)
}
// String returns the string value that can be sent back to Parse to get
// the same Pointer result.
func (p *Pointer) String() string {
if len(p.Parts) == 0 {
return ""
}
// Copy the parts so we can convert back the escapes
result := make([]string, len(p.Parts))
copy(result, p.Parts)
for i, p := range p.Parts {
result[i] = strings.Replace(
strings.Replace(p, "~", "~0", -1), "/", "~1", -1)
}
return "/" + strings.Join(result, "/")
}
// Parent returns a pointer to the parent element of this pointer.
//
// If Pointer represents the root (empty parts), a pointer representing
// the root is returned. Therefore, to check for the root, IsRoot() should be
// called.
func (p *Pointer) Parent() *Pointer {
// If this is root, then we just return a new root pointer. We allocate
// a new one though so this can still be modified.
if p.IsRoot() {
return &Pointer{}
}
parts := make([]string, len(p.Parts)-1)
copy(parts, p.Parts[:len(p.Parts)-1])
return &Pointer{
Parts: parts,
}
}
// IsRoot returns true if this pointer represents the root document.
func (p *Pointer) IsRoot() bool {
return len(p.Parts) == 0
}
// coerce is a helper to coerce a value to a specific type if it must
// and if its possible. If it isn't possible, an error is returned.
func coerce(value reflect.Value, to reflect.Type) (reflect.Value, error) {
// If the value is already assignable to the type, then let it go
if value.Type().AssignableTo(to) {
return value, nil
}
// If a direct conversion is possible, do that
if value.Type().ConvertibleTo(to) {
return value.Convert(to), nil
}
// Create a new value to hold our result
result := reflect.New(to)
// Decode
if err := mapstructure.WeakDecode(value.Interface(), result.Interface()); err != nil {
return result, fmt.Errorf(
"couldn't convert value %#v to type %s",
value.Interface(), to.String())
}
// We need to indirect the value since reflect.New always creates a pointer
return reflect.Indirect(result), nil
}

122
vendor/github.com/mitchellh/pointerstructure/set.go generated vendored Normal file
View file

@ -0,0 +1,122 @@
package pointerstructure
import (
"fmt"
"reflect"
)
// Set writes a value v to the pointer p in structure s.
//
// The structures s must have non-zero values set up to this pointer.
// For example, if setting "/bob/0/name", then "/bob/0" must be set already.
//
// The returned value is potentially a new value if this pointer represents
// the root document. Otherwise, the returned value will always be s.
func (p *Pointer) Set(s, v interface{}) (interface{}, error) {
// if we represent the root doc, return that
if len(p.Parts) == 0 {
return v, nil
}
// Save the original since this is going to be our return value
originalS := s
// Get the parent value
var err error
s, err = p.Parent().Get(s)
if err != nil {
return nil, err
}
// Map for lookup of getter to call for type
funcMap := map[reflect.Kind]setFunc{
reflect.Array: p.setSlice,
reflect.Map: p.setMap,
reflect.Slice: p.setSlice,
}
val := reflect.ValueOf(s)
for val.Kind() == reflect.Interface {
val = val.Elem()
}
for val.Kind() == reflect.Ptr {
val = reflect.Indirect(val)
}
f, ok := funcMap[val.Kind()]
if !ok {
return nil, fmt.Errorf("set %s: invalid value kind: %s", p, val.Kind())
}
result, err := f(originalS, val, reflect.ValueOf(v))
if err != nil {
return nil, fmt.Errorf("set %s: %s", p, err)
}
return result, nil
}
type setFunc func(interface{}, reflect.Value, reflect.Value) (interface{}, error)
func (p *Pointer) setMap(root interface{}, m, value reflect.Value) (interface{}, error) {
part := p.Parts[len(p.Parts)-1]
key, err := coerce(reflect.ValueOf(part), m.Type().Key())
if err != nil {
return root, err
}
elem, err := coerce(value, m.Type().Elem())
if err != nil {
return root, err
}
// Set the key
m.SetMapIndex(key, elem)
return root, nil
}
func (p *Pointer) setSlice(root interface{}, s, value reflect.Value) (interface{}, error) {
// Coerce the value, we'll need that no matter what
value, err := coerce(value, s.Type().Elem())
if err != nil {
return root, err
}
// If the part is the special "-", that means to append it (RFC6901 4.)
part := p.Parts[len(p.Parts)-1]
if part == "-" {
return p.setSliceAppend(root, s, value)
}
// Coerce the key to an int
idxVal, err := coerce(reflect.ValueOf(part), reflect.TypeOf(42))
if err != nil {
return root, err
}
idx := int(idxVal.Int())
// Verify we're within bounds
if idx < 0 || idx >= s.Len() {
return root, fmt.Errorf(
"index %d is out of range (length = %d)", idx, s.Len())
}
// Set the key
s.Index(idx).Set(value)
return root, nil
}
func (p *Pointer) setSliceAppend(root interface{}, s, value reflect.Value) (interface{}, error) {
// Coerce the value, we'll need that no matter what. This should
// be a no-op since we expect it to be done already, but there is
// a fast-path check for that in coerce so do it anyways.
value, err := coerce(value, s.Type().Elem())
if err != nil {
return root, err
}
// We can assume "s" is the parent of pointer value. We need to actually
// write s back because Append can return a new slice.
return p.Parent().Set(root, reflect.Append(s, value).Interface())
}

42
vendor/github.com/mitchellh/pointerstructure/sort.go generated vendored Normal file
View file

@ -0,0 +1,42 @@
package pointerstructure
import (
"sort"
)
// Sort does an in-place sort of the pointers so that they are in order
// of least specific to most specific alphabetized. For example:
// "/foo", "/foo/0", "/qux"
//
// This ordering is ideal for applying the changes in a way that ensures
// that parents are set first.
func Sort(p []*Pointer) { sort.Sort(PointerSlice(p)) }
// PointerSlice is a slice of pointers that adheres to sort.Interface
type PointerSlice []*Pointer
func (p PointerSlice) Len() int { return len(p) }
func (p PointerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
func (p PointerSlice) Less(i, j int) bool {
// Equal number of parts, do a string compare per part
for idx, ival := range p[i].Parts {
// If we're passed the length of p[j] parts, then we're done
if idx >= len(p[j].Parts) {
break
}
// Compare the values if they're not equal
jval := p[j].Parts[idx]
if ival != jval {
return ival < jval
}
}
// Equal prefix, take the shorter
if len(p[i].Parts) != len(p[j].Parts) {
return len(p[i].Parts) < len(p[j].Parts)
}
// Equal, it doesn't matter
return false
}

9
vendor/github.com/patrickmn/go-cache/CONTRIBUTORS generated vendored Normal file
View file

@ -0,0 +1,9 @@
This is a list of people who have contributed code to go-cache. They, or their
employers, are the copyright holders of the contributed code. Contributed code
is subject to the license restrictions listed in LICENSE (as they were when the
code was contributed.)
Dustin Sallings <dustin@spy.net>
Jason Mooberry <jasonmoo@me.com>
Sergey Shepelev <temotor@gmail.com>
Alex Edwards <ajmedwards@gmail.com>

19
vendor/github.com/patrickmn/go-cache/LICENSE generated vendored Normal file
View file

@ -0,0 +1,19 @@
Copyright (c) 2012-2017 Patrick Mylund Nielsen and the go-cache contributors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

83
vendor/github.com/patrickmn/go-cache/README.md generated vendored Normal file
View file

@ -0,0 +1,83 @@
# go-cache
go-cache is an in-memory key:value store/cache similar to memcached that is
suitable for applications running on a single machine. Its major advantage is
that, being essentially a thread-safe `map[string]interface{}` with expiration
times, it doesn't need to serialize or transmit its contents over the network.
Any object can be stored, for a given duration or forever, and the cache can be
safely used by multiple goroutines.
Although go-cache isn't meant to be used as a persistent datastore, the entire
cache can be saved to and loaded from a file (using `c.Items()` to retrieve the
items map to serialize, and `NewFrom()` to create a cache from a deserialized
one) to recover from downtime quickly. (See the docs for `NewFrom()` for caveats.)
### Installation
`go get github.com/patrickmn/go-cache`
### Usage
```go
import (
"fmt"
"github.com/patrickmn/go-cache"
"time"
)
func main() {
// Create a cache with a default expiration time of 5 minutes, and which
// purges expired items every 10 minutes
c := cache.New(5*time.Minute, 10*time.Minute)
// Set the value of the key "foo" to "bar", with the default expiration time
c.Set("foo", "bar", cache.DefaultExpiration)
// Set the value of the key "baz" to 42, with no expiration time
// (the item won't be removed until it is re-set, or removed using
// c.Delete("baz")
c.Set("baz", 42, cache.NoExpiration)
// Get the string associated with the key "foo" from the cache
foo, found := c.Get("foo")
if found {
fmt.Println(foo)
}
// Since Go is statically typed, and cache values can be anything, type
// assertion is needed when values are being passed to functions that don't
// take arbitrary types, (i.e. interface{}). The simplest way to do this for
// values which will only be used once--e.g. for passing to another
// function--is:
foo, found := c.Get("foo")
if found {
MyFunction(foo.(string))
}
// This gets tedious if the value is used several times in the same function.
// You might do either of the following instead:
if x, found := c.Get("foo"); found {
foo := x.(string)
// ...
}
// or
var foo string
if x, found := c.Get("foo"); found {
foo = x.(string)
}
// ...
// foo can then be passed around freely as a string
// Want performance? Store pointers!
c.Set("foo", &MyStruct, cache.DefaultExpiration)
if x, found := c.Get("foo"); found {
foo := x.(*MyStruct)
// ...
}
}
```
### Reference
`godoc` or [http://godoc.org/github.com/patrickmn/go-cache](http://godoc.org/github.com/patrickmn/go-cache)

1161
vendor/github.com/patrickmn/go-cache/cache.go generated vendored Normal file

File diff suppressed because it is too large Load diff

192
vendor/github.com/patrickmn/go-cache/sharded.go generated vendored Normal file
View file

@ -0,0 +1,192 @@
package cache
import (
"crypto/rand"
"math"
"math/big"
insecurerand "math/rand"
"os"
"runtime"
"time"
)
// This is an experimental and unexported (for now) attempt at making a cache
// with better algorithmic complexity than the standard one, namely by
// preventing write locks of the entire cache when an item is added. As of the
// time of writing, the overhead of selecting buckets results in cache
// operations being about twice as slow as for the standard cache with small
// total cache sizes, and faster for larger ones.
//
// See cache_test.go for a few benchmarks.
type unexportedShardedCache struct {
*shardedCache
}
type shardedCache struct {
seed uint32
m uint32
cs []*cache
janitor *shardedJanitor
}
// djb2 with better shuffling. 5x faster than FNV with the hash.Hash overhead.
func djb33(seed uint32, k string) uint32 {
var (
l = uint32(len(k))
d = 5381 + seed + l
i = uint32(0)
)
// Why is all this 5x faster than a for loop?
if l >= 4 {
for i < l-4 {
d = (d * 33) ^ uint32(k[i])
d = (d * 33) ^ uint32(k[i+1])
d = (d * 33) ^ uint32(k[i+2])
d = (d * 33) ^ uint32(k[i+3])
i += 4
}
}
switch l - i {
case 1:
case 2:
d = (d * 33) ^ uint32(k[i])
case 3:
d = (d * 33) ^ uint32(k[i])
d = (d * 33) ^ uint32(k[i+1])
case 4:
d = (d * 33) ^ uint32(k[i])
d = (d * 33) ^ uint32(k[i+1])
d = (d * 33) ^ uint32(k[i+2])
}
return d ^ (d >> 16)
}
func (sc *shardedCache) bucket(k string) *cache {
return sc.cs[djb33(sc.seed, k)%sc.m]
}
func (sc *shardedCache) Set(k string, x interface{}, d time.Duration) {
sc.bucket(k).Set(k, x, d)
}
func (sc *shardedCache) Add(k string, x interface{}, d time.Duration) error {
return sc.bucket(k).Add(k, x, d)
}
func (sc *shardedCache) Replace(k string, x interface{}, d time.Duration) error {
return sc.bucket(k).Replace(k, x, d)
}
func (sc *shardedCache) Get(k string) (interface{}, bool) {
return sc.bucket(k).Get(k)
}
func (sc *shardedCache) Increment(k string, n int64) error {
return sc.bucket(k).Increment(k, n)
}
func (sc *shardedCache) IncrementFloat(k string, n float64) error {
return sc.bucket(k).IncrementFloat(k, n)
}
func (sc *shardedCache) Decrement(k string, n int64) error {
return sc.bucket(k).Decrement(k, n)
}
func (sc *shardedCache) Delete(k string) {
sc.bucket(k).Delete(k)
}
func (sc *shardedCache) DeleteExpired() {
for _, v := range sc.cs {
v.DeleteExpired()
}
}
// Returns the items in the cache. This may include items that have expired,
// but have not yet been cleaned up. If this is significant, the Expiration
// fields of the items should be checked. Note that explicit synchronization
// is needed to use a cache and its corresponding Items() return values at
// the same time, as the maps are shared.
func (sc *shardedCache) Items() []map[string]Item {
res := make([]map[string]Item, len(sc.cs))
for i, v := range sc.cs {
res[i] = v.Items()
}
return res
}
func (sc *shardedCache) Flush() {
for _, v := range sc.cs {
v.Flush()
}
}
type shardedJanitor struct {
Interval time.Duration
stop chan bool
}
func (j *shardedJanitor) Run(sc *shardedCache) {
j.stop = make(chan bool)
tick := time.Tick(j.Interval)
for {
select {
case <-tick:
sc.DeleteExpired()
case <-j.stop:
return
}
}
}
func stopShardedJanitor(sc *unexportedShardedCache) {
sc.janitor.stop <- true
}
func runShardedJanitor(sc *shardedCache, ci time.Duration) {
j := &shardedJanitor{
Interval: ci,
}
sc.janitor = j
go j.Run(sc)
}
func newShardedCache(n int, de time.Duration) *shardedCache {
max := big.NewInt(0).SetUint64(uint64(math.MaxUint32))
rnd, err := rand.Int(rand.Reader, max)
var seed uint32
if err != nil {
os.Stderr.Write([]byte("WARNING: go-cache's newShardedCache failed to read from the system CSPRNG (/dev/urandom or equivalent.) Your system's security may be compromised. Continuing with an insecure seed.\n"))
seed = insecurerand.Uint32()
} else {
seed = uint32(rnd.Uint64())
}
sc := &shardedCache{
seed: seed,
m: uint32(n),
cs: make([]*cache, n),
}
for i := 0; i < n; i++ {
c := &cache{
defaultExpiration: de,
items: map[string]Item{},
}
sc.cs[i] = c
}
return sc
}
func unexportedNewSharded(defaultExpiration, cleanupInterval time.Duration, shards int) *unexportedShardedCache {
if defaultExpiration == 0 {
defaultExpiration = -1
}
sc := newShardedCache(shards, defaultExpiration)
SC := &unexportedShardedCache{sc}
if cleanupInterval > 0 {
runShardedJanitor(sc, cleanupInterval)
runtime.SetFinalizer(SC, stopShardedJanitor)
}
return SC
}

10
vendor/github.com/pquerna/cachecontrol/.travis.yml generated vendored Normal file
View file

@ -0,0 +1,10 @@
language: go
install:
- go get -d -v ./...
- go get -u github.com/stretchr/testify/require
go:
- 1.7
- 1.8
- tip

202
vendor/github.com/pquerna/cachecontrol/LICENSE generated vendored Normal file
View file

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

107
vendor/github.com/pquerna/cachecontrol/README.md generated vendored Normal file
View file

@ -0,0 +1,107 @@
# cachecontrol: HTTP Caching Parser and Interpretation
[![GoDoc](https://godoc.org/github.com/pquerna/cachecontrol?status.svg)](https://godoc.org/github.com/pquerna/cachecontrol)[![Build Status](https://travis-ci.org/pquerna/cachecontrol.svg?branch=master)](https://travis-ci.org/pquerna/cachecontrol)
`cachecontrol` implements [RFC 7234](http://tools.ietf.org/html/rfc7234) __Hypertext Transfer Protocol (HTTP/1.1): Caching__. It does this by parsing the `Cache-Control` and other headers, providing information about requests and responses -- but `cachecontrol` does not implement an actual cache backend, just the control plane to make decisions about if a particular response is cachable.
# Usage
`cachecontrol.CachableResponse` returns an array of [reasons](https://godoc.org/github.com/pquerna/cachecontrol/cacheobject#Reason) why a response should not be cached and when it expires. In the case that `len(reasons) == 0`, the response is cachable according to the RFC. However, some people want non-compliant caches for various business use cases, so each reason is specifically named, so if your cache wants to cache `POST` requests, it can easily do that, but still be RFC compliant in other situations.
# Examples
## Can you cache Example.com?
```go
package main
import (
"github.com/pquerna/cachecontrol"
"fmt"
"io/ioutil"
"net/http"
)
func main() {
req, _ := http.NewRequest("GET", "http://www.example.com/", nil)
res, _ := http.DefaultClient.Do(req)
_, _ = ioutil.ReadAll(res.Body)
reasons, expires, _ := cachecontrol.CachableResponse(req, res, cachecontrol.Options{})
fmt.Println("Reasons to not cache: ", reasons)
fmt.Println("Expiration: ", expires.String())
}
```
## Can I use this in a high performance caching server?
`cachecontrol` is divided into two packages: `cachecontrol` with a high level API, and a lower level `cacheobject` package. Use [Object](https://godoc.org/github.com/pquerna/cachecontrol/cacheobject#Object) in a high performance use case where you have previously parsed headers containing dates or would like to avoid memory allocations.
```go
package main
import (
"github.com/pquerna/cachecontrol/cacheobject"
"fmt"
"io/ioutil"
"net/http"
)
func main() {
req, _ := http.NewRequest("GET", "http://www.example.com/", nil)
res, _ := http.DefaultClient.Do(req)
_, _ = ioutil.ReadAll(res.Body)
reqDir, _ := cacheobject.ParseRequestCacheControl(req.Header.Get("Cache-Control"))
resDir, _ := cacheobject.ParseResponseCacheControl(res.Header.Get("Cache-Control"))
expiresHeader, _ := http.ParseTime(res.Header.Get("Expires"))
dateHeader, _ := http.ParseTime(res.Header.Get("Date"))
lastModifiedHeader, _ := http.ParseTime(res.Header.Get("Last-Modified"))
obj := cacheobject.Object{
RespDirectives: resDir,
RespHeaders: res.Header,
RespStatusCode: res.StatusCode,
RespExpiresHeader: expiresHeader,
RespDateHeader: dateHeader,
RespLastModifiedHeader: lastModifiedHeader,
ReqDirectives: reqDir,
ReqHeaders: req.Header,
ReqMethod: req.Method,
NowUTC: time.Now().UTC(),
}
rv := cacheobject.ObjectResults{}
cacheobject.CachableObject(&obj, &rv)
cacheobject.ExpirationObject(&obj, &rv)
fmt.Println("Errors: ", rv.OutErr)
fmt.Println("Reasons to not cache: ", rv.OutReasons)
fmt.Println("Warning headers to add: ", rv.OutWarnings)
fmt.Println("Expiration: ", rv.OutExpirationTime.String())
}
```
## Improvements, bugs, adding features, and taking cachecontrol new directions!
Please [open issues in Github](https://github.com/pquerna/cachecontrol/issues) for ideas, bugs, and general thoughts. Pull requests are of course preferred :)
# Credits
`cachecontrol` has recieved significant contributions from:
* [Paul Querna](https://github.com/pquerna)
## License
`cachecontrol` is licensed under the [Apache License, Version 2.0](./LICENSE)

48
vendor/github.com/pquerna/cachecontrol/api.go generated vendored Normal file
View file

@ -0,0 +1,48 @@
/**
* Copyright 2015 Paul Querna
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package cachecontrol
import (
"github.com/pquerna/cachecontrol/cacheobject"
"net/http"
"time"
)
type Options struct {
// Set to True for a prviate cache, which is not shared amoung users (eg, in a browser)
// Set to False for a "shared" cache, which is more common in a server context.
PrivateCache bool
}
// Given an HTTP Request, the future Status Code, and an ResponseWriter,
// determine the possible reasons a response SHOULD NOT be cached.
func CachableResponseWriter(req *http.Request,
statusCode int,
resp http.ResponseWriter,
opts Options) ([]cacheobject.Reason, time.Time, error) {
return cacheobject.UsingRequestResponse(req, statusCode, resp.Header(), opts.PrivateCache)
}
// Given an HTTP Request and Response, determine the possible reasons a response SHOULD NOT
// be cached.
func CachableResponse(req *http.Request,
resp *http.Response,
opts Options) ([]cacheobject.Reason, time.Time, error) {
return cacheobject.UsingRequestResponse(req, resp.StatusCode, resp.Header, opts.PrivateCache)
}

View file

@ -0,0 +1,546 @@
/**
* Copyright 2015 Paul Querna
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package cacheobject
import (
"errors"
"math"
"net/http"
"net/textproto"
"strconv"
"strings"
)
// TODO(pquerna): add extensions from here: http://www.iana.org/assignments/http-cache-directives/http-cache-directives.xhtml
var (
ErrQuoteMismatch = errors.New("Missing closing quote")
ErrMaxAgeDeltaSeconds = errors.New("Failed to parse delta-seconds in `max-age`")
ErrSMaxAgeDeltaSeconds = errors.New("Failed to parse delta-seconds in `s-maxage`")
ErrMaxStaleDeltaSeconds = errors.New("Failed to parse delta-seconds in `min-fresh`")
ErrMinFreshDeltaSeconds = errors.New("Failed to parse delta-seconds in `min-fresh`")
ErrNoCacheNoArgs = errors.New("Unexpected argument to `no-cache`")
ErrNoStoreNoArgs = errors.New("Unexpected argument to `no-store`")
ErrNoTransformNoArgs = errors.New("Unexpected argument to `no-transform`")
ErrOnlyIfCachedNoArgs = errors.New("Unexpected argument to `only-if-cached`")
ErrMustRevalidateNoArgs = errors.New("Unexpected argument to `must-revalidate`")
ErrPublicNoArgs = errors.New("Unexpected argument to `public`")
ErrProxyRevalidateNoArgs = errors.New("Unexpected argument to `proxy-revalidate`")
// Experimental
ErrImmutableNoArgs = errors.New("Unexpected argument to `immutable`")
ErrStaleIfErrorDeltaSeconds = errors.New("Failed to parse delta-seconds in `stale-if-error`")
ErrStaleWhileRevalidateDeltaSeconds = errors.New("Failed to parse delta-seconds in `stale-while-revalidate`")
)
func whitespace(b byte) bool {
if b == '\t' || b == ' ' {
return true
}
return false
}
func parse(value string, cd cacheDirective) error {
var err error = nil
i := 0
for i < len(value) && err == nil {
// eat leading whitespace or commas
if whitespace(value[i]) || value[i] == ',' {
i++
continue
}
j := i + 1
for j < len(value) {
if !isToken(value[j]) {
break
}
j++
}
token := strings.ToLower(value[i:j])
tokenHasFields := hasFieldNames(token)
/*
println("GOT TOKEN:")
println(" i -> ", i)
println(" j -> ", j)
println(" token -> ", token)
*/
if j+1 < len(value) && value[j] == '=' {
k := j + 1
// minimum size two bytes of "", but we let httpUnquote handle it.
if k < len(value) && value[k] == '"' {
eaten, result := httpUnquote(value[k:])
if eaten == -1 {
return ErrQuoteMismatch
}
i = k + eaten
err = cd.addPair(token, result)
} else {
z := k
for z < len(value) {
if tokenHasFields {
if whitespace(value[z]) {
break
}
} else {
if whitespace(value[z]) || value[z] == ',' {
break
}
}
z++
}
i = z
result := value[k:z]
if result != "" && result[len(result)-1] == ',' {
result = result[:len(result)-1]
}
err = cd.addPair(token, result)
}
} else {
if token != "," {
err = cd.addToken(token)
}
i = j
}
}
return err
}
// DeltaSeconds specifies a non-negative integer, representing
// time in seconds: http://tools.ietf.org/html/rfc7234#section-1.2.1
//
// When set to -1, this means unset.
//
type DeltaSeconds int32
// Parser for delta-seconds, a uint31, more or less:
// http://tools.ietf.org/html/rfc7234#section-1.2.1
func parseDeltaSeconds(v string) (DeltaSeconds, error) {
n, err := strconv.ParseUint(v, 10, 32)
if err != nil {
if numError, ok := err.(*strconv.NumError); ok {
if numError.Err == strconv.ErrRange {
return DeltaSeconds(math.MaxInt32), nil
}
}
return DeltaSeconds(-1), err
} else {
if n > math.MaxInt32 {
return DeltaSeconds(math.MaxInt32), nil
} else {
return DeltaSeconds(n), nil
}
}
}
// Fields present in a header.
type FieldNames map[string]bool
// internal interface for shared methods of RequestCacheDirectives and ResponseCacheDirectives
type cacheDirective interface {
addToken(s string) error
addPair(s string, v string) error
}
// LOW LEVEL API: Repersentation of possible request directives in a `Cache-Control` header: http://tools.ietf.org/html/rfc7234#section-5.2.1
//
// Note: Many fields will be `nil` in practice.
//
type RequestCacheDirectives struct {
// max-age(delta seconds): http://tools.ietf.org/html/rfc7234#section-5.2.1.1
//
// The "max-age" request directive indicates that the client is
// unwilling to accept a response whose age is greater than the
// specified number of seconds. Unless the max-stale request directive
// is also present, the client is not willing to accept a stale
// response.
MaxAge DeltaSeconds
// max-stale(delta seconds): http://tools.ietf.org/html/rfc7234#section-5.2.1.2
//
// The "max-stale" request directive indicates that the client is
// willing to accept a response that has exceeded its freshness
// lifetime. If max-stale is assigned a value, then the client is
// willing to accept a response that has exceeded its freshness lifetime
// by no more than the specified number of seconds. If no value is
// assigned to max-stale, then the client is willing to accept a stale
// response of any age.
MaxStale DeltaSeconds
// min-fresh(delta seconds): http://tools.ietf.org/html/rfc7234#section-5.2.1.3
//
// The "min-fresh" request directive indicates that the client is
// willing to accept a response whose freshness lifetime is no less than
// its current age plus the specified time in seconds. That is, the
// client wants a response that will still be fresh for at least the
// specified number of seconds.
MinFresh DeltaSeconds
// no-cache(bool): http://tools.ietf.org/html/rfc7234#section-5.2.1.4
//
// The "no-cache" request directive indicates that a cache MUST NOT use
// a stored response to satisfy the request without successful
// validation on the origin server.
NoCache bool
// no-store(bool): http://tools.ietf.org/html/rfc7234#section-5.2.1.5
//
// The "no-store" request directive indicates that a cache MUST NOT
// store any part of either this request or any response to it. This
// directive applies to both private and shared caches.
NoStore bool
// no-transform(bool): http://tools.ietf.org/html/rfc7234#section-5.2.1.6
//
// The "no-transform" request directive indicates that an intermediary
// (whether or not it implements a cache) MUST NOT transform the
// payload, as defined in Section 5.7.2 of RFC7230.
NoTransform bool
// only-if-cached(bool): http://tools.ietf.org/html/rfc7234#section-5.2.1.7
//
// The "only-if-cached" request directive indicates that the client only
// wishes to obtain a stored response.
OnlyIfCached bool
// Extensions: http://tools.ietf.org/html/rfc7234#section-5.2.3
//
// The Cache-Control header field can be extended through the use of one
// or more cache-extension tokens, each with an optional value. A cache
// MUST ignore unrecognized cache directives.
Extensions []string
}
func (cd *RequestCacheDirectives) addToken(token string) error {
var err error = nil
switch token {
case "max-age":
err = ErrMaxAgeDeltaSeconds
case "max-stale":
err = ErrMaxStaleDeltaSeconds
case "min-fresh":
err = ErrMinFreshDeltaSeconds
case "no-cache":
cd.NoCache = true
case "no-store":
cd.NoStore = true
case "no-transform":
cd.NoTransform = true
case "only-if-cached":
cd.OnlyIfCached = true
default:
cd.Extensions = append(cd.Extensions, token)
}
return err
}
func (cd *RequestCacheDirectives) addPair(token string, v string) error {
var err error = nil
switch token {
case "max-age":
cd.MaxAge, err = parseDeltaSeconds(v)
if err != nil {
err = ErrMaxAgeDeltaSeconds
}
case "max-stale":
cd.MaxStale, err = parseDeltaSeconds(v)
if err != nil {
err = ErrMaxStaleDeltaSeconds
}
case "min-fresh":
cd.MinFresh, err = parseDeltaSeconds(v)
if err != nil {
err = ErrMinFreshDeltaSeconds
}
case "no-cache":
err = ErrNoCacheNoArgs
case "no-store":
err = ErrNoStoreNoArgs
case "no-transform":
err = ErrNoTransformNoArgs
case "only-if-cached":
err = ErrOnlyIfCachedNoArgs
default:
// TODO(pquerna): this sucks, making user re-parse
cd.Extensions = append(cd.Extensions, token+"="+v)
}
return err
}
// LOW LEVEL API: Parses a Cache Control Header from a Request into a set of directives.
func ParseRequestCacheControl(value string) (*RequestCacheDirectives, error) {
cd := &RequestCacheDirectives{
MaxAge: -1,
MaxStale: -1,
MinFresh: -1,
}
err := parse(value, cd)
if err != nil {
return nil, err
}
return cd, nil
}
// LOW LEVEL API: Repersentation of possible response directives in a `Cache-Control` header: http://tools.ietf.org/html/rfc7234#section-5.2.2
//
// Note: Many fields will be `nil` in practice.
//
type ResponseCacheDirectives struct {
// must-revalidate(bool): http://tools.ietf.org/html/rfc7234#section-5.2.2.1
//
// The "must-revalidate" response directive indicates that once it has
// become stale, a cache MUST NOT use the response to satisfy subsequent
// requests without successful validation on the origin server.
MustRevalidate bool
// no-cache(FieldName): http://tools.ietf.org/html/rfc7234#section-5.2.2.2
//
// The "no-cache" response directive indicates that the response MUST
// NOT be used to satisfy a subsequent request without successful
// validation on the origin server.
//
// If the no-cache response directive specifies one or more field-names,
// then a cache MAY use the response to satisfy a subsequent request,
// subject to any other restrictions on caching. However, any header
// fields in the response that have the field-name(s) listed MUST NOT be
// sent in the response to a subsequent request without successful
// revalidation with the origin server.
NoCache FieldNames
// no-cache(cast-to-bool): http://tools.ietf.org/html/rfc7234#section-5.2.2.2
//
// While the RFC defines optional field-names on a no-cache directive,
// many applications only want to know if any no-cache directives were
// present at all.
NoCachePresent bool
// no-store(bool): http://tools.ietf.org/html/rfc7234#section-5.2.2.3
//
// The "no-store" request directive indicates that a cache MUST NOT
// store any part of either this request or any response to it. This
// directive applies to both private and shared caches.
NoStore bool
// no-transform(bool): http://tools.ietf.org/html/rfc7234#section-5.2.2.4
//
// The "no-transform" response directive indicates that an intermediary
// (regardless of whether it implements a cache) MUST NOT transform the
// payload, as defined in Section 5.7.2 of RFC7230.
NoTransform bool
// public(bool): http://tools.ietf.org/html/rfc7234#section-5.2.2.5
//
// The "public" response directive indicates that any cache MAY store
// the response, even if the response would normally be non-cacheable or
// cacheable only within a private cache.
Public bool
// private(FieldName): http://tools.ietf.org/html/rfc7234#section-5.2.2.6
//
// The "private" response directive indicates that the response message
// is intended for a single user and MUST NOT be stored by a shared
// cache. A private cache MAY store the response and reuse it for later
// requests, even if the response would normally be non-cacheable.
//
// If the private response directive specifies one or more field-names,
// this requirement is limited to the field-values associated with the
// listed response header fields. That is, a shared cache MUST NOT
// store the specified field-names(s), whereas it MAY store the
// remainder of the response message.
Private FieldNames
// private(cast-to-bool): http://tools.ietf.org/html/rfc7234#section-5.2.2.6
//
// While the RFC defines optional field-names on a private directive,
// many applications only want to know if any private directives were
// present at all.
PrivatePresent bool
// proxy-revalidate(bool): http://tools.ietf.org/html/rfc7234#section-5.2.2.7
//
// The "proxy-revalidate" response directive has the same meaning as the
// must-revalidate response directive, except that it does not apply to
// private caches.
ProxyRevalidate bool
// max-age(delta seconds): http://tools.ietf.org/html/rfc7234#section-5.2.2.8
//
// The "max-age" response directive indicates that the response is to be
// considered stale after its age is greater than the specified number
// of seconds.
MaxAge DeltaSeconds
// s-maxage(delta seconds): http://tools.ietf.org/html/rfc7234#section-5.2.2.9
//
// The "s-maxage" response directive indicates that, in shared caches,
// the maximum age specified by this directive overrides the maximum age
// specified by either the max-age directive or the Expires header
// field. The s-maxage directive also implies the semantics of the
// proxy-revalidate response directive.
SMaxAge DeltaSeconds
////
// Experimental features
// - https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control#Extension_Cache-Control_directives
// - https://www.fastly.com/blog/stale-while-revalidate-stale-if-error-available-today
////
// immutable(cast-to-bool): experimental feature
Immutable bool
// stale-if-error(delta seconds): experimental feature
StaleIfError DeltaSeconds
// stale-while-revalidate(delta seconds): experimental feature
StaleWhileRevalidate DeltaSeconds
// Extensions: http://tools.ietf.org/html/rfc7234#section-5.2.3
//
// The Cache-Control header field can be extended through the use of one
// or more cache-extension tokens, each with an optional value. A cache
// MUST ignore unrecognized cache directives.
Extensions []string
}
// LOW LEVEL API: Parses a Cache Control Header from a Response into a set of directives.
func ParseResponseCacheControl(value string) (*ResponseCacheDirectives, error) {
cd := &ResponseCacheDirectives{
MaxAge: -1,
SMaxAge: -1,
// Exerimantal stale timeouts
StaleIfError: -1,
StaleWhileRevalidate: -1,
}
err := parse(value, cd)
if err != nil {
return nil, err
}
return cd, nil
}
func (cd *ResponseCacheDirectives) addToken(token string) error {
var err error = nil
switch token {
case "must-revalidate":
cd.MustRevalidate = true
case "no-cache":
cd.NoCachePresent = true
case "no-store":
cd.NoStore = true
case "no-transform":
cd.NoTransform = true
case "public":
cd.Public = true
case "private":
cd.PrivatePresent = true
case "proxy-revalidate":
cd.ProxyRevalidate = true
case "max-age":
err = ErrMaxAgeDeltaSeconds
case "s-maxage":
err = ErrSMaxAgeDeltaSeconds
// Experimental
case "immutable":
cd.Immutable = true
case "stale-if-error":
err = ErrMaxAgeDeltaSeconds
case "stale-while-revalidate":
err = ErrMaxAgeDeltaSeconds
default:
cd.Extensions = append(cd.Extensions, token)
}
return err
}
func hasFieldNames(token string) bool {
switch token {
case "no-cache":
return true
case "private":
return true
}
return false
}
func (cd *ResponseCacheDirectives) addPair(token string, v string) error {
var err error = nil
switch token {
case "must-revalidate":
err = ErrMustRevalidateNoArgs
case "no-cache":
cd.NoCachePresent = true
tokens := strings.Split(v, ",")
if cd.NoCache == nil {
cd.NoCache = make(FieldNames)
}
for _, t := range tokens {
k := http.CanonicalHeaderKey(textproto.TrimString(t))
cd.NoCache[k] = true
}
case "no-store":
err = ErrNoStoreNoArgs
case "no-transform":
err = ErrNoTransformNoArgs
case "public":
err = ErrPublicNoArgs
case "private":
cd.PrivatePresent = true
tokens := strings.Split(v, ",")
if cd.Private == nil {
cd.Private = make(FieldNames)
}
for _, t := range tokens {
k := http.CanonicalHeaderKey(textproto.TrimString(t))
cd.Private[k] = true
}
case "proxy-revalidate":
err = ErrProxyRevalidateNoArgs
case "max-age":
cd.MaxAge, err = parseDeltaSeconds(v)
case "s-maxage":
cd.SMaxAge, err = parseDeltaSeconds(v)
// Experimental
case "immutable":
err = ErrImmutableNoArgs
case "stale-if-error":
cd.StaleIfError, err = parseDeltaSeconds(v)
case "stale-while-revalidate":
cd.StaleWhileRevalidate, err = parseDeltaSeconds(v)
default:
// TODO(pquerna): this sucks, making user re-parse, and its technically not 'quoted' like the original,
// but this is still easier, just a SplitN on "="
cd.Extensions = append(cd.Extensions, token+"="+v)
}
return err
}

View file

@ -0,0 +1,93 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cacheobject
// This file deals with lexical matters of HTTP
func isSeparator(c byte) bool {
switch c {
case '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=', '{', '}', ' ', '\t':
return true
}
return false
}
func isCtl(c byte) bool { return (0 <= c && c <= 31) || c == 127 }
func isChar(c byte) bool { return 0 <= c && c <= 127 }
func isAnyText(c byte) bool { return !isCtl(c) }
func isQdText(c byte) bool { return isAnyText(c) && c != '"' }
func isToken(c byte) bool { return isChar(c) && !isCtl(c) && !isSeparator(c) }
// Valid escaped sequences are not specified in RFC 2616, so for now, we assume
// that they coincide with the common sense ones used by GO. Malformed
// characters should probably not be treated as errors by a robust (forgiving)
// parser, so we replace them with the '?' character.
func httpUnquotePair(b byte) byte {
// skip the first byte, which should always be '\'
switch b {
case 'a':
return '\a'
case 'b':
return '\b'
case 'f':
return '\f'
case 'n':
return '\n'
case 'r':
return '\r'
case 't':
return '\t'
case 'v':
return '\v'
case '\\':
return '\\'
case '\'':
return '\''
case '"':
return '"'
}
return '?'
}
// raw must begin with a valid quoted string. Only the first quoted string is
// parsed and is unquoted in result. eaten is the number of bytes parsed, or -1
// upon failure.
func httpUnquote(raw string) (eaten int, result string) {
buf := make([]byte, len(raw))
if raw[0] != '"' {
return -1, ""
}
eaten = 1
j := 0 // # of bytes written in buf
for i := 1; i < len(raw); i++ {
switch b := raw[i]; b {
case '"':
eaten++
buf = buf[0:j]
return i + 1, string(buf)
case '\\':
if len(raw) < i+2 {
return -1, ""
}
buf[j] = httpUnquotePair(raw[i+1])
eaten += 2
j++
i++
default:
if isQdText(b) {
buf[j] = b
} else {
buf[j] = '?'
}
eaten++
j++
}
}
return -1, ""
}

View file

@ -0,0 +1,387 @@
/**
* Copyright 2015 Paul Querna
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package cacheobject
import (
"net/http"
"time"
)
// LOW LEVEL API: Repersents a potentially cachable HTTP object.
//
// This struct is designed to be serialized efficiently, so in a high
// performance caching server, things like Date-Strings don't need to be
// parsed for every use of a cached object.
type Object struct {
CacheIsPrivate bool
RespDirectives *ResponseCacheDirectives
RespHeaders http.Header
RespStatusCode int
RespExpiresHeader time.Time
RespDateHeader time.Time
RespLastModifiedHeader time.Time
ReqDirectives *RequestCacheDirectives
ReqHeaders http.Header
ReqMethod string
NowUTC time.Time
}
// LOW LEVEL API: Repersents the results of examinig an Object with
// CachableObject and ExpirationObject.
//
// TODO(pquerna): decide if this is a good idea or bad
type ObjectResults struct {
OutReasons []Reason
OutWarnings []Warning
OutExpirationTime time.Time
OutErr error
}
// LOW LEVEL API: Check if a object is cachable.
func CachableObject(obj *Object, rv *ObjectResults) {
rv.OutReasons = nil
rv.OutWarnings = nil
rv.OutErr = nil
switch obj.ReqMethod {
case "GET":
break
case "HEAD":
break
case "POST":
/**
POST: http://tools.ietf.org/html/rfc7231#section-4.3.3
Responses to POST requests are only cacheable when they include
explicit freshness information (see Section 4.2.1 of [RFC7234]).
However, POST caching is not widely implemented. For cases where an
origin server wishes the client to be able to cache the result of a
POST in a way that can be reused by a later GET, the origin server
MAY send a 200 (OK) response containing the result and a
Content-Location header field that has the same value as the POST's
effective request URI (Section 3.1.4.2).
*/
if !hasFreshness(obj.ReqDirectives, obj.RespDirectives, obj.RespHeaders, obj.RespExpiresHeader, obj.CacheIsPrivate) {
rv.OutReasons = append(rv.OutReasons, ReasonRequestMethodPOST)
}
case "PUT":
rv.OutReasons = append(rv.OutReasons, ReasonRequestMethodPUT)
case "DELETE":
rv.OutReasons = append(rv.OutReasons, ReasonRequestMethodDELETE)
case "CONNECT":
rv.OutReasons = append(rv.OutReasons, ReasonRequestMethodCONNECT)
case "OPTIONS":
rv.OutReasons = append(rv.OutReasons, ReasonRequestMethodOPTIONS)
case "TRACE":
rv.OutReasons = append(rv.OutReasons, ReasonRequestMethodTRACE)
// HTTP Extension Methods: http://www.iana.org/assignments/http-methods/http-methods.xhtml
//
// To my knowledge, none of them are cachable. Please open a ticket if this is not the case!
//
default:
rv.OutReasons = append(rv.OutReasons, ReasonRequestMethodUnkown)
}
if obj.ReqDirectives.NoStore {
rv.OutReasons = append(rv.OutReasons, ReasonRequestNoStore)
}
// Storing Responses to Authenticated Requests: http://tools.ietf.org/html/rfc7234#section-3.2
authz := obj.ReqHeaders.Get("Authorization")
if authz != "" {
if obj.RespDirectives.MustRevalidate ||
obj.RespDirectives.Public ||
obj.RespDirectives.SMaxAge != -1 {
// Expires of some kind present, this is potentially OK.
} else {
rv.OutReasons = append(rv.OutReasons, ReasonRequestAuthorizationHeader)
}
}
if obj.RespDirectives.PrivatePresent && !obj.CacheIsPrivate {
rv.OutReasons = append(rv.OutReasons, ReasonResponsePrivate)
}
if obj.RespDirectives.NoStore {
rv.OutReasons = append(rv.OutReasons, ReasonResponseNoStore)
}
/*
the response either:
* contains an Expires header field (see Section 5.3), or
* contains a max-age response directive (see Section 5.2.2.8), or
* contains a s-maxage response directive (see Section 5.2.2.9)
and the cache is shared, or
* contains a Cache Control Extension (see Section 5.2.3) that
allows it to be cached, or
* has a status code that is defined as cacheable by default (see
Section 4.2.2), or
* contains a public response directive (see Section 5.2.2.5).
*/
expires := obj.RespHeaders.Get("Expires") != ""
statusCachable := cachableStatusCode(obj.RespStatusCode)
if expires ||
obj.RespDirectives.MaxAge != -1 ||
(obj.RespDirectives.SMaxAge != -1 && !obj.CacheIsPrivate) ||
statusCachable ||
obj.RespDirectives.Public {
/* cachable by default, at least one of the above conditions was true */
} else {
rv.OutReasons = append(rv.OutReasons, ReasonResponseUncachableByDefault)
}
}
var twentyFourHours = time.Duration(24 * time.Hour)
const debug = false
// LOW LEVEL API: Update an objects expiration time.
func ExpirationObject(obj *Object, rv *ObjectResults) {
/**
* Okay, lets calculate Freshness/Expiration now. woo:
* http://tools.ietf.org/html/rfc7234#section-4.2
*/
/*
o If the cache is shared and the s-maxage response directive
(Section 5.2.2.9) is present, use its value, or
o If the max-age response directive (Section 5.2.2.8) is present,
use its value, or
o If the Expires response header field (Section 5.3) is present, use
its value minus the value of the Date response header field, or
o Otherwise, no explicit expiration time is present in the response.
A heuristic freshness lifetime might be applicable; see
Section 4.2.2.
*/
var expiresTime time.Time
if obj.RespDirectives.SMaxAge != -1 && !obj.CacheIsPrivate {
expiresTime = obj.NowUTC.Add(time.Second * time.Duration(obj.RespDirectives.SMaxAge))
} else if obj.RespDirectives.MaxAge != -1 {
expiresTime = obj.NowUTC.UTC().Add(time.Second * time.Duration(obj.RespDirectives.MaxAge))
} else if !obj.RespExpiresHeader.IsZero() {
serverDate := obj.RespDateHeader
if serverDate.IsZero() {
// common enough case when a Date: header has not yet been added to an
// active response.
serverDate = obj.NowUTC
}
expiresTime = obj.NowUTC.Add(obj.RespExpiresHeader.Sub(serverDate))
} else if !obj.RespLastModifiedHeader.IsZero() {
// heuristic freshness lifetime
rv.OutWarnings = append(rv.OutWarnings, WarningHeuristicExpiration)
// http://httpd.apache.org/docs/2.4/mod/mod_cache.html#cachelastmodifiedfactor
// CacheMaxExpire defaults to 24 hours
// CacheLastModifiedFactor: is 0.1
//
// expiry-period = MIN(time-since-last-modified-date * factor, 24 hours)
//
// obj.NowUTC
since := obj.RespLastModifiedHeader.Sub(obj.NowUTC)
since = time.Duration(float64(since) * -0.1)
if since > twentyFourHours {
expiresTime = obj.NowUTC.Add(twentyFourHours)
} else {
expiresTime = obj.NowUTC.Add(since)
}
if debug {
println("Now UTC: ", obj.NowUTC.String())
println("Last-Modified: ", obj.RespLastModifiedHeader.String())
println("Since: ", since.String())
println("TwentyFourHours: ", twentyFourHours.String())
println("Expiration: ", expiresTime.String())
}
} else {
// TODO(pquerna): what should the default behavoir be for expiration time?
}
rv.OutExpirationTime = expiresTime
}
// Evaluate cachability based on an HTTP request, and parts of the response.
func UsingRequestResponse(req *http.Request,
statusCode int,
respHeaders http.Header,
privateCache bool) ([]Reason, time.Time, error) {
reasons, time, _, _, err := UsingRequestResponseWithObject(req, statusCode, respHeaders, privateCache)
return reasons, time, err
}
// Evaluate cachability based on an HTTP request, and parts of the response.
// Returns the parsed Object as well.
func UsingRequestResponseWithObject(req *http.Request,
statusCode int,
respHeaders http.Header,
privateCache bool) ([]Reason, time.Time, []Warning, *Object, error) {
var reqHeaders http.Header
var reqMethod string
var reqDir *RequestCacheDirectives = nil
respDir, err := ParseResponseCacheControl(respHeaders.Get("Cache-Control"))
if err != nil {
return nil, time.Time{}, nil, nil, err
}
if req != nil {
reqDir, err = ParseRequestCacheControl(req.Header.Get("Cache-Control"))
if err != nil {
return nil, time.Time{}, nil, nil, err
}
reqHeaders = req.Header
reqMethod = req.Method
}
var expiresHeader time.Time
var dateHeader time.Time
var lastModifiedHeader time.Time
if respHeaders.Get("Expires") != "" {
expiresHeader, err = http.ParseTime(respHeaders.Get("Expires"))
if err != nil {
// sometimes servers will return `Expires: 0` or `Expires: -1` to
// indicate expired content
expiresHeader = time.Time{}
}
expiresHeader = expiresHeader.UTC()
}
if respHeaders.Get("Date") != "" {
dateHeader, err = http.ParseTime(respHeaders.Get("Date"))
if err != nil {
return nil, time.Time{}, nil, nil, err
}
dateHeader = dateHeader.UTC()
}
if respHeaders.Get("Last-Modified") != "" {
lastModifiedHeader, err = http.ParseTime(respHeaders.Get("Last-Modified"))
if err != nil {
return nil, time.Time{}, nil, nil, err
}
lastModifiedHeader = lastModifiedHeader.UTC()
}
obj := Object{
CacheIsPrivate: privateCache,
RespDirectives: respDir,
RespHeaders: respHeaders,
RespStatusCode: statusCode,
RespExpiresHeader: expiresHeader,
RespDateHeader: dateHeader,
RespLastModifiedHeader: lastModifiedHeader,
ReqDirectives: reqDir,
ReqHeaders: reqHeaders,
ReqMethod: reqMethod,
NowUTC: time.Now().UTC(),
}
rv := ObjectResults{}
CachableObject(&obj, &rv)
if rv.OutErr != nil {
return nil, time.Time{}, nil, nil, rv.OutErr
}
ExpirationObject(&obj, &rv)
if rv.OutErr != nil {
return nil, time.Time{}, nil, nil, rv.OutErr
}
return rv.OutReasons, rv.OutExpirationTime, rv.OutWarnings, &obj, nil
}
// calculate if a freshness directive is present: http://tools.ietf.org/html/rfc7234#section-4.2.1
func hasFreshness(reqDir *RequestCacheDirectives, respDir *ResponseCacheDirectives, respHeaders http.Header, respExpires time.Time, privateCache bool) bool {
if !privateCache && respDir.SMaxAge != -1 {
return true
}
if respDir.MaxAge != -1 {
return true
}
if !respExpires.IsZero() || respHeaders.Get("Expires") != "" {
return true
}
return false
}
func cachableStatusCode(statusCode int) bool {
/*
Responses with status codes that are defined as cacheable by default
(e.g., 200, 203, 204, 206, 300, 301, 404, 405, 410, 414, and 501 in
this specification) can be reused by a cache with heuristic
expiration unless otherwise indicated by the method definition or
explicit cache controls [RFC7234]; all other status codes are not
cacheable by default.
*/
switch statusCode {
case 200:
return true
case 203:
return true
case 204:
return true
case 206:
return true
case 300:
return true
case 301:
return true
case 404:
return true
case 405:
return true
case 410:
return true
case 414:
return true
case 501:
return true
default:
return false
}
}

View file

@ -0,0 +1,95 @@
/**
* Copyright 2015 Paul Querna
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package cacheobject
// Repersents a potential Reason to not cache an object.
//
// Applications may wish to ignore specific reasons, which will make them non-RFC
// compliant, but this type gives them specific cases they can choose to ignore,
// making them compliant in as many cases as they can.
type Reason int
const (
// The request method was POST and an Expiration header was not supplied.
ReasonRequestMethodPOST Reason = iota
// The request method was PUT and PUTs are not cachable.
ReasonRequestMethodPUT
// The request method was DELETE and DELETEs are not cachable.
ReasonRequestMethodDELETE
// The request method was CONNECT and CONNECTs are not cachable.
ReasonRequestMethodCONNECT
// The request method was OPTIONS and OPTIONS are not cachable.
ReasonRequestMethodOPTIONS
// The request method was TRACE and TRACEs are not cachable.
ReasonRequestMethodTRACE
// The request method was not recognized by cachecontrol, and should not be cached.
ReasonRequestMethodUnkown
// The request included an Cache-Control: no-store header
ReasonRequestNoStore
// The request included an Authorization header without an explicit Public or Expiration time: http://tools.ietf.org/html/rfc7234#section-3.2
ReasonRequestAuthorizationHeader
// The response included an Cache-Control: no-store header
ReasonResponseNoStore
// The response included an Cache-Control: private header and this is not a Private cache
ReasonResponsePrivate
// The response failed to meet at least one of the conditions specified in RFC 7234 section 3: http://tools.ietf.org/html/rfc7234#section-3
ReasonResponseUncachableByDefault
)
func (r Reason) String() string {
switch r {
case ReasonRequestMethodPOST:
return "ReasonRequestMethodPOST"
case ReasonRequestMethodPUT:
return "ReasonRequestMethodPUT"
case ReasonRequestMethodDELETE:
return "ReasonRequestMethodDELETE"
case ReasonRequestMethodCONNECT:
return "ReasonRequestMethodCONNECT"
case ReasonRequestMethodOPTIONS:
return "ReasonRequestMethodOPTIONS"
case ReasonRequestMethodTRACE:
return "ReasonRequestMethodTRACE"
case ReasonRequestMethodUnkown:
return "ReasonRequestMethodUnkown"
case ReasonRequestNoStore:
return "ReasonRequestNoStore"
case ReasonRequestAuthorizationHeader:
return "ReasonRequestAuthorizationHeader"
case ReasonResponseNoStore:
return "ReasonResponseNoStore"
case ReasonResponsePrivate:
return "ReasonResponsePrivate"
case ReasonResponseUncachableByDefault:
return "ReasonResponseUncachableByDefault"
}
panic(r)
}

View file

@ -0,0 +1,107 @@
/**
* Copyright 2015 Paul Querna
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package cacheobject
import (
"fmt"
"net/http"
"time"
)
// Repersents an HTTP Warning: http://tools.ietf.org/html/rfc7234#section-5.5
type Warning int
const (
// Response is Stale
// A cache SHOULD generate this whenever the sent response is stale.
WarningResponseIsStale Warning = 110
// Revalidation Failed
// A cache SHOULD generate this when sending a stale
// response because an attempt to validate the response failed, due to an
// inability to reach the server.
WarningRevalidationFailed Warning = 111
// Disconnected Operation
// A cache SHOULD generate this if it is intentionally disconnected from
// the rest of the network for a period of time.
WarningDisconnectedOperation Warning = 112
// Heuristic Expiration
//
// A cache SHOULD generate this if it heuristically chose a freshness
// lifetime greater than 24 hours and the response's age is greater than
// 24 hours.
WarningHeuristicExpiration Warning = 113
// Miscellaneous Warning
//
// The warning text can include arbitrary information to be presented to
// a human user or logged. A system receiving this warning MUST NOT
// take any automated action, besides presenting the warning to the
// user.
WarningMiscellaneousWarning Warning = 199
// Transformation Applied
//
// This Warning code MUST be added by a proxy if it applies any
// transformation to the representation, such as changing the
// content-coding, media-type, or modifying the representation data,
// unless this Warning code already appears in the response.
WarningTransformationApplied Warning = 214
// Miscellaneous Persistent Warning
//
// The warning text can include arbitrary information to be presented to
// a human user or logged. A system receiving this warning MUST NOT
// take any automated action.
WarningMiscellaneousPersistentWarning Warning = 299
)
func (w Warning) HeaderString(agent string, date time.Time) string {
if agent == "" {
agent = "-"
} else {
// TODO(pquerna): this doesn't escape agent if it contains bad things.
agent = `"` + agent + `"`
}
return fmt.Sprintf(`%d %s "%s" %s`, w, agent, w.String(), date.Format(http.TimeFormat))
}
func (w Warning) String() string {
switch w {
case WarningResponseIsStale:
return "Response is Stale"
case WarningRevalidationFailed:
return "Revalidation Failed"
case WarningDisconnectedOperation:
return "Disconnected Operation"
case WarningHeuristicExpiration:
return "Heuristic Expiration"
case WarningMiscellaneousWarning:
// TODO(pquerna): ideally had a better way to override this one code.
return "Miscellaneous Warning"
case WarningTransformationApplied:
return "Transformation Applied"
case WarningMiscellaneousPersistentWarning:
// TODO(pquerna): same as WarningMiscellaneousWarning
return "Miscellaneous Persistent Warning"
}
panic(w)
}

25
vendor/github.com/pquerna/cachecontrol/doc.go generated vendored Normal file
View file

@ -0,0 +1,25 @@
/**
* Copyright 2015 Paul Querna
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package cachecontrol implements the logic for HTTP Caching
//
// Deciding if an HTTP Response can be cached is often harder
// and more bug prone than an actual cache storage backend.
// cachecontrol provides a simple interface to determine if
// request and response pairs are cachable as defined under
// RFC 7234 http://tools.ietf.org/html/rfc7234
package cachecontrol

View file

@ -8,11 +8,9 @@ matrix:
- go: tip
go:
- '1.7.x'
- '1.8.x'
- '1.9.x'
- '1.10.x'
- '1.11.x'
- '1.12.x'
- tip
go_import_path: gopkg.in/square/go-jose.v2

View file

@ -29,7 +29,7 @@ import (
"math/big"
"golang.org/x/crypto/ed25519"
"gopkg.in/square/go-jose.v2/cipher"
josecipher "gopkg.in/square/go-jose.v2/cipher"
"gopkg.in/square/go-jose.v2/json"
)
@ -288,7 +288,7 @@ func (ctx rsaDecrypterSigner) signPayload(payload []byte, alg SignatureAlgorithm
out, err = rsa.SignPKCS1v15(RandReader, ctx.privateKey, hash, hashed)
case PS256, PS384, PS512:
out, err = rsa.SignPSS(RandReader, ctx.privateKey, hash, hashed, &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthAuto,
SaltLength: rsa.PSSSaltLengthEqualsHash,
})
}

View file

@ -17,8 +17,10 @@
package josecipher
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"encoding/binary"
)
@ -44,16 +46,38 @@ func DeriveECDHES(alg string, apuData, apvData []byte, priv *ecdsa.PrivateKey, p
panic("public key not on same curve as private key")
}
z, _ := priv.PublicKey.Curve.ScalarMult(pub.X, pub.Y, priv.D.Bytes())
reader := NewConcatKDF(crypto.SHA256, z.Bytes(), algID, ptyUInfo, ptyVInfo, supPubInfo, []byte{})
z, _ := priv.Curve.ScalarMult(pub.X, pub.Y, priv.D.Bytes())
zBytes := z.Bytes()
// Note that calling z.Bytes() on a big.Int may strip leading zero bytes from
// the returned byte array. This can lead to a problem where zBytes will be
// shorter than expected which breaks the key derivation. Therefore we must pad
// to the full length of the expected coordinate here before calling the KDF.
octSize := dSize(priv.Curve)
if len(zBytes) != octSize {
zBytes = append(bytes.Repeat([]byte{0}, octSize-len(zBytes)), zBytes...)
}
reader := NewConcatKDF(crypto.SHA256, zBytes, algID, ptyUInfo, ptyVInfo, supPubInfo, []byte{})
key := make([]byte, size)
// Read on the KDF will never fail
_, _ = reader.Read(key)
return key
}
// dSize returns the size in octets for a coordinate on a elliptic curve.
func dSize(curve elliptic.Curve) int {
order := curve.Params().P
bitLen := order.BitLen()
size := bitLen / 8
if bitLen%8 != 0 {
size++
}
return size
}
func lengthPrefixed(data []byte) []byte {
out := make([]byte, len(data)+4)
binary.BigEndian.PutUint32(out, uint32(len(data)))

View file

@ -141,6 +141,8 @@ func NewEncrypter(enc ContentEncryption, rcpt Recipient, opts *EncrypterOptions)
keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key
case *JSONWebKey:
keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key
case OpaqueKeyEncrypter:
keyID, rawKey = encryptionKey.KeyID(), encryptionKey
default:
rawKey = encryptionKey
}
@ -267,9 +269,11 @@ func makeJWERecipient(alg KeyAlgorithm, encryptionKey interface{}) (recipientKey
recipient, err := makeJWERecipient(alg, encryptionKey.Key)
recipient.keyID = encryptionKey.KeyID
return recipient, err
default:
return recipientKeyInfo{}, ErrUnsupportedKeyType
}
if encrypter, ok := encryptionKey.(OpaqueKeyEncrypter); ok {
return newOpaqueKeyEncrypter(alg, encrypter)
}
return recipientKeyInfo{}, ErrUnsupportedKeyType
}
// newDecrypter creates an appropriate decrypter based on the key type
@ -295,9 +299,11 @@ func newDecrypter(decryptionKey interface{}) (keyDecrypter, error) {
return newDecrypter(decryptionKey.Key)
case *JSONWebKey:
return newDecrypter(decryptionKey.Key)
default:
return nil, ErrUnsupportedKeyType
}
if okd, ok := decryptionKey.(OpaqueKeyDecrypter); ok {
return &opaqueKeyDecrypter{decrypter: okd}, nil
}
return nil, ErrUnsupportedKeyType
}
// Implementation of encrypt method producing a JWE object.

View file

@ -23,13 +23,12 @@ import (
"encoding/binary"
"io"
"math/big"
"regexp"
"strings"
"unicode"
"gopkg.in/square/go-jose.v2/json"
)
var stripWhitespaceRegex = regexp.MustCompile("\\s")
// Helper function to serialize known-good objects.
// Precondition: value is not a nil pointer.
func mustSerializeJSON(value interface{}) []byte {
@ -56,7 +55,14 @@ func mustSerializeJSON(value interface{}) []byte {
// Strip all newlines and whitespace
func stripWhitespace(data string) string {
return stripWhitespaceRegex.ReplaceAllString(data, "")
buf := strings.Builder{}
buf.Grow(len(data))
for _, r := range data {
if !unicode.IsSpace(r) {
buf.WriteRune(r)
}
}
return buf.String()
}
// Perform compression based on algorithm

View file

@ -357,11 +357,11 @@ func (key rawJSONWebKey) ecPublicKey() (*ecdsa.PublicKey, error) {
// the curve specified in the "crv" parameter.
// https://tools.ietf.org/html/rfc7518#section-6.2.1.2
if curveSize(curve) != len(key.X.data) {
return nil, fmt.Errorf("square/go-jose: invalid EC private key, wrong length for x")
return nil, fmt.Errorf("square/go-jose: invalid EC public key, wrong length for x")
}
if curveSize(curve) != len(key.Y.data) {
return nil, fmt.Errorf("square/go-jose: invalid EC private key, wrong length for y")
return nil, fmt.Errorf("square/go-jose: invalid EC public key, wrong length for y")
}
x := key.X.bigInt()

View file

@ -17,6 +17,7 @@
package jose
import (
"bytes"
"encoding/base64"
"errors"
"fmt"
@ -75,13 +76,21 @@ type Signature struct {
}
// ParseSigned parses a signed message in compact or full serialization format.
func ParseSigned(input string) (*JSONWebSignature, error) {
input = stripWhitespace(input)
if strings.HasPrefix(input, "{") {
return parseSignedFull(input)
func ParseSigned(signature string) (*JSONWebSignature, error) {
signature = stripWhitespace(signature)
if strings.HasPrefix(signature, "{") {
return parseSignedFull(signature)
}
return parseSignedCompact(input)
return parseSignedCompact(signature, nil)
}
// ParseDetached parses a signed message in compact serialization format with detached payload.
func ParseDetached(signature string, payload []byte) (*JSONWebSignature, error) {
if payload == nil {
return nil, errors.New("square/go-jose: nil payload")
}
return parseSignedCompact(stripWhitespace(signature), payload)
}
// Get a header value
@ -93,20 +102,39 @@ func (sig Signature) mergedHeaders() rawHeader {
}
// Compute data to be signed
func (obj JSONWebSignature) computeAuthData(payload []byte, signature *Signature) []byte {
var serializedProtected string
func (obj JSONWebSignature) computeAuthData(payload []byte, signature *Signature) ([]byte, error) {
var authData bytes.Buffer
protectedHeader := new(rawHeader)
if signature.original != nil && signature.original.Protected != nil {
serializedProtected = signature.original.Protected.base64()
if err := json.Unmarshal(signature.original.Protected.bytes(), protectedHeader); err != nil {
return nil, err
}
authData.WriteString(signature.original.Protected.base64())
} else if signature.protected != nil {
serializedProtected = base64.RawURLEncoding.EncodeToString(mustSerializeJSON(signature.protected))
} else {
serializedProtected = ""
protectedHeader = signature.protected
authData.WriteString(base64.RawURLEncoding.EncodeToString(mustSerializeJSON(protectedHeader)))
}
return []byte(fmt.Sprintf("%s.%s",
serializedProtected,
base64.RawURLEncoding.EncodeToString(payload)))
needsBase64 := true
if protectedHeader != nil {
var err error
if needsBase64, err = protectedHeader.getB64(); err != nil {
needsBase64 = true
}
}
authData.WriteByte('.')
if needsBase64 {
authData.WriteString(base64.RawURLEncoding.EncodeToString(payload))
} else {
authData.Write(payload)
}
return authData.Bytes(), nil
}
// parseSignedFull parses a message in full format.
@ -246,21 +274,27 @@ func (parsed *rawJSONWebSignature) sanitized() (*JSONWebSignature, error) {
}
// parseSignedCompact parses a message in compact format.
func parseSignedCompact(input string) (*JSONWebSignature, error) {
func parseSignedCompact(input string, payload []byte) (*JSONWebSignature, error) {
parts := strings.Split(input, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("square/go-jose: compact JWS format must have three parts")
}
if parts[1] != "" && payload != nil {
return nil, fmt.Errorf("square/go-jose: payload is not detached")
}
rawProtected, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
return nil, err
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if payload == nil {
payload, err = base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, err
}
}
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
@ -275,19 +309,30 @@ func parseSignedCompact(input string) (*JSONWebSignature, error) {
return raw.sanitized()
}
// CompactSerialize serializes an object using the compact serialization format.
func (obj JSONWebSignature) CompactSerialize() (string, error) {
func (obj JSONWebSignature) compactSerialize(detached bool) (string, error) {
if len(obj.Signatures) != 1 || obj.Signatures[0].header != nil || obj.Signatures[0].protected == nil {
return "", ErrNotSupported
}
serializedProtected := mustSerializeJSON(obj.Signatures[0].protected)
serializedProtected := base64.RawURLEncoding.EncodeToString(mustSerializeJSON(obj.Signatures[0].protected))
payload := ""
signature := base64.RawURLEncoding.EncodeToString(obj.Signatures[0].Signature)
return fmt.Sprintf(
"%s.%s.%s",
base64.RawURLEncoding.EncodeToString(serializedProtected),
base64.RawURLEncoding.EncodeToString(obj.payload),
base64.RawURLEncoding.EncodeToString(obj.Signatures[0].Signature)), nil
if !detached {
payload = base64.RawURLEncoding.EncodeToString(obj.payload)
}
return fmt.Sprintf("%s.%s.%s", serializedProtected, payload, signature), nil
}
// CompactSerialize serializes an object using the compact serialization format.
func (obj JSONWebSignature) CompactSerialize() (string, error) {
return obj.compactSerialize(false)
}
// DetachedCompactSerialize serializes an object using the compact serialization format with detached payload.
func (obj JSONWebSignature) DetachedCompactSerialize() (string, error) {
return obj.compactSerialize(true)
}
// FullSerialize serializes an object using the full JSON serialization format.

View file

@ -81,3 +81,64 @@ type opaqueVerifier struct {
func (o *opaqueVerifier) verifyPayload(payload []byte, signature []byte, alg SignatureAlgorithm) error {
return o.verifier.VerifyPayload(payload, signature, alg)
}
// OpaqueKeyEncrypter is an interface that supports encrypting keys with an opaque key.
type OpaqueKeyEncrypter interface {
// KeyID returns the kid
KeyID() string
// Algs returns a list of supported key encryption algorithms.
Algs() []KeyAlgorithm
// encryptKey encrypts the CEK using the given algorithm.
encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error)
}
type opaqueKeyEncrypter struct {
encrypter OpaqueKeyEncrypter
}
func newOpaqueKeyEncrypter(alg KeyAlgorithm, encrypter OpaqueKeyEncrypter) (recipientKeyInfo, error) {
var algSupported bool
for _, salg := range encrypter.Algs() {
if alg == salg {
algSupported = true
break
}
}
if !algSupported {
return recipientKeyInfo{}, ErrUnsupportedAlgorithm
}
return recipientKeyInfo{
keyID: encrypter.KeyID(),
keyAlg: alg,
keyEncrypter: &opaqueKeyEncrypter{
encrypter: encrypter,
},
}, nil
}
func (oke *opaqueKeyEncrypter) encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) {
return oke.encrypter.encryptKey(cek, alg)
}
//OpaqueKeyDecrypter is an interface that supports decrypting keys with an opaque key.
type OpaqueKeyDecrypter interface {
DecryptKey(encryptedKey []byte, header Header) ([]byte, error)
}
type opaqueKeyDecrypter struct {
decrypter OpaqueKeyDecrypter
}
func (okd *opaqueKeyDecrypter) decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) {
mergedHeaders := rawHeader{}
mergedHeaders.merge(&headers)
mergedHeaders.merge(recipient.header)
header, err := mergedHeaders.sanitized()
if err != nil {
return nil, err
}
return okd.decrypter.DecryptKey(recipient.encryptedKey, header)
}

View file

@ -153,12 +153,18 @@ const (
headerJWK = "jwk" // *JSONWebKey
headerKeyID = "kid" // string
headerNonce = "nonce" // string
headerB64 = "b64" // bool
headerP2C = "p2c" // *byteBuffer (int)
headerP2S = "p2s" // *byteBuffer ([]byte)
)
// supportedCritical is the set of supported extensions that are understood and processed.
var supportedCritical = map[string]bool{
headerB64: true,
}
// rawHeader represents the JOSE header for JWE/JWS objects (used for parsing).
//
// The decoding of the constituent items is deferred because we want to marshal
@ -349,6 +355,21 @@ func (parsed rawHeader) getP2S() (*byteBuffer, error) {
return parsed.getByteBuffer(headerP2S)
}
// getB64 extracts parsed "b64" from the raw JSON, defaulting to true.
func (parsed rawHeader) getB64() (bool, error) {
v := parsed[headerB64]
if v == nil {
return true, nil
}
var b64 bool
err := json.Unmarshal(*v, &b64)
if err != nil {
return true, err
}
return b64, nil
}
// sanitized produces a cleaned-up header object from the raw JSON.
func (parsed rawHeader) sanitized() (h Header, err error) {
for k, v := range parsed {

View file

@ -17,6 +17,7 @@
package jose
import (
"bytes"
"crypto/ecdsa"
"crypto/rsa"
"encoding/base64"
@ -77,6 +78,27 @@ func (so *SignerOptions) WithType(typ ContentType) *SignerOptions {
return so.WithHeader(HeaderType, typ)
}
// WithCritical adds the given names to the critical ("crit") header and returns
// the updated SignerOptions.
func (so *SignerOptions) WithCritical(names ...string) *SignerOptions {
if so.ExtraHeaders[headerCritical] == nil {
so.WithHeader(headerCritical, make([]string, 0, len(names)))
}
crit := so.ExtraHeaders[headerCritical].([]string)
so.ExtraHeaders[headerCritical] = append(crit, names...)
return so
}
// WithBase64 adds a base64url-encode payload ("b64") header and returns the updated
// SignerOptions. When the "b64" value is "false", the payload is not base64 encoded.
func (so *SignerOptions) WithBase64(b64 bool) *SignerOptions {
if !b64 {
so.WithHeader(headerB64, b64)
so.WithCritical(headerB64)
}
return so
}
type payloadSigner interface {
signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error)
}
@ -233,7 +255,10 @@ func (ctx *genericSigner) Sign(payload []byte) (*JSONWebSignature, error) {
if ctx.embedJWK {
protected[headerJWK] = recipient.publicKey()
} else {
protected[headerKeyID] = recipient.publicKey().KeyID
keyID := recipient.publicKey().KeyID
if keyID != "" {
protected[headerKeyID] = keyID
}
}
}
@ -250,12 +275,26 @@ func (ctx *genericSigner) Sign(payload []byte) (*JSONWebSignature, error) {
}
serializedProtected := mustSerializeJSON(protected)
needsBase64 := true
input := []byte(fmt.Sprintf("%s.%s",
base64.RawURLEncoding.EncodeToString(serializedProtected),
base64.RawURLEncoding.EncodeToString(payload)))
if b64, ok := protected[headerB64]; ok {
if needsBase64, ok = b64.(bool); !ok {
return nil, errors.New("square/go-jose: Invalid b64 header parameter")
}
}
signatureInfo, err := recipient.signer.signPayload(input, recipient.sigAlg)
var input bytes.Buffer
input.WriteString(base64.RawURLEncoding.EncodeToString(serializedProtected))
input.WriteByte('.')
if needsBase64 {
input.WriteString(base64.RawURLEncoding.EncodeToString(payload))
} else {
input.Write(payload)
}
signatureInfo, err := recipient.signer.signPayload(input.Bytes(), recipient.sigAlg)
if err != nil {
return nil, err
}
@ -324,12 +363,18 @@ func (obj JSONWebSignature) DetachedVerify(payload []byte, verificationKey inter
if err != nil {
return err
}
if len(critical) > 0 {
// Unsupported crit header
for _, name := range critical {
if !supportedCritical[name] {
return ErrCryptoFailure
}
}
input, err := obj.computeAuthData(payload, &signature)
if err != nil {
return ErrCryptoFailure
}
input := obj.computeAuthData(payload, &signature)
alg := headers.getSignatureAlgorithm()
err = verifier.verifyPayload(input, signature.Signature, alg)
if err == nil {
@ -366,18 +411,25 @@ func (obj JSONWebSignature) DetachedVerifyMulti(payload []byte, verificationKey
return -1, Signature{}, err
}
outer:
for i, signature := range obj.Signatures {
headers := signature.mergedHeaders()
critical, err := headers.getCritical()
if err != nil {
continue
}
if len(critical) > 0 {
// Unsupported crit header
for _, name := range critical {
if !supportedCritical[name] {
continue outer
}
}
input, err := obj.computeAuthData(payload, &signature)
if err != nil {
continue
}
input := obj.computeAuthData(payload, &signature)
alg := headers.getSignatureAlgorithm()
err = verifier.verifyPayload(input, signature.Signature, alg)
if err == nil {

13
vendor/modules.txt vendored
View file

@ -91,6 +91,8 @@ github.com/circonus-labs/circonus-gometrics/checkmgr
github.com/circonus-labs/circonusllhist
# github.com/coredns/coredns v1.1.2
github.com/coredns/coredns/plugin/pkg/dnsutil
# github.com/coreos/go-oidc v2.1.0+incompatible
github.com/coreos/go-oidc
# github.com/davecgh/go-spew v1.1.1
github.com/davecgh/go-spew/spew
# github.com/denverdino/aliyungo v0.0.0-20170926055100-d3308649c661
@ -222,7 +224,7 @@ github.com/hashicorp/go-sockaddr
github.com/hashicorp/go-sockaddr/template
# github.com/hashicorp/go-syslog v1.0.0
github.com/hashicorp/go-syslog
# github.com/hashicorp/go-uuid v1.0.1
# github.com/hashicorp/go-uuid v1.0.2
github.com/hashicorp/go-uuid
# github.com/hashicorp/go-version v1.2.0
github.com/hashicorp/go-version
@ -307,6 +309,8 @@ github.com/mitchellh/go-testing-interface
github.com/mitchellh/hashstructure
# github.com/mitchellh/mapstructure v1.2.3
github.com/mitchellh/mapstructure
# github.com/mitchellh/pointerstructure v1.0.0
github.com/mitchellh/pointerstructure
# github.com/mitchellh/reflectwalk v1.0.1
github.com/mitchellh/reflectwalk
# github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd
@ -319,6 +323,8 @@ github.com/nicolai86/scaleway-sdk
github.com/packethost/packngo
# github.com/pascaldekloe/goe v0.1.0
github.com/pascaldekloe/goe/verify
# github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/patrickmn/go-cache
# github.com/pierrec/lz4 v2.0.5+incompatible
github.com/pierrec/lz4
github.com/pierrec/lz4/internal/xxh32
@ -330,6 +336,9 @@ github.com/pmezard/go-difflib/difflib
github.com/posener/complete
github.com/posener/complete/cmd
github.com/posener/complete/cmd/install
# github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35
github.com/pquerna/cachecontrol
github.com/pquerna/cachecontrol/cacheobject
# github.com/prometheus/client_golang v1.0.0
github.com/prometheus/client_golang/prometheus
github.com/prometheus/client_golang/prometheus/internal
@ -543,7 +552,7 @@ google.golang.org/grpc/tap
gopkg.in/inf.v0
# gopkg.in/resty.v1 v1.12.0
gopkg.in/resty.v1
# gopkg.in/square/go-jose.v2 v2.3.1
# gopkg.in/square/go-jose.v2 v2.4.1
gopkg.in/square/go-jose.v2
gopkg.in/square/go-jose.v2/cipher
gopkg.in/square/go-jose.v2/json