989 lines
23 KiB
Go
989 lines
23 KiB
Go
package api
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/x509"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"reflect"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/go-test/deep"
|
|
"github.com/hashicorp/go-hclog"
|
|
|
|
"github.com/hashicorp/vault/sdk/helper/consts"
|
|
)
|
|
|
|
func init() {
|
|
// Ensure our special envvars are not present
|
|
os.Setenv("VAULT_ADDR", "")
|
|
os.Setenv("VAULT_TOKEN", "")
|
|
}
|
|
|
|
func TestDefaultConfig_envvar(t *testing.T) {
|
|
os.Setenv("VAULT_ADDR", "https://vault.mycompany.com")
|
|
defer os.Setenv("VAULT_ADDR", "")
|
|
|
|
config := DefaultConfig()
|
|
if config.Address != "https://vault.mycompany.com" {
|
|
t.Fatalf("bad: %s", config.Address)
|
|
}
|
|
|
|
os.Setenv("VAULT_TOKEN", "testing")
|
|
defer os.Setenv("VAULT_TOKEN", "")
|
|
|
|
client, err := NewClient(config)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
|
|
if token := client.Token(); token != "testing" {
|
|
t.Fatalf("bad: %s", token)
|
|
}
|
|
}
|
|
|
|
func TestClientDefaultHttpClient(t *testing.T) {
|
|
_, err := NewClient(&Config{
|
|
HttpClient: http.DefaultClient,
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func TestClientNilConfig(t *testing.T) {
|
|
client, err := NewClient(nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if client == nil {
|
|
t.Fatal("expected a non-nil client")
|
|
}
|
|
}
|
|
|
|
func TestClientSetAddress(t *testing.T) {
|
|
client, err := NewClient(nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := client.SetAddress("http://172.168.2.1:8300"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if client.addr.Host != "172.168.2.1:8300" {
|
|
t.Fatalf("bad: expected: '172.168.2.1:8300' actual: %q", client.addr.Host)
|
|
}
|
|
}
|
|
|
|
func TestClientToken(t *testing.T) {
|
|
tokenValue := "foo"
|
|
handler := func(w http.ResponseWriter, req *http.Request) {}
|
|
|
|
config, ln := testHTTPServer(t, http.HandlerFunc(handler))
|
|
defer ln.Close()
|
|
|
|
client, err := NewClient(config)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
|
|
client.SetToken(tokenValue)
|
|
|
|
// Verify the token is set
|
|
if v := client.Token(); v != tokenValue {
|
|
t.Fatalf("bad: %s", v)
|
|
}
|
|
|
|
client.ClearToken()
|
|
|
|
if v := client.Token(); v != "" {
|
|
t.Fatalf("bad: %s", v)
|
|
}
|
|
}
|
|
|
|
func TestClientHostHeader(t *testing.T) {
|
|
handler := func(w http.ResponseWriter, req *http.Request) {
|
|
w.Write([]byte(req.Host))
|
|
}
|
|
config, ln := testHTTPServer(t, http.HandlerFunc(handler))
|
|
defer ln.Close()
|
|
|
|
config.Address = strings.ReplaceAll(config.Address, "127.0.0.1", "localhost")
|
|
client, err := NewClient(config)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
|
|
// Set the token manually
|
|
client.SetToken("foo")
|
|
|
|
resp, err := client.RawRequest(client.NewRequest("PUT", "/"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Copy the response
|
|
var buf bytes.Buffer
|
|
io.Copy(&buf, resp.Body)
|
|
|
|
// Verify we got the response from the primary
|
|
if buf.String() != strings.ReplaceAll(config.Address, "http://", "") {
|
|
t.Fatalf("Bad address: %s", buf.String())
|
|
}
|
|
}
|
|
|
|
func TestClientBadToken(t *testing.T) {
|
|
handler := func(w http.ResponseWriter, req *http.Request) {}
|
|
|
|
config, ln := testHTTPServer(t, http.HandlerFunc(handler))
|
|
defer ln.Close()
|
|
|
|
client, err := NewClient(config)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
|
|
client.SetToken("foo")
|
|
_, err = client.RawRequest(client.NewRequest("PUT", "/"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
client.SetToken("foo\u007f")
|
|
_, err = client.RawRequest(client.NewRequest("PUT", "/"))
|
|
if err == nil || !strings.Contains(err.Error(), "printable") {
|
|
t.Fatalf("expected error due to bad token")
|
|
}
|
|
}
|
|
|
|
func TestClientRedirect(t *testing.T) {
|
|
primary := func(w http.ResponseWriter, req *http.Request) {
|
|
w.Write([]byte("test"))
|
|
}
|
|
config, ln := testHTTPServer(t, http.HandlerFunc(primary))
|
|
defer ln.Close()
|
|
|
|
standby := func(w http.ResponseWriter, req *http.Request) {
|
|
w.Header().Set("Location", config.Address)
|
|
w.WriteHeader(307)
|
|
}
|
|
config2, ln2 := testHTTPServer(t, http.HandlerFunc(standby))
|
|
defer ln2.Close()
|
|
|
|
client, err := NewClient(config2)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
|
|
// Set the token manually
|
|
client.SetToken("foo")
|
|
|
|
// Do a raw "/" request
|
|
resp, err := client.RawRequest(client.NewRequest("PUT", "/"))
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
|
|
// Copy the response
|
|
var buf bytes.Buffer
|
|
io.Copy(&buf, resp.Body)
|
|
|
|
// Verify we got the response from the primary
|
|
if buf.String() != "test" {
|
|
t.Fatalf("Bad: %s", buf.String())
|
|
}
|
|
}
|
|
|
|
func TestDefaulRetryPolicy(t *testing.T) {
|
|
cases := map[string]struct {
|
|
resp *http.Response
|
|
err error
|
|
expect bool
|
|
expectErr error
|
|
}{
|
|
"retry on error": {
|
|
err: fmt.Errorf("error"),
|
|
expect: true,
|
|
},
|
|
"don't retry connection failures": {
|
|
err: &url.Error{
|
|
Err: x509.UnknownAuthorityError{},
|
|
},
|
|
},
|
|
"don't retry on 200": {
|
|
resp: &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
},
|
|
},
|
|
"don't retry on 4xx": {
|
|
resp: &http.Response{
|
|
StatusCode: http.StatusBadRequest,
|
|
},
|
|
},
|
|
"don't retry on 501": {
|
|
resp: &http.Response{
|
|
StatusCode: http.StatusNotImplemented,
|
|
},
|
|
},
|
|
"retry on 500": {
|
|
resp: &http.Response{
|
|
StatusCode: http.StatusInternalServerError,
|
|
},
|
|
expect: true,
|
|
},
|
|
"retry on 5xx": {
|
|
resp: &http.Response{
|
|
StatusCode: http.StatusGatewayTimeout,
|
|
},
|
|
expect: true,
|
|
},
|
|
}
|
|
|
|
for name, test := range cases {
|
|
t.Run(name, func(t *testing.T) {
|
|
retry, err := DefaultRetryPolicy(context.Background(), test.resp, test.err)
|
|
if retry != test.expect {
|
|
t.Fatalf("expected to retry request: '%t', but actual result was: '%t'", test.expect, retry)
|
|
}
|
|
if err != test.expectErr {
|
|
t.Fatalf("expected error from retry policy: '%s', but actual result was: '%s'", err, test.expectErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestClientEnvSettings(t *testing.T) {
|
|
cwd, _ := os.Getwd()
|
|
oldCACert := os.Getenv(EnvVaultCACert)
|
|
oldCAPath := os.Getenv(EnvVaultCAPath)
|
|
oldClientCert := os.Getenv(EnvVaultClientCert)
|
|
oldClientKey := os.Getenv(EnvVaultClientKey)
|
|
oldSkipVerify := os.Getenv(EnvVaultSkipVerify)
|
|
oldMaxRetries := os.Getenv(EnvVaultMaxRetries)
|
|
os.Setenv(EnvVaultCACert, cwd+"/test-fixtures/keys/cert.pem")
|
|
os.Setenv(EnvVaultCAPath, cwd+"/test-fixtures/keys")
|
|
os.Setenv(EnvVaultClientCert, cwd+"/test-fixtures/keys/cert.pem")
|
|
os.Setenv(EnvVaultClientKey, cwd+"/test-fixtures/keys/key.pem")
|
|
os.Setenv(EnvVaultSkipVerify, "true")
|
|
os.Setenv(EnvVaultMaxRetries, "5")
|
|
defer os.Setenv(EnvVaultCACert, oldCACert)
|
|
defer os.Setenv(EnvVaultCAPath, oldCAPath)
|
|
defer os.Setenv(EnvVaultClientCert, oldClientCert)
|
|
defer os.Setenv(EnvVaultClientKey, oldClientKey)
|
|
defer os.Setenv(EnvVaultSkipVerify, oldSkipVerify)
|
|
defer os.Setenv(EnvVaultMaxRetries, oldMaxRetries)
|
|
|
|
config := DefaultConfig()
|
|
if err := config.ReadEnvironment(); err != nil {
|
|
t.Fatalf("error reading environment: %v", err)
|
|
}
|
|
|
|
tlsConfig := config.HttpClient.Transport.(*http.Transport).TLSClientConfig
|
|
if len(tlsConfig.RootCAs.Subjects()) == 0 {
|
|
t.Fatalf("bad: expected a cert pool with at least one subject")
|
|
}
|
|
if tlsConfig.GetClientCertificate == nil {
|
|
t.Fatalf("bad: expected client tls config to have a certificate getter")
|
|
}
|
|
if tlsConfig.InsecureSkipVerify != true {
|
|
t.Fatalf("bad: %v", tlsConfig.InsecureSkipVerify)
|
|
}
|
|
}
|
|
|
|
func TestClientDeprecatedEnvSettings(t *testing.T) {
|
|
oldInsecure := os.Getenv(EnvVaultInsecure)
|
|
os.Setenv(EnvVaultInsecure, "true")
|
|
defer os.Setenv(EnvVaultInsecure, oldInsecure)
|
|
|
|
config := DefaultConfig()
|
|
if err := config.ReadEnvironment(); err != nil {
|
|
t.Fatalf("error reading environment: %v", err)
|
|
}
|
|
|
|
tlsConfig := config.HttpClient.Transport.(*http.Transport).TLSClientConfig
|
|
if tlsConfig.InsecureSkipVerify != true {
|
|
t.Fatalf("bad: %v", tlsConfig.InsecureSkipVerify)
|
|
}
|
|
}
|
|
|
|
func TestClientEnvNamespace(t *testing.T) {
|
|
var seenNamespace string
|
|
handler := func(w http.ResponseWriter, req *http.Request) {
|
|
seenNamespace = req.Header.Get(consts.NamespaceHeaderName)
|
|
}
|
|
config, ln := testHTTPServer(t, http.HandlerFunc(handler))
|
|
defer ln.Close()
|
|
|
|
oldVaultNamespace := os.Getenv(EnvVaultNamespace)
|
|
defer os.Setenv(EnvVaultNamespace, oldVaultNamespace)
|
|
os.Setenv(EnvVaultNamespace, "test")
|
|
|
|
client, err := NewClient(config)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
|
|
_, err = client.RawRequest(client.NewRequest("GET", "/"))
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
|
|
if seenNamespace != "test" {
|
|
t.Fatalf("Bad: %s", seenNamespace)
|
|
}
|
|
}
|
|
|
|
func TestParsingRateAndBurst(t *testing.T) {
|
|
var (
|
|
correctFormat = "400:400"
|
|
observedRate, observedBurst, err = parseRateLimit(correctFormat)
|
|
expectedRate, expectedBurst = float64(400), 400
|
|
)
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
if expectedRate != observedRate {
|
|
t.Errorf("Expected rate %v but found %v", expectedRate, observedRate)
|
|
}
|
|
if expectedBurst != observedBurst {
|
|
t.Errorf("Expected burst %v but found %v", expectedBurst, observedBurst)
|
|
}
|
|
}
|
|
|
|
func TestParsingRateOnly(t *testing.T) {
|
|
var (
|
|
correctFormat = "400"
|
|
observedRate, observedBurst, err = parseRateLimit(correctFormat)
|
|
expectedRate, expectedBurst = float64(400), 400
|
|
)
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
if expectedRate != observedRate {
|
|
t.Errorf("Expected rate %v but found %v", expectedRate, observedRate)
|
|
}
|
|
if expectedBurst != observedBurst {
|
|
t.Errorf("Expected burst %v but found %v", expectedBurst, observedBurst)
|
|
}
|
|
}
|
|
|
|
func TestParsingErrorCase(t *testing.T) {
|
|
incorrectFormat := "foobar"
|
|
_, _, err := parseRateLimit(incorrectFormat)
|
|
if err == nil {
|
|
t.Error("Expected error, found no error")
|
|
}
|
|
}
|
|
|
|
func TestClientTimeoutSetting(t *testing.T) {
|
|
oldClientTimeout := os.Getenv(EnvVaultClientTimeout)
|
|
os.Setenv(EnvVaultClientTimeout, "10")
|
|
defer os.Setenv(EnvVaultClientTimeout, oldClientTimeout)
|
|
config := DefaultConfig()
|
|
config.ReadEnvironment()
|
|
_, err := NewClient(config)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
|
|
|
func (rt roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
|
|
return rt(r)
|
|
}
|
|
|
|
func TestClientNonTransportRoundTripper(t *testing.T) {
|
|
client := &http.Client{
|
|
Transport: roundTripperFunc(http.DefaultTransport.RoundTrip),
|
|
}
|
|
|
|
_, err := NewClient(&Config{
|
|
HttpClient: client,
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func TestClone(t *testing.T) {
|
|
type fields struct{}
|
|
tests := []struct {
|
|
name string
|
|
config *Config
|
|
headers *http.Header
|
|
}{
|
|
{
|
|
name: "default",
|
|
config: DefaultConfig(),
|
|
},
|
|
{
|
|
name: "cloneHeaders",
|
|
config: &Config{
|
|
CloneHeaders: true,
|
|
},
|
|
headers: &http.Header{
|
|
"X-foo": []string{"bar"},
|
|
"X-baz": []string{"qux"},
|
|
},
|
|
},
|
|
{
|
|
name: "preventStaleReads",
|
|
config: &Config{
|
|
ReadYourWrites: true,
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
client1, err := NewClient(tt.config)
|
|
if err != nil {
|
|
t.Fatalf("NewClient failed: %v", err)
|
|
}
|
|
|
|
// Set all of the things that we provide setter methods for, which modify config values
|
|
err = client1.SetAddress("http://example.com:8080")
|
|
if err != nil {
|
|
t.Fatalf("SetAddress failed: %v", err)
|
|
}
|
|
|
|
clientTimeout := time.Until(time.Now().AddDate(0, 0, 1))
|
|
client1.SetClientTimeout(clientTimeout)
|
|
|
|
checkRetry := func(ctx context.Context, resp *http.Response, err error) (bool, error) {
|
|
return true, nil
|
|
}
|
|
client1.SetCheckRetry(checkRetry)
|
|
|
|
client1.SetLogger(hclog.NewNullLogger())
|
|
|
|
client1.SetLimiter(5.0, 10)
|
|
client1.SetMaxRetries(5)
|
|
client1.SetOutputCurlString(true)
|
|
client1.SetSRVLookup(true)
|
|
|
|
if tt.headers != nil {
|
|
client1.SetHeaders(*tt.headers)
|
|
}
|
|
|
|
client2, err := client1.Clone()
|
|
if err != nil {
|
|
t.Fatalf("Clone failed: %v", err)
|
|
}
|
|
|
|
if client1.Address() != client2.Address() {
|
|
t.Fatalf("addresses don't match: %v vs %v", client1.Address(), client2.Address())
|
|
}
|
|
if client1.ClientTimeout() != client2.ClientTimeout() {
|
|
t.Fatalf("timeouts don't match: %v vs %v", client1.ClientTimeout(), client2.ClientTimeout())
|
|
}
|
|
if client1.CheckRetry() != nil && client2.CheckRetry() == nil {
|
|
t.Fatal("checkRetry functions don't match. client2 is nil.")
|
|
}
|
|
if (client1.Limiter() != nil && client2.Limiter() == nil) || (client1.Limiter() == nil && client2.Limiter() != nil) {
|
|
t.Fatalf("limiters don't match: %v vs %v", client1.Limiter(), client2.Limiter())
|
|
}
|
|
if client1.Limiter().Limit() != client2.Limiter().Limit() {
|
|
t.Fatalf("limiter limits don't match: %v vs %v", client1.Limiter().Limit(), client2.Limiter().Limit())
|
|
}
|
|
if client1.Limiter().Burst() != client2.Limiter().Burst() {
|
|
t.Fatalf("limiter bursts don't match: %v vs %v", client1.Limiter().Burst(), client2.Limiter().Burst())
|
|
}
|
|
if client1.MaxRetries() != client2.MaxRetries() {
|
|
t.Fatalf("maxRetries don't match: %v vs %v", client1.MaxRetries(), client2.MaxRetries())
|
|
}
|
|
if client1.OutputCurlString() != client2.OutputCurlString() {
|
|
t.Fatalf("outputCurlString doesn't match: %v vs %v", client1.OutputCurlString(), client2.OutputCurlString())
|
|
}
|
|
if client1.SRVLookup() != client2.SRVLookup() {
|
|
t.Fatalf("SRVLookup doesn't match: %v vs %v", client1.SRVLookup(), client2.SRVLookup())
|
|
}
|
|
if tt.config.CloneHeaders {
|
|
if !reflect.DeepEqual(client1.Headers(), client2.Headers()) {
|
|
t.Fatalf("Headers() don't match: %v vs %v", client1.Headers(), client2.Headers())
|
|
}
|
|
if client1.config.CloneHeaders != client2.config.CloneHeaders {
|
|
t.Fatalf("config.CloneHeaders doesn't match: %v vs %v", client1.config.CloneHeaders, client2.config.CloneHeaders)
|
|
}
|
|
if tt.headers != nil {
|
|
if !reflect.DeepEqual(*tt.headers, client2.Headers()) {
|
|
t.Fatalf("expected headers %v, actual %v", *tt.headers, client2.Headers())
|
|
}
|
|
}
|
|
}
|
|
if tt.config.ReadYourWrites && client1.replicationStateStore == nil {
|
|
t.Fatalf("replicationStateStore is nil")
|
|
}
|
|
if !reflect.DeepEqual(client1.replicationStateStore, client2.replicationStateStore) {
|
|
t.Fatalf("expected replicationStateStore %v, actual %v", client1.replicationStateStore,
|
|
client2.replicationStateStore)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSetHeadersRaceSafe(t *testing.T) {
|
|
client, err1 := NewClient(nil)
|
|
if err1 != nil {
|
|
t.Fatalf("NewClient failed: %v", err1)
|
|
}
|
|
|
|
start := make(chan interface{})
|
|
done := make(chan interface{})
|
|
|
|
testPairs := map[string]string{
|
|
"soda": "rootbeer",
|
|
"veggie": "carrots",
|
|
"fruit": "apples",
|
|
"color": "red",
|
|
"protein": "egg",
|
|
}
|
|
|
|
for key, value := range testPairs {
|
|
tmpKey := key
|
|
tmpValue := value
|
|
go func() {
|
|
<-start
|
|
// This test fails if here, you replace client.AddHeader(tmpKey, tmpValue) with:
|
|
// headerCopy := client.Header()
|
|
// headerCopy.AddHeader(tmpKey, tmpValue)
|
|
// client.SetHeader(headerCopy)
|
|
client.AddHeader(tmpKey, tmpValue)
|
|
done <- true
|
|
}()
|
|
}
|
|
|
|
// Start everyone at once.
|
|
close(start)
|
|
|
|
// Wait until everyone is done.
|
|
for i := 0; i < len(testPairs); i++ {
|
|
<-done
|
|
}
|
|
|
|
// Check that all the test pairs are in the resulting
|
|
// headers.
|
|
resultingHeaders := client.Headers()
|
|
for key, value := range testPairs {
|
|
if resultingHeaders.Get(key) != value {
|
|
t.Fatal("expected " + value + " for " + key)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestMergeReplicationStates(t *testing.T) {
|
|
type testCase struct {
|
|
name string
|
|
old []string
|
|
new string
|
|
expected []string
|
|
}
|
|
|
|
testCases := []testCase{
|
|
{
|
|
name: "empty-old",
|
|
old: nil,
|
|
new: "v1:cid:1:0:",
|
|
expected: []string{"v1:cid:1:0:"},
|
|
},
|
|
{
|
|
name: "old-smaller",
|
|
old: []string{"v1:cid:1:0:"},
|
|
new: "v1:cid:2:0:",
|
|
expected: []string{"v1:cid:2:0:"},
|
|
},
|
|
{
|
|
name: "old-bigger",
|
|
old: []string{"v1:cid:2:0:"},
|
|
new: "v1:cid:1:0:",
|
|
expected: []string{"v1:cid:2:0:"},
|
|
},
|
|
{
|
|
name: "mixed-single",
|
|
old: []string{"v1:cid:1:0:"},
|
|
new: "v1:cid:0:1:",
|
|
expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
|
|
},
|
|
{
|
|
name: "mixed-single-alt",
|
|
old: []string{"v1:cid:0:1:"},
|
|
new: "v1:cid:1:0:",
|
|
expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
|
|
},
|
|
{
|
|
name: "mixed-double",
|
|
old: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
|
|
new: "v1:cid:2:0:",
|
|
expected: []string{"v1:cid:0:1:", "v1:cid:2:0:"},
|
|
},
|
|
{
|
|
name: "newer-both",
|
|
old: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
|
|
new: "v1:cid:2:1:",
|
|
expected: []string{"v1:cid:2:1:"},
|
|
},
|
|
}
|
|
|
|
b64enc := func(ss []string) []string {
|
|
var ret []string
|
|
for _, s := range ss {
|
|
ret = append(ret, base64.StdEncoding.EncodeToString([]byte(s)))
|
|
}
|
|
return ret
|
|
}
|
|
b64dec := func(ss []string) []string {
|
|
var ret []string
|
|
for _, s := range ss {
|
|
d, err := base64.StdEncoding.DecodeString(s)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
ret = append(ret, string(d))
|
|
}
|
|
return ret
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
out := b64dec(MergeReplicationStates(b64enc(tc.old), base64.StdEncoding.EncodeToString([]byte(tc.new))))
|
|
if diff := deep.Equal(out, tc.expected); len(diff) != 0 {
|
|
t.Errorf("got=%v, expected=%v, diff=%v", out, tc.expected, diff)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestReplicationStateStore_recordState(t *testing.T) {
|
|
b64enc := func(s string) string {
|
|
return base64.StdEncoding.EncodeToString([]byte(s))
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
expected []string
|
|
resp []*Response
|
|
}{
|
|
{
|
|
name: "single",
|
|
resp: []*Response{
|
|
{
|
|
Response: &http.Response{
|
|
Header: map[string][]string{
|
|
HeaderIndex: {
|
|
b64enc("v1:cid:1:0:"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
expected: []string{
|
|
b64enc("v1:cid:1:0:"),
|
|
},
|
|
},
|
|
{
|
|
name: "empty",
|
|
resp: []*Response{
|
|
{
|
|
Response: &http.Response{
|
|
Header: map[string][]string{},
|
|
},
|
|
},
|
|
},
|
|
expected: nil,
|
|
},
|
|
{
|
|
name: "multiple",
|
|
resp: []*Response{
|
|
{
|
|
Response: &http.Response{
|
|
Header: map[string][]string{
|
|
HeaderIndex: {
|
|
b64enc("v1:cid:0:1:"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Response: &http.Response{
|
|
Header: map[string][]string{
|
|
HeaderIndex: {
|
|
b64enc("v1:cid:1:0:"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
expected: []string{
|
|
b64enc("v1:cid:0:1:"),
|
|
b64enc("v1:cid:1:0:"),
|
|
},
|
|
},
|
|
{
|
|
name: "duplicates",
|
|
resp: []*Response{
|
|
{
|
|
Response: &http.Response{
|
|
Header: map[string][]string{
|
|
HeaderIndex: {
|
|
b64enc("v1:cid:1:0:"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
Response: &http.Response{
|
|
Header: map[string][]string{
|
|
HeaderIndex: {
|
|
b64enc("v1:cid:1:0:"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
expected: []string{
|
|
b64enc("v1:cid:1:0:"),
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
w := &replicationStateStore{}
|
|
|
|
var wg sync.WaitGroup
|
|
for _, r := range tt.resp {
|
|
wg.Add(1)
|
|
go func(r *Response) {
|
|
defer wg.Done()
|
|
w.recordState(r)
|
|
}(r)
|
|
}
|
|
wg.Wait()
|
|
|
|
if !reflect.DeepEqual(tt.expected, w.store) {
|
|
t.Errorf("recordState(): expected states %v, actual %v", tt.expected, w.store)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestReplicationStateStore_requireState(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
states []string
|
|
req []*Request
|
|
expected []string
|
|
}{
|
|
{
|
|
name: "empty",
|
|
states: []string{},
|
|
req: []*Request{
|
|
{
|
|
Headers: make(http.Header),
|
|
},
|
|
},
|
|
expected: nil,
|
|
},
|
|
{
|
|
name: "basic",
|
|
states: []string{
|
|
"v1:cid:0:1:",
|
|
"v1:cid:1:0:",
|
|
},
|
|
req: []*Request{
|
|
{
|
|
Headers: make(http.Header),
|
|
},
|
|
},
|
|
expected: []string{
|
|
"v1:cid:0:1:",
|
|
"v1:cid:1:0:",
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
store := &replicationStateStore{
|
|
store: tt.states,
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
for _, r := range tt.req {
|
|
wg.Add(1)
|
|
go func(r *Request) {
|
|
defer wg.Done()
|
|
store.requireState(r)
|
|
}(r)
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
var actual []string
|
|
for _, r := range tt.req {
|
|
if values := r.Headers.Values(HeaderIndex); len(values) > 0 {
|
|
actual = append(actual, values...)
|
|
}
|
|
}
|
|
sort.Strings(actual)
|
|
if !reflect.DeepEqual(tt.expected, actual) {
|
|
t.Errorf("requireState(): expected states %v, actual %v", tt.expected, actual)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestClient_ReadYourWrites(t *testing.T) {
|
|
b64enc := func(s string) string {
|
|
return base64.StdEncoding.EncodeToString([]byte(s))
|
|
}
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
w.Header().Set(HeaderIndex, strings.TrimLeft(req.URL.Path, "/"))
|
|
})
|
|
|
|
tests := []struct {
|
|
name string
|
|
handler http.Handler
|
|
wantStates []string
|
|
values [][]string
|
|
clone bool
|
|
}{
|
|
{
|
|
name: "multiple_duplicates",
|
|
clone: false,
|
|
handler: handler,
|
|
wantStates: []string{
|
|
b64enc("v1:cid:0:4:"),
|
|
},
|
|
values: [][]string{
|
|
{
|
|
b64enc("v1:cid:0:4:"),
|
|
b64enc("v1:cid:0:2:"),
|
|
},
|
|
{
|
|
b64enc("v1:cid:0:4:"),
|
|
b64enc("v1:cid:0:2:"),
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "basic_clone",
|
|
clone: true,
|
|
handler: handler,
|
|
wantStates: []string{
|
|
b64enc("v1:cid:0:4:"),
|
|
},
|
|
values: [][]string{
|
|
{
|
|
b64enc("v1:cid:0:4:"),
|
|
},
|
|
{
|
|
b64enc("v1:cid:0:3:"),
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "multiple_clone",
|
|
clone: true,
|
|
handler: handler,
|
|
wantStates: []string{
|
|
b64enc("v1:cid:0:4:"),
|
|
},
|
|
values: [][]string{
|
|
{
|
|
b64enc("v1:cid:0:4:"),
|
|
b64enc("v1:cid:0:2:"),
|
|
},
|
|
{
|
|
b64enc("v1:cid:0:3:"),
|
|
b64enc("v1:cid:0:1:"),
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "multiple_duplicates_clone",
|
|
clone: true,
|
|
handler: handler,
|
|
wantStates: []string{
|
|
b64enc("v1:cid:0:4:"),
|
|
},
|
|
values: [][]string{
|
|
{
|
|
b64enc("v1:cid:0:4:"),
|
|
b64enc("v1:cid:0:2:"),
|
|
},
|
|
{
|
|
b64enc("v1:cid:0:4:"),
|
|
b64enc("v1:cid:0:2:"),
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
testRequest := func(client *Client, val string) {
|
|
req := client.NewRequest("GET", "/"+val)
|
|
req.Headers.Set(HeaderIndex, val)
|
|
resp, err := client.RawRequestWithContext(context.Background(), req)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// validate that the server provided a valid header value in its response
|
|
actual := resp.Header.Get(HeaderIndex)
|
|
if actual != val {
|
|
t.Errorf("expected header value %v, actual %v", val, actual)
|
|
}
|
|
}
|
|
|
|
config, ln := testHTTPServer(t, handler)
|
|
defer ln.Close()
|
|
|
|
config.ReadYourWrites = true
|
|
config.Address = fmt.Sprintf("http://%s", ln.Addr())
|
|
parent, err := NewClient(config)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < len(tt.values); i++ {
|
|
var c *Client
|
|
if tt.clone {
|
|
c, err = parent.Clone()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
} else {
|
|
c = parent
|
|
}
|
|
|
|
for _, val := range tt.values[i] {
|
|
wg.Add(1)
|
|
go func(val string) {
|
|
defer wg.Done()
|
|
testRequest(c, val)
|
|
}(val)
|
|
}
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
if !reflect.DeepEqual(tt.wantStates, parent.replicationStateStore.states()) {
|
|
t.Errorf("expected states %v, actual %v", tt.wantStates, parent.replicationStateStore.states())
|
|
}
|
|
})
|
|
}
|
|
}
|