Move public key comparison logic to its own function

This commit is contained in:
Jeff Mitchell 2015-11-17 10:01:42 -05:00
parent 4681d027c0
commit 26c8cf874d
3 changed files with 68 additions and 6 deletions

View file

@ -56,7 +56,7 @@ func TestBackend_RSAKey(t *testing.T) {
intdata := map[string]interface{}{} intdata := map[string]interface{}{}
reqdata := map[string]interface{}{} reqdata := map[string]interface{}{}
testCase.Steps = append(testCase.Steps, generateCATestingSteps(t, rsaCACert, rsaCAKey, intdata, reqdata)...) testCase.Steps = append(testCase.Steps, generateCATestingSteps(t, rsaCACert, rsaCAKey, ecCACert, intdata, reqdata)...)
logicaltest.Test(t, testCase) logicaltest.Test(t, testCase)
} }
@ -86,7 +86,7 @@ func TestBackend_ECKey(t *testing.T) {
intdata := map[string]interface{}{} intdata := map[string]interface{}{}
reqdata := map[string]interface{}{} reqdata := map[string]interface{}{}
testCase.Steps = append(testCase.Steps, generateCATestingSteps(t, ecCACert, ecCAKey, intdata, reqdata)...) testCase.Steps = append(testCase.Steps, generateCATestingSteps(t, ecCACert, ecCAKey, rsaCACert, intdata, reqdata)...)
logicaltest.Test(t, testCase) logicaltest.Test(t, testCase)
} }
@ -480,6 +480,7 @@ func generateCSRSteps(t *testing.T, caCert, caKey string, intdata, reqdata map[s
"max_path_length": 1, "max_path_length": 1,
}, },
}, },
logicaltest.TestStep{ logicaltest.TestStep{
Operation: logical.WriteOperation, Operation: logical.WriteOperation,
Path: "config/ca/sign", Path: "config/ca/sign",
@ -533,7 +534,7 @@ func generateCSRSteps(t *testing.T, caCert, caKey string, intdata, reqdata map[s
// Generates steps to test out CA configuration -- certificates + CRL expiry, // Generates steps to test out CA configuration -- certificates + CRL expiry,
// and ensure that the certificates are readable after storing them // and ensure that the certificates are readable after storing them
func generateCATestingSteps(t *testing.T, caCert, caKey string, intdata, reqdata map[string]interface{}) []logicaltest.TestStep { func generateCATestingSteps(t *testing.T, caCert, caKey, otherCaCert string, intdata, reqdata map[string]interface{}) []logicaltest.TestStep {
ret := []logicaltest.TestStep{ ret := []logicaltest.TestStep{
logicaltest.TestStep{ logicaltest.TestStep{
Operation: logical.WriteOperation, Operation: logical.WriteOperation,
@ -613,6 +614,18 @@ func generateCATestingSteps(t *testing.T, caCert, caKey string, intdata, reqdata
// Now test uploading when the private key is already stored, such // Now test uploading when the private key is already stored, such
// as when uploading a CA signed as the result of a generated CSR // as when uploading a CA signed as the result of a generated CSR
// First we test the wrong one, to ensure that the key comparator is
// working correctly
logicaltest.TestStep{
Operation: logical.WriteOperation,
Path: "config/ca/set",
Data: map[string]interface{}{
"pem_bundle": otherCaCert,
},
ErrorOk: true,
},
// Now, the right one
logicaltest.TestStep{ logicaltest.TestStep{
Operation: logical.WriteOperation, Operation: logical.WriteOperation,
Path: "config/ca/set", Path: "config/ca/set",

View file

@ -3,7 +3,6 @@ package pki
import ( import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"reflect"
"time" "time"
"github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/helper/certutil"
@ -446,8 +445,12 @@ func (b *backend) pathCASetWrite(
} }
// If true, the stored private key corresponds to the cert's // If true, the stored private key corresponds to the cert's
// public key, so fill it in // public key, so fill it in
//panic(fmt.Sprintf("\nparsedCB.PrivateKey.Public().: %#v\nparsedBundle.Certificate.PublicKey")) equal, err := certutil.ComparePublicKeys(parsedCB.PrivateKey.Public(), parsedBundle.Certificate.PublicKey)
if reflect.DeepEqual(parsedCB.PrivateKey.Public(), parsedBundle.Certificate.PublicKey) { if err != nil {
return logical.ErrorResponse(
"stored public key does not match the public key on the certificate"), nil
}
if equal {
parsedBundle.PrivateKey = parsedCB.PrivateKey parsedBundle.PrivateKey = parsedCB.PrivateKey
parsedBundle.PrivateKeyType = parsedCB.PrivateKeyType parsedBundle.PrivateKeyType = parsedCB.PrivateKeyType
parsedBundle.PrivateKeyBytes = parsedCB.PrivateKeyBytes parsedBundle.PrivateKeyBytes = parsedCB.PrivateKeyBytes

View file

@ -183,6 +183,7 @@ func ParsePEMBundle(pemBundle string) (*ParsedCertBundle, error) {
return parsedBundle, nil return parsedBundle, nil
} }
// GeneratePrivateKey generates a private key with the specified type and key bits
func GeneratePrivateKey(keyType string, keyBits int, emb EmbeddedParsedPrivateKeyContainer) error { func GeneratePrivateKey(keyType string, keyBits int, emb EmbeddedParsedPrivateKeyContainer) error {
var err error var err error
result := &EmbeddedParsedPrivateKey{} result := &EmbeddedParsedPrivateKey{}
@ -226,6 +227,7 @@ func GeneratePrivateKey(keyType string, keyBits int, emb EmbeddedParsedPrivateKe
return nil return nil
} }
// GenerateSerialNumber generates a serial number suitable for a certificate
func GenerateSerialNumber() (*big.Int, error) { func GenerateSerialNumber() (*big.Int, error) {
serial, err := rand.Int(rand.Reader, (&big.Int{}).Exp(big.NewInt(2), big.NewInt(159), nil)) serial, err := rand.Int(rand.Reader, (&big.Int{}).Exp(big.NewInt(2), big.NewInt(159), nil))
if err != nil { if err != nil {
@ -233,3 +235,47 @@ func GenerateSerialNumber() (*big.Int, error) {
} }
return serial, nil return serial, nil
} }
// ComparePublicKeys compares two public keys and returns true if they match
func ComparePublicKeys(key1Iface, key2Iface crypto.PublicKey) (bool, error) {
switch key1Iface.(type) {
case *rsa.PublicKey:
key1 := key1Iface.(*rsa.PublicKey)
key2, ok := key2Iface.(*rsa.PublicKey)
if !ok {
return false, fmt.Errorf("key types do not match: %T and %T", key1Iface, key2Iface)
}
if key1.N.Cmp(key2.N) != 0 ||
key1.E != key2.E {
return false, nil
}
return true, nil
case *ecdsa.PublicKey:
key1 := key1Iface.(*ecdsa.PublicKey)
key2, ok := key2Iface.(*ecdsa.PublicKey)
if !ok {
return false, fmt.Errorf("key types do not match: %T and %T", key1Iface, key2Iface)
}
if key1.X.Cmp(key2.X) != 0 ||
key1.Y.Cmp(key2.Y) != 0 {
return false, nil
}
key1Params := key1.Params()
key2Params := key2.Params()
if key1Params.P.Cmp(key2Params.P) != 0 ||
key1Params.N.Cmp(key2Params.N) != 0 ||
key1Params.B.Cmp(key2Params.B) != 0 ||
key1Params.Gx.Cmp(key2Params.Gx) != 0 ||
key1Params.Gy.Cmp(key2Params.Gy) != 0 ||
key1Params.BitSize != key2Params.BitSize {
return false, nil
}
return true, nil
default:
return false, fmt.Errorf("cannot compare key with type %T", key1Iface)
}
return false, fmt.Errorf("undefined error comparing public keys")
}