2023-03-29 21:08:31 +00:00
|
|
|
package pki
|
2023-03-28 17:29:54 +00:00
|
|
|
|
|
|
|
import (
|
|
|
|
"crypto/rand"
|
|
|
|
"encoding/base64"
|
2023-04-12 15:29:54 +00:00
|
|
|
"errors"
|
2023-03-28 17:29:54 +00:00
|
|
|
"fmt"
|
|
|
|
"io"
|
|
|
|
"sync"
|
|
|
|
"sync/atomic"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/hashicorp/vault/sdk/framework"
|
2023-04-03 20:08:25 +00:00
|
|
|
"github.com/hashicorp/vault/sdk/logical"
|
2023-03-28 17:29:54 +00:00
|
|
|
)
|
|
|
|
|
2023-04-03 20:08:25 +00:00
|
|
|
const (
|
|
|
|
// How long nonces are considered valid.
|
|
|
|
nonceExpiry = 15 * time.Minute
|
|
|
|
|
|
|
|
// Path Prefixes
|
2023-04-12 15:29:54 +00:00
|
|
|
acmePathPrefix = "acme/"
|
|
|
|
acmeAccountPrefix = acmePathPrefix + "accounts/"
|
|
|
|
acmeThumbprintPrefix = acmePathPrefix + "account-thumbprints/"
|
2023-04-03 20:08:25 +00:00
|
|
|
)
|
2023-03-28 17:29:54 +00:00
|
|
|
|
2023-03-29 21:08:31 +00:00
|
|
|
type acmeState struct {
|
2023-03-28 17:29:54 +00:00
|
|
|
nextExpiry *atomic.Int64
|
|
|
|
nonces *sync.Map // map[string]time.Time
|
|
|
|
}
|
|
|
|
|
2023-04-12 15:29:54 +00:00
|
|
|
type acmeThumbprint struct {
|
|
|
|
Kid string `json:"kid"`
|
|
|
|
Thumbprint string `json:"-"`
|
|
|
|
}
|
|
|
|
|
2023-03-29 21:08:31 +00:00
|
|
|
func NewACMEState() *acmeState {
|
|
|
|
return &acmeState{
|
2023-03-28 17:29:54 +00:00
|
|
|
nextExpiry: new(atomic.Int64),
|
|
|
|
nonces: new(sync.Map),
|
2023-03-29 18:22:48 +00:00
|
|
|
}
|
2023-03-28 17:29:54 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
func generateNonce() (string, error) {
|
|
|
|
data := make([]byte, 21)
|
|
|
|
if _, err := io.ReadFull(rand.Reader, data); err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
return base64.RawURLEncoding.EncodeToString(data), nil
|
|
|
|
}
|
|
|
|
|
2023-03-29 21:08:31 +00:00
|
|
|
func (a *acmeState) GetNonce() (string, time.Time, error) {
|
2023-03-28 17:29:54 +00:00
|
|
|
now := time.Now()
|
|
|
|
nonce, err := generateNonce()
|
|
|
|
if err != nil {
|
|
|
|
return "", now, err
|
|
|
|
}
|
|
|
|
|
|
|
|
then := now.Add(nonceExpiry)
|
|
|
|
a.nonces.Store(nonce, then)
|
|
|
|
|
|
|
|
nextExpiry := a.nextExpiry.Load()
|
|
|
|
next := time.Unix(nextExpiry, 0)
|
|
|
|
if now.After(next) || then.Before(next) {
|
|
|
|
a.nextExpiry.Store(then.Unix())
|
|
|
|
}
|
|
|
|
|
|
|
|
return nonce, then, nil
|
|
|
|
}
|
|
|
|
|
2023-03-29 21:08:31 +00:00
|
|
|
func (a *acmeState) RedeemNonce(nonce string) bool {
|
2023-03-28 17:29:54 +00:00
|
|
|
rawTimeout, present := a.nonces.LoadAndDelete(nonce)
|
|
|
|
if !present {
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
|
|
|
|
timeout := rawTimeout.(time.Time)
|
|
|
|
if time.Now().After(timeout) {
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
|
2023-03-29 21:08:31 +00:00
|
|
|
func (a *acmeState) DoTidyNonces() {
|
2023-03-28 17:29:54 +00:00
|
|
|
now := time.Now()
|
|
|
|
expiry := a.nextExpiry.Load()
|
|
|
|
then := time.Unix(expiry, 0)
|
|
|
|
|
|
|
|
if expiry == 0 || now.After(then) {
|
|
|
|
a.TidyNonces()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-03-29 21:08:31 +00:00
|
|
|
func (a *acmeState) TidyNonces() {
|
2023-03-28 17:29:54 +00:00
|
|
|
now := time.Now()
|
|
|
|
nextRun := now.Add(nonceExpiry)
|
|
|
|
|
|
|
|
a.nonces.Range(func(key, value any) bool {
|
|
|
|
timeout := value.(time.Time)
|
|
|
|
if now.After(timeout) {
|
|
|
|
a.nonces.Delete(key)
|
|
|
|
}
|
|
|
|
|
|
|
|
if timeout.Before(nextRun) {
|
|
|
|
nextRun = timeout
|
|
|
|
}
|
|
|
|
|
|
|
|
return false /* don't quit looping */
|
|
|
|
})
|
|
|
|
|
|
|
|
a.nextExpiry.Store(nextRun.Unix())
|
|
|
|
}
|
|
|
|
|
2023-04-12 13:05:42 +00:00
|
|
|
type ACMEAccountStatus string
|
|
|
|
|
|
|
|
func (aas ACMEAccountStatus) String() string {
|
|
|
|
return string(aas)
|
|
|
|
}
|
2023-04-03 20:08:25 +00:00
|
|
|
|
|
|
|
const (
|
2023-04-12 13:05:42 +00:00
|
|
|
StatusValid ACMEAccountStatus = "valid"
|
|
|
|
StatusDeactivated ACMEAccountStatus = "deactivated"
|
|
|
|
StatusRevoked ACMEAccountStatus = "revoked"
|
2023-04-03 20:08:25 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
type acmeAccount struct {
|
2023-04-12 13:05:42 +00:00
|
|
|
KeyId string `json:"-"`
|
|
|
|
Status ACMEAccountStatus `json:"status"`
|
|
|
|
Contact []string `json:"contact"`
|
|
|
|
TermsOfServiceAgreed bool `json:"termsOfServiceAgreed"`
|
|
|
|
Jwk []byte `json:"jwk"`
|
2023-03-28 17:29:54 +00:00
|
|
|
}
|
|
|
|
|
2023-04-03 20:08:25 +00:00
|
|
|
func (a *acmeState) CreateAccount(ac *acmeContext, c *jwsCtx, contact []string, termsOfServiceAgreed bool) (*acmeAccount, error) {
|
2023-04-12 15:29:54 +00:00
|
|
|
// Write out the thumbprint value/entry out first, if we get an error mid-way through
|
|
|
|
// this is easier to recover from. The new kid with the same existing public key
|
|
|
|
// will rewrite the thumbprint entry. This goes in hand with LoadAccountByKey that
|
|
|
|
// will return a nil, nil value if the referenced kid in a loaded thumbprint does not
|
|
|
|
// exist. This effectively makes this self-healing IF the end-user re-attempts the
|
|
|
|
// account creation with the same public key.
|
|
|
|
thumbprint, err := c.GetKeyThumbprint()
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("failed generating thumbprint: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
thumbPrint := &acmeThumbprint{
|
|
|
|
Kid: c.Kid,
|
|
|
|
Thumbprint: thumbprint,
|
|
|
|
}
|
|
|
|
thumbPrintEntry, err := logical.StorageEntryJSON(acmeThumbprintPrefix+thumbprint, thumbPrint)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("error generating account thumbprint entry: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if err = ac.sc.Storage.Put(ac.sc.Context, thumbPrintEntry); err != nil {
|
|
|
|
return nil, fmt.Errorf("error writing account thumbprint entry: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Now write out the main value that the thumbprint points too.
|
2023-04-03 20:08:25 +00:00
|
|
|
acct := &acmeAccount{
|
|
|
|
KeyId: c.Kid,
|
|
|
|
Contact: contact,
|
|
|
|
TermsOfServiceAgreed: termsOfServiceAgreed,
|
|
|
|
Jwk: c.Jwk,
|
2023-04-12 13:05:42 +00:00
|
|
|
Status: StatusValid,
|
2023-04-03 20:08:25 +00:00
|
|
|
}
|
2023-04-12 15:29:54 +00:00
|
|
|
json, err := logical.StorageEntryJSON(acmeAccountPrefix+c.Kid, acct)
|
2023-04-03 20:08:25 +00:00
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("error creating account entry: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if err := ac.sc.Storage.Put(ac.sc.Context, json); err != nil {
|
|
|
|
return nil, fmt.Errorf("error writing account entry: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return acct, nil
|
2023-03-29 19:06:09 +00:00
|
|
|
}
|
|
|
|
|
2023-04-12 13:05:42 +00:00
|
|
|
func (a *acmeState) UpdateAccount(ac *acmeContext, acct *acmeAccount) error {
|
2023-04-12 15:29:54 +00:00
|
|
|
json, err := logical.StorageEntryJSON(acmeAccountPrefix+acct.KeyId, acct)
|
2023-04-12 13:05:42 +00:00
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("error creating account entry: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if err := ac.sc.Storage.Put(ac.sc.Context, json); err != nil {
|
|
|
|
return fmt.Errorf("error writing account entry: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2023-04-12 15:29:54 +00:00
|
|
|
// LoadAccount will load the account object based on the passed in keyId field value
|
|
|
|
// otherwise will return an error if the account does not exist.
|
|
|
|
func (a *acmeState) LoadAccount(ac *acmeContext, keyId string) (*acmeAccount, error) {
|
|
|
|
entry, err := ac.sc.Storage.Get(ac.sc.Context, acmeAccountPrefix+keyId)
|
2023-04-03 20:08:25 +00:00
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("error loading account: %w", err)
|
|
|
|
}
|
|
|
|
if entry == nil {
|
2023-04-12 15:29:54 +00:00
|
|
|
return nil, fmt.Errorf("account not found: %w", ErrAccountDoesNotExist)
|
2023-04-03 20:08:25 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
var acct acmeAccount
|
|
|
|
err = entry.DecodeJSON(&acct)
|
|
|
|
if err != nil {
|
2023-04-12 15:29:54 +00:00
|
|
|
return nil, fmt.Errorf("error decoding account: %w", err)
|
2023-04-03 20:08:25 +00:00
|
|
|
}
|
|
|
|
|
2023-04-12 15:29:54 +00:00
|
|
|
acct.KeyId = keyId
|
2023-04-12 13:05:42 +00:00
|
|
|
|
2023-04-03 20:08:25 +00:00
|
|
|
return &acct, nil
|
|
|
|
}
|
|
|
|
|
2023-04-12 15:29:54 +00:00
|
|
|
// LoadAccountByKey will attempt to load the account based on a key thumbprint. If the thumbprint
|
|
|
|
// or kid is unknown a nil, nil will be returned.
|
|
|
|
func (a *acmeState) LoadAccountByKey(ac *acmeContext, keyThumbprint string) (*acmeAccount, error) {
|
|
|
|
thumbprintEntry, err := ac.sc.Storage.Get(ac.sc.Context, acmeThumbprintPrefix+keyThumbprint)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("failed loading acme thumbprintEntry for key: %w", err)
|
|
|
|
}
|
|
|
|
if thumbprintEntry == nil {
|
|
|
|
return nil, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
var thumbprint acmeThumbprint
|
|
|
|
err = thumbprintEntry.DecodeJSON(&thumbprint)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("failed decoding thumbprint entry: %s: %w", keyThumbprint, err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if len(thumbprint.Kid) == 0 {
|
|
|
|
return nil, fmt.Errorf("empty kid within thumbprint entry: %s", keyThumbprint)
|
|
|
|
}
|
|
|
|
|
|
|
|
acct, err := a.LoadAccount(ac, thumbprint.Kid)
|
|
|
|
if err != nil {
|
|
|
|
// If we fail to lookup the account that the thumbprint entry references, assume a bad
|
|
|
|
// write previously occurred in which we managed to write out the thumbprint but failed
|
|
|
|
// writing out the main account information.
|
|
|
|
if errors.Is(err, ErrAccountDoesNotExist) {
|
|
|
|
return nil, nil
|
|
|
|
}
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return acct, nil
|
2023-04-03 20:08:25 +00:00
|
|
|
}
|
|
|
|
|
2023-04-12 15:29:54 +00:00
|
|
|
func (a *acmeState) LoadJWK(ac *acmeContext, keyId string) ([]byte, error) {
|
|
|
|
key, err := a.LoadAccount(ac, keyId)
|
2023-03-28 17:29:54 +00:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
2023-04-03 20:08:25 +00:00
|
|
|
if len(key.Jwk) == 0 {
|
2023-03-28 17:29:54 +00:00
|
|
|
return nil, fmt.Errorf("malformed key entry lacks JWK")
|
|
|
|
}
|
|
|
|
|
2023-04-03 20:08:25 +00:00
|
|
|
return key.Jwk, nil
|
2023-03-28 17:29:54 +00:00
|
|
|
}
|
|
|
|
|
2023-04-03 20:08:25 +00:00
|
|
|
func (a *acmeState) ParseRequestParams(ac *acmeContext, data *framework.FieldData) (*jwsCtx, map[string]interface{}, error) {
|
2023-03-29 21:08:31 +00:00
|
|
|
var c jwsCtx
|
2023-03-28 17:29:54 +00:00
|
|
|
var m map[string]interface{}
|
|
|
|
|
|
|
|
// Parse the key out.
|
2023-03-29 21:08:31 +00:00
|
|
|
rawJWKBase64, ok := data.GetOk("protected")
|
|
|
|
if !ok {
|
|
|
|
return nil, nil, fmt.Errorf("missing required field 'protected': %w", ErrMalformed)
|
|
|
|
}
|
|
|
|
jwkBase64 := rawJWKBase64.(string)
|
|
|
|
|
2023-03-28 17:29:54 +00:00
|
|
|
jwkBytes, err := base64.RawURLEncoding.DecodeString(jwkBase64)
|
|
|
|
if err != nil {
|
2023-03-29 21:08:31 +00:00
|
|
|
return nil, nil, fmt.Errorf("failed to base64 parse 'protected': %s: %w", err, ErrMalformed)
|
2023-03-28 17:29:54 +00:00
|
|
|
}
|
2023-04-03 20:08:25 +00:00
|
|
|
if err = c.UnmarshalJSON(a, ac, jwkBytes); err != nil {
|
2023-03-28 17:29:54 +00:00
|
|
|
return nil, nil, fmt.Errorf("failed to json unmarshal 'protected': %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Since we already parsed the header to verify the JWS context, we
|
|
|
|
// should read and redeem the nonce here too, to avoid doing any extra
|
|
|
|
// work if it is invalid.
|
|
|
|
if !a.RedeemNonce(c.Nonce) {
|
2023-03-29 21:08:31 +00:00
|
|
|
return nil, nil, fmt.Errorf("invalid or reused nonce: %w", ErrBadNonce)
|
2023-03-28 17:29:54 +00:00
|
|
|
}
|
|
|
|
|
2023-03-29 21:08:31 +00:00
|
|
|
rawPayloadBase64, ok := data.GetOk("payload")
|
|
|
|
if !ok {
|
|
|
|
return nil, nil, fmt.Errorf("missing required field 'payload': %w", ErrMalformed)
|
|
|
|
}
|
|
|
|
payloadBase64 := rawPayloadBase64.(string)
|
|
|
|
|
|
|
|
rawSignatureBase64, ok := data.GetOk("signature")
|
|
|
|
if !ok {
|
|
|
|
return nil, nil, fmt.Errorf("missing required field 'signature': %w", ErrMalformed)
|
|
|
|
}
|
|
|
|
signatureBase64 := rawSignatureBase64.(string)
|
2023-03-28 17:29:54 +00:00
|
|
|
|
|
|
|
// go-jose only seems to support compact signature encodings.
|
|
|
|
compactSig := fmt.Sprintf("%v.%v.%v", jwkBase64, payloadBase64, signatureBase64)
|
|
|
|
m, err = c.VerifyJWS(compactSig)
|
|
|
|
if err != nil {
|
|
|
|
return nil, nil, fmt.Errorf("failed to verify signature: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return &c, m, nil
|
|
|
|
}
|