diff --git a/builtin/logical/pki/acme/errors.go b/builtin/logical/pki/acme/errors.go new file mode 100644 index 000000000..996c0f66a --- /dev/null +++ b/builtin/logical/pki/acme/errors.go @@ -0,0 +1,179 @@ +package acme + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/hashicorp/go-multierror" + "github.com/hashicorp/vault/sdk/logical" +) + +// Error prefix; see RFC 8555 Section 6.7. Errors. +const ErrorPrefix = "urn:ietf:params:acme:error:" +const ErrorContentType = "application/problem+json" + +// See RFC 8555 Section 6.7. Errors. +var ErrAccountDoesNotExist = errors.New("The request specified an account that does not exist") + +var ( + ErrAlreadyRevoked = errors.New("The request specified a certificate to be revoked that has already been revoked") + ErrBadCSR = errors.New("The CSR is unacceptable (e.g., due to a short key)") + ErrBadNonce = errors.New("The client sent an unacceptable anti-replay nonce") + ErrBadPublicKey = errors.New("The JWS was signed by a public key the server does not support") + ErrBadRevocationReason = errors.New("The revocation reason provided is not allowed by the server") + ErrBadSignatureAlgorithm = errors.New("The JWS was signed with an algorithm the server does not support") + ErrCAA = errors.New("Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate") + ErrCompound = errors.New("Specific error conditions are indicated in the 'subproblems' array") + ErrConnection = errors.New("The server could not connect to validation target") + ErrDNS = errors.New("There was a problem with a DNS query during identifier validation") + ErrExternalAccountRequired = errors.New("The request must include a value for the 'externalAccountBinding' field") + ErrIncorrectResponse = errors.New("Response received didn't match the challenge's requirements") + ErrInvalidContact = errors.New("A contact URL for an account was invalid") + ErrMalformed = errors.New("The request message was malformed") + ErrOrderNotReady = errors.New("The request attempted to finalize an order that is not ready to be finalized") + ErrRateLimited = errors.New("The request exceeds a rate limit") + ErrRejectedIdentifier = errors.New("The server will not issue certificates for the identifier") + ErrServerInternal = errors.New("The server experienced an internal error") + ErrTLS = errors.New("The server received a TLS error during validation") + ErrUnauthorized = errors.New("The client lacks sufficient authorization") + ErrUnsupportedContact = errors.New("A contact URL for an account used an unsupported protocol scheme") + ErrUnsupportedIdentifier = errors.New("An identifier is of an unsupported type") + ErrUserActionRequired = errors.New("Visit the 'instance' URL and take actions specified there") +) + +// Mapping of err->name; see table in RFC 8555 Section 6.7. Errors. +var errIdMappings = map[error]string{ + ErrAccountDoesNotExist: "accountDoesNotExist", + ErrAlreadyRevoked: "alreadyRevoked", + ErrBadCSR: "badCSR", + ErrBadNonce: "badNonce", + ErrBadPublicKey: "badPublicKey", + ErrBadRevocationReason: "badRevocationReason", + ErrBadSignatureAlgorithm: "badSignatureAlgorithm", + ErrCAA: "caa", + ErrCompound: "compound", + ErrConnection: "connection", + ErrDNS: "dns", + ErrExternalAccountRequired: "externalAccountRequired", + ErrIncorrectResponse: "incorrectResponse", + ErrInvalidContact: "invalidContact", + ErrMalformed: "malformed", + ErrOrderNotReady: "orderNotReady", + ErrRateLimited: "rateLimited", + ErrRejectedIdentifier: "rejectedIdentifier", + ErrServerInternal: "serverInternal", + ErrTLS: "tls", + ErrUnauthorized: "unauthorized", + ErrUnsupportedContact: "unsupportedContact", + ErrUnsupportedIdentifier: "unsupportedIdentifier", + ErrUserActionRequired: "userActionRequired", +} + +// Mapping of err->status codes; see table in RFC 8555 Section 6.7. Errors. +var errCodeMappings = map[error]int{ + ErrAccountDoesNotExist: http.StatusNotFound, + ErrAlreadyRevoked: http.StatusBadRequest, + ErrBadCSR: http.StatusBadRequest, + ErrBadNonce: http.StatusBadRequest, + ErrBadPublicKey: http.StatusBadRequest, + ErrBadRevocationReason: http.StatusBadRequest, + ErrBadSignatureAlgorithm: http.StatusBadRequest, + ErrCAA: http.StatusForbidden, + ErrCompound: http.StatusBadRequest, + ErrConnection: http.StatusInternalServerError, + ErrDNS: http.StatusInternalServerError, + ErrExternalAccountRequired: http.StatusUnauthorized, + ErrIncorrectResponse: http.StatusBadRequest, + ErrInvalidContact: http.StatusBadRequest, + ErrMalformed: http.StatusBadRequest, + ErrOrderNotReady: http.StatusForbidden, // See RFC 8555 Section 7.4. Applying for Certificate Issuance. + ErrRateLimited: http.StatusTooManyRequests, + ErrRejectedIdentifier: http.StatusBadRequest, + ErrServerInternal: http.StatusInternalServerError, + ErrTLS: http.StatusInternalServerError, + ErrUnauthorized: http.StatusUnauthorized, + ErrUnsupportedContact: http.StatusBadRequest, + ErrUnsupportedIdentifier: http.StatusBadRequest, + ErrUserActionRequired: http.StatusUnauthorized, +} + +type ErrorResponse struct { + StatusCode int `json:"-"` + Type string `json:"type"` + Detail string `json:"detail"` + Subproblems []*ErrorResponse `json:"subproblems"` +} + +func (e *ErrorResponse) Marshal() (*logical.Response, error) { + body, err := json.Marshal(e) + if err != nil { + return nil, fmt.Errorf("failed marshalling of error response: %w", err) + } + + var resp logical.Response + resp.Data = map[string]interface{}{ + logical.HTTPContentType: ErrorContentType, + logical.HTTPRawBody: body, + logical.HTTPStatusCode: e.StatusCode, + } + + return &resp, nil +} + +func FindType(given error) (err error, id string, code int, found bool) { + for err, id = range errIdMappings { + if errors.Is(given, err) { + break + } + } + + if err == nil { + err = ErrServerInternal + id = errIdMappings[err] + } + + code = errCodeMappings[err] + + return +} + +func TranslateError(given error) (*logical.Response, error) { + if errors.Is(given, logical.ErrReadOnly) { + return nil, given + } + + // We're multierror aware here: if we're given a list of errors, assume + // they're structured so the first error is the outer error and the inner + // subproblems are subsequent in the multierror. + var remaining []error + if unwrapped, ok := given.(*multierror.Error); ok { + remaining = unwrapped.Errors[1:] + given = unwrapped.Errors[0] + } + + _, id, code, found := FindType(given) + if !found && len(remaining) > 0 { + // Translate multierrors into a generic error code. + id = errIdMappings[ErrCompound] + code = errCodeMappings[ErrCompound] + } + + var body ErrorResponse + body.Type = ErrorPrefix + id + body.Detail = given.Error() + body.StatusCode = code + + for _, subgiven := range remaining { + _, subid, _, _ := FindType(subgiven) + + var sub ErrorResponse + sub.Type = ErrorPrefix + subid + body.Detail = subgiven.Error() + + body.Subproblems = append(body.Subproblems, &sub) + } + + return body.Marshal() +} diff --git a/builtin/logical/pki/acme/jws.go b/builtin/logical/pki/acme/jws.go new file mode 100644 index 000000000..c63b23590 --- /dev/null +++ b/builtin/logical/pki/acme/jws.go @@ -0,0 +1,96 @@ +package acme + +import ( + "encoding/json" + "fmt" + + jose "gopkg.in/square/go-jose.v2" +) + +// This wraps a JWS message structure. +type JWSCtx struct { + Algo string `json:"alg"` + Kid string `json:"kid"` + jwk json.RawMessage `json:"jwk"` + Nonce string `json:"nonce"` + Url string `json:"url"` + key jose.JSONWebKey `json:"-"` +} + +func (c *JWSCtx) UnmarshalJSON(a *ACMEState, jws []byte) error { + var err error + if err = json.Unmarshal(jws, c); err != nil { + return err + } + + if c.Kid != "" && len(c.jwk) > 0 { + // See RFC 8555 Section 6.2. Request Authentication: + // + // > The "jwk" and "kid" fields are mutually exclusive. Servers MUST + // > reject requests that contain both. + return fmt.Errorf("invalid header: got both account 'kid' and 'jwk' in the same message; expected only one") + } + + if c.Kid == "" && len(c.jwk) == 0 { + // See RFC 8555 Section 6.2 Request Authorization: + // + // > Either "jwk" (JSON Web Key) or "kid" (Key ID) as specified + // > below + return fmt.Errorf("invalid header: got neither required fields of 'kid' nor 'jwk'") + } + + if c.Kid != "" { + // Load KID from storage first. + c.jwk, err = a.LoadJWK(c.Kid) + if err != nil { + return err + } + } + + if err = c.key.UnmarshalJSON(c.jwk); err != nil { + return err + } + + if !c.key.Valid() { + return fmt.Errorf("received invalid jwk") + } + + return nil +} + +func hasValues(h jose.Header) bool { + return h.KeyID != "" || h.JSONWebKey != nil || h.Algorithm != "" || h.Nonce != "" || len(h.ExtraHeaders) > 0 +} + +func (c *JWSCtx) VerifyJWS(signature string) (map[string]interface{}, error) { + sig, err := jose.ParseSigned(signature) + if err != nil { + return nil, fmt.Errorf("error parsing signature: %w", err) + } + + if len(sig.Signatures) > 1 { + // See RFC 8555 Section 6.2. Request Authentication: + // + // > The JWS MUST NOT have multiple signatures + return nil, fmt.Errorf("request had multiple signatures") + } + + if hasValues(sig.Signatures[0].Unprotected) { + // See RFC 8555 Section 6.2. Request Authentication: + // + // > The JWS Unprotected Header [RFC7515] MUST NOT be used + return nil, fmt.Errorf("request had unprotected headers") + } + + payload, err := sig.Verify(c.key) + if err != nil { + return nil, err + } + + var m map[string]interface{} + if err := json.Unmarshal(payload, &m); err != nil { + return nil, fmt.Errorf("failed to json unmarshal 'payload': %w", err) + } + + return m, nil +} diff --git a/builtin/logical/pki/acme/state.go b/builtin/logical/pki/acme/state.go new file mode 100644 index 000000000..889d75eb5 --- /dev/null +++ b/builtin/logical/pki/acme/state.go @@ -0,0 +1,153 @@ +package acme + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "io" + "sync" + "sync/atomic" + "time" + + "github.com/hashicorp/vault/sdk/framework" +) + +// How long nonces are considered valid. +const nonceExpiry = 15 * time.Minute + +type ACMEState struct { + nextExpiry *atomic.Int64 + nonces *sync.Map // map[string]time.Time +} + +func NewACMEState() (*ACMEState, error) { + return &ACMEState{ + nextExpiry: new(atomic.Int64), + nonces: new(sync.Map), + }, nil +} + +func generateNonce() (string, error) { + data := make([]byte, 21) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + return "", err + } + + return base64.RawURLEncoding.EncodeToString(data), nil +} + +func (a *ACMEState) GetNonce() (string, time.Time, error) { + now := time.Now() + nonce, err := generateNonce() + if err != nil { + return "", now, err + } + + then := now.Add(nonceExpiry) + a.nonces.Store(nonce, then) + + nextExpiry := a.nextExpiry.Load() + next := time.Unix(nextExpiry, 0) + if now.After(next) || then.Before(next) { + a.nextExpiry.Store(then.Unix()) + } + + return nonce, then, nil +} + +func (a *ACMEState) RedeemNonce(nonce string) bool { + rawTimeout, present := a.nonces.LoadAndDelete(nonce) + if !present { + return false + } + + timeout := rawTimeout.(time.Time) + if time.Now().After(timeout) { + return false + } + + return true +} + +func (a *ACMEState) DoTidyNonces() { + now := time.Now() + expiry := a.nextExpiry.Load() + then := time.Unix(expiry, 0) + + if expiry == 0 || now.After(then) { + a.TidyNonces() + } +} + +func (a *ACMEState) TidyNonces() { + now := time.Now() + nextRun := now.Add(nonceExpiry) + + a.nonces.Range(func(key, value any) bool { + timeout := value.(time.Time) + if now.After(timeout) { + a.nonces.Delete(key) + } + + if timeout.Before(nextRun) { + nextRun = timeout + } + + return false /* don't quit looping */ + }) + + a.nextExpiry.Store(nextRun.Unix()) +} + +func (a *ACMEState) LoadKey(keyID string) (map[string]interface{}, error) { + // TODO + return nil, nil +} + +func (a *ACMEState) LoadJWK(keyID string) ([]byte, error) { + key, err := a.LoadKey(keyID) + if err != nil { + return nil, err + } + + jwk, present := key["jwk"] + if !present { + return nil, fmt.Errorf("malformed key entry lacks JWK") + } + + return jwk.([]byte), nil +} + +func (a *ACMEState) ParseRequestParams(data *framework.FieldData) (*JWSCtx, map[string]interface{}, error) { + var c JWSCtx + var m map[string]interface{} + + // Parse the key out. + jwkBase64 := data.Get("protected").(string) + jwkBytes, err := base64.RawURLEncoding.DecodeString(jwkBase64) + if err != nil { + return nil, nil, fmt.Errorf("failed to base64 parse 'protected': %w", err) + } + if err = c.UnmarshalJSON(a, jwkBytes); err != nil { + return nil, nil, fmt.Errorf("failed to json unmarshal 'protected': %w", err) + } + + // Since we already parsed the header to verify the JWS context, we + // should read and redeem the nonce here too, to avoid doing any extra + // work if it is invalid. + if !a.RedeemNonce(c.Nonce) { + return nil, nil, fmt.Errorf("invalid or reused nonce") + } + + payloadBase64 := data.Get("payload").(string) + signatureBase64 := data.Get("signature").(string) + + // go-jose only seems to support compact signature encodings. + compactSig := fmt.Sprintf("%v.%v.%v", jwkBase64, payloadBase64, signatureBase64) + m, err = c.VerifyJWS(compactSig) + if err != nil { + return nil, nil, fmt.Errorf("failed to verify signature: %w", err) + } + + return &c, m, nil +} diff --git a/builtin/logical/pki/acme/state_test.go b/builtin/logical/pki/acme/state_test.go new file mode 100644 index 000000000..e33dd91b8 --- /dev/null +++ b/builtin/logical/pki/acme/state_test.go @@ -0,0 +1,40 @@ +package acme + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAcmeNonces(t *testing.T) { + t.Parallel() + + a, err := NewACMEState() + require.NoError(t, err) + + // Simple operation should succeed. + nonce, _, err := a.GetNonce() + require.NoError(t, err) + require.NotEmpty(t, nonce) + + require.True(t, a.RedeemNonce(nonce)) + require.False(t, a.RedeemNonce(nonce)) + + // Redeeming in opposite order should work. + var nonces []string + for i := 0; i < len(nonce); i++ { + nonce, _, err = a.GetNonce() + require.NoError(t, err) + require.NotEmpty(t, nonce) + } + + for i := len(nonces) - 1; i >= 0; i-- { + nonce = nonces[i] + require.True(t, a.RedeemNonce(nonce)) + } + + for i := 0; i < len(nonces); i++ { + nonce = nonces[i] + require.False(t, a.RedeemNonce(nonce)) + } +}