Move all ACME wrappers into a dedicated go file (#20174)

* Move all ACME wrappers into a dedicated go file

 - Make it easier to figure out where the various wrappers for
   ACME exist by locating them inside a dedicated go file instead
   of spread out across the various path_acme_xxx files.

* Add missing copyright headers to PKI files
This commit is contained in:
Steven Clark 2023-04-14 14:12:31 -04:00 committed by GitHub
parent 045de8a0b6
commit 138f36eafe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 214 additions and 166 deletions

View File

@ -1,3 +1,6 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package pki
import (

View File

@ -1,3 +1,6 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package pki
import (

View File

@ -1,3 +1,6 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package pki
import (

View File

@ -1,3 +1,6 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package pki
import (

View File

@ -1,3 +1,6 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package pki
import (

View File

@ -0,0 +1,178 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package pki
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
type acmeContext struct {
baseUrl *url.URL
sc *storageContext
}
type (
acmeOperation func(acmeCtx *acmeContext, r *logical.Request, _ *framework.FieldData) (*logical.Response, error)
acmeParsedOperation func(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}) (*logical.Response, error)
acmeAccountRequiredOperation func(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}, acct *acmeAccount) (*logical.Response, error)
)
func acmeErrorWrapper(op framework.OperationFunc) framework.OperationFunc {
return func(ctx context.Context, r *logical.Request, data *framework.FieldData) (*logical.Response, error) {
resp, err := op(ctx, r, data)
if err != nil {
return TranslateError(err)
}
return resp, nil
}
}
func (b *backend) acmeWrapper(op acmeOperation) framework.OperationFunc {
return acmeErrorWrapper(func(ctx context.Context, r *logical.Request, data *framework.FieldData) (*logical.Response, error) {
sc := b.makeStorageContext(ctx, r.Storage)
if false {
// TODO sclark: Check if ACME is enable here
return nil, fmt.Errorf("ACME is disabled in configuration: %w", ErrServerInternal)
}
baseUrl, err := getAcmeBaseUrl(sc, r.Path)
if err != nil {
return nil, err
}
acmeCtx := &acmeContext{
baseUrl: baseUrl,
sc: sc,
}
return op(acmeCtx, r, data)
})
}
func (b *backend) acmeParsedWrapper(op acmeParsedOperation) framework.OperationFunc {
return b.acmeWrapper(func(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData) (*logical.Response, error) {
user, data, err := b.acmeState.ParseRequestParams(acmeCtx, fields)
if err != nil {
return nil, err
}
resp, err := op(acmeCtx, r, fields, user, data)
// Our response handlers might not add the necessary headers.
if resp != nil {
if resp.Headers == nil {
resp.Headers = map[string][]string{}
}
if _, ok := resp.Headers["Replay-Nonce"]; !ok {
nonce, _, err := b.acmeState.GetNonce()
if err != nil {
return nil, err
}
resp.Headers["Replay-Nonce"] = []string{nonce}
}
if _, ok := resp.Headers["Link"]; !ok {
resp.Headers["Link"] = genAcmeLinkHeader(acmeCtx)
} else {
directory := genAcmeLinkHeader(acmeCtx)[0]
addDirectory := true
for _, item := range resp.Headers["Link"] {
if item == directory {
addDirectory = false
break
}
}
if addDirectory {
resp.Headers["Link"] = append(resp.Headers["Link"], directory)
}
}
// ACME responses don't understand Vault's default encoding
// format. Rather than expecting everything to handle creating
// ACME-formatted responses, do the marshaling in one place.
if _, ok := resp.Data[logical.HTTPRawBody]; !ok {
ignored_values := map[string]bool{logical.HTTPContentType: true, logical.HTTPStatusCode: true}
fields := map[string]interface{}{}
body := map[string]interface{}{
logical.HTTPContentType: "application/json",
logical.HTTPStatusCode: http.StatusOK,
}
for key, value := range resp.Data {
if _, present := ignored_values[key]; !present {
fields[key] = value
} else {
body[key] = value
}
}
rawBody, err := json.Marshal(fields)
if err != nil {
return nil, fmt.Errorf("Error marshaling JSON body: %w", err)
}
body[logical.HTTPRawBody] = rawBody
resp.Data = body
}
}
return resp, err
})
}
func (b *backend) acmeAccountRequiredWrapper(op acmeAccountRequiredOperation) framework.OperationFunc {
return b.acmeParsedWrapper(func(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData, uc *jwsCtx, data map[string]interface{}) (*logical.Response, error) {
if !uc.Existing {
return nil, fmt.Errorf("cannot process request without a 'kid': %w", ErrMalformed)
}
account, err := b.acmeState.LoadAccount(acmeCtx, uc.Kid)
if err != nil {
return nil, fmt.Errorf("error loading account: %w", err)
}
if account.Status != StatusValid {
// Treating "revoked" and "deactivated" as the same here.
return nil, fmt.Errorf("%w: account in status: %s", ErrUnauthorized, account.Status)
}
return op(acmeCtx, r, fields, uc, data, account)
})
}
func getAcmeBaseUrl(sc *storageContext, path string) (*url.URL, error) {
cfg, err := sc.getClusterConfig()
if err != nil {
return nil, fmt.Errorf("failed loading cluster config: %w", err)
}
if cfg.Path == "" {
return nil, fmt.Errorf("ACME feature requires local cluster path configuration to be set: %w", ErrServerInternal)
}
baseUrl, err := url.Parse(cfg.Path)
if err != nil {
return nil, fmt.Errorf("ACME feature a proper URL configured in local cluster path: %w", ErrServerInternal)
}
directoryPrefix := ""
lastIndex := strings.LastIndex(path, "/acme/")
if lastIndex != -1 {
directoryPrefix = path[0:lastIndex]
}
return baseUrl.JoinPath(directoryPrefix, "/acme/"), nil
}

View File

@ -1,3 +1,6 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package pki
import (

View File

@ -1,3 +1,6 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package pki
import (

View File

@ -1,12 +1,12 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package pki
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
@ -55,71 +55,6 @@ func patternAcmeDirectory(b *backend, pattern string) *framework.Path {
}
}
type acmeOperation func(acmeCtx *acmeContext, r *logical.Request, _ *framework.FieldData) (*logical.Response, error)
type acmeContext struct {
baseUrl *url.URL
sc *storageContext
}
func (b *backend) acmeWrapper(op acmeOperation) framework.OperationFunc {
return acmeErrorWrapper(func(ctx context.Context, r *logical.Request, data *framework.FieldData) (*logical.Response, error) {
sc := b.makeStorageContext(ctx, r.Storage)
if false {
// TODO sclark: Check if ACME is enable here
return nil, fmt.Errorf("ACME is disabled in configuration: %w", ErrServerInternal)
}
baseUrl, err := getAcmeBaseUrl(sc, r.Path)
if err != nil {
return nil, err
}
acmeCtx := &acmeContext{
baseUrl: baseUrl,
sc: sc,
}
return op(acmeCtx, r, data)
})
}
func getAcmeBaseUrl(sc *storageContext, path string) (*url.URL, error) {
cfg, err := sc.getClusterConfig()
if err != nil {
return nil, fmt.Errorf("failed loading cluster config: %w", err)
}
if cfg.Path == "" {
return nil, fmt.Errorf("ACME feature requires local cluster path configuration to be set: %w", ErrServerInternal)
}
baseUrl, err := url.Parse(cfg.Path)
if err != nil {
return nil, fmt.Errorf("ACME feature a proper URL configured in local cluster path: %w", ErrServerInternal)
}
directoryPrefix := ""
lastIndex := strings.LastIndex(path, "/acme/")
if lastIndex != -1 {
directoryPrefix = path[0:lastIndex]
}
return baseUrl.JoinPath(directoryPrefix, "/acme/"), nil
}
func acmeErrorWrapper(op framework.OperationFunc) framework.OperationFunc {
return func(ctx context.Context, r *logical.Request, data *framework.FieldData) (*logical.Response, error) {
resp, err := op(ctx, r, data)
if err != nil {
return TranslateError(err)
}
return resp, nil
}
}
func (b *backend) acmeDirectoryHandler(acmeCtx *acmeContext, r *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
rawBody, err := json.Marshal(map[string]interface{}{
"newNonce": acmeCtx.baseUrl.JoinPath("new-nonce").String(),

View File

@ -1,7 +1,9 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package pki
import (
"encoding/json"
"fmt"
"net/http"
"strings"
@ -127,81 +129,6 @@ func patternAcmeNewAccount(b *backend, pattern string) *framework.Path {
}
}
type acmeParsedOperation func(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}) (*logical.Response, error)
func (b *backend) acmeParsedWrapper(op acmeParsedOperation) framework.OperationFunc {
return b.acmeWrapper(func(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData) (*logical.Response, error) {
user, data, err := b.acmeState.ParseRequestParams(acmeCtx, fields)
if err != nil {
return nil, err
}
resp, err := op(acmeCtx, r, fields, user, data)
// Our response handlers might not add the necessary headers.
if resp != nil {
if resp.Headers == nil {
resp.Headers = map[string][]string{}
}
if _, ok := resp.Headers["Replay-Nonce"]; !ok {
nonce, _, err := b.acmeState.GetNonce()
if err != nil {
return nil, err
}
resp.Headers["Replay-Nonce"] = []string{nonce}
}
if _, ok := resp.Headers["Link"]; !ok {
resp.Headers["Link"] = genAcmeLinkHeader(acmeCtx)
} else {
directory := genAcmeLinkHeader(acmeCtx)[0]
addDirectory := true
for _, item := range resp.Headers["Link"] {
if item == directory {
addDirectory = false
break
}
}
if addDirectory {
resp.Headers["Link"] = append(resp.Headers["Link"], directory)
}
}
// ACME responses don't understand Vault's default encoding
// format. Rather than expecting everything to handle creating
// ACME-formatted responses, do the marshaling in one place.
if _, ok := resp.Data[logical.HTTPRawBody]; !ok {
ignored_values := map[string]bool{logical.HTTPContentType: true, logical.HTTPStatusCode: true}
fields := map[string]interface{}{}
body := map[string]interface{}{
logical.HTTPContentType: "application/json",
logical.HTTPStatusCode: http.StatusOK,
}
for key, value := range resp.Data {
if _, present := ignored_values[key]; !present {
fields[key] = value
} else {
body[key] = value
}
}
rawBody, err := json.Marshal(fields)
if err != nil {
return nil, fmt.Errorf("Error marshaling JSON body: %w", err)
}
body[logical.HTTPRawBody] = rawBody
resp.Data = body
}
}
return resp, err
})
}
func (b *backend) acmeNewAccountHandler(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}) (*logical.Response, error) {
// Parameters
var ok bool

View File

@ -1,3 +1,6 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package pki
import (

View File

@ -1,3 +1,6 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package pki
import (
@ -133,28 +136,6 @@ func patternAcmeGetOrder(b *backend, pattern string) *framework.Path {
}
}
type acmeAccountRequiredOperation func(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData, userCtx *jwsCtx, data map[string]interface{}, acct *acmeAccount) (*logical.Response, error)
func (b *backend) acmeAccountRequiredWrapper(op acmeAccountRequiredOperation) framework.OperationFunc {
return b.acmeParsedWrapper(func(acmeCtx *acmeContext, r *logical.Request, fields *framework.FieldData, uc *jwsCtx, data map[string]interface{}) (*logical.Response, error) {
if !uc.Existing {
return nil, fmt.Errorf("cannot process request without a 'kid': %w", ErrMalformed)
}
account, err := b.acmeState.LoadAccount(acmeCtx, uc.Kid)
if err != nil {
return nil, fmt.Errorf("error loading account: %w", err)
}
if account.Status != StatusValid {
// Treating "revoked" and "deactivated" as the same here.
return nil, fmt.Errorf("%w: account status is %s", ErrUnauthorized, account.Status)
}
return op(acmeCtx, r, fields, uc, data, account)
})
}
func (b *backend) acmeGetOrderHandler(ac *acmeContext, _ *logical.Request, fields *framework.FieldData, uc *jwsCtx, _ map[string]interface{}, acct *acmeAccount) (*logical.Response, error) {
orderId := fields.Get("order_id").(string)

View File

@ -1,3 +1,6 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package pki
import (