api.Client: support isolated read-after-write (#12814)

- add new configuration option, ReadYourWrites, which enables a Client
  to provide cluster replication states to every request. A curated set
  of cluster replication states are stored in the replicationStateStore,
  and is shared across clones.
This commit is contained in:
Ben Ash 2021-10-14 14:51:31 -04:00 committed by GitHub
parent aa7de03ef5
commit 0b095588c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 451 additions and 17 deletions

View File

@ -24,11 +24,12 @@ import (
retryablehttp "github.com/hashicorp/go-retryablehttp"
rootcerts "github.com/hashicorp/go-rootcerts"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"golang.org/x/net/http2"
"golang.org/x/time/rate"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/hashicorp/vault/sdk/logical"
"golang.org/x/net/http2"
"golang.org/x/time/rate"
)
const (
@ -49,6 +50,7 @@ const (
EnvVaultMFA = "VAULT_MFA"
EnvRateLimit = "VAULT_RATE_LIMIT"
EnvHTTPProxy = "VAULT_HTTP_PROXY"
HeaderIndex = "X-Vault-Index"
)
// Deprecated values
@ -133,8 +135,18 @@ type Config struct {
// SRVLookup enables the client to lookup the host through DNS SRV lookup
SRVLookup bool
// CloneHeaders ensures that the source client's headers are copied to its clone.
// CloneHeaders ensures that the source client's headers are copied to
// its clone.
CloneHeaders bool
// ReadYourWrites ensures isolated read-after-write semantics by
// providing discovered cluster replication states in each request.
// The shared state is automatically propagated to all Client clones.
//
// Note: Careful consideration should be made prior to enabling this setting
// since there will be a performance penalty paid upon each request.
// This feature requires Enterprise server-side.
ReadYourWrites bool
}
// TLSConfig contains the parameters needed to configure TLS on the HTTP client
@ -415,16 +427,17 @@ func parseRateLimit(val string) (rate float64, burst int, err error) {
// Client is the client to the Vault API. Create a client with NewClient.
type Client struct {
modifyLock sync.RWMutex
addr *url.URL
config *Config
token string
headers http.Header
wrappingLookupFunc WrappingLookupFunc
mfaCreds []string
policyOverride bool
requestCallbacks []RequestCallback
responseCallbacks []ResponseCallback
modifyLock sync.RWMutex
addr *url.URL
config *Config
token string
headers http.Header
wrappingLookupFunc WrappingLookupFunc
mfaCreds []string
policyOverride bool
requestCallbacks []RequestCallback
responseCallbacks []ResponseCallback
replicationStateStore *replicationStateStore
}
// NewClient returns a new client for the given configuration.
@ -498,6 +511,10 @@ func NewClient(c *Config) (*Client, error) {
headers: make(http.Header),
}
if c.ReadYourWrites {
client.replicationStateStore = &replicationStateStore{}
}
// Add the VaultRequest SSRF protection header
client.headers[consts.RequestHeaderName] = []string{"true"}
@ -530,6 +547,7 @@ func (c *Client) CloneConfig() *Config {
newConfig.OutputCurlString = c.config.OutputCurlString
newConfig.SRVLookup = c.config.SRVLookup
newConfig.CloneHeaders = c.config.CloneHeaders
newConfig.ReadYourWrites = c.config.ReadYourWrites
// we specifically want a _copy_ of the client here, not a pointer to the original one
newClient := *c.config.HttpClient
@ -855,6 +873,32 @@ func (c *Client) CloneHeaders() bool {
return c.config.CloneHeaders
}
// SetReadYourWrites to prevent reading stale cluster replication state.
func (c *Client) SetReadYourWrites(preventStaleReads bool) {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
c.config.modifyLock.Lock()
defer c.config.modifyLock.Unlock()
if preventStaleReads && c.replicationStateStore == nil {
c.replicationStateStore = &replicationStateStore{}
} else {
c.replicationStateStore = nil
}
c.config.ReadYourWrites = preventStaleReads
}
// ReadYourWrites gets the configured value of ReadYourWrites
func (c *Client) ReadYourWrites() bool {
c.modifyLock.RLock()
defer c.modifyLock.RUnlock()
c.config.modifyLock.RLock()
defer c.config.modifyLock.RUnlock()
return c.config.ReadYourWrites
}
// Clone creates a new client with the same configuration. Note that the same
// underlying http.Client is used; modifying the client from more than one
// goroutine at once may not be safe, so modify the client as needed and then
@ -886,6 +930,7 @@ func (c *Client) Clone() (*Client, error) {
AgentAddress: config.AgentAddress,
SRVLookup: config.SRVLookup,
CloneHeaders: config.CloneHeaders,
ReadYourWrites: config.ReadYourWrites,
}
client, err := NewClient(newConfig)
if err != nil {
@ -896,6 +941,8 @@ func (c *Client) Clone() (*Client, error) {
client.SetHeaders(c.Headers().Clone())
}
client.replicationStateStore = c.replicationStateStore
return client, nil
}
@ -1001,6 +1048,10 @@ func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Respon
cb(r)
}
if c.config.ReadYourWrites {
c.replicationStateStore.requireState(r)
}
if limiter != nil {
limiter.Wait(ctx)
}
@ -1111,6 +1162,10 @@ START:
for _, cb := range c.responseCallbacks {
cb(result)
}
if c.config.ReadYourWrites {
c.replicationStateStore.recordState(result)
}
}
if err := result.Error(); err != nil {
return result, err
@ -1152,7 +1207,7 @@ func (c *Client) WithResponseCallbacks(callbacks ...ResponseCallback) *Client {
// by Vault in a response header.
func RecordState(state *string) ResponseCallback {
return func(resp *Response) {
*state = resp.Header.Get("X-Vault-Index")
*state = resp.Header.Get(HeaderIndex)
}
}
@ -1162,7 +1217,7 @@ func RecordState(state *string) ResponseCallback {
func RequireState(states ...string) RequestCallback {
return func(req *Request) {
for _, s := range states {
req.Headers.Add("X-Vault-Index", s)
req.Headers.Add(HeaderIndex, s)
}
}
}
@ -1300,3 +1355,39 @@ func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bo
}
return false, nil
}
// replicationStateStore is used to track cluster replication states
// in order to ensure proper read-after-write semantics for a Client.
type replicationStateStore struct {
m sync.RWMutex
store []string
}
// recordState updates the store's replication states with the merger of all
// states.
func (w *replicationStateStore) recordState(resp *Response) {
w.m.Lock()
defer w.m.Unlock()
newState := resp.Header.Get(HeaderIndex)
if newState != "" {
w.store = MergeReplicationStates(w.store, newState)
}
}
// requireState updates the Request with the store's current replication states.
func (w *replicationStateStore) requireState(req *Request) {
w.m.RLock()
defer w.m.RUnlock()
for _, s := range w.store {
req.Headers.Add(HeaderIndex, s)
}
}
// states currently stored.
func (w *replicationStateStore) states() []string {
w.m.RLock()
defer w.m.RUnlock()
c := make([]string, len(w.store))
copy(c, w.store)
return c
}

View File

@ -11,12 +11,15 @@ import (
"net/url"
"os"
"reflect"
"sort"
"strings"
"sync"
"testing"
"time"
"github.com/go-test/deep"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/consts"
)
@ -412,8 +415,7 @@ func TestClientNonTransportRoundTripper(t *testing.T) {
}
func TestClone(t *testing.T) {
type fields struct {
}
type fields struct{}
tests := []struct {
name string
config *Config
@ -433,6 +435,12 @@ func TestClone(t *testing.T) {
"X-baz": []string{"qux"},
},
},
{
name: "preventStaleReads",
config: &Config{
ReadYourWrites: true,
},
},
}
for _, tt := range tests {
@ -512,6 +520,13 @@ func TestClone(t *testing.T) {
}
}
}
if tt.config.ReadYourWrites && client1.replicationStateStore == nil {
t.Fatalf("replicationStateStore is nil")
}
if !reflect.DeepEqual(client1.replicationStateStore, client2.replicationStateStore) {
t.Fatalf("expected replicationStateStore %v, actual %v", client1.replicationStateStore,
client2.replicationStateStore)
}
})
}
}
@ -646,3 +661,328 @@ func TestMergeReplicationStates(t *testing.T) {
})
}
}
func TestReplicationStateStore_recordState(t *testing.T) {
b64enc := func(s string) string {
return base64.StdEncoding.EncodeToString([]byte(s))
}
tests := []struct {
name string
expected []string
resp []*Response
}{
{
name: "single",
resp: []*Response{
{
Response: &http.Response{
Header: map[string][]string{
HeaderIndex: {
b64enc("v1:cid:1:0:"),
},
},
},
},
},
expected: []string{
b64enc("v1:cid:1:0:"),
},
},
{
name: "empty",
resp: []*Response{
{
Response: &http.Response{
Header: map[string][]string{},
},
},
},
expected: nil,
},
{
name: "multiple",
resp: []*Response{
{
Response: &http.Response{
Header: map[string][]string{
HeaderIndex: {
b64enc("v1:cid:0:1:"),
},
},
},
},
{
Response: &http.Response{
Header: map[string][]string{
HeaderIndex: {
b64enc("v1:cid:1:0:"),
},
},
},
},
},
expected: []string{
b64enc("v1:cid:0:1:"),
b64enc("v1:cid:1:0:"),
},
},
{
name: "duplicates",
resp: []*Response{
{
Response: &http.Response{
Header: map[string][]string{
HeaderIndex: {
b64enc("v1:cid:1:0:"),
},
},
},
},
{
Response: &http.Response{
Header: map[string][]string{
HeaderIndex: {
b64enc("v1:cid:1:0:"),
},
},
},
},
},
expected: []string{
b64enc("v1:cid:1:0:"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := &replicationStateStore{}
var wg sync.WaitGroup
for _, r := range tt.resp {
wg.Add(1)
go func(r *Response) {
defer wg.Done()
w.recordState(r)
}(r)
}
wg.Wait()
if !reflect.DeepEqual(tt.expected, w.store) {
t.Errorf("recordState(): expected states %v, actual %v", tt.expected, w.store)
}
})
}
}
func TestReplicationStateStore_requireState(t *testing.T) {
tests := []struct {
name string
states []string
req []*Request
expected []string
}{
{
name: "empty",
states: []string{},
req: []*Request{
{
Headers: make(http.Header),
},
},
expected: nil,
},
{
name: "basic",
states: []string{
"v1:cid:0:1:",
"v1:cid:1:0:",
},
req: []*Request{
{
Headers: make(http.Header),
},
},
expected: []string{
"v1:cid:0:1:",
"v1:cid:1:0:",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := &replicationStateStore{
store: tt.states,
}
var wg sync.WaitGroup
for _, r := range tt.req {
wg.Add(1)
go func(r *Request) {
defer wg.Done()
store.requireState(r)
}(r)
}
wg.Wait()
var actual []string
for _, r := range tt.req {
if values := r.Headers.Values(HeaderIndex); len(values) > 0 {
actual = append(actual, values...)
}
}
sort.Strings(actual)
if !reflect.DeepEqual(tt.expected, actual) {
t.Errorf("requireState(): expected states %v, actual %v", tt.expected, actual)
}
})
}
}
func TestClient_ReadYourWrites(t *testing.T) {
b64enc := func(s string) string {
return base64.StdEncoding.EncodeToString([]byte(s))
}
handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set(HeaderIndex, strings.TrimLeft(req.URL.Path, "/"))
})
tests := []struct {
name string
handler http.Handler
wantStates []string
values [][]string
clone bool
}{
{
name: "multiple_duplicates",
clone: false,
handler: handler,
wantStates: []string{
b64enc("v1:cid:0:4:"),
},
values: [][]string{
{
b64enc("v1:cid:0:4:"),
b64enc("v1:cid:0:2:"),
},
{
b64enc("v1:cid:0:4:"),
b64enc("v1:cid:0:2:"),
},
},
},
{
name: "basic_clone",
clone: true,
handler: handler,
wantStates: []string{
b64enc("v1:cid:0:4:"),
},
values: [][]string{
{
b64enc("v1:cid:0:4:"),
},
{
b64enc("v1:cid:0:3:"),
},
},
},
{
name: "multiple_clone",
clone: true,
handler: handler,
wantStates: []string{
b64enc("v1:cid:0:4:"),
},
values: [][]string{
{
b64enc("v1:cid:0:4:"),
b64enc("v1:cid:0:2:"),
},
{
b64enc("v1:cid:0:3:"),
b64enc("v1:cid:0:1:"),
},
},
},
{
name: "multiple_duplicates_clone",
clone: true,
handler: handler,
wantStates: []string{
b64enc("v1:cid:0:4:"),
},
values: [][]string{
{
b64enc("v1:cid:0:4:"),
b64enc("v1:cid:0:2:"),
},
{
b64enc("v1:cid:0:4:"),
b64enc("v1:cid:0:2:"),
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testRequest := func(client *Client, val string) {
req := client.NewRequest("GET", "/"+val)
req.Headers.Set(HeaderIndex, val)
resp, err := client.RawRequestWithContext(context.Background(), req)
if err != nil {
t.Fatal(err)
}
// validate that the server provided a valid header value in its response
actual := resp.Header.Get(HeaderIndex)
if actual != val {
t.Errorf("expected header value %v, actual %v", val, actual)
}
}
config, ln := testHTTPServer(t, handler)
defer ln.Close()
config.ReadYourWrites = true
config.Address = fmt.Sprintf("http://%s", ln.Addr())
parent, err := NewClient(config)
if err != nil {
t.Fatal(err)
}
var wg sync.WaitGroup
for i := 0; i < len(tt.values); i++ {
var c *Client
if tt.clone {
c, err = parent.Clone()
if err != nil {
t.Fatal(err)
}
} else {
c = parent
}
for _, val := range tt.values[i] {
wg.Add(1)
go func(val string) {
defer wg.Done()
testRequest(c, val)
}(val)
}
}
wg.Wait()
if !reflect.DeepEqual(tt.wantStates, parent.replicationStateStore.states()) {
t.Errorf("expected states %v, actual %v", tt.wantStates, parent.replicationStateStore.states())
}
})
}
}

3
changelog/12814.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:improvement
api: Add configuration option for ensuring isolated read-after-write semantics for all Client requests.
```