open-vault/vault/cluster_test.go

375 lines
10 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package vault
import (
"bytes"
"context"
"crypto/tls"
"fmt"
"net/http"
"sync"
"testing"
"time"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/testhelpers/corehelpers"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/physical"
"github.com/hashicorp/vault/sdk/physical/inmem"
"github.com/hashicorp/vault/vault/cluster"
)
var clusterTestPausePeriod = 2 * time.Second
func TestClusterFetching(t *testing.T) {
c, _, _ := TestCoreUnsealed(t)
err := c.setupCluster(context.Background())
if err != nil {
t.Fatal(err)
}
cluster, err := c.Cluster(context.Background())
if err != nil {
t.Fatal(err)
}
// Test whether expected values are found
if cluster == nil || cluster.Name == "" || cluster.ID == "" {
t.Fatalf("cluster information missing: cluster: %#v", cluster)
}
}
func TestClusterHAFetching(t *testing.T) {
logger := logging.NewVaultLogger(log.Trace)
redirect := "http://127.0.0.1:8200"
inm, err := inmem.NewInmemHA(nil, logger)
if err != nil {
t.Fatal(err)
}
inmha, err := inmem.NewInmemHA(nil, logger)
if err != nil {
t.Fatal(err)
}
c, err := NewCore(&CoreConfig{
Physical: inm,
HAPhysical: inmha.(physical.HABackend),
RedirectAddr: redirect,
DisableMlock: true,
})
if err != nil {
t.Fatalf("err: %v", err)
}
defer c.Shutdown()
keys, _ := TestCoreInit(t, c)
for _, key := range keys {
if _, err := TestCoreUnseal(c, TestKeyCopy(key)); err != nil {
t.Fatalf("unseal err: %s", err)
}
}
// Verify unsealed
if c.Sealed() {
t.Fatal("should not be sealed")
}
// Wait for core to become active
TestWaitActive(t, c)
cluster, err := c.Cluster(context.Background())
if err != nil {
t.Fatal(err)
}
// Test whether expected values are found
if cluster == nil || cluster.Name == "" || cluster.ID == "" {
t.Fatalf("cluster information missing: cluster:%#v", cluster)
}
}
func TestCluster_ListenForRequests(t *testing.T) {
// Make this nicer for tests
manualStepDownSleepPeriod = 5 * time.Second
cluster := NewTestCluster(t, nil, &TestClusterOptions{
KeepStandbysSealed: true,
})
cluster.Start()
defer cluster.Cleanup()
cores := cluster.Cores
// Wait for core to become active
TestWaitActive(t, cores[0].Core)
clusterListener := cores[0].getClusterListener()
clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{cores[0].Core})
addrs := cores[0].getClusterListener().Addrs()
// Use this to have a valid config after sealing since ClusterTLSConfig returns nil
checkListenersFunc := func(expectFail bool) {
dialer := clusterListener.GetDialerFunc(context.Background(), consts.RequestForwardingALPN)
for i := range cores[0].Listeners {
clnAddr := addrs[i]
netConn, err := dialer(clnAddr.String(), 0)
if err != nil {
if expectFail {
t.Logf("testing %s unsuccessful as expected", clnAddr)
continue
}
t.Fatalf("error: %v\ncluster listener is %s", err, clnAddr)
}
if expectFail {
t.Fatalf("testing %s not unsuccessful as expected", clnAddr)
}
conn := netConn.(*tls.Conn)
err = conn.Handshake()
if err != nil {
t.Fatal(err)
}
connState := conn.ConnectionState()
switch {
case connState.Version != tls.VersionTLS12 && connState.Version != tls.VersionTLS13:
t.Fatal("version mismatch")
case connState.NegotiatedProtocol != consts.RequestForwardingALPN || !connState.NegotiatedProtocolIsMutual:
t.Fatal("bad protocol negotiation")
}
t.Logf("testing %s successful", clnAddr)
}
}
time.Sleep(clusterTestPausePeriod)
checkListenersFunc(false)
err := cores[0].StepDown(context.Background(), &logical.Request{
Operation: logical.UpdateOperation,
Path: "sys/step-down",
ClientToken: cluster.RootToken,
})
if err != nil {
t.Fatal(err)
}
// StepDown doesn't wait during actual preSeal so give time for listeners
// to close
time.Sleep(clusterTestPausePeriod)
checkListenersFunc(true)
// After this period it should be active again
TestWaitActive(t, cores[0].Core)
cores[0].getClusterListener().AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{cores[0].Core})
checkListenersFunc(false)
err = cores[0].Core.Seal(cluster.RootToken)
if err != nil {
t.Fatal(err)
}
time.Sleep(clusterTestPausePeriod)
// After sealing it should be inactive again
checkListenersFunc(true)
}
func TestCluster_ForwardRequests(t *testing.T) {
// Make this nicer for tests
manualStepDownSleepPeriod = 5 * time.Second
t.Run("tcpLayer", func(t *testing.T) {
testCluster_ForwardRequestsCommon(t, nil)
})
t.Run("inmemLayer", func(t *testing.T) {
// Run again with in-memory network
inmemCluster, err := cluster.NewInmemLayerCluster("inmem-cluster", 3, log.New(&log.LoggerOptions{
Mutex: &sync.Mutex{},
Level: log.Trace,
Name: "inmem-cluster",
}))
if err != nil {
t.Fatal(err)
}
testCluster_ForwardRequestsCommon(t, &TestClusterOptions{
ClusterLayers: inmemCluster,
})
})
}
func testCluster_ForwardRequestsCommon(t *testing.T, clusterOpts *TestClusterOptions) {
cluster := NewTestCluster(t, nil, clusterOpts)
cores := cluster.Cores
cores[0].Handler.(*http.ServeMux).HandleFunc("/core1", func(w http.ResponseWriter, req *http.Request) {
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(201)
w.Write([]byte("core1"))
})
cores[1].Handler.(*http.ServeMux).HandleFunc("/core2", func(w http.ResponseWriter, req *http.Request) {
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(202)
w.Write([]byte("core2"))
})
cores[2].Handler.(*http.ServeMux).HandleFunc("/core3", func(w http.ResponseWriter, req *http.Request) {
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(203)
w.Write([]byte("core3"))
})
cluster.Start()
defer cluster.Cleanup()
root := cluster.RootToken
// Wait for core to become active
TestWaitActiveForwardingReady(t, cores[0].Core)
// Test forwarding a request. Since we're going directly from core to core
// with no fallback we know that if it worked, request handling is working
testCluster_ForwardRequests(t, cores[1], root, "core1")
testCluster_ForwardRequests(t, cores[2], root, "core1")
//
// Now we do a bunch of round-robining. The point is to make sure that as
// nodes come and go, we can always successfully forward to the active
// node.
//
// Ensure active core is cores[1] and test
testCluster_Forwarding(t, cluster, 0, 1, root, "core2")
// Ensure active core is cores[2] and test
testCluster_Forwarding(t, cluster, 1, 2, root, "core3")
// Ensure active core is cores[0] and test
testCluster_Forwarding(t, cluster, 2, 0, root, "core1")
// Ensure active core is cores[1] and test
testCluster_Forwarding(t, cluster, 0, 1, root, "core2")
// Ensure active core is cores[2] and test
testCluster_Forwarding(t, cluster, 1, 2, root, "core3")
}
func testCluster_Forwarding(t *testing.T, cluster *TestCluster, oldLeaderCoreIdx, newLeaderCoreIdx int, rootToken, remoteCoreID string) {
t.Logf("new leaderidx will be %d, stepping down other cores to make it so", newLeaderCoreIdx)
err := cluster.Cores[oldLeaderCoreIdx].StepDown(context.Background(), &logical.Request{
Operation: logical.UpdateOperation,
Path: "sys/step-down",
ClientToken: rootToken,
})
if err != nil {
t.Fatal(err)
}
time.Sleep(clusterTestPausePeriod)
for i := 0; i < 3; i++ {
if i != oldLeaderCoreIdx && i != newLeaderCoreIdx {
_ = cluster.Cores[i].StepDown(context.Background(), &logical.Request{
Operation: logical.UpdateOperation,
Path: "sys/step-down",
ClientToken: rootToken,
})
time.Sleep(clusterTestPausePeriod)
}
}
TestWaitActiveForwardingReady(t, cluster.Cores[newLeaderCoreIdx].Core)
deadline := time.Now().Add(5 * time.Second)
var ready int
for time.Now().Before(deadline) {
for i := 0; i < 3; i++ {
if i != newLeaderCoreIdx {
leaderParams := cluster.Cores[i].clusterLeaderParams.Load().(*ClusterLeaderParams)
if leaderParams != nil && leaderParams.LeaderClusterAddr == cluster.Cores[newLeaderCoreIdx].ClusterAddr() {
ready++
}
}
}
if ready == 2 {
break
}
ready = 0
time.Sleep(100 * time.Millisecond)
}
if ready != 2 {
t.Fatal("standbys have not discovered the new active node in time")
}
for i := 0; i < 3; i++ {
if i != newLeaderCoreIdx {
testCluster_ForwardRequests(t, cluster.Cores[i], rootToken, remoteCoreID)
}
}
}
func testCluster_ForwardRequests(t *testing.T, c *TestClusterCore, rootToken, remoteCoreID string) {
t.Helper()
standby, err := c.Standby()
if err != nil {
t.Fatal(err)
}
if !standby {
t.Fatal("expected core to be standby")
}
// We need to call Leader as that refreshes the connection info
isLeader, _, _, err := c.Leader()
if err != nil {
t.Fatal(err)
}
if isLeader {
t.Fatal("core should not be leader")
}
corehelpers.RetryUntil(t, 5*time.Second, func() error {
state := c.ActiveNodeReplicationState()
if state == 0 {
return fmt.Errorf("heartbeats have not yet returned a valid active node replication state: %d", state)
}
return nil
})
bodBuf := bytes.NewReader([]byte(`{ "foo": "bar", "zip": "zap" }`))
req, err := http.NewRequest("PUT", "https://pushit.real.good:9281/"+remoteCoreID, bodBuf)
if err != nil {
t.Fatal(err)
}
req.Header.Add(consts.AuthHeaderName, rootToken)
req = req.WithContext(context.WithValue(req.Context(), "original_request_path", req.URL.Path))
statusCode, header, respBytes, err := c.ForwardRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if header == nil {
t.Fatal("err: expected at least a content-type header")
}
if header.Get("Content-Type") != "application/json" {
t.Fatalf("bad content-type: %s", header.Get("Content-Type"))
}
body := string(respBytes)
if body != remoteCoreID {
t.Fatalf("expected %s, got %s", remoteCoreID, body)
}
switch body {
case "core1":
if statusCode != 201 {
t.Fatal("bad response")
}
case "core2":
if statusCode != 202 {
t.Fatal("bad response")
}
case "core3":
if statusCode != 203 {
t.Fatal("bad response")
}
}
}