From 303f59dce3a4b79fce5c6d9855b308e0904b83f4 Mon Sep 17 00:00:00 2001 From: Trishank Karthik Kuppusamy Date: Wed, 31 Aug 2022 12:27:03 -0400 Subject: [PATCH] Allow configuring the possible salt lengths for RSA PSS signatures (#16549) * accommodate salt lengths for RSA PSS * address feedback * generalise salt length to an int * fix error reporting * Revert "fix error reporting" This reverts commit 8adfc15fe3303b8fdf9f094ea246945ab1364077. * fix a faulty check * check for min/max salt lengths * stringly-typed HTTP param * unit tests for sign/verify HTTP requests also, add marshaling for both SDK and HTTP requests * randomly sample valid salt length * add changelog * add documentation --- builtin/logical/transit/path_sign_verify.go | 66 ++++- .../logical/transit/path_sign_verify_test.go | 255 ++++++++++++++++++ changelog/16549.txt | 3 + sdk/helper/keysutil/consts.go | 13 + sdk/helper/keysutil/policy.go | 105 +++++--- sdk/helper/keysutil/policy_test.go | 201 ++++++++++++++ website/content/api-docs/secret/transit.mdx | 12 + 7 files changed, 609 insertions(+), 46 deletions(-) create mode 100644 changelog/16549.txt diff --git a/builtin/logical/transit/path_sign_verify.go b/builtin/logical/transit/path_sign_verify.go index ade69530d..1f0a9f3cb 100644 --- a/builtin/logical/transit/path_sign_verify.go +++ b/builtin/logical/transit/path_sign_verify.go @@ -2,8 +2,11 @@ package transit import ( "context" + "crypto/rsa" "encoding/base64" "fmt" + "strconv" + "strings" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/errutil" @@ -131,6 +134,13 @@ Options are 'pss' or 'pkcs1v15'. Defaults to 'pss'`, Default: "asn1", Description: `The method by which to marshal the signature. The default is 'asn1' which is used by openssl and X.509. It can also be set to 'jws' which is used for JWT signatures; setting it to this will also cause the encoding of the signature to be url-safe base64 instead of using standard base64 encoding. Currently only valid for ECDSA P-256 key types".`, }, + + "salt_length": { + Type: framework.TypeString, + Default: "auto", + Description: `The salt length used to sign. Currently only applies to the RSA PSS signature scheme. +Options are 'auto' (the default used by Golang, causing the salt to be as large as possible when signing), 'hash' (causes the salt length to equal the length of the hash used in the signature), or an integer between the minimum and the maximum permissible salt lengths for the given RSA key size. Defaults to 'auto'.`, + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ @@ -217,6 +227,13 @@ Options are 'pss' or 'pkcs1v15'. Defaults to 'pss'`, Default: "asn1", Description: `The method by which to unmarshal the signature when verifying. The default is 'asn1' which is used by openssl and X.509; can also be set to 'jws' which is used for JWT signatures in which case the signature is also expected to be url-safe base64 encoding instead of standard base64 encoding. Currently only valid for ECDSA P-256 key types".`, }, + + "salt_length": { + Type: framework.TypeString, + Default: "auto", + Description: `The salt length used to sign. Currently only applies to the RSA PSS signature scheme. +Options are 'auto' (the default used by Golang, causing the salt to be as large as possible when signing), 'hash' (causes the salt length to equal the length of the hash used in the signature), or an integer between the minimum and the maximum permissible salt lengths for the given RSA key size. Defaults to 'auto'.`, + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ @@ -228,6 +245,33 @@ Options are 'pss' or 'pkcs1v15'. Defaults to 'pss'`, } } +func (b *backend) getSaltLength(d *framework.FieldData) (int, error) { + rawSaltLength, ok := d.GetOk("salt_length") + // This should only happen when something is wrong with the schema, + // so this is a reasonable default. + if !ok { + return rsa.PSSSaltLengthAuto, nil + } + + rawSaltLengthStr := rawSaltLength.(string) + lowerSaltLengthStr := strings.ToLower(rawSaltLengthStr) + switch lowerSaltLengthStr { + case "auto": + return rsa.PSSSaltLengthAuto, nil + case "hash": + return rsa.PSSSaltLengthEqualsHash, nil + default: + saltLengthInt, err := strconv.Atoi(lowerSaltLengthStr) + if err != nil { + return rsa.PSSSaltLengthEqualsHash - 1, fmt.Errorf("salt length neither 'auto', 'hash', nor an int: %s", rawSaltLength) + } + if saltLengthInt < rsa.PSSSaltLengthEqualsHash { + return rsa.PSSSaltLengthEqualsHash - 1, fmt.Errorf("salt length is invalid: %d", saltLengthInt) + } + return saltLengthInt, nil + } +} + func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) ver := d.Get("key_version").(int) @@ -252,6 +296,10 @@ func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *fr prehashed := d.Get("prehashed").(bool) sigAlgorithm := d.Get("signature_algorithm").(string) + saltLength, err := b.getSaltLength(d) + if err != nil { + return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + } // Get the policy p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{ @@ -330,7 +378,12 @@ func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *fr } } - sig, err := p.Sign(ver, context, input, hashAlgorithm, sigAlgorithm, marshaling) + sig, err := p.SignWithOptions(ver, context, input, &keysutil.SigningOptions{ + HashAlgorithm: hashAlgorithm, + Marshaling: marshaling, + SaltLength: saltLength, + SigAlgorithm: sigAlgorithm, + }) if err != nil { if batchInputRaw != nil { response[i].Error = err.Error() @@ -470,6 +523,10 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d * prehashed := d.Get("prehashed").(bool) sigAlgorithm := d.Get("signature_algorithm").(string) + saltLength, err := b.getSaltLength(d) + if err != nil { + return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + } // Get the policy p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{ @@ -533,7 +590,12 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d * } } - valid, err := p.VerifySignature(context, input, hashAlgorithm, sigAlgorithm, marshaling, sig) + valid, err := p.VerifySignatureWithOptions(context, input, sig, &keysutil.SigningOptions{ + HashAlgorithm: hashAlgorithm, + Marshaling: marshaling, + SaltLength: saltLength, + SigAlgorithm: sigAlgorithm, + }) if err != nil { switch err.(type) { case errutil.UserError: diff --git a/builtin/logical/transit/path_sign_verify_test.go b/builtin/logical/transit/path_sign_verify_test.go index 072f8a265..7436675c4 100644 --- a/builtin/logical/transit/path_sign_verify_test.go +++ b/builtin/logical/transit/path_sign_verify_test.go @@ -700,3 +700,258 @@ func TestTransit_SignVerify_ED25519(t *testing.T) { outcome[1].valid = false verifyRequest(req, false, outcome, "bar", goodsig, true) } + +func TestTransit_SignVerify_RSA_PSS(t *testing.T) { + t.Run("2048", func(t *testing.T) { + testTransit_SignVerify_RSA_PSS(t, 2048) + }) + t.Run("3072", func(t *testing.T) { + testTransit_SignVerify_RSA_PSS(t, 3072) + }) + t.Run("4096", func(t *testing.T) { + testTransit_SignVerify_RSA_PSS(t, 4096) + }) +} + +func testTransit_SignVerify_RSA_PSS(t *testing.T, bits int) { + b, storage := createBackendWithSysView(t) + + // First create a key + req := &logical.Request{ + Storage: storage, + Operation: logical.UpdateOperation, + Path: "keys/foo", + Data: map[string]interface{}{ + "type": fmt.Sprintf("rsa-%d", bits), + }, + } + _, err := b.HandleRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + + signRequest := func(errExpected bool, postpath string) string { + t.Helper() + req.Path = "sign/foo" + postpath + resp, err := b.HandleRequest(context.Background(), req) + if err != nil && !errExpected { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if errExpected { + if !resp.IsError() { + t.Fatalf("bad: should have gotten error response: %#v", *resp) + } + return "" + } + if resp.IsError() { + t.Fatalf("bad: got error response: %#v", *resp) + } + // Since we are reusing the same request, let's clear the salt length each time. + delete(req.Data, "salt_length") + + value, ok := resp.Data["signature"] + if !ok { + t.Fatalf("no signature key found in returned data, got resp data %#v", resp.Data) + } + return value.(string) + } + + verifyRequest := func(errExpected bool, postpath, sig string) { + t.Helper() + req.Path = "verify/foo" + postpath + req.Data["signature"] = sig + resp, err := b.HandleRequest(context.Background(), req) + if err != nil { + if errExpected { + return + } + t.Fatalf("got error: %v, sig was %v", err, sig) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if resp.IsError() { + if errExpected { + return + } + t.Fatalf("bad: got error response: %#v", *resp) + } + value, ok := resp.Data["valid"] + if !ok { + t.Fatalf("no valid key found in returned data, got resp data %#v", resp.Data) + } + if !value.(bool) && !errExpected { + t.Fatalf("verification failed; req was %#v, resp is %#v", *req, *resp) + } else if value.(bool) && errExpected { + t.Fatalf("expected error and didn't get one; req was %#v, resp is %#v", *req, *resp) + } + // Since we are reusing the same request, let's clear the signature each time. + delete(req.Data, "signature") + } + + newReqData := func(hashAlgorithm string, marshalingName string) map[string]interface{} { + return map[string]interface{}{ + "input": "dGhlIHF1aWNrIGJyb3duIGZveA==", + "signature_algorithm": "pss", + "hash_algorithm": hashAlgorithm, + "marshaling_algorithm": marshalingName, + } + } + + signAndVerifyRequest := func(hashAlgorithm string, marshalingName string, signSaltLength string, signErrExpected bool, verifySaltLength string, verifyErrExpected bool) { + t.Log("\t\t\t", signSaltLength, "/", verifySaltLength) + req.Data = newReqData(hashAlgorithm, marshalingName) + + req.Data["salt_length"] = signSaltLength + t.Log("\t\t\t\t", "sign req data:", req.Data) + sig := signRequest(signErrExpected, "") + + req.Data["salt_length"] = verifySaltLength + t.Log("\t\t\t\t", "verify req data:", req.Data) + verifyRequest(verifyErrExpected, "", sig) + } + + invalidSaltLengths := []string{"bar", "-2"} + t.Log("invalidSaltLengths:", invalidSaltLengths) + + autoSaltLengths := []string{"auto", "0"} + t.Log("autoSaltLengths:", autoSaltLengths) + + hashSaltLengths := []string{"hash", "-1"} + t.Log("hashSaltLengths:", hashSaltLengths) + + positiveSaltLengths := []string{"1"} + t.Log("positiveSaltLengths:", positiveSaltLengths) + + nonAutoSaltLengths := append(hashSaltLengths, positiveSaltLengths...) + t.Log("nonAutoSaltLengths:", nonAutoSaltLengths) + + validSaltLengths := append(autoSaltLengths, nonAutoSaltLengths...) + t.Log("validSaltLengths:", validSaltLengths) + + testCombinatorics := func(hashAlgorithm string, marshalingName string) { + t.Log("\t\t", "valid", "/", "invalid salt lengths") + for _, validSaltLength := range validSaltLengths { + for _, invalidSaltLength := range invalidSaltLengths { + signAndVerifyRequest(hashAlgorithm, marshalingName, validSaltLength, false, invalidSaltLength, true) + } + } + + t.Log("\t\t", "invalid", "/", "invalid salt lengths") + for _, invalidSaltLength1 := range invalidSaltLengths { + for _, invalidSaltLength2 := range invalidSaltLengths { + signAndVerifyRequest(hashAlgorithm, marshalingName, invalidSaltLength1, true, invalidSaltLength2, true) + } + } + + t.Log("\t\t", "invalid", "/", "valid salt lengths") + for _, invalidSaltLength := range invalidSaltLengths { + for _, validSaltLength := range validSaltLengths { + signAndVerifyRequest(hashAlgorithm, marshalingName, invalidSaltLength, true, validSaltLength, true) + } + } + + t.Log("\t\t", "valid", "/", "valid salt lengths") + for _, validSaltLength := range validSaltLengths { + signAndVerifyRequest(hashAlgorithm, marshalingName, validSaltLength, false, validSaltLength, false) + } + + t.Log("\t\t", "hash", "/", "hash salt lengths") + for _, hashSaltLength1 := range hashSaltLengths { + for _, hashSaltLength2 := range hashSaltLengths { + if hashSaltLength1 != hashSaltLength2 { + signAndVerifyRequest(hashAlgorithm, marshalingName, hashSaltLength1, false, hashSaltLength2, false) + } + } + } + + t.Log("\t\t", "hash", "/", "positive salt lengths") + for _, hashSaltLength := range hashSaltLengths { + for _, positiveSaltLength := range positiveSaltLengths { + signAndVerifyRequest(hashAlgorithm, marshalingName, hashSaltLength, false, positiveSaltLength, true) + } + } + + t.Log("\t\t", "positive", "/", "hash salt lengths") + for _, positiveSaltLength := range positiveSaltLengths { + for _, hashSaltLength := range hashSaltLengths { + signAndVerifyRequest(hashAlgorithm, marshalingName, positiveSaltLength, false, hashSaltLength, true) + } + } + + t.Log("\t\t", "auto", "/", "auto salt lengths") + for _, autoSaltLength1 := range autoSaltLengths { + for _, autoSaltLength2 := range autoSaltLengths { + if autoSaltLength1 != autoSaltLength2 { + signAndVerifyRequest(hashAlgorithm, marshalingName, autoSaltLength1, false, autoSaltLength2, false) + } + } + } + + t.Log("\t\t", "auto", "/", "non-auto salt lengths") + for _, autoSaltLength := range autoSaltLengths { + for _, nonAutoSaltLength := range nonAutoSaltLengths { + signAndVerifyRequest(hashAlgorithm, marshalingName, autoSaltLength, false, nonAutoSaltLength, true) + } + } + + t.Log("\t\t", "non-auto", "/", "auto salt lengths") + for _, nonAutoSaltLength := range nonAutoSaltLengths { + for _, autoSaltLength := range autoSaltLengths { + signAndVerifyRequest(hashAlgorithm, marshalingName, nonAutoSaltLength, false, autoSaltLength, false) + } + } + } + + testAutoSignAndVerify := func(hashAlgorithm string, marshalingName string) { + t.Log("\t\t", "Make a signature with an implicit, automatic salt length") + req.Data = newReqData(hashAlgorithm, marshalingName) + t.Log("\t\t\t", "sign req data:", req.Data) + sig := signRequest(false, "") + + t.Log("\t\t", "Verify it with an implicit, automatic salt length") + t.Log("\t\t\t", "verify req data:", req.Data) + verifyRequest(false, "", sig) + + t.Log("\t\t", "Verify it with an explicit, automatic salt length") + for _, autoSaltLength := range autoSaltLengths { + t.Log("\t\t\t", "auto", "/", autoSaltLength) + req.Data["salt_length"] = autoSaltLength + t.Log("\t\t\t\t", "verify req data:", req.Data) + verifyRequest(false, "", sig) + } + + t.Log("\t\t", "Try to verify it with an explicit, incorrect salt length") + for _, nonAutoSaltLength := range nonAutoSaltLengths { + t.Log("\t\t\t", "auto", "/", nonAutoSaltLength) + req.Data["salt_length"] = nonAutoSaltLength + t.Log("\t\t\t\t", "verify req data:", req.Data) + verifyRequest(true, "", sig) + } + + t.Log("\t\t", "Make a signature with an explicit, valid salt length & and verify it with an implicit, automatic salt length") + for _, validSaltLength := range validSaltLengths { + t.Log("\t\t\t", validSaltLength, "/", "auto") + + req.Data = newReqData(hashAlgorithm, marshalingName) + req.Data["salt_length"] = validSaltLength + t.Log("\t\t\t", "sign req data:", req.Data) + sig := signRequest(false, "") + + t.Log("\t\t\t", "verify req data:", req.Data) + verifyRequest(false, "", sig) + } + } + + for hashAlgorithm := range keysutil.HashTypeMap { + t.Log("Hash algorithm:", hashAlgorithm) + for marshalingName := range keysutil.MarshalingTypeMap { + t.Log("\t", "Marshaling type:", marshalingName) + testCombinatorics(hashAlgorithm, marshalingName) + testAutoSignAndVerify(hashAlgorithm, marshalingName) + } + } +} diff --git a/changelog/16549.txt b/changelog/16549.txt new file mode 100644 index 000000000..101d1f924 --- /dev/null +++ b/changelog/16549.txt @@ -0,0 +1,3 @@ +```release-note:improvement +secrets/transit: Allow configuring the possible salt lengths for RSA PSS signatures. +``` \ No newline at end of file diff --git a/sdk/helper/keysutil/consts.go b/sdk/helper/keysutil/consts.go index 2a83ab849..cbb812351 100644 --- a/sdk/helper/keysutil/consts.go +++ b/sdk/helper/keysutil/consts.go @@ -1,6 +1,7 @@ package keysutil import ( + "crypto" "crypto/sha1" "crypto/sha256" "crypto/sha512" @@ -57,6 +58,18 @@ var ( HashTypeSHA3512: sha3.New512, } + CryptoHashMap = map[HashType]crypto.Hash{ + HashTypeSHA1: crypto.SHA1, + HashTypeSHA2224: crypto.SHA224, + HashTypeSHA2256: crypto.SHA256, + HashTypeSHA2384: crypto.SHA384, + HashTypeSHA2512: crypto.SHA512, + HashTypeSHA3224: crypto.SHA3_224, + HashTypeSHA3256: crypto.SHA3_256, + HashTypeSHA3384: crypto.SHA3_384, + HashTypeSHA3512: crypto.SHA3_512, + } + MarshalingTypeMap = map[string]MarshalingType{ "asn1": MarshalingTypeASN1, "jws": MarshalingTypeJWS, diff --git a/sdk/helper/keysutil/policy.go b/sdk/helper/keysutil/policy.go index 914b87988..86b5e0e49 100644 --- a/sdk/helper/keysutil/policy.go +++ b/sdk/helper/keysutil/policy.go @@ -80,6 +80,13 @@ type BackupInfo struct { Version int `json:"version"` } +type SigningOptions struct { + HashAlgorithm HashType + Marshaling MarshalingType + SaltLength int + SigAlgorithm string +} + type SigningResult struct { Signature string PublicKey []byte @@ -1026,6 +1033,29 @@ func (p *Policy) HMACKey(version int) ([]byte, error) { } func (p *Policy) Sign(ver int, context, input []byte, hashAlgorithm HashType, sigAlgorithm string, marshaling MarshalingType) (*SigningResult, error) { + return p.SignWithOptions(ver, context, input, &SigningOptions{ + HashAlgorithm: hashAlgorithm, + Marshaling: marshaling, + SaltLength: rsa.PSSSaltLengthAuto, + SigAlgorithm: sigAlgorithm, + }) +} + +func (p *Policy) minRSAPSSSaltLength() int { + // https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/crypto/rsa/pss.go;l=247 + return rsa.PSSSaltLengthEqualsHash +} + +func (p *Policy) maxRSAPSSSaltLength(priv *rsa.PrivateKey, hash crypto.Hash) int { + // https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/crypto/rsa/pss.go;l=288 + return (priv.N.BitLen()-1+7)/8 - 2 - hash.Size() +} + +func (p *Policy) validRSAPSSSaltLength(priv *rsa.PrivateKey, hash crypto.Hash, saltLength int) bool { + return p.minRSAPSSSaltLength() <= saltLength && saltLength <= p.maxRSAPSSSaltLength(priv, hash) +} + +func (p *Policy) SignWithOptions(ver int, context, input []byte, options *SigningOptions) (*SigningResult, error) { if !p.Type.SigningSupported() { return nil, fmt.Errorf("message signing not supported for key type %v", p.Type) } @@ -1049,6 +1079,11 @@ func (p *Policy) Sign(ver int, context, input []byte, hashAlgorithm HashType, si return nil, err } + hashAlgorithm := options.HashAlgorithm + marshaling := options.Marshaling + saltLength := options.SaltLength + sigAlgorithm := options.SigAlgorithm + switch p.Type { case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521: var curveBits int @@ -1139,27 +1174,8 @@ func (p *Policy) Sign(ver int, context, input []byte, hashAlgorithm HashType, si case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096: key := keyParams.RSAKey - var algo crypto.Hash - switch hashAlgorithm { - case HashTypeSHA1: - algo = crypto.SHA1 - case HashTypeSHA2224: - algo = crypto.SHA224 - case HashTypeSHA2256: - algo = crypto.SHA256 - case HashTypeSHA2384: - algo = crypto.SHA384 - case HashTypeSHA2512: - algo = crypto.SHA512 - case HashTypeSHA3224: - algo = crypto.SHA3_224 - case HashTypeSHA3256: - algo = crypto.SHA3_256 - case HashTypeSHA3384: - algo = crypto.SHA3_384 - case HashTypeSHA3512: - algo = crypto.SHA3_512 - default: + algo, ok := CryptoHashMap[hashAlgorithm] + if !ok { return nil, errutil.InternalError{Err: "unsupported hash algorithm"} } @@ -1169,7 +1185,10 @@ func (p *Policy) Sign(ver int, context, input []byte, hashAlgorithm HashType, si switch sigAlgorithm { case "pss": - sig, err = rsa.SignPSS(rand.Reader, key, algo, input, nil) + if !p.validRSAPSSSaltLength(key, algo, saltLength) { + return nil, errutil.UserError{Err: fmt.Sprintf("requested salt length %d is invalid", saltLength)} + } + sig, err = rsa.SignPSS(rand.Reader, key, algo, input, &rsa.PSSOptions{SaltLength: saltLength}) if err != nil { return nil, err } @@ -1203,6 +1222,15 @@ func (p *Policy) Sign(ver int, context, input []byte, hashAlgorithm HashType, si } func (p *Policy) VerifySignature(context, input []byte, hashAlgorithm HashType, sigAlgorithm string, marshaling MarshalingType, sig string) (bool, error) { + return p.VerifySignatureWithOptions(context, input, sig, &SigningOptions{ + HashAlgorithm: hashAlgorithm, + Marshaling: marshaling, + SaltLength: rsa.PSSSaltLengthAuto, + SigAlgorithm: sigAlgorithm, + }) +} + +func (p *Policy) VerifySignatureWithOptions(context, input []byte, sig string, options *SigningOptions) (bool, error) { if !p.Type.SigningSupported() { return false, errutil.UserError{Err: fmt.Sprintf("message verification not supported for key type %v", p.Type)} } @@ -1235,6 +1263,11 @@ func (p *Policy) VerifySignature(context, input []byte, hashAlgorithm HashType, return false, errutil.UserError{Err: ErrTooOld} } + hashAlgorithm := options.HashAlgorithm + marshaling := options.Marshaling + saltLength := options.SaltLength + sigAlgorithm := options.SigAlgorithm + var sigBytes []byte switch marshaling { case MarshalingTypeASN1: @@ -1318,27 +1351,8 @@ func (p *Policy) VerifySignature(context, input []byte, hashAlgorithm HashType, key := keyEntry.RSAKey - var algo crypto.Hash - switch hashAlgorithm { - case HashTypeSHA1: - algo = crypto.SHA1 - case HashTypeSHA2224: - algo = crypto.SHA224 - case HashTypeSHA2256: - algo = crypto.SHA256 - case HashTypeSHA2384: - algo = crypto.SHA384 - case HashTypeSHA2512: - algo = crypto.SHA512 - case HashTypeSHA3224: - algo = crypto.SHA3_224 - case HashTypeSHA3256: - algo = crypto.SHA3_256 - case HashTypeSHA3384: - algo = crypto.SHA3_384 - case HashTypeSHA3512: - algo = crypto.SHA3_512 - default: + algo, ok := CryptoHashMap[hashAlgorithm] + if !ok { return false, errutil.InternalError{Err: "unsupported hash algorithm"} } @@ -1348,7 +1362,10 @@ func (p *Policy) VerifySignature(context, input []byte, hashAlgorithm HashType, switch sigAlgorithm { case "pss": - err = rsa.VerifyPSS(&key.PublicKey, algo, input, sigBytes, nil) + if !p.validRSAPSSSaltLength(key, algo, saltLength) { + return false, errutil.UserError{Err: fmt.Sprintf("requested salt length %d is invalid", saltLength)} + } + err = rsa.VerifyPSS(&key.PublicKey, algo, input, sigBytes, &rsa.PSSOptions{SaltLength: saltLength}) case "pkcs1v15": err = rsa.VerifyPKCS1v15(&key.PublicKey, algo, input, sigBytes) default: diff --git a/sdk/helper/keysutil/policy_test.go b/sdk/helper/keysutil/policy_test.go index a2d9206a8..0d111beb3 100644 --- a/sdk/helper/keysutil/policy_test.go +++ b/sdk/helper/keysutil/policy_test.go @@ -8,14 +8,19 @@ import ( "crypto/rand" "crypto/rsa" "crypto/x509" + "errors" + "fmt" + mathrand "math/rand" "reflect" "strconv" + "strings" "sync" "testing" "time" "golang.org/x/crypto/ed25519" + "github.com/hashicorp/vault/sdk/helper/errutil" "github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/logical" "github.com/mitchellh/copystructure" @@ -698,6 +703,26 @@ func generateTestKeys() (map[KeyType][]byte, error) { } keyMap[KeyType_RSA2048] = rsaKeyBytes + rsaKey, err = rsa.GenerateKey(rand.Reader, 3072) + if err != nil { + return nil, err + } + rsaKeyBytes, err = x509.MarshalPKCS8PrivateKey(rsaKey) + if err != nil { + return nil, err + } + keyMap[KeyType_RSA3072] = rsaKeyBytes + + rsaKey, err = rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + return nil, err + } + rsaKeyBytes, err = x509.MarshalPKCS8PrivateKey(rsaKey) + if err != nil { + return nil, err + } + keyMap[KeyType_RSA4096] = rsaKeyBytes + ecdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { return nil, err @@ -754,3 +779,179 @@ func BenchmarkSymmetric(b *testing.B) { } } } + +func saltOptions(options SigningOptions, saltLength int) SigningOptions { + return SigningOptions{ + HashAlgorithm: options.HashAlgorithm, + Marshaling: options.Marshaling, + SaltLength: saltLength, + SigAlgorithm: options.SigAlgorithm, + } +} + +func manualVerify(depth int, t *testing.T, p *Policy, input []byte, sig *SigningResult, options SigningOptions) { + tabs := strings.Repeat("\t", depth) + t.Log(tabs, "Manually verifying signature with options:", options) + + tabs = strings.Repeat("\t", depth+1) + verified, err := p.VerifySignatureWithOptions(nil, input, sig.Signature, &options) + if err != nil { + t.Fatal(tabs, "❌ Failed to manually verify signature:", err) + } + if !verified { + t.Fatal(tabs, "❌ Failed to manually verify signature") + } +} + +func autoVerify(depth int, t *testing.T, p *Policy, input []byte, sig *SigningResult, options SigningOptions) { + tabs := strings.Repeat("\t", depth) + t.Log(tabs, "Automatically verifying signature with options:", options) + + tabs = strings.Repeat("\t", depth+1) + verified, err := p.VerifySignature(nil, input, options.HashAlgorithm, options.SigAlgorithm, options.Marshaling, sig.Signature) + if err != nil { + t.Fatal(tabs, "❌ Failed to automatically verify signature:", err) + } + if !verified { + t.Fatal(tabs, "❌ Failed to automatically verify signature") + } +} + +func Test_RSA_PSS(t *testing.T) { + t.Log("Testing RSA PSS") + mathrand.Seed(time.Now().UnixNano()) + + var userError errutil.UserError + ctx := context.Background() + storage := &logical.InmemStorage{} + // https://crypto.stackexchange.com/a/1222 + input := []byte("the ancients say the longer the salt, the more provable the security") + sigAlgorithm := "pss" + + tabs := make(map[int]string) + for i := 1; i <= 6; i++ { + tabs[i] = strings.Repeat("\t", i) + } + + test_RSA_PSS := func(p *Policy, rsaKey *rsa.PrivateKey, hashType HashType, marshalingType MarshalingType) { + unsaltedOptions := SigningOptions{ + HashAlgorithm: hashType, + Marshaling: marshalingType, + SigAlgorithm: sigAlgorithm, + } + cryptoHash := CryptoHashMap[hashType] + minSaltLength := p.minRSAPSSSaltLength() + maxSaltLength := p.maxRSAPSSSaltLength(rsaKey, cryptoHash) + hash := cryptoHash.New() + hash.Write(input) + input = hash.Sum(nil) + + // 1. Make an "automatic" signature with the given key size and hash algorithm, + // but an automatically chosen salt length. + t.Log(tabs[3], "Make an automatic signature") + sig, err := p.Sign(0, nil, input, hashType, sigAlgorithm, marshalingType) + if err != nil { + t.Fatal(tabs[4], "❌ Failed to automatically sign:", err) + } + + // 1.1 Verify this automatic signature using the *inferred* salt length. + autoVerify(4, t, p, input, sig, unsaltedOptions) + + // 1.2. Verify this automatic signature using the *correct, given* salt length. + manualVerify(4, t, p, input, sig, saltOptions(unsaltedOptions, maxSaltLength)) + + // 1.3. Try to verify this automatic signature using *incorrect, given* salt lengths. + t.Log(tabs[4], "Test incorrect salt lengths") + incorrectSaltLengths := []int{minSaltLength, maxSaltLength - 1} + for _, saltLength := range incorrectSaltLengths { + t.Log(tabs[5], "Salt length:", saltLength) + saltedOptions := saltOptions(unsaltedOptions, saltLength) + + verified, _ := p.VerifySignatureWithOptions(nil, input, sig.Signature, &saltedOptions) + if verified { + t.Fatal(tabs[6], "❌ Failed to invalidate", verified, "signature using incorrect salt length:", err) + } + } + + // 2. Rule out boundary, invalid salt lengths. + t.Log(tabs[3], "Test invalid salt lengths") + invalidSaltLengths := []int{minSaltLength - 1, maxSaltLength + 1} + for _, saltLength := range invalidSaltLengths { + t.Log(tabs[4], "Salt length:", saltLength) + saltedOptions := saltOptions(unsaltedOptions, saltLength) + + // 2.1. Fail to sign. + t.Log(tabs[5], "Try to make a manual signature") + _, err := p.SignWithOptions(0, nil, input, &saltedOptions) + if !errors.As(err, &userError) { + t.Fatal(tabs[6], "❌ Failed to reject invalid salt length:", err) + } + + // 2.2. Fail to verify. + t.Log(tabs[5], "Try to verify an automatic signature using an invalid salt length") + _, err = p.VerifySignatureWithOptions(nil, input, sig.Signature, &saltedOptions) + if !errors.As(err, &userError) { + t.Fatal(tabs[6], "❌ Failed to reject invalid salt length:", err) + } + } + + // 3. For three possible valid salt lengths... + t.Log(tabs[3], "Test three possible valid salt lengths") + midSaltLength := mathrand.Intn(maxSaltLength-1) + 1 // [1, maxSaltLength) + validSaltLengths := []int{minSaltLength, midSaltLength, maxSaltLength} + for _, saltLength := range validSaltLengths { + t.Log(tabs[4], "Salt length:", saltLength) + saltedOptions := saltOptions(unsaltedOptions, saltLength) + + // 3.1. Make a "manual" signature with the given key size, hash algorithm, and salt length. + t.Log(tabs[5], "Make a manual signature") + sig, err := p.SignWithOptions(0, nil, input, &saltedOptions) + if err != nil { + t.Fatal(tabs[6], "❌ Failed to manually sign:", err) + } + + // 3.2. Verify this manual signature using the *correct, given* salt length. + manualVerify(6, t, p, input, sig, saltedOptions) + + // 3.3. Verify this manual signature using the *inferred* salt length. + autoVerify(6, t, p, input, sig, unsaltedOptions) + } + } + + rsaKeyTypes := []KeyType{KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096} + testKeys, err := generateTestKeys() + if err != nil { + t.Fatalf("error generating test keys: %s", err) + } + + // 1. For each standard RSA key size 2048, 3072, and 4096... + for _, rsaKeyType := range rsaKeyTypes { + t.Log("Key size: ", rsaKeyType) + p := &Policy{ + Name: fmt.Sprint(rsaKeyType), // NOTE: crucial to create a new key per key size + Type: rsaKeyType, + } + + rsaKeyBytes := testKeys[rsaKeyType] + err := p.Import(ctx, storage, rsaKeyBytes, rand.Reader) + if err != nil { + t.Fatal(tabs[1], "❌ Failed to import key:", err) + } + rsaKeyAny, err := x509.ParsePKCS8PrivateKey(rsaKeyBytes) + if err != nil { + t.Fatalf("error parsing test keys: %s", err) + } + rsaKey := rsaKeyAny.(*rsa.PrivateKey) + + // 2. For each hash algorithm... + for hashAlgorithm, hashType := range HashTypeMap { + t.Log(tabs[1], "Hash algorithm:", hashAlgorithm) + + // 3. For each marshaling type... + for marshalingName, marshalingType := range MarshalingTypeMap { + t.Log(tabs[2], "Marshaling type:", marshalingName) + test_RSA_PSS(p, rsaKey, hashType, marshalingType) + } + } + } +} diff --git a/website/content/api-docs/secret/transit.mdx b/website/content/api-docs/secret/transit.mdx index 4f426b096..65c49612e 100644 --- a/website/content/api-docs/secret/transit.mdx +++ b/website/content/api-docs/secret/transit.mdx @@ -1157,6 +1157,12 @@ supports signing. also change the output encoding to URL-safe Base64 encoding instead of standard Base64-encoding. +- `salt_length` `(string: "auto")` – The salt length used to sign. This currently only applies to the RSA PSS signature scheme. Options are: + + - `auto`: The default used by Golang (causing the salt to be as large as possible when signing) + - `hash`: Causes the salt length to equal the length of the hash used in the signature + - An integer between the minimum and the maximum permissible salt lengths for the given RSA key size. + ### Sample Request ```shell-session @@ -1328,6 +1334,12 @@ data. also expect the input encoding to URL-safe Base64 encoding instead of standard Base64-encoding. +- `salt_length` `(string: "auto")` – The salt length used to sign. This currently only applies to the RSA PSS signature scheme. Options are: + + - `auto`: The default used by Golang (causing the salt to be as large as possible when signing) + - `hash`: Causes the salt length to equal the length of the hash used in the signature + - An integer between the minimum and the maximum permissible salt lengths for the given RSA key size. + ### Sample Request ```shell-session