api.Client: support isolated read-after-write (#12814)
- add new configuration option, ReadYourWrites, which enables a Client to provide cluster replication states to every request. A curated set of cluster replication states are stored in the replicationStateStore, and is shared across clones.
This commit is contained in:
parent
aa7de03ef5
commit
0b095588c6
121
api/client.go
121
api/client.go
|
@ -24,11 +24,12 @@ import (
|
|||
retryablehttp "github.com/hashicorp/go-retryablehttp"
|
||||
rootcerts "github.com/hashicorp/go-rootcerts"
|
||||
"github.com/hashicorp/go-secure-stdlib/parseutil"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/helper/strutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -49,6 +50,7 @@ const (
|
|||
EnvVaultMFA = "VAULT_MFA"
|
||||
EnvRateLimit = "VAULT_RATE_LIMIT"
|
||||
EnvHTTPProxy = "VAULT_HTTP_PROXY"
|
||||
HeaderIndex = "X-Vault-Index"
|
||||
)
|
||||
|
||||
// Deprecated values
|
||||
|
@ -133,8 +135,18 @@ type Config struct {
|
|||
// SRVLookup enables the client to lookup the host through DNS SRV lookup
|
||||
SRVLookup bool
|
||||
|
||||
// CloneHeaders ensures that the source client's headers are copied to its clone.
|
||||
// CloneHeaders ensures that the source client's headers are copied to
|
||||
// its clone.
|
||||
CloneHeaders bool
|
||||
|
||||
// ReadYourWrites ensures isolated read-after-write semantics by
|
||||
// providing discovered cluster replication states in each request.
|
||||
// The shared state is automatically propagated to all Client clones.
|
||||
//
|
||||
// Note: Careful consideration should be made prior to enabling this setting
|
||||
// since there will be a performance penalty paid upon each request.
|
||||
// This feature requires Enterprise server-side.
|
||||
ReadYourWrites bool
|
||||
}
|
||||
|
||||
// TLSConfig contains the parameters needed to configure TLS on the HTTP client
|
||||
|
@ -415,16 +427,17 @@ func parseRateLimit(val string) (rate float64, burst int, err error) {
|
|||
|
||||
// Client is the client to the Vault API. Create a client with NewClient.
|
||||
type Client struct {
|
||||
modifyLock sync.RWMutex
|
||||
addr *url.URL
|
||||
config *Config
|
||||
token string
|
||||
headers http.Header
|
||||
wrappingLookupFunc WrappingLookupFunc
|
||||
mfaCreds []string
|
||||
policyOverride bool
|
||||
requestCallbacks []RequestCallback
|
||||
responseCallbacks []ResponseCallback
|
||||
modifyLock sync.RWMutex
|
||||
addr *url.URL
|
||||
config *Config
|
||||
token string
|
||||
headers http.Header
|
||||
wrappingLookupFunc WrappingLookupFunc
|
||||
mfaCreds []string
|
||||
policyOverride bool
|
||||
requestCallbacks []RequestCallback
|
||||
responseCallbacks []ResponseCallback
|
||||
replicationStateStore *replicationStateStore
|
||||
}
|
||||
|
||||
// NewClient returns a new client for the given configuration.
|
||||
|
@ -498,6 +511,10 @@ func NewClient(c *Config) (*Client, error) {
|
|||
headers: make(http.Header),
|
||||
}
|
||||
|
||||
if c.ReadYourWrites {
|
||||
client.replicationStateStore = &replicationStateStore{}
|
||||
}
|
||||
|
||||
// Add the VaultRequest SSRF protection header
|
||||
client.headers[consts.RequestHeaderName] = []string{"true"}
|
||||
|
||||
|
@ -530,6 +547,7 @@ func (c *Client) CloneConfig() *Config {
|
|||
newConfig.OutputCurlString = c.config.OutputCurlString
|
||||
newConfig.SRVLookup = c.config.SRVLookup
|
||||
newConfig.CloneHeaders = c.config.CloneHeaders
|
||||
newConfig.ReadYourWrites = c.config.ReadYourWrites
|
||||
|
||||
// we specifically want a _copy_ of the client here, not a pointer to the original one
|
||||
newClient := *c.config.HttpClient
|
||||
|
@ -855,6 +873,32 @@ func (c *Client) CloneHeaders() bool {
|
|||
return c.config.CloneHeaders
|
||||
}
|
||||
|
||||
// SetReadYourWrites to prevent reading stale cluster replication state.
|
||||
func (c *Client) SetReadYourWrites(preventStaleReads bool) {
|
||||
c.modifyLock.Lock()
|
||||
defer c.modifyLock.Unlock()
|
||||
c.config.modifyLock.Lock()
|
||||
defer c.config.modifyLock.Unlock()
|
||||
|
||||
if preventStaleReads && c.replicationStateStore == nil {
|
||||
c.replicationStateStore = &replicationStateStore{}
|
||||
} else {
|
||||
c.replicationStateStore = nil
|
||||
}
|
||||
|
||||
c.config.ReadYourWrites = preventStaleReads
|
||||
}
|
||||
|
||||
// ReadYourWrites gets the configured value of ReadYourWrites
|
||||
func (c *Client) ReadYourWrites() bool {
|
||||
c.modifyLock.RLock()
|
||||
defer c.modifyLock.RUnlock()
|
||||
c.config.modifyLock.RLock()
|
||||
defer c.config.modifyLock.RUnlock()
|
||||
|
||||
return c.config.ReadYourWrites
|
||||
}
|
||||
|
||||
// Clone creates a new client with the same configuration. Note that the same
|
||||
// underlying http.Client is used; modifying the client from more than one
|
||||
// goroutine at once may not be safe, so modify the client as needed and then
|
||||
|
@ -886,6 +930,7 @@ func (c *Client) Clone() (*Client, error) {
|
|||
AgentAddress: config.AgentAddress,
|
||||
SRVLookup: config.SRVLookup,
|
||||
CloneHeaders: config.CloneHeaders,
|
||||
ReadYourWrites: config.ReadYourWrites,
|
||||
}
|
||||
client, err := NewClient(newConfig)
|
||||
if err != nil {
|
||||
|
@ -896,6 +941,8 @@ func (c *Client) Clone() (*Client, error) {
|
|||
client.SetHeaders(c.Headers().Clone())
|
||||
}
|
||||
|
||||
client.replicationStateStore = c.replicationStateStore
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
|
@ -1001,6 +1048,10 @@ func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Respon
|
|||
cb(r)
|
||||
}
|
||||
|
||||
if c.config.ReadYourWrites {
|
||||
c.replicationStateStore.requireState(r)
|
||||
}
|
||||
|
||||
if limiter != nil {
|
||||
limiter.Wait(ctx)
|
||||
}
|
||||
|
@ -1111,6 +1162,10 @@ START:
|
|||
for _, cb := range c.responseCallbacks {
|
||||
cb(result)
|
||||
}
|
||||
|
||||
if c.config.ReadYourWrites {
|
||||
c.replicationStateStore.recordState(result)
|
||||
}
|
||||
}
|
||||
if err := result.Error(); err != nil {
|
||||
return result, err
|
||||
|
@ -1152,7 +1207,7 @@ func (c *Client) WithResponseCallbacks(callbacks ...ResponseCallback) *Client {
|
|||
// by Vault in a response header.
|
||||
func RecordState(state *string) ResponseCallback {
|
||||
return func(resp *Response) {
|
||||
*state = resp.Header.Get("X-Vault-Index")
|
||||
*state = resp.Header.Get(HeaderIndex)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1162,7 +1217,7 @@ func RecordState(state *string) ResponseCallback {
|
|||
func RequireState(states ...string) RequestCallback {
|
||||
return func(req *Request) {
|
||||
for _, s := range states {
|
||||
req.Headers.Add("X-Vault-Index", s)
|
||||
req.Headers.Add(HeaderIndex, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1300,3 +1355,39 @@ func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bo
|
|||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// replicationStateStore is used to track cluster replication states
|
||||
// in order to ensure proper read-after-write semantics for a Client.
|
||||
type replicationStateStore struct {
|
||||
m sync.RWMutex
|
||||
store []string
|
||||
}
|
||||
|
||||
// recordState updates the store's replication states with the merger of all
|
||||
// states.
|
||||
func (w *replicationStateStore) recordState(resp *Response) {
|
||||
w.m.Lock()
|
||||
defer w.m.Unlock()
|
||||
newState := resp.Header.Get(HeaderIndex)
|
||||
if newState != "" {
|
||||
w.store = MergeReplicationStates(w.store, newState)
|
||||
}
|
||||
}
|
||||
|
||||
// requireState updates the Request with the store's current replication states.
|
||||
func (w *replicationStateStore) requireState(req *Request) {
|
||||
w.m.RLock()
|
||||
defer w.m.RUnlock()
|
||||
for _, s := range w.store {
|
||||
req.Headers.Add(HeaderIndex, s)
|
||||
}
|
||||
}
|
||||
|
||||
// states currently stored.
|
||||
func (w *replicationStateStore) states() []string {
|
||||
w.m.RLock()
|
||||
defer w.m.RUnlock()
|
||||
c := make([]string, len(w.store))
|
||||
copy(c, w.store)
|
||||
return c
|
||||
}
|
||||
|
|
|
@ -11,12 +11,15 @@ import (
|
|||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-test/deep"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
)
|
||||
|
||||
|
@ -412,8 +415,7 @@ func TestClientNonTransportRoundTripper(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestClone(t *testing.T) {
|
||||
type fields struct {
|
||||
}
|
||||
type fields struct{}
|
||||
tests := []struct {
|
||||
name string
|
||||
config *Config
|
||||
|
@ -433,6 +435,12 @@ func TestClone(t *testing.T) {
|
|||
"X-baz": []string{"qux"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "preventStaleReads",
|
||||
config: &Config{
|
||||
ReadYourWrites: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
@ -512,6 +520,13 @@ func TestClone(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
if tt.config.ReadYourWrites && client1.replicationStateStore == nil {
|
||||
t.Fatalf("replicationStateStore is nil")
|
||||
}
|
||||
if !reflect.DeepEqual(client1.replicationStateStore, client2.replicationStateStore) {
|
||||
t.Fatalf("expected replicationStateStore %v, actual %v", client1.replicationStateStore,
|
||||
client2.replicationStateStore)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -646,3 +661,328 @@ func TestMergeReplicationStates(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicationStateStore_recordState(t *testing.T) {
|
||||
b64enc := func(s string) string {
|
||||
return base64.StdEncoding.EncodeToString([]byte(s))
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
expected []string
|
||||
resp []*Response
|
||||
}{
|
||||
{
|
||||
name: "single",
|
||||
resp: []*Response{
|
||||
{
|
||||
Response: &http.Response{
|
||||
Header: map[string][]string{
|
||||
HeaderIndex: {
|
||||
b64enc("v1:cid:1:0:"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: []string{
|
||||
b64enc("v1:cid:1:0:"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
resp: []*Response{
|
||||
{
|
||||
Response: &http.Response{
|
||||
Header: map[string][]string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "multiple",
|
||||
resp: []*Response{
|
||||
{
|
||||
Response: &http.Response{
|
||||
Header: map[string][]string{
|
||||
HeaderIndex: {
|
||||
b64enc("v1:cid:0:1:"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Response: &http.Response{
|
||||
Header: map[string][]string{
|
||||
HeaderIndex: {
|
||||
b64enc("v1:cid:1:0:"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: []string{
|
||||
b64enc("v1:cid:0:1:"),
|
||||
b64enc("v1:cid:1:0:"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "duplicates",
|
||||
resp: []*Response{
|
||||
{
|
||||
Response: &http.Response{
|
||||
Header: map[string][]string{
|
||||
HeaderIndex: {
|
||||
b64enc("v1:cid:1:0:"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Response: &http.Response{
|
||||
Header: map[string][]string{
|
||||
HeaderIndex: {
|
||||
b64enc("v1:cid:1:0:"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: []string{
|
||||
b64enc("v1:cid:1:0:"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := &replicationStateStore{}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, r := range tt.resp {
|
||||
wg.Add(1)
|
||||
go func(r *Response) {
|
||||
defer wg.Done()
|
||||
w.recordState(r)
|
||||
}(r)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if !reflect.DeepEqual(tt.expected, w.store) {
|
||||
t.Errorf("recordState(): expected states %v, actual %v", tt.expected, w.store)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicationStateStore_requireState(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
states []string
|
||||
req []*Request
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
states: []string{},
|
||||
req: []*Request{
|
||||
{
|
||||
Headers: make(http.Header),
|
||||
},
|
||||
},
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "basic",
|
||||
states: []string{
|
||||
"v1:cid:0:1:",
|
||||
"v1:cid:1:0:",
|
||||
},
|
||||
req: []*Request{
|
||||
{
|
||||
Headers: make(http.Header),
|
||||
},
|
||||
},
|
||||
expected: []string{
|
||||
"v1:cid:0:1:",
|
||||
"v1:cid:1:0:",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
store := &replicationStateStore{
|
||||
store: tt.states,
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, r := range tt.req {
|
||||
wg.Add(1)
|
||||
go func(r *Request) {
|
||||
defer wg.Done()
|
||||
store.requireState(r)
|
||||
}(r)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
var actual []string
|
||||
for _, r := range tt.req {
|
||||
if values := r.Headers.Values(HeaderIndex); len(values) > 0 {
|
||||
actual = append(actual, values...)
|
||||
}
|
||||
}
|
||||
sort.Strings(actual)
|
||||
if !reflect.DeepEqual(tt.expected, actual) {
|
||||
t.Errorf("requireState(): expected states %v, actual %v", tt.expected, actual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_ReadYourWrites(t *testing.T) {
|
||||
b64enc := func(s string) string {
|
||||
return base64.StdEncoding.EncodeToString([]byte(s))
|
||||
}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Header().Set(HeaderIndex, strings.TrimLeft(req.URL.Path, "/"))
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.Handler
|
||||
wantStates []string
|
||||
values [][]string
|
||||
clone bool
|
||||
}{
|
||||
{
|
||||
name: "multiple_duplicates",
|
||||
clone: false,
|
||||
handler: handler,
|
||||
wantStates: []string{
|
||||
b64enc("v1:cid:0:4:"),
|
||||
},
|
||||
values: [][]string{
|
||||
{
|
||||
b64enc("v1:cid:0:4:"),
|
||||
b64enc("v1:cid:0:2:"),
|
||||
},
|
||||
{
|
||||
b64enc("v1:cid:0:4:"),
|
||||
b64enc("v1:cid:0:2:"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "basic_clone",
|
||||
clone: true,
|
||||
handler: handler,
|
||||
wantStates: []string{
|
||||
b64enc("v1:cid:0:4:"),
|
||||
},
|
||||
values: [][]string{
|
||||
{
|
||||
b64enc("v1:cid:0:4:"),
|
||||
},
|
||||
{
|
||||
b64enc("v1:cid:0:3:"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple_clone",
|
||||
clone: true,
|
||||
handler: handler,
|
||||
wantStates: []string{
|
||||
b64enc("v1:cid:0:4:"),
|
||||
},
|
||||
values: [][]string{
|
||||
{
|
||||
b64enc("v1:cid:0:4:"),
|
||||
b64enc("v1:cid:0:2:"),
|
||||
},
|
||||
{
|
||||
b64enc("v1:cid:0:3:"),
|
||||
b64enc("v1:cid:0:1:"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple_duplicates_clone",
|
||||
clone: true,
|
||||
handler: handler,
|
||||
wantStates: []string{
|
||||
b64enc("v1:cid:0:4:"),
|
||||
},
|
||||
values: [][]string{
|
||||
{
|
||||
b64enc("v1:cid:0:4:"),
|
||||
b64enc("v1:cid:0:2:"),
|
||||
},
|
||||
{
|
||||
b64enc("v1:cid:0:4:"),
|
||||
b64enc("v1:cid:0:2:"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testRequest := func(client *Client, val string) {
|
||||
req := client.NewRequest("GET", "/"+val)
|
||||
req.Headers.Set(HeaderIndex, val)
|
||||
resp, err := client.RawRequestWithContext(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// validate that the server provided a valid header value in its response
|
||||
actual := resp.Header.Get(HeaderIndex)
|
||||
if actual != val {
|
||||
t.Errorf("expected header value %v, actual %v", val, actual)
|
||||
}
|
||||
}
|
||||
|
||||
config, ln := testHTTPServer(t, handler)
|
||||
defer ln.Close()
|
||||
|
||||
config.ReadYourWrites = true
|
||||
config.Address = fmt.Sprintf("http://%s", ln.Addr())
|
||||
parent, err := NewClient(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < len(tt.values); i++ {
|
||||
var c *Client
|
||||
if tt.clone {
|
||||
c, err = parent.Clone()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
c = parent
|
||||
}
|
||||
|
||||
for _, val := range tt.values[i] {
|
||||
wg.Add(1)
|
||||
go func(val string) {
|
||||
defer wg.Done()
|
||||
testRequest(c, val)
|
||||
}(val)
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if !reflect.DeepEqual(tt.wantStates, parent.replicationStateStore.states()) {
|
||||
t.Errorf("expected states %v, actual %v", tt.wantStates, parent.replicationStateStore.states())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
```release-note:improvement
|
||||
api: Add configuration option for ensuring isolated read-after-write semantics for all Client requests.
|
||||
```
|
Loading…
Reference in New Issue