Updates the JWT/OIDC auth plugin (#10919)
This commit is contained in:
parent
e494e8a141
commit
a7531a11ea
|
@ -0,0 +1,10 @@
|
|||
```release-note:feature
|
||||
auth/jwt: Adds `max_age` role parameter and `auth_time` claim validation.
|
||||
```
|
||||
```release-note:bug
|
||||
auth/jwt: Fixes an issue where JWT verification keys weren't updated after a `jwks_url` change.
|
||||
```
|
||||
```release-note:bug
|
||||
auth/jwt: Fixes an issue where `jwt_supported_algs` were not being validated for JWT auth using
|
||||
`jwks_url` and `jwt_validation_pubkeys`.
|
||||
```
|
|
@ -56,6 +56,7 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
|
|||
_, err = client.Logical().Write("auth/jwt/config", map[string]interface{}{
|
||||
"bound_issuer": "https://team-vault.auth0.com/",
|
||||
"jwt_validation_pubkeys": TestECDSAPubKey,
|
||||
"jwt_supported_algs": "ES256",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
|
@ -263,6 +263,7 @@ func testAgentExitAfterAuth(t *testing.T, viaFlag bool) {
|
|||
_, err = client.Logical().Write("auth/jwt/config", map[string]interface{}{
|
||||
"bound_issuer": "https://team-vault.auth0.com/",
|
||||
"jwt_validation_pubkeys": agent.TestECDSAPubKey,
|
||||
"jwt_supported_algs": "ES256",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
2
go.mod
2
go.mod
|
@ -79,7 +79,7 @@ require (
|
|||
github.com/hashicorp/vault-plugin-auth-centrify v0.7.0
|
||||
github.com/hashicorp/vault-plugin-auth-cf v0.7.0
|
||||
github.com/hashicorp/vault-plugin-auth-gcp v0.8.0
|
||||
github.com/hashicorp/vault-plugin-auth-jwt v0.7.2-0.20201203001230-e35700fcc0d5
|
||||
github.com/hashicorp/vault-plugin-auth-jwt v0.7.2-0.20210212182451-0d0819f8e5e3
|
||||
github.com/hashicorp/vault-plugin-auth-kerberos v0.2.0
|
||||
github.com/hashicorp/vault-plugin-auth-kubernetes v0.8.0
|
||||
github.com/hashicorp/vault-plugin-auth-oci v0.6.0
|
||||
|
|
16
go.sum
16
go.sum
|
@ -246,8 +246,8 @@ github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkE
|
|||
github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
|
||||
github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk=
|
||||
github.com/coreos/go-oidc v2.0.0+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc=
|
||||
github.com/coreos/go-oidc v2.1.0+incompatible h1:sdJrfw8akMnCuUlaZU3tE/uYXFgfqom8DBE9so9EBsM=
|
||||
github.com/coreos/go-oidc v2.1.0+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc=
|
||||
github.com/coreos/go-oidc v2.2.1+incompatible h1:mh48q/BqXqgjVHpy2ZY7WnWAbenxRjsz9N1i1YxjHAk=
|
||||
github.com/coreos/go-oidc v2.2.1+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc=
|
||||
github.com/coreos/go-semver v0.2.0 h1:3Jm3tLmsgAYcjC+4Up7hJrFBPr+n7rAqYeSw/SZazuY=
|
||||
github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
|
||||
github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
|
||||
|
@ -469,6 +469,8 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
|||
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM=
|
||||
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.4 h1:L8R9j+yAqZuZjsqh/z+F1NCffTKKLShY6zXTItVIZ8M=
|
||||
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-github v17.0.0+incompatible h1:N0LgJ1j65A7kfXrZnUDaYCs/Sf4rEjNlfyDHW9dolSY=
|
||||
github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ=
|
||||
github.com/google/go-metrics-stackdriver v0.2.0 h1:rbs2sxHAPn2OtUj9JdR/Gij1YKGl0BTVD0augB+HEjE=
|
||||
|
@ -532,6 +534,8 @@ github.com/grpc-ecosystem/grpc-gateway v1.9.5 h1:UImYN5qQ8tuGpGE16ZmjvcTtTw24zw1
|
|||
github.com/grpc-ecosystem/grpc-gateway v1.9.5/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY=
|
||||
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8=
|
||||
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4=
|
||||
github.com/hashicorp/cap v0.0.0-20210204173447-5fcddadbf7c7 h1:6OHvaQs9ys66bR1yqHuoI231JAoalgGgxeqzQuVOfX0=
|
||||
github.com/hashicorp/cap v0.0.0-20210204173447-5fcddadbf7c7/go.mod h1:tIk5rB1nihW5+9bZjI7xlc8LGw8FYfiFMKOpHPbWgug=
|
||||
github.com/hashicorp/consul-template v0.25.2-0.20210123001810-166043f8559d h1:DSrhJ8Nqyr3oleIu0rCjRV4j6f4CJSPUp5DljXKKu4w=
|
||||
github.com/hashicorp/consul-template v0.25.2-0.20210123001810-166043f8559d/go.mod h1:kNLSN13aPJz/P0we1XNU+ZDsjkbzX+iHJ+dJOqFZck0=
|
||||
github.com/hashicorp/consul/api v1.3.0/go.mod h1:MmDNSzIMUjNpY/mQ398R4bk2FnqQLoPndWW5VkKPlCE=
|
||||
|
@ -645,8 +649,8 @@ github.com/hashicorp/vault-plugin-auth-cf v0.7.0/go.mod h1:exPUMj8yNohKM7yRiHa7O
|
|||
github.com/hashicorp/vault-plugin-auth-gcp v0.5.1/go.mod h1:eLj92eX8MPI4vY1jaazVLF2sVbSAJ3LRHLRhF/pUmlI=
|
||||
github.com/hashicorp/vault-plugin-auth-gcp v0.8.0 h1:E9EHvC9jCDNix/pB9NKYYLMUkpfv65TSDk2rVvtkdzU=
|
||||
github.com/hashicorp/vault-plugin-auth-gcp v0.8.0/go.mod h1:sHDguHmyGScoalGLEjuxvDCrMPVlw2c3f+ieeiHcv6w=
|
||||
github.com/hashicorp/vault-plugin-auth-jwt v0.7.2-0.20201203001230-e35700fcc0d5 h1:BEsc9LNqgCNMhRVVOzS2v1Czioqod5Lln+Zol7zFmak=
|
||||
github.com/hashicorp/vault-plugin-auth-jwt v0.7.2-0.20201203001230-e35700fcc0d5/go.mod h1:pyR4z5f2Vuz9TXucuN0rivUJTtSdlOtDdZ16IqBjZVo=
|
||||
github.com/hashicorp/vault-plugin-auth-jwt v0.7.2-0.20210212182451-0d0819f8e5e3 h1:Lc2wDPfAiiiFRAkQlu1aXrBRHn/BFvjAXZXKrmtY7zs=
|
||||
github.com/hashicorp/vault-plugin-auth-jwt v0.7.2-0.20210212182451-0d0819f8e5e3/go.mod h1:Gn6ELc1X5nmZ/pxoXf0nA4lG2gwuGnY6SNyW40tR/ws=
|
||||
github.com/hashicorp/vault-plugin-auth-kerberos v0.2.0 h1:7ct50ngVFTeO7EJ3N9PvPHeHc+2cANTHi2+9RwIUIHM=
|
||||
github.com/hashicorp/vault-plugin-auth-kerberos v0.2.0/go.mod h1:IM/n7LY1rIM4MVzOfSH6cRmY/C2rGkrjGrEr0B/yO9c=
|
||||
github.com/hashicorp/vault-plugin-auth-kubernetes v0.8.0 h1:v1jOqR70chxRxONey7g/v0/57MneP05z2dfw6qmlE+8=
|
||||
|
@ -982,6 +986,8 @@ github.com/posener/complete v1.2.3 h1:NP0eAhjcjImqslEwo/1hq7gpajME0fTLTezBKDqfXq
|
|||
github.com/posener/complete v1.2.3/go.mod h1:WZIdtGGp+qx0sLrYKtIRAruyNpv6hFCicSgv7Sy7s/s=
|
||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=
|
||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA=
|
||||
github.com/pquerna/cachecontrol v0.0.0-20201205024021-ac21108117ac h1:jWKYCNlX4J5s8M0nHYkh7Y7c9gRVDEb3mq51j5J0F5M=
|
||||
github.com/pquerna/cachecontrol v0.0.0-20201205024021-ac21108117ac/go.mod h1:hoLfEwdY11HjRfKFH6KqnPsfxlo3BP6bJehpDv8t6sQ=
|
||||
github.com/pquerna/otp v1.2.1-0.20191009055518-468c2dd2b58d h1:PinQItctnaL2LtkaSM678+ZLLy5TajwOeXzWvYC7tII=
|
||||
github.com/pquerna/otp v1.2.1-0.20191009055518-468c2dd2b58d/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
|
||||
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
|
||||
|
@ -1145,6 +1151,7 @@ github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q
|
|||
github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q=
|
||||
github.com/yandex-cloud/go-genproto v0.0.0-20200722140432-762fe965ce77/go.mod h1:HEUYX/p8966tMUHHT+TsS0hF/Ca/NYwqprC5WXSDMfE=
|
||||
github.com/yandex-cloud/go-sdk v0.0.0-20200722140627-2194e5077f13/go.mod h1:LEdAMqa1v/7KYe4b13ALLkonuDxLph57ibUb50ctvJk=
|
||||
github.com/yhat/scrape v0.0.0-20161128144610-24b7890b0945/go.mod h1:4vRFPPNYllgCacoj+0FoKOjTW68rUhEfqPLiEJaK2w8=
|
||||
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/gopher-lua v0.0.0-20200816102855-ee81675732da h1:NimzV1aGyq29m5ukMK0AMWEhFaL/lrEOaephfuoiARg=
|
||||
|
@ -1283,6 +1290,7 @@ golang.org/x/net v0.0.0-20200320220750-118fecf932d8/go.mod h1:z5CRVTTTmAJ677TzLL
|
|||
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
|
||||
golang.org/x/net v0.0.0-20200602114024-627f9648deb9/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
|
||||
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/net v0.0.0-20201002202402-0a1ea396d57c h1:dk0ukUIHmGHqASjP0iue2261isepFCC6XRCSd1nHgDw=
|
||||
golang.org/x/net v0.0.0-20201002202402-0a1ea396d57c/go.mod h1:iQL9McJNjoIa5mjH6nYTCTZXUN6RP+XW3eib7Ya3XcI=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
language: go
|
||||
|
||||
go:
|
||||
- "1.9"
|
||||
- "1.10"
|
||||
- "1.12"
|
||||
- "1.13"
|
||||
|
||||
install:
|
||||
- go get -v -t github.com/coreos/go-oidc/...
|
||||
- go get golang.org/x/tools/cmd/cover
|
||||
- go get github.com/golang/lint/golint
|
||||
- go get golang.org/x/lint/golint
|
||||
|
||||
script:
|
||||
- ./test
|
||||
|
|
|
@ -69,6 +69,7 @@ type Provider struct {
|
|||
authURL string
|
||||
tokenURL string
|
||||
userInfoURL string
|
||||
algorithms []string
|
||||
|
||||
// Raw claims returned by the server.
|
||||
rawClaims []byte
|
||||
|
@ -82,11 +83,27 @@ type cachedKeys struct {
|
|||
}
|
||||
|
||||
type providerJSON struct {
|
||||
Issuer string `json:"issuer"`
|
||||
AuthURL string `json:"authorization_endpoint"`
|
||||
TokenURL string `json:"token_endpoint"`
|
||||
JWKSURL string `json:"jwks_uri"`
|
||||
UserInfoURL string `json:"userinfo_endpoint"`
|
||||
Issuer string `json:"issuer"`
|
||||
AuthURL string `json:"authorization_endpoint"`
|
||||
TokenURL string `json:"token_endpoint"`
|
||||
JWKSURL string `json:"jwks_uri"`
|
||||
UserInfoURL string `json:"userinfo_endpoint"`
|
||||
Algorithms []string `json:"id_token_signing_alg_values_supported"`
|
||||
}
|
||||
|
||||
// supportedAlgorithms is a list of algorithms explicitly supported by this
|
||||
// package. If a provider supports other algorithms, such as HS256 or none,
|
||||
// those values won't be passed to the IDTokenVerifier.
|
||||
var supportedAlgorithms = map[string]bool{
|
||||
RS256: true,
|
||||
RS384: true,
|
||||
RS512: true,
|
||||
ES256: true,
|
||||
ES384: true,
|
||||
ES512: true,
|
||||
PS256: true,
|
||||
PS384: true,
|
||||
PS512: true,
|
||||
}
|
||||
|
||||
// NewProvider uses the OpenID Connect discovery mechanism to construct a Provider.
|
||||
|
@ -123,11 +140,18 @@ func NewProvider(ctx context.Context, issuer string) (*Provider, error) {
|
|||
if p.Issuer != issuer {
|
||||
return nil, fmt.Errorf("oidc: issuer did not match the issuer returned by provider, expected %q got %q", issuer, p.Issuer)
|
||||
}
|
||||
var algs []string
|
||||
for _, a := range p.Algorithms {
|
||||
if supportedAlgorithms[a] {
|
||||
algs = append(algs, a)
|
||||
}
|
||||
}
|
||||
return &Provider{
|
||||
issuer: p.Issuer,
|
||||
authURL: p.AuthURL,
|
||||
tokenURL: p.TokenURL,
|
||||
userInfoURL: p.UserInfoURL,
|
||||
algorithms: algs,
|
||||
rawClaims: body,
|
||||
remoteKeySet: NewRemoteKeySet(ctx, p.JWKSURL),
|
||||
}, nil
|
||||
|
|
|
@ -79,7 +79,9 @@ type Config struct {
|
|||
ClientID string
|
||||
// If specified, only this set of algorithms may be used to sign the JWT.
|
||||
//
|
||||
// Since many providers only support RS256, SupportedSigningAlgs defaults to this value.
|
||||
// If the IDTokenVerifier is created from a provider with (*Provider).Verifier, this
|
||||
// defaults to the set of algorithms the provider supports. Otherwise this values
|
||||
// defaults to RS256.
|
||||
SupportedSigningAlgs []string
|
||||
|
||||
// If true, no ClientID check performed. Must be true if ClientID field is empty.
|
||||
|
@ -105,6 +107,13 @@ type Config struct {
|
|||
// The returned IDTokenVerifier is tied to the Provider's context and its behavior is
|
||||
// undefined once the Provider's context is canceled.
|
||||
func (p *Provider) Verifier(config *Config) *IDTokenVerifier {
|
||||
if len(config.SupportedSigningAlgs) == 0 && len(p.algorithms) > 0 {
|
||||
// Make a copy so we don't modify the config values.
|
||||
cp := &Config{}
|
||||
*cp = *config
|
||||
cp.SupportedSigningAlgs = p.algorithms
|
||||
config = cp
|
||||
}
|
||||
return NewVerifier(p.issuer, p.remoteKeySet, config)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright 2017, The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE.md file.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package cmp determines equality of values.
|
||||
//
|
||||
|
@ -100,8 +100,8 @@ func Equal(x, y interface{}, opts ...Option) bool {
|
|||
// same input values and options.
|
||||
//
|
||||
// The output is displayed as a literal in pseudo-Go syntax.
|
||||
// At the start of each line, a "-" prefix indicates an element removed from y,
|
||||
// a "+" prefix to indicates an element added to y, and the lack of a prefix
|
||||
// At the start of each line, a "-" prefix indicates an element removed from x,
|
||||
// a "+" prefix to indicates an element added from y, and the lack of a prefix
|
||||
// indicates an element common to both x and y. If possible, the output
|
||||
// uses fmt.Stringer.String or error.Error methods to produce more humanly
|
||||
// readable outputs. In such cases, the string is prefixed with either an
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright 2017, The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE.md file.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build purego
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright 2017, The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE.md file.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !purego
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright 2017, The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE.md file.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !cmp_debug
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright 2017, The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE.md file.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build cmp_debug
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright 2017, The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE.md file.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package diff implements an algorithm for producing edit-scripts.
|
||||
// The edit-script is a sequence of operations needed to transform one list
|
||||
|
@ -119,7 +119,7 @@ func (r Result) Similar() bool {
|
|||
return r.NumSame+1 >= r.NumDiff
|
||||
}
|
||||
|
||||
var randInt = rand.New(rand.NewSource(time.Now().Unix())).Intn(2)
|
||||
var randBool = rand.New(rand.NewSource(time.Now().Unix())).Intn(2) == 0
|
||||
|
||||
// Difference reports whether two lists of lengths nx and ny are equal
|
||||
// given the definition of equality provided as f.
|
||||
|
@ -168,17 +168,6 @@ func Difference(nx, ny int, f EqualFunc) (es EditScript) {
|
|||
// A vertical edge is equivalent to inserting a symbol from list Y.
|
||||
// A diagonal edge is equivalent to a matching symbol between both X and Y.
|
||||
|
||||
// To ensure flexibility in changing the algorithm in the future,
|
||||
// introduce some degree of deliberate instability.
|
||||
// This is achieved by fiddling the zigzag iterator to start searching
|
||||
// the graph starting from the bottom-right versus than the top-left.
|
||||
// The result may differ depending on the starting search location,
|
||||
// but still produces a valid edit script.
|
||||
zigzagInit := randInt // either 0 or 1
|
||||
if flags.Deterministic {
|
||||
zigzagInit = 0
|
||||
}
|
||||
|
||||
// Invariants:
|
||||
// • 0 ≤ fwdPath.X ≤ (fwdFrontier.X, revFrontier.X) ≤ revPath.X ≤ nx
|
||||
// • 0 ≤ fwdPath.Y ≤ (fwdFrontier.Y, revFrontier.Y) ≤ revPath.Y ≤ ny
|
||||
|
@ -197,6 +186,11 @@ func Difference(nx, ny int, f EqualFunc) (es EditScript) {
|
|||
// approximately the square-root of the search budget.
|
||||
searchBudget := 4 * (nx + ny) // O(n)
|
||||
|
||||
// Running the tests with the "cmp_debug" build tag prints a visualization
|
||||
// of the algorithm running in real-time. This is educational for
|
||||
// understanding how the algorithm works. See debug_enable.go.
|
||||
f = debug.Begin(nx, ny, f, &fwdPath.es, &revPath.es)
|
||||
|
||||
// The algorithm below is a greedy, meet-in-the-middle algorithm for
|
||||
// computing sub-optimal edit-scripts between two lists.
|
||||
//
|
||||
|
@ -214,22 +208,28 @@ func Difference(nx, ny int, f EqualFunc) (es EditScript) {
|
|||
// frontier towards the opposite corner.
|
||||
// • This algorithm terminates when either the X coordinates or the
|
||||
// Y coordinates of the forward and reverse frontier points ever intersect.
|
||||
//
|
||||
|
||||
// This algorithm is correct even if searching only in the forward direction
|
||||
// or in the reverse direction. We do both because it is commonly observed
|
||||
// that two lists commonly differ because elements were added to the front
|
||||
// or end of the other list.
|
||||
//
|
||||
// Running the tests with the "cmp_debug" build tag prints a visualization
|
||||
// of the algorithm running in real-time. This is educational for
|
||||
// understanding how the algorithm works. See debug_enable.go.
|
||||
f = debug.Begin(nx, ny, f, &fwdPath.es, &revPath.es)
|
||||
for {
|
||||
// Non-deterministically start with either the forward or reverse direction
|
||||
// to introduce some deliberate instability so that we have the flexibility
|
||||
// to change this algorithm in the future.
|
||||
if flags.Deterministic || randBool {
|
||||
goto forwardSearch
|
||||
} else {
|
||||
goto reverseSearch
|
||||
}
|
||||
|
||||
forwardSearch:
|
||||
{
|
||||
// Forward search from the beginning.
|
||||
if fwdFrontier.X >= revFrontier.X || fwdFrontier.Y >= revFrontier.Y || searchBudget == 0 {
|
||||
break
|
||||
goto finishSearch
|
||||
}
|
||||
for stop1, stop2, i := false, false, zigzagInit; !(stop1 && stop2) && searchBudget > 0; i++ {
|
||||
for stop1, stop2, i := false, false, 0; !(stop1 && stop2) && searchBudget > 0; i++ {
|
||||
// Search in a diagonal pattern for a match.
|
||||
z := zigzag(i)
|
||||
p := point{fwdFrontier.X + z, fwdFrontier.Y - z}
|
||||
|
@ -262,10 +262,14 @@ func Difference(nx, ny int, f EqualFunc) (es EditScript) {
|
|||
} else {
|
||||
fwdFrontier.Y++
|
||||
}
|
||||
goto reverseSearch
|
||||
}
|
||||
|
||||
reverseSearch:
|
||||
{
|
||||
// Reverse search from the end.
|
||||
if fwdFrontier.X >= revFrontier.X || fwdFrontier.Y >= revFrontier.Y || searchBudget == 0 {
|
||||
break
|
||||
goto finishSearch
|
||||
}
|
||||
for stop1, stop2, i := false, false, 0; !(stop1 && stop2) && searchBudget > 0; i++ {
|
||||
// Search in a diagonal pattern for a match.
|
||||
|
@ -300,8 +304,10 @@ func Difference(nx, ny int, f EqualFunc) (es EditScript) {
|
|||
} else {
|
||||
revFrontier.Y--
|
||||
}
|
||||
goto forwardSearch
|
||||
}
|
||||
|
||||
finishSearch:
|
||||
// Join the forward and reverse paths and then append the reverse path.
|
||||
fwdPath.connect(revPath.point, f)
|
||||
for i := len(revPath.es) - 1; i >= 0; i-- {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright 2019, The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE.md file.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package flags
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright 2019, The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE.md file.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !go1.10
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright 2019, The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE.md file.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build go1.10
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright 2017, The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE.md file.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package function provides functionality for identifying function types.
|
||||
package function
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright 2020, The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE.md file.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package value
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright 2018, The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE.md file.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build purego
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright 2018, The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE.md file.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !purego
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright 2017, The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE.md file.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package value
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright 2017, The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE.md file.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package value
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright 2017, The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE.md file.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package cmp
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright 2017, The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE.md file.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package cmp
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright 2017, The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE.md file.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package cmp
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright 2019, The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE.md file.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package cmp
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright 2020, The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE.md file.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package cmp
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright 2019, The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE.md file.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package cmp
|
||||
|
||||
|
@ -351,6 +351,8 @@ func formatMapKey(v reflect.Value, disambiguate bool, ptrs *pointerReferences) s
|
|||
opts.PrintAddresses = disambiguate
|
||||
opts.AvoidStringer = disambiguate
|
||||
opts.QualifiedNames = disambiguate
|
||||
opts.VerbosityLevel = maxVerbosityPreset
|
||||
opts.LimitVerbosity = true
|
||||
s := opts.FormatValue(v, reflect.Map, ptrs).String()
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright 2019, The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE.md file.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package cmp
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright 2019, The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE.md file.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package cmp
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright 2019, The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE.md file.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package cmp
|
||||
|
||||
|
|
|
@ -0,0 +1,362 @@
|
|||
Mozilla Public License, version 2.0
|
||||
|
||||
1. Definitions
|
||||
|
||||
1.1. "Contributor"
|
||||
|
||||
means each individual or legal entity that creates, contributes to the
|
||||
creation of, or owns Covered Software.
|
||||
|
||||
1.2. "Contributor Version"
|
||||
|
||||
means the combination of the Contributions of others (if any) used by a
|
||||
Contributor and that particular Contributor's Contribution.
|
||||
|
||||
1.3. "Contribution"
|
||||
|
||||
means Covered Software of a particular Contributor.
|
||||
|
||||
1.4. "Covered Software"
|
||||
|
||||
means Source Code Form to which the initial Contributor has attached the
|
||||
notice in Exhibit A, the Executable Form of such Source Code Form, and
|
||||
Modifications of such Source Code Form, in each case including portions
|
||||
thereof.
|
||||
|
||||
1.5. "Incompatible With Secondary Licenses"
|
||||
means
|
||||
|
||||
a. that the initial Contributor has attached the notice described in
|
||||
Exhibit B to the Covered Software; or
|
||||
|
||||
b. that the Covered Software was made available under the terms of
|
||||
version 1.1 or earlier of the License, but not also under the terms of
|
||||
a Secondary License.
|
||||
|
||||
1.6. "Executable Form"
|
||||
|
||||
means any form of the work other than Source Code Form.
|
||||
|
||||
1.7. "Larger Work"
|
||||
|
||||
means a work that combines Covered Software with other material, in a
|
||||
separate file or files, that is not Covered Software.
|
||||
|
||||
1.8. "License"
|
||||
|
||||
means this document.
|
||||
|
||||
1.9. "Licensable"
|
||||
|
||||
means having the right to grant, to the maximum extent possible, whether
|
||||
at the time of the initial grant or subsequently, any and all of the
|
||||
rights conveyed by this License.
|
||||
|
||||
1.10. "Modifications"
|
||||
|
||||
means any of the following:
|
||||
|
||||
a. any file in Source Code Form that results from an addition to,
|
||||
deletion from, or modification of the contents of Covered Software; or
|
||||
|
||||
b. any new file in Source Code Form that contains any Covered Software.
|
||||
|
||||
1.11. "Patent Claims" of a Contributor
|
||||
|
||||
means any patent claim(s), including without limitation, method,
|
||||
process, and apparatus claims, in any patent Licensable by such
|
||||
Contributor that would be infringed, but for the grant of the License,
|
||||
by the making, using, selling, offering for sale, having made, import,
|
||||
or transfer of either its Contributions or its Contributor Version.
|
||||
|
||||
1.12. "Secondary License"
|
||||
|
||||
means either the GNU General Public License, Version 2.0, the GNU Lesser
|
||||
General Public License, Version 2.1, the GNU Affero General Public
|
||||
License, Version 3.0, or any later versions of those licenses.
|
||||
|
||||
1.13. "Source Code Form"
|
||||
|
||||
means the form of the work preferred for making modifications.
|
||||
|
||||
1.14. "You" (or "Your")
|
||||
|
||||
means an individual or a legal entity exercising rights under this
|
||||
License. For legal entities, "You" includes any entity that controls, is
|
||||
controlled by, or is under common control with You. For purposes of this
|
||||
definition, "control" means (a) the power, direct or indirect, to cause
|
||||
the direction or management of such entity, whether by contract or
|
||||
otherwise, or (b) ownership of more than fifty percent (50%) of the
|
||||
outstanding shares or beneficial ownership of such entity.
|
||||
|
||||
|
||||
2. License Grants and Conditions
|
||||
|
||||
2.1. Grants
|
||||
|
||||
Each Contributor hereby grants You a world-wide, royalty-free,
|
||||
non-exclusive license:
|
||||
|
||||
a. under intellectual property rights (other than patent or trademark)
|
||||
Licensable by such Contributor to use, reproduce, make available,
|
||||
modify, display, perform, distribute, and otherwise exploit its
|
||||
Contributions, either on an unmodified basis, with Modifications, or
|
||||
as part of a Larger Work; and
|
||||
|
||||
b. under Patent Claims of such Contributor to make, use, sell, offer for
|
||||
sale, have made, import, and otherwise transfer either its
|
||||
Contributions or its Contributor Version.
|
||||
|
||||
2.2. Effective Date
|
||||
|
||||
The licenses granted in Section 2.1 with respect to any Contribution
|
||||
become effective for each Contribution on the date the Contributor first
|
||||
distributes such Contribution.
|
||||
|
||||
2.3. Limitations on Grant Scope
|
||||
|
||||
The licenses granted in this Section 2 are the only rights granted under
|
||||
this License. No additional rights or licenses will be implied from the
|
||||
distribution or licensing of Covered Software under this License.
|
||||
Notwithstanding Section 2.1(b) above, no patent license is granted by a
|
||||
Contributor:
|
||||
|
||||
a. for any code that a Contributor has removed from Covered Software; or
|
||||
|
||||
b. for infringements caused by: (i) Your and any other third party's
|
||||
modifications of Covered Software, or (ii) the combination of its
|
||||
Contributions with other software (except as part of its Contributor
|
||||
Version); or
|
||||
|
||||
c. under Patent Claims infringed by Covered Software in the absence of
|
||||
its Contributions.
|
||||
|
||||
This License does not grant any rights in the trademarks, service marks,
|
||||
or logos of any Contributor (except as may be necessary to comply with
|
||||
the notice requirements in Section 3.4).
|
||||
|
||||
2.4. Subsequent Licenses
|
||||
|
||||
No Contributor makes additional grants as a result of Your choice to
|
||||
distribute the Covered Software under a subsequent version of this
|
||||
License (see Section 10.2) or under the terms of a Secondary License (if
|
||||
permitted under the terms of Section 3.3).
|
||||
|
||||
2.5. Representation
|
||||
|
||||
Each Contributor represents that the Contributor believes its
|
||||
Contributions are its original creation(s) or it has sufficient rights to
|
||||
grant the rights to its Contributions conveyed by this License.
|
||||
|
||||
2.6. Fair Use
|
||||
|
||||
This License is not intended to limit any rights You have under
|
||||
applicable copyright doctrines of fair use, fair dealing, or other
|
||||
equivalents.
|
||||
|
||||
2.7. Conditions
|
||||
|
||||
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in
|
||||
Section 2.1.
|
||||
|
||||
|
||||
3. Responsibilities
|
||||
|
||||
3.1. Distribution of Source Form
|
||||
|
||||
All distribution of Covered Software in Source Code Form, including any
|
||||
Modifications that You create or to which You contribute, must be under
|
||||
the terms of this License. You must inform recipients that the Source
|
||||
Code Form of the Covered Software is governed by the terms of this
|
||||
License, and how they can obtain a copy of this License. You may not
|
||||
attempt to alter or restrict the recipients' rights in the Source Code
|
||||
Form.
|
||||
|
||||
3.2. Distribution of Executable Form
|
||||
|
||||
If You distribute Covered Software in Executable Form then:
|
||||
|
||||
a. such Covered Software must also be made available in Source Code Form,
|
||||
as described in Section 3.1, and You must inform recipients of the
|
||||
Executable Form how they can obtain a copy of such Source Code Form by
|
||||
reasonable means in a timely manner, at a charge no more than the cost
|
||||
of distribution to the recipient; and
|
||||
|
||||
b. You may distribute such Executable Form under the terms of this
|
||||
License, or sublicense it under different terms, provided that the
|
||||
license for the Executable Form does not attempt to limit or alter the
|
||||
recipients' rights in the Source Code Form under this License.
|
||||
|
||||
3.3. Distribution of a Larger Work
|
||||
|
||||
You may create and distribute a Larger Work under terms of Your choice,
|
||||
provided that You also comply with the requirements of this License for
|
||||
the Covered Software. If the Larger Work is a combination of Covered
|
||||
Software with a work governed by one or more Secondary Licenses, and the
|
||||
Covered Software is not Incompatible With Secondary Licenses, this
|
||||
License permits You to additionally distribute such Covered Software
|
||||
under the terms of such Secondary License(s), so that the recipient of
|
||||
the Larger Work may, at their option, further distribute the Covered
|
||||
Software under the terms of either this License or such Secondary
|
||||
License(s).
|
||||
|
||||
3.4. Notices
|
||||
|
||||
You may not remove or alter the substance of any license notices
|
||||
(including copyright notices, patent notices, disclaimers of warranty, or
|
||||
limitations of liability) contained within the Source Code Form of the
|
||||
Covered Software, except that You may alter any license notices to the
|
||||
extent required to remedy known factual inaccuracies.
|
||||
|
||||
3.5. Application of Additional Terms
|
||||
|
||||
You may choose to offer, and to charge a fee for, warranty, support,
|
||||
indemnity or liability obligations to one or more recipients of Covered
|
||||
Software. However, You may do so only on Your own behalf, and not on
|
||||
behalf of any Contributor. You must make it absolutely clear that any
|
||||
such warranty, support, indemnity, or liability obligation is offered by
|
||||
You alone, and You hereby agree to indemnify every Contributor for any
|
||||
liability incurred by such Contributor as a result of warranty, support,
|
||||
indemnity or liability terms You offer. You may include additional
|
||||
disclaimers of warranty and limitations of liability specific to any
|
||||
jurisdiction.
|
||||
|
||||
4. Inability to Comply Due to Statute or Regulation
|
||||
|
||||
If it is impossible for You to comply with any of the terms of this License
|
||||
with respect to some or all of the Covered Software due to statute,
|
||||
judicial order, or regulation then You must: (a) comply with the terms of
|
||||
this License to the maximum extent possible; and (b) describe the
|
||||
limitations and the code they affect. Such description must be placed in a
|
||||
text file included with all distributions of the Covered Software under
|
||||
this License. Except to the extent prohibited by statute or regulation,
|
||||
such description must be sufficiently detailed for a recipient of ordinary
|
||||
skill to be able to understand it.
|
||||
|
||||
5. Termination
|
||||
|
||||
5.1. The rights granted under this License will terminate automatically if You
|
||||
fail to comply with any of its terms. However, if You become compliant,
|
||||
then the rights granted under this License from a particular Contributor
|
||||
are reinstated (a) provisionally, unless and until such Contributor
|
||||
explicitly and finally terminates Your grants, and (b) on an ongoing
|
||||
basis, if such Contributor fails to notify You of the non-compliance by
|
||||
some reasonable means prior to 60 days after You have come back into
|
||||
compliance. Moreover, Your grants from a particular Contributor are
|
||||
reinstated on an ongoing basis if such Contributor notifies You of the
|
||||
non-compliance by some reasonable means, this is the first time You have
|
||||
received notice of non-compliance with this License from such
|
||||
Contributor, and You become compliant prior to 30 days after Your receipt
|
||||
of the notice.
|
||||
|
||||
5.2. If You initiate litigation against any entity by asserting a patent
|
||||
infringement claim (excluding declaratory judgment actions,
|
||||
counter-claims, and cross-claims) alleging that a Contributor Version
|
||||
directly or indirectly infringes any patent, then the rights granted to
|
||||
You by any and all Contributors for the Covered Software under Section
|
||||
2.1 of this License shall terminate.
|
||||
|
||||
5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user
|
||||
license agreements (excluding distributors and resellers) which have been
|
||||
validly granted by You or Your distributors under this License prior to
|
||||
termination shall survive termination.
|
||||
|
||||
6. Disclaimer of Warranty
|
||||
|
||||
Covered Software is provided under this License on an "as is" basis,
|
||||
without warranty of any kind, either expressed, implied, or statutory,
|
||||
including, without limitation, warranties that the Covered Software is free
|
||||
of defects, merchantable, fit for a particular purpose or non-infringing.
|
||||
The entire risk as to the quality and performance of the Covered Software
|
||||
is with You. Should any Covered Software prove defective in any respect,
|
||||
You (not any Contributor) assume the cost of any necessary servicing,
|
||||
repair, or correction. This disclaimer of warranty constitutes an essential
|
||||
part of this License. No use of any Covered Software is authorized under
|
||||
this License except under this disclaimer.
|
||||
|
||||
7. Limitation of Liability
|
||||
|
||||
Under no circumstances and under no legal theory, whether tort (including
|
||||
negligence), contract, or otherwise, shall any Contributor, or anyone who
|
||||
distributes Covered Software as permitted above, be liable to You for any
|
||||
direct, indirect, special, incidental, or consequential damages of any
|
||||
character including, without limitation, damages for lost profits, loss of
|
||||
goodwill, work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses, even if such party shall have been
|
||||
informed of the possibility of such damages. This limitation of liability
|
||||
shall not apply to liability for death or personal injury resulting from
|
||||
such party's negligence to the extent applicable law prohibits such
|
||||
limitation. Some jurisdictions do not allow the exclusion or limitation of
|
||||
incidental or consequential damages, so this exclusion and limitation may
|
||||
not apply to You.
|
||||
|
||||
8. Litigation
|
||||
|
||||
Any litigation relating to this License may be brought only in the courts
|
||||
of a jurisdiction where the defendant maintains its principal place of
|
||||
business and such litigation shall be governed by laws of that
|
||||
jurisdiction, without reference to its conflict-of-law provisions. Nothing
|
||||
in this Section shall prevent a party's ability to bring cross-claims or
|
||||
counter-claims.
|
||||
|
||||
9. Miscellaneous
|
||||
|
||||
This License represents the complete agreement concerning the subject
|
||||
matter hereof. If any provision of this License is held to be
|
||||
unenforceable, such provision shall be reformed only to the extent
|
||||
necessary to make it enforceable. Any law or regulation which provides that
|
||||
the language of a contract shall be construed against the drafter shall not
|
||||
be used to construe this License against a Contributor.
|
||||
|
||||
|
||||
10. Versions of the License
|
||||
|
||||
10.1. New Versions
|
||||
|
||||
Mozilla Foundation is the license steward. Except as provided in Section
|
||||
10.3, no one other than the license steward has the right to modify or
|
||||
publish new versions of this License. Each version will be given a
|
||||
distinguishing version number.
|
||||
|
||||
10.2. Effect of New Versions
|
||||
|
||||
You may distribute the Covered Software under the terms of the version
|
||||
of the License under which You originally received the Covered Software,
|
||||
or under the terms of any subsequent version published by the license
|
||||
steward.
|
||||
|
||||
10.3. Modified Versions
|
||||
|
||||
If you create software not governed by this License, and you want to
|
||||
create a new license for such software, you may create and use a
|
||||
modified version of this License if you rename the license and remove
|
||||
any references to the name of the license steward (except to note that
|
||||
such modified license differs from this License).
|
||||
|
||||
10.4. Distributing Source Code Form that is Incompatible With Secondary
|
||||
Licenses If You choose to distribute Source Code Form that is
|
||||
Incompatible With Secondary Licenses under the terms of this version of
|
||||
the License, the notice described in Exhibit B of this License must be
|
||||
attached.
|
||||
|
||||
Exhibit A - Source Code Form License Notice
|
||||
|
||||
This Source Code Form is subject to the
|
||||
terms of the Mozilla Public License, v.
|
||||
2.0. If a copy of the MPL was not
|
||||
distributed with this file, You can
|
||||
obtain one at
|
||||
http://mozilla.org/MPL/2.0/.
|
||||
|
||||
If it is not possible or desirable to put the notice in a particular file,
|
||||
then You may include the notice in a location (such as a LICENSE file in a
|
||||
relevant directory) where a recipient would be likely to look for such a
|
||||
notice.
|
||||
|
||||
You may add additional accurate notices of copyright ownership.
|
||||
|
||||
Exhibit B - "Incompatible With Secondary Licenses" Notice
|
||||
|
||||
This Source Code Form is "Incompatible
|
||||
With Secondary Licenses", as defined by
|
||||
the Mozilla Public License, v. 2.0.
|
|
@ -0,0 +1,20 @@
|
|||
# jwt
|
||||
[![Go Reference](https://pkg.go.dev/badge/github.com/hashicorp/cap/jwt.svg)](https://pkg.go.dev/github.com/hashicorp/cap/jwt)
|
||||
|
||||
Package jwt provides signature verification and claims set validation for JSON Web Tokens (JWT)
|
||||
of the JSON Web Signature (JWS) form.
|
||||
|
||||
Primary types provided by the package:
|
||||
|
||||
* `KeySet`: Represents a set of keys that can be used to verify the signatures of JWTs.
|
||||
A KeySet is expected to be backed by a set of local or remote keys.
|
||||
|
||||
* `Validator`: Provides signature verification and claims set validation behavior for JWTs.
|
||||
|
||||
* `Expected`: Defines the expected claims values to assert when validating a JWT.
|
||||
|
||||
* `Alg`: Represents asymmetric signing algorithms.
|
||||
|
||||
### Examples:
|
||||
|
||||
Please see [docs_test.go](./docs_test.go) for additional usage examples.
|
|
@ -0,0 +1,46 @@
|
|||
package jwt
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Alg represents asymmetric signing algorithms
|
||||
type Alg string
|
||||
|
||||
const (
|
||||
// JOSE asymmetric signing algorithm values as defined by RFC 7518.
|
||||
//
|
||||
// See: https://tools.ietf.org/html/rfc7518#section-3.1
|
||||
RS256 Alg = "RS256" // RSASSA-PKCS-v1.5 using SHA-256
|
||||
RS384 Alg = "RS384" // RSASSA-PKCS-v1.5 using SHA-384
|
||||
RS512 Alg = "RS512" // RSASSA-PKCS-v1.5 using SHA-512
|
||||
ES256 Alg = "ES256" // ECDSA using P-256 and SHA-256
|
||||
ES384 Alg = "ES384" // ECDSA using P-384 and SHA-384
|
||||
ES512 Alg = "ES512" // ECDSA using P-521 and SHA-512
|
||||
PS256 Alg = "PS256" // RSASSA-PSS using SHA256 and MGF1-SHA256
|
||||
PS384 Alg = "PS384" // RSASSA-PSS using SHA384 and MGF1-SHA384
|
||||
PS512 Alg = "PS512" // RSASSA-PSS using SHA512 and MGF1-SHA512
|
||||
EdDSA Alg = "EdDSA" // Ed25519 using SHA-512
|
||||
)
|
||||
|
||||
var supportedAlgorithms = map[Alg]bool{
|
||||
RS256: true,
|
||||
RS384: true,
|
||||
RS512: true,
|
||||
ES256: true,
|
||||
ES384: true,
|
||||
ES512: true,
|
||||
PS256: true,
|
||||
PS384: true,
|
||||
PS512: true,
|
||||
EdDSA: true,
|
||||
}
|
||||
|
||||
// SupportedSigningAlgorithm returns an error if any of the given Algs
|
||||
// are not supported signing algorithms.
|
||||
func SupportedSigningAlgorithm(algs ...Alg) error {
|
||||
for _, a := range algs {
|
||||
if !supportedAlgorithms[a] {
|
||||
return fmt.Errorf("unsupported signing algorithm %q", a)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
/*
|
||||
Package jwt provides signature verification and claims set validation for JSON Web Tokens (JWT)
|
||||
of the JSON Web Signature (JWS) form.
|
||||
|
||||
JWT claims set validation provided by the package includes the option to validate
|
||||
all registered claim names defined in https://tools.ietf.org/html/rfc7519#section-4.1.
|
||||
|
||||
JOSE header validation provided by the the package includes the option to validate the "alg"
|
||||
(Algorithm) Header Parameter defined in https://tools.ietf.org/html/rfc7515#section-4.1.
|
||||
|
||||
JWT signature verification is supported by providing keys from the following sources:
|
||||
|
||||
- JSON Web Key Set (JWKS) URL
|
||||
- OIDC Discovery mechanism
|
||||
- Local public keys
|
||||
|
||||
JWT signature verification supports the following asymmetric algorithms as defined in
|
||||
https://www.rfc-editor.org/rfc/rfc7518.html#section-3.1:
|
||||
|
||||
- RS256: RSASSA-PKCS1-v1_5 using SHA-256
|
||||
- RS384: RSASSA-PKCS1-v1_5 using SHA-384
|
||||
- RS512: RSASSA-PKCS1-v1_5 using SHA-512
|
||||
- ES256: ECDSA using P-256 and SHA-256
|
||||
- ES384: ECDSA using P-384 and SHA-384
|
||||
- ES512: ECDSA using P-521 and SHA-512
|
||||
- PS256: RSASSA-PSS using SHA-256 and MGF1 with SHA-256
|
||||
- PS384: RSASSA-PSS using SHA-384 and MGF1 with SHA-384
|
||||
- PS512: RSASSA-PSS using SHA-512 and MGF1 with SHA-512
|
||||
- EdDSA: Ed25519 using SHA-512
|
||||
*/
|
||||
package jwt
|
|
@ -0,0 +1,265 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
// DefaultLeewaySeconds defines the amount of leeway that's used by default
|
||||
// for validating the "nbf" (Not Before) and "exp" (Expiration Time) claims.
|
||||
const DefaultLeewaySeconds = 150
|
||||
|
||||
// Validator validates JSON Web Tokens (JWT) by providing signature
|
||||
// verification and claims set validation.
|
||||
type Validator struct {
|
||||
keySet KeySet
|
||||
}
|
||||
|
||||
// NewValidator returns a Validator that uses the given KeySet to verify JWT signatures.
|
||||
func NewValidator(keySet KeySet) (*Validator, error) {
|
||||
if keySet == nil {
|
||||
return nil, errors.New("keySet must not be nil")
|
||||
}
|
||||
|
||||
return &Validator{
|
||||
keySet: keySet,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Expected defines the expected claims values to assert when validating a JWT.
|
||||
// For claims that involve validation of the JWT with respect to time, leeway
|
||||
// fields are provided to account for potential clock skew.
|
||||
type Expected struct {
|
||||
// The expected JWT "iss" (issuer) claim value. If empty, validation is skipped.
|
||||
Issuer string
|
||||
|
||||
// The expected JWT "sub" (subject) claim value. If empty, validation is skipped.
|
||||
Subject string
|
||||
|
||||
// The expected JWT "jti" (JWT ID) claim value. If empty, validation is skipped.
|
||||
ID string
|
||||
|
||||
// The list of expected JWT "aud" (audience) claim values to match against.
|
||||
// The JWT claim will be considered valid if it matches any of the expected
|
||||
// audiences. If empty, validation is skipped.
|
||||
Audiences []string
|
||||
|
||||
// SigningAlgorithms provides the list of expected JWS "alg" (algorithm) header
|
||||
// parameter values to match against. The JWS header parameter will be considered
|
||||
// valid if it matches any of the expected signing algorithms. The following
|
||||
// algorithms are supported: RS256, RS384, RS512, ES256, ES384, ES512, PS256,
|
||||
// PS384, PS512, EdDSA. If empty, defaults to RS256.
|
||||
SigningAlgorithms []Alg
|
||||
|
||||
// NotBeforeLeeway provides the option to set an amount of leeway to use when
|
||||
// validating the "nbf" (Not Before) claim. If the duration is zero or not
|
||||
// provided, a default leeway of 150 seconds will be used. If the duration is
|
||||
// negative, no leeway will be used.
|
||||
NotBeforeLeeway time.Duration
|
||||
|
||||
// ExpirationLeeway provides the option to set an amount of leeway to use when
|
||||
// validating the "exp" (Expiration Time) claim. If the duration is zero or not
|
||||
// provided, a default leeway of 150 seconds will be used. If the duration is
|
||||
// negative, no leeway will be used.
|
||||
ExpirationLeeway time.Duration
|
||||
|
||||
// ClockSkewLeeway provides the option to set an amount of leeway to use when
|
||||
// validating the "nbf" (Not Before), "exp" (Expiration Time), and "iat" (Issued At)
|
||||
// claims. If the duration is zero or not provided, a default leeway of 60 seconds
|
||||
// will be used. If the duration is negative, no leeway will be used.
|
||||
ClockSkewLeeway time.Duration
|
||||
|
||||
// Now provides the option to specify a func for determining what the current time is.
|
||||
// The func will be used to provide the current time when validating a JWT with respect to
|
||||
// the "nbf" (Not Before), "exp" (Expiration Time), and "iat" (Issued At) claims. If not
|
||||
// provided, defaults to returning time.Now().
|
||||
Now func() time.Time
|
||||
}
|
||||
|
||||
// Validate validates JWTs of the JWS compact serialization form.
|
||||
//
|
||||
// The given JWT is considered valid if:
|
||||
// 1. Its signature is successfully verified.
|
||||
// 2. Its claims set and header parameter values match what's given by Expected.
|
||||
// 3. It's valid with respect to the current time. This means that the current
|
||||
// time must be within the times (inclusive) given by the "nbf" (Not Before)
|
||||
// and "exp" (Expiration Time) claims and after the time given by the "iat"
|
||||
// (Issued At) claim, with configurable leeway. See Expected.Now() for details
|
||||
// on how the current time is provided for validation.
|
||||
func (v *Validator) Validate(ctx context.Context, token string, expected Expected) (map[string]interface{}, error) {
|
||||
// First, verify the signature to ensure subsequent validation is against verified claims
|
||||
allClaims, err := v.keySet.VerifySignature(ctx, token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error verifying token signature: %w", err)
|
||||
}
|
||||
|
||||
// Validate the signing algorithm in the JWS header
|
||||
if err := validateSigningAlgorithm(token, expected.SigningAlgorithms); err != nil {
|
||||
return nil, fmt.Errorf("invalid algorithm (alg) header parameter: %w", err)
|
||||
}
|
||||
|
||||
// Unmarshal all claims into the set of public JWT registered claims
|
||||
claims := jwt.Claims{}
|
||||
allClaimsJSON, err := json.Marshal(allClaims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := json.Unmarshal(allClaimsJSON, &claims); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// At least one of the "nbf" (Not Before), "exp" (Expiration Time), or "iat" (Issued At)
|
||||
// claims are required to be set.
|
||||
if claims.IssuedAt == nil {
|
||||
claims.IssuedAt = new(jwt.NumericDate)
|
||||
}
|
||||
if claims.Expiry == nil {
|
||||
claims.Expiry = new(jwt.NumericDate)
|
||||
}
|
||||
if claims.NotBefore == nil {
|
||||
claims.NotBefore = new(jwt.NumericDate)
|
||||
}
|
||||
if *claims.IssuedAt == 0 && *claims.Expiry == 0 && *claims.NotBefore == 0 {
|
||||
return nil, errors.New("no issued at (iat), not before (nbf), or expiration time (exp) claims in token")
|
||||
}
|
||||
|
||||
// If "exp" (Expiration Time) is not set, then set it to the latest of
|
||||
// either the "iat" (Issued At) or "nbf" (Not Before) claims plus leeway.
|
||||
if *claims.Expiry == 0 {
|
||||
latestStart := *claims.IssuedAt
|
||||
if *claims.NotBefore > *claims.IssuedAt {
|
||||
latestStart = *claims.NotBefore
|
||||
}
|
||||
leeway := expected.ExpirationLeeway.Seconds()
|
||||
if expected.ExpirationLeeway.Seconds() < 0 {
|
||||
leeway = 0
|
||||
} else if expected.ExpirationLeeway.Seconds() == 0 {
|
||||
leeway = DefaultLeewaySeconds
|
||||
}
|
||||
*claims.Expiry = jwt.NumericDate(int64(latestStart) + int64(leeway))
|
||||
}
|
||||
|
||||
// If "nbf" (Not Before) is not set, then set it to the "iat" (Issued At) if set.
|
||||
// Otherwise, set it to the "exp" (Expiration Time) minus leeway.
|
||||
if *claims.NotBefore == 0 {
|
||||
if *claims.IssuedAt != 0 {
|
||||
*claims.NotBefore = *claims.IssuedAt
|
||||
} else {
|
||||
leeway := expected.NotBeforeLeeway.Seconds()
|
||||
if expected.NotBeforeLeeway.Seconds() < 0 {
|
||||
leeway = 0
|
||||
} else if expected.NotBeforeLeeway.Seconds() == 0 {
|
||||
leeway = DefaultLeewaySeconds
|
||||
}
|
||||
*claims.NotBefore = jwt.NumericDate(int64(*claims.Expiry) - int64(leeway))
|
||||
}
|
||||
}
|
||||
|
||||
// Set clock skew leeway to apply when validating all time-related claims
|
||||
cksLeeway := expected.ClockSkewLeeway
|
||||
if expected.ClockSkewLeeway.Seconds() < 0 {
|
||||
cksLeeway = 0
|
||||
} else if expected.ClockSkewLeeway.Seconds() == 0 {
|
||||
cksLeeway = jwt.DefaultLeeway
|
||||
}
|
||||
|
||||
// Validate claims by asserting they're as expected
|
||||
if expected.Issuer != "" && expected.Issuer != claims.Issuer {
|
||||
return nil, fmt.Errorf("invalid issuer (iss) claim")
|
||||
}
|
||||
if expected.Subject != "" && expected.Subject != claims.Subject {
|
||||
return nil, fmt.Errorf("invalid subject (sub) claim")
|
||||
}
|
||||
if expected.ID != "" && expected.ID != claims.ID {
|
||||
return nil, fmt.Errorf("invalid ID (jti) claim")
|
||||
}
|
||||
if err := validateAudience(expected.Audiences, claims.Audience); err != nil {
|
||||
return nil, fmt.Errorf("invalid audience (aud) claim: %w", err)
|
||||
}
|
||||
|
||||
// Validate that the token is not expired with respect to the current time
|
||||
now := time.Now()
|
||||
if expected.Now != nil {
|
||||
now = expected.Now()
|
||||
}
|
||||
if claims.NotBefore != nil && now.Add(cksLeeway).Before(claims.NotBefore.Time()) {
|
||||
return nil, errors.New("invalid not before (nbf) claim: token not yet valid")
|
||||
}
|
||||
if claims.Expiry != nil && now.Add(-cksLeeway).After(claims.Expiry.Time()) {
|
||||
return nil, errors.New("invalid expiration time (exp) claim: token is expired")
|
||||
}
|
||||
if claims.IssuedAt != nil && now.Add(cksLeeway).Before(claims.IssuedAt.Time()) {
|
||||
return nil, errors.New("invalid issued at (iat) claim: token issued in the future")
|
||||
}
|
||||
|
||||
return allClaims, nil
|
||||
}
|
||||
|
||||
// validateSigningAlgorithm checks whether the JWS "alg" (Algorithm) header
|
||||
// parameter value for the given JWT matches any given in expectedAlgorithms.
|
||||
// If expectedAlgorithms is empty, RS256 will be expected by default.
|
||||
func validateSigningAlgorithm(token string, expectedAlgorithms []Alg) error {
|
||||
if err := SupportedSigningAlgorithm(expectedAlgorithms...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
jws, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(jws.Signatures) == 0 {
|
||||
return fmt.Errorf("token must be signed")
|
||||
}
|
||||
if len(jws.Signatures) == 1 && len(jws.Signatures[0].Signature) == 0 {
|
||||
return fmt.Errorf("token must be signed")
|
||||
}
|
||||
if len(jws.Signatures) > 1 {
|
||||
return fmt.Errorf("token with multiple signatures not supported")
|
||||
}
|
||||
|
||||
if len(expectedAlgorithms) == 0 {
|
||||
expectedAlgorithms = []Alg{RS256}
|
||||
}
|
||||
|
||||
actual := Alg(jws.Signatures[0].Header.Algorithm)
|
||||
for _, expected := range expectedAlgorithms {
|
||||
if expected == actual {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("token signed with unexpected algorithm")
|
||||
}
|
||||
|
||||
// validateAudience returns an error if audClaim does not contain any audiences
|
||||
// given by expectedAudiences. If expectedAudiences is empty, it skips validation
|
||||
// and returns nil.
|
||||
func validateAudience(expectedAudiences, audClaim []string) error {
|
||||
if len(expectedAudiences) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, v := range expectedAudiences {
|
||||
if contains(audClaim, v) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return errors.New("audience claim does not match any expected audience")
|
||||
}
|
||||
|
||||
func contains(sl []string, st string) bool {
|
||||
for _, s := range sl {
|
||||
if s == st {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
|
@ -0,0 +1,241 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"mime"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/hashicorp/go-cleanhttp"
|
||||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
// KeySet represents a set of keys that can be used to verify the signatures of JWTs.
|
||||
// A KeySet is expected to be backed by a set of local or remote keys.
|
||||
type KeySet interface {
|
||||
|
||||
// VerifySignature parses the given JWT, verifies its signature, and returns the claims in its payload.
|
||||
// The given JWT must be of the JWS compact serialization form.
|
||||
VerifySignature(ctx context.Context, token string) (claims map[string]interface{}, err error)
|
||||
}
|
||||
|
||||
// jsonWebKeySet verifies JWT signatures using keys obtained from a JWKS URL.
|
||||
type jsonWebKeySet struct {
|
||||
remoteJWKS oidc.KeySet
|
||||
}
|
||||
|
||||
// staticKeySet verifies JWT signatures using local public keys.
|
||||
type staticKeySet struct {
|
||||
publicKeys []crypto.PublicKey
|
||||
}
|
||||
|
||||
// NewOIDCDiscoveryKeySet returns a KeySet that verifies JWT signatures using keys from the
|
||||
// JSON Web Key Set (JWKS) published in the discovery document at the given issuer URL.
|
||||
// The client used to obtain the remote keys will verify server certificates using the root
|
||||
// certificates provided by issuerCAPEM. If issuerCAPEM is not provided, system certificates
|
||||
// are used.
|
||||
func NewOIDCDiscoveryKeySet(ctx context.Context, issuer string, issuerCAPEM string) (KeySet, error) {
|
||||
if issuer == "" {
|
||||
return nil, errors.New("issuer must not be empty")
|
||||
}
|
||||
|
||||
// Configure an http client with the given certificates
|
||||
caCtx, err := createCAContext(ctx, issuerCAPEM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client := http.DefaultClient
|
||||
if c, ok := caCtx.Value(oauth2.HTTPClient).(*http.Client); ok {
|
||||
client = c
|
||||
}
|
||||
|
||||
// Create and send the http request for the OIDC discovery document
|
||||
wellKnown := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
|
||||
req, err := http.NewRequest(http.MethodGet, wellKnown, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := client.Do(req.WithContext(caCtx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read the response body and status code
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read response body: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("%s: %s", resp.Status, body)
|
||||
}
|
||||
|
||||
// Unmarshal the response body to obtain the issuer and JWKS URL
|
||||
var p struct {
|
||||
Issuer string `json:"issuer"`
|
||||
JWKSURL string `json:"jwks_uri"`
|
||||
}
|
||||
if err := unmarshalResp(resp, body, &p); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode OIDC discovery document: %w", err)
|
||||
}
|
||||
|
||||
// Ensure that the returned issuer matches what was given by issuer
|
||||
if p.Issuer != issuer {
|
||||
return nil, fmt.Errorf("issuer did not match the returned issuer, expected %q got %q",
|
||||
issuer, p.Issuer)
|
||||
}
|
||||
|
||||
return &jsonWebKeySet{
|
||||
remoteJWKS: oidc.NewRemoteKeySet(caCtx, p.JWKSURL),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewJSONWebKeySet returns a KeySet that verifies JWT signatures using keys from the JSON Web
|
||||
// Key Set (JWKS) at the given jwksURL. The client used to obtain the remote JWKS will verify
|
||||
// server certificates using the root certificates provided by jwksCAPEM. If jwksCAPEM is not
|
||||
// provided, system certificates are used.
|
||||
func NewJSONWebKeySet(ctx context.Context, jwksURL string, jwksCAPEM string) (KeySet, error) {
|
||||
if jwksURL == "" {
|
||||
return nil, errors.New("jwksURL must not be empty")
|
||||
}
|
||||
|
||||
caCtx, err := createCAContext(ctx, jwksCAPEM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &jsonWebKeySet{
|
||||
remoteJWKS: oidc.NewRemoteKeySet(caCtx, jwksURL),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// VerifySignature parses the given JWT, verifies its signature using JWKS keys, and returns
|
||||
// the claims in its payload. The given JWT must be of the JWS compact serialization form.
|
||||
func (ks *jsonWebKeySet) VerifySignature(ctx context.Context, token string) (map[string]interface{}, error) {
|
||||
payload, err := ks.remoteJWKS.VerifySignature(ctx, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Unmarshal payload into a set of all received claims
|
||||
allClaims := map[string]interface{}{}
|
||||
if err := json.Unmarshal(payload, &allClaims); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return allClaims, nil
|
||||
}
|
||||
|
||||
// NewStaticKeySet returns a KeySet that verifies JWT signatures using the given publicKeys.
|
||||
func NewStaticKeySet(publicKeys []crypto.PublicKey) (KeySet, error) {
|
||||
if len(publicKeys) == 0 {
|
||||
return nil, errors.New("publicKeys must not be empty")
|
||||
}
|
||||
|
||||
return &staticKeySet{
|
||||
publicKeys: publicKeys,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// VerifySignature parses the given JWT, verifies its signature using local public keys, and
|
||||
// returns the claims in its payload. The given JWT must be of the JWS compact serialization form.
|
||||
func (ks *staticKeySet) VerifySignature(_ context.Context, token string) (map[string]interface{}, error) {
|
||||
parsedJWT, err := jwt.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var valid bool
|
||||
allClaims := map[string]interface{}{}
|
||||
for _, key := range ks.publicKeys {
|
||||
if err := parsedJWT.Claims(key, &allClaims); err == nil {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
return nil, errors.New("no known key successfully validated the token signature")
|
||||
}
|
||||
|
||||
return allClaims, nil
|
||||
}
|
||||
|
||||
// ParsePublicKeyPEM is used to parse RSA and ECDSA public keys from PEMs. The given
|
||||
// data must be of PEM-encoded x509 certificate or PKIX public key forms. It returns
|
||||
// an *rsa.PublicKey or *ecdsa.PublicKey.
|
||||
func ParsePublicKeyPEM(data []byte) (crypto.PublicKey, error) {
|
||||
block, data := pem.Decode(data)
|
||||
if block != nil {
|
||||
var rawKey interface{}
|
||||
var err error
|
||||
if rawKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil {
|
||||
if cert, err := x509.ParseCertificate(block.Bytes); err == nil {
|
||||
rawKey = cert.PublicKey
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if rsaPublicKey, ok := rawKey.(*rsa.PublicKey); ok {
|
||||
return rsaPublicKey, nil
|
||||
}
|
||||
if ecPublicKey, ok := rawKey.(*ecdsa.PublicKey); ok {
|
||||
return ecPublicKey, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("data does not contain any valid RSA or ECDSA public keys")
|
||||
}
|
||||
|
||||
// createCAContext returns a context with a custom TLS client that's configured with the root
|
||||
// certificates from caPEM. If no certificates are configured, the original context is returned.
|
||||
func createCAContext(ctx context.Context, caPEM string) (context.Context, error) {
|
||||
if caPEM == "" {
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
certPool := x509.NewCertPool()
|
||||
if ok := certPool.AppendCertsFromPEM([]byte(caPEM)); !ok {
|
||||
return nil, errors.New("could not parse CA PEM value successfully")
|
||||
}
|
||||
|
||||
tr := cleanhttp.DefaultPooledTransport()
|
||||
tr.TLSClientConfig = &tls.Config{
|
||||
RootCAs: certPool,
|
||||
}
|
||||
tc := &http.Client{
|
||||
Transport: tr,
|
||||
}
|
||||
|
||||
caCtx := context.WithValue(ctx, oauth2.HTTPClient, tc)
|
||||
|
||||
return caCtx, nil
|
||||
}
|
||||
|
||||
// unmarshalResp JSON unmarshals the given body into the value pointed to by v.
|
||||
// If it is unable to JSON unmarshal body into v, then it returns an appropriate
|
||||
// error based on the Content-Type header of r.
|
||||
func unmarshalResp(r *http.Response, body []byte, v interface{}) error {
|
||||
err := json.Unmarshal(body, &v)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
ct := r.Header.Get("Content-Type")
|
||||
mediaType, _, parseErr := mime.ParseMediaType(ct)
|
||||
if parseErr == nil && mediaType == "application/json" {
|
||||
return fmt.Errorf("got Content-Type = application/json, but could not unmarshal as JSON: %v", err)
|
||||
}
|
||||
return fmt.Errorf("expected Content-Type = application/json, got %q: %v", ct, err)
|
||||
}
|
|
@ -0,0 +1,131 @@
|
|||
# oidc
|
||||
[![Go Reference](https://pkg.go.dev/badge/github.com/hashicorp/cap/oidc.svg)](https://pkg.go.dev/github.com/hashicorp/cap/oidc)
|
||||
|
||||
oidc is a package for writing clients that integrate with OIDC Providers using
|
||||
OIDC flows.
|
||||
|
||||
Primary types provided by the package:
|
||||
|
||||
* `Request`: represents one OIDC authentication flow for a user. It contains the
|
||||
data needed to uniquely represent that one-time flow across the multiple
|
||||
interactions needed to complete the OIDC flow the user is attempting. All
|
||||
Requests contain an expiration for the user's OIDC flow.
|
||||
|
||||
* `Token`: represents an OIDC id_token, as well as an Oauth2 access_token and
|
||||
refresh_token (including the the access_token expiry)
|
||||
|
||||
* `Config`: provides the configuration for a typical 3-legged OIDC
|
||||
authorization code flow (for example: client ID/Secret, redirectURL, supported
|
||||
signing algorithms, additional scopes requested, etc)
|
||||
|
||||
* `Provider`: provides integration with an OIDC provider.
|
||||
The provider provides capabilities like: generating an auth URL, exchanging
|
||||
codes for tokens, verifying tokens, making user info requests, etc.
|
||||
|
||||
* `Alg`: represents asymmetric signing algorithms
|
||||
|
||||
* `Error`: provides an error and provides the ability to specify an error code,
|
||||
operation that raised the error, the kind of error, and any wrapped error
|
||||
|
||||
#### [oidc.callback](callback/)
|
||||
[![Go Reference](https://pkg.go.dev/badge/github.com/hashicorp/cap/oidc/callback.svg)](https://pkg.go.dev/github.com/hashicorp/cap/oidc/callback)
|
||||
|
||||
The callback package includes handlers (http.HandlerFunc) which can be used
|
||||
for the callback leg an OIDC flow. Callback handlers for both the authorization
|
||||
code flow (with optional PKCE) and the implicit flow are provided.
|
||||
|
||||
<hr>
|
||||
|
||||
### Examples:
|
||||
|
||||
* [CLI example](examples/cli/) which implements an OIDC
|
||||
user authentication CLI.
|
||||
|
||||
* [SPA example](examples/spa) which implements an OIDC user
|
||||
authentication SPA (single page app).
|
||||
|
||||
<hr>
|
||||
|
||||
Example of a provider using an authorization code flow:
|
||||
|
||||
```go
|
||||
// Create a new provider config
|
||||
pc, err := oidc.NewConfig(
|
||||
"http://your-issuer.com/",
|
||||
"your_client_id",
|
||||
"your_client_secret",
|
||||
[]oidc.Alg{oidc.RS256},
|
||||
[]string{"http://your_redirect_url"},
|
||||
)
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
|
||||
// Create a provider
|
||||
p, err := oidc.NewProvider(pc)
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
defer p.Done()
|
||||
|
||||
|
||||
// Create a Request for a user's authentication attempt that will use the
|
||||
// authorization code flow. (See NewRequest(...) using the WithPKCE and
|
||||
// WithImplicit options for creating a Request that uses those flows.)
|
||||
oidcRequest, err := oidc.NewRequest(2 * time.Minute, "http://your_redirect_url/callback")
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
|
||||
// Create an auth URL
|
||||
authURL, err := p.AuthURL(context.Background(), oidcRequest)
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
fmt.Println("open url to kick-off authentication: ", authURL)
|
||||
```
|
||||
|
||||
Create a http.Handler for OIDC authentication response redirects.
|
||||
|
||||
```go
|
||||
func NewHandler(ctx context.Context, p *oidc.Provider, rw callback.RequestReader) (http.HandlerFunc, error)
|
||||
if p == nil {
|
||||
// handle error
|
||||
}
|
||||
if rw == nil {
|
||||
// handle error
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
oidcRequest, err := rw.Read(ctx, req.FormValue("state"))
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
// Exchange(...) will verify the tokens before returning.
|
||||
token, err := p.Exchange(ctx, oidcRequest, req.FormValue("state"), req.FormValue("code"))
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
var claims map[string]interface{}
|
||||
if err := t.IDToken().Claims(&claims); err != nil {
|
||||
// handle error
|
||||
}
|
||||
|
||||
// Get the user's claims via the provider's UserInfo endpoint
|
||||
var infoClaims map[string]interface{}
|
||||
err = p.UserInfo(ctx, token.StaticTokenSource(), claims["sub"].(string), &infoClaims)
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
resp := struct {
|
||||
IDTokenClaims map[string]interface{}
|
||||
UserInfoClaims map[string]interface{}
|
||||
}{claims, infoClaims}
|
||||
enc := json.NewEncoder(w)
|
||||
if err := enc.Encode(resp); err != nil {
|
||||
// handle error
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
package oidc
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// AccessToken is an oauth access_token.
|
||||
type AccessToken string
|
||||
|
||||
// RedactedAccessToken is the redacted string or json for an oauth access_token.
|
||||
const RedactedAccessToken = "[REDACTED: access_token]"
|
||||
|
||||
// String will redact the token.
|
||||
func (t AccessToken) String() string {
|
||||
return RedactedAccessToken
|
||||
}
|
||||
|
||||
// MarshalJSON will redact the token.
|
||||
func (t AccessToken) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(RedactedAccessToken)
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
package oidc
|
||||
|
||||
// Alg represents asymmetric signing algorithms
|
||||
type Alg string
|
||||
|
||||
const (
|
||||
// JOSE asymmetric signing algorithm values as defined by RFC 7518.
|
||||
//
|
||||
// See: https://tools.ietf.org/html/rfc7518#section-3.1
|
||||
RS256 Alg = "RS256" // RSASSA-PKCS-v1.5 using SHA-256
|
||||
RS384 Alg = "RS384" // RSASSA-PKCS-v1.5 using SHA-384
|
||||
RS512 Alg = "RS512" // RSASSA-PKCS-v1.5 using SHA-512
|
||||
ES256 Alg = "ES256" // ECDSA using P-256 and SHA-256
|
||||
ES384 Alg = "ES384" // ECDSA using P-384 and SHA-384
|
||||
ES512 Alg = "ES512" // ECDSA using P-521 and SHA-512
|
||||
PS256 Alg = "PS256" // RSASSA-PSS using SHA256 and MGF1-SHA256
|
||||
PS384 Alg = "PS384" // RSASSA-PSS using SHA384 and MGF1-SHA384
|
||||
PS512 Alg = "PS512" // RSASSA-PSS using SHA512 and MGF1-SHA512
|
||||
EdDSA Alg = "EdDSA"
|
||||
)
|
||||
|
||||
var supportedAlgorithms = map[Alg]bool{
|
||||
RS256: true,
|
||||
RS384: true,
|
||||
RS512: true,
|
||||
ES256: true,
|
||||
ES384: true,
|
||||
ES512: true,
|
||||
PS256: true,
|
||||
PS384: true,
|
||||
PS512: true,
|
||||
EdDSA: true,
|
||||
}
|
|
@ -0,0 +1,239 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/hashicorp/cap/oidc/internal/strutils"
|
||||
)
|
||||
|
||||
// ClientSecret is an oauth client Secret.
|
||||
type ClientSecret string
|
||||
|
||||
// RedactedClientSecret is the redacted string or json for an oauth client secret.
|
||||
const RedactedClientSecret = "[REDACTED: client secret]"
|
||||
|
||||
// String will redact the client secret.
|
||||
func (t ClientSecret) String() string {
|
||||
return RedactedClientSecret
|
||||
}
|
||||
|
||||
// MarshalJSON will redact the client secret.
|
||||
func (t ClientSecret) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(RedactedClientSecret)
|
||||
}
|
||||
|
||||
// Config represents the configuration for an OIDC provider used by a relying
|
||||
// party.
|
||||
type Config struct {
|
||||
// ClientID is the relying party ID.
|
||||
ClientID string
|
||||
|
||||
// ClientSecret is the relying party secret. This may be empty if you only
|
||||
// intend to use the provider with the authorization Code with PKCE or the
|
||||
// implicit flows.
|
||||
ClientSecret ClientSecret
|
||||
|
||||
// Scopes is a list of default oidc scopes to request of the provider. The
|
||||
// required "oidc" scope is requested by default, and does not need to be
|
||||
// part of this optional list. If a Request has scopes, they will override
|
||||
// this configured list for a specific authentication attempt.
|
||||
Scopes []string
|
||||
|
||||
// Issuer is a case-sensitive URL string using the https scheme that
|
||||
// contains scheme, host, and optionally, port number and path components
|
||||
// and no query or fragment components.
|
||||
// See the Issuer Identifier spec: https://openid.net/specs/openid-connect-core-1_0.html#IssuerIdentifier
|
||||
// See the OIDC connect discovery spec: https://openid.net/specs/openid-connect-discovery-1_0.html#IdentifierNormalization
|
||||
// See the id_token spec: https://tools.ietf.org/html/rfc7519#section-4.1.1
|
||||
Issuer string
|
||||
|
||||
// SupportedSigningAlgs is a list of supported signing algorithms. List of
|
||||
// currently supported algs: RS256, RS384, RS512, ES256, ES384, ES512,
|
||||
// PS256, PS384, PS512
|
||||
//
|
||||
// The list can be used to limit the supported algorithms when verifying
|
||||
// id_token signatures, an id_token's at_hash claim against an
|
||||
// access_token, etc.
|
||||
SupportedSigningAlgs []Alg
|
||||
|
||||
// AllowedRedirectURLs is a list of allowed URLs for the provider to
|
||||
// redirect to after a user authenticates. If AllowedRedirects is empty,
|
||||
// the package will not check the Request.RedirectURL() to see if it's
|
||||
// allowed, and the check will be left to the OIDC provider's /authorize
|
||||
// endpoint.
|
||||
AllowedRedirectURLs []string
|
||||
|
||||
// Audiences is an optional default list of case-sensitive strings to use when
|
||||
// verifying an id_token's "aud" claim (which is also a list) If provided,
|
||||
// the audiences of an id_token must match one of the configured audiences.
|
||||
// If a Request has audiences, they will override this configured list for a
|
||||
// specific authentication attempt.
|
||||
Audiences []string
|
||||
|
||||
// ProviderCA is an optional CA certs (PEM encoded) to use when sending
|
||||
// requests to the provider. If you have a list of *x509.Certificates, then
|
||||
// see EncodeCertificates(...) to PEM encode them.
|
||||
ProviderCA string
|
||||
|
||||
// NowFunc is a time func that returns the current time.
|
||||
NowFunc func() time.Time
|
||||
}
|
||||
|
||||
// NewConfig composes a new config for a provider.
|
||||
//
|
||||
// The "oidc" scope will always be added to the new configuration's Scopes,
|
||||
// regardless of what additional scopes are requested via the WithScopes option
|
||||
// and duplicate scopes are allowed.
|
||||
//
|
||||
// Supported options: WithProviderCA, WithScopes, WithAudiences, WithNow
|
||||
func NewConfig(issuer string, clientID string, clientSecret ClientSecret, supported []Alg, allowedRedirectURLs []string, opt ...Option) (*Config, error) {
|
||||
const op = "NewConfig"
|
||||
opts := getConfigOpts(opt...)
|
||||
c := &Config{
|
||||
Issuer: issuer,
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
SupportedSigningAlgs: supported,
|
||||
Scopes: opts.withScopes,
|
||||
ProviderCA: opts.withProviderCA,
|
||||
Audiences: opts.withAudiences,
|
||||
NowFunc: opts.withNowFunc,
|
||||
AllowedRedirectURLs: allowedRedirectURLs,
|
||||
}
|
||||
if err := c.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("%s: invalid provider config: %w", op, err)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Validate the provider configuration. Among other validations, it verifies
|
||||
// the issuer is not empty, but it doesn't verify the Issuer is discoverable via
|
||||
// an http request. SupportedSigningAlgs are validated against the list of
|
||||
// currently supported algs: RS256, RS384, RS512, ES256, ES384, ES512, PS256,
|
||||
// PS384, PS512
|
||||
func (c *Config) Validate() error {
|
||||
const op = "Config.Validate"
|
||||
|
||||
// Note: c.ClientSecret is intentionally not checked for empty, in order to
|
||||
// support providers that only use the implicit flow or PKCE.
|
||||
if c == nil {
|
||||
return fmt.Errorf("%s: provider config is nil: %w", op, ErrNilParameter)
|
||||
}
|
||||
if c.ClientID == "" {
|
||||
return fmt.Errorf("%s: client ID is empty: %w", op, ErrInvalidParameter)
|
||||
}
|
||||
if c.Issuer == "" {
|
||||
return fmt.Errorf("%s: discovery URL is empty: %w", op, ErrInvalidParameter)
|
||||
}
|
||||
if len(c.AllowedRedirectURLs) > 0 {
|
||||
var invalidURLs []string
|
||||
for _, allowed := range c.AllowedRedirectURLs {
|
||||
if _, err := url.Parse(allowed); err != nil {
|
||||
invalidURLs = append(invalidURLs, allowed)
|
||||
}
|
||||
}
|
||||
if len(invalidURLs) > 0 {
|
||||
return fmt.Errorf("%s: Invalid AllowedRedirectURLs provided %s: %w", op, strings.Join(invalidURLs, ", "), ErrInvalidParameter)
|
||||
}
|
||||
}
|
||||
|
||||
u, err := url.Parse(c.Issuer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: issuer %s is invalid (%s): %w", op, c.Issuer, err, ErrInvalidIssuer)
|
||||
}
|
||||
if !strutils.StrListContains([]string{"https", "http"}, u.Scheme) {
|
||||
return fmt.Errorf("%s: issuer %s schema is not http or https: %w", op, c.Issuer, ErrInvalidIssuer)
|
||||
}
|
||||
if len(c.SupportedSigningAlgs) == 0 {
|
||||
return fmt.Errorf("%s: supported algorithms is empty: %w", op, ErrInvalidParameter)
|
||||
}
|
||||
for _, a := range c.SupportedSigningAlgs {
|
||||
if !supportedAlgorithms[a] {
|
||||
return fmt.Errorf("%s: unsupported algorithm %s: %w", op, a, ErrInvalidParameter)
|
||||
}
|
||||
}
|
||||
if c.ProviderCA != "" {
|
||||
certPool := x509.NewCertPool()
|
||||
if ok := certPool.AppendCertsFromPEM([]byte(c.ProviderCA)); !ok {
|
||||
return fmt.Errorf("%s: %w", op, ErrInvalidCACert)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Now will return the current time which can be overridden by the NowFunc
|
||||
func (c *Config) Now() time.Time {
|
||||
if c.NowFunc != nil {
|
||||
return c.NowFunc()
|
||||
}
|
||||
return time.Now() // fallback to this default
|
||||
}
|
||||
|
||||
// configOptions is the set of available options
|
||||
type configOptions struct {
|
||||
withScopes []string
|
||||
withAudiences []string
|
||||
withProviderCA string
|
||||
withNowFunc func() time.Time
|
||||
}
|
||||
|
||||
// configDefaults is a handy way to get the defaults at runtime and
|
||||
// during unit tests.
|
||||
func configDefaults() configOptions {
|
||||
return configOptions{
|
||||
withScopes: []string{oidc.ScopeOpenID},
|
||||
}
|
||||
}
|
||||
|
||||
// getConfigOpts gets the defaults and applies the opt overrides passed
|
||||
// in.
|
||||
func getConfigOpts(opt ...Option) configOptions {
|
||||
opts := configDefaults()
|
||||
ApplyOpts(&opts, opt...)
|
||||
return opts
|
||||
}
|
||||
|
||||
// WithProviderCA provides optional CA certs (PEM encoded) for the provider's
|
||||
// config. These certs will can be used when making http requests to the
|
||||
// provider.
|
||||
//
|
||||
// Valid for: Config
|
||||
//
|
||||
// See EncodeCertificates(...) to PEM encode a number of certs.
|
||||
func WithProviderCA(cert string) Option {
|
||||
return func(o interface{}) {
|
||||
if o, ok := o.(*configOptions); ok {
|
||||
o.withProviderCA = cert
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// EncodeCertificates will encode a number of x509 certificates to PEM. It will
|
||||
// help encode certs for use with the WithProviderCA(...) option.
|
||||
func EncodeCertificates(certs ...*x509.Certificate) (string, error) {
|
||||
const op = "EncodeCert"
|
||||
var buffer bytes.Buffer
|
||||
if len(certs) == 0 {
|
||||
return "", fmt.Errorf("%s: no certs provided: %w", op, ErrInvalidParameter)
|
||||
}
|
||||
for _, cert := range certs {
|
||||
if cert == nil {
|
||||
return "", fmt.Errorf("%s: empty cert: %w", op, ErrNilParameter)
|
||||
}
|
||||
if err := pem.Encode(&buffer, &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Raw,
|
||||
}); err != nil {
|
||||
return "", fmt.Errorf("%s: unable to encode cert: %w", op, err)
|
||||
}
|
||||
}
|
||||
return buffer.String(), nil
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
package oidc
|
||||
|
||||
// Display is a string value that specifies how the Authorization Server
|
||||
// displays the authentication and consent user interface pages to the End-User.
|
||||
//
|
||||
// See: https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
||||
type Display string
|
||||
|
||||
const (
|
||||
// Defined the Display values that specifies how the Authorization Server
|
||||
// displays the authentication and consent user interface pages to the End-User.
|
||||
//
|
||||
// See: https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
||||
Page Display = "page"
|
||||
Popup Display = "popup"
|
||||
Touch Display = "touch"
|
||||
WAP Display = "wap"
|
||||
)
|
|
@ -0,0 +1,43 @@
|
|||
/*
|
||||
oidc is a package for writing clients that integrate with OIDC Providers using
|
||||
OIDC flows.
|
||||
|
||||
|
||||
Primary types provided by the package:
|
||||
|
||||
* Request: represents one OIDC authentication flow for a user. It contains the
|
||||
data needed to uniquely represent that one-time flow across the multiple
|
||||
interactions needed to complete the OIDC flow the user is attempting. All
|
||||
Requests contain an expiration for the user's OIDC flow. Optionally, Requests may
|
||||
contain overrides of configured provider defaults for audiences, scopes and a
|
||||
redirect URL.
|
||||
|
||||
* Token: represents an OIDC id_token, as well as an Oauth2 access_token and
|
||||
refresh_token (including the access_token expiry)
|
||||
|
||||
* Config: provides the configuration for OIDC provider used by a relying
|
||||
party (for example: client ID/Secret, redirectURL, supported
|
||||
signing algorithms, additional scopes requested, etc)
|
||||
|
||||
* Provider: provides integration with a provider. The provider provides
|
||||
capabilities like: generating an auth URL, exchanging codes for tokens,
|
||||
verifying tokens, making user info requests, etc.
|
||||
|
||||
The oidc.callback package
|
||||
|
||||
The callback package includes handlers (http.HandlerFunc) which can be used
|
||||
for the callback leg an OIDC flow. Callback handlers for both the authorization
|
||||
code flow (with optional PKCE) and the implicit flow are provided.
|
||||
|
||||
Example apps
|
||||
|
||||
Complete concise example solutions:
|
||||
|
||||
* OIDC authentication CLI:
|
||||
https://github.com/hashicorp/cap/tree/main/oidc/examples/cli/
|
||||
|
||||
* OIDC authentication SPA:
|
||||
https://github.com/hashicorp/cap/tree/main/oidc/examples/spa/
|
||||
|
||||
*/
|
||||
package oidc
|
|
@ -0,0 +1,40 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidParameter = errors.New("invalid parameter")
|
||||
ErrNilParameter = errors.New("nil parameter")
|
||||
ErrInvalidCACert = errors.New("invalid CA certificate")
|
||||
ErrInvalidIssuer = errors.New("invalid issuer")
|
||||
ErrExpiredRequest = errors.New("request is expired")
|
||||
ErrInvalidResponseState = errors.New("invalid response state")
|
||||
ErrInvalidSignature = errors.New("invalid signature")
|
||||
ErrInvalidSubject = errors.New("invalid subject")
|
||||
ErrInvalidAudience = errors.New("invalid audience")
|
||||
ErrInvalidNonce = errors.New("invalid nonce")
|
||||
ErrInvalidNotBefore = errors.New("invalid not before")
|
||||
ErrExpiredToken = errors.New("token is expired")
|
||||
ErrInvalidJWKs = errors.New("invalid jwks")
|
||||
ErrInvalidIssuedAt = errors.New("invalid issued at (iat)")
|
||||
ErrInvalidAuthorizedParty = errors.New("invalid authorized party (azp)")
|
||||
ErrInvalidAtHash = errors.New("access_token hash does not match value in id_token")
|
||||
ErrInvalidCodeHash = errors.New("authorization code hash does not match value in id_token")
|
||||
ErrTokenNotSigned = errors.New("token is not signed")
|
||||
ErrMalformedToken = errors.New("token malformed")
|
||||
ErrUnsupportedAlg = errors.New("unsupported signing algorithm")
|
||||
ErrIDGeneratorFailed = errors.New("id generation failed")
|
||||
ErrMissingIDToken = errors.New("id_token is missing")
|
||||
ErrMissingAccessToken = errors.New("access_token is missing")
|
||||
ErrIDTokenVerificationFailed = errors.New("id_token verification failed")
|
||||
ErrNotFound = errors.New("not found")
|
||||
ErrLoginFailed = errors.New("login failed")
|
||||
ErrUserInfoFailed = errors.New("user info failed")
|
||||
ErrUnauthorizedRedirectURI = errors.New("unauthorized redirect_uri")
|
||||
ErrInvalidFlow = errors.New("invalid OIDC flow")
|
||||
ErrUnsupportedChallengeMethod = errors.New("unsupported PKCE challenge method")
|
||||
ErrExpiredAuthTime = errors.New("expired auth_time")
|
||||
ErrMissingClaim = errors.New("missing required claim")
|
||||
)
|
|
@ -0,0 +1,71 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/cap/oidc/internal/base62"
|
||||
)
|
||||
|
||||
// DefaultIDLength is the default length for generated IDs, which are used for
|
||||
// state and nonce parameters during OIDC flows.
|
||||
//
|
||||
// For ID length requirements see:
|
||||
// https://tools.ietf.org/html/rfc6749#section-10.10
|
||||
const DefaultIDLength = 20
|
||||
|
||||
// NewID generates a ID with an optional prefix. The ID generated is suitable
|
||||
// for a Request's State or Nonce. The ID length will be DefaultIDLen, unless an
|
||||
// optional prefix is provided which will add the prefix's length + an
|
||||
// underscore. The WithPrefix, WithLen options are supported.
|
||||
//
|
||||
// For ID length requirements see:
|
||||
// https://tools.ietf.org/html/rfc6749#section-10.10
|
||||
func NewID(opt ...Option) (string, error) {
|
||||
const op = "NewID"
|
||||
opts := getIDOpts(opt...)
|
||||
id, err := base62.Random(opts.withLen)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%s: unable to generate id: %w", op, err)
|
||||
}
|
||||
switch {
|
||||
case opts.withPrefix != "":
|
||||
return fmt.Sprintf("%s_%s", opts.withPrefix, id), nil
|
||||
default:
|
||||
return id, nil
|
||||
}
|
||||
}
|
||||
|
||||
// idOptions is the set of available options.
|
||||
type idOptions struct {
|
||||
withPrefix string
|
||||
withLen int
|
||||
}
|
||||
|
||||
// idDefaults is a handy way to get the defaults at runtime and
|
||||
// during unit tests.
|
||||
func idDefaults() idOptions {
|
||||
return idOptions{
|
||||
withLen: DefaultIDLength,
|
||||
}
|
||||
}
|
||||
|
||||
// getConfigOpts gets the defaults and applies the opt overrides passed
|
||||
// in.
|
||||
func getIDOpts(opt ...Option) idOptions {
|
||||
opts := idDefaults()
|
||||
ApplyOpts(&opts, opt...)
|
||||
return opts
|
||||
}
|
||||
|
||||
// WithPrefix provides an optional prefix for an new ID. When this options is
|
||||
// provided, NewID will prepend the prefix and an underscore to the new
|
||||
// identifier.
|
||||
//
|
||||
// Valid for: ID
|
||||
func WithPrefix(prefix string) Option {
|
||||
return func(o interface{}) {
|
||||
if o, ok := o.(*idOptions); ok {
|
||||
o.withPrefix = prefix
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,145 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
// IDToken is an oidc id_token.
|
||||
// See https://openid.net/specs/openid-connect-core-1_0.html#IDToken.
|
||||
type IDToken string
|
||||
|
||||
// RedactedIDToken is the redacted string or json for an oidc id_token.
|
||||
const RedactedIDToken = "[REDACTED: id_token]"
|
||||
|
||||
// String will redact the token.
|
||||
func (t IDToken) String() string {
|
||||
return RedactedIDToken
|
||||
}
|
||||
|
||||
// MarshalJSON will redact the token.
|
||||
func (t IDToken) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(RedactedIDToken)
|
||||
}
|
||||
|
||||
// Claims retrieves the IDToken claims.
|
||||
func (t IDToken) Claims(claims interface{}) error {
|
||||
const op = "IDToken.Claims"
|
||||
if len(t) == 0 {
|
||||
return fmt.Errorf("%s: id_token is empty: %w", op, ErrInvalidParameter)
|
||||
}
|
||||
if claims == nil {
|
||||
return fmt.Errorf("%s: claims interface is nil: %w", op, ErrNilParameter)
|
||||
}
|
||||
return UnmarshalClaims(string(t), claims)
|
||||
}
|
||||
|
||||
// VerifyAccessToken verifies the at_hash claim of the id_token against the hash
|
||||
// of the access_token.
|
||||
//
|
||||
// It will return true when it can verify the access_token. It will return false
|
||||
// when it's unable to verify the access_token.
|
||||
//
|
||||
// It will return an error whenever it's possible to verify the access_token and
|
||||
// the verification fails.
|
||||
//
|
||||
// Note: while we support signing id_tokens with EdDSA, unfortunately the
|
||||
// access_token hash cannot be verified without knowing the key's curve. See:
|
||||
// https://bitbucket.org/openid/connect/issues/1125
|
||||
//
|
||||
// For more info about verifying access_tokens returned during an OIDC flow see:
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken
|
||||
func (t IDToken) VerifyAccessToken(accessToken AccessToken) (bool, error) {
|
||||
const op = "VerifyAccessToken"
|
||||
canVerify, err := t.verifyHashClaim("at_hash", string(accessToken))
|
||||
if err != nil {
|
||||
return canVerify, fmt.Errorf("%s: %w", op, err)
|
||||
}
|
||||
return canVerify, nil
|
||||
}
|
||||
|
||||
// VerifyAuthorizationCode verifies the c_hash claim of the id_token against the
|
||||
// hash of the authorization code.
|
||||
//
|
||||
// It will return true when it can verify the authorization code. It will return
|
||||
// false when it's unable to verify the authorization code.
|
||||
//
|
||||
// It will return an error whenever it's possible to verify the authorization
|
||||
// code and the verification fails.
|
||||
//
|
||||
// Note: while we support signing id_tokens with EdDSA, unfortunately the
|
||||
// authorization code hash cannot be verified without knowing the key's curve.
|
||||
// See: https://bitbucket.org/openid/connect/issues/1125
|
||||
//
|
||||
// For more info about authorization code verification using the id_token's
|
||||
// c_hash claim see:
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#HybridIDToken
|
||||
func (t IDToken) VerifyAuthorizationCode(code string) (bool, error) {
|
||||
const op = "VerifyAccessToken"
|
||||
canVerify, err := t.verifyHashClaim("c_hash", code)
|
||||
if err != nil {
|
||||
return canVerify, fmt.Errorf("%s: %w", op, err)
|
||||
}
|
||||
return canVerify, nil
|
||||
}
|
||||
|
||||
func (t IDToken) verifyHashClaim(claimName string, token string) (bool, error) {
|
||||
const op = "verifyHashClaim"
|
||||
var claims map[string]interface{}
|
||||
if err := t.Claims(&claims); err != nil {
|
||||
return false, fmt.Errorf("%s: %w", op, err)
|
||||
}
|
||||
tokenHash, ok := claims[claimName].(string)
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
jws, err := jose.ParseSigned(string(t))
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("%s: malformed jwt (%v): %w", op, err, ErrMalformedToken)
|
||||
}
|
||||
switch len(jws.Signatures) {
|
||||
case 0:
|
||||
return false, fmt.Errorf("%s: id_token not signed: %w", op, ErrTokenNotSigned)
|
||||
case 1:
|
||||
default:
|
||||
return false, fmt.Errorf("%s: multiple signatures on id_token not supported", op)
|
||||
}
|
||||
sig := jws.Signatures[0]
|
||||
if _, ok := supportedAlgorithms[Alg(sig.Header.Algorithm)]; !ok {
|
||||
return false, fmt.Errorf("%s: id_token signed with algorithm %q: %w", op, sig.Header.Algorithm, ErrUnsupportedAlg)
|
||||
}
|
||||
sigAlgorithm := Alg(sig.Header.Algorithm)
|
||||
|
||||
var h hash.Hash
|
||||
switch sigAlgorithm {
|
||||
case RS256, ES256, PS256:
|
||||
h = sha256.New()
|
||||
case RS384, ES384, PS384:
|
||||
h = sha512.New384()
|
||||
case RS512, ES512, PS512:
|
||||
h = sha512.New()
|
||||
case EdDSA:
|
||||
return false, nil
|
||||
default:
|
||||
return false, fmt.Errorf("%s: unsupported signing algorithm %s: %w", op, sigAlgorithm, ErrUnsupportedAlg)
|
||||
}
|
||||
_, _ = h.Write([]byte(token)) // hash documents that Write will never return an error
|
||||
sum := h.Sum(nil)[:h.Size()/2]
|
||||
actual := base64.RawURLEncoding.EncodeToString(sum)
|
||||
if actual != tokenHash {
|
||||
switch claimName {
|
||||
case "at_hash":
|
||||
return false, fmt.Errorf("%s: %w", op, ErrInvalidAtHash)
|
||||
case "c_hash":
|
||||
return false, fmt.Errorf("%s: %w", op, ErrInvalidCodeHash)
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
|
@ -0,0 +1,50 @@
|
|||
// Package base62 provides utilities for working with base62 strings.
|
||||
// base62 strings will only contain characters: 0-9, a-z, A-Z
|
||||
package base62
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"io"
|
||||
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
)
|
||||
|
||||
const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
||||
const csLen = byte(len(charset))
|
||||
|
||||
// Random generates a random string using base-62 characters.
|
||||
// Resulting entropy is ~5.95 bits/character.
|
||||
func Random(length int) (string, error) {
|
||||
return RandomWithReader(length, rand.Reader)
|
||||
}
|
||||
|
||||
// RandomWithReader generates a random string using base-62 characters and a given reader.
|
||||
// Resulting entropy is ~5.95 bits/character.
|
||||
func RandomWithReader(length int, reader io.Reader) (string, error) {
|
||||
if length == 0 {
|
||||
return "", nil
|
||||
}
|
||||
output := make([]byte, 0, length)
|
||||
|
||||
// Request a bit more than length to reduce the chance
|
||||
// of needing more than one batch of random bytes
|
||||
batchSize := length + length/4
|
||||
|
||||
for {
|
||||
buf, err := uuid.GenerateRandomBytesWithReader(batchSize, reader)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for _, b := range buf {
|
||||
// Avoid bias by using a value range that's a multiple of 62
|
||||
if b < (csLen * 4) {
|
||||
output = append(output, charset[b%csLen])
|
||||
|
||||
if len(output) == length {
|
||||
return string(output), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
35
vendor/github.com/hashicorp/cap/oidc/internal/strutils/strutils.go
generated
vendored
Normal file
35
vendor/github.com/hashicorp/cap/oidc/internal/strutils/strutils.go
generated
vendored
Normal file
|
@ -0,0 +1,35 @@
|
|||
package strutils
|
||||
|
||||
import "strings"
|
||||
|
||||
// StrListContains looks for a string in a list of strings.
|
||||
func StrListContains(haystack []string, needle string) bool {
|
||||
for _, item := range haystack {
|
||||
if item == needle {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RemoveDuplicatesStable removes duplicate and empty elements from a slice of
|
||||
// strings, preserving order (and case) of the original slice.
|
||||
// In all cases, strings are compared after trimming whitespace
|
||||
// If caseInsensitive, strings will be compared after ToLower()
|
||||
func RemoveDuplicatesStable(items []string, caseInsensitive bool) []string {
|
||||
itemsMap := make(map[string]bool, len(items))
|
||||
deduplicated := make([]string, 0, len(items))
|
||||
|
||||
for _, item := range items {
|
||||
key := strings.TrimSpace(item)
|
||||
if caseInsensitive {
|
||||
key = strings.ToLower(key)
|
||||
}
|
||||
if key == "" || itemsMap[key] {
|
||||
continue
|
||||
}
|
||||
itemsMap[key] = true
|
||||
deduplicated = append(deduplicated, item)
|
||||
}
|
||||
return deduplicated
|
||||
}
|
|
@ -0,0 +1,85 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/hashicorp/cap/oidc/internal/strutils"
|
||||
)
|
||||
|
||||
// Option defines a common functional options type which can be used in a
|
||||
// variadic parameter pattern.
|
||||
type Option func(interface{})
|
||||
|
||||
// ApplyOpts takes a pointer to the options struct as a set of default options
|
||||
// and applies the slice of opts as overrides.
|
||||
func ApplyOpts(opts interface{}, opt ...Option) {
|
||||
for _, o := range opt {
|
||||
if o == nil { // ignore any nil Options
|
||||
continue
|
||||
}
|
||||
o(opts)
|
||||
}
|
||||
}
|
||||
|
||||
// WithNow provides an optional func for determining what the current time it
|
||||
// is.
|
||||
//
|
||||
// Valid for: Config, Tk and Request
|
||||
func WithNow(now func() time.Time) Option {
|
||||
return func(o interface{}) {
|
||||
if now == nil {
|
||||
return
|
||||
}
|
||||
switch v := o.(type) {
|
||||
case *configOptions:
|
||||
v.withNowFunc = now
|
||||
case *tokenOptions:
|
||||
v.withNowFunc = now
|
||||
case *reqOptions:
|
||||
v.withNowFunc = now
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithScopes provides an optional list of scopes.
|
||||
//
|
||||
// Valid for: Config and Request
|
||||
func WithScopes(scopes ...string) Option {
|
||||
return func(o interface{}) {
|
||||
if len(scopes) == 0 {
|
||||
return
|
||||
}
|
||||
switch v := o.(type) {
|
||||
case *configOptions:
|
||||
// configOptions already has the oidc.ScopeOpenID in its defaults.
|
||||
scopes = strutils.RemoveDuplicatesStable(scopes, false)
|
||||
v.withScopes = append(v.withScopes, scopes...)
|
||||
case *reqOptions:
|
||||
// need to prepend the oidc.ScopeOpenID
|
||||
ts := append([]string{oidc.ScopeOpenID}, scopes...)
|
||||
scopes = strutils.RemoveDuplicatesStable(ts, false)
|
||||
v.withScopes = append(v.withScopes, scopes...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithAudiences provides an optional list of audiences.
|
||||
//
|
||||
//Valid for: Config and Request
|
||||
func WithAudiences(auds ...string) Option {
|
||||
return func(o interface{}) {
|
||||
if len(auds) == 0 {
|
||||
return
|
||||
}
|
||||
auds := strutils.RemoveDuplicatesStable(auds, false)
|
||||
switch v := o.(type) {
|
||||
case *configOptions:
|
||||
v.withAudiences = append(v.withAudiences, auds...)
|
||||
case *reqOptions:
|
||||
v.withAudiences = append(v.withAudiences, auds...)
|
||||
case *userInfoOptions:
|
||||
v.withAudiences = append(v.withAudiences, auds...)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,99 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/cap/oidc/internal/base62"
|
||||
)
|
||||
|
||||
// ChallengeMethod represents PKCE code challenge methods as defined by RFC
|
||||
// 7636.
|
||||
type ChallengeMethod string
|
||||
|
||||
const (
|
||||
// PKCE code challenge methods as defined by RFC 7636.
|
||||
//
|
||||
// See: https://tools.ietf.org/html/rfc7636#page-9
|
||||
S256 ChallengeMethod = "S256" // SHA-256
|
||||
)
|
||||
|
||||
// CodeVerifier represents an OAuth PKCE code verifier.
|
||||
//
|
||||
// See: https://tools.ietf.org/html/rfc7636#section-4.1
|
||||
type CodeVerifier interface {
|
||||
|
||||
// Verifier returns the code verifier (see:
|
||||
// https://tools.ietf.org/html/rfc7636#section-4.1)
|
||||
Verifier() string
|
||||
|
||||
// Challenge returns the code verifier's code challenge (see:
|
||||
// https://tools.ietf.org/html/rfc7636#section-4.2)
|
||||
Challenge() string
|
||||
|
||||
// Method returns the code verifier's challenge method (see
|
||||
// https://tools.ietf.org/html/rfc7636#section-4.2)
|
||||
Method() ChallengeMethod
|
||||
|
||||
// Copy returns a copy of the verifier
|
||||
Copy() CodeVerifier
|
||||
}
|
||||
|
||||
// S256Verifier represents an OAuth PKCE code verifier that uses the S256
|
||||
// challenge method. It implements the CodeVerifier interface.
|
||||
type S256Verifier struct {
|
||||
verifier string
|
||||
challenge string
|
||||
method ChallengeMethod
|
||||
}
|
||||
|
||||
// min len of 43 chars per https://tools.ietf.org/html/rfc7636#section-4.1
|
||||
const verifierLen = 43
|
||||
|
||||
// NewCodeVerifier creates a new CodeVerifier (*S256Verifier).
|
||||
//
|
||||
// See: https://tools.ietf.org/html/rfc7636#section-4.1
|
||||
func NewCodeVerifier() (*S256Verifier, error) {
|
||||
const op = "NewCodeVerifier"
|
||||
data, err := base62.Random(verifierLen)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: unable to create verifier data %w", op, err)
|
||||
}
|
||||
v := &S256Verifier{
|
||||
verifier: data, // no need to encode it, since bas62.Random uses a limited set of characters.
|
||||
method: S256,
|
||||
}
|
||||
if v.challenge, err = CreateCodeChallenge(v); err != nil {
|
||||
return nil, fmt.Errorf("%s: unable to create code challenge: %w", op, err)
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (v *S256Verifier) Verifier() string { return v.verifier } // Verifier implements the CodeVerifier.Verifier() interface function.
|
||||
func (v *S256Verifier) Challenge() string { return v.challenge } // Challenge implements the CodeVerifier.Challenge() interface function.
|
||||
func (v *S256Verifier) Method() ChallengeMethod { return v.method } // Method implements the CodeVerifier.Method() interface function.
|
||||
|
||||
// Copy returns a copy of the verifier.
|
||||
func (v *S256Verifier) Copy() CodeVerifier {
|
||||
return &S256Verifier{
|
||||
verifier: v.verifier,
|
||||
challenge: v.challenge,
|
||||
method: v.method,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateCodeChallenge creates a code challenge from the verifier. Supported
|
||||
// ChallengeMethods: S256
|
||||
//
|
||||
// See: https://tools.ietf.org/html/rfc7636#section-4.2
|
||||
func CreateCodeChallenge(v CodeVerifier) (string, error) {
|
||||
// currently, we only support S256
|
||||
if v.Method() != S256 {
|
||||
return "", fmt.Errorf("CreateCodeChallenge: %s is invalid: %w", v.Method(), ErrUnsupportedChallengeMethod)
|
||||
}
|
||||
h := sha256.New()
|
||||
_, _ = h.Write([]byte(v.Verifier())) // hash documents that Write will never return an Error
|
||||
sum := h.Sum(nil)
|
||||
return base64.RawURLEncoding.EncodeToString(sum), nil
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
package oidc
|
||||
|
||||
// Prompt is a string values that specifies whether the Authorization Server
|
||||
// prompts the End-User for reauthentication and consent.
|
||||
//
|
||||
// See: https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
||||
type Prompt string
|
||||
|
||||
const (
|
||||
// Defined the Prompt values that specifies whether the Authorization Server
|
||||
// prompts the End-User for reauthentication and consent.
|
||||
//
|
||||
// See: https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
||||
None Prompt = "none"
|
||||
Login Prompt = "login"
|
||||
Consent Prompt = "consent"
|
||||
SelectAccount Prompt = "select_account"
|
||||
)
|
|
@ -0,0 +1,655 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/hashicorp/cap/oidc/internal/strutils"
|
||||
"github.com/hashicorp/go-cleanhttp"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// Provider provides integration with an OIDC provider.
|
||||
// It's primary capabilities include:
|
||||
// * Kicking off a user authentication via either the authorization code flow
|
||||
// (with optional PKCE) or implicit flow via the URL from p.AuthURL(...)
|
||||
//
|
||||
// * The authorization code flow (with optional PKCE) by exchanging an auth
|
||||
// code for tokens in p.Exchange(...)
|
||||
//
|
||||
// * Verifying an id_token issued by a provider with p.VerifyIDToken(...)
|
||||
//
|
||||
// * Retrieving a user's OAuth claims with p.UserInfo(...)
|
||||
type Provider struct {
|
||||
config *Config
|
||||
provider *oidc.Provider
|
||||
|
||||
// client uses a pooled transport that uses the config's ProviderCA if
|
||||
// provided, otherwise it will use the installed system CA chain. This
|
||||
// client's resources idle connections are closed in Provider.Done()
|
||||
client *http.Client
|
||||
|
||||
mu sync.Mutex
|
||||
|
||||
// backgroundCtx is the context used by the provider for background
|
||||
// activities like: refreshing JWKs Key sets, refreshing tokens, etc
|
||||
backgroundCtx context.Context
|
||||
|
||||
// backgroundCtxCancel is used to cancel any background activities running
|
||||
// in spawned go routines.
|
||||
backgroundCtxCancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewProvider creates and initializes a Provider. Intializing the provider,
|
||||
// includes making an http request to the provider's issuer.
|
||||
//
|
||||
// See Provider.Done() which must be called to release provider resources.
|
||||
func NewProvider(c *Config) (*Provider, error) {
|
||||
const op = "NewProvider"
|
||||
if c == nil {
|
||||
return nil, fmt.Errorf("%s: provider config is nil: %w", op, ErrNilParameter)
|
||||
}
|
||||
if err := c.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("%s: provider config is invalid: %w", op, err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// initializing the Provider with it's background ctx/cancel will
|
||||
// allow us to use p.Stop() to release any resources when returning errors
|
||||
// from this function.
|
||||
p := &Provider{
|
||||
config: c,
|
||||
backgroundCtx: ctx,
|
||||
backgroundCtxCancel: cancel,
|
||||
}
|
||||
|
||||
oidcCtx, err := p.HTTPClientContext(p.backgroundCtx)
|
||||
if err != nil {
|
||||
p.Done() // release the backgroundCtxCancel resources
|
||||
return nil, fmt.Errorf("%s: unable to create http client: %w", op, err)
|
||||
}
|
||||
|
||||
provider, err := oidc.NewProvider(oidcCtx, c.Issuer) // makes http req to issuer for discovery
|
||||
if err != nil {
|
||||
p.Done() // release the backgroundCtxCancel resources
|
||||
// we don't know what's causing the problem, so we won't classify the
|
||||
// error with a Kind
|
||||
return nil, fmt.Errorf("%s: unable to create provider: %w", op, err)
|
||||
}
|
||||
p.provider = provider
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// Done with the provider's background resources and must be called for every
|
||||
// Provider created
|
||||
func (p *Provider) Done() {
|
||||
// checking for nil here prevents a panic when developers neglect to check
|
||||
// the for an error before deferring a call to p.Done():
|
||||
// p, err := NewProvider(...)
|
||||
// defer p.Done()
|
||||
// if err != nil { ... }
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.backgroundCtxCancel != nil {
|
||||
p.backgroundCtxCancel()
|
||||
p.backgroundCtxCancel = nil
|
||||
}
|
||||
|
||||
// release the http.Client's pooled transport resources.
|
||||
if p.client != nil {
|
||||
p.client.CloseIdleConnections()
|
||||
}
|
||||
}
|
||||
|
||||
// AuthURL will generate a URL the caller can use to kick off an OIDC
|
||||
// authorization code (with optional PKCE) or an implicit flow with an IdP.
|
||||
//
|
||||
// See NewRequest() to create an oidc flow Request with a valid state and Nonce that
|
||||
// will uniquely identify the user's authentication attempt throughout the flow.
|
||||
func (p *Provider) AuthURL(ctx context.Context, oidcRequest Request) (url string, e error) {
|
||||
const op = "Provider.AuthURL"
|
||||
if oidcRequest.State() == "" {
|
||||
return "", fmt.Errorf("%s: request id is empty: %w", op, ErrInvalidParameter)
|
||||
}
|
||||
if oidcRequest.Nonce() == "" {
|
||||
return "", fmt.Errorf("%s: request nonce is empty: %w", op, ErrInvalidParameter)
|
||||
}
|
||||
if oidcRequest.State() == oidcRequest.Nonce() {
|
||||
return "", fmt.Errorf("%s: request id and nonce cannot be equal: %w", op, ErrInvalidParameter)
|
||||
}
|
||||
withImplicit, withImplicitAccessToken := oidcRequest.ImplicitFlow()
|
||||
if oidcRequest.PKCEVerifier() != nil && withImplicit {
|
||||
return "", fmt.Errorf("%s: request requests both implicit flow and authorization code with PKCE: %w", op, ErrInvalidParameter)
|
||||
}
|
||||
if oidcRequest.RedirectURL() == "" {
|
||||
return "", fmt.Errorf("%s: request redirect URL is empty: %w", op, ErrInvalidParameter)
|
||||
}
|
||||
if err := p.validRedirect(oidcRequest.RedirectURL()); err != nil {
|
||||
return "", fmt.Errorf("%s: %w", op, err)
|
||||
}
|
||||
var scopes []string
|
||||
switch {
|
||||
case len(oidcRequest.Scopes()) > 0:
|
||||
scopes = oidcRequest.Scopes()
|
||||
default:
|
||||
scopes = p.config.Scopes
|
||||
}
|
||||
// Add the "openid" scope, which is a required scope for oidc flows
|
||||
if !strutils.StrListContains(scopes, oidc.ScopeOpenID) {
|
||||
scopes = append([]string{oidc.ScopeOpenID}, scopes...)
|
||||
}
|
||||
|
||||
// Configure an OpenID Connect aware OAuth2 client
|
||||
oauth2Config := oauth2.Config{
|
||||
ClientID: p.config.ClientID,
|
||||
ClientSecret: string(p.config.ClientSecret),
|
||||
RedirectURL: oidcRequest.RedirectURL(),
|
||||
Endpoint: p.provider.Endpoint(),
|
||||
Scopes: scopes,
|
||||
}
|
||||
authCodeOpts := []oauth2.AuthCodeOption{
|
||||
oidc.Nonce(oidcRequest.Nonce()),
|
||||
}
|
||||
if withImplicit {
|
||||
reqTokens := []string{"id_token"}
|
||||
if withImplicitAccessToken {
|
||||
reqTokens = append(reqTokens, "token")
|
||||
}
|
||||
authCodeOpts = append(authCodeOpts, oauth2.SetAuthURLParam("response_mode", "form_post"), oauth2.SetAuthURLParam("response_type", strings.Join(reqTokens, " ")))
|
||||
}
|
||||
if oidcRequest.PKCEVerifier() != nil {
|
||||
authCodeOpts = append(authCodeOpts, oauth2.SetAuthURLParam("code_challenge", oidcRequest.PKCEVerifier().Challenge()), oauth2.SetAuthURLParam("code_challenge_method", string(oidcRequest.PKCEVerifier().Method())))
|
||||
}
|
||||
if secs, exp := oidcRequest.MaxAge(); !exp.IsZero() {
|
||||
authCodeOpts = append(authCodeOpts, oauth2.SetAuthURLParam("max_age", strconv.Itoa(int(secs))))
|
||||
}
|
||||
if len(oidcRequest.Prompts()) > 0 {
|
||||
prompts := make([]string, 0, len(oidcRequest.Prompts()))
|
||||
for _, v := range oidcRequest.Prompts() {
|
||||
prompts = append(prompts, string(v))
|
||||
}
|
||||
prompts = strutils.RemoveDuplicatesStable(prompts, false)
|
||||
if strutils.StrListContains(prompts, string(None)) && len(prompts) > 1 {
|
||||
return "", fmt.Errorf(`%s: prompts (%s) includes "none" with other values: %w`, op, prompts, ErrInvalidParameter)
|
||||
}
|
||||
authCodeOpts = append(authCodeOpts, oauth2.SetAuthURLParam("prompt", strings.Join(prompts, " ")))
|
||||
}
|
||||
if oidcRequest.Display() != "" {
|
||||
authCodeOpts = append(authCodeOpts, oauth2.SetAuthURLParam("display", string(oidcRequest.Display())))
|
||||
}
|
||||
if len(oidcRequest.UILocales()) > 0 {
|
||||
locales := make([]string, 0, len(oidcRequest.UILocales()))
|
||||
for _, l := range oidcRequest.UILocales() {
|
||||
locales = append(locales, string(l.String()))
|
||||
}
|
||||
authCodeOpts = append(authCodeOpts, oauth2.SetAuthURLParam("ui_locales", strings.Join(locales, " ")))
|
||||
}
|
||||
if len(oidcRequest.Claims()) > 0 {
|
||||
authCodeOpts = append(authCodeOpts, oauth2.SetAuthURLParam("claims", string(oidcRequest.Claims())))
|
||||
}
|
||||
if len(oidcRequest.ACRValues()) > 0 {
|
||||
authCodeOpts = append(authCodeOpts, oauth2.SetAuthURLParam("acr_values", strings.Join(oidcRequest.ACRValues(), " ")))
|
||||
}
|
||||
return oauth2Config.AuthCodeURL(oidcRequest.State(), authCodeOpts...), nil
|
||||
}
|
||||
|
||||
// Exchange will request a token from the oidc token endpoint, using the
|
||||
// authorizationCode and authorizationState it received in an earlier successful
|
||||
// oidc authentication response.
|
||||
//
|
||||
// Exchange will use PKCE when the user's oidc Request specifies its use.
|
||||
//
|
||||
// It will also validate the authorizationState it receives against the
|
||||
// existing Request for the user's oidc authentication flow.
|
||||
//
|
||||
// On success, the Token returned will include an IDToken and may
|
||||
// include an AccessToken and RefreshToken.
|
||||
//
|
||||
// Any tokens returned will have been verified.
|
||||
// See: Provider.VerifyIDToken for info about id_token verification.
|
||||
//
|
||||
// When present, the id_token at_hash claim is verified against the
|
||||
// access_token. (see:
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation)
|
||||
//
|
||||
// The id_token c_hash claim is verified when present.
|
||||
func (p *Provider) Exchange(ctx context.Context, oidcRequest Request, authorizationState string, authorizationCode string) (*Tk, error) {
|
||||
const op = "Provider.Exchange"
|
||||
if p.config == nil {
|
||||
return nil, fmt.Errorf("%s: provider config is nil: %w", op, ErrNilParameter)
|
||||
}
|
||||
if oidcRequest == nil {
|
||||
return nil, fmt.Errorf("%s: request is nil: %w", op, ErrNilParameter)
|
||||
}
|
||||
if withImplicit, _ := oidcRequest.ImplicitFlow(); withImplicit {
|
||||
return nil, fmt.Errorf("%s: request (%s) should not be using the implicit flow: %w", op, oidcRequest.State(), ErrInvalidFlow)
|
||||
}
|
||||
if oidcRequest.State() != authorizationState {
|
||||
return nil, fmt.Errorf("%s: authentication request state and authorization state are not equal: %w", op, ErrInvalidParameter)
|
||||
}
|
||||
if oidcRequest.RedirectURL() == "" {
|
||||
return nil, fmt.Errorf("%s: authentication request redirect URL is empty: %w", op, ErrInvalidParameter)
|
||||
}
|
||||
if err := p.validRedirect(oidcRequest.RedirectURL()); err != nil {
|
||||
return nil, fmt.Errorf("%s: %w", op, err)
|
||||
}
|
||||
if oidcRequest.IsExpired() {
|
||||
return nil, fmt.Errorf("%s: authentication request is expired: %w", op, ErrInvalidParameter)
|
||||
}
|
||||
|
||||
oidcCtx, err := p.HTTPClientContext(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: unable to create http client: %w", op, err)
|
||||
}
|
||||
var scopes []string
|
||||
switch {
|
||||
case len(oidcRequest.Scopes()) > 0:
|
||||
scopes = oidcRequest.Scopes()
|
||||
default:
|
||||
scopes = p.config.Scopes
|
||||
}
|
||||
// Add the "openid" scope, which is a required scope for oidc flows
|
||||
scopes = append([]string{oidc.ScopeOpenID}, scopes...)
|
||||
var oauth2Config = oauth2.Config{
|
||||
ClientID: p.config.ClientID,
|
||||
ClientSecret: string(p.config.ClientSecret),
|
||||
RedirectURL: oidcRequest.RedirectURL(),
|
||||
Endpoint: p.provider.Endpoint(),
|
||||
Scopes: scopes,
|
||||
}
|
||||
var authCodeOpts []oauth2.AuthCodeOption
|
||||
if oidcRequest.PKCEVerifier() != nil {
|
||||
authCodeOpts = append(authCodeOpts, oauth2.SetAuthURLParam("code_verifier", oidcRequest.PKCEVerifier().Verifier()))
|
||||
}
|
||||
oauth2Token, err := oauth2Config.Exchange(oidcCtx, authorizationCode, authCodeOpts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: unable to exchange auth code with provider: %w", op, p.convertError(err))
|
||||
}
|
||||
|
||||
idToken, ok := oauth2Token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%s: id_token is missing from auth code exchange: %w", op, ErrMissingIDToken)
|
||||
}
|
||||
t, err := NewToken(IDToken(idToken), oauth2Token, WithNow(p.config.NowFunc))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: unable to create new id_token: %w", op, err)
|
||||
}
|
||||
claims, err := p.VerifyIDToken(ctx, t.IDToken(), oidcRequest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: id_token failed verification: %w", op, err)
|
||||
}
|
||||
if t.AccessToken() != "" {
|
||||
if _, err := t.IDToken().VerifyAccessToken(t.AccessToken()); err != nil {
|
||||
return nil, fmt.Errorf("%s: access_token failed verification: %w", op, err)
|
||||
}
|
||||
}
|
||||
|
||||
// when the optional c_hash claims is present it needs to be verified.
|
||||
c_hash, ok := claims["c_hash"].(string)
|
||||
if ok && c_hash != "" {
|
||||
_, err := t.IDToken().VerifyAuthorizationCode(authorizationCode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: code hash failed verification: %w", op, err)
|
||||
}
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// UserInfo gets the UserInfo claims from the provider using the token produced
|
||||
// by the tokenSource. Only JSON user info responses are supported (signed JWT
|
||||
// responses are not). The WithAudiences option is supported to specify
|
||||
// optional audiences to verify when the aud claim is present in the response.
|
||||
//
|
||||
// It verifies:
|
||||
// * sub (sub) is required and must match
|
||||
// * issuer (iss) - if the iss claim is included in returned claims
|
||||
// * audiences (aud) - if the aud claim is included in returned claims and
|
||||
// WithAudiences option is provided.
|
||||
//
|
||||
// See: https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse
|
||||
func (p *Provider) UserInfo(ctx context.Context, tokenSource oauth2.TokenSource, validSubject string, claims interface{}, opt ...Option) error {
|
||||
const op = "Provider.UserInfo"
|
||||
opts := getUserInfoOpts(opt...)
|
||||
|
||||
if tokenSource == nil {
|
||||
return fmt.Errorf("%s: token source is nil: %w", op, ErrNilParameter)
|
||||
}
|
||||
if claims == nil {
|
||||
return fmt.Errorf("%s: claims interface is nil: %w", op, ErrNilParameter)
|
||||
}
|
||||
if reflect.ValueOf(claims).Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("%s: interface parameter must to be a pointer: %w", op, ErrInvalidParameter)
|
||||
}
|
||||
oidcCtx, err := p.HTTPClientContext(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: unable to create http client: %w", op, err)
|
||||
}
|
||||
|
||||
userinfo, err := p.provider.UserInfo(oidcCtx, tokenSource)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: provider UserInfo request failed: %w", op, p.convertError(err))
|
||||
}
|
||||
type verifyClaims struct {
|
||||
Sub string
|
||||
Iss string
|
||||
Aud []string
|
||||
}
|
||||
var vc verifyClaims
|
||||
err = userinfo.Claims(&vc)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: failed to parse claims for UserInfo verification: %w", op, err)
|
||||
}
|
||||
// Subject is required to match
|
||||
if vc.Sub != validSubject {
|
||||
return fmt.Errorf("%s: %w", op, ErrInvalidSubject)
|
||||
}
|
||||
// optional issuer check...
|
||||
if vc.Iss != "" && vc.Iss != p.config.Issuer {
|
||||
return fmt.Errorf("%s: %w", op, ErrInvalidIssuer)
|
||||
}
|
||||
// optional audiences check...
|
||||
if len(opts.withAudiences) > 0 {
|
||||
if err := p.verifyAudience(opts.withAudiences, vc.Aud); err != nil {
|
||||
return fmt.Errorf("%s: %w", op, ErrInvalidAudience)
|
||||
}
|
||||
}
|
||||
|
||||
err = userinfo.Claims(&claims)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: failed to get UserInfo claims: %w", op, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// userInfoOptions is the set of available options for the Provider.UserInfo
|
||||
// function
|
||||
type userInfoOptions struct {
|
||||
withAudiences []string
|
||||
}
|
||||
|
||||
// userInfoDefaults is a handy way to get the defaults at runtime and during unit
|
||||
// tests.
|
||||
func userInfoDefaults() userInfoOptions {
|
||||
return userInfoOptions{}
|
||||
}
|
||||
|
||||
// getUserInfoOpts gets the provider.UserInfo defaults and applies the opt
|
||||
// overrides passed in
|
||||
func getUserInfoOpts(opt ...Option) userInfoOptions {
|
||||
opts := userInfoDefaults()
|
||||
ApplyOpts(&opts, opt...)
|
||||
return opts
|
||||
}
|
||||
|
||||
// VerifyIDToken will verify the inbound IDToken and return its claims.
|
||||
// It verifies:
|
||||
// * signature (including if a supported signing algorithm was used)
|
||||
// * issuer (iss)
|
||||
// * expiration (exp)
|
||||
// * issued at (iat) (with a leeway of 1 min)
|
||||
// * not before (nbf) (with a leeway of 1 min)
|
||||
// * nonce (nonce)
|
||||
// * audience (aud) contains all audiences required from the provider's config
|
||||
// * when there are multiple audiences (aud), then one of them must equal
|
||||
// the client_id
|
||||
// * when present, the authorized party (azp) must equal the client id
|
||||
// * when there are multiple audiences (aud), then the authorized party (azp)
|
||||
// must equal the client id
|
||||
// * when there is a single audience (aud) and it is not equal to the client
|
||||
// id, then the authorized party (azp) must equal the client id
|
||||
// * when max_age was requested, the auth_time claim is verified (with a leeway
|
||||
// of 1 min)
|
||||
//
|
||||
// See: https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
func (p *Provider) VerifyIDToken(ctx context.Context, t IDToken, oidcRequest Request, opt ...Option) (map[string]interface{}, error) {
|
||||
const op = "Provider.VerifyIDToken"
|
||||
if t == "" {
|
||||
return nil, fmt.Errorf("%s: id_token is empty: %w", op, ErrInvalidParameter)
|
||||
}
|
||||
if oidcRequest.Nonce() == "" {
|
||||
return nil, fmt.Errorf("%s: nonce is empty: %w", op, ErrInvalidParameter)
|
||||
}
|
||||
algs := []string{}
|
||||
for _, a := range p.config.SupportedSigningAlgs {
|
||||
algs = append(algs, string(a))
|
||||
}
|
||||
oidcConfig := &oidc.Config{
|
||||
SkipClientIDCheck: true,
|
||||
SupportedSigningAlgs: algs,
|
||||
Now: p.config.Now,
|
||||
}
|
||||
verifier := p.provider.Verifier(oidcConfig)
|
||||
nowTime := p.config.Now() // intialized right after the Verifier so there idea of nowTime sort of coresponds.
|
||||
leeway := 1 * time.Minute
|
||||
|
||||
// verifier.Verify will check the supported algs, signature, iss, exp, nbf.
|
||||
// aud will be checked later in this function.
|
||||
oidcIDToken, err := verifier.Verify(ctx, string(t))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: invalid id_token: %w", op, p.convertError(err))
|
||||
}
|
||||
// so.. we still need to check: nonce, iat, auth_time, azp, the aud includes
|
||||
// additional audiences configured.
|
||||
if oidcIDToken.Nonce != oidcRequest.Nonce() {
|
||||
return nil, fmt.Errorf("%s: invalid id_token nonce: %w", op, ErrInvalidNonce)
|
||||
}
|
||||
if nowTime.Add(leeway).Before(oidcIDToken.IssuedAt) {
|
||||
return nil, fmt.Errorf(
|
||||
"%s: invalid id_token current time %v before the iat (issued at) time %v: %w",
|
||||
op,
|
||||
nowTime,
|
||||
oidcIDToken.IssuedAt,
|
||||
ErrInvalidIssuedAt,
|
||||
)
|
||||
}
|
||||
|
||||
var audiences []string
|
||||
switch {
|
||||
case len(oidcRequest.Audiences()) > 0:
|
||||
audiences = oidcRequest.Audiences()
|
||||
default:
|
||||
audiences = p.config.Audiences
|
||||
}
|
||||
if err := p.verifyAudience(audiences, oidcIDToken.Audience); err != nil {
|
||||
return nil, fmt.Errorf("%s: invalid id_token audiences: %w", op, err)
|
||||
}
|
||||
if len(oidcIDToken.Audience) > 1 && !strutils.StrListContains(oidcIDToken.Audience, p.config.ClientID) {
|
||||
return nil, fmt.Errorf("%s: invalid id_token: multiple audiences (%s) and one of them is not equal client_id (%s): %w", op, oidcIDToken.Audience, p.config.ClientID, ErrInvalidAudience)
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
if err := t.Claims(&claims); err != nil {
|
||||
return nil, fmt.Errorf("%s: %w", op, err)
|
||||
}
|
||||
|
||||
azp, foundAzp := claims["azp"]
|
||||
if foundAzp {
|
||||
if azp != p.config.ClientID {
|
||||
return nil, fmt.Errorf("%s: invalid id_token: authorized party (%s) is not equal client_id (%s): %w", op, azp, p.config.ClientID, ErrInvalidAuthorizedParty)
|
||||
}
|
||||
}
|
||||
if len(oidcIDToken.Audience) > 1 && azp != p.config.ClientID {
|
||||
return nil, fmt.Errorf("%s: invalid id_token: multiple audiences and authorized party (%s) is not equal client_id (%s): %w", op, azp, p.config.ClientID, ErrInvalidAuthorizedParty)
|
||||
}
|
||||
if (len(oidcIDToken.Audience) == 1 && oidcIDToken.Audience[0] != p.config.ClientID) && azp != p.config.ClientID {
|
||||
return nil, fmt.Errorf(
|
||||
"%s: invalid id_token: one audience (%s) which is not the client_id (%s) and authorized party (%s) is not equal client_id (%s): %w",
|
||||
op,
|
||||
oidcIDToken.Audience[0],
|
||||
p.config.ClientID,
|
||||
azp,
|
||||
p.config.ClientID,
|
||||
ErrInvalidAuthorizedParty)
|
||||
}
|
||||
|
||||
if secs, authAfter := oidcRequest.MaxAge(); !authAfter.IsZero() {
|
||||
atClaim, ok := claims["auth_time"].(float64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%s: missing auth_time claim when max age was requested: %w", op, ErrMissingClaim)
|
||||
}
|
||||
authTime := time.Unix(int64(atClaim), 0)
|
||||
if !authTime.Add(leeway).After(authAfter) {
|
||||
return nil, fmt.Errorf("%s: auth_time (%s) is beyond max age (%d): %w", op, authTime, secs, ErrExpiredAuthTime)
|
||||
}
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// verifyAudience simply verified that the aud claim against the allowed
|
||||
// audiences.
|
||||
func (p *Provider) verifyAudience(allowedAudiences, audienceClaim []string) error {
|
||||
const op = "verifyAudiences"
|
||||
if len(allowedAudiences) > 0 {
|
||||
found := false
|
||||
for _, v := range allowedAudiences {
|
||||
if strutils.StrListContains(audienceClaim, v) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return fmt.Errorf("%s: invalid id_token audiences: %w", op, ErrInvalidAudience)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// convertError is used to convert errors from the core-os and oauth2 library
|
||||
// calls of: provider.Exchange, verifier.Verify and provider.UserInfo
|
||||
func (p *Provider) convertError(e error) error {
|
||||
switch {
|
||||
case strings.Contains(e.Error(), "id token issued by a different provider"):
|
||||
return fmt.Errorf("%s: %w", e.Error(), ErrInvalidIssuer)
|
||||
case strings.Contains(e.Error(), "signed with unsupported algorithm"):
|
||||
return fmt.Errorf("%s: %w", e.Error(), ErrUnsupportedAlg)
|
||||
case strings.Contains(e.Error(), "before the nbf (not before) time"):
|
||||
return fmt.Errorf("%s: %w", e.Error(), ErrInvalidNotBefore)
|
||||
case strings.Contains(e.Error(), "before the iat (issued at) time"):
|
||||
return fmt.Errorf("%s: %w", e.Error(), ErrInvalidIssuedAt)
|
||||
case strings.Contains(e.Error(), "token is expired"):
|
||||
return fmt.Errorf("%s: %w", e.Error(), ErrExpiredToken)
|
||||
case strings.Contains(e.Error(), "failed to verify id token signature"):
|
||||
return fmt.Errorf("%s: %w", e.Error(), ErrInvalidSignature)
|
||||
case strings.Contains(e.Error(), "failed to decode keys"):
|
||||
return fmt.Errorf("%s: %w", e.Error(), ErrInvalidJWKs)
|
||||
case strings.Contains(e.Error(), "get keys failed"):
|
||||
return fmt.Errorf("%s: %w", e.Error(), ErrInvalidJWKs)
|
||||
case strings.Contains(e.Error(), "server response missing access_token"):
|
||||
return fmt.Errorf("%s: %w", e.Error(), ErrMissingAccessToken)
|
||||
case strings.Contains(e.Error(), "404 Not Found"):
|
||||
return fmt.Errorf("%s: %w", e.Error(), ErrNotFound)
|
||||
default:
|
||||
return e
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPClient returns an http.Client for the provider. The returned client uses
|
||||
// a pooled transport (so it can reuse connections) that uses the provider's
|
||||
// config CA certificate PEM if provided, otherwise it will use the installed
|
||||
// system CA chain. This client's idle connections are closed in
|
||||
// Provider.Done()
|
||||
func (p *Provider) HTTPClient() (*http.Client, error) {
|
||||
const op = "Provider.NewHTTPClient"
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.client != nil {
|
||||
return p.client, nil
|
||||
}
|
||||
// since it's called by the provider factory, we need to check that the
|
||||
// config isn't nil
|
||||
if p.config == nil {
|
||||
return nil, fmt.Errorf("%s: the provider's config is nil %w", op, ErrNilParameter)
|
||||
}
|
||||
|
||||
// use the cleanhttp package to create a "pooled" transport that's better
|
||||
// configured for requests that re-use the same provider host. Among other
|
||||
// things, this transport supports better concurrency when making requests
|
||||
// to the same host. On the downside, this transport can leak file
|
||||
// descriptors over time, so we'll be sure to call
|
||||
// client.CloseIdleConnections() in the Provider.Done() to stave that off.
|
||||
tr := cleanhttp.DefaultPooledTransport()
|
||||
|
||||
if p.config.ProviderCA != "" {
|
||||
certPool := x509.NewCertPool()
|
||||
if ok := certPool.AppendCertsFromPEM([]byte(p.config.ProviderCA)); !ok {
|
||||
return nil, fmt.Errorf("%s: %w", op, ErrInvalidCACert)
|
||||
}
|
||||
|
||||
tr.TLSClientConfig = &tls.Config{
|
||||
RootCAs: certPool,
|
||||
}
|
||||
}
|
||||
|
||||
c := &http.Client{
|
||||
Transport: tr,
|
||||
}
|
||||
p.client = c
|
||||
return p.client, nil
|
||||
}
|
||||
|
||||
// HTTPClientContext returns a new Context that carries the provider's HTTP
|
||||
// client. This method sets the same context key used by the
|
||||
// github.com/coreos/go-oidc and golang.org/x/oauth2 packages, so the returned
|
||||
// context works for those packages as well.
|
||||
func (p *Provider) HTTPClientContext(ctx context.Context) (context.Context, error) {
|
||||
const op = "Provider.HTTPClientContext"
|
||||
c, err := p.HTTPClient()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: %w", op, err)
|
||||
|
||||
}
|
||||
// simple to implement as a wrapper for the coreos package
|
||||
return oidc.ClientContext(ctx, c), nil
|
||||
}
|
||||
|
||||
// validRedirect checks whether uri is in allowed using special handling for
|
||||
// loopback uris. Ref: https://tools.ietf.org/html/rfc8252#section-7.3
|
||||
func (p *Provider) validRedirect(uri string) error {
|
||||
const op = "Provider.validRedirect"
|
||||
if len(p.config.AllowedRedirectURLs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
inputURI, err := url.Parse(uri)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: redirect URI %s is an invalid URI %s: %w", op, uri, err.Error(), ErrInvalidParameter)
|
||||
}
|
||||
|
||||
// if uri isn't a loopback, just string search the allowed list
|
||||
if !strutils.StrListContains([]string{"localhost", "127.0.0.1", "::1"}, inputURI.Hostname()) {
|
||||
if !strutils.StrListContains(p.config.AllowedRedirectURLs, uri) {
|
||||
return fmt.Errorf("%s: redirect URI %s: %w", op, uri, ErrUnauthorizedRedirectURI)
|
||||
}
|
||||
}
|
||||
|
||||
// otherwise, search for a match in a port-agnostic manner, per the OAuth RFC.
|
||||
inputURI.Host = inputURI.Hostname()
|
||||
|
||||
for _, a := range p.config.AllowedRedirectURLs {
|
||||
allowedURI, err := url.Parse(a)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: allowed redirect URI %s is an invalid URI %s: %w", op, allowedURI, err.Error(), ErrInvalidParameter)
|
||||
}
|
||||
allowedURI.Host = allowedURI.Hostname()
|
||||
|
||||
if inputURI.String() == allowedURI.String() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("%s: redirect URI %s: %w", op, uri, ErrUnauthorizedRedirectURI)
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
package oidc
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// RefreshToken is an oauth refresh_token.
|
||||
// See https://tools.ietf.org/html/rfc6749#section-1.5.
|
||||
type RefreshToken string
|
||||
|
||||
// RedactedRefreshToken is the redacted string or json for an oauth refresh_token.
|
||||
const RedactedRefreshToken = "[REDACTED: refresh_token]"
|
||||
|
||||
// String will redact the token.
|
||||
func (t RefreshToken) String() string {
|
||||
return RedactedRefreshToken
|
||||
}
|
||||
|
||||
// MarshalJSON will redact the token.
|
||||
func (t RefreshToken) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(RedactedRefreshToken)
|
||||
}
|
|
@ -0,0 +1,627 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
// Request basically represents one OIDC authentication flow for a user. It
|
||||
// contains the data needed to uniquely represent that one-time flow across the
|
||||
// multiple interactions needed to complete the OIDC flow the user is
|
||||
// attempting.
|
||||
//
|
||||
// Request() is passed throughout the OIDC interactions to uniquely identify the
|
||||
// flow's request. The Request.State() and Request.Nonce() cannot be equal, and
|
||||
// will be used during the OIDC flow to prevent CSRF and replay attacks (see the
|
||||
// oidc spec for specifics).
|
||||
//
|
||||
// Audiences and Scopes are optional overrides of configured provider defaults
|
||||
// for specific authentication attempts
|
||||
type Request interface {
|
||||
// State is a unique identifier and an opaque value used to maintain request
|
||||
// between the oidc request and the callback. State cannot equal the Nonce.
|
||||
// See https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest.
|
||||
State() string
|
||||
|
||||
// Nonce is a unique nonce and a string value used to associate a Client
|
||||
// session with an ID Token, and to mitigate replay attacks. Nonce cannot
|
||||
// equal the ID.
|
||||
// See https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
||||
// and https://openid.net/specs/openid-connect-core-1_0.html#NonceNotes.
|
||||
Nonce() string
|
||||
|
||||
// IsExpired returns true if the request has expired. Implementations should
|
||||
// support a time skew (perhaps RequestExpirySkew) when checking expiration.
|
||||
IsExpired() bool
|
||||
|
||||
// Audiences is an specific authentication attempt's list of optional
|
||||
// case-sensitive strings to use when verifying an id_token's "aud" claim
|
||||
// (which is also a list). If provided, the audiences of an id_token must
|
||||
// match one of the configured audiences. If a Request does not have
|
||||
// audiences, then the configured list of default audiences will be used.
|
||||
Audiences() []string
|
||||
|
||||
// Scopes is a specific authentication attempt's list of optional
|
||||
// scopes to request of the provider. The required "oidc" scope is requested
|
||||
// by default, and does not need to be part of this optional list. If a
|
||||
// Request does not have Scopes, then the configured list of default
|
||||
// requested scopes will be used.
|
||||
Scopes() []string
|
||||
|
||||
// RedirectURL is a URL where providers will redirect responses to
|
||||
// authentication requests.
|
||||
RedirectURL() string
|
||||
|
||||
// ImplicitFlow indicates whether or not to use the implicit flow with form
|
||||
// post. Getting only an id_token for an implicit flow should be the
|
||||
// default for implementations, but at times it's necessary to also request
|
||||
// an access_token, so this function and the WithImplicitFlow(...) option
|
||||
// allows for those scenarios. Overall, it is recommend to not request
|
||||
// access_tokens during the implicit flow. If you need an access_token,
|
||||
// then use the authorization code flows and if you can't secure a client
|
||||
// secret then use the authorization code flow with PKCE.
|
||||
//
|
||||
// The first returned bool represents if the implicit flow has been requested.
|
||||
// The second returned bool represents if an access token has been requested
|
||||
// during the implicit flow.
|
||||
//
|
||||
// See: https://openid.net/specs/openid-connect-core-1_0.html#ImplicitFlowAuth
|
||||
// See: https://openid.net/specs/oauth-v2-form-post-response-mode-1_0.html
|
||||
ImplicitFlow() (useImplicitFlow bool, includeAccessToken bool)
|
||||
|
||||
// PKCEVerifier indicates whether or not to use the authorization code flow
|
||||
// with PKCE. PKCE should be used for any client which cannot secure a
|
||||
// client secret (SPA and native apps) or is susceptible to authorization
|
||||
// code intercept attacks. When supported by your OIDC provider, PKCE should
|
||||
// be used instead of the implicit flow.
|
||||
//
|
||||
// See: https://tools.ietf.org/html/rfc7636
|
||||
PKCEVerifier() CodeVerifier
|
||||
|
||||
// MaxAge: when authAfter is not a zero value (authTime.IsZero()) then the
|
||||
// id_token's auth_time claim must be after the specified time.
|
||||
//
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
||||
MaxAge() (seconds uint, authAfter time.Time)
|
||||
|
||||
// Prompts optionally defines a list of values that specifies whether the
|
||||
// Authorization Server prompts the End-User for reauthentication and
|
||||
// consent. See MaxAge() if wish to specify an allowable elapsed time in
|
||||
// seconds since the last time the End-User was actively authenticated by
|
||||
// the OP.
|
||||
//
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
||||
Prompts() []Prompt
|
||||
|
||||
// Display optionally specifies how the Authorization Server displays the
|
||||
// authentication and consent user interface pages to the End-User.
|
||||
//
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
||||
Display() Display
|
||||
|
||||
// UILocales optionally specifies End-User's preferred languages via
|
||||
// language Tags, ordered by preference.
|
||||
//
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
||||
UILocales() []language.Tag
|
||||
|
||||
// Claims optionally requests that specific claims be returned using
|
||||
// the claims parameter.
|
||||
//
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#ClaimsParameter
|
||||
Claims() []byte
|
||||
|
||||
// ACRValues() optionally specifies the acr values that the Authorization
|
||||
// Server is being requested to use for processing this Authentication
|
||||
// Request, with the values appearing in order of preference.
|
||||
//
|
||||
// NOTE: Requested acr_values are not verified by the Provider.Exchange(...)
|
||||
// or Provider.VerifyIDToken() functions, since the request/return values
|
||||
// are determined by the provider's implementation. You'll need to verify
|
||||
// the claims returned yourself based on values provided by you OIDC
|
||||
// Provider's documentation.
|
||||
//
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
||||
ACRValues() []string
|
||||
}
|
||||
|
||||
// Req represents the oidc request used for oidc flows and implements the Request interface.
|
||||
type Req struct {
|
||||
// state is a unique identifier and an opaque value used to maintain request
|
||||
// between the oidc request and the callback.
|
||||
state string
|
||||
|
||||
// nonce is a unique nonce and suitable for use as an oidc nonce.
|
||||
nonce string
|
||||
|
||||
// Expiration is the expiration time for the Request.
|
||||
expiration time.Time
|
||||
|
||||
// redirectURL is a URL where providers will redirect responses to
|
||||
// authentication requests.
|
||||
redirectURL string
|
||||
|
||||
// scopes is a specific authentication attempt's list of optional
|
||||
// scopes to request of the provider. The required "oidc" scope is requested
|
||||
// by default, and does not need to be part of this optional list. If a
|
||||
// Request does not have Scopes, then the configured list of default
|
||||
// requested scopes will be used.
|
||||
scopes []string
|
||||
|
||||
// audiences is an specific authentication attempt's list of optional
|
||||
// case-sensitive strings to use when verifying an id_token's "aud" claim
|
||||
// (which is also a list). If provided, the audiences of an id_token must
|
||||
// match one of the configured audiences. If a Request does not have
|
||||
// audiences, then the configured list of default audiences will be used.
|
||||
audiences []string
|
||||
|
||||
// nowFunc is an optional function that returns the current time
|
||||
nowFunc func() time.Time
|
||||
|
||||
// withImplicit indicates whether or not to use the implicit flow. Getting
|
||||
// only an id_token for an implicit flow is the default. If an access_token
|
||||
// is also required, then withImplicit.withAccessToken will be true. It
|
||||
// is recommend to not request access_tokens during the implicit flow. If
|
||||
// you need an access_token, then use the authorization code flows (with
|
||||
// optional PKCE).
|
||||
withImplicit *implicitFlow
|
||||
|
||||
// withVerifier indicates whether or not to use the authorization code flow
|
||||
// with PKCE. It suppies the required CodeVerifier for PKCE.
|
||||
withVerifier CodeVerifier
|
||||
|
||||
// withMaxAge: when withMaxAge.authAfter is not a zero value
|
||||
// (authTime.IsZero()) then the id_token's auth_time claim must be after the
|
||||
// specified time.
|
||||
withMaxAge *maxAge
|
||||
|
||||
// withPrompts optionally defines a list of values that specifies whether
|
||||
// the Authorization Server prompts the End-User for reauthentication and
|
||||
// consent.
|
||||
withPrompts []Prompt
|
||||
|
||||
// withDisplay optionally specifies how the Authorization Server displays the
|
||||
// authentication and consent user interface pages to the End-User.
|
||||
withDisplay Display
|
||||
|
||||
// withUILocales optionally specifies End-User's preferred languages via
|
||||
// language Tags, ordered by preference.
|
||||
withUILocales []language.Tag
|
||||
|
||||
// withClaims optionally requests that specific claims be returned
|
||||
// using the claims parameter.
|
||||
withClaims []byte
|
||||
|
||||
// withACRValues() optionally specifies the acr values that the Authorization
|
||||
// Server is being requested to use for processing this Authentication
|
||||
// Request, with the values appearing in order of preference.
|
||||
withACRValues []string
|
||||
}
|
||||
|
||||
// ensure that Request implements the Request interface.
|
||||
var _ Request = (*Req)(nil)
|
||||
|
||||
// NewRequest creates a new Request (*Req).
|
||||
// Supports the options:
|
||||
// * WithState
|
||||
// * WithNow
|
||||
// * WithAudiences
|
||||
// * WithScopes
|
||||
// * WithImplicit
|
||||
// * WithPKCE
|
||||
// * WithMaxAge
|
||||
// * WithPrompts
|
||||
// * WithDisplay
|
||||
// * WithUILocales
|
||||
// * WithClaims
|
||||
func NewRequest(expireIn time.Duration, redirectURL string, opt ...Option) (*Req, error) {
|
||||
const op = "oidc.NewRequest"
|
||||
opts := getReqOpts(opt...)
|
||||
if redirectURL == "" {
|
||||
return nil, fmt.Errorf("%s: redirect URL is empty: %w", op, ErrInvalidParameter)
|
||||
}
|
||||
nonce, err := NewID(WithPrefix("n"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: unable to generate a request's nonce: %w", op, err)
|
||||
}
|
||||
|
||||
var state string
|
||||
switch {
|
||||
case opts.withState != "":
|
||||
state = opts.withState
|
||||
default:
|
||||
var err error
|
||||
state, err = NewID(WithPrefix("st"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: unable to generate a request's state: %w", op, err)
|
||||
}
|
||||
}
|
||||
|
||||
if expireIn == 0 || expireIn < 0 {
|
||||
return nil, fmt.Errorf("%s: expireIn not greater than zero: %w", op, ErrInvalidParameter)
|
||||
}
|
||||
if opts.withVerifier != nil && opts.withImplicitFlow != nil {
|
||||
return nil, fmt.Errorf("%s: requested both implicit flow and authorization code with PKCE: %w", op, ErrInvalidParameter)
|
||||
}
|
||||
r := &Req{
|
||||
state: state,
|
||||
nonce: nonce,
|
||||
redirectURL: redirectURL,
|
||||
nowFunc: opts.withNowFunc,
|
||||
audiences: opts.withAudiences,
|
||||
scopes: opts.withScopes,
|
||||
withImplicit: opts.withImplicitFlow,
|
||||
withVerifier: opts.withVerifier,
|
||||
withPrompts: opts.withPrompts,
|
||||
withDisplay: opts.withDisplay,
|
||||
withUILocales: opts.withUILocales,
|
||||
withClaims: opts.withClaims,
|
||||
withACRValues: opts.withACRValues,
|
||||
}
|
||||
r.expiration = r.now().Add(expireIn)
|
||||
if opts.withMaxAge != nil {
|
||||
opts.withMaxAge.authAfter = r.now().Add(time.Duration(-opts.withMaxAge.seconds) * time.Second)
|
||||
r.withMaxAge = opts.withMaxAge
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// State implements the Request.State() interface function.
|
||||
func (r *Req) State() string { return r.state }
|
||||
|
||||
// Nonce implements the Request.Nonce() interface function.
|
||||
func (r *Req) Nonce() string { return r.nonce }
|
||||
|
||||
// Audiences implements the Request.Audiences() interface function and returns a
|
||||
// copy of the audiences.
|
||||
func (r *Req) Audiences() []string {
|
||||
if r.audiences == nil {
|
||||
return nil
|
||||
}
|
||||
cp := make([]string, len(r.audiences))
|
||||
copy(cp, r.audiences)
|
||||
return cp
|
||||
}
|
||||
|
||||
// Scopes implements the Request.Scopes() interface function and returns a copy of
|
||||
// the scopes.
|
||||
func (r *Req) Scopes() []string {
|
||||
if r.scopes == nil {
|
||||
return nil
|
||||
}
|
||||
cp := make([]string, len(r.scopes))
|
||||
copy(cp, r.scopes)
|
||||
return cp
|
||||
}
|
||||
|
||||
// RedirectURL implements the Request.RedirectURL() interface function.
|
||||
func (r *Req) RedirectURL() string { return r.redirectURL }
|
||||
|
||||
// PKCEVerifier implements the Request.PKCEVerifier() interface function and
|
||||
// returns a copy of the CodeVerifier
|
||||
func (r *Req) PKCEVerifier() CodeVerifier {
|
||||
if r.withVerifier == nil {
|
||||
return nil
|
||||
}
|
||||
return r.withVerifier.Copy()
|
||||
}
|
||||
|
||||
// Prompts() implements the Request.Prompts() interface function and returns a
|
||||
// copy of the prompts.
|
||||
func (r *Req) Prompts() []Prompt {
|
||||
if r.withPrompts == nil {
|
||||
return nil
|
||||
}
|
||||
cp := make([]Prompt, len(r.withPrompts))
|
||||
copy(cp, r.withPrompts)
|
||||
return cp
|
||||
}
|
||||
|
||||
// Display() implements the Request.Display() interface function.
|
||||
func (r *Req) Display() Display { return r.withDisplay }
|
||||
|
||||
// UILocales() implements the Request.UILocales() interface function and returns a
|
||||
// copy of the UILocales
|
||||
func (r *Req) UILocales() []language.Tag {
|
||||
if r.withUILocales == nil {
|
||||
return nil
|
||||
}
|
||||
cp := make([]language.Tag, len(r.withUILocales))
|
||||
copy(cp, r.withUILocales)
|
||||
return cp
|
||||
}
|
||||
|
||||
// Claims() implements the Request.Claims() interface function
|
||||
// and returns a copy of the claims request.
|
||||
func (r *Req) Claims() []byte {
|
||||
if r.withClaims == nil {
|
||||
return nil
|
||||
}
|
||||
cp := make([]byte, len(r.withClaims))
|
||||
copy(cp, r.withClaims)
|
||||
return cp
|
||||
}
|
||||
|
||||
// ACRValues() implements the Request.ARCValues() interface function and returns a
|
||||
// copy of the acr values
|
||||
func (r *Req) ACRValues() []string {
|
||||
if len(r.withACRValues) == 0 {
|
||||
return nil
|
||||
}
|
||||
cp := make([]string, len(r.withACRValues))
|
||||
copy(cp, r.withACRValues)
|
||||
return cp
|
||||
}
|
||||
|
||||
// MaxAge: when authAfter is not a zero value (authTime.IsZero()) then the
|
||||
// id_token's auth_time claim must be after the specified time.
|
||||
//
|
||||
// See: https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
||||
func (r *Req) MaxAge() (uint, time.Time) {
|
||||
if r.withMaxAge == nil {
|
||||
return 0, time.Time{}
|
||||
}
|
||||
return r.withMaxAge.seconds, r.withMaxAge.authAfter.Truncate(time.Second)
|
||||
}
|
||||
|
||||
// ImplicitFlow indicates whether or not to use the implicit flow. Getting
|
||||
// only an id_token for an implicit flow is the default, but at times
|
||||
// it's necessary to also request an access_token, so this function and the
|
||||
// WithImplicitFlow(...) option allows for those scenarios. Overall, it is
|
||||
// recommend to not request access_tokens during the implicit flow. If you need
|
||||
// an access_token, then use the authorization code flows and if you can't
|
||||
// secure a client secret then use the authorization code flow with PKCE.
|
||||
//
|
||||
// The first returned bool represents if the implicit flow has been requested.
|
||||
// The second returned bool represents if an access token has been requested
|
||||
// during the implicit flow.
|
||||
func (r *Req) ImplicitFlow() (bool, bool) {
|
||||
if r.withImplicit == nil {
|
||||
return false, false
|
||||
}
|
||||
switch {
|
||||
case r.withImplicit.withAccessToken:
|
||||
return true, true
|
||||
default:
|
||||
return true, false
|
||||
}
|
||||
}
|
||||
|
||||
// RequestExpirySkew defines a time skew when checking a Request's expiration.
|
||||
const RequestExpirySkew = 1 * time.Second
|
||||
|
||||
// IsExpired returns true if the request has expired.
|
||||
func (r *Req) IsExpired() bool {
|
||||
return r.expiration.Before(time.Now().Add(RequestExpirySkew))
|
||||
}
|
||||
|
||||
// now returns the current time using the optional timeFn
|
||||
func (r *Req) now() time.Time {
|
||||
if r.nowFunc != nil {
|
||||
return r.nowFunc()
|
||||
}
|
||||
return time.Now() // fallback to this default
|
||||
}
|
||||
|
||||
type implicitFlow struct {
|
||||
withAccessToken bool
|
||||
}
|
||||
|
||||
type maxAge struct {
|
||||
seconds uint
|
||||
authAfter time.Time
|
||||
}
|
||||
|
||||
// reqOptions is the set of available options for Req functions
|
||||
type reqOptions struct {
|
||||
withNowFunc func() time.Time
|
||||
withScopes []string
|
||||
withAudiences []string
|
||||
withImplicitFlow *implicitFlow
|
||||
withVerifier CodeVerifier
|
||||
withMaxAge *maxAge
|
||||
withPrompts []Prompt
|
||||
withDisplay Display
|
||||
withUILocales []language.Tag
|
||||
withClaims []byte
|
||||
withACRValues []string
|
||||
withState string
|
||||
}
|
||||
|
||||
// reqDefaults is a handy way to get the defaults at runtime and during unit
|
||||
// tests.
|
||||
func reqDefaults() reqOptions {
|
||||
return reqOptions{}
|
||||
}
|
||||
|
||||
// getReqOpts gets the request defaults and applies the opt overrides passed in
|
||||
func getReqOpts(opt ...Option) reqOptions {
|
||||
opts := reqDefaults()
|
||||
ApplyOpts(&opts, opt...)
|
||||
return opts
|
||||
}
|
||||
|
||||
// WithImplicitFlow provides an option to use an OIDC implicit flow with form
|
||||
// post. It should be noted that if your OIDC provider supports PKCE, then use
|
||||
// it over the implicit flow. Getting only an id_token is the default, and
|
||||
// optionally passing a true bool will request an access_token as well during
|
||||
// the flow. You cannot use WithImplicit and WithPKCE together. It is
|
||||
// recommend to not request access_tokens during the implicit flow. If you need
|
||||
// an access_token, then use the authorization code flows.
|
||||
//
|
||||
// Option is valid for: Request
|
||||
//
|
||||
// See: https://openid.net/specs/openid-connect-core-1_0.html#ImplicitFlowAuth
|
||||
// See: https://openid.net/specs/oauth-v2-form-post-response-mode-1_0.html
|
||||
func WithImplicitFlow(args ...interface{}) Option {
|
||||
withAccessToken := false
|
||||
for _, arg := range args {
|
||||
switch arg := arg.(type) {
|
||||
case bool:
|
||||
if arg {
|
||||
withAccessToken = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return func(o interface{}) {
|
||||
if o, ok := o.(*reqOptions); ok {
|
||||
o.withImplicitFlow = &implicitFlow{
|
||||
withAccessToken: withAccessToken,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithPKCE provides an option to use a CodeVerifier with the authorization
|
||||
// code flow with PKCE. You cannot use WithImplicit and WithPKCE together.
|
||||
//
|
||||
// Option is valid for: Request
|
||||
//
|
||||
// See: https://tools.ietf.org/html/rfc7636
|
||||
func WithPKCE(v CodeVerifier) Option {
|
||||
return func(o interface{}) {
|
||||
if o, ok := o.(*reqOptions); ok {
|
||||
o.withVerifier = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithMaxAge provides an optional maximum authentication age, which is the
|
||||
// allowable elapsed time in seconds since the last time the user was actively
|
||||
// authenticated by the provider. When a max age is specified, the provider
|
||||
// must include a auth_time claim in the returned id_token. This makes it
|
||||
// preferable to prompt=login, where you have no way to verify when an
|
||||
// authentication took place.
|
||||
//
|
||||
// Option is valid for: Request
|
||||
//
|
||||
// See: https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
||||
func WithMaxAge(seconds uint) Option {
|
||||
return func(o interface{}) {
|
||||
if o, ok := o.(*reqOptions); ok {
|
||||
// authAfter will be a zero value, since it's not set until the
|
||||
// NewRequest() factory, when it can determine it's nowFunc
|
||||
o.withMaxAge = &maxAge{
|
||||
seconds: seconds,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithPrompts provides an optional list of values that specifies whether the
|
||||
// Authorization Server prompts the End-User for reauthentication and consent.
|
||||
//
|
||||
// See MaxAge() if wish to specify an allowable elapsed time in seconds since
|
||||
// the last time the End-User was actively authenticated by the OP.
|
||||
//
|
||||
// Option is valid for: Request
|
||||
//
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
||||
func WithPrompts(prompts ...Prompt) Option {
|
||||
return func(o interface{}) {
|
||||
if o, ok := o.(*reqOptions); ok {
|
||||
o.withPrompts = prompts
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithDisplay optionally specifies how the Authorization Server displays the
|
||||
// authentication and consent user interface pages to the End-User.
|
||||
//
|
||||
// Option is valid for: Request
|
||||
//
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
||||
func WithDisplay(d Display) Option {
|
||||
return func(o interface{}) {
|
||||
if o, ok := o.(*reqOptions); ok {
|
||||
o.withDisplay = d
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithUILocales optionally specifies End-User's preferred languages via
|
||||
// language Tags, ordered by preference.
|
||||
//
|
||||
// Option is valid for: Request
|
||||
//
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
||||
func WithUILocales(locales ...language.Tag) Option {
|
||||
return func(o interface{}) {
|
||||
if o, ok := o.(*reqOptions); ok {
|
||||
o.withUILocales = locales
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithClaims optionally requests that specific claims be returned using
|
||||
// the claims parameter.
|
||||
//
|
||||
// Option is valid for: Request
|
||||
//
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#ClaimsParameter
|
||||
func WithClaims(json []byte) Option {
|
||||
return func(o interface{}) {
|
||||
if o, ok := o.(*reqOptions); ok {
|
||||
o.withClaims = json
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithACRValues optionally specifies the acr values that the Authorization
|
||||
// Server is being requested to use for processing this Authentication
|
||||
// Request, with the values appearing in order of preference.
|
||||
//
|
||||
// NOTE: Requested acr_values are not verified by the Provider.Exchange(...)
|
||||
// or Provider.VerifyIDToken() functions, since the request/return values
|
||||
// are determined by the provider's implementation. You'll need to verify
|
||||
// the claims returned yourself based on values provided by you OIDC
|
||||
// Provider's documentation.
|
||||
//
|
||||
// Option is valid for: Request
|
||||
//
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
||||
func WithACRValues(values ...string) Option {
|
||||
return func(o interface{}) {
|
||||
if o, ok := o.(*reqOptions); ok {
|
||||
o.withACRValues = values
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithState optionally specifies a value to use for the request's state.
|
||||
// Typically, state is a random string generated for you when you create
|
||||
// a new Request. This option allows you to override that auto-generated value
|
||||
// with a specific value of your own choosing.
|
||||
//
|
||||
// The primary reason for using the state parameter is to mitigate CSRF attacks
|
||||
// by using a unique and non-guessable value associated with each authentication
|
||||
// request about to be initiated. That value allows you to prevent the attack by
|
||||
// confirming that the value coming from the response matches the one you sent.
|
||||
// Since the state parameter is a string, you can encode any other information
|
||||
// in it.
|
||||
//
|
||||
// Some care must be taken to not use a state which is longer than your OIDC
|
||||
// Provider allows. The specification places no limit on the length, but there
|
||||
// are many practical limitations placed on the length by browsers, proxies and
|
||||
// of course your OIDC provider.
|
||||
//
|
||||
// State should be at least 20 chars long (see:
|
||||
// https://tools.ietf.org/html/rfc6749#section-10.10).
|
||||
//
|
||||
// See NewID(...) for a function that generates a sufficiently
|
||||
// random string and supports the WithPrefix(...) option, which can be used
|
||||
// prefix your custom state payload.
|
||||
//
|
||||
// Neither a max or min length is enforced when you use the WithState option.
|
||||
//
|
||||
// Option is valid for: Request
|
||||
//
|
||||
func WithState(s string) Option {
|
||||
return func(o interface{}) {
|
||||
if o, ok := o.(*reqOptions); ok {
|
||||
o.withState = s
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,214 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"hash"
|
||||
"math/big"
|
||||
"net"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
// TestGenerateKeys will generate a test ECDSA P-256 pub/priv key pair.
|
||||
func TestGenerateKeys(t *testing.T) (crypto.PublicKey, crypto.PrivateKey) {
|
||||
t.Helper()
|
||||
require := require.New(t)
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(err)
|
||||
return &priv.PublicKey, priv
|
||||
}
|
||||
|
||||
// TestSignJWT will bundle the provided claims into a test signed JWT.
|
||||
func TestSignJWT(t *testing.T, key crypto.PrivateKey, alg string, claims interface{}, keyID []byte) string {
|
||||
t.Helper()
|
||||
require := require.New(t)
|
||||
|
||||
hdr := map[jose.HeaderKey]interface{}{}
|
||||
if keyID != nil {
|
||||
hdr["key_id"] = string(keyID)
|
||||
}
|
||||
|
||||
sig, err := jose.NewSigner(
|
||||
jose.SigningKey{Algorithm: jose.SignatureAlgorithm(alg), Key: key},
|
||||
(&jose.SignerOptions{ExtraHeaders: hdr}).WithType("JWT"),
|
||||
)
|
||||
require.NoError(err)
|
||||
|
||||
raw, err := jwt.Signed(sig).
|
||||
Claims(claims).
|
||||
CompactSerialize()
|
||||
require.NoError(err)
|
||||
return raw
|
||||
}
|
||||
|
||||
// TestGenerateCA will generate a test x509 CA cert, along with it encoded in a
|
||||
// PEM format.
|
||||
func TestGenerateCA(t *testing.T, hosts []string) (*x509.Certificate, string) {
|
||||
t.Helper()
|
||||
require := require.New(t)
|
||||
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
||||
require.NoError(err)
|
||||
|
||||
// ECDSA, ED25519 and RSA subject keys should have the DigitalSignature
|
||||
// KeyUsage bits set in the x509.Certificate template
|
||||
keyUsage := x509.KeyUsageDigitalSignature
|
||||
|
||||
validFor := 2 * time.Minute
|
||||
notBefore := time.Now()
|
||||
notAfter := notBefore.Add(validFor)
|
||||
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
require.NoError(err)
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Acme Co"},
|
||||
},
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
|
||||
KeyUsage: keyUsage,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
for _, h := range hosts {
|
||||
if ip := net.ParseIP(h); ip != nil {
|
||||
template.IPAddresses = append(template.IPAddresses, ip)
|
||||
} else {
|
||||
template.DNSNames = append(template.DNSNames, h)
|
||||
}
|
||||
}
|
||||
|
||||
template.IsCA = true
|
||||
template.KeyUsage |= x509.KeyUsageCertSign
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||
require.NoError(err)
|
||||
|
||||
c, err := x509.ParseCertificate(derBytes)
|
||||
require.NoError(err)
|
||||
|
||||
return c, string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}))
|
||||
}
|
||||
|
||||
// testHash will generate an hash using a signature algorithm. It is used to
|
||||
// test at_hash and c_hash id_token claims. This is helpful internally, but
|
||||
// intentionally not exported.
|
||||
func testHash(t *testing.T, signatureAlg Alg, data string) string {
|
||||
t.Helper()
|
||||
require := require.New(t)
|
||||
var h hash.Hash
|
||||
switch signatureAlg {
|
||||
case RS256, ES256, PS256:
|
||||
h = sha256.New()
|
||||
case RS384, ES384, PS384:
|
||||
h = sha512.New384()
|
||||
case RS512, ES512, PS512:
|
||||
h = sha512.New()
|
||||
case EdDSA:
|
||||
return "EdDSA-hash"
|
||||
default:
|
||||
require.FailNowf("", "testHash: unsupported signing algorithm %s", string(signatureAlg))
|
||||
}
|
||||
require.NotNil(h)
|
||||
_, _ = h.Write([]byte(string(data))) // hash documents that Write will never return an error
|
||||
sum := h.Sum(nil)[:h.Size()/2]
|
||||
actual := base64.RawURLEncoding.EncodeToString(sum)
|
||||
return actual
|
||||
}
|
||||
|
||||
// testDefaultJWT creates a default test JWT and is internally helpful, but for now we won't export it.
|
||||
func testDefaultJWT(t *testing.T, privKey crypto.PrivateKey, expireIn time.Duration, nonce string, additionalClaims map[string]interface{}) string {
|
||||
t.Helper()
|
||||
now := float64(time.Now().Unix())
|
||||
claims := map[string]interface{}{
|
||||
"iss": "https://example.com/",
|
||||
"iat": now,
|
||||
"nbf": now,
|
||||
"exp": float64(time.Now().Unix()),
|
||||
"aud": []string{"www.example.com"},
|
||||
"sub": "alice@example.com",
|
||||
"nonce": nonce,
|
||||
}
|
||||
for k, v := range additionalClaims {
|
||||
claims[k] = v
|
||||
}
|
||||
testJWT := TestSignJWT(t, privKey, string(ES256), claims, nil)
|
||||
return testJWT
|
||||
}
|
||||
|
||||
// testNewConfig creates a new config from the TestProvider. It will set the
|
||||
// TestProvider's client ID/secret and use the TestProviders signing algorithm
|
||||
// when building the configuration. This is helpful internally, but
|
||||
// intentionally not exported.
|
||||
func testNewConfig(t *testing.T, clientID, clientSecret, allowedRedirectURL string, tp *TestProvider) *Config {
|
||||
const op = "testNewConfig"
|
||||
t.Helper()
|
||||
require := require.New(t)
|
||||
|
||||
require.NotEmptyf(clientID, "%s: client id is empty", op)
|
||||
require.NotEmptyf(clientSecret, "%s: client secret is empty", op)
|
||||
require.NotEmptyf(allowedRedirectURL, "%s: redirect URL is empty", op)
|
||||
|
||||
tp.SetClientCreds(clientID, clientSecret)
|
||||
_, _, alg, _ := tp.SigningKeys()
|
||||
c, err := NewConfig(
|
||||
tp.Addr(),
|
||||
clientID,
|
||||
ClientSecret(clientSecret),
|
||||
[]Alg{alg},
|
||||
[]string{allowedRedirectURL},
|
||||
nil,
|
||||
WithProviderCA(tp.CACert()),
|
||||
)
|
||||
require.NoError(err)
|
||||
return c
|
||||
}
|
||||
|
||||
// testNewProvider creates a new Provider. It uses the TestProvider (tp) to properly
|
||||
// construct the provider's configuration (see testNewConfig). This is helpful internally, but
|
||||
// intentionally not exported.
|
||||
func testNewProvider(t *testing.T, clientID, clientSecret, redirectURL string, tp *TestProvider) *Provider {
|
||||
const op = "testNewProvider"
|
||||
t.Helper()
|
||||
require := require.New(t)
|
||||
require.NotEmptyf(clientID, "%s: client id is empty", op)
|
||||
require.NotEmptyf(clientSecret, "%s: client secret is empty", op)
|
||||
require.NotEmptyf(redirectURL, "%s: redirect URL is empty", op)
|
||||
|
||||
tc := testNewConfig(t, clientID, clientSecret, redirectURL, tp)
|
||||
p, err := NewProvider(tc)
|
||||
require.NoError(err)
|
||||
t.Cleanup(p.Done)
|
||||
return p
|
||||
}
|
||||
|
||||
// testAssertEqualFunc gives you a way to assert that two functions (passed as
|
||||
// interface{}) are equal. This is helpful internally, but intentionally not
|
||||
// exported.
|
||||
func testAssertEqualFunc(t *testing.T, wantFunc, gotFunc interface{}, format string, args ...interface{}) {
|
||||
t.Helper()
|
||||
want := runtime.FuncForPC(reflect.ValueOf(wantFunc).Pointer()).Name()
|
||||
got := runtime.FuncForPC(reflect.ValueOf(gotFunc).Pointer()).Name()
|
||||
assert.Equalf(t, want, got, format, args...)
|
||||
}
|
|
@ -0,0 +1,910 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/cap/oidc/internal/strutils"
|
||||
"github.com/hashicorp/go-cleanhttp"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
// TestProvider is a local http server that supports test provider capabilities
|
||||
// which makes writing tests much easier. Much of this TestProvider
|
||||
// design/implementation comes from Consul's oauthtest package. A big thanks to
|
||||
// the original package's contributors.
|
||||
//
|
||||
// It's important to remember that the TestProvider is stateful (see any of its
|
||||
// receiver functions that begin with Set*).
|
||||
//
|
||||
// Once you've started a TestProvider http server with StartTestProvider(...),
|
||||
// the following test endpoints are supported:
|
||||
//
|
||||
// * GET /.well-known/openid-configuration OIDC Discovery
|
||||
//
|
||||
// * GET or POST /authorize OIDC authorization supporting both
|
||||
// the authorization code flow (with
|
||||
// optional PKCE) and the implicit
|
||||
// flow with form_post.
|
||||
//
|
||||
// * POST /token OIDC token
|
||||
//
|
||||
// * GET /userinfo OAuth UserInfo
|
||||
//
|
||||
// * GET /.well-known/jwks.json JWKs used to verify issued JWT tokens
|
||||
//
|
||||
// Making requests to these endpoints are facilitated by
|
||||
// * TestProvider.HTTPClient which returns an http.Client for making requests.
|
||||
// * TestProvider.CACert which the pem-encoded CA certificate used by the HTTPS server.
|
||||
//
|
||||
// Runtime Configuration:
|
||||
// * Issuer: Addr() returns the the current base URL for the test provider's
|
||||
// running webserver, which can be used as an OIDC Issuer for discovery and
|
||||
// is also used for the iss claim when issuing JWTs.
|
||||
//
|
||||
// * Relying Party ClientID/ClientSecret: SetClientCreds(...) updates the
|
||||
// creds and they are empty by default.
|
||||
//
|
||||
// * Now: SetNowFunc(...) updates the provider's "now" function and time.Now
|
||||
// is the default.
|
||||
//
|
||||
// * Expiry: SetExpectedExpiry( exp time.Duration) updates the expiry and
|
||||
// now + 5 * time.Second is the default.
|
||||
//
|
||||
// * Signing keys: SetSigningKeys(...) updates the keys and a ECDSA P-256 pair
|
||||
// of priv/pub keys are the default with a signing algorithm of ES256
|
||||
//
|
||||
// * Authorization Code: SetExpectedAuthCode(...) updates the auth code
|
||||
// required by the /authorize endpoint and the code is empty by default.
|
||||
//
|
||||
// * Authorization Nonce: SetExpectedAuthNonce(...) updates the nonce required
|
||||
// by the /authorize endpont and the nonce is empty by default.
|
||||
//
|
||||
// * Allowed RedirectURIs: SetAllowedRedirectURIs(...) updates the allowed
|
||||
// redirect URIs and "https://example.com" is the default.
|
||||
//
|
||||
// * Custom Claims: SetCustomClaims(...) updates custom claims added to JWTs issued
|
||||
// and the custom claims are empty by default.
|
||||
//
|
||||
// * Audiences: SetCustomAudience(...) updates the audience claim of JWTs issued
|
||||
// and the ClientID is the default.
|
||||
//
|
||||
// * Authentication Time (auth_time): SetOmitAuthTimeClaim(...) allows you to
|
||||
// turn off/on the inclusion of an auth_time claim in issued JWTs and the claim
|
||||
// is included by default.
|
||||
//
|
||||
// * Issuing id_tokens: SetOmitIDTokens(...) allows you to turn off/on the issuing of
|
||||
// id_tokens from the /token endpoint. id_tokens are issued by default.
|
||||
//
|
||||
// * Issuing access_tokens: SetOmitAccessTokens(...) allows you to turn off/on
|
||||
// the issuing of access_tokens from the /token endpoint. access_tokens are issued
|
||||
// by default.
|
||||
//
|
||||
// * Authorization State: SetExpectedState sets the value for the state parameter
|
||||
// returned from the /authorized endpoint
|
||||
//
|
||||
// * Token Responses: SetDisableToken disables the /token endpoint, causing
|
||||
// it to return a 401 http status.
|
||||
//
|
||||
// * Implicit Flow Responses: SetDisableImplicit disables implicit flow responses,
|
||||
// causing them to return a 401 http status.
|
||||
//
|
||||
// * PKCE verifier: SetPKCEVerifier(oidc.CodeVerifier) sets the PKCE code_verifier
|
||||
// and PKCEVerifier() returns the current verifier.
|
||||
//
|
||||
// * UserInfo: SetUserInfoReply sets the UserInfo endpoint response and
|
||||
// UserInfoReply() returns the current response.
|
||||
type TestProvider struct {
|
||||
httpServer *httptest.Server
|
||||
caCert string
|
||||
|
||||
jwks *jose.JSONWebKeySet
|
||||
allowedRedirectURIs []string
|
||||
replySubject string
|
||||
replyUserinfo interface{}
|
||||
replyExpiry time.Duration
|
||||
|
||||
mu sync.Mutex
|
||||
clientID string
|
||||
clientSecret string
|
||||
expectedAuthCode string
|
||||
expectedAuthNonce string
|
||||
expectedState string
|
||||
customClaims map[string]interface{}
|
||||
customAudiences []string
|
||||
omitAuthTimeClaim bool
|
||||
omitIDToken bool
|
||||
omitAccessToken bool
|
||||
disableUserInfo bool
|
||||
disableJWKs bool
|
||||
disableToken bool
|
||||
disableImplicit bool
|
||||
invalidJWKs bool
|
||||
nowFunc func() time.Time
|
||||
pkceVerifier CodeVerifier
|
||||
|
||||
// privKey *ecdsa.PrivateKey
|
||||
privKey crypto.PrivateKey
|
||||
pubKey crypto.PublicKey
|
||||
keyID string
|
||||
alg Alg
|
||||
|
||||
t *testing.T
|
||||
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// Stop stops the running TestProvider.
|
||||
func (p *TestProvider) Stop() {
|
||||
p.httpServer.Close()
|
||||
if p.client != nil {
|
||||
p.client.CloseIdleConnections()
|
||||
}
|
||||
}
|
||||
|
||||
// StartTestProvider creates and starts a running TestProvider http server. The
|
||||
// WithPort option is supported. The TestProvider will be shutdown when the
|
||||
// test and all it's subtests complete via a registered function with
|
||||
// t.Cleanup(...).
|
||||
func StartTestProvider(t *testing.T, opt ...Option) *TestProvider {
|
||||
t.Helper()
|
||||
require := require.New(t)
|
||||
opts := getTestProviderOpts(opt...)
|
||||
|
||||
v, err := NewCodeVerifier()
|
||||
require.NoError(err)
|
||||
p := &TestProvider{
|
||||
t: t,
|
||||
nowFunc: time.Now,
|
||||
pkceVerifier: v,
|
||||
customClaims: map[string]interface{}{},
|
||||
replyExpiry: 5 * time.Second,
|
||||
|
||||
allowedRedirectURIs: []string{
|
||||
"https://example.com",
|
||||
},
|
||||
replySubject: "alice@example.com",
|
||||
replyUserinfo: map[string]interface{}{
|
||||
"sub": "alice@example.com",
|
||||
"dob": "1978",
|
||||
"friend": "bob",
|
||||
"nickname": "A",
|
||||
"advisor": "Faythe",
|
||||
"nosy-neighbor": "Eve",
|
||||
},
|
||||
}
|
||||
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(err)
|
||||
p.pubKey, p.privKey = &priv.PublicKey, priv
|
||||
p.alg = ES256
|
||||
p.keyID = strconv.Itoa(int(time.Now().Unix()))
|
||||
p.jwks = &jose.JSONWebKeySet{
|
||||
Keys: []jose.JSONWebKey{
|
||||
{
|
||||
Key: p.pubKey,
|
||||
KeyID: p.keyID,
|
||||
},
|
||||
},
|
||||
}
|
||||
p.httpServer = httptestNewUnstartedServerWithPort(t, p, opts.withPort)
|
||||
p.httpServer.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
|
||||
p.httpServer.StartTLS()
|
||||
t.Cleanup(p.Stop)
|
||||
|
||||
cert := p.httpServer.Certificate()
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
|
||||
require.NoError(err)
|
||||
p.caCert = buf.String()
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// testProviderOptions is the set of available options for TestProvider
|
||||
// functions
|
||||
type testProviderOptions struct {
|
||||
withPort int
|
||||
withAtHashOf string
|
||||
withCHashOf string
|
||||
}
|
||||
|
||||
// testProviderDefaults is a handy way to get the defaults at runtime and during unit
|
||||
// tests.
|
||||
func testProviderDefaults() testProviderOptions {
|
||||
return testProviderOptions{}
|
||||
}
|
||||
|
||||
// getTestProviderOpts gets the test provider defaults and applies the opt
|
||||
// overrides passed in
|
||||
func getTestProviderOpts(opt ...Option) testProviderOptions {
|
||||
opts := testProviderDefaults()
|
||||
ApplyOpts(&opts, opt...)
|
||||
return opts
|
||||
}
|
||||
|
||||
// WithTestPort provides an optional port for the test provider.
|
||||
//
|
||||
// Valid for: TestProvider.StartTestProvider
|
||||
func WithTestPort(port int) Option {
|
||||
return func(o interface{}) {
|
||||
if o, ok := o.(*testProviderOptions); ok {
|
||||
o.withPort = port
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// withTestAtHash provides an option to request the at_hash claim. Valid for:
|
||||
// TestProvider.issueSignedJWT
|
||||
func withTestAtHash(accessToken string) Option {
|
||||
return func(o interface{}) {
|
||||
if o, ok := o.(*testProviderOptions); ok {
|
||||
o.withAtHashOf = accessToken
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// withTestCHash provides an option to request the c_hash claim. Valid for:
|
||||
// TestProvider.issueSignedJWT
|
||||
func withTestCHash(authorizationCode string) Option {
|
||||
return func(o interface{}) {
|
||||
if o, ok := o.(*testProviderOptions); ok {
|
||||
o.withCHashOf = authorizationCode
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPClient returns an http.Client for the test provider. The returned client
|
||||
// uses a pooled transport (so it can reuse connections) that uses the
|
||||
// test provider's CA certificate. This client's idle connections are closed in
|
||||
// TestProvider.Done()
|
||||
func (p *TestProvider) HTTPClient() *http.Client {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.client != nil {
|
||||
return p.client
|
||||
}
|
||||
p.t.Helper()
|
||||
require := require.New(p.t)
|
||||
|
||||
// use the cleanhttp package to create a "pooled" transport that's better
|
||||
// configured for requests that re-use the same provider host. Among other
|
||||
// things, this transport supports better concurrency when making requests
|
||||
// to the same host. On the downside, this transport can leak file
|
||||
// descriptors over time, so we'll be sure to call
|
||||
// client.CloseIdleConnections() in the TestProvider.Done() to stave that off.
|
||||
tr := cleanhttp.DefaultPooledTransport()
|
||||
|
||||
certPool := x509.NewCertPool()
|
||||
ok := certPool.AppendCertsFromPEM([]byte(p.caCert))
|
||||
require.True(ok)
|
||||
|
||||
tr.TLSClientConfig = &tls.Config{
|
||||
RootCAs: certPool,
|
||||
}
|
||||
|
||||
p.client = &http.Client{
|
||||
Transport: tr,
|
||||
}
|
||||
return p.client
|
||||
}
|
||||
|
||||
// SetExpectedExpiry is for configuring the expected expiry for any JWTs issued
|
||||
// by the provider (the default is 5 seconds)
|
||||
func (p *TestProvider) SetExpectedExpiry(exp time.Duration) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.replyExpiry = exp
|
||||
}
|
||||
|
||||
// SetClientCreds is for configuring the relying party client ID and client
|
||||
// secret information required for the OIDC workflows.
|
||||
func (p *TestProvider) SetClientCreds(clientID, clientSecret string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.clientID = clientID
|
||||
p.clientSecret = clientSecret
|
||||
}
|
||||
|
||||
// ClientCreds returns the relying party client information required for the
|
||||
// OIDC workflows.
|
||||
func (p *TestProvider) ClientCreds() (clientID, clientSecret string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.clientID, p.clientSecret
|
||||
}
|
||||
|
||||
// SetExpectedAuthCode configures the auth code to return from /auth and the
|
||||
// allowed auth code for /token.
|
||||
func (p *TestProvider) SetExpectedAuthCode(code string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.expectedAuthCode = code
|
||||
}
|
||||
|
||||
// SetExpectedAuthNonce configures the nonce value required for /auth.
|
||||
func (p *TestProvider) SetExpectedAuthNonce(nonce string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.expectedAuthNonce = nonce
|
||||
}
|
||||
|
||||
// SetAllowedRedirectURIs allows you to configure the allowed redirect URIs for
|
||||
// the OIDC workflow. If not configured a sample of "https://example.com" is
|
||||
// used.
|
||||
func (p *TestProvider) SetAllowedRedirectURIs(uris []string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.allowedRedirectURIs = uris
|
||||
}
|
||||
|
||||
// SetCustomClaims lets you set claims to return in the JWT issued by the OIDC
|
||||
// workflow.
|
||||
func (p *TestProvider) SetCustomClaims(customClaims map[string]interface{}) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.customClaims = customClaims
|
||||
}
|
||||
|
||||
// SetCustomAudience configures what audience value to embed in the JWT issued
|
||||
// by the OIDC workflow.
|
||||
func (p *TestProvider) SetCustomAudience(customAudiences ...string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.customAudiences = customAudiences
|
||||
}
|
||||
|
||||
// SetNowFunc configures how the test provider will determine the current time. The
|
||||
// default is time.Now()
|
||||
func (p *TestProvider) SetNowFunc(n func() time.Time) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.t.Helper()
|
||||
require := require.New(p.t)
|
||||
require.NotNilf(n, "TestProvider.SetNowFunc: time func is nil")
|
||||
p.nowFunc = n
|
||||
}
|
||||
|
||||
// SetOmitAuthTimeClaim turn on/off the omitting of an auth_time claim from
|
||||
// id_tokens from the /token endpoint. If set to true, the test provider will
|
||||
// not include the auth_time claim in issued id_tokens from the /token endpoint.
|
||||
func (p *TestProvider) SetOmitAuthTimeClaim(omitAuthTime bool) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.omitAuthTimeClaim = omitAuthTime
|
||||
}
|
||||
|
||||
// SetOmitIDTokens turn on/off the omitting of id_tokens from the /token
|
||||
// endpoint. If set to true, the test provider will not omit (issue) id_tokens
|
||||
// from the /token endpoint.
|
||||
func (p *TestProvider) SetOmitIDTokens(omitIDTokens bool) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.omitIDToken = omitIDTokens
|
||||
}
|
||||
|
||||
// OmitAccessTokens turn on/off the omitting of access_tokens from the /token
|
||||
// endpoint. If set to true, the test provider will not omit (issue)
|
||||
// access_tokens from the /token endpoint.
|
||||
func (p *TestProvider) SetOmitAccessTokens(omitAccessTokens bool) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.omitAccessToken = omitAccessTokens
|
||||
}
|
||||
|
||||
// SetDisableUserInfo makes the userinfo endpoint return 404 and omits it from the
|
||||
// discovery config.
|
||||
func (p *TestProvider) SetDisableUserInfo(disable bool) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.disableUserInfo = disable
|
||||
}
|
||||
|
||||
// SetDisableJWKs makes the JWKs endpoint return 404
|
||||
func (p *TestProvider) SetDisableJWKs(disable bool) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.disableJWKs = disable
|
||||
}
|
||||
|
||||
// SetInvalidJWKS makes the JWKs endpoint return an invalid response
|
||||
func (p *TestProvider) SetInvalidJWKS(invalid bool) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.invalidJWKs = invalid
|
||||
}
|
||||
|
||||
// SetExpectedState sets the value for the state parameter returned from
|
||||
// /authorized
|
||||
func (p *TestProvider) SetExpectedState(s string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.expectedState = s
|
||||
}
|
||||
|
||||
// SetDisableToken makes the /token endpoint return 401
|
||||
func (p *TestProvider) SetDisableToken(disable bool) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.disableToken = disable
|
||||
}
|
||||
|
||||
// SetDisableImplicit makes implicit flow responses return 401
|
||||
func (p *TestProvider) SetDisableImplicit(disable bool) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.disableImplicit = disable
|
||||
}
|
||||
|
||||
// SetPKCEVerifier sets the PKCE oidc.CodeVerifier
|
||||
func (p *TestProvider) SetPKCEVerifier(verifier CodeVerifier) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.t.Helper()
|
||||
require.NotNil(p.t, verifier)
|
||||
p.pkceVerifier = verifier
|
||||
}
|
||||
|
||||
// PKCEVerifier returns the PKCE oidc.CodeVerifier
|
||||
func (p *TestProvider) PKCEVerifier() CodeVerifier {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.pkceVerifier
|
||||
}
|
||||
|
||||
// SetUserInfoReply sets the UserInfo endpoint response.
|
||||
func (p *TestProvider) SetUserInfoReply(resp interface{}) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.replyUserinfo = resp
|
||||
}
|
||||
|
||||
// SetUserInfoReply sets the UserInfo endpoint response.
|
||||
func (p *TestProvider) UserInfoReply() interface{} {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.replyUserinfo
|
||||
}
|
||||
|
||||
// Addr returns the current base URL for the test provider's running webserver,
|
||||
// which can be used as an OIDC issuer for discovery and is also used for the
|
||||
// iss claim when issuing JWTs.
|
||||
func (p *TestProvider) Addr() string { return p.httpServer.URL }
|
||||
|
||||
// CACert returns the pem-encoded CA certificate used by the test provider's
|
||||
// HTTPS server.
|
||||
func (p *TestProvider) CACert() string { return p.caCert }
|
||||
|
||||
// SigningKeys returns the test provider's keys used to sign JWTs, its Alg and
|
||||
// Key ID.
|
||||
func (p *TestProvider) SigningKeys() (crypto.PrivateKey, crypto.PublicKey, Alg, string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.privKey, p.pubKey, p.alg, p.keyID
|
||||
}
|
||||
|
||||
// SetSigningKeys sets the test provider's keys and alg used to sign JWTs.
|
||||
func (p *TestProvider) SetSigningKeys(privKey crypto.PrivateKey, pubKey crypto.PublicKey, alg Alg, KeyID string) {
|
||||
const op = "TestProvider.SetSigningKeys"
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.t.Helper()
|
||||
require := require.New(p.t)
|
||||
require.NotNilf(privKey, "%s: private key is nil")
|
||||
require.NotNilf(pubKey, "%s: public key is empty")
|
||||
require.NotEmptyf(alg, "%s: alg is empty")
|
||||
require.NotEmptyf(KeyID, "%s: key id is empty")
|
||||
p.privKey = privKey
|
||||
p.pubKey = pubKey
|
||||
p.alg = alg
|
||||
p.keyID = KeyID
|
||||
p.jwks = &jose.JSONWebKeySet{
|
||||
Keys: []jose.JSONWebKey{
|
||||
{
|
||||
Key: p.pubKey,
|
||||
KeyID: p.keyID,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *TestProvider) writeJSON(w http.ResponseWriter, out interface{}) error {
|
||||
const op = "TestProvider.writeJSON"
|
||||
p.t.Helper()
|
||||
require := require.New(p.t)
|
||||
require.NotNilf(w, "%s: http.ResponseWriter is nil")
|
||||
enc := json.NewEncoder(w)
|
||||
return enc.Encode(out)
|
||||
}
|
||||
|
||||
// writeImplicitResponse will write the required form data response for an
|
||||
// implicit flow response to the OIDC authorize endpoint
|
||||
func (p *TestProvider) writeImplicitResponse(w http.ResponseWriter, state, redirectURL string) error {
|
||||
p.t.Helper()
|
||||
require := require.New(p.t)
|
||||
require.NotNilf(w, "%s: http.ResponseWriter is nil")
|
||||
|
||||
w.Header().Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
const respForm = `
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head><title>Submit This Form</title></head>
|
||||
<body onload="javascript:document.forms[0].submit()">
|
||||
<form method="post" action="%s">
|
||||
<input type="hidden" name="state" id="state" value="%s"/>
|
||||
%s
|
||||
</form>
|
||||
</body>
|
||||
</html>`
|
||||
const tokenField = `<input type="hidden" name="%s" id="%s" value="%s"/>
|
||||
`
|
||||
accessToken := p.issueSignedJWT()
|
||||
idToken := p.issueSignedJWT(withTestAtHash(accessToken))
|
||||
var respTokens strings.Builder
|
||||
if !p.omitAccessToken {
|
||||
respTokens.WriteString(fmt.Sprintf(tokenField, "access_token", "access_token", accessToken))
|
||||
}
|
||||
if !p.omitIDToken {
|
||||
respTokens.WriteString(fmt.Sprintf(tokenField, "id_token", "id_token", idToken))
|
||||
}
|
||||
if _, err := w.Write([]byte(fmt.Sprintf(respForm, redirectURL, state, respTokens.String()))); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *TestProvider) issueSignedJWT(opt ...Option) string {
|
||||
opts := getTestProviderOpts(opt...)
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"sub": p.replySubject,
|
||||
"iss": p.Addr(),
|
||||
"nbf": float64(p.nowFunc().Add(-p.replyExpiry).Unix()),
|
||||
"exp": float64(p.nowFunc().Add(p.replyExpiry).Unix()),
|
||||
"auth_time": float64(p.nowFunc().Unix()),
|
||||
"iat": float64(p.nowFunc().Unix()),
|
||||
"aud": []string{p.clientID},
|
||||
}
|
||||
if len(p.customAudiences) != 0 {
|
||||
claims["aud"] = append(claims["aud"].([]string), p.customAudiences...)
|
||||
}
|
||||
if p.expectedAuthNonce != "" {
|
||||
p.customClaims["nonce"] = p.expectedAuthNonce
|
||||
}
|
||||
for k, v := range p.customClaims {
|
||||
claims[k] = v
|
||||
}
|
||||
if opts.withAtHashOf != "" {
|
||||
claims["at_hash"] = p.testHash(opts.withAtHashOf)
|
||||
}
|
||||
if opts.withCHashOf != "" {
|
||||
claims["c_hash"] = p.testHash(opts.withCHashOf)
|
||||
}
|
||||
return TestSignJWT(p.t, p.privKey, string(p.alg), claims, nil)
|
||||
}
|
||||
|
||||
// testHash will generate an hash using a signature algorithm. It is used to
|
||||
// test at_hash and c_hash id_token claims. This is helpful internally, but
|
||||
// intentionally not exported.
|
||||
func (p *TestProvider) testHash(data string) string {
|
||||
p.t.Helper()
|
||||
require := require.New(p.t)
|
||||
require.NotEmptyf(data, "testHash: data to hash is empty")
|
||||
var h hash.Hash
|
||||
switch p.alg {
|
||||
case RS256, ES256, PS256:
|
||||
h = sha256.New()
|
||||
case RS384, ES384, PS384:
|
||||
h = sha512.New384()
|
||||
case RS512, ES512, PS512:
|
||||
h = sha512.New()
|
||||
case EdDSA:
|
||||
return "EdDSA-hash"
|
||||
default:
|
||||
require.FailNowf("", "testHash: unsupported signing algorithm %s", string(p.alg))
|
||||
}
|
||||
require.NotNil(h)
|
||||
_, _ = h.Write([]byte(string(data))) // hash documents that Write will never return an error
|
||||
sum := h.Sum(nil)[:h.Size()/2]
|
||||
actual := base64.RawURLEncoding.EncodeToString(sum)
|
||||
return actual
|
||||
}
|
||||
|
||||
// writeAuthErrorResponse writes a standard OIDC authentication error response.
|
||||
// See: https://openid.net/specs/openid-connect-core-1_0.html#AuthError
|
||||
func (p *TestProvider) writeAuthErrorResponse(w http.ResponseWriter, req *http.Request, redirectURL, state, errorCode, errorMessage string) {
|
||||
p.t.Helper()
|
||||
require := require.New(p.t)
|
||||
require.NotNilf(w, "%s: http.ResponseWriter is nil")
|
||||
require.NotNilf(req, "%s: http.Request is nil")
|
||||
require.NotEmptyf(errorCode, "%s: errorCode is empty")
|
||||
|
||||
// state and error are required error response parameters
|
||||
redirectURI := redirectURL +
|
||||
"?state=" + url.QueryEscape(state) +
|
||||
"&error=" + url.QueryEscape(errorCode)
|
||||
|
||||
if errorMessage != "" {
|
||||
// add optional error response parameter
|
||||
redirectURI += "&error_description=" + url.QueryEscape(errorMessage)
|
||||
}
|
||||
|
||||
http.Redirect(w, req, redirectURI, http.StatusFound)
|
||||
}
|
||||
|
||||
// writeTokenErrorResponse writes a standard OIDC token error response.
|
||||
// See: https://openid.net/specs/openid-connect-core-1_0.html#TokenErrorResponse
|
||||
func (p *TestProvider) writeTokenErrorResponse(w http.ResponseWriter, statusCode int, errorCode, errorMessage string) error {
|
||||
require := require.New(p.t)
|
||||
require.NotNilf(w, "%s: http.ResponseWriter is nil")
|
||||
require.NotEmptyf(errorCode, "%s: errorCode is empty")
|
||||
require.NotEmptyf(statusCode, "%s: statusCode is empty")
|
||||
|
||||
body := struct {
|
||||
Code string `json:"error"`
|
||||
Desc string `json:"error_description,omitempty"`
|
||||
}{
|
||||
Code: errorCode,
|
||||
Desc: errorMessage,
|
||||
}
|
||||
|
||||
w.WriteHeader(statusCode)
|
||||
return p.writeJSON(w, &body)
|
||||
}
|
||||
|
||||
// ServeHTTP implements the test provider's http.Handler.
|
||||
func (p *TestProvider) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
|
||||
// define all the endpoints supported
|
||||
const (
|
||||
openidConfiguration = "/.well-known/openid-configuration"
|
||||
authorize = "/authorize"
|
||||
token = "/token"
|
||||
userInfo = "/userinfo"
|
||||
wellKnownJwks = "/.well-known/jwks.json"
|
||||
)
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.t.Helper()
|
||||
require := require.New(p.t)
|
||||
require.NotNilf(w, "%s: http.ResponseWriter is nil")
|
||||
require.NotNilf(req, "%s: http.Request is nil")
|
||||
|
||||
// set a default Content-Type which will be overridden as needed.
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
switch req.URL.Path {
|
||||
case openidConfiguration:
|
||||
// OIDC Discovery endpoint request
|
||||
// See: https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfigurationResponse
|
||||
if req.Method != "GET" {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
reply := struct {
|
||||
Issuer string `json:"issuer"`
|
||||
AuthEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
JWKSURI string `json:"jwks_uri"`
|
||||
UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"`
|
||||
}{
|
||||
Issuer: p.Addr(),
|
||||
AuthEndpoint: p.Addr() + authorize,
|
||||
TokenEndpoint: p.Addr() + token,
|
||||
JWKSURI: p.Addr() + wellKnownJwks,
|
||||
UserinfoEndpoint: p.Addr() + userInfo,
|
||||
}
|
||||
if p.disableUserInfo {
|
||||
reply.UserinfoEndpoint = ""
|
||||
}
|
||||
|
||||
err := p.writeJSON(w, &reply)
|
||||
require.NoErrorf(err, "%s: internal error: %w", openidConfiguration, err)
|
||||
|
||||
return
|
||||
case authorize:
|
||||
// Supports both the authorization code and implicit flows
|
||||
// See: https://openid.net/specs/openid-connect-core-1_0.html#AuthorizationEndpoint
|
||||
if !strutils.StrListContains([]string{"POST", "GET"}, req.Method) {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
err := req.ParseForm()
|
||||
require.NoErrorf(err, "%s: internal error: %w", authorize, err)
|
||||
|
||||
respType := req.FormValue("response_type")
|
||||
scopes := req.Form["scope"]
|
||||
state := req.FormValue("state")
|
||||
redirectURI := req.FormValue("redirect_uri")
|
||||
respMode := req.FormValue("response_mode")
|
||||
|
||||
if respType != "code" && !strings.Contains(respType, "id_token") {
|
||||
p.writeAuthErrorResponse(w, req, redirectURI, state, "unsupported_response_type", "")
|
||||
return
|
||||
}
|
||||
if !strutils.StrListContains(scopes, "openid") {
|
||||
p.writeAuthErrorResponse(w, req, redirectURI, state, "invalid_scope", "")
|
||||
return
|
||||
}
|
||||
|
||||
if p.expectedAuthCode == "" {
|
||||
p.writeAuthErrorResponse(w, req, redirectURI, state, "access_denied", "")
|
||||
return
|
||||
}
|
||||
|
||||
nonce := req.FormValue("nonce")
|
||||
if p.expectedAuthNonce != "" && p.expectedAuthNonce != nonce {
|
||||
p.writeAuthErrorResponse(w, req, redirectURI, state, "access_denied", "")
|
||||
return
|
||||
}
|
||||
|
||||
if state == "" {
|
||||
p.writeAuthErrorResponse(w, req, redirectURI, state, "invalid_request", "missing state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
if redirectURI == "" {
|
||||
p.writeAuthErrorResponse(w, req, redirectURI, state, "invalid_request", "missing redirect_uri parameter")
|
||||
return
|
||||
}
|
||||
|
||||
var s string
|
||||
switch {
|
||||
case p.expectedState != "":
|
||||
s = p.expectedState
|
||||
default:
|
||||
s = state
|
||||
}
|
||||
|
||||
if strings.Contains(respType, "id_token") {
|
||||
if respMode != "form_post" {
|
||||
p.writeAuthErrorResponse(w, req, redirectURI, state, "unsupported_response_mode", "must be form_post")
|
||||
}
|
||||
if p.disableImplicit {
|
||||
p.writeAuthErrorResponse(w, req, redirectURI, state, "access_denied", "")
|
||||
}
|
||||
err := p.writeImplicitResponse(w, s, redirectURI)
|
||||
require.NoErrorf(err, "%s: internal error: %w", token, err)
|
||||
return
|
||||
}
|
||||
|
||||
redirectURI += "?state=" + url.QueryEscape(s) +
|
||||
"&code=" + url.QueryEscape(p.expectedAuthCode)
|
||||
|
||||
http.Redirect(w, req, redirectURI, http.StatusFound)
|
||||
|
||||
return
|
||||
|
||||
case wellKnownJwks:
|
||||
if p.disableJWKs {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if p.invalidJWKs {
|
||||
_, err := w.Write([]byte("It's not a keyset!"))
|
||||
require.NoErrorf(err, "%s: internal error: %w", wellKnownJwks, err)
|
||||
return
|
||||
}
|
||||
if req.Method != "GET" {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
err := p.writeJSON(w, p.jwks)
|
||||
require.NoErrorf(err, "%s: internal error: %w", wellKnownJwks, err)
|
||||
return
|
||||
case token:
|
||||
if p.disableToken {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if req.Method != "POST" {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case req.FormValue("grant_type") != "authorization_code":
|
||||
_ = p.writeTokenErrorResponse(w, http.StatusBadRequest, "invalid_request", "bad grant_type")
|
||||
return
|
||||
case !strutils.StrListContains(p.allowedRedirectURIs, req.FormValue("redirect_uri")):
|
||||
_ = p.writeTokenErrorResponse(w, http.StatusBadRequest, "invalid_request", "redirect_uri is not allowed")
|
||||
return
|
||||
case req.FormValue("code") != p.expectedAuthCode:
|
||||
_ = p.writeTokenErrorResponse(w, http.StatusUnauthorized, "invalid_grant", "unexpected auth code")
|
||||
return
|
||||
case req.FormValue("code_verifier") != "" && req.FormValue("code_verifier") != p.pkceVerifier.Verifier():
|
||||
_ = p.writeTokenErrorResponse(w, http.StatusUnauthorized, "invalid_verifier", "unexpected verifier")
|
||||
return
|
||||
}
|
||||
|
||||
accessToken := p.issueSignedJWT()
|
||||
idToken := p.issueSignedJWT(withTestAtHash(accessToken), withTestCHash(p.expectedAuthCode))
|
||||
reply := struct {
|
||||
AccessToken string `json:"access_token,omitempty"`
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
}{
|
||||
AccessToken: accessToken,
|
||||
IDToken: idToken,
|
||||
}
|
||||
if p.omitIDToken {
|
||||
reply.IDToken = ""
|
||||
}
|
||||
if p.omitAccessToken {
|
||||
reply.AccessToken = ""
|
||||
}
|
||||
|
||||
if err := p.writeJSON(w, &reply); err != nil {
|
||||
require.NoErrorf(err, "%s: internal error: %w", token, err)
|
||||
return
|
||||
}
|
||||
return
|
||||
case userInfo:
|
||||
if p.disableUserInfo {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if req.Method != "GET" {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
if err := p.writeJSON(w, p.replyUserinfo); err != nil {
|
||||
require.NoErrorf(err, "%s: internal error: %w", userInfo, err)
|
||||
return
|
||||
}
|
||||
return
|
||||
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// httptestNewUnstartedServerWithPort is roughly the same as
|
||||
// httptest.NewUnstartedServer() but allows the caller to explicitly choose the
|
||||
// port if desired.
|
||||
func httptestNewUnstartedServerWithPort(t *testing.T, handler http.Handler, port int) *httptest.Server {
|
||||
t.Helper()
|
||||
require := require.New(t)
|
||||
require.NotNil(handler)
|
||||
if port == 0 {
|
||||
return httptest.NewUnstartedServer(handler)
|
||||
}
|
||||
addr := net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
|
||||
l, err := net.Listen("tcp", addr)
|
||||
require.NoError(err)
|
||||
|
||||
return &httptest.Server{
|
||||
Listener: l,
|
||||
Config: &http.Server{Handler: handler},
|
||||
}
|
||||
}
|
|
@ -0,0 +1,184 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// Token interface represents an OIDC id_token, as well as an Oauth2
|
||||
// access_token and refresh_token (including the the access_token expiry).
|
||||
type Token interface {
|
||||
// RefreshToken returns the Token's refresh_token.
|
||||
RefreshToken() RefreshToken
|
||||
|
||||
// AccessToken returns the Token's access_token.
|
||||
AccessToken() AccessToken
|
||||
|
||||
// IDToken returns the Token's id_token.
|
||||
IDToken() IDToken
|
||||
|
||||
// Expiry returns the expiration of the access_token.
|
||||
Expiry() time.Time
|
||||
|
||||
// Valid will ensure that the access_token is not empty or expired.
|
||||
Valid() bool
|
||||
|
||||
// IsExpired returns true if the token has expired. Implementations should
|
||||
// support a time skew (perhaps TokenExpirySkew) when checking expiration.
|
||||
IsExpired() bool
|
||||
}
|
||||
|
||||
// StaticTokenSource is a single function interface that defines a method to
|
||||
// create a oauth2.TokenSource that always returns the same token. Because the
|
||||
// token is never refreshed. A TokenSource can be used to when calling a
|
||||
// provider's UserInfo(), among other things.
|
||||
type StaticTokenSource interface {
|
||||
StaticTokenSource() oauth2.TokenSource
|
||||
}
|
||||
|
||||
// Tk satisfies the Token interface and represents an Oauth2 access_token and
|
||||
// refresh_token (including the the access_token expiry), as well as an OIDC
|
||||
// id_token. The access_token and refresh_token may be empty.
|
||||
type Tk struct {
|
||||
idToken IDToken
|
||||
underlying *oauth2.Token
|
||||
|
||||
// nowFunc is an optional function that returns the current time
|
||||
nowFunc func() time.Time
|
||||
}
|
||||
|
||||
// ensure that Tk implements the Token interface
|
||||
var _ Token = (*Tk)(nil)
|
||||
|
||||
// NewToken creates a new Token (*Tk). The IDToken is required and the
|
||||
// *oauth2.Token may be nil. Supports the WithNow option (with a default to
|
||||
// time.Now).
|
||||
func NewToken(i IDToken, t *oauth2.Token, opt ...Option) (*Tk, error) {
|
||||
// since oauth2 is part of stdlib we're not going to worry about it leaking
|
||||
// into our abstraction in this factory
|
||||
const op = "NewToken"
|
||||
if i == "" {
|
||||
return nil, fmt.Errorf("%s: id_token is empty: %w", op, ErrInvalidParameter)
|
||||
}
|
||||
opts := getTokenOpts(opt...)
|
||||
return &Tk{
|
||||
idToken: i,
|
||||
underlying: t,
|
||||
nowFunc: opts.withNowFunc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AccessToken implements the Token.AccessToken() interface function and may
|
||||
// return an empty AccessToken.
|
||||
func (t *Tk) AccessToken() AccessToken {
|
||||
if t.underlying == nil {
|
||||
return ""
|
||||
}
|
||||
return AccessToken(t.underlying.AccessToken)
|
||||
}
|
||||
|
||||
// RefreshToken implements the Token.RefreshToken() interface function and may
|
||||
// return an empty RefreshToken.
|
||||
func (t *Tk) RefreshToken() RefreshToken {
|
||||
if t.underlying == nil {
|
||||
return ""
|
||||
}
|
||||
return RefreshToken(t.underlying.RefreshToken)
|
||||
}
|
||||
|
||||
// IDToken implements the IDToken.IDToken() interface function.
|
||||
func (t *Tk) IDToken() IDToken { return IDToken(t.idToken) }
|
||||
|
||||
// TokenExpirySkew defines a time skew when checking a Token's expiration.
|
||||
const TokenExpirySkew = 10 * time.Second
|
||||
|
||||
// Expiry implements the Token.Expiry() interface function and may return a
|
||||
// "zero" time if the token's AccessToken is empty.
|
||||
func (t *Tk) Expiry() time.Time {
|
||||
if t.underlying == nil {
|
||||
return time.Time{}
|
||||
}
|
||||
return t.underlying.Expiry
|
||||
}
|
||||
|
||||
// StaticTokenSource returns a TokenSource that always returns the same token.
|
||||
// Because the provided token t is never refreshed. It will return nil, if the
|
||||
// t.AccessToken() is empty.
|
||||
func (t *Tk) StaticTokenSource() oauth2.TokenSource {
|
||||
if t.underlying == nil {
|
||||
return nil
|
||||
}
|
||||
return oauth2.StaticTokenSource(t.underlying)
|
||||
}
|
||||
|
||||
// IsExpired will return true if the token's access token is expired or empty.
|
||||
func (t *Tk) IsExpired() bool {
|
||||
if t.underlying == nil {
|
||||
return true
|
||||
}
|
||||
if t.underlying.Expiry.IsZero() {
|
||||
return false
|
||||
}
|
||||
return t.underlying.Expiry.Round(0).Before(time.Now().Add(TokenExpirySkew))
|
||||
}
|
||||
|
||||
// Valid will ensure that the access_token is not empty or expired. It will
|
||||
// return false if t.AccessToken() is empty.
|
||||
func (t *Tk) Valid() bool {
|
||||
if t == nil || t.underlying == nil {
|
||||
return false
|
||||
}
|
||||
if t.underlying.AccessToken == "" {
|
||||
return false
|
||||
}
|
||||
return !t.IsExpired()
|
||||
}
|
||||
|
||||
// now returns the current time using the optional nowFunc.
|
||||
func (t *Tk) now() time.Time {
|
||||
if t.nowFunc != nil {
|
||||
return t.nowFunc()
|
||||
}
|
||||
return time.Now() // fallback to this default
|
||||
}
|
||||
|
||||
// tokenOptions is the set of available options for Token functions
|
||||
type tokenOptions struct {
|
||||
withNowFunc func() time.Time
|
||||
}
|
||||
|
||||
// tokenDefaults is a handy way to get the defaults at runtime and during unit
|
||||
// tests.
|
||||
func tokenDefaults() tokenOptions {
|
||||
return tokenOptions{}
|
||||
}
|
||||
|
||||
// getTokenOpts gets the token defaults and applies the opt overrides passed
|
||||
// in
|
||||
func getTokenOpts(opt ...Option) tokenOptions {
|
||||
opts := tokenDefaults()
|
||||
ApplyOpts(&opts, opt...)
|
||||
return opts
|
||||
}
|
||||
|
||||
// UnmarshalClaims will retrieve the claims from the provided raw JWT token.
|
||||
func UnmarshalClaims(rawToken string, claims interface{}) error {
|
||||
const op = "UnmarshalClaims"
|
||||
parts := strings.Split(string(rawToken), ".")
|
||||
if len(parts) != 3 {
|
||||
return fmt.Errorf("%s: malformed jwt, expected 3 parts got %d: %w", op, len(parts), ErrInvalidParameter)
|
||||
}
|
||||
raw, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: malformed jwt claims: %w", op, err)
|
||||
}
|
||||
if err := json.Unmarshal(raw, claims); err != nil {
|
||||
return fmt.Errorf("%s: unable to marshal jwt JSON: %w", op, err)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -1,5 +1,8 @@
|
|||
## Next
|
||||
|
||||
## 0.8.1
|
||||
### December 14th, 2020
|
||||
|
||||
BUG FIXES:
|
||||
|
||||
* Fixes `bound_claims` validation for provider-specific group and user info fetching [[GH-149](https://github.com/hashicorp/vault-plugin-auth-jwt/pull/149)]
|
||||
|
|
|
@ -3,11 +3,11 @@ package jwtauth
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/cap/jwt"
|
||||
"github.com/hashicorp/cap/oidc"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/patrickmn/go-cache"
|
||||
|
@ -32,9 +32,9 @@ type jwtAuthBackend struct {
|
|||
|
||||
l sync.RWMutex
|
||||
provider *oidc.Provider
|
||||
keySet oidc.KeySet
|
||||
validator *jwt.Validator
|
||||
cachedConfig *jwtConfig
|
||||
oidcStates *cache.Cache
|
||||
oidcRequests *cache.Cache
|
||||
|
||||
providerCtx context.Context
|
||||
providerCtxCancel context.CancelFunc
|
||||
|
@ -43,7 +43,7 @@ type jwtAuthBackend struct {
|
|||
func backend() *jwtAuthBackend {
|
||||
b := new(jwtAuthBackend)
|
||||
b.providerCtx, b.providerCtxCancel = context.WithCancel(context.Background())
|
||||
b.oidcStates = cache.New(oidcStateTimeout, 1*time.Minute)
|
||||
b.oidcRequests = cache.New(oidcRequestTimeout, oidcRequestCleanupInterval)
|
||||
|
||||
b.Backend = &framework.Backend{
|
||||
AuthRenew: b.pathLoginRenew,
|
||||
|
@ -86,6 +86,9 @@ func (b *jwtAuthBackend) cleanup(_ context.Context) {
|
|||
if b.providerCtxCancel != nil {
|
||||
b.providerCtxCancel()
|
||||
}
|
||||
if b.provider != nil {
|
||||
b.provider.Done()
|
||||
}
|
||||
b.l.Unlock()
|
||||
}
|
||||
|
||||
|
@ -98,23 +101,18 @@ func (b *jwtAuthBackend) invalidate(ctx context.Context, key string) {
|
|||
|
||||
func (b *jwtAuthBackend) reset() {
|
||||
b.l.Lock()
|
||||
if b.provider != nil {
|
||||
b.provider.Done()
|
||||
}
|
||||
b.provider = nil
|
||||
b.cachedConfig = nil
|
||||
b.validator = nil
|
||||
b.l.Unlock()
|
||||
}
|
||||
|
||||
func (b *jwtAuthBackend) getProvider(config *jwtConfig) (*oidc.Provider, error) {
|
||||
b.l.RLock()
|
||||
unlockFunc := b.l.RUnlock
|
||||
defer func() { unlockFunc() }()
|
||||
|
||||
if b.provider != nil {
|
||||
return b.provider, nil
|
||||
}
|
||||
|
||||
b.l.RUnlock()
|
||||
b.l.Lock()
|
||||
unlockFunc = b.l.Unlock
|
||||
defer b.l.Unlock()
|
||||
|
||||
if b.provider != nil {
|
||||
return b.provider, nil
|
||||
|
@ -129,27 +127,42 @@ func (b *jwtAuthBackend) getProvider(config *jwtConfig) (*oidc.Provider, error)
|
|||
return provider, nil
|
||||
}
|
||||
|
||||
// getKeySet returns a new JWKS KeySet based on the provided config.
|
||||
func (b *jwtAuthBackend) getKeySet(config *jwtConfig) (oidc.KeySet, error) {
|
||||
// jwtValidator returns a new JWT validator based on the provided config.
|
||||
func (b *jwtAuthBackend) jwtValidator(config *jwtConfig) (*jwt.Validator, error) {
|
||||
b.l.Lock()
|
||||
defer b.l.Unlock()
|
||||
|
||||
if b.keySet != nil {
|
||||
return b.keySet, nil
|
||||
if b.validator != nil {
|
||||
return b.validator, nil
|
||||
}
|
||||
|
||||
if config.JWKSURL == "" {
|
||||
return nil, errors.New("keyset error: jwks_url not configured")
|
||||
var err error
|
||||
var keySet jwt.KeySet
|
||||
|
||||
// Configure the key set for the validator
|
||||
switch config.authType() {
|
||||
case JWKS:
|
||||
keySet, err = jwt.NewJSONWebKeySet(b.providerCtx, config.JWKSURL, config.JWKSCAPEM)
|
||||
case StaticKeys:
|
||||
keySet, err = jwt.NewStaticKeySet(config.ParsedJWTPubKeys)
|
||||
case OIDCDiscovery:
|
||||
keySet, err = jwt.NewOIDCDiscoveryKeySet(b.providerCtx, config.OIDCDiscoveryURL, config.OIDCDiscoveryCAPEM)
|
||||
default:
|
||||
return nil, errors.New("unsupported config type")
|
||||
}
|
||||
|
||||
ctx, err := b.createCAContext(b.providerCtx, config.JWKSCAPEM)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("error parsing jwks_ca_pem: {{err}}", err)
|
||||
return nil, fmt.Errorf("keyset configuration error: %w", err)
|
||||
}
|
||||
|
||||
b.keySet = oidc.NewRemoteKeySet(ctx, config.JWKSURL)
|
||||
validator, err := jwt.NewValidator(keySet)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("JWT validator configuration error: %w", err)
|
||||
}
|
||||
|
||||
return b.keySet, nil
|
||||
b.validator = validator
|
||||
|
||||
return b.validator, nil
|
||||
}
|
||||
|
||||
const (
|
||||
|
|
|
@ -3,13 +3,12 @@ module github.com/hashicorp/vault-plugin-auth-jwt
|
|||
go 1.14
|
||||
|
||||
require (
|
||||
github.com/coreos/go-oidc v2.1.0+incompatible
|
||||
github.com/go-test/deep v1.0.2-0.20181118220953-042da051cf31
|
||||
github.com/hashicorp/cap v0.0.0-20210204173447-5fcddadbf7c7
|
||||
github.com/hashicorp/errwrap v1.0.0
|
||||
github.com/hashicorp/go-cleanhttp v0.5.1
|
||||
github.com/hashicorp/go-hclog v0.12.0
|
||||
github.com/hashicorp/go-sockaddr v1.0.2
|
||||
github.com/hashicorp/go-uuid v1.0.2
|
||||
github.com/hashicorp/go-version v1.2.0 // indirect
|
||||
github.com/hashicorp/vault/api v1.0.5-0.20200215224050-f6547fa8e820
|
||||
github.com/hashicorp/vault/sdk v0.1.14-0.20200215224050-f6547fa8e820
|
||||
|
@ -17,11 +16,10 @@ require (
|
|||
github.com/mitchellh/mapstructure v1.1.2
|
||||
github.com/mitchellh/pointerstructure v1.0.0
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
|
||||
github.com/ryanuber/go-glob v1.0.0
|
||||
github.com/stretchr/testify v1.4.0
|
||||
github.com/stretchr/testify v1.6.1
|
||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d
|
||||
golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a
|
||||
google.golang.org/api v0.29.0
|
||||
gopkg.in/square/go-jose.v2 v2.4.1
|
||||
gopkg.in/square/go-jose.v2 v2.5.1
|
||||
)
|
||||
|
|
|
@ -40,8 +40,8 @@ github.com/circonus-labs/circonus-gometrics v2.3.1+incompatible/go.mod h1:nmEj6D
|
|||
github.com/circonus-labs/circonusllhist v0.1.3/go.mod h1:kMXHVDlOchFAehlya5ePtbp5jckzBHf4XRpQvBOLI+I=
|
||||
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
|
||||
github.com/coreos/go-oidc v2.1.0+incompatible h1:sdJrfw8akMnCuUlaZU3tE/uYXFgfqom8DBE9so9EBsM=
|
||||
github.com/coreos/go-oidc v2.1.0+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc=
|
||||
github.com/coreos/go-oidc v2.2.1+incompatible h1:mh48q/BqXqgjVHpy2ZY7WnWAbenxRjsz9N1i1YxjHAk=
|
||||
github.com/coreos/go-oidc v2.2.1+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
|
@ -87,6 +87,8 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
|
|||
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4=
|
||||
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.4 h1:L8R9j+yAqZuZjsqh/z+F1NCffTKKLShY6zXTItVIZ8M=
|
||||
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
|
||||
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
|
||||
github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
|
||||
|
@ -97,6 +99,8 @@ github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm4
|
|||
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
|
||||
github.com/googleapis/gax-go/v2 v2.0.5 h1:sjZBwGj9Jlw33ImPtvFviGYvseOtDM7hkSKB7+Tv3SM=
|
||||
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
|
||||
github.com/hashicorp/cap v0.0.0-20210204173447-5fcddadbf7c7 h1:6OHvaQs9ys66bR1yqHuoI231JAoalgGgxeqzQuVOfX0=
|
||||
github.com/hashicorp/cap v0.0.0-20210204173447-5fcddadbf7c7/go.mod h1:tIk5rB1nihW5+9bZjI7xlc8LGw8FYfiFMKOpHPbWgug=
|
||||
github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=
|
||||
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
github.com/hashicorp/go-cleanhttp v0.5.0/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80=
|
||||
|
@ -184,8 +188,8 @@ github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
|
|||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI=
|
||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=
|
||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA=
|
||||
github.com/pquerna/cachecontrol v0.0.0-20201205024021-ac21108117ac h1:jWKYCNlX4J5s8M0nHYkh7Y7c9gRVDEb3mq51j5J0F5M=
|
||||
github.com/pquerna/cachecontrol v0.0.0-20201205024021-ac21108117ac/go.mod h1:hoLfEwdY11HjRfKFH6KqnPsfxlo3BP6bJehpDv8t6sQ=
|
||||
github.com/prometheus/client_golang v0.9.2/go.mod h1:OsXs2jCmiKlQ1lTBmv21f2mNfw4xf/QclQDMrYNZzcM=
|
||||
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
|
@ -201,7 +205,10 @@ github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0
|
|||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM=
|
||||
github.com/yhat/scrape v0.0.0-20161128144610-24b7890b0945/go.mod h1:4vRFPPNYllgCacoj+0FoKOjTW68rUhEfqPLiEJaK2w8=
|
||||
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
|
||||
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
|
||||
|
@ -216,6 +223,8 @@ golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8U
|
|||
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||
|
@ -266,6 +275,8 @@ golang.org/x/net v0.0.0-20200222125558-5a598a2470a0/go.mod h1:z5CRVTTTmAJ677TzLL
|
|||
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e h1:3G+cUijn7XD+S4eJFddp53Pv7+slrESplyjG25HgL+k=
|
||||
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
|
||||
golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA=
|
||||
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
|
@ -312,6 +323,8 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
|||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 h1:SvFZT6jyqRaOeXpc5h/JSfZenJ2O330aBsf7JfSUXmQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
|
@ -410,10 +423,12 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8
|
|||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
gopkg.in/square/go-jose.v2 v2.3.1 h1:SK5KegNXmKmqE342YYN2qPHEnUYeoMiXXl1poUlI+o4=
|
||||
gopkg.in/square/go-jose.v2 v2.3.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
|
||||
gopkg.in/square/go-jose.v2 v2.4.1 h1:H0TmLt7/KmzlrDOpa1F+zr0Tk90PbJYBfsVUmRLrf9Y=
|
||||
gopkg.in/square/go-jose.v2 v2.4.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
|
||||
gopkg.in/square/go-jose.v2 v2.5.1 h1:7odma5RETjNHWJnR32wx8t+Io4djHE1PqxCFx3iiZ2w=
|
||||
gopkg.in/square/go-jose.v2 v2.5.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
|
||||
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
|
|
|
@ -2,14 +2,15 @@ package jwtauth
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/hashicorp/cap/jwt"
|
||||
"github.com/hashicorp/cap/oidc"
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/go-cleanhttp"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
|
@ -17,7 +18,6 @@ import (
|
|||
"github.com/hashicorp/vault/sdk/helper/strutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"golang.org/x/oauth2"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -236,22 +236,25 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque
|
|||
return logical.ErrorResponse("both 'oidc_client_id' and 'oidc_client_secret' must be set for OIDC"), nil
|
||||
|
||||
case config.OIDCDiscoveryURL != "":
|
||||
_, err := b.createProvider(config)
|
||||
var err error
|
||||
if config.OIDCClientID != "" && config.OIDCClientSecret != "" {
|
||||
_, err = b.createProvider(config)
|
||||
} else {
|
||||
_, err = jwt.NewOIDCDiscoveryKeySet(ctx, config.OIDCDiscoveryURL, config.OIDCDiscoveryCAPEM)
|
||||
}
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(errwrap.Wrapf("error checking oidc discovery URL: {{err}}", err).Error()), nil
|
||||
return logical.ErrorResponse("error checking oidc discovery URL: %s", err.Error()), nil
|
||||
}
|
||||
|
||||
case config.OIDCClientID != "" && config.OIDCDiscoveryURL == "":
|
||||
return logical.ErrorResponse("'oidc_discovery_url' must be set for OIDC"), nil
|
||||
|
||||
case config.JWKSURL != "":
|
||||
ctx, err := b.createCAContext(context.Background(), config.JWKSCAPEM)
|
||||
keyset, err := jwt.NewJSONWebKeySet(ctx, config.JWKSURL, config.JWKSCAPEM)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(errwrap.Wrapf("error checking jwks_ca_pem: {{err}}", err).Error()), nil
|
||||
}
|
||||
|
||||
keyset := oidc.NewRemoteKeySet(ctx, config.JWKSURL)
|
||||
|
||||
// Try to verify a correctly formatted JWT. The signature will fail to match, but other
|
||||
// errors with fetching the remote keyset should be reported.
|
||||
testJWT := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Hf3E3iCHzqC5QIQ0nCqS1kw78IiQTRVzsLTuKoDIpdk"
|
||||
|
@ -278,12 +281,8 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque
|
|||
// NOTE: the OIDC lib states that if nothing is passed into its config, it
|
||||
// defaults to "RS256". So in the case of a zero value here it won't
|
||||
// default to e.g. "none".
|
||||
for _, a := range config.JWTSupportedAlgs {
|
||||
switch a {
|
||||
case oidc.RS256, oidc.RS384, oidc.RS512, oidc.ES256, oidc.ES384, oidc.ES512, oidc.PS256, oidc.PS384, oidc.PS512, string(jose.EdDSA):
|
||||
default:
|
||||
return logical.ErrorResponse(fmt.Sprintf("Invalid supported algorithm: %s", a)), nil
|
||||
}
|
||||
if err := jwt.SupportedSigningAlgorithm(toAlg(config.JWTSupportedAlgs)...); err != nil {
|
||||
return logical.ErrorResponse("invalid jwt_supported_algs: %s", err), nil
|
||||
}
|
||||
|
||||
// Validate response_types
|
||||
|
@ -321,12 +320,23 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque
|
|||
}
|
||||
|
||||
func (b *jwtAuthBackend) createProvider(config *jwtConfig) (*oidc.Provider, error) {
|
||||
oidcCtx, err := b.createCAContext(b.providerCtx, config.OIDCDiscoveryCAPEM)
|
||||
supportedSigAlgs := make([]oidc.Alg, len(config.JWTSupportedAlgs))
|
||||
for i, a := range config.JWTSupportedAlgs {
|
||||
supportedSigAlgs[i] = oidc.Alg(a)
|
||||
}
|
||||
|
||||
if len(supportedSigAlgs) == 0 {
|
||||
supportedSigAlgs = []oidc.Alg{oidc.RS256}
|
||||
}
|
||||
|
||||
c, err := oidc.NewConfig(config.OIDCDiscoveryURL, config.OIDCClientID,
|
||||
oidc.ClientSecret(config.OIDCClientSecret), supportedSigAlgs, []string{},
|
||||
oidc.WithProviderCA(config.OIDCDiscoveryCAPEM))
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("error creating provider: {{err}}", err)
|
||||
}
|
||||
|
||||
provider, err := oidc.NewProvider(oidcCtx, config.OIDCDiscoveryURL)
|
||||
provider, err := oidc.NewProvider(c)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("error creating provider with given values: {{err}}", err)
|
||||
}
|
||||
|
@ -377,7 +387,7 @@ type jwtConfig struct {
|
|||
ProviderConfig map[string]interface{} `json:"provider_config"`
|
||||
NamespaceInState bool `json:"namespace_in_state"`
|
||||
|
||||
ParsedJWTPubKeys []interface{} `json:"-"`
|
||||
ParsedJWTPubKeys []crypto.PublicKey `json:"-"`
|
||||
}
|
||||
|
||||
const (
|
||||
|
|
|
@ -2,17 +2,15 @@ package jwtauth
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/hashicorp/cap/jwt"
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/cidrutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func pathLogin(b *jwtAuthBackend) *framework.Path {
|
||||
|
@ -88,132 +86,37 @@ func (b *jwtAuthBackend) pathLogin(ctx context.Context, req *logical.Request, d
|
|||
}
|
||||
}
|
||||
|
||||
// Here is where things diverge. If it is using OIDC Discovery, validate that way;
|
||||
// otherwise validate against the locally configured or JWKS keys. Once things are
|
||||
// validated, we re-unify the request path when evaluating the claims.
|
||||
allClaims := map[string]interface{}{}
|
||||
configType := config.authType()
|
||||
|
||||
switch {
|
||||
case configType == StaticKeys || configType == JWKS:
|
||||
claims := jwt.Claims{}
|
||||
if configType == JWKS {
|
||||
keySet, err := b.getKeySet(config)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(errwrap.Wrapf("error fetching jwks keyset: {{err}}", err).Error()), nil
|
||||
}
|
||||
|
||||
// Verify signature (and only signature... other elements are checked later)
|
||||
payload, err := keySet.VerifySignature(ctx, token)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(errwrap.Wrapf("error verifying token: {{err}}", err).Error()), nil
|
||||
}
|
||||
|
||||
// Unmarshal payload into two copies: public claims for library verification, and a set
|
||||
// of all received claims.
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal claims: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal(payload, &allClaims); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal claims: %v", err)
|
||||
}
|
||||
} else {
|
||||
parsedJWT, err := jwt.ParseSigned(token)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(errwrap.Wrapf("error parsing token: {{err}}", err).Error()), nil
|
||||
}
|
||||
|
||||
var valid bool
|
||||
for _, key := range config.ParsedJWTPubKeys {
|
||||
if err := parsedJWT.Claims(key, &claims, &allClaims); err == nil {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
return logical.ErrorResponse("no known key successfully validated the token signature"), nil
|
||||
}
|
||||
}
|
||||
|
||||
// We require notbefore or expiry; if only one is provided, we allow 5 minutes of leeway by default.
|
||||
// Configurable by ExpirationLeeway and NotBeforeLeeway
|
||||
if claims.IssuedAt == nil {
|
||||
claims.IssuedAt = new(jwt.NumericDate)
|
||||
}
|
||||
if claims.Expiry == nil {
|
||||
claims.Expiry = new(jwt.NumericDate)
|
||||
}
|
||||
if claims.NotBefore == nil {
|
||||
claims.NotBefore = new(jwt.NumericDate)
|
||||
}
|
||||
if *claims.IssuedAt == 0 && *claims.Expiry == 0 && *claims.NotBefore == 0 {
|
||||
return logical.ErrorResponse("no issue time, notbefore, or expiration time encoded in token"), nil
|
||||
}
|
||||
|
||||
if *claims.Expiry == 0 {
|
||||
latestStart := *claims.IssuedAt
|
||||
if *claims.NotBefore > *claims.IssuedAt {
|
||||
latestStart = *claims.NotBefore
|
||||
}
|
||||
leeway := role.ExpirationLeeway.Seconds()
|
||||
if role.ExpirationLeeway.Seconds() < 0 {
|
||||
leeway = 0
|
||||
} else if role.ExpirationLeeway.Seconds() == 0 {
|
||||
leeway = claimDefaultLeeway
|
||||
}
|
||||
*claims.Expiry = jwt.NumericDate(int64(latestStart) + int64(leeway))
|
||||
}
|
||||
|
||||
if *claims.NotBefore == 0 {
|
||||
if *claims.IssuedAt != 0 {
|
||||
*claims.NotBefore = *claims.IssuedAt
|
||||
} else {
|
||||
leeway := role.NotBeforeLeeway.Seconds()
|
||||
if role.NotBeforeLeeway.Seconds() < 0 {
|
||||
leeway = 0
|
||||
} else if role.NotBeforeLeeway.Seconds() == 0 {
|
||||
leeway = claimDefaultLeeway
|
||||
}
|
||||
*claims.NotBefore = jwt.NumericDate(int64(*claims.Expiry) - int64(leeway))
|
||||
}
|
||||
}
|
||||
|
||||
if len(claims.Audience) > 0 && len(role.BoundAudiences) == 0 {
|
||||
return logical.ErrorResponse("audience claim found in JWT but no audiences bound to the role"), nil
|
||||
}
|
||||
|
||||
expected := jwt.Expected{
|
||||
Issuer: config.BoundIssuer,
|
||||
Subject: role.BoundSubject,
|
||||
Time: time.Now(),
|
||||
}
|
||||
|
||||
cksLeeway := role.ClockSkewLeeway
|
||||
if role.ClockSkewLeeway.Seconds() < 0 {
|
||||
cksLeeway = 0
|
||||
} else if role.ClockSkewLeeway.Seconds() == 0 {
|
||||
cksLeeway = jwt.DefaultLeeway
|
||||
}
|
||||
|
||||
if err := claims.ValidateWithLeeway(expected, cksLeeway); err != nil {
|
||||
return logical.ErrorResponse(errwrap.Wrapf("error validating claims: {{err}}", err).Error()), nil
|
||||
}
|
||||
|
||||
if err := validateAudience(role.BoundAudiences, claims.Audience, true); err != nil {
|
||||
return logical.ErrorResponse(errwrap.Wrapf("error validating claims: {{err}}", err).Error()), nil
|
||||
}
|
||||
|
||||
case configType == OIDCDiscovery:
|
||||
allClaims, err = b.verifyOIDCToken(ctx, config, role, token)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, errors.New("unhandled case during login")
|
||||
// Get the JWT validator based on the configured auth type
|
||||
validator, err := b.jwtValidator(config)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse("error configuring token validator: %s", err.Error()), nil
|
||||
}
|
||||
|
||||
alias, groupAliases, err := b.createIdentity(ctx, allClaims, role)
|
||||
// Set expected claims values to assert on the JWT
|
||||
expected := jwt.Expected{
|
||||
Issuer: config.BoundIssuer,
|
||||
Subject: role.BoundSubject,
|
||||
Audiences: role.BoundAudiences,
|
||||
SigningAlgorithms: toAlg(config.JWTSupportedAlgs),
|
||||
NotBeforeLeeway: role.NotBeforeLeeway,
|
||||
ExpirationLeeway: role.ExpirationLeeway,
|
||||
ClockSkewLeeway: role.ClockSkewLeeway,
|
||||
}
|
||||
|
||||
// Validate the JWT by verifying its signature and asserting expected claims values
|
||||
allClaims, err := validator.Validate(ctx, token, expected)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse("error validating token: %s", err.Error()), nil
|
||||
}
|
||||
|
||||
// If there are no bound audiences for the role, then the existence of any audience
|
||||
// in the audience claim should result in an error.
|
||||
aud, ok := getClaim(b.Logger(), allClaims, "aud").([]interface{})
|
||||
if ok && len(aud) > 0 && len(role.BoundAudiences) == 0 {
|
||||
return logical.ErrorResponse("audience claim found in JWT but no audiences bound to the role"), nil
|
||||
}
|
||||
|
||||
alias, groupAliases, err := b.createIdentity(ctx, allClaims, role, nil)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
@ -266,49 +169,9 @@ func (b *jwtAuthBackend) pathLoginRenew(ctx context.Context, req *logical.Reques
|
|||
return resp, nil
|
||||
}
|
||||
|
||||
func (b *jwtAuthBackend) verifyOIDCToken(ctx context.Context, config *jwtConfig, role *jwtRole, rawToken string) (map[string]interface{}, error) {
|
||||
allClaims := make(map[string]interface{})
|
||||
|
||||
provider, err := b.getProvider(config)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("error getting provider for login operation: {{err}}", err)
|
||||
}
|
||||
|
||||
oidcConfig := &oidc.Config{
|
||||
SupportedSigningAlgs: config.JWTSupportedAlgs,
|
||||
}
|
||||
|
||||
if role.RoleType == "oidc" {
|
||||
oidcConfig.ClientID = config.OIDCClientID
|
||||
} else {
|
||||
oidcConfig.SkipClientIDCheck = true
|
||||
}
|
||||
|
||||
verifier := provider.Verifier(oidcConfig)
|
||||
|
||||
idToken, err := verifier.Verify(ctx, rawToken)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("error validating signature: {{err}}", err)
|
||||
}
|
||||
|
||||
if err := idToken.Claims(&allClaims); err != nil {
|
||||
return nil, errwrap.Wrapf("unable to successfully parse all claims from token: {{err}}", err)
|
||||
}
|
||||
|
||||
if role.BoundSubject != "" && role.BoundSubject != idToken.Subject {
|
||||
return nil, errors.New("sub claim does not match bound subject")
|
||||
}
|
||||
|
||||
if err := validateAudience(role.BoundAudiences, idToken.Audience, false); err != nil {
|
||||
return nil, errwrap.Wrapf("error validating claims: {{err}}", err)
|
||||
}
|
||||
|
||||
return allClaims, nil
|
||||
}
|
||||
|
||||
// createIdentity creates an alias and set of groups aliases based on the role
|
||||
// definition and received claims.
|
||||
func (b *jwtAuthBackend) createIdentity(ctx context.Context, allClaims map[string]interface{}, role *jwtRole) (*logical.Alias, []*logical.Alias, error) {
|
||||
func (b *jwtAuthBackend) createIdentity(ctx context.Context, allClaims map[string]interface{}, role *jwtRole, tokenSource oauth2.TokenSource) (*logical.Alias, []*logical.Alias, error) {
|
||||
userClaimRaw, ok := allClaims[role.UserClaim]
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("claim %q not found in token", role.UserClaim)
|
||||
|
@ -343,7 +206,7 @@ func (b *jwtAuthBackend) createIdentity(ctx context.Context, allClaims map[strin
|
|||
return alias, groupAliases, nil
|
||||
}
|
||||
|
||||
groupsClaimRaw, err := b.fetchGroups(ctx, pConfig, allClaims, role)
|
||||
groupsClaimRaw, err := b.fetchGroups(ctx, pConfig, allClaims, role, tokenSource)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to fetch groups: %s", err)
|
||||
}
|
||||
|
@ -382,12 +245,12 @@ func (b *jwtAuthBackend) fetchUserInfo(ctx context.Context, pConfig CustomProvid
|
|||
}
|
||||
|
||||
// Checks if there's a custom provider_config and calls FetchGroups() if implemented
|
||||
func (b *jwtAuthBackend) fetchGroups(ctx context.Context, pConfig CustomProvider, allClaims map[string]interface{}, role *jwtRole) (interface{}, error) {
|
||||
func (b *jwtAuthBackend) fetchGroups(ctx context.Context, pConfig CustomProvider, allClaims map[string]interface{}, role *jwtRole, tokenSource oauth2.TokenSource) (interface{}, error) {
|
||||
// If the custom provider implements interface GroupsFetcher, call it,
|
||||
// otherwise fall through to the default method
|
||||
if pConfig != nil {
|
||||
if gf, ok := pConfig.(GroupsFetcher); ok {
|
||||
groupsRaw, err := gf.FetchGroups(ctx, b, allClaims, role)
|
||||
groupsRaw, err := gf.FetchGroups(ctx, b, allClaims, role, tokenSource)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -406,6 +269,14 @@ func (b *jwtAuthBackend) fetchGroups(ctx context.Context, pConfig CustomProvider
|
|||
return groupsClaimRaw, nil
|
||||
}
|
||||
|
||||
func toAlg(a []string) []jwt.Alg {
|
||||
alg := make([]jwt.Alg, len(a))
|
||||
for i, e := range a {
|
||||
alg[i] = jwt.Alg(e)
|
||||
}
|
||||
return alg
|
||||
}
|
||||
|
||||
const (
|
||||
pathLoginHelpSyn = `
|
||||
Authenticates to Vault using a JWT (or OIDC) token.
|
||||
|
|
|
@ -10,17 +10,18 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/hashicorp/cap/oidc"
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/cidrutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/strutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
var oidcStateTimeout = 10 * time.Minute
|
||||
const (
|
||||
oidcRequestTimeout = 10 * time.Minute
|
||||
oidcRequestCleanupInterval = 1 * time.Minute
|
||||
)
|
||||
|
||||
const (
|
||||
// OIDC error prefixes. These are searched for specifically by the UI, so any
|
||||
|
@ -33,14 +34,15 @@ const (
|
|||
noCode = "no_code"
|
||||
)
|
||||
|
||||
// oidcState is created when an authURL is requested. The state identifier is
|
||||
// passed throughout the OAuth process.
|
||||
type oidcState struct {
|
||||
rolename string
|
||||
nonce string
|
||||
redirectURI string
|
||||
code string
|
||||
idToken string
|
||||
// oidcRequest represents a single OIDC authentication flow. It is created when
|
||||
// an authURL is requested. It is uniquely identified by a state, which is passed
|
||||
// throughout the multiple interactions needed to complete the flow.
|
||||
type oidcRequest struct {
|
||||
oidc.Request
|
||||
|
||||
rolename string
|
||||
code string
|
||||
idToken string
|
||||
|
||||
// clientNonce is used between Vault and the client/application (e.g. CLI) making the request,
|
||||
// and is unrelated to the OIDC nonce above. It is optional.
|
||||
|
@ -136,13 +138,13 @@ func (b *jwtAuthBackend) pathCallbackPost(ctx context.Context, req *logical.Requ
|
|||
},
|
||||
}
|
||||
|
||||
// Store the provided code and/or token into state, which must already exist.
|
||||
state, err := b.amendState(stateID, code, idToken)
|
||||
// Store the provided code and/or token into its OIDC request, which must already exist.
|
||||
oidcReq, err := b.amendOIDCRequest(stateID, code, idToken)
|
||||
if err != nil {
|
||||
resp.Data[logical.HTTPRawBody] = []byte(errorHTML(errLoginFailed, "Expired or missing OAuth state."))
|
||||
resp.Data[logical.HTTPStatusCode] = http.StatusBadRequest
|
||||
} else {
|
||||
mount := parseMount(state.redirectURI)
|
||||
mount := parseMount(oidcReq.RedirectURL())
|
||||
if mount == "" {
|
||||
resp.Data[logical.HTTPRawBody] = []byte(errorHTML(errLoginFailed, "Invalid redirect path."))
|
||||
resp.Data[logical.HTTPStatusCode] = http.StatusBadRequest
|
||||
|
@ -165,8 +167,8 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request,
|
|||
|
||||
stateID := d.Get("state").(string)
|
||||
|
||||
state := b.verifyState(stateID)
|
||||
if state == nil {
|
||||
oidcReq := b.verifyOIDCRequest(stateID)
|
||||
if oidcReq == nil {
|
||||
return logical.ErrorResponse(errLoginFailed + " Expired or missing OAuth state."), nil
|
||||
}
|
||||
|
||||
|
@ -174,11 +176,11 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request,
|
|||
|
||||
// If a client_nonce was provided at the start of the auth process as part of the auth_url
|
||||
// request, require that it is present and matching during the callback phase.
|
||||
if state.clientNonce != "" && clientNonce != state.clientNonce {
|
||||
if oidcReq.clientNonce != "" && clientNonce != oidcReq.clientNonce {
|
||||
return logical.ErrorResponse("invalid client_nonce"), nil
|
||||
}
|
||||
|
||||
roleName := state.rolename
|
||||
roleName := oidcReq.rolename
|
||||
role, err := b.role(ctx, req.Storage, roleName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -202,74 +204,67 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request,
|
|||
return nil, errwrap.Wrapf("error getting provider for login operation: {{err}}", err)
|
||||
}
|
||||
|
||||
oidcCtx, err := b.createCAContext(ctx, config.OIDCDiscoveryCAPEM)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("error preparing context for login operation: {{err}}", err)
|
||||
}
|
||||
|
||||
var oauth2Config = oauth2.Config{
|
||||
ClientID: config.OIDCClientID,
|
||||
ClientSecret: config.OIDCClientSecret,
|
||||
RedirectURL: state.redirectURI,
|
||||
Endpoint: provider.Endpoint(),
|
||||
Scopes: []string{oidc.ScopeOpenID},
|
||||
}
|
||||
|
||||
var rawToken string
|
||||
var oauth2Token *oauth2.Token
|
||||
var rawToken oidc.IDToken
|
||||
var token *oidc.Tk
|
||||
|
||||
code := d.Get("code").(string)
|
||||
if code == noCode {
|
||||
code = state.code
|
||||
code = oidcReq.code
|
||||
}
|
||||
|
||||
if code == "" {
|
||||
if state.idToken == "" {
|
||||
if oidcReq.idToken == "" {
|
||||
return logical.ErrorResponse(errLoginFailed + " No code or id_token received."), nil
|
||||
}
|
||||
rawToken = state.idToken
|
||||
rawToken = oidc.IDToken(oidcReq.idToken)
|
||||
} else {
|
||||
oauth2Token, err = oauth2Config.Exchange(oidcCtx, code)
|
||||
// ID token verification takes place in exchange
|
||||
token, err = provider.Exchange(ctx, oidcReq, stateID, code)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(errLoginFailed+" Error exchanging oidc code: %q.", err.Error()), nil
|
||||
}
|
||||
|
||||
// Extract the ID Token from OAuth2 token.
|
||||
var ok bool
|
||||
rawToken, ok = oauth2Token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return logical.ErrorResponse(errTokenVerification + " No id_token found in response."), nil
|
||||
}
|
||||
rawToken = token.IDToken()
|
||||
}
|
||||
|
||||
if role.VerboseOIDCLogging {
|
||||
b.Logger().Debug("OIDC provider response", "ID token", rawToken)
|
||||
loggedToken := "invalid token format"
|
||||
|
||||
parts := strings.Split(string(rawToken), ".")
|
||||
if len(parts) == 3 {
|
||||
// strip signature from logged token
|
||||
loggedToken = fmt.Sprintf("%s.%s.xxxxxxxxxxx", parts[0], parts[1])
|
||||
}
|
||||
|
||||
b.Logger().Debug("OIDC provider response", "id_token", loggedToken)
|
||||
}
|
||||
|
||||
// Parse and verify ID Token payload.
|
||||
allClaims, err := b.verifyOIDCToken(ctx, config, role, rawToken)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse("%s %s", errTokenVerification, err.Error()), nil
|
||||
}
|
||||
|
||||
if allClaims["nonce"] != state.nonce {
|
||||
return logical.ErrorResponse(errTokenVerification + " Invalid ID token nonce."), nil
|
||||
// Parse claims from the ID token payload.
|
||||
var allClaims map[string]interface{}
|
||||
if err := rawToken.Claims(&allClaims); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
delete(allClaims, "nonce")
|
||||
|
||||
// Get the subject claim for bound subject and user info validation
|
||||
var subject string
|
||||
if subStr, ok := allClaims["sub"].(string); ok {
|
||||
subject = subStr
|
||||
}
|
||||
|
||||
if role.BoundSubject != "" && role.BoundSubject != subject {
|
||||
return nil, errors.New("sub claim does not match bound subject")
|
||||
}
|
||||
|
||||
// If we have a token, attempt to fetch information from the /userinfo endpoint
|
||||
// and merge it with the existing claims data. A failure to fetch additional information
|
||||
// from this endpoint will not invalidate the authorization flow.
|
||||
if oauth2Token != nil {
|
||||
if userinfo, err := provider.UserInfo(oidcCtx, oauth2.StaticTokenSource(oauth2Token)); err == nil {
|
||||
_ = userinfo.Claims(&allClaims)
|
||||
} else {
|
||||
logFunc := b.Logger().Warn
|
||||
if strings.Contains(err.Error(), "user info endpoint is not supported") {
|
||||
logFunc = b.Logger().Info
|
||||
}
|
||||
logFunc("error reading /userinfo endpoint", "error", err)
|
||||
if err := provider.UserInfo(ctx, token.StaticTokenSource(), subject, &allClaims); err != nil {
|
||||
logFunc := b.Logger().Warn
|
||||
if strings.Contains(err.Error(), "user info endpoint is not supported") {
|
||||
logFunc = b.Logger().Info
|
||||
}
|
||||
logFunc("error reading /userinfo endpoint", "error", err)
|
||||
}
|
||||
|
||||
if role.VerboseOIDCLogging {
|
||||
|
@ -280,7 +275,7 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request,
|
|||
}
|
||||
}
|
||||
|
||||
alias, groupAliases, err := b.createIdentity(ctx, allClaims, role)
|
||||
alias, groupAliases, err := b.createIdentity(ctx, allClaims, role, token.StaticTokenSource())
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
@ -396,7 +391,7 @@ func (b *jwtAuthBackend) authURL(ctx context.Context, req *logical.Request, d *f
|
|||
|
||||
// If configured for form_post, redirect directly to Vault instead of the UI,
|
||||
// if this was initiated by the UI (which currently has no knowledge of mode).
|
||||
///
|
||||
//
|
||||
// TODO: it would be better to convey this to the UI and have it send the
|
||||
// correct URL directly.
|
||||
if config.OIDCResponseMode == responseModeFormPost {
|
||||
|
@ -409,106 +404,84 @@ func (b *jwtAuthBackend) authURL(ctx context.Context, req *logical.Request, d *f
|
|||
return resp, nil
|
||||
}
|
||||
|
||||
// "openid" is a required scope for OpenID Connect flows
|
||||
scopes := append([]string{oidc.ScopeOpenID}, role.OIDCScopes...)
|
||||
|
||||
// Configure an OpenID Connect aware OAuth2 client
|
||||
oauth2Config := oauth2.Config{
|
||||
ClientID: config.OIDCClientID,
|
||||
ClientSecret: config.OIDCClientSecret,
|
||||
RedirectURL: redirectURI,
|
||||
Endpoint: provider.Endpoint(),
|
||||
Scopes: scopes,
|
||||
}
|
||||
|
||||
stateID, nonce, err := b.createState(roleName, redirectURI, clientNonce)
|
||||
oidcReq, err := b.createOIDCRequest(config, role, roleName, redirectURI, clientNonce)
|
||||
if err != nil {
|
||||
logger.Warn("error generating OAuth state", "error", err)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
urlStr, err := provider.AuthURL(ctx, oidcReq)
|
||||
if err != nil {
|
||||
logger.Warn("error generating auth URL", "error", err)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// embed namespace in state in the auth_url
|
||||
if config.NamespaceInState && len(namespace) > 0 {
|
||||
// embed namespace in state in the auth_url
|
||||
stateID = fmt.Sprintf("%s,ns=%s", stateID, namespace)
|
||||
stateWithNamespace := fmt.Sprintf("%s,ns=%s", oidcReq.State(), namespace)
|
||||
urlStr = strings.Replace(urlStr, oidcReq.State(), url.QueryEscape(stateWithNamespace), 1)
|
||||
}
|
||||
|
||||
authCodeOpts := []oauth2.AuthCodeOption{
|
||||
oidc.Nonce(nonce),
|
||||
}
|
||||
|
||||
// Add "form_post" param if requested. Note: the operator is allowed to configure "query"
|
||||
// as well, but that is the default for the AuthCode method and needn't be explicitly added.
|
||||
if config.OIDCResponseMode == responseModeFormPost {
|
||||
authCodeOpts = append(authCodeOpts, oauth2.SetAuthURLParam("response_mode", responseModeFormPost))
|
||||
}
|
||||
|
||||
// Build the final authorization URL. oauth2Config doesn't support response types other than
|
||||
// code, so some manual tweaking is required.
|
||||
urlStr := oauth2Config.AuthCodeURL(stateID, authCodeOpts...)
|
||||
|
||||
var rt string
|
||||
if config.hasType(responseTypeCode) {
|
||||
rt += responseTypeCode + " "
|
||||
}
|
||||
if config.hasType(responseTypeIDToken) {
|
||||
rt += responseTypeIDToken + " "
|
||||
}
|
||||
|
||||
rt = strings.TrimSpace(rt)
|
||||
urlStr = strings.Replace(urlStr, "response_type=code",
|
||||
fmt.Sprintf("response_type=%s", url.QueryEscape(rt)), 1)
|
||||
|
||||
resp.Data["auth_url"] = urlStr
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// createState make an expiring state object, associated with a random state ID
|
||||
// that is passed throughout the OAuth process. A nonce is also included in the
|
||||
// auth process, and for simplicity will be identical in length/format as the state ID.
|
||||
func (b *jwtAuthBackend) createState(rolename, redirectURI, clientNonce string) (string, string, error) {
|
||||
// Get enough bytes for 2 160-bit IDs (per rfc6749#section-10.10)
|
||||
bytes, err := uuid.GenerateRandomBytes(2 * 20)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
// createOIDCRequest makes an expiring request object, associated with a random state ID
|
||||
// that is passed throughout the OAuth process. A nonce is also included in the auth process.
|
||||
func (b *jwtAuthBackend) createOIDCRequest(config *jwtConfig, role *jwtRole, rolename, redirectURI, clientNonce string) (*oidcRequest, error) {
|
||||
options := []oidc.Option{
|
||||
oidc.WithAudiences(role.BoundAudiences...),
|
||||
oidc.WithScopes(role.OIDCScopes...),
|
||||
}
|
||||
|
||||
stateID := fmt.Sprintf("%x", bytes[:20])
|
||||
nonce := fmt.Sprintf("%x", bytes[20:])
|
||||
if config.hasType(responseTypeIDToken) {
|
||||
options = append(options, oidc.WithImplicitFlow())
|
||||
}
|
||||
|
||||
b.oidcStates.SetDefault(stateID, &oidcState{
|
||||
if role.MaxAge > 0 {
|
||||
options = append(options, oidc.WithMaxAge(uint(role.MaxAge.Seconds())))
|
||||
}
|
||||
|
||||
request, err := oidc.NewRequest(oidcRequestTimeout, redirectURI, options...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oidcReq := &oidcRequest{
|
||||
Request: request,
|
||||
rolename: rolename,
|
||||
nonce: nonce,
|
||||
redirectURI: redirectURI,
|
||||
clientNonce: clientNonce,
|
||||
})
|
||||
}
|
||||
b.oidcRequests.SetDefault(request.State(), oidcReq)
|
||||
|
||||
return stateID, nonce, nil
|
||||
return oidcReq, nil
|
||||
}
|
||||
|
||||
func (b *jwtAuthBackend) amendState(stateID, code, idToken string) (*oidcState, error) {
|
||||
stateRaw, ok := b.oidcStates.Get(stateID)
|
||||
func (b *jwtAuthBackend) amendOIDCRequest(stateID, code, idToken string) (*oidcRequest, error) {
|
||||
requestRaw, ok := b.oidcRequests.Get(stateID)
|
||||
if !ok {
|
||||
return nil, errors.New("OIDC state not found")
|
||||
}
|
||||
|
||||
state := stateRaw.(*oidcState)
|
||||
state.code = code
|
||||
state.idToken = idToken
|
||||
oidcReq := requestRaw.(*oidcRequest)
|
||||
oidcReq.code = code
|
||||
oidcReq.idToken = idToken
|
||||
|
||||
b.oidcStates.SetDefault(stateID, state)
|
||||
b.oidcRequests.SetDefault(stateID, oidcReq)
|
||||
|
||||
return state, nil
|
||||
return oidcReq, nil
|
||||
}
|
||||
|
||||
// verifyState tests whether the provided state ID is valid and returns the
|
||||
// associated state object if so. A nil state is returned if the ID is not found
|
||||
// or expired. The state should only ever be retrieved once and is deleted as
|
||||
// verifyOIDCRequest tests whether the provided state ID is valid and returns the
|
||||
// associated oidcRequest if so. A nil oidcRequest is returned if the ID is not found
|
||||
// or expired. The oidcRequest should only ever be retrieved once and is deleted as
|
||||
// part of this request.
|
||||
func (b *jwtAuthBackend) verifyState(stateID string) *oidcState {
|
||||
defer b.oidcStates.Delete(stateID)
|
||||
func (b *jwtAuthBackend) verifyOIDCRequest(stateID string) *oidcRequest {
|
||||
defer b.oidcRequests.Delete(stateID)
|
||||
|
||||
if stateRaw, ok := b.oidcStates.Get(stateID); ok {
|
||||
return stateRaw.(*oidcState)
|
||||
if requestRaw, ok := b.oidcRequests.Get(stateID); ok {
|
||||
return requestRaw.(*oidcRequest)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -7,13 +7,12 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
|
||||
"github.com/hashicorp/go-sockaddr"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/strutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/tokenutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
var reservedMetadata = []string{"role"}
|
||||
|
@ -144,6 +143,11 @@ Defaults to 60 (1 minute) if set to 0 and can be disabled if set to -1.`,
|
|||
Not recommended in production since sensitive information may be present
|
||||
in OIDC responses.`,
|
||||
},
|
||||
"max_age": {
|
||||
Type: framework.TypeDurationSecond,
|
||||
Description: `Specifies the allowable elapsed time in seconds since the last time the
|
||||
user was actively authenticated.`,
|
||||
},
|
||||
},
|
||||
ExistenceCheck: b.pathRoleExistenceCheck,
|
||||
Operations: map[logical.Operation]framework.OperationHandler{
|
||||
|
@ -202,6 +206,7 @@ type jwtRole struct {
|
|||
OIDCScopes []string `json:"oidc_scopes"`
|
||||
AllowedRedirectURIs []string `json:"allowed_redirect_uris"`
|
||||
VerboseOIDCLogging bool `json:"verbose_oidc_logging"`
|
||||
MaxAge time.Duration `json:"max_age"`
|
||||
|
||||
// Deprecated by TokenParams
|
||||
Policies []string `json:"policies"`
|
||||
|
@ -308,6 +313,7 @@ func (b *jwtAuthBackend) pathRoleRead(ctx context.Context, req *logical.Request,
|
|||
"allowed_redirect_uris": role.AllowedRedirectURIs,
|
||||
"oidc_scopes": role.OIDCScopes,
|
||||
"verbose_oidc_logging": role.VerboseOIDCLogging,
|
||||
"max_age": int64(role.MaxAge.Seconds()),
|
||||
}
|
||||
|
||||
role.PopulateTokenData(d)
|
||||
|
@ -441,6 +447,10 @@ func (b *jwtAuthBackend) pathRoleCreateUpdate(ctx context.Context, req *logical.
|
|||
role.VerboseOIDCLogging = verboseOIDCLoggingRaw.(bool)
|
||||
}
|
||||
|
||||
if maxAgeRaw, ok := data.GetOk("max_age"); ok {
|
||||
role.MaxAge = time.Duration(maxAgeRaw.(int)) * time.Second
|
||||
}
|
||||
|
||||
boundClaimsType := data.Get("bound_claims_type").(string)
|
||||
switch boundClaimsType {
|
||||
case boundClaimsTypeString, boundClaimsTypeGlob:
|
||||
|
|
|
@ -3,22 +3,24 @@ package jwtauth
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/clientcredentials"
|
||||
)
|
||||
|
||||
const (
|
||||
// The old MS graph API requires setting an api-version query parameter
|
||||
windowsGraphHost = "graph.windows.net"
|
||||
windowsAPIVersion = "1.6"
|
||||
// Deprecated: The host of the Azure Active Directory (AAD) graph API
|
||||
azureADGraphHost = "graph.windows.net"
|
||||
|
||||
// The host and version of the Microsoft Graph API
|
||||
microsoftGraphHost = "graph.microsoft.com"
|
||||
microsoftGraphAPIVersion = "/v1.0"
|
||||
|
||||
// Distributed claim fields
|
||||
claimNamesField = "_claim_names"
|
||||
|
@ -29,9 +31,6 @@ const (
|
|||
type AzureProvider struct {
|
||||
// Context for azure calls
|
||||
ctx context.Context
|
||||
|
||||
// OIDC provider
|
||||
provider *oidc.Provider
|
||||
}
|
||||
|
||||
// Initialize anything in the AzureProvider struct - satisfying the CustomProvider interface
|
||||
|
@ -45,7 +44,7 @@ func (a *AzureProvider) SensitiveKeys() []string {
|
|||
}
|
||||
|
||||
// FetchGroups - custom groups fetching for azure - satisfying GroupsFetcher interface
|
||||
func (a *AzureProvider) FetchGroups(_ context.Context, b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole) (interface{}, error) {
|
||||
func (a *AzureProvider) FetchGroups(_ context.Context, b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole, tokenSource oauth2.TokenSource) (interface{}, error) {
|
||||
groupsClaimRaw := getClaim(b.Logger(), allClaims, role.GroupsClaim)
|
||||
|
||||
if groupsClaimRaw == nil {
|
||||
|
@ -57,20 +56,12 @@ func (a *AzureProvider) FetchGroups(_ context.Context, b *jwtAuthBackend, allCla
|
|||
return nil, fmt.Errorf("unable to get claim sources: %s", err)
|
||||
}
|
||||
|
||||
// Get provider because we'll need to get a new token for microsoft's
|
||||
// graph API, specifically the old graph API
|
||||
provider, err := b.getProvider(b.cachedConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get provider: %s", err)
|
||||
}
|
||||
a.provider = provider
|
||||
|
||||
a.ctx, err = b.createCAContext(b.providerCtx, b.cachedConfig.OIDCDiscoveryCAPEM)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create CA Context: %s", err)
|
||||
}
|
||||
|
||||
azureGroups, err := a.getAzureGroups(azureClaimSourcesURL, b.cachedConfig)
|
||||
azureGroups, err := a.getAzureGroups(azureClaimSourcesURL, tokenSource)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%q claim not found in token: %v", role.GroupsClaim, err)
|
||||
}
|
||||
|
@ -112,46 +103,62 @@ func (a *AzureProvider) getClaimSource(logger log.Logger, allClaims map[string]i
|
|||
if val == nil {
|
||||
return "", fmt.Errorf("unable to locate %s in claims", endpoint)
|
||||
}
|
||||
logger.Debug(fmt.Sprintf("found Azure Graph API endpoint for group membership: %v", val))
|
||||
return fmt.Sprintf("%v", val), nil
|
||||
|
||||
urlParsed, err := url.Parse(fmt.Sprintf("%v", val))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to parse claim source URL: %w", err)
|
||||
}
|
||||
|
||||
// If the endpoint source for the groups claim has a host of the deprecated AAD graph API,
|
||||
// then replace it to instead use the Microsoft graph API. The AAD graph API is deprecated
|
||||
// and will eventually stop servicing requests. See details at:
|
||||
// - https://developer.microsoft.com/en-us/office/blogs/microsoft-graph-or-azure-ad-graph/
|
||||
// - https://docs.microsoft.com/en-us/graph/api/overview?view=graph-rest-1.0
|
||||
if urlParsed.Host == azureADGraphHost {
|
||||
urlParsed.Host = microsoftGraphHost
|
||||
urlParsed.Path = microsoftGraphAPIVersion + urlParsed.Path
|
||||
}
|
||||
|
||||
logger.Debug(fmt.Sprintf("found Azure Graph API endpoint for group membership: %v", urlParsed.String()))
|
||||
return urlParsed.String(), nil
|
||||
}
|
||||
|
||||
// Fetch user groups from the Azure AD Graph API
|
||||
func (a *AzureProvider) getAzureGroups(groupsURL string, c *jwtConfig) (interface{}, error) {
|
||||
// Fetch user groups from the Microsoft Graph API
|
||||
func (a *AzureProvider) getAzureGroups(groupsURL string, tokenSource oauth2.TokenSource) (interface{}, error) {
|
||||
urlParsed, err := url.Parse(groupsURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse distributed groups source url %s: %s", groupsURL, err)
|
||||
}
|
||||
token, err := a.getAzureToken(c, urlParsed.Host)
|
||||
|
||||
// Use the Access Token that was pre-negotiated between the Claims Provider and RP
|
||||
// via https://openid.net/specs/openid-connect-core-1_0.html#AggregatedDistributedClaims.
|
||||
if tokenSource == nil {
|
||||
return nil, errors.New("token unavailable to call Microsoft Graph API")
|
||||
}
|
||||
token, err := tokenSource.Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get token: %s", err)
|
||||
}
|
||||
payload := strings.NewReader("{\"securityEnabledOnly\": false}")
|
||||
req, err := http.NewRequest("POST", groupsURL, payload)
|
||||
req, err := http.NewRequest("POST", urlParsed.String(), payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error constructing groups endpoint request: %s", err)
|
||||
}
|
||||
req.Header.Add("content-type", "application/json")
|
||||
req.Header.Add("authorization", fmt.Sprintf("Bearer %s", token))
|
||||
token.SetAuthHeader(req)
|
||||
|
||||
// If endpoint is the old windows graph api, add api-version
|
||||
if urlParsed.Host == windowsGraphHost {
|
||||
query := req.URL.Query()
|
||||
query.Add("api-version", windowsAPIVersion)
|
||||
req.URL.RawQuery = query.Encode()
|
||||
}
|
||||
client := http.DefaultClient
|
||||
if c, ok := a.ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
|
||||
client = c
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to call Azure AD Graph API: %s", err)
|
||||
return nil, fmt.Errorf("unable to call Microsoft Graph API: %s", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read Azure AD Graph API response: %s", err)
|
||||
return nil, fmt.Errorf("failed to read Microsoft Graph API response: %s", err)
|
||||
}
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to get groups: %s", string(body))
|
||||
|
@ -164,25 +171,6 @@ func (a *AzureProvider) getAzureGroups(groupsURL string, c *jwtConfig) (interfac
|
|||
return target.Value, nil
|
||||
}
|
||||
|
||||
// Login to Azure, using client id and secret.
|
||||
func (a *AzureProvider) getAzureToken(c *jwtConfig, host string) (string, error) {
|
||||
config := &clientcredentials.Config{
|
||||
ClientID: c.OIDCClientID,
|
||||
ClientSecret: c.OIDCClientSecret,
|
||||
TokenURL: a.provider.Endpoint().TokenURL,
|
||||
Scopes: []string{
|
||||
"openid",
|
||||
"profile",
|
||||
"https://" + host + "/.default",
|
||||
},
|
||||
}
|
||||
token, err := config.Token(a.ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to fetch Azure token: %s", err)
|
||||
}
|
||||
return token.AccessToken, nil
|
||||
}
|
||||
|
||||
type azureGroups struct {
|
||||
Value []interface{} `json:"value"`
|
||||
}
|
||||
|
|
|
@ -3,6 +3,8 @@ package jwtauth
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// Provider-specific configuration interfaces
|
||||
|
@ -58,5 +60,5 @@ type UserInfoFetcher interface {
|
|||
// GroupsFetcher - Optional support for custom groups handling
|
||||
type GroupsFetcher interface {
|
||||
// FetchGroups queries for groups claims during login
|
||||
FetchGroups(context.Context, *jwtAuthBackend, map[string]interface{}, *jwtRole) (interface{}, error)
|
||||
FetchGroups(context.Context, *jwtAuthBackend, map[string]interface{}, *jwtRole, oauth2.TokenSource) (interface{}, error)
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"io/ioutil"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
"golang.org/x/oauth2/jwt"
|
||||
admin "google.golang.org/api/admin/directory/v1"
|
||||
|
@ -104,7 +105,7 @@ func (g *GSuiteProvider) SensitiveKeys() []string {
|
|||
}
|
||||
|
||||
// FetchGroups fetches and returns groups from G Suite.
|
||||
func (g *GSuiteProvider) FetchGroups(ctx context.Context, b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole) (interface{}, error) {
|
||||
func (g *GSuiteProvider) FetchGroups(ctx context.Context, b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole, _ oauth2.TokenSource) (interface{}, error) {
|
||||
if !g.config.FetchGroups {
|
||||
return nil, nil
|
||||
}
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
arch:
|
||||
- amd64
|
||||
- ppc64le
|
||||
language: go
|
||||
|
||||
install:
|
||||
- go get -d -v ./...
|
||||
- go get -u github.com/stretchr/testify/require
|
||||
|
||||
go:
|
||||
- 1.7
|
||||
- 1.8
|
||||
- tip
|
||||
- "1.14"
|
||||
- "1.15"
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# cachecontrol: HTTP Caching Parser and Interpretation
|
||||
|
||||
[![GoDoc](https://godoc.org/github.com/pquerna/cachecontrol?status.svg)](https://godoc.org/github.com/pquerna/cachecontrol)[![Build Status](https://travis-ci.org/pquerna/cachecontrol.svg?branch=master)](https://travis-ci.org/pquerna/cachecontrol)
|
||||
[![PkgGoDev](https://pkg.go.dev/badge/github.com/pquerna/cachecontrol?tab=doc)](https://pkg.go.dev/github.com/pquerna/cachecontrol?tab=doc)[![Build Status](https://travis-ci.org/pquerna/cachecontrol.svg?branch=master)](https://travis-ci.org/pquerna/cachecontrol)
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ import (
|
|||
)
|
||||
|
||||
type Options struct {
|
||||
// Set to True for a prviate cache, which is not shared amoung users (eg, in a browser)
|
||||
// Set to True for a private cache, which is not shared among users (eg, in a browser)
|
||||
// Set to False for a "shared" cache, which is more common in a server context.
|
||||
PrivateCache bool
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ var (
|
|||
ErrQuoteMismatch = errors.New("Missing closing quote")
|
||||
ErrMaxAgeDeltaSeconds = errors.New("Failed to parse delta-seconds in `max-age`")
|
||||
ErrSMaxAgeDeltaSeconds = errors.New("Failed to parse delta-seconds in `s-maxage`")
|
||||
ErrMaxStaleDeltaSeconds = errors.New("Failed to parse delta-seconds in `min-fresh`")
|
||||
ErrMaxStaleDeltaSeconds = errors.New("Failed to parse delta-seconds in `max-stale`")
|
||||
ErrMinFreshDeltaSeconds = errors.New("Failed to parse delta-seconds in `min-fresh`")
|
||||
ErrNoCacheNoArgs = errors.New("Unexpected argument to `no-cache`")
|
||||
ErrNoStoreNoArgs = errors.New("Unexpected argument to `no-store`")
|
||||
|
@ -164,7 +164,7 @@ type cacheDirective interface {
|
|||
addPair(s string, v string) error
|
||||
}
|
||||
|
||||
// LOW LEVEL API: Repersentation of possible request directives in a `Cache-Control` header: http://tools.ietf.org/html/rfc7234#section-5.2.1
|
||||
// LOW LEVEL API: Representation of possible request directives in a `Cache-Control` header: http://tools.ietf.org/html/rfc7234#section-5.2.1
|
||||
//
|
||||
// Note: Many fields will be `nil` in practice.
|
||||
//
|
||||
|
@ -189,6 +189,7 @@ type RequestCacheDirectives struct {
|
|||
// assigned to max-stale, then the client is willing to accept a stale
|
||||
// response of any age.
|
||||
MaxStale DeltaSeconds
|
||||
MaxStaleSet bool
|
||||
|
||||
// min-fresh(delta seconds): http://tools.ietf.org/html/rfc7234#section-5.2.1.3
|
||||
//
|
||||
|
@ -240,10 +241,10 @@ func (cd *RequestCacheDirectives) addToken(token string) error {
|
|||
switch token {
|
||||
case "max-age":
|
||||
err = ErrMaxAgeDeltaSeconds
|
||||
case "max-stale":
|
||||
err = ErrMaxStaleDeltaSeconds
|
||||
case "min-fresh":
|
||||
err = ErrMinFreshDeltaSeconds
|
||||
case "max-stale":
|
||||
cd.MaxStaleSet = true
|
||||
case "no-cache":
|
||||
cd.NoCache = true
|
||||
case "no-store":
|
||||
|
|
|
@ -22,7 +22,7 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
// LOW LEVEL API: Repersents a potentially cachable HTTP object.
|
||||
// LOW LEVEL API: Represents a potentially cachable HTTP object.
|
||||
//
|
||||
// This struct is designed to be serialized efficiently, so in a high
|
||||
// performance caching server, things like Date-Strings don't need to be
|
||||
|
@ -44,7 +44,7 @@ type Object struct {
|
|||
NowUTC time.Time
|
||||
}
|
||||
|
||||
// LOW LEVEL API: Repersents the results of examinig an Object with
|
||||
// LOW LEVEL API: Represents the results of examining an Object with
|
||||
// CachableObject and ExpirationObject.
|
||||
//
|
||||
// TODO(pquerna): decide if this is a good idea or bad
|
||||
|
@ -103,10 +103,10 @@ func CachableObject(obj *Object, rv *ObjectResults) {
|
|||
// To my knowledge, none of them are cachable. Please open a ticket if this is not the case!
|
||||
//
|
||||
default:
|
||||
rv.OutReasons = append(rv.OutReasons, ReasonRequestMethodUnkown)
|
||||
rv.OutReasons = append(rv.OutReasons, ReasonRequestMethodUnknown)
|
||||
}
|
||||
|
||||
if obj.ReqDirectives.NoStore {
|
||||
if obj.ReqDirectives != nil && obj.ReqDirectives.NoStore {
|
||||
rv.OutReasons = append(rv.OutReasons, ReasonRequestNoStore)
|
||||
}
|
||||
|
||||
|
@ -232,7 +232,7 @@ func ExpirationObject(obj *Object, rv *ObjectResults) {
|
|||
println("Expiration: ", expiresTime.String())
|
||||
}
|
||||
} else {
|
||||
// TODO(pquerna): what should the default behavoir be for expiration time?
|
||||
// TODO(pquerna): what should the default behavior be for expiration time?
|
||||
}
|
||||
|
||||
rv.OutExpirationTime = expiresTime
|
||||
|
|
|
@ -45,7 +45,7 @@ const (
|
|||
ReasonRequestMethodTRACE
|
||||
|
||||
// The request method was not recognized by cachecontrol, and should not be cached.
|
||||
ReasonRequestMethodUnkown
|
||||
ReasonRequestMethodUnknown
|
||||
|
||||
// The request included an Cache-Control: no-store header
|
||||
ReasonRequestNoStore
|
||||
|
@ -77,7 +77,7 @@ func (r Reason) String() string {
|
|||
return "ReasonRequestMethodOPTIONS"
|
||||
case ReasonRequestMethodTRACE:
|
||||
return "ReasonRequestMethodTRACE"
|
||||
case ReasonRequestMethodUnkown:
|
||||
case ReasonRequestMethodUnknown:
|
||||
return "ReasonRequestMethodUnkown"
|
||||
case ReasonRequestNoStore:
|
||||
return "ReasonRequestNoStore"
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
module github.com/pquerna/cachecontrol
|
||||
|
||||
go 1.14
|
||||
|
||||
require github.com/stretchr/testify v1.6.1
|
|
@ -0,0 +1,11 @@
|
|||
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
|
@ -0,0 +1,16 @@
|
|||
// Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT.
|
||||
|
||||
package language
|
||||
|
||||
// This file contains code common to the maketables.go and the package code.
|
||||
|
||||
// AliasType is the type of an alias in AliasMap.
|
||||
type AliasType int8
|
||||
|
||||
const (
|
||||
Deprecated AliasType = iota
|
||||
Macro
|
||||
Legacy
|
||||
|
||||
AliasTypeUnknown AliasType = -1
|
||||
)
|
|
@ -0,0 +1,29 @@
|
|||
// Copyright 2018 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package language
|
||||
|
||||
// CompactCoreInfo is a compact integer with the three core tags encoded.
|
||||
type CompactCoreInfo uint32
|
||||
|
||||
// GetCompactCore generates a uint32 value that is guaranteed to be unique for
|
||||
// different language, region, and script values.
|
||||
func GetCompactCore(t Tag) (cci CompactCoreInfo, ok bool) {
|
||||
if t.LangID > langNoIndexOffset {
|
||||
return 0, false
|
||||
}
|
||||
cci |= CompactCoreInfo(t.LangID) << (8 + 12)
|
||||
cci |= CompactCoreInfo(t.ScriptID) << 12
|
||||
cci |= CompactCoreInfo(t.RegionID)
|
||||
return cci, true
|
||||
}
|
||||
|
||||
// Tag generates a tag from c.
|
||||
func (c CompactCoreInfo) Tag() Tag {
|
||||
return Tag{
|
||||
LangID: Language(c >> 20),
|
||||
RegionID: Region(c & 0x3ff),
|
||||
ScriptID: Script(c>>12) & 0xff,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
// Copyright 2018 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package compact defines a compact representation of language tags.
|
||||
//
|
||||
// Common language tags (at least all for which locale information is defined
|
||||
// in CLDR) are assigned a unique index. Each Tag is associated with such an
|
||||
// ID for selecting language-related resources (such as translations) as well
|
||||
// as one for selecting regional defaults (currency, number formatting, etc.)
|
||||
//
|
||||
// It may want to export this functionality at some point, but at this point
|
||||
// this is only available for use within x/text.
|
||||
package compact // import "golang.org/x/text/internal/language/compact"
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/text/internal/language"
|
||||
)
|
||||
|
||||
// ID is an integer identifying a single tag.
|
||||
type ID uint16
|
||||
|
||||
func getCoreIndex(t language.Tag) (id ID, ok bool) {
|
||||
cci, ok := language.GetCompactCore(t)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
i := sort.Search(len(coreTags), func(i int) bool {
|
||||
return cci <= coreTags[i]
|
||||
})
|
||||
if i == len(coreTags) || coreTags[i] != cci {
|
||||
return 0, false
|
||||
}
|
||||
return ID(i), true
|
||||
}
|
||||
|
||||
// Parent returns the ID of the parent or the root ID if id is already the root.
|
||||
func (id ID) Parent() ID {
|
||||
return parents[id]
|
||||
}
|
||||
|
||||
// Tag converts id to an internal language Tag.
|
||||
func (id ID) Tag() language.Tag {
|
||||
if int(id) >= len(coreTags) {
|
||||
return specialTags[int(id)-len(coreTags)]
|
||||
}
|
||||
return coreTags[id].Tag()
|
||||
}
|
||||
|
||||
var specialTags []language.Tag
|
||||
|
||||
func init() {
|
||||
tags := strings.Split(specialTagsStr, " ")
|
||||
specialTags = make([]language.Tag, len(tags))
|
||||
for i, t := range tags {
|
||||
specialTags[i] = language.MustParse(t)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,260 @@
|
|||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:generate go run gen.go gen_index.go -output tables.go
|
||||
//go:generate go run gen_parents.go
|
||||
|
||||
package compact
|
||||
|
||||
// TODO: Remove above NOTE after:
|
||||
// - verifying that tables are dropped correctly (most notably matcher tables).
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"golang.org/x/text/internal/language"
|
||||
)
|
||||
|
||||
// Tag represents a BCP 47 language tag. It is used to specify an instance of a
|
||||
// specific language or locale. All language tag values are guaranteed to be
|
||||
// well-formed.
|
||||
type Tag struct {
|
||||
// NOTE: exported tags will become part of the public API.
|
||||
language ID
|
||||
locale ID
|
||||
full fullTag // always a language.Tag for now.
|
||||
}
|
||||
|
||||
const _und = 0
|
||||
|
||||
type fullTag interface {
|
||||
IsRoot() bool
|
||||
Parent() language.Tag
|
||||
}
|
||||
|
||||
// Make a compact Tag from a fully specified internal language Tag.
|
||||
func Make(t language.Tag) (tag Tag) {
|
||||
if region := t.TypeForKey("rg"); len(region) == 6 && region[2:] == "zzzz" {
|
||||
if r, err := language.ParseRegion(region[:2]); err == nil {
|
||||
tFull := t
|
||||
t, _ = t.SetTypeForKey("rg", "")
|
||||
// TODO: should we not consider "va" for the language tag?
|
||||
var exact1, exact2 bool
|
||||
tag.language, exact1 = FromTag(t)
|
||||
t.RegionID = r
|
||||
tag.locale, exact2 = FromTag(t)
|
||||
if !exact1 || !exact2 {
|
||||
tag.full = tFull
|
||||
}
|
||||
return tag
|
||||
}
|
||||
}
|
||||
lang, ok := FromTag(t)
|
||||
tag.language = lang
|
||||
tag.locale = lang
|
||||
if !ok {
|
||||
tag.full = t
|
||||
}
|
||||
return tag
|
||||
}
|
||||
|
||||
// Tag returns an internal language Tag version of this tag.
|
||||
func (t Tag) Tag() language.Tag {
|
||||
if t.full != nil {
|
||||
return t.full.(language.Tag)
|
||||
}
|
||||
tag := t.language.Tag()
|
||||
if t.language != t.locale {
|
||||
loc := t.locale.Tag()
|
||||
tag, _ = tag.SetTypeForKey("rg", strings.ToLower(loc.RegionID.String())+"zzzz")
|
||||
}
|
||||
return tag
|
||||
}
|
||||
|
||||
// IsCompact reports whether this tag is fully defined in terms of ID.
|
||||
func (t *Tag) IsCompact() bool {
|
||||
return t.full == nil
|
||||
}
|
||||
|
||||
// MayHaveVariants reports whether a tag may have variants. If it returns false
|
||||
// it is guaranteed the tag does not have variants.
|
||||
func (t Tag) MayHaveVariants() bool {
|
||||
return t.full != nil || int(t.language) >= len(coreTags)
|
||||
}
|
||||
|
||||
// MayHaveExtensions reports whether a tag may have extensions. If it returns
|
||||
// false it is guaranteed the tag does not have them.
|
||||
func (t Tag) MayHaveExtensions() bool {
|
||||
return t.full != nil ||
|
||||
int(t.language) >= len(coreTags) ||
|
||||
t.language != t.locale
|
||||
}
|
||||
|
||||
// IsRoot returns true if t is equal to language "und".
|
||||
func (t Tag) IsRoot() bool {
|
||||
if t.full != nil {
|
||||
return t.full.IsRoot()
|
||||
}
|
||||
return t.language == _und
|
||||
}
|
||||
|
||||
// Parent returns the CLDR parent of t. In CLDR, missing fields in data for a
|
||||
// specific language are substituted with fields from the parent language.
|
||||
// The parent for a language may change for newer versions of CLDR.
|
||||
func (t Tag) Parent() Tag {
|
||||
if t.full != nil {
|
||||
return Make(t.full.Parent())
|
||||
}
|
||||
if t.language != t.locale {
|
||||
// Simulate stripping -u-rg-xxxxxx
|
||||
return Tag{language: t.language, locale: t.language}
|
||||
}
|
||||
// TODO: use parent lookup table once cycle from internal package is
|
||||
// removed. Probably by internalizing the table and declaring this fast
|
||||
// enough.
|
||||
// lang := compactID(internal.Parent(uint16(t.language)))
|
||||
lang, _ := FromTag(t.language.Tag().Parent())
|
||||
return Tag{language: lang, locale: lang}
|
||||
}
|
||||
|
||||
// returns token t and the rest of the string.
|
||||
func nextToken(s string) (t, tail string) {
|
||||
p := strings.Index(s[1:], "-")
|
||||
if p == -1 {
|
||||
return s[1:], ""
|
||||
}
|
||||
p++
|
||||
return s[1:p], s[p:]
|
||||
}
|
||||
|
||||
// LanguageID returns an index, where 0 <= index < NumCompactTags, for tags
|
||||
// for which data exists in the text repository.The index will change over time
|
||||
// and should not be stored in persistent storage. If t does not match a compact
|
||||
// index, exact will be false and the compact index will be returned for the
|
||||
// first match after repeatedly taking the Parent of t.
|
||||
func LanguageID(t Tag) (id ID, exact bool) {
|
||||
return t.language, t.full == nil
|
||||
}
|
||||
|
||||
// RegionalID returns the ID for the regional variant of this tag. This index is
|
||||
// used to indicate region-specific overrides, such as default currency, default
|
||||
// calendar and week data, default time cycle, and default measurement system
|
||||
// and unit preferences.
|
||||
//
|
||||
// For instance, the tag en-GB-u-rg-uszzzz specifies British English with US
|
||||
// settings for currency, number formatting, etc. The CompactIndex for this tag
|
||||
// will be that for en-GB, while the RegionalID will be the one corresponding to
|
||||
// en-US.
|
||||
func RegionalID(t Tag) (id ID, exact bool) {
|
||||
return t.locale, t.full == nil
|
||||
}
|
||||
|
||||
// LanguageTag returns t stripped of regional variant indicators.
|
||||
//
|
||||
// At the moment this means it is stripped of a regional and variant subtag "rg"
|
||||
// and "va" in the "u" extension.
|
||||
func (t Tag) LanguageTag() Tag {
|
||||
if t.full == nil {
|
||||
return Tag{language: t.language, locale: t.language}
|
||||
}
|
||||
tt := t.Tag()
|
||||
tt.SetTypeForKey("rg", "")
|
||||
tt.SetTypeForKey("va", "")
|
||||
return Make(tt)
|
||||
}
|
||||
|
||||
// RegionalTag returns the regional variant of the tag.
|
||||
//
|
||||
// At the moment this means that the region is set from the regional subtag
|
||||
// "rg" in the "u" extension.
|
||||
func (t Tag) RegionalTag() Tag {
|
||||
rt := Tag{language: t.locale, locale: t.locale}
|
||||
if t.full == nil {
|
||||
return rt
|
||||
}
|
||||
b := language.Builder{}
|
||||
tag := t.Tag()
|
||||
// tag, _ = tag.SetTypeForKey("rg", "")
|
||||
b.SetTag(t.locale.Tag())
|
||||
if v := tag.Variants(); v != "" {
|
||||
for _, v := range strings.Split(v, "-") {
|
||||
b.AddVariant(v)
|
||||
}
|
||||
}
|
||||
for _, e := range tag.Extensions() {
|
||||
b.AddExt(e)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// FromTag reports closest matching ID for an internal language Tag.
|
||||
func FromTag(t language.Tag) (id ID, exact bool) {
|
||||
// TODO: perhaps give more frequent tags a lower index.
|
||||
// TODO: we could make the indexes stable. This will excluded some
|
||||
// possibilities for optimization, so don't do this quite yet.
|
||||
exact = true
|
||||
|
||||
b, s, r := t.Raw()
|
||||
if t.HasString() {
|
||||
if t.IsPrivateUse() {
|
||||
// We have no entries for user-defined tags.
|
||||
return 0, false
|
||||
}
|
||||
hasExtra := false
|
||||
if t.HasVariants() {
|
||||
if t.HasExtensions() {
|
||||
build := language.Builder{}
|
||||
build.SetTag(language.Tag{LangID: b, ScriptID: s, RegionID: r})
|
||||
build.AddVariant(t.Variants())
|
||||
exact = false
|
||||
t = build.Make()
|
||||
}
|
||||
hasExtra = true
|
||||
} else if _, ok := t.Extension('u'); ok {
|
||||
// TODO: va may mean something else. Consider not considering it.
|
||||
// Strip all but the 'va' entry.
|
||||
old := t
|
||||
variant := t.TypeForKey("va")
|
||||
t = language.Tag{LangID: b, ScriptID: s, RegionID: r}
|
||||
if variant != "" {
|
||||
t, _ = t.SetTypeForKey("va", variant)
|
||||
hasExtra = true
|
||||
}
|
||||
exact = old == t
|
||||
} else {
|
||||
exact = false
|
||||
}
|
||||
if hasExtra {
|
||||
// We have some variants.
|
||||
for i, s := range specialTags {
|
||||
if s == t {
|
||||
return ID(i + len(coreTags)), exact
|
||||
}
|
||||
}
|
||||
exact = false
|
||||
}
|
||||
}
|
||||
if x, ok := getCoreIndex(t); ok {
|
||||
return x, exact
|
||||
}
|
||||
exact = false
|
||||
if r != 0 && s == 0 {
|
||||
// Deal with cases where an extra script is inserted for the region.
|
||||
t, _ := t.Maximize()
|
||||
if x, ok := getCoreIndex(t); ok {
|
||||
return x, exact
|
||||
}
|
||||
}
|
||||
for t = t.Parent(); t != root; t = t.Parent() {
|
||||
// No variants specified: just compare core components.
|
||||
// The key has the form lllssrrr, where l, s, and r are nibbles for
|
||||
// respectively the langID, scriptID, and regionID.
|
||||
if x, ok := getCoreIndex(t); ok {
|
||||
return x, exact
|
||||
}
|
||||
}
|
||||
return 0, exact
|
||||
}
|
||||
|
||||
var root = language.Tag{}
|
|
@ -0,0 +1,120 @@
|
|||
// Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT.
|
||||
|
||||
package compact
|
||||
|
||||
// parents maps a compact index of a tag to the compact index of the parent of
|
||||
// this tag.
|
||||
var parents = []ID{ // 775 elements
|
||||
// Entry 0 - 3F
|
||||
0x0000, 0x0000, 0x0001, 0x0001, 0x0000, 0x0004, 0x0000, 0x0006,
|
||||
0x0000, 0x0008, 0x0000, 0x000a, 0x000a, 0x000a, 0x000a, 0x000a,
|
||||
0x000a, 0x000a, 0x000a, 0x000a, 0x000a, 0x000a, 0x000a, 0x000a,
|
||||
0x000a, 0x000a, 0x000a, 0x000a, 0x000a, 0x000a, 0x000a, 0x000a,
|
||||
0x000a, 0x000a, 0x000a, 0x000a, 0x000a, 0x000a, 0x000a, 0x0000,
|
||||
0x0000, 0x0028, 0x0000, 0x002a, 0x0000, 0x002c, 0x0000, 0x0000,
|
||||
0x002f, 0x002e, 0x002e, 0x0000, 0x0033, 0x0000, 0x0035, 0x0000,
|
||||
0x0037, 0x0000, 0x0039, 0x0000, 0x003b, 0x0000, 0x0000, 0x003e,
|
||||
// Entry 40 - 7F
|
||||
0x0000, 0x0040, 0x0040, 0x0000, 0x0043, 0x0043, 0x0000, 0x0046,
|
||||
0x0000, 0x0048, 0x0000, 0x0000, 0x004b, 0x004a, 0x004a, 0x0000,
|
||||
0x004f, 0x004f, 0x004f, 0x004f, 0x0000, 0x0054, 0x0054, 0x0000,
|
||||
0x0057, 0x0000, 0x0059, 0x0000, 0x005b, 0x0000, 0x005d, 0x005d,
|
||||
0x0000, 0x0060, 0x0000, 0x0062, 0x0000, 0x0064, 0x0000, 0x0066,
|
||||
0x0066, 0x0000, 0x0069, 0x0000, 0x006b, 0x006b, 0x006b, 0x006b,
|
||||
0x006b, 0x006b, 0x006b, 0x0000, 0x0073, 0x0000, 0x0075, 0x0000,
|
||||
0x0077, 0x0000, 0x0000, 0x007a, 0x0000, 0x007c, 0x0000, 0x007e,
|
||||
// Entry 80 - BF
|
||||
0x0000, 0x0080, 0x0080, 0x0000, 0x0083, 0x0083, 0x0000, 0x0086,
|
||||
0x0087, 0x0087, 0x0087, 0x0086, 0x0088, 0x0087, 0x0087, 0x0087,
|
||||
0x0086, 0x0087, 0x0087, 0x0087, 0x0087, 0x0087, 0x0087, 0x0088,
|
||||
0x0087, 0x0087, 0x0087, 0x0087, 0x0088, 0x0087, 0x0088, 0x0087,
|
||||
0x0087, 0x0088, 0x0087, 0x0087, 0x0087, 0x0087, 0x0087, 0x0087,
|
||||
0x0087, 0x0087, 0x0087, 0x0086, 0x0087, 0x0087, 0x0087, 0x0087,
|
||||
0x0087, 0x0087, 0x0087, 0x0087, 0x0087, 0x0087, 0x0087, 0x0087,
|
||||
0x0087, 0x0087, 0x0087, 0x0087, 0x0087, 0x0086, 0x0087, 0x0086,
|
||||
// Entry C0 - FF
|
||||
0x0087, 0x0087, 0x0087, 0x0087, 0x0087, 0x0087, 0x0087, 0x0087,
|
||||
0x0088, 0x0087, 0x0087, 0x0087, 0x0087, 0x0087, 0x0087, 0x0087,
|
||||
0x0086, 0x0087, 0x0087, 0x0087, 0x0087, 0x0087, 0x0088, 0x0087,
|
||||
0x0087, 0x0088, 0x0087, 0x0087, 0x0087, 0x0087, 0x0087, 0x0087,
|
||||
0x0087, 0x0087, 0x0087, 0x0087, 0x0087, 0x0086, 0x0086, 0x0087,
|
||||
0x0087, 0x0086, 0x0087, 0x0087, 0x0087, 0x0087, 0x0087, 0x0000,
|
||||
0x00ef, 0x0000, 0x00f1, 0x00f2, 0x00f2, 0x00f2, 0x00f2, 0x00f2,
|
||||
0x00f2, 0x00f2, 0x00f2, 0x00f2, 0x00f1, 0x00f2, 0x00f1, 0x00f1,
|
||||
// Entry 100 - 13F
|
||||
0x00f2, 0x00f2, 0x00f1, 0x00f2, 0x00f2, 0x00f2, 0x00f2, 0x00f1,
|
||||
0x00f2, 0x00f2, 0x00f2, 0x00f2, 0x00f2, 0x00f2, 0x0000, 0x010e,
|
||||
0x0000, 0x0110, 0x0000, 0x0112, 0x0000, 0x0114, 0x0114, 0x0000,
|
||||
0x0117, 0x0117, 0x0117, 0x0117, 0x0000, 0x011c, 0x0000, 0x011e,
|
||||
0x0000, 0x0120, 0x0120, 0x0000, 0x0123, 0x0123, 0x0123, 0x0123,
|
||||
0x0123, 0x0123, 0x0123, 0x0123, 0x0123, 0x0123, 0x0123, 0x0123,
|
||||
0x0123, 0x0123, 0x0123, 0x0123, 0x0123, 0x0123, 0x0123, 0x0123,
|
||||
0x0123, 0x0123, 0x0123, 0x0123, 0x0123, 0x0123, 0x0123, 0x0123,
|
||||
// Entry 140 - 17F
|
||||
0x0123, 0x0123, 0x0123, 0x0123, 0x0123, 0x0123, 0x0123, 0x0123,
|
||||
0x0123, 0x0123, 0x0123, 0x0123, 0x0123, 0x0123, 0x0123, 0x0123,
|
||||
0x0123, 0x0123, 0x0000, 0x0152, 0x0000, 0x0154, 0x0000, 0x0156,
|
||||
0x0000, 0x0158, 0x0000, 0x015a, 0x0000, 0x015c, 0x015c, 0x015c,
|
||||
0x0000, 0x0160, 0x0000, 0x0000, 0x0163, 0x0000, 0x0165, 0x0000,
|
||||
0x0167, 0x0167, 0x0167, 0x0000, 0x016b, 0x0000, 0x016d, 0x0000,
|
||||
0x016f, 0x0000, 0x0171, 0x0171, 0x0000, 0x0174, 0x0000, 0x0176,
|
||||
0x0000, 0x0178, 0x0000, 0x017a, 0x0000, 0x017c, 0x0000, 0x017e,
|
||||
// Entry 180 - 1BF
|
||||
0x0000, 0x0000, 0x0000, 0x0182, 0x0000, 0x0184, 0x0184, 0x0184,
|
||||
0x0184, 0x0000, 0x0000, 0x0000, 0x018b, 0x0000, 0x0000, 0x018e,
|
||||
0x0000, 0x0000, 0x0191, 0x0000, 0x0000, 0x0000, 0x0195, 0x0000,
|
||||
0x0197, 0x0000, 0x0000, 0x019a, 0x0000, 0x0000, 0x019d, 0x0000,
|
||||
0x019f, 0x0000, 0x01a1, 0x0000, 0x01a3, 0x0000, 0x01a5, 0x0000,
|
||||
0x01a7, 0x0000, 0x01a9, 0x0000, 0x01ab, 0x0000, 0x01ad, 0x0000,
|
||||
0x01af, 0x0000, 0x01b1, 0x01b1, 0x0000, 0x01b4, 0x0000, 0x01b6,
|
||||
0x0000, 0x01b8, 0x0000, 0x01ba, 0x0000, 0x01bc, 0x0000, 0x0000,
|
||||
// Entry 1C0 - 1FF
|
||||
0x01bf, 0x0000, 0x01c1, 0x0000, 0x01c3, 0x0000, 0x01c5, 0x0000,
|
||||
0x01c7, 0x0000, 0x01c9, 0x0000, 0x01cb, 0x01cb, 0x01cb, 0x01cb,
|
||||
0x0000, 0x01d0, 0x0000, 0x01d2, 0x01d2, 0x0000, 0x01d5, 0x0000,
|
||||
0x01d7, 0x0000, 0x01d9, 0x0000, 0x01db, 0x0000, 0x01dd, 0x0000,
|
||||
0x01df, 0x01df, 0x0000, 0x01e2, 0x0000, 0x01e4, 0x0000, 0x01e6,
|
||||
0x0000, 0x01e8, 0x0000, 0x01ea, 0x0000, 0x01ec, 0x0000, 0x01ee,
|
||||
0x0000, 0x01f0, 0x0000, 0x0000, 0x01f3, 0x0000, 0x01f5, 0x01f5,
|
||||
0x01f5, 0x0000, 0x01f9, 0x0000, 0x01fb, 0x0000, 0x01fd, 0x0000,
|
||||
// Entry 200 - 23F
|
||||
0x01ff, 0x0000, 0x0000, 0x0202, 0x0000, 0x0204, 0x0204, 0x0000,
|
||||
0x0207, 0x0000, 0x0209, 0x0209, 0x0000, 0x020c, 0x020c, 0x0000,
|
||||
0x020f, 0x020f, 0x020f, 0x020f, 0x020f, 0x020f, 0x020f, 0x0000,
|
||||
0x0217, 0x0000, 0x0219, 0x0000, 0x021b, 0x0000, 0x0000, 0x0000,
|
||||
0x0000, 0x0000, 0x0221, 0x0000, 0x0000, 0x0224, 0x0000, 0x0226,
|
||||
0x0226, 0x0000, 0x0229, 0x0000, 0x022b, 0x022b, 0x0000, 0x0000,
|
||||
0x022f, 0x022e, 0x022e, 0x0000, 0x0000, 0x0234, 0x0000, 0x0236,
|
||||
0x0000, 0x0238, 0x0000, 0x0244, 0x023a, 0x0244, 0x0244, 0x0244,
|
||||
// Entry 240 - 27F
|
||||
0x0244, 0x0244, 0x0244, 0x0244, 0x023a, 0x0244, 0x0244, 0x0000,
|
||||
0x0247, 0x0247, 0x0247, 0x0000, 0x024b, 0x0000, 0x024d, 0x0000,
|
||||
0x024f, 0x024f, 0x0000, 0x0252, 0x0000, 0x0254, 0x0254, 0x0254,
|
||||
0x0254, 0x0254, 0x0254, 0x0000, 0x025b, 0x0000, 0x025d, 0x0000,
|
||||
0x025f, 0x0000, 0x0261, 0x0000, 0x0263, 0x0000, 0x0265, 0x0000,
|
||||
0x0000, 0x0268, 0x0268, 0x0268, 0x0000, 0x026c, 0x0000, 0x026e,
|
||||
0x0000, 0x0270, 0x0000, 0x0000, 0x0000, 0x0274, 0x0273, 0x0273,
|
||||
0x0000, 0x0278, 0x0000, 0x027a, 0x0000, 0x027c, 0x0000, 0x0000,
|
||||
// Entry 280 - 2BF
|
||||
0x0000, 0x0000, 0x0281, 0x0000, 0x0000, 0x0284, 0x0000, 0x0286,
|
||||
0x0286, 0x0286, 0x0286, 0x0000, 0x028b, 0x028b, 0x028b, 0x0000,
|
||||
0x028f, 0x028f, 0x028f, 0x028f, 0x028f, 0x0000, 0x0295, 0x0295,
|
||||
0x0295, 0x0295, 0x0000, 0x0000, 0x0000, 0x0000, 0x029d, 0x029d,
|
||||
0x029d, 0x0000, 0x02a1, 0x02a1, 0x02a1, 0x02a1, 0x0000, 0x0000,
|
||||
0x02a7, 0x02a7, 0x02a7, 0x02a7, 0x0000, 0x02ac, 0x0000, 0x02ae,
|
||||
0x02ae, 0x0000, 0x02b1, 0x0000, 0x02b3, 0x0000, 0x02b5, 0x02b5,
|
||||
0x0000, 0x0000, 0x02b9, 0x0000, 0x0000, 0x0000, 0x02bd, 0x0000,
|
||||
// Entry 2C0 - 2FF
|
||||
0x02bf, 0x02bf, 0x0000, 0x0000, 0x02c3, 0x0000, 0x02c5, 0x0000,
|
||||
0x02c7, 0x0000, 0x02c9, 0x0000, 0x02cb, 0x0000, 0x02cd, 0x02cd,
|
||||
0x0000, 0x0000, 0x02d1, 0x0000, 0x02d3, 0x02d0, 0x02d0, 0x0000,
|
||||
0x0000, 0x02d8, 0x02d7, 0x02d7, 0x0000, 0x0000, 0x02dd, 0x0000,
|
||||
0x02df, 0x0000, 0x02e1, 0x0000, 0x0000, 0x02e4, 0x0000, 0x02e6,
|
||||
0x0000, 0x0000, 0x02e9, 0x0000, 0x02eb, 0x0000, 0x02ed, 0x0000,
|
||||
0x02ef, 0x02ef, 0x0000, 0x0000, 0x02f3, 0x02f2, 0x02f2, 0x0000,
|
||||
0x02f7, 0x0000, 0x02f9, 0x02f9, 0x02f9, 0x02f9, 0x02f9, 0x0000,
|
||||
// Entry 300 - 33F
|
||||
0x02ff, 0x0300, 0x02ff, 0x0000, 0x0303, 0x0051, 0x00e6,
|
||||
} // Size: 1574 bytes
|
||||
|
||||
// Total table size 1574 bytes (1KiB); checksum: 895AAF0B
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,91 @@
|
|||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package compact
|
||||
|
||||
var (
|
||||
und = Tag{}
|
||||
|
||||
Und Tag = Tag{}
|
||||
|
||||
Afrikaans Tag = Tag{language: afIndex, locale: afIndex}
|
||||
Amharic Tag = Tag{language: amIndex, locale: amIndex}
|
||||
Arabic Tag = Tag{language: arIndex, locale: arIndex}
|
||||
ModernStandardArabic Tag = Tag{language: ar001Index, locale: ar001Index}
|
||||
Azerbaijani Tag = Tag{language: azIndex, locale: azIndex}
|
||||
Bulgarian Tag = Tag{language: bgIndex, locale: bgIndex}
|
||||
Bengali Tag = Tag{language: bnIndex, locale: bnIndex}
|
||||
Catalan Tag = Tag{language: caIndex, locale: caIndex}
|
||||
Czech Tag = Tag{language: csIndex, locale: csIndex}
|
||||
Danish Tag = Tag{language: daIndex, locale: daIndex}
|
||||
German Tag = Tag{language: deIndex, locale: deIndex}
|
||||
Greek Tag = Tag{language: elIndex, locale: elIndex}
|
||||
English Tag = Tag{language: enIndex, locale: enIndex}
|
||||
AmericanEnglish Tag = Tag{language: enUSIndex, locale: enUSIndex}
|
||||
BritishEnglish Tag = Tag{language: enGBIndex, locale: enGBIndex}
|
||||
Spanish Tag = Tag{language: esIndex, locale: esIndex}
|
||||
EuropeanSpanish Tag = Tag{language: esESIndex, locale: esESIndex}
|
||||
LatinAmericanSpanish Tag = Tag{language: es419Index, locale: es419Index}
|
||||
Estonian Tag = Tag{language: etIndex, locale: etIndex}
|
||||
Persian Tag = Tag{language: faIndex, locale: faIndex}
|
||||
Finnish Tag = Tag{language: fiIndex, locale: fiIndex}
|
||||
Filipino Tag = Tag{language: filIndex, locale: filIndex}
|
||||
French Tag = Tag{language: frIndex, locale: frIndex}
|
||||
CanadianFrench Tag = Tag{language: frCAIndex, locale: frCAIndex}
|
||||
Gujarati Tag = Tag{language: guIndex, locale: guIndex}
|
||||
Hebrew Tag = Tag{language: heIndex, locale: heIndex}
|
||||
Hindi Tag = Tag{language: hiIndex, locale: hiIndex}
|
||||
Croatian Tag = Tag{language: hrIndex, locale: hrIndex}
|
||||
Hungarian Tag = Tag{language: huIndex, locale: huIndex}
|
||||
Armenian Tag = Tag{language: hyIndex, locale: hyIndex}
|
||||
Indonesian Tag = Tag{language: idIndex, locale: idIndex}
|
||||
Icelandic Tag = Tag{language: isIndex, locale: isIndex}
|
||||
Italian Tag = Tag{language: itIndex, locale: itIndex}
|
||||
Japanese Tag = Tag{language: jaIndex, locale: jaIndex}
|
||||
Georgian Tag = Tag{language: kaIndex, locale: kaIndex}
|
||||
Kazakh Tag = Tag{language: kkIndex, locale: kkIndex}
|
||||
Khmer Tag = Tag{language: kmIndex, locale: kmIndex}
|
||||
Kannada Tag = Tag{language: knIndex, locale: knIndex}
|
||||
Korean Tag = Tag{language: koIndex, locale: koIndex}
|
||||
Kirghiz Tag = Tag{language: kyIndex, locale: kyIndex}
|
||||
Lao Tag = Tag{language: loIndex, locale: loIndex}
|
||||
Lithuanian Tag = Tag{language: ltIndex, locale: ltIndex}
|
||||
Latvian Tag = Tag{language: lvIndex, locale: lvIndex}
|
||||
Macedonian Tag = Tag{language: mkIndex, locale: mkIndex}
|
||||
Malayalam Tag = Tag{language: mlIndex, locale: mlIndex}
|
||||
Mongolian Tag = Tag{language: mnIndex, locale: mnIndex}
|
||||
Marathi Tag = Tag{language: mrIndex, locale: mrIndex}
|
||||
Malay Tag = Tag{language: msIndex, locale: msIndex}
|
||||
Burmese Tag = Tag{language: myIndex, locale: myIndex}
|
||||
Nepali Tag = Tag{language: neIndex, locale: neIndex}
|
||||
Dutch Tag = Tag{language: nlIndex, locale: nlIndex}
|
||||
Norwegian Tag = Tag{language: noIndex, locale: noIndex}
|
||||
Punjabi Tag = Tag{language: paIndex, locale: paIndex}
|
||||
Polish Tag = Tag{language: plIndex, locale: plIndex}
|
||||
Portuguese Tag = Tag{language: ptIndex, locale: ptIndex}
|
||||
BrazilianPortuguese Tag = Tag{language: ptBRIndex, locale: ptBRIndex}
|
||||
EuropeanPortuguese Tag = Tag{language: ptPTIndex, locale: ptPTIndex}
|
||||
Romanian Tag = Tag{language: roIndex, locale: roIndex}
|
||||
Russian Tag = Tag{language: ruIndex, locale: ruIndex}
|
||||
Sinhala Tag = Tag{language: siIndex, locale: siIndex}
|
||||
Slovak Tag = Tag{language: skIndex, locale: skIndex}
|
||||
Slovenian Tag = Tag{language: slIndex, locale: slIndex}
|
||||
Albanian Tag = Tag{language: sqIndex, locale: sqIndex}
|
||||
Serbian Tag = Tag{language: srIndex, locale: srIndex}
|
||||
SerbianLatin Tag = Tag{language: srLatnIndex, locale: srLatnIndex}
|
||||
Swedish Tag = Tag{language: svIndex, locale: svIndex}
|
||||
Swahili Tag = Tag{language: swIndex, locale: swIndex}
|
||||
Tamil Tag = Tag{language: taIndex, locale: taIndex}
|
||||
Telugu Tag = Tag{language: teIndex, locale: teIndex}
|
||||
Thai Tag = Tag{language: thIndex, locale: thIndex}
|
||||
Turkish Tag = Tag{language: trIndex, locale: trIndex}
|
||||
Ukrainian Tag = Tag{language: ukIndex, locale: ukIndex}
|
||||
Urdu Tag = Tag{language: urIndex, locale: urIndex}
|
||||
Uzbek Tag = Tag{language: uzIndex, locale: uzIndex}
|
||||
Vietnamese Tag = Tag{language: viIndex, locale: viIndex}
|
||||
Chinese Tag = Tag{language: zhIndex, locale: zhIndex}
|
||||
SimplifiedChinese Tag = Tag{language: zhHansIndex, locale: zhHansIndex}
|
||||
TraditionalChinese Tag = Tag{language: zhHantIndex, locale: zhHantIndex}
|
||||
Zulu Tag = Tag{language: zuIndex, locale: zuIndex}
|
||||
)
|
|
@ -0,0 +1,167 @@
|
|||
// Copyright 2018 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package language
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// A Builder allows constructing a Tag from individual components.
|
||||
// Its main user is Compose in the top-level language package.
|
||||
type Builder struct {
|
||||
Tag Tag
|
||||
|
||||
private string // the x extension
|
||||
variants []string
|
||||
extensions []string
|
||||
}
|
||||
|
||||
// Make returns a new Tag from the current settings.
|
||||
func (b *Builder) Make() Tag {
|
||||
t := b.Tag
|
||||
|
||||
if len(b.extensions) > 0 || len(b.variants) > 0 {
|
||||
sort.Sort(sortVariants(b.variants))
|
||||
sort.Strings(b.extensions)
|
||||
|
||||
if b.private != "" {
|
||||
b.extensions = append(b.extensions, b.private)
|
||||
}
|
||||
n := maxCoreSize + tokenLen(b.variants...) + tokenLen(b.extensions...)
|
||||
buf := make([]byte, n)
|
||||
p := t.genCoreBytes(buf)
|
||||
t.pVariant = byte(p)
|
||||
p += appendTokens(buf[p:], b.variants...)
|
||||
t.pExt = uint16(p)
|
||||
p += appendTokens(buf[p:], b.extensions...)
|
||||
t.str = string(buf[:p])
|
||||
// We may not always need to remake the string, but when or when not
|
||||
// to do so is rather tricky.
|
||||
scan := makeScanner(buf[:p])
|
||||
t, _ = parse(&scan, "")
|
||||
return t
|
||||
|
||||
} else if b.private != "" {
|
||||
t.str = b.private
|
||||
t.RemakeString()
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// SetTag copies all the settings from a given Tag. Any previously set values
|
||||
// are discarded.
|
||||
func (b *Builder) SetTag(t Tag) {
|
||||
b.Tag.LangID = t.LangID
|
||||
b.Tag.RegionID = t.RegionID
|
||||
b.Tag.ScriptID = t.ScriptID
|
||||
// TODO: optimize
|
||||
b.variants = b.variants[:0]
|
||||
if variants := t.Variants(); variants != "" {
|
||||
for _, vr := range strings.Split(variants[1:], "-") {
|
||||
b.variants = append(b.variants, vr)
|
||||
}
|
||||
}
|
||||
b.extensions, b.private = b.extensions[:0], ""
|
||||
for _, e := range t.Extensions() {
|
||||
b.AddExt(e)
|
||||
}
|
||||
}
|
||||
|
||||
// AddExt adds extension e to the tag. e must be a valid extension as returned
|
||||
// by Tag.Extension. If the extension already exists, it will be discarded,
|
||||
// except for a -u extension, where non-existing key-type pairs will added.
|
||||
func (b *Builder) AddExt(e string) {
|
||||
if e[0] == 'x' {
|
||||
if b.private == "" {
|
||||
b.private = e
|
||||
}
|
||||
return
|
||||
}
|
||||
for i, s := range b.extensions {
|
||||
if s[0] == e[0] {
|
||||
if e[0] == 'u' {
|
||||
b.extensions[i] += e[1:]
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
b.extensions = append(b.extensions, e)
|
||||
}
|
||||
|
||||
// SetExt sets the extension e to the tag. e must be a valid extension as
|
||||
// returned by Tag.Extension. If the extension already exists, it will be
|
||||
// overwritten, except for a -u extension, where the individual key-type pairs
|
||||
// will be set.
|
||||
func (b *Builder) SetExt(e string) {
|
||||
if e[0] == 'x' {
|
||||
b.private = e
|
||||
return
|
||||
}
|
||||
for i, s := range b.extensions {
|
||||
if s[0] == e[0] {
|
||||
if e[0] == 'u' {
|
||||
b.extensions[i] = e + s[1:]
|
||||
} else {
|
||||
b.extensions[i] = e
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
b.extensions = append(b.extensions, e)
|
||||
}
|
||||
|
||||
// AddVariant adds any number of variants.
|
||||
func (b *Builder) AddVariant(v ...string) {
|
||||
for _, v := range v {
|
||||
if v != "" {
|
||||
b.variants = append(b.variants, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ClearVariants removes any variants previously added, including those
|
||||
// copied from a Tag in SetTag.
|
||||
func (b *Builder) ClearVariants() {
|
||||
b.variants = b.variants[:0]
|
||||
}
|
||||
|
||||
// ClearExtensions removes any extensions previously added, including those
|
||||
// copied from a Tag in SetTag.
|
||||
func (b *Builder) ClearExtensions() {
|
||||
b.private = ""
|
||||
b.extensions = b.extensions[:0]
|
||||
}
|
||||
|
||||
func tokenLen(token ...string) (n int) {
|
||||
for _, t := range token {
|
||||
n += len(t) + 1
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func appendTokens(b []byte, token ...string) int {
|
||||
p := 0
|
||||
for _, t := range token {
|
||||
b[p] = '-'
|
||||
copy(b[p+1:], t)
|
||||
p += 1 + len(t)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
type sortVariants []string
|
||||
|
||||
func (s sortVariants) Len() int {
|
||||
return len(s)
|
||||
}
|
||||
|
||||
func (s sortVariants) Swap(i, j int) {
|
||||
s[j], s[i] = s[i], s[j]
|
||||
}
|
||||
|
||||
func (s sortVariants) Less(i, j int) bool {
|
||||
return variantIndex[s[i]] < variantIndex[s[j]]
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package language
|
||||
|
||||
// BaseLanguages returns the list of all supported base languages. It generates
|
||||
// the list by traversing the internal structures.
|
||||
func BaseLanguages() []Language {
|
||||
base := make([]Language, 0, NumLanguages)
|
||||
for i := 0; i < langNoIndexOffset; i++ {
|
||||
// We included "und" already for the value 0.
|
||||
if i != nonCanonicalUnd {
|
||||
base = append(base, Language(i))
|
||||
}
|
||||
}
|
||||
i := langNoIndexOffset
|
||||
for _, v := range langNoIndex {
|
||||
for k := 0; k < 8; k++ {
|
||||
if v&1 == 1 {
|
||||
base = append(base, Language(i))
|
||||
}
|
||||
v >>= 1
|
||||
i++
|
||||
}
|
||||
}
|
||||
return base
|
||||
}
|
|
@ -0,0 +1,596 @@
|
|||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:generate go run gen.go gen_common.go -output tables.go
|
||||
|
||||
package language // import "golang.org/x/text/internal/language"
|
||||
|
||||
// TODO: Remove above NOTE after:
|
||||
// - verifying that tables are dropped correctly (most notably matcher tables).
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxCoreSize is the maximum size of a BCP 47 tag without variants and
|
||||
// extensions. Equals max lang (3) + script (4) + max reg (3) + 2 dashes.
|
||||
maxCoreSize = 12
|
||||
|
||||
// max99thPercentileSize is a somewhat arbitrary buffer size that presumably
|
||||
// is large enough to hold at least 99% of the BCP 47 tags.
|
||||
max99thPercentileSize = 32
|
||||
|
||||
// maxSimpleUExtensionSize is the maximum size of a -u extension with one
|
||||
// key-type pair. Equals len("-u-") + key (2) + dash + max value (8).
|
||||
maxSimpleUExtensionSize = 14
|
||||
)
|
||||
|
||||
// Tag represents a BCP 47 language tag. It is used to specify an instance of a
|
||||
// specific language or locale. All language tag values are guaranteed to be
|
||||
// well-formed. The zero value of Tag is Und.
|
||||
type Tag struct {
|
||||
// TODO: the following fields have the form TagTypeID. This name is chosen
|
||||
// to allow refactoring the public package without conflicting with its
|
||||
// Base, Script, and Region methods. Once the transition is fully completed
|
||||
// the ID can be stripped from the name.
|
||||
|
||||
LangID Language
|
||||
RegionID Region
|
||||
// TODO: we will soon run out of positions for ScriptID. Idea: instead of
|
||||
// storing lang, region, and ScriptID codes, store only the compact index and
|
||||
// have a lookup table from this code to its expansion. This greatly speeds
|
||||
// up table lookup, speed up common variant cases.
|
||||
// This will also immediately free up 3 extra bytes. Also, the pVariant
|
||||
// field can now be moved to the lookup table, as the compact index uniquely
|
||||
// determines the offset of a possible variant.
|
||||
ScriptID Script
|
||||
pVariant byte // offset in str, includes preceding '-'
|
||||
pExt uint16 // offset of first extension, includes preceding '-'
|
||||
|
||||
// str is the string representation of the Tag. It will only be used if the
|
||||
// tag has variants or extensions.
|
||||
str string
|
||||
}
|
||||
|
||||
// Make is a convenience wrapper for Parse that omits the error.
|
||||
// In case of an error, a sensible default is returned.
|
||||
func Make(s string) Tag {
|
||||
t, _ := Parse(s)
|
||||
return t
|
||||
}
|
||||
|
||||
// Raw returns the raw base language, script and region, without making an
|
||||
// attempt to infer their values.
|
||||
// TODO: consider removing
|
||||
func (t Tag) Raw() (b Language, s Script, r Region) {
|
||||
return t.LangID, t.ScriptID, t.RegionID
|
||||
}
|
||||
|
||||
// equalTags compares language, script and region subtags only.
|
||||
func (t Tag) equalTags(a Tag) bool {
|
||||
return t.LangID == a.LangID && t.ScriptID == a.ScriptID && t.RegionID == a.RegionID
|
||||
}
|
||||
|
||||
// IsRoot returns true if t is equal to language "und".
|
||||
func (t Tag) IsRoot() bool {
|
||||
if int(t.pVariant) < len(t.str) {
|
||||
return false
|
||||
}
|
||||
return t.equalTags(Und)
|
||||
}
|
||||
|
||||
// IsPrivateUse reports whether the Tag consists solely of an IsPrivateUse use
|
||||
// tag.
|
||||
func (t Tag) IsPrivateUse() bool {
|
||||
return t.str != "" && t.pVariant == 0
|
||||
}
|
||||
|
||||
// RemakeString is used to update t.str in case lang, script or region changed.
|
||||
// It is assumed that pExt and pVariant still point to the start of the
|
||||
// respective parts.
|
||||
func (t *Tag) RemakeString() {
|
||||
if t.str == "" {
|
||||
return
|
||||
}
|
||||
extra := t.str[t.pVariant:]
|
||||
if t.pVariant > 0 {
|
||||
extra = extra[1:]
|
||||
}
|
||||
if t.equalTags(Und) && strings.HasPrefix(extra, "x-") {
|
||||
t.str = extra
|
||||
t.pVariant = 0
|
||||
t.pExt = 0
|
||||
return
|
||||
}
|
||||
var buf [max99thPercentileSize]byte // avoid extra memory allocation in most cases.
|
||||
b := buf[:t.genCoreBytes(buf[:])]
|
||||
if extra != "" {
|
||||
diff := len(b) - int(t.pVariant)
|
||||
b = append(b, '-')
|
||||
b = append(b, extra...)
|
||||
t.pVariant = uint8(int(t.pVariant) + diff)
|
||||
t.pExt = uint16(int(t.pExt) + diff)
|
||||
} else {
|
||||
t.pVariant = uint8(len(b))
|
||||
t.pExt = uint16(len(b))
|
||||
}
|
||||
t.str = string(b)
|
||||
}
|
||||
|
||||
// genCoreBytes writes a string for the base languages, script and region tags
|
||||
// to the given buffer and returns the number of bytes written. It will never
|
||||
// write more than maxCoreSize bytes.
|
||||
func (t *Tag) genCoreBytes(buf []byte) int {
|
||||
n := t.LangID.StringToBuf(buf[:])
|
||||
if t.ScriptID != 0 {
|
||||
n += copy(buf[n:], "-")
|
||||
n += copy(buf[n:], t.ScriptID.String())
|
||||
}
|
||||
if t.RegionID != 0 {
|
||||
n += copy(buf[n:], "-")
|
||||
n += copy(buf[n:], t.RegionID.String())
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// String returns the canonical string representation of the language tag.
|
||||
func (t Tag) String() string {
|
||||
if t.str != "" {
|
||||
return t.str
|
||||
}
|
||||
if t.ScriptID == 0 && t.RegionID == 0 {
|
||||
return t.LangID.String()
|
||||
}
|
||||
buf := [maxCoreSize]byte{}
|
||||
return string(buf[:t.genCoreBytes(buf[:])])
|
||||
}
|
||||
|
||||
// MarshalText implements encoding.TextMarshaler.
|
||||
func (t Tag) MarshalText() (text []byte, err error) {
|
||||
if t.str != "" {
|
||||
text = append(text, t.str...)
|
||||
} else if t.ScriptID == 0 && t.RegionID == 0 {
|
||||
text = append(text, t.LangID.String()...)
|
||||
} else {
|
||||
buf := [maxCoreSize]byte{}
|
||||
text = buf[:t.genCoreBytes(buf[:])]
|
||||
}
|
||||
return text, nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements encoding.TextUnmarshaler.
|
||||
func (t *Tag) UnmarshalText(text []byte) error {
|
||||
tag, err := Parse(string(text))
|
||||
*t = tag
|
||||
return err
|
||||
}
|
||||
|
||||
// Variants returns the part of the tag holding all variants or the empty string
|
||||
// if there are no variants defined.
|
||||
func (t Tag) Variants() string {
|
||||
if t.pVariant == 0 {
|
||||
return ""
|
||||
}
|
||||
return t.str[t.pVariant:t.pExt]
|
||||
}
|
||||
|
||||
// VariantOrPrivateUseTags returns variants or private use tags.
|
||||
func (t Tag) VariantOrPrivateUseTags() string {
|
||||
if t.pExt > 0 {
|
||||
return t.str[t.pVariant:t.pExt]
|
||||
}
|
||||
return t.str[t.pVariant:]
|
||||
}
|
||||
|
||||
// HasString reports whether this tag defines more than just the raw
|
||||
// components.
|
||||
func (t Tag) HasString() bool {
|
||||
return t.str != ""
|
||||
}
|
||||
|
||||
// Parent returns the CLDR parent of t. In CLDR, missing fields in data for a
|
||||
// specific language are substituted with fields from the parent language.
|
||||
// The parent for a language may change for newer versions of CLDR.
|
||||
func (t Tag) Parent() Tag {
|
||||
if t.str != "" {
|
||||
// Strip the variants and extensions.
|
||||
b, s, r := t.Raw()
|
||||
t = Tag{LangID: b, ScriptID: s, RegionID: r}
|
||||
if t.RegionID == 0 && t.ScriptID != 0 && t.LangID != 0 {
|
||||
base, _ := addTags(Tag{LangID: t.LangID})
|
||||
if base.ScriptID == t.ScriptID {
|
||||
return Tag{LangID: t.LangID}
|
||||
}
|
||||
}
|
||||
return t
|
||||
}
|
||||
if t.LangID != 0 {
|
||||
if t.RegionID != 0 {
|
||||
maxScript := t.ScriptID
|
||||
if maxScript == 0 {
|
||||
max, _ := addTags(t)
|
||||
maxScript = max.ScriptID
|
||||
}
|
||||
|
||||
for i := range parents {
|
||||
if Language(parents[i].lang) == t.LangID && Script(parents[i].maxScript) == maxScript {
|
||||
for _, r := range parents[i].fromRegion {
|
||||
if Region(r) == t.RegionID {
|
||||
return Tag{
|
||||
LangID: t.LangID,
|
||||
ScriptID: Script(parents[i].script),
|
||||
RegionID: Region(parents[i].toRegion),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Strip the script if it is the default one.
|
||||
base, _ := addTags(Tag{LangID: t.LangID})
|
||||
if base.ScriptID != maxScript {
|
||||
return Tag{LangID: t.LangID, ScriptID: maxScript}
|
||||
}
|
||||
return Tag{LangID: t.LangID}
|
||||
} else if t.ScriptID != 0 {
|
||||
// The parent for an base-script pair with a non-default script is
|
||||
// "und" instead of the base language.
|
||||
base, _ := addTags(Tag{LangID: t.LangID})
|
||||
if base.ScriptID != t.ScriptID {
|
||||
return Und
|
||||
}
|
||||
return Tag{LangID: t.LangID}
|
||||
}
|
||||
}
|
||||
return Und
|
||||
}
|
||||
|
||||
// ParseExtension parses s as an extension and returns it on success.
|
||||
func ParseExtension(s string) (ext string, err error) {
|
||||
scan := makeScannerString(s)
|
||||
var end int
|
||||
if n := len(scan.token); n != 1 {
|
||||
return "", ErrSyntax
|
||||
}
|
||||
scan.toLower(0, len(scan.b))
|
||||
end = parseExtension(&scan)
|
||||
if end != len(s) {
|
||||
return "", ErrSyntax
|
||||
}
|
||||
return string(scan.b), nil
|
||||
}
|
||||
|
||||
// HasVariants reports whether t has variants.
|
||||
func (t Tag) HasVariants() bool {
|
||||
return uint16(t.pVariant) < t.pExt
|
||||
}
|
||||
|
||||
// HasExtensions reports whether t has extensions.
|
||||
func (t Tag) HasExtensions() bool {
|
||||
return int(t.pExt) < len(t.str)
|
||||
}
|
||||
|
||||
// Extension returns the extension of type x for tag t. It will return
|
||||
// false for ok if t does not have the requested extension. The returned
|
||||
// extension will be invalid in this case.
|
||||
func (t Tag) Extension(x byte) (ext string, ok bool) {
|
||||
for i := int(t.pExt); i < len(t.str)-1; {
|
||||
var ext string
|
||||
i, ext = getExtension(t.str, i)
|
||||
if ext[0] == x {
|
||||
return ext, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// Extensions returns all extensions of t.
|
||||
func (t Tag) Extensions() []string {
|
||||
e := []string{}
|
||||
for i := int(t.pExt); i < len(t.str)-1; {
|
||||
var ext string
|
||||
i, ext = getExtension(t.str, i)
|
||||
e = append(e, ext)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// TypeForKey returns the type associated with the given key, where key and type
|
||||
// are of the allowed values defined for the Unicode locale extension ('u') in
|
||||
// https://www.unicode.org/reports/tr35/#Unicode_Language_and_Locale_Identifiers.
|
||||
// TypeForKey will traverse the inheritance chain to get the correct value.
|
||||
func (t Tag) TypeForKey(key string) string {
|
||||
if start, end, _ := t.findTypeForKey(key); end != start {
|
||||
return t.str[start:end]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var (
|
||||
errPrivateUse = errors.New("cannot set a key on a private use tag")
|
||||
errInvalidArguments = errors.New("invalid key or type")
|
||||
)
|
||||
|
||||
// SetTypeForKey returns a new Tag with the key set to type, where key and type
|
||||
// are of the allowed values defined for the Unicode locale extension ('u') in
|
||||
// https://www.unicode.org/reports/tr35/#Unicode_Language_and_Locale_Identifiers.
|
||||
// An empty value removes an existing pair with the same key.
|
||||
func (t Tag) SetTypeForKey(key, value string) (Tag, error) {
|
||||
if t.IsPrivateUse() {
|
||||
return t, errPrivateUse
|
||||
}
|
||||
if len(key) != 2 {
|
||||
return t, errInvalidArguments
|
||||
}
|
||||
|
||||
// Remove the setting if value is "".
|
||||
if value == "" {
|
||||
start, end, _ := t.findTypeForKey(key)
|
||||
if start != end {
|
||||
// Remove key tag and leading '-'.
|
||||
start -= 4
|
||||
|
||||
// Remove a possible empty extension.
|
||||
if (end == len(t.str) || t.str[end+2] == '-') && t.str[start-2] == '-' {
|
||||
start -= 2
|
||||
}
|
||||
if start == int(t.pVariant) && end == len(t.str) {
|
||||
t.str = ""
|
||||
t.pVariant, t.pExt = 0, 0
|
||||
} else {
|
||||
t.str = fmt.Sprintf("%s%s", t.str[:start], t.str[end:])
|
||||
}
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
if len(value) < 3 || len(value) > 8 {
|
||||
return t, errInvalidArguments
|
||||
}
|
||||
|
||||
var (
|
||||
buf [maxCoreSize + maxSimpleUExtensionSize]byte
|
||||
uStart int // start of the -u extension.
|
||||
)
|
||||
|
||||
// Generate the tag string if needed.
|
||||
if t.str == "" {
|
||||
uStart = t.genCoreBytes(buf[:])
|
||||
buf[uStart] = '-'
|
||||
uStart++
|
||||
}
|
||||
|
||||
// Create new key-type pair and parse it to verify.
|
||||
b := buf[uStart:]
|
||||
copy(b, "u-")
|
||||
copy(b[2:], key)
|
||||
b[4] = '-'
|
||||
b = b[:5+copy(b[5:], value)]
|
||||
scan := makeScanner(b)
|
||||
if parseExtensions(&scan); scan.err != nil {
|
||||
return t, scan.err
|
||||
}
|
||||
|
||||
// Assemble the replacement string.
|
||||
if t.str == "" {
|
||||
t.pVariant, t.pExt = byte(uStart-1), uint16(uStart-1)
|
||||
t.str = string(buf[:uStart+len(b)])
|
||||
} else {
|
||||
s := t.str
|
||||
start, end, hasExt := t.findTypeForKey(key)
|
||||
if start == end {
|
||||
if hasExt {
|
||||
b = b[2:]
|
||||
}
|
||||
t.str = fmt.Sprintf("%s-%s%s", s[:start], b, s[end:])
|
||||
} else {
|
||||
t.str = fmt.Sprintf("%s%s%s", s[:start], value, s[end:])
|
||||
}
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// findKeyAndType returns the start and end position for the type corresponding
|
||||
// to key or the point at which to insert the key-value pair if the type
|
||||
// wasn't found. The hasExt return value reports whether an -u extension was present.
|
||||
// Note: the extensions are typically very small and are likely to contain
|
||||
// only one key-type pair.
|
||||
func (t Tag) findTypeForKey(key string) (start, end int, hasExt bool) {
|
||||
p := int(t.pExt)
|
||||
if len(key) != 2 || p == len(t.str) || p == 0 {
|
||||
return p, p, false
|
||||
}
|
||||
s := t.str
|
||||
|
||||
// Find the correct extension.
|
||||
for p++; s[p] != 'u'; p++ {
|
||||
if s[p] > 'u' {
|
||||
p--
|
||||
return p, p, false
|
||||
}
|
||||
if p = nextExtension(s, p); p == len(s) {
|
||||
return len(s), len(s), false
|
||||
}
|
||||
}
|
||||
// Proceed to the hyphen following the extension name.
|
||||
p++
|
||||
|
||||
// curKey is the key currently being processed.
|
||||
curKey := ""
|
||||
|
||||
// Iterate over keys until we get the end of a section.
|
||||
for {
|
||||
// p points to the hyphen preceding the current token.
|
||||
if p3 := p + 3; s[p3] == '-' {
|
||||
// Found a key.
|
||||
// Check whether we just processed the key that was requested.
|
||||
if curKey == key {
|
||||
return start, p, true
|
||||
}
|
||||
// Set to the next key and continue scanning type tokens.
|
||||
curKey = s[p+1 : p3]
|
||||
if curKey > key {
|
||||
return p, p, true
|
||||
}
|
||||
// Start of the type token sequence.
|
||||
start = p + 4
|
||||
// A type is at least 3 characters long.
|
||||
p += 7 // 4 + 3
|
||||
} else {
|
||||
// Attribute or type, which is at least 3 characters long.
|
||||
p += 4
|
||||
}
|
||||
// p points past the third character of a type or attribute.
|
||||
max := p + 5 // maximum length of token plus hyphen.
|
||||
if len(s) < max {
|
||||
max = len(s)
|
||||
}
|
||||
for ; p < max && s[p] != '-'; p++ {
|
||||
}
|
||||
// Bail if we have exhausted all tokens or if the next token starts
|
||||
// a new extension.
|
||||
if p == len(s) || s[p+2] == '-' {
|
||||
if curKey == key {
|
||||
return start, p, true
|
||||
}
|
||||
return p, p, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ParseBase parses a 2- or 3-letter ISO 639 code.
|
||||
// It returns a ValueError if s is a well-formed but unknown language identifier
|
||||
// or another error if another error occurred.
|
||||
func ParseBase(s string) (Language, error) {
|
||||
if n := len(s); n < 2 || 3 < n {
|
||||
return 0, ErrSyntax
|
||||
}
|
||||
var buf [3]byte
|
||||
return getLangID(buf[:copy(buf[:], s)])
|
||||
}
|
||||
|
||||
// ParseScript parses a 4-letter ISO 15924 code.
|
||||
// It returns a ValueError if s is a well-formed but unknown script identifier
|
||||
// or another error if another error occurred.
|
||||
func ParseScript(s string) (Script, error) {
|
||||
if len(s) != 4 {
|
||||
return 0, ErrSyntax
|
||||
}
|
||||
var buf [4]byte
|
||||
return getScriptID(script, buf[:copy(buf[:], s)])
|
||||
}
|
||||
|
||||
// EncodeM49 returns the Region for the given UN M.49 code.
|
||||
// It returns an error if r is not a valid code.
|
||||
func EncodeM49(r int) (Region, error) {
|
||||
return getRegionM49(r)
|
||||
}
|
||||
|
||||
// ParseRegion parses a 2- or 3-letter ISO 3166-1 or a UN M.49 code.
|
||||
// It returns a ValueError if s is a well-formed but unknown region identifier
|
||||
// or another error if another error occurred.
|
||||
func ParseRegion(s string) (Region, error) {
|
||||
if n := len(s); n < 2 || 3 < n {
|
||||
return 0, ErrSyntax
|
||||
}
|
||||
var buf [3]byte
|
||||
return getRegionID(buf[:copy(buf[:], s)])
|
||||
}
|
||||
|
||||
// IsCountry returns whether this region is a country or autonomous area. This
|
||||
// includes non-standard definitions from CLDR.
|
||||
func (r Region) IsCountry() bool {
|
||||
if r == 0 || r.IsGroup() || r.IsPrivateUse() && r != _XK {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// IsGroup returns whether this region defines a collection of regions. This
|
||||
// includes non-standard definitions from CLDR.
|
||||
func (r Region) IsGroup() bool {
|
||||
if r == 0 {
|
||||
return false
|
||||
}
|
||||
return int(regionInclusion[r]) < len(regionContainment)
|
||||
}
|
||||
|
||||
// Contains returns whether Region c is contained by Region r. It returns true
|
||||
// if c == r.
|
||||
func (r Region) Contains(c Region) bool {
|
||||
if r == c {
|
||||
return true
|
||||
}
|
||||
g := regionInclusion[r]
|
||||
if g >= nRegionGroups {
|
||||
return false
|
||||
}
|
||||
m := regionContainment[g]
|
||||
|
||||
d := regionInclusion[c]
|
||||
b := regionInclusionBits[d]
|
||||
|
||||
// A contained country may belong to multiple disjoint groups. Matching any
|
||||
// of these indicates containment. If the contained region is a group, it
|
||||
// must strictly be a subset.
|
||||
if d >= nRegionGroups {
|
||||
return b&m != 0
|
||||
}
|
||||
return b&^m == 0
|
||||
}
|
||||
|
||||
var errNoTLD = errors.New("language: region is not a valid ccTLD")
|
||||
|
||||
// TLD returns the country code top-level domain (ccTLD). UK is returned for GB.
|
||||
// In all other cases it returns either the region itself or an error.
|
||||
//
|
||||
// This method may return an error for a region for which there exists a
|
||||
// canonical form with a ccTLD. To get that ccTLD canonicalize r first. The
|
||||
// region will already be canonicalized it was obtained from a Tag that was
|
||||
// obtained using any of the default methods.
|
||||
func (r Region) TLD() (Region, error) {
|
||||
// See http://en.wikipedia.org/wiki/Country_code_top-level_domain for the
|
||||
// difference between ISO 3166-1 and IANA ccTLD.
|
||||
if r == _GB {
|
||||
r = _UK
|
||||
}
|
||||
if (r.typ() & ccTLD) == 0 {
|
||||
return 0, errNoTLD
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// Canonicalize returns the region or a possible replacement if the region is
|
||||
// deprecated. It will not return a replacement for deprecated regions that
|
||||
// are split into multiple regions.
|
||||
func (r Region) Canonicalize() Region {
|
||||
if cr := normRegion(r); cr != 0 {
|
||||
return cr
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// Variant represents a registered variant of a language as defined by BCP 47.
|
||||
type Variant struct {
|
||||
ID uint8
|
||||
str string
|
||||
}
|
||||
|
||||
// ParseVariant parses and returns a Variant. An error is returned if s is not
|
||||
// a valid variant.
|
||||
func ParseVariant(s string) (Variant, error) {
|
||||
s = strings.ToLower(s)
|
||||
if id, ok := variantIndex[s]; ok {
|
||||
return Variant{id, s}, nil
|
||||
}
|
||||
return Variant{}, NewValueError([]byte(s))
|
||||
}
|
||||
|
||||
// String returns the string representation of the variant.
|
||||
func (v Variant) String() string {
|
||||
return v.str
|
||||
}
|
|
@ -0,0 +1,412 @@
|
|||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package language
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
|
||||
"golang.org/x/text/internal/tag"
|
||||
)
|
||||
|
||||
// findIndex tries to find the given tag in idx and returns a standardized error
|
||||
// if it could not be found.
|
||||
func findIndex(idx tag.Index, key []byte, form string) (index int, err error) {
|
||||
if !tag.FixCase(form, key) {
|
||||
return 0, ErrSyntax
|
||||
}
|
||||
i := idx.Index(key)
|
||||
if i == -1 {
|
||||
return 0, NewValueError(key)
|
||||
}
|
||||
return i, nil
|
||||
}
|
||||
|
||||
func searchUint(imap []uint16, key uint16) int {
|
||||
return sort.Search(len(imap), func(i int) bool {
|
||||
return imap[i] >= key
|
||||
})
|
||||
}
|
||||
|
||||
type Language uint16
|
||||
|
||||
// getLangID returns the langID of s if s is a canonical subtag
|
||||
// or langUnknown if s is not a canonical subtag.
|
||||
func getLangID(s []byte) (Language, error) {
|
||||
if len(s) == 2 {
|
||||
return getLangISO2(s)
|
||||
}
|
||||
return getLangISO3(s)
|
||||
}
|
||||
|
||||
// TODO language normalization as well as the AliasMaps could be moved to the
|
||||
// higher level package, but it is a bit tricky to separate the generation.
|
||||
|
||||
func (id Language) Canonicalize() (Language, AliasType) {
|
||||
return normLang(id)
|
||||
}
|
||||
|
||||
// mapLang returns the mapped langID of id according to mapping m.
|
||||
func normLang(id Language) (Language, AliasType) {
|
||||
k := sort.Search(len(AliasMap), func(i int) bool {
|
||||
return AliasMap[i].From >= uint16(id)
|
||||
})
|
||||
if k < len(AliasMap) && AliasMap[k].From == uint16(id) {
|
||||
return Language(AliasMap[k].To), AliasTypes[k]
|
||||
}
|
||||
return id, AliasTypeUnknown
|
||||
}
|
||||
|
||||
// getLangISO2 returns the langID for the given 2-letter ISO language code
|
||||
// or unknownLang if this does not exist.
|
||||
func getLangISO2(s []byte) (Language, error) {
|
||||
if !tag.FixCase("zz", s) {
|
||||
return 0, ErrSyntax
|
||||
}
|
||||
if i := lang.Index(s); i != -1 && lang.Elem(i)[3] != 0 {
|
||||
return Language(i), nil
|
||||
}
|
||||
return 0, NewValueError(s)
|
||||
}
|
||||
|
||||
const base = 'z' - 'a' + 1
|
||||
|
||||
func strToInt(s []byte) uint {
|
||||
v := uint(0)
|
||||
for i := 0; i < len(s); i++ {
|
||||
v *= base
|
||||
v += uint(s[i] - 'a')
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// converts the given integer to the original ASCII string passed to strToInt.
|
||||
// len(s) must match the number of characters obtained.
|
||||
func intToStr(v uint, s []byte) {
|
||||
for i := len(s) - 1; i >= 0; i-- {
|
||||
s[i] = byte(v%base) + 'a'
|
||||
v /= base
|
||||
}
|
||||
}
|
||||
|
||||
// getLangISO3 returns the langID for the given 3-letter ISO language code
|
||||
// or unknownLang if this does not exist.
|
||||
func getLangISO3(s []byte) (Language, error) {
|
||||
if tag.FixCase("und", s) {
|
||||
// first try to match canonical 3-letter entries
|
||||
for i := lang.Index(s[:2]); i != -1; i = lang.Next(s[:2], i) {
|
||||
if e := lang.Elem(i); e[3] == 0 && e[2] == s[2] {
|
||||
// We treat "und" as special and always translate it to "unspecified".
|
||||
// Note that ZZ and Zzzz are private use and are not treated as
|
||||
// unspecified by default.
|
||||
id := Language(i)
|
||||
if id == nonCanonicalUnd {
|
||||
return 0, nil
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
}
|
||||
if i := altLangISO3.Index(s); i != -1 {
|
||||
return Language(altLangIndex[altLangISO3.Elem(i)[3]]), nil
|
||||
}
|
||||
n := strToInt(s)
|
||||
if langNoIndex[n/8]&(1<<(n%8)) != 0 {
|
||||
return Language(n) + langNoIndexOffset, nil
|
||||
}
|
||||
// Check for non-canonical uses of ISO3.
|
||||
for i := lang.Index(s[:1]); i != -1; i = lang.Next(s[:1], i) {
|
||||
if e := lang.Elem(i); e[2] == s[1] && e[3] == s[2] {
|
||||
return Language(i), nil
|
||||
}
|
||||
}
|
||||
return 0, NewValueError(s)
|
||||
}
|
||||
return 0, ErrSyntax
|
||||
}
|
||||
|
||||
// StringToBuf writes the string to b and returns the number of bytes
|
||||
// written. cap(b) must be >= 3.
|
||||
func (id Language) StringToBuf(b []byte) int {
|
||||
if id >= langNoIndexOffset {
|
||||
intToStr(uint(id)-langNoIndexOffset, b[:3])
|
||||
return 3
|
||||
} else if id == 0 {
|
||||
return copy(b, "und")
|
||||
}
|
||||
l := lang[id<<2:]
|
||||
if l[3] == 0 {
|
||||
return copy(b, l[:3])
|
||||
}
|
||||
return copy(b, l[:2])
|
||||
}
|
||||
|
||||
// String returns the BCP 47 representation of the langID.
|
||||
// Use b as variable name, instead of id, to ensure the variable
|
||||
// used is consistent with that of Base in which this type is embedded.
|
||||
func (b Language) String() string {
|
||||
if b == 0 {
|
||||
return "und"
|
||||
} else if b >= langNoIndexOffset {
|
||||
b -= langNoIndexOffset
|
||||
buf := [3]byte{}
|
||||
intToStr(uint(b), buf[:])
|
||||
return string(buf[:])
|
||||
}
|
||||
l := lang.Elem(int(b))
|
||||
if l[3] == 0 {
|
||||
return l[:3]
|
||||
}
|
||||
return l[:2]
|
||||
}
|
||||
|
||||
// ISO3 returns the ISO 639-3 language code.
|
||||
func (b Language) ISO3() string {
|
||||
if b == 0 || b >= langNoIndexOffset {
|
||||
return b.String()
|
||||
}
|
||||
l := lang.Elem(int(b))
|
||||
if l[3] == 0 {
|
||||
return l[:3]
|
||||
} else if l[2] == 0 {
|
||||
return altLangISO3.Elem(int(l[3]))[:3]
|
||||
}
|
||||
// This allocation will only happen for 3-letter ISO codes
|
||||
// that are non-canonical BCP 47 language identifiers.
|
||||
return l[0:1] + l[2:4]
|
||||
}
|
||||
|
||||
// IsPrivateUse reports whether this language code is reserved for private use.
|
||||
func (b Language) IsPrivateUse() bool {
|
||||
return langPrivateStart <= b && b <= langPrivateEnd
|
||||
}
|
||||
|
||||
// SuppressScript returns the script marked as SuppressScript in the IANA
|
||||
// language tag repository, or 0 if there is no such script.
|
||||
func (b Language) SuppressScript() Script {
|
||||
if b < langNoIndexOffset {
|
||||
return Script(suppressScript[b])
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
type Region uint16
|
||||
|
||||
// getRegionID returns the region id for s if s is a valid 2-letter region code
|
||||
// or unknownRegion.
|
||||
func getRegionID(s []byte) (Region, error) {
|
||||
if len(s) == 3 {
|
||||
if isAlpha(s[0]) {
|
||||
return getRegionISO3(s)
|
||||
}
|
||||
if i, err := strconv.ParseUint(string(s), 10, 10); err == nil {
|
||||
return getRegionM49(int(i))
|
||||
}
|
||||
}
|
||||
return getRegionISO2(s)
|
||||
}
|
||||
|
||||
// getRegionISO2 returns the regionID for the given 2-letter ISO country code
|
||||
// or unknownRegion if this does not exist.
|
||||
func getRegionISO2(s []byte) (Region, error) {
|
||||
i, err := findIndex(regionISO, s, "ZZ")
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return Region(i) + isoRegionOffset, nil
|
||||
}
|
||||
|
||||
// getRegionISO3 returns the regionID for the given 3-letter ISO country code
|
||||
// or unknownRegion if this does not exist.
|
||||
func getRegionISO3(s []byte) (Region, error) {
|
||||
if tag.FixCase("ZZZ", s) {
|
||||
for i := regionISO.Index(s[:1]); i != -1; i = regionISO.Next(s[:1], i) {
|
||||
if e := regionISO.Elem(i); e[2] == s[1] && e[3] == s[2] {
|
||||
return Region(i) + isoRegionOffset, nil
|
||||
}
|
||||
}
|
||||
for i := 0; i < len(altRegionISO3); i += 3 {
|
||||
if tag.Compare(altRegionISO3[i:i+3], s) == 0 {
|
||||
return Region(altRegionIDs[i/3]), nil
|
||||
}
|
||||
}
|
||||
return 0, NewValueError(s)
|
||||
}
|
||||
return 0, ErrSyntax
|
||||
}
|
||||
|
||||
func getRegionM49(n int) (Region, error) {
|
||||
if 0 < n && n <= 999 {
|
||||
const (
|
||||
searchBits = 7
|
||||
regionBits = 9
|
||||
regionMask = 1<<regionBits - 1
|
||||
)
|
||||
idx := n >> searchBits
|
||||
buf := fromM49[m49Index[idx]:m49Index[idx+1]]
|
||||
val := uint16(n) << regionBits // we rely on bits shifting out
|
||||
i := sort.Search(len(buf), func(i int) bool {
|
||||
return buf[i] >= val
|
||||
})
|
||||
if r := fromM49[int(m49Index[idx])+i]; r&^regionMask == val {
|
||||
return Region(r & regionMask), nil
|
||||
}
|
||||
}
|
||||
var e ValueError
|
||||
fmt.Fprint(bytes.NewBuffer([]byte(e.v[:])), n)
|
||||
return 0, e
|
||||
}
|
||||
|
||||
// normRegion returns a region if r is deprecated or 0 otherwise.
|
||||
// TODO: consider supporting BYS (-> BLR), CSK (-> 200 or CZ), PHI (-> PHL) and AFI (-> DJ).
|
||||
// TODO: consider mapping split up regions to new most populous one (like CLDR).
|
||||
func normRegion(r Region) Region {
|
||||
m := regionOldMap
|
||||
k := sort.Search(len(m), func(i int) bool {
|
||||
return m[i].From >= uint16(r)
|
||||
})
|
||||
if k < len(m) && m[k].From == uint16(r) {
|
||||
return Region(m[k].To)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
const (
|
||||
iso3166UserAssigned = 1 << iota
|
||||
ccTLD
|
||||
bcp47Region
|
||||
)
|
||||
|
||||
func (r Region) typ() byte {
|
||||
return regionTypes[r]
|
||||
}
|
||||
|
||||
// String returns the BCP 47 representation for the region.
|
||||
// It returns "ZZ" for an unspecified region.
|
||||
func (r Region) String() string {
|
||||
if r < isoRegionOffset {
|
||||
if r == 0 {
|
||||
return "ZZ"
|
||||
}
|
||||
return fmt.Sprintf("%03d", r.M49())
|
||||
}
|
||||
r -= isoRegionOffset
|
||||
return regionISO.Elem(int(r))[:2]
|
||||
}
|
||||
|
||||
// ISO3 returns the 3-letter ISO code of r.
|
||||
// Note that not all regions have a 3-letter ISO code.
|
||||
// In such cases this method returns "ZZZ".
|
||||
func (r Region) ISO3() string {
|
||||
if r < isoRegionOffset {
|
||||
return "ZZZ"
|
||||
}
|
||||
r -= isoRegionOffset
|
||||
reg := regionISO.Elem(int(r))
|
||||
switch reg[2] {
|
||||
case 0:
|
||||
return altRegionISO3[reg[3]:][:3]
|
||||
case ' ':
|
||||
return "ZZZ"
|
||||
}
|
||||
return reg[0:1] + reg[2:4]
|
||||
}
|
||||
|
||||
// M49 returns the UN M.49 encoding of r, or 0 if this encoding
|
||||
// is not defined for r.
|
||||
func (r Region) M49() int {
|
||||
return int(m49[r])
|
||||
}
|
||||
|
||||
// IsPrivateUse reports whether r has the ISO 3166 User-assigned status. This
|
||||
// may include private-use tags that are assigned by CLDR and used in this
|
||||
// implementation. So IsPrivateUse and IsCountry can be simultaneously true.
|
||||
func (r Region) IsPrivateUse() bool {
|
||||
return r.typ()&iso3166UserAssigned != 0
|
||||
}
|
||||
|
||||
type Script uint8
|
||||
|
||||
// getScriptID returns the script id for string s. It assumes that s
|
||||
// is of the format [A-Z][a-z]{3}.
|
||||
func getScriptID(idx tag.Index, s []byte) (Script, error) {
|
||||
i, err := findIndex(idx, s, "Zzzz")
|
||||
return Script(i), err
|
||||
}
|
||||
|
||||
// String returns the script code in title case.
|
||||
// It returns "Zzzz" for an unspecified script.
|
||||
func (s Script) String() string {
|
||||
if s == 0 {
|
||||
return "Zzzz"
|
||||
}
|
||||
return script.Elem(int(s))
|
||||
}
|
||||
|
||||
// IsPrivateUse reports whether this script code is reserved for private use.
|
||||
func (s Script) IsPrivateUse() bool {
|
||||
return _Qaaa <= s && s <= _Qabx
|
||||
}
|
||||
|
||||
const (
|
||||
maxAltTaglen = len("en-US-POSIX")
|
||||
maxLen = maxAltTaglen
|
||||
)
|
||||
|
||||
var (
|
||||
// grandfatheredMap holds a mapping from legacy and grandfathered tags to
|
||||
// their base language or index to more elaborate tag.
|
||||
grandfatheredMap = map[[maxLen]byte]int16{
|
||||
[maxLen]byte{'a', 'r', 't', '-', 'l', 'o', 'j', 'b', 'a', 'n'}: _jbo, // art-lojban
|
||||
[maxLen]byte{'i', '-', 'a', 'm', 'i'}: _ami, // i-ami
|
||||
[maxLen]byte{'i', '-', 'b', 'n', 'n'}: _bnn, // i-bnn
|
||||
[maxLen]byte{'i', '-', 'h', 'a', 'k'}: _hak, // i-hak
|
||||
[maxLen]byte{'i', '-', 'k', 'l', 'i', 'n', 'g', 'o', 'n'}: _tlh, // i-klingon
|
||||
[maxLen]byte{'i', '-', 'l', 'u', 'x'}: _lb, // i-lux
|
||||
[maxLen]byte{'i', '-', 'n', 'a', 'v', 'a', 'j', 'o'}: _nv, // i-navajo
|
||||
[maxLen]byte{'i', '-', 'p', 'w', 'n'}: _pwn, // i-pwn
|
||||
[maxLen]byte{'i', '-', 't', 'a', 'o'}: _tao, // i-tao
|
||||
[maxLen]byte{'i', '-', 't', 'a', 'y'}: _tay, // i-tay
|
||||
[maxLen]byte{'i', '-', 't', 's', 'u'}: _tsu, // i-tsu
|
||||
[maxLen]byte{'n', 'o', '-', 'b', 'o', 'k'}: _nb, // no-bok
|
||||
[maxLen]byte{'n', 'o', '-', 'n', 'y', 'n'}: _nn, // no-nyn
|
||||
[maxLen]byte{'s', 'g', 'n', '-', 'b', 'e', '-', 'f', 'r'}: _sfb, // sgn-BE-FR
|
||||
[maxLen]byte{'s', 'g', 'n', '-', 'b', 'e', '-', 'n', 'l'}: _vgt, // sgn-BE-NL
|
||||
[maxLen]byte{'s', 'g', 'n', '-', 'c', 'h', '-', 'd', 'e'}: _sgg, // sgn-CH-DE
|
||||
[maxLen]byte{'z', 'h', '-', 'g', 'u', 'o', 'y', 'u'}: _cmn, // zh-guoyu
|
||||
[maxLen]byte{'z', 'h', '-', 'h', 'a', 'k', 'k', 'a'}: _hak, // zh-hakka
|
||||
[maxLen]byte{'z', 'h', '-', 'm', 'i', 'n', '-', 'n', 'a', 'n'}: _nan, // zh-min-nan
|
||||
[maxLen]byte{'z', 'h', '-', 'x', 'i', 'a', 'n', 'g'}: _hsn, // zh-xiang
|
||||
|
||||
// Grandfathered tags with no modern replacement will be converted as
|
||||
// follows:
|
||||
[maxLen]byte{'c', 'e', 'l', '-', 'g', 'a', 'u', 'l', 'i', 's', 'h'}: -1, // cel-gaulish
|
||||
[maxLen]byte{'e', 'n', '-', 'g', 'b', '-', 'o', 'e', 'd'}: -2, // en-GB-oed
|
||||
[maxLen]byte{'i', '-', 'd', 'e', 'f', 'a', 'u', 'l', 't'}: -3, // i-default
|
||||
[maxLen]byte{'i', '-', 'e', 'n', 'o', 'c', 'h', 'i', 'a', 'n'}: -4, // i-enochian
|
||||
[maxLen]byte{'i', '-', 'm', 'i', 'n', 'g', 'o'}: -5, // i-mingo
|
||||
[maxLen]byte{'z', 'h', '-', 'm', 'i', 'n'}: -6, // zh-min
|
||||
|
||||
// CLDR-specific tag.
|
||||
[maxLen]byte{'r', 'o', 'o', 't'}: 0, // root
|
||||
[maxLen]byte{'e', 'n', '-', 'u', 's', '-', 'p', 'o', 's', 'i', 'x'}: -7, // en_US_POSIX"
|
||||
}
|
||||
|
||||
altTagIndex = [...]uint8{0, 17, 31, 45, 61, 74, 86, 102}
|
||||
|
||||
altTags = "xtg-x-cel-gaulishen-GB-oxendicten-x-i-defaultund-x-i-enochiansee-x-i-mingonan-x-zh-minen-US-u-va-posix"
|
||||
)
|
||||
|
||||
func grandfathered(s [maxAltTaglen]byte) (t Tag, ok bool) {
|
||||
if v, ok := grandfatheredMap[s]; ok {
|
||||
if v < 0 {
|
||||
return Make(altTags[altTagIndex[-v-1]:altTagIndex[-v]]), true
|
||||
}
|
||||
t.LangID = Language(v)
|
||||
return t, true
|
||||
}
|
||||
return t, false
|
||||
}
|
|
@ -0,0 +1,226 @@
|
|||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package language
|
||||
|
||||
import "errors"
|
||||
|
||||
type scriptRegionFlags uint8
|
||||
|
||||
const (
|
||||
isList = 1 << iota
|
||||
scriptInFrom
|
||||
regionInFrom
|
||||
)
|
||||
|
||||
func (t *Tag) setUndefinedLang(id Language) {
|
||||
if t.LangID == 0 {
|
||||
t.LangID = id
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tag) setUndefinedScript(id Script) {
|
||||
if t.ScriptID == 0 {
|
||||
t.ScriptID = id
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tag) setUndefinedRegion(id Region) {
|
||||
if t.RegionID == 0 || t.RegionID.Contains(id) {
|
||||
t.RegionID = id
|
||||
}
|
||||
}
|
||||
|
||||
// ErrMissingLikelyTagsData indicates no information was available
|
||||
// to compute likely values of missing tags.
|
||||
var ErrMissingLikelyTagsData = errors.New("missing likely tags data")
|
||||
|
||||
// addLikelySubtags sets subtags to their most likely value, given the locale.
|
||||
// In most cases this means setting fields for unknown values, but in some
|
||||
// cases it may alter a value. It returns an ErrMissingLikelyTagsData error
|
||||
// if the given locale cannot be expanded.
|
||||
func (t Tag) addLikelySubtags() (Tag, error) {
|
||||
id, err := addTags(t)
|
||||
if err != nil {
|
||||
return t, err
|
||||
} else if id.equalTags(t) {
|
||||
return t, nil
|
||||
}
|
||||
id.RemakeString()
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// specializeRegion attempts to specialize a group region.
|
||||
func specializeRegion(t *Tag) bool {
|
||||
if i := regionInclusion[t.RegionID]; i < nRegionGroups {
|
||||
x := likelyRegionGroup[i]
|
||||
if Language(x.lang) == t.LangID && Script(x.script) == t.ScriptID {
|
||||
t.RegionID = Region(x.region)
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Maximize returns a new tag with missing tags filled in.
|
||||
func (t Tag) Maximize() (Tag, error) {
|
||||
return addTags(t)
|
||||
}
|
||||
|
||||
func addTags(t Tag) (Tag, error) {
|
||||
// We leave private use identifiers alone.
|
||||
if t.IsPrivateUse() {
|
||||
return t, nil
|
||||
}
|
||||
if t.ScriptID != 0 && t.RegionID != 0 {
|
||||
if t.LangID != 0 {
|
||||
// already fully specified
|
||||
specializeRegion(&t)
|
||||
return t, nil
|
||||
}
|
||||
// Search matches for und-script-region. Note that for these cases
|
||||
// region will never be a group so there is no need to check for this.
|
||||
list := likelyRegion[t.RegionID : t.RegionID+1]
|
||||
if x := list[0]; x.flags&isList != 0 {
|
||||
list = likelyRegionList[x.lang : x.lang+uint16(x.script)]
|
||||
}
|
||||
for _, x := range list {
|
||||
// Deviating from the spec. See match_test.go for details.
|
||||
if Script(x.script) == t.ScriptID {
|
||||
t.setUndefinedLang(Language(x.lang))
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
if t.LangID != 0 {
|
||||
// Search matches for lang-script and lang-region, where lang != und.
|
||||
if t.LangID < langNoIndexOffset {
|
||||
x := likelyLang[t.LangID]
|
||||
if x.flags&isList != 0 {
|
||||
list := likelyLangList[x.region : x.region+uint16(x.script)]
|
||||
if t.ScriptID != 0 {
|
||||
for _, x := range list {
|
||||
if Script(x.script) == t.ScriptID && x.flags&scriptInFrom != 0 {
|
||||
t.setUndefinedRegion(Region(x.region))
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
} else if t.RegionID != 0 {
|
||||
count := 0
|
||||
goodScript := true
|
||||
tt := t
|
||||
for _, x := range list {
|
||||
// We visit all entries for which the script was not
|
||||
// defined, including the ones where the region was not
|
||||
// defined. This allows for proper disambiguation within
|
||||
// regions.
|
||||
if x.flags&scriptInFrom == 0 && t.RegionID.Contains(Region(x.region)) {
|
||||
tt.RegionID = Region(x.region)
|
||||
tt.setUndefinedScript(Script(x.script))
|
||||
goodScript = goodScript && tt.ScriptID == Script(x.script)
|
||||
count++
|
||||
}
|
||||
}
|
||||
if count == 1 {
|
||||
return tt, nil
|
||||
}
|
||||
// Even if we fail to find a unique Region, we might have
|
||||
// an unambiguous script.
|
||||
if goodScript {
|
||||
t.ScriptID = tt.ScriptID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Search matches for und-script.
|
||||
if t.ScriptID != 0 {
|
||||
x := likelyScript[t.ScriptID]
|
||||
if x.region != 0 {
|
||||
t.setUndefinedRegion(Region(x.region))
|
||||
t.setUndefinedLang(Language(x.lang))
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
// Search matches for und-region. If und-script-region exists, it would
|
||||
// have been found earlier.
|
||||
if t.RegionID != 0 {
|
||||
if i := regionInclusion[t.RegionID]; i < nRegionGroups {
|
||||
x := likelyRegionGroup[i]
|
||||
if x.region != 0 {
|
||||
t.setUndefinedLang(Language(x.lang))
|
||||
t.setUndefinedScript(Script(x.script))
|
||||
t.RegionID = Region(x.region)
|
||||
}
|
||||
} else {
|
||||
x := likelyRegion[t.RegionID]
|
||||
if x.flags&isList != 0 {
|
||||
x = likelyRegionList[x.lang]
|
||||
}
|
||||
if x.script != 0 && x.flags != scriptInFrom {
|
||||
t.setUndefinedLang(Language(x.lang))
|
||||
t.setUndefinedScript(Script(x.script))
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Search matches for lang.
|
||||
if t.LangID < langNoIndexOffset {
|
||||
x := likelyLang[t.LangID]
|
||||
if x.flags&isList != 0 {
|
||||
x = likelyLangList[x.region]
|
||||
}
|
||||
if x.region != 0 {
|
||||
t.setUndefinedScript(Script(x.script))
|
||||
t.setUndefinedRegion(Region(x.region))
|
||||
}
|
||||
specializeRegion(&t)
|
||||
if t.LangID == 0 {
|
||||
t.LangID = _en // default language
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
return t, ErrMissingLikelyTagsData
|
||||
}
|
||||
|
||||
func (t *Tag) setTagsFrom(id Tag) {
|
||||
t.LangID = id.LangID
|
||||
t.ScriptID = id.ScriptID
|
||||
t.RegionID = id.RegionID
|
||||
}
|
||||
|
||||
// minimize removes the region or script subtags from t such that
|
||||
// t.addLikelySubtags() == t.minimize().addLikelySubtags().
|
||||
func (t Tag) minimize() (Tag, error) {
|
||||
t, err := minimizeTags(t)
|
||||
if err != nil {
|
||||
return t, err
|
||||
}
|
||||
t.RemakeString()
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// minimizeTags mimics the behavior of the ICU 51 C implementation.
|
||||
func minimizeTags(t Tag) (Tag, error) {
|
||||
if t.equalTags(Und) {
|
||||
return t, nil
|
||||
}
|
||||
max, err := addTags(t)
|
||||
if err != nil {
|
||||
return t, err
|
||||
}
|
||||
for _, id := range [...]Tag{
|
||||
{LangID: t.LangID},
|
||||
{LangID: t.LangID, RegionID: t.RegionID},
|
||||
{LangID: t.LangID, ScriptID: t.ScriptID},
|
||||
} {
|
||||
if x, err := addTags(id); err == nil && max.equalTags(x) {
|
||||
t.setTagsFrom(id)
|
||||
break
|
||||
}
|
||||
}
|
||||
return t, nil
|
||||
}
|
|
@ -0,0 +1,594 @@
|
|||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package language
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"golang.org/x/text/internal/tag"
|
||||
)
|
||||
|
||||
// isAlpha returns true if the byte is not a digit.
|
||||
// b must be an ASCII letter or digit.
|
||||
func isAlpha(b byte) bool {
|
||||
return b > '9'
|
||||
}
|
||||
|
||||
// isAlphaNum returns true if the string contains only ASCII letters or digits.
|
||||
func isAlphaNum(s []byte) bool {
|
||||
for _, c := range s {
|
||||
if !('a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// ErrSyntax is returned by any of the parsing functions when the
|
||||
// input is not well-formed, according to BCP 47.
|
||||
// TODO: return the position at which the syntax error occurred?
|
||||
var ErrSyntax = errors.New("language: tag is not well-formed")
|
||||
|
||||
// ErrDuplicateKey is returned when a tag contains the same key twice with
|
||||
// different values in the -u section.
|
||||
var ErrDuplicateKey = errors.New("language: different values for same key in -u extension")
|
||||
|
||||
// ValueError is returned by any of the parsing functions when the
|
||||
// input is well-formed but the respective subtag is not recognized
|
||||
// as a valid value.
|
||||
type ValueError struct {
|
||||
v [8]byte
|
||||
}
|
||||
|
||||
// NewValueError creates a new ValueError.
|
||||
func NewValueError(tag []byte) ValueError {
|
||||
var e ValueError
|
||||
copy(e.v[:], tag)
|
||||
return e
|
||||
}
|
||||
|
||||
func (e ValueError) tag() []byte {
|
||||
n := bytes.IndexByte(e.v[:], 0)
|
||||
if n == -1 {
|
||||
n = 8
|
||||
}
|
||||
return e.v[:n]
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e ValueError) Error() string {
|
||||
return fmt.Sprintf("language: subtag %q is well-formed but unknown", e.tag())
|
||||
}
|
||||
|
||||
// Subtag returns the subtag for which the error occurred.
|
||||
func (e ValueError) Subtag() string {
|
||||
return string(e.tag())
|
||||
}
|
||||
|
||||
// scanner is used to scan BCP 47 tokens, which are separated by _ or -.
|
||||
type scanner struct {
|
||||
b []byte
|
||||
bytes [max99thPercentileSize]byte
|
||||
token []byte
|
||||
start int // start position of the current token
|
||||
end int // end position of the current token
|
||||
next int // next point for scan
|
||||
err error
|
||||
done bool
|
||||
}
|
||||
|
||||
func makeScannerString(s string) scanner {
|
||||
scan := scanner{}
|
||||
if len(s) <= len(scan.bytes) {
|
||||
scan.b = scan.bytes[:copy(scan.bytes[:], s)]
|
||||
} else {
|
||||
scan.b = []byte(s)
|
||||
}
|
||||
scan.init()
|
||||
return scan
|
||||
}
|
||||
|
||||
// makeScanner returns a scanner using b as the input buffer.
|
||||
// b is not copied and may be modified by the scanner routines.
|
||||
func makeScanner(b []byte) scanner {
|
||||
scan := scanner{b: b}
|
||||
scan.init()
|
||||
return scan
|
||||
}
|
||||
|
||||
func (s *scanner) init() {
|
||||
for i, c := range s.b {
|
||||
if c == '_' {
|
||||
s.b[i] = '-'
|
||||
}
|
||||
}
|
||||
s.scan()
|
||||
}
|
||||
|
||||
// restToLower converts the string between start and end to lower case.
|
||||
func (s *scanner) toLower(start, end int) {
|
||||
for i := start; i < end; i++ {
|
||||
c := s.b[i]
|
||||
if 'A' <= c && c <= 'Z' {
|
||||
s.b[i] += 'a' - 'A'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *scanner) setError(e error) {
|
||||
if s.err == nil || (e == ErrSyntax && s.err != ErrSyntax) {
|
||||
s.err = e
|
||||
}
|
||||
}
|
||||
|
||||
// resizeRange shrinks or grows the array at position oldStart such that
|
||||
// a new string of size newSize can fit between oldStart and oldEnd.
|
||||
// Sets the scan point to after the resized range.
|
||||
func (s *scanner) resizeRange(oldStart, oldEnd, newSize int) {
|
||||
s.start = oldStart
|
||||
if end := oldStart + newSize; end != oldEnd {
|
||||
diff := end - oldEnd
|
||||
if end < cap(s.b) {
|
||||
b := make([]byte, len(s.b)+diff)
|
||||
copy(b, s.b[:oldStart])
|
||||
copy(b[end:], s.b[oldEnd:])
|
||||
s.b = b
|
||||
} else {
|
||||
s.b = append(s.b[end:], s.b[oldEnd:]...)
|
||||
}
|
||||
s.next = end + (s.next - s.end)
|
||||
s.end = end
|
||||
}
|
||||
}
|
||||
|
||||
// replace replaces the current token with repl.
|
||||
func (s *scanner) replace(repl string) {
|
||||
s.resizeRange(s.start, s.end, len(repl))
|
||||
copy(s.b[s.start:], repl)
|
||||
}
|
||||
|
||||
// gobble removes the current token from the input.
|
||||
// Caller must call scan after calling gobble.
|
||||
func (s *scanner) gobble(e error) {
|
||||
s.setError(e)
|
||||
if s.start == 0 {
|
||||
s.b = s.b[:+copy(s.b, s.b[s.next:])]
|
||||
s.end = 0
|
||||
} else {
|
||||
s.b = s.b[:s.start-1+copy(s.b[s.start-1:], s.b[s.end:])]
|
||||
s.end = s.start - 1
|
||||
}
|
||||
s.next = s.start
|
||||
}
|
||||
|
||||
// deleteRange removes the given range from s.b before the current token.
|
||||
func (s *scanner) deleteRange(start, end int) {
|
||||
s.b = s.b[:start+copy(s.b[start:], s.b[end:])]
|
||||
diff := end - start
|
||||
s.next -= diff
|
||||
s.start -= diff
|
||||
s.end -= diff
|
||||
}
|
||||
|
||||
// scan parses the next token of a BCP 47 string. Tokens that are larger
|
||||
// than 8 characters or include non-alphanumeric characters result in an error
|
||||
// and are gobbled and removed from the output.
|
||||
// It returns the end position of the last token consumed.
|
||||
func (s *scanner) scan() (end int) {
|
||||
end = s.end
|
||||
s.token = nil
|
||||
for s.start = s.next; s.next < len(s.b); {
|
||||
i := bytes.IndexByte(s.b[s.next:], '-')
|
||||
if i == -1 {
|
||||
s.end = len(s.b)
|
||||
s.next = len(s.b)
|
||||
i = s.end - s.start
|
||||
} else {
|
||||
s.end = s.next + i
|
||||
s.next = s.end + 1
|
||||
}
|
||||
token := s.b[s.start:s.end]
|
||||
if i < 1 || i > 8 || !isAlphaNum(token) {
|
||||
s.gobble(ErrSyntax)
|
||||
continue
|
||||
}
|
||||
s.token = token
|
||||
return end
|
||||
}
|
||||
if n := len(s.b); n > 0 && s.b[n-1] == '-' {
|
||||
s.setError(ErrSyntax)
|
||||
s.b = s.b[:len(s.b)-1]
|
||||
}
|
||||
s.done = true
|
||||
return end
|
||||
}
|
||||
|
||||
// acceptMinSize parses multiple tokens of the given size or greater.
|
||||
// It returns the end position of the last token consumed.
|
||||
func (s *scanner) acceptMinSize(min int) (end int) {
|
||||
end = s.end
|
||||
s.scan()
|
||||
for ; len(s.token) >= min; s.scan() {
|
||||
end = s.end
|
||||
}
|
||||
return end
|
||||
}
|
||||
|
||||
// Parse parses the given BCP 47 string and returns a valid Tag. If parsing
|
||||
// failed it returns an error and any part of the tag that could be parsed.
|
||||
// If parsing succeeded but an unknown value was found, it returns
|
||||
// ValueError. The Tag returned in this case is just stripped of the unknown
|
||||
// value. All other values are preserved. It accepts tags in the BCP 47 format
|
||||
// and extensions to this standard defined in
|
||||
// https://www.unicode.org/reports/tr35/#Unicode_Language_and_Locale_Identifiers.
|
||||
func Parse(s string) (t Tag, err error) {
|
||||
// TODO: consider supporting old-style locale key-value pairs.
|
||||
if s == "" {
|
||||
return Und, ErrSyntax
|
||||
}
|
||||
if len(s) <= maxAltTaglen {
|
||||
b := [maxAltTaglen]byte{}
|
||||
for i, c := range s {
|
||||
// Generating invalid UTF-8 is okay as it won't match.
|
||||
if 'A' <= c && c <= 'Z' {
|
||||
c += 'a' - 'A'
|
||||
} else if c == '_' {
|
||||
c = '-'
|
||||
}
|
||||
b[i] = byte(c)
|
||||
}
|
||||
if t, ok := grandfathered(b); ok {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
scan := makeScannerString(s)
|
||||
return parse(&scan, s)
|
||||
}
|
||||
|
||||
func parse(scan *scanner, s string) (t Tag, err error) {
|
||||
t = Und
|
||||
var end int
|
||||
if n := len(scan.token); n <= 1 {
|
||||
scan.toLower(0, len(scan.b))
|
||||
if n == 0 || scan.token[0] != 'x' {
|
||||
return t, ErrSyntax
|
||||
}
|
||||
end = parseExtensions(scan)
|
||||
} else if n >= 4 {
|
||||
return Und, ErrSyntax
|
||||
} else { // the usual case
|
||||
t, end = parseTag(scan)
|
||||
if n := len(scan.token); n == 1 {
|
||||
t.pExt = uint16(end)
|
||||
end = parseExtensions(scan)
|
||||
} else if end < len(scan.b) {
|
||||
scan.setError(ErrSyntax)
|
||||
scan.b = scan.b[:end]
|
||||
}
|
||||
}
|
||||
if int(t.pVariant) < len(scan.b) {
|
||||
if end < len(s) {
|
||||
s = s[:end]
|
||||
}
|
||||
if len(s) > 0 && tag.Compare(s, scan.b) == 0 {
|
||||
t.str = s
|
||||
} else {
|
||||
t.str = string(scan.b)
|
||||
}
|
||||
} else {
|
||||
t.pVariant, t.pExt = 0, 0
|
||||
}
|
||||
return t, scan.err
|
||||
}
|
||||
|
||||
// parseTag parses language, script, region and variants.
|
||||
// It returns a Tag and the end position in the input that was parsed.
|
||||
func parseTag(scan *scanner) (t Tag, end int) {
|
||||
var e error
|
||||
// TODO: set an error if an unknown lang, script or region is encountered.
|
||||
t.LangID, e = getLangID(scan.token)
|
||||
scan.setError(e)
|
||||
scan.replace(t.LangID.String())
|
||||
langStart := scan.start
|
||||
end = scan.scan()
|
||||
for len(scan.token) == 3 && isAlpha(scan.token[0]) {
|
||||
// From http://tools.ietf.org/html/bcp47, <lang>-<extlang> tags are equivalent
|
||||
// to a tag of the form <extlang>.
|
||||
lang, e := getLangID(scan.token)
|
||||
if lang != 0 {
|
||||
t.LangID = lang
|
||||
copy(scan.b[langStart:], lang.String())
|
||||
scan.b[langStart+3] = '-'
|
||||
scan.start = langStart + 4
|
||||
}
|
||||
scan.gobble(e)
|
||||
end = scan.scan()
|
||||
}
|
||||
if len(scan.token) == 4 && isAlpha(scan.token[0]) {
|
||||
t.ScriptID, e = getScriptID(script, scan.token)
|
||||
if t.ScriptID == 0 {
|
||||
scan.gobble(e)
|
||||
}
|
||||
end = scan.scan()
|
||||
}
|
||||
if n := len(scan.token); n >= 2 && n <= 3 {
|
||||
t.RegionID, e = getRegionID(scan.token)
|
||||
if t.RegionID == 0 {
|
||||
scan.gobble(e)
|
||||
} else {
|
||||
scan.replace(t.RegionID.String())
|
||||
}
|
||||
end = scan.scan()
|
||||
}
|
||||
scan.toLower(scan.start, len(scan.b))
|
||||
t.pVariant = byte(end)
|
||||
end = parseVariants(scan, end, t)
|
||||
t.pExt = uint16(end)
|
||||
return t, end
|
||||
}
|
||||
|
||||
var separator = []byte{'-'}
|
||||
|
||||
// parseVariants scans tokens as long as each token is a valid variant string.
|
||||
// Duplicate variants are removed.
|
||||
func parseVariants(scan *scanner, end int, t Tag) int {
|
||||
start := scan.start
|
||||
varIDBuf := [4]uint8{}
|
||||
variantBuf := [4][]byte{}
|
||||
varID := varIDBuf[:0]
|
||||
variant := variantBuf[:0]
|
||||
last := -1
|
||||
needSort := false
|
||||
for ; len(scan.token) >= 4; scan.scan() {
|
||||
// TODO: measure the impact of needing this conversion and redesign
|
||||
// the data structure if there is an issue.
|
||||
v, ok := variantIndex[string(scan.token)]
|
||||
if !ok {
|
||||
// unknown variant
|
||||
// TODO: allow user-defined variants?
|
||||
scan.gobble(NewValueError(scan.token))
|
||||
continue
|
||||
}
|
||||
varID = append(varID, v)
|
||||
variant = append(variant, scan.token)
|
||||
if !needSort {
|
||||
if last < int(v) {
|
||||
last = int(v)
|
||||
} else {
|
||||
needSort = true
|
||||
// There is no legal combinations of more than 7 variants
|
||||
// (and this is by no means a useful sequence).
|
||||
const maxVariants = 8
|
||||
if len(varID) > maxVariants {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
end = scan.end
|
||||
}
|
||||
if needSort {
|
||||
sort.Sort(variantsSort{varID, variant})
|
||||
k, l := 0, -1
|
||||
for i, v := range varID {
|
||||
w := int(v)
|
||||
if l == w {
|
||||
// Remove duplicates.
|
||||
continue
|
||||
}
|
||||
varID[k] = varID[i]
|
||||
variant[k] = variant[i]
|
||||
k++
|
||||
l = w
|
||||
}
|
||||
if str := bytes.Join(variant[:k], separator); len(str) == 0 {
|
||||
end = start - 1
|
||||
} else {
|
||||
scan.resizeRange(start, end, len(str))
|
||||
copy(scan.b[scan.start:], str)
|
||||
end = scan.end
|
||||
}
|
||||
}
|
||||
return end
|
||||
}
|
||||
|
||||
type variantsSort struct {
|
||||
i []uint8
|
||||
v [][]byte
|
||||
}
|
||||
|
||||
func (s variantsSort) Len() int {
|
||||
return len(s.i)
|
||||
}
|
||||
|
||||
func (s variantsSort) Swap(i, j int) {
|
||||
s.i[i], s.i[j] = s.i[j], s.i[i]
|
||||
s.v[i], s.v[j] = s.v[j], s.v[i]
|
||||
}
|
||||
|
||||
func (s variantsSort) Less(i, j int) bool {
|
||||
return s.i[i] < s.i[j]
|
||||
}
|
||||
|
||||
type bytesSort struct {
|
||||
b [][]byte
|
||||
n int // first n bytes to compare
|
||||
}
|
||||
|
||||
func (b bytesSort) Len() int {
|
||||
return len(b.b)
|
||||
}
|
||||
|
||||
func (b bytesSort) Swap(i, j int) {
|
||||
b.b[i], b.b[j] = b.b[j], b.b[i]
|
||||
}
|
||||
|
||||
func (b bytesSort) Less(i, j int) bool {
|
||||
for k := 0; k < b.n; k++ {
|
||||
if b.b[i][k] == b.b[j][k] {
|
||||
continue
|
||||
}
|
||||
return b.b[i][k] < b.b[j][k]
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// parseExtensions parses and normalizes the extensions in the buffer.
|
||||
// It returns the last position of scan.b that is part of any extension.
|
||||
// It also trims scan.b to remove excess parts accordingly.
|
||||
func parseExtensions(scan *scanner) int {
|
||||
start := scan.start
|
||||
exts := [][]byte{}
|
||||
private := []byte{}
|
||||
end := scan.end
|
||||
for len(scan.token) == 1 {
|
||||
extStart := scan.start
|
||||
ext := scan.token[0]
|
||||
end = parseExtension(scan)
|
||||
extension := scan.b[extStart:end]
|
||||
if len(extension) < 3 || (ext != 'x' && len(extension) < 4) {
|
||||
scan.setError(ErrSyntax)
|
||||
end = extStart
|
||||
continue
|
||||
} else if start == extStart && (ext == 'x' || scan.start == len(scan.b)) {
|
||||
scan.b = scan.b[:end]
|
||||
return end
|
||||
} else if ext == 'x' {
|
||||
private = extension
|
||||
break
|
||||
}
|
||||
exts = append(exts, extension)
|
||||
}
|
||||
sort.Sort(bytesSort{exts, 1})
|
||||
if len(private) > 0 {
|
||||
exts = append(exts, private)
|
||||
}
|
||||
scan.b = scan.b[:start]
|
||||
if len(exts) > 0 {
|
||||
scan.b = append(scan.b, bytes.Join(exts, separator)...)
|
||||
} else if start > 0 {
|
||||
// Strip trailing '-'.
|
||||
scan.b = scan.b[:start-1]
|
||||
}
|
||||
return end
|
||||
}
|
||||
|
||||
// parseExtension parses a single extension and returns the position of
|
||||
// the extension end.
|
||||
func parseExtension(scan *scanner) int {
|
||||
start, end := scan.start, scan.end
|
||||
switch scan.token[0] {
|
||||
case 'u':
|
||||
attrStart := end
|
||||
scan.scan()
|
||||
for last := []byte{}; len(scan.token) > 2; scan.scan() {
|
||||
if bytes.Compare(scan.token, last) != -1 {
|
||||
// Attributes are unsorted. Start over from scratch.
|
||||
p := attrStart + 1
|
||||
scan.next = p
|
||||
attrs := [][]byte{}
|
||||
for scan.scan(); len(scan.token) > 2; scan.scan() {
|
||||
attrs = append(attrs, scan.token)
|
||||
end = scan.end
|
||||
}
|
||||
sort.Sort(bytesSort{attrs, 3})
|
||||
copy(scan.b[p:], bytes.Join(attrs, separator))
|
||||
break
|
||||
}
|
||||
last = scan.token
|
||||
end = scan.end
|
||||
}
|
||||
var last, key []byte
|
||||
for attrEnd := end; len(scan.token) == 2; last = key {
|
||||
key = scan.token
|
||||
keyEnd := scan.end
|
||||
end = scan.acceptMinSize(3)
|
||||
// TODO: check key value validity
|
||||
if keyEnd == end || bytes.Compare(key, last) != 1 {
|
||||
// We have an invalid key or the keys are not sorted.
|
||||
// Start scanning keys from scratch and reorder.
|
||||
p := attrEnd + 1
|
||||
scan.next = p
|
||||
keys := [][]byte{}
|
||||
for scan.scan(); len(scan.token) == 2; {
|
||||
keyStart, keyEnd := scan.start, scan.end
|
||||
end = scan.acceptMinSize(3)
|
||||
if keyEnd != end {
|
||||
keys = append(keys, scan.b[keyStart:end])
|
||||
} else {
|
||||
scan.setError(ErrSyntax)
|
||||
end = keyStart
|
||||
}
|
||||
}
|
||||
sort.Stable(bytesSort{keys, 2})
|
||||
if n := len(keys); n > 0 {
|
||||
k := 0
|
||||
for i := 1; i < n; i++ {
|
||||
if !bytes.Equal(keys[k][:2], keys[i][:2]) {
|
||||
k++
|
||||
keys[k] = keys[i]
|
||||
} else if !bytes.Equal(keys[k], keys[i]) {
|
||||
scan.setError(ErrDuplicateKey)
|
||||
}
|
||||
}
|
||||
keys = keys[:k+1]
|
||||
}
|
||||
reordered := bytes.Join(keys, separator)
|
||||
if e := p + len(reordered); e < end {
|
||||
scan.deleteRange(e, end)
|
||||
end = e
|
||||
}
|
||||
copy(scan.b[p:], reordered)
|
||||
break
|
||||
}
|
||||
}
|
||||
case 't':
|
||||
scan.scan()
|
||||
if n := len(scan.token); n >= 2 && n <= 3 && isAlpha(scan.token[1]) {
|
||||
_, end = parseTag(scan)
|
||||
scan.toLower(start, end)
|
||||
}
|
||||
for len(scan.token) == 2 && !isAlpha(scan.token[1]) {
|
||||
end = scan.acceptMinSize(3)
|
||||
}
|
||||
case 'x':
|
||||
end = scan.acceptMinSize(1)
|
||||
default:
|
||||
end = scan.acceptMinSize(2)
|
||||
}
|
||||
return end
|
||||
}
|
||||
|
||||
// getExtension returns the name, body and end position of the extension.
|
||||
func getExtension(s string, p int) (end int, ext string) {
|
||||
if s[p] == '-' {
|
||||
p++
|
||||
}
|
||||
if s[p] == 'x' {
|
||||
return len(s), s[p:]
|
||||
}
|
||||
end = nextExtension(s, p)
|
||||
return end, s[p:end]
|
||||
}
|
||||
|
||||
// nextExtension finds the next extension within the string, searching
|
||||
// for the -<char>- pattern from position p.
|
||||
// In the fast majority of cases, language tags will have at most
|
||||
// one extension and extensions tend to be small.
|
||||
func nextExtension(s string, p int) int {
|
||||
for n := len(s) - 3; p < n; {
|
||||
if s[p] == '-' {
|
||||
if s[p+2] == '-' {
|
||||
return p
|
||||
}
|
||||
p += 3
|
||||
} else {
|
||||
p++
|
||||
}
|
||||
}
|
||||
return len(s)
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,48 @@
|
|||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package language
|
||||
|
||||
// MustParse is like Parse, but panics if the given BCP 47 tag cannot be parsed.
|
||||
// It simplifies safe initialization of Tag values.
|
||||
func MustParse(s string) Tag {
|
||||
t, err := Parse(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// MustParseBase is like ParseBase, but panics if the given base cannot be parsed.
|
||||
// It simplifies safe initialization of Base values.
|
||||
func MustParseBase(s string) Language {
|
||||
b, err := ParseBase(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// MustParseScript is like ParseScript, but panics if the given script cannot be
|
||||
// parsed. It simplifies safe initialization of Script values.
|
||||
func MustParseScript(s string) Script {
|
||||
scr, err := ParseScript(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return scr
|
||||
}
|
||||
|
||||
// MustParseRegion is like ParseRegion, but panics if the given region cannot be
|
||||
// parsed. It simplifies safe initialization of Region values.
|
||||
func MustParseRegion(s string) Region {
|
||||
r, err := ParseRegion(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// Und is the root language.
|
||||
var Und Tag
|
|
@ -0,0 +1,100 @@
|
|||
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package tag contains functionality handling tags and related data.
|
||||
package tag // import "golang.org/x/text/internal/tag"
|
||||
|
||||
import "sort"
|
||||
|
||||
// An Index converts tags to a compact numeric value.
|
||||
//
|
||||
// All elements are of size 4. Tags may be up to 4 bytes long. Excess bytes can
|
||||
// be used to store additional information about the tag.
|
||||
type Index string
|
||||
|
||||
// Elem returns the element data at the given index.
|
||||
func (s Index) Elem(x int) string {
|
||||
return string(s[x*4 : x*4+4])
|
||||
}
|
||||
|
||||
// Index reports the index of the given key or -1 if it could not be found.
|
||||
// Only the first len(key) bytes from the start of the 4-byte entries will be
|
||||
// considered for the search and the first match in Index will be returned.
|
||||
func (s Index) Index(key []byte) int {
|
||||
n := len(key)
|
||||
// search the index of the first entry with an equal or higher value than
|
||||
// key in s.
|
||||
index := sort.Search(len(s)/4, func(i int) bool {
|
||||
return cmp(s[i*4:i*4+n], key) != -1
|
||||
})
|
||||
i := index * 4
|
||||
if cmp(s[i:i+len(key)], key) != 0 {
|
||||
return -1
|
||||
}
|
||||
return index
|
||||
}
|
||||
|
||||
// Next finds the next occurrence of key after index x, which must have been
|
||||
// obtained from a call to Index using the same key. It returns x+1 or -1.
|
||||
func (s Index) Next(key []byte, x int) int {
|
||||
if x++; x*4 < len(s) && cmp(s[x*4:x*4+len(key)], key) == 0 {
|
||||
return x
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// cmp returns an integer comparing a and b lexicographically.
|
||||
func cmp(a Index, b []byte) int {
|
||||
n := len(a)
|
||||
if len(b) < n {
|
||||
n = len(b)
|
||||
}
|
||||
for i, c := range b[:n] {
|
||||
switch {
|
||||
case a[i] > c:
|
||||
return 1
|
||||
case a[i] < c:
|
||||
return -1
|
||||
}
|
||||
}
|
||||
switch {
|
||||
case len(a) < len(b):
|
||||
return -1
|
||||
case len(a) > len(b):
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Compare returns an integer comparing a and b lexicographically.
|
||||
func Compare(a string, b []byte) int {
|
||||
return cmp(Index(a), b)
|
||||
}
|
||||
|
||||
// FixCase reformats b to the same pattern of cases as form.
|
||||
// If returns false if string b is malformed.
|
||||
func FixCase(form string, b []byte) bool {
|
||||
if len(form) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i, c := range b {
|
||||
if form[i] <= 'Z' {
|
||||
if c >= 'a' {
|
||||
c -= 'z' - 'Z'
|
||||
}
|
||||
if c < 'A' || 'Z' < c {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
if c <= 'Z' {
|
||||
c += 'z' - 'Z'
|
||||
}
|
||||
if c < 'a' || 'z' < c {
|
||||
return false
|
||||
}
|
||||
}
|
||||
b[i] = c
|
||||
}
|
||||
return true
|
||||
}
|
|
@ -0,0 +1,187 @@
|
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package language
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"golang.org/x/text/internal/language"
|
||||
)
|
||||
|
||||
// The Coverage interface is used to define the level of coverage of an
|
||||
// internationalization service. Note that not all types are supported by all
|
||||
// services. As lists may be generated on the fly, it is recommended that users
|
||||
// of a Coverage cache the results.
|
||||
type Coverage interface {
|
||||
// Tags returns the list of supported tags.
|
||||
Tags() []Tag
|
||||
|
||||
// BaseLanguages returns the list of supported base languages.
|
||||
BaseLanguages() []Base
|
||||
|
||||
// Scripts returns the list of supported scripts.
|
||||
Scripts() []Script
|
||||
|
||||
// Regions returns the list of supported regions.
|
||||
Regions() []Region
|
||||
}
|
||||
|
||||
var (
|
||||
// Supported defines a Coverage that lists all supported subtags. Tags
|
||||
// always returns nil.
|
||||
Supported Coverage = allSubtags{}
|
||||
)
|
||||
|
||||
// TODO:
|
||||
// - Support Variants, numbering systems.
|
||||
// - CLDR coverage levels.
|
||||
// - Set of common tags defined in this package.
|
||||
|
||||
type allSubtags struct{}
|
||||
|
||||
// Regions returns the list of supported regions. As all regions are in a
|
||||
// consecutive range, it simply returns a slice of numbers in increasing order.
|
||||
// The "undefined" region is not returned.
|
||||
func (s allSubtags) Regions() []Region {
|
||||
reg := make([]Region, language.NumRegions)
|
||||
for i := range reg {
|
||||
reg[i] = Region{language.Region(i + 1)}
|
||||
}
|
||||
return reg
|
||||
}
|
||||
|
||||
// Scripts returns the list of supported scripts. As all scripts are in a
|
||||
// consecutive range, it simply returns a slice of numbers in increasing order.
|
||||
// The "undefined" script is not returned.
|
||||
func (s allSubtags) Scripts() []Script {
|
||||
scr := make([]Script, language.NumScripts)
|
||||
for i := range scr {
|
||||
scr[i] = Script{language.Script(i + 1)}
|
||||
}
|
||||
return scr
|
||||
}
|
||||
|
||||
// BaseLanguages returns the list of all supported base languages. It generates
|
||||
// the list by traversing the internal structures.
|
||||
func (s allSubtags) BaseLanguages() []Base {
|
||||
bs := language.BaseLanguages()
|
||||
base := make([]Base, len(bs))
|
||||
for i, b := range bs {
|
||||
base[i] = Base{b}
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
// Tags always returns nil.
|
||||
func (s allSubtags) Tags() []Tag {
|
||||
return nil
|
||||
}
|
||||
|
||||
// coverage is used by NewCoverage which is used as a convenient way for
|
||||
// creating Coverage implementations for partially defined data. Very often a
|
||||
// package will only need to define a subset of slices. coverage provides a
|
||||
// convenient way to do this. Moreover, packages using NewCoverage, instead of
|
||||
// their own implementation, will not break if later new slice types are added.
|
||||
type coverage struct {
|
||||
tags func() []Tag
|
||||
bases func() []Base
|
||||
scripts func() []Script
|
||||
regions func() []Region
|
||||
}
|
||||
|
||||
func (s *coverage) Tags() []Tag {
|
||||
if s.tags == nil {
|
||||
return nil
|
||||
}
|
||||
return s.tags()
|
||||
}
|
||||
|
||||
// bases implements sort.Interface and is used to sort base languages.
|
||||
type bases []Base
|
||||
|
||||
func (b bases) Len() int {
|
||||
return len(b)
|
||||
}
|
||||
|
||||
func (b bases) Swap(i, j int) {
|
||||
b[i], b[j] = b[j], b[i]
|
||||
}
|
||||
|
||||
func (b bases) Less(i, j int) bool {
|
||||
return b[i].langID < b[j].langID
|
||||
}
|
||||
|
||||
// BaseLanguages returns the result from calling s.bases if it is specified or
|
||||
// otherwise derives the set of supported base languages from tags.
|
||||
func (s *coverage) BaseLanguages() []Base {
|
||||
if s.bases == nil {
|
||||
tags := s.Tags()
|
||||
if len(tags) == 0 {
|
||||
return nil
|
||||
}
|
||||
a := make([]Base, len(tags))
|
||||
for i, t := range tags {
|
||||
a[i] = Base{language.Language(t.lang())}
|
||||
}
|
||||
sort.Sort(bases(a))
|
||||
k := 0
|
||||
for i := 1; i < len(a); i++ {
|
||||
if a[k] != a[i] {
|
||||
k++
|
||||
a[k] = a[i]
|
||||
}
|
||||
}
|
||||
return a[:k+1]
|
||||
}
|
||||
return s.bases()
|
||||
}
|
||||
|
||||
func (s *coverage) Scripts() []Script {
|
||||
if s.scripts == nil {
|
||||
return nil
|
||||
}
|
||||
return s.scripts()
|
||||
}
|
||||
|
||||
func (s *coverage) Regions() []Region {
|
||||
if s.regions == nil {
|
||||
return nil
|
||||
}
|
||||
return s.regions()
|
||||
}
|
||||
|
||||
// NewCoverage returns a Coverage for the given lists. It is typically used by
|
||||
// packages providing internationalization services to define their level of
|
||||
// coverage. A list may be of type []T or func() []T, where T is either Tag,
|
||||
// Base, Script or Region. The returned Coverage derives the value for Bases
|
||||
// from Tags if no func or slice for []Base is specified. For other unspecified
|
||||
// types the returned Coverage will return nil for the respective methods.
|
||||
func NewCoverage(list ...interface{}) Coverage {
|
||||
s := &coverage{}
|
||||
for _, x := range list {
|
||||
switch v := x.(type) {
|
||||
case func() []Base:
|
||||
s.bases = v
|
||||
case func() []Script:
|
||||
s.scripts = v
|
||||
case func() []Region:
|
||||
s.regions = v
|
||||
case func() []Tag:
|
||||
s.tags = v
|
||||
case []Base:
|
||||
s.bases = func() []Base { return v }
|
||||
case []Script:
|
||||
s.scripts = func() []Script { return v }
|
||||
case []Region:
|
||||
s.regions = func() []Region { return v }
|
||||
case []Tag:
|
||||
s.tags = func() []Tag { return v }
|
||||
default:
|
||||
panic(fmt.Sprintf("language: unsupported set type %T", v))
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
|
@ -0,0 +1,102 @@
|
|||
// Copyright 2017 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package language implements BCP 47 language tags and related functionality.
|
||||
//
|
||||
// The most important function of package language is to match a list of
|
||||
// user-preferred languages to a list of supported languages.
|
||||
// It alleviates the developer of dealing with the complexity of this process
|
||||
// and provides the user with the best experience
|
||||
// (see https://blog.golang.org/matchlang).
|
||||
//
|
||||
//
|
||||
// Matching preferred against supported languages
|
||||
//
|
||||
// A Matcher for an application that supports English, Australian English,
|
||||
// Danish, and standard Mandarin can be created as follows:
|
||||
//
|
||||
// var matcher = language.NewMatcher([]language.Tag{
|
||||
// language.English, // The first language is used as fallback.
|
||||
// language.MustParse("en-AU"),
|
||||
// language.Danish,
|
||||
// language.Chinese,
|
||||
// })
|
||||
//
|
||||
// This list of supported languages is typically implied by the languages for
|
||||
// which there exists translations of the user interface.
|
||||
//
|
||||
// User-preferred languages usually come as a comma-separated list of BCP 47
|
||||
// language tags.
|
||||
// The MatchString finds best matches for such strings:
|
||||
//
|
||||
// handler(w http.ResponseWriter, r *http.Request) {
|
||||
// lang, _ := r.Cookie("lang")
|
||||
// accept := r.Header.Get("Accept-Language")
|
||||
// tag, _ := language.MatchStrings(matcher, lang.String(), accept)
|
||||
//
|
||||
// // tag should now be used for the initialization of any
|
||||
// // locale-specific service.
|
||||
// }
|
||||
//
|
||||
// The Matcher's Match method can be used to match Tags directly.
|
||||
//
|
||||
// Matchers are aware of the intricacies of equivalence between languages, such
|
||||
// as deprecated subtags, legacy tags, macro languages, mutual
|
||||
// intelligibility between scripts and languages, and transparently passing
|
||||
// BCP 47 user configuration.
|
||||
// For instance, it will know that a reader of Bokmål Danish can read Norwegian
|
||||
// and will know that Cantonese ("yue") is a good match for "zh-HK".
|
||||
//
|
||||
//
|
||||
// Using match results
|
||||
//
|
||||
// To guarantee a consistent user experience to the user it is important to
|
||||
// use the same language tag for the selection of any locale-specific services.
|
||||
// For example, it is utterly confusing to substitute spelled-out numbers
|
||||
// or dates in one language in text of another language.
|
||||
// More subtly confusing is using the wrong sorting order or casing
|
||||
// algorithm for a certain language.
|
||||
//
|
||||
// All the packages in x/text that provide locale-specific services
|
||||
// (e.g. collate, cases) should be initialized with the tag that was
|
||||
// obtained at the start of an interaction with the user.
|
||||
//
|
||||
// Note that Tag that is returned by Match and MatchString may differ from any
|
||||
// of the supported languages, as it may contain carried over settings from
|
||||
// the user tags.
|
||||
// This may be inconvenient when your application has some additional
|
||||
// locale-specific data for your supported languages.
|
||||
// Match and MatchString both return the index of the matched supported tag
|
||||
// to simplify associating such data with the matched tag.
|
||||
//
|
||||
//
|
||||
// Canonicalization
|
||||
//
|
||||
// If one uses the Matcher to compare languages one does not need to
|
||||
// worry about canonicalization.
|
||||
//
|
||||
// The meaning of a Tag varies per application. The language package
|
||||
// therefore delays canonicalization and preserves information as much
|
||||
// as possible. The Matcher, however, will always take into account that
|
||||
// two different tags may represent the same language.
|
||||
//
|
||||
// By default, only legacy and deprecated tags are converted into their
|
||||
// canonical equivalent. All other information is preserved. This approach makes
|
||||
// the confidence scores more accurate and allows matchers to distinguish
|
||||
// between variants that are otherwise lost.
|
||||
//
|
||||
// As a consequence, two tags that should be treated as identical according to
|
||||
// BCP 47 or CLDR, like "en-Latn" and "en", will be represented differently. The
|
||||
// Matcher handles such distinctions, though, and is aware of the
|
||||
// equivalence relations. The CanonType type can be used to alter the
|
||||
// canonicalization form.
|
||||
//
|
||||
// References
|
||||
//
|
||||
// BCP 47 - Tags for Identifying Languages http://tools.ietf.org/html/bcp47
|
||||
//
|
||||
package language // import "golang.org/x/text/language"
|
||||
|
||||
// TODO: explanation on how to match languages for your own locale-specific
|
||||
// service.
|
|
@ -0,0 +1,38 @@
|
|||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !go1.2
|
||||
|
||||
package language
|
||||
|
||||
import "sort"
|
||||
|
||||
func sortStable(s sort.Interface) {
|
||||
ss := stableSort{
|
||||
s: s,
|
||||
pos: make([]int, s.Len()),
|
||||
}
|
||||
for i := range ss.pos {
|
||||
ss.pos[i] = i
|
||||
}
|
||||
sort.Sort(&ss)
|
||||
}
|
||||
|
||||
type stableSort struct {
|
||||
s sort.Interface
|
||||
pos []int
|
||||
}
|
||||
|
||||
func (s *stableSort) Len() int {
|
||||
return len(s.pos)
|
||||
}
|
||||
|
||||
func (s *stableSort) Less(i, j int) bool {
|
||||
return s.s.Less(i, j) || !s.s.Less(j, i) && s.pos[i] < s.pos[j]
|
||||
}
|
||||
|
||||
func (s *stableSort) Swap(i, j int) {
|
||||
s.s.Swap(i, j)
|
||||
s.pos[i], s.pos[j] = s.pos[j], s.pos[i]
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build go1.2
|
||||
|
||||
package language
|
||||
|
||||
import "sort"
|
||||
|
||||
var sortStable = sort.Stable
|
|
@ -0,0 +1,601 @@
|
|||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:generate go run gen.go -output tables.go
|
||||
|
||||
package language
|
||||
|
||||
// TODO: Remove above NOTE after:
|
||||
// - verifying that tables are dropped correctly (most notably matcher tables).
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"golang.org/x/text/internal/language"
|
||||
"golang.org/x/text/internal/language/compact"
|
||||
)
|
||||
|
||||
// Tag represents a BCP 47 language tag. It is used to specify an instance of a
|
||||
// specific language or locale. All language tag values are guaranteed to be
|
||||
// well-formed.
|
||||
type Tag compact.Tag
|
||||
|
||||
func makeTag(t language.Tag) (tag Tag) {
|
||||
return Tag(compact.Make(t))
|
||||
}
|
||||
|
||||
func (t *Tag) tag() language.Tag {
|
||||
return (*compact.Tag)(t).Tag()
|
||||
}
|
||||
|
||||
func (t *Tag) isCompact() bool {
|
||||
return (*compact.Tag)(t).IsCompact()
|
||||
}
|
||||
|
||||
// TODO: improve performance.
|
||||
func (t *Tag) lang() language.Language { return t.tag().LangID }
|
||||
func (t *Tag) region() language.Region { return t.tag().RegionID }
|
||||
func (t *Tag) script() language.Script { return t.tag().ScriptID }
|
||||
|
||||
// Make is a convenience wrapper for Parse that omits the error.
|
||||
// In case of an error, a sensible default is returned.
|
||||
func Make(s string) Tag {
|
||||
return Default.Make(s)
|
||||
}
|
||||
|
||||
// Make is a convenience wrapper for c.Parse that omits the error.
|
||||
// In case of an error, a sensible default is returned.
|
||||
func (c CanonType) Make(s string) Tag {
|
||||
t, _ := c.Parse(s)
|
||||
return t
|
||||
}
|
||||
|
||||
// Raw returns the raw base language, script and region, without making an
|
||||
// attempt to infer their values.
|
||||
func (t Tag) Raw() (b Base, s Script, r Region) {
|
||||
tt := t.tag()
|
||||
return Base{tt.LangID}, Script{tt.ScriptID}, Region{tt.RegionID}
|
||||
}
|
||||
|
||||
// IsRoot returns true if t is equal to language "und".
|
||||
func (t Tag) IsRoot() bool {
|
||||
return compact.Tag(t).IsRoot()
|
||||
}
|
||||
|
||||
// CanonType can be used to enable or disable various types of canonicalization.
|
||||
type CanonType int
|
||||
|
||||
const (
|
||||
// Replace deprecated base languages with their preferred replacements.
|
||||
DeprecatedBase CanonType = 1 << iota
|
||||
// Replace deprecated scripts with their preferred replacements.
|
||||
DeprecatedScript
|
||||
// Replace deprecated regions with their preferred replacements.
|
||||
DeprecatedRegion
|
||||
// Remove redundant scripts.
|
||||
SuppressScript
|
||||
// Normalize legacy encodings. This includes legacy languages defined in
|
||||
// CLDR as well as bibliographic codes defined in ISO-639.
|
||||
Legacy
|
||||
// Map the dominant language of a macro language group to the macro language
|
||||
// subtag. For example cmn -> zh.
|
||||
Macro
|
||||
// The CLDR flag should be used if full compatibility with CLDR is required.
|
||||
// There are a few cases where language.Tag may differ from CLDR. To follow all
|
||||
// of CLDR's suggestions, use All|CLDR.
|
||||
CLDR
|
||||
|
||||
// Raw can be used to Compose or Parse without Canonicalization.
|
||||
Raw CanonType = 0
|
||||
|
||||
// Replace all deprecated tags with their preferred replacements.
|
||||
Deprecated = DeprecatedBase | DeprecatedScript | DeprecatedRegion
|
||||
|
||||
// All canonicalizations recommended by BCP 47.
|
||||
BCP47 = Deprecated | SuppressScript
|
||||
|
||||
// All canonicalizations.
|
||||
All = BCP47 | Legacy | Macro
|
||||
|
||||
// Default is the canonicalization used by Parse, Make and Compose. To
|
||||
// preserve as much information as possible, canonicalizations that remove
|
||||
// potentially valuable information are not included. The Matcher is
|
||||
// designed to recognize similar tags that would be the same if
|
||||
// they were canonicalized using All.
|
||||
Default = Deprecated | Legacy
|
||||
|
||||
canonLang = DeprecatedBase | Legacy | Macro
|
||||
|
||||
// TODO: LikelyScript, LikelyRegion: suppress similar to ICU.
|
||||
)
|
||||
|
||||
// canonicalize returns the canonicalized equivalent of the tag and
|
||||
// whether there was any change.
|
||||
func canonicalize(c CanonType, t language.Tag) (language.Tag, bool) {
|
||||
if c == Raw {
|
||||
return t, false
|
||||
}
|
||||
changed := false
|
||||
if c&SuppressScript != 0 {
|
||||
if t.LangID.SuppressScript() == t.ScriptID {
|
||||
t.ScriptID = 0
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if c&canonLang != 0 {
|
||||
for {
|
||||
if l, aliasType := t.LangID.Canonicalize(); l != t.LangID {
|
||||
switch aliasType {
|
||||
case language.Legacy:
|
||||
if c&Legacy != 0 {
|
||||
if t.LangID == _sh && t.ScriptID == 0 {
|
||||
t.ScriptID = _Latn
|
||||
}
|
||||
t.LangID = l
|
||||
changed = true
|
||||
}
|
||||
case language.Macro:
|
||||
if c&Macro != 0 {
|
||||
// We deviate here from CLDR. The mapping "nb" -> "no"
|
||||
// qualifies as a typical Macro language mapping. However,
|
||||
// for legacy reasons, CLDR maps "no", the macro language
|
||||
// code for Norwegian, to the dominant variant "nb". This
|
||||
// change is currently under consideration for CLDR as well.
|
||||
// See https://unicode.org/cldr/trac/ticket/2698 and also
|
||||
// https://unicode.org/cldr/trac/ticket/1790 for some of the
|
||||
// practical implications. TODO: this check could be removed
|
||||
// if CLDR adopts this change.
|
||||
if c&CLDR == 0 || t.LangID != _nb {
|
||||
changed = true
|
||||
t.LangID = l
|
||||
}
|
||||
}
|
||||
case language.Deprecated:
|
||||
if c&DeprecatedBase != 0 {
|
||||
if t.LangID == _mo && t.RegionID == 0 {
|
||||
t.RegionID = _MD
|
||||
}
|
||||
t.LangID = l
|
||||
changed = true
|
||||
// Other canonicalization types may still apply.
|
||||
continue
|
||||
}
|
||||
}
|
||||
} else if c&Legacy != 0 && t.LangID == _no && c&CLDR != 0 {
|
||||
t.LangID = _nb
|
||||
changed = true
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if c&DeprecatedScript != 0 {
|
||||
if t.ScriptID == _Qaai {
|
||||
changed = true
|
||||
t.ScriptID = _Zinh
|
||||
}
|
||||
}
|
||||
if c&DeprecatedRegion != 0 {
|
||||
if r := t.RegionID.Canonicalize(); r != t.RegionID {
|
||||
changed = true
|
||||
t.RegionID = r
|
||||
}
|
||||
}
|
||||
return t, changed
|
||||
}
|
||||
|
||||
// Canonicalize returns the canonicalized equivalent of the tag.
|
||||
func (c CanonType) Canonicalize(t Tag) (Tag, error) {
|
||||
// First try fast path.
|
||||
if t.isCompact() {
|
||||
if _, changed := canonicalize(c, compact.Tag(t).Tag()); !changed {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
// It is unlikely that one will canonicalize a tag after matching. So do
|
||||
// a slow but simple approach here.
|
||||
if tag, changed := canonicalize(c, t.tag()); changed {
|
||||
tag.RemakeString()
|
||||
return makeTag(tag), nil
|
||||
}
|
||||
return t, nil
|
||||
|
||||
}
|
||||
|
||||
// Confidence indicates the level of certainty for a given return value.
|
||||
// For example, Serbian may be written in Cyrillic or Latin script.
|
||||
// The confidence level indicates whether a value was explicitly specified,
|
||||
// whether it is typically the only possible value, or whether there is
|
||||
// an ambiguity.
|
||||
type Confidence int
|
||||
|
||||
const (
|
||||
No Confidence = iota // full confidence that there was no match
|
||||
Low // most likely value picked out of a set of alternatives
|
||||
High // value is generally assumed to be the correct match
|
||||
Exact // exact match or explicitly specified value
|
||||
)
|
||||
|
||||
var confName = []string{"No", "Low", "High", "Exact"}
|
||||
|
||||
func (c Confidence) String() string {
|
||||
return confName[c]
|
||||
}
|
||||
|
||||
// String returns the canonical string representation of the language tag.
|
||||
func (t Tag) String() string {
|
||||
return t.tag().String()
|
||||
}
|
||||
|
||||
// MarshalText implements encoding.TextMarshaler.
|
||||
func (t Tag) MarshalText() (text []byte, err error) {
|
||||
return t.tag().MarshalText()
|
||||
}
|
||||
|
||||
// UnmarshalText implements encoding.TextUnmarshaler.
|
||||
func (t *Tag) UnmarshalText(text []byte) error {
|
||||
var tag language.Tag
|
||||
err := tag.UnmarshalText(text)
|
||||
*t = makeTag(tag)
|
||||
return err
|
||||
}
|
||||
|
||||
// Base returns the base language of the language tag. If the base language is
|
||||
// unspecified, an attempt will be made to infer it from the context.
|
||||
// It uses a variant of CLDR's Add Likely Subtags algorithm. This is subject to change.
|
||||
func (t Tag) Base() (Base, Confidence) {
|
||||
if b := t.lang(); b != 0 {
|
||||
return Base{b}, Exact
|
||||
}
|
||||
tt := t.tag()
|
||||
c := High
|
||||
if tt.ScriptID == 0 && !tt.RegionID.IsCountry() {
|
||||
c = Low
|
||||
}
|
||||
if tag, err := tt.Maximize(); err == nil && tag.LangID != 0 {
|
||||
return Base{tag.LangID}, c
|
||||
}
|
||||
return Base{0}, No
|
||||
}
|
||||
|
||||
// Script infers the script for the language tag. If it was not explicitly given, it will infer
|
||||
// a most likely candidate.
|
||||
// If more than one script is commonly used for a language, the most likely one
|
||||
// is returned with a low confidence indication. For example, it returns (Cyrl, Low)
|
||||
// for Serbian.
|
||||
// If a script cannot be inferred (Zzzz, No) is returned. We do not use Zyyy (undetermined)
|
||||
// as one would suspect from the IANA registry for BCP 47. In a Unicode context Zyyy marks
|
||||
// common characters (like 1, 2, 3, '.', etc.) and is therefore more like multiple scripts.
|
||||
// See https://www.unicode.org/reports/tr24/#Values for more details. Zzzz is also used for
|
||||
// unknown value in CLDR. (Zzzz, Exact) is returned if Zzzz was explicitly specified.
|
||||
// Note that an inferred script is never guaranteed to be the correct one. Latin is
|
||||
// almost exclusively used for Afrikaans, but Arabic has been used for some texts
|
||||
// in the past. Also, the script that is commonly used may change over time.
|
||||
// It uses a variant of CLDR's Add Likely Subtags algorithm. This is subject to change.
|
||||
func (t Tag) Script() (Script, Confidence) {
|
||||
if scr := t.script(); scr != 0 {
|
||||
return Script{scr}, Exact
|
||||
}
|
||||
tt := t.tag()
|
||||
sc, c := language.Script(_Zzzz), No
|
||||
if scr := tt.LangID.SuppressScript(); scr != 0 {
|
||||
// Note: it is not always the case that a language with a suppress
|
||||
// script value is only written in one script (e.g. kk, ms, pa).
|
||||
if tt.RegionID == 0 {
|
||||
return Script{scr}, High
|
||||
}
|
||||
sc, c = scr, High
|
||||
}
|
||||
if tag, err := tt.Maximize(); err == nil {
|
||||
if tag.ScriptID != sc {
|
||||
sc, c = tag.ScriptID, Low
|
||||
}
|
||||
} else {
|
||||
tt, _ = canonicalize(Deprecated|Macro, tt)
|
||||
if tag, err := tt.Maximize(); err == nil && tag.ScriptID != sc {
|
||||
sc, c = tag.ScriptID, Low
|
||||
}
|
||||
}
|
||||
return Script{sc}, c
|
||||
}
|
||||
|
||||
// Region returns the region for the language tag. If it was not explicitly given, it will
|
||||
// infer a most likely candidate from the context.
|
||||
// It uses a variant of CLDR's Add Likely Subtags algorithm. This is subject to change.
|
||||
func (t Tag) Region() (Region, Confidence) {
|
||||
if r := t.region(); r != 0 {
|
||||
return Region{r}, Exact
|
||||
}
|
||||
tt := t.tag()
|
||||
if tt, err := tt.Maximize(); err == nil {
|
||||
return Region{tt.RegionID}, Low // TODO: differentiate between high and low.
|
||||
}
|
||||
tt, _ = canonicalize(Deprecated|Macro, tt)
|
||||
if tag, err := tt.Maximize(); err == nil {
|
||||
return Region{tag.RegionID}, Low
|
||||
}
|
||||
return Region{_ZZ}, No // TODO: return world instead of undetermined?
|
||||
}
|
||||
|
||||
// Variants returns the variants specified explicitly for this language tag.
|
||||
// or nil if no variant was specified.
|
||||
func (t Tag) Variants() []Variant {
|
||||
if !compact.Tag(t).MayHaveVariants() {
|
||||
return nil
|
||||
}
|
||||
v := []Variant{}
|
||||
x, str := "", t.tag().Variants()
|
||||
for str != "" {
|
||||
x, str = nextToken(str)
|
||||
v = append(v, Variant{x})
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// Parent returns the CLDR parent of t. In CLDR, missing fields in data for a
|
||||
// specific language are substituted with fields from the parent language.
|
||||
// The parent for a language may change for newer versions of CLDR.
|
||||
//
|
||||
// Parent returns a tag for a less specific language that is mutually
|
||||
// intelligible or Und if there is no such language. This may not be the same as
|
||||
// simply stripping the last BCP 47 subtag. For instance, the parent of "zh-TW"
|
||||
// is "zh-Hant", and the parent of "zh-Hant" is "und".
|
||||
func (t Tag) Parent() Tag {
|
||||
return Tag(compact.Tag(t).Parent())
|
||||
}
|
||||
|
||||
// returns token t and the rest of the string.
|
||||
func nextToken(s string) (t, tail string) {
|
||||
p := strings.Index(s[1:], "-")
|
||||
if p == -1 {
|
||||
return s[1:], ""
|
||||
}
|
||||
p++
|
||||
return s[1:p], s[p:]
|
||||
}
|
||||
|
||||
// Extension is a single BCP 47 extension.
|
||||
type Extension struct {
|
||||
s string
|
||||
}
|
||||
|
||||
// String returns the string representation of the extension, including the
|
||||
// type tag.
|
||||
func (e Extension) String() string {
|
||||
return e.s
|
||||
}
|
||||
|
||||
// ParseExtension parses s as an extension and returns it on success.
|
||||
func ParseExtension(s string) (e Extension, err error) {
|
||||
ext, err := language.ParseExtension(s)
|
||||
return Extension{ext}, err
|
||||
}
|
||||
|
||||
// Type returns the one-byte extension type of e. It returns 0 for the zero
|
||||
// exception.
|
||||
func (e Extension) Type() byte {
|
||||
if e.s == "" {
|
||||
return 0
|
||||
}
|
||||
return e.s[0]
|
||||
}
|
||||
|
||||
// Tokens returns the list of tokens of e.
|
||||
func (e Extension) Tokens() []string {
|
||||
return strings.Split(e.s, "-")
|
||||
}
|
||||
|
||||
// Extension returns the extension of type x for tag t. It will return
|
||||
// false for ok if t does not have the requested extension. The returned
|
||||
// extension will be invalid in this case.
|
||||
func (t Tag) Extension(x byte) (ext Extension, ok bool) {
|
||||
if !compact.Tag(t).MayHaveExtensions() {
|
||||
return Extension{}, false
|
||||
}
|
||||
e, ok := t.tag().Extension(x)
|
||||
return Extension{e}, ok
|
||||
}
|
||||
|
||||
// Extensions returns all extensions of t.
|
||||
func (t Tag) Extensions() []Extension {
|
||||
if !compact.Tag(t).MayHaveExtensions() {
|
||||
return nil
|
||||
}
|
||||
e := []Extension{}
|
||||
for _, ext := range t.tag().Extensions() {
|
||||
e = append(e, Extension{ext})
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// TypeForKey returns the type associated with the given key, where key and type
|
||||
// are of the allowed values defined for the Unicode locale extension ('u') in
|
||||
// https://www.unicode.org/reports/tr35/#Unicode_Language_and_Locale_Identifiers.
|
||||
// TypeForKey will traverse the inheritance chain to get the correct value.
|
||||
func (t Tag) TypeForKey(key string) string {
|
||||
if !compact.Tag(t).MayHaveExtensions() {
|
||||
if key != "rg" && key != "va" {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return t.tag().TypeForKey(key)
|
||||
}
|
||||
|
||||
// SetTypeForKey returns a new Tag with the key set to type, where key and type
|
||||
// are of the allowed values defined for the Unicode locale extension ('u') in
|
||||
// https://www.unicode.org/reports/tr35/#Unicode_Language_and_Locale_Identifiers.
|
||||
// An empty value removes an existing pair with the same key.
|
||||
func (t Tag) SetTypeForKey(key, value string) (Tag, error) {
|
||||
tt, err := t.tag().SetTypeForKey(key, value)
|
||||
return makeTag(tt), err
|
||||
}
|
||||
|
||||
// NumCompactTags is the number of compact tags. The maximum tag is
|
||||
// NumCompactTags-1.
|
||||
const NumCompactTags = compact.NumCompactTags
|
||||
|
||||
// CompactIndex returns an index, where 0 <= index < NumCompactTags, for tags
|
||||
// for which data exists in the text repository.The index will change over time
|
||||
// and should not be stored in persistent storage. If t does not match a compact
|
||||
// index, exact will be false and the compact index will be returned for the
|
||||
// first match after repeatedly taking the Parent of t.
|
||||
func CompactIndex(t Tag) (index int, exact bool) {
|
||||
id, exact := compact.LanguageID(compact.Tag(t))
|
||||
return int(id), exact
|
||||
}
|
||||
|
||||
var root = language.Tag{}
|
||||
|
||||
// Base is an ISO 639 language code, used for encoding the base language
|
||||
// of a language tag.
|
||||
type Base struct {
|
||||
langID language.Language
|
||||
}
|
||||
|
||||
// ParseBase parses a 2- or 3-letter ISO 639 code.
|
||||
// It returns a ValueError if s is a well-formed but unknown language identifier
|
||||
// or another error if another error occurred.
|
||||
func ParseBase(s string) (Base, error) {
|
||||
l, err := language.ParseBase(s)
|
||||
return Base{l}, err
|
||||
}
|
||||
|
||||
// String returns the BCP 47 representation of the base language.
|
||||
func (b Base) String() string {
|
||||
return b.langID.String()
|
||||
}
|
||||
|
||||
// ISO3 returns the ISO 639-3 language code.
|
||||
func (b Base) ISO3() string {
|
||||
return b.langID.ISO3()
|
||||
}
|
||||
|
||||
// IsPrivateUse reports whether this language code is reserved for private use.
|
||||
func (b Base) IsPrivateUse() bool {
|
||||
return b.langID.IsPrivateUse()
|
||||
}
|
||||
|
||||
// Script is a 4-letter ISO 15924 code for representing scripts.
|
||||
// It is idiomatically represented in title case.
|
||||
type Script struct {
|
||||
scriptID language.Script
|
||||
}
|
||||
|
||||
// ParseScript parses a 4-letter ISO 15924 code.
|
||||
// It returns a ValueError if s is a well-formed but unknown script identifier
|
||||
// or another error if another error occurred.
|
||||
func ParseScript(s string) (Script, error) {
|
||||
sc, err := language.ParseScript(s)
|
||||
return Script{sc}, err
|
||||
}
|
||||
|
||||
// String returns the script code in title case.
|
||||
// It returns "Zzzz" for an unspecified script.
|
||||
func (s Script) String() string {
|
||||
return s.scriptID.String()
|
||||
}
|
||||
|
||||
// IsPrivateUse reports whether this script code is reserved for private use.
|
||||
func (s Script) IsPrivateUse() bool {
|
||||
return s.scriptID.IsPrivateUse()
|
||||
}
|
||||
|
||||
// Region is an ISO 3166-1 or UN M.49 code for representing countries and regions.
|
||||
type Region struct {
|
||||
regionID language.Region
|
||||
}
|
||||
|
||||
// EncodeM49 returns the Region for the given UN M.49 code.
|
||||
// It returns an error if r is not a valid code.
|
||||
func EncodeM49(r int) (Region, error) {
|
||||
rid, err := language.EncodeM49(r)
|
||||
return Region{rid}, err
|
||||
}
|
||||
|
||||
// ParseRegion parses a 2- or 3-letter ISO 3166-1 or a UN M.49 code.
|
||||
// It returns a ValueError if s is a well-formed but unknown region identifier
|
||||
// or another error if another error occurred.
|
||||
func ParseRegion(s string) (Region, error) {
|
||||
r, err := language.ParseRegion(s)
|
||||
return Region{r}, err
|
||||
}
|
||||
|
||||
// String returns the BCP 47 representation for the region.
|
||||
// It returns "ZZ" for an unspecified region.
|
||||
func (r Region) String() string {
|
||||
return r.regionID.String()
|
||||
}
|
||||
|
||||
// ISO3 returns the 3-letter ISO code of r.
|
||||
// Note that not all regions have a 3-letter ISO code.
|
||||
// In such cases this method returns "ZZZ".
|
||||
func (r Region) ISO3() string {
|
||||
return r.regionID.ISO3()
|
||||
}
|
||||
|
||||
// M49 returns the UN M.49 encoding of r, or 0 if this encoding
|
||||
// is not defined for r.
|
||||
func (r Region) M49() int {
|
||||
return r.regionID.M49()
|
||||
}
|
||||
|
||||
// IsPrivateUse reports whether r has the ISO 3166 User-assigned status. This
|
||||
// may include private-use tags that are assigned by CLDR and used in this
|
||||
// implementation. So IsPrivateUse and IsCountry can be simultaneously true.
|
||||
func (r Region) IsPrivateUse() bool {
|
||||
return r.regionID.IsPrivateUse()
|
||||
}
|
||||
|
||||
// IsCountry returns whether this region is a country or autonomous area. This
|
||||
// includes non-standard definitions from CLDR.
|
||||
func (r Region) IsCountry() bool {
|
||||
return r.regionID.IsCountry()
|
||||
}
|
||||
|
||||
// IsGroup returns whether this region defines a collection of regions. This
|
||||
// includes non-standard definitions from CLDR.
|
||||
func (r Region) IsGroup() bool {
|
||||
return r.regionID.IsGroup()
|
||||
}
|
||||
|
||||
// Contains returns whether Region c is contained by Region r. It returns true
|
||||
// if c == r.
|
||||
func (r Region) Contains(c Region) bool {
|
||||
return r.regionID.Contains(c.regionID)
|
||||
}
|
||||
|
||||
// TLD returns the country code top-level domain (ccTLD). UK is returned for GB.
|
||||
// In all other cases it returns either the region itself or an error.
|
||||
//
|
||||
// This method may return an error for a region for which there exists a
|
||||
// canonical form with a ccTLD. To get that ccTLD canonicalize r first. The
|
||||
// region will already be canonicalized it was obtained from a Tag that was
|
||||
// obtained using any of the default methods.
|
||||
func (r Region) TLD() (Region, error) {
|
||||
tld, err := r.regionID.TLD()
|
||||
return Region{tld}, err
|
||||
}
|
||||
|
||||
// Canonicalize returns the region or a possible replacement if the region is
|
||||
// deprecated. It will not return a replacement for deprecated regions that
|
||||
// are split into multiple regions.
|
||||
func (r Region) Canonicalize() Region {
|
||||
return Region{r.regionID.Canonicalize()}
|
||||
}
|
||||
|
||||
// Variant represents a registered variant of a language as defined by BCP 47.
|
||||
type Variant struct {
|
||||
variant string
|
||||
}
|
||||
|
||||
// ParseVariant parses and returns a Variant. An error is returned if s is not
|
||||
// a valid variant.
|
||||
func ParseVariant(s string) (Variant, error) {
|
||||
v, err := language.ParseVariant(s)
|
||||
return Variant{v.String()}, err
|
||||
}
|
||||
|
||||
// String returns the string representation of the variant.
|
||||
func (v Variant) String() string {
|
||||
return v.variant
|
||||
}
|
|
@ -0,0 +1,735 @@
|
|||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package language
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/text/internal/language"
|
||||
)
|
||||
|
||||
// A MatchOption configures a Matcher.
|
||||
type MatchOption func(*matcher)
|
||||
|
||||
// PreferSameScript will, in the absence of a match, result in the first
|
||||
// preferred tag with the same script as a supported tag to match this supported
|
||||
// tag. The default is currently true, but this may change in the future.
|
||||
func PreferSameScript(preferSame bool) MatchOption {
|
||||
return func(m *matcher) { m.preferSameScript = preferSame }
|
||||
}
|
||||
|
||||
// TODO(v1.0.0): consider making Matcher a concrete type, instead of interface.
|
||||
// There doesn't seem to be too much need for multiple types.
|
||||
// Making it a concrete type allows MatchStrings to be a method, which will
|
||||
// improve its discoverability.
|
||||
|
||||
// MatchStrings parses and matches the given strings until one of them matches
|
||||
// the language in the Matcher. A string may be an Accept-Language header as
|
||||
// handled by ParseAcceptLanguage. The default language is returned if no
|
||||
// other language matched.
|
||||
func MatchStrings(m Matcher, lang ...string) (tag Tag, index int) {
|
||||
for _, accept := range lang {
|
||||
desired, _, err := ParseAcceptLanguage(accept)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if tag, index, conf := m.Match(desired...); conf != No {
|
||||
return tag, index
|
||||
}
|
||||
}
|
||||
tag, index, _ = m.Match()
|
||||
return
|
||||
}
|
||||
|
||||
// Matcher is the interface that wraps the Match method.
|
||||
//
|
||||
// Match returns the best match for any of the given tags, along with
|
||||
// a unique index associated with the returned tag and a confidence
|
||||
// score.
|
||||
type Matcher interface {
|
||||
Match(t ...Tag) (tag Tag, index int, c Confidence)
|
||||
}
|
||||
|
||||
// Comprehends reports the confidence score for a speaker of a given language
|
||||
// to being able to comprehend the written form of an alternative language.
|
||||
func Comprehends(speaker, alternative Tag) Confidence {
|
||||
_, _, c := NewMatcher([]Tag{alternative}).Match(speaker)
|
||||
return c
|
||||
}
|
||||
|
||||
// NewMatcher returns a Matcher that matches an ordered list of preferred tags
|
||||
// against a list of supported tags based on written intelligibility, closeness
|
||||
// of dialect, equivalence of subtags and various other rules. It is initialized
|
||||
// with the list of supported tags. The first element is used as the default
|
||||
// value in case no match is found.
|
||||
//
|
||||
// Its Match method matches the first of the given Tags to reach a certain
|
||||
// confidence threshold. The tags passed to Match should therefore be specified
|
||||
// in order of preference. Extensions are ignored for matching.
|
||||
//
|
||||
// The index returned by the Match method corresponds to the index of the
|
||||
// matched tag in t, but is augmented with the Unicode extension ('u')of the
|
||||
// corresponding preferred tag. This allows user locale options to be passed
|
||||
// transparently.
|
||||
func NewMatcher(t []Tag, options ...MatchOption) Matcher {
|
||||
return newMatcher(t, options)
|
||||
}
|
||||
|
||||
func (m *matcher) Match(want ...Tag) (t Tag, index int, c Confidence) {
|
||||
var tt language.Tag
|
||||
match, w, c := m.getBest(want...)
|
||||
if match != nil {
|
||||
tt, index = match.tag, match.index
|
||||
} else {
|
||||
// TODO: this should be an option
|
||||
tt = m.default_.tag
|
||||
if m.preferSameScript {
|
||||
outer:
|
||||
for _, w := range want {
|
||||
script, _ := w.Script()
|
||||
if script.scriptID == 0 {
|
||||
// Don't do anything if there is no script, such as with
|
||||
// private subtags.
|
||||
continue
|
||||
}
|
||||
for i, h := range m.supported {
|
||||
if script.scriptID == h.maxScript {
|
||||
tt, index = h.tag, i
|
||||
break outer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO: select first language tag based on script.
|
||||
}
|
||||
if w.RegionID != tt.RegionID && w.RegionID != 0 {
|
||||
if w.RegionID != 0 && tt.RegionID != 0 && tt.RegionID.Contains(w.RegionID) {
|
||||
tt.RegionID = w.RegionID
|
||||
tt.RemakeString()
|
||||
} else if r := w.RegionID.String(); len(r) == 2 {
|
||||
// TODO: also filter macro and deprecated.
|
||||
tt, _ = tt.SetTypeForKey("rg", strings.ToLower(r)+"zzzz")
|
||||
}
|
||||
}
|
||||
// Copy options from the user-provided tag into the result tag. This is hard
|
||||
// to do after the fact, so we do it here.
|
||||
// TODO: add in alternative variants to -u-va-.
|
||||
// TODO: add preferred region to -u-rg-.
|
||||
if e := w.Extensions(); len(e) > 0 {
|
||||
b := language.Builder{}
|
||||
b.SetTag(tt)
|
||||
for _, e := range e {
|
||||
b.AddExt(e)
|
||||
}
|
||||
tt = b.Make()
|
||||
}
|
||||
return makeTag(tt), index, c
|
||||
}
|
||||
|
||||
// ErrMissingLikelyTagsData indicates no information was available
|
||||
// to compute likely values of missing tags.
|
||||
var ErrMissingLikelyTagsData = errors.New("missing likely tags data")
|
||||
|
||||
// func (t *Tag) setTagsFrom(id Tag) {
|
||||
// t.LangID = id.LangID
|
||||
// t.ScriptID = id.ScriptID
|
||||
// t.RegionID = id.RegionID
|
||||
// }
|
||||
|
||||
// Tag Matching
|
||||
// CLDR defines an algorithm for finding the best match between two sets of language
|
||||
// tags. The basic algorithm defines how to score a possible match and then find
|
||||
// the match with the best score
|
||||
// (see https://www.unicode.org/reports/tr35/#LanguageMatching).
|
||||
// Using scoring has several disadvantages. The scoring obfuscates the importance of
|
||||
// the various factors considered, making the algorithm harder to understand. Using
|
||||
// scoring also requires the full score to be computed for each pair of tags.
|
||||
//
|
||||
// We will use a different algorithm which aims to have the following properties:
|
||||
// - clarity on the precedence of the various selection factors, and
|
||||
// - improved performance by allowing early termination of a comparison.
|
||||
//
|
||||
// Matching algorithm (overview)
|
||||
// Input:
|
||||
// - supported: a set of supported tags
|
||||
// - default: the default tag to return in case there is no match
|
||||
// - desired: list of desired tags, ordered by preference, starting with
|
||||
// the most-preferred.
|
||||
//
|
||||
// Algorithm:
|
||||
// 1) Set the best match to the lowest confidence level
|
||||
// 2) For each tag in "desired":
|
||||
// a) For each tag in "supported":
|
||||
// 1) compute the match between the two tags.
|
||||
// 2) if the match is better than the previous best match, replace it
|
||||
// with the new match. (see next section)
|
||||
// b) if the current best match is Exact and pin is true the result will be
|
||||
// frozen to the language found thusfar, although better matches may
|
||||
// still be found for the same language.
|
||||
// 3) If the best match so far is below a certain threshold, return "default".
|
||||
//
|
||||
// Ranking:
|
||||
// We use two phases to determine whether one pair of tags are a better match
|
||||
// than another pair of tags. First, we determine a rough confidence level. If the
|
||||
// levels are different, the one with the highest confidence wins.
|
||||
// Second, if the rough confidence levels are identical, we use a set of tie-breaker
|
||||
// rules.
|
||||
//
|
||||
// The confidence level of matching a pair of tags is determined by finding the
|
||||
// lowest confidence level of any matches of the corresponding subtags (the
|
||||
// result is deemed as good as its weakest link).
|
||||
// We define the following levels:
|
||||
// Exact - An exact match of a subtag, before adding likely subtags.
|
||||
// MaxExact - An exact match of a subtag, after adding likely subtags.
|
||||
// [See Note 2].
|
||||
// High - High level of mutual intelligibility between different subtag
|
||||
// variants.
|
||||
// Low - Low level of mutual intelligibility between different subtag
|
||||
// variants.
|
||||
// No - No mutual intelligibility.
|
||||
//
|
||||
// The following levels can occur for each type of subtag:
|
||||
// Base: Exact, MaxExact, High, Low, No
|
||||
// Script: Exact, MaxExact [see Note 3], Low, No
|
||||
// Region: Exact, MaxExact, High
|
||||
// Variant: Exact, High
|
||||
// Private: Exact, No
|
||||
//
|
||||
// Any result with a confidence level of Low or higher is deemed a possible match.
|
||||
// Once a desired tag matches any of the supported tags with a level of MaxExact
|
||||
// or higher, the next desired tag is not considered (see Step 2.b).
|
||||
// Note that CLDR provides languageMatching data that defines close equivalence
|
||||
// classes for base languages, scripts and regions.
|
||||
//
|
||||
// Tie-breaking
|
||||
// If we get the same confidence level for two matches, we apply a sequence of
|
||||
// tie-breaking rules. The first that succeeds defines the result. The rules are
|
||||
// applied in the following order.
|
||||
// 1) Original language was defined and was identical.
|
||||
// 2) Original region was defined and was identical.
|
||||
// 3) Distance between two maximized regions was the smallest.
|
||||
// 4) Original script was defined and was identical.
|
||||
// 5) Distance from want tag to have tag using the parent relation [see Note 5.]
|
||||
// If there is still no winner after these rules are applied, the first match
|
||||
// found wins.
|
||||
//
|
||||
// Notes:
|
||||
// [2] In practice, as matching of Exact is done in a separate phase from
|
||||
// matching the other levels, we reuse the Exact level to mean MaxExact in
|
||||
// the second phase. As a consequence, we only need the levels defined by
|
||||
// the Confidence type. The MaxExact confidence level is mapped to High in
|
||||
// the public API.
|
||||
// [3] We do not differentiate between maximized script values that were derived
|
||||
// from suppressScript versus most likely tag data. We determined that in
|
||||
// ranking the two, one ranks just after the other. Moreover, the two cannot
|
||||
// occur concurrently. As a consequence, they are identical for practical
|
||||
// purposes.
|
||||
// [4] In case of deprecated, macro-equivalents and legacy mappings, we assign
|
||||
// the MaxExact level to allow iw vs he to still be a closer match than
|
||||
// en-AU vs en-US, for example.
|
||||
// [5] In CLDR a locale inherits fields that are unspecified for this locale
|
||||
// from its parent. Therefore, if a locale is a parent of another locale,
|
||||
// it is a strong measure for closeness, especially when no other tie
|
||||
// breaker rule applies. One could also argue it is inconsistent, for
|
||||
// example, when pt-AO matches pt (which CLDR equates with pt-BR), even
|
||||
// though its parent is pt-PT according to the inheritance rules.
|
||||
//
|
||||
// Implementation Details:
|
||||
// There are several performance considerations worth pointing out. Most notably,
|
||||
// we preprocess as much as possible (within reason) at the time of creation of a
|
||||
// matcher. This includes:
|
||||
// - creating a per-language map, which includes data for the raw base language
|
||||
// and its canonicalized variant (if applicable),
|
||||
// - expanding entries for the equivalence classes defined in CLDR's
|
||||
// languageMatch data.
|
||||
// The per-language map ensures that typically only a very small number of tags
|
||||
// need to be considered. The pre-expansion of canonicalized subtags and
|
||||
// equivalence classes reduces the amount of map lookups that need to be done at
|
||||
// runtime.
|
||||
|
||||
// matcher keeps a set of supported language tags, indexed by language.
|
||||
type matcher struct {
|
||||
default_ *haveTag
|
||||
supported []*haveTag
|
||||
index map[language.Language]*matchHeader
|
||||
passSettings bool
|
||||
preferSameScript bool
|
||||
}
|
||||
|
||||
// matchHeader has the lists of tags for exact matches and matches based on
|
||||
// maximized and canonicalized tags for a given language.
|
||||
type matchHeader struct {
|
||||
haveTags []*haveTag
|
||||
original bool
|
||||
}
|
||||
|
||||
// haveTag holds a supported Tag and its maximized script and region. The maximized
|
||||
// or canonicalized language is not stored as it is not needed during matching.
|
||||
type haveTag struct {
|
||||
tag language.Tag
|
||||
|
||||
// index of this tag in the original list of supported tags.
|
||||
index int
|
||||
|
||||
// conf is the maximum confidence that can result from matching this haveTag.
|
||||
// When conf < Exact this means it was inserted after applying a CLDR equivalence rule.
|
||||
conf Confidence
|
||||
|
||||
// Maximized region and script.
|
||||
maxRegion language.Region
|
||||
maxScript language.Script
|
||||
|
||||
// altScript may be checked as an alternative match to maxScript. If altScript
|
||||
// matches, the confidence level for this match is Low. Theoretically there
|
||||
// could be multiple alternative scripts. This does not occur in practice.
|
||||
altScript language.Script
|
||||
|
||||
// nextMax is the index of the next haveTag with the same maximized tags.
|
||||
nextMax uint16
|
||||
}
|
||||
|
||||
func makeHaveTag(tag language.Tag, index int) (haveTag, language.Language) {
|
||||
max := tag
|
||||
if tag.LangID != 0 || tag.RegionID != 0 || tag.ScriptID != 0 {
|
||||
max, _ = canonicalize(All, max)
|
||||
max, _ = max.Maximize()
|
||||
max.RemakeString()
|
||||
}
|
||||
return haveTag{tag, index, Exact, max.RegionID, max.ScriptID, altScript(max.LangID, max.ScriptID), 0}, max.LangID
|
||||
}
|
||||
|
||||
// altScript returns an alternative script that may match the given script with
|
||||
// a low confidence. At the moment, the langMatch data allows for at most one
|
||||
// script to map to another and we rely on this to keep the code simple.
|
||||
func altScript(l language.Language, s language.Script) language.Script {
|
||||
for _, alt := range matchScript {
|
||||
// TODO: also match cases where language is not the same.
|
||||
if (language.Language(alt.wantLang) == l || language.Language(alt.haveLang) == l) &&
|
||||
language.Script(alt.haveScript) == s {
|
||||
return language.Script(alt.wantScript)
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// addIfNew adds a haveTag to the list of tags only if it is a unique tag.
|
||||
// Tags that have the same maximized values are linked by index.
|
||||
func (h *matchHeader) addIfNew(n haveTag, exact bool) {
|
||||
h.original = h.original || exact
|
||||
// Don't add new exact matches.
|
||||
for _, v := range h.haveTags {
|
||||
if equalsRest(v.tag, n.tag) {
|
||||
return
|
||||
}
|
||||
}
|
||||
// Allow duplicate maximized tags, but create a linked list to allow quickly
|
||||
// comparing the equivalents and bail out.
|
||||
for i, v := range h.haveTags {
|
||||
if v.maxScript == n.maxScript &&
|
||||
v.maxRegion == n.maxRegion &&
|
||||
v.tag.VariantOrPrivateUseTags() == n.tag.VariantOrPrivateUseTags() {
|
||||
for h.haveTags[i].nextMax != 0 {
|
||||
i = int(h.haveTags[i].nextMax)
|
||||
}
|
||||
h.haveTags[i].nextMax = uint16(len(h.haveTags))
|
||||
break
|
||||
}
|
||||
}
|
||||
h.haveTags = append(h.haveTags, &n)
|
||||
}
|
||||
|
||||
// header returns the matchHeader for the given language. It creates one if
|
||||
// it doesn't already exist.
|
||||
func (m *matcher) header(l language.Language) *matchHeader {
|
||||
if h := m.index[l]; h != nil {
|
||||
return h
|
||||
}
|
||||
h := &matchHeader{}
|
||||
m.index[l] = h
|
||||
return h
|
||||
}
|
||||
|
||||
func toConf(d uint8) Confidence {
|
||||
if d <= 10 {
|
||||
return High
|
||||
}
|
||||
if d < 30 {
|
||||
return Low
|
||||
}
|
||||
return No
|
||||
}
|
||||
|
||||
// newMatcher builds an index for the given supported tags and returns it as
|
||||
// a matcher. It also expands the index by considering various equivalence classes
|
||||
// for a given tag.
|
||||
func newMatcher(supported []Tag, options []MatchOption) *matcher {
|
||||
m := &matcher{
|
||||
index: make(map[language.Language]*matchHeader),
|
||||
preferSameScript: true,
|
||||
}
|
||||
for _, o := range options {
|
||||
o(m)
|
||||
}
|
||||
if len(supported) == 0 {
|
||||
m.default_ = &haveTag{}
|
||||
return m
|
||||
}
|
||||
// Add supported languages to the index. Add exact matches first to give
|
||||
// them precedence.
|
||||
for i, tag := range supported {
|
||||
tt := tag.tag()
|
||||
pair, _ := makeHaveTag(tt, i)
|
||||
m.header(tt.LangID).addIfNew(pair, true)
|
||||
m.supported = append(m.supported, &pair)
|
||||
}
|
||||
m.default_ = m.header(supported[0].lang()).haveTags[0]
|
||||
// Keep these in two different loops to support the case that two equivalent
|
||||
// languages are distinguished, such as iw and he.
|
||||
for i, tag := range supported {
|
||||
tt := tag.tag()
|
||||
pair, max := makeHaveTag(tt, i)
|
||||
if max != tt.LangID {
|
||||
m.header(max).addIfNew(pair, true)
|
||||
}
|
||||
}
|
||||
|
||||
// update is used to add indexes in the map for equivalent languages.
|
||||
// update will only add entries to original indexes, thus not computing any
|
||||
// transitive relations.
|
||||
update := func(want, have uint16, conf Confidence) {
|
||||
if hh := m.index[language.Language(have)]; hh != nil {
|
||||
if !hh.original {
|
||||
return
|
||||
}
|
||||
hw := m.header(language.Language(want))
|
||||
for _, ht := range hh.haveTags {
|
||||
v := *ht
|
||||
if conf < v.conf {
|
||||
v.conf = conf
|
||||
}
|
||||
v.nextMax = 0 // this value needs to be recomputed
|
||||
if v.altScript != 0 {
|
||||
v.altScript = altScript(language.Language(want), v.maxScript)
|
||||
}
|
||||
hw.addIfNew(v, conf == Exact && hh.original)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add entries for languages with mutual intelligibility as defined by CLDR's
|
||||
// languageMatch data.
|
||||
for _, ml := range matchLang {
|
||||
update(ml.want, ml.have, toConf(ml.distance))
|
||||
if !ml.oneway {
|
||||
update(ml.have, ml.want, toConf(ml.distance))
|
||||
}
|
||||
}
|
||||
|
||||
// Add entries for possible canonicalizations. This is an optimization to
|
||||
// ensure that only one map lookup needs to be done at runtime per desired tag.
|
||||
// First we match deprecated equivalents. If they are perfect equivalents
|
||||
// (their canonicalization simply substitutes a different language code, but
|
||||
// nothing else), the match confidence is Exact, otherwise it is High.
|
||||
for i, lm := range language.AliasMap {
|
||||
// If deprecated codes match and there is no fiddling with the script or
|
||||
// or region, we consider it an exact match.
|
||||
conf := Exact
|
||||
if language.AliasTypes[i] != language.Macro {
|
||||
if !isExactEquivalent(language.Language(lm.From)) {
|
||||
conf = High
|
||||
}
|
||||
update(lm.To, lm.From, conf)
|
||||
}
|
||||
update(lm.From, lm.To, conf)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// getBest gets the best matching tag in m for any of the given tags, taking into
|
||||
// account the order of preference of the given tags.
|
||||
func (m *matcher) getBest(want ...Tag) (got *haveTag, orig language.Tag, c Confidence) {
|
||||
best := bestMatch{}
|
||||
for i, ww := range want {
|
||||
w := ww.tag()
|
||||
var max language.Tag
|
||||
// Check for exact match first.
|
||||
h := m.index[w.LangID]
|
||||
if w.LangID != 0 {
|
||||
if h == nil {
|
||||
continue
|
||||
}
|
||||
// Base language is defined.
|
||||
max, _ = canonicalize(Legacy|Deprecated|Macro, w)
|
||||
// A region that is added through canonicalization is stronger than
|
||||
// a maximized region: set it in the original (e.g. mo -> ro-MD).
|
||||
if w.RegionID != max.RegionID {
|
||||
w.RegionID = max.RegionID
|
||||
}
|
||||
// TODO: should we do the same for scripts?
|
||||
// See test case: en, sr, nl ; sh ; sr
|
||||
max, _ = max.Maximize()
|
||||
} else {
|
||||
// Base language is not defined.
|
||||
if h != nil {
|
||||
for i := range h.haveTags {
|
||||
have := h.haveTags[i]
|
||||
if equalsRest(have.tag, w) {
|
||||
return have, w, Exact
|
||||
}
|
||||
}
|
||||
}
|
||||
if w.ScriptID == 0 && w.RegionID == 0 {
|
||||
// We skip all tags matching und for approximate matching, including
|
||||
// private tags.
|
||||
continue
|
||||
}
|
||||
max, _ = w.Maximize()
|
||||
if h = m.index[max.LangID]; h == nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
pin := true
|
||||
for _, t := range want[i+1:] {
|
||||
if w.LangID == t.lang() {
|
||||
pin = false
|
||||
break
|
||||
}
|
||||
}
|
||||
// Check for match based on maximized tag.
|
||||
for i := range h.haveTags {
|
||||
have := h.haveTags[i]
|
||||
best.update(have, w, max.ScriptID, max.RegionID, pin)
|
||||
if best.conf == Exact {
|
||||
for have.nextMax != 0 {
|
||||
have = h.haveTags[have.nextMax]
|
||||
best.update(have, w, max.ScriptID, max.RegionID, pin)
|
||||
}
|
||||
return best.have, best.want, best.conf
|
||||
}
|
||||
}
|
||||
}
|
||||
if best.conf <= No {
|
||||
if len(want) != 0 {
|
||||
return nil, want[0].tag(), No
|
||||
}
|
||||
return nil, language.Tag{}, No
|
||||
}
|
||||
return best.have, best.want, best.conf
|
||||
}
|
||||
|
||||
// bestMatch accumulates the best match so far.
|
||||
type bestMatch struct {
|
||||
have *haveTag
|
||||
want language.Tag
|
||||
conf Confidence
|
||||
pinnedRegion language.Region
|
||||
pinLanguage bool
|
||||
sameRegionGroup bool
|
||||
// Cached results from applying tie-breaking rules.
|
||||
origLang bool
|
||||
origReg bool
|
||||
paradigmReg bool
|
||||
regGroupDist uint8
|
||||
origScript bool
|
||||
}
|
||||
|
||||
// update updates the existing best match if the new pair is considered to be a
|
||||
// better match. To determine if the given pair is a better match, it first
|
||||
// computes the rough confidence level. If this surpasses the current match, it
|
||||
// will replace it and update the tie-breaker rule cache. If there is a tie, it
|
||||
// proceeds with applying a series of tie-breaker rules. If there is no
|
||||
// conclusive winner after applying the tie-breaker rules, it leaves the current
|
||||
// match as the preferred match.
|
||||
//
|
||||
// If pin is true and have and tag are a strong match, it will henceforth only
|
||||
// consider matches for this language. This corresponds to the nothing that most
|
||||
// users have a strong preference for the first defined language. A user can
|
||||
// still prefer a second language over a dialect of the preferred language by
|
||||
// explicitly specifying dialects, e.g. "en, nl, en-GB". In this case pin should
|
||||
// be false.
|
||||
func (m *bestMatch) update(have *haveTag, tag language.Tag, maxScript language.Script, maxRegion language.Region, pin bool) {
|
||||
// Bail if the maximum attainable confidence is below that of the current best match.
|
||||
c := have.conf
|
||||
if c < m.conf {
|
||||
return
|
||||
}
|
||||
// Don't change the language once we already have found an exact match.
|
||||
if m.pinLanguage && tag.LangID != m.want.LangID {
|
||||
return
|
||||
}
|
||||
// Pin the region group if we are comparing tags for the same language.
|
||||
if tag.LangID == m.want.LangID && m.sameRegionGroup {
|
||||
_, sameGroup := regionGroupDist(m.pinnedRegion, have.maxRegion, have.maxScript, m.want.LangID)
|
||||
if !sameGroup {
|
||||
return
|
||||
}
|
||||
}
|
||||
if c == Exact && have.maxScript == maxScript {
|
||||
// If there is another language and then another entry of this language,
|
||||
// don't pin anything, otherwise pin the language.
|
||||
m.pinLanguage = pin
|
||||
}
|
||||
if equalsRest(have.tag, tag) {
|
||||
} else if have.maxScript != maxScript {
|
||||
// There is usually very little comprehension between different scripts.
|
||||
// In a few cases there may still be Low comprehension. This possibility
|
||||
// is pre-computed and stored in have.altScript.
|
||||
if Low < m.conf || have.altScript != maxScript {
|
||||
return
|
||||
}
|
||||
c = Low
|
||||
} else if have.maxRegion != maxRegion {
|
||||
if High < c {
|
||||
// There is usually a small difference between languages across regions.
|
||||
c = High
|
||||
}
|
||||
}
|
||||
|
||||
// We store the results of the computations of the tie-breaker rules along
|
||||
// with the best match. There is no need to do the checks once we determine
|
||||
// we have a winner, but we do still need to do the tie-breaker computations.
|
||||
// We use "beaten" to keep track if we still need to do the checks.
|
||||
beaten := false // true if the new pair defeats the current one.
|
||||
if c != m.conf {
|
||||
if c < m.conf {
|
||||
return
|
||||
}
|
||||
beaten = true
|
||||
}
|
||||
|
||||
// Tie-breaker rules:
|
||||
// We prefer if the pre-maximized language was specified and identical.
|
||||
origLang := have.tag.LangID == tag.LangID && tag.LangID != 0
|
||||
if !beaten && m.origLang != origLang {
|
||||
if m.origLang {
|
||||
return
|
||||
}
|
||||
beaten = true
|
||||
}
|
||||
|
||||
// We prefer if the pre-maximized region was specified and identical.
|
||||
origReg := have.tag.RegionID == tag.RegionID && tag.RegionID != 0
|
||||
if !beaten && m.origReg != origReg {
|
||||
if m.origReg {
|
||||
return
|
||||
}
|
||||
beaten = true
|
||||
}
|
||||
|
||||
regGroupDist, sameGroup := regionGroupDist(have.maxRegion, maxRegion, maxScript, tag.LangID)
|
||||
if !beaten && m.regGroupDist != regGroupDist {
|
||||
if regGroupDist > m.regGroupDist {
|
||||
return
|
||||
}
|
||||
beaten = true
|
||||
}
|
||||
|
||||
paradigmReg := isParadigmLocale(tag.LangID, have.maxRegion)
|
||||
if !beaten && m.paradigmReg != paradigmReg {
|
||||
if !paradigmReg {
|
||||
return
|
||||
}
|
||||
beaten = true
|
||||
}
|
||||
|
||||
// Next we prefer if the pre-maximized script was specified and identical.
|
||||
origScript := have.tag.ScriptID == tag.ScriptID && tag.ScriptID != 0
|
||||
if !beaten && m.origScript != origScript {
|
||||
if m.origScript {
|
||||
return
|
||||
}
|
||||
beaten = true
|
||||
}
|
||||
|
||||
// Update m to the newly found best match.
|
||||
if beaten {
|
||||
m.have = have
|
||||
m.want = tag
|
||||
m.conf = c
|
||||
m.pinnedRegion = maxRegion
|
||||
m.sameRegionGroup = sameGroup
|
||||
m.origLang = origLang
|
||||
m.origReg = origReg
|
||||
m.paradigmReg = paradigmReg
|
||||
m.origScript = origScript
|
||||
m.regGroupDist = regGroupDist
|
||||
}
|
||||
}
|
||||
|
||||
func isParadigmLocale(lang language.Language, r language.Region) bool {
|
||||
for _, e := range paradigmLocales {
|
||||
if language.Language(e[0]) == lang && (r == language.Region(e[1]) || r == language.Region(e[2])) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// regionGroupDist computes the distance between two regions based on their
|
||||
// CLDR grouping.
|
||||
func regionGroupDist(a, b language.Region, script language.Script, lang language.Language) (dist uint8, same bool) {
|
||||
const defaultDistance = 4
|
||||
|
||||
aGroup := uint(regionToGroups[a]) << 1
|
||||
bGroup := uint(regionToGroups[b]) << 1
|
||||
for _, ri := range matchRegion {
|
||||
if language.Language(ri.lang) == lang && (ri.script == 0 || language.Script(ri.script) == script) {
|
||||
group := uint(1 << (ri.group &^ 0x80))
|
||||
if 0x80&ri.group == 0 {
|
||||
if aGroup&bGroup&group != 0 { // Both regions are in the group.
|
||||
return ri.distance, ri.distance == defaultDistance
|
||||
}
|
||||
} else {
|
||||
if (aGroup|bGroup)&group == 0 { // Both regions are not in the group.
|
||||
return ri.distance, ri.distance == defaultDistance
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return defaultDistance, true
|
||||
}
|
||||
|
||||
// equalsRest compares everything except the language.
|
||||
func equalsRest(a, b language.Tag) bool {
|
||||
// TODO: don't include extensions in this comparison. To do this efficiently,
|
||||
// though, we should handle private tags separately.
|
||||
return a.ScriptID == b.ScriptID && a.RegionID == b.RegionID && a.VariantOrPrivateUseTags() == b.VariantOrPrivateUseTags()
|
||||
}
|
||||
|
||||
// isExactEquivalent returns true if canonicalizing the language will not alter
|
||||
// the script or region of a tag.
|
||||
func isExactEquivalent(l language.Language) bool {
|
||||
for _, o := range notEquivalent {
|
||||
if o == l {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
var notEquivalent []language.Language
|
||||
|
||||
func init() {
|
||||
// Create a list of all languages for which canonicalization may alter the
|
||||
// script or region.
|
||||
for _, lm := range language.AliasMap {
|
||||
tag := language.Tag{LangID: language.Language(lm.From)}
|
||||
if tag, _ = canonicalize(All, tag); tag.ScriptID != 0 || tag.RegionID != 0 {
|
||||
notEquivalent = append(notEquivalent, language.Language(lm.From))
|
||||
}
|
||||
}
|
||||
// Maximize undefined regions of paradigm locales.
|
||||
for i, v := range paradigmLocales {
|
||||
t := language.Tag{LangID: language.Language(v[0])}
|
||||
max, _ := t.Maximize()
|
||||
if v[1] == 0 {
|
||||
paradigmLocales[i][1] = uint16(max.RegionID)
|
||||
}
|
||||
if v[2] == 0 {
|
||||
paradigmLocales[i][2] = uint16(max.RegionID)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,228 @@
|
|||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package language
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/text/internal/language"
|
||||
)
|
||||
|
||||
// ValueError is returned by any of the parsing functions when the
|
||||
// input is well-formed but the respective subtag is not recognized
|
||||
// as a valid value.
|
||||
type ValueError interface {
|
||||
error
|
||||
|
||||
// Subtag returns the subtag for which the error occurred.
|
||||
Subtag() string
|
||||
}
|
||||
|
||||
// Parse parses the given BCP 47 string and returns a valid Tag. If parsing
|
||||
// failed it returns an error and any part of the tag that could be parsed.
|
||||
// If parsing succeeded but an unknown value was found, it returns
|
||||
// ValueError. The Tag returned in this case is just stripped of the unknown
|
||||
// value. All other values are preserved. It accepts tags in the BCP 47 format
|
||||
// and extensions to this standard defined in
|
||||
// https://www.unicode.org/reports/tr35/#Unicode_Language_and_Locale_Identifiers.
|
||||
// The resulting tag is canonicalized using the default canonicalization type.
|
||||
func Parse(s string) (t Tag, err error) {
|
||||
return Default.Parse(s)
|
||||
}
|
||||
|
||||
// Parse parses the given BCP 47 string and returns a valid Tag. If parsing
|
||||
// failed it returns an error and any part of the tag that could be parsed.
|
||||
// If parsing succeeded but an unknown value was found, it returns
|
||||
// ValueError. The Tag returned in this case is just stripped of the unknown
|
||||
// value. All other values are preserved. It accepts tags in the BCP 47 format
|
||||
// and extensions to this standard defined in
|
||||
// https://www.unicode.org/reports/tr35/#Unicode_Language_and_Locale_Identifiers.
|
||||
// The resulting tag is canonicalized using the canonicalization type c.
|
||||
func (c CanonType) Parse(s string) (t Tag, err error) {
|
||||
tt, err := language.Parse(s)
|
||||
if err != nil {
|
||||
return makeTag(tt), err
|
||||
}
|
||||
tt, changed := canonicalize(c, tt)
|
||||
if changed {
|
||||
tt.RemakeString()
|
||||
}
|
||||
return makeTag(tt), err
|
||||
}
|
||||
|
||||
// Compose creates a Tag from individual parts, which may be of type Tag, Base,
|
||||
// Script, Region, Variant, []Variant, Extension, []Extension or error. If a
|
||||
// Base, Script or Region or slice of type Variant or Extension is passed more
|
||||
// than once, the latter will overwrite the former. Variants and Extensions are
|
||||
// accumulated, but if two extensions of the same type are passed, the latter
|
||||
// will replace the former. For -u extensions, though, the key-type pairs are
|
||||
// added, where later values overwrite older ones. A Tag overwrites all former
|
||||
// values and typically only makes sense as the first argument. The resulting
|
||||
// tag is returned after canonicalizing using the Default CanonType. If one or
|
||||
// more errors are encountered, one of the errors is returned.
|
||||
func Compose(part ...interface{}) (t Tag, err error) {
|
||||
return Default.Compose(part...)
|
||||
}
|
||||
|
||||
// Compose creates a Tag from individual parts, which may be of type Tag, Base,
|
||||
// Script, Region, Variant, []Variant, Extension, []Extension or error. If a
|
||||
// Base, Script or Region or slice of type Variant or Extension is passed more
|
||||
// than once, the latter will overwrite the former. Variants and Extensions are
|
||||
// accumulated, but if two extensions of the same type are passed, the latter
|
||||
// will replace the former. For -u extensions, though, the key-type pairs are
|
||||
// added, where later values overwrite older ones. A Tag overwrites all former
|
||||
// values and typically only makes sense as the first argument. The resulting
|
||||
// tag is returned after canonicalizing using CanonType c. If one or more errors
|
||||
// are encountered, one of the errors is returned.
|
||||
func (c CanonType) Compose(part ...interface{}) (t Tag, err error) {
|
||||
var b language.Builder
|
||||
if err = update(&b, part...); err != nil {
|
||||
return und, err
|
||||
}
|
||||
b.Tag, _ = canonicalize(c, b.Tag)
|
||||
return makeTag(b.Make()), err
|
||||
}
|
||||
|
||||
var errInvalidArgument = errors.New("invalid Extension or Variant")
|
||||
|
||||
func update(b *language.Builder, part ...interface{}) (err error) {
|
||||
for _, x := range part {
|
||||
switch v := x.(type) {
|
||||
case Tag:
|
||||
b.SetTag(v.tag())
|
||||
case Base:
|
||||
b.Tag.LangID = v.langID
|
||||
case Script:
|
||||
b.Tag.ScriptID = v.scriptID
|
||||
case Region:
|
||||
b.Tag.RegionID = v.regionID
|
||||
case Variant:
|
||||
if v.variant == "" {
|
||||
err = errInvalidArgument
|
||||
break
|
||||
}
|
||||
b.AddVariant(v.variant)
|
||||
case Extension:
|
||||
if v.s == "" {
|
||||
err = errInvalidArgument
|
||||
break
|
||||
}
|
||||
b.SetExt(v.s)
|
||||
case []Variant:
|
||||
b.ClearVariants()
|
||||
for _, v := range v {
|
||||
b.AddVariant(v.variant)
|
||||
}
|
||||
case []Extension:
|
||||
b.ClearExtensions()
|
||||
for _, e := range v {
|
||||
b.SetExt(e.s)
|
||||
}
|
||||
// TODO: support parsing of raw strings based on morphology or just extensions?
|
||||
case error:
|
||||
if v != nil {
|
||||
err = v
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var errInvalidWeight = errors.New("ParseAcceptLanguage: invalid weight")
|
||||
|
||||
// ParseAcceptLanguage parses the contents of an Accept-Language header as
|
||||
// defined in http://www.ietf.org/rfc/rfc2616.txt and returns a list of Tags and
|
||||
// a list of corresponding quality weights. It is more permissive than RFC 2616
|
||||
// and may return non-nil slices even if the input is not valid.
|
||||
// The Tags will be sorted by highest weight first and then by first occurrence.
|
||||
// Tags with a weight of zero will be dropped. An error will be returned if the
|
||||
// input could not be parsed.
|
||||
func ParseAcceptLanguage(s string) (tag []Tag, q []float32, err error) {
|
||||
var entry string
|
||||
for s != "" {
|
||||
if entry, s = split(s, ','); entry == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
entry, weight := split(entry, ';')
|
||||
|
||||
// Scan the language.
|
||||
t, err := Parse(entry)
|
||||
if err != nil {
|
||||
id, ok := acceptFallback[entry]
|
||||
if !ok {
|
||||
return nil, nil, err
|
||||
}
|
||||
t = makeTag(language.Tag{LangID: id})
|
||||
}
|
||||
|
||||
// Scan the optional weight.
|
||||
w := 1.0
|
||||
if weight != "" {
|
||||
weight = consume(weight, 'q')
|
||||
weight = consume(weight, '=')
|
||||
// consume returns the empty string when a token could not be
|
||||
// consumed, resulting in an error for ParseFloat.
|
||||
if w, err = strconv.ParseFloat(weight, 32); err != nil {
|
||||
return nil, nil, errInvalidWeight
|
||||
}
|
||||
// Drop tags with a quality weight of 0.
|
||||
if w <= 0 {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
tag = append(tag, t)
|
||||
q = append(q, float32(w))
|
||||
}
|
||||
sortStable(&tagSort{tag, q})
|
||||
return tag, q, nil
|
||||
}
|
||||
|
||||
// consume removes a leading token c from s and returns the result or the empty
|
||||
// string if there is no such token.
|
||||
func consume(s string, c byte) string {
|
||||
if s == "" || s[0] != c {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(s[1:])
|
||||
}
|
||||
|
||||
func split(s string, c byte) (head, tail string) {
|
||||
if i := strings.IndexByte(s, c); i >= 0 {
|
||||
return strings.TrimSpace(s[:i]), strings.TrimSpace(s[i+1:])
|
||||
}
|
||||
return strings.TrimSpace(s), ""
|
||||
}
|
||||
|
||||
// Add hack mapping to deal with a small number of cases that occur
|
||||
// in Accept-Language (with reasonable frequency).
|
||||
var acceptFallback = map[string]language.Language{
|
||||
"english": _en,
|
||||
"deutsch": _de,
|
||||
"italian": _it,
|
||||
"french": _fr,
|
||||
"*": _mul, // defined in the spec to match all languages.
|
||||
}
|
||||
|
||||
type tagSort struct {
|
||||
tag []Tag
|
||||
q []float32
|
||||
}
|
||||
|
||||
func (s *tagSort) Len() int {
|
||||
return len(s.q)
|
||||
}
|
||||
|
||||
func (s *tagSort) Less(i, j int) bool {
|
||||
return s.q[i] > s.q[j]
|
||||
}
|
||||
|
||||
func (s *tagSort) Swap(i, j int) {
|
||||
s.tag[i], s.tag[j] = s.tag[j], s.tag[i]
|
||||
s.q[i], s.q[j] = s.q[j], s.q[i]
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue