diff --git a/builtin/logical/pki/backend_test.go b/builtin/logical/pki/backend_test.go index 19fcd4fdc..4e1af6da3 100644 --- a/builtin/logical/pki/backend_test.go +++ b/builtin/logical/pki/backend_test.go @@ -56,7 +56,7 @@ func TestBackend_RSAKey(t *testing.T) { intdata := map[string]interface{}{} reqdata := map[string]interface{}{} - testCase.Steps = append(testCase.Steps, generateCATestingSteps(t, rsaCACert, rsaCAKey, intdata, reqdata)...) + testCase.Steps = append(testCase.Steps, generateCATestingSteps(t, rsaCACert, rsaCAKey, ecCACert, intdata, reqdata)...) logicaltest.Test(t, testCase) } @@ -86,7 +86,7 @@ func TestBackend_ECKey(t *testing.T) { intdata := map[string]interface{}{} reqdata := map[string]interface{}{} - testCase.Steps = append(testCase.Steps, generateCATestingSteps(t, ecCACert, ecCAKey, intdata, reqdata)...) + testCase.Steps = append(testCase.Steps, generateCATestingSteps(t, ecCACert, ecCAKey, rsaCACert, intdata, reqdata)...) logicaltest.Test(t, testCase) } @@ -480,6 +480,7 @@ func generateCSRSteps(t *testing.T, caCert, caKey string, intdata, reqdata map[s "max_path_length": 1, }, }, + logicaltest.TestStep{ Operation: logical.WriteOperation, Path: "config/ca/sign", @@ -533,7 +534,7 @@ func generateCSRSteps(t *testing.T, caCert, caKey string, intdata, reqdata map[s // Generates steps to test out CA configuration -- certificates + CRL expiry, // and ensure that the certificates are readable after storing them -func generateCATestingSteps(t *testing.T, caCert, caKey string, intdata, reqdata map[string]interface{}) []logicaltest.TestStep { +func generateCATestingSteps(t *testing.T, caCert, caKey, otherCaCert string, intdata, reqdata map[string]interface{}) []logicaltest.TestStep { ret := []logicaltest.TestStep{ logicaltest.TestStep{ Operation: logical.WriteOperation, @@ -613,6 +614,18 @@ func generateCATestingSteps(t *testing.T, caCert, caKey string, intdata, reqdata // Now test uploading when the private key is already stored, such // as when uploading a CA signed as the result of a generated CSR + // First we test the wrong one, to ensure that the key comparator is + // working correctly + logicaltest.TestStep{ + Operation: logical.WriteOperation, + Path: "config/ca/set", + Data: map[string]interface{}{ + "pem_bundle": otherCaCert, + }, + ErrorOk: true, + }, + + // Now, the right one logicaltest.TestStep{ Operation: logical.WriteOperation, Path: "config/ca/set", diff --git a/builtin/logical/pki/path_config_ca.go b/builtin/logical/pki/path_config_ca.go index f395391bc..ba14c80d2 100644 --- a/builtin/logical/pki/path_config_ca.go +++ b/builtin/logical/pki/path_config_ca.go @@ -3,7 +3,6 @@ package pki import ( "encoding/base64" "fmt" - "reflect" "time" "github.com/hashicorp/vault/helper/certutil" @@ -446,8 +445,12 @@ func (b *backend) pathCASetWrite( } // If true, the stored private key corresponds to the cert's // public key, so fill it in - //panic(fmt.Sprintf("\nparsedCB.PrivateKey.Public().: %#v\nparsedBundle.Certificate.PublicKey")) - if reflect.DeepEqual(parsedCB.PrivateKey.Public(), parsedBundle.Certificate.PublicKey) { + equal, err := certutil.ComparePublicKeys(parsedCB.PrivateKey.Public(), parsedBundle.Certificate.PublicKey) + if err != nil { + return logical.ErrorResponse( + "stored public key does not match the public key on the certificate"), nil + } + if equal { parsedBundle.PrivateKey = parsedCB.PrivateKey parsedBundle.PrivateKeyType = parsedCB.PrivateKeyType parsedBundle.PrivateKeyBytes = parsedCB.PrivateKeyBytes diff --git a/helper/certutil/helpers.go b/helper/certutil/helpers.go index 49ad055fd..c283ee912 100644 --- a/helper/certutil/helpers.go +++ b/helper/certutil/helpers.go @@ -183,6 +183,7 @@ func ParsePEMBundle(pemBundle string) (*ParsedCertBundle, error) { return parsedBundle, nil } +// GeneratePrivateKey generates a private key with the specified type and key bits func GeneratePrivateKey(keyType string, keyBits int, emb EmbeddedParsedPrivateKeyContainer) error { var err error result := &EmbeddedParsedPrivateKey{} @@ -226,6 +227,7 @@ func GeneratePrivateKey(keyType string, keyBits int, emb EmbeddedParsedPrivateKe return nil } +// GenerateSerialNumber generates a serial number suitable for a certificate func GenerateSerialNumber() (*big.Int, error) { serial, err := rand.Int(rand.Reader, (&big.Int{}).Exp(big.NewInt(2), big.NewInt(159), nil)) if err != nil { @@ -233,3 +235,47 @@ func GenerateSerialNumber() (*big.Int, error) { } return serial, nil } + +// ComparePublicKeys compares two public keys and returns true if they match +func ComparePublicKeys(key1Iface, key2Iface crypto.PublicKey) (bool, error) { + switch key1Iface.(type) { + case *rsa.PublicKey: + key1 := key1Iface.(*rsa.PublicKey) + key2, ok := key2Iface.(*rsa.PublicKey) + if !ok { + return false, fmt.Errorf("key types do not match: %T and %T", key1Iface, key2Iface) + } + if key1.N.Cmp(key2.N) != 0 || + key1.E != key2.E { + return false, nil + } + return true, nil + + case *ecdsa.PublicKey: + key1 := key1Iface.(*ecdsa.PublicKey) + key2, ok := key2Iface.(*ecdsa.PublicKey) + if !ok { + return false, fmt.Errorf("key types do not match: %T and %T", key1Iface, key2Iface) + } + if key1.X.Cmp(key2.X) != 0 || + key1.Y.Cmp(key2.Y) != 0 { + return false, nil + } + key1Params := key1.Params() + key2Params := key2.Params() + if key1Params.P.Cmp(key2Params.P) != 0 || + key1Params.N.Cmp(key2Params.N) != 0 || + key1Params.B.Cmp(key2Params.B) != 0 || + key1Params.Gx.Cmp(key2Params.Gx) != 0 || + key1Params.Gy.Cmp(key2Params.Gy) != 0 || + key1Params.BitSize != key2Params.BitSize { + return false, nil + } + return true, nil + + default: + return false, fmt.Errorf("cannot compare key with type %T", key1Iface) + } + + return false, fmt.Errorf("undefined error comparing public keys") +}