diff --git a/builtin/logical/transit/backend_test.go b/builtin/logical/transit/backend_test.go index 44d901456..3ead8960d 100644 --- a/builtin/logical/transit/backend_test.go +++ b/builtin/logical/transit/backend_test.go @@ -2022,3 +2022,257 @@ func TestTransitPKICSR(t *testing.T) { t.Logf("root: %v", rootCertPEM) t.Logf("leaf: %v", leafCertPEM) } + +func TestTransit_ReadPublicKeyImported(t *testing.T) { + testTransit_ReadPublicKeyImported(t, "rsa-2048") + testTransit_ReadPublicKeyImported(t, "ecdsa-p256") +} + +func testTransit_ReadPublicKeyImported(t *testing.T, keyType string) { + generateKeys(t) + b, s := createBackendWithStorage(t) + keyID, err := uuid.GenerateUUID() + if err != nil { + t.Fatalf("failed to generate key ID: %s", err) + } + + // Get key + privateKey := getKey(t, keyType) + publicKeyBytes, err := getPublicKey(privateKey, keyType) + if err != nil { + t.Fatalf("failed to extract the public key: %s", err) + } + + // Import key + importReq := &logical.Request{ + Storage: s, + Operation: logical.UpdateOperation, + Path: fmt.Sprintf("keys/%s/import", keyID), + Data: map[string]interface{}{ + "public_key": publicKeyBytes, + "type": keyType, + }, + } + importResp, err := b.HandleRequest(context.Background(), importReq) + if err != nil || (importResp != nil && importResp.IsError()) { + t.Fatalf("failed to import public key. err: %s\nresp: %#v", err, importResp) + } + + // Read key + readReq := &logical.Request{ + Operation: logical.ReadOperation, + Path: "keys/" + keyID, + Storage: s, + } + + readResp, err := b.HandleRequest(context.Background(), readReq) + if err != nil || (readResp != nil && readResp.IsError()) { + t.Fatalf("failed to read key. err: %s\nresp: %#v", err, readResp) + } +} + +func TestTransit_SignWithImportedPublicKey(t *testing.T) { + testTransit_SignWithImportedPublicKey(t, "rsa-2048") + testTransit_SignWithImportedPublicKey(t, "ecdsa-p256") +} + +func testTransit_SignWithImportedPublicKey(t *testing.T, keyType string) { + generateKeys(t) + b, s := createBackendWithStorage(t) + keyID, err := uuid.GenerateUUID() + if err != nil { + t.Fatalf("failed to generate key ID: %s", err) + } + + // Get key + privateKey := getKey(t, keyType) + publicKeyBytes, err := getPublicKey(privateKey, keyType) + if err != nil { + t.Fatalf("failed to extract the public key: %s", err) + } + + // Import key + importReq := &logical.Request{ + Storage: s, + Operation: logical.UpdateOperation, + Path: fmt.Sprintf("keys/%s/import", keyID), + Data: map[string]interface{}{ + "public_key": publicKeyBytes, + "type": keyType, + }, + } + importResp, err := b.HandleRequest(context.Background(), importReq) + if err != nil || (importResp != nil && importResp.IsError()) { + t.Fatalf("failed to import public key. err: %s\nresp: %#v", err, importResp) + } + + // Sign text + signReq := &logical.Request{ + Path: "sign/" + keyID, + Operation: logical.UpdateOperation, + Storage: s, + Data: map[string]interface{}{ + "plaintext": base64.StdEncoding.EncodeToString([]byte(testPlaintext)), + }, + } + + _, err = b.HandleRequest(context.Background(), signReq) + if err == nil { + t.Fatalf("expected error, should have failed to sign input") + } +} + +func TestTransit_VerifyWithImportedPublicKey(t *testing.T) { + generateKeys(t) + keyType := "rsa-2048" + b, s := createBackendWithStorage(t) + keyID, err := uuid.GenerateUUID() + if err != nil { + t.Fatalf("failed to generate key ID: %s", err) + } + + // Get key + privateKey := getKey(t, keyType) + publicKeyBytes, err := getPublicKey(privateKey, keyType) + if err != nil { + t.Fatal(err) + } + + // Retrieve public wrapping key + wrappingKey, err := b.getWrappingKey(context.Background(), s) + if err != nil || wrappingKey == nil { + t.Fatalf("failed to retrieve public wrapping key: %s", err) + } + + privWrappingKey := wrappingKey.Keys[strconv.Itoa(wrappingKey.LatestVersion)].RSAKey + pubWrappingKey := &privWrappingKey.PublicKey + + // generate ciphertext + importBlob := wrapTargetKeyForImport(t, pubWrappingKey, privateKey, keyType, "SHA256") + + // Import private key + importReq := &logical.Request{ + Storage: s, + Operation: logical.UpdateOperation, + Path: fmt.Sprintf("keys/%s/import", keyID), + Data: map[string]interface{}{ + "ciphertext": importBlob, + "type": keyType, + }, + } + importResp, err := b.HandleRequest(context.Background(), importReq) + if err != nil || (importResp != nil && importResp.IsError()) { + t.Fatalf("failed to import key. err: %s\nresp: %#v", err, importResp) + } + + // Sign text + signReq := &logical.Request{ + Storage: s, + Path: "sign/" + keyID, + Operation: logical.UpdateOperation, + Data: map[string]interface{}{ + "plaintext": base64.StdEncoding.EncodeToString([]byte(testPlaintext)), + }, + } + + signResp, err := b.HandleRequest(context.Background(), signReq) + if err != nil || (signResp != nil && signResp.IsError()) { + t.Fatalf("failed to sign plaintext. err: %s\nresp: %#v", err, signResp) + } + + // Get signature + signature := signResp.Data["signature"].(string) + + // Import new key as public key + importPubReq := &logical.Request{ + Storage: s, + Operation: logical.UpdateOperation, + Path: fmt.Sprintf("keys/%s/import", "public-key-rsa"), + Data: map[string]interface{}{ + "public_key": publicKeyBytes, + "type": keyType, + }, + } + importPubResp, err := b.HandleRequest(context.Background(), importPubReq) + if err != nil || (importPubResp != nil && importPubResp.IsError()) { + t.Fatalf("failed to import public key. err: %s\nresp: %#v", err, importPubResp) + } + + // Verify signed text + verifyReq := &logical.Request{ + Path: "verify/public-key-rsa", + Operation: logical.UpdateOperation, + Storage: s, + Data: map[string]interface{}{ + "input": base64.StdEncoding.EncodeToString([]byte(testPlaintext)), + "signature": signature, + }, + } + + verifyResp, err := b.HandleRequest(context.Background(), verifyReq) + if err != nil || (importResp != nil && verifyResp.IsError()) { + t.Fatalf("failed to verify signed data. err: %s\nresp: %#v", err, importResp) + } +} + +func TestTransit_ExportPublicKeyImported(t *testing.T) { + testTransit_ExportPublicKeyImported(t, "rsa-2048") + testTransit_ExportPublicKeyImported(t, "ecdsa-p256") +} + +func testTransit_ExportPublicKeyImported(t *testing.T, keyType string) { + generateKeys(t) + b, s := createBackendWithStorage(t) + keyID, err := uuid.GenerateUUID() + if err != nil { + t.Fatalf("failed to generate key ID: %s", err) + } + + // Get key + privateKey := getKey(t, keyType) + publicKeyBytes, err := getPublicKey(privateKey, keyType) + if err != nil { + t.Fatalf("failed to extract the public key: %s", err) + } + + // Import key + importReq := &logical.Request{ + Storage: s, + Operation: logical.UpdateOperation, + Path: fmt.Sprintf("keys/%s/import", keyID), + Data: map[string]interface{}{ + "public_key": publicKeyBytes, + "type": keyType, + "exportable": true, + }, + } + importResp, err := b.HandleRequest(context.Background(), importReq) + if err != nil || (importResp != nil && importResp.IsError()) { + t.Fatalf("failed to import public key. err: %s\nresp: %#v", err, importResp) + } + + // Export key + exportReq := &logical.Request{ + Operation: logical.ReadOperation, + Path: fmt.Sprintf("export/signing-key/%s/latest", keyID), + Storage: s, + } + + exportResp, err := b.HandleRequest(context.Background(), exportReq) + if err != nil || (exportResp != nil && exportResp.IsError()) { + t.Fatalf("failed to export key. err: %v\nresp: %#v", err, exportResp) + } + + responseKeys, exist := exportResp.Data["keys"] + if !exist { + t.Fatal("expected response data to hold a 'keys' field") + } + + exportedKeyBytes := responseKeys.(map[string]string)["1"] + exportedKeyBlock, _ := pem.Decode([]byte(exportedKeyBytes)) + publicKeyBlock, _ := pem.Decode(publicKeyBytes) + + if !reflect.DeepEqual(publicKeyBlock.Bytes, exportedKeyBlock.Bytes) { + t.Fatal("exported key bytes should have matched with imported key") + } +} diff --git a/builtin/logical/transit/path_encrypt_test.go b/builtin/logical/transit/path_encrypt_test.go index 131517fda..b886d4fef 100644 --- a/builtin/logical/transit/path_encrypt_test.go +++ b/builtin/logical/transit/path_encrypt_test.go @@ -6,12 +6,14 @@ package transit import ( "context" "encoding/json" + "fmt" "reflect" "strings" "testing" "github.com/hashicorp/vault/sdk/helper/keysutil" + uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/sdk/logical" "github.com/mitchellh/mapstructure" ) @@ -944,3 +946,48 @@ func TestShouldWarnAboutNonceUsage(t *testing.T) { } } } + +func TestTransit_EncryptWithRSAPublicKey(t *testing.T) { + generateKeys(t) + b, s := createBackendWithStorage(t) + keyType := "rsa-2048" + keyID, err := uuid.GenerateUUID() + if err != nil { + t.Fatalf("failed to generate key ID: %s", err) + } + + // Get key + privateKey := getKey(t, keyType) + publicKeyBytes, err := getPublicKey(privateKey, keyType) + if err != nil { + t.Fatal(err) + } + + // Import key + req := &logical.Request{ + Storage: s, + Operation: logical.UpdateOperation, + Path: fmt.Sprintf("keys/%s/import", keyID), + Data: map[string]interface{}{ + "public_key": publicKeyBytes, + "type": keyType, + }, + } + _, err = b.HandleRequest(context.Background(), req) + if err != nil { + t.Fatalf("failed to import public key: %s", err) + } + + req = &logical.Request{ + Operation: logical.CreateOperation, + Path: fmt.Sprintf("encrypt/%s", keyID), + Storage: s, + Data: map[string]interface{}{ + "plaintext": "bXkgc2VjcmV0IGRhdGE=", + }, + } + _, err = b.HandleRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } +} diff --git a/builtin/logical/transit/path_export.go b/builtin/logical/transit/path_export.go index 96a000ec3..a3e6fc6d2 100644 --- a/builtin/logical/transit/path_export.go +++ b/builtin/logical/transit/path_export.go @@ -7,7 +7,6 @@ import ( "context" "crypto/ecdsa" "crypto/elliptic" - "crypto/rsa" "crypto/x509" "encoding/base64" "encoding/pem" @@ -169,7 +168,11 @@ func getExportKey(policy *keysutil.Policy, key *keysutil.KeyEntry, exportType st return strings.TrimSpace(base64.StdEncoding.EncodeToString(key.Key)), nil case keysutil.KeyType_RSA2048, keysutil.KeyType_RSA3072, keysutil.KeyType_RSA4096: - return encodeRSAPrivateKey(key.RSAKey), nil + rsaKey, err := encodeRSAPrivateKey(key) + if err != nil { + return "", err + } + return rsaKey, nil } case exportTypeSigningKey: @@ -194,23 +197,41 @@ func getExportKey(policy *keysutil.Policy, key *keysutil.KeyEntry, exportType st return strings.TrimSpace(base64.StdEncoding.EncodeToString(key.Key)), nil case keysutil.KeyType_RSA2048, keysutil.KeyType_RSA3072, keysutil.KeyType_RSA4096: - return encodeRSAPrivateKey(key.RSAKey), nil + rsaKey, err := encodeRSAPrivateKey(key) + if err != nil { + return "", err + } + return rsaKey, nil } } return "", fmt.Errorf("unknown key type %v", policy.Type) } -func encodeRSAPrivateKey(key *rsa.PrivateKey) string { +func encodeRSAPrivateKey(key *keysutil.KeyEntry) (string, error) { // When encoding PKCS1, the PEM header should be `RSA PRIVATE KEY`. When Go // has PKCS8 encoding support, we may want to change this. - derBytes := x509.MarshalPKCS1PrivateKey(key) - pemBlock := &pem.Block{ - Type: "RSA PRIVATE KEY", + var blockType string + var derBytes []byte + var err error + if !key.IsPrivateKeyMissing() { + blockType = "RSA PRIVATE KEY" + derBytes = x509.MarshalPKCS1PrivateKey(key.RSAKey) + } else { + blockType = "PUBLIC KEY" + derBytes, err = x509.MarshalPKIXPublicKey(key.RSAPublicKey) + if err != nil { + return "", err + } + } + + pemBlock := pem.Block{ + Type: blockType, Bytes: derBytes, } - pemBytes := pem.EncodeToMemory(pemBlock) - return string(pemBytes) + + pemBytes := pem.EncodeToMemory(&pemBlock) + return string(pemBytes), nil } func keyEntryToECPrivateKey(k *keysutil.KeyEntry, curve elliptic.Curve) (string, error) { @@ -218,27 +239,46 @@ func keyEntryToECPrivateKey(k *keysutil.KeyEntry, curve elliptic.Curve) (string, return "", errors.New("nil KeyEntry provided") } - privKey := &ecdsa.PrivateKey{ - PublicKey: ecdsa.PublicKey{ - Curve: curve, - X: k.EC_X, - Y: k.EC_Y, - }, - D: k.EC_D, - } - ecder, err := x509.MarshalECPrivateKey(privKey) - if err != nil { - return "", err - } - if ecder == nil { - return "", errors.New("no data returned when marshalling to private key") + pubKey := ecdsa.PublicKey{ + Curve: curve, + X: k.EC_X, + Y: k.EC_Y, } - block := pem.Block{ - Type: "EC PRIVATE KEY", - Bytes: ecder, + var blockType string + var derBytes []byte + var err error + if !k.IsPrivateKeyMissing() { + blockType = "EC PRIVATE KEY" + privKey := &ecdsa.PrivateKey{ + PublicKey: pubKey, + D: k.EC_D, + } + derBytes, err = x509.MarshalECPrivateKey(privKey) + if err != nil { + return "", err + } + if derBytes == nil { + return "", errors.New("no data returned when marshalling to private key") + } + } else { + blockType = "PUBLIC KEY" + derBytes, err = x509.MarshalPKIXPublicKey(&pubKey) + if err != nil { + return "", err + } + + if derBytes == nil { + return "", errors.New("no data returned when marshalling to public key") + } } - return strings.TrimSpace(string(pem.EncodeToMemory(&block))), nil + + pemBlock := pem.Block{ + Type: blockType, + Bytes: derBytes, + } + + return strings.TrimSpace(string(pem.EncodeToMemory(&pemBlock))), nil } const pathExportHelpSyn = `Export named encryption or signing key` diff --git a/builtin/logical/transit/path_import.go b/builtin/logical/transit/path_import.go index 97e62fa2f..540c9c361 100644 --- a/builtin/logical/transit/path_import.go +++ b/builtin/logical/transit/path_import.go @@ -59,6 +59,10 @@ ephemeral AES key. Can be one of "SHA1", "SHA224", "SHA256" (default), "SHA384", Description: `The base64-encoded ciphertext of the keys. The AES key should be encrypted using OAEP with the wrapping key and then concatenated with the import key, wrapped by the AES key.`, }, + "public_key": { + Type: framework.TypeString, + Description: `The plaintext PEM public key to be imported. If "ciphertext" is set, this field is ignored.`, + }, "allow_rotation": { Type: framework.TypeBool, Description: "True if the imported key may be rotated within Vault; false otherwise.", @@ -128,12 +132,27 @@ func (b *backend) pathImportVersion() *framework.Path { Description: `The base64-encoded ciphertext of the keys. The AES key should be encrypted using OAEP with the wrapping key and then concatenated with the import key, wrapped by the AES key.`, }, + "public_key": { + Type: framework.TypeString, + Description: `The plaintext public key to be imported. If "ciphertext" is set, this field is ignored.`, + }, "hash_function": { Type: framework.TypeString, Default: "SHA256", Description: `The hash function used as a random oracle in the OAEP wrapping of the user-generated, ephemeral AES key. Can be one of "SHA1", "SHA224", "SHA256" (default), "SHA384", or "SHA512"`, }, + "bump_version": { + Type: framework.TypeBool, + Default: true, + Description: `By default, each operation will create a new key version. +If set to 'false', will try to update the 'Latest' version of the key, unless changed in the "version" field.`, + }, + "version": { + Type: framework.TypeInt, + Description: `Key version to be updated, if left empty 'Latest' version will be updated. +If "bump_version" is set to 'true', this field is ignored.`, + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ logical.UpdateOperation: b.pathImportVersionWrite, @@ -147,11 +166,9 @@ func (b *backend) pathImportWrite(ctx context.Context, req *logical.Request, d * name := d.Get("name").(string) derived := d.Get("derived").(bool) keyType := d.Get("type").(string) - hashFnStr := d.Get("hash_function").(string) exportable := d.Get("exportable").(bool) allowPlaintextBackup := d.Get("allow_plaintext_backup").(bool) autoRotatePeriod := time.Second * time.Duration(d.Get("auto_rotate_period").(int)) - ciphertextString := d.Get("ciphertext").(string) allowRotation := d.Get("allow_rotation").(bool) // Ensure the caller didn't supply "convergent_encryption" as a field, since it's not supported on import. @@ -163,6 +180,12 @@ func (b *backend) pathImportWrite(ctx context.Context, req *logical.Request, d * return nil, errors.New("allow_rotation must be set to true if auto-rotation is enabled") } + // Ensure that at least on `key` field has been set + isCiphertextSet, err := checkKeyFieldsSet(d) + if err != nil { + return nil, err + } + polReq := keysutil.PolicyRequest{ Storage: req.Storage, Name: name, @@ -171,6 +194,7 @@ func (b *backend) pathImportWrite(ctx context.Context, req *logical.Request, d * AllowPlaintextBackup: allowPlaintextBackup, AutoRotatePeriod: autoRotatePeriod, AllowImportedKeyRotation: allowRotation, + IsPrivateKey: isCiphertextSet, } switch strings.ToLower(keyType) { @@ -200,11 +224,6 @@ func (b *backend) pathImportWrite(ctx context.Context, req *logical.Request, d * return logical.ErrorResponse(fmt.Sprintf("unknown key type: %v", keyType)), logical.ErrInvalidRequest } - hashFn, err := parseHashFn(hashFnStr) - if err != nil { - return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest - } - p, _, err := b.GetPolicy(ctx, polReq, b.GetRandomReader()) if err != nil { return nil, err @@ -217,14 +236,9 @@ func (b *backend) pathImportWrite(ctx context.Context, req *logical.Request, d * return nil, errors.New("the import path cannot be used with an existing key; use import-version to rotate an existing imported key") } - ciphertext, err := base64.StdEncoding.DecodeString(ciphertextString) + key, resp, err := b.extractKeyFromFields(ctx, req, d, polReq.KeyType, isCiphertextSet) if err != nil { - return nil, err - } - - key, err := b.decryptImportedKey(ctx, req.Storage, ciphertext, hashFn) - if err != nil { - return nil, err + return resp, err } err = b.lm.ImportPolicy(ctx, polReq, key, b.GetRandomReader()) @@ -237,20 +251,19 @@ func (b *backend) pathImportWrite(ctx context.Context, req *logical.Request, d * func (b *backend) pathImportVersionWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - hashFnStr := d.Get("hash_function").(string) - ciphertextString := d.Get("ciphertext").(string) + bumpVersion := d.Get("bump_version").(bool) + + isCiphertextSet, err := checkKeyFieldsSet(d) + if err != nil { + return nil, err + } polReq := keysutil.PolicyRequest{ - Storage: req.Storage, - Name: name, - Upsert: false, + Storage: req.Storage, + Name: name, + Upsert: false, + IsPrivateKey: isCiphertextSet, } - - hashFn, err := parseHashFn(hashFnStr) - if err != nil { - return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest - } - p, _, err := b.GetPolicy(ctx, polReq, b.GetRandomReader()) if err != nil { return nil, err @@ -270,15 +283,26 @@ func (b *backend) pathImportVersionWrite(ctx context.Context, req *logical.Reque } defer p.Unlock() - ciphertext, err := base64.StdEncoding.DecodeString(ciphertextString) - if err != nil { - return nil, err + // Get param version if set else LatestVersion + versionToUpdate := p.LatestVersion + if version, ok := d.Raw["version"]; ok { + versionToUpdate = version.(int) } - importKey, err := b.decryptImportedKey(ctx, req.Storage, ciphertext, hashFn) + + key, resp, err := b.extractKeyFromFields(ctx, req, d, p.Type, isCiphertextSet) if err != nil { - return nil, err + return resp, err + } + + if bumpVersion { + err = p.ImportPublicOrPrivate(ctx, req.Storage, key, isCiphertextSet, b.GetRandomReader()) + } else { + // Check if given version can be updated given input + err := p.KeyVersionCanBeUpdated(versionToUpdate, isCiphertextSet) + if err == nil { + err = p.ImportPrivateKeyForVersion(ctx, req.Storage, versionToUpdate, key) + } } - err = p.Import(ctx, req.Storage, importKey, b.GetRandomReader()) if err != nil { return nil, err } @@ -336,6 +360,36 @@ func (b *backend) decryptImportedKey(ctx context.Context, storage logical.Storag return importKey, nil } +func (b *backend) extractKeyFromFields(ctx context.Context, req *logical.Request, d *framework.FieldData, keyType keysutil.KeyType, isPrivateKey bool) ([]byte, *logical.Response, error) { + var key []byte + if isPrivateKey { + hashFnStr := d.Get("hash_function").(string) + hashFn, err := parseHashFn(hashFnStr) + if err != nil { + return key, logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + } + + ciphertextString := d.Get("ciphertext").(string) + ciphertext, err := base64.StdEncoding.DecodeString(ciphertextString) + if err != nil { + return key, nil, err + } + + key, err = b.decryptImportedKey(ctx, req.Storage, ciphertext, hashFn) + if err != nil { + return key, nil, err + } + } else { + publicKeyString := d.Get("public_key").(string) + if !keyType.ImportPublicKeySupported() { + return key, nil, errors.New("provided type does not support public_key import") + } + key = []byte(publicKeyString) + } + + return key, nil, nil +} + func parseHashFn(hashFn string) (hash.Hash, error) { switch strings.ToUpper(hashFn) { case "SHA1": @@ -353,6 +407,29 @@ func parseHashFn(hashFn string) (hash.Hash, error) { } } +// checkKeyFieldsSet: Checks which key fields are set. If both are set, an error is returned +func checkKeyFieldsSet(d *framework.FieldData) (bool, error) { + ciphertextSet := isFieldSet("ciphertext", d) + publicKeySet := isFieldSet("publicKey", d) + + if ciphertextSet && publicKeySet { + return false, errors.New("only one of the following fields, ciphertext and public_key, can be set") + } else if ciphertextSet { + return true, nil + } else { + return false, nil + } +} + +func isFieldSet(fieldName string, d *framework.FieldData) bool { + _, fieldSet := d.Raw[fieldName] + if !fieldSet { + return false + } + + return true +} + const ( pathImportWriteSyn = "Imports an externally-generated key into a new transit key" pathImportWriteDesc = "This path is used to import an externally-generated " + diff --git a/builtin/logical/transit/path_import_test.go b/builtin/logical/transit/path_import_test.go index 67b7a9ce4..a31d151c1 100644 --- a/builtin/logical/transit/path_import_test.go +++ b/builtin/logical/transit/path_import_test.go @@ -5,6 +5,7 @@ package transit import ( "context" + "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" @@ -12,6 +13,7 @@ import ( "crypto/rsa" "crypto/x509" "encoding/base64" + "encoding/pem" "fmt" "strconv" "sync" @@ -427,6 +429,70 @@ func TestTransit_Import(t *testing.T) { } }, ) + + t.Run( + "import public key ed25519", + func(t *testing.T) { + keyType := "ed25519" + keyID, err := uuid.GenerateUUID() + if err != nil { + t.Fatalf("failed to generate key ID: %s", err) + } + + // Get keys + privateKey := getKey(t, keyType) + publicKeyBytes, err := getPublicKey(privateKey, keyType) + if err != nil { + t.Fatal(err) + } + + // Import key + req := &logical.Request{ + Storage: s, + Operation: logical.UpdateOperation, + Path: fmt.Sprintf("keys/%s/import", keyID), + Data: map[string]interface{}{ + "public_key": publicKeyBytes, + "type": keyType, + }, + } + _, err = b.HandleRequest(context.Background(), req) + if err == nil { + t.Fatalf("invalid public_key import incorrectly succeeeded") + } + }) + + t.Run( + "import public key ecdsa", + func(t *testing.T) { + keyType := "ecdsa-p256" + keyID, err := uuid.GenerateUUID() + if err != nil { + t.Fatalf("failed to generate key ID: %s", err) + } + + // Get keys + privateKey := getKey(t, keyType) + publicKeyBytes, err := getPublicKey(privateKey, keyType) + if err != nil { + t.Fatal(err) + } + + // Import key + req := &logical.Request{ + Storage: s, + Operation: logical.UpdateOperation, + Path: fmt.Sprintf("keys/%s/import", keyID), + Data: map[string]interface{}{ + "public_key": publicKeyBytes, + "type": keyType, + }, + } + _, err = b.HandleRequest(context.Background(), req) + if err != nil { + t.Fatalf("failed to import public key: %s", err) + } + }) } func TestTransit_ImportVersion(t *testing.T) { @@ -573,6 +639,53 @@ func TestTransit_ImportVersion(t *testing.T) { } }, ) + + t.Run( + "import rsa public key and update version with private counterpart", + func(t *testing.T) { + keyType := "rsa-2048" + keyID, err := uuid.GenerateUUID() + if err != nil { + t.Fatalf("failed to generate key ID: %s", err) + } + + // Get keys + privateKey := getKey(t, keyType) + importBlob := wrapTargetKeyForImport(t, pubWrappingKey, privateKey, keyType, "SHA256") + publicKeyBytes, err := getPublicKey(privateKey, keyType) + if err != nil { + t.Fatal(err) + } + + // Import RSA public key + req := &logical.Request{ + Storage: s, + Operation: logical.UpdateOperation, + Path: fmt.Sprintf("keys/%s/import", keyID), + Data: map[string]interface{}{ + "public_key": publicKeyBytes, + "type": keyType, + }, + } + _, err = b.HandleRequest(context.Background(), req) + if err != nil { + t.Fatalf("failed to import public key: %s", err) + } + + // Update version - import RSA private key + req = &logical.Request{ + Storage: s, + Operation: logical.UpdateOperation, + Path: fmt.Sprintf("keys/%s/import_version", keyID), + Data: map[string]interface{}{ + "ciphertext": importBlob, + }, + } + _, err = b.HandleRequest(context.Background(), req) + if err != nil { + t.Fatalf("failed to update key: %s", err) + } + }) } func wrapTargetKeyForImport(t *testing.T, wrappingKey *rsa.PublicKey, targetKey interface{}, targetKeyType string, hashFnName string) string { @@ -663,3 +776,40 @@ func generateKey(keyType string) (interface{}, error) { return nil, fmt.Errorf("failed to generate unsupported key type: %s", keyType) } } + +func getPublicKey(privateKey crypto.PrivateKey, keyType string) ([]byte, error) { + var publicKey crypto.PublicKey + var publicKeyBytes []byte + switch keyType { + case "rsa-2048", "rsa-3072", "rsa-4096": + publicKey = privateKey.(*rsa.PrivateKey).Public() + case "ecdsa-p256", "ecdsa-p384", "ecdsa-p521": + publicKey = privateKey.(*ecdsa.PrivateKey).Public() + case "ed25519": + publicKey = privateKey.(ed25519.PrivateKey).Public() + default: + return publicKeyBytes, fmt.Errorf("failed to get public key from %s key", keyType) + } + + publicKeyBytes, err := publicKeyToBytes(publicKey) + if err != nil { + return publicKeyBytes, err + } + + return publicKeyBytes, nil +} + +func publicKeyToBytes(publicKey crypto.PublicKey) ([]byte, error) { + var publicKeyBytesPem []byte + publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey) + if err != nil { + return publicKeyBytesPem, fmt.Errorf("failed to marshal public key: %s", err) + } + + pemBlock := &pem.Block{ + Type: "PUBLIC KEY", + Bytes: publicKeyBytes, + } + + return pem.EncodeToMemory(pemBlock), nil +} diff --git a/builtin/logical/transit/path_keys.go b/builtin/logical/transit/path_keys.go index 8634d8741..65eec260a 100644 --- a/builtin/logical/transit/path_keys.go +++ b/builtin/logical/transit/path_keys.go @@ -5,6 +5,7 @@ package transit import ( "context" + "crypto" "crypto/elliptic" "crypto/x509" "encoding/base64" @@ -408,9 +409,15 @@ func (b *backend) pathPolicyRead(ctx context.Context, req *logical.Request, d *f key.Name = "rsa-4096" } + var publicKey crypto.PublicKey + publicKey = v.RSAPublicKey + if !v.IsPrivateKeyMissing() { + publicKey = v.RSAKey.Public() + } + // Encode the RSA public key in PEM format to return over the // API - derBytes, err := x509.MarshalPKIXPublicKey(v.RSAKey.Public()) + derBytes, err := x509.MarshalPKIXPublicKey(publicKey) if err != nil { return nil, fmt.Errorf("error marshaling RSA public key: %w", err) } diff --git a/changelog/17934.txt b/changelog/17934.txt new file mode 100644 index 000000000..7f087a915 --- /dev/null +++ b/changelog/17934.txt @@ -0,0 +1,3 @@ +```release-note:improvement +secrets/transit: Add support to import public keys in transit engine and allow encryption and verification of signed data +``` diff --git a/sdk/helper/keysutil/lock_manager.go b/sdk/helper/keysutil/lock_manager.go index 306dd1693..6d2881e0d 100644 --- a/sdk/helper/keysutil/lock_manager.go +++ b/sdk/helper/keysutil/lock_manager.go @@ -63,6 +63,9 @@ type PolicyRequest struct { // AllowImportedKeyRotation indicates whether an imported key may be rotated by Vault AllowImportedKeyRotation bool + // Indicates whether a private or public key is imported/upserted + IsPrivateKey bool + // The UUID of the managed key, if using one ManagedKeyUUID string } @@ -511,7 +514,7 @@ func (lm *LockManager) ImportPolicy(ctx context.Context, req PolicyRequest, key } } - err = p.Import(ctx, req.Storage, key, rand) + err = p.ImportPublicOrPrivate(ctx, req.Storage, key, req.IsPrivateKey, rand) if err != nil { return fmt.Errorf("error importing key: %s", err) } diff --git a/sdk/helper/keysutil/policy.go b/sdk/helper/keysutil/policy.go index d5620e31c..750a63926 100644 --- a/sdk/helper/keysutil/policy.go +++ b/sdk/helper/keysutil/policy.go @@ -167,6 +167,14 @@ func (kt KeyType) AssociatedDataSupported() bool { return false } +func (kt KeyType) ImportPublicKeySupported() bool { + switch kt { + case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096, KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521: + return true + } + return false +} + func (kt KeyType) String() string { switch kt { case KeyType_AES128_GCM96: @@ -218,7 +226,8 @@ type KeyEntry struct { EC_Y *big.Int `json:"ec_y"` EC_D *big.Int `json:"ec_d"` - RSAKey *rsa.PrivateKey `json:"rsa_key"` + RSAKey *rsa.PrivateKey `json:"rsa_key"` + RSAPublicKey *rsa.PublicKey `json:"rsa_public_key"` // The public key in an appropriate format for the type of key FormattedPublicKey string `json:"public_key"` @@ -234,6 +243,14 @@ type KeyEntry struct { ManagedKeyUUID string `json:"managed_key_id,omitempty"` } +func (ke *KeyEntry) IsPrivateKeyMissing() bool { + if ke.RSAKey != nil || ke.EC_D != nil || len(ke.Key) != 0 { + return false + } + + return true +} + // deprecatedKeyEntryMap is used to allow JSON marshal/unmarshal type deprecatedKeyEntryMap map[int]KeyEntry @@ -969,6 +986,9 @@ func (p *Policy) DecryptWithFactory(context, nonce []byte, value string, factori return "", err } key := keyEntry.RSAKey + if key == nil { + return "", errutil.InternalError{Err: fmt.Sprintf("cannot decrypt ciphertext, key version does not have a private counterpart")} + } plain, err = rsa.DecryptOAEP(sha256.New(), rand.Reader, key, decoded, nil) if err != nil { return "", errutil.InternalError{Err: fmt.Sprintf("failed to RSA decrypt the ciphertext: %v", err)} @@ -1043,13 +1063,13 @@ func (p *Policy) minRSAPSSSaltLength() int { return rsa.PSSSaltLengthEqualsHash } -func (p *Policy) maxRSAPSSSaltLength(priv *rsa.PrivateKey, hash crypto.Hash) int { +func (p *Policy) maxRSAPSSSaltLength(keyBitLen int, hash crypto.Hash) int { // https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/crypto/rsa/pss.go;l=288 - return (priv.N.BitLen()-1+7)/8 - 2 - hash.Size() + return (keyBitLen-1+7)/8 - 2 - hash.Size() } -func (p *Policy) validRSAPSSSaltLength(priv *rsa.PrivateKey, hash crypto.Hash, saltLength int) bool { - return p.minRSAPSSSaltLength() <= saltLength && saltLength <= p.maxRSAPSSSaltLength(priv, hash) +func (p *Policy) validRSAPSSSaltLength(keyBitLen int, hash crypto.Hash, saltLength int) bool { + return p.minRSAPSSSaltLength() <= saltLength && saltLength <= p.maxRSAPSSSaltLength(keyBitLen, hash) } func (p *Policy) SignWithOptions(ver int, context, input []byte, options *SigningOptions) (*SigningResult, error) { @@ -1076,6 +1096,11 @@ func (p *Policy) SignWithOptions(ver int, context, input []byte, options *Signin return nil, err } + // Before signing, check if key has its private part, if not return error + if keyParams.IsPrivateKeyMissing() { + return nil, errutil.UserError{Err: "requested version for signing does not contain a private part"} + } + hashAlgorithm := options.HashAlgorithm marshaling := options.Marshaling saltLength := options.SaltLength @@ -1182,7 +1207,7 @@ func (p *Policy) SignWithOptions(ver int, context, input []byte, options *Signin switch sigAlgorithm { case "pss": - if !p.validRSAPSSSaltLength(key, algo, saltLength) { + if !p.validRSAPSSSaltLength(key.N.BitLen(), algo, saltLength) { return nil, errutil.UserError{Err: fmt.Sprintf("requested salt length %d is invalid", saltLength)} } sig, err = rsa.SignPSS(rand.Reader, key, algo, input, &rsa.PSSOptions{SaltLength: saltLength}) @@ -1357,8 +1382,6 @@ func (p *Policy) VerifySignatureWithOptions(context, input []byte, sig string, o return false, err } - key := keyEntry.RSAKey - algo, ok := CryptoHashMap[hashAlgorithm] if !ok { return false, errutil.InternalError{Err: "unsupported hash algorithm"} @@ -1370,12 +1393,20 @@ func (p *Policy) VerifySignatureWithOptions(context, input []byte, sig string, o switch sigAlgorithm { case "pss": - if !p.validRSAPSSSaltLength(key, algo, saltLength) { + publicKey := keyEntry.RSAPublicKey + if !keyEntry.IsPrivateKeyMissing() { + publicKey = &keyEntry.RSAKey.PublicKey + } + if !p.validRSAPSSSaltLength(publicKey.N.BitLen(), algo, saltLength) { return false, errutil.UserError{Err: fmt.Sprintf("requested salt length %d is invalid", saltLength)} } - err = rsa.VerifyPSS(&key.PublicKey, algo, input, sigBytes, &rsa.PSSOptions{SaltLength: saltLength}) + err = rsa.VerifyPSS(publicKey, algo, input, sigBytes, &rsa.PSSOptions{SaltLength: saltLength}) case "pkcs1v15": - err = rsa.VerifyPKCS1v15(&key.PublicKey, algo, input, sigBytes) + publicKey := keyEntry.RSAPublicKey + if !keyEntry.IsPrivateKeyMissing() { + publicKey = &keyEntry.RSAKey.PublicKey + } + err = rsa.VerifyPKCS1v15(publicKey, algo, input, sigBytes) default: return false, errutil.InternalError{Err: fmt.Sprintf("unsupported rsa signature algorithm %s", sigAlgorithm)} } @@ -1396,6 +1427,10 @@ func (p *Policy) VerifySignatureWithOptions(context, input []byte, sig string, o } func (p *Policy) Import(ctx context.Context, storage logical.Storage, key []byte, randReader io.Reader) error { + return p.ImportPublicOrPrivate(ctx, storage, key, true, randReader) +} + +func (p *Policy) ImportPublicOrPrivate(ctx context.Context, storage logical.Storage, key []byte, isPrivateKey bool, randReader io.Reader) error { now := time.Now() entry := KeyEntry{ CreationTime: now, @@ -1422,91 +1457,42 @@ func (p *Policy) Import(ctx context.Context, storage logical.Storage, key []byte p.KeySize = len(key) } } else { - parsedPrivateKey, err := x509.ParsePKCS8PrivateKey(key) - if err != nil { - if strings.Contains(err.Error(), "unknown elliptic curve") { - var edErr error - parsedPrivateKey, edErr = ParsePKCS8Ed25519PrivateKey(key) - if edErr != nil { - return fmt.Errorf("error parsing asymmetric key:\n - assuming contents are an ed25519 private key: %v\n - original error: %w", edErr, err) - } + var parsedKey any + var err error + if isPrivateKey { + parsedKey, err = x509.ParsePKCS8PrivateKey(key) + if err != nil { + if strings.Contains(err.Error(), "unknown elliptic curve") { + var edErr error + parsedKey, edErr = ParsePKCS8Ed25519PrivateKey(key) + if edErr != nil { + return fmt.Errorf("error parsing asymmetric key:\n - assuming contents are an ed25519 private key: %s\n - original error: %v", edErr, err) + } - // Parsing as Ed25519-in-PKCS8-ECPrivateKey succeeded! - } else if strings.Contains(err.Error(), oidSignatureRSAPSS.String()) { - var rsaErr error - parsedPrivateKey, rsaErr = ParsePKCS8RSAPSSPrivateKey(key) - if rsaErr != nil { - return fmt.Errorf("error parsing asymmetric key:\n - assuming contents are an RSA/PSS private key: %v\n - original error: %w", rsaErr, err) - } + // Parsing as Ed25519-in-PKCS8-ECPrivateKey succeeded! + } else if strings.Contains(err.Error(), oidSignatureRSAPSS.String()) { + var rsaErr error + parsedKey, rsaErr = ParsePKCS8RSAPSSPrivateKey(key) + if rsaErr != nil { + return fmt.Errorf("error parsing asymmetric key:\n - assuming contents are an RSA/PSS private key: %v\n - original error: %w", rsaErr, err) + } - // Parsing as RSA-PSS in PKCS8 succeeded! - } else { - return fmt.Errorf("error parsing asymmetric key: %s", err) + // Parsing as RSA-PSS in PKCS8 succeeded! + } else { + return fmt.Errorf("error parsing asymmetric key: %s", err) + } + } + } else { + pemBlock, _ := pem.Decode(key) + parsedKey, err = x509.ParsePKIXPublicKey(pemBlock.Bytes) + if err != nil { + return fmt.Errorf("error parsing public key: %s", err) } } - switch parsedPrivateKey.(type) { - case *ecdsa.PrivateKey: - if p.Type != KeyType_ECDSA_P256 && p.Type != KeyType_ECDSA_P384 && p.Type != KeyType_ECDSA_P521 { - return fmt.Errorf("invalid key type: expected %s, got %T", p.Type, parsedPrivateKey) - } - - ecdsaKey := parsedPrivateKey.(*ecdsa.PrivateKey) - curve := elliptic.P256() - if p.Type == KeyType_ECDSA_P384 { - curve = elliptic.P384() - } else if p.Type == KeyType_ECDSA_P521 { - curve = elliptic.P521() - } - - if ecdsaKey.Curve != curve { - return fmt.Errorf("invalid curve: expected %s, got %s", curve.Params().Name, ecdsaKey.Curve.Params().Name) - } - - entry.EC_D = ecdsaKey.D - entry.EC_X = ecdsaKey.X - entry.EC_Y = ecdsaKey.Y - derBytes, err := x509.MarshalPKIXPublicKey(ecdsaKey.Public()) - if err != nil { - return errwrap.Wrapf("error marshaling public key: {{err}}", err) - } - pemBlock := &pem.Block{ - Type: "PUBLIC KEY", - Bytes: derBytes, - } - pemBytes := pem.EncodeToMemory(pemBlock) - if pemBytes == nil || len(pemBytes) == 0 { - return fmt.Errorf("error PEM-encoding public key") - } - entry.FormattedPublicKey = string(pemBytes) - case ed25519.PrivateKey: - if p.Type != KeyType_ED25519 { - return fmt.Errorf("invalid key type: expected %s, got %T", p.Type, parsedPrivateKey) - } - privateKey := parsedPrivateKey.(ed25519.PrivateKey) - - entry.Key = privateKey - publicKey := privateKey.Public().(ed25519.PublicKey) - entry.FormattedPublicKey = base64.StdEncoding.EncodeToString(publicKey) - case *rsa.PrivateKey: - if p.Type != KeyType_RSA2048 && p.Type != KeyType_RSA3072 && p.Type != KeyType_RSA4096 { - return fmt.Errorf("invalid key type: expected %s, got %T", p.Type, parsedPrivateKey) - } - - keyBytes := 256 - if p.Type == KeyType_RSA3072 { - keyBytes = 384 - } else if p.Type == KeyType_RSA4096 { - keyBytes = 512 - } - rsaKey := parsedPrivateKey.(*rsa.PrivateKey) - if rsaKey.Size() != keyBytes { - return fmt.Errorf("invalid key size: expected %d bytes, got %d bytes", keyBytes, rsaKey.Size()) - } - - entry.RSAKey = rsaKey - default: - return fmt.Errorf("invalid key type: expected %s, got %T", p.Type, parsedPrivateKey) + err = entry.parseFromKey(p.Type, parsedKey) + if err != nil { + return err } } @@ -2021,8 +2007,13 @@ func (p *Policy) EncryptWithFactory(ver int, context []byte, nonce []byte, value if err != nil { return "", err } - key := keyEntry.RSAKey - ciphertext, err = rsa.EncryptOAEP(sha256.New(), rand.Reader, &key.PublicKey, plaintext, nil) + var publicKey *rsa.PublicKey + if keyEntry.RSAKey != nil { + publicKey = &keyEntry.RSAKey.PublicKey + } else { + publicKey = keyEntry.RSAPublicKey + } + ciphertext, err = rsa.EncryptOAEP(sha256.New(), rand.Reader, publicKey, plaintext, nil) if err != nil { return "", errutil.InternalError{Err: fmt.Sprintf("failed to RSA encrypt the plaintext: %v", err)} } @@ -2067,3 +2058,163 @@ func (p *Policy) EncryptWithFactory(ver int, context []byte, nonce []byte, value return encoded, nil } + +func (p *Policy) KeyVersionCanBeUpdated(keyVersion int, isPrivateKey bool) error { + keyEntry, err := p.safeGetKeyEntry(keyVersion) + if err != nil { + return err + } + + if !p.Type.ImportPublicKeySupported() { + return errors.New("provided type does not support importing key versions") + } + + isPrivateKeyMissing := keyEntry.IsPrivateKeyMissing() + if isPrivateKeyMissing && !isPrivateKey { + return errors.New("cannot add a public key to a key version that already has a public key set") + } + + if !isPrivateKeyMissing { + return errors.New("private key imported, key version cannot be updated") + } + + return nil +} + +func (p *Policy) ImportPrivateKeyForVersion(ctx context.Context, storage logical.Storage, keyVersion int, key []byte) error { + keyEntry, err := p.safeGetKeyEntry(keyVersion) + if err != nil { + return err + } + + // Parse key + parsedPrivateKey, err := x509.ParsePKCS8PrivateKey(key) + if err != nil { + return fmt.Errorf("error parsing asymmetric key: %s", err) + } + + switch parsedPrivateKey.(type) { + case *ecdsa.PrivateKey: + ecdsaKey := parsedPrivateKey.(*ecdsa.PrivateKey) + pemBlock, _ := pem.Decode([]byte(keyEntry.FormattedPublicKey)) + publicKey, err := x509.ParsePKIXPublicKey(pemBlock.Bytes) + if err != nil { + return fmt.Errorf("failed to parse key entry public key: %v", err) + } + if !publicKey.(*ecdsa.PublicKey).Equal(ecdsaKey.PublicKey) { + return fmt.Errorf("cannot import key, key pair does not match") + } + case *rsa.PrivateKey: + rsaKey := parsedPrivateKey.(*rsa.PrivateKey) + if !rsaKey.PublicKey.Equal(keyEntry.RSAPublicKey) { + return fmt.Errorf("cannot import key, key pair does not match") + } + } + + err = keyEntry.parseFromKey(p.Type, parsedPrivateKey) + if err != nil { + return err + } + + p.Keys[strconv.Itoa(keyVersion)] = keyEntry + + return p.Persist(ctx, storage) +} + +func (ke *KeyEntry) parseFromKey(PolKeyType KeyType, parsedKey any) error { + switch parsedKey.(type) { + case *ecdsa.PrivateKey, *ecdsa.PublicKey: + if PolKeyType != KeyType_ECDSA_P256 && PolKeyType != KeyType_ECDSA_P384 && PolKeyType != KeyType_ECDSA_P521 { + return fmt.Errorf("invalid key type: expected %s, got %T", PolKeyType, parsedKey) + } + + curve := elliptic.P256() + if PolKeyType == KeyType_ECDSA_P384 { + curve = elliptic.P384() + } else if PolKeyType == KeyType_ECDSA_P521 { + curve = elliptic.P521() + } + + var derBytes []byte + var err error + ecdsaKey, ok := parsedKey.(*ecdsa.PrivateKey) + if ok { + + if ecdsaKey.Curve != curve { + return fmt.Errorf("invalid curve: expected %s, got %s", curve.Params().Name, ecdsaKey.Curve.Params().Name) + } + + ke.EC_D = ecdsaKey.D + ke.EC_X = ecdsaKey.X + ke.EC_Y = ecdsaKey.Y + + derBytes, err = x509.MarshalPKIXPublicKey(ecdsaKey.Public()) + if err != nil { + return errwrap.Wrapf("error marshaling public key: {{err}}", err) + } + } else { + ecdsaKey := parsedKey.(*ecdsa.PublicKey) + + if ecdsaKey.Curve != curve { + return fmt.Errorf("invalid curve: expected %s, got %s", curve.Params().Name, ecdsaKey.Curve.Params().Name) + } + + ke.EC_X = ecdsaKey.X + ke.EC_Y = ecdsaKey.Y + + derBytes, err = x509.MarshalPKIXPublicKey(ecdsaKey) + if err != nil { + return errwrap.Wrapf("error marshaling public key: {{err}}", err) + } + } + + pemBlock := &pem.Block{ + Type: "PUBLIC KEY", + Bytes: derBytes, + } + pemBytes := pem.EncodeToMemory(pemBlock) + if pemBytes == nil || len(pemBytes) == 0 { + return fmt.Errorf("error PEM-encoding public key") + } + ke.FormattedPublicKey = string(pemBytes) + case ed25519.PrivateKey: + if PolKeyType != KeyType_ED25519 { + return fmt.Errorf("invalid key type: expected %s, got %T", PolKeyType, parsedKey) + } + privateKey := parsedKey.(ed25519.PrivateKey) + + ke.Key = privateKey + publicKey := privateKey.Public().(ed25519.PublicKey) + ke.FormattedPublicKey = base64.StdEncoding.EncodeToString(publicKey) + case *rsa.PrivateKey, *rsa.PublicKey: + if PolKeyType != KeyType_RSA2048 && PolKeyType != KeyType_RSA3072 && PolKeyType != KeyType_RSA4096 { + return fmt.Errorf("invalid key type: expected %s, got %T", PolKeyType, parsedKey) + } + + keyBytes := 256 + if PolKeyType == KeyType_RSA3072 { + keyBytes = 384 + } else if PolKeyType == KeyType_RSA4096 { + keyBytes = 512 + } + + rsaKey, ok := parsedKey.(*rsa.PrivateKey) + if ok { + if rsaKey.Size() != keyBytes { + return fmt.Errorf("invalid key size: expected %d bytes, got %d bytes", keyBytes, rsaKey.Size()) + } + ke.RSAKey = rsaKey + ke.RSAPublicKey = nil + } else { + rsaKey := parsedKey.(*rsa.PublicKey) + if rsaKey.Size() != keyBytes { + return fmt.Errorf("invalid key size: expected %d bytes, got %d bytes", keyBytes, rsaKey.Size()) + } + ke.RSAPublicKey = rsaKey + } + default: + return fmt.Errorf("invalid key type: expected %s, got %T", PolKeyType, parsedKey) + } + + return nil +} diff --git a/sdk/helper/keysutil/policy_test.go b/sdk/helper/keysutil/policy_test.go index daf19a825..f5e4d35eb 100644 --- a/sdk/helper/keysutil/policy_test.go +++ b/sdk/helper/keysutil/policy_test.go @@ -846,7 +846,7 @@ func Test_RSA_PSS(t *testing.T) { } cryptoHash := CryptoHashMap[hashType] minSaltLength := p.minRSAPSSSaltLength() - maxSaltLength := p.maxRSAPSSSaltLength(rsaKey, cryptoHash) + maxSaltLength := p.maxRSAPSSSaltLength(rsaKey.N.BitLen(), cryptoHash) hash := cryptoHash.New() hash.Write(input) input = hash.Sum(nil) diff --git a/website/content/api-docs/secret/transit.mdx b/website/content/api-docs/secret/transit.mdx index ac77b9cc0..302198ffa 100644 --- a/website/content/api-docs/secret/transit.mdx +++ b/website/content/api-docs/secret/transit.mdx @@ -109,6 +109,7 @@ $ curl \ This endpoint imports existing key material into a new transit-managed encryption key. To import key material into an existing key, see the `import_version/` endpoint. +// TODO: Has to be updated. | Method | Path | | :----- | :--------------------------- | @@ -125,7 +126,8 @@ returned by Vault and the encryption of the import key material under the provided AES key. The wrapped AES key should be the first 512 bytes of the ciphertext, and the encrypted key material should be the remaining bytes. See the BYOK section of the [Transit secrets engine documentation](/vault/docs/secrets/transit#bring-your-own-key-byok) -for more information on constructing the ciphertext. +for more information on constructing the ciphertext. If `public_key` is set, +this field is not required. - `hash_function` `(string: "SHA256")` - The hash function used for the RSA-OAEP step of creating the ciphertext. Supported hash functions are: @@ -151,6 +153,9 @@ the hash function defaults to SHA256. - `rsa-3072` - RSA with bit size of 3072 (asymmetric) - `rsa-4096` - RSA with bit size of 4096 (asymmetric) +- `public_key` `(string: "", optional)` - A plaintext PEM public key to be imported. +If `ciphertext` is set, this field is ignored. + - `allow_rotation` `(bool: false)` - If set, the imported key can be rotated within Vault by using the `rotate` endpoint. @@ -198,6 +203,7 @@ $ curl \ ## Import Key Version This endpoint imports new key material into an existing imported key. +// TODO: Has to be updated. | Method | Path | | :----- | :----------------------------------- | @@ -219,12 +225,23 @@ provided AES key. The wrapped AES key should be the first 512 bytes of the ciphertext, and the encrypted key material should be the remaining bytes. See the BYOK section of the [Transit secrets engine documentation](/vault/docs/secrets/transit#bring-your-own-key-byok) for more information on constructing the ciphertext. +// TODO: Update text - `hash_function` `(string: "SHA256")` - The hash function used for the RSA-OAEP step of creating the ciphertext. Supported hash functions are: `SHA1`, `SHA224`, `SHA256`, `SHA384`, and `SHA512`. If not specified, the hash function defaults to SHA256. +- `public_key` `(string: "", optional)` - A plaintext PEM public key to be imported. + If `ciphertext` is set, this field is ignored. + +- `bump_version` - By default, each operator will create a new key version. +If set to "false", will try to update the latest version of the key, +unless changed in parameter `version`. + +- `version` - Key version to be updated, if left empty "Latest" version will be updated. +If `bump_version` is set to "true", this field is ignored. + ### Sample Payload ```json