1003 lines
27 KiB
Go
1003 lines
27 KiB
Go
package testhelpers
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"math/rand"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/armon/go-metrics"
|
|
raftlib "github.com/hashicorp/raft"
|
|
"github.com/hashicorp/vault/api"
|
|
"github.com/hashicorp/vault/helper/metricsutil"
|
|
"github.com/hashicorp/vault/helper/namespace"
|
|
"github.com/hashicorp/vault/physical/raft"
|
|
"github.com/hashicorp/vault/sdk/helper/xor"
|
|
"github.com/hashicorp/vault/vault"
|
|
"github.com/mitchellh/go-testing-interface"
|
|
)
|
|
|
|
type GenerateRootKind int
|
|
|
|
const (
|
|
GenerateRootRegular GenerateRootKind = iota
|
|
GenerateRootDR
|
|
GenerateRecovery
|
|
)
|
|
|
|
// GenerateRoot generates a root token on the target cluster.
|
|
func GenerateRoot(t testing.T, cluster *vault.TestCluster, kind GenerateRootKind) string {
|
|
t.Helper()
|
|
token, err := GenerateRootWithError(t, cluster, kind)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return token
|
|
}
|
|
|
|
func GenerateRootWithError(t testing.T, cluster *vault.TestCluster, kind GenerateRootKind) (string, error) {
|
|
t.Helper()
|
|
// If recovery keys supported, use those to perform root token generation instead
|
|
var keys [][]byte
|
|
if cluster.Cores[0].SealAccess().RecoveryKeySupported() {
|
|
keys = cluster.RecoveryKeys
|
|
} else {
|
|
keys = cluster.BarrierKeys
|
|
}
|
|
client := cluster.Cores[0].Client
|
|
|
|
var err error
|
|
var status *api.GenerateRootStatusResponse
|
|
switch kind {
|
|
case GenerateRootRegular:
|
|
status, err = client.Sys().GenerateRootInit("", "")
|
|
case GenerateRootDR:
|
|
status, err = client.Sys().GenerateDROperationTokenInit("", "")
|
|
case GenerateRecovery:
|
|
status, err = client.Sys().GenerateRecoveryOperationTokenInit("", "")
|
|
}
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if status.Required > len(keys) {
|
|
return "", fmt.Errorf("need more keys than have, need %d have %d", status.Required, len(keys))
|
|
}
|
|
|
|
otp := status.OTP
|
|
|
|
for i, key := range keys {
|
|
if i >= status.Required {
|
|
break
|
|
}
|
|
|
|
strKey := base64.StdEncoding.EncodeToString(key)
|
|
switch kind {
|
|
case GenerateRootRegular:
|
|
status, err = client.Sys().GenerateRootUpdate(strKey, status.Nonce)
|
|
case GenerateRootDR:
|
|
status, err = client.Sys().GenerateDROperationTokenUpdate(strKey, status.Nonce)
|
|
case GenerateRecovery:
|
|
status, err = client.Sys().GenerateRecoveryOperationTokenUpdate(strKey, status.Nonce)
|
|
}
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
if !status.Complete {
|
|
return "", errors.New("generate root operation did not end successfully")
|
|
}
|
|
|
|
tokenBytes, err := base64.RawStdEncoding.DecodeString(status.EncodedToken)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
tokenBytes, err = xor.XORBytes(tokenBytes, []byte(otp))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return string(tokenBytes), nil
|
|
}
|
|
|
|
// RandomWithPrefix is used to generate a unique name with a prefix, for
|
|
// randomizing names in acceptance tests
|
|
func RandomWithPrefix(name string) string {
|
|
return fmt.Sprintf("%s-%d", name, rand.New(rand.NewSource(time.Now().UnixNano())).Int())
|
|
}
|
|
|
|
func EnsureCoresSealed(t testing.T, c *vault.TestCluster) {
|
|
t.Helper()
|
|
for _, core := range c.Cores {
|
|
EnsureCoreSealed(t, core)
|
|
}
|
|
}
|
|
|
|
func EnsureCoreSealed(t testing.T, core *vault.TestClusterCore) {
|
|
t.Helper()
|
|
core.Seal(t)
|
|
timeout := time.Now().Add(60 * time.Second)
|
|
for {
|
|
if time.Now().After(timeout) {
|
|
t.Fatal("timeout waiting for core to seal")
|
|
}
|
|
if core.Core.Sealed() {
|
|
break
|
|
}
|
|
time.Sleep(250 * time.Millisecond)
|
|
}
|
|
}
|
|
|
|
func EnsureCoresUnsealed(t testing.T, c *vault.TestCluster) {
|
|
t.Helper()
|
|
for i, core := range c.Cores {
|
|
err := AttemptUnsealCore(c, core)
|
|
if err != nil {
|
|
t.Fatalf("failed to unseal core %d: %v", i, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func EnsureCoreUnsealed(t testing.T, c *vault.TestCluster, core *vault.TestClusterCore) {
|
|
t.Helper()
|
|
err := AttemptUnsealCore(c, core)
|
|
if err != nil {
|
|
t.Fatalf("failed to unseal core: %v", err)
|
|
}
|
|
}
|
|
|
|
func AttemptUnsealCores(c *vault.TestCluster) error {
|
|
for i, core := range c.Cores {
|
|
err := AttemptUnsealCore(c, core)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to unseal core %d: %v", i, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func AttemptUnsealCore(c *vault.TestCluster, core *vault.TestClusterCore) error {
|
|
if !core.Sealed() {
|
|
return nil
|
|
}
|
|
|
|
core.SealAccess().ClearCaches(context.Background())
|
|
if err := core.UnsealWithStoredKeys(context.Background()); err != nil {
|
|
return err
|
|
}
|
|
|
|
client := core.Client
|
|
client.Sys().ResetUnsealProcess()
|
|
for j := 0; j < len(c.BarrierKeys); j++ {
|
|
statusResp, err := client.Sys().Unseal(base64.StdEncoding.EncodeToString(c.BarrierKeys[j]))
|
|
if err != nil {
|
|
// Sometimes when we get here it's already unsealed on its own
|
|
// and then this fails for DR secondaries so check again
|
|
if core.Sealed() {
|
|
return err
|
|
} else {
|
|
return nil
|
|
}
|
|
}
|
|
if statusResp == nil {
|
|
return fmt.Errorf("nil status response during unseal")
|
|
}
|
|
if !statusResp.Sealed {
|
|
break
|
|
}
|
|
}
|
|
if core.Sealed() {
|
|
return fmt.Errorf("core is still sealed")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func EnsureStableActiveNode(t testing.T, cluster *vault.TestCluster) {
|
|
t.Helper()
|
|
deriveStableActiveCore(t, cluster)
|
|
}
|
|
|
|
func DeriveStableActiveCore(t testing.T, cluster *vault.TestCluster) *vault.TestClusterCore {
|
|
t.Helper()
|
|
return deriveStableActiveCore(t, cluster)
|
|
}
|
|
|
|
func deriveStableActiveCore(t testing.T, cluster *vault.TestCluster) *vault.TestClusterCore {
|
|
t.Helper()
|
|
activeCore := DeriveActiveCore(t, cluster)
|
|
minDuration := time.NewTimer(3 * time.Second)
|
|
|
|
for i := 0; i < 60; i++ {
|
|
leaderResp, err := activeCore.Client.Sys().Leader()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !leaderResp.IsSelf {
|
|
minDuration.Reset(3 * time.Second)
|
|
}
|
|
time.Sleep(200 * time.Millisecond)
|
|
}
|
|
|
|
select {
|
|
case <-minDuration.C:
|
|
default:
|
|
if stopped := minDuration.Stop(); stopped {
|
|
t.Fatal("unstable active node")
|
|
}
|
|
// Drain the value
|
|
<-minDuration.C
|
|
}
|
|
|
|
return activeCore
|
|
}
|
|
|
|
func DeriveActiveCore(t testing.T, cluster *vault.TestCluster) *vault.TestClusterCore {
|
|
t.Helper()
|
|
for i := 0; i < 60; i++ {
|
|
for _, core := range cluster.Cores {
|
|
leaderResp, err := core.Client.Sys().Leader()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if leaderResp.IsSelf {
|
|
return core
|
|
}
|
|
}
|
|
time.Sleep(1 * time.Second)
|
|
}
|
|
t.Fatal("could not derive the active core")
|
|
return nil
|
|
}
|
|
|
|
func DeriveStandbyCores(t testing.T, cluster *vault.TestCluster) []*vault.TestClusterCore {
|
|
t.Helper()
|
|
cores := make([]*vault.TestClusterCore, 0, 2)
|
|
for _, core := range cluster.Cores {
|
|
leaderResp, err := core.Client.Sys().Leader()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !leaderResp.IsSelf {
|
|
cores = append(cores, core)
|
|
}
|
|
}
|
|
|
|
return cores
|
|
}
|
|
|
|
func WaitForNCoresUnsealed(t testing.T, cluster *vault.TestCluster, n int) {
|
|
t.Helper()
|
|
for i := 0; i < 30; i++ {
|
|
unsealed := 0
|
|
for _, core := range cluster.Cores {
|
|
if !core.Core.Sealed() {
|
|
unsealed++
|
|
}
|
|
}
|
|
|
|
if unsealed >= n {
|
|
return
|
|
}
|
|
time.Sleep(time.Second)
|
|
}
|
|
|
|
t.Fatalf("%d cores were not unsealed", n)
|
|
}
|
|
|
|
func SealCores(t testing.T, cluster *vault.TestCluster) {
|
|
t.Helper()
|
|
for _, core := range cluster.Cores {
|
|
if err := core.Shutdown(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
timeout := time.Now().Add(3 * time.Second)
|
|
for {
|
|
if time.Now().After(timeout) {
|
|
t.Fatal("timeout waiting for core to seal")
|
|
}
|
|
if core.Sealed() {
|
|
break
|
|
}
|
|
time.Sleep(100 * time.Millisecond)
|
|
}
|
|
}
|
|
}
|
|
|
|
func WaitForNCoresSealed(t testing.T, cluster *vault.TestCluster, n int) {
|
|
t.Helper()
|
|
for i := 0; i < 60; i++ {
|
|
sealed := 0
|
|
for _, core := range cluster.Cores {
|
|
if core.Core.Sealed() {
|
|
sealed++
|
|
}
|
|
}
|
|
|
|
if sealed >= n {
|
|
return
|
|
}
|
|
time.Sleep(time.Second)
|
|
}
|
|
|
|
t.Fatalf("%d cores were not sealed", n)
|
|
}
|
|
|
|
func WaitForActiveNode(t testing.T, cluster *vault.TestCluster) *vault.TestClusterCore {
|
|
t.Helper()
|
|
for i := 0; i < 60; i++ {
|
|
for _, core := range cluster.Cores {
|
|
if standby, _ := core.Core.Standby(); !standby {
|
|
return core
|
|
}
|
|
}
|
|
|
|
time.Sleep(time.Second)
|
|
}
|
|
|
|
t.Fatalf("node did not become active")
|
|
return nil
|
|
}
|
|
|
|
func WaitForStandbyNode(t testing.T, core *vault.TestClusterCore) {
|
|
t.Helper()
|
|
for i := 0; i < 30; i++ {
|
|
if isLeader, _, clusterAddr, _ := core.Core.Leader(); isLeader != true && clusterAddr != "" {
|
|
return
|
|
}
|
|
if core.Core.ActiveNodeReplicationState() == 0 {
|
|
return
|
|
}
|
|
|
|
time.Sleep(time.Second)
|
|
}
|
|
|
|
t.Fatalf("node did not become standby")
|
|
}
|
|
|
|
func RekeyCluster(t testing.T, cluster *vault.TestCluster, recovery bool) [][]byte {
|
|
t.Helper()
|
|
cluster.Logger.Info("rekeying cluster", "recovery", recovery)
|
|
client := cluster.Cores[0].Client
|
|
|
|
initFunc := client.Sys().RekeyInit
|
|
if recovery {
|
|
initFunc = client.Sys().RekeyRecoveryKeyInit
|
|
}
|
|
init, err := initFunc(&api.RekeyInitRequest{
|
|
SecretShares: 5,
|
|
SecretThreshold: 3,
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
var statusResp *api.RekeyUpdateResponse
|
|
keys := cluster.BarrierKeys
|
|
if cluster.Cores[0].Core.SealAccess().RecoveryKeySupported() {
|
|
keys = cluster.RecoveryKeys
|
|
}
|
|
|
|
updateFunc := client.Sys().RekeyUpdate
|
|
if recovery {
|
|
updateFunc = client.Sys().RekeyRecoveryKeyUpdate
|
|
}
|
|
for j := 0; j < len(keys); j++ {
|
|
statusResp, err = updateFunc(base64.StdEncoding.EncodeToString(keys[j]), init.Nonce)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if statusResp == nil {
|
|
t.Fatal("nil status response during unseal")
|
|
}
|
|
if statusResp.Complete {
|
|
break
|
|
}
|
|
}
|
|
cluster.Logger.Info("cluster rekeyed", "recovery", recovery)
|
|
|
|
if cluster.Cores[0].Core.SealAccess().RecoveryKeySupported() && !recovery {
|
|
return nil
|
|
}
|
|
if len(statusResp.KeysB64) != 5 {
|
|
t.Fatal("wrong number of keys")
|
|
}
|
|
|
|
newKeys := make([][]byte, 5)
|
|
for i, key := range statusResp.KeysB64 {
|
|
newKeys[i], err = base64.StdEncoding.DecodeString(key)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
return newKeys
|
|
}
|
|
|
|
// TestRaftServerAddressProvider is a ServerAddressProvider that uses the
|
|
// ClusterAddr() of each node to provide raft addresses.
|
|
//
|
|
// Note that TestRaftServerAddressProvider should only be used in cases where
|
|
// cores that are part of a raft configuration have already had
|
|
// startClusterListener() called (via either unsealing or raft joining).
|
|
type TestRaftServerAddressProvider struct {
|
|
Cluster *vault.TestCluster
|
|
}
|
|
|
|
func (p *TestRaftServerAddressProvider) ServerAddr(id raftlib.ServerID) (raftlib.ServerAddress, error) {
|
|
for _, core := range p.Cluster.Cores {
|
|
if core.NodeID == string(id) {
|
|
parsed, err := url.Parse(core.ClusterAddr())
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return raftlib.ServerAddress(parsed.Host), nil
|
|
}
|
|
}
|
|
|
|
return "", errors.New("could not find cluster addr")
|
|
}
|
|
|
|
func RaftClusterJoinNodes(t testing.T, cluster *vault.TestCluster) {
|
|
addressProvider := &TestRaftServerAddressProvider{Cluster: cluster}
|
|
|
|
atomic.StoreUint32(&vault.TestingUpdateClusterAddr, 1)
|
|
|
|
leader := cluster.Cores[0]
|
|
|
|
// Seal the leader so we can install an address provider
|
|
{
|
|
EnsureCoreSealed(t, leader)
|
|
leader.UnderlyingRawStorage.(*raft.RaftBackend).SetServerAddressProvider(addressProvider)
|
|
cluster.UnsealCore(t, leader)
|
|
vault.TestWaitActive(t, leader.Core)
|
|
}
|
|
|
|
leaderInfos := []*raft.LeaderJoinInfo{
|
|
{
|
|
LeaderAPIAddr: leader.Client.Address(),
|
|
TLSConfig: leader.TLSConfig(),
|
|
},
|
|
}
|
|
|
|
// Join followers
|
|
for i := 1; i < len(cluster.Cores); i++ {
|
|
core := cluster.Cores[i]
|
|
core.UnderlyingRawStorage.(*raft.RaftBackend).SetServerAddressProvider(addressProvider)
|
|
_, err := core.JoinRaftCluster(namespace.RootContext(context.Background()), leaderInfos, false)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
cluster.UnsealCore(t, core)
|
|
}
|
|
|
|
WaitForNCoresUnsealed(t, cluster, len(cluster.Cores))
|
|
}
|
|
|
|
// HardcodedServerAddressProvider is a ServerAddressProvider that uses
|
|
// a hardcoded map of raft node addresses.
|
|
//
|
|
// It is useful in cases where the raft configuration is known ahead of time,
|
|
// but some of the cores have not yet had startClusterListener() called (via
|
|
// either unsealing or raft joining), and thus do not yet have a ClusterAddr()
|
|
// assigned.
|
|
type HardcodedServerAddressProvider struct {
|
|
Entries map[raftlib.ServerID]raftlib.ServerAddress
|
|
}
|
|
|
|
func (p *HardcodedServerAddressProvider) ServerAddr(id raftlib.ServerID) (raftlib.ServerAddress, error) {
|
|
if addr, ok := p.Entries[id]; ok {
|
|
return addr, nil
|
|
}
|
|
return "", errors.New("could not find cluster addr")
|
|
}
|
|
|
|
// NewHardcodedServerAddressProvider is a convenience function that makes a
|
|
// ServerAddressProvider from a given cluster address base port.
|
|
func NewHardcodedServerAddressProvider(numCores, baseClusterPort int) raftlib.ServerAddressProvider {
|
|
entries := make(map[raftlib.ServerID]raftlib.ServerAddress)
|
|
|
|
for i := 0; i < numCores; i++ {
|
|
id := fmt.Sprintf("core-%d", i)
|
|
addr := fmt.Sprintf("127.0.0.1:%d", baseClusterPort+i)
|
|
entries[raftlib.ServerID(id)] = raftlib.ServerAddress(addr)
|
|
}
|
|
|
|
return &HardcodedServerAddressProvider{
|
|
entries,
|
|
}
|
|
}
|
|
|
|
// VerifyRaftConfiguration checks that we have a valid raft configuration, i.e.
|
|
// the correct number of servers, having the correct NodeIDs, and exactly one
|
|
// leader.
|
|
func VerifyRaftConfiguration(core *vault.TestClusterCore, numCores int) error {
|
|
backend := core.UnderlyingRawStorage.(*raft.RaftBackend)
|
|
ctx := namespace.RootContext(context.Background())
|
|
config, err := backend.GetConfiguration(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
servers := config.Servers
|
|
if len(servers) != numCores {
|
|
return fmt.Errorf("Found %d servers, not %d", len(servers), numCores)
|
|
}
|
|
|
|
leaders := 0
|
|
for i, s := range servers {
|
|
if s.NodeID != fmt.Sprintf("core-%d", i) {
|
|
return fmt.Errorf("Found unexpected node ID %q", s.NodeID)
|
|
}
|
|
if s.Leader {
|
|
leaders++
|
|
}
|
|
}
|
|
|
|
if leaders != 1 {
|
|
return fmt.Errorf("Found %d leaders", leaders)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func RaftAppliedIndex(core *vault.TestClusterCore) uint64 {
|
|
return core.UnderlyingRawStorage.(*raft.RaftBackend).AppliedIndex()
|
|
}
|
|
|
|
func WaitForRaftApply(t testing.T, core *vault.TestClusterCore, index uint64) {
|
|
t.Helper()
|
|
|
|
backend := core.UnderlyingRawStorage.(*raft.RaftBackend)
|
|
for i := 0; i < 30; i++ {
|
|
if backend.AppliedIndex() >= index {
|
|
return
|
|
}
|
|
|
|
time.Sleep(time.Second)
|
|
}
|
|
|
|
t.Fatalf("node did not apply index")
|
|
}
|
|
|
|
// AwaitLeader waits for one of the cluster's nodes to become leader.
|
|
func AwaitLeader(t testing.T, cluster *vault.TestCluster) (int, error) {
|
|
timeout := time.Now().Add(60 * time.Second)
|
|
for {
|
|
if time.Now().After(timeout) {
|
|
break
|
|
}
|
|
|
|
for i, core := range cluster.Cores {
|
|
if core.Core.Sealed() {
|
|
continue
|
|
}
|
|
|
|
isLeader, _, _, _ := core.Leader()
|
|
if isLeader {
|
|
return i, nil
|
|
}
|
|
}
|
|
|
|
time.Sleep(time.Second)
|
|
}
|
|
|
|
return 0, fmt.Errorf("timeout waiting leader")
|
|
}
|
|
|
|
func GenerateDebugLogs(t testing.T, client *api.Client) chan struct{} {
|
|
t.Helper()
|
|
|
|
stopCh := make(chan struct{})
|
|
|
|
go func() {
|
|
ticker := time.NewTicker(time.Second)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-stopCh:
|
|
return
|
|
case <-ticker.C:
|
|
err := client.Sys().Mount("foo", &api.MountInput{
|
|
Type: "kv",
|
|
Options: map[string]string{
|
|
"version": "1",
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
err = client.Sys().Unmount("foo")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
return stopCh
|
|
}
|
|
|
|
// VerifyRaftPeers verifies that the raft configuration contains a given set of peers.
|
|
// The `expected` contains a map of expected peers. Existing entries are deleted
|
|
// from the map by removing entries whose keys are in the raft configuration.
|
|
// Remaining entries result in an error return so that the caller can poll for
|
|
// an expected configuration.
|
|
func VerifyRaftPeers(t testing.T, client *api.Client, expected map[string]bool) error {
|
|
t.Helper()
|
|
|
|
resp, err := client.Logical().Read("sys/storage/raft/configuration")
|
|
if err != nil {
|
|
t.Fatalf("error reading raft config: %v", err)
|
|
}
|
|
|
|
if resp == nil || resp.Data == nil {
|
|
t.Fatal("missing response data")
|
|
}
|
|
|
|
config, ok := resp.Data["config"].(map[string]interface{})
|
|
if !ok {
|
|
t.Fatal("missing config in response data")
|
|
}
|
|
|
|
servers, ok := config["servers"].([]interface{})
|
|
if !ok {
|
|
t.Fatal("missing servers in response data config")
|
|
}
|
|
|
|
// Iterate through the servers and remove the node found in the response
|
|
// from the expected collection
|
|
for _, s := range servers {
|
|
server := s.(map[string]interface{})
|
|
delete(expected, server["node_id"].(string))
|
|
}
|
|
|
|
// If the collection is non-empty, it means that the peer was not found in
|
|
// the response.
|
|
if len(expected) != 0 {
|
|
return fmt.Errorf("failed to read configuration successfully, expected peers not found in configuration list: %v", expected)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func TestMetricSinkProvider(gaugeInterval time.Duration) func(string) (*metricsutil.ClusterMetricSink, *metricsutil.MetricsHelper) {
|
|
return func(clusterName string) (*metricsutil.ClusterMetricSink, *metricsutil.MetricsHelper) {
|
|
inm := metrics.NewInmemSink(1000000*time.Hour, 2000000*time.Hour)
|
|
clusterSink := metricsutil.NewClusterMetricSink(clusterName, inm)
|
|
clusterSink.GaugeInterval = gaugeInterval
|
|
return clusterSink, metricsutil.NewMetricsHelper(inm, false)
|
|
}
|
|
}
|
|
|
|
func SysMetricsReq(client *api.Client, cluster *vault.TestCluster, unauth bool) (*SysMetricsJSON, error) {
|
|
r := client.NewRequest("GET", "/v1/sys/metrics")
|
|
if !unauth {
|
|
r.Headers.Set("X-Vault-Token", cluster.RootToken)
|
|
}
|
|
var data SysMetricsJSON
|
|
resp, err := client.RawRequestWithContext(context.Background(), r)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
bodyBytes, err := ioutil.ReadAll(resp.Response.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
if err := json.Unmarshal(bodyBytes, &data); err != nil {
|
|
return nil, errors.New("failed to unmarshal:" + err.Error())
|
|
}
|
|
return &data, nil
|
|
}
|
|
|
|
type SysMetricsJSON struct {
|
|
Gauges []gaugeJSON `json:"Gauges"`
|
|
Counters []counterJSON `json:"Counters"`
|
|
|
|
// note: this is referred to as a "Summary" type in our telemetry docs, but
|
|
// the field name in the JSON is "Samples"
|
|
Summaries []summaryJSON `json:"Samples"`
|
|
}
|
|
|
|
type baseInfoJSON struct {
|
|
Name string `json:"Name"`
|
|
Labels map[string]interface{} `json:"Labels"`
|
|
}
|
|
|
|
type gaugeJSON struct {
|
|
baseInfoJSON
|
|
Value int `json:"Value"`
|
|
}
|
|
|
|
type counterJSON struct {
|
|
baseInfoJSON
|
|
Count int `json:"Count"`
|
|
Rate float64 `json:"Rate"`
|
|
Sum int `json:"Sum"`
|
|
Min int `json:"Min"`
|
|
Max int `json:"Max"`
|
|
Mean float64 `json:"Mean"`
|
|
Stddev float64 `json:"Stddev"`
|
|
}
|
|
|
|
type summaryJSON struct {
|
|
baseInfoJSON
|
|
Count int `json:"Count"`
|
|
Rate float64 `json:"Rate"`
|
|
Sum float64 `json:"Sum"`
|
|
Min float64 `json:"Min"`
|
|
Max float64 `json:"Max"`
|
|
Mean float64 `json:"Mean"`
|
|
Stddev float64 `json:"Stddev"`
|
|
}
|
|
|
|
// SetNonRootToken sets a token on :client: with a fairly generic policy.
|
|
// This is useful if a test needs to examine differing behavior based on if a
|
|
// root token is passed with the request.
|
|
func SetNonRootToken(client *api.Client) error {
|
|
policy := `path "*" { capabilities = ["create", "update", "read"] }`
|
|
if err := client.Sys().PutPolicy("policy", policy); err != nil {
|
|
return fmt.Errorf("error putting policy: %v", err)
|
|
}
|
|
|
|
secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{
|
|
Policies: []string{"policy"},
|
|
TTL: "30m",
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("error creating token secret: %v", err)
|
|
}
|
|
|
|
if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" {
|
|
return fmt.Errorf("missing token auth data")
|
|
}
|
|
|
|
client.SetToken(secret.Auth.ClientToken)
|
|
return nil
|
|
}
|
|
|
|
// RetryUntilAtCadence runs f until it returns a nil result or the timeout is reached.
|
|
// If a nil result hasn't been obtained by timeout, calls t.Fatal.
|
|
func RetryUntilAtCadence(t testing.T, timeout, sleepTime time.Duration, f func() error) {
|
|
t.Helper()
|
|
deadline := time.Now().Add(timeout)
|
|
var err error
|
|
for time.Now().Before(deadline) {
|
|
if err = f(); err == nil {
|
|
return
|
|
}
|
|
time.Sleep(sleepTime)
|
|
}
|
|
t.Fatalf("did not complete before deadline, err: %v", err)
|
|
}
|
|
|
|
// RetryUntil runs f until it returns a nil result or the timeout is reached.
|
|
// If a nil result hasn't been obtained by timeout, calls t.Fatal.
|
|
func RetryUntil(t testing.T, timeout time.Duration, f func() error) {
|
|
t.Helper()
|
|
deadline := time.Now().Add(timeout)
|
|
var err error
|
|
for time.Now().Before(deadline) {
|
|
if err = f(); err == nil {
|
|
return
|
|
}
|
|
time.Sleep(100 * time.Millisecond)
|
|
}
|
|
t.Fatalf("did not complete before deadline, err: %v", err)
|
|
}
|
|
|
|
// CreateEntityAndAlias clones an existing client and creates an entity/alias.
|
|
// It returns the cloned client, entityID, and aliasID.
|
|
func CreateEntityAndAlias(t testing.T, client *api.Client, mountAccessor, entityName, aliasName string) (*api.Client, string, string) {
|
|
t.Helper()
|
|
userClient, err := client.Clone()
|
|
if err != nil {
|
|
t.Fatalf("failed to clone the client:%v", err)
|
|
}
|
|
userClient.SetToken(client.Token())
|
|
|
|
resp, err := client.Logical().WriteWithContext(context.Background(), "identity/entity", map[string]interface{}{
|
|
"name": entityName,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("failed to create an entity:%v", err)
|
|
}
|
|
entityID := resp.Data["id"].(string)
|
|
|
|
aliasResp, err := client.Logical().WriteWithContext(context.Background(), "identity/entity-alias", map[string]interface{}{
|
|
"name": aliasName,
|
|
"canonical_id": entityID,
|
|
"mount_accessor": mountAccessor,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("failed to create an entity alias:%v", err)
|
|
}
|
|
aliasID := aliasResp.Data["id"].(string)
|
|
if aliasID == "" {
|
|
t.Fatal("Alias ID not present in response")
|
|
}
|
|
_, err = client.Logical().WriteWithContext(context.Background(), fmt.Sprintf("auth/userpass/users/%s", aliasName), map[string]interface{}{
|
|
"password": "testpassword",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("failed to configure userpass backend: %v", err)
|
|
}
|
|
|
|
return userClient, entityID, aliasID
|
|
}
|
|
|
|
// SetupTOTPMount enables the totp secrets engine by mounting it. This requires
|
|
// that the test cluster has a totp backend available.
|
|
func SetupTOTPMount(t testing.T, client *api.Client) {
|
|
t.Helper()
|
|
// Mount the TOTP backend
|
|
mountInfo := &api.MountInput{
|
|
Type: "totp",
|
|
}
|
|
if err := client.Sys().Mount("totp", mountInfo); err != nil {
|
|
t.Fatalf("failed to mount totp backend: %v", err)
|
|
}
|
|
}
|
|
|
|
// SetupTOTPMethod configures the TOTP secrets engine with a provided config map.
|
|
func SetupTOTPMethod(t testing.T, client *api.Client, config map[string]interface{}) string {
|
|
t.Helper()
|
|
|
|
resp1, err := client.Logical().Write("identity/mfa/method/totp", config)
|
|
|
|
if err != nil || (resp1 == nil) {
|
|
t.Fatalf("bad: resp: %#v\n err: %v", resp1, err)
|
|
}
|
|
|
|
methodID := resp1.Data["method_id"].(string)
|
|
if methodID == "" {
|
|
t.Fatalf("method ID is empty")
|
|
}
|
|
|
|
return methodID
|
|
}
|
|
|
|
// SetupMFALoginEnforcement configures a single enforcement method using the
|
|
// provided config map. "name" field is required in the config map.
|
|
func SetupMFALoginEnforcement(t testing.T, client *api.Client, config map[string]interface{}) {
|
|
t.Helper()
|
|
enfName, ok := config["name"]
|
|
if !ok {
|
|
t.Fatalf("couldn't find name in login-enforcement config")
|
|
}
|
|
_, err := client.Logical().WriteWithContext(context.Background(), fmt.Sprintf("identity/mfa/login-enforcement/%s", enfName), config)
|
|
if err != nil {
|
|
t.Fatalf("failed to configure MFAEnforcementConfig: %v", err)
|
|
}
|
|
}
|
|
|
|
// SetupUserpassMountAccessor sets up userpass auth and returns its mount
|
|
// accessor. This requires that the test cluster has a "userpass" auth method
|
|
// available.
|
|
func SetupUserpassMountAccessor(t testing.T, client *api.Client) string {
|
|
t.Helper()
|
|
// Enable Userpass authentication
|
|
err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{
|
|
Type: "userpass",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("failed to enable userpass auth: %v", err)
|
|
}
|
|
|
|
auths, err := client.Sys().ListAuthWithContext(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("failed to list auth methods: %v", err)
|
|
}
|
|
if auths == nil || auths["userpass/"] == nil {
|
|
t.Fatalf("failed to get userpass mount accessor")
|
|
}
|
|
|
|
return auths["userpass/"].Accessor
|
|
}
|
|
|
|
// RegisterEntityInTOTPEngine registers an entity with a methodID and returns
|
|
// the generated name.
|
|
func RegisterEntityInTOTPEngine(t testing.T, client *api.Client, entityID, methodID string) string {
|
|
t.Helper()
|
|
totpGenName := fmt.Sprintf("%s-%s", entityID, methodID)
|
|
secret, err := client.Logical().WriteWithContext(context.Background(), "identity/mfa/method/totp/admin-generate", map[string]interface{}{
|
|
"entity_id": entityID,
|
|
"method_id": methodID,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("failed to generate a TOTP secret on an entity: %v", err)
|
|
}
|
|
totpURL := secret.Data["url"].(string)
|
|
if totpURL == "" {
|
|
t.Fatalf("failed to get TOTP url in secret response: %+v", secret)
|
|
}
|
|
_, err = client.Logical().WriteWithContext(context.Background(), fmt.Sprintf("totp/keys/%s", totpGenName), map[string]interface{}{
|
|
"url": totpURL,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("failed to register a TOTP URL: %v", err)
|
|
}
|
|
enfPath := fmt.Sprintf("identity/mfa/login-enforcement/%s", methodID[0:4])
|
|
_, err = client.Logical().WriteWithContext(context.Background(), enfPath, map[string]interface{}{
|
|
"name": methodID[0:4],
|
|
"identity_entity_ids": []string{entityID},
|
|
"mfa_method_ids": []string{methodID},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("failed to create login enforcement")
|
|
}
|
|
|
|
return totpGenName
|
|
}
|
|
|
|
// GetTOTPCodeFromEngine requests a TOTP code from the specified enginePath.
|
|
func GetTOTPCodeFromEngine(t testing.T, client *api.Client, enginePath string) string {
|
|
t.Helper()
|
|
totpPath := fmt.Sprintf("totp/code/%s", enginePath)
|
|
secret, err := client.Logical().ReadWithContext(context.Background(), totpPath)
|
|
if err != nil {
|
|
t.Fatalf("failed to create totp passcode: %v", err)
|
|
}
|
|
if secret == nil || secret.Data == nil {
|
|
t.Fatalf("bad secret returned from %s", totpPath)
|
|
}
|
|
return secret.Data["code"].(string)
|
|
}
|
|
|
|
// SetupLoginMFATOTP setups up a TOTP MFA using some basic configuration and
|
|
// returns all relevant information to the client.
|
|
func SetupLoginMFATOTP(t testing.T, client *api.Client, methodName string, waitPeriod int) (*api.Client, string, string) {
|
|
t.Helper()
|
|
// Mount the totp secrets engine
|
|
SetupTOTPMount(t, client)
|
|
|
|
// Create a mount accessor to associate with an entity
|
|
mountAccessor := SetupUserpassMountAccessor(t, client)
|
|
|
|
// Create a test entity and alias
|
|
entityClient, entityID, _ := CreateEntityAndAlias(t, client, mountAccessor, "entity1", "testuser1")
|
|
|
|
// Configure a default TOTP method
|
|
totpConfig := map[string]interface{}{
|
|
"issuer": "yCorp",
|
|
"period": waitPeriod,
|
|
"algorithm": "SHA256",
|
|
"digits": 6,
|
|
"skew": 1,
|
|
"key_size": 20,
|
|
"qr_size": 200,
|
|
"max_validation_attempts": 5,
|
|
"method_name": methodName,
|
|
}
|
|
methodID := SetupTOTPMethod(t, client, totpConfig)
|
|
|
|
// Configure a default login enforcement
|
|
enforcementConfig := map[string]interface{}{
|
|
"auth_method_types": []string{"userpass"},
|
|
"name": methodID[0:4],
|
|
"mfa_method_ids": []string{methodID},
|
|
}
|
|
|
|
SetupMFALoginEnforcement(t, client, enforcementConfig)
|
|
return entityClient, entityID, methodID
|
|
}
|
|
|
|
func SkipUnlessEnvVarsSet(t testing.T, envVars []string) {
|
|
t.Helper()
|
|
|
|
for _, i := range envVars {
|
|
if os.Getenv(i) == "" {
|
|
t.Skipf("%s must be set for this test to run", strings.Join(envVars, " "))
|
|
}
|
|
}
|
|
}
|