// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package http import ( "bytes" "encoding/base64" "encoding/json" "fmt" "io" "math/rand" "net/http" "strings" "sync" "sync/atomic" "testing" "time" "golang.org/x/net/http2" cleanhttp "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/vault/api" credCert "github.com/hashicorp/vault/builtin/credential/cert" "github.com/hashicorp/vault/builtin/logical/transit" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/keysutil" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" ) func TestHTTP_Fallback_Bad_Address(t *testing.T) { coreConfig := &vault.CoreConfig{ LogicalBackends: map[string]logical.Factory{ "transit": transit.Factory, }, ClusterAddr: "https://127.3.4.1:8382", } cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ HandlerFunc: Handler, }) cluster.Start() defer cluster.Cleanup() cores := cluster.Cores // make it easy to get access to the active core := cores[0].Core vault.TestWaitActive(t, core) addrs := []string{ fmt.Sprintf("https://127.0.0.1:%d", cores[1].Listeners[0].Address.Port), fmt.Sprintf("https://127.0.0.1:%d", cores[2].Listeners[0].Address.Port), } for _, addr := range addrs { config := api.DefaultConfig() config.Address = addr config.HttpClient.Transport.(*http.Transport).TLSClientConfig = cores[0].TLSConfig() client, err := api.NewClient(config) if err != nil { t.Fatal(err) } client.SetToken(cluster.RootToken) secret, err := client.Auth().Token().LookupSelf() if err != nil { t.Fatal(err) } if secret == nil { t.Fatal("secret is nil") } if secret.Data["id"].(string) != cluster.RootToken { t.Fatal("token mismatch") } } } func TestHTTP_Fallback_Disabled(t *testing.T) { coreConfig := &vault.CoreConfig{ LogicalBackends: map[string]logical.Factory{ "transit": transit.Factory, }, ClusterAddr: "empty", } cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ HandlerFunc: Handler, }) cluster.Start() defer cluster.Cleanup() cores := cluster.Cores // make it easy to get access to the active core := cores[0].Core vault.TestWaitActive(t, core) addrs := []string{ fmt.Sprintf("https://127.0.0.1:%d", cores[1].Listeners[0].Address.Port), fmt.Sprintf("https://127.0.0.1:%d", cores[2].Listeners[0].Address.Port), } for _, addr := range addrs { config := api.DefaultConfig() config.Address = addr config.HttpClient.Transport.(*http.Transport).TLSClientConfig = cores[0].TLSConfig() client, err := api.NewClient(config) if err != nil { t.Fatal(err) } client.SetToken(cluster.RootToken) secret, err := client.Auth().Token().LookupSelf() if err != nil { t.Fatal(err) } if secret == nil { t.Fatal("secret is nil") } if secret.Data["id"].(string) != cluster.RootToken { t.Fatal("token mismatch") } } } // This function recreates the fuzzy testing from transit to pipe a large // number of requests from the standbys to the active node. func TestHTTP_Forwarding_Stress(t *testing.T) { testHTTP_Forwarding_Stress_Common(t, false, 50) testHTTP_Forwarding_Stress_Common(t, true, 50) } func testHTTP_Forwarding_Stress_Common(t *testing.T, parallel bool, num uint32) { testPlaintext := "the quick brown fox" testPlaintextB64 := "dGhlIHF1aWNrIGJyb3duIGZveA==" coreConfig := &vault.CoreConfig{ LogicalBackends: map[string]logical.Factory{ "transit": transit.Factory, }, } cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ HandlerFunc: Handler, }) cluster.Start() defer cluster.Cleanup() cores := cluster.Cores // make it easy to get access to the active core := cores[0].Core vault.TestWaitActive(t, core) wg := sync.WaitGroup{} funcs := []string{"encrypt", "decrypt", "rotate", "change_min_version"} keys := []string{"test1", "test2", "test3"} hosts := []string{ fmt.Sprintf("https://127.0.0.1:%d/v1/transit/", cores[1].Listeners[0].Address.Port), fmt.Sprintf("https://127.0.0.1:%d/v1/transit/", cores[2].Listeners[0].Address.Port), } transport := &http.Transport{ TLSClientConfig: cores[0].TLSConfig(), } if err := http2.ConfigureTransport(transport); err != nil { t.Fatal(err) } client := &http.Client{ Transport: transport, CheckRedirect: func(*http.Request, []*http.Request) error { return fmt.Errorf("redirects not allowed in this test") }, } // core.Logger().Printf("[TRACE] mounting transit") req, err := http.NewRequest("POST", fmt.Sprintf("https://127.0.0.1:%d/v1/sys/mounts/transit", cores[0].Listeners[0].Address.Port), bytes.NewBuffer([]byte("{\"type\": \"transit\"}"))) if err != nil { t.Fatal(err) } req.Header.Set(consts.AuthHeaderName, cluster.RootToken) _, err = client.Do(req) if err != nil { t.Fatal(err) } // core.Logger().Printf("[TRACE] done mounting transit") var totalOps *uint32 = new(uint32) var successfulOps *uint32 = new(uint32) var key1ver *int32 = new(int32) *key1ver = 1 var key2ver *int32 = new(int32) *key2ver = 1 var key3ver *int32 = new(int32) *key3ver = 1 var numWorkers *uint32 = new(uint32) *numWorkers = 50 var numWorkersStarted *uint32 = new(uint32) var waitLock sync.Mutex waitCond := sync.NewCond(&waitLock) // This is the goroutine loop doFuzzy := func(id int, parallel bool) { var myTotalOps uint32 var mySuccessfulOps uint32 var keyVer int32 = 1 // Check for panics, otherwise notify we're done defer func() { if err := recover(); err != nil { core.Logger().Error("got a panic", "error", err) t.Fail() } atomic.AddUint32(totalOps, myTotalOps) atomic.AddUint32(successfulOps, mySuccessfulOps) wg.Done() }() // Holds the latest encrypted value for each key latestEncryptedText := map[string]string{} client := &http.Client{ Transport: transport, } var chosenFunc, chosenKey, chosenHost string myRand := rand.New(rand.NewSource(int64(id) * 400)) doReq := func(method, url string, body io.Reader) (*http.Response, error) { req, err := http.NewRequest(method, url, body) if err != nil { return nil, err } req.Header.Set(consts.AuthHeaderName, cluster.RootToken) resp, err := client.Do(req) if err != nil { return nil, err } return resp, nil } doResp := func(resp *http.Response) (*api.Secret, error) { if resp == nil { return nil, fmt.Errorf("nil response") } defer resp.Body.Close() // Make sure we weren't redirected if resp.StatusCode > 300 && resp.StatusCode < 400 { return nil, fmt.Errorf("got status code %d, resp was %#v", resp.StatusCode, *resp) } result := &api.Response{Response: resp} err := result.Error() if err != nil { return nil, err } secret, err := api.ParseSecret(result.Body) if err != nil { return nil, err } return secret, nil } for _, chosenHost := range hosts { for _, chosenKey := range keys { // Try to write the key to make sure it exists _, err := doReq("POST", chosenHost+"keys/"+fmt.Sprintf("%s-%t", chosenKey, parallel), bytes.NewBuffer([]byte("{}"))) if err != nil { panic(err) } } } if !parallel { chosenHost = hosts[id%len(hosts)] chosenKey = fmt.Sprintf("key-%t-%d", parallel, id) _, err := doReq("POST", chosenHost+"keys/"+chosenKey, bytes.NewBuffer([]byte("{}"))) if err != nil { panic(err) } } atomic.AddUint32(numWorkersStarted, 1) waitCond.L.Lock() for atomic.LoadUint32(numWorkersStarted) != atomic.LoadUint32(numWorkers) { waitCond.Wait() } waitCond.L.Unlock() waitCond.Broadcast() core.Logger().Debug("Starting goroutine", "id", id) startTime := time.Now() for { // Stop after 10 seconds if time.Now().Sub(startTime) > 10*time.Second { return } myTotalOps++ // Pick a function and a key chosenFunc = funcs[myRand.Int()%len(funcs)] if parallel { chosenKey = fmt.Sprintf("%s-%t", keys[myRand.Int()%len(keys)], parallel) chosenHost = hosts[myRand.Int()%len(hosts)] } switch chosenFunc { // Encrypt our plaintext and store the result case "encrypt": // core.Logger().Printf("[TRACE] %s, %s, %d", chosenFunc, chosenKey, id) resp, err := doReq("POST", chosenHost+"encrypt/"+chosenKey, bytes.NewBuffer([]byte(fmt.Sprintf("{\"plaintext\": \"%s\"}", testPlaintextB64)))) if err != nil { panic(err) } secret, err := doResp(resp) if err != nil { panic(err) } latest := secret.Data["ciphertext"].(string) if latest == "" { panic(fmt.Errorf("bad ciphertext")) } latestEncryptedText[chosenKey] = secret.Data["ciphertext"].(string) mySuccessfulOps++ // Decrypt the ciphertext and compare the result case "decrypt": ct := latestEncryptedText[chosenKey] if ct == "" { mySuccessfulOps++ continue } // core.Logger().Printf("[TRACE] %s, %s, %d", chosenFunc, chosenKey, id) resp, err := doReq("POST", chosenHost+"decrypt/"+chosenKey, bytes.NewBuffer([]byte(fmt.Sprintf("{\"ciphertext\": \"%s\"}", ct)))) if err != nil { panic(err) } secret, err := doResp(resp) if err != nil { // This could well happen since the min version is jumping around if strings.Contains(err.Error(), keysutil.ErrTooOld) { mySuccessfulOps++ continue } panic(err) } ptb64 := secret.Data["plaintext"].(string) pt, err := base64.StdEncoding.DecodeString(ptb64) if err != nil { panic(fmt.Errorf("got an error decoding base64 plaintext: %v", err)) } if string(pt) != testPlaintext { panic(fmt.Errorf("got bad plaintext back: %s", pt)) } mySuccessfulOps++ // Rotate to a new key version case "rotate": // core.Logger().Printf("[TRACE] %s, %s, %d", chosenFunc, chosenKey, id) _, err := doReq("POST", chosenHost+"keys/"+chosenKey+"/rotate", bytes.NewBuffer([]byte("{}"))) if err != nil { panic(err) } if parallel { switch chosenKey { case "test1": atomic.AddInt32(key1ver, 1) case "test2": atomic.AddInt32(key2ver, 1) case "test3": atomic.AddInt32(key3ver, 1) } } else { keyVer++ } mySuccessfulOps++ // Change the min version, which also tests the archive functionality case "change_min_version": var latestVersion int32 = keyVer if parallel { switch chosenKey { case "test1": latestVersion = atomic.LoadInt32(key1ver) case "test2": latestVersion = atomic.LoadInt32(key2ver) case "test3": latestVersion = atomic.LoadInt32(key3ver) } } setVersion := (myRand.Int31() % latestVersion) + 1 // core.Logger().Printf("[TRACE] %s, %s, %d, new min version %d", chosenFunc, chosenKey, id, setVersion) _, err := doReq("POST", chosenHost+"keys/"+chosenKey+"/config", bytes.NewBuffer([]byte(fmt.Sprintf("{\"min_decryption_version\": %d}", setVersion)))) if err != nil { panic(err) } mySuccessfulOps++ } } } atomic.StoreUint32(numWorkers, num) // Spawn some of these workers for 10 seconds for i := 0; i < int(atomic.LoadUint32(numWorkers)); i++ { wg.Add(1) // core.Logger().Printf("[TRACE] spawning %d", i) go doFuzzy(i+1, parallel) } // Wait for them all to finish wg.Wait() if *totalOps == 0 || *totalOps != *successfulOps { t.Fatalf("total/successful ops zero or mismatch: %d/%d; parallel: %t, num %d", *totalOps, *successfulOps, parallel, num) } t.Logf("total operations tried: %d, total successful: %d; parallel: %t, num %d", *totalOps, *successfulOps, parallel, num) } // This tests TLS connection state forwarding by ensuring that we can use a // client TLS to authenticate against the cert backend func TestHTTP_Forwarding_ClientTLS(t *testing.T) { coreConfig := &vault.CoreConfig{ CredentialBackends: map[string]logical.Factory{ "cert": credCert.Factory, }, } cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ HandlerFunc: Handler, }) cluster.Start() defer cluster.Cleanup() cores := cluster.Cores // make it easy to get access to the active core := cores[0].Core vault.TestWaitActive(t, core) transport := cleanhttp.DefaultTransport() transport.TLSClientConfig = cores[0].TLSConfig() if err := http2.ConfigureTransport(transport); err != nil { t.Fatal(err) } client := &http.Client{ Transport: transport, } req, err := http.NewRequest("POST", fmt.Sprintf("https://127.0.0.1:%d/v1/sys/auth/cert", cores[0].Listeners[0].Address.Port), bytes.NewBuffer([]byte("{\"type\": \"cert\"}"))) if err != nil { t.Fatal(err) } req.Header.Set(consts.AuthHeaderName, cluster.RootToken) _, err = client.Do(req) if err != nil { t.Fatal(err) } type certConfig struct { Certificate string `json:"certificate"` Policies string `json:"policies"` } encodedCertConfig, err := json.Marshal(&certConfig{ Certificate: string(cluster.CACertPEM), Policies: "default", }) if err != nil { t.Fatal(err) } req, err = http.NewRequest("POST", fmt.Sprintf("https://127.0.0.1:%d/v1/auth/cert/certs/test", cores[0].Listeners[0].Address.Port), bytes.NewBuffer(encodedCertConfig)) if err != nil { t.Fatal(err) } req.Header.Set(consts.AuthHeaderName, cluster.RootToken) _, err = client.Do(req) if err != nil { t.Fatal(err) } addrs := []string{ fmt.Sprintf("https://127.0.0.1:%d", cores[1].Listeners[0].Address.Port), fmt.Sprintf("https://127.0.0.1:%d", cores[2].Listeners[0].Address.Port), } for i, addr := range addrs { // Ensure we can't possibly use lingering connections even though it should // be to a different address transport = cleanhttp.DefaultTransport() // i starts at zero but cores in addrs start at 1 transport.TLSClientConfig = cores[i+1].TLSConfig() if err := http2.ConfigureTransport(transport); err != nil { t.Fatal(err) } httpClient := &http.Client{ Transport: transport, CheckRedirect: func(*http.Request, []*http.Request) error { return fmt.Errorf("redirects not allowed in this test") }, } client, err := api.NewClient(&api.Config{ Address: addr, HttpClient: httpClient, }) if err != nil { t.Fatal(err) } secret, err := client.Logical().Write("auth/cert/login", nil) if err != nil { t.Fatal(err) } if secret == nil { t.Fatal("secret is nil") } if secret.Auth == nil { t.Fatal("auth is nil") } if secret.Auth.Policies == nil || len(secret.Auth.Policies) == 0 || secret.Auth.Policies[0] != "default" { t.Fatalf("bad policies: %#v", secret.Auth.Policies) } if secret.Auth.ClientToken == "" { t.Fatalf("bad client token: %#v", *secret.Auth) } client.SetToken(secret.Auth.ClientToken) secret, err = client.Auth().Token().LookupSelf() if err != nil { t.Fatal(err) } if secret == nil { t.Fatal("secret is nil") } if secret.Data == nil || len(secret.Data) == 0 { t.Fatal("secret data was empty") } } } func TestHTTP_Forwarding_HelpOperation(t *testing.T) { cluster := vault.NewTestCluster(t, &vault.CoreConfig{}, &vault.TestClusterOptions{ HandlerFunc: Handler, }) cluster.Start() defer cluster.Cleanup() cores := cluster.Cores vault.TestWaitActive(t, cores[0].Core) testHelp := func(client *api.Client) { help, err := client.Help("auth/token") if err != nil { t.Fatal(err) } if help == nil { t.Fatal("help was nil") } } testHelp(cores[0].Client) testHelp(cores[1].Client) } func TestHTTP_Forwarding_LocalOnly(t *testing.T) { cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ HandlerFunc: Handler, }) cluster.Start() defer cluster.Cleanup() cores := cluster.Cores vault.TestWaitActive(t, cores[0].Core) testLocalOnly := func(client *api.Client) { _, err := client.Logical().Read("sys/config/state/sanitized") if err == nil { t.Fatal("expected error") } } testLocalOnly(cores[1].Client) testLocalOnly(cores[2].Client) }