open-vault/plugins/database/mongodb/cert_helpers_test.go

232 lines
5.0 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package mongodb
import (
"bytes"
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"strings"
"testing"
"time"
)
type certBuilder struct {
tmpl *x509.Certificate
parentTmpl *x509.Certificate
selfSign bool
parentKey *rsa.PrivateKey
isCA bool
}
type certOpt func(*certBuilder) error
func commonName(cn string) certOpt {
return func(builder *certBuilder) error {
builder.tmpl.Subject.CommonName = cn
return nil
}
}
func parent(parent certificate) certOpt {
return func(builder *certBuilder) error {
builder.parentKey = parent.privKey.privKey
builder.parentTmpl = parent.template
return nil
}
}
func isCA(isCA bool) certOpt {
return func(builder *certBuilder) error {
builder.isCA = isCA
return nil
}
}
func selfSign() certOpt {
return func(builder *certBuilder) error {
builder.selfSign = true
return nil
}
}
func dns(dns ...string) certOpt {
return func(builder *certBuilder) error {
builder.tmpl.DNSNames = dns
return nil
}
}
func newCert(t *testing.T, opts ...certOpt) (cert certificate) {
t.Helper()
builder := certBuilder{
tmpl: &x509.Certificate{
SerialNumber: makeSerial(t),
Subject: pkix.Name{
CommonName: makeCommonName(),
},
NotBefore: time.Now().Add(-1 * time.Hour),
NotAfter: time.Now().Add(1 * time.Hour),
IsCA: false,
KeyUsage: x509.KeyUsageDigitalSignature |
x509.KeyUsageKeyEncipherment |
x509.KeyUsageKeyAgreement,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
},
}
for _, opt := range opts {
err := opt(&builder)
if err != nil {
t.Fatalf("Failed to set up certificate builder: %s", err)
}
}
key := newPrivateKey(t)
builder.tmpl.SubjectKeyId = getSubjKeyID(t, key.privKey)
tmpl := builder.tmpl
parent := builder.parentTmpl
publicKey := key.privKey.Public()
signingKey := builder.parentKey
if builder.selfSign {
parent = tmpl
signingKey = key.privKey
}
if builder.isCA {
tmpl.IsCA = true
tmpl.KeyUsage = x509.KeyUsageCertSign | x509.KeyUsageCRLSign
tmpl.ExtKeyUsage = nil
} else {
tmpl.KeyUsage = x509.KeyUsageDigitalSignature |
x509.KeyUsageKeyEncipherment |
x509.KeyUsageKeyAgreement
tmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}
}
certBytes, err := x509.CreateCertificate(rand.Reader, tmpl, parent, publicKey, signingKey)
if err != nil {
t.Fatalf("Unable to generate certificate: %s", err)
}
certPem := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
})
tlsCert, err := tls.X509KeyPair(certPem, key.pem)
if err != nil {
t.Fatalf("Unable to parse X509 key pair: %s", err)
}
return certificate{
template: tmpl,
privKey: key,
tlsCert: tlsCert,
rawCert: certBytes,
pem: certPem,
isCA: builder.isCA,
}
}
// ////////////////////////////////////////////////////////////////////////////
// Private Key
// ////////////////////////////////////////////////////////////////////////////
type keyWrapper struct {
privKey *rsa.PrivateKey
pem []byte
}
func newPrivateKey(t *testing.T) (key keyWrapper) {
t.Helper()
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Unable to generate key for cert: %s", err)
}
privKeyPem := pem.EncodeToMemory(
&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privKey),
},
)
key = keyWrapper{
privKey: privKey,
pem: privKeyPem,
}
return key
}
// ////////////////////////////////////////////////////////////////////////////
// Certificate
// ////////////////////////////////////////////////////////////////////////////
type certificate struct {
privKey keyWrapper
template *x509.Certificate
tlsCert tls.Certificate
rawCert []byte
pem []byte
isCA bool
}
func (cert certificate) CombinedPEM() []byte {
if cert.isCA {
return cert.pem
}
return bytes.Join([][]byte{cert.privKey.pem, cert.pem}, []byte{'\n'})
}
// ////////////////////////////////////////////////////////////////////////////
// Helpers
// ////////////////////////////////////////////////////////////////////////////
func makeSerial(t *testing.T) *big.Int {
t.Helper()
v := &big.Int{}
serialNumberLimit := v.Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
t.Fatalf("Unable to generate serial number: %s", err)
}
return serialNumber
}
// Pulled from sdk/helper/certutil & slightly modified for test usage
func getSubjKeyID(t *testing.T, privateKey crypto.Signer) []byte {
t.Helper()
if privateKey == nil {
t.Fatalf("passed-in private key is nil")
}
marshaledKey, err := x509.MarshalPKIXPublicKey(privateKey.Public())
if err != nil {
t.Fatalf("error marshalling public key: %s", err)
}
subjKeyID := sha1.Sum(marshaledKey)
return subjKeyID[:]
}
func makeCommonName() (cn string) {
return strings.ReplaceAll(time.Now().Format("20060102T150405.000"), ".", "")
}