27bb03bbc0
* adding copyright header * fix fmt and a test
609 lines
16 KiB
Go
609 lines
16 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package http
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"math/rand"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/net/http2"
|
|
|
|
cleanhttp "github.com/hashicorp/go-cleanhttp"
|
|
"github.com/hashicorp/vault/api"
|
|
credCert "github.com/hashicorp/vault/builtin/credential/cert"
|
|
"github.com/hashicorp/vault/builtin/logical/transit"
|
|
"github.com/hashicorp/vault/sdk/helper/consts"
|
|
"github.com/hashicorp/vault/sdk/helper/keysutil"
|
|
"github.com/hashicorp/vault/sdk/logical"
|
|
"github.com/hashicorp/vault/vault"
|
|
)
|
|
|
|
func TestHTTP_Fallback_Bad_Address(t *testing.T) {
|
|
coreConfig := &vault.CoreConfig{
|
|
LogicalBackends: map[string]logical.Factory{
|
|
"transit": transit.Factory,
|
|
},
|
|
ClusterAddr: "https://127.3.4.1:8382",
|
|
}
|
|
|
|
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
|
|
HandlerFunc: Handler,
|
|
})
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
cores := cluster.Cores
|
|
|
|
// make it easy to get access to the active
|
|
core := cores[0].Core
|
|
vault.TestWaitActive(t, core)
|
|
|
|
addrs := []string{
|
|
fmt.Sprintf("https://127.0.0.1:%d", cores[1].Listeners[0].Address.Port),
|
|
fmt.Sprintf("https://127.0.0.1:%d", cores[2].Listeners[0].Address.Port),
|
|
}
|
|
|
|
for _, addr := range addrs {
|
|
config := api.DefaultConfig()
|
|
config.Address = addr
|
|
config.HttpClient.Transport.(*http.Transport).TLSClientConfig = cores[0].TLSConfig()
|
|
|
|
client, err := api.NewClient(config)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
client.SetToken(cluster.RootToken)
|
|
|
|
secret, err := client.Auth().Token().LookupSelf()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if secret == nil {
|
|
t.Fatal("secret is nil")
|
|
}
|
|
if secret.Data["id"].(string) != cluster.RootToken {
|
|
t.Fatal("token mismatch")
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestHTTP_Fallback_Disabled(t *testing.T) {
|
|
coreConfig := &vault.CoreConfig{
|
|
LogicalBackends: map[string]logical.Factory{
|
|
"transit": transit.Factory,
|
|
},
|
|
ClusterAddr: "empty",
|
|
}
|
|
|
|
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
|
|
HandlerFunc: Handler,
|
|
})
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
cores := cluster.Cores
|
|
|
|
// make it easy to get access to the active
|
|
core := cores[0].Core
|
|
vault.TestWaitActive(t, core)
|
|
|
|
addrs := []string{
|
|
fmt.Sprintf("https://127.0.0.1:%d", cores[1].Listeners[0].Address.Port),
|
|
fmt.Sprintf("https://127.0.0.1:%d", cores[2].Listeners[0].Address.Port),
|
|
}
|
|
|
|
for _, addr := range addrs {
|
|
config := api.DefaultConfig()
|
|
config.Address = addr
|
|
config.HttpClient.Transport.(*http.Transport).TLSClientConfig = cores[0].TLSConfig()
|
|
|
|
client, err := api.NewClient(config)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
client.SetToken(cluster.RootToken)
|
|
|
|
secret, err := client.Auth().Token().LookupSelf()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if secret == nil {
|
|
t.Fatal("secret is nil")
|
|
}
|
|
if secret.Data["id"].(string) != cluster.RootToken {
|
|
t.Fatal("token mismatch")
|
|
}
|
|
}
|
|
}
|
|
|
|
// This function recreates the fuzzy testing from transit to pipe a large
|
|
// number of requests from the standbys to the active node.
|
|
func TestHTTP_Forwarding_Stress(t *testing.T) {
|
|
testHTTP_Forwarding_Stress_Common(t, false, 50)
|
|
testHTTP_Forwarding_Stress_Common(t, true, 50)
|
|
}
|
|
|
|
func testHTTP_Forwarding_Stress_Common(t *testing.T, parallel bool, num uint32) {
|
|
testPlaintext := "the quick brown fox"
|
|
testPlaintextB64 := "dGhlIHF1aWNrIGJyb3duIGZveA=="
|
|
|
|
coreConfig := &vault.CoreConfig{
|
|
LogicalBackends: map[string]logical.Factory{
|
|
"transit": transit.Factory,
|
|
},
|
|
}
|
|
|
|
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
|
|
HandlerFunc: Handler,
|
|
})
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
cores := cluster.Cores
|
|
|
|
// make it easy to get access to the active
|
|
core := cores[0].Core
|
|
vault.TestWaitActive(t, core)
|
|
|
|
wg := sync.WaitGroup{}
|
|
|
|
funcs := []string{"encrypt", "decrypt", "rotate", "change_min_version"}
|
|
keys := []string{"test1", "test2", "test3"}
|
|
|
|
hosts := []string{
|
|
fmt.Sprintf("https://127.0.0.1:%d/v1/transit/", cores[1].Listeners[0].Address.Port),
|
|
fmt.Sprintf("https://127.0.0.1:%d/v1/transit/", cores[2].Listeners[0].Address.Port),
|
|
}
|
|
|
|
transport := &http.Transport{
|
|
TLSClientConfig: cores[0].TLSConfig(),
|
|
}
|
|
if err := http2.ConfigureTransport(transport); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
client := &http.Client{
|
|
Transport: transport,
|
|
CheckRedirect: func(*http.Request, []*http.Request) error {
|
|
return fmt.Errorf("redirects not allowed in this test")
|
|
},
|
|
}
|
|
|
|
// core.Logger().Printf("[TRACE] mounting transit")
|
|
req, err := http.NewRequest("POST", fmt.Sprintf("https://127.0.0.1:%d/v1/sys/mounts/transit", cores[0].Listeners[0].Address.Port),
|
|
bytes.NewBuffer([]byte("{\"type\": \"transit\"}")))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
req.Header.Set(consts.AuthHeaderName, cluster.RootToken)
|
|
_, err = client.Do(req)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
// core.Logger().Printf("[TRACE] done mounting transit")
|
|
|
|
var totalOps *uint32 = new(uint32)
|
|
var successfulOps *uint32 = new(uint32)
|
|
var key1ver *int32 = new(int32)
|
|
*key1ver = 1
|
|
var key2ver *int32 = new(int32)
|
|
*key2ver = 1
|
|
var key3ver *int32 = new(int32)
|
|
*key3ver = 1
|
|
var numWorkers *uint32 = new(uint32)
|
|
*numWorkers = 50
|
|
var numWorkersStarted *uint32 = new(uint32)
|
|
var waitLock sync.Mutex
|
|
waitCond := sync.NewCond(&waitLock)
|
|
|
|
// This is the goroutine loop
|
|
doFuzzy := func(id int, parallel bool) {
|
|
var myTotalOps uint32
|
|
var mySuccessfulOps uint32
|
|
var keyVer int32 = 1
|
|
// Check for panics, otherwise notify we're done
|
|
defer func() {
|
|
if err := recover(); err != nil {
|
|
core.Logger().Error("got a panic", "error", err)
|
|
t.Fail()
|
|
}
|
|
atomic.AddUint32(totalOps, myTotalOps)
|
|
atomic.AddUint32(successfulOps, mySuccessfulOps)
|
|
wg.Done()
|
|
}()
|
|
|
|
// Holds the latest encrypted value for each key
|
|
latestEncryptedText := map[string]string{}
|
|
|
|
client := &http.Client{
|
|
Transport: transport,
|
|
}
|
|
|
|
var chosenFunc, chosenKey, chosenHost string
|
|
|
|
myRand := rand.New(rand.NewSource(int64(id) * 400))
|
|
|
|
doReq := func(method, url string, body io.Reader) (*http.Response, error) {
|
|
req, err := http.NewRequest(method, url, body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set(consts.AuthHeaderName, cluster.RootToken)
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
doResp := func(resp *http.Response) (*api.Secret, error) {
|
|
if resp == nil {
|
|
return nil, fmt.Errorf("nil response")
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Make sure we weren't redirected
|
|
if resp.StatusCode > 300 && resp.StatusCode < 400 {
|
|
return nil, fmt.Errorf("got status code %d, resp was %#v", resp.StatusCode, *resp)
|
|
}
|
|
|
|
result := &api.Response{Response: resp}
|
|
err := result.Error()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
secret, err := api.ParseSecret(result.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return secret, nil
|
|
}
|
|
|
|
for _, chosenHost := range hosts {
|
|
for _, chosenKey := range keys {
|
|
// Try to write the key to make sure it exists
|
|
_, err := doReq("POST", chosenHost+"keys/"+fmt.Sprintf("%s-%t", chosenKey, parallel), bytes.NewBuffer([]byte("{}")))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
if !parallel {
|
|
chosenHost = hosts[id%len(hosts)]
|
|
chosenKey = fmt.Sprintf("key-%t-%d", parallel, id)
|
|
|
|
_, err := doReq("POST", chosenHost+"keys/"+chosenKey, bytes.NewBuffer([]byte("{}")))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
atomic.AddUint32(numWorkersStarted, 1)
|
|
|
|
waitCond.L.Lock()
|
|
for atomic.LoadUint32(numWorkersStarted) != atomic.LoadUint32(numWorkers) {
|
|
waitCond.Wait()
|
|
}
|
|
waitCond.L.Unlock()
|
|
waitCond.Broadcast()
|
|
|
|
core.Logger().Debug("Starting goroutine", "id", id)
|
|
|
|
startTime := time.Now()
|
|
for {
|
|
// Stop after 10 seconds
|
|
if time.Now().Sub(startTime) > 10*time.Second {
|
|
return
|
|
}
|
|
|
|
myTotalOps++
|
|
|
|
// Pick a function and a key
|
|
chosenFunc = funcs[myRand.Int()%len(funcs)]
|
|
if parallel {
|
|
chosenKey = fmt.Sprintf("%s-%t", keys[myRand.Int()%len(keys)], parallel)
|
|
chosenHost = hosts[myRand.Int()%len(hosts)]
|
|
}
|
|
|
|
switch chosenFunc {
|
|
// Encrypt our plaintext and store the result
|
|
case "encrypt":
|
|
// core.Logger().Printf("[TRACE] %s, %s, %d", chosenFunc, chosenKey, id)
|
|
resp, err := doReq("POST", chosenHost+"encrypt/"+chosenKey, bytes.NewBuffer([]byte(fmt.Sprintf("{\"plaintext\": \"%s\"}", testPlaintextB64))))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
secret, err := doResp(resp)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
latest := secret.Data["ciphertext"].(string)
|
|
if latest == "" {
|
|
panic(fmt.Errorf("bad ciphertext"))
|
|
}
|
|
latestEncryptedText[chosenKey] = secret.Data["ciphertext"].(string)
|
|
|
|
mySuccessfulOps++
|
|
|
|
// Decrypt the ciphertext and compare the result
|
|
case "decrypt":
|
|
ct := latestEncryptedText[chosenKey]
|
|
if ct == "" {
|
|
mySuccessfulOps++
|
|
continue
|
|
}
|
|
|
|
// core.Logger().Printf("[TRACE] %s, %s, %d", chosenFunc, chosenKey, id)
|
|
resp, err := doReq("POST", chosenHost+"decrypt/"+chosenKey, bytes.NewBuffer([]byte(fmt.Sprintf("{\"ciphertext\": \"%s\"}", ct))))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
secret, err := doResp(resp)
|
|
if err != nil {
|
|
// This could well happen since the min version is jumping around
|
|
if strings.Contains(err.Error(), keysutil.ErrTooOld) {
|
|
mySuccessfulOps++
|
|
continue
|
|
}
|
|
panic(err)
|
|
}
|
|
|
|
ptb64 := secret.Data["plaintext"].(string)
|
|
pt, err := base64.StdEncoding.DecodeString(ptb64)
|
|
if err != nil {
|
|
panic(fmt.Errorf("got an error decoding base64 plaintext: %v", err))
|
|
}
|
|
if string(pt) != testPlaintext {
|
|
panic(fmt.Errorf("got bad plaintext back: %s", pt))
|
|
}
|
|
|
|
mySuccessfulOps++
|
|
|
|
// Rotate to a new key version
|
|
case "rotate":
|
|
// core.Logger().Printf("[TRACE] %s, %s, %d", chosenFunc, chosenKey, id)
|
|
_, err := doReq("POST", chosenHost+"keys/"+chosenKey+"/rotate", bytes.NewBuffer([]byte("{}")))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
if parallel {
|
|
switch chosenKey {
|
|
case "test1":
|
|
atomic.AddInt32(key1ver, 1)
|
|
case "test2":
|
|
atomic.AddInt32(key2ver, 1)
|
|
case "test3":
|
|
atomic.AddInt32(key3ver, 1)
|
|
}
|
|
} else {
|
|
keyVer++
|
|
}
|
|
|
|
mySuccessfulOps++
|
|
|
|
// Change the min version, which also tests the archive functionality
|
|
case "change_min_version":
|
|
var latestVersion int32 = keyVer
|
|
if parallel {
|
|
switch chosenKey {
|
|
case "test1":
|
|
latestVersion = atomic.LoadInt32(key1ver)
|
|
case "test2":
|
|
latestVersion = atomic.LoadInt32(key2ver)
|
|
case "test3":
|
|
latestVersion = atomic.LoadInt32(key3ver)
|
|
}
|
|
}
|
|
|
|
setVersion := (myRand.Int31() % latestVersion) + 1
|
|
|
|
// core.Logger().Printf("[TRACE] %s, %s, %d, new min version %d", chosenFunc, chosenKey, id, setVersion)
|
|
|
|
_, err := doReq("POST", chosenHost+"keys/"+chosenKey+"/config", bytes.NewBuffer([]byte(fmt.Sprintf("{\"min_decryption_version\": %d}", setVersion))))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
mySuccessfulOps++
|
|
}
|
|
}
|
|
}
|
|
|
|
atomic.StoreUint32(numWorkers, num)
|
|
|
|
// Spawn some of these workers for 10 seconds
|
|
for i := 0; i < int(atomic.LoadUint32(numWorkers)); i++ {
|
|
wg.Add(1)
|
|
// core.Logger().Printf("[TRACE] spawning %d", i)
|
|
go doFuzzy(i+1, parallel)
|
|
}
|
|
|
|
// Wait for them all to finish
|
|
wg.Wait()
|
|
|
|
if *totalOps == 0 || *totalOps != *successfulOps {
|
|
t.Fatalf("total/successful ops zero or mismatch: %d/%d; parallel: %t, num %d", *totalOps, *successfulOps, parallel, num)
|
|
}
|
|
t.Logf("total operations tried: %d, total successful: %d; parallel: %t, num %d", *totalOps, *successfulOps, parallel, num)
|
|
}
|
|
|
|
// This tests TLS connection state forwarding by ensuring that we can use a
|
|
// client TLS to authenticate against the cert backend
|
|
func TestHTTP_Forwarding_ClientTLS(t *testing.T) {
|
|
coreConfig := &vault.CoreConfig{
|
|
CredentialBackends: map[string]logical.Factory{
|
|
"cert": credCert.Factory,
|
|
},
|
|
}
|
|
|
|
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
|
|
HandlerFunc: Handler,
|
|
})
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
cores := cluster.Cores
|
|
|
|
// make it easy to get access to the active
|
|
core := cores[0].Core
|
|
vault.TestWaitActive(t, core)
|
|
|
|
transport := cleanhttp.DefaultTransport()
|
|
transport.TLSClientConfig = cores[0].TLSConfig()
|
|
if err := http2.ConfigureTransport(transport); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
client := &http.Client{
|
|
Transport: transport,
|
|
}
|
|
|
|
req, err := http.NewRequest("POST", fmt.Sprintf("https://127.0.0.1:%d/v1/sys/auth/cert", cores[0].Listeners[0].Address.Port),
|
|
bytes.NewBuffer([]byte("{\"type\": \"cert\"}")))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
req.Header.Set(consts.AuthHeaderName, cluster.RootToken)
|
|
_, err = client.Do(req)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
type certConfig struct {
|
|
Certificate string `json:"certificate"`
|
|
Policies string `json:"policies"`
|
|
}
|
|
encodedCertConfig, err := json.Marshal(&certConfig{
|
|
Certificate: string(cluster.CACertPEM),
|
|
Policies: "default",
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
req, err = http.NewRequest("POST", fmt.Sprintf("https://127.0.0.1:%d/v1/auth/cert/certs/test", cores[0].Listeners[0].Address.Port),
|
|
bytes.NewBuffer(encodedCertConfig))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
req.Header.Set(consts.AuthHeaderName, cluster.RootToken)
|
|
_, err = client.Do(req)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
addrs := []string{
|
|
fmt.Sprintf("https://127.0.0.1:%d", cores[1].Listeners[0].Address.Port),
|
|
fmt.Sprintf("https://127.0.0.1:%d", cores[2].Listeners[0].Address.Port),
|
|
}
|
|
|
|
for i, addr := range addrs {
|
|
// Ensure we can't possibly use lingering connections even though it should
|
|
// be to a different address
|
|
transport = cleanhttp.DefaultTransport()
|
|
// i starts at zero but cores in addrs start at 1
|
|
transport.TLSClientConfig = cores[i+1].TLSConfig()
|
|
if err := http2.ConfigureTransport(transport); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
httpClient := &http.Client{
|
|
Transport: transport,
|
|
CheckRedirect: func(*http.Request, []*http.Request) error {
|
|
return fmt.Errorf("redirects not allowed in this test")
|
|
},
|
|
}
|
|
client, err := api.NewClient(&api.Config{
|
|
Address: addr,
|
|
HttpClient: httpClient,
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
secret, err := client.Logical().Write("auth/cert/login", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if secret == nil {
|
|
t.Fatal("secret is nil")
|
|
}
|
|
if secret.Auth == nil {
|
|
t.Fatal("auth is nil")
|
|
}
|
|
if secret.Auth.Policies == nil || len(secret.Auth.Policies) == 0 || secret.Auth.Policies[0] != "default" {
|
|
t.Fatalf("bad policies: %#v", secret.Auth.Policies)
|
|
}
|
|
if secret.Auth.ClientToken == "" {
|
|
t.Fatalf("bad client token: %#v", *secret.Auth)
|
|
}
|
|
client.SetToken(secret.Auth.ClientToken)
|
|
secret, err = client.Auth().Token().LookupSelf()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if secret == nil {
|
|
t.Fatal("secret is nil")
|
|
}
|
|
if secret.Data == nil || len(secret.Data) == 0 {
|
|
t.Fatal("secret data was empty")
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestHTTP_Forwarding_HelpOperation(t *testing.T) {
|
|
cluster := vault.NewTestCluster(t, &vault.CoreConfig{}, &vault.TestClusterOptions{
|
|
HandlerFunc: Handler,
|
|
})
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
cores := cluster.Cores
|
|
|
|
vault.TestWaitActive(t, cores[0].Core)
|
|
|
|
testHelp := func(client *api.Client) {
|
|
help, err := client.Help("auth/token")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if help == nil {
|
|
t.Fatal("help was nil")
|
|
}
|
|
}
|
|
|
|
testHelp(cores[0].Client)
|
|
testHelp(cores[1].Client)
|
|
}
|
|
|
|
func TestHTTP_Forwarding_LocalOnly(t *testing.T) {
|
|
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
|
|
HandlerFunc: Handler,
|
|
})
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
cores := cluster.Cores
|
|
|
|
vault.TestWaitActive(t, cores[0].Core)
|
|
|
|
testLocalOnly := func(client *api.Client) {
|
|
_, err := client.Logical().Read("sys/config/state/sanitized")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
}
|
|
|
|
testLocalOnly(cores[1].Client)
|
|
testLocalOnly(cores[2].Client)
|
|
}
|