357 lines
9.0 KiB
Go
357 lines
9.0 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package testcluster
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/hashicorp/go-multierror"
|
|
"github.com/hashicorp/go-uuid"
|
|
"github.com/hashicorp/vault/api"
|
|
"github.com/hashicorp/vault/sdk/helper/xor"
|
|
)
|
|
|
|
// Note that OSS standbys will not accept seal requests. And ent perf standbys
|
|
// may fail it as well if they haven't yet been able to get "elected" as perf standbys.
|
|
func SealNode(ctx context.Context, cluster VaultCluster, nodeIdx int) error {
|
|
if nodeIdx >= len(cluster.Nodes()) {
|
|
return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx)
|
|
}
|
|
node := cluster.Nodes()[nodeIdx]
|
|
client := node.APIClient()
|
|
|
|
err := client.Sys().SealWithContext(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return NodeSealed(ctx, cluster, nodeIdx)
|
|
}
|
|
|
|
func SealAllNodes(ctx context.Context, cluster VaultCluster) error {
|
|
for i := range cluster.Nodes() {
|
|
if err := SealNode(ctx, cluster, i); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func UnsealNode(ctx context.Context, cluster VaultCluster, nodeIdx int) error {
|
|
if nodeIdx >= len(cluster.Nodes()) {
|
|
return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx)
|
|
}
|
|
node := cluster.Nodes()[nodeIdx]
|
|
client := node.APIClient()
|
|
|
|
for _, key := range cluster.GetBarrierOrRecoveryKeys() {
|
|
_, err := client.Sys().UnsealWithContext(ctx, hex.EncodeToString(key))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return NodeHealthy(ctx, cluster, nodeIdx)
|
|
}
|
|
|
|
func UnsealAllNodes(ctx context.Context, cluster VaultCluster) error {
|
|
for i := range cluster.Nodes() {
|
|
if err := UnsealNode(ctx, cluster, i); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func NodeSealed(ctx context.Context, cluster VaultCluster, nodeIdx int) error {
|
|
if nodeIdx >= len(cluster.Nodes()) {
|
|
return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx)
|
|
}
|
|
node := cluster.Nodes()[nodeIdx]
|
|
client := node.APIClient()
|
|
|
|
var health *api.HealthResponse
|
|
var err error
|
|
for ctx.Err() == nil {
|
|
health, err = client.Sys().HealthWithContext(ctx)
|
|
switch {
|
|
case err != nil:
|
|
case !health.Sealed:
|
|
err = fmt.Errorf("unsealed: %#v", health)
|
|
default:
|
|
return nil
|
|
}
|
|
time.Sleep(500 * time.Millisecond)
|
|
}
|
|
return fmt.Errorf("node %d is not sealed: %v", nodeIdx, err)
|
|
}
|
|
|
|
func WaitForNCoresSealed(ctx context.Context, cluster VaultCluster, n int) error {
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
errs := make(chan error)
|
|
for i := range cluster.Nodes() {
|
|
go func(i int) {
|
|
var err error
|
|
for ctx.Err() == nil {
|
|
err = NodeSealed(ctx, cluster, i)
|
|
if err == nil {
|
|
errs <- nil
|
|
return
|
|
}
|
|
time.Sleep(100 * time.Millisecond)
|
|
}
|
|
if err == nil {
|
|
err = ctx.Err()
|
|
}
|
|
errs <- err
|
|
}(i)
|
|
}
|
|
|
|
var merr *multierror.Error
|
|
var sealed int
|
|
for range cluster.Nodes() {
|
|
err := <-errs
|
|
if err != nil {
|
|
merr = multierror.Append(merr, err)
|
|
} else {
|
|
sealed++
|
|
if sealed == n {
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
return fmt.Errorf("%d cores were not sealed, errs: %v", n, merr.ErrorOrNil())
|
|
}
|
|
|
|
func NodeHealthy(ctx context.Context, cluster VaultCluster, nodeIdx int) error {
|
|
if nodeIdx >= len(cluster.Nodes()) {
|
|
return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx)
|
|
}
|
|
node := cluster.Nodes()[nodeIdx]
|
|
client := node.APIClient()
|
|
|
|
var health *api.HealthResponse
|
|
var err error
|
|
for ctx.Err() == nil {
|
|
health, err = client.Sys().HealthWithContext(ctx)
|
|
switch {
|
|
case err != nil:
|
|
case health == nil:
|
|
err = fmt.Errorf("nil response to health check")
|
|
case health.Sealed:
|
|
err = fmt.Errorf("sealed: %#v", health)
|
|
default:
|
|
return nil
|
|
}
|
|
time.Sleep(500 * time.Millisecond)
|
|
}
|
|
return fmt.Errorf("node %d is unhealthy: %v", nodeIdx, err)
|
|
}
|
|
|
|
func LeaderNode(ctx context.Context, cluster VaultCluster) (int, error) {
|
|
for i, node := range cluster.Nodes() {
|
|
client := node.APIClient()
|
|
ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
|
|
resp, err := client.Sys().LeaderWithContext(ctx)
|
|
cancel()
|
|
if err != nil || resp == nil || !resp.IsSelf {
|
|
continue
|
|
}
|
|
return i, nil
|
|
}
|
|
return -1, fmt.Errorf("no leader found")
|
|
}
|
|
|
|
func WaitForActiveNode(ctx context.Context, cluster VaultCluster) (int, error) {
|
|
for ctx.Err() == nil {
|
|
if idx, _ := LeaderNode(ctx, cluster); idx != -1 {
|
|
return idx, nil
|
|
}
|
|
time.Sleep(500 * time.Millisecond)
|
|
}
|
|
return -1, ctx.Err()
|
|
}
|
|
|
|
func WaitForActiveNodeAndPerfStandbys(ctx context.Context, cluster VaultCluster) error {
|
|
logger := cluster.NamedLogger("WaitForActiveNodeAndPerfStandbys")
|
|
// This WaitForActiveNode was added because after a Raft cluster is sealed
|
|
// and then unsealed, when it comes up it may have a different leader than
|
|
// Core0, making this helper fail.
|
|
// A sleep before calling WaitForActiveNodeAndPerfStandbys seems to sort
|
|
// things out, but so apparently does this. We should be able to eliminate
|
|
// this call to WaitForActiveNode by reworking the logic in this method.
|
|
if _, err := WaitForActiveNode(ctx, cluster); err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(cluster.Nodes()) == 1 {
|
|
return nil
|
|
}
|
|
|
|
expectedStandbys := len(cluster.Nodes()) - 1
|
|
|
|
mountPoint, err := uuid.GenerateUUID()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
leaderClient := cluster.Nodes()[0].APIClient()
|
|
|
|
for ctx.Err() == nil {
|
|
err = leaderClient.Sys().MountWithContext(ctx, mountPoint, &api.MountInput{
|
|
Type: "kv",
|
|
Local: true,
|
|
})
|
|
if err == nil {
|
|
break
|
|
}
|
|
time.Sleep(1 * time.Second)
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("unable to mount KV engine: %v", err)
|
|
}
|
|
path := mountPoint + "/waitforactivenodeandperfstandbys"
|
|
var standbys, actives int64
|
|
errchan := make(chan error, len(cluster.Nodes()))
|
|
for i := range cluster.Nodes() {
|
|
go func(coreNo int) {
|
|
node := cluster.Nodes()[coreNo]
|
|
client := node.APIClient()
|
|
val := 1
|
|
var err error
|
|
defer func() {
|
|
errchan <- err
|
|
}()
|
|
|
|
var lastWAL uint64
|
|
for ctx.Err() == nil {
|
|
_, err = leaderClient.Logical().WriteWithContext(ctx, path, map[string]interface{}{
|
|
"bar": val,
|
|
})
|
|
val++
|
|
time.Sleep(250 * time.Millisecond)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
var leader *api.LeaderResponse
|
|
leader, err = client.Sys().LeaderWithContext(ctx)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
switch {
|
|
case leader.IsSelf:
|
|
logger.Trace("waiting for core", "core", coreNo, "isLeader", true)
|
|
atomic.AddInt64(&actives, 1)
|
|
return
|
|
case leader.PerfStandby && leader.PerfStandbyLastRemoteWAL > 0:
|
|
switch {
|
|
case lastWAL == 0:
|
|
lastWAL = leader.PerfStandbyLastRemoteWAL
|
|
logger.Trace("waiting for core", "core", coreNo, "lastRemoteWAL", leader.PerfStandbyLastRemoteWAL, "lastWAL", lastWAL)
|
|
case lastWAL < leader.PerfStandbyLastRemoteWAL:
|
|
logger.Trace("waiting for core", "core", coreNo, "lastRemoteWAL", leader.PerfStandbyLastRemoteWAL, "lastWAL", lastWAL)
|
|
atomic.AddInt64(&standbys, 1)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}(i)
|
|
}
|
|
|
|
errs := make([]error, 0, len(cluster.Nodes()))
|
|
for range cluster.Nodes() {
|
|
errs = append(errs, <-errchan)
|
|
}
|
|
if actives != 1 || int(standbys) != expectedStandbys {
|
|
return fmt.Errorf("expected 1 active core and %d standbys, got %d active and %d standbys, errs: %v",
|
|
expectedStandbys, actives, standbys, errs)
|
|
}
|
|
|
|
for ctx.Err() == nil {
|
|
err = leaderClient.Sys().UnmountWithContext(ctx, mountPoint)
|
|
if err == nil {
|
|
break
|
|
}
|
|
time.Sleep(time.Second)
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("unable to unmount KV engine on primary")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type GenerateRootKind int
|
|
|
|
const (
|
|
GenerateRootRegular GenerateRootKind = iota
|
|
GenerateRootDR
|
|
GenerateRecovery
|
|
)
|
|
|
|
func GenerateRoot(cluster VaultCluster, kind GenerateRootKind) (string, error) {
|
|
// If recovery keys supported, use those to perform root token generation instead
|
|
keys := cluster.GetBarrierOrRecoveryKeys()
|
|
|
|
client := cluster.Nodes()[0].APIClient()
|
|
|
|
var err error
|
|
var status *api.GenerateRootStatusResponse
|
|
switch kind {
|
|
case GenerateRootRegular:
|
|
status, err = client.Sys().GenerateRootInit("", "")
|
|
case GenerateRootDR:
|
|
status, err = client.Sys().GenerateDROperationTokenInit("", "")
|
|
case GenerateRecovery:
|
|
status, err = client.Sys().GenerateRecoveryOperationTokenInit("", "")
|
|
}
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if status.Required > len(keys) {
|
|
return "", fmt.Errorf("need more keys than have, need %d have %d", status.Required, len(keys))
|
|
}
|
|
|
|
otp := status.OTP
|
|
|
|
for i, key := range keys {
|
|
if i >= status.Required {
|
|
break
|
|
}
|
|
|
|
strKey := base64.StdEncoding.EncodeToString(key)
|
|
switch kind {
|
|
case GenerateRootRegular:
|
|
status, err = client.Sys().GenerateRootUpdate(strKey, status.Nonce)
|
|
case GenerateRootDR:
|
|
status, err = client.Sys().GenerateDROperationTokenUpdate(strKey, status.Nonce)
|
|
case GenerateRecovery:
|
|
status, err = client.Sys().GenerateRecoveryOperationTokenUpdate(strKey, status.Nonce)
|
|
}
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
if !status.Complete {
|
|
return "", fmt.Errorf("generate root operation did not end successfully")
|
|
}
|
|
|
|
tokenBytes, err := base64.RawStdEncoding.DecodeString(status.EncodedToken)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
tokenBytes, err = xor.XORBytes(tokenBytes, []byte(otp))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return string(tokenBytes), nil
|
|
}
|