open-vault/builtin/credential/cert/backend_test.go

149 lines
3.6 KiB
Go
Raw Normal View History

package cert
import (
"crypto/tls"
"fmt"
"io/ioutil"
"testing"
"time"
"github.com/hashicorp/vault/logical"
logicaltest "github.com/hashicorp/vault/logical/testing"
)
// Test a client trusted by a CA
func TestBackend_basic_CA(t *testing.T) {
connState := testConnState(t, "../../../test/key/ourdomain.cer",
"../../../test/key/ourdomain.key")
ca, err := ioutil.ReadFile("../../../test/ca/root.cer")
if err != nil {
t.Fatalf("err: %v", err)
}
logicaltest.Test(t, logicaltest.TestCase{
Backend: Backend(),
Steps: []logicaltest.TestStep{
testAccStepCert(t, "web", ca, "foo"),
testAccStepLogin(t, connState),
},
})
}
// Test a self-signed client that is trusted
func TestBackend_basic_singleCert(t *testing.T) {
connState := testConnState(t, "../../../test/unsigned/cert.pem",
"../../../test/unsigned/key.pem")
ca, err := ioutil.ReadFile("../../../test/unsigned/cert.pem")
if err != nil {
t.Fatalf("err: %v", err)
}
logicaltest.Test(t, logicaltest.TestCase{
Backend: Backend(),
Steps: []logicaltest.TestStep{
testAccStepCert(t, "web", ca, "foo"),
testAccStepLogin(t, connState),
},
})
}
// Test an untrusted self-signed client
func TestBackend_untrusted(t *testing.T) {
connState := testConnState(t, "../../../test/unsigned/cert.pem",
"../../../test/unsigned/key.pem")
logicaltest.Test(t, logicaltest.TestCase{
Backend: Backend(),
Steps: []logicaltest.TestStep{
testAccStepLoginInvalid(t, connState),
},
})
}
func testAccStepLogin(t *testing.T, connState tls.ConnectionState) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.WriteOperation,
Path: "login",
Unauthenticated: true,
ConnState: &connState,
Check: func(resp *logical.Response) error {
if resp.Auth.Lease != 1000*time.Second {
t.Fatalf("bad lease length: %#v", resp.Auth)
}
fn := logicaltest.TestCheckAuth([]string{"foo"})
return fn(resp)
},
}
}
func testAccStepLoginInvalid(t *testing.T, connState tls.ConnectionState) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.WriteOperation,
Path: "login",
Unauthenticated: true,
ConnState: &connState,
Check: func(resp *logical.Response) error {
if resp.Auth != nil {
return fmt.Errorf("should not be authorized: %#v", resp)
}
return nil
},
ErrorOk: true,
}
}
func testAccStepCert(
t *testing.T, name string, cert []byte, policies string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.WriteOperation,
2015-04-24 17:31:57 +00:00
Path: "certs/" + name,
Data: map[string]interface{}{
"certificate": string(cert),
2015-04-24 17:31:57 +00:00
"policies": policies,
"display_name": name,
"lease": 1000,
},
}
}
func testConnState(t *testing.T, certPath, keyPath string) tls.ConnectionState {
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
t.Fatalf("err: %v", err)
}
conf := &tls.Config{
Certificates: []tls.Certificate{cert},
ClientAuth: tls.RequestClientCert,
InsecureSkipVerify: true,
}
list, err := tls.Listen("tcp", "127.0.0.1:0", conf)
if err != nil {
t.Fatalf("err: %v", err)
}
defer list.Close()
go func() {
addr := list.Addr().String()
conn, err := tls.Dial("tcp", addr, conf)
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
// Write ping
conn.Write([]byte("ping"))
}()
serverConn, err := list.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer serverConn.Close()
// Read the pign
buf := make([]byte, 4)
serverConn.Read(buf)
// Grab the current state
connState := serverConn.(*tls.Conn).ConnectionState()
return connState
}