open-vault/api/client_test.go
Marc Boudreau 03d75a7b60
Improving Handling of Unix Domain Socket Addresses (#11904)
* Removed redundant checks for same env var in ReadEnvironment, extracted Unix domain socket logic to function, and made use of this logic in SetAddress.  Adjusted unit tests to verify proper Unix domain socket handling.

* Adding case to revert from Unix domain socket dial function back to TCP

* Adding changelog file

* Only adjust DialContext if RoundTripper is an http.Transport

* Switching from read lock to normal lock

* only reset transport DialContext when setting different address type

* made ParseAddress a method on Config

* Adding additional tests to cover transitions to/from TCP to Unix

* Moved Config type method ParseAddress closer to type's other methods.

* make release note more end-user focused

* adopt review feedback to add comment about holding a lock
2022-06-21 15:16:58 -07:00

1369 lines
33 KiB
Go

package api
import (
"bytes"
"context"
"crypto/x509"
"encoding/base64"
"fmt"
"io"
"net/http"
"net/url"
"os"
"reflect"
"sort"
"strings"
"sync"
"testing"
"time"
"github.com/go-test/deep"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/consts"
)
func init() {
// Ensure our special envvars are not present
os.Setenv("VAULT_ADDR", "")
os.Setenv("VAULT_TOKEN", "")
}
func TestDefaultConfig_envvar(t *testing.T) {
os.Setenv("VAULT_ADDR", "https://vault.mycompany.com")
defer os.Setenv("VAULT_ADDR", "")
config := DefaultConfig()
if config.Address != "https://vault.mycompany.com" {
t.Fatalf("bad: %s", config.Address)
}
os.Setenv("VAULT_TOKEN", "testing")
defer os.Setenv("VAULT_TOKEN", "")
client, err := NewClient(config)
if err != nil {
t.Fatalf("err: %s", err)
}
if token := client.Token(); token != "testing" {
t.Fatalf("bad: %s", token)
}
}
func TestClientDefaultHttpClient(t *testing.T) {
_, err := NewClient(&Config{
HttpClient: http.DefaultClient,
})
if err != nil {
t.Fatal(err)
}
}
func TestClientNilConfig(t *testing.T) {
client, err := NewClient(nil)
if err != nil {
t.Fatal(err)
}
if client == nil {
t.Fatal("expected a non-nil client")
}
}
func TestClientDefaultHttpClient_unixSocket(t *testing.T) {
os.Setenv("VAULT_AGENT_ADDR", "unix:///var/run/vault.sock")
defer os.Setenv("VAULT_AGENT_ADDR", "")
client, err := NewClient(nil)
if err != nil {
t.Fatal(err)
}
if client == nil {
t.Fatal("expected a non-nil client")
}
if client.addr.Scheme != "http" {
t.Fatalf("bad: %s", client.addr.Scheme)
}
if client.addr.Host != "/var/run/vault.sock" {
t.Fatalf("bad: %s", client.addr.Host)
}
}
func TestClientSetAddress(t *testing.T) {
client, err := NewClient(nil)
if err != nil {
t.Fatal(err)
}
// Start with TCP address using HTTP
if err := client.SetAddress("http://172.168.2.1:8300"); err != nil {
t.Fatal(err)
}
if client.addr.Host != "172.168.2.1:8300" {
t.Fatalf("bad: expected: '172.168.2.1:8300' actual: %q", client.addr.Host)
}
// Test switching to Unix Socket address from TCP address
if err := client.SetAddress("unix:///var/run/vault.sock"); err != nil {
t.Fatal(err)
}
if client.addr.Scheme != "http" {
t.Fatalf("bad: expected: 'http' actual: %q", client.addr.Scheme)
}
if client.addr.Host != "/var/run/vault.sock" {
t.Fatalf("bad: expected: '/var/run/vault.sock' actual: %q", client.addr.Host)
}
if client.addr.Path != "" {
t.Fatalf("bad: expected '' actual: %q", client.addr.Path)
}
if client.config.HttpClient.Transport.(*http.Transport).DialContext == nil {
t.Fatal("bad: expected DialContext to not be nil")
}
// Test switching to TCP address from Unix Socket address
if err := client.SetAddress("http://172.168.2.1:8300"); err != nil {
t.Fatal(err)
}
if client.addr.Host != "172.168.2.1:8300" {
t.Fatalf("bad: expected: '172.168.2.1:8300' actual: %q", client.addr.Host)
}
if client.addr.Scheme != "http" {
t.Fatalf("bad: expected: 'http' actual: %q", client.addr.Scheme)
}
}
func TestClientToken(t *testing.T) {
tokenValue := "foo"
handler := func(w http.ResponseWriter, req *http.Request) {}
config, ln := testHTTPServer(t, http.HandlerFunc(handler))
defer ln.Close()
client, err := NewClient(config)
if err != nil {
t.Fatalf("err: %s", err)
}
client.SetToken(tokenValue)
// Verify the token is set
if v := client.Token(); v != tokenValue {
t.Fatalf("bad: %s", v)
}
client.ClearToken()
if v := client.Token(); v != "" {
t.Fatalf("bad: %s", v)
}
}
func TestClientHostHeader(t *testing.T) {
handler := func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte(req.Host))
}
config, ln := testHTTPServer(t, http.HandlerFunc(handler))
defer ln.Close()
config.Address = strings.ReplaceAll(config.Address, "127.0.0.1", "localhost")
client, err := NewClient(config)
if err != nil {
t.Fatalf("err: %s", err)
}
// Set the token manually
client.SetToken("foo")
resp, err := client.RawRequest(client.NewRequest(http.MethodPut, "/"))
if err != nil {
t.Fatal(err)
}
// Copy the response
var buf bytes.Buffer
io.Copy(&buf, resp.Body)
// Verify we got the response from the primary
if buf.String() != strings.ReplaceAll(config.Address, "http://", "") {
t.Fatalf("Bad address: %s", buf.String())
}
}
func TestClientBadToken(t *testing.T) {
handler := func(w http.ResponseWriter, req *http.Request) {}
config, ln := testHTTPServer(t, http.HandlerFunc(handler))
defer ln.Close()
client, err := NewClient(config)
if err != nil {
t.Fatalf("err: %s", err)
}
client.SetToken("foo")
_, err = client.RawRequest(client.NewRequest(http.MethodPut, "/"))
if err != nil {
t.Fatal(err)
}
client.SetToken("foo\u007f")
_, err = client.RawRequest(client.NewRequest(http.MethodPut, "/"))
if err == nil || !strings.Contains(err.Error(), "printable") {
t.Fatalf("expected error due to bad token")
}
}
func TestClientRedirect(t *testing.T) {
primary := func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("test"))
}
config, ln := testHTTPServer(t, http.HandlerFunc(primary))
defer ln.Close()
standby := func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Location", config.Address)
w.WriteHeader(307)
}
config2, ln2 := testHTTPServer(t, http.HandlerFunc(standby))
defer ln2.Close()
client, err := NewClient(config2)
if err != nil {
t.Fatalf("err: %s", err)
}
// Set the token manually
client.SetToken("foo")
// Do a raw "/" request
resp, err := client.RawRequest(client.NewRequest(http.MethodPut, "/"))
if err != nil {
t.Fatalf("err: %s", err)
}
// Copy the response
var buf bytes.Buffer
io.Copy(&buf, resp.Body)
// Verify we got the response from the primary
if buf.String() != "test" {
t.Fatalf("Bad: %s", buf.String())
}
}
func TestDefaulRetryPolicy(t *testing.T) {
cases := map[string]struct {
resp *http.Response
err error
expect bool
expectErr error
}{
"retry on error": {
err: fmt.Errorf("error"),
expect: true,
},
"don't retry connection failures": {
err: &url.Error{
Err: x509.UnknownAuthorityError{},
},
},
"don't retry on 200": {
resp: &http.Response{
StatusCode: http.StatusOK,
},
},
"don't retry on 4xx": {
resp: &http.Response{
StatusCode: http.StatusBadRequest,
},
},
"don't retry on 501": {
resp: &http.Response{
StatusCode: http.StatusNotImplemented,
},
},
"retry on 500": {
resp: &http.Response{
StatusCode: http.StatusInternalServerError,
},
expect: true,
},
"retry on 5xx": {
resp: &http.Response{
StatusCode: http.StatusGatewayTimeout,
},
expect: true,
},
}
for name, test := range cases {
t.Run(name, func(t *testing.T) {
retry, err := DefaultRetryPolicy(context.Background(), test.resp, test.err)
if retry != test.expect {
t.Fatalf("expected to retry request: '%t', but actual result was: '%t'", test.expect, retry)
}
if err != test.expectErr {
t.Fatalf("expected error from retry policy: '%s', but actual result was: '%s'", err, test.expectErr)
}
})
}
}
func TestClientEnvSettings(t *testing.T) {
cwd, _ := os.Getwd()
caCertBytes, err := os.ReadFile(cwd + "/test-fixtures/keys/cert.pem")
if err != nil {
t.Fatalf("error reading %q cert file: %v", cwd+"/test-fixtures/keys/cert.pem", err)
}
oldCACert := os.Getenv(EnvVaultCACert)
oldCACertBytes := os.Getenv(EnvVaultCACertBytes)
oldCAPath := os.Getenv(EnvVaultCAPath)
oldClientCert := os.Getenv(EnvVaultClientCert)
oldClientKey := os.Getenv(EnvVaultClientKey)
oldSkipVerify := os.Getenv(EnvVaultSkipVerify)
oldMaxRetries := os.Getenv(EnvVaultMaxRetries)
os.Setenv(EnvVaultCACert, cwd+"/test-fixtures/keys/cert.pem")
os.Setenv(EnvVaultCACertBytes, string(caCertBytes))
os.Setenv(EnvVaultCAPath, cwd+"/test-fixtures/keys")
os.Setenv(EnvVaultClientCert, cwd+"/test-fixtures/keys/cert.pem")
os.Setenv(EnvVaultClientKey, cwd+"/test-fixtures/keys/key.pem")
os.Setenv(EnvVaultSkipVerify, "true")
os.Setenv(EnvVaultMaxRetries, "5")
defer func() {
os.Setenv(EnvVaultCACert, oldCACert)
os.Setenv(EnvVaultCACertBytes, oldCACertBytes)
os.Setenv(EnvVaultCAPath, oldCAPath)
os.Setenv(EnvVaultClientCert, oldClientCert)
os.Setenv(EnvVaultClientKey, oldClientKey)
os.Setenv(EnvVaultSkipVerify, oldSkipVerify)
os.Setenv(EnvVaultMaxRetries, oldMaxRetries)
}()
config := DefaultConfig()
if err := config.ReadEnvironment(); err != nil {
t.Fatalf("error reading environment: %v", err)
}
tlsConfig := config.HttpClient.Transport.(*http.Transport).TLSClientConfig
if len(tlsConfig.RootCAs.Subjects()) == 0 {
t.Fatalf("bad: expected a cert pool with at least one subject")
}
if tlsConfig.GetClientCertificate == nil {
t.Fatalf("bad: expected client tls config to have a certificate getter")
}
if tlsConfig.InsecureSkipVerify != true {
t.Fatalf("bad: %v", tlsConfig.InsecureSkipVerify)
}
}
func TestClientDeprecatedEnvSettings(t *testing.T) {
oldInsecure := os.Getenv(EnvVaultInsecure)
os.Setenv(EnvVaultInsecure, "true")
defer os.Setenv(EnvVaultInsecure, oldInsecure)
config := DefaultConfig()
if err := config.ReadEnvironment(); err != nil {
t.Fatalf("error reading environment: %v", err)
}
tlsConfig := config.HttpClient.Transport.(*http.Transport).TLSClientConfig
if tlsConfig.InsecureSkipVerify != true {
t.Fatalf("bad: %v", tlsConfig.InsecureSkipVerify)
}
}
func TestClientEnvNamespace(t *testing.T) {
var seenNamespace string
handler := func(w http.ResponseWriter, req *http.Request) {
seenNamespace = req.Header.Get(consts.NamespaceHeaderName)
}
config, ln := testHTTPServer(t, http.HandlerFunc(handler))
defer ln.Close()
oldVaultNamespace := os.Getenv(EnvVaultNamespace)
defer os.Setenv(EnvVaultNamespace, oldVaultNamespace)
os.Setenv(EnvVaultNamespace, "test")
client, err := NewClient(config)
if err != nil {
t.Fatalf("err: %s", err)
}
_, err = client.RawRequest(client.NewRequest(http.MethodGet, "/"))
if err != nil {
t.Fatalf("err: %s", err)
}
if seenNamespace != "test" {
t.Fatalf("Bad: %s", seenNamespace)
}
}
func TestParsingRateAndBurst(t *testing.T) {
var (
correctFormat = "400:400"
observedRate, observedBurst, err = parseRateLimit(correctFormat)
expectedRate, expectedBurst = float64(400), 400
)
if err != nil {
t.Error(err)
}
if expectedRate != observedRate {
t.Errorf("Expected rate %v but found %v", expectedRate, observedRate)
}
if expectedBurst != observedBurst {
t.Errorf("Expected burst %v but found %v", expectedBurst, observedBurst)
}
}
func TestParsingRateOnly(t *testing.T) {
var (
correctFormat = "400"
observedRate, observedBurst, err = parseRateLimit(correctFormat)
expectedRate, expectedBurst = float64(400), 400
)
if err != nil {
t.Error(err)
}
if expectedRate != observedRate {
t.Errorf("Expected rate %v but found %v", expectedRate, observedRate)
}
if expectedBurst != observedBurst {
t.Errorf("Expected burst %v but found %v", expectedBurst, observedBurst)
}
}
func TestParsingErrorCase(t *testing.T) {
incorrectFormat := "foobar"
_, _, err := parseRateLimit(incorrectFormat)
if err == nil {
t.Error("Expected error, found no error")
}
}
func TestClientTimeoutSetting(t *testing.T) {
oldClientTimeout := os.Getenv(EnvVaultClientTimeout)
os.Setenv(EnvVaultClientTimeout, "10")
defer os.Setenv(EnvVaultClientTimeout, oldClientTimeout)
config := DefaultConfig()
config.ReadEnvironment()
_, err := NewClient(config)
if err != nil {
t.Fatal(err)
}
}
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (rt roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return rt(r)
}
func TestClientNonTransportRoundTripper(t *testing.T) {
client := &http.Client{
Transport: roundTripperFunc(http.DefaultTransport.RoundTrip),
}
_, err := NewClient(&Config{
HttpClient: client,
})
if err != nil {
t.Fatal(err)
}
}
func TestClientNonTransportRoundTripperUnixAddress(t *testing.T) {
client := &http.Client{
Transport: roundTripperFunc(http.DefaultTransport.RoundTrip),
}
_, err := NewClient(&Config{
HttpClient: client,
Address: "unix:///var/run/vault.sock",
})
if err == nil {
t.Fatal("bad: expected error got nil")
}
}
func TestClone(t *testing.T) {
type fields struct{}
tests := []struct {
name string
config *Config
headers *http.Header
token string
}{
{
name: "default",
config: DefaultConfig(),
},
{
name: "cloneHeaders",
config: &Config{
CloneHeaders: true,
},
headers: &http.Header{
"X-foo": []string{"bar"},
"X-baz": []string{"qux"},
},
},
{
name: "preventStaleReads",
config: &Config{
ReadYourWrites: true,
},
},
{
name: "cloneToken",
config: &Config{
CloneToken: true,
},
token: "cloneToken",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parent, err := NewClient(tt.config)
if err != nil {
t.Fatalf("NewClient failed: %v", err)
}
// Set all of the things that we provide setter methods for, which modify config values
err = parent.SetAddress("http://example.com:8080")
if err != nil {
t.Fatalf("SetAddress failed: %v", err)
}
clientTimeout := time.Until(time.Now().AddDate(0, 0, 1))
parent.SetClientTimeout(clientTimeout)
checkRetry := func(ctx context.Context, resp *http.Response, err error) (bool, error) {
return true, nil
}
parent.SetCheckRetry(checkRetry)
parent.SetLogger(hclog.NewNullLogger())
parent.SetLimiter(5.0, 10)
parent.SetMaxRetries(5)
parent.SetOutputCurlString(true)
parent.SetOutputPolicy(true)
parent.SetSRVLookup(true)
if tt.headers != nil {
parent.SetHeaders(*tt.headers)
}
if tt.token != "" {
parent.SetToken(tt.token)
}
clone, err := parent.Clone()
if err != nil {
t.Fatalf("Clone failed: %v", err)
}
if parent.Address() != clone.Address() {
t.Fatalf("addresses don't match: %v vs %v", parent.Address(), clone.Address())
}
if parent.ClientTimeout() != clone.ClientTimeout() {
t.Fatalf("timeouts don't match: %v vs %v", parent.ClientTimeout(), clone.ClientTimeout())
}
if parent.CheckRetry() != nil && clone.CheckRetry() == nil {
t.Fatal("checkRetry functions don't match. clone is nil.")
}
if (parent.Limiter() != nil && clone.Limiter() == nil) || (parent.Limiter() == nil && clone.Limiter() != nil) {
t.Fatalf("limiters don't match: %v vs %v", parent.Limiter(), clone.Limiter())
}
if parent.Limiter().Limit() != clone.Limiter().Limit() {
t.Fatalf("limiter limits don't match: %v vs %v", parent.Limiter().Limit(), clone.Limiter().Limit())
}
if parent.Limiter().Burst() != clone.Limiter().Burst() {
t.Fatalf("limiter bursts don't match: %v vs %v", parent.Limiter().Burst(), clone.Limiter().Burst())
}
if parent.MaxRetries() != clone.MaxRetries() {
t.Fatalf("maxRetries don't match: %v vs %v", parent.MaxRetries(), clone.MaxRetries())
}
if parent.OutputCurlString() == clone.OutputCurlString() {
t.Fatalf("outputCurlString was copied over when it shouldn't have been: %v and %v", parent.OutputCurlString(), clone.OutputCurlString())
}
if parent.SRVLookup() != clone.SRVLookup() {
t.Fatalf("SRVLookup doesn't match: %v vs %v", parent.SRVLookup(), clone.SRVLookup())
}
if tt.config.CloneHeaders {
if !reflect.DeepEqual(parent.Headers(), clone.Headers()) {
t.Fatalf("Headers() don't match: %v vs %v", parent.Headers(), clone.Headers())
}
if parent.config.CloneHeaders != clone.config.CloneHeaders {
t.Fatalf("config.CloneHeaders doesn't match: %v vs %v", parent.config.CloneHeaders, clone.config.CloneHeaders)
}
if tt.headers != nil {
if !reflect.DeepEqual(*tt.headers, clone.Headers()) {
t.Fatalf("expected headers %v, actual %v", *tt.headers, clone.Headers())
}
}
}
if tt.config.ReadYourWrites && parent.replicationStateStore == nil {
t.Fatalf("replicationStateStore is nil")
}
if tt.config.CloneToken {
if tt.token == "" {
t.Fatalf("test requires a non-empty token")
}
if parent.config.CloneToken != clone.config.CloneToken {
t.Fatalf("config.CloneToken doesn't match: %v vs %v", parent.config.CloneToken, clone.config.CloneToken)
}
if parent.token != clone.token {
t.Fatalf("tokens do not match: %v vs %v", parent.token, clone.token)
}
} else {
// assumes `VAULT_TOKEN` is unset or has an empty value.
expected := ""
if clone.token != expected {
t.Fatalf("expected clone's token %q, actual %q", expected, clone.token)
}
}
if !reflect.DeepEqual(parent.replicationStateStore, clone.replicationStateStore) {
t.Fatalf("expected replicationStateStore %v, actual %v", parent.replicationStateStore,
clone.replicationStateStore)
}
})
}
}
func TestSetHeadersRaceSafe(t *testing.T) {
client, err1 := NewClient(nil)
if err1 != nil {
t.Fatalf("NewClient failed: %v", err1)
}
start := make(chan interface{})
done := make(chan interface{})
testPairs := map[string]string{
"soda": "rootbeer",
"veggie": "carrots",
"fruit": "apples",
"color": "red",
"protein": "egg",
}
for key, value := range testPairs {
tmpKey := key
tmpValue := value
go func() {
<-start
// This test fails if here, you replace client.AddHeader(tmpKey, tmpValue) with:
// headerCopy := client.Header()
// headerCopy.AddHeader(tmpKey, tmpValue)
// client.SetHeader(headerCopy)
client.AddHeader(tmpKey, tmpValue)
done <- true
}()
}
// Start everyone at once.
close(start)
// Wait until everyone is done.
for i := 0; i < len(testPairs); i++ {
<-done
}
// Check that all the test pairs are in the resulting
// headers.
resultingHeaders := client.Headers()
for key, value := range testPairs {
if resultingHeaders.Get(key) != value {
t.Fatal("expected " + value + " for " + key)
}
}
}
func TestMergeReplicationStates(t *testing.T) {
type testCase struct {
name string
old []string
new string
expected []string
}
testCases := []testCase{
{
name: "empty-old",
old: nil,
new: "v1:cid:1:0:",
expected: []string{"v1:cid:1:0:"},
},
{
name: "old-smaller",
old: []string{"v1:cid:1:0:"},
new: "v1:cid:2:0:",
expected: []string{"v1:cid:2:0:"},
},
{
name: "old-bigger",
old: []string{"v1:cid:2:0:"},
new: "v1:cid:1:0:",
expected: []string{"v1:cid:2:0:"},
},
{
name: "mixed-single",
old: []string{"v1:cid:1:0:"},
new: "v1:cid:0:1:",
expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
},
{
name: "mixed-single-alt",
old: []string{"v1:cid:0:1:"},
new: "v1:cid:1:0:",
expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
},
{
name: "mixed-double",
old: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
new: "v1:cid:2:0:",
expected: []string{"v1:cid:0:1:", "v1:cid:2:0:"},
},
{
name: "newer-both",
old: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
new: "v1:cid:2:1:",
expected: []string{"v1:cid:2:1:"},
},
}
b64enc := func(ss []string) []string {
var ret []string
for _, s := range ss {
ret = append(ret, base64.StdEncoding.EncodeToString([]byte(s)))
}
return ret
}
b64dec := func(ss []string) []string {
var ret []string
for _, s := range ss {
d, err := base64.StdEncoding.DecodeString(s)
if err != nil {
t.Fatal(err)
}
ret = append(ret, string(d))
}
return ret
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
out := b64dec(MergeReplicationStates(b64enc(tc.old), base64.StdEncoding.EncodeToString([]byte(tc.new))))
if diff := deep.Equal(out, tc.expected); len(diff) != 0 {
t.Errorf("got=%v, expected=%v, diff=%v", out, tc.expected, diff)
}
})
}
}
func TestReplicationStateStore_recordState(t *testing.T) {
b64enc := func(s string) string {
return base64.StdEncoding.EncodeToString([]byte(s))
}
tests := []struct {
name string
expected []string
resp []*Response
}{
{
name: "single",
resp: []*Response{
{
Response: &http.Response{
Header: map[string][]string{
HeaderIndex: {
b64enc("v1:cid:1:0:"),
},
},
},
},
},
expected: []string{
b64enc("v1:cid:1:0:"),
},
},
{
name: "empty",
resp: []*Response{
{
Response: &http.Response{
Header: map[string][]string{},
},
},
},
expected: nil,
},
{
name: "multiple",
resp: []*Response{
{
Response: &http.Response{
Header: map[string][]string{
HeaderIndex: {
b64enc("v1:cid:0:1:"),
},
},
},
},
{
Response: &http.Response{
Header: map[string][]string{
HeaderIndex: {
b64enc("v1:cid:1:0:"),
},
},
},
},
},
expected: []string{
b64enc("v1:cid:0:1:"),
b64enc("v1:cid:1:0:"),
},
},
{
name: "duplicates",
resp: []*Response{
{
Response: &http.Response{
Header: map[string][]string{
HeaderIndex: {
b64enc("v1:cid:1:0:"),
},
},
},
},
{
Response: &http.Response{
Header: map[string][]string{
HeaderIndex: {
b64enc("v1:cid:1:0:"),
},
},
},
},
},
expected: []string{
b64enc("v1:cid:1:0:"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := &replicationStateStore{}
var wg sync.WaitGroup
for _, r := range tt.resp {
wg.Add(1)
go func(r *Response) {
defer wg.Done()
w.recordState(r)
}(r)
}
wg.Wait()
if !reflect.DeepEqual(tt.expected, w.store) {
t.Errorf("recordState(): expected states %v, actual %v", tt.expected, w.store)
}
})
}
}
func TestReplicationStateStore_requireState(t *testing.T) {
tests := []struct {
name string
states []string
req []*Request
expected []string
}{
{
name: "empty",
states: []string{},
req: []*Request{
{
Headers: make(http.Header),
},
},
expected: nil,
},
{
name: "basic",
states: []string{
"v1:cid:0:1:",
"v1:cid:1:0:",
},
req: []*Request{
{
Headers: make(http.Header),
},
},
expected: []string{
"v1:cid:0:1:",
"v1:cid:1:0:",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := &replicationStateStore{
store: tt.states,
}
var wg sync.WaitGroup
for _, r := range tt.req {
wg.Add(1)
go func(r *Request) {
defer wg.Done()
store.requireState(r)
}(r)
}
wg.Wait()
var actual []string
for _, r := range tt.req {
if values := r.Headers.Values(HeaderIndex); len(values) > 0 {
actual = append(actual, values...)
}
}
sort.Strings(actual)
if !reflect.DeepEqual(tt.expected, actual) {
t.Errorf("requireState(): expected states %v, actual %v", tt.expected, actual)
}
})
}
}
func TestClient_ReadYourWrites(t *testing.T) {
b64enc := func(s string) string {
return base64.StdEncoding.EncodeToString([]byte(s))
}
handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set(HeaderIndex, strings.TrimLeft(req.URL.Path, "/"))
})
tests := []struct {
name string
handler http.Handler
wantStates []string
values [][]string
clone bool
}{
{
name: "multiple_duplicates",
clone: false,
handler: handler,
wantStates: []string{
b64enc("v1:cid:0:4:"),
},
values: [][]string{
{
b64enc("v1:cid:0:4:"),
b64enc("v1:cid:0:2:"),
},
{
b64enc("v1:cid:0:4:"),
b64enc("v1:cid:0:2:"),
},
},
},
{
name: "basic_clone",
clone: true,
handler: handler,
wantStates: []string{
b64enc("v1:cid:0:4:"),
},
values: [][]string{
{
b64enc("v1:cid:0:4:"),
},
{
b64enc("v1:cid:0:3:"),
},
},
},
{
name: "multiple_clone",
clone: true,
handler: handler,
wantStates: []string{
b64enc("v1:cid:0:4:"),
},
values: [][]string{
{
b64enc("v1:cid:0:4:"),
b64enc("v1:cid:0:2:"),
},
{
b64enc("v1:cid:0:3:"),
b64enc("v1:cid:0:1:"),
},
},
},
{
name: "multiple_duplicates_clone",
clone: true,
handler: handler,
wantStates: []string{
b64enc("v1:cid:0:4:"),
},
values: [][]string{
{
b64enc("v1:cid:0:4:"),
b64enc("v1:cid:0:2:"),
},
{
b64enc("v1:cid:0:4:"),
b64enc("v1:cid:0:2:"),
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testRequest := func(client *Client, val string) {
req := client.NewRequest(http.MethodGet, "/"+val)
req.Headers.Set(HeaderIndex, val)
resp, err := client.RawRequestWithContext(context.Background(), req)
if err != nil {
t.Fatal(err)
}
// validate that the server provided a valid header value in its response
actual := resp.Header.Get(HeaderIndex)
if actual != val {
t.Errorf("expected header value %v, actual %v", val, actual)
}
}
config, ln := testHTTPServer(t, handler)
defer ln.Close()
config.ReadYourWrites = true
config.Address = fmt.Sprintf("http://%s", ln.Addr())
parent, err := NewClient(config)
if err != nil {
t.Fatal(err)
}
var wg sync.WaitGroup
for i := 0; i < len(tt.values); i++ {
var c *Client
if tt.clone {
c, err = parent.Clone()
if err != nil {
t.Fatal(err)
}
} else {
c = parent
}
for _, val := range tt.values[i] {
wg.Add(1)
go func(val string) {
defer wg.Done()
testRequest(c, val)
}(val)
}
}
wg.Wait()
if !reflect.DeepEqual(tt.wantStates, parent.replicationStateStore.states()) {
t.Errorf("expected states %v, actual %v", tt.wantStates, parent.replicationStateStore.states())
}
})
}
}
func TestClient_SetReadYourWrites(t *testing.T) {
tests := []struct {
name string
config *Config
calls []bool
}{
{
name: "false",
config: &Config{},
calls: []bool{false},
},
{
name: "true",
config: &Config{},
calls: []bool{true},
},
{
name: "multi-false",
config: &Config{},
calls: []bool{false, false},
},
{
name: "multi-true",
config: &Config{},
calls: []bool{true, true},
},
{
name: "multi-mix",
config: &Config{},
calls: []bool{false, true, false, true},
},
}
assertSetReadYourRights := func(t *testing.T, c *Client, v bool, s *replicationStateStore) {
t.Helper()
c.SetReadYourWrites(v)
if c.config.ReadYourWrites != v {
t.Fatalf("expected config.ReadYourWrites %#v, actual %#v", v, c.config.ReadYourWrites)
}
if !reflect.DeepEqual(s, c.replicationStateStore) {
t.Fatalf("expected replicationStateStore %#v, actual %#v", s, c.replicationStateStore)
}
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Client{
config: tt.config,
}
for i, v := range tt.calls {
var expectStateStore *replicationStateStore
if v {
if c.replicationStateStore == nil {
c.replicationStateStore = &replicationStateStore{
store: []string{},
}
}
c.replicationStateStore.store = append(c.replicationStateStore.store,
fmt.Sprintf("%s-%d", tt.name, i))
expectStateStore = c.replicationStateStore
}
assertSetReadYourRights(t, c, v, expectStateStore)
}
})
}
}
func TestClient_SetCloneToken(t *testing.T) {
tests := []struct {
name string
calls []bool
}{
{
name: "false",
calls: []bool{false},
},
{
name: "true",
calls: []bool{true},
},
{
name: "multi",
calls: []bool{true, false, true},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Client{
config: &Config{},
}
var expected bool
for _, v := range tt.calls {
actual := c.CloneToken()
if expected != actual {
t.Fatalf("expected %v, actual %v", expected, actual)
}
expected = v
c.SetCloneToken(expected)
actual = c.CloneToken()
if actual != expected {
t.Fatalf("SetCloneToken(): expected %v, actual %v", expected, actual)
}
}
})
}
}
func TestClientWithNamespace(t *testing.T) {
var ns string
handler := func(w http.ResponseWriter, req *http.Request) {
ns = req.Header.Get(consts.NamespaceHeaderName)
}
config, ln := testHTTPServer(t, http.HandlerFunc(handler))
defer ln.Close()
// set up a client with a namespace
client, err := NewClient(config)
if err != nil {
t.Fatalf("err: %s", err)
}
ogNS := "test"
client.SetNamespace(ogNS)
_, err = client.rawRequestWithContext(
context.Background(),
client.NewRequest(http.MethodGet, "/"))
if err != nil {
t.Fatalf("err: %s", err)
}
if ns != ogNS {
t.Fatalf("Expected namespace: \"%s\", got \"%s\"", ogNS, ns)
}
// make a call with a temporary namespace
newNS := "new-namespace"
_, err = client.WithNamespace(newNS).rawRequestWithContext(
context.Background(),
client.NewRequest(http.MethodGet, "/"))
if err != nil {
t.Fatalf("err: %s", err)
}
if ns != newNS {
t.Fatalf("Expected new namespace: \"%s\", got \"%s\"", newNS, ns)
}
// ensure client has not been modified
_, err = client.rawRequestWithContext(
context.Background(),
client.NewRequest(http.MethodGet, "/"))
if err != nil {
t.Fatalf("err: %s", err)
}
if ns != ogNS {
t.Fatalf("Expected original namespace: \"%s\", got \"%s\"", ogNS, ns)
}
// make call with empty ns
_, err = client.WithNamespace("").rawRequestWithContext(
context.Background(),
client.NewRequest(http.MethodGet, "/"))
if err != nil {
t.Fatalf("err: %s", err)
}
if ns != "" {
t.Fatalf("Expected no namespace, got \"%s\"", ns)
}
// ensure client has not been modified
if client.Namespace() != ogNS {
t.Fatalf("Expected original namespace: \"%s\", got \"%s\"", ogNS, client.Namespace())
}
}
func TestVaultProxy(t *testing.T) {
const NoProxy string = "NO_PROXY"
tests := map[string]struct {
name string
vaultHttpProxy string
vaultProxyAddr string
noProxy string
requestUrl string
expectedResolvedProxyUrl string
}{
"VAULT_HTTP_PROXY used when NO_PROXY env var doesn't include request host": {
vaultHttpProxy: "https://hashicorp.com",
vaultProxyAddr: "",
noProxy: "terraform.io",
requestUrl: "https://vaultproject.io",
},
"VAULT_HTTP_PROXY used when NO_PROXY env var includes request host": {
vaultHttpProxy: "https://hashicorp.com",
vaultProxyAddr: "",
noProxy: "terraform.io,vaultproject.io",
requestUrl: "https://vaultproject.io",
},
"VAULT_PROXY_ADDR used when NO_PROXY env var doesn't include request host": {
vaultHttpProxy: "",
vaultProxyAddr: "https://hashicorp.com",
noProxy: "terraform.io",
requestUrl: "https://vaultproject.io",
},
"VAULT_PROXY_ADDR used when NO_PROXY env var includes request host": {
vaultHttpProxy: "",
vaultProxyAddr: "https://hashicorp.com",
noProxy: "terraform.io,vaultproject.io",
requestUrl: "https://vaultproject.io",
},
"VAULT_PROXY_ADDR used when VAULT_HTTP_PROXY env var also supplied": {
vaultHttpProxy: "https://hashicorp.com",
vaultProxyAddr: "https://terraform.io",
noProxy: "",
requestUrl: "https://vaultproject.io",
expectedResolvedProxyUrl: "https://terraform.io",
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
if tc.vaultHttpProxy != "" {
oldVaultHttpProxy := os.Getenv(EnvHTTPProxy)
os.Setenv(EnvHTTPProxy, tc.vaultHttpProxy)
defer os.Setenv(EnvHTTPProxy, oldVaultHttpProxy)
}
if tc.vaultProxyAddr != "" {
oldVaultProxyAddr := os.Getenv(EnvVaultProxyAddr)
os.Setenv(EnvVaultProxyAddr, tc.vaultProxyAddr)
defer os.Setenv(EnvVaultProxyAddr, oldVaultProxyAddr)
}
if tc.noProxy != "" {
oldNoProxy := os.Getenv(NoProxy)
os.Setenv(NoProxy, tc.noProxy)
defer os.Setenv(NoProxy, oldNoProxy)
}
c := DefaultConfig()
if c.Error != nil {
t.Fatalf("Expected no error reading config, found error %v", c.Error)
}
r, _ := http.NewRequest("GET", tc.requestUrl, nil)
proxyUrl, err := c.HttpClient.Transport.(*http.Transport).Proxy(r)
if err != nil {
t.Fatalf("Expected no error resolving proxy, found error %v", err)
}
if proxyUrl == nil || proxyUrl.String() == "" {
t.Fatalf("Expected proxy to be resolved but no proxy returned")
}
if tc.expectedResolvedProxyUrl != "" && proxyUrl.String() != tc.expectedResolvedProxyUrl {
t.Fatalf("Expected resolved proxy URL to be %v but was %v", tc.expectedResolvedProxyUrl, proxyUrl.String())
}
})
}
}
func TestParseAddressWithUnixSocket(t *testing.T) {
address := "unix:///var/run/vault.sock"
config := DefaultConfig()
u, err := config.ParseAddress(address)
if err != nil {
t.Fatal("Error not expected")
}
if u.Scheme != "http" {
t.Fatal("Scheme not changed to http")
}
if u.Host != "/var/run/vault.sock" {
t.Fatal("Host not changed to socket name")
}
if u.Path != "" {
t.Fatal("Path expected to be blank")
}
if config.HttpClient.Transport.(*http.Transport).DialContext == nil {
t.Fatal("DialContext function not set in config.HttpClient.Transport")
}
}