OSS side changes for PKI HSM type handling fix (#14364)

This commit is contained in:
Scott Miller 2022-03-03 15:30:18 -06:00 committed by GitHub
parent 003d8fb1fe
commit f753db2783
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 77 additions and 8 deletions

View File

@ -1,15 +1,20 @@
package pki
import (
"context"
"crypto/ecdsa"
"crypto/rsa"
"time"
"golang.org/x/crypto/ed25519"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/logical"
)
func (b *backend) getGenerationParams(
data *framework.FieldData,
func (b *backend) getGenerationParams(ctx context.Context,
data *framework.FieldData, mountPoint string,
) (exported bool, format string, role *roleEntry, errorResp *logical.Response) {
exportedStr := data.Get("exported").(string)
switch exportedStr {
@ -30,6 +35,8 @@ func (b *backend) getGenerationParams(
return
}
keyType := data.Get("key_type").(string)
keyBits := data.Get("key_bits").(int)
if exportedStr == "kms" {
_, okKeyType := data.Raw["key_type"]
_, okKeyBits := data.Raw["key_bits"]
@ -39,12 +46,35 @@ func (b *backend) getGenerationParams(
`invalid parameter for the kms path parameter, key_type nor key_bits arguments can be set in this mode`)
return
}
keyId, err := getManagedKeyId(data)
if err != nil {
errorResp = logical.ErrorResponse("unable to determine managed key id")
return
}
// Determine key type and key bits from the managed public key
withManagedPKIKey(ctx, b, keyId, mountPoint, func(ctx context.Context, key logical.ManagedSigningKey) error {
pubKey, err := key.GetPublicKey(ctx)
if err != nil {
return err
}
switch pubKey.(type) {
case *rsa.PublicKey:
keyType = "rsa"
keyBits = pubKey.(*rsa.PublicKey).Size() * 8
case *ecdsa.PublicKey:
keyType = "ec"
case *ed25519.PublicKey:
keyType = "ed25519"
}
return nil
})
}
role = &roleEntry{
TTL: time.Duration(data.Get("ttl").(int)) * time.Second,
KeyType: data.Get("key_type").(string),
KeyBits: data.Get("key_bits").(int),
KeyType: keyType,
KeyBits: keyBits,
SignatureBits: data.Get("signature_bits").(int),
AllowLocalhost: true,
AllowAnyName: true,

View File

@ -31,3 +31,7 @@ func generateCSRBundle(_ context.Context, _ *backend, input *inputBundle, data *
func parseCABundle(_ context.Context, _ *backend, _ *logical.Request, bundle *certutil.CertBundle) (*certutil.ParsedCertBundle, error) {
return bundle.ToParsedCertBundle()
}
func withManagedPKIKey(_ context.Context, _ *backend, _ keyId, _ string, _ logical.ManagedSigningKeyConsumer) error {
return errEntOnly
}

View File

@ -72,7 +72,7 @@ endpoint.`,
func (b *backend) pathGenerateIntermediate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
var err error
exported, format, role, errorResp := b.getGenerationParams(data)
exported, format, role, errorResp := b.getGenerationParams(ctx, data, req.MountPoint)
if errorResp != nil {
return errorResp, nil
}

View File

@ -155,7 +155,7 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request,
return resp, nil
}
exported, format, role, errorResp := b.getGenerationParams(data)
exported, format, role, errorResp := b.getGenerationParams(ctx, data, req.MountPoint)
if errorResp != nil {
return errorResp, nil
}

View File

@ -4,6 +4,8 @@ import (
"fmt"
"strings"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/errutil"
)
@ -28,9 +30,42 @@ func kmsRequested(input *inputBundle) bool {
return exportedStr.(string) == "kms"
}
func getManagedKeyNameOrUUID(input *inputBundle) (name string, UUID string, err error) {
type keyId interface {
String() string
}
type (
UUIDKey string
NameKey string
)
func (u UUIDKey) String() string {
return string(u)
}
func (n NameKey) String() string {
return string(n)
}
// getManagedKeyId returns a NameKey or a UUIDKey, whichever was specified in the
// request API data.
func getManagedKeyId(data *framework.FieldData) (keyId, error) {
name, UUID, err := getManagedKeyNameOrUUID(data)
if err != nil {
return nil, err
}
var keyId keyId = NameKey(name)
if len(UUID) > 0 {
keyId = UUIDKey(UUID)
}
return keyId, nil
}
func getManagedKeyNameOrUUID(data *framework.FieldData) (name string, UUID string, err error) {
getApiData := func(argName string) (string, error) {
arg, ok := input.apiData.GetOk(argName)
arg, ok := data.GetOk(argName)
if !ok {
return "", nil
}