open-vault/api/client_test.go

1439 lines
35 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package api
import (
"bytes"
"context"
"crypto/x509"
"encoding/base64"
"fmt"
"io"
"net/http"
"net/url"
"os"
"reflect"
"sort"
"strings"
"sync"
"testing"
"time"
"github.com/go-test/deep"
"github.com/hashicorp/go-hclog"
)
func init() {
// Ensure our special envvars are not present
os.Setenv("VAULT_ADDR", "")
os.Setenv("VAULT_TOKEN", "")
}
func TestDefaultConfig_envvar(t *testing.T) {
os.Setenv("VAULT_ADDR", "https://vault.mycompany.com")
defer os.Setenv("VAULT_ADDR", "")
config := DefaultConfig()
if config.Address != "https://vault.mycompany.com" {
t.Fatalf("bad: %s", config.Address)
}
os.Setenv("VAULT_TOKEN", "testing")
defer os.Setenv("VAULT_TOKEN", "")
client, err := NewClient(config)
if err != nil {
t.Fatalf("err: %s", err)
}
if token := client.Token(); token != "testing" {
t.Fatalf("bad: %s", token)
}
}
func TestClientDefaultHttpClient(t *testing.T) {
_, err := NewClient(&Config{
HttpClient: http.DefaultClient,
})
if err != nil {
t.Fatal(err)
}
}
func TestClientNilConfig(t *testing.T) {
client, err := NewClient(nil)
if err != nil {
t.Fatal(err)
}
if client == nil {
t.Fatal("expected a non-nil client")
}
}
func TestClientDefaultHttpClient_unixSocket(t *testing.T) {
os.Setenv("VAULT_AGENT_ADDR", "unix:///var/run/vault.sock")
defer os.Setenv("VAULT_AGENT_ADDR", "")
client, err := NewClient(nil)
if err != nil {
t.Fatal(err)
}
if client == nil {
t.Fatal("expected a non-nil client")
}
if client.addr.Scheme != "http" {
t.Fatalf("bad: %s", client.addr.Scheme)
}
if client.addr.Host != "/var/run/vault.sock" {
t.Fatalf("bad: %s", client.addr.Host)
}
}
func TestClientSetAddress(t *testing.T) {
client, err := NewClient(nil)
if err != nil {
t.Fatal(err)
}
// Start with TCP address using HTTP
if err := client.SetAddress("http://172.168.2.1:8300"); err != nil {
t.Fatal(err)
}
if client.addr.Host != "172.168.2.1:8300" {
t.Fatalf("bad: expected: '172.168.2.1:8300' actual: %q", client.addr.Host)
}
// Test switching to Unix Socket address from TCP address
if err := client.SetAddress("unix:///var/run/vault.sock"); err != nil {
t.Fatal(err)
}
if client.addr.Scheme != "http" {
t.Fatalf("bad: expected: 'http' actual: %q", client.addr.Scheme)
}
if client.addr.Host != "/var/run/vault.sock" {
t.Fatalf("bad: expected: '/var/run/vault.sock' actual: %q", client.addr.Host)
}
if client.addr.Path != "" {
t.Fatalf("bad: expected '' actual: %q", client.addr.Path)
}
if client.config.HttpClient.Transport.(*http.Transport).DialContext == nil {
t.Fatal("bad: expected DialContext to not be nil")
}
// Test switching to TCP address from Unix Socket address
if err := client.SetAddress("http://172.168.2.1:8300"); err != nil {
t.Fatal(err)
}
if client.addr.Host != "172.168.2.1:8300" {
t.Fatalf("bad: expected: '172.168.2.1:8300' actual: %q", client.addr.Host)
}
if client.addr.Scheme != "http" {
t.Fatalf("bad: expected: 'http' actual: %q", client.addr.Scheme)
}
}
func TestClientToken(t *testing.T) {
tokenValue := "foo"
handler := func(w http.ResponseWriter, req *http.Request) {}
config, ln := testHTTPServer(t, http.HandlerFunc(handler))
defer ln.Close()
client, err := NewClient(config)
if err != nil {
t.Fatalf("err: %s", err)
}
client.SetToken(tokenValue)
// Verify the token is set
if v := client.Token(); v != tokenValue {
t.Fatalf("bad: %s", v)
}
client.ClearToken()
if v := client.Token(); v != "" {
t.Fatalf("bad: %s", v)
}
}
func TestClientHostHeader(t *testing.T) {
handler := func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte(req.Host))
}
config, ln := testHTTPServer(t, http.HandlerFunc(handler))
defer ln.Close()
config.Address = strings.ReplaceAll(config.Address, "127.0.0.1", "localhost")
client, err := NewClient(config)
if err != nil {
t.Fatalf("err: %s", err)
}
// Set the token manually
client.SetToken("foo")
resp, err := client.RawRequest(client.NewRequest(http.MethodPut, "/"))
if err != nil {
t.Fatal(err)
}
// Copy the response
var buf bytes.Buffer
io.Copy(&buf, resp.Body)
// Verify we got the response from the primary
if buf.String() != strings.ReplaceAll(config.Address, "http://", "") {
t.Fatalf("Bad address: %s", buf.String())
}
}
func TestClientBadToken(t *testing.T) {
handler := func(w http.ResponseWriter, req *http.Request) {}
config, ln := testHTTPServer(t, http.HandlerFunc(handler))
defer ln.Close()
client, err := NewClient(config)
if err != nil {
t.Fatalf("err: %s", err)
}
client.SetToken("foo")
_, err = client.RawRequest(client.NewRequest(http.MethodPut, "/"))
if err != nil {
t.Fatal(err)
}
client.SetToken("foo\u007f")
_, err = client.RawRequest(client.NewRequest(http.MethodPut, "/"))
if err == nil || !strings.Contains(err.Error(), "printable") {
t.Fatalf("expected error due to bad token")
}
}
func TestClientDisableRedirects(t *testing.T) {
tests := map[string]struct {
statusCode int
expectedNumReqs int
disableRedirects bool
}{
"Disabled redirects: Moved permanently": {statusCode: 301, expectedNumReqs: 1, disableRedirects: true},
"Disabled redirects: Found": {statusCode: 302, expectedNumReqs: 1, disableRedirects: true},
"Disabled redirects: Temporary Redirect": {statusCode: 307, expectedNumReqs: 1, disableRedirects: true},
"Enable redirects: Moved permanently": {statusCode: 301, expectedNumReqs: 2, disableRedirects: false},
}
for name, tc := range tests {
test := tc
name := name
t.Run(name, func(t *testing.T) {
t.Parallel()
numReqs := 0
var config *Config
respFunc := func(w http.ResponseWriter, req *http.Request) {
// Track how many requests the server has handled
numReqs++
// Send back the relevant status code and generate a location
w.Header().Set("Location", fmt.Sprintf(config.Address+"/reqs/%v", numReqs))
w.WriteHeader(test.statusCode)
}
config, ln := testHTTPServer(t, http.HandlerFunc(respFunc))
config.DisableRedirects = test.disableRedirects
defer ln.Close()
client, err := NewClient(config)
if err != nil {
t.Fatalf("%s: error %v", name, err)
}
req := client.NewRequest("GET", "/")
resp, err := client.rawRequestWithContext(context.Background(), req)
if err != nil {
t.Fatalf("%s: error %v", name, err)
}
if numReqs != test.expectedNumReqs {
t.Fatalf("%s: expected %v request(s) but got %v", name, test.expectedNumReqs, numReqs)
}
if resp.StatusCode != test.statusCode {
t.Fatalf("%s: expected status code %v got %v", name, test.statusCode, resp.StatusCode)
}
location, err := resp.Location()
if err != nil {
t.Fatalf("%s error %v", name, err)
}
if req.URL.String() == location.String() {
t.Fatalf("%s: expected request URL %v to be different from redirect URL %v", name, req.URL, resp.Request.URL)
}
})
}
}
func TestClientRedirect(t *testing.T) {
primary := func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("test"))
}
config, ln := testHTTPServer(t, http.HandlerFunc(primary))
defer ln.Close()
standby := func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Location", config.Address)
w.WriteHeader(307)
}
config2, ln2 := testHTTPServer(t, http.HandlerFunc(standby))
defer ln2.Close()
client, err := NewClient(config2)
if err != nil {
t.Fatalf("err: %s", err)
}
// Set the token manually
client.SetToken("foo")
// Do a raw "/" request
resp, err := client.RawRequest(client.NewRequest(http.MethodPut, "/"))
if err != nil {
t.Fatalf("err: %s", err)
}
// Copy the response
var buf bytes.Buffer
io.Copy(&buf, resp.Body)
// Verify we got the response from the primary
if buf.String() != "test" {
t.Fatalf("Bad: %s", buf.String())
}
}
func TestDefaulRetryPolicy(t *testing.T) {
cases := map[string]struct {
resp *http.Response
err error
expect bool
expectErr error
}{
"retry on error": {
err: fmt.Errorf("error"),
expect: true,
},
"don't retry connection failures": {
err: &url.Error{
Err: x509.UnknownAuthorityError{},
},
},
"don't retry on 200": {
resp: &http.Response{
StatusCode: http.StatusOK,
},
},
"don't retry on 4xx": {
resp: &http.Response{
StatusCode: http.StatusBadRequest,
},
},
"don't retry on 501": {
resp: &http.Response{
StatusCode: http.StatusNotImplemented,
},
},
"retry on 500": {
resp: &http.Response{
StatusCode: http.StatusInternalServerError,
},
expect: true,
},
"retry on 5xx": {
resp: &http.Response{
StatusCode: http.StatusGatewayTimeout,
},
expect: true,
},
}
for name, test := range cases {
t.Run(name, func(t *testing.T) {
retry, err := DefaultRetryPolicy(context.Background(), test.resp, test.err)
if retry != test.expect {
t.Fatalf("expected to retry request: '%t', but actual result was: '%t'", test.expect, retry)
}
if err != test.expectErr {
t.Fatalf("expected error from retry policy: %q, but actual result was: %q", err, test.expectErr)
}
})
}
}
func TestClientEnvSettings(t *testing.T) {
cwd, _ := os.Getwd()
caCertBytes, err := os.ReadFile(cwd + "/test-fixtures/keys/cert.pem")
if err != nil {
t.Fatalf("error reading %q cert file: %v", cwd+"/test-fixtures/keys/cert.pem", err)
}
oldCACert := os.Getenv(EnvVaultCACert)
oldCACertBytes := os.Getenv(EnvVaultCACertBytes)
oldCAPath := os.Getenv(EnvVaultCAPath)
oldClientCert := os.Getenv(EnvVaultClientCert)
oldClientKey := os.Getenv(EnvVaultClientKey)
oldSkipVerify := os.Getenv(EnvVaultSkipVerify)
oldMaxRetries := os.Getenv(EnvVaultMaxRetries)
oldDisableRedirects := os.Getenv(EnvVaultDisableRedirects)
os.Setenv(EnvVaultCACert, cwd+"/test-fixtures/keys/cert.pem")
os.Setenv(EnvVaultCACertBytes, string(caCertBytes))
os.Setenv(EnvVaultCAPath, cwd+"/test-fixtures/keys")
os.Setenv(EnvVaultClientCert, cwd+"/test-fixtures/keys/cert.pem")
os.Setenv(EnvVaultClientKey, cwd+"/test-fixtures/keys/key.pem")
os.Setenv(EnvVaultSkipVerify, "true")
os.Setenv(EnvVaultMaxRetries, "5")
os.Setenv(EnvVaultDisableRedirects, "true")
defer func() {
os.Setenv(EnvVaultCACert, oldCACert)
os.Setenv(EnvVaultCACertBytes, oldCACertBytes)
os.Setenv(EnvVaultCAPath, oldCAPath)
os.Setenv(EnvVaultClientCert, oldClientCert)
os.Setenv(EnvVaultClientKey, oldClientKey)
os.Setenv(EnvVaultSkipVerify, oldSkipVerify)
os.Setenv(EnvVaultMaxRetries, oldMaxRetries)
os.Setenv(EnvVaultDisableRedirects, oldDisableRedirects)
}()
config := DefaultConfig()
if err := config.ReadEnvironment(); err != nil {
t.Fatalf("error reading environment: %v", err)
}
tlsConfig := config.HttpClient.Transport.(*http.Transport).TLSClientConfig
if len(tlsConfig.RootCAs.Subjects()) == 0 {
t.Fatalf("bad: expected a cert pool with at least one subject")
}
if tlsConfig.GetClientCertificate == nil {
t.Fatalf("bad: expected client tls config to have a certificate getter")
}
if tlsConfig.InsecureSkipVerify != true {
t.Fatalf("bad: %v", tlsConfig.InsecureSkipVerify)
}
if config.DisableRedirects != true {
t.Fatalf("bad: expected disable redirects to be true: %v", config.DisableRedirects)
}
}
func TestClientDeprecatedEnvSettings(t *testing.T) {
oldInsecure := os.Getenv(EnvVaultInsecure)
os.Setenv(EnvVaultInsecure, "true")
defer os.Setenv(EnvVaultInsecure, oldInsecure)
config := DefaultConfig()
if err := config.ReadEnvironment(); err != nil {
t.Fatalf("error reading environment: %v", err)
}
tlsConfig := config.HttpClient.Transport.(*http.Transport).TLSClientConfig
if tlsConfig.InsecureSkipVerify != true {
t.Fatalf("bad: %v", tlsConfig.InsecureSkipVerify)
}
}
func TestClientEnvNamespace(t *testing.T) {
var seenNamespace string
handler := func(w http.ResponseWriter, req *http.Request) {
seenNamespace = req.Header.Get(NamespaceHeaderName)
}
config, ln := testHTTPServer(t, http.HandlerFunc(handler))
defer ln.Close()
oldVaultNamespace := os.Getenv(EnvVaultNamespace)
defer os.Setenv(EnvVaultNamespace, oldVaultNamespace)
os.Setenv(EnvVaultNamespace, "test")
client, err := NewClient(config)
if err != nil {
t.Fatalf("err: %s", err)
}
_, err = client.RawRequest(client.NewRequest(http.MethodGet, "/"))
if err != nil {
t.Fatalf("err: %s", err)
}
if seenNamespace != "test" {
t.Fatalf("Bad: %s", seenNamespace)
}
}
func TestParsingRateAndBurst(t *testing.T) {
var (
correctFormat = "400:400"
observedRate, observedBurst, err = parseRateLimit(correctFormat)
expectedRate, expectedBurst = float64(400), 400
)
if err != nil {
t.Error(err)
}
if expectedRate != observedRate {
t.Errorf("Expected rate %v but found %v", expectedRate, observedRate)
}
if expectedBurst != observedBurst {
t.Errorf("Expected burst %v but found %v", expectedBurst, observedBurst)
}
}
func TestParsingRateOnly(t *testing.T) {
var (
correctFormat = "400"
observedRate, observedBurst, err = parseRateLimit(correctFormat)
expectedRate, expectedBurst = float64(400), 400
)
if err != nil {
t.Error(err)
}
if expectedRate != observedRate {
t.Errorf("Expected rate %v but found %v", expectedRate, observedRate)
}
if expectedBurst != observedBurst {
t.Errorf("Expected burst %v but found %v", expectedBurst, observedBurst)
}
}
func TestParsingErrorCase(t *testing.T) {
incorrectFormat := "foobar"
_, _, err := parseRateLimit(incorrectFormat)
if err == nil {
t.Error("Expected error, found no error")
}
}
func TestClientTimeoutSetting(t *testing.T) {
oldClientTimeout := os.Getenv(EnvVaultClientTimeout)
os.Setenv(EnvVaultClientTimeout, "10")
defer os.Setenv(EnvVaultClientTimeout, oldClientTimeout)
config := DefaultConfig()
config.ReadEnvironment()
_, err := NewClient(config)
if err != nil {
t.Fatal(err)
}
}
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (rt roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return rt(r)
}
func TestClientNonTransportRoundTripper(t *testing.T) {
client := &http.Client{
Transport: roundTripperFunc(http.DefaultTransport.RoundTrip),
}
_, err := NewClient(&Config{
HttpClient: client,
})
if err != nil {
t.Fatal(err)
}
}
func TestClientNonTransportRoundTripperUnixAddress(t *testing.T) {
client := &http.Client{
Transport: roundTripperFunc(http.DefaultTransport.RoundTrip),
}
_, err := NewClient(&Config{
HttpClient: client,
Address: "unix:///var/run/vault.sock",
})
if err == nil {
t.Fatal("bad: expected error got nil")
}
}
func TestClone(t *testing.T) {
type fields struct{}
tests := []struct {
name string
config *Config
headers *http.Header
token string
}{
{
name: "default",
config: DefaultConfig(),
},
{
name: "cloneHeaders",
config: &Config{
CloneHeaders: true,
},
headers: &http.Header{
"X-foo": []string{"bar"},
"X-baz": []string{"qux"},
},
},
{
name: "preventStaleReads",
config: &Config{
ReadYourWrites: true,
},
},
{
name: "cloneToken",
config: &Config{
CloneToken: true,
},
token: "cloneToken",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parent, err := NewClient(tt.config)
if err != nil {
t.Fatalf("NewClient failed: %v", err)
}
// Set all of the things that we provide setter methods for, which modify config values
err = parent.SetAddress("http://example.com:8080")
if err != nil {
t.Fatalf("SetAddress failed: %v", err)
}
clientTimeout := time.Until(time.Now().AddDate(0, 0, 1))
parent.SetClientTimeout(clientTimeout)
checkRetry := func(ctx context.Context, resp *http.Response, err error) (bool, error) {
return true, nil
}
parent.SetCheckRetry(checkRetry)
parent.SetLogger(hclog.NewNullLogger())
parent.SetLimiter(5.0, 10)
parent.SetMaxRetries(5)
parent.SetOutputCurlString(true)
parent.SetOutputPolicy(true)
parent.SetSRVLookup(true)
if tt.headers != nil {
parent.SetHeaders(*tt.headers)
}
if tt.token != "" {
parent.SetToken(tt.token)
}
clone, err := parent.Clone()
if err != nil {
t.Fatalf("Clone failed: %v", err)
}
if parent.Address() != clone.Address() {
t.Fatalf("addresses don't match: %v vs %v", parent.Address(), clone.Address())
}
if parent.ClientTimeout() != clone.ClientTimeout() {
t.Fatalf("timeouts don't match: %v vs %v", parent.ClientTimeout(), clone.ClientTimeout())
}
if parent.CheckRetry() != nil && clone.CheckRetry() == nil {
t.Fatal("checkRetry functions don't match. clone is nil.")
}
if (parent.Limiter() != nil && clone.Limiter() == nil) || (parent.Limiter() == nil && clone.Limiter() != nil) {
t.Fatalf("limiters don't match: %v vs %v", parent.Limiter(), clone.Limiter())
}
if parent.Limiter().Limit() != clone.Limiter().Limit() {
t.Fatalf("limiter limits don't match: %v vs %v", parent.Limiter().Limit(), clone.Limiter().Limit())
}
if parent.Limiter().Burst() != clone.Limiter().Burst() {
t.Fatalf("limiter bursts don't match: %v vs %v", parent.Limiter().Burst(), clone.Limiter().Burst())
}
if parent.MaxRetries() != clone.MaxRetries() {
t.Fatalf("maxRetries don't match: %v vs %v", parent.MaxRetries(), clone.MaxRetries())
}
if parent.OutputCurlString() == clone.OutputCurlString() {
t.Fatalf("outputCurlString was copied over when it shouldn't have been: %v and %v", parent.OutputCurlString(), clone.OutputCurlString())
}
if parent.SRVLookup() != clone.SRVLookup() {
t.Fatalf("SRVLookup doesn't match: %v vs %v", parent.SRVLookup(), clone.SRVLookup())
}
if tt.config.CloneHeaders {
if !reflect.DeepEqual(parent.Headers(), clone.Headers()) {
t.Fatalf("Headers() don't match: %v vs %v", parent.Headers(), clone.Headers())
}
if parent.config.CloneHeaders != clone.config.CloneHeaders {
t.Fatalf("config.CloneHeaders doesn't match: %v vs %v", parent.config.CloneHeaders, clone.config.CloneHeaders)
}
if tt.headers != nil {
if !reflect.DeepEqual(*tt.headers, clone.Headers()) {
t.Fatalf("expected headers %v, actual %v", *tt.headers, clone.Headers())
}
}
}
if tt.config.ReadYourWrites && parent.replicationStateStore == nil {
t.Fatalf("replicationStateStore is nil")
}
if tt.config.CloneToken {
if tt.token == "" {
t.Fatalf("test requires a non-empty token")
}
if parent.config.CloneToken != clone.config.CloneToken {
t.Fatalf("config.CloneToken doesn't match: %v vs %v", parent.config.CloneToken, clone.config.CloneToken)
}
if parent.token != clone.token {
t.Fatalf("tokens do not match: %v vs %v", parent.token, clone.token)
}
} else {
// assumes `VAULT_TOKEN` is unset or has an empty value.
expected := ""
if clone.token != expected {
t.Fatalf("expected clone's token %q, actual %q", expected, clone.token)
}
}
if !reflect.DeepEqual(parent.replicationStateStore, clone.replicationStateStore) {
t.Fatalf("expected replicationStateStore %v, actual %v", parent.replicationStateStore,
clone.replicationStateStore)
}
})
}
}
func TestSetHeadersRaceSafe(t *testing.T) {
client, err1 := NewClient(nil)
if err1 != nil {
t.Fatalf("NewClient failed: %v", err1)
}
start := make(chan interface{})
done := make(chan interface{})
testPairs := map[string]string{
"soda": "rootbeer",
"veggie": "carrots",
"fruit": "apples",
"color": "red",
"protein": "egg",
}
for key, value := range testPairs {
tmpKey := key
tmpValue := value
go func() {
<-start
// This test fails if here, you replace client.AddHeader(tmpKey, tmpValue) with:
// headerCopy := client.Header()
// headerCopy.AddHeader(tmpKey, tmpValue)
// client.SetHeader(headerCopy)
client.AddHeader(tmpKey, tmpValue)
done <- true
}()
}
// Start everyone at once.
close(start)
// Wait until everyone is done.
for i := 0; i < len(testPairs); i++ {
<-done
}
// Check that all the test pairs are in the resulting
// headers.
resultingHeaders := client.Headers()
for key, value := range testPairs {
if resultingHeaders.Get(key) != value {
t.Fatal("expected " + value + " for " + key)
}
}
}
func TestMergeReplicationStates(t *testing.T) {
type testCase struct {
name string
old []string
new string
expected []string
}
testCases := []testCase{
{
name: "empty-old",
old: nil,
new: "v1:cid:1:0:",
expected: []string{"v1:cid:1:0:"},
},
{
name: "old-smaller",
old: []string{"v1:cid:1:0:"},
new: "v1:cid:2:0:",
expected: []string{"v1:cid:2:0:"},
},
{
name: "old-bigger",
old: []string{"v1:cid:2:0:"},
new: "v1:cid:1:0:",
expected: []string{"v1:cid:2:0:"},
},
{
name: "mixed-single",
old: []string{"v1:cid:1:0:"},
new: "v1:cid:0:1:",
expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
},
{
name: "mixed-single-alt",
old: []string{"v1:cid:0:1:"},
new: "v1:cid:1:0:",
expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
},
{
name: "mixed-double",
old: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
new: "v1:cid:2:0:",
expected: []string{"v1:cid:0:1:", "v1:cid:2:0:"},
},
{
name: "newer-both",
old: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
new: "v1:cid:2:1:",
expected: []string{"v1:cid:2:1:"},
},
}
b64enc := func(ss []string) []string {
var ret []string
for _, s := range ss {
ret = append(ret, base64.StdEncoding.EncodeToString([]byte(s)))
}
return ret
}
b64dec := func(ss []string) []string {
var ret []string
for _, s := range ss {
d, err := base64.StdEncoding.DecodeString(s)
if err != nil {
t.Fatal(err)
}
ret = append(ret, string(d))
}
return ret
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
out := b64dec(MergeReplicationStates(b64enc(tc.old), base64.StdEncoding.EncodeToString([]byte(tc.new))))
if diff := deep.Equal(out, tc.expected); len(diff) != 0 {
t.Errorf("got=%v, expected=%v, diff=%v", out, tc.expected, diff)
}
})
}
}
func TestReplicationStateStore_recordState(t *testing.T) {
b64enc := func(s string) string {
return base64.StdEncoding.EncodeToString([]byte(s))
}
tests := []struct {
name string
expected []string
resp []*Response
}{
{
name: "single",
resp: []*Response{
{
Response: &http.Response{
Header: map[string][]string{
HeaderIndex: {
b64enc("v1:cid:1:0:"),
},
},
},
},
},
expected: []string{
b64enc("v1:cid:1:0:"),
},
},
{
name: "empty",
resp: []*Response{
{
Response: &http.Response{
Header: map[string][]string{},
},
},
},
expected: nil,
},
{
name: "multiple",
resp: []*Response{
{
Response: &http.Response{
Header: map[string][]string{
HeaderIndex: {
b64enc("v1:cid:0:1:"),
},
},
},
},
{
Response: &http.Response{
Header: map[string][]string{
HeaderIndex: {
b64enc("v1:cid:1:0:"),
},
},
},
},
},
expected: []string{
b64enc("v1:cid:0:1:"),
b64enc("v1:cid:1:0:"),
},
},
{
name: "duplicates",
resp: []*Response{
{
Response: &http.Response{
Header: map[string][]string{
HeaderIndex: {
b64enc("v1:cid:1:0:"),
},
},
},
},
{
Response: &http.Response{
Header: map[string][]string{
HeaderIndex: {
b64enc("v1:cid:1:0:"),
},
},
},
},
},
expected: []string{
b64enc("v1:cid:1:0:"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := &replicationStateStore{}
var wg sync.WaitGroup
for _, r := range tt.resp {
wg.Add(1)
go func(r *Response) {
defer wg.Done()
w.recordState(r)
}(r)
}
wg.Wait()
if !reflect.DeepEqual(tt.expected, w.store) {
t.Errorf("recordState(): expected states %v, actual %v", tt.expected, w.store)
}
})
}
}
func TestReplicationStateStore_requireState(t *testing.T) {
tests := []struct {
name string
states []string
req []*Request
expected []string
}{
{
name: "empty",
states: []string{},
req: []*Request{
{
Headers: make(http.Header),
},
},
expected: nil,
},
{
name: "basic",
states: []string{
"v1:cid:0:1:",
"v1:cid:1:0:",
},
req: []*Request{
{
Headers: make(http.Header),
},
},
expected: []string{
"v1:cid:0:1:",
"v1:cid:1:0:",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := &replicationStateStore{
store: tt.states,
}
var wg sync.WaitGroup
for _, r := range tt.req {
wg.Add(1)
go func(r *Request) {
defer wg.Done()
store.requireState(r)
}(r)
}
wg.Wait()
var actual []string
for _, r := range tt.req {
if values := r.Headers.Values(HeaderIndex); len(values) > 0 {
actual = append(actual, values...)
}
}
sort.Strings(actual)
if !reflect.DeepEqual(tt.expected, actual) {
t.Errorf("requireState(): expected states %v, actual %v", tt.expected, actual)
}
})
}
}
func TestClient_ReadYourWrites(t *testing.T) {
b64enc := func(s string) string {
return base64.StdEncoding.EncodeToString([]byte(s))
}
handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set(HeaderIndex, strings.TrimLeft(req.URL.Path, "/"))
})
tests := []struct {
name string
handler http.Handler
wantStates []string
values [][]string
clone bool
}{
{
name: "multiple_duplicates",
clone: false,
handler: handler,
wantStates: []string{
b64enc("v1:cid:0:4:"),
},
values: [][]string{
{
b64enc("v1:cid:0:4:"),
b64enc("v1:cid:0:2:"),
},
{
b64enc("v1:cid:0:4:"),
b64enc("v1:cid:0:2:"),
},
},
},
{
name: "basic_clone",
clone: true,
handler: handler,
wantStates: []string{
b64enc("v1:cid:0:4:"),
},
values: [][]string{
{
b64enc("v1:cid:0:4:"),
},
{
b64enc("v1:cid:0:3:"),
},
},
},
{
name: "multiple_clone",
clone: true,
handler: handler,
wantStates: []string{
b64enc("v1:cid:0:4:"),
},
values: [][]string{
{
b64enc("v1:cid:0:4:"),
b64enc("v1:cid:0:2:"),
},
{
b64enc("v1:cid:0:3:"),
b64enc("v1:cid:0:1:"),
},
},
},
{
name: "multiple_duplicates_clone",
clone: true,
handler: handler,
wantStates: []string{
b64enc("v1:cid:0:4:"),
},
values: [][]string{
{
b64enc("v1:cid:0:4:"),
b64enc("v1:cid:0:2:"),
},
{
b64enc("v1:cid:0:4:"),
b64enc("v1:cid:0:2:"),
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testRequest := func(client *Client, val string) {
req := client.NewRequest(http.MethodGet, "/"+val)
req.Headers.Set(HeaderIndex, val)
resp, err := client.RawRequestWithContext(context.Background(), req)
if err != nil {
t.Fatal(err)
}
// validate that the server provided a valid header value in its response
actual := resp.Header.Get(HeaderIndex)
if actual != val {
t.Errorf("expected header value %v, actual %v", val, actual)
}
}
config, ln := testHTTPServer(t, handler)
defer ln.Close()
config.ReadYourWrites = true
config.Address = fmt.Sprintf("http://%s", ln.Addr())
parent, err := NewClient(config)
if err != nil {
t.Fatal(err)
}
var wg sync.WaitGroup
for i := 0; i < len(tt.values); i++ {
var c *Client
if tt.clone {
c, err = parent.Clone()
if err != nil {
t.Fatal(err)
}
} else {
c = parent
}
for _, val := range tt.values[i] {
wg.Add(1)
go func(val string) {
defer wg.Done()
testRequest(c, val)
}(val)
}
}
wg.Wait()
if !reflect.DeepEqual(tt.wantStates, parent.replicationStateStore.states()) {
t.Errorf("expected states %v, actual %v", tt.wantStates, parent.replicationStateStore.states())
}
})
}
}
func TestClient_SetReadYourWrites(t *testing.T) {
tests := []struct {
name string
config *Config
calls []bool
}{
{
name: "false",
config: &Config{},
calls: []bool{false},
},
{
name: "true",
config: &Config{},
calls: []bool{true},
},
{
name: "multi-false",
config: &Config{},
calls: []bool{false, false},
},
{
name: "multi-true",
config: &Config{},
calls: []bool{true, true},
},
{
name: "multi-mix",
config: &Config{},
calls: []bool{false, true, false, true},
},
}
assertSetReadYourRights := func(t *testing.T, c *Client, v bool, s *replicationStateStore) {
t.Helper()
c.SetReadYourWrites(v)
if c.config.ReadYourWrites != v {
t.Fatalf("expected config.ReadYourWrites %#v, actual %#v", v, c.config.ReadYourWrites)
}
if !reflect.DeepEqual(s, c.replicationStateStore) {
t.Fatalf("expected replicationStateStore %#v, actual %#v", s, c.replicationStateStore)
}
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Client{
config: tt.config,
}
for i, v := range tt.calls {
var expectStateStore *replicationStateStore
if v {
if c.replicationStateStore == nil {
c.replicationStateStore = &replicationStateStore{
store: []string{},
}
}
c.replicationStateStore.store = append(c.replicationStateStore.store,
fmt.Sprintf("%s-%d", tt.name, i))
expectStateStore = c.replicationStateStore
}
assertSetReadYourRights(t, c, v, expectStateStore)
}
})
}
}
func TestClient_SetCloneToken(t *testing.T) {
tests := []struct {
name string
calls []bool
}{
{
name: "false",
calls: []bool{false},
},
{
name: "true",
calls: []bool{true},
},
{
name: "multi",
calls: []bool{true, false, true},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Client{
config: &Config{},
}
var expected bool
for _, v := range tt.calls {
actual := c.CloneToken()
if expected != actual {
t.Fatalf("expected %v, actual %v", expected, actual)
}
expected = v
c.SetCloneToken(expected)
actual = c.CloneToken()
if actual != expected {
t.Fatalf("SetCloneToken(): expected %v, actual %v", expected, actual)
}
}
})
}
}
func TestClientWithNamespace(t *testing.T) {
var ns string
handler := func(w http.ResponseWriter, req *http.Request) {
ns = req.Header.Get(NamespaceHeaderName)
}
config, ln := testHTTPServer(t, http.HandlerFunc(handler))
defer ln.Close()
// set up a client with a namespace
client, err := NewClient(config)
if err != nil {
t.Fatalf("err: %s", err)
}
ogNS := "test"
client.SetNamespace(ogNS)
_, err = client.rawRequestWithContext(
context.Background(),
client.NewRequest(http.MethodGet, "/"))
if err != nil {
t.Fatalf("err: %s", err)
}
if ns != ogNS {
t.Fatalf("Expected namespace: %q, got %q", ogNS, ns)
}
// make a call with a temporary namespace
newNS := "new-namespace"
_, err = client.WithNamespace(newNS).rawRequestWithContext(
context.Background(),
client.NewRequest(http.MethodGet, "/"))
if err != nil {
t.Fatalf("err: %s", err)
}
if ns != newNS {
t.Fatalf("Expected new namespace: %q, got %q", newNS, ns)
}
// ensure client has not been modified
_, err = client.rawRequestWithContext(
context.Background(),
client.NewRequest(http.MethodGet, "/"))
if err != nil {
t.Fatalf("err: %s", err)
}
if ns != ogNS {
t.Fatalf("Expected original namespace: %q, got %q", ogNS, ns)
}
// make call with empty ns
_, err = client.WithNamespace("").rawRequestWithContext(
context.Background(),
client.NewRequest(http.MethodGet, "/"))
if err != nil {
t.Fatalf("err: %s", err)
}
if ns != "" {
t.Fatalf("Expected no namespace, got %q", ns)
}
// ensure client has not been modified
if client.Namespace() != ogNS {
t.Fatalf("Expected original namespace: %q, got %q", ogNS, client.Namespace())
}
}
func TestVaultProxy(t *testing.T) {
const NoProxy string = "NO_PROXY"
tests := map[string]struct {
name string
vaultHttpProxy string
vaultProxyAddr string
noProxy string
requestUrl string
expectedResolvedProxyUrl string
}{
"VAULT_HTTP_PROXY used when NO_PROXY env var doesn't include request host": {
vaultHttpProxy: "https://hashicorp.com",
vaultProxyAddr: "",
noProxy: "terraform.io",
requestUrl: "https://vaultproject.io",
},
"VAULT_HTTP_PROXY used when NO_PROXY env var includes request host": {
vaultHttpProxy: "https://hashicorp.com",
vaultProxyAddr: "",
noProxy: "terraform.io,vaultproject.io",
requestUrl: "https://vaultproject.io",
},
"VAULT_PROXY_ADDR used when NO_PROXY env var doesn't include request host": {
vaultHttpProxy: "",
vaultProxyAddr: "https://hashicorp.com",
noProxy: "terraform.io",
requestUrl: "https://vaultproject.io",
},
"VAULT_PROXY_ADDR used when NO_PROXY env var includes request host": {
vaultHttpProxy: "",
vaultProxyAddr: "https://hashicorp.com",
noProxy: "terraform.io,vaultproject.io",
requestUrl: "https://vaultproject.io",
},
"VAULT_PROXY_ADDR used when VAULT_HTTP_PROXY env var also supplied": {
vaultHttpProxy: "https://hashicorp.com",
vaultProxyAddr: "https://terraform.io",
noProxy: "",
requestUrl: "https://vaultproject.io",
expectedResolvedProxyUrl: "https://terraform.io",
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
if tc.vaultHttpProxy != "" {
oldVaultHttpProxy := os.Getenv(EnvHTTPProxy)
os.Setenv(EnvHTTPProxy, tc.vaultHttpProxy)
defer os.Setenv(EnvHTTPProxy, oldVaultHttpProxy)
}
if tc.vaultProxyAddr != "" {
oldVaultProxyAddr := os.Getenv(EnvVaultProxyAddr)
os.Setenv(EnvVaultProxyAddr, tc.vaultProxyAddr)
defer os.Setenv(EnvVaultProxyAddr, oldVaultProxyAddr)
}
if tc.noProxy != "" {
oldNoProxy := os.Getenv(NoProxy)
os.Setenv(NoProxy, tc.noProxy)
defer os.Setenv(NoProxy, oldNoProxy)
}
c := DefaultConfig()
if c.Error != nil {
t.Fatalf("Expected no error reading config, found error %v", c.Error)
}
r, _ := http.NewRequest("GET", tc.requestUrl, nil)
proxyUrl, err := c.HttpClient.Transport.(*http.Transport).Proxy(r)
if err != nil {
t.Fatalf("Expected no error resolving proxy, found error %v", err)
}
if proxyUrl == nil || proxyUrl.String() == "" {
t.Fatalf("Expected proxy to be resolved but no proxy returned")
}
if tc.expectedResolvedProxyUrl != "" && proxyUrl.String() != tc.expectedResolvedProxyUrl {
t.Fatalf("Expected resolved proxy URL to be %v but was %v", tc.expectedResolvedProxyUrl, proxyUrl.String())
}
})
}
}
func TestParseAddressWithUnixSocket(t *testing.T) {
address := "unix:///var/run/vault.sock"
config := DefaultConfig()
u, err := config.ParseAddress(address)
if err != nil {
t.Fatal("Error not expected")
}
if u.Scheme != "http" {
t.Fatal("Scheme not changed to http")
}
if u.Host != "/var/run/vault.sock" {
t.Fatal("Host not changed to socket name")
}
if u.Path != "" {
t.Fatal("Path expected to be blank")
}
if config.HttpClient.Transport.(*http.Transport).DialContext == nil {
t.Fatal("DialContext function not set in config.HttpClient.Transport")
}
}