512 lines
13 KiB
Go
512 lines
13 KiB
Go
package cert
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/hashicorp/vault/api"
|
|
"github.com/hashicorp/vault/logical"
|
|
"github.com/hashicorp/vault/logical/framework"
|
|
logicaltest "github.com/hashicorp/vault/logical/testing"
|
|
"github.com/mitchellh/mapstructure"
|
|
)
|
|
|
|
func testFactory(t *testing.T) logical.Backend {
|
|
b, err := Factory(&logical.BackendConfig{
|
|
System: &logical.StaticSystemView{
|
|
DefaultLeaseTTLVal: 300 * time.Second,
|
|
MaxLeaseTTLVal: 1800 * time.Second,
|
|
},
|
|
StorageView: &logical.InmemStorage{},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("error: %s", err)
|
|
}
|
|
return b
|
|
}
|
|
|
|
// Test the certificates being registered to the backend
|
|
func TestBackend_CertWrites(t *testing.T) {
|
|
// CA cert
|
|
ca1, err := ioutil.ReadFile("test-fixtures/root/rootcacert.pem")
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
// Non CA Cert
|
|
ca2, err := ioutil.ReadFile("test-fixtures/keys/cert.pem")
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
// Non CA cert without TLS web client authentication
|
|
ca3, err := ioutil.ReadFile("test-fixtures/noclientauthcert.pem")
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
tc := logicaltest.TestCase{
|
|
AcceptanceTest: true,
|
|
Backend: testFactory(t),
|
|
Steps: []logicaltest.TestStep{
|
|
testAccStepCert(t, "aaa", ca1, "foo", false),
|
|
testAccStepCert(t, "bbb", ca2, "foo", false),
|
|
testAccStepCert(t, "ccc", ca3, "foo", true),
|
|
},
|
|
}
|
|
tc.Steps = append(tc.Steps, testAccStepListCerts(t, []string{"aaa", "bbb"})...)
|
|
logicaltest.Test(t, tc)
|
|
}
|
|
|
|
// Test a client trusted by a CA
|
|
func TestBackend_basic_CA(t *testing.T) {
|
|
connState := testConnState(t, "test-fixtures/keys/cert.pem",
|
|
"test-fixtures/keys/key.pem", "test-fixtures/root/rootcacert.pem")
|
|
ca, err := ioutil.ReadFile("test-fixtures/root/rootcacert.pem")
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
logicaltest.Test(t, logicaltest.TestCase{
|
|
AcceptanceTest: true,
|
|
Backend: testFactory(t),
|
|
Steps: []logicaltest.TestStep{
|
|
testAccStepCert(t, "web", ca, "foo", false),
|
|
testAccStepLogin(t, connState),
|
|
testAccStepCertLease(t, "web", ca, "foo"),
|
|
testAccStepCertTTL(t, "web", ca, "foo"),
|
|
testAccStepLogin(t, connState),
|
|
testAccStepCertNoLease(t, "web", ca, "foo"),
|
|
testAccStepLoginDefaultLease(t, connState),
|
|
},
|
|
})
|
|
}
|
|
|
|
// Test CRL behavior
|
|
func TestBackend_CRLs(t *testing.T) {
|
|
connState := testConnState(t, "test-fixtures/keys/cert.pem",
|
|
"test-fixtures/keys/key.pem", "test-fixtures/root/rootcacert.pem")
|
|
ca, err := ioutil.ReadFile("test-fixtures/root/rootcacert.pem")
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
crl, err := ioutil.ReadFile("test-fixtures/root/root.crl")
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
logicaltest.Test(t, logicaltest.TestCase{
|
|
AcceptanceTest: true,
|
|
Backend: testFactory(t),
|
|
Steps: []logicaltest.TestStep{
|
|
testAccStepCertNoLease(t, "web", ca, "foo"),
|
|
testAccStepLoginDefaultLease(t, connState),
|
|
testAccStepAddCRL(t, crl, connState),
|
|
testAccStepReadCRL(t, connState),
|
|
testAccStepLoginInvalid(t, connState),
|
|
testAccStepDeleteCRL(t, connState),
|
|
testAccStepLoginDefaultLease(t, connState),
|
|
},
|
|
})
|
|
}
|
|
|
|
// Test a self-signed client that is trusted
|
|
func TestBackend_basic_singleCert(t *testing.T) {
|
|
connState := testConnState(t, "test-fixtures/keys/cert.pem",
|
|
"test-fixtures/keys/key.pem", "test-fixtures/root/rootcacert.pem")
|
|
ca, err := ioutil.ReadFile("test-fixtures/root/rootcacert.pem")
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
logicaltest.Test(t, logicaltest.TestCase{
|
|
AcceptanceTest: true,
|
|
Backend: testFactory(t),
|
|
Steps: []logicaltest.TestStep{
|
|
testAccStepCert(t, "web", ca, "foo", false),
|
|
testAccStepLogin(t, connState),
|
|
},
|
|
})
|
|
}
|
|
|
|
// Test an untrusted self-signed client
|
|
func TestBackend_untrusted(t *testing.T) {
|
|
connState := testConnState(t, "test-fixtures/keys/cert.pem",
|
|
"test-fixtures/keys/key.pem", "test-fixtures/root/rootcacert.pem")
|
|
logicaltest.Test(t, logicaltest.TestCase{
|
|
AcceptanceTest: true,
|
|
Backend: testFactory(t),
|
|
Steps: []logicaltest.TestStep{
|
|
testAccStepLoginInvalid(t, connState),
|
|
},
|
|
})
|
|
}
|
|
|
|
func testAccStepAddCRL(t *testing.T, crl []byte, connState tls.ConnectionState) logicaltest.TestStep {
|
|
return logicaltest.TestStep{
|
|
Operation: logical.UpdateOperation,
|
|
Path: "crls/test",
|
|
ConnState: &connState,
|
|
Data: map[string]interface{}{
|
|
"crl": crl,
|
|
},
|
|
}
|
|
}
|
|
|
|
func testAccStepReadCRL(t *testing.T, connState tls.ConnectionState) logicaltest.TestStep {
|
|
return logicaltest.TestStep{
|
|
Operation: logical.ReadOperation,
|
|
Path: "crls/test",
|
|
ConnState: &connState,
|
|
Check: func(resp *logical.Response) error {
|
|
crlInfo := CRLInfo{}
|
|
err := mapstructure.Decode(resp.Data, &crlInfo)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
if len(crlInfo.Serials) != 1 {
|
|
t.Fatalf("bad: expected CRL with length 1, got %d", len(crlInfo.Serials))
|
|
}
|
|
if _, ok := crlInfo.Serials["637101449987587619778072672905061040630001617053"]; !ok {
|
|
t.Fatalf("bad: expected serial number not found in CRL")
|
|
}
|
|
return nil
|
|
},
|
|
}
|
|
}
|
|
|
|
func testAccStepDeleteCRL(t *testing.T, connState tls.ConnectionState) logicaltest.TestStep {
|
|
return logicaltest.TestStep{
|
|
Operation: logical.DeleteOperation,
|
|
Path: "crls/test",
|
|
ConnState: &connState,
|
|
}
|
|
}
|
|
|
|
func testAccStepLogin(t *testing.T, connState tls.ConnectionState) logicaltest.TestStep {
|
|
return logicaltest.TestStep{
|
|
Operation: logical.UpdateOperation,
|
|
Path: "login",
|
|
Unauthenticated: true,
|
|
ConnState: &connState,
|
|
Check: func(resp *logical.Response) error {
|
|
if resp.Auth.TTL != 1000*time.Second {
|
|
t.Fatalf("bad lease length: %#v", resp.Auth)
|
|
}
|
|
|
|
fn := logicaltest.TestCheckAuth([]string{"default", "foo"})
|
|
return fn(resp)
|
|
},
|
|
}
|
|
}
|
|
|
|
func testAccStepLoginDefaultLease(t *testing.T, connState tls.ConnectionState) logicaltest.TestStep {
|
|
return logicaltest.TestStep{
|
|
Operation: logical.UpdateOperation,
|
|
Path: "login",
|
|
Unauthenticated: true,
|
|
ConnState: &connState,
|
|
Check: func(resp *logical.Response) error {
|
|
if resp.Auth.TTL != 300*time.Second {
|
|
t.Fatalf("bad lease length: %#v", resp.Auth)
|
|
}
|
|
|
|
fn := logicaltest.TestCheckAuth([]string{"default", "foo"})
|
|
return fn(resp)
|
|
},
|
|
}
|
|
}
|
|
|
|
func testAccStepLoginInvalid(t *testing.T, connState tls.ConnectionState) logicaltest.TestStep {
|
|
return logicaltest.TestStep{
|
|
Operation: logical.UpdateOperation,
|
|
Path: "login",
|
|
Unauthenticated: true,
|
|
ConnState: &connState,
|
|
Check: func(resp *logical.Response) error {
|
|
if resp.Auth != nil {
|
|
return fmt.Errorf("should not be authorized: %#v", resp)
|
|
}
|
|
return nil
|
|
},
|
|
ErrorOk: true,
|
|
}
|
|
}
|
|
|
|
func testAccStepListCerts(
|
|
t *testing.T, certs []string) []logicaltest.TestStep {
|
|
return []logicaltest.TestStep{
|
|
logicaltest.TestStep{
|
|
Operation: logical.ListOperation,
|
|
Path: "certs",
|
|
Check: func(resp *logical.Response) error {
|
|
if resp == nil {
|
|
return fmt.Errorf("nil response")
|
|
}
|
|
if resp.Data == nil {
|
|
return fmt.Errorf("nil data")
|
|
}
|
|
if resp.Data["keys"] == interface{}(nil) {
|
|
return fmt.Errorf("nil keys")
|
|
}
|
|
keys := resp.Data["keys"].([]string)
|
|
if !reflect.DeepEqual(keys, certs) {
|
|
return fmt.Errorf("mismatch: keys is %#v, certs is %#v", keys, certs)
|
|
}
|
|
return nil
|
|
},
|
|
}, logicaltest.TestStep{
|
|
Operation: logical.ListOperation,
|
|
Path: "certs/",
|
|
Check: func(resp *logical.Response) error {
|
|
if resp == nil {
|
|
return fmt.Errorf("nil response")
|
|
}
|
|
if resp.Data == nil {
|
|
return fmt.Errorf("nil data")
|
|
}
|
|
if resp.Data["keys"] == interface{}(nil) {
|
|
return fmt.Errorf("nil keys")
|
|
}
|
|
keys := resp.Data["keys"].([]string)
|
|
if !reflect.DeepEqual(keys, certs) {
|
|
return fmt.Errorf("mismatch: keys is %#v, certs is %#v", keys, certs)
|
|
}
|
|
|
|
return nil
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func testAccStepCert(
|
|
t *testing.T, name string, cert []byte, policies string, expectError bool) logicaltest.TestStep {
|
|
return logicaltest.TestStep{
|
|
Operation: logical.UpdateOperation,
|
|
Path: "certs/" + name,
|
|
ErrorOk: expectError,
|
|
Data: map[string]interface{}{
|
|
"certificate": string(cert),
|
|
"policies": policies,
|
|
"display_name": name,
|
|
"lease": 1000,
|
|
},
|
|
Check: func(resp *logical.Response) error {
|
|
if resp == nil && expectError {
|
|
return fmt.Errorf("expected error but received nil")
|
|
}
|
|
return nil
|
|
},
|
|
}
|
|
}
|
|
|
|
func testAccStepCertLease(
|
|
t *testing.T, name string, cert []byte, policies string) logicaltest.TestStep {
|
|
return logicaltest.TestStep{
|
|
Operation: logical.UpdateOperation,
|
|
Path: "certs/" + name,
|
|
Data: map[string]interface{}{
|
|
"certificate": string(cert),
|
|
"policies": policies,
|
|
"display_name": name,
|
|
"lease": 1000,
|
|
},
|
|
}
|
|
}
|
|
|
|
func testAccStepCertTTL(
|
|
t *testing.T, name string, cert []byte, policies string) logicaltest.TestStep {
|
|
return logicaltest.TestStep{
|
|
Operation: logical.UpdateOperation,
|
|
Path: "certs/" + name,
|
|
Data: map[string]interface{}{
|
|
"certificate": string(cert),
|
|
"policies": policies,
|
|
"display_name": name,
|
|
"ttl": "1000s",
|
|
},
|
|
}
|
|
}
|
|
|
|
func testAccStepCertNoLease(
|
|
t *testing.T, name string, cert []byte, policies string) logicaltest.TestStep {
|
|
return logicaltest.TestStep{
|
|
Operation: logical.UpdateOperation,
|
|
Path: "certs/" + name,
|
|
Data: map[string]interface{}{
|
|
"certificate": string(cert),
|
|
"policies": policies,
|
|
"display_name": name,
|
|
},
|
|
}
|
|
}
|
|
|
|
func testConnState(t *testing.T, certPath, keyPath, rootCertPath string) tls.ConnectionState {
|
|
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
rootCAs, err := api.LoadCACert(rootCertPath)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
listenConf := &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
ClientAuth: tls.RequestClientCert,
|
|
InsecureSkipVerify: false,
|
|
RootCAs: rootCAs,
|
|
}
|
|
dialConf := new(tls.Config)
|
|
*dialConf = *listenConf
|
|
list, err := tls.Listen("tcp", "127.0.0.1:0", listenConf)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
defer list.Close()
|
|
|
|
go func() {
|
|
addr := list.Addr().String()
|
|
conn, err := tls.Dial("tcp", addr, dialConf)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
// Write ping
|
|
conn.Write([]byte("ping"))
|
|
}()
|
|
|
|
serverConn, err := list.Accept()
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
defer serverConn.Close()
|
|
|
|
// Read the pign
|
|
buf := make([]byte, 4)
|
|
serverConn.Read(buf)
|
|
|
|
// Grab the current state
|
|
connState := serverConn.(*tls.Conn).ConnectionState()
|
|
return connState
|
|
}
|
|
|
|
func Test_Renew(t *testing.T) {
|
|
storage := &logical.InmemStorage{}
|
|
|
|
lb, err := Factory(&logical.BackendConfig{
|
|
System: &logical.StaticSystemView{
|
|
DefaultLeaseTTLVal: 300 * time.Second,
|
|
MaxLeaseTTLVal: 1800 * time.Second,
|
|
},
|
|
StorageView: storage,
|
|
})
|
|
if err != nil {
|
|
t.Fatal("error: %s", err)
|
|
}
|
|
|
|
b := lb.(*backend)
|
|
connState := testConnState(t, "test-fixtures/keys/cert.pem",
|
|
"test-fixtures/keys/key.pem", "test-fixtures/root/rootcacert.pem")
|
|
ca, err := ioutil.ReadFile("test-fixtures/root/rootcacert.pem")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
req := &logical.Request{
|
|
Connection: &logical.Connection{
|
|
ConnState: &connState,
|
|
},
|
|
Storage: storage,
|
|
Auth: &logical.Auth{},
|
|
}
|
|
|
|
fd := &framework.FieldData{
|
|
Raw: map[string]interface{}{
|
|
"name": "test",
|
|
"certificate": ca,
|
|
"policies": "foo,bar",
|
|
},
|
|
Schema: pathCerts(b).Fields,
|
|
}
|
|
|
|
resp, err := b.pathCertWrite(req, fd)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
resp, err = b.pathLogin(req, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
req.Auth.InternalData = resp.Auth.InternalData
|
|
req.Auth.Metadata = resp.Auth.Metadata
|
|
req.Auth.LeaseOptions = resp.Auth.LeaseOptions
|
|
req.Auth.Policies = resp.Auth.Policies
|
|
req.Auth.IssueTime = time.Now()
|
|
|
|
// Normal renewal
|
|
resp, err = b.pathLoginRenew(req, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if resp == nil {
|
|
t.Fatal("got nil response from renew")
|
|
}
|
|
if resp.IsError() {
|
|
t.Fatalf("got error: %#v", *resp)
|
|
}
|
|
|
|
// Change the policies -- this should fail
|
|
fd.Raw["policies"] = "zip,zap"
|
|
resp, err = b.pathCertWrite(req, fd)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
resp, err = b.pathLoginRenew(req, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if resp == nil {
|
|
t.Fatal("got nil response from renew")
|
|
}
|
|
if !resp.IsError() {
|
|
t.Fatal("expected error")
|
|
}
|
|
|
|
// Put the policies back, this shold be okay
|
|
fd.Raw["policies"] = "bar,foo"
|
|
resp, err = b.pathCertWrite(req, fd)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
resp, err = b.pathLoginRenew(req, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if resp == nil {
|
|
t.Fatal("got nil response from renew")
|
|
}
|
|
if resp.IsError() {
|
|
t.Fatal("got error: %#v", *resp)
|
|
}
|
|
|
|
// Delete CA, make sure we can't renew
|
|
resp, err = b.pathCertDelete(req, fd)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
resp, err = b.pathLoginRenew(req, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if resp == nil {
|
|
t.Fatal("got nil response from renew")
|
|
}
|
|
if !resp.IsError() {
|
|
t.Fatal("expected error")
|
|
}
|
|
}
|