JWT wrapping tokens (#2172)
This commit is contained in:
parent
8ef964c496
commit
3129187dc2
|
@ -87,6 +87,14 @@ func (f *AuditFormatter) FormatRequest(
|
|||
errString = err.Error()
|
||||
}
|
||||
|
||||
var reqWrapInfo *AuditRequestWrapInfo
|
||||
if req.WrapInfo != nil {
|
||||
reqWrapInfo = &AuditRequestWrapInfo{
|
||||
TTL: int(req.WrapInfo.TTL / time.Second),
|
||||
Format: req.WrapInfo.Format,
|
||||
}
|
||||
}
|
||||
|
||||
reqEntry := &AuditRequestEntry{
|
||||
Type: "request",
|
||||
Error: errString,
|
||||
|
@ -105,7 +113,7 @@ func (f *AuditFormatter) FormatRequest(
|
|||
Path: req.Path,
|
||||
Data: req.Data,
|
||||
RemoteAddr: getRemoteAddr(req),
|
||||
WrapTTL: int(req.WrapTTL / time.Second),
|
||||
WrapInfo: reqWrapInfo,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -238,9 +246,17 @@ func (f *AuditFormatter) FormatResponse(
|
|||
}
|
||||
}
|
||||
|
||||
var respWrapInfo *AuditWrapInfo
|
||||
var reqWrapInfo *AuditRequestWrapInfo
|
||||
if req.WrapInfo != nil {
|
||||
reqWrapInfo = &AuditRequestWrapInfo{
|
||||
TTL: int(req.WrapInfo.TTL / time.Second),
|
||||
Format: req.WrapInfo.Format,
|
||||
}
|
||||
}
|
||||
|
||||
var respWrapInfo *AuditResponseWrapInfo
|
||||
if resp.WrapInfo != nil {
|
||||
respWrapInfo = &AuditWrapInfo{
|
||||
respWrapInfo = &AuditResponseWrapInfo{
|
||||
TTL: int(resp.WrapInfo.TTL / time.Second),
|
||||
Token: resp.WrapInfo.Token,
|
||||
CreationTime: resp.WrapInfo.CreationTime.Format(time.RFC3339Nano),
|
||||
|
@ -266,7 +282,7 @@ func (f *AuditFormatter) FormatResponse(
|
|||
Path: req.Path,
|
||||
Data: req.Data,
|
||||
RemoteAddr: getRemoteAddr(req),
|
||||
WrapTTL: int(req.WrapTTL / time.Second),
|
||||
WrapInfo: reqWrapInfo,
|
||||
},
|
||||
|
||||
Response: AuditResponse{
|
||||
|
@ -312,7 +328,7 @@ type AuditRequest struct {
|
|||
Path string `json:"path"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
RemoteAddr string `json:"remote_address"`
|
||||
WrapTTL int `json:"wrap_ttl"`
|
||||
WrapInfo *AuditRequestWrapInfo `json:"wrap_info,omitempty"`
|
||||
}
|
||||
|
||||
type AuditResponse struct {
|
||||
|
@ -320,7 +336,7 @@ type AuditResponse struct {
|
|||
Secret *AuditSecret `json:"secret,omitempty"`
|
||||
Data map[string]interface{} `json:"data,omitempty"`
|
||||
Redirect string `json:"redirect,omitempty"`
|
||||
WrapInfo *AuditWrapInfo `json:"wrap_info,omitempty"`
|
||||
WrapInfo *AuditResponseWrapInfo `json:"wrap_info,omitempty"`
|
||||
}
|
||||
|
||||
type AuditAuth struct {
|
||||
|
@ -335,7 +351,12 @@ type AuditSecret struct {
|
|||
LeaseID string `json:"lease_id"`
|
||||
}
|
||||
|
||||
type AuditWrapInfo struct {
|
||||
type AuditRequestWrapInfo struct {
|
||||
TTL int `json:"ttl"`
|
||||
Format string `json:"format"`
|
||||
}
|
||||
|
||||
type AuditResponseWrapInfo struct {
|
||||
TTL int `json:"ttl"`
|
||||
Token string `json:"token"`
|
||||
CreationTime string `json:"creation_time"`
|
||||
|
|
|
@ -29,7 +29,9 @@ func TestFormatJSON_formatRequest(t *testing.T) {
|
|||
Connection: &logical.Connection{
|
||||
RemoteAddr: "127.0.0.1",
|
||||
},
|
||||
WrapTTL: 60 * time.Second,
|
||||
WrapInfo: &logical.RequestWrapInfo{
|
||||
TTL: 60 * time.Second,
|
||||
},
|
||||
},
|
||||
errors.New("this is an error"),
|
||||
testFormatJSONReqBasicStr,
|
||||
|
@ -74,5 +76,5 @@ func TestFormatJSON_formatRequest(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
const testFormatJSONReqBasicStr = `{"time":"2015-08-05T13:45:46Z","type":"request","auth":{"display_name":"","policies":["root"],"metadata":null},"request":{"operation":"update","path":"/foo","data":null,"wrap_ttl":60,"remote_address":"127.0.0.1"},"error":"this is an error"}
|
||||
const testFormatJSONReqBasicStr = `{"time":"2015-08-05T13:45:46Z","type":"request","auth":{"display_name":"","policies":["root"],"metadata":null},"request":{"operation":"update","path":"/foo","data":null,"wrap_ttl":60,"remote_address":"127.0.0.1","wrap_info":{"ttl":60,"format":""}},"error":"this is an error"}
|
||||
`
|
||||
|
|
|
@ -28,11 +28,13 @@ func TestFormatJSONx_formatRequest(t *testing.T) {
|
|||
Connection: &logical.Connection{
|
||||
RemoteAddr: "127.0.0.1",
|
||||
},
|
||||
WrapTTL: 60 * time.Second,
|
||||
WrapInfo: &logical.RequestWrapInfo{
|
||||
TTL: 60 * time.Second,
|
||||
},
|
||||
},
|
||||
errors.New("this is an error"),
|
||||
"",
|
||||
`<json:object name="auth"><json:string name="accessor"></json:string><json:string name="client_token"></json:string><json:string name="display_name"></json:string><json:null name="metadata" /><json:array name="policies"><json:string>root</json:string></json:array></json:object><json:string name="error">this is an error</json:string><json:object name="request"><json:string name="client_token"></json:string><json:string name="client_token_accessor"></json:string><json:null name="data" /><json:string name="id"></json:string><json:string name="operation">update</json:string><json:string name="path">/foo</json:string><json:string name="remote_address">127.0.0.1</json:string><json:number name="wrap_ttl">60</json:number></json:object><json:string name="type">request</json:string>`,
|
||||
`<json:object name="auth"><json:string name="accessor"></json:string><json:string name="client_token"></json:string><json:string name="display_name"></json:string><json:null name="metadata" /><json:array name="policies"><json:string>root</json:string></json:array></json:object><json:string name="error">this is an error</json:string><json:object name="request"><json:string name="client_token"></json:string><json:string name="client_token_accessor"></json:string><json:null name="data" /><json:string name="id"></json:string><json:string name="operation">update</json:string><json:string name="path">/foo</json:string><json:string name="remote_address">127.0.0.1</json:string><json:object name="wrap_info"><json:string name="format"></json:string><json:number name="ttl">60</json:number></json:object></json:object><json:string name="type">request</json:string>`,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
@ -84,7 +84,7 @@ func Hash(salter *salt.Salt, raw interface{}) error {
|
|||
|
||||
s.Data = data.(map[string]interface{})
|
||||
|
||||
case *logical.WrapInfo:
|
||||
case *logical.ResponseWrapInfo:
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -44,7 +44,9 @@ func TestCopy_request(t *testing.T) {
|
|||
Data: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
WrapTTL: 60 * time.Second,
|
||||
WrapInfo: &logical.RequestWrapInfo{
|
||||
TTL: 60 * time.Second,
|
||||
},
|
||||
}
|
||||
arg := expected
|
||||
|
||||
|
@ -67,7 +69,7 @@ func TestCopy_response(t *testing.T) {
|
|||
Data: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
WrapInfo: &logical.WrapInfo{
|
||||
WrapInfo: &logical.ResponseWrapInfo{
|
||||
TTL: 60,
|
||||
Token: "foo",
|
||||
CreationTime: time.Now(),
|
||||
|
@ -138,7 +140,7 @@ func TestHash(t *testing.T) {
|
|||
Data: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
WrapInfo: &logical.WrapInfo{
|
||||
WrapInfo: &logical.ResponseWrapInfo{
|
||||
TTL: 60,
|
||||
Token: "bar",
|
||||
CreationTime: now,
|
||||
|
@ -149,7 +151,7 @@ func TestHash(t *testing.T) {
|
|||
Data: map[string]interface{}{
|
||||
"foo": "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317",
|
||||
},
|
||||
WrapInfo: &logical.WrapInfo{
|
||||
WrapInfo: &logical.ResponseWrapInfo{
|
||||
TTL: 60,
|
||||
Token: "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317",
|
||||
CreationTime: now,
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/meta"
|
||||
)
|
||||
|
@ -37,12 +36,6 @@ func (c *UnwrapCommand) Run(args []string) int {
|
|||
case 0:
|
||||
case 1:
|
||||
tokenID = args[0]
|
||||
_, err = uuid.ParseUUID(tokenID)
|
||||
if err != nil {
|
||||
c.Ui.Error(fmt.Sprintf(
|
||||
"Given token could not be parsed as a UUID: %v", err))
|
||||
return 1
|
||||
}
|
||||
default:
|
||||
c.Ui.Error("Unwrap expects zero or one argument (the ID of the wrapping token)")
|
||||
flags.Usage()
|
||||
|
|
|
@ -19,10 +19,14 @@ const (
|
|||
// AuthHeaderName is the name of the header containing the token.
|
||||
AuthHeaderName = "X-Vault-Token"
|
||||
|
||||
// WrapHeaderName is the name of the header containing a directive to wrap the
|
||||
// response.
|
||||
// WrapTTLHeaderName is the name of the header containing a directive to
|
||||
// wrap the response
|
||||
WrapTTLHeaderName = "X-Vault-Wrap-TTL"
|
||||
|
||||
// WrapFormatHeaderName is the name of the header containing the format to
|
||||
// wrap in; has no effect if the wrap TTL is not set
|
||||
WrapFormatHeaderName = "X-Vault-Wrap-Format"
|
||||
|
||||
// NoRequestForwardingHeaderName is the name of the header telling Vault
|
||||
// not to use request forwarding
|
||||
NoRequestForwardingHeaderName = "X-Vault-No-Request-Forwarding"
|
||||
|
@ -91,20 +95,7 @@ func wrappingVerificationFunc(core *vault.Core, req *logical.Request) error {
|
|||
return fmt.Errorf("invalid request")
|
||||
}
|
||||
|
||||
var token string
|
||||
if req.Data != nil && req.Data["token"] != nil {
|
||||
if tokenStr, ok := req.Data["token"].(string); !ok {
|
||||
return fmt.Errorf("could not decode token in request body")
|
||||
} else if tokenStr == "" {
|
||||
return fmt.Errorf("empty token in request body")
|
||||
} else {
|
||||
token = tokenStr
|
||||
}
|
||||
} else {
|
||||
token = req.ClientToken
|
||||
}
|
||||
|
||||
valid, err := core.ValidateWrappingToken(token)
|
||||
valid, err := core.ValidateWrappingToken(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error validating wrapping token: %v", err)
|
||||
}
|
||||
|
@ -288,9 +279,8 @@ func requestAuth(core *vault.Core, r *http.Request, req *logical.Request) *logic
|
|||
return req
|
||||
}
|
||||
|
||||
// requestWrapTTL adds the WrapTTL value to the logical.Request if it
|
||||
// exists.
|
||||
func requestWrapTTL(r *http.Request, req *logical.Request) (*logical.Request, error) {
|
||||
// requestWrapInfo adds the WrapInfo value to the logical.Request if wrap info exists
|
||||
func requestWrapInfo(r *http.Request, req *logical.Request) (*logical.Request, error) {
|
||||
// First try for the header value
|
||||
wrapTTL := r.Header.Get(WrapTTLHeaderName)
|
||||
if wrapTTL == "" {
|
||||
|
@ -305,7 +295,16 @@ func requestWrapTTL(r *http.Request, req *logical.Request) (*logical.Request, er
|
|||
if int64(dur) < 0 {
|
||||
return req, fmt.Errorf("requested wrap ttl cannot be negative")
|
||||
}
|
||||
req.WrapTTL = dur
|
||||
|
||||
req.WrapInfo = &logical.RequestWrapInfo{
|
||||
TTL: dur,
|
||||
}
|
||||
|
||||
wrapFormat := r.Header.Get(WrapFormatHeaderName)
|
||||
switch wrapFormat {
|
||||
case "jwt":
|
||||
req.WrapInfo.Format = "jwt"
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
|
|
@ -80,7 +80,7 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques
|
|||
Connection: getConnection(r),
|
||||
})
|
||||
|
||||
req, err = requestWrapTTL(r, req)
|
||||
req, err = requestWrapInfo(r, req)
|
||||
if err != nil {
|
||||
return nil, http.StatusBadRequest, errwrap.Wrapf("error parsing X-Vault-Wrap-TTL header: {{err}}", err)
|
||||
}
|
||||
|
|
|
@ -6,6 +6,18 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
// RequestWrapInfo is a struct that stores information about desired response
|
||||
// wrapping behavior
|
||||
type RequestWrapInfo struct {
|
||||
// Setting to non-zero specifies that the response should be wrapped.
|
||||
// Specifies the desired TTL of the wrapping token.
|
||||
TTL time.Duration `json:"ttl" structs:"ttl" mapstructure:"ttl"`
|
||||
|
||||
// The format to use for the wrapped response; if not specified it's a bare
|
||||
// token
|
||||
Format string `json:"format" structs:"format" mapstructure:"format"`
|
||||
}
|
||||
|
||||
// Request is a struct that stores the parameters and context
|
||||
// of a request being made to Vault. It is used to abstract
|
||||
// the details of the higher level request protocol from the handlers.
|
||||
|
@ -61,9 +73,8 @@ type Request struct {
|
|||
// request path with the MountPoint trimmed off.
|
||||
MountPoint string `json:"mount_point" structs:"mount_point" mapstructure:"mount_point"`
|
||||
|
||||
// WrapTTL contains the requested TTL of the token used to wrap the
|
||||
// response in a cubbyhole.
|
||||
WrapTTL time.Duration `json:"wrap_ttl" struct:"wrap_ttl" mapstructure:"wrap_ttl"`
|
||||
// WrapInfo contains requested response wrapping parameters
|
||||
WrapInfo *RequestWrapInfo `json:"wrap_info" structs:"wrap_info" mapstructure:"wrap_info"`
|
||||
}
|
||||
|
||||
// Get returns a data field and guards for nil Data
|
||||
|
|
|
@ -28,7 +28,7 @@ const (
|
|||
HTTPStatusCode = "http_status_code"
|
||||
)
|
||||
|
||||
type WrapInfo struct {
|
||||
type ResponseWrapInfo struct {
|
||||
// Setting to non-zero specifies that the response should be wrapped.
|
||||
// Specifies the desired TTL of the wrapping token.
|
||||
TTL time.Duration `json:"ttl" structs:"ttl" mapstructure:"ttl"`
|
||||
|
@ -43,6 +43,9 @@ type WrapInfo struct {
|
|||
// If the contained response is the output of a token creation call, the
|
||||
// created token's accessor will be accessible here
|
||||
WrappedAccessor string `json:"wrapped_accessor" structs:"wrapped_accessor" mapstructure:"wrapped_accessor"`
|
||||
|
||||
// The format to use. This doesn't get returned, it's only internal.
|
||||
Format string `json:"format" structs:"format" mapstructure:"format"`
|
||||
}
|
||||
|
||||
// Response is a struct that stores the response of a request.
|
||||
|
@ -75,7 +78,7 @@ type Response struct {
|
|||
warnings []string `json:"warnings" structs:"warnings" mapstructure:"warnings"`
|
||||
|
||||
// Information for wrapping the response in a cubbyhole
|
||||
WrapInfo *WrapInfo `json:"wrap_info" structs:"wrap_info" mapstructure:"wrap_info"`
|
||||
WrapInfo *ResponseWrapInfo `json:"wrap_info" structs:"wrap_info" mapstructure:"wrap_info"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -120,7 +123,7 @@ func init() {
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("error copying WrapInfo: %v", err)
|
||||
}
|
||||
ret.WrapInfo = retWrapInfo.(*WrapInfo)
|
||||
ret.WrapInfo = retWrapInfo.(*ResponseWrapInfo)
|
||||
}
|
||||
|
||||
return &ret, nil
|
||||
|
|
|
@ -29,7 +29,8 @@ const (
|
|||
// Storage path where the local cluster name and identifier are stored
|
||||
coreLocalClusterInfoPath = "core/cluster/local/info"
|
||||
|
||||
corePrivateKeyTypeP521 = "p521"
|
||||
corePrivateKeyTypeP521 = "p521"
|
||||
corePrivateKeyTypeED25519 = "ed25519"
|
||||
|
||||
// Internal so as not to log a trace message
|
||||
IntNoForwardingHeaderName = "X-Vault-Internal-No-Request-Forwarding"
|
||||
|
@ -39,11 +40,13 @@ var (
|
|||
ErrCannotForward = errors.New("cannot forward request; no connection or address not known")
|
||||
)
|
||||
|
||||
// This can be one of a few key types so the different params may or may not be filled
|
||||
type clusterKeyParams struct {
|
||||
Type string `json:"type"`
|
||||
X *big.Int `json:"x"`
|
||||
Y *big.Int `json:"y"`
|
||||
D *big.Int `json:"d"`
|
||||
Type string `json:"type"`
|
||||
X *big.Int `json:"x,omitempty"`
|
||||
Y *big.Int `json:"y,omitempty"`
|
||||
D *big.Int `json:"d,omitempty"`
|
||||
ED25519Key []byte `json:"ed25519_key,omitempty"`
|
||||
}
|
||||
|
||||
type activeConnection struct {
|
||||
|
|
|
@ -251,6 +251,10 @@ type Core struct {
|
|||
// reloadFuncsLock controlls access to the funcs
|
||||
reloadFuncsLock sync.RWMutex
|
||||
|
||||
// wrappingJWTKey is the key used for generating JWTs containing response
|
||||
// wrapping information
|
||||
wrappingJWTKey *ecdsa.PrivateKey
|
||||
|
||||
//
|
||||
// Cluster information
|
||||
//
|
||||
|
@ -820,7 +824,7 @@ func (c *Core) Unseal(key []byte) (bool, error) {
|
|||
// Do post-unseal setup if HA is not enabled
|
||||
if c.ha == nil {
|
||||
// We still need to set up cluster info even if it's not part of a
|
||||
// cluster right now
|
||||
// cluster right now. This also populates the cached cluster object.
|
||||
if err := c.setupCluster(); err != nil {
|
||||
c.logger.Error("core: cluster setup failed", "error", err)
|
||||
c.barrier.Seal()
|
||||
|
@ -1139,6 +1143,9 @@ func (c *Core) postUnseal() (retErr error) {
|
|||
return err
|
||||
}
|
||||
}
|
||||
if err := c.ensureWrappingKey(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.loadMounts(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1556,27 +1563,3 @@ func (c *Core) BarrierKeyLength() (min, max int) {
|
|||
max += shamir.ShareOverhead
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Core) ValidateWrappingToken(token string) (bool, error) {
|
||||
if token == "" {
|
||||
return false, fmt.Errorf("token is empty")
|
||||
}
|
||||
|
||||
te, err := c.tokenStore.Lookup(token)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if te == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if len(te.Policies) != 1 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if te.Policies[0] != responseWrappingPolicyName {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
|
|
@ -41,6 +41,10 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen
|
|||
"raw/*",
|
||||
"rotate",
|
||||
},
|
||||
|
||||
Unauthenticated: []string{
|
||||
"wrapping/pubkey",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
|
@ -542,6 +546,20 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen
|
|||
HelpDescription: strings.TrimSpace(sysHelp["rotate"][1]),
|
||||
},
|
||||
|
||||
/*
|
||||
// Disabled for the moment as we don't support this externally
|
||||
&framework.Path{
|
||||
Pattern: "wrapping/pubkey$",
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.handleWrappingPubkey,
|
||||
},
|
||||
|
||||
HelpSynopsis: strings.TrimSpace(sysHelp["wrappubkey"][0]),
|
||||
HelpDescription: strings.TrimSpace(sysHelp["wrappubkey"][1]),
|
||||
},
|
||||
*/
|
||||
|
||||
&framework.Path{
|
||||
Pattern: "wrapping/wrap$",
|
||||
|
||||
|
@ -1472,9 +1490,22 @@ func (b *SystemBackend) handleRotate(
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
func (b *SystemBackend) handleWrappingPubkey(
|
||||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
x, _ := b.Core.wrappingJWTKey.X.MarshalText()
|
||||
y, _ := b.Core.wrappingJWTKey.Y.MarshalText()
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"jwt_x": string(x),
|
||||
"jwt_y": string(y),
|
||||
"jwt_curve": corePrivateKeyTypeP521,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *SystemBackend) handleWrappingWrap(
|
||||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
if req.WrapTTL == 0 {
|
||||
if req.WrapInfo == nil || req.WrapInfo.TTL == 0 {
|
||||
return logical.ErrorResponse("endpoint requires response wrapping to be used"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
|
@ -1497,6 +1528,10 @@ func (b *SystemBackend) handleWrappingUnwrap(
|
|||
token = req.ClientToken
|
||||
}
|
||||
|
||||
if wt := b.Core.parseVaultTokenFromJWT(token); wt != nil {
|
||||
token = *wt
|
||||
}
|
||||
|
||||
if thirdParty {
|
||||
// Use the token to decrement the use count to avoid a second operation on the token.
|
||||
_, err := b.Core.tokenStore.UseTokenByID(token)
|
||||
|
@ -1557,6 +1592,10 @@ func (b *SystemBackend) handleWrappingLookup(
|
|||
return logical.ErrorResponse("missing \"token\" value in input"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
if wt := b.Core.parseVaultTokenFromJWT(token); wt != nil {
|
||||
token = *wt
|
||||
}
|
||||
|
||||
cubbyReq := &logical.Request{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "cubbyhole/wrapinfo",
|
||||
|
@ -1613,6 +1652,10 @@ func (b *SystemBackend) handleWrappingRewrap(
|
|||
token = req.ClientToken
|
||||
}
|
||||
|
||||
if wt := b.Core.parseVaultTokenFromJWT(token); wt != nil {
|
||||
token = *wt
|
||||
}
|
||||
|
||||
if thirdParty {
|
||||
// Use the token to decrement the use count to avoid a second operation on the token.
|
||||
_, err := b.Core.tokenStore.UseTokenByID(token)
|
||||
|
@ -1683,7 +1726,7 @@ func (b *SystemBackend) handleWrappingRewrap(
|
|||
Data: map[string]interface{}{
|
||||
"response": response,
|
||||
},
|
||||
WrapInfo: &logical.WrapInfo{
|
||||
WrapInfo: &logical.ResponseWrapInfo{
|
||||
TTL: time.Duration(creationTTL),
|
||||
},
|
||||
}, nil
|
||||
|
@ -2087,6 +2130,11 @@ Enable a new audit backend or disable an existing backend.
|
|||
`Round trips the given input data into a response-wrapped token.`,
|
||||
},
|
||||
|
||||
"wrappubkey": {
|
||||
"Returns pubkeys used in some wrapping formats.",
|
||||
"Returns pubkeys used in some wrapping formats.",
|
||||
},
|
||||
|
||||
"unwrap": {
|
||||
"Unwraps a response-wrapped token.",
|
||||
`Unwraps a response-wrapped token. Unlike simply reading from cubbyhole/response,
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package vault
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -188,27 +187,37 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
|
|||
if resp != nil {
|
||||
// If wrapping is used, use the shortest between the request and response
|
||||
var wrapTTL time.Duration
|
||||
var wrapFormat string
|
||||
|
||||
// Ensure no wrap info information is set other than, possibly, the TTL
|
||||
if resp.WrapInfo != nil {
|
||||
if resp.WrapInfo.TTL > 0 {
|
||||
wrapTTL = resp.WrapInfo.TTL
|
||||
}
|
||||
wrapFormat = resp.WrapInfo.Format
|
||||
resp.WrapInfo = nil
|
||||
}
|
||||
|
||||
if req.WrapTTL > 0 {
|
||||
switch {
|
||||
case wrapTTL == 0:
|
||||
wrapTTL = req.WrapTTL
|
||||
case req.WrapTTL < wrapTTL:
|
||||
wrapTTL = req.WrapTTL
|
||||
if req.WrapInfo != nil {
|
||||
if req.WrapInfo.TTL > 0 {
|
||||
switch {
|
||||
case wrapTTL == 0:
|
||||
wrapTTL = req.WrapInfo.TTL
|
||||
case req.WrapInfo.TTL < wrapTTL:
|
||||
wrapTTL = req.WrapInfo.TTL
|
||||
}
|
||||
}
|
||||
// If the wrap format hasn't been set by the response, set it to
|
||||
// the request format
|
||||
if req.WrapInfo.Format != "" && wrapFormat == "" {
|
||||
wrapFormat = req.WrapInfo.Format
|
||||
}
|
||||
}
|
||||
|
||||
if wrapTTL > 0 {
|
||||
resp.WrapInfo = &logical.WrapInfo{
|
||||
TTL: wrapTTL,
|
||||
resp.WrapInfo = &logical.ResponseWrapInfo{
|
||||
TTL: wrapTTL,
|
||||
Format: wrapFormat,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -325,30 +334,37 @@ func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *log
|
|||
if resp != nil {
|
||||
// If wrapping is used, use the shortest between the request and response
|
||||
var wrapTTL time.Duration
|
||||
var wrapFormat string
|
||||
|
||||
// Ensure no wrap info information is set other than, possibly, the TTL
|
||||
if resp.WrapInfo != nil {
|
||||
if resp.WrapInfo.TTL > 0 {
|
||||
wrapTTL = resp.WrapInfo.TTL
|
||||
}
|
||||
wrapFormat = resp.WrapInfo.Format
|
||||
resp.WrapInfo = nil
|
||||
}
|
||||
|
||||
if req.WrapTTL > 0 {
|
||||
switch {
|
||||
case wrapTTL == 0:
|
||||
wrapTTL = req.WrapTTL
|
||||
case req.WrapTTL < wrapTTL:
|
||||
wrapTTL = req.WrapTTL
|
||||
if req.WrapInfo != nil {
|
||||
if req.WrapInfo.TTL > 0 {
|
||||
switch {
|
||||
case wrapTTL == 0:
|
||||
wrapTTL = req.WrapInfo.TTL
|
||||
case req.WrapInfo.TTL < wrapTTL:
|
||||
wrapTTL = req.WrapInfo.TTL
|
||||
}
|
||||
}
|
||||
if req.WrapInfo.Format != "" && wrapFormat == "" {
|
||||
wrapFormat = req.WrapInfo.Format
|
||||
}
|
||||
}
|
||||
|
||||
if wrapTTL > 0 {
|
||||
resp.WrapInfo = &logical.WrapInfo{
|
||||
TTL: wrapTTL,
|
||||
resp.WrapInfo = &logical.ResponseWrapInfo{
|
||||
TTL: wrapTTL,
|
||||
Format: wrapFormat,
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// A login request should never return a secret!
|
||||
|
@ -431,138 +447,3 @@ func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *log
|
|||
|
||||
return resp, auth, err
|
||||
}
|
||||
|
||||
func (c *Core) wrapInCubbyhole(req *logical.Request, resp *logical.Response) (*logical.Response, error) {
|
||||
// Before wrapping, obey special rules for listing: if no entries are
|
||||
// found, 404. This prevents unwrapping only to find empty data.
|
||||
if req.Operation == logical.ListOperation {
|
||||
if resp == nil || len(resp.Data) == 0 {
|
||||
return nil, logical.ErrUnsupportedPath
|
||||
}
|
||||
keysRaw, ok := resp.Data["keys"]
|
||||
if !ok || keysRaw == nil {
|
||||
return nil, logical.ErrUnsupportedPath
|
||||
}
|
||||
keys, ok := keysRaw.([]string)
|
||||
if !ok {
|
||||
return nil, logical.ErrUnsupportedPath
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
return nil, logical.ErrUnsupportedPath
|
||||
}
|
||||
}
|
||||
|
||||
// If we are wrapping, the first part (performed in this functions) happens
|
||||
// before auditing so that resp.WrapInfo.Token can contain the HMAC'd
|
||||
// wrapping token ID in the audit logs, so that it can be determined from
|
||||
// the audit logs whether the token was ever actually used.
|
||||
creationTime := time.Now()
|
||||
te := TokenEntry{
|
||||
Path: req.Path,
|
||||
Policies: []string{"response-wrapping"},
|
||||
CreationTime: creationTime.Unix(),
|
||||
TTL: resp.WrapInfo.TTL,
|
||||
NumUses: 1,
|
||||
ExplicitMaxTTL: resp.WrapInfo.TTL,
|
||||
}
|
||||
|
||||
if err := c.tokenStore.create(&te); err != nil {
|
||||
c.logger.Error("core: failed to create wrapping token", "error", err)
|
||||
return nil, ErrInternalError
|
||||
}
|
||||
|
||||
resp.WrapInfo.Token = te.ID
|
||||
resp.WrapInfo.CreationTime = creationTime
|
||||
|
||||
// This will only be non-nil if this response contains a token, so in that
|
||||
// case put the accessor in the wrap info.
|
||||
if resp.Auth != nil {
|
||||
resp.WrapInfo.WrappedAccessor = resp.Auth.Accessor
|
||||
}
|
||||
|
||||
cubbyReq := &logical.Request{
|
||||
Operation: logical.CreateOperation,
|
||||
Path: "cubbyhole/response",
|
||||
ClientToken: te.ID,
|
||||
}
|
||||
|
||||
// During a rewrap, store the original response, don't wrap it again.
|
||||
if req.Path == "sys/wrapping/rewrap" {
|
||||
cubbyReq.Data = map[string]interface{}{
|
||||
"response": resp.Data["response"],
|
||||
}
|
||||
} else {
|
||||
httpResponse := logical.LogicalResponseToHTTPResponse(resp)
|
||||
|
||||
// Add the unique identifier of the original request to the response
|
||||
httpResponse.RequestID = req.ID
|
||||
|
||||
// Because of the way that JSON encodes (likely just in Go) we actually get
|
||||
// mixed-up values for ints if we simply put this object in the response
|
||||
// and encode the whole thing; so instead we marshal it first, then store
|
||||
// the string response. This actually ends up making it easier on the
|
||||
// client side, too, as it becomes a straight read-string-pass-to-unmarshal
|
||||
// operation.
|
||||
|
||||
marshaledResponse, err := json.Marshal(httpResponse)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to marshal wrapped response", "error", err)
|
||||
return nil, ErrInternalError
|
||||
}
|
||||
|
||||
cubbyReq.Data = map[string]interface{}{
|
||||
"response": string(marshaledResponse),
|
||||
}
|
||||
}
|
||||
|
||||
cubbyResp, err := c.router.Route(cubbyReq)
|
||||
if err != nil {
|
||||
// Revoke since it's not yet being tracked for expiration
|
||||
c.tokenStore.Revoke(te.ID)
|
||||
c.logger.Error("core: failed to store wrapped response information", "error", err)
|
||||
return nil, ErrInternalError
|
||||
}
|
||||
if cubbyResp != nil && cubbyResp.IsError() {
|
||||
c.tokenStore.Revoke(te.ID)
|
||||
c.logger.Error("core: failed to store wrapped response information", "error", cubbyResp.Data["error"])
|
||||
return cubbyResp, nil
|
||||
}
|
||||
|
||||
// Store info for lookup
|
||||
cubbyReq.Path = "cubbyhole/wrapinfo"
|
||||
cubbyReq.Data = map[string]interface{}{
|
||||
"creation_ttl": resp.WrapInfo.TTL,
|
||||
"creation_time": creationTime,
|
||||
}
|
||||
cubbyResp, err = c.router.Route(cubbyReq)
|
||||
if err != nil {
|
||||
// Revoke since it's not yet being tracked for expiration
|
||||
c.tokenStore.Revoke(te.ID)
|
||||
c.logger.Error("core: failed to store wrapping information", "error", err)
|
||||
return nil, ErrInternalError
|
||||
}
|
||||
if cubbyResp != nil && cubbyResp.IsError() {
|
||||
c.tokenStore.Revoke(te.ID)
|
||||
c.logger.Error("core: failed to store wrapping information", "error", cubbyResp.Data["error"])
|
||||
return cubbyResp, nil
|
||||
}
|
||||
|
||||
auth := &logical.Auth{
|
||||
ClientToken: te.ID,
|
||||
Policies: []string{"response-wrapping"},
|
||||
LeaseOptions: logical.LeaseOptions{
|
||||
TTL: te.TTL,
|
||||
Renewable: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Register the wrapped token with the expiration manager
|
||||
if err := c.expiration.RegisterAuth(te.Path, auth); err != nil {
|
||||
// Revoke since it's not yet being tracked for expiration
|
||||
c.tokenStore.Revoke(te.ID)
|
||||
c.logger.Error("core: failed to register cubbyhole wrapping token lease", "request_path", req.Path, "error", err)
|
||||
return nil, ErrInternalError
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
|
|
@ -46,7 +46,9 @@ func TestRequestHandling_Wrapping(t *testing.T) {
|
|||
Path: "wraptest/foo",
|
||||
ClientToken: root,
|
||||
Operation: logical.ReadOperation,
|
||||
WrapTTL: time.Duration(15 * time.Second),
|
||||
WrapInfo: &logical.RequestWrapInfo{
|
||||
TTL: time.Duration(15 * time.Second),
|
||||
},
|
||||
}
|
||||
resp, err = core.HandleRequest(req)
|
||||
if err != nil {
|
||||
|
@ -120,7 +122,9 @@ func TestRequestHandling_LoginWrapping(t *testing.T) {
|
|||
req = &logical.Request{
|
||||
Path: "auth/userpass/login/test",
|
||||
Operation: logical.UpdateOperation,
|
||||
WrapTTL: time.Duration(15 * time.Second),
|
||||
WrapInfo: &logical.RequestWrapInfo{
|
||||
TTL: time.Duration(15 * time.Second),
|
||||
},
|
||||
Data: map[string]interface{}{
|
||||
"password": "foo",
|
||||
},
|
||||
|
|
|
@ -252,8 +252,14 @@ func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (*logica
|
|||
// Cache the identifier of the request
|
||||
originalReqID := req.ID
|
||||
|
||||
// Cache the wrap TTL of the request
|
||||
originalWrapTTL := req.WrapTTL
|
||||
// Cache the wrap info of the request
|
||||
var wrapInfo *logical.RequestWrapInfo
|
||||
if req.WrapInfo != nil {
|
||||
wrapInfo = &logical.RequestWrapInfo{
|
||||
TTL: req.WrapInfo.TTL,
|
||||
Format: req.WrapInfo.Format,
|
||||
}
|
||||
}
|
||||
|
||||
// Reset the request before returning
|
||||
defer func() {
|
||||
|
@ -263,7 +269,7 @@ func (r *Router) routeCommon(req *logical.Request, existenceCheck bool) (*logica
|
|||
req.ID = originalReqID
|
||||
req.Storage = nil
|
||||
req.ClientToken = clientToken
|
||||
req.WrapTTL = originalWrapTTL
|
||||
req.WrapInfo = wrapInfo
|
||||
}()
|
||||
|
||||
// Invoke the backend
|
||||
|
|
|
@ -0,0 +1,309 @@
|
|||
package vault
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/SermoDigital/jose/crypto"
|
||||
"github.com/SermoDigital/jose/jws"
|
||||
"github.com/SermoDigital/jose/jwt"
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
const (
|
||||
// The location of the key used to generate response-wrapping JWTs
|
||||
coreWrappingJWTKeyPath = "core/wrapping/jwtkey"
|
||||
)
|
||||
|
||||
func (c *Core) ensureWrappingKey() error {
|
||||
entry, err := c.barrier.Get(coreWrappingJWTKeyPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var keyParams clusterKeyParams
|
||||
|
||||
if entry == nil {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("failed to generate wrapping key: {{err}}", err)
|
||||
}
|
||||
keyParams.D = key.D
|
||||
keyParams.X = key.X
|
||||
keyParams.Y = key.Y
|
||||
keyParams.Type = corePrivateKeyTypeP521
|
||||
val, err := jsonutil.EncodeJSON(keyParams)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("failed to encode wrapping key: {{err}}", err)
|
||||
}
|
||||
entry = &Entry{
|
||||
Key: coreWrappingJWTKeyPath,
|
||||
Value: val,
|
||||
}
|
||||
if err = c.barrier.Put(entry); err != nil {
|
||||
return errwrap.Wrapf("failed to store wrapping key: {{err}}", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Redundant if we just created it, but in this case serves as a check anyways
|
||||
if err = jsonutil.DecodeJSON(entry.Value, &keyParams); err != nil {
|
||||
return errwrap.Wrapf("failed to decode wrapping key parameters: {{err}}", err)
|
||||
}
|
||||
|
||||
c.wrappingJWTKey = &ecdsa.PrivateKey{
|
||||
PublicKey: ecdsa.PublicKey{
|
||||
Curve: elliptic.P521(),
|
||||
X: keyParams.X,
|
||||
Y: keyParams.Y,
|
||||
},
|
||||
D: keyParams.D,
|
||||
}
|
||||
|
||||
c.logger.Info("core: loaded wrapping token key")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Core) wrapInCubbyhole(req *logical.Request, resp *logical.Response) (*logical.Response, error) {
|
||||
// Before wrapping, obey special rules for listing: if no entries are
|
||||
// found, 404. This prevents unwrapping only to find empty data.
|
||||
if req.Operation == logical.ListOperation {
|
||||
if resp == nil || len(resp.Data) == 0 {
|
||||
return nil, logical.ErrUnsupportedPath
|
||||
}
|
||||
keysRaw, ok := resp.Data["keys"]
|
||||
if !ok || keysRaw == nil {
|
||||
return nil, logical.ErrUnsupportedPath
|
||||
}
|
||||
keys, ok := keysRaw.([]string)
|
||||
if !ok {
|
||||
return nil, logical.ErrUnsupportedPath
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
return nil, logical.ErrUnsupportedPath
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
// If we are wrapping, the first part (performed in this functions) happens
|
||||
// before auditing so that resp.WrapInfo.Token can contain the HMAC'd
|
||||
// wrapping token ID in the audit logs, so that it can be determined from
|
||||
// the audit logs whether the token was ever actually used.
|
||||
creationTime := time.Now()
|
||||
te := TokenEntry{
|
||||
Path: req.Path,
|
||||
Policies: []string{"response-wrapping"},
|
||||
CreationTime: creationTime.Unix(),
|
||||
TTL: resp.WrapInfo.TTL,
|
||||
NumUses: 1,
|
||||
ExplicitMaxTTL: resp.WrapInfo.TTL,
|
||||
}
|
||||
|
||||
if err := c.tokenStore.create(&te); err != nil {
|
||||
c.logger.Error("core: failed to create wrapping token", "error", err)
|
||||
return nil, ErrInternalError
|
||||
}
|
||||
|
||||
resp.WrapInfo.Token = te.ID
|
||||
resp.WrapInfo.CreationTime = creationTime
|
||||
|
||||
// This will only be non-nil if this response contains a token, so in that
|
||||
// case put the accessor in the wrap info.
|
||||
if resp.Auth != nil {
|
||||
resp.WrapInfo.WrappedAccessor = resp.Auth.Accessor
|
||||
}
|
||||
|
||||
switch resp.WrapInfo.Format {
|
||||
case "jwt":
|
||||
// Create the JWT
|
||||
claims := jws.Claims{}
|
||||
// Map the JWT ID to the token ID for ease ofuse
|
||||
claims.SetJWTID(te.ID)
|
||||
// Set the issue time to the creation time
|
||||
claims.SetIssuedAt(creationTime)
|
||||
// Set the expiration to the TTL
|
||||
claims.SetExpiration(creationTime.Add(resp.WrapInfo.TTL))
|
||||
if resp.Auth != nil {
|
||||
claims.Set("accessor", resp.Auth.Accessor)
|
||||
}
|
||||
claims.Set("type", "wrapping")
|
||||
claims.Set("addr", c.redirectAddr)
|
||||
jwt := jws.NewJWT(claims, crypto.SigningMethodES512)
|
||||
serWebToken, err := jwt.Serialize(c.wrappingJWTKey)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to serialize JWT", "error", err)
|
||||
return nil, ErrInternalError
|
||||
}
|
||||
resp.WrapInfo.Token = string(serWebToken)
|
||||
}
|
||||
|
||||
cubbyReq := &logical.Request{
|
||||
Operation: logical.CreateOperation,
|
||||
Path: "cubbyhole/response",
|
||||
ClientToken: te.ID,
|
||||
}
|
||||
|
||||
// During a rewrap, store the original response, don't wrap it again.
|
||||
if req.Path == "sys/wrapping/rewrap" {
|
||||
cubbyReq.Data = map[string]interface{}{
|
||||
"response": resp.Data["response"],
|
||||
}
|
||||
} else {
|
||||
httpResponse := logical.LogicalResponseToHTTPResponse(resp)
|
||||
|
||||
// Add the unique identifier of the original request to the response
|
||||
httpResponse.RequestID = req.ID
|
||||
|
||||
// Because of the way that JSON encodes (likely just in Go) we actually get
|
||||
// mixed-up values for ints if we simply put this object in the response
|
||||
// and encode the whole thing; so instead we marshal it first, then store
|
||||
// the string response. This actually ends up making it easier on the
|
||||
// client side, too, as it becomes a straight read-string-pass-to-unmarshal
|
||||
// operation.
|
||||
|
||||
marshaledResponse, err := json.Marshal(httpResponse)
|
||||
if err != nil {
|
||||
c.logger.Error("core: failed to marshal wrapped response", "error", err)
|
||||
return nil, ErrInternalError
|
||||
}
|
||||
|
||||
cubbyReq.Data = map[string]interface{}{
|
||||
"response": string(marshaledResponse),
|
||||
}
|
||||
}
|
||||
|
||||
cubbyResp, err := c.router.Route(cubbyReq)
|
||||
if err != nil {
|
||||
// Revoke since it's not yet being tracked for expiration
|
||||
c.tokenStore.Revoke(te.ID)
|
||||
c.logger.Error("core: failed to store wrapped response information", "error", err)
|
||||
return nil, ErrInternalError
|
||||
}
|
||||
if cubbyResp != nil && cubbyResp.IsError() {
|
||||
c.tokenStore.Revoke(te.ID)
|
||||
c.logger.Error("core: failed to store wrapped response information", "error", cubbyResp.Data["error"])
|
||||
return cubbyResp, nil
|
||||
}
|
||||
|
||||
// Store info for lookup
|
||||
cubbyReq.Path = "cubbyhole/wrapinfo"
|
||||
cubbyReq.Data = map[string]interface{}{
|
||||
"creation_ttl": resp.WrapInfo.TTL,
|
||||
"creation_time": creationTime,
|
||||
}
|
||||
cubbyResp, err = c.router.Route(cubbyReq)
|
||||
if err != nil {
|
||||
// Revoke since it's not yet being tracked for expiration
|
||||
c.tokenStore.Revoke(te.ID)
|
||||
c.logger.Error("core: failed to store wrapping information", "error", err)
|
||||
return nil, ErrInternalError
|
||||
}
|
||||
if cubbyResp != nil && cubbyResp.IsError() {
|
||||
c.tokenStore.Revoke(te.ID)
|
||||
c.logger.Error("core: failed to store wrapping information", "error", cubbyResp.Data["error"])
|
||||
return cubbyResp, nil
|
||||
}
|
||||
|
||||
auth := &logical.Auth{
|
||||
ClientToken: te.ID,
|
||||
Policies: []string{"response-wrapping"},
|
||||
LeaseOptions: logical.LeaseOptions{
|
||||
TTL: te.TTL,
|
||||
Renewable: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Register the wrapped token with the expiration manager
|
||||
if err := c.expiration.RegisterAuth(te.Path, auth); err != nil {
|
||||
// Revoke since it's not yet being tracked for expiration
|
||||
c.tokenStore.Revoke(te.ID)
|
||||
c.logger.Error("core: failed to register cubbyhole wrapping token lease", "request_path", req.Path, "error", err)
|
||||
return nil, ErrInternalError
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *Core) ValidateWrappingToken(req *logical.Request) (bool, error) {
|
||||
if req == nil {
|
||||
return false, fmt.Errorf("invalid request")
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
var token string
|
||||
if req.Data != nil && req.Data["token"] != nil {
|
||||
if tokenStr, ok := req.Data["token"].(string); !ok {
|
||||
return false, fmt.Errorf("could not decode token in request body")
|
||||
} else if tokenStr == "" {
|
||||
return false, fmt.Errorf("empty token in request body")
|
||||
} else {
|
||||
token = tokenStr
|
||||
}
|
||||
} else {
|
||||
token = req.ClientToken
|
||||
}
|
||||
|
||||
// Check for it being a JWT. If it is, and it is valid, we extract the
|
||||
// internal client token from it and use that during lookup.
|
||||
if strings.Count(token, ".") == 2 {
|
||||
wt, err := jws.ParseJWT([]byte(token))
|
||||
// If there's an error we simply fall back to attempting to use it as a regular token
|
||||
if err == nil && wt != nil {
|
||||
validator := &jwt.Validator{}
|
||||
validator.SetClaim("type", "wrapping")
|
||||
if err = wt.Validate(&c.wrappingJWTKey.PublicKey, crypto.SigningMethodES512, []*jwt.Validator{validator}...); err != nil {
|
||||
return false, errwrap.Wrapf("wrapping token signature could not be validated: {{err}}", err)
|
||||
}
|
||||
token, _ = wt.Claims().JWTID()
|
||||
}
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
return false, fmt.Errorf("token is empty")
|
||||
}
|
||||
|
||||
te, err := c.tokenStore.Lookup(token)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if te == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if len(te.Policies) != 1 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if te.Policies[0] != responseWrappingPolicyName {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// parseVaultTokenFromJWT returns a string iff the token was a JWT and we could
|
||||
// extract the original token ID from inside
|
||||
func (c *Core) parseVaultTokenFromJWT(token string) *string {
|
||||
var result string
|
||||
if strings.Count(token, ".") != 2 {
|
||||
return nil
|
||||
}
|
||||
|
||||
wt, err := jws.ParseJWT([]byte(token))
|
||||
if err != nil || wt == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
result, _ = wt.Claims().JWTID()
|
||||
|
||||
return &result
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015 Sermo Digital LLC
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
JOSE
|
||||
============
|
||||
[![Build Status](https://travis-ci.org/SermoDigital/jose.svg?branch=master)](https://travis-ci.org/SermoDigital/jose)
|
||||
[![GoDoc](https://godoc.org/github.com/SermoDigital/jose?status.svg)](https://godoc.org/github.com/SermoDigital/jose)
|
||||
|
||||
JOSE is a comprehensive set of JWT, JWS, and JWE libraries.
|
||||
|
||||
## Why
|
||||
|
||||
The only other JWS/JWE/JWT implementations are specific to JWT, and none
|
||||
were particularly pleasant to work with.
|
||||
|
||||
These libraries should provide an easy, straightforward way to securely
|
||||
create, parse, and validate JWS, JWE, and JWTs.
|
||||
|
||||
## Notes:
|
||||
JWE is currently unimplemented.
|
||||
|
||||
## Version 0.9:
|
||||
|
||||
## Documentation
|
||||
|
||||
The docs can be found at [godoc.org] [docs], as usual.
|
||||
|
||||
A gopkg.in mirror can be found at https://gopkg.in/jose.v1, thanks to
|
||||
@zia-newversion. (For context, see issue #30.)
|
||||
|
||||
### [JWS RFC][jws]
|
||||
### [JWE RFC][jwe]
|
||||
### [JWT RFC][jwt]
|
||||
|
||||
## License
|
||||
|
||||
[MIT] [license].
|
||||
|
||||
[docs]: https://godoc.org/github.com/SermoDigital/jose
|
||||
[license]: https://github.com/SermoDigital/jose/blob/master/LICENSE.md
|
||||
[jws]: https://tools.ietf.org/html/rfc7515
|
||||
[jwe]: https://tools.ietf.org/html/rfc7516
|
||||
[jwt]: https://tools.ietf.org/html/rfc7519
|
|
@ -0,0 +1,8 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
go build ./...
|
||||
go test ./...
|
||||
golint ./...
|
||||
go vet ./...
|
|
@ -0,0 +1,44 @@
|
|||
package jose
|
||||
|
||||
import "encoding/base64"
|
||||
|
||||
// Encoder is satisfied if the type can marshal itself into a valid
|
||||
// structure for a JWS.
|
||||
type Encoder interface {
|
||||
// Base64 implies T -> JSON -> RawURLEncodingBase64
|
||||
Base64() ([]byte, error)
|
||||
}
|
||||
|
||||
// Base64Decode decodes a base64-encoded byte slice.
|
||||
func Base64Decode(b []byte) ([]byte, error) {
|
||||
buf := make([]byte, base64.RawURLEncoding.DecodedLen(len(b)))
|
||||
n, err := base64.RawURLEncoding.Decode(buf, b)
|
||||
return buf[:n], err
|
||||
}
|
||||
|
||||
// Base64Encode encodes a byte slice.
|
||||
func Base64Encode(b []byte) []byte {
|
||||
buf := make([]byte, base64.RawURLEncoding.EncodedLen(len(b)))
|
||||
base64.RawURLEncoding.Encode(buf, b)
|
||||
return buf
|
||||
}
|
||||
|
||||
// EncodeEscape base64-encodes a byte slice but escapes it for JSON.
|
||||
// It'll return the format: `"base64"`
|
||||
func EncodeEscape(b []byte) []byte {
|
||||
buf := make([]byte, base64.RawURLEncoding.EncodedLen(len(b))+2)
|
||||
buf[0] = '"'
|
||||
base64.RawURLEncoding.Encode(buf[1:], b)
|
||||
buf[len(buf)-1] = '"'
|
||||
return buf
|
||||
}
|
||||
|
||||
// DecodeEscaped decodes a base64-encoded byte slice straight from a JSON
|
||||
// structure. It assumes it's in the format: `"base64"`, but can handle
|
||||
// cases where it's not.
|
||||
func DecodeEscaped(b []byte) ([]byte, error) {
|
||||
if len(b) > 1 && b[0] == '"' && b[len(b)-1] == '"' {
|
||||
b = b[1 : len(b)-1]
|
||||
}
|
||||
return Base64Decode(b)
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
// Package crypto implements "SigningMethods" and "EncryptionMethods";
|
||||
// that is, ways to sign and encrypt JWS and JWEs, respectively, as well
|
||||
// as JWTs.
|
||||
package crypto
|
|
@ -0,0 +1,117 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"encoding/asn1"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
// ErrECDSAVerification is missing from crypto/ecdsa compared to crypto/rsa
|
||||
var ErrECDSAVerification = errors.New("crypto/ecdsa: verification error")
|
||||
|
||||
// SigningMethodECDSA implements the ECDSA family of signing methods signing
|
||||
// methods
|
||||
type SigningMethodECDSA struct {
|
||||
Name string
|
||||
Hash crypto.Hash
|
||||
_ struct{}
|
||||
}
|
||||
|
||||
// ECPoint is a marshalling structure for the EC points R and S.
|
||||
type ECPoint struct {
|
||||
R *big.Int
|
||||
S *big.Int
|
||||
}
|
||||
|
||||
// Specific instances of EC SigningMethods.
|
||||
var (
|
||||
// SigningMethodES256 implements ES256.
|
||||
SigningMethodES256 = &SigningMethodECDSA{
|
||||
Name: "ES256",
|
||||
Hash: crypto.SHA256,
|
||||
}
|
||||
|
||||
// SigningMethodES384 implements ES384.
|
||||
SigningMethodES384 = &SigningMethodECDSA{
|
||||
Name: "ES384",
|
||||
Hash: crypto.SHA384,
|
||||
}
|
||||
|
||||
// SigningMethodES512 implements ES512.
|
||||
SigningMethodES512 = &SigningMethodECDSA{
|
||||
Name: "ES512",
|
||||
Hash: crypto.SHA512,
|
||||
}
|
||||
)
|
||||
|
||||
// Alg returns the name of the SigningMethodECDSA instance.
|
||||
func (m *SigningMethodECDSA) Alg() string { return m.Name }
|
||||
|
||||
// Verify implements the Verify method from SigningMethod.
|
||||
// For this verify method, key must be an *ecdsa.PublicKey.
|
||||
func (m *SigningMethodECDSA) Verify(raw []byte, signature Signature, key interface{}) error {
|
||||
|
||||
ecdsaKey, ok := key.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
return ErrInvalidKey
|
||||
}
|
||||
|
||||
// Unmarshal asn1 ECPoint
|
||||
var ecpoint ECPoint
|
||||
if _, err := asn1.Unmarshal(signature, &ecpoint); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Verify the signature
|
||||
if !ecdsa.Verify(ecdsaKey, m.sum(raw), ecpoint.R, ecpoint.S) {
|
||||
return ErrECDSAVerification
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sign implements the Sign method from SigningMethod.
|
||||
// For this signing method, key must be an *ecdsa.PrivateKey.
|
||||
func (m *SigningMethodECDSA) Sign(data []byte, key interface{}) (Signature, error) {
|
||||
|
||||
ecdsaKey, ok := key.(*ecdsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, ErrInvalidKey
|
||||
}
|
||||
|
||||
r, s, err := ecdsa.Sign(rand.Reader, ecdsaKey, m.sum(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
signature, err := asn1.Marshal(ECPoint{R: r, S: s})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Signature(signature), nil
|
||||
}
|
||||
|
||||
func (m *SigningMethodECDSA) sum(b []byte) []byte {
|
||||
h := m.Hash.New()
|
||||
h.Write(b)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
// Hasher implements the Hasher method from SigningMethod.
|
||||
func (m *SigningMethodECDSA) Hasher() crypto.Hash {
|
||||
return m.Hash
|
||||
}
|
||||
|
||||
// MarshalJSON is in case somebody decides to place SigningMethodECDSA
|
||||
// inside the Header, presumably because they (wrongly) decided it was a good
|
||||
// idea to use the SigningMethod itself instead of the SigningMethod's Alg
|
||||
// method. In order to keep things sane, marshalling this will simply
|
||||
// return the JSON-compatible representation of m.Alg().
|
||||
func (m *SigningMethodECDSA) MarshalJSON() ([]byte, error) {
|
||||
return []byte(`"` + m.Alg() + `"`), nil
|
||||
}
|
||||
|
||||
var _ json.Marshaler = (*SigningMethodECDSA)(nil)
|
|
@ -0,0 +1,48 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// ECDSA parsing errors.
|
||||
var (
|
||||
ErrNotECPublicKey = errors.New("Key is not a valid ECDSA public key")
|
||||
ErrNotECPrivateKey = errors.New("Key is not a valid ECDSA private key")
|
||||
)
|
||||
|
||||
// ParseECPrivateKeyFromPEM will parse a PEM encoded EC Private
|
||||
// Key Structure.
|
||||
func ParseECPrivateKeyFromPEM(key []byte) (*ecdsa.PrivateKey, error) {
|
||||
block, _ := pem.Decode(key)
|
||||
if block == nil {
|
||||
return nil, ErrKeyMustBePEMEncoded
|
||||
}
|
||||
return x509.ParseECPrivateKey(block.Bytes)
|
||||
}
|
||||
|
||||
// ParseECPublicKeyFromPEM will parse a PEM encoded PKCS1 or PKCS8 public key
|
||||
func ParseECPublicKeyFromPEM(key []byte) (*ecdsa.PublicKey, error) {
|
||||
|
||||
block, _ := pem.Decode(key)
|
||||
if block == nil {
|
||||
return nil, ErrKeyMustBePEMEncoded
|
||||
}
|
||||
|
||||
parsedKey, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsedKey = cert.PublicKey
|
||||
}
|
||||
|
||||
pkey, ok := parsedKey.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
return nil, ErrNotECPublicKey
|
||||
}
|
||||
return pkey, nil
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
package crypto
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrInvalidKey means the key argument passed to SigningMethod.Verify
|
||||
// was not the correct type.
|
||||
ErrInvalidKey = errors.New("key is invalid")
|
||||
)
|
|
@ -0,0 +1,81 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/hmac"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// SigningMethodHMAC implements the HMAC-SHA family of SigningMethods.
|
||||
type SigningMethodHMAC struct {
|
||||
Name string
|
||||
Hash crypto.Hash
|
||||
_ struct{}
|
||||
}
|
||||
|
||||
// Specific instances of HMAC-SHA SigningMethods.
|
||||
var (
|
||||
// SigningMethodHS256 implements HS256.
|
||||
SigningMethodHS256 = &SigningMethodHMAC{
|
||||
Name: "HS256",
|
||||
Hash: crypto.SHA256,
|
||||
}
|
||||
|
||||
// SigningMethodHS384 implements HS384.
|
||||
SigningMethodHS384 = &SigningMethodHMAC{
|
||||
Name: "HS384",
|
||||
Hash: crypto.SHA384,
|
||||
}
|
||||
|
||||
// SigningMethodHS512 implements HS512.
|
||||
SigningMethodHS512 = &SigningMethodHMAC{
|
||||
Name: "HS512",
|
||||
Hash: crypto.SHA512,
|
||||
}
|
||||
|
||||
// ErrSignatureInvalid is returned when the provided signature is found
|
||||
// to be invalid.
|
||||
ErrSignatureInvalid = errors.New("signature is invalid")
|
||||
)
|
||||
|
||||
// Alg implements the SigningMethod interface.
|
||||
func (m *SigningMethodHMAC) Alg() string { return m.Name }
|
||||
|
||||
// Verify implements the Verify method from SigningMethod.
|
||||
// For this signing method, must be a []byte.
|
||||
func (m *SigningMethodHMAC) Verify(raw []byte, signature Signature, key interface{}) error {
|
||||
keyBytes, ok := key.([]byte)
|
||||
if !ok {
|
||||
return ErrInvalidKey
|
||||
}
|
||||
hasher := hmac.New(m.Hash.New, keyBytes)
|
||||
hasher.Write(raw)
|
||||
if hmac.Equal(signature, hasher.Sum(nil)) {
|
||||
return nil
|
||||
}
|
||||
return ErrSignatureInvalid
|
||||
}
|
||||
|
||||
// Sign implements the Sign method from SigningMethod for this signing method.
|
||||
// Key must be a []byte.
|
||||
func (m *SigningMethodHMAC) Sign(data []byte, key interface{}) (Signature, error) {
|
||||
keyBytes, ok := key.([]byte)
|
||||
if !ok {
|
||||
return nil, ErrInvalidKey
|
||||
}
|
||||
hasher := hmac.New(m.Hash.New, keyBytes)
|
||||
hasher.Write(data)
|
||||
return Signature(hasher.Sum(nil)), nil
|
||||
}
|
||||
|
||||
// Hasher implements the SigningMethod interface.
|
||||
func (m *SigningMethodHMAC) Hasher() crypto.Hash { return m.Hash }
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
// See SigningMethodECDSA.MarshalJSON() for information.
|
||||
func (m *SigningMethodHMAC) MarshalJSON() ([]byte, error) {
|
||||
return []byte(`"` + m.Alg() + `"`), nil
|
||||
}
|
||||
|
||||
var _ json.Marshaler = (*SigningMethodHMAC)(nil)
|
|
@ -0,0 +1,72 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"encoding/json"
|
||||
"hash"
|
||||
"io"
|
||||
)
|
||||
|
||||
func init() {
|
||||
crypto.RegisterHash(crypto.Hash(0), h)
|
||||
}
|
||||
|
||||
// h is passed to crypto.RegisterHash.
|
||||
func h() hash.Hash {
|
||||
return &f{Writer: nil}
|
||||
}
|
||||
|
||||
type f struct{ io.Writer }
|
||||
|
||||
// Sum helps implement the hash.Hash interface.
|
||||
func (_ *f) Sum(b []byte) []byte { return nil }
|
||||
|
||||
// Reset helps implement the hash.Hash interface.
|
||||
func (_ *f) Reset() {}
|
||||
|
||||
// Size helps implement the hash.Hash interface.
|
||||
func (_ *f) Size() int { return -1 }
|
||||
|
||||
// BlockSize helps implement the hash.Hash interface.
|
||||
func (_ *f) BlockSize() int { return -1 }
|
||||
|
||||
// Unsecured is the default "none" algorithm.
|
||||
var Unsecured = &SigningMethodNone{
|
||||
Name: "none",
|
||||
Hash: crypto.Hash(0),
|
||||
}
|
||||
|
||||
// SigningMethodNone is the default "none" algorithm.
|
||||
type SigningMethodNone struct {
|
||||
Name string
|
||||
Hash crypto.Hash
|
||||
_ struct{}
|
||||
}
|
||||
|
||||
// Verify helps implement the SigningMethod interface.
|
||||
func (_ *SigningMethodNone) Verify(_ []byte, _ Signature, _ interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sign helps implement the SigningMethod interface.
|
||||
func (_ *SigningMethodNone) Sign(_ []byte, _ interface{}) (Signature, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Alg helps implement the SigningMethod interface.
|
||||
func (m *SigningMethodNone) Alg() string {
|
||||
return m.Name
|
||||
}
|
||||
|
||||
// Hasher helps implement the SigningMethod interface.
|
||||
func (m *SigningMethodNone) Hasher() crypto.Hash {
|
||||
return m.Hash
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
// See SigningMethodECDSA.MarshalJSON() for information.
|
||||
func (m *SigningMethodNone) MarshalJSON() ([]byte, error) {
|
||||
return []byte(`"` + m.Alg() + `"`), nil
|
||||
}
|
||||
|
||||
var _ json.Marshaler = (*SigningMethodNone)(nil)
|
|
@ -0,0 +1,80 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// SigningMethodRSA implements the RSA family of SigningMethods.
|
||||
type SigningMethodRSA struct {
|
||||
Name string
|
||||
Hash crypto.Hash
|
||||
_ struct{}
|
||||
}
|
||||
|
||||
// Specific instances of RSA SigningMethods.
|
||||
var (
|
||||
// SigningMethodRS256 implements RS256.
|
||||
SigningMethodRS256 = &SigningMethodRSA{
|
||||
Name: "RS256",
|
||||
Hash: crypto.SHA256,
|
||||
}
|
||||
|
||||
// SigningMethodRS384 implements RS384.
|
||||
SigningMethodRS384 = &SigningMethodRSA{
|
||||
Name: "RS384",
|
||||
Hash: crypto.SHA384,
|
||||
}
|
||||
|
||||
// SigningMethodRS512 implements RS512.
|
||||
SigningMethodRS512 = &SigningMethodRSA{
|
||||
Name: "RS512",
|
||||
Hash: crypto.SHA512,
|
||||
}
|
||||
)
|
||||
|
||||
// Alg implements the SigningMethod interface.
|
||||
func (m *SigningMethodRSA) Alg() string { return m.Name }
|
||||
|
||||
// Verify implements the Verify method from SigningMethod.
|
||||
// For this signing method, must be an *rsa.PublicKey.
|
||||
func (m *SigningMethodRSA) Verify(raw []byte, sig Signature, key interface{}) error {
|
||||
rsaKey, ok := key.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return ErrInvalidKey
|
||||
}
|
||||
return rsa.VerifyPKCS1v15(rsaKey, m.Hash, m.sum(raw), sig)
|
||||
}
|
||||
|
||||
// Sign implements the Sign method from SigningMethod.
|
||||
// For this signing method, must be an *rsa.PrivateKey structure.
|
||||
func (m *SigningMethodRSA) Sign(data []byte, key interface{}) (Signature, error) {
|
||||
rsaKey, ok := key.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, ErrInvalidKey
|
||||
}
|
||||
sigBytes, err := rsa.SignPKCS1v15(rand.Reader, rsaKey, m.Hash, m.sum(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Signature(sigBytes), nil
|
||||
}
|
||||
|
||||
func (m *SigningMethodRSA) sum(b []byte) []byte {
|
||||
h := m.Hash.New()
|
||||
h.Write(b)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
// Hasher implements the SigningMethod interface.
|
||||
func (m *SigningMethodRSA) Hasher() crypto.Hash { return m.Hash }
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
// See SigningMethodECDSA.MarshalJSON() for information.
|
||||
func (m *SigningMethodRSA) MarshalJSON() ([]byte, error) {
|
||||
return []byte(`"` + m.Alg() + `"`), nil
|
||||
}
|
||||
|
||||
var _ json.Marshaler = (*SigningMethodRSA)(nil)
|
|
@ -0,0 +1,96 @@
|
|||
// +build go1.4
|
||||
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// SigningMethodRSAPSS implements the RSAPSS family of SigningMethods.
|
||||
type SigningMethodRSAPSS struct {
|
||||
*SigningMethodRSA
|
||||
Options *rsa.PSSOptions
|
||||
}
|
||||
|
||||
// Specific instances for RS/PS SigningMethods.
|
||||
var (
|
||||
// SigningMethodPS256 implements PS256.
|
||||
SigningMethodPS256 = &SigningMethodRSAPSS{
|
||||
&SigningMethodRSA{
|
||||
Name: "PS256",
|
||||
Hash: crypto.SHA256,
|
||||
},
|
||||
&rsa.PSSOptions{
|
||||
SaltLength: rsa.PSSSaltLengthAuto,
|
||||
Hash: crypto.SHA256,
|
||||
},
|
||||
}
|
||||
|
||||
// SigningMethodPS384 implements PS384.
|
||||
SigningMethodPS384 = &SigningMethodRSAPSS{
|
||||
&SigningMethodRSA{
|
||||
Name: "PS384",
|
||||
Hash: crypto.SHA384,
|
||||
},
|
||||
&rsa.PSSOptions{
|
||||
SaltLength: rsa.PSSSaltLengthAuto,
|
||||
Hash: crypto.SHA384,
|
||||
},
|
||||
}
|
||||
|
||||
// SigningMethodPS512 implements PS512.
|
||||
SigningMethodPS512 = &SigningMethodRSAPSS{
|
||||
&SigningMethodRSA{
|
||||
Name: "PS512",
|
||||
Hash: crypto.SHA512,
|
||||
},
|
||||
&rsa.PSSOptions{
|
||||
SaltLength: rsa.PSSSaltLengthAuto,
|
||||
Hash: crypto.SHA512,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// Verify implements the Verify method from SigningMethod.
|
||||
// For this verify method, key must be an *rsa.PublicKey.
|
||||
func (m *SigningMethodRSAPSS) Verify(raw []byte, signature Signature, key interface{}) error {
|
||||
rsaKey, ok := key.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return ErrInvalidKey
|
||||
}
|
||||
return rsa.VerifyPSS(rsaKey, m.Hash, m.sum(raw), signature, m.Options)
|
||||
}
|
||||
|
||||
// Sign implements the Sign method from SigningMethod.
|
||||
// For this signing method, key must be an *rsa.PrivateKey.
|
||||
func (m *SigningMethodRSAPSS) Sign(raw []byte, key interface{}) (Signature, error) {
|
||||
rsaKey, ok := key.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, ErrInvalidKey
|
||||
}
|
||||
sigBytes, err := rsa.SignPSS(rand.Reader, rsaKey, m.Hash, m.sum(raw), m.Options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Signature(sigBytes), nil
|
||||
}
|
||||
|
||||
func (m *SigningMethodRSAPSS) sum(b []byte) []byte {
|
||||
h := m.Hash.New()
|
||||
h.Write(b)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
// Hasher implements the Hasher method from SigningMethod.
|
||||
func (m *SigningMethodRSAPSS) Hasher() crypto.Hash { return m.Hash }
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
// See SigningMethodECDSA.MarshalJSON() for information.
|
||||
func (m *SigningMethodRSAPSS) MarshalJSON() ([]byte, error) {
|
||||
return []byte(`"` + m.Alg() + `"`), nil
|
||||
}
|
||||
|
||||
var _ json.Marshaler = (*SigningMethodRSAPSS)(nil)
|
|
@ -0,0 +1,70 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// Errors specific to rsa_utils.
|
||||
var (
|
||||
ErrKeyMustBePEMEncoded = errors.New("invalid key: Key must be PEM encoded PKCS1 or PKCS8 private key")
|
||||
ErrNotRSAPrivateKey = errors.New("key is not a valid RSA private key")
|
||||
ErrNotRSAPublicKey = errors.New("key is not a valid RSA public key")
|
||||
)
|
||||
|
||||
// ParseRSAPrivateKeyFromPEM parses a PEM encoded PKCS1 or PKCS8 private key.
|
||||
func ParseRSAPrivateKeyFromPEM(key []byte) (*rsa.PrivateKey, error) {
|
||||
var err error
|
||||
|
||||
// Parse PEM block
|
||||
var block *pem.Block
|
||||
if block, _ = pem.Decode(key); block == nil {
|
||||
return nil, ErrKeyMustBePEMEncoded
|
||||
}
|
||||
|
||||
var parsedKey interface{}
|
||||
if parsedKey, err = x509.ParsePKCS1PrivateKey(block.Bytes); err != nil {
|
||||
if parsedKey, err = x509.ParsePKCS8PrivateKey(block.Bytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var pkey *rsa.PrivateKey
|
||||
var ok bool
|
||||
if pkey, ok = parsedKey.(*rsa.PrivateKey); !ok {
|
||||
return nil, ErrNotRSAPrivateKey
|
||||
}
|
||||
|
||||
return pkey, nil
|
||||
}
|
||||
|
||||
// ParseRSAPublicKeyFromPEM parses PEM encoded PKCS1 or PKCS8 public key.
|
||||
func ParseRSAPublicKeyFromPEM(key []byte) (*rsa.PublicKey, error) {
|
||||
var err error
|
||||
|
||||
// Parse PEM block
|
||||
var block *pem.Block
|
||||
if block, _ = pem.Decode(key); block == nil {
|
||||
return nil, ErrKeyMustBePEMEncoded
|
||||
}
|
||||
|
||||
// Parse the key
|
||||
var parsedKey interface{}
|
||||
if parsedKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil {
|
||||
if cert, err := x509.ParseCertificate(block.Bytes); err == nil {
|
||||
parsedKey = cert.PublicKey
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var pkey *rsa.PublicKey
|
||||
var ok bool
|
||||
if pkey, ok = parsedKey.(*rsa.PublicKey); !ok {
|
||||
return nil, ErrNotRSAPublicKey
|
||||
}
|
||||
|
||||
return pkey, nil
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/SermoDigital/jose"
|
||||
)
|
||||
|
||||
// Signature is a JWS signature.
|
||||
type Signature []byte
|
||||
|
||||
// MarshalJSON implements json.Marshaler for a signature.
|
||||
func (s Signature) MarshalJSON() ([]byte, error) {
|
||||
return jose.EncodeEscape(s), nil
|
||||
}
|
||||
|
||||
// Base64 helps implements jose.Encoder for Signature.
|
||||
func (s Signature) Base64() ([]byte, error) {
|
||||
return jose.Base64Encode(s), nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler for signature.
|
||||
func (s *Signature) UnmarshalJSON(b []byte) error {
|
||||
dec, err := jose.DecodeEscaped(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*s = Signature(dec)
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
_ json.Marshaler = (Signature)(nil)
|
||||
_ json.Unmarshaler = (*Signature)(nil)
|
||||
_ jose.Encoder = (Signature)(nil)
|
||||
)
|
|
@ -0,0 +1,24 @@
|
|||
package crypto
|
||||
|
||||
import "crypto"
|
||||
|
||||
// SigningMethod is an interface that provides a way to sign JWS tokens.
|
||||
type SigningMethod interface {
|
||||
// Alg describes the signing algorithm, and is used to uniquely
|
||||
// describe the specific crypto.SigningMethod.
|
||||
Alg() string
|
||||
|
||||
// Verify accepts the raw content, the signature, and the key used
|
||||
// to sign the raw content, and returns any errors found while validating
|
||||
// the signature and content.
|
||||
Verify(raw []byte, sig Signature, key interface{}) error
|
||||
|
||||
// Sign returns a Signature for the raw bytes, as well as any errors
|
||||
// that occurred during the signing.
|
||||
Sign(raw []byte, key interface{}) (Signature, error)
|
||||
|
||||
// Used to cause quick panics when a crypto.SigningMethod whose form of hashing
|
||||
// isn't linked in the binary when you register a crypto.SigningMethod.
|
||||
// To spoof this, see "crypto.SigningMethodNone".
|
||||
Hasher() crypto.Hash
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
// Package jose implements some helper functions and types for the children
|
||||
// packages, jws, jwt, and jwe.
|
||||
package jose
|
|
@ -0,0 +1,124 @@
|
|||
package jose
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// Header implements a JOSE Header with the addition of some helper
|
||||
// methods, similar to net/url.Values.
|
||||
type Header map[string]interface{}
|
||||
|
||||
// Get retrieves the value corresponding with key from the Header.
|
||||
func (h Header) Get(key string) interface{} {
|
||||
if h == nil {
|
||||
return nil
|
||||
}
|
||||
return h[key]
|
||||
}
|
||||
|
||||
// Set sets Claims[key] = val. It'll overwrite without warning.
|
||||
func (h Header) Set(key string, val interface{}) {
|
||||
h[key] = val
|
||||
}
|
||||
|
||||
// Del removes the value that corresponds with key from the Header.
|
||||
func (h Header) Del(key string) {
|
||||
delete(h, key)
|
||||
}
|
||||
|
||||
// Has returns true if a value for the given key exists inside the Header.
|
||||
func (h Header) Has(key string) bool {
|
||||
_, ok := h[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler for Header.
|
||||
func (h Header) MarshalJSON() ([]byte, error) {
|
||||
if len(h) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
b, err := json.Marshal(map[string]interface{}(h))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return EncodeEscape(b), nil
|
||||
}
|
||||
|
||||
// Base64 implements the Encoder interface.
|
||||
func (h Header) Base64() ([]byte, error) {
|
||||
return h.MarshalJSON()
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler for Header.
|
||||
func (h *Header) UnmarshalJSON(b []byte) error {
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
b, err := DecodeEscaped(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(b, (*map[string]interface{})(h))
|
||||
}
|
||||
|
||||
// Protected Headers are base64-encoded after they're marshaled into
|
||||
// JSON.
|
||||
type Protected Header
|
||||
|
||||
// Get retrieves the value corresponding with key from the Protected Header.
|
||||
func (p Protected) Get(key string) interface{} {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return p[key]
|
||||
}
|
||||
|
||||
// Set sets Protected[key] = val. It'll overwrite without warning.
|
||||
func (p Protected) Set(key string, val interface{}) {
|
||||
p[key] = val
|
||||
}
|
||||
|
||||
// Del removes the value that corresponds with key from the Protected Header.
|
||||
func (p Protected) Del(key string) {
|
||||
delete(p, key)
|
||||
}
|
||||
|
||||
// Has returns true if a value for the given key exists inside the Protected
|
||||
// Header.
|
||||
func (p Protected) Has(key string) bool {
|
||||
_, ok := p[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler for Protected.
|
||||
func (p Protected) MarshalJSON() ([]byte, error) {
|
||||
b, err := json.Marshal(map[string]interface{}(p))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return EncodeEscape(b), nil
|
||||
}
|
||||
|
||||
// Base64 implements the Encoder interface.
|
||||
func (p Protected) Base64() ([]byte, error) {
|
||||
b, err := json.Marshal(map[string]interface{}(p))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Base64Encode(b), nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler for Protected.
|
||||
func (p *Protected) UnmarshalJSON(b []byte) error {
|
||||
var h Header
|
||||
if err := h.UnmarshalJSON(b); err != nil {
|
||||
return err
|
||||
}
|
||||
*p = Protected(h)
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
_ json.Marshaler = (Protected)(nil)
|
||||
_ json.Unmarshaler = (*Protected)(nil)
|
||||
_ json.Marshaler = (Header)(nil)
|
||||
_ json.Unmarshaler = (*Header)(nil)
|
||||
)
|
|
@ -0,0 +1,190 @@
|
|||
package jws
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/SermoDigital/jose"
|
||||
"github.com/SermoDigital/jose/jwt"
|
||||
)
|
||||
|
||||
// Claims represents a set of JOSE Claims.
|
||||
type Claims jwt.Claims
|
||||
|
||||
// Get retrieves the value corresponding with key from the Claims.
|
||||
func (c Claims) Get(key string) interface{} {
|
||||
return jwt.Claims(c).Get(key)
|
||||
}
|
||||
|
||||
// Set sets Claims[key] = val. It'll overwrite without warning.
|
||||
func (c Claims) Set(key string, val interface{}) {
|
||||
jwt.Claims(c).Set(key, val)
|
||||
}
|
||||
|
||||
// Del removes the value that corresponds with key from the Claims.
|
||||
func (c Claims) Del(key string) {
|
||||
jwt.Claims(c).Del(key)
|
||||
}
|
||||
|
||||
// Has returns true if a value for the given key exists inside the Claims.
|
||||
func (c Claims) Has(key string) bool {
|
||||
return jwt.Claims(c).Has(key)
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler for Claims.
|
||||
func (c Claims) MarshalJSON() ([]byte, error) {
|
||||
return jwt.Claims(c).MarshalJSON()
|
||||
}
|
||||
|
||||
// Base64 implements the Encoder interface.
|
||||
func (c Claims) Base64() ([]byte, error) {
|
||||
return jwt.Claims(c).Base64()
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler for Claims.
|
||||
func (c *Claims) UnmarshalJSON(b []byte) error {
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
b, err := jose.DecodeEscaped(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Since json.Unmarshal calls UnmarshalJSON,
|
||||
// calling json.Unmarshal on *p would be infinitely recursive
|
||||
// A temp variable is needed because &map[string]interface{}(*p) is
|
||||
// invalid Go.
|
||||
|
||||
tmp := map[string]interface{}(*c)
|
||||
if err = json.Unmarshal(b, &tmp); err != nil {
|
||||
return err
|
||||
}
|
||||
*c = Claims(tmp)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Issuer retrieves claim "iss" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.1
|
||||
func (c Claims) Issuer() (string, bool) {
|
||||
return jwt.Claims(c).Issuer()
|
||||
}
|
||||
|
||||
// Subject retrieves claim "sub" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.2
|
||||
func (c Claims) Subject() (string, bool) {
|
||||
return jwt.Claims(c).Subject()
|
||||
}
|
||||
|
||||
// Audience retrieves claim "aud" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.3
|
||||
func (c Claims) Audience() ([]string, bool) {
|
||||
return jwt.Claims(c).Audience()
|
||||
}
|
||||
|
||||
// Expiration retrieves claim "exp" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.4
|
||||
func (c Claims) Expiration() (time.Time, bool) {
|
||||
return jwt.Claims(c).Expiration()
|
||||
}
|
||||
|
||||
// NotBefore retrieves claim "nbf" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.5
|
||||
func (c Claims) NotBefore() (time.Time, bool) {
|
||||
return jwt.Claims(c).NotBefore()
|
||||
}
|
||||
|
||||
// IssuedAt retrieves claim "iat" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.6
|
||||
func (c Claims) IssuedAt() (time.Time, bool) {
|
||||
return jwt.Claims(c).IssuedAt()
|
||||
}
|
||||
|
||||
// JWTID retrieves claim "jti" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.7
|
||||
func (c Claims) JWTID() (string, bool) {
|
||||
return jwt.Claims(c).JWTID()
|
||||
}
|
||||
|
||||
// RemoveIssuer deletes claim "iss" from c.
|
||||
func (c Claims) RemoveIssuer() {
|
||||
jwt.Claims(c).RemoveIssuer()
|
||||
}
|
||||
|
||||
// RemoveSubject deletes claim "sub" from c.
|
||||
func (c Claims) RemoveSubject() {
|
||||
jwt.Claims(c).RemoveIssuer()
|
||||
}
|
||||
|
||||
// RemoveAudience deletes claim "aud" from c.
|
||||
func (c Claims) RemoveAudience() {
|
||||
jwt.Claims(c).Audience()
|
||||
}
|
||||
|
||||
// RemoveExpiration deletes claim "exp" from c.
|
||||
func (c Claims) RemoveExpiration() {
|
||||
jwt.Claims(c).RemoveExpiration()
|
||||
}
|
||||
|
||||
// RemoveNotBefore deletes claim "nbf" from c.
|
||||
func (c Claims) RemoveNotBefore() {
|
||||
jwt.Claims(c).NotBefore()
|
||||
}
|
||||
|
||||
// RemoveIssuedAt deletes claim "iat" from c.
|
||||
func (c Claims) RemoveIssuedAt() {
|
||||
jwt.Claims(c).IssuedAt()
|
||||
}
|
||||
|
||||
// RemoveJWTID deletes claim "jti" from c.
|
||||
func (c Claims) RemoveJWTID() {
|
||||
jwt.Claims(c).RemoveJWTID()
|
||||
}
|
||||
|
||||
// SetIssuer sets claim "iss" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.1
|
||||
func (c Claims) SetIssuer(issuer string) {
|
||||
jwt.Claims(c).SetIssuer(issuer)
|
||||
}
|
||||
|
||||
// SetSubject sets claim "iss" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.2
|
||||
func (c Claims) SetSubject(subject string) {
|
||||
jwt.Claims(c).SetSubject(subject)
|
||||
}
|
||||
|
||||
// SetAudience sets claim "aud" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.3
|
||||
func (c Claims) SetAudience(audience ...string) {
|
||||
jwt.Claims(c).SetAudience(audience...)
|
||||
}
|
||||
|
||||
// SetExpiration sets claim "exp" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.4
|
||||
func (c Claims) SetExpiration(expiration time.Time) {
|
||||
jwt.Claims(c).SetExpiration(expiration)
|
||||
}
|
||||
|
||||
// SetNotBefore sets claim "nbf" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.5
|
||||
func (c Claims) SetNotBefore(notBefore time.Time) {
|
||||
jwt.Claims(c).SetNotBefore(notBefore)
|
||||
}
|
||||
|
||||
// SetIssuedAt sets claim "iat" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.6
|
||||
func (c Claims) SetIssuedAt(issuedAt time.Time) {
|
||||
jwt.Claims(c).SetIssuedAt(issuedAt)
|
||||
}
|
||||
|
||||
// SetJWTID sets claim "jti" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.7
|
||||
func (c Claims) SetJWTID(uniqueID string) {
|
||||
jwt.Claims(c).SetJWTID(uniqueID)
|
||||
}
|
||||
|
||||
var (
|
||||
_ json.Marshaler = (Claims)(nil)
|
||||
_ json.Unmarshaler = (*Claims)(nil)
|
||||
)
|
|
@ -0,0 +1,2 @@
|
|||
// Package jws implements JWSs per RFC 7515
|
||||
package jws
|
|
@ -0,0 +1,62 @@
|
|||
package jws
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
|
||||
// ErrNotEnoughMethods is returned if New was called _or_ the Flat/Compact
|
||||
// methods were called with 0 SigningMethods.
|
||||
ErrNotEnoughMethods = errors.New("not enough methods provided")
|
||||
|
||||
// ErrCouldNotUnmarshal is returned when Parse's json.Unmarshaler
|
||||
// parameter returns an error.
|
||||
ErrCouldNotUnmarshal = errors.New("custom unmarshal failed")
|
||||
|
||||
// ErrNotCompact signals that the provided potential JWS is not
|
||||
// in its compact representation.
|
||||
ErrNotCompact = errors.New("not a compact JWS")
|
||||
|
||||
// ErrDuplicateHeaderParameter signals that there are duplicate parameters
|
||||
// in the provided Headers.
|
||||
ErrDuplicateHeaderParameter = errors.New("duplicate parameters in the JOSE Header")
|
||||
|
||||
// ErrTwoEmptyHeaders is returned if both Headers are empty.
|
||||
ErrTwoEmptyHeaders = errors.New("both headers cannot be empty")
|
||||
|
||||
// ErrNotEnoughKeys is returned when not enough keys are provided for
|
||||
// the given SigningMethods.
|
||||
ErrNotEnoughKeys = errors.New("not enough keys (for given methods)")
|
||||
|
||||
// ErrDidNotValidate means the given JWT did not properly validate
|
||||
ErrDidNotValidate = errors.New("did not validate")
|
||||
|
||||
// ErrNoAlgorithm means no algorithm ("alg") was found in the Protected
|
||||
// Header.
|
||||
ErrNoAlgorithm = errors.New("no algorithm found")
|
||||
|
||||
// ErrAlgorithmDoesntExist means the algorithm asked for cannot be
|
||||
// found inside the signingMethod cache.
|
||||
ErrAlgorithmDoesntExist = errors.New("algorithm doesn't exist")
|
||||
|
||||
// ErrMismatchedAlgorithms means the algorithm inside the JWT was
|
||||
// different than the algorithm the caller wanted to use.
|
||||
ErrMismatchedAlgorithms = errors.New("mismatched algorithms")
|
||||
|
||||
// ErrCannotValidate means the JWS cannot be validated for various
|
||||
// reasons. For example, if there aren't any signatures/payloads/headers
|
||||
// to actually validate.
|
||||
ErrCannotValidate = errors.New("cannot validate")
|
||||
|
||||
// ErrIsNotJWT means the given JWS is not a JWT.
|
||||
ErrIsNotJWT = errors.New("JWS is not a JWT")
|
||||
|
||||
// ErrHoldsJWE means the given JWS holds a JWE inside its payload.
|
||||
ErrHoldsJWE = errors.New("JWS holds JWE")
|
||||
|
||||
// ErrNotEnoughValidSignatures means the JWS did not meet the required
|
||||
// number of signatures.
|
||||
ErrNotEnoughValidSignatures = errors.New("not enough valid signatures in the JWS")
|
||||
|
||||
// ErrNoTokenInRequest means there's no token present inside the *http.Request.
|
||||
ErrNoTokenInRequest = errors.New("no token present in request")
|
||||
)
|
|
@ -0,0 +1,490 @@
|
|||
package jws
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/SermoDigital/jose"
|
||||
"github.com/SermoDigital/jose/crypto"
|
||||
)
|
||||
|
||||
// JWS implements a JWS per RFC 7515.
|
||||
type JWS interface {
|
||||
// Payload Returns the payload.
|
||||
Payload() interface{}
|
||||
|
||||
// SetPayload sets the payload with the given value.
|
||||
SetPayload(p interface{})
|
||||
|
||||
// Protected returns the JWS' Protected Header.
|
||||
Protected() jose.Protected
|
||||
|
||||
// ProtectedAt returns the JWS' Protected Header.
|
||||
// i represents the index of the Protected Header.
|
||||
ProtectedAt(i int) jose.Protected
|
||||
|
||||
// Header returns the JWS' unprotected Header.
|
||||
Header() jose.Header
|
||||
|
||||
// HeaderAt returns the JWS' unprotected Header.
|
||||
// i represents the index of the unprotected Header.
|
||||
HeaderAt(i int) jose.Header
|
||||
|
||||
// Verify validates the current JWS' signature as-is. Refer to
|
||||
// ValidateMulti for more information.
|
||||
Verify(key interface{}, method crypto.SigningMethod) error
|
||||
|
||||
// ValidateMulti validates the current JWS' signature as-is. Since it's
|
||||
// meant to be called after parsing a stream of bytes into a JWS, it
|
||||
// shouldn't do any internal parsing like the Sign, Flat, Compact, or
|
||||
// General methods do.
|
||||
VerifyMulti(keys []interface{}, methods []crypto.SigningMethod, o *SigningOpts) error
|
||||
|
||||
// VerifyCallback validates the current JWS' signature as-is. It
|
||||
// accepts a callback function that can be used to access header
|
||||
// parameters to lookup needed information. For example, looking
|
||||
// up the "kid" parameter.
|
||||
// The return slice must be a slice of keys used in the verification
|
||||
// of the JWS.
|
||||
VerifyCallback(fn VerifyCallback, methods []crypto.SigningMethod, o *SigningOpts) error
|
||||
|
||||
// General serializes the JWS into its "general" form per
|
||||
// https://tools.ietf.org/html/rfc7515#section-7.2.1
|
||||
General(keys ...interface{}) ([]byte, error)
|
||||
|
||||
// Flat serializes the JWS to its "flattened" form per
|
||||
// https://tools.ietf.org/html/rfc7515#section-7.2.2
|
||||
Flat(key interface{}) ([]byte, error)
|
||||
|
||||
// Compact serializes the JWS into its "compact" form per
|
||||
// https://tools.ietf.org/html/rfc7515#section-7.1
|
||||
Compact(key interface{}) ([]byte, error)
|
||||
|
||||
// IsJWT returns true if the JWS is a JWT.
|
||||
IsJWT() bool
|
||||
}
|
||||
|
||||
// jws represents a specific jws.
|
||||
type jws struct {
|
||||
payload *payload
|
||||
plcache rawBase64
|
||||
clean bool
|
||||
|
||||
sb []sigHead
|
||||
|
||||
isJWT bool
|
||||
}
|
||||
|
||||
// Payload returns the jws' payload.
|
||||
func (j *jws) Payload() interface{} {
|
||||
return j.payload.v
|
||||
}
|
||||
|
||||
// SetPayload sets the jws' raw, unexported payload.
|
||||
func (j *jws) SetPayload(val interface{}) {
|
||||
j.payload.v = val
|
||||
}
|
||||
|
||||
// Protected returns the JWS' Protected Header.
|
||||
func (j *jws) Protected() jose.Protected {
|
||||
return j.sb[0].protected
|
||||
}
|
||||
|
||||
// Protected returns the JWS' Protected Header.
|
||||
// i represents the index of the Protected Header.
|
||||
// Left empty, it defaults to 0.
|
||||
func (j *jws) ProtectedAt(i int) jose.Protected {
|
||||
return j.sb[i].protected
|
||||
}
|
||||
|
||||
// Header returns the JWS' unprotected Header.
|
||||
func (j *jws) Header() jose.Header {
|
||||
return j.sb[0].unprotected
|
||||
}
|
||||
|
||||
// HeaderAt returns the JWS' unprotected Header.
|
||||
// |i| is the index of the unprotected Header.
|
||||
func (j *jws) HeaderAt(i int) jose.Header {
|
||||
return j.sb[i].unprotected
|
||||
}
|
||||
|
||||
// sigHead represents the 'signatures' member of the jws' "general"
|
||||
// serialization form per
|
||||
// https://tools.ietf.org/html/rfc7515#section-7.2.1
|
||||
//
|
||||
// It's embedded inside the "flat" structure in order to properly
|
||||
// create the "flat" jws.
|
||||
type sigHead struct {
|
||||
Protected rawBase64 `json:"protected,omitempty"`
|
||||
Unprotected rawBase64 `json:"header,omitempty"`
|
||||
Signature crypto.Signature `json:"signature"`
|
||||
|
||||
protected jose.Protected
|
||||
unprotected jose.Header
|
||||
clean bool
|
||||
|
||||
method crypto.SigningMethod
|
||||
}
|
||||
|
||||
func (s *sigHead) unmarshal() error {
|
||||
if err := s.protected.UnmarshalJSON(s.Protected); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.unprotected.UnmarshalJSON(s.Unprotected)
|
||||
}
|
||||
|
||||
// New creates a JWS with the provided crypto.SigningMethods.
|
||||
func New(content interface{}, methods ...crypto.SigningMethod) JWS {
|
||||
sb := make([]sigHead, len(methods))
|
||||
for i := range methods {
|
||||
sb[i] = sigHead{
|
||||
protected: jose.Protected{
|
||||
"alg": methods[i].Alg(),
|
||||
},
|
||||
unprotected: jose.Header{},
|
||||
method: methods[i],
|
||||
}
|
||||
}
|
||||
return &jws{
|
||||
payload: &payload{v: content},
|
||||
sb: sb,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *sigHead) assignMethod(p jose.Protected) error {
|
||||
alg, ok := p.Get("alg").(string)
|
||||
if !ok {
|
||||
return ErrNoAlgorithm
|
||||
}
|
||||
|
||||
sm := GetSigningMethod(alg)
|
||||
if sm == nil {
|
||||
return ErrNoAlgorithm
|
||||
}
|
||||
s.method = sm
|
||||
return nil
|
||||
}
|
||||
|
||||
type generic struct {
|
||||
Payload rawBase64 `json:"payload"`
|
||||
sigHead
|
||||
Signatures []sigHead `json:"signatures,omitempty"`
|
||||
}
|
||||
|
||||
// Parse parses any of the three serialized jws forms into a physical
|
||||
// jws per https://tools.ietf.org/html/rfc7515#section-5.2
|
||||
//
|
||||
// It accepts a json.Unmarshaler in order to properly parse
|
||||
// the payload. In order to keep the caller from having to do extra
|
||||
// parsing of the payload, a json.Unmarshaler can be passed
|
||||
// which will be then to unmarshal the payload however the caller
|
||||
// wishes. Do note that if json.Unmarshal returns an error the
|
||||
// original payload will be used as if no json.Unmarshaler was
|
||||
// passed.
|
||||
//
|
||||
// Internally, Parse applies some heuristics and then calls either
|
||||
// ParseGeneral, ParseFlat, or ParseCompact.
|
||||
// It should only be called if, for whatever reason, you do not
|
||||
// know which form the serialized JWT is in.
|
||||
//
|
||||
// It cannot parse a JWT.
|
||||
func Parse(encoded []byte, u ...json.Unmarshaler) (JWS, error) {
|
||||
// Try and unmarshal into a generic struct that'll
|
||||
// hopefully hold either of the two JSON serialization
|
||||
// formats.
|
||||
var g generic
|
||||
|
||||
// Not valid JSON. Let's try compact.
|
||||
if err := json.Unmarshal(encoded, &g); err != nil {
|
||||
return ParseCompact(encoded, u...)
|
||||
}
|
||||
|
||||
if g.Signatures == nil {
|
||||
return g.parseFlat(u...)
|
||||
}
|
||||
return g.parseGeneral(u...)
|
||||
}
|
||||
|
||||
// ParseGeneral parses a jws serialized into its "general" form per
|
||||
// https://tools.ietf.org/html/rfc7515#section-7.2.1
|
||||
// into a physical jws per
|
||||
// https://tools.ietf.org/html/rfc7515#section-5.2
|
||||
//
|
||||
// For information on the json.Unmarshaler parameter, see Parse.
|
||||
func ParseGeneral(encoded []byte, u ...json.Unmarshaler) (JWS, error) {
|
||||
var g generic
|
||||
if err := json.Unmarshal(encoded, &g); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return g.parseGeneral(u...)
|
||||
}
|
||||
|
||||
func (g *generic) parseGeneral(u ...json.Unmarshaler) (JWS, error) {
|
||||
|
||||
var p payload
|
||||
if len(u) > 0 {
|
||||
p.u = u[0]
|
||||
}
|
||||
|
||||
if err := p.UnmarshalJSON(g.Payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range g.Signatures {
|
||||
if err := g.Signatures[i].unmarshal(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := checkHeaders(jose.Header(g.Signatures[i].protected), g.Signatures[i].unprotected); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := g.Signatures[i].assignMethod(g.Signatures[i].protected); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
g.clean = len(g.Signatures) != 0
|
||||
|
||||
return &jws{
|
||||
payload: &p,
|
||||
plcache: g.Payload,
|
||||
clean: true,
|
||||
sb: g.Signatures,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ParseFlat parses a jws serialized into its "flat" form per
|
||||
// https://tools.ietf.org/html/rfc7515#section-7.2.2
|
||||
// into a physical jws per
|
||||
// https://tools.ietf.org/html/rfc7515#section-5.2
|
||||
//
|
||||
// For information on the json.Unmarshaler parameter, see Parse.
|
||||
func ParseFlat(encoded []byte, u ...json.Unmarshaler) (JWS, error) {
|
||||
var g generic
|
||||
if err := json.Unmarshal(encoded, &g); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return g.parseFlat(u...)
|
||||
}
|
||||
|
||||
func (g *generic) parseFlat(u ...json.Unmarshaler) (JWS, error) {
|
||||
|
||||
var p payload
|
||||
if len(u) > 0 {
|
||||
p.u = u[0]
|
||||
}
|
||||
|
||||
if err := p.UnmarshalJSON(g.Payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := g.sigHead.unmarshal(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
g.sigHead.clean = true
|
||||
|
||||
if err := checkHeaders(jose.Header(g.sigHead.protected), g.sigHead.unprotected); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := g.sigHead.assignMethod(g.sigHead.protected); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &jws{
|
||||
payload: &p,
|
||||
plcache: g.Payload,
|
||||
clean: true,
|
||||
sb: []sigHead{g.sigHead},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ParseCompact parses a jws serialized into its "compact" form per
|
||||
// https://tools.ietf.org/html/rfc7515#section-7.1
|
||||
// into a physical jws per
|
||||
// https://tools.ietf.org/html/rfc7515#section-5.2
|
||||
//
|
||||
// For information on the json.Unmarshaler parameter, see Parse.
|
||||
func ParseCompact(encoded []byte, u ...json.Unmarshaler) (JWS, error) {
|
||||
return parseCompact(encoded, false, u...)
|
||||
}
|
||||
|
||||
func parseCompact(encoded []byte, jwt bool, u ...json.Unmarshaler) (*jws, error) {
|
||||
|
||||
// This section loosely follows
|
||||
// https://tools.ietf.org/html/rfc7519#section-7.2
|
||||
// because it's used to parse _both_ jws and JWTs.
|
||||
|
||||
parts := bytes.Split(encoded, []byte{'.'})
|
||||
if len(parts) != 3 {
|
||||
return nil, ErrNotCompact
|
||||
}
|
||||
|
||||
var p jose.Protected
|
||||
if err := p.UnmarshalJSON(parts[0]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s := sigHead{
|
||||
Protected: parts[0],
|
||||
protected: p,
|
||||
Signature: parts[2],
|
||||
clean: true,
|
||||
}
|
||||
|
||||
if err := s.assignMethod(p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var pl payload
|
||||
if len(u) > 0 {
|
||||
pl.u = u[0]
|
||||
}
|
||||
|
||||
j := jws{
|
||||
payload: &pl,
|
||||
plcache: parts[1],
|
||||
sb: []sigHead{s},
|
||||
isJWT: jwt,
|
||||
}
|
||||
|
||||
if err := j.payload.UnmarshalJSON(parts[1]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
j.clean = true
|
||||
|
||||
if err := j.sb[0].Signature.UnmarshalJSON(parts[2]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// https://tools.ietf.org/html/rfc7519#section-7.2.8
|
||||
cty, ok := p.Get("cty").(string)
|
||||
if ok && cty == "JWT" {
|
||||
return &j, ErrHoldsJWE
|
||||
}
|
||||
return &j, nil
|
||||
}
|
||||
|
||||
var (
|
||||
// JWSFormKey is the form "key" which should be used inside
|
||||
// ParseFromRequest if the request is a multipart.Form.
|
||||
JWSFormKey = "access_token"
|
||||
|
||||
// MaxMemory is maximum amount of memory which should be used
|
||||
// inside ParseFromRequest while parsing the multipart.Form
|
||||
// if the request is a multipart.Form.
|
||||
MaxMemory int64 = 10e6
|
||||
)
|
||||
|
||||
// Format specifies which "format" the JWS is in -- Flat, General,
|
||||
// or compact. Additionally, constants for JWT/Unknown are added.
|
||||
type Format uint8
|
||||
|
||||
const (
|
||||
// Unknown format.
|
||||
Unknown Format = iota
|
||||
|
||||
// Flat format.
|
||||
Flat
|
||||
|
||||
// General format.
|
||||
General
|
||||
|
||||
// Compact format.
|
||||
Compact
|
||||
)
|
||||
|
||||
var parseJumpTable = [...]func([]byte, ...json.Unmarshaler) (JWS, error){
|
||||
Unknown: Parse,
|
||||
Flat: ParseFlat,
|
||||
General: ParseGeneral,
|
||||
Compact: ParseCompact,
|
||||
1<<8 - 1: Parse, // Max uint8.
|
||||
}
|
||||
|
||||
func init() {
|
||||
for i := range parseJumpTable {
|
||||
if parseJumpTable[i] == nil {
|
||||
parseJumpTable[i] = Parse
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func fromHeader(req *http.Request) ([]byte, bool) {
|
||||
if ah := req.Header.Get("Authorization"); len(ah) > 7 && strings.EqualFold(ah[0:7], "BEARER ") {
|
||||
return []byte(ah[7:]), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func fromForm(req *http.Request) ([]byte, bool) {
|
||||
if err := req.ParseMultipartForm(MaxMemory); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
if tokStr := req.Form.Get(JWSFormKey); tokStr != "" {
|
||||
return []byte(tokStr), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// ParseFromHeader tries to find the JWS in an http.Request header.
|
||||
func ParseFromHeader(req *http.Request, format Format, u ...json.Unmarshaler) (JWS, error) {
|
||||
if b, ok := fromHeader(req); ok {
|
||||
return parseJumpTable[format](b, u...)
|
||||
}
|
||||
return nil, ErrNoTokenInRequest
|
||||
}
|
||||
|
||||
// ParseFromForm tries to find the JWS in an http.Request form request.
|
||||
func ParseFromForm(req *http.Request, format Format, u ...json.Unmarshaler) (JWS, error) {
|
||||
if b, ok := fromForm(req); ok {
|
||||
return parseJumpTable[format](b, u...)
|
||||
}
|
||||
return nil, ErrNoTokenInRequest
|
||||
}
|
||||
|
||||
// ParseFromRequest tries to find the JWS in an http.Request.
|
||||
// This method will call ParseMultipartForm if there's no token in the header.
|
||||
func ParseFromRequest(req *http.Request, format Format, u ...json.Unmarshaler) (JWS, error) {
|
||||
token, err := ParseFromHeader(req, format, u...)
|
||||
if err == nil {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
token, err = ParseFromForm(req, format, u...)
|
||||
if err == nil {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// IgnoreDupes should be set to true if the internal duplicate header key check
|
||||
// should ignore duplicate Header keys instead of reporting an error when
|
||||
// duplicate Header keys are found.
|
||||
//
|
||||
// Note:
|
||||
// Duplicate Header keys are defined in
|
||||
// https://tools.ietf.org/html/rfc7515#section-5.2
|
||||
// meaning keys that both the protected and unprotected
|
||||
// Headers possess.
|
||||
var IgnoreDupes bool
|
||||
|
||||
// checkHeaders returns an error per the constraints described in
|
||||
// IgnoreDupes' comment.
|
||||
func checkHeaders(a, b jose.Header) error {
|
||||
if len(a)+len(b) == 0 {
|
||||
return ErrTwoEmptyHeaders
|
||||
}
|
||||
for key := range a {
|
||||
if b.Has(key) && !IgnoreDupes {
|
||||
return ErrDuplicateHeaderParameter
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ JWS = (*jws)(nil)
|
|
@ -0,0 +1,132 @@
|
|||
package jws
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// Flat serializes the JWS to its "flattened" form per
|
||||
// https://tools.ietf.org/html/rfc7515#section-7.2.2
|
||||
func (j *jws) Flat(key interface{}) ([]byte, error) {
|
||||
if len(j.sb) < 1 {
|
||||
return nil, ErrNotEnoughMethods
|
||||
}
|
||||
if err := j.sign(key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(struct {
|
||||
Payload rawBase64 `json:"payload"`
|
||||
sigHead
|
||||
}{
|
||||
Payload: j.plcache,
|
||||
sigHead: j.sb[0],
|
||||
})
|
||||
}
|
||||
|
||||
// General serializes the JWS into its "general" form per
|
||||
// https://tools.ietf.org/html/rfc7515#section-7.2.1
|
||||
//
|
||||
// If only one key is passed it's used for all the provided
|
||||
// crypto.SigningMethods. Otherwise, len(keys) must equal the number
|
||||
// of crypto.SigningMethods added.
|
||||
func (j *jws) General(keys ...interface{}) ([]byte, error) {
|
||||
if err := j.sign(keys...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(struct {
|
||||
Payload rawBase64 `json:"payload"`
|
||||
Signatures []sigHead `json:"signatures"`
|
||||
}{
|
||||
Payload: j.plcache,
|
||||
Signatures: j.sb,
|
||||
})
|
||||
}
|
||||
|
||||
// Compact serializes the JWS into its "compact" form per
|
||||
// https://tools.ietf.org/html/rfc7515#section-7.1
|
||||
func (j *jws) Compact(key interface{}) ([]byte, error) {
|
||||
if len(j.sb) < 1 {
|
||||
return nil, ErrNotEnoughMethods
|
||||
}
|
||||
|
||||
if err := j.sign(key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sig, err := j.sb[0].Signature.Base64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return format(
|
||||
j.sb[0].Protected,
|
||||
j.plcache,
|
||||
sig,
|
||||
), nil
|
||||
}
|
||||
|
||||
// sign signs each index of j's sb member.
|
||||
func (j *jws) sign(keys ...interface{}) error {
|
||||
if err := j.cache(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(keys) < 1 ||
|
||||
len(keys) > 1 && len(keys) != len(j.sb) {
|
||||
return ErrNotEnoughKeys
|
||||
}
|
||||
|
||||
if len(keys) == 1 {
|
||||
k := keys[0]
|
||||
keys = make([]interface{}, len(j.sb))
|
||||
for i := range keys {
|
||||
keys[i] = k
|
||||
}
|
||||
}
|
||||
|
||||
for i := range j.sb {
|
||||
if err := j.sb[i].cache(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
raw := format(j.sb[i].Protected, j.plcache)
|
||||
sig, err := j.sb[i].method.Sign(raw, keys[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
j.sb[i].Signature = sig
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cache marshals the payload, but only if it's changed since the last cache.
|
||||
func (j *jws) cache() (err error) {
|
||||
if !j.clean {
|
||||
j.plcache, err = j.payload.Base64()
|
||||
j.clean = err == nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// cache marshals the protected and unprotected headers, but only if
|
||||
// they've changed since their last cache.
|
||||
func (s *sigHead) cache() (err error) {
|
||||
if !s.clean {
|
||||
s.Protected, err = s.protected.Base64()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.Unprotected, err = s.unprotected.Base64()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
s.clean = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// format formats a slice of bytes in the order given, joining
|
||||
// them with a period.
|
||||
func format(a ...[]byte) []byte {
|
||||
return bytes.Join(a, []byte{'.'})
|
||||
}
|
|
@ -0,0 +1,203 @@
|
|||
package jws
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/SermoDigital/jose/crypto"
|
||||
)
|
||||
|
||||
// VerifyCallback is a callback function that can be used to access header
|
||||
// parameters to lookup needed information. For example, looking
|
||||
// up the "kid" parameter.
|
||||
// The return slice must be a slice of keys used in the verification
|
||||
// of the JWS.
|
||||
type VerifyCallback func(JWS) ([]interface{}, error)
|
||||
|
||||
// VerifyCallback validates the current JWS' signature as-is. It
|
||||
// accepts a callback function that can be used to access header
|
||||
// parameters to lookup needed information. For example, looking
|
||||
// up the "kid" parameter.
|
||||
// The return slice must be a slice of keys used in the verification
|
||||
// of the JWS.
|
||||
func (j *jws) VerifyCallback(fn VerifyCallback, methods []crypto.SigningMethod, o *SigningOpts) error {
|
||||
keys, err := fn(j)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return j.VerifyMulti(keys, methods, o)
|
||||
}
|
||||
|
||||
// IsMultiError returns true if the given error is type *MultiError.
|
||||
func IsMultiError(err error) bool {
|
||||
_, ok := err.(*MultiError)
|
||||
return ok
|
||||
}
|
||||
|
||||
// MultiError is a slice of errors.
|
||||
type MultiError []error
|
||||
|
||||
// Errors implements the error interface.
|
||||
func (m *MultiError) Error() string {
|
||||
var s string
|
||||
var n int
|
||||
for _, err := range *m {
|
||||
if err != nil {
|
||||
if n == 0 {
|
||||
s = err.Error()
|
||||
}
|
||||
n++
|
||||
}
|
||||
}
|
||||
switch n {
|
||||
case 0:
|
||||
return ""
|
||||
case 1:
|
||||
return s
|
||||
case 2:
|
||||
return s + " and 1 other error"
|
||||
}
|
||||
return fmt.Sprintf("%s (and %d other errors)", s, n-1)
|
||||
}
|
||||
|
||||
// Any means any of the JWS signatures need to verify.
|
||||
// Refer to verifyMulti for more information.
|
||||
const Any int = 0
|
||||
|
||||
// VerifyMulti verifies the current JWS as-is. Since it's meant to be
|
||||
// called after parsing a stream of bytes into a JWS, it doesn't do any
|
||||
// internal parsing like the Sign, Flat, Compact, or General methods do.
|
||||
func (j *jws) VerifyMulti(keys []interface{}, methods []crypto.SigningMethod, o *SigningOpts) error {
|
||||
|
||||
// Catch a simple mistake. Parameter o is irrelevant in this scenario.
|
||||
if len(keys) == 1 &&
|
||||
len(methods) == 1 &&
|
||||
len(j.sb) == 1 {
|
||||
return j.Verify(keys[0], methods[0])
|
||||
}
|
||||
|
||||
if len(j.sb) != len(methods) {
|
||||
return ErrNotEnoughMethods
|
||||
}
|
||||
|
||||
if len(keys) < 1 ||
|
||||
len(keys) > 1 && len(keys) != len(j.sb) {
|
||||
return ErrNotEnoughKeys
|
||||
}
|
||||
|
||||
// TODO do this better.
|
||||
if len(keys) == 1 {
|
||||
k := keys[0]
|
||||
keys = make([]interface{}, len(methods))
|
||||
for i := range keys {
|
||||
keys[i] = k
|
||||
}
|
||||
}
|
||||
|
||||
var o2 SigningOpts
|
||||
if o == nil {
|
||||
o = new(SigningOpts)
|
||||
}
|
||||
|
||||
var m MultiError
|
||||
for i := range j.sb {
|
||||
err := j.sb[i].verify(j.plcache, keys[i], methods[i])
|
||||
if err != nil {
|
||||
m = append(m, err)
|
||||
} else {
|
||||
o2.Inc()
|
||||
if o.Needs(i) {
|
||||
o.ptr++
|
||||
o2.Append(i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err := o.Validate(&o2)
|
||||
if err != nil {
|
||||
m = append(m, err)
|
||||
}
|
||||
if len(m) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &m
|
||||
}
|
||||
|
||||
// SigningOpts is a struct which holds options for validating
|
||||
// JWS signatures.
|
||||
// Number represents the cumulative which signatures need to verify
|
||||
// in order for the JWS to be considered valid.
|
||||
// Leave 'Number' empty or set it to the constant 'Any' if any number of
|
||||
// valid signatures (greater than one) should verify the JWS.
|
||||
//
|
||||
// Use the indices of the signatures that need to verify in order
|
||||
// for the JWS to be considered valid if specific signatures need
|
||||
// to verify in order for the JWS to be considered valid.
|
||||
//
|
||||
// Note:
|
||||
// The JWS spec requires *at least* one
|
||||
// signature to verify in order for the JWS to be considered valid.
|
||||
type SigningOpts struct {
|
||||
// Minimum of signatures which need to verify.
|
||||
Number int
|
||||
|
||||
// Indices of specific signatures which need to verify.
|
||||
Indices []int
|
||||
ptr int
|
||||
|
||||
_ struct{}
|
||||
}
|
||||
|
||||
// Append appends x to s' Indices member.
|
||||
func (s *SigningOpts) Append(x int) {
|
||||
s.Indices = append(s.Indices, x)
|
||||
}
|
||||
|
||||
// Needs returns true if x resides inside s' Indices member
|
||||
// for the given index. It's used to match two SigningOpts Indices members.
|
||||
func (s *SigningOpts) Needs(x int) bool {
|
||||
return s.ptr < len(s.Indices) && s.Indices[s.ptr] == x
|
||||
}
|
||||
|
||||
// Inc increments s' Number member by one.
|
||||
func (s *SigningOpts) Inc() { s.Number++ }
|
||||
|
||||
// Validate returns any errors found while validating the
|
||||
// provided SigningOpts. The receiver validates |have|.
|
||||
// It'll return an error if the passed SigningOpts' Number member is less
|
||||
// than s' or if the passed SigningOpts' Indices slice isn't equal to s'.
|
||||
func (s *SigningOpts) Validate(have *SigningOpts) error {
|
||||
if have.Number < s.Number ||
|
||||
(s.Indices != nil &&
|
||||
!eq(s.Indices, have.Indices)) {
|
||||
return ErrNotEnoughValidSignatures
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func eq(a, b []int) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Verify verifies the current JWS as-is. Refer to verifyMulti
|
||||
// for more information.
|
||||
func (j *jws) Verify(key interface{}, method crypto.SigningMethod) error {
|
||||
if len(j.sb) < 1 {
|
||||
return ErrCannotValidate
|
||||
}
|
||||
return j.sb[0].verify(j.plcache, key, method)
|
||||
}
|
||||
|
||||
func (s *sigHead) verify(pl []byte, key interface{}, method crypto.SigningMethod) error {
|
||||
if s.method.Alg() != method.Alg() || s.method.Hasher() != method.Hasher() {
|
||||
return ErrMismatchedAlgorithms
|
||||
}
|
||||
return method.Verify(format(s.Protected, pl), s.Signature, key)
|
||||
}
|
|
@ -0,0 +1,115 @@
|
|||
package jws
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/SermoDigital/jose"
|
||||
"github.com/SermoDigital/jose/crypto"
|
||||
"github.com/SermoDigital/jose/jwt"
|
||||
)
|
||||
|
||||
// NewJWT creates a new JWT with the given claims.
|
||||
func NewJWT(claims Claims, method crypto.SigningMethod) jwt.JWT {
|
||||
j, ok := New(claims, method).(*jws)
|
||||
if !ok {
|
||||
panic("jws.NewJWT: runtime panic: New(...).(*jws) != true")
|
||||
}
|
||||
j.sb[0].protected.Set("typ", "JWT")
|
||||
j.isJWT = true
|
||||
return j
|
||||
}
|
||||
|
||||
// Serialize helps implements jwt.JWT.
|
||||
func (j *jws) Serialize(key interface{}) ([]byte, error) {
|
||||
if j.isJWT {
|
||||
return j.Compact(key)
|
||||
}
|
||||
return nil, ErrIsNotJWT
|
||||
}
|
||||
|
||||
// Claims helps implements jwt.JWT.
|
||||
func (j *jws) Claims() jwt.Claims {
|
||||
if j.isJWT {
|
||||
if c, ok := j.payload.v.(Claims); ok {
|
||||
return jwt.Claims(c)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseJWTFromRequest tries to find the JWT in an http.Request.
|
||||
// This method will call ParseMultipartForm if there's no token in the header.
|
||||
func ParseJWTFromRequest(req *http.Request) (jwt.JWT, error) {
|
||||
if b, ok := fromHeader(req); ok {
|
||||
return ParseJWT(b)
|
||||
}
|
||||
if b, ok := fromForm(req); ok {
|
||||
return ParseJWT(b)
|
||||
}
|
||||
return nil, ErrNoTokenInRequest
|
||||
}
|
||||
|
||||
// ParseJWT parses a serialized jwt.JWT into a physical jwt.JWT.
|
||||
// If its payload isn't a set of claims (or able to be coerced into
|
||||
// a set of claims) it'll return an error stating the
|
||||
// JWT isn't a JWT.
|
||||
func ParseJWT(encoded []byte) (jwt.JWT, error) {
|
||||
t, err := parseCompact(encoded, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, ok := t.Payload().(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, ErrIsNotJWT
|
||||
}
|
||||
t.SetPayload(Claims(c))
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// IsJWT returns true if the JWS is a JWT.
|
||||
func (j *jws) IsJWT() bool {
|
||||
return j.isJWT
|
||||
}
|
||||
|
||||
func (j *jws) Validate(key interface{}, m crypto.SigningMethod, v ...*jwt.Validator) error {
|
||||
if j.isJWT {
|
||||
if err := j.Verify(key, m); err != nil {
|
||||
return err
|
||||
}
|
||||
var v1 jwt.Validator
|
||||
if len(v) > 0 {
|
||||
v1 = *v[0]
|
||||
}
|
||||
c, ok := j.payload.v.(Claims)
|
||||
if ok {
|
||||
if err := v1.Validate(j); err != nil {
|
||||
return err
|
||||
}
|
||||
return jwt.Claims(c).Validate(jose.Now(), v1.EXP, v1.NBF)
|
||||
}
|
||||
}
|
||||
return ErrIsNotJWT
|
||||
}
|
||||
|
||||
// Conv converts a func(Claims) error to type jwt.ValidateFunc.
|
||||
func Conv(fn func(Claims) error) jwt.ValidateFunc {
|
||||
if fn == nil {
|
||||
return nil
|
||||
}
|
||||
return func(c jwt.Claims) error {
|
||||
return fn(Claims(c))
|
||||
}
|
||||
}
|
||||
|
||||
// NewValidator returns a jwt.Validator.
|
||||
func NewValidator(c Claims, exp, nbf time.Duration, fn func(Claims) error) *jwt.Validator {
|
||||
return &jwt.Validator{
|
||||
Expected: jwt.Claims(c),
|
||||
EXP: exp,
|
||||
NBF: nbf,
|
||||
Fn: Conv(fn),
|
||||
}
|
||||
}
|
||||
|
||||
var _ jwt.JWT = (*jws)(nil)
|
|
@ -0,0 +1,52 @@
|
|||
package jws
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/SermoDigital/jose"
|
||||
)
|
||||
|
||||
// payload represents the payload of a JWS.
|
||||
type payload struct {
|
||||
v interface{}
|
||||
u json.Unmarshaler
|
||||
_ struct{}
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler for payload.
|
||||
func (p *payload) MarshalJSON() ([]byte, error) {
|
||||
b, err := json.Marshal(p.v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return jose.EncodeEscape(b), nil
|
||||
}
|
||||
|
||||
// Base64 implements jose.Encoder.
|
||||
func (p *payload) Base64() ([]byte, error) {
|
||||
b, err := json.Marshal(p.v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return jose.Base64Encode(b), nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Unmarshaler for payload.
|
||||
func (p *payload) UnmarshalJSON(b []byte) error {
|
||||
b2, err := jose.DecodeEscaped(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if p.u != nil {
|
||||
err := p.u.UnmarshalJSON(b2)
|
||||
p.v = p.u
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(b2, &p.v)
|
||||
}
|
||||
|
||||
var (
|
||||
_ json.Marshaler = (*payload)(nil)
|
||||
_ json.Unmarshaler = (*payload)(nil)
|
||||
_ jose.Encoder = (*payload)(nil)
|
||||
)
|
|
@ -0,0 +1,28 @@
|
|||
package jws
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type rawBase64 []byte
|
||||
|
||||
// MarshalJSON implements json.Marshaler for rawBase64.
|
||||
func (r rawBase64) MarshalJSON() ([]byte, error) {
|
||||
buf := make([]byte, len(r)+2)
|
||||
buf[0] = '"'
|
||||
copy(buf[1:], r)
|
||||
buf[len(buf)-1] = '"'
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Unmarshaler for rawBase64.
|
||||
func (r *rawBase64) UnmarshalJSON(b []byte) error {
|
||||
if len(b) > 1 && b[0] == '"' && b[len(b)-1] == '"' {
|
||||
b = b[1 : len(b)-1]
|
||||
}
|
||||
*r = rawBase64(b)
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
_ json.Marshaler = (rawBase64)(nil)
|
||||
_ json.Unmarshaler = (*rawBase64)(nil)
|
||||
)
|
|
@ -0,0 +1,63 @@
|
|||
package jws
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/SermoDigital/jose/crypto"
|
||||
)
|
||||
|
||||
var (
|
||||
mu sync.RWMutex
|
||||
|
||||
signingMethods = map[string]crypto.SigningMethod{
|
||||
crypto.SigningMethodES256.Alg(): crypto.SigningMethodES256,
|
||||
crypto.SigningMethodES384.Alg(): crypto.SigningMethodES384,
|
||||
crypto.SigningMethodES512.Alg(): crypto.SigningMethodES512,
|
||||
|
||||
crypto.SigningMethodPS256.Alg(): crypto.SigningMethodPS256,
|
||||
crypto.SigningMethodPS384.Alg(): crypto.SigningMethodPS384,
|
||||
crypto.SigningMethodPS512.Alg(): crypto.SigningMethodPS512,
|
||||
|
||||
crypto.SigningMethodRS256.Alg(): crypto.SigningMethodRS256,
|
||||
crypto.SigningMethodRS384.Alg(): crypto.SigningMethodRS384,
|
||||
crypto.SigningMethodRS512.Alg(): crypto.SigningMethodRS512,
|
||||
|
||||
crypto.SigningMethodHS256.Alg(): crypto.SigningMethodHS256,
|
||||
crypto.SigningMethodHS384.Alg(): crypto.SigningMethodHS384,
|
||||
crypto.SigningMethodHS512.Alg(): crypto.SigningMethodHS512,
|
||||
|
||||
crypto.Unsecured.Alg(): crypto.Unsecured,
|
||||
}
|
||||
)
|
||||
|
||||
// RegisterSigningMethod registers the crypto.SigningMethod in the global map.
|
||||
// This is typically done inside the caller's init function.
|
||||
func RegisterSigningMethod(sm crypto.SigningMethod) {
|
||||
alg := sm.Alg()
|
||||
if GetSigningMethod(alg) != nil {
|
||||
panic("jose/jws: cannot duplicate signing methods")
|
||||
}
|
||||
|
||||
if !sm.Hasher().Available() {
|
||||
panic("jose/jws: specific hash is unavailable")
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
signingMethods[alg] = sm
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
// RemoveSigningMethod removes the crypto.SigningMethod from the global map.
|
||||
func RemoveSigningMethod(sm crypto.SigningMethod) {
|
||||
mu.Lock()
|
||||
delete(signingMethods, sm.Alg())
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
// GetSigningMethod retrieves a crypto.SigningMethod from the global map.
|
||||
func GetSigningMethod(alg string) (method crypto.SigningMethod) {
|
||||
mu.RLock()
|
||||
method = signingMethods[alg]
|
||||
mu.RUnlock()
|
||||
return method
|
||||
}
|
|
@ -0,0 +1,274 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/SermoDigital/jose"
|
||||
)
|
||||
|
||||
// Claims implements a set of JOSE Claims with the addition of some helper
|
||||
// methods, similar to net/url.Values.
|
||||
type Claims map[string]interface{}
|
||||
|
||||
// Validate validates the Claims per the claims found in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1
|
||||
func (c Claims) Validate(now time.Time, expLeeway, nbfLeeway time.Duration) error {
|
||||
if exp, ok := c.Expiration(); ok {
|
||||
if now.After(exp.Add(expLeeway)) {
|
||||
return ErrTokenIsExpired
|
||||
}
|
||||
}
|
||||
|
||||
if nbf, ok := c.NotBefore(); ok {
|
||||
if !now.After(nbf.Add(-nbfLeeway)) {
|
||||
return ErrTokenNotYetValid
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves the value corresponding with key from the Claims.
|
||||
func (c Claims) Get(key string) interface{} {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return c[key]
|
||||
}
|
||||
|
||||
// Set sets Claims[key] = val. It'll overwrite without warning.
|
||||
func (c Claims) Set(key string, val interface{}) {
|
||||
c[key] = val
|
||||
}
|
||||
|
||||
// Del removes the value that corresponds with key from the Claims.
|
||||
func (c Claims) Del(key string) {
|
||||
delete(c, key)
|
||||
}
|
||||
|
||||
// Has returns true if a value for the given key exists inside the Claims.
|
||||
func (c Claims) Has(key string) bool {
|
||||
_, ok := c[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler for Claims.
|
||||
func (c Claims) MarshalJSON() ([]byte, error) {
|
||||
if c == nil || len(c) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return json.Marshal(map[string]interface{}(c))
|
||||
}
|
||||
|
||||
// Base64 implements the jose.Encoder interface.
|
||||
func (c Claims) Base64() ([]byte, error) {
|
||||
b, err := c.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return jose.Base64Encode(b), nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler for Claims.
|
||||
func (c *Claims) UnmarshalJSON(b []byte) error {
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
b, err := jose.DecodeEscaped(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Since json.Unmarshal calls UnmarshalJSON,
|
||||
// calling json.Unmarshal on *p would be infinitely recursive
|
||||
// A temp variable is needed because &map[string]interface{}(*p) is
|
||||
// invalid Go. (Address of unaddressable object and all that...)
|
||||
|
||||
tmp := map[string]interface{}(*c)
|
||||
if err = json.Unmarshal(b, &tmp); err != nil {
|
||||
return err
|
||||
}
|
||||
*c = Claims(tmp)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Issuer retrieves claim "iss" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.1
|
||||
func (c Claims) Issuer() (string, bool) {
|
||||
v, ok := c.Get("iss").(string)
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// Subject retrieves claim "sub" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.2
|
||||
func (c Claims) Subject() (string, bool) {
|
||||
v, ok := c.Get("sub").(string)
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// Audience retrieves claim "aud" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.3
|
||||
func (c Claims) Audience() ([]string, bool) {
|
||||
// Audience claim must be stringy. That is, it may be one string
|
||||
// or multiple strings but it should not be anything else. E.g. an int.
|
||||
switch t := c.Get("aud").(type) {
|
||||
case string:
|
||||
return []string{t}, true
|
||||
case []string:
|
||||
return t, true
|
||||
case []interface{}:
|
||||
return stringify(t...)
|
||||
case interface{}:
|
||||
return stringify(t)
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func stringify(a ...interface{}) ([]string, bool) {
|
||||
if len(a) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
s := make([]string, len(a))
|
||||
for i := range a {
|
||||
str, ok := a[i].(string)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
s[i] = str
|
||||
}
|
||||
return s, true
|
||||
}
|
||||
|
||||
// Expiration retrieves claim "exp" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.4
|
||||
func (c Claims) Expiration() (time.Time, bool) {
|
||||
return c.GetTime("exp")
|
||||
}
|
||||
|
||||
// NotBefore retrieves claim "nbf" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.5
|
||||
func (c Claims) NotBefore() (time.Time, bool) {
|
||||
return c.GetTime("nbf")
|
||||
}
|
||||
|
||||
// IssuedAt retrieves claim "iat" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.6
|
||||
func (c Claims) IssuedAt() (time.Time, bool) {
|
||||
return c.GetTime("iat")
|
||||
}
|
||||
|
||||
// JWTID retrieves claim "jti" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.7
|
||||
func (c Claims) JWTID() (string, bool) {
|
||||
v, ok := c.Get("jti").(string)
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// RemoveIssuer deletes claim "iss" from c.
|
||||
func (c Claims) RemoveIssuer() { c.Del("iss") }
|
||||
|
||||
// RemoveSubject deletes claim "sub" from c.
|
||||
func (c Claims) RemoveSubject() { c.Del("sub") }
|
||||
|
||||
// RemoveAudience deletes claim "aud" from c.
|
||||
func (c Claims) RemoveAudience() { c.Del("aud") }
|
||||
|
||||
// RemoveExpiration deletes claim "exp" from c.
|
||||
func (c Claims) RemoveExpiration() { c.Del("exp") }
|
||||
|
||||
// RemoveNotBefore deletes claim "nbf" from c.
|
||||
func (c Claims) RemoveNotBefore() { c.Del("nbf") }
|
||||
|
||||
// RemoveIssuedAt deletes claim "iat" from c.
|
||||
func (c Claims) RemoveIssuedAt() { c.Del("iat") }
|
||||
|
||||
// RemoveJWTID deletes claim "jti" from c.
|
||||
func (c Claims) RemoveJWTID() { c.Del("jti") }
|
||||
|
||||
// SetIssuer sets claim "iss" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.1
|
||||
func (c Claims) SetIssuer(issuer string) {
|
||||
c.Set("iss", issuer)
|
||||
}
|
||||
|
||||
// SetSubject sets claim "iss" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.2
|
||||
func (c Claims) SetSubject(subject string) {
|
||||
c.Set("sub", subject)
|
||||
}
|
||||
|
||||
// SetAudience sets claim "aud" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.3
|
||||
func (c Claims) SetAudience(audience ...string) {
|
||||
if len(audience) == 1 {
|
||||
c.Set("aud", audience[0])
|
||||
} else {
|
||||
c.Set("aud", audience)
|
||||
}
|
||||
}
|
||||
|
||||
// SetExpiration sets claim "exp" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.4
|
||||
func (c Claims) SetExpiration(expiration time.Time) {
|
||||
c.SetTime("exp", expiration)
|
||||
}
|
||||
|
||||
// SetNotBefore sets claim "nbf" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.5
|
||||
func (c Claims) SetNotBefore(notBefore time.Time) {
|
||||
c.SetTime("nbf", notBefore)
|
||||
}
|
||||
|
||||
// SetIssuedAt sets claim "iat" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.6
|
||||
func (c Claims) SetIssuedAt(issuedAt time.Time) {
|
||||
c.SetTime("iat", issuedAt)
|
||||
}
|
||||
|
||||
// SetJWTID sets claim "jti" per its type in
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.7
|
||||
func (c Claims) SetJWTID(uniqueID string) {
|
||||
c.Set("jti", uniqueID)
|
||||
}
|
||||
|
||||
// GetTime returns a Unix timestamp for the given key.
|
||||
//
|
||||
// It converts an int, int32, int64, uint, uint32, uint64 or float64 into a Unix
|
||||
// timestamp (epoch seconds). float32 does not have sufficient precision to
|
||||
// store a Unix timestamp.
|
||||
//
|
||||
// Numeric values parsed from JSON will always be stored as float64 since
|
||||
// Claims is a map[string]interface{}. However, the values may be stored directly
|
||||
// in the claims as a different type.
|
||||
func (c Claims) GetTime(key string) (time.Time, bool) {
|
||||
switch t := c.Get(key).(type) {
|
||||
case int:
|
||||
return time.Unix(int64(t), 0), true
|
||||
case int32:
|
||||
return time.Unix(int64(t), 0), true
|
||||
case int64:
|
||||
return time.Unix(int64(t), 0), true
|
||||
case uint:
|
||||
return time.Unix(int64(t), 0), true
|
||||
case uint32:
|
||||
return time.Unix(int64(t), 0), true
|
||||
case uint64:
|
||||
return time.Unix(int64(t), 0), true
|
||||
case float64:
|
||||
return time.Unix(int64(t), 0), true
|
||||
default:
|
||||
return time.Time{}, false
|
||||
}
|
||||
}
|
||||
|
||||
// SetTime stores a UNIX time for the given key.
|
||||
func (c Claims) SetTime(key string, t time.Time) {
|
||||
c.Set(key, t.Unix())
|
||||
}
|
||||
|
||||
var (
|
||||
_ json.Marshaler = (Claims)(nil)
|
||||
_ json.Unmarshaler = (*Claims)(nil)
|
||||
)
|
|
@ -0,0 +1,2 @@
|
|||
// Package jwt implements JWTs per RFC 7519
|
||||
package jwt
|
|
@ -0,0 +1,47 @@
|
|||
package jwt
|
||||
|
||||
func verifyPrincipals(pcpls, auds []string) bool {
|
||||
// "Each principal intended to process the JWT MUST
|
||||
// identify itself with a value in the audience claim."
|
||||
// - https://tools.ietf.org/html/rfc7519#section-4.1.3
|
||||
|
||||
found := -1
|
||||
for i, p := range pcpls {
|
||||
for _, v := range auds {
|
||||
if p == v {
|
||||
found++
|
||||
break
|
||||
}
|
||||
}
|
||||
if found != i {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// ValidAudience returns true iff:
|
||||
// - a and b are strings and a == b
|
||||
// - a is string, b is []string and a is in b
|
||||
// - a is []string, b is []string and all of a is in b
|
||||
// - a is []string, b is string and len(a) == 1 and a[0] == b
|
||||
func ValidAudience(a, b interface{}) bool {
|
||||
s1, ok := a.(string)
|
||||
if ok {
|
||||
if s2, ok := b.(string); ok {
|
||||
return s1 == s2
|
||||
}
|
||||
a2, ok := b.([]string)
|
||||
return ok && verifyPrincipals([]string{s1}, a2)
|
||||
}
|
||||
|
||||
a1, ok := a.([]string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if a2, ok := b.([]string); ok {
|
||||
return verifyPrincipals(a1, a2)
|
||||
}
|
||||
s2, ok := b.(string)
|
||||
return ok && len(a1) == 1 && a1[0] == s2
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
package jwt
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrTokenIsExpired is return when time.Now().Unix() is after
|
||||
// the token's "exp" claim.
|
||||
ErrTokenIsExpired = errors.New("token is expired")
|
||||
|
||||
// ErrTokenNotYetValid is return when time.Now().Unix() is before
|
||||
// the token's "nbf" claim.
|
||||
ErrTokenNotYetValid = errors.New("token is not yet valid")
|
||||
|
||||
// ErrInvalidISSClaim means the "iss" claim is invalid.
|
||||
ErrInvalidISSClaim = errors.New("claim \"iss\" is invalid")
|
||||
|
||||
// ErrInvalidSUBClaim means the "sub" claim is invalid.
|
||||
ErrInvalidSUBClaim = errors.New("claim \"sub\" is invalid")
|
||||
|
||||
// ErrInvalidIATClaim means the "iat" claim is invalid.
|
||||
ErrInvalidIATClaim = errors.New("claim \"iat\" is invalid")
|
||||
|
||||
// ErrInvalidJTIClaim means the "jti" claim is invalid.
|
||||
ErrInvalidJTIClaim = errors.New("claim \"jti\" is invalid")
|
||||
|
||||
// ErrInvalidAUDClaim means the "aud" claim is invalid.
|
||||
ErrInvalidAUDClaim = errors.New("claim \"aud\" is invalid")
|
||||
)
|
|
@ -0,0 +1,144 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/SermoDigital/jose/crypto"
|
||||
)
|
||||
|
||||
// JWT represents a JWT per RFC 7519.
|
||||
// It's described as an interface instead of a physical structure
|
||||
// because both JWS and JWEs can be JWTs. So, in order to use either,
|
||||
// import one of those two packages and use their "NewJWT" (and other)
|
||||
// functions.
|
||||
type JWT interface {
|
||||
// Claims returns the set of Claims.
|
||||
Claims() Claims
|
||||
|
||||
// Validate returns an error describing any issues found while
|
||||
// validating the JWT. For info on the fn parameter, see the
|
||||
// comment on ValidateFunc.
|
||||
Validate(key interface{}, method crypto.SigningMethod, v ...*Validator) error
|
||||
|
||||
// Serialize serializes the JWT into its on-the-wire
|
||||
// representation.
|
||||
Serialize(key interface{}) ([]byte, error)
|
||||
}
|
||||
|
||||
// ValidateFunc is a function that provides access to the JWT
|
||||
// and allows for custom validation. Keep in mind that the Verify
|
||||
// methods in the JWS/JWE sibling packages call ValidateFunc *after*
|
||||
// validating the JWS/JWE, but *before* any validation per the JWT
|
||||
// RFC. Therefore, the ValidateFunc can be used to short-circuit
|
||||
// verification, but cannot be used to circumvent the RFC.
|
||||
// Custom JWT implementations are free to abuse this, but it is
|
||||
// not recommended.
|
||||
type ValidateFunc func(Claims) error
|
||||
|
||||
// Validator represents some of the validation options.
|
||||
type Validator struct {
|
||||
Expected Claims // If non-nil, these are required to match.
|
||||
EXP time.Duration // EXPLeeway
|
||||
NBF time.Duration // NBFLeeway
|
||||
Fn ValidateFunc // See ValidateFunc for more information.
|
||||
|
||||
_ struct{} // Require explicitly-named struct fields.
|
||||
}
|
||||
|
||||
// Validate validates the JWT based on the expected claims in v.
|
||||
// Note: it only validates the registered claims per
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1
|
||||
//
|
||||
// Custom claims should be validated using v's Fn member.
|
||||
func (v *Validator) Validate(j JWT) error {
|
||||
if iss, ok := v.Expected.Issuer(); ok &&
|
||||
j.Claims().Get("iss") != iss {
|
||||
return ErrInvalidISSClaim
|
||||
}
|
||||
if sub, ok := v.Expected.Subject(); ok &&
|
||||
j.Claims().Get("sub") != sub {
|
||||
return ErrInvalidSUBClaim
|
||||
}
|
||||
if iat, ok := v.Expected.IssuedAt(); ok {
|
||||
if t, ok := j.Claims().GetTime("iat"); !t.Equal(iat) || !ok {
|
||||
return ErrInvalidIATClaim
|
||||
}
|
||||
}
|
||||
if jti, ok := v.Expected.JWTID(); ok &&
|
||||
j.Claims().Get("jti") != jti {
|
||||
return ErrInvalidJTIClaim
|
||||
}
|
||||
|
||||
if aud, ok := v.Expected.Audience(); ok {
|
||||
aud2, ok := j.Claims().Audience()
|
||||
if !ok || !ValidAudience(aud, aud2) {
|
||||
return ErrInvalidAUDClaim
|
||||
}
|
||||
}
|
||||
|
||||
if v.Fn != nil {
|
||||
return v.Fn(j.Claims())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetClaim sets the claim with the given val.
|
||||
func (v *Validator) SetClaim(claim string, val interface{}) {
|
||||
v.expect()
|
||||
v.Expected.Set(claim, val)
|
||||
}
|
||||
|
||||
// SetIssuer sets the "iss" claim per
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.1
|
||||
func (v *Validator) SetIssuer(iss string) {
|
||||
v.expect()
|
||||
v.Expected.Set("iss", iss)
|
||||
}
|
||||
|
||||
// SetSubject sets the "sub" claim per
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.2
|
||||
func (v *Validator) SetSubject(sub string) {
|
||||
v.expect()
|
||||
v.Expected.Set("sub", sub)
|
||||
}
|
||||
|
||||
// SetAudience sets the "aud" claim per
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.3
|
||||
func (v *Validator) SetAudience(aud string) {
|
||||
v.expect()
|
||||
v.Expected.Set("aud", aud)
|
||||
}
|
||||
|
||||
// SetExpiration sets the "exp" claim per
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.4
|
||||
func (v *Validator) SetExpiration(exp time.Time) {
|
||||
v.expect()
|
||||
v.Expected.Set("exp", exp)
|
||||
}
|
||||
|
||||
// SetNotBefore sets the "nbf" claim per
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.5
|
||||
func (v *Validator) SetNotBefore(nbf time.Time) {
|
||||
v.expect()
|
||||
v.Expected.Set("nbf", nbf)
|
||||
}
|
||||
|
||||
// SetIssuedAt sets the "iat" claim per
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.6
|
||||
func (v *Validator) SetIssuedAt(iat time.Time) {
|
||||
v.expect()
|
||||
v.Expected.Set("iat", iat)
|
||||
}
|
||||
|
||||
// SetJWTID sets the "jti" claim per
|
||||
// https://tools.ietf.org/html/rfc7519#section-4.1.7
|
||||
func (v *Validator) SetJWTID(jti string) {
|
||||
v.expect()
|
||||
v.Expected.Set("jti", jti)
|
||||
}
|
||||
|
||||
func (v *Validator) expect() {
|
||||
if v.Expected == nil {
|
||||
v.Expected = make(Claims)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
package jose
|
||||
|
||||
import "time"
|
||||
|
||||
// Now returns the current time in UTC.
|
||||
func Now() time.Time { return time.Now().UTC() }
|
|
@ -60,6 +60,30 @@
|
|||
"revision": "2a3aa15961d5fee6047b8151b67ac2f08ba2c48c",
|
||||
"revisionTime": "2016-11-16T21:39:02Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "t+uej2kiyqRyQYguygI8t9nJH2w=",
|
||||
"path": "github.com/SermoDigital/jose",
|
||||
"revision": "2bd9b81ac51d6d6134fcd4fd846bd2e7347a15f9",
|
||||
"revisionTime": "2016-12-05T22:51:55Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "u92C5yEz1FLBeoXp8jujn3tUNFI=",
|
||||
"path": "github.com/SermoDigital/jose/crypto",
|
||||
"revision": "2bd9b81ac51d6d6134fcd4fd846bd2e7347a15f9",
|
||||
"revisionTime": "2016-12-05T22:51:55Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "FJP5enLaw1JzZNu6Fen+eEKY7Lo=",
|
||||
"path": "github.com/SermoDigital/jose/jws",
|
||||
"revision": "2bd9b81ac51d6d6134fcd4fd846bd2e7347a15f9",
|
||||
"revisionTime": "2016-12-05T22:51:55Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "3gGGWQ3PcKHE+2cVlkJmIlBfBlw=",
|
||||
"path": "github.com/SermoDigital/jose/jwt",
|
||||
"revision": "2bd9b81ac51d6d6134fcd4fd846bd2e7347a15f9",
|
||||
"revisionTime": "2016-12-05T22:51:55Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "LLVyR2dAgkihu0+HdZF+JK0gMMs=",
|
||||
"path": "github.com/agl/ed25519",
|
||||
|
|
Loading…
Reference in New Issue