From 0b095588c65f2fb33c06048a55c4e296e19464b1 Mon Sep 17 00:00:00 2001 From: Ben Ash <32777270+benashz@users.noreply.github.com> Date: Thu, 14 Oct 2021 14:51:31 -0400 Subject: [PATCH] api.Client: support isolated read-after-write (#12814) - add new configuration option, ReadYourWrites, which enables a Client to provide cluster replication states to every request. A curated set of cluster replication states are stored in the replicationStateStore, and is shared across clones. --- api/client.go | 121 ++++++++++++++-- api/client_test.go | 344 +++++++++++++++++++++++++++++++++++++++++++- changelog/12814.txt | 3 + 3 files changed, 451 insertions(+), 17 deletions(-) create mode 100644 changelog/12814.txt diff --git a/api/client.go b/api/client.go index 34974d742..2d1c3b683 100644 --- a/api/client.go +++ b/api/client.go @@ -24,11 +24,12 @@ import ( retryablehttp "github.com/hashicorp/go-retryablehttp" rootcerts "github.com/hashicorp/go-rootcerts" "github.com/hashicorp/go-secure-stdlib/parseutil" + "golang.org/x/net/http2" + "golang.org/x/time/rate" + "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/strutil" "github.com/hashicorp/vault/sdk/logical" - "golang.org/x/net/http2" - "golang.org/x/time/rate" ) const ( @@ -49,6 +50,7 @@ const ( EnvVaultMFA = "VAULT_MFA" EnvRateLimit = "VAULT_RATE_LIMIT" EnvHTTPProxy = "VAULT_HTTP_PROXY" + HeaderIndex = "X-Vault-Index" ) // Deprecated values @@ -133,8 +135,18 @@ type Config struct { // SRVLookup enables the client to lookup the host through DNS SRV lookup SRVLookup bool - // CloneHeaders ensures that the source client's headers are copied to its clone. + // CloneHeaders ensures that the source client's headers are copied to + // its clone. CloneHeaders bool + + // ReadYourWrites ensures isolated read-after-write semantics by + // providing discovered cluster replication states in each request. + // The shared state is automatically propagated to all Client clones. + // + // Note: Careful consideration should be made prior to enabling this setting + // since there will be a performance penalty paid upon each request. + // This feature requires Enterprise server-side. + ReadYourWrites bool } // TLSConfig contains the parameters needed to configure TLS on the HTTP client @@ -415,16 +427,17 @@ func parseRateLimit(val string) (rate float64, burst int, err error) { // Client is the client to the Vault API. Create a client with NewClient. type Client struct { - modifyLock sync.RWMutex - addr *url.URL - config *Config - token string - headers http.Header - wrappingLookupFunc WrappingLookupFunc - mfaCreds []string - policyOverride bool - requestCallbacks []RequestCallback - responseCallbacks []ResponseCallback + modifyLock sync.RWMutex + addr *url.URL + config *Config + token string + headers http.Header + wrappingLookupFunc WrappingLookupFunc + mfaCreds []string + policyOverride bool + requestCallbacks []RequestCallback + responseCallbacks []ResponseCallback + replicationStateStore *replicationStateStore } // NewClient returns a new client for the given configuration. @@ -498,6 +511,10 @@ func NewClient(c *Config) (*Client, error) { headers: make(http.Header), } + if c.ReadYourWrites { + client.replicationStateStore = &replicationStateStore{} + } + // Add the VaultRequest SSRF protection header client.headers[consts.RequestHeaderName] = []string{"true"} @@ -530,6 +547,7 @@ func (c *Client) CloneConfig() *Config { newConfig.OutputCurlString = c.config.OutputCurlString newConfig.SRVLookup = c.config.SRVLookup newConfig.CloneHeaders = c.config.CloneHeaders + newConfig.ReadYourWrites = c.config.ReadYourWrites // we specifically want a _copy_ of the client here, not a pointer to the original one newClient := *c.config.HttpClient @@ -855,6 +873,32 @@ func (c *Client) CloneHeaders() bool { return c.config.CloneHeaders } +// SetReadYourWrites to prevent reading stale cluster replication state. +func (c *Client) SetReadYourWrites(preventStaleReads bool) { + c.modifyLock.Lock() + defer c.modifyLock.Unlock() + c.config.modifyLock.Lock() + defer c.config.modifyLock.Unlock() + + if preventStaleReads && c.replicationStateStore == nil { + c.replicationStateStore = &replicationStateStore{} + } else { + c.replicationStateStore = nil + } + + c.config.ReadYourWrites = preventStaleReads +} + +// ReadYourWrites gets the configured value of ReadYourWrites +func (c *Client) ReadYourWrites() bool { + c.modifyLock.RLock() + defer c.modifyLock.RUnlock() + c.config.modifyLock.RLock() + defer c.config.modifyLock.RUnlock() + + return c.config.ReadYourWrites +} + // Clone creates a new client with the same configuration. Note that the same // underlying http.Client is used; modifying the client from more than one // goroutine at once may not be safe, so modify the client as needed and then @@ -886,6 +930,7 @@ func (c *Client) Clone() (*Client, error) { AgentAddress: config.AgentAddress, SRVLookup: config.SRVLookup, CloneHeaders: config.CloneHeaders, + ReadYourWrites: config.ReadYourWrites, } client, err := NewClient(newConfig) if err != nil { @@ -896,6 +941,8 @@ func (c *Client) Clone() (*Client, error) { client.SetHeaders(c.Headers().Clone()) } + client.replicationStateStore = c.replicationStateStore + return client, nil } @@ -1001,6 +1048,10 @@ func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Respon cb(r) } + if c.config.ReadYourWrites { + c.replicationStateStore.requireState(r) + } + if limiter != nil { limiter.Wait(ctx) } @@ -1111,6 +1162,10 @@ START: for _, cb := range c.responseCallbacks { cb(result) } + + if c.config.ReadYourWrites { + c.replicationStateStore.recordState(result) + } } if err := result.Error(); err != nil { return result, err @@ -1152,7 +1207,7 @@ func (c *Client) WithResponseCallbacks(callbacks ...ResponseCallback) *Client { // by Vault in a response header. func RecordState(state *string) ResponseCallback { return func(resp *Response) { - *state = resp.Header.Get("X-Vault-Index") + *state = resp.Header.Get(HeaderIndex) } } @@ -1162,7 +1217,7 @@ func RecordState(state *string) ResponseCallback { func RequireState(states ...string) RequestCallback { return func(req *Request) { for _, s := range states { - req.Headers.Add("X-Vault-Index", s) + req.Headers.Add(HeaderIndex, s) } } } @@ -1300,3 +1355,39 @@ func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bo } return false, nil } + +// replicationStateStore is used to track cluster replication states +// in order to ensure proper read-after-write semantics for a Client. +type replicationStateStore struct { + m sync.RWMutex + store []string +} + +// recordState updates the store's replication states with the merger of all +// states. +func (w *replicationStateStore) recordState(resp *Response) { + w.m.Lock() + defer w.m.Unlock() + newState := resp.Header.Get(HeaderIndex) + if newState != "" { + w.store = MergeReplicationStates(w.store, newState) + } +} + +// requireState updates the Request with the store's current replication states. +func (w *replicationStateStore) requireState(req *Request) { + w.m.RLock() + defer w.m.RUnlock() + for _, s := range w.store { + req.Headers.Add(HeaderIndex, s) + } +} + +// states currently stored. +func (w *replicationStateStore) states() []string { + w.m.RLock() + defer w.m.RUnlock() + c := make([]string, len(w.store)) + copy(c, w.store) + return c +} diff --git a/api/client_test.go b/api/client_test.go index 3b306a927..f335a765b 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -11,12 +11,15 @@ import ( "net/url" "os" "reflect" + "sort" "strings" + "sync" "testing" "time" "github.com/go-test/deep" "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/sdk/helper/consts" ) @@ -412,8 +415,7 @@ func TestClientNonTransportRoundTripper(t *testing.T) { } func TestClone(t *testing.T) { - type fields struct { - } + type fields struct{} tests := []struct { name string config *Config @@ -433,6 +435,12 @@ func TestClone(t *testing.T) { "X-baz": []string{"qux"}, }, }, + { + name: "preventStaleReads", + config: &Config{ + ReadYourWrites: true, + }, + }, } for _, tt := range tests { @@ -512,6 +520,13 @@ func TestClone(t *testing.T) { } } } + if tt.config.ReadYourWrites && client1.replicationStateStore == nil { + t.Fatalf("replicationStateStore is nil") + } + if !reflect.DeepEqual(client1.replicationStateStore, client2.replicationStateStore) { + t.Fatalf("expected replicationStateStore %v, actual %v", client1.replicationStateStore, + client2.replicationStateStore) + } }) } } @@ -646,3 +661,328 @@ func TestMergeReplicationStates(t *testing.T) { }) } } + +func TestReplicationStateStore_recordState(t *testing.T) { + b64enc := func(s string) string { + return base64.StdEncoding.EncodeToString([]byte(s)) + } + + tests := []struct { + name string + expected []string + resp []*Response + }{ + { + name: "single", + resp: []*Response{ + { + Response: &http.Response{ + Header: map[string][]string{ + HeaderIndex: { + b64enc("v1:cid:1:0:"), + }, + }, + }, + }, + }, + expected: []string{ + b64enc("v1:cid:1:0:"), + }, + }, + { + name: "empty", + resp: []*Response{ + { + Response: &http.Response{ + Header: map[string][]string{}, + }, + }, + }, + expected: nil, + }, + { + name: "multiple", + resp: []*Response{ + { + Response: &http.Response{ + Header: map[string][]string{ + HeaderIndex: { + b64enc("v1:cid:0:1:"), + }, + }, + }, + }, + { + Response: &http.Response{ + Header: map[string][]string{ + HeaderIndex: { + b64enc("v1:cid:1:0:"), + }, + }, + }, + }, + }, + expected: []string{ + b64enc("v1:cid:0:1:"), + b64enc("v1:cid:1:0:"), + }, + }, + { + name: "duplicates", + resp: []*Response{ + { + Response: &http.Response{ + Header: map[string][]string{ + HeaderIndex: { + b64enc("v1:cid:1:0:"), + }, + }, + }, + }, + { + Response: &http.Response{ + Header: map[string][]string{ + HeaderIndex: { + b64enc("v1:cid:1:0:"), + }, + }, + }, + }, + }, + expected: []string{ + b64enc("v1:cid:1:0:"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := &replicationStateStore{} + + var wg sync.WaitGroup + for _, r := range tt.resp { + wg.Add(1) + go func(r *Response) { + defer wg.Done() + w.recordState(r) + }(r) + } + wg.Wait() + + if !reflect.DeepEqual(tt.expected, w.store) { + t.Errorf("recordState(): expected states %v, actual %v", tt.expected, w.store) + } + }) + } +} + +func TestReplicationStateStore_requireState(t *testing.T) { + tests := []struct { + name string + states []string + req []*Request + expected []string + }{ + { + name: "empty", + states: []string{}, + req: []*Request{ + { + Headers: make(http.Header), + }, + }, + expected: nil, + }, + { + name: "basic", + states: []string{ + "v1:cid:0:1:", + "v1:cid:1:0:", + }, + req: []*Request{ + { + Headers: make(http.Header), + }, + }, + expected: []string{ + "v1:cid:0:1:", + "v1:cid:1:0:", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &replicationStateStore{ + store: tt.states, + } + + var wg sync.WaitGroup + for _, r := range tt.req { + wg.Add(1) + go func(r *Request) { + defer wg.Done() + store.requireState(r) + }(r) + } + + wg.Wait() + + var actual []string + for _, r := range tt.req { + if values := r.Headers.Values(HeaderIndex); len(values) > 0 { + actual = append(actual, values...) + } + } + sort.Strings(actual) + if !reflect.DeepEqual(tt.expected, actual) { + t.Errorf("requireState(): expected states %v, actual %v", tt.expected, actual) + } + }) + } +} + +func TestClient_ReadYourWrites(t *testing.T) { + b64enc := func(s string) string { + return base64.StdEncoding.EncodeToString([]byte(s)) + } + + handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Header().Set(HeaderIndex, strings.TrimLeft(req.URL.Path, "/")) + }) + + tests := []struct { + name string + handler http.Handler + wantStates []string + values [][]string + clone bool + }{ + { + name: "multiple_duplicates", + clone: false, + handler: handler, + wantStates: []string{ + b64enc("v1:cid:0:4:"), + }, + values: [][]string{ + { + b64enc("v1:cid:0:4:"), + b64enc("v1:cid:0:2:"), + }, + { + b64enc("v1:cid:0:4:"), + b64enc("v1:cid:0:2:"), + }, + }, + }, + { + name: "basic_clone", + clone: true, + handler: handler, + wantStates: []string{ + b64enc("v1:cid:0:4:"), + }, + values: [][]string{ + { + b64enc("v1:cid:0:4:"), + }, + { + b64enc("v1:cid:0:3:"), + }, + }, + }, + { + name: "multiple_clone", + clone: true, + handler: handler, + wantStates: []string{ + b64enc("v1:cid:0:4:"), + }, + values: [][]string{ + { + b64enc("v1:cid:0:4:"), + b64enc("v1:cid:0:2:"), + }, + { + b64enc("v1:cid:0:3:"), + b64enc("v1:cid:0:1:"), + }, + }, + }, + { + name: "multiple_duplicates_clone", + clone: true, + handler: handler, + wantStates: []string{ + b64enc("v1:cid:0:4:"), + }, + values: [][]string{ + { + b64enc("v1:cid:0:4:"), + b64enc("v1:cid:0:2:"), + }, + { + b64enc("v1:cid:0:4:"), + b64enc("v1:cid:0:2:"), + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testRequest := func(client *Client, val string) { + req := client.NewRequest("GET", "/"+val) + req.Headers.Set(HeaderIndex, val) + resp, err := client.RawRequestWithContext(context.Background(), req) + if err != nil { + t.Fatal(err) + } + + // validate that the server provided a valid header value in its response + actual := resp.Header.Get(HeaderIndex) + if actual != val { + t.Errorf("expected header value %v, actual %v", val, actual) + } + } + + config, ln := testHTTPServer(t, handler) + defer ln.Close() + + config.ReadYourWrites = true + config.Address = fmt.Sprintf("http://%s", ln.Addr()) + parent, err := NewClient(config) + if err != nil { + t.Fatal(err) + } + + var wg sync.WaitGroup + for i := 0; i < len(tt.values); i++ { + var c *Client + if tt.clone { + c, err = parent.Clone() + if err != nil { + t.Fatal(err) + } + } else { + c = parent + } + + for _, val := range tt.values[i] { + wg.Add(1) + go func(val string) { + defer wg.Done() + testRequest(c, val) + }(val) + } + } + + wg.Wait() + + if !reflect.DeepEqual(tt.wantStates, parent.replicationStateStore.states()) { + t.Errorf("expected states %v, actual %v", tt.wantStates, parent.replicationStateStore.states()) + } + }) + } +} diff --git a/changelog/12814.txt b/changelog/12814.txt new file mode 100644 index 000000000..9d5b541d6 --- /dev/null +++ b/changelog/12814.txt @@ -0,0 +1,3 @@ +```release-note:improvement +api: Add configuration option for ensuring isolated read-after-write semantics for all Client requests. +```