diff --git a/builtin/logical/pki/ca_util.go b/builtin/logical/pki/ca_util.go index 04ec478aa..be1890984 100644 --- a/builtin/logical/pki/ca_util.go +++ b/builtin/logical/pki/ca_util.go @@ -1,15 +1,20 @@ package pki import ( + "context" + "crypto/ecdsa" + "crypto/rsa" "time" + "golang.org/x/crypto/ed25519" + "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/logical" ) -func (b *backend) getGenerationParams( - data *framework.FieldData, +func (b *backend) getGenerationParams(ctx context.Context, + data *framework.FieldData, mountPoint string, ) (exported bool, format string, role *roleEntry, errorResp *logical.Response) { exportedStr := data.Get("exported").(string) switch exportedStr { @@ -30,6 +35,8 @@ func (b *backend) getGenerationParams( return } + keyType := data.Get("key_type").(string) + keyBits := data.Get("key_bits").(int) if exportedStr == "kms" { _, okKeyType := data.Raw["key_type"] _, okKeyBits := data.Raw["key_bits"] @@ -39,12 +46,35 @@ func (b *backend) getGenerationParams( `invalid parameter for the kms path parameter, key_type nor key_bits arguments can be set in this mode`) return } + + keyId, err := getManagedKeyId(data) + if err != nil { + errorResp = logical.ErrorResponse("unable to determine managed key id") + return + } + // Determine key type and key bits from the managed public key + withManagedPKIKey(ctx, b, keyId, mountPoint, func(ctx context.Context, key logical.ManagedSigningKey) error { + pubKey, err := key.GetPublicKey(ctx) + if err != nil { + return err + } + switch pubKey.(type) { + case *rsa.PublicKey: + keyType = "rsa" + keyBits = pubKey.(*rsa.PublicKey).Size() * 8 + case *ecdsa.PublicKey: + keyType = "ec" + case *ed25519.PublicKey: + keyType = "ed25519" + } + return nil + }) } role = &roleEntry{ TTL: time.Duration(data.Get("ttl").(int)) * time.Second, - KeyType: data.Get("key_type").(string), - KeyBits: data.Get("key_bits").(int), + KeyType: keyType, + KeyBits: keyBits, SignatureBits: data.Get("signature_bits").(int), AllowLocalhost: true, AllowAnyName: true, diff --git a/builtin/logical/pki/managed_key_util.go b/builtin/logical/pki/managed_key_util.go index 8efcaefb6..c649828b9 100644 --- a/builtin/logical/pki/managed_key_util.go +++ b/builtin/logical/pki/managed_key_util.go @@ -31,3 +31,7 @@ func generateCSRBundle(_ context.Context, _ *backend, input *inputBundle, data * func parseCABundle(_ context.Context, _ *backend, _ *logical.Request, bundle *certutil.CertBundle) (*certutil.ParsedCertBundle, error) { return bundle.ToParsedCertBundle() } + +func withManagedPKIKey(_ context.Context, _ *backend, _ keyId, _ string, _ logical.ManagedSigningKeyConsumer) error { + return errEntOnly +} diff --git a/builtin/logical/pki/path_intermediate.go b/builtin/logical/pki/path_intermediate.go index fc578d6fe..77e4d4a54 100644 --- a/builtin/logical/pki/path_intermediate.go +++ b/builtin/logical/pki/path_intermediate.go @@ -72,7 +72,7 @@ endpoint.`, func (b *backend) pathGenerateIntermediate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { var err error - exported, format, role, errorResp := b.getGenerationParams(data) + exported, format, role, errorResp := b.getGenerationParams(ctx, data, req.MountPoint) if errorResp != nil { return errorResp, nil } diff --git a/builtin/logical/pki/path_root.go b/builtin/logical/pki/path_root.go index bf260be56..4cc607e12 100644 --- a/builtin/logical/pki/path_root.go +++ b/builtin/logical/pki/path_root.go @@ -155,7 +155,7 @@ func (b *backend) pathCAGenerateRoot(ctx context.Context, req *logical.Request, return resp, nil } - exported, format, role, errorResp := b.getGenerationParams(data) + exported, format, role, errorResp := b.getGenerationParams(ctx, data, req.MountPoint) if errorResp != nil { return errorResp, nil } diff --git a/builtin/logical/pki/util.go b/builtin/logical/pki/util.go index 566c49e22..d880f1d3f 100644 --- a/builtin/logical/pki/util.go +++ b/builtin/logical/pki/util.go @@ -4,6 +4,8 @@ import ( "fmt" "strings" + "github.com/hashicorp/vault/sdk/framework" + "github.com/hashicorp/vault/sdk/helper/errutil" ) @@ -28,9 +30,42 @@ func kmsRequested(input *inputBundle) bool { return exportedStr.(string) == "kms" } -func getManagedKeyNameOrUUID(input *inputBundle) (name string, UUID string, err error) { +type keyId interface { + String() string +} + +type ( + UUIDKey string + NameKey string +) + +func (u UUIDKey) String() string { + return string(u) +} + +func (n NameKey) String() string { + return string(n) +} + +// getManagedKeyId returns a NameKey or a UUIDKey, whichever was specified in the +// request API data. +func getManagedKeyId(data *framework.FieldData) (keyId, error) { + name, UUID, err := getManagedKeyNameOrUUID(data) + if err != nil { + return nil, err + } + + var keyId keyId = NameKey(name) + if len(UUID) > 0 { + keyId = UUIDKey(UUID) + } + + return keyId, nil +} + +func getManagedKeyNameOrUUID(data *framework.FieldData) (name string, UUID string, err error) { getApiData := func(argName string) (string, error) { - arg, ok := input.apiData.GetOk(argName) + arg, ok := data.GetOk(argName) if !ok { return "", nil }