Add acme account storage (#19953)

* Enable creation of accounts

 - Refactors many methods to take an acmeContext, which holds the
   storageContext on it.
 - Updates the core ACME Handlers to use *acmeContext, to avoid
   copying structs.
 - Makes JWK exported so the JSON parser can find it.

Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com>

* Finish ACME account creation

 - This ensures a Kid is created when one doesn't exist
 - Expands the parsed handler capabilities, to format the response and
   set required headers.

Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com>

---------

Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com>
This commit is contained in:
Alexander Scheel 2023-04-03 16:08:25 -04:00 committed by GitHub
parent b86a09fb2a
commit 3ed31ff262
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 198 additions and 59 deletions

View File

@ -26,20 +26,20 @@ var AllowedOuterJWSTypes = map[string]interface{}{
type jwsCtx struct {
Algo string `json:"alg"`
Kid string `json:"kid"`
jwk json.RawMessage `json:"jwk"`
Jwk json.RawMessage `json:"jwk"`
Nonce string `json:"nonce"`
Url string `json:"url"`
key jose.JSONWebKey `json:"-"`
Key jose.JSONWebKey `json:"-"`
Existing bool `json:"-"`
}
func (c *jwsCtx) UnmarshalJSON(a *acmeState, jws []byte) error {
func (c *jwsCtx) UnmarshalJSON(a *acmeState, ac *acmeContext, jws []byte) error {
var err error
if err = json.Unmarshal(jws, c); err != nil {
return err
}
if c.Kid != "" && len(c.jwk) > 0 {
if c.Kid != "" && len(c.Jwk) > 0 {
// See RFC 8555 Section 6.2. Request Authentication:
//
// > The "jwk" and "kid" fields are mutually exclusive. Servers MUST
@ -47,7 +47,7 @@ func (c *jwsCtx) UnmarshalJSON(a *acmeState, jws []byte) error {
return fmt.Errorf("invalid header: got both account 'kid' and 'jwk' in the same message; expected only one: %w", ErrMalformed)
}
if c.Kid == "" && len(c.jwk) == 0 {
if c.Kid == "" && len(c.Jwk) == 0 {
// See RFC 8555 Section 6.2. Request Authentication:
//
// > Either "jwk" (JSON Web Key) or "kid" (Key ID) as specified
@ -70,24 +70,24 @@ func (c *jwsCtx) UnmarshalJSON(a *acmeState, jws []byte) error {
if c.Kid != "" {
// Load KID from storage first.
c.jwk, err = a.LoadJWK(c.Kid)
c.Jwk, err = a.LoadJWK(ac, c.Kid)
if err != nil {
return err
}
c.Existing = true
}
if err = c.key.UnmarshalJSON(c.jwk); err != nil {
if err = c.Key.UnmarshalJSON(c.Jwk); err != nil {
return err
}
if !c.key.Valid() {
if !c.Key.Valid() {
return fmt.Errorf("received invalid jwk: %w", ErrMalformed)
}
if c.Kid != "" {
if c.Kid == "" {
// Create a key ID
kid, err := c.key.Thumbprint(crypto.SHA256)
kid, err := c.Key.Thumbprint(crypto.SHA256)
if err != nil {
return fmt.Errorf("failed creating thumbprint: %w", err)
}
@ -128,7 +128,7 @@ func (c *jwsCtx) VerifyJWS(signature string) (map[string]interface{}, error) {
return nil, fmt.Errorf("request had unprotected headers: %w", ErrMalformed)
}
payload, err := sig.Verify(c.key)
payload, err := sig.Verify(c.Key)
if err != nil {
return nil, err
}

View File

@ -5,15 +5,23 @@ import (
"encoding/base64"
"fmt"
"io"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
// How long nonces are considered valid.
const nonceExpiry = 15 * time.Minute
const (
// How long nonces are considered valid.
nonceExpiry = 15 * time.Minute
// Path Prefixes
acmePathPrefix = "acme/"
acmeAccountPrefix = acmePathPrefix + "accounts/"
)
type acmeState struct {
nextExpiry *atomic.Int64
@ -99,36 +107,86 @@ func (a *acmeState) TidyNonces() {
a.nextExpiry.Store(nextRun.Unix())
}
func (a *acmeState) CreateAccount(c *jwsCtx, contact []string, termsOfServiceAgreed bool) (map[string]interface{}, error) {
// TODO
return nil, nil
type ACMEStates string
const (
StatusValid = "valid"
StatusDeactivated = "deactivated"
StatusRevoked = "revoked"
)
type acmeAccount struct {
KeyId string `json:"-"`
Status ACMEStates `json:"state"`
Contact []string `json:"contact"`
TermsOfServiceAgreed bool `json:"termsOfServiceAgreed"`
Jwk []byte `json:"jwk"`
}
func (a *acmeState) LoadAccount(keyID string) (map[string]interface{}, error) {
// TODO
return nil, nil
func (a *acmeState) CreateAccount(ac *acmeContext, c *jwsCtx, contact []string, termsOfServiceAgreed bool) (*acmeAccount, error) {
acct := &acmeAccount{
KeyId: c.Kid,
Contact: contact,
TermsOfServiceAgreed: termsOfServiceAgreed,
Jwk: c.Jwk,
}
json, err := logical.StorageEntryJSON(acmeAccountPrefix+c.Kid, acct)
if err != nil {
return nil, fmt.Errorf("error creating account entry: %w", err)
}
if err := ac.sc.Storage.Put(ac.sc.Context, json); err != nil {
return nil, fmt.Errorf("error writing account entry: %w", err)
}
return acct, nil
}
func (a *acmeState) DoesAccountExist(keyId string) bool {
account, err := a.LoadAccount(keyId)
return err == nil && len(account) > 0
func cleanKid(keyID string) string {
pieces := strings.Split(keyID, "/")
return pieces[len(pieces)-1]
}
func (a *acmeState) LoadJWK(keyID string) ([]byte, error) {
key, err := a.LoadAccount(keyID)
func (a *acmeState) LoadAccount(ac *acmeContext, keyID string) (*acmeAccount, error) {
kid := cleanKid(keyID)
entry, err := ac.sc.Storage.Get(ac.sc.Context, acmeAccountPrefix+kid)
if err != nil {
return nil, fmt.Errorf("error loading account: %w", err)
}
if entry == nil {
return nil, fmt.Errorf("account not found: %w", ErrMalformed)
}
var acct acmeAccount
err = entry.DecodeJSON(&acct)
if err != nil {
return nil, fmt.Errorf("error loading account: %w", err)
}
return &acct, nil
}
func (a *acmeState) DoesAccountExist(ac *acmeContext, keyId string) bool {
account, err := a.LoadAccount(ac, keyId)
return err == nil && account != nil
}
func (a *acmeState) LoadJWK(ac *acmeContext, keyID string) ([]byte, error) {
key, err := a.LoadAccount(ac, keyID)
if err != nil {
return nil, err
}
jwk, present := key["jwk"]
if !present {
if len(key.Jwk) == 0 {
return nil, fmt.Errorf("malformed key entry lacks JWK")
}
return jwk.([]byte), nil
return key.Jwk, nil
}
func (a *acmeState) ParseRequestParams(data *framework.FieldData) (*jwsCtx, map[string]interface{}, error) {
func (a *acmeState) ParseRequestParams(ac *acmeContext, data *framework.FieldData) (*jwsCtx, map[string]interface{}, error) {
var c jwsCtx
var m map[string]interface{}
@ -143,7 +201,7 @@ func (a *acmeState) ParseRequestParams(data *framework.FieldData) (*jwsCtx, map[
if err != nil {
return nil, nil, fmt.Errorf("failed to base64 parse 'protected': %s: %w", err, ErrMalformed)
}
if err = c.UnmarshalJSON(a, jwkBytes); err != nil {
if err = c.UnmarshalJSON(a, ac, jwkBytes); err != nil {
return nil, nil, fmt.Errorf("failed to json unmarshal 'protected': %w", err)
}

View File

@ -55,7 +55,7 @@ func patternAcmeDirectory(b *backend, pattern string) *framework.Path {
}
}
type acmeOperation func(acmeCtx acmeContext, r *logical.Request, _ *framework.FieldData) (*logical.Response, error)
type acmeOperation func(acmeCtx *acmeContext, r *logical.Request, _ *framework.FieldData) (*logical.Response, error)
type acmeContext struct {
baseUrl *url.URL
@ -76,7 +76,7 @@ func (b *backend) acmeWrapper(op acmeOperation) framework.OperationFunc {
return nil, err
}
acmeCtx := acmeContext{
acmeCtx := &acmeContext{
baseUrl: baseUrl,
sc: sc,
}
@ -120,7 +120,7 @@ func acmeErrorWrapper(op framework.OperationFunc) framework.OperationFunc {
}
}
func (b *backend) acmeDirectoryHandler(acmeCtx acmeContext, r *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
func (b *backend) acmeDirectoryHandler(acmeCtx *acmeContext, r *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
rawBody, err := json.Marshal(map[string]interface{}{
"newNonce": acmeCtx.baseUrl.JoinPath("new-nonce").String(),
"newAccount": acmeCtx.baseUrl.JoinPath("new-account").String(),

View File

@ -1,7 +1,9 @@
package pki
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/hashicorp/vault/sdk/framework"
@ -88,31 +90,102 @@ func patternAcmeNewAccount(b *backend, pattern string) *framework.Path {
}
}
type acmeParsedOperation func(acmeCtx acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}) (*logical.Response, error)
type acmeParsedOperation func(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}) (*logical.Response, error)
func (b *backend) acmeParsedWrapper(op acmeParsedOperation) framework.OperationFunc {
return b.acmeWrapper(func(acmeCtx acmeContext, r *logical.Request, fields *framework.FieldData) (*logical.Response, error) {
user, data, err := b.acmeState.ParseRequestParams(fields)
return b.acmeWrapper(func(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData) (*logical.Response, error) {
user, data, err := b.acmeState.ParseRequestParams(acmeCtx, fields)
if err != nil {
return nil, err
}
return op(acmeCtx, r, fields, user, data)
resp, err := op(acmeCtx, r, fields, user, data)
// Our response handlers might not add the necessary headers.
if resp != nil {
if resp.Headers == nil {
resp.Headers = map[string][]string{}
}
if _, ok := resp.Headers["Replay-Nonce"]; !ok {
nonce, _, err := b.acmeState.GetNonce()
if err != nil {
return nil, err
}
resp.Headers["Replay-Nonce"] = []string{nonce}
}
if _, ok := resp.Headers["Link"]; !ok {
resp.Headers["Link"] = genAcmeLinkHeader(acmeCtx)
} else {
directory := genAcmeLinkHeader(acmeCtx)[0]
addDirectory := true
for _, item := range resp.Headers["Link"] {
if item == directory {
addDirectory = false
break
}
}
if addDirectory {
resp.Headers["Link"] = append(resp.Headers["Link"], directory)
}
}
// ACME responses don't understand Vault's default encoding
// format. Rather than expecting everything to handle creating
// ACME-formatted responses, do the marshaling in one place.
if _, ok := resp.Data[logical.HTTPRawBody]; !ok {
ignored_values := map[string]bool{logical.HTTPContentType: true, logical.HTTPStatusCode: true}
fields := map[string]interface{}{}
body := map[string]interface{}{
logical.HTTPContentType: "application/json",
logical.HTTPStatusCode: http.StatusOK,
}
for key, value := range resp.Data {
if _, present := ignored_values[key]; !present {
fields[key] = value
} else {
body[key] = value
}
}
rawBody, err := json.Marshal(fields)
if err != nil {
return nil, fmt.Errorf("Error marshaling JSON body: %w", err)
}
body[logical.HTTPRawBody] = rawBody
resp.Data = body
}
}
return resp, err
})
}
func (b *backend) acmeNewAccountHandler(acmeCtx acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}) (*logical.Response, error) {
func (b *backend) acmeNewAccountHandler(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}) (*logical.Response, error) {
// Parameters
var ok bool
var onlyReturnExisting bool
var contact []string
var contacts []string
var termsOfServiceAgreed bool
rawContact, present := data["contact"]
if present {
contact, ok = rawContact.([]string)
listContact, ok := rawContact.([]interface{})
if !ok {
return nil, fmt.Errorf("invalid type for field 'contact': %w", ErrMalformed)
return nil, fmt.Errorf("invalid type (%T) for field 'contact': %w", rawContact, ErrMalformed)
}
for index, singleContact := range listContact {
contact, ok := singleContact.(string)
if !ok {
return nil, fmt.Errorf("invalid type (%T) for field 'contact' item %d: %w", singleContact, index, ErrMalformed)
}
contacts = append(contacts, contact)
}
}
@ -120,7 +193,7 @@ func (b *backend) acmeNewAccountHandler(acmeCtx acmeContext, r *logical.Request,
if present {
termsOfServiceAgreed, ok = rawTermsOfServiceAgreed.(bool)
if !ok {
return nil, fmt.Errorf("invalid type for field 'termsOfServiceAgreed': %w", ErrMalformed)
return nil, fmt.Errorf("invalid type (%T) for field 'termsOfServiceAgreed': %w", rawTermsOfServiceAgreed, ErrMalformed)
}
}
@ -128,7 +201,7 @@ func (b *backend) acmeNewAccountHandler(acmeCtx acmeContext, r *logical.Request,
if present {
onlyReturnExisting, ok = rawOnlyReturnExisting.(bool)
if !ok {
return nil, fmt.Errorf("invalid type for field 'onlyReturnExisting': %w", ErrMalformed)
return nil, fmt.Errorf("invalid type (%T) for field 'onlyReturnExisting': %w", rawOnlyReturnExisting, ErrMalformed)
}
}
@ -139,38 +212,39 @@ func (b *backend) acmeNewAccountHandler(acmeCtx acmeContext, r *logical.Request,
return b.acmeNewAccountSearchHandler(acmeCtx, r, fields, userCtx, data)
}
return b.acmeNewAccountCreateHandler(acmeCtx, r, fields, userCtx, data, contact, termsOfServiceAgreed)
return b.acmeNewAccountCreateHandler(acmeCtx, r, fields, userCtx, data, contacts, termsOfServiceAgreed)
}
func formatAccountResponse(location string, status string, contact []string) *logical.Response {
func formatAccountResponse(location string, acct *acmeAccount) *logical.Response {
resp := &logical.Response{
Data: map[string]interface{}{
"status": status,
"status": acct.Status,
"orders": location + "/orders",
},
Headers: map[string][]string{
"Location": {location},
},
}
if len(contact) > 0 {
resp.Data["contact"] = contact
if len(acct.Contact) > 0 {
resp.Data["contact"] = acct.Contact
}
resp.Headers["Location"] = []string{location}
return resp
}
func (b *backend) acmeNewAccountSearchHandler(acmeCtx acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}) (*logical.Response, error) {
if userCtx.Existing || b.acmeState.DoesAccountExist(userCtx.Kid) {
func (b *backend) acmeNewAccountSearchHandler(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}) (*logical.Response, error) {
if userCtx.Existing || b.acmeState.DoesAccountExist(acmeCtx, userCtx.Kid) {
// This account exists; return its details. It would be slightly
// weird to specify a kid in the request (and not use an explicit
// jwk here), but we might as well support it too.
account, err := b.acmeState.LoadAccount(userCtx.Kid)
account, err := b.acmeState.LoadAccount(acmeCtx, userCtx.Kid)
if err != nil {
return nil, fmt.Errorf("error loading account: %w", err)
}
location := acmeCtx.baseUrl.String() + "account/" + userCtx.Kid
return formatAccountResponse(location, account["status"].(string), account["contact"].([]string)), nil
return formatAccountResponse(location, account), nil
}
// Per RFC 8555 Section 7.3.1. Finding an Account URL Given a Key:
@ -181,13 +255,13 @@ func (b *backend) acmeNewAccountSearchHandler(acmeCtx acmeContext, r *logical.Re
return nil, fmt.Errorf("An account with this key does not exist: %w", ErrAccountDoesNotExist)
}
func (b *backend) acmeNewAccountCreateHandler(acmeCtx acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}, contact []string, termsOfServiceAgreed bool) (*logical.Response, error) {
func (b *backend) acmeNewAccountCreateHandler(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}, contact []string, termsOfServiceAgreed bool) (*logical.Response, error) {
if userCtx.Existing {
return nil, fmt.Errorf("cannot submit to newAccount with 'kid': %w", ErrMalformed)
}
// If the account already exists, return the existing one.
if b.acmeState.DoesAccountExist(userCtx.Kid) {
if b.acmeState.DoesAccountExist(acmeCtx, userCtx.Kid) {
return b.acmeNewAccountSearchHandler(acmeCtx, r, fields, userCtx, data)
}
@ -196,11 +270,18 @@ func (b *backend) acmeNewAccountCreateHandler(acmeCtx acmeContext, r *logical.Re
return nil, fmt.Errorf("terms of service not agreed to: %w", ErrUserActionRequired)
}
account, err := b.acmeState.CreateAccount(userCtx, contact, termsOfServiceAgreed)
account, err := b.acmeState.CreateAccount(acmeCtx, userCtx, contact, termsOfServiceAgreed)
if err != nil {
return nil, fmt.Errorf("failed to create account: %w", err)
}
location := acmeCtx.baseUrl.String() + "account/" + userCtx.Kid
return formatAccountResponse(location, account["status"].(string), account["contact"].([]string)), nil
resp := formatAccountResponse(location, account)
// Per RFC 8555 Section 7.3. Account Management:
//
// > The server returns this account object in a 201 (Created) response,
// > with the account URL in a Location header field.
resp.Data[logical.HTTPStatusCode] = http.StatusCreated
return resp, nil
}

View File

@ -51,7 +51,7 @@ func patternAcmeNonce(b *backend, pattern string) *framework.Path {
}
}
func (b *backend) acmeNonceHandler(ctx acmeContext, r *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
func (b *backend) acmeNonceHandler(ctx *acmeContext, r *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
nonce, _, err := b.acmeState.GetNonce()
if err != nil {
return nil, err
@ -78,7 +78,7 @@ func (b *backend) acmeNonceHandler(ctx acmeContext, r *logical.Request, _ *frame
}, nil
}
func genAcmeLinkHeader(ctx acmeContext) []string {
func genAcmeLinkHeader(ctx *acmeContext) []string {
path := fmt.Sprintf("<%s>;rel=\"index\"", ctx.baseUrl.JoinPath("directory").String())
return []string{path}
}