530 lines
13 KiB
Go
530 lines
13 KiB
Go
// package oidcauthtest exposes tools to assist in writing unit tests of OIDC
|
|
// and JWT authentication workflows.
|
|
//
|
|
// When the package is loaded it will randomly generate an ECDSA signing
|
|
// keypair used to sign JWTs both via the Server and the SignJWT method.
|
|
package oidcauthtest
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/x509"
|
|
"encoding/json"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/hashicorp/consul/internal/go-sso/oidcauth/internal/strutil"
|
|
"github.com/mitchellh/go-testing-interface"
|
|
"github.com/stretchr/testify/require"
|
|
"gopkg.in/square/go-jose.v2"
|
|
"gopkg.in/square/go-jose.v2/jwt"
|
|
)
|
|
|
|
// Server is local server the mocks the endpoints used by the OIDC and
|
|
// JWKS process.
|
|
type Server struct {
|
|
httpServer *httptest.Server
|
|
caCert string
|
|
returnFunc func()
|
|
|
|
jwks *jose.JSONWebKeySet
|
|
allowedRedirectURIs []string
|
|
replySubject string
|
|
replyUserinfo map[string]interface{}
|
|
|
|
mu sync.Mutex
|
|
clientID string
|
|
clientSecret string
|
|
expectedAuthCode string
|
|
expectedAuthNonce string
|
|
customClaims map[string]interface{}
|
|
customAudience string
|
|
omitIDToken bool
|
|
disableUserInfo bool
|
|
}
|
|
|
|
type startOption struct {
|
|
port int
|
|
returnFunc func()
|
|
}
|
|
|
|
// WithPort is a option for Start that lets the caller control the port
|
|
// allocation. The returnFunc parameter is used when the provider is stopped to
|
|
// return the port in whatever bookkeeping system the caller wants to use.
|
|
func WithPort(port int, returnFunc func()) startOption {
|
|
return startOption{
|
|
port: port,
|
|
returnFunc: returnFunc,
|
|
}
|
|
}
|
|
|
|
// Start creates a disposable Server. If the port provided is
|
|
// zero it will bind to a random free port, otherwise the provided port is
|
|
// used.
|
|
func Start(t testing.T, options ...startOption) *Server {
|
|
s := &Server{
|
|
allowedRedirectURIs: []string{
|
|
"https://example.com",
|
|
},
|
|
replySubject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
|
|
replyUserinfo: map[string]interface{}{
|
|
"color": "red",
|
|
"temperature": "76",
|
|
"flavor": "umami",
|
|
},
|
|
}
|
|
|
|
jwks, err := newJWKS(ecdsaPublicKey)
|
|
require.NoError(t, err)
|
|
s.jwks = jwks
|
|
|
|
var (
|
|
port int
|
|
returnFunc func()
|
|
)
|
|
for _, option := range options {
|
|
if option.port > 0 {
|
|
port = option.port
|
|
returnFunc = option.returnFunc
|
|
}
|
|
}
|
|
|
|
s.httpServer = httptestNewUnstartedServerWithPort(s, port)
|
|
s.httpServer.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
|
|
s.httpServer.StartTLS()
|
|
if returnFunc != nil {
|
|
t.Cleanup(returnFunc)
|
|
}
|
|
t.Cleanup(s.httpServer.Close)
|
|
|
|
cert := s.httpServer.Certificate()
|
|
|
|
var buf bytes.Buffer
|
|
require.NoError(t, pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}))
|
|
s.caCert = buf.String()
|
|
|
|
return s
|
|
}
|
|
|
|
// SetClientCreds is for configuring the client information required for the
|
|
// OIDC workflows.
|
|
func (s *Server) SetClientCreds(clientID, clientSecret string) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.clientID = clientID
|
|
s.clientSecret = clientSecret
|
|
}
|
|
|
|
// SetExpectedAuthCode configures the auth code to return from /auth and the
|
|
// allowed auth code for /token.
|
|
func (s *Server) SetExpectedAuthCode(code string) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.expectedAuthCode = code
|
|
}
|
|
|
|
// SetExpectedAuthNonce configures the nonce value required for /auth.
|
|
func (s *Server) SetExpectedAuthNonce(nonce string) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.expectedAuthNonce = nonce
|
|
}
|
|
|
|
// SetAllowedRedirectURIs allows you to configure the allowed redirect URIs for
|
|
// the OIDC workflow. If not configured a sample of "https://example.com" is
|
|
// used.
|
|
func (s *Server) SetAllowedRedirectURIs(uris []string) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.allowedRedirectURIs = uris
|
|
}
|
|
|
|
// SetCustomClaims lets you set claims to return in the JWT issued by the OIDC
|
|
// workflow.
|
|
func (s *Server) SetCustomClaims(customClaims map[string]interface{}) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.customClaims = customClaims
|
|
}
|
|
|
|
// SetCustomAudience configures what audience value to embed in the JWT issued
|
|
// by the OIDC workflow.
|
|
func (s *Server) SetCustomAudience(customAudience string) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.customAudience = customAudience
|
|
}
|
|
|
|
// OmitIDTokens forces an error state where the /token endpoint does not return
|
|
// id_token.
|
|
func (s *Server) OmitIDTokens() {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.omitIDToken = true
|
|
}
|
|
|
|
// DisableUserInfo makes the userinfo endpoint return 404 and omits it from the
|
|
// discovery config.
|
|
func (s *Server) DisableUserInfo() {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.disableUserInfo = true
|
|
}
|
|
|
|
// Stop stops the running Server.
|
|
func (s *Server) Stop() {
|
|
s.httpServer.Close()
|
|
}
|
|
|
|
// Addr returns the current base URL for the running webserver.
|
|
func (s *Server) Addr() string { return s.httpServer.URL }
|
|
|
|
// CACert returns the pem-encoded CA certificate used by the HTTPS server.
|
|
func (s *Server) CACert() string { return s.caCert }
|
|
|
|
// SigningKeys returns the pem-encoded keys used to sign JWTs.
|
|
func (s *Server) SigningKeys() (pub, priv string) {
|
|
return SigningKeys()
|
|
}
|
|
|
|
// ServeHTTP implements http.Handler.
|
|
func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
|
|
switch req.URL.Path {
|
|
case "/.well-known/openid-configuration":
|
|
if req.Method != "GET" {
|
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
reply := struct {
|
|
Issuer string `json:"issuer"`
|
|
AuthEndpoint string `json:"authorization_endpoint"`
|
|
TokenEndpoint string `json:"token_endpoint"`
|
|
JWKSURI string `json:"jwks_uri"`
|
|
UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"`
|
|
}{
|
|
Issuer: s.Addr(),
|
|
AuthEndpoint: s.Addr() + "/auth",
|
|
TokenEndpoint: s.Addr() + "/token",
|
|
JWKSURI: s.Addr() + "/certs",
|
|
UserinfoEndpoint: s.Addr() + "/userinfo",
|
|
}
|
|
if s.disableUserInfo {
|
|
reply.UserinfoEndpoint = ""
|
|
}
|
|
|
|
if err := writeJSON(w, &reply); err != nil {
|
|
return
|
|
}
|
|
|
|
case "/auth":
|
|
if req.Method != "GET" {
|
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
qv := req.URL.Query()
|
|
|
|
if qv.Get("response_type") != "code" {
|
|
writeAuthErrorResponse(w, req, "unsupported_response_type", "")
|
|
return
|
|
}
|
|
if qv.Get("scope") != "openid" {
|
|
writeAuthErrorResponse(w, req, "invalid_scope", "")
|
|
return
|
|
}
|
|
|
|
if s.expectedAuthCode == "" {
|
|
writeAuthErrorResponse(w, req, "access_denied", "")
|
|
return
|
|
}
|
|
|
|
nonce := qv.Get("nonce")
|
|
if s.expectedAuthNonce != "" && s.expectedAuthNonce != nonce {
|
|
writeAuthErrorResponse(w, req, "access_denied", "")
|
|
return
|
|
}
|
|
|
|
state := qv.Get("state")
|
|
if state == "" {
|
|
writeAuthErrorResponse(w, req, "invalid_request", "missing state parameter")
|
|
return
|
|
}
|
|
|
|
redirectURI := qv.Get("redirect_uri")
|
|
if redirectURI == "" {
|
|
writeAuthErrorResponse(w, req, "invalid_request", "missing redirect_uri parameter")
|
|
return
|
|
}
|
|
|
|
redirectURI += "?state=" + url.QueryEscape(state) +
|
|
"&code=" + url.QueryEscape(s.expectedAuthCode)
|
|
|
|
http.Redirect(w, req, redirectURI, http.StatusFound)
|
|
|
|
return
|
|
|
|
case "/certs":
|
|
if req.Method != "GET" {
|
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
if err := writeJSON(w, s.jwks); err != nil {
|
|
return
|
|
}
|
|
|
|
case "/certs_missing":
|
|
w.WriteHeader(http.StatusNotFound)
|
|
|
|
case "/certs_invalid":
|
|
w.Write([]byte("It's not a keyset!"))
|
|
|
|
case "/token":
|
|
if req.Method != "POST" {
|
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
switch {
|
|
case req.FormValue("grant_type") != "authorization_code":
|
|
_ = writeTokenErrorResponse(w, req, http.StatusBadRequest, "invalid_request", "bad grant_type")
|
|
return
|
|
case !strutil.StrListContains(s.allowedRedirectURIs, req.FormValue("redirect_uri")):
|
|
_ = writeTokenErrorResponse(w, req, http.StatusBadRequest, "invalid_request", "redirect_uri is not allowed")
|
|
return
|
|
case req.FormValue("code") != s.expectedAuthCode:
|
|
_ = writeTokenErrorResponse(w, req, http.StatusUnauthorized, "invalid_grant", "unexpected auth code")
|
|
return
|
|
}
|
|
|
|
stdClaims := jwt.Claims{
|
|
Subject: s.replySubject,
|
|
Issuer: s.Addr(),
|
|
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
|
|
Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Second)),
|
|
Audience: jwt.Audience{s.clientID},
|
|
}
|
|
if s.customAudience != "" {
|
|
stdClaims.Audience = jwt.Audience{s.customAudience}
|
|
}
|
|
|
|
jwtData, err := SignJWT("", stdClaims, s.customClaims)
|
|
if err != nil {
|
|
_ = writeTokenErrorResponse(w, req, http.StatusInternalServerError, "server_error", err.Error())
|
|
return
|
|
}
|
|
|
|
reply := struct {
|
|
AccessToken string `json:"access_token"`
|
|
IDToken string `json:"id_token,omitempty"`
|
|
}{
|
|
AccessToken: jwtData,
|
|
IDToken: jwtData,
|
|
}
|
|
if s.omitIDToken {
|
|
reply.IDToken = ""
|
|
}
|
|
if err := writeJSON(w, &reply); err != nil {
|
|
return
|
|
}
|
|
|
|
case "/userinfo":
|
|
if s.disableUserInfo {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
return
|
|
}
|
|
if req.Method != "GET" {
|
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
if err := writeJSON(w, s.replyUserinfo); err != nil {
|
|
return
|
|
}
|
|
|
|
default:
|
|
w.WriteHeader(http.StatusNotFound)
|
|
}
|
|
}
|
|
|
|
func writeAuthErrorResponse(w http.ResponseWriter, req *http.Request, errorCode, errorMessage string) {
|
|
qv := req.URL.Query()
|
|
|
|
redirectURI := qv.Get("redirect_uri") +
|
|
"?state=" + url.QueryEscape(qv.Get("state")) +
|
|
"&error=" + url.QueryEscape(errorCode)
|
|
|
|
if errorMessage != "" {
|
|
redirectURI += "&error_description=" + url.QueryEscape(errorMessage)
|
|
}
|
|
|
|
http.Redirect(w, req, redirectURI, http.StatusFound)
|
|
}
|
|
|
|
func writeTokenErrorResponse(w http.ResponseWriter, req *http.Request, statusCode int, errorCode, errorMessage string) error {
|
|
body := struct {
|
|
Code string `json:"error"`
|
|
Desc string `json:"error_description,omitempty"`
|
|
}{
|
|
Code: errorCode,
|
|
Desc: errorMessage,
|
|
}
|
|
|
|
w.WriteHeader(statusCode)
|
|
return writeJSON(w, &body)
|
|
}
|
|
|
|
// newJWKS converts a pem-encoded public key into JWKS data suitable for a
|
|
// verification endpoint response
|
|
func newJWKS(pubKey string) (*jose.JSONWebKeySet, error) {
|
|
block, _ := pem.Decode([]byte(pubKey))
|
|
if block == nil {
|
|
return nil, fmt.Errorf("unable to decode public key")
|
|
}
|
|
input := block.Bytes
|
|
|
|
pub, err := x509.ParsePKIXPublicKey(input)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &jose.JSONWebKeySet{
|
|
Keys: []jose.JSONWebKey{
|
|
jose.JSONWebKey{
|
|
Key: pub,
|
|
},
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func writeJSON(w http.ResponseWriter, out interface{}) error {
|
|
enc := json.NewEncoder(w)
|
|
return enc.Encode(out)
|
|
}
|
|
|
|
// SignJWT will bundle the provided claims into a signed JWT. The provided key
|
|
// is assumed to be ECDSA.
|
|
//
|
|
// If no private key is provided, the default package keys are used. These can
|
|
// be retrieved via the SigningKeys() method.
|
|
func SignJWT(privKey string, claims jwt.Claims, privateClaims interface{}) (string, error) {
|
|
if privKey == "" {
|
|
privKey = ecdsaPrivateKey
|
|
}
|
|
var key *ecdsa.PrivateKey
|
|
block, _ := pem.Decode([]byte(privKey))
|
|
if block != nil {
|
|
var err error
|
|
key, err = x509.ParseECPrivateKey(block.Bytes)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
|
|
sig, err := jose.NewSigner(
|
|
jose.SigningKey{Algorithm: jose.ES256, Key: key},
|
|
(&jose.SignerOptions{}).WithType("JWT"),
|
|
)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
raw, err := jwt.Signed(sig).
|
|
Claims(claims).
|
|
Claims(privateClaims).
|
|
CompactSerialize()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return raw, nil
|
|
}
|
|
|
|
// httptestNewUnstartedServerWithPort is roughly the same as
|
|
// httptest.NewUnstartedServer() but allows the caller to explicitly choose the
|
|
// port if desired.
|
|
func httptestNewUnstartedServerWithPort(handler http.Handler, port int) *httptest.Server {
|
|
if port == 0 {
|
|
return httptest.NewUnstartedServer(handler)
|
|
}
|
|
addr := net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
|
|
l, err := net.Listen("tcp", addr)
|
|
if err != nil {
|
|
panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
|
|
}
|
|
|
|
return &httptest.Server{
|
|
Listener: l,
|
|
Config: &http.Server{Handler: handler},
|
|
}
|
|
}
|
|
|
|
// SigningKeys returns the pem-encoded keys used to sign JWTs by default.
|
|
func SigningKeys() (pub, priv string) {
|
|
return ecdsaPublicKey, ecdsaPrivateKey
|
|
}
|
|
|
|
var (
|
|
ecdsaPublicKey string
|
|
ecdsaPrivateKey string
|
|
)
|
|
|
|
func init() {
|
|
// Each time we run tests we generate a unique set of keys for use in the
|
|
// test. These are cached between runs but do not persist between restarts
|
|
// of the test binary.
|
|
var err error
|
|
ecdsaPublicKey, ecdsaPrivateKey, err = generateKey()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
func generateKey() (pub, priv string, err error) {
|
|
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("error generating private key: %v", err)
|
|
}
|
|
|
|
{
|
|
derBytes, err := x509.MarshalECPrivateKey(privateKey)
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("error marshaling private key: %v", err)
|
|
}
|
|
pemBlock := &pem.Block{
|
|
Type: "EC PRIVATE KEY",
|
|
Bytes: derBytes,
|
|
}
|
|
priv = string(pem.EncodeToMemory(pemBlock))
|
|
}
|
|
{
|
|
derBytes, err := x509.MarshalPKIXPublicKey(privateKey.Public())
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("error marshaling public key: %v", err)
|
|
}
|
|
pemBlock := &pem.Block{
|
|
Type: "PUBLIC KEY",
|
|
Bytes: derBytes,
|
|
}
|
|
pub = string(pem.EncodeToMemory(pemBlock))
|
|
}
|
|
|
|
return pub, priv, nil
|
|
}
|