OSS side changes for PKI HSM type handling fix (#14364)
This commit is contained in:
parent
003d8fb1fe
commit
f753db2783
|
@ -1,15 +1,20 @@
|
|||
package pki
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ed25519"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/certutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
func (b *backend) getGenerationParams(
|
||||
data *framework.FieldData,
|
||||
func (b *backend) getGenerationParams(ctx context.Context,
|
||||
data *framework.FieldData, mountPoint string,
|
||||
) (exported bool, format string, role *roleEntry, errorResp *logical.Response) {
|
||||
exportedStr := data.Get("exported").(string)
|
||||
switch exportedStr {
|
||||
|
@ -30,6 +35,8 @@ func (b *backend) getGenerationParams(
|
|||
return
|
||||
}
|
||||
|
||||
keyType := data.Get("key_type").(string)
|
||||
keyBits := data.Get("key_bits").(int)
|
||||
if exportedStr == "kms" {
|
||||
_, okKeyType := data.Raw["key_type"]
|
||||
_, okKeyBits := data.Raw["key_bits"]
|
||||
|
@ -39,12 +46,35 @@ func (b *backend) getGenerationParams(
|
|||
`invalid parameter for the kms path parameter, key_type nor key_bits arguments can be set in this mode`)
|
||||
return
|
||||
}
|
||||
|
||||
keyId, err := getManagedKeyId(data)
|
||||
if err != nil {
|
||||
errorResp = logical.ErrorResponse("unable to determine managed key id")
|
||||
return
|
||||
}
|
||||
// Determine key type and key bits from the managed public key
|
||||
withManagedPKIKey(ctx, b, keyId, mountPoint, func(ctx context.Context, key logical.ManagedSigningKey) error {
|
||||
pubKey, err := key.GetPublicKey(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch pubKey.(type) {
|
||||
case *rsa.PublicKey:
|
||||
keyType = "rsa"
|
||||
keyBits = pubKey.(*rsa.PublicKey).Size() * 8
|
||||
case *ecdsa.PublicKey:
|
||||
keyType = "ec"
|
||||
case *ed25519.PublicKey:
|
||||
keyType = "ed25519"
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
role = &roleEntry{
|
||||
TTL: time.Duration(data.Get("ttl").(int)) * time.Second,
|
||||
KeyType: data.Get("key_type").(string),
|
||||
KeyBits: data.Get("key_bits").(int),
|
||||
KeyType: keyType,
|
||||
KeyBits: keyBits,
|
||||
SignatureBits: data.Get("signature_bits").(int),
|
||||
AllowLocalhost: true,
|
||||
AllowAnyName: true,
|
||||
|
|
|
@ -31,3 +31,7 @@ func generateCSRBundle(_ context.Context, _ *backend, input *inputBundle, data *
|
|||
func parseCABundle(_ context.Context, _ *backend, _ *logical.Request, bundle *certutil.CertBundle) (*certutil.ParsedCertBundle, error) {
|
||||
return bundle.ToParsedCertBundle()
|
||||
}
|
||||
|
||||
func withManagedPKIKey(_ context.Context, _ *backend, _ keyId, _ string, _ logical.ManagedSigningKeyConsumer) error {
|
||||
return errEntOnly
|
||||
}
|
||||
|
|
|
@ -72,7 +72,7 @@ endpoint.`,
|
|||
func (b *backend) pathGenerateIntermediate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
var err error
|
||||
|
||||
exported, format, role, errorResp := b.getGenerationParams(data)
|
||||
exported, format, role, errorResp := b.getGenerationParams(ctx, data, req.MountPoint)
|
||||
if errorResp != nil {
|
||||
return errorResp, nil
|
||||
}
|
||||
|
|
|
@ -155,7 +155,7 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request,
|
|||
return resp, nil
|
||||
}
|
||||
|
||||
exported, format, role, errorResp := b.getGenerationParams(data)
|
||||
exported, format, role, errorResp := b.getGenerationParams(ctx, data, req.MountPoint)
|
||||
if errorResp != nil {
|
||||
return errorResp, nil
|
||||
}
|
||||
|
|
|
@ -4,6 +4,8 @@ import (
|
|||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/helper/errutil"
|
||||
)
|
||||
|
||||
|
@ -28,9 +30,42 @@ func kmsRequested(input *inputBundle) bool {
|
|||
return exportedStr.(string) == "kms"
|
||||
}
|
||||
|
||||
func getManagedKeyNameOrUUID(input *inputBundle) (name string, UUID string, err error) {
|
||||
type keyId interface {
|
||||
String() string
|
||||
}
|
||||
|
||||
type (
|
||||
UUIDKey string
|
||||
NameKey string
|
||||
)
|
||||
|
||||
func (u UUIDKey) String() string {
|
||||
return string(u)
|
||||
}
|
||||
|
||||
func (n NameKey) String() string {
|
||||
return string(n)
|
||||
}
|
||||
|
||||
// getManagedKeyId returns a NameKey or a UUIDKey, whichever was specified in the
|
||||
// request API data.
|
||||
func getManagedKeyId(data *framework.FieldData) (keyId, error) {
|
||||
name, UUID, err := getManagedKeyNameOrUUID(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var keyId keyId = NameKey(name)
|
||||
if len(UUID) > 0 {
|
||||
keyId = UUIDKey(UUID)
|
||||
}
|
||||
|
||||
return keyId, nil
|
||||
}
|
||||
|
||||
func getManagedKeyNameOrUUID(data *framework.FieldData) (name string, UUID string, err error) {
|
||||
getApiData := func(argName string) (string, error) {
|
||||
arg, ok := input.apiData.GetOk(argName)
|
||||
arg, ok := data.GetOk(argName)
|
||||
if !ok {
|
||||
return "", nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue