Allow configuring the possible salt lengths for RSA PSS signatures (#16549)
* accommodate salt lengths for RSA PSS * address feedback * generalise salt length to an int * fix error reporting * Revert "fix error reporting" This reverts commit 8adfc15fe3303b8fdf9f094ea246945ab1364077. * fix a faulty check * check for min/max salt lengths * stringly-typed HTTP param * unit tests for sign/verify HTTP requests also, add marshaling for both SDK and HTTP requests * randomly sample valid salt length * add changelog * add documentation
This commit is contained in:
parent
2fb4ed211d
commit
303f59dce3
|
@ -2,8 +2,11 @@ package transit
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/errutil"
|
||||
|
@ -131,6 +134,13 @@ Options are 'pss' or 'pkcs1v15'. Defaults to 'pss'`,
|
|||
Default: "asn1",
|
||||
Description: `The method by which to marshal the signature. The default is 'asn1' which is used by openssl and X.509. It can also be set to 'jws' which is used for JWT signatures; setting it to this will also cause the encoding of the signature to be url-safe base64 instead of using standard base64 encoding. Currently only valid for ECDSA P-256 key types".`,
|
||||
},
|
||||
|
||||
"salt_length": {
|
||||
Type: framework.TypeString,
|
||||
Default: "auto",
|
||||
Description: `The salt length used to sign. Currently only applies to the RSA PSS signature scheme.
|
||||
Options are 'auto' (the default used by Golang, causing the salt to be as large as possible when signing), 'hash' (causes the salt length to equal the length of the hash used in the signature), or an integer between the minimum and the maximum permissible salt lengths for the given RSA key size. Defaults to 'auto'.`,
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
|
@ -217,6 +227,13 @@ Options are 'pss' or 'pkcs1v15'. Defaults to 'pss'`,
|
|||
Default: "asn1",
|
||||
Description: `The method by which to unmarshal the signature when verifying. The default is 'asn1' which is used by openssl and X.509; can also be set to 'jws' which is used for JWT signatures in which case the signature is also expected to be url-safe base64 encoding instead of standard base64 encoding. Currently only valid for ECDSA P-256 key types".`,
|
||||
},
|
||||
|
||||
"salt_length": {
|
||||
Type: framework.TypeString,
|
||||
Default: "auto",
|
||||
Description: `The salt length used to sign. Currently only applies to the RSA PSS signature scheme.
|
||||
Options are 'auto' (the default used by Golang, causing the salt to be as large as possible when signing), 'hash' (causes the salt length to equal the length of the hash used in the signature), or an integer between the minimum and the maximum permissible salt lengths for the given RSA key size. Defaults to 'auto'.`,
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
|
@ -228,6 +245,33 @@ Options are 'pss' or 'pkcs1v15'. Defaults to 'pss'`,
|
|||
}
|
||||
}
|
||||
|
||||
func (b *backend) getSaltLength(d *framework.FieldData) (int, error) {
|
||||
rawSaltLength, ok := d.GetOk("salt_length")
|
||||
// This should only happen when something is wrong with the schema,
|
||||
// so this is a reasonable default.
|
||||
if !ok {
|
||||
return rsa.PSSSaltLengthAuto, nil
|
||||
}
|
||||
|
||||
rawSaltLengthStr := rawSaltLength.(string)
|
||||
lowerSaltLengthStr := strings.ToLower(rawSaltLengthStr)
|
||||
switch lowerSaltLengthStr {
|
||||
case "auto":
|
||||
return rsa.PSSSaltLengthAuto, nil
|
||||
case "hash":
|
||||
return rsa.PSSSaltLengthEqualsHash, nil
|
||||
default:
|
||||
saltLengthInt, err := strconv.Atoi(lowerSaltLengthStr)
|
||||
if err != nil {
|
||||
return rsa.PSSSaltLengthEqualsHash - 1, fmt.Errorf("salt length neither 'auto', 'hash', nor an int: %s", rawSaltLength)
|
||||
}
|
||||
if saltLengthInt < rsa.PSSSaltLengthEqualsHash {
|
||||
return rsa.PSSSaltLengthEqualsHash - 1, fmt.Errorf("salt length is invalid: %d", saltLengthInt)
|
||||
}
|
||||
return saltLengthInt, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
ver := d.Get("key_version").(int)
|
||||
|
@ -252,6 +296,10 @@ func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *fr
|
|||
|
||||
prehashed := d.Get("prehashed").(bool)
|
||||
sigAlgorithm := d.Get("signature_algorithm").(string)
|
||||
saltLength, err := b.getSaltLength(d)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Get the policy
|
||||
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
|
||||
|
@ -330,7 +378,12 @@ func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *fr
|
|||
}
|
||||
}
|
||||
|
||||
sig, err := p.Sign(ver, context, input, hashAlgorithm, sigAlgorithm, marshaling)
|
||||
sig, err := p.SignWithOptions(ver, context, input, &keysutil.SigningOptions{
|
||||
HashAlgorithm: hashAlgorithm,
|
||||
Marshaling: marshaling,
|
||||
SaltLength: saltLength,
|
||||
SigAlgorithm: sigAlgorithm,
|
||||
})
|
||||
if err != nil {
|
||||
if batchInputRaw != nil {
|
||||
response[i].Error = err.Error()
|
||||
|
@ -470,6 +523,10 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d *
|
|||
|
||||
prehashed := d.Get("prehashed").(bool)
|
||||
sigAlgorithm := d.Get("signature_algorithm").(string)
|
||||
saltLength, err := b.getSaltLength(d)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Get the policy
|
||||
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
|
||||
|
@ -533,7 +590,12 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d *
|
|||
}
|
||||
}
|
||||
|
||||
valid, err := p.VerifySignature(context, input, hashAlgorithm, sigAlgorithm, marshaling, sig)
|
||||
valid, err := p.VerifySignatureWithOptions(context, input, sig, &keysutil.SigningOptions{
|
||||
HashAlgorithm: hashAlgorithm,
|
||||
Marshaling: marshaling,
|
||||
SaltLength: saltLength,
|
||||
SigAlgorithm: sigAlgorithm,
|
||||
})
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case errutil.UserError:
|
||||
|
|
|
@ -700,3 +700,258 @@ func TestTransit_SignVerify_ED25519(t *testing.T) {
|
|||
outcome[1].valid = false
|
||||
verifyRequest(req, false, outcome, "bar", goodsig, true)
|
||||
}
|
||||
|
||||
func TestTransit_SignVerify_RSA_PSS(t *testing.T) {
|
||||
t.Run("2048", func(t *testing.T) {
|
||||
testTransit_SignVerify_RSA_PSS(t, 2048)
|
||||
})
|
||||
t.Run("3072", func(t *testing.T) {
|
||||
testTransit_SignVerify_RSA_PSS(t, 3072)
|
||||
})
|
||||
t.Run("4096", func(t *testing.T) {
|
||||
testTransit_SignVerify_RSA_PSS(t, 4096)
|
||||
})
|
||||
}
|
||||
|
||||
func testTransit_SignVerify_RSA_PSS(t *testing.T, bits int) {
|
||||
b, storage := createBackendWithSysView(t)
|
||||
|
||||
// First create a key
|
||||
req := &logical.Request{
|
||||
Storage: storage,
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "keys/foo",
|
||||
Data: map[string]interface{}{
|
||||
"type": fmt.Sprintf("rsa-%d", bits),
|
||||
},
|
||||
}
|
||||
_, err := b.HandleRequest(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
signRequest := func(errExpected bool, postpath string) string {
|
||||
t.Helper()
|
||||
req.Path = "sign/foo" + postpath
|
||||
resp, err := b.HandleRequest(context.Background(), req)
|
||||
if err != nil && !errExpected {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("expected non-nil response")
|
||||
}
|
||||
if errExpected {
|
||||
if !resp.IsError() {
|
||||
t.Fatalf("bad: should have gotten error response: %#v", *resp)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
if resp.IsError() {
|
||||
t.Fatalf("bad: got error response: %#v", *resp)
|
||||
}
|
||||
// Since we are reusing the same request, let's clear the salt length each time.
|
||||
delete(req.Data, "salt_length")
|
||||
|
||||
value, ok := resp.Data["signature"]
|
||||
if !ok {
|
||||
t.Fatalf("no signature key found in returned data, got resp data %#v", resp.Data)
|
||||
}
|
||||
return value.(string)
|
||||
}
|
||||
|
||||
verifyRequest := func(errExpected bool, postpath, sig string) {
|
||||
t.Helper()
|
||||
req.Path = "verify/foo" + postpath
|
||||
req.Data["signature"] = sig
|
||||
resp, err := b.HandleRequest(context.Background(), req)
|
||||
if err != nil {
|
||||
if errExpected {
|
||||
return
|
||||
}
|
||||
t.Fatalf("got error: %v, sig was %v", err, sig)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("expected non-nil response")
|
||||
}
|
||||
if resp.IsError() {
|
||||
if errExpected {
|
||||
return
|
||||
}
|
||||
t.Fatalf("bad: got error response: %#v", *resp)
|
||||
}
|
||||
value, ok := resp.Data["valid"]
|
||||
if !ok {
|
||||
t.Fatalf("no valid key found in returned data, got resp data %#v", resp.Data)
|
||||
}
|
||||
if !value.(bool) && !errExpected {
|
||||
t.Fatalf("verification failed; req was %#v, resp is %#v", *req, *resp)
|
||||
} else if value.(bool) && errExpected {
|
||||
t.Fatalf("expected error and didn't get one; req was %#v, resp is %#v", *req, *resp)
|
||||
}
|
||||
// Since we are reusing the same request, let's clear the signature each time.
|
||||
delete(req.Data, "signature")
|
||||
}
|
||||
|
||||
newReqData := func(hashAlgorithm string, marshalingName string) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"input": "dGhlIHF1aWNrIGJyb3duIGZveA==",
|
||||
"signature_algorithm": "pss",
|
||||
"hash_algorithm": hashAlgorithm,
|
||||
"marshaling_algorithm": marshalingName,
|
||||
}
|
||||
}
|
||||
|
||||
signAndVerifyRequest := func(hashAlgorithm string, marshalingName string, signSaltLength string, signErrExpected bool, verifySaltLength string, verifyErrExpected bool) {
|
||||
t.Log("\t\t\t", signSaltLength, "/", verifySaltLength)
|
||||
req.Data = newReqData(hashAlgorithm, marshalingName)
|
||||
|
||||
req.Data["salt_length"] = signSaltLength
|
||||
t.Log("\t\t\t\t", "sign req data:", req.Data)
|
||||
sig := signRequest(signErrExpected, "")
|
||||
|
||||
req.Data["salt_length"] = verifySaltLength
|
||||
t.Log("\t\t\t\t", "verify req data:", req.Data)
|
||||
verifyRequest(verifyErrExpected, "", sig)
|
||||
}
|
||||
|
||||
invalidSaltLengths := []string{"bar", "-2"}
|
||||
t.Log("invalidSaltLengths:", invalidSaltLengths)
|
||||
|
||||
autoSaltLengths := []string{"auto", "0"}
|
||||
t.Log("autoSaltLengths:", autoSaltLengths)
|
||||
|
||||
hashSaltLengths := []string{"hash", "-1"}
|
||||
t.Log("hashSaltLengths:", hashSaltLengths)
|
||||
|
||||
positiveSaltLengths := []string{"1"}
|
||||
t.Log("positiveSaltLengths:", positiveSaltLengths)
|
||||
|
||||
nonAutoSaltLengths := append(hashSaltLengths, positiveSaltLengths...)
|
||||
t.Log("nonAutoSaltLengths:", nonAutoSaltLengths)
|
||||
|
||||
validSaltLengths := append(autoSaltLengths, nonAutoSaltLengths...)
|
||||
t.Log("validSaltLengths:", validSaltLengths)
|
||||
|
||||
testCombinatorics := func(hashAlgorithm string, marshalingName string) {
|
||||
t.Log("\t\t", "valid", "/", "invalid salt lengths")
|
||||
for _, validSaltLength := range validSaltLengths {
|
||||
for _, invalidSaltLength := range invalidSaltLengths {
|
||||
signAndVerifyRequest(hashAlgorithm, marshalingName, validSaltLength, false, invalidSaltLength, true)
|
||||
}
|
||||
}
|
||||
|
||||
t.Log("\t\t", "invalid", "/", "invalid salt lengths")
|
||||
for _, invalidSaltLength1 := range invalidSaltLengths {
|
||||
for _, invalidSaltLength2 := range invalidSaltLengths {
|
||||
signAndVerifyRequest(hashAlgorithm, marshalingName, invalidSaltLength1, true, invalidSaltLength2, true)
|
||||
}
|
||||
}
|
||||
|
||||
t.Log("\t\t", "invalid", "/", "valid salt lengths")
|
||||
for _, invalidSaltLength := range invalidSaltLengths {
|
||||
for _, validSaltLength := range validSaltLengths {
|
||||
signAndVerifyRequest(hashAlgorithm, marshalingName, invalidSaltLength, true, validSaltLength, true)
|
||||
}
|
||||
}
|
||||
|
||||
t.Log("\t\t", "valid", "/", "valid salt lengths")
|
||||
for _, validSaltLength := range validSaltLengths {
|
||||
signAndVerifyRequest(hashAlgorithm, marshalingName, validSaltLength, false, validSaltLength, false)
|
||||
}
|
||||
|
||||
t.Log("\t\t", "hash", "/", "hash salt lengths")
|
||||
for _, hashSaltLength1 := range hashSaltLengths {
|
||||
for _, hashSaltLength2 := range hashSaltLengths {
|
||||
if hashSaltLength1 != hashSaltLength2 {
|
||||
signAndVerifyRequest(hashAlgorithm, marshalingName, hashSaltLength1, false, hashSaltLength2, false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Log("\t\t", "hash", "/", "positive salt lengths")
|
||||
for _, hashSaltLength := range hashSaltLengths {
|
||||
for _, positiveSaltLength := range positiveSaltLengths {
|
||||
signAndVerifyRequest(hashAlgorithm, marshalingName, hashSaltLength, false, positiveSaltLength, true)
|
||||
}
|
||||
}
|
||||
|
||||
t.Log("\t\t", "positive", "/", "hash salt lengths")
|
||||
for _, positiveSaltLength := range positiveSaltLengths {
|
||||
for _, hashSaltLength := range hashSaltLengths {
|
||||
signAndVerifyRequest(hashAlgorithm, marshalingName, positiveSaltLength, false, hashSaltLength, true)
|
||||
}
|
||||
}
|
||||
|
||||
t.Log("\t\t", "auto", "/", "auto salt lengths")
|
||||
for _, autoSaltLength1 := range autoSaltLengths {
|
||||
for _, autoSaltLength2 := range autoSaltLengths {
|
||||
if autoSaltLength1 != autoSaltLength2 {
|
||||
signAndVerifyRequest(hashAlgorithm, marshalingName, autoSaltLength1, false, autoSaltLength2, false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Log("\t\t", "auto", "/", "non-auto salt lengths")
|
||||
for _, autoSaltLength := range autoSaltLengths {
|
||||
for _, nonAutoSaltLength := range nonAutoSaltLengths {
|
||||
signAndVerifyRequest(hashAlgorithm, marshalingName, autoSaltLength, false, nonAutoSaltLength, true)
|
||||
}
|
||||
}
|
||||
|
||||
t.Log("\t\t", "non-auto", "/", "auto salt lengths")
|
||||
for _, nonAutoSaltLength := range nonAutoSaltLengths {
|
||||
for _, autoSaltLength := range autoSaltLengths {
|
||||
signAndVerifyRequest(hashAlgorithm, marshalingName, nonAutoSaltLength, false, autoSaltLength, false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
testAutoSignAndVerify := func(hashAlgorithm string, marshalingName string) {
|
||||
t.Log("\t\t", "Make a signature with an implicit, automatic salt length")
|
||||
req.Data = newReqData(hashAlgorithm, marshalingName)
|
||||
t.Log("\t\t\t", "sign req data:", req.Data)
|
||||
sig := signRequest(false, "")
|
||||
|
||||
t.Log("\t\t", "Verify it with an implicit, automatic salt length")
|
||||
t.Log("\t\t\t", "verify req data:", req.Data)
|
||||
verifyRequest(false, "", sig)
|
||||
|
||||
t.Log("\t\t", "Verify it with an explicit, automatic salt length")
|
||||
for _, autoSaltLength := range autoSaltLengths {
|
||||
t.Log("\t\t\t", "auto", "/", autoSaltLength)
|
||||
req.Data["salt_length"] = autoSaltLength
|
||||
t.Log("\t\t\t\t", "verify req data:", req.Data)
|
||||
verifyRequest(false, "", sig)
|
||||
}
|
||||
|
||||
t.Log("\t\t", "Try to verify it with an explicit, incorrect salt length")
|
||||
for _, nonAutoSaltLength := range nonAutoSaltLengths {
|
||||
t.Log("\t\t\t", "auto", "/", nonAutoSaltLength)
|
||||
req.Data["salt_length"] = nonAutoSaltLength
|
||||
t.Log("\t\t\t\t", "verify req data:", req.Data)
|
||||
verifyRequest(true, "", sig)
|
||||
}
|
||||
|
||||
t.Log("\t\t", "Make a signature with an explicit, valid salt length & and verify it with an implicit, automatic salt length")
|
||||
for _, validSaltLength := range validSaltLengths {
|
||||
t.Log("\t\t\t", validSaltLength, "/", "auto")
|
||||
|
||||
req.Data = newReqData(hashAlgorithm, marshalingName)
|
||||
req.Data["salt_length"] = validSaltLength
|
||||
t.Log("\t\t\t", "sign req data:", req.Data)
|
||||
sig := signRequest(false, "")
|
||||
|
||||
t.Log("\t\t\t", "verify req data:", req.Data)
|
||||
verifyRequest(false, "", sig)
|
||||
}
|
||||
}
|
||||
|
||||
for hashAlgorithm := range keysutil.HashTypeMap {
|
||||
t.Log("Hash algorithm:", hashAlgorithm)
|
||||
for marshalingName := range keysutil.MarshalingTypeMap {
|
||||
t.Log("\t", "Marshaling type:", marshalingName)
|
||||
testCombinatorics(hashAlgorithm, marshalingName)
|
||||
testAutoSignAndVerify(hashAlgorithm, marshalingName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
```release-note:improvement
|
||||
secrets/transit: Allow configuring the possible salt lengths for RSA PSS signatures.
|
||||
```
|
|
@ -1,6 +1,7 @@
|
|||
package keysutil
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
|
@ -57,6 +58,18 @@ var (
|
|||
HashTypeSHA3512: sha3.New512,
|
||||
}
|
||||
|
||||
CryptoHashMap = map[HashType]crypto.Hash{
|
||||
HashTypeSHA1: crypto.SHA1,
|
||||
HashTypeSHA2224: crypto.SHA224,
|
||||
HashTypeSHA2256: crypto.SHA256,
|
||||
HashTypeSHA2384: crypto.SHA384,
|
||||
HashTypeSHA2512: crypto.SHA512,
|
||||
HashTypeSHA3224: crypto.SHA3_224,
|
||||
HashTypeSHA3256: crypto.SHA3_256,
|
||||
HashTypeSHA3384: crypto.SHA3_384,
|
||||
HashTypeSHA3512: crypto.SHA3_512,
|
||||
}
|
||||
|
||||
MarshalingTypeMap = map[string]MarshalingType{
|
||||
"asn1": MarshalingTypeASN1,
|
||||
"jws": MarshalingTypeJWS,
|
||||
|
|
|
@ -80,6 +80,13 @@ type BackupInfo struct {
|
|||
Version int `json:"version"`
|
||||
}
|
||||
|
||||
type SigningOptions struct {
|
||||
HashAlgorithm HashType
|
||||
Marshaling MarshalingType
|
||||
SaltLength int
|
||||
SigAlgorithm string
|
||||
}
|
||||
|
||||
type SigningResult struct {
|
||||
Signature string
|
||||
PublicKey []byte
|
||||
|
@ -1026,6 +1033,29 @@ func (p *Policy) HMACKey(version int) ([]byte, error) {
|
|||
}
|
||||
|
||||
func (p *Policy) Sign(ver int, context, input []byte, hashAlgorithm HashType, sigAlgorithm string, marshaling MarshalingType) (*SigningResult, error) {
|
||||
return p.SignWithOptions(ver, context, input, &SigningOptions{
|
||||
HashAlgorithm: hashAlgorithm,
|
||||
Marshaling: marshaling,
|
||||
SaltLength: rsa.PSSSaltLengthAuto,
|
||||
SigAlgorithm: sigAlgorithm,
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Policy) minRSAPSSSaltLength() int {
|
||||
// https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/crypto/rsa/pss.go;l=247
|
||||
return rsa.PSSSaltLengthEqualsHash
|
||||
}
|
||||
|
||||
func (p *Policy) maxRSAPSSSaltLength(priv *rsa.PrivateKey, hash crypto.Hash) int {
|
||||
// https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/crypto/rsa/pss.go;l=288
|
||||
return (priv.N.BitLen()-1+7)/8 - 2 - hash.Size()
|
||||
}
|
||||
|
||||
func (p *Policy) validRSAPSSSaltLength(priv *rsa.PrivateKey, hash crypto.Hash, saltLength int) bool {
|
||||
return p.minRSAPSSSaltLength() <= saltLength && saltLength <= p.maxRSAPSSSaltLength(priv, hash)
|
||||
}
|
||||
|
||||
func (p *Policy) SignWithOptions(ver int, context, input []byte, options *SigningOptions) (*SigningResult, error) {
|
||||
if !p.Type.SigningSupported() {
|
||||
return nil, fmt.Errorf("message signing not supported for key type %v", p.Type)
|
||||
}
|
||||
|
@ -1049,6 +1079,11 @@ func (p *Policy) Sign(ver int, context, input []byte, hashAlgorithm HashType, si
|
|||
return nil, err
|
||||
}
|
||||
|
||||
hashAlgorithm := options.HashAlgorithm
|
||||
marshaling := options.Marshaling
|
||||
saltLength := options.SaltLength
|
||||
sigAlgorithm := options.SigAlgorithm
|
||||
|
||||
switch p.Type {
|
||||
case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521:
|
||||
var curveBits int
|
||||
|
@ -1139,27 +1174,8 @@ func (p *Policy) Sign(ver int, context, input []byte, hashAlgorithm HashType, si
|
|||
case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
|
||||
key := keyParams.RSAKey
|
||||
|
||||
var algo crypto.Hash
|
||||
switch hashAlgorithm {
|
||||
case HashTypeSHA1:
|
||||
algo = crypto.SHA1
|
||||
case HashTypeSHA2224:
|
||||
algo = crypto.SHA224
|
||||
case HashTypeSHA2256:
|
||||
algo = crypto.SHA256
|
||||
case HashTypeSHA2384:
|
||||
algo = crypto.SHA384
|
||||
case HashTypeSHA2512:
|
||||
algo = crypto.SHA512
|
||||
case HashTypeSHA3224:
|
||||
algo = crypto.SHA3_224
|
||||
case HashTypeSHA3256:
|
||||
algo = crypto.SHA3_256
|
||||
case HashTypeSHA3384:
|
||||
algo = crypto.SHA3_384
|
||||
case HashTypeSHA3512:
|
||||
algo = crypto.SHA3_512
|
||||
default:
|
||||
algo, ok := CryptoHashMap[hashAlgorithm]
|
||||
if !ok {
|
||||
return nil, errutil.InternalError{Err: "unsupported hash algorithm"}
|
||||
}
|
||||
|
||||
|
@ -1169,7 +1185,10 @@ func (p *Policy) Sign(ver int, context, input []byte, hashAlgorithm HashType, si
|
|||
|
||||
switch sigAlgorithm {
|
||||
case "pss":
|
||||
sig, err = rsa.SignPSS(rand.Reader, key, algo, input, nil)
|
||||
if !p.validRSAPSSSaltLength(key, algo, saltLength) {
|
||||
return nil, errutil.UserError{Err: fmt.Sprintf("requested salt length %d is invalid", saltLength)}
|
||||
}
|
||||
sig, err = rsa.SignPSS(rand.Reader, key, algo, input, &rsa.PSSOptions{SaltLength: saltLength})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1203,6 +1222,15 @@ func (p *Policy) Sign(ver int, context, input []byte, hashAlgorithm HashType, si
|
|||
}
|
||||
|
||||
func (p *Policy) VerifySignature(context, input []byte, hashAlgorithm HashType, sigAlgorithm string, marshaling MarshalingType, sig string) (bool, error) {
|
||||
return p.VerifySignatureWithOptions(context, input, sig, &SigningOptions{
|
||||
HashAlgorithm: hashAlgorithm,
|
||||
Marshaling: marshaling,
|
||||
SaltLength: rsa.PSSSaltLengthAuto,
|
||||
SigAlgorithm: sigAlgorithm,
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Policy) VerifySignatureWithOptions(context, input []byte, sig string, options *SigningOptions) (bool, error) {
|
||||
if !p.Type.SigningSupported() {
|
||||
return false, errutil.UserError{Err: fmt.Sprintf("message verification not supported for key type %v", p.Type)}
|
||||
}
|
||||
|
@ -1235,6 +1263,11 @@ func (p *Policy) VerifySignature(context, input []byte, hashAlgorithm HashType,
|
|||
return false, errutil.UserError{Err: ErrTooOld}
|
||||
}
|
||||
|
||||
hashAlgorithm := options.HashAlgorithm
|
||||
marshaling := options.Marshaling
|
||||
saltLength := options.SaltLength
|
||||
sigAlgorithm := options.SigAlgorithm
|
||||
|
||||
var sigBytes []byte
|
||||
switch marshaling {
|
||||
case MarshalingTypeASN1:
|
||||
|
@ -1318,27 +1351,8 @@ func (p *Policy) VerifySignature(context, input []byte, hashAlgorithm HashType,
|
|||
|
||||
key := keyEntry.RSAKey
|
||||
|
||||
var algo crypto.Hash
|
||||
switch hashAlgorithm {
|
||||
case HashTypeSHA1:
|
||||
algo = crypto.SHA1
|
||||
case HashTypeSHA2224:
|
||||
algo = crypto.SHA224
|
||||
case HashTypeSHA2256:
|
||||
algo = crypto.SHA256
|
||||
case HashTypeSHA2384:
|
||||
algo = crypto.SHA384
|
||||
case HashTypeSHA2512:
|
||||
algo = crypto.SHA512
|
||||
case HashTypeSHA3224:
|
||||
algo = crypto.SHA3_224
|
||||
case HashTypeSHA3256:
|
||||
algo = crypto.SHA3_256
|
||||
case HashTypeSHA3384:
|
||||
algo = crypto.SHA3_384
|
||||
case HashTypeSHA3512:
|
||||
algo = crypto.SHA3_512
|
||||
default:
|
||||
algo, ok := CryptoHashMap[hashAlgorithm]
|
||||
if !ok {
|
||||
return false, errutil.InternalError{Err: "unsupported hash algorithm"}
|
||||
}
|
||||
|
||||
|
@ -1348,7 +1362,10 @@ func (p *Policy) VerifySignature(context, input []byte, hashAlgorithm HashType,
|
|||
|
||||
switch sigAlgorithm {
|
||||
case "pss":
|
||||
err = rsa.VerifyPSS(&key.PublicKey, algo, input, sigBytes, nil)
|
||||
if !p.validRSAPSSSaltLength(key, algo, saltLength) {
|
||||
return false, errutil.UserError{Err: fmt.Sprintf("requested salt length %d is invalid", saltLength)}
|
||||
}
|
||||
err = rsa.VerifyPSS(&key.PublicKey, algo, input, sigBytes, &rsa.PSSOptions{SaltLength: saltLength})
|
||||
case "pkcs1v15":
|
||||
err = rsa.VerifyPKCS1v15(&key.PublicKey, algo, input, sigBytes)
|
||||
default:
|
||||
|
|
|
@ -8,14 +8,19 @@ import (
|
|||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
mathrand "math/rand"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ed25519"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/helper/errutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/mitchellh/copystructure"
|
||||
|
@ -698,6 +703,26 @@ func generateTestKeys() (map[KeyType][]byte, error) {
|
|||
}
|
||||
keyMap[KeyType_RSA2048] = rsaKeyBytes
|
||||
|
||||
rsaKey, err = rsa.GenerateKey(rand.Reader, 3072)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rsaKeyBytes, err = x509.MarshalPKCS8PrivateKey(rsaKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyMap[KeyType_RSA3072] = rsaKeyBytes
|
||||
|
||||
rsaKey, err = rsa.GenerateKey(rand.Reader, 4096)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rsaKeyBytes, err = x509.MarshalPKCS8PrivateKey(rsaKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyMap[KeyType_RSA4096] = rsaKeyBytes
|
||||
|
||||
ecdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -754,3 +779,179 @@ func BenchmarkSymmetric(b *testing.B) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func saltOptions(options SigningOptions, saltLength int) SigningOptions {
|
||||
return SigningOptions{
|
||||
HashAlgorithm: options.HashAlgorithm,
|
||||
Marshaling: options.Marshaling,
|
||||
SaltLength: saltLength,
|
||||
SigAlgorithm: options.SigAlgorithm,
|
||||
}
|
||||
}
|
||||
|
||||
func manualVerify(depth int, t *testing.T, p *Policy, input []byte, sig *SigningResult, options SigningOptions) {
|
||||
tabs := strings.Repeat("\t", depth)
|
||||
t.Log(tabs, "Manually verifying signature with options:", options)
|
||||
|
||||
tabs = strings.Repeat("\t", depth+1)
|
||||
verified, err := p.VerifySignatureWithOptions(nil, input, sig.Signature, &options)
|
||||
if err != nil {
|
||||
t.Fatal(tabs, "❌ Failed to manually verify signature:", err)
|
||||
}
|
||||
if !verified {
|
||||
t.Fatal(tabs, "❌ Failed to manually verify signature")
|
||||
}
|
||||
}
|
||||
|
||||
func autoVerify(depth int, t *testing.T, p *Policy, input []byte, sig *SigningResult, options SigningOptions) {
|
||||
tabs := strings.Repeat("\t", depth)
|
||||
t.Log(tabs, "Automatically verifying signature with options:", options)
|
||||
|
||||
tabs = strings.Repeat("\t", depth+1)
|
||||
verified, err := p.VerifySignature(nil, input, options.HashAlgorithm, options.SigAlgorithm, options.Marshaling, sig.Signature)
|
||||
if err != nil {
|
||||
t.Fatal(tabs, "❌ Failed to automatically verify signature:", err)
|
||||
}
|
||||
if !verified {
|
||||
t.Fatal(tabs, "❌ Failed to automatically verify signature")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_RSA_PSS(t *testing.T) {
|
||||
t.Log("Testing RSA PSS")
|
||||
mathrand.Seed(time.Now().UnixNano())
|
||||
|
||||
var userError errutil.UserError
|
||||
ctx := context.Background()
|
||||
storage := &logical.InmemStorage{}
|
||||
// https://crypto.stackexchange.com/a/1222
|
||||
input := []byte("the ancients say the longer the salt, the more provable the security")
|
||||
sigAlgorithm := "pss"
|
||||
|
||||
tabs := make(map[int]string)
|
||||
for i := 1; i <= 6; i++ {
|
||||
tabs[i] = strings.Repeat("\t", i)
|
||||
}
|
||||
|
||||
test_RSA_PSS := func(p *Policy, rsaKey *rsa.PrivateKey, hashType HashType, marshalingType MarshalingType) {
|
||||
unsaltedOptions := SigningOptions{
|
||||
HashAlgorithm: hashType,
|
||||
Marshaling: marshalingType,
|
||||
SigAlgorithm: sigAlgorithm,
|
||||
}
|
||||
cryptoHash := CryptoHashMap[hashType]
|
||||
minSaltLength := p.minRSAPSSSaltLength()
|
||||
maxSaltLength := p.maxRSAPSSSaltLength(rsaKey, cryptoHash)
|
||||
hash := cryptoHash.New()
|
||||
hash.Write(input)
|
||||
input = hash.Sum(nil)
|
||||
|
||||
// 1. Make an "automatic" signature with the given key size and hash algorithm,
|
||||
// but an automatically chosen salt length.
|
||||
t.Log(tabs[3], "Make an automatic signature")
|
||||
sig, err := p.Sign(0, nil, input, hashType, sigAlgorithm, marshalingType)
|
||||
if err != nil {
|
||||
t.Fatal(tabs[4], "❌ Failed to automatically sign:", err)
|
||||
}
|
||||
|
||||
// 1.1 Verify this automatic signature using the *inferred* salt length.
|
||||
autoVerify(4, t, p, input, sig, unsaltedOptions)
|
||||
|
||||
// 1.2. Verify this automatic signature using the *correct, given* salt length.
|
||||
manualVerify(4, t, p, input, sig, saltOptions(unsaltedOptions, maxSaltLength))
|
||||
|
||||
// 1.3. Try to verify this automatic signature using *incorrect, given* salt lengths.
|
||||
t.Log(tabs[4], "Test incorrect salt lengths")
|
||||
incorrectSaltLengths := []int{minSaltLength, maxSaltLength - 1}
|
||||
for _, saltLength := range incorrectSaltLengths {
|
||||
t.Log(tabs[5], "Salt length:", saltLength)
|
||||
saltedOptions := saltOptions(unsaltedOptions, saltLength)
|
||||
|
||||
verified, _ := p.VerifySignatureWithOptions(nil, input, sig.Signature, &saltedOptions)
|
||||
if verified {
|
||||
t.Fatal(tabs[6], "❌ Failed to invalidate", verified, "signature using incorrect salt length:", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Rule out boundary, invalid salt lengths.
|
||||
t.Log(tabs[3], "Test invalid salt lengths")
|
||||
invalidSaltLengths := []int{minSaltLength - 1, maxSaltLength + 1}
|
||||
for _, saltLength := range invalidSaltLengths {
|
||||
t.Log(tabs[4], "Salt length:", saltLength)
|
||||
saltedOptions := saltOptions(unsaltedOptions, saltLength)
|
||||
|
||||
// 2.1. Fail to sign.
|
||||
t.Log(tabs[5], "Try to make a manual signature")
|
||||
_, err := p.SignWithOptions(0, nil, input, &saltedOptions)
|
||||
if !errors.As(err, &userError) {
|
||||
t.Fatal(tabs[6], "❌ Failed to reject invalid salt length:", err)
|
||||
}
|
||||
|
||||
// 2.2. Fail to verify.
|
||||
t.Log(tabs[5], "Try to verify an automatic signature using an invalid salt length")
|
||||
_, err = p.VerifySignatureWithOptions(nil, input, sig.Signature, &saltedOptions)
|
||||
if !errors.As(err, &userError) {
|
||||
t.Fatal(tabs[6], "❌ Failed to reject invalid salt length:", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. For three possible valid salt lengths...
|
||||
t.Log(tabs[3], "Test three possible valid salt lengths")
|
||||
midSaltLength := mathrand.Intn(maxSaltLength-1) + 1 // [1, maxSaltLength)
|
||||
validSaltLengths := []int{minSaltLength, midSaltLength, maxSaltLength}
|
||||
for _, saltLength := range validSaltLengths {
|
||||
t.Log(tabs[4], "Salt length:", saltLength)
|
||||
saltedOptions := saltOptions(unsaltedOptions, saltLength)
|
||||
|
||||
// 3.1. Make a "manual" signature with the given key size, hash algorithm, and salt length.
|
||||
t.Log(tabs[5], "Make a manual signature")
|
||||
sig, err := p.SignWithOptions(0, nil, input, &saltedOptions)
|
||||
if err != nil {
|
||||
t.Fatal(tabs[6], "❌ Failed to manually sign:", err)
|
||||
}
|
||||
|
||||
// 3.2. Verify this manual signature using the *correct, given* salt length.
|
||||
manualVerify(6, t, p, input, sig, saltedOptions)
|
||||
|
||||
// 3.3. Verify this manual signature using the *inferred* salt length.
|
||||
autoVerify(6, t, p, input, sig, unsaltedOptions)
|
||||
}
|
||||
}
|
||||
|
||||
rsaKeyTypes := []KeyType{KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096}
|
||||
testKeys, err := generateTestKeys()
|
||||
if err != nil {
|
||||
t.Fatalf("error generating test keys: %s", err)
|
||||
}
|
||||
|
||||
// 1. For each standard RSA key size 2048, 3072, and 4096...
|
||||
for _, rsaKeyType := range rsaKeyTypes {
|
||||
t.Log("Key size: ", rsaKeyType)
|
||||
p := &Policy{
|
||||
Name: fmt.Sprint(rsaKeyType), // NOTE: crucial to create a new key per key size
|
||||
Type: rsaKeyType,
|
||||
}
|
||||
|
||||
rsaKeyBytes := testKeys[rsaKeyType]
|
||||
err := p.Import(ctx, storage, rsaKeyBytes, rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(tabs[1], "❌ Failed to import key:", err)
|
||||
}
|
||||
rsaKeyAny, err := x509.ParsePKCS8PrivateKey(rsaKeyBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing test keys: %s", err)
|
||||
}
|
||||
rsaKey := rsaKeyAny.(*rsa.PrivateKey)
|
||||
|
||||
// 2. For each hash algorithm...
|
||||
for hashAlgorithm, hashType := range HashTypeMap {
|
||||
t.Log(tabs[1], "Hash algorithm:", hashAlgorithm)
|
||||
|
||||
// 3. For each marshaling type...
|
||||
for marshalingName, marshalingType := range MarshalingTypeMap {
|
||||
t.Log(tabs[2], "Marshaling type:", marshalingName)
|
||||
test_RSA_PSS(p, rsaKey, hashType, marshalingType)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1157,6 +1157,12 @@ supports signing.
|
|||
also change the output encoding to URL-safe Base64 encoding instead of
|
||||
standard Base64-encoding.
|
||||
|
||||
- `salt_length` `(string: "auto")` – The salt length used to sign. This currently only applies to the RSA PSS signature scheme. Options are:
|
||||
|
||||
- `auto`: The default used by Golang (causing the salt to be as large as possible when signing)
|
||||
- `hash`: Causes the salt length to equal the length of the hash used in the signature
|
||||
- An integer between the minimum and the maximum permissible salt lengths for the given RSA key size.
|
||||
|
||||
### Sample Request
|
||||
|
||||
```shell-session
|
||||
|
@ -1328,6 +1334,12 @@ data.
|
|||
also expect the input encoding to URL-safe Base64 encoding instead of
|
||||
standard Base64-encoding.
|
||||
|
||||
- `salt_length` `(string: "auto")` – The salt length used to sign. This currently only applies to the RSA PSS signature scheme. Options are:
|
||||
|
||||
- `auto`: The default used by Golang (causing the salt to be as large as possible when signing)
|
||||
- `hash`: Causes the salt length to equal the length of the hash used in the signature
|
||||
- An integer between the minimum and the maximum permissible salt lengths for the given RSA key size.
|
||||
|
||||
### Sample Request
|
||||
|
||||
```shell-session
|
||||
|
|
Loading…
Reference in New Issue