Allow configuring the possible salt lengths for RSA PSS signatures (#16549)

* accommodate salt lengths for RSA PSS

* address feedback

* generalise salt length to an int

* fix error reporting

* Revert "fix error reporting"

This reverts commit 8adfc15fe3303b8fdf9f094ea246945ab1364077.

* fix a faulty check

* check for min/max salt lengths

* stringly-typed HTTP param

* unit tests for sign/verify HTTP requests

also, add marshaling for both SDK and HTTP requests

* randomly sample valid salt length

* add changelog

* add documentation
This commit is contained in:
Trishank Karthik Kuppusamy 2022-08-31 12:27:03 -04:00 committed by GitHub
parent 2fb4ed211d
commit 303f59dce3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 609 additions and 46 deletions

View File

@ -2,8 +2,11 @@ package transit
import (
"context"
"crypto/rsa"
"encoding/base64"
"fmt"
"strconv"
"strings"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/errutil"
@ -131,6 +134,13 @@ Options are 'pss' or 'pkcs1v15'. Defaults to 'pss'`,
Default: "asn1",
Description: `The method by which to marshal the signature. The default is 'asn1' which is used by openssl and X.509. It can also be set to 'jws' which is used for JWT signatures; setting it to this will also cause the encoding of the signature to be url-safe base64 instead of using standard base64 encoding. Currently only valid for ECDSA P-256 key types".`,
},
"salt_length": {
Type: framework.TypeString,
Default: "auto",
Description: `The salt length used to sign. Currently only applies to the RSA PSS signature scheme.
Options are 'auto' (the default used by Golang, causing the salt to be as large as possible when signing), 'hash' (causes the salt length to equal the length of the hash used in the signature), or an integer between the minimum and the maximum permissible salt lengths for the given RSA key size. Defaults to 'auto'.`,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
@ -217,6 +227,13 @@ Options are 'pss' or 'pkcs1v15'. Defaults to 'pss'`,
Default: "asn1",
Description: `The method by which to unmarshal the signature when verifying. The default is 'asn1' which is used by openssl and X.509; can also be set to 'jws' which is used for JWT signatures in which case the signature is also expected to be url-safe base64 encoding instead of standard base64 encoding. Currently only valid for ECDSA P-256 key types".`,
},
"salt_length": {
Type: framework.TypeString,
Default: "auto",
Description: `The salt length used to sign. Currently only applies to the RSA PSS signature scheme.
Options are 'auto' (the default used by Golang, causing the salt to be as large as possible when signing), 'hash' (causes the salt length to equal the length of the hash used in the signature), or an integer between the minimum and the maximum permissible salt lengths for the given RSA key size. Defaults to 'auto'.`,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
@ -228,6 +245,33 @@ Options are 'pss' or 'pkcs1v15'. Defaults to 'pss'`,
}
}
func (b *backend) getSaltLength(d *framework.FieldData) (int, error) {
rawSaltLength, ok := d.GetOk("salt_length")
// This should only happen when something is wrong with the schema,
// so this is a reasonable default.
if !ok {
return rsa.PSSSaltLengthAuto, nil
}
rawSaltLengthStr := rawSaltLength.(string)
lowerSaltLengthStr := strings.ToLower(rawSaltLengthStr)
switch lowerSaltLengthStr {
case "auto":
return rsa.PSSSaltLengthAuto, nil
case "hash":
return rsa.PSSSaltLengthEqualsHash, nil
default:
saltLengthInt, err := strconv.Atoi(lowerSaltLengthStr)
if err != nil {
return rsa.PSSSaltLengthEqualsHash - 1, fmt.Errorf("salt length neither 'auto', 'hash', nor an int: %s", rawSaltLength)
}
if saltLengthInt < rsa.PSSSaltLengthEqualsHash {
return rsa.PSSSaltLengthEqualsHash - 1, fmt.Errorf("salt length is invalid: %d", saltLengthInt)
}
return saltLengthInt, nil
}
}
func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
ver := d.Get("key_version").(int)
@ -252,6 +296,10 @@ func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *fr
prehashed := d.Get("prehashed").(bool)
sigAlgorithm := d.Get("signature_algorithm").(string)
saltLength, err := b.getSaltLength(d)
if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
}
// Get the policy
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
@ -330,7 +378,12 @@ func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *fr
}
}
sig, err := p.Sign(ver, context, input, hashAlgorithm, sigAlgorithm, marshaling)
sig, err := p.SignWithOptions(ver, context, input, &keysutil.SigningOptions{
HashAlgorithm: hashAlgorithm,
Marshaling: marshaling,
SaltLength: saltLength,
SigAlgorithm: sigAlgorithm,
})
if err != nil {
if batchInputRaw != nil {
response[i].Error = err.Error()
@ -470,6 +523,10 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d *
prehashed := d.Get("prehashed").(bool)
sigAlgorithm := d.Get("signature_algorithm").(string)
saltLength, err := b.getSaltLength(d)
if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
}
// Get the policy
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
@ -533,7 +590,12 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d *
}
}
valid, err := p.VerifySignature(context, input, hashAlgorithm, sigAlgorithm, marshaling, sig)
valid, err := p.VerifySignatureWithOptions(context, input, sig, &keysutil.SigningOptions{
HashAlgorithm: hashAlgorithm,
Marshaling: marshaling,
SaltLength: saltLength,
SigAlgorithm: sigAlgorithm,
})
if err != nil {
switch err.(type) {
case errutil.UserError:

View File

@ -700,3 +700,258 @@ func TestTransit_SignVerify_ED25519(t *testing.T) {
outcome[1].valid = false
verifyRequest(req, false, outcome, "bar", goodsig, true)
}
func TestTransit_SignVerify_RSA_PSS(t *testing.T) {
t.Run("2048", func(t *testing.T) {
testTransit_SignVerify_RSA_PSS(t, 2048)
})
t.Run("3072", func(t *testing.T) {
testTransit_SignVerify_RSA_PSS(t, 3072)
})
t.Run("4096", func(t *testing.T) {
testTransit_SignVerify_RSA_PSS(t, 4096)
})
}
func testTransit_SignVerify_RSA_PSS(t *testing.T, bits int) {
b, storage := createBackendWithSysView(t)
// First create a key
req := &logical.Request{
Storage: storage,
Operation: logical.UpdateOperation,
Path: "keys/foo",
Data: map[string]interface{}{
"type": fmt.Sprintf("rsa-%d", bits),
},
}
_, err := b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
signRequest := func(errExpected bool, postpath string) string {
t.Helper()
req.Path = "sign/foo" + postpath
resp, err := b.HandleRequest(context.Background(), req)
if err != nil && !errExpected {
t.Fatal(err)
}
if resp == nil {
t.Fatal("expected non-nil response")
}
if errExpected {
if !resp.IsError() {
t.Fatalf("bad: should have gotten error response: %#v", *resp)
}
return ""
}
if resp.IsError() {
t.Fatalf("bad: got error response: %#v", *resp)
}
// Since we are reusing the same request, let's clear the salt length each time.
delete(req.Data, "salt_length")
value, ok := resp.Data["signature"]
if !ok {
t.Fatalf("no signature key found in returned data, got resp data %#v", resp.Data)
}
return value.(string)
}
verifyRequest := func(errExpected bool, postpath, sig string) {
t.Helper()
req.Path = "verify/foo" + postpath
req.Data["signature"] = sig
resp, err := b.HandleRequest(context.Background(), req)
if err != nil {
if errExpected {
return
}
t.Fatalf("got error: %v, sig was %v", err, sig)
}
if resp == nil {
t.Fatal("expected non-nil response")
}
if resp.IsError() {
if errExpected {
return
}
t.Fatalf("bad: got error response: %#v", *resp)
}
value, ok := resp.Data["valid"]
if !ok {
t.Fatalf("no valid key found in returned data, got resp data %#v", resp.Data)
}
if !value.(bool) && !errExpected {
t.Fatalf("verification failed; req was %#v, resp is %#v", *req, *resp)
} else if value.(bool) && errExpected {
t.Fatalf("expected error and didn't get one; req was %#v, resp is %#v", *req, *resp)
}
// Since we are reusing the same request, let's clear the signature each time.
delete(req.Data, "signature")
}
newReqData := func(hashAlgorithm string, marshalingName string) map[string]interface{} {
return map[string]interface{}{
"input": "dGhlIHF1aWNrIGJyb3duIGZveA==",
"signature_algorithm": "pss",
"hash_algorithm": hashAlgorithm,
"marshaling_algorithm": marshalingName,
}
}
signAndVerifyRequest := func(hashAlgorithm string, marshalingName string, signSaltLength string, signErrExpected bool, verifySaltLength string, verifyErrExpected bool) {
t.Log("\t\t\t", signSaltLength, "/", verifySaltLength)
req.Data = newReqData(hashAlgorithm, marshalingName)
req.Data["salt_length"] = signSaltLength
t.Log("\t\t\t\t", "sign req data:", req.Data)
sig := signRequest(signErrExpected, "")
req.Data["salt_length"] = verifySaltLength
t.Log("\t\t\t\t", "verify req data:", req.Data)
verifyRequest(verifyErrExpected, "", sig)
}
invalidSaltLengths := []string{"bar", "-2"}
t.Log("invalidSaltLengths:", invalidSaltLengths)
autoSaltLengths := []string{"auto", "0"}
t.Log("autoSaltLengths:", autoSaltLengths)
hashSaltLengths := []string{"hash", "-1"}
t.Log("hashSaltLengths:", hashSaltLengths)
positiveSaltLengths := []string{"1"}
t.Log("positiveSaltLengths:", positiveSaltLengths)
nonAutoSaltLengths := append(hashSaltLengths, positiveSaltLengths...)
t.Log("nonAutoSaltLengths:", nonAutoSaltLengths)
validSaltLengths := append(autoSaltLengths, nonAutoSaltLengths...)
t.Log("validSaltLengths:", validSaltLengths)
testCombinatorics := func(hashAlgorithm string, marshalingName string) {
t.Log("\t\t", "valid", "/", "invalid salt lengths")
for _, validSaltLength := range validSaltLengths {
for _, invalidSaltLength := range invalidSaltLengths {
signAndVerifyRequest(hashAlgorithm, marshalingName, validSaltLength, false, invalidSaltLength, true)
}
}
t.Log("\t\t", "invalid", "/", "invalid salt lengths")
for _, invalidSaltLength1 := range invalidSaltLengths {
for _, invalidSaltLength2 := range invalidSaltLengths {
signAndVerifyRequest(hashAlgorithm, marshalingName, invalidSaltLength1, true, invalidSaltLength2, true)
}
}
t.Log("\t\t", "invalid", "/", "valid salt lengths")
for _, invalidSaltLength := range invalidSaltLengths {
for _, validSaltLength := range validSaltLengths {
signAndVerifyRequest(hashAlgorithm, marshalingName, invalidSaltLength, true, validSaltLength, true)
}
}
t.Log("\t\t", "valid", "/", "valid salt lengths")
for _, validSaltLength := range validSaltLengths {
signAndVerifyRequest(hashAlgorithm, marshalingName, validSaltLength, false, validSaltLength, false)
}
t.Log("\t\t", "hash", "/", "hash salt lengths")
for _, hashSaltLength1 := range hashSaltLengths {
for _, hashSaltLength2 := range hashSaltLengths {
if hashSaltLength1 != hashSaltLength2 {
signAndVerifyRequest(hashAlgorithm, marshalingName, hashSaltLength1, false, hashSaltLength2, false)
}
}
}
t.Log("\t\t", "hash", "/", "positive salt lengths")
for _, hashSaltLength := range hashSaltLengths {
for _, positiveSaltLength := range positiveSaltLengths {
signAndVerifyRequest(hashAlgorithm, marshalingName, hashSaltLength, false, positiveSaltLength, true)
}
}
t.Log("\t\t", "positive", "/", "hash salt lengths")
for _, positiveSaltLength := range positiveSaltLengths {
for _, hashSaltLength := range hashSaltLengths {
signAndVerifyRequest(hashAlgorithm, marshalingName, positiveSaltLength, false, hashSaltLength, true)
}
}
t.Log("\t\t", "auto", "/", "auto salt lengths")
for _, autoSaltLength1 := range autoSaltLengths {
for _, autoSaltLength2 := range autoSaltLengths {
if autoSaltLength1 != autoSaltLength2 {
signAndVerifyRequest(hashAlgorithm, marshalingName, autoSaltLength1, false, autoSaltLength2, false)
}
}
}
t.Log("\t\t", "auto", "/", "non-auto salt lengths")
for _, autoSaltLength := range autoSaltLengths {
for _, nonAutoSaltLength := range nonAutoSaltLengths {
signAndVerifyRequest(hashAlgorithm, marshalingName, autoSaltLength, false, nonAutoSaltLength, true)
}
}
t.Log("\t\t", "non-auto", "/", "auto salt lengths")
for _, nonAutoSaltLength := range nonAutoSaltLengths {
for _, autoSaltLength := range autoSaltLengths {
signAndVerifyRequest(hashAlgorithm, marshalingName, nonAutoSaltLength, false, autoSaltLength, false)
}
}
}
testAutoSignAndVerify := func(hashAlgorithm string, marshalingName string) {
t.Log("\t\t", "Make a signature with an implicit, automatic salt length")
req.Data = newReqData(hashAlgorithm, marshalingName)
t.Log("\t\t\t", "sign req data:", req.Data)
sig := signRequest(false, "")
t.Log("\t\t", "Verify it with an implicit, automatic salt length")
t.Log("\t\t\t", "verify req data:", req.Data)
verifyRequest(false, "", sig)
t.Log("\t\t", "Verify it with an explicit, automatic salt length")
for _, autoSaltLength := range autoSaltLengths {
t.Log("\t\t\t", "auto", "/", autoSaltLength)
req.Data["salt_length"] = autoSaltLength
t.Log("\t\t\t\t", "verify req data:", req.Data)
verifyRequest(false, "", sig)
}
t.Log("\t\t", "Try to verify it with an explicit, incorrect salt length")
for _, nonAutoSaltLength := range nonAutoSaltLengths {
t.Log("\t\t\t", "auto", "/", nonAutoSaltLength)
req.Data["salt_length"] = nonAutoSaltLength
t.Log("\t\t\t\t", "verify req data:", req.Data)
verifyRequest(true, "", sig)
}
t.Log("\t\t", "Make a signature with an explicit, valid salt length & and verify it with an implicit, automatic salt length")
for _, validSaltLength := range validSaltLengths {
t.Log("\t\t\t", validSaltLength, "/", "auto")
req.Data = newReqData(hashAlgorithm, marshalingName)
req.Data["salt_length"] = validSaltLength
t.Log("\t\t\t", "sign req data:", req.Data)
sig := signRequest(false, "")
t.Log("\t\t\t", "verify req data:", req.Data)
verifyRequest(false, "", sig)
}
}
for hashAlgorithm := range keysutil.HashTypeMap {
t.Log("Hash algorithm:", hashAlgorithm)
for marshalingName := range keysutil.MarshalingTypeMap {
t.Log("\t", "Marshaling type:", marshalingName)
testCombinatorics(hashAlgorithm, marshalingName)
testAutoSignAndVerify(hashAlgorithm, marshalingName)
}
}
}

3
changelog/16549.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:improvement
secrets/transit: Allow configuring the possible salt lengths for RSA PSS signatures.
```

View File

@ -1,6 +1,7 @@
package keysutil
import (
"crypto"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
@ -57,6 +58,18 @@ var (
HashTypeSHA3512: sha3.New512,
}
CryptoHashMap = map[HashType]crypto.Hash{
HashTypeSHA1: crypto.SHA1,
HashTypeSHA2224: crypto.SHA224,
HashTypeSHA2256: crypto.SHA256,
HashTypeSHA2384: crypto.SHA384,
HashTypeSHA2512: crypto.SHA512,
HashTypeSHA3224: crypto.SHA3_224,
HashTypeSHA3256: crypto.SHA3_256,
HashTypeSHA3384: crypto.SHA3_384,
HashTypeSHA3512: crypto.SHA3_512,
}
MarshalingTypeMap = map[string]MarshalingType{
"asn1": MarshalingTypeASN1,
"jws": MarshalingTypeJWS,

View File

@ -80,6 +80,13 @@ type BackupInfo struct {
Version int `json:"version"`
}
type SigningOptions struct {
HashAlgorithm HashType
Marshaling MarshalingType
SaltLength int
SigAlgorithm string
}
type SigningResult struct {
Signature string
PublicKey []byte
@ -1026,6 +1033,29 @@ func (p *Policy) HMACKey(version int) ([]byte, error) {
}
func (p *Policy) Sign(ver int, context, input []byte, hashAlgorithm HashType, sigAlgorithm string, marshaling MarshalingType) (*SigningResult, error) {
return p.SignWithOptions(ver, context, input, &SigningOptions{
HashAlgorithm: hashAlgorithm,
Marshaling: marshaling,
SaltLength: rsa.PSSSaltLengthAuto,
SigAlgorithm: sigAlgorithm,
})
}
func (p *Policy) minRSAPSSSaltLength() int {
// https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/crypto/rsa/pss.go;l=247
return rsa.PSSSaltLengthEqualsHash
}
func (p *Policy) maxRSAPSSSaltLength(priv *rsa.PrivateKey, hash crypto.Hash) int {
// https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/crypto/rsa/pss.go;l=288
return (priv.N.BitLen()-1+7)/8 - 2 - hash.Size()
}
func (p *Policy) validRSAPSSSaltLength(priv *rsa.PrivateKey, hash crypto.Hash, saltLength int) bool {
return p.minRSAPSSSaltLength() <= saltLength && saltLength <= p.maxRSAPSSSaltLength(priv, hash)
}
func (p *Policy) SignWithOptions(ver int, context, input []byte, options *SigningOptions) (*SigningResult, error) {
if !p.Type.SigningSupported() {
return nil, fmt.Errorf("message signing not supported for key type %v", p.Type)
}
@ -1049,6 +1079,11 @@ func (p *Policy) Sign(ver int, context, input []byte, hashAlgorithm HashType, si
return nil, err
}
hashAlgorithm := options.HashAlgorithm
marshaling := options.Marshaling
saltLength := options.SaltLength
sigAlgorithm := options.SigAlgorithm
switch p.Type {
case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521:
var curveBits int
@ -1139,27 +1174,8 @@ func (p *Policy) Sign(ver int, context, input []byte, hashAlgorithm HashType, si
case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
key := keyParams.RSAKey
var algo crypto.Hash
switch hashAlgorithm {
case HashTypeSHA1:
algo = crypto.SHA1
case HashTypeSHA2224:
algo = crypto.SHA224
case HashTypeSHA2256:
algo = crypto.SHA256
case HashTypeSHA2384:
algo = crypto.SHA384
case HashTypeSHA2512:
algo = crypto.SHA512
case HashTypeSHA3224:
algo = crypto.SHA3_224
case HashTypeSHA3256:
algo = crypto.SHA3_256
case HashTypeSHA3384:
algo = crypto.SHA3_384
case HashTypeSHA3512:
algo = crypto.SHA3_512
default:
algo, ok := CryptoHashMap[hashAlgorithm]
if !ok {
return nil, errutil.InternalError{Err: "unsupported hash algorithm"}
}
@ -1169,7 +1185,10 @@ func (p *Policy) Sign(ver int, context, input []byte, hashAlgorithm HashType, si
switch sigAlgorithm {
case "pss":
sig, err = rsa.SignPSS(rand.Reader, key, algo, input, nil)
if !p.validRSAPSSSaltLength(key, algo, saltLength) {
return nil, errutil.UserError{Err: fmt.Sprintf("requested salt length %d is invalid", saltLength)}
}
sig, err = rsa.SignPSS(rand.Reader, key, algo, input, &rsa.PSSOptions{SaltLength: saltLength})
if err != nil {
return nil, err
}
@ -1203,6 +1222,15 @@ func (p *Policy) Sign(ver int, context, input []byte, hashAlgorithm HashType, si
}
func (p *Policy) VerifySignature(context, input []byte, hashAlgorithm HashType, sigAlgorithm string, marshaling MarshalingType, sig string) (bool, error) {
return p.VerifySignatureWithOptions(context, input, sig, &SigningOptions{
HashAlgorithm: hashAlgorithm,
Marshaling: marshaling,
SaltLength: rsa.PSSSaltLengthAuto,
SigAlgorithm: sigAlgorithm,
})
}
func (p *Policy) VerifySignatureWithOptions(context, input []byte, sig string, options *SigningOptions) (bool, error) {
if !p.Type.SigningSupported() {
return false, errutil.UserError{Err: fmt.Sprintf("message verification not supported for key type %v", p.Type)}
}
@ -1235,6 +1263,11 @@ func (p *Policy) VerifySignature(context, input []byte, hashAlgorithm HashType,
return false, errutil.UserError{Err: ErrTooOld}
}
hashAlgorithm := options.HashAlgorithm
marshaling := options.Marshaling
saltLength := options.SaltLength
sigAlgorithm := options.SigAlgorithm
var sigBytes []byte
switch marshaling {
case MarshalingTypeASN1:
@ -1318,27 +1351,8 @@ func (p *Policy) VerifySignature(context, input []byte, hashAlgorithm HashType,
key := keyEntry.RSAKey
var algo crypto.Hash
switch hashAlgorithm {
case HashTypeSHA1:
algo = crypto.SHA1
case HashTypeSHA2224:
algo = crypto.SHA224
case HashTypeSHA2256:
algo = crypto.SHA256
case HashTypeSHA2384:
algo = crypto.SHA384
case HashTypeSHA2512:
algo = crypto.SHA512
case HashTypeSHA3224:
algo = crypto.SHA3_224
case HashTypeSHA3256:
algo = crypto.SHA3_256
case HashTypeSHA3384:
algo = crypto.SHA3_384
case HashTypeSHA3512:
algo = crypto.SHA3_512
default:
algo, ok := CryptoHashMap[hashAlgorithm]
if !ok {
return false, errutil.InternalError{Err: "unsupported hash algorithm"}
}
@ -1348,7 +1362,10 @@ func (p *Policy) VerifySignature(context, input []byte, hashAlgorithm HashType,
switch sigAlgorithm {
case "pss":
err = rsa.VerifyPSS(&key.PublicKey, algo, input, sigBytes, nil)
if !p.validRSAPSSSaltLength(key, algo, saltLength) {
return false, errutil.UserError{Err: fmt.Sprintf("requested salt length %d is invalid", saltLength)}
}
err = rsa.VerifyPSS(&key.PublicKey, algo, input, sigBytes, &rsa.PSSOptions{SaltLength: saltLength})
case "pkcs1v15":
err = rsa.VerifyPKCS1v15(&key.PublicKey, algo, input, sigBytes)
default:

View File

@ -8,14 +8,19 @@ import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"errors"
"fmt"
mathrand "math/rand"
"reflect"
"strconv"
"strings"
"sync"
"testing"
"time"
"golang.org/x/crypto/ed25519"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/helper/jsonutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/mitchellh/copystructure"
@ -698,6 +703,26 @@ func generateTestKeys() (map[KeyType][]byte, error) {
}
keyMap[KeyType_RSA2048] = rsaKeyBytes
rsaKey, err = rsa.GenerateKey(rand.Reader, 3072)
if err != nil {
return nil, err
}
rsaKeyBytes, err = x509.MarshalPKCS8PrivateKey(rsaKey)
if err != nil {
return nil, err
}
keyMap[KeyType_RSA3072] = rsaKeyBytes
rsaKey, err = rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, err
}
rsaKeyBytes, err = x509.MarshalPKCS8PrivateKey(rsaKey)
if err != nil {
return nil, err
}
keyMap[KeyType_RSA4096] = rsaKeyBytes
ecdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, err
@ -754,3 +779,179 @@ func BenchmarkSymmetric(b *testing.B) {
}
}
}
func saltOptions(options SigningOptions, saltLength int) SigningOptions {
return SigningOptions{
HashAlgorithm: options.HashAlgorithm,
Marshaling: options.Marshaling,
SaltLength: saltLength,
SigAlgorithm: options.SigAlgorithm,
}
}
func manualVerify(depth int, t *testing.T, p *Policy, input []byte, sig *SigningResult, options SigningOptions) {
tabs := strings.Repeat("\t", depth)
t.Log(tabs, "Manually verifying signature with options:", options)
tabs = strings.Repeat("\t", depth+1)
verified, err := p.VerifySignatureWithOptions(nil, input, sig.Signature, &options)
if err != nil {
t.Fatal(tabs, "❌ Failed to manually verify signature:", err)
}
if !verified {
t.Fatal(tabs, "❌ Failed to manually verify signature")
}
}
func autoVerify(depth int, t *testing.T, p *Policy, input []byte, sig *SigningResult, options SigningOptions) {
tabs := strings.Repeat("\t", depth)
t.Log(tabs, "Automatically verifying signature with options:", options)
tabs = strings.Repeat("\t", depth+1)
verified, err := p.VerifySignature(nil, input, options.HashAlgorithm, options.SigAlgorithm, options.Marshaling, sig.Signature)
if err != nil {
t.Fatal(tabs, "❌ Failed to automatically verify signature:", err)
}
if !verified {
t.Fatal(tabs, "❌ Failed to automatically verify signature")
}
}
func Test_RSA_PSS(t *testing.T) {
t.Log("Testing RSA PSS")
mathrand.Seed(time.Now().UnixNano())
var userError errutil.UserError
ctx := context.Background()
storage := &logical.InmemStorage{}
// https://crypto.stackexchange.com/a/1222
input := []byte("the ancients say the longer the salt, the more provable the security")
sigAlgorithm := "pss"
tabs := make(map[int]string)
for i := 1; i <= 6; i++ {
tabs[i] = strings.Repeat("\t", i)
}
test_RSA_PSS := func(p *Policy, rsaKey *rsa.PrivateKey, hashType HashType, marshalingType MarshalingType) {
unsaltedOptions := SigningOptions{
HashAlgorithm: hashType,
Marshaling: marshalingType,
SigAlgorithm: sigAlgorithm,
}
cryptoHash := CryptoHashMap[hashType]
minSaltLength := p.minRSAPSSSaltLength()
maxSaltLength := p.maxRSAPSSSaltLength(rsaKey, cryptoHash)
hash := cryptoHash.New()
hash.Write(input)
input = hash.Sum(nil)
// 1. Make an "automatic" signature with the given key size and hash algorithm,
// but an automatically chosen salt length.
t.Log(tabs[3], "Make an automatic signature")
sig, err := p.Sign(0, nil, input, hashType, sigAlgorithm, marshalingType)
if err != nil {
t.Fatal(tabs[4], "❌ Failed to automatically sign:", err)
}
// 1.1 Verify this automatic signature using the *inferred* salt length.
autoVerify(4, t, p, input, sig, unsaltedOptions)
// 1.2. Verify this automatic signature using the *correct, given* salt length.
manualVerify(4, t, p, input, sig, saltOptions(unsaltedOptions, maxSaltLength))
// 1.3. Try to verify this automatic signature using *incorrect, given* salt lengths.
t.Log(tabs[4], "Test incorrect salt lengths")
incorrectSaltLengths := []int{minSaltLength, maxSaltLength - 1}
for _, saltLength := range incorrectSaltLengths {
t.Log(tabs[5], "Salt length:", saltLength)
saltedOptions := saltOptions(unsaltedOptions, saltLength)
verified, _ := p.VerifySignatureWithOptions(nil, input, sig.Signature, &saltedOptions)
if verified {
t.Fatal(tabs[6], "❌ Failed to invalidate", verified, "signature using incorrect salt length:", err)
}
}
// 2. Rule out boundary, invalid salt lengths.
t.Log(tabs[3], "Test invalid salt lengths")
invalidSaltLengths := []int{minSaltLength - 1, maxSaltLength + 1}
for _, saltLength := range invalidSaltLengths {
t.Log(tabs[4], "Salt length:", saltLength)
saltedOptions := saltOptions(unsaltedOptions, saltLength)
// 2.1. Fail to sign.
t.Log(tabs[5], "Try to make a manual signature")
_, err := p.SignWithOptions(0, nil, input, &saltedOptions)
if !errors.As(err, &userError) {
t.Fatal(tabs[6], "❌ Failed to reject invalid salt length:", err)
}
// 2.2. Fail to verify.
t.Log(tabs[5], "Try to verify an automatic signature using an invalid salt length")
_, err = p.VerifySignatureWithOptions(nil, input, sig.Signature, &saltedOptions)
if !errors.As(err, &userError) {
t.Fatal(tabs[6], "❌ Failed to reject invalid salt length:", err)
}
}
// 3. For three possible valid salt lengths...
t.Log(tabs[3], "Test three possible valid salt lengths")
midSaltLength := mathrand.Intn(maxSaltLength-1) + 1 // [1, maxSaltLength)
validSaltLengths := []int{minSaltLength, midSaltLength, maxSaltLength}
for _, saltLength := range validSaltLengths {
t.Log(tabs[4], "Salt length:", saltLength)
saltedOptions := saltOptions(unsaltedOptions, saltLength)
// 3.1. Make a "manual" signature with the given key size, hash algorithm, and salt length.
t.Log(tabs[5], "Make a manual signature")
sig, err := p.SignWithOptions(0, nil, input, &saltedOptions)
if err != nil {
t.Fatal(tabs[6], "❌ Failed to manually sign:", err)
}
// 3.2. Verify this manual signature using the *correct, given* salt length.
manualVerify(6, t, p, input, sig, saltedOptions)
// 3.3. Verify this manual signature using the *inferred* salt length.
autoVerify(6, t, p, input, sig, unsaltedOptions)
}
}
rsaKeyTypes := []KeyType{KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096}
testKeys, err := generateTestKeys()
if err != nil {
t.Fatalf("error generating test keys: %s", err)
}
// 1. For each standard RSA key size 2048, 3072, and 4096...
for _, rsaKeyType := range rsaKeyTypes {
t.Log("Key size: ", rsaKeyType)
p := &Policy{
Name: fmt.Sprint(rsaKeyType), // NOTE: crucial to create a new key per key size
Type: rsaKeyType,
}
rsaKeyBytes := testKeys[rsaKeyType]
err := p.Import(ctx, storage, rsaKeyBytes, rand.Reader)
if err != nil {
t.Fatal(tabs[1], "❌ Failed to import key:", err)
}
rsaKeyAny, err := x509.ParsePKCS8PrivateKey(rsaKeyBytes)
if err != nil {
t.Fatalf("error parsing test keys: %s", err)
}
rsaKey := rsaKeyAny.(*rsa.PrivateKey)
// 2. For each hash algorithm...
for hashAlgorithm, hashType := range HashTypeMap {
t.Log(tabs[1], "Hash algorithm:", hashAlgorithm)
// 3. For each marshaling type...
for marshalingName, marshalingType := range MarshalingTypeMap {
t.Log(tabs[2], "Marshaling type:", marshalingName)
test_RSA_PSS(p, rsaKey, hashType, marshalingType)
}
}
}
}

View File

@ -1157,6 +1157,12 @@ supports signing.
also change the output encoding to URL-safe Base64 encoding instead of
standard Base64-encoding.
- `salt_length` `(string: "auto")`  The salt length used to sign. This currently only applies to the RSA PSS signature scheme. Options are:
- `auto`: The default used by Golang (causing the salt to be as large as possible when signing)
- `hash`: Causes the salt length to equal the length of the hash used in the signature
- An integer between the minimum and the maximum permissible salt lengths for the given RSA key size.
### Sample Request
```shell-session
@ -1328,6 +1334,12 @@ data.
also expect the input encoding to URL-safe Base64 encoding instead of
standard Base64-encoding.
- `salt_length` `(string: "auto")`  The salt length used to sign. This currently only applies to the RSA PSS signature scheme. Options are:
- `auto`: The default used by Golang (causing the salt to be as large as possible when signing)
- `hash`: Causes the salt length to equal the length of the hash used in the signature
- An integer between the minimum and the maximum permissible salt lengths for the given RSA key size.
### Sample Request
```shell-session