// package oidcauthtest exposes tools to assist in writing unit tests of OIDC // and JWT authentication workflows. // // When the package is loaded it will randomly generate an ECDSA signing // keypair used to sign JWTs both via the Server and the SignJWT method. package oidcauthtest import ( "bytes" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/x509" "encoding/json" "encoding/pem" "fmt" "io/ioutil" "log" "net" "net/http" "net/http/httptest" "net/url" "strconv" "sync" "time" "github.com/hashicorp/consul/internal/go-sso/oidcauth/internal/strutil" "github.com/mitchellh/go-testing-interface" "github.com/stretchr/testify/require" "gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2/jwt" ) // Server is local server the mocks the endpoints used by the OIDC and // JWKS process. type Server struct { httpServer *httptest.Server caCert string returnFunc func() jwks *jose.JSONWebKeySet allowedRedirectURIs []string replySubject string replyUserinfo map[string]interface{} mu sync.Mutex clientID string clientSecret string expectedAuthCode string expectedAuthNonce string customClaims map[string]interface{} customAudience string omitIDToken bool disableUserInfo bool } type startOption struct { port int returnFunc func() } // WithPort is a option for Start that lets the caller control the port // allocation. The returnFunc parameter is used when the provider is stopped to // return the port in whatever bookkeeping system the caller wants to use. func WithPort(port int, returnFunc func()) startOption { return startOption{ port: port, returnFunc: returnFunc, } } // Start creates a disposable Server. If the port provided is // zero it will bind to a random free port, otherwise the provided port is // used. func Start(t testing.T, options ...startOption) *Server { s := &Server{ allowedRedirectURIs: []string{ "https://example.com", }, replySubject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients", replyUserinfo: map[string]interface{}{ "color": "red", "temperature": "76", "flavor": "umami", }, } jwks, err := newJWKS(ecdsaPublicKey) require.NoError(t, err) s.jwks = jwks var ( port int returnFunc func() ) for _, option := range options { if option.port > 0 { port = option.port returnFunc = option.returnFunc } } s.httpServer = httptestNewUnstartedServerWithPort(s, port) s.httpServer.Config.ErrorLog = log.New(ioutil.Discard, "", 0) s.httpServer.StartTLS() if returnFunc != nil { t.Cleanup(returnFunc) } t.Cleanup(s.httpServer.Close) cert := s.httpServer.Certificate() var buf bytes.Buffer require.NoError(t, pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})) s.caCert = buf.String() return s } // SetClientCreds is for configuring the client information required for the // OIDC workflows. func (s *Server) SetClientCreds(clientID, clientSecret string) { s.mu.Lock() defer s.mu.Unlock() s.clientID = clientID s.clientSecret = clientSecret } // SetExpectedAuthCode configures the auth code to return from /auth and the // allowed auth code for /token. func (s *Server) SetExpectedAuthCode(code string) { s.mu.Lock() defer s.mu.Unlock() s.expectedAuthCode = code } // SetExpectedAuthNonce configures the nonce value required for /auth. func (s *Server) SetExpectedAuthNonce(nonce string) { s.mu.Lock() defer s.mu.Unlock() s.expectedAuthNonce = nonce } // SetAllowedRedirectURIs allows you to configure the allowed redirect URIs for // the OIDC workflow. If not configured a sample of "https://example.com" is // used. func (s *Server) SetAllowedRedirectURIs(uris []string) { s.mu.Lock() defer s.mu.Unlock() s.allowedRedirectURIs = uris } // SetCustomClaims lets you set claims to return in the JWT issued by the OIDC // workflow. func (s *Server) SetCustomClaims(customClaims map[string]interface{}) { s.mu.Lock() defer s.mu.Unlock() s.customClaims = customClaims } // SetCustomAudience configures what audience value to embed in the JWT issued // by the OIDC workflow. func (s *Server) SetCustomAudience(customAudience string) { s.mu.Lock() defer s.mu.Unlock() s.customAudience = customAudience } // OmitIDTokens forces an error state where the /token endpoint does not return // id_token. func (s *Server) OmitIDTokens() { s.mu.Lock() defer s.mu.Unlock() s.omitIDToken = true } // DisableUserInfo makes the userinfo endpoint return 404 and omits it from the // discovery config. func (s *Server) DisableUserInfo() { s.mu.Lock() defer s.mu.Unlock() s.disableUserInfo = true } // Stop stops the running Server. func (s *Server) Stop() { s.httpServer.Close() } // Addr returns the current base URL for the running webserver. func (s *Server) Addr() string { return s.httpServer.URL } // CACert returns the pem-encoded CA certificate used by the HTTPS server. func (s *Server) CACert() string { return s.caCert } // SigningKeys returns the pem-encoded keys used to sign JWTs. func (s *Server) SigningKeys() (pub, priv string) { return SigningKeys() } // ServeHTTP implements http.Handler. func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { s.mu.Lock() defer s.mu.Unlock() w.Header().Set("Content-Type", "application/json") switch req.URL.Path { case "/.well-known/openid-configuration": if req.Method != "GET" { w.WriteHeader(http.StatusMethodNotAllowed) return } reply := struct { Issuer string `json:"issuer"` AuthEndpoint string `json:"authorization_endpoint"` TokenEndpoint string `json:"token_endpoint"` JWKSURI string `json:"jwks_uri"` UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"` }{ Issuer: s.Addr(), AuthEndpoint: s.Addr() + "/auth", TokenEndpoint: s.Addr() + "/token", JWKSURI: s.Addr() + "/certs", UserinfoEndpoint: s.Addr() + "/userinfo", } if s.disableUserInfo { reply.UserinfoEndpoint = "" } if err := writeJSON(w, &reply); err != nil { return } case "/auth": if req.Method != "GET" { w.WriteHeader(http.StatusMethodNotAllowed) return } qv := req.URL.Query() if qv.Get("response_type") != "code" { writeAuthErrorResponse(w, req, "unsupported_response_type", "") return } if qv.Get("scope") != "openid" { writeAuthErrorResponse(w, req, "invalid_scope", "") return } if s.expectedAuthCode == "" { writeAuthErrorResponse(w, req, "access_denied", "") return } nonce := qv.Get("nonce") if s.expectedAuthNonce != "" && s.expectedAuthNonce != nonce { writeAuthErrorResponse(w, req, "access_denied", "") return } state := qv.Get("state") if state == "" { writeAuthErrorResponse(w, req, "invalid_request", "missing state parameter") return } redirectURI := qv.Get("redirect_uri") if redirectURI == "" { writeAuthErrorResponse(w, req, "invalid_request", "missing redirect_uri parameter") return } redirectURI += "?state=" + url.QueryEscape(state) + "&code=" + url.QueryEscape(s.expectedAuthCode) http.Redirect(w, req, redirectURI, http.StatusFound) return case "/certs": if req.Method != "GET" { w.WriteHeader(http.StatusMethodNotAllowed) return } if err := writeJSON(w, s.jwks); err != nil { return } case "/certs_missing": w.WriteHeader(http.StatusNotFound) case "/certs_invalid": w.Write([]byte("It's not a keyset!")) case "/token": if req.Method != "POST" { w.WriteHeader(http.StatusMethodNotAllowed) return } switch { case req.FormValue("grant_type") != "authorization_code": _ = writeTokenErrorResponse(w, req, http.StatusBadRequest, "invalid_request", "bad grant_type") return case !strutil.StrListContains(s.allowedRedirectURIs, req.FormValue("redirect_uri")): _ = writeTokenErrorResponse(w, req, http.StatusBadRequest, "invalid_request", "redirect_uri is not allowed") return case req.FormValue("code") != s.expectedAuthCode: _ = writeTokenErrorResponse(w, req, http.StatusUnauthorized, "invalid_grant", "unexpected auth code") return } stdClaims := jwt.Claims{ Subject: s.replySubject, Issuer: s.Addr(), NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)), Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Second)), Audience: jwt.Audience{s.clientID}, } if s.customAudience != "" { stdClaims.Audience = jwt.Audience{s.customAudience} } jwtData, err := SignJWT("", stdClaims, s.customClaims) if err != nil { _ = writeTokenErrorResponse(w, req, http.StatusInternalServerError, "server_error", err.Error()) return } reply := struct { AccessToken string `json:"access_token"` IDToken string `json:"id_token,omitempty"` }{ AccessToken: jwtData, IDToken: jwtData, } if s.omitIDToken { reply.IDToken = "" } if err := writeJSON(w, &reply); err != nil { return } case "/userinfo": if s.disableUserInfo { w.WriteHeader(http.StatusNotFound) return } if req.Method != "GET" { w.WriteHeader(http.StatusMethodNotAllowed) return } if err := writeJSON(w, s.replyUserinfo); err != nil { return } default: w.WriteHeader(http.StatusNotFound) } } func writeAuthErrorResponse(w http.ResponseWriter, req *http.Request, errorCode, errorMessage string) { qv := req.URL.Query() redirectURI := qv.Get("redirect_uri") + "?state=" + url.QueryEscape(qv.Get("state")) + "&error=" + url.QueryEscape(errorCode) if errorMessage != "" { redirectURI += "&error_description=" + url.QueryEscape(errorMessage) } http.Redirect(w, req, redirectURI, http.StatusFound) } func writeTokenErrorResponse(w http.ResponseWriter, req *http.Request, statusCode int, errorCode, errorMessage string) error { body := struct { Code string `json:"error"` Desc string `json:"error_description,omitempty"` }{ Code: errorCode, Desc: errorMessage, } w.WriteHeader(statusCode) return writeJSON(w, &body) } // newJWKS converts a pem-encoded public key into JWKS data suitable for a // verification endpoint response func newJWKS(pubKey string) (*jose.JSONWebKeySet, error) { block, _ := pem.Decode([]byte(pubKey)) if block == nil { return nil, fmt.Errorf("unable to decode public key") } input := block.Bytes pub, err := x509.ParsePKIXPublicKey(input) if err != nil { return nil, err } return &jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ { Key: pub, }, }, }, nil } func writeJSON(w http.ResponseWriter, out interface{}) error { enc := json.NewEncoder(w) return enc.Encode(out) } // SignJWT will bundle the provided claims into a signed JWT. The provided key // is assumed to be ECDSA. // // If no private key is provided, the default package keys are used. These can // be retrieved via the SigningKeys() method. func SignJWT(privKey string, claims jwt.Claims, privateClaims interface{}) (string, error) { if privKey == "" { privKey = ecdsaPrivateKey } var key *ecdsa.PrivateKey block, _ := pem.Decode([]byte(privKey)) if block != nil { var err error key, err = x509.ParseECPrivateKey(block.Bytes) if err != nil { return "", err } } sig, err := jose.NewSigner( jose.SigningKey{Algorithm: jose.ES256, Key: key}, (&jose.SignerOptions{}).WithType("JWT"), ) if err != nil { return "", err } raw, err := jwt.Signed(sig). Claims(claims). Claims(privateClaims). CompactSerialize() if err != nil { return "", err } return raw, nil } // httptestNewUnstartedServerWithPort is roughly the same as // httptest.NewUnstartedServer() but allows the caller to explicitly choose the // port if desired. func httptestNewUnstartedServerWithPort(handler http.Handler, port int) *httptest.Server { if port == 0 { return httptest.NewUnstartedServer(handler) } addr := net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) l, err := net.Listen("tcp", addr) if err != nil { panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err)) } return &httptest.Server{ Listener: l, Config: &http.Server{Handler: handler}, } } // SigningKeys returns the pem-encoded keys used to sign JWTs by default. func SigningKeys() (pub, priv string) { return ecdsaPublicKey, ecdsaPrivateKey } var ( ecdsaPublicKey string ecdsaPrivateKey string ) func init() { // Each time we run tests we generate a unique set of keys for use in the // test. These are cached between runs but do not persist between restarts // of the test binary. var err error ecdsaPublicKey, ecdsaPrivateKey, err = generateKey() if err != nil { panic(err) } } func generateKey() (pub, priv string, err error) { privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { return "", "", fmt.Errorf("error generating private key: %v", err) } { derBytes, err := x509.MarshalECPrivateKey(privateKey) if err != nil { return "", "", fmt.Errorf("error marshaling private key: %v", err) } pemBlock := &pem.Block{ Type: "EC PRIVATE KEY", Bytes: derBytes, } priv = string(pem.EncodeToMemory(pemBlock)) } { derBytes, err := x509.MarshalPKIXPublicKey(privateKey.Public()) if err != nil { return "", "", fmt.Errorf("error marshaling public key: %v", err) } pemBlock := &pem.Block{ Type: "PUBLIC KEY", Bytes: derBytes, } pub = string(pem.EncodeToMemory(pemBlock)) } return pub, priv, nil }