From c1ddfbb5389889c8557d548439dd8e30aafbc42b Mon Sep 17 00:00:00 2001 From: Nick Cabatoff Date: Wed, 24 Feb 2021 06:58:10 -0500 Subject: [PATCH] OSS parts of the new client controlled consistency feature (#10974) --- api/client.go | 90 +++++++++++- command/agent.go | 42 ++++-- command/agent/cache/api_proxy.go | 136 +++++++++++++++++- command/agent/cache/api_proxy_test.go | 84 +++++++++++ command/agent/config/config.go | 2 + command/agent/config/config_test.go | 29 ++++ .../test-fixtures/config-consistency.hcl | 9 ++ helper/storagepacker/storagepacker.go | 18 +-- helper/testhelpers/teststorage/teststorage.go | 1 - http/handler.go | 50 ++++++- http/logical.go | 12 +- http/logical_test.go | 2 +- http/util.go | 2 + physical/raft/fsm.go | 14 +- physical/raft/raft.go | 10 ++ sdk/logical/request.go | 38 +++++ vault/cluster.go | 12 ++ vault/core.go | 67 +++++++++ vault/core_util.go | 8 ++ vault/dynamic_system_view.go | 1 + vault/identity_store.go | 6 +- vault/identity_store_entities.go | 2 + vault/identity_store_util.go | 10 +- vault/request_handling.go | 22 +++ vault/wrapping.go | 3 + .../github.com/hashicorp/vault/api/client.go | 90 +++++++++++- .../hashicorp/vault/sdk/logical/request.go | 38 +++++ 27 files changed, 754 insertions(+), 44 deletions(-) create mode 100644 command/agent/config/test-fixtures/config-consistency.hcl diff --git a/api/client.go b/api/client.go index d286b7393..52e109b7b 100644 --- a/api/client.go +++ b/api/client.go @@ -384,6 +384,8 @@ type Client struct { wrappingLookupFunc WrappingLookupFunc mfaCreds []string policyOverride bool + requestCallbacks []RequestCallback + responseCallbacks []ResponseCallback } // NewClient returns a new client for the given configuration. @@ -866,6 +868,10 @@ func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Respon c.modifyLock.RUnlock() + for _, cb := range c.requestCallbacks { + cb(r) + } + if limiter != nil { limiter.Wait(ctx) } @@ -907,7 +913,7 @@ START: } if checkRetry == nil { - checkRetry = retryablehttp.DefaultRetryPolicy + checkRetry = DefaultRetryPolicy } client := &retryablehttp.Client{ @@ -968,9 +974,91 @@ START: goto START } + if result != nil { + for _, cb := range c.responseCallbacks { + cb(result) + } + } if err := result.Error(); err != nil { return result, err } return result, nil } + +type RequestCallback func(*Request) +type ResponseCallback func(*Response) + +// WithRequestCallbacks makes a shallow clone of Client, modifies it to use +// the given callbacks, and returns it. Each of the callbacks will be invoked +// on every outgoing request. A client may be used to issue requests +// concurrently; any locking needed by callbacks invoked concurrently is the +// callback's responsibility. +func (c *Client) WithRequestCallbacks(callbacks ...RequestCallback) *Client { + c2 := *c + c2.modifyLock = sync.RWMutex{} + c2.requestCallbacks = callbacks + return &c2 +} + +// WithResponseCallbacks makes a shallow clone of Client, modifies it to use +// the given callbacks, and returns it. Each of the callbacks will be invoked +// on every received response. A client may be used to issue requests +// concurrently; any locking needed by callbacks invoked concurrently is the +// callback's responsibility. +func (c *Client) WithResponseCallbacks(callbacks ...ResponseCallback) *Client { + c2 := *c + c2.modifyLock = sync.RWMutex{} + c2.responseCallbacks = callbacks + return &c2 +} + +// RecordState returns a response callback that will record the state returned +// by Vault in a response header. +func RecordState(state *string) ResponseCallback { + return func(resp *Response) { + *state = resp.Header.Get("X-Vault-Index") + } +} + +// RequireState returns a request callback that will add a request header to +// specify the state we require of Vault. This state was obtained from a +// response header seen previous, probably captured with RecordState. +func RequireState(states ...string) RequestCallback { + return func(req *Request) { + for _, s := range states { + req.Headers.Add("X-Vault-Index", s) + } + } +} + +// ForwardInconsistent returns a request callback that will add a request +// header which says: if the state required isn't present on the node receiving +// this request, forward it to the active node. This should be used in +// conjunction with RequireState. +func ForwardInconsistent() RequestCallback { + return func(req *Request) { + req.Headers.Set("X-Vault-Inconsistent", "forward-active-node") + } +} + +// ForwardAlways returns a request callback which adds a header telling any +// performance standbys handling the request to forward it to the active node. +// This feature must be enabled in Vault's configuration. +func ForwardAlways() RequestCallback { + return func(req *Request) { + req.Headers.Set("X-Vault-Forward", "active-node") + } +} + +// DefaultRetryPolicy is the default retry policy used by new Client objects. +// It is the same as retryablehttp.DefaultRetryPolicy except that it also retries +// 412 requests, which are returned by Vault when a X-Vault-Index header isn't +// satisfied. +func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) { + retry, err := retryablehttp.DefaultRetryPolicy(ctx, resp, err) + if err != nil || retry { + return retry, err + } + return resp.StatusCode == 412, nil +} diff --git a/command/agent.go b/command/agent.go index cf413004a..8b3b53127 100644 --- a/command/agent.go +++ b/command/agent.go @@ -416,6 +416,30 @@ func (c *AgentCommand) Run(args []string) int { } } + enforceConsistency := cache.EnforceConsistencyNever + whenInconsistent := cache.WhenInconsistentFail + if config.Cache != nil { + switch config.Cache.EnforceConsistency { + case "always": + enforceConsistency = cache.EnforceConsistencyAlways + case "never", "": + default: + c.UI.Error(fmt.Sprintf("Unknown cache setting for enforce_consistency: %q", config.Cache.EnforceConsistency)) + return 1 + } + + switch config.Cache.WhenInconsistent { + case "retry": + whenInconsistent = cache.WhenInconsistentRetry + case "forward": + whenInconsistent = cache.WhenInconsistentForward + case "fail", "": + default: + c.UI.Error(fmt.Sprintf("Unknown cache setting for when_inconsistent: %q", config.Cache.WhenInconsistent)) + return 1 + } + } + // Warn if cache _and_ cert auto-auth is enabled but certificates were not // provided in the auto_auth.method["cert"].config stanza. if config.Cache != nil && (config.AutoAuth != nil && config.AutoAuth.Method != nil && config.AutoAuth.Method.Type == "cert") { @@ -437,20 +461,16 @@ func (c *AgentCommand) Run(args []string) int { c.UI.Output("==> Vault agent started! Log data will stream in below:\n") } - // Inform any tests that the server is ready - select { - case c.startedCh <- struct{}{}: - default: - } - // Parse agent listener configurations if config.Cache != nil && len(config.Listeners) != 0 { cacheLogger := c.logger.Named("cache") // Create the API proxier apiProxy, err := cache.NewAPIProxy(&cache.APIProxyConfig{ - Client: client, - Logger: cacheLogger.Named("apiproxy"), + Client: client, + Logger: cacheLogger.Named("apiproxy"), + EnforceConsistency: enforceConsistency, + WhenInconsistentAction: whenInconsistent, }) if err != nil { c.UI.Error(fmt.Sprintf("Error creating API proxy: %v", err)) @@ -547,6 +567,12 @@ func (c *AgentCommand) Run(args []string) int { defer c.cleanupGuard.Do(listenerCloseFunc) } + // Inform any tests that the server is ready + select { + case c.startedCh <- struct{}{}: + default: + } + // Listen for signals // TODO: implement support for SIGHUP reloading of configuration // signal.Notify(c.signalCh) diff --git a/command/agent/cache/api_proxy.go b/command/agent/cache/api_proxy.go index 580c14f75..5ba7076d1 100644 --- a/command/agent/cache/api_proxy.go +++ b/command/agent/cache/api_proxy.go @@ -3,21 +3,49 @@ package cache import ( "context" "fmt" + "sync" hclog "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-retryablehttp" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/sdk/helper/strutil" + "github.com/hashicorp/vault/vault" +) + +type EnforceConsistency int + +const ( + EnforceConsistencyNever EnforceConsistency = iota + EnforceConsistencyAlways +) + +type WhenInconsistentAction int + +const ( + WhenInconsistentFail WhenInconsistentAction = iota + WhenInconsistentRetry + WhenInconsistentForward ) // APIProxy is an implementation of the proxier interface that is used to // forward the request to Vault and get the response. type APIProxy struct { - client *api.Client - logger hclog.Logger + client *api.Client + logger hclog.Logger + enforceConsistency EnforceConsistency + whenInconsistentAction WhenInconsistentAction + l sync.RWMutex + lastIndexStates []string } +var _ Proxier = &APIProxy{} + type APIProxyConfig struct { - Client *api.Client - Logger hclog.Logger + Client *api.Client + Logger hclog.Logger + EnforceConsistency EnforceConsistency + WhenInconsistentAction WhenInconsistentAction } func NewAPIProxy(config *APIProxyConfig) (Proxier, error) { @@ -25,11 +53,65 @@ func NewAPIProxy(config *APIProxyConfig) (Proxier, error) { return nil, fmt.Errorf("nil API client") } return &APIProxy{ - client: config.Client, - logger: config.Logger, + client: config.Client, + logger: config.Logger, + enforceConsistency: config.EnforceConsistency, + whenInconsistentAction: config.WhenInconsistentAction, }, nil } +// compareStates returns 1 if s1 is newer or identical, -1 if s1 is older, and 0 +// if neither s1 or s2 is strictly greater. An error is returned if s1 or s2 +// are invalid or from different clusters. +func compareStates(s1, s2 string) (int, error) { + w1, err := vault.ParseRequiredState(s1, nil) + if err != nil { + return 0, err + } + w2, err := vault.ParseRequiredState(s2, nil) + if err != nil { + return 0, err + } + + if w1.ClusterID != w2.ClusterID { + return 0, fmt.Errorf("don't know how to compare states with different ClusterIDs") + } + + switch { + case w1.LocalIndex >= w2.LocalIndex && w1.ReplicatedIndex >= w2.ReplicatedIndex: + return 1, nil + // We've already handled the case where both are equal above, so really we're + // asking here if one or both are lesser. + case w1.LocalIndex <= w2.LocalIndex && w1.ReplicatedIndex <= w2.ReplicatedIndex: + return -1, nil + } + + return 0, nil +} + +func mergeStates(old []string, new string) []string { + if len(old) == 0 || len(old) > 2 { + return []string{new} + } + + var ret []string + for _, o := range old { + c, err := compareStates(o, new) + if err != nil { + return []string{new} + } + switch c { + case 1: + ret = append(ret, o) + case -1: + ret = append(ret, new) + case 0: + ret = append(ret, o, new) + } + } + return strutil.RemoveDuplicates(ret, false) +} + func (ap *APIProxy) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) { client, err := ap.client.Clone() if err != nil { @@ -51,6 +133,34 @@ func (ap *APIProxy) Send(ctx context.Context, req *SendRequest) (*SendResponse, fwReq.Params = query } + var newState string + manageState := ap.enforceConsistency == EnforceConsistencyAlways && + req.Request.Header.Get(http.VaultIndexHeaderName) == "" && + req.Request.Header.Get(http.VaultForwardHeaderName) == "" && + req.Request.Header.Get(http.VaultInconsistentHeaderName) == "" + + if manageState { + client = client.WithResponseCallbacks(api.RecordState(&newState)) + ap.l.RLock() + lastStates := ap.lastIndexStates + ap.l.RUnlock() + if len(lastStates) != 0 { + client = client.WithRequestCallbacks(api.RequireState(lastStates...)) + switch ap.whenInconsistentAction { + case WhenInconsistentFail: + // In this mode we want to delegate handling of inconsistency + // failures to the external client talking to Agent. + client.SetCheckRetry(retryablehttp.DefaultRetryPolicy) + case WhenInconsistentRetry: + // In this mode we want to handle retries due to inconsistency + // internally. This is the default api.Client behaviour so + // we needn't do anything. + case WhenInconsistentForward: + fwReq.Headers.Set(http.VaultInconsistentHeaderName, http.VaultInconsistentForward) + } + } + } + // Make the request to Vault and get the response ap.logger.Info("forwarding request", "method", req.Request.Method, "path", req.Request.URL.Path) @@ -60,6 +170,20 @@ func (ap *APIProxy) Send(ctx context.Context, req *SendRequest) (*SendResponse, return nil, err } + if newState != "" { + ap.l.Lock() + // We want to be using the "newest" states seen, but newer isn't well + // defined here. There can be two states S1 and S2 which aren't strictly ordered: + // S1 could have a newer localindex and S2 could have a newer replicatedindex. So + // we need to merge them. But we can't merge them because we wouldn't be able to + // "sign" the resulting header because we don't have access to the HMAC key that + // Vault uses to do so. So instead we compare any of the 0-2 saved states + // we have to the new header, keeping the newest 1-2 of these, and sending + // them to Vault to evaluate. + ap.lastIndexStates = mergeStates(ap.lastIndexStates, newState) + ap.l.Unlock() + } + // Before error checking from the request call, we'd want to initialize a SendResponse to // potentially return sendResponse, newErr := NewSendResponse(resp, nil) diff --git a/command/agent/cache/api_proxy_test.go b/command/agent/cache/api_proxy_test.go index b90f579c3..26efc0d9e 100644 --- a/command/agent/cache/api_proxy_test.go +++ b/command/agent/cache/api_proxy_test.go @@ -1,6 +1,8 @@ package cache import ( + "encoding/base64" + "github.com/go-test/deep" "net/http" "testing" @@ -93,3 +95,85 @@ func TestAPIProxy_queryParams(t *testing.T) { t.Fatalf("exptected standby to return 200, got: %v", resp.Response.StatusCode) } } + +func TestMergeStates(t *testing.T) { + type testCase struct { + name string + old []string + new string + expected []string + } + + var testCases = []testCase{ + { + name: "empty-old", + old: nil, + new: "v1:cid:1:0:", + expected: []string{"v1:cid:1:0:"}, + }, + { + name: "old-smaller", + old: []string{"v1:cid:1:0:"}, + new: "v1:cid:2:0:", + expected: []string{"v1:cid:2:0:"}, + }, + { + name: "old-bigger", + old: []string{"v1:cid:2:0:"}, + new: "v1:cid:1:0:", + expected: []string{"v1:cid:2:0:"}, + }, + { + name: "mixed-single", + old: []string{"v1:cid:1:0:"}, + new: "v1:cid:0:1:", + expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"}, + }, + { + name: "mixed-single-alt", + old: []string{"v1:cid:0:1:"}, + new: "v1:cid:1:0:", + expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"}, + }, + { + name: "mixed-double", + old: []string{"v1:cid:0:1:", "v1:cid:1:0:"}, + new: "v1:cid:2:0:", + expected: []string{"v1:cid:0:1:", "v1:cid:2:0:"}, + }, + { + name: "newer-both", + old: []string{"v1:cid:0:1:", "v1:cid:1:0:"}, + new: "v1:cid:2:1:", + expected: []string{"v1:cid:2:1:"}, + }, + } + + b64enc := func(ss []string) []string { + var ret []string + for _, s := range ss { + ret = append(ret, base64.StdEncoding.EncodeToString([]byte(s))) + } + return ret + } + b64dec := func(ss []string) []string { + var ret []string + for _, s := range ss { + d, err := base64.StdEncoding.DecodeString(s) + if err != nil { + t.Fatal(err) + } + ret = append(ret, string(d)) + } + return ret + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + out := b64dec(mergeStates(b64enc(tc.old), base64.StdEncoding.EncodeToString([]byte(tc.new)))) + if diff := deep.Equal(out, tc.expected); len(diff) != 0 { + t.Errorf("got=%v, expected=%v, diff=%v", out, tc.expected, diff) + } + }) + } +} diff --git a/command/agent/config/config.go b/command/agent/config/config.go index 8e3f3f7de..32f99b55c 100644 --- a/command/agent/config/config.go +++ b/command/agent/config/config.go @@ -48,6 +48,8 @@ type Cache struct { UseAutoAuthTokenRaw interface{} `hcl:"use_auto_auth_token"` UseAutoAuthToken bool `hcl:"-"` ForceAutoAuthToken bool `hcl:"-"` + EnforceConsistency string `hcl:"enforce_consistency"` + WhenInconsistent string `hcl:"when_inconsistent"` } // AutoAuth is the configured authentication method and sinks diff --git a/command/agent/config/config_test.go b/command/agent/config/config_test.go index 960c374fb..8a502d7f9 100644 --- a/command/agent/config/config_test.go +++ b/command/agent/config/config_test.go @@ -742,3 +742,32 @@ func TestLoadConfigFile_Vault_Retry_Empty(t *testing.T) { t.Fatal(diff) } } + +func TestLoadConfigFile_EnforceConsistency(t *testing.T) { + config, err := LoadConfig("./test-fixtures/config-consistency.hcl") + if err != nil { + t.Fatal(err) + } + + expected := &Config{ + SharedConfig: &configutil.SharedConfig{ + Listeners: []*configutil.Listener{ + { + Type: "tcp", + Address: "127.0.0.1:8300", + TLSDisable: true, + }, + }, + PidFile: "", + }, + Cache: &Cache{ + EnforceConsistency: "always", + WhenInconsistent: "retry", + }, + } + + config.Listeners[0].RawConfig = nil + if diff := deep.Equal(config, expected); diff != nil { + t.Fatal(diff) + } +} diff --git a/command/agent/config/test-fixtures/config-consistency.hcl b/command/agent/config/test-fixtures/config-consistency.hcl new file mode 100644 index 000000000..d57e05573 --- /dev/null +++ b/command/agent/config/test-fixtures/config-consistency.hcl @@ -0,0 +1,9 @@ +cache { + enforce_consistency = "always" + when_inconsistent = "retry" +} + +listener "tcp" { + address = "127.0.0.1:8300" + tls_disable = true +} diff --git a/helper/storagepacker/storagepacker.go b/helper/storagepacker/storagepacker.go index f15150824..c0eea6522 100644 --- a/helper/storagepacker/storagepacker.go +++ b/helper/storagepacker/storagepacker.go @@ -41,7 +41,7 @@ func (s *StoragePacker) View() logical.Storage { } // GetBucket returns a bucket for a given key -func (s *StoragePacker) GetBucket(key string) (*Bucket, error) { +func (s *StoragePacker) GetBucket(ctx context.Context, key string) (*Bucket, error) { if key == "" { return nil, fmt.Errorf("missing bucket key") } @@ -51,7 +51,7 @@ func (s *StoragePacker) GetBucket(key string) (*Bucket, error) { defer lock.RUnlock() // Read from storage - storageEntry, err := s.view.Get(context.Background(), key) + storageEntry, err := s.view.Get(ctx, key) if err != nil { return nil, errwrap.Wrapf("failed to read packed storage entry: {{err}}", err) } @@ -126,8 +126,8 @@ func (s *StoragePacker) BucketKey(itemID string) string { } // DeleteItem removes the item from the respective bucket -func (s *StoragePacker) DeleteItem(_ context.Context, itemID string) error { - return s.DeleteMultipleItems(context.Background(), nil, []string{itemID}) +func (s *StoragePacker) DeleteItem(ctx context.Context, itemID string) error { + return s.DeleteMultipleItems(ctx, nil, []string{itemID}) } func (s *StoragePacker) DeleteMultipleItems(ctx context.Context, logger hclog.Logger, itemIDs []string) error { @@ -171,7 +171,7 @@ func (s *StoragePacker) DeleteMultipleItems(ctx context.Context, logger hclog.Lo idx := 0 for bucketKey, itemsToRemove := range byBucket { // Read bucket from storage - storageEntry, err := s.view.Get(context.Background(), bucketKey) + storageEntry, err := s.view.Get(ctx, bucketKey) if err != nil { return errwrap.Wrapf("failed to read packed storage value: {{err}}", err) } @@ -278,7 +278,7 @@ func (s *StoragePacker) GetItem(itemID string) (*Item, error) { bucketKey := s.BucketKey(itemID) // Fetch the bucket entry - bucket, err := s.GetBucket(bucketKey) + bucket, err := s.GetBucket(context.Background(), bucketKey) if err != nil { return nil, errwrap.Wrapf("failed to read packed storage item: {{err}}", err) } @@ -297,7 +297,7 @@ func (s *StoragePacker) GetItem(itemID string) (*Item, error) { } // PutItem stores the given item in its respective bucket -func (s *StoragePacker) PutItem(_ context.Context, item *Item) error { +func (s *StoragePacker) PutItem(ctx context.Context, item *Item) error { defer metrics.MeasureSince([]string{"storage_packer", "put_item"}, time.Now()) if item == nil { @@ -323,7 +323,7 @@ func (s *StoragePacker) PutItem(_ context.Context, item *Item) error { defer lock.Unlock() // Check if there is an existing bucket for a given key - storageEntry, err := s.view.Get(context.Background(), bucketKey) + storageEntry, err := s.view.Get(ctx, bucketKey) if err != nil { return errwrap.Wrapf("failed to read packed storage bucket entry: {{err}}", err) } @@ -354,7 +354,7 @@ func (s *StoragePacker) PutItem(_ context.Context, item *Item) error { } } - return s.putBucket(context.Background(), bucket) + return s.putBucket(ctx, bucket) } // NewStoragePacker creates a new storage packer for a given view diff --git a/helper/testhelpers/teststorage/teststorage.go b/helper/testhelpers/teststorage/teststorage.go index 9b880a4d6..8ca7441d8 100644 --- a/helper/testhelpers/teststorage/teststorage.go +++ b/helper/testhelpers/teststorage/teststorage.go @@ -188,7 +188,6 @@ func FileBackendSetup(conf *vault.CoreConfig, opts *vault.TestClusterOptions) { } func RaftBackendSetup(conf *vault.CoreConfig, opts *vault.TestClusterOptions) { - conf.DisablePerformanceStandby = true opts.KeepStandbysSealed = true opts.PhysicalFactory = MakeRaftBackend opts.SetupFunc = func(t testing.T, c *vault.TestCluster) { diff --git a/http/handler.go b/http/handler.go index 75e8e9c57..ba46fb584 100644 --- a/http/handler.go +++ b/http/handler.go @@ -57,6 +57,12 @@ const ( // soft-mandatory Sentinel policies. PolicyOverrideHeaderName = "X-Vault-Policy-Override" + VaultIndexHeaderName = "X-Vault-Index" + VaultInconsistentHeaderName = "X-Vault-Inconsistent" + VaultForwardHeaderName = "X-Vault-Forward" + VaultInconsistentForward = "forward-active-node" + VaultInconsistentFail = "fail" + // DefaultMaxRequestSize is the default maximum accepted request size. This // is to prevent a denial of service attack where no Content-Length is // provided and the server is fed ever more data until it exhausts memory. @@ -666,12 +672,52 @@ func parseFormRequest(r *http.Request) (map[string]interface{}, error) { return data, nil } +// forwardBasedOnHeaders returns true if the request headers specify that +// we should forward to the active node - either unconditionally or because +// a specified state isn't present locally. +func forwardBasedOnHeaders(core *vault.Core, r *http.Request) (bool, error) { + rawForward := r.Header.Get(VaultForwardHeaderName) + if rawForward != "" { + if !core.AllowForwardingViaHeader() { + return false, fmt.Errorf("forwarding via header %s disabled in configuration", VaultForwardHeaderName) + } + if rawForward == "active-node" { + return true, nil + } + return false, nil + } + + rawInconsistent := r.Header.Get(VaultInconsistentHeaderName) + if rawInconsistent == "" { + return false, nil + } + + switch rawInconsistent { + case VaultInconsistentForward: + if !core.AllowForwardingViaHeader() { + return false, fmt.Errorf("forwarding via header %s=%s disabled in configuration", + VaultInconsistentHeaderName, VaultInconsistentForward) + } + default: + return false, nil + } + + return core.MissingRequiredState(r.Header.Values(VaultIndexHeaderName)), nil +} + // handleRequestForwarding determines whether to forward a request or not, // falling back on the older behavior of redirecting the client func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // If we are a performance standby we can handle the request. - if core.PerfStandby() { + // Note if the client requested forwarding + shouldForward, err := forwardBasedOnHeaders(core, r) + if err != nil { + respondError(w, http.StatusBadRequest, err) + return + } + + // If we are a performance standby we can maybe handle the request. + if core.PerfStandby() && !shouldForward { ns, err := namespace.FromContext(r.Context()) if err != nil { respondError(w, http.StatusBadRequest, err) diff --git a/http/logical.go b/http/logical.go index c06ae552b..0891a442a 100644 --- a/http/logical.go +++ b/http/logical.go @@ -219,6 +219,12 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques if err != nil || status != 0 { return nil, nil, status, err } + + rawRequired := r.Header.Values(VaultIndexHeaderName) + if len(rawRequired) != 0 && core.MissingRequiredState(rawRequired) { + return nil, nil, http.StatusPreconditionFailed, fmt.Errorf("required index state not present") + } + req, err = requestAuth(core, r, req) if err != nil { if errwrap.Contains(err, logical.ErrPermissionDenied.Error()) { @@ -453,13 +459,13 @@ func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForw return default: // Build and return the proper response if everything is fine. - respondLogical(w, r, req, resp, injectDataIntoTopLevel) + respondLogical(core, w, r, req, resp, injectDataIntoTopLevel) return } }) } -func respondLogical(w http.ResponseWriter, r *http.Request, req *logical.Request, resp *logical.Response, injectDataIntoTopLevel bool) { +func respondLogical(core *vault.Core, w http.ResponseWriter, r *http.Request, req *logical.Request, resp *logical.Response, injectDataIntoTopLevel bool) { var httpResp *logical.HTTPResponse var ret interface{} @@ -509,6 +515,8 @@ func respondLogical(w http.ResponseWriter, r *http.Request, req *logical.Request } } + adjustResponse(core, w, req) + // Respond respondOk(w, ret) return diff --git a/http/logical_test.go b/http/logical_test.go index ce6d32bf4..fd41df8aa 100644 --- a/http/logical_test.go +++ b/http/logical_test.go @@ -369,7 +369,7 @@ func TestLogical_RespondWithStatusCode(t *testing.T) { } w := httptest.NewRecorder() - respondLogical(w, nil, nil, resp404, false) + respondLogical(nil, w, nil, nil, resp404, false) if w.Code != 404 { t.Fatalf("Bad Status code: %d", w.Code) diff --git a/http/util.go b/http/util.go index f16a16236..0ba6a9edc 100644 --- a/http/util.go +++ b/http/util.go @@ -28,6 +28,8 @@ var ( additionalRoutes = func(mux *http.ServeMux, core *vault.Core) {} nonVotersAllowed = false + + adjustResponse = func(core *vault.Core, w http.ResponseWriter, req *logical.Request) {} ) func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler { diff --git a/physical/raft/fsm.go b/physical/raft/fsm.go index 3fe4d4472..a586ccca4 100644 --- a/physical/raft/fsm.go +++ b/physical/raft/fsm.go @@ -77,8 +77,8 @@ type FSM struct { logger log.Logger noopRestore bool - // applyDelay is used to simulate a slow apply in tests - applyDelay time.Duration + // applyCallback is used to control the pace of applies in tests + applyCallback func() db *bolt.DB @@ -131,8 +131,12 @@ func (f *FSM) getDB() *bolt.DB { // SetFSMDelay adds a delay to the FSM apply. This is used in tests to simulate // a slow apply. func (r *RaftBackend) SetFSMDelay(delay time.Duration) { + r.SetFSMApplyCallback(func() { time.Sleep(delay) }) +} + +func (r *RaftBackend) SetFSMApplyCallback(f func()) { r.fsm.l.Lock() - r.fsm.applyDelay = delay + r.fsm.applyCallback = f r.fsm.l.Unlock() } @@ -469,8 +473,8 @@ func (f *FSM) ApplyBatch(logs []*raft.Log) []interface{} { f.l.RLock() defer f.l.RUnlock() - if f.applyDelay > 0 { - time.Sleep(f.applyDelay) + if f.applyCallback != nil { + f.applyCallback() } err = f.db.Update(func(tx *bolt.Tx) error { diff --git a/physical/raft/raft.go b/physical/raft/raft.go index 5481552e3..52cd04982 100644 --- a/physical/raft/raft.go +++ b/physical/raft/raft.go @@ -263,6 +263,16 @@ func NewRaftBackend(conf map[string]string, logger log.Logger) (physical.Backend return nil, fmt.Errorf("failed to create fsm: %v", err) } + if delayRaw, ok := conf["apply_delay"]; ok { + delay, err := time.ParseDuration(delayRaw) + if err != nil { + return nil, fmt.Errorf("apply_delay does not parse as a duration: %w", err) + } + fsm.applyCallback = func() { + time.Sleep(delay) + } + } + // Build an all in-memory setup for dev mode, otherwise prepare a full // disk-based setup. var log raft.LogStore diff --git a/sdk/logical/request.go b/sdk/logical/request.go index e4cd6ad36..1a103a57a 100644 --- a/sdk/logical/request.go +++ b/sdk/logical/request.go @@ -1,6 +1,7 @@ package logical import ( + "context" "fmt" "net/http" "strings" @@ -54,6 +55,30 @@ const ( ClientTokenFromAuthzHeader ) +type WALState struct { + ClusterID string + LocalIndex uint64 + ReplicatedIndex uint64 +} + +const indexStateCtxKey = "index_state" + +// IndexStateContext returns a context with an added value holding the index +// state that should be populated on writes. +func IndexStateContext(ctx context.Context, state *WALState) context.Context { + return context.WithValue(ctx, indexStateCtxKey, state) +} + +// IndexStateFromContext is a helper to look up if the provided context contains +// an index state pointer. +func IndexStateFromContext(ctx context.Context) *WALState { + s, ok := ctx.Value(indexStateCtxKey).(*WALState) + if !ok { + return nil + } + return s +} + // Request is a struct that stores the parameters and context of a request // being made to Vault. It is used to abstract the details of the higher level // request protocol from the handlers. @@ -179,6 +204,11 @@ type Request struct { // ResponseWriter if set can be used to stream a response value to the http // request that generated this logical.Request object. ResponseWriter *HTTPResponseWriter `json:"-" sentinel:""` + + // responseState is used internally to propagate the state that should appear + // in response headers; it's attached to the request rather than the response + // because not all requests yields non-nil responses. + responseState *WALState } // Clone returns a deep copy of the request by using copystructure @@ -243,6 +273,14 @@ func (r *Request) SetLastRemoteWAL(last uint64) { r.lastRemoteWAL = last } +func (r *Request) ResponseState() *WALState { + return r.responseState +} + +func (r *Request) SetResponseState(w *WALState) { + r.responseState = w +} + func (r *Request) TokenEntry() *TokenEntry { return r.tokenEntry } diff --git a/vault/cluster.go b/vault/cluster.go index 8344fe6d7..649dd0d3d 100644 --- a/vault/cluster.go +++ b/vault/cluster.go @@ -280,6 +280,18 @@ func (c *Core) setupCluster(ctx context.Context) error { } } + c.clusterID.Store(cluster.ID) + return nil +} + +func (c *Core) loadCluster(ctx context.Context) error { + cluster, err := c.Cluster(ctx) + if err != nil { + c.logger.Error("failed to get cluster details", "error", err) + return err + } + + c.clusterID.Store(cluster.ID) return nil } diff --git a/vault/core.go b/vault/core.go index 3168f8820..45de53c4c 100644 --- a/vault/core.go +++ b/vault/core.go @@ -3,10 +3,14 @@ package vault import ( "context" "crypto/ecdsa" + "crypto/hmac" "crypto/rand" + "crypto/sha256" "crypto/subtle" "crypto/tls" "crypto/x509" + "encoding/base64" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -16,6 +20,8 @@ import ( "net/url" "os" "path/filepath" + "strconv" + "strings" "sync" "sync/atomic" "time" @@ -48,6 +54,7 @@ import ( "github.com/hashicorp/vault/vault/quotas" vaultseal "github.com/hashicorp/vault/vault/seal" "github.com/patrickmn/go-cache" + uberAtomic "go.uber.org/atomic" "google.golang.org/grpc" ) @@ -72,6 +79,8 @@ const ( // clusters that they need to perform a rekey operation synchronously; this // isn't keyring-canary to avoid ignoring it when ignoring core/keyring coreKeyringCanaryPath = "core/canary-keyring" + + indexHeaderHMACKeyPath = "core/index-header-hmac-key" ) var ( @@ -108,6 +117,7 @@ var ( PerformanceMerkleRoot = merkleRootImpl DRMerkleRoot = merkleRootImpl LastRemoteWAL = lastRemoteWALImpl + LastRemoteUpstreamWAL = lastRemoteUpstreamWALImpl WaitUntilWALShipped = waitUntilWALShippedImpl ) @@ -391,6 +401,8 @@ type Core struct { // // Name clusterName string + // ID + clusterID uberAtomic.String // Specific cipher suites to use for clustering, if any clusterCipherSuites []uint16 // Used to modify cluster parameters @@ -546,6 +558,8 @@ type Core struct { // number of workers to use for lease revocation in the expiration manager numExpirationWorkers int + + IndexHeaderHMACKey uberAtomic.Value } // CoreConfig is used to parameterize a core @@ -2201,6 +2215,10 @@ func lastRemoteWALImpl(c *Core) uint64 { return 0 } +func lastRemoteUpstreamWALImpl(c *Core) uint64 { + return 0 +} + func (c *Core) PhysicalSealConfigs(ctx context.Context) (*SealConfig, *SealConfig, error) { pe, err := c.physical.Get(ctx, barrierSealConfigPath) if err != nil { @@ -2542,6 +2560,10 @@ func (c *Core) IsDRSecondary() bool { return c.ReplicationState().HasState(consts.ReplicationDRSecondary) } +func (c *Core) IsPerfSecondary() bool { + return c.ReplicationState().HasState(consts.ReplicationPerformanceSecondary) +} + func (c *Core) AddLogger(logger log.Logger) { c.allLoggersLock.Lock() defer c.allLoggersLock.Unlock() @@ -2705,3 +2727,48 @@ func (c *Core) KeyRotateGracePeriod() time.Duration { func (c *Core) SetKeyRotateGracePeriod(t time.Duration) { atomic.StoreInt64(c.keyRotateGracePeriod, int64(t)) } + +func ParseRequiredState(raw string, hmacKey []byte) (*logical.WALState, error) { + cooked, err := base64.StdEncoding.DecodeString(raw) + if err != nil { + return nil, err + } + s := string(cooked) + + lastIndex := strings.LastIndexByte(s, ':') + if lastIndex == -1 { + return nil, fmt.Errorf("invalid state header format") + } + state, stateHMACRaw := s[:lastIndex], s[lastIndex+1:] + stateHMAC, err := hex.DecodeString(stateHMACRaw) + if err != nil { + return nil, fmt.Errorf("invalid state header HMAC: %v, %w", stateHMACRaw, err) + } + + if len(hmacKey) != 0 { + hm := hmac.New(sha256.New, hmacKey) + hm.Write([]byte(state)) + if !hmac.Equal(hm.Sum(nil), stateHMAC) { + return nil, fmt.Errorf("invalid state header HMAC (mismatch)") + } + } + + pieces := strings.Split(state, ":") + if len(pieces) != 4 || pieces[0] != "v1" || pieces[1] == "" { + return nil, fmt.Errorf("invalid state header format") + } + localIndex, err := strconv.ParseUint(pieces[2], 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid state header format") + } + replicatedIndex, err := strconv.ParseUint(pieces[3], 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid state header format") + } + + return &logical.WALState{ + ClusterID: pieces[1], + LocalIndex: localIndex, + ReplicatedIndex: replicatedIndex, + }, nil +} diff --git a/vault/core_util.go b/vault/core_util.go index 4296c462b..57d76f0d0 100644 --- a/vault/core_util.go +++ b/vault/core_util.go @@ -158,3 +158,11 @@ func (c *Core) quotasHandleLeases(ctx context.Context, action quotas.LeaseAction func (c *Core) namespaceByPath(path string) *namespace.Namespace { return namespace.RootNamespace } + +func (c *Core) AllowForwardingViaHeader() bool { + return false +} + +func (c *Core) MissingRequiredState(raw []string) bool { + return false +} diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index bf48428e6..41b132bd1 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -52,6 +52,7 @@ func (e extendedSystemViewImpl) ForwardGenericRequest(ctx context.Context, req * // Forward the request if allowed if couldForward(e.core) { ctx = namespace.ContextWithNamespace(ctx, e.mountEntry.Namespace()) + ctx = logical.IndexStateContext(ctx, &logical.WALState{}) ctx = context.WithValue(ctx, ctxKeyForwardedRequestMountAccessor{}, e.mountEntry.Accessor) return forward(ctx, e.core, req) } diff --git a/vault/identity_store.go b/vault/identity_store.go index 4b844d58b..b7e5422ac 100644 --- a/vault/identity_store.go +++ b/vault/identity_store.go @@ -27,7 +27,7 @@ const ( var ( caseSensitivityKey = "casesensitivity" - sendGroupUpgrade = func(*IdentityStore, *identity.Group) (bool, error) { return false, nil } + sendGroupUpgrade = func(context.Context, *IdentityStore, *identity.Group) (bool, error) { return false, nil } parseExtraEntityFromBucket = func(context.Context, *IdentityStore, *identity.Entity) (bool, error) { return false, nil } addExtraEntityDataToResponse = func(*identity.Entity, map[string]interface{}) {} ) @@ -210,7 +210,7 @@ func (i *IdentityStore) Invalidate(ctx context.Context, key string) { } // Get the storage bucket entry - bucket, err := i.entityPacker.GetBucket(key) + bucket, err := i.entityPacker.GetBucket(ctx, key) if err != nil { i.logger.Error("failed to refresh entities", "key", key, "error", err) return @@ -273,7 +273,7 @@ func (i *IdentityStore) Invalidate(ctx context.Context, key string) { } // Get the storage bucket entry - bucket, err := i.groupPacker.GetBucket(key) + bucket, err := i.groupPacker.GetBucket(ctx, key) if err != nil { i.logger.Error("failed to refresh group", "key", key, "error", err) return diff --git a/vault/identity_store_entities.go b/vault/identity_store_entities.go index b68f69889..50f4c5b33 100644 --- a/vault/identity_store_entities.go +++ b/vault/identity_store_entities.go @@ -569,6 +569,8 @@ func (i *IdentityStore) handleEntityBatchDelete() framework.OperationFunc { } } +// handleEntityDeleteCommon deletes an entity by removing it from groups of +// which it's a member and then, if update is true, deleting the entity itself. func (i *IdentityStore) handleEntityDeleteCommon(ctx context.Context, txn *memdb.Txn, entity *identity.Entity, update bool) error { ns, err := namespace.FromContext(ctx) if err != nil { diff --git a/vault/identity_store_util.go b/vault/identity_store_util.go index 57a27172a..fdcfd9071 100644 --- a/vault/identity_store_util.go +++ b/vault/identity_store_util.go @@ -89,7 +89,7 @@ func (i *IdentityStore) loadGroups(ctx context.Context) error { i.logger.Debug("groups collected", "num_existing", len(existing)) for _, key := range existing { - bucket, err := i.groupPacker.GetBucket(groupBucketsPrefix + key) + bucket, err := i.groupPacker.GetBucket(ctx, groupBucketsPrefix+key) if err != nil { return err } @@ -124,7 +124,7 @@ func (i *IdentityStore) loadGroups(ctx context.Context) error { } continue } - nsCtx := namespace.ContextWithNamespace(context.Background(), ns) + nsCtx := namespace.ContextWithNamespace(ctx, ns) // Ensure that there are no groups with duplicate names groupByName, err := i.MemDBGroupByName(nsCtx, group.Name, false) @@ -212,7 +212,7 @@ func (i *IdentityStore) loadEntities(ctx context.Context) error { return } - bucket, err := i.entityPacker.GetBucket(storagepacker.StoragePackerBucketsPrefix + key) + bucket, err := i.entityPacker.GetBucket(ctx, storagepacker.StoragePackerBucketsPrefix+key) if err != nil { errs <- err continue @@ -292,7 +292,7 @@ func (i *IdentityStore) loadEntities(ctx context.Context) error { } continue } - nsCtx := namespace.ContextWithNamespace(context.Background(), ns) + nsCtx := namespace.ContextWithNamespace(ctx, ns) // Ensure that there are no entities with duplicate names entityByName, err := i.MemDBEntityByName(nsCtx, entity.Name, false) @@ -1437,7 +1437,7 @@ func (i *IdentityStore) UpsertGroupInTxn(ctx context.Context, txn *memdb.Txn, gr Message: groupAsAny, } - sent, err := sendGroupUpgrade(i, group) + sent, err := sendGroupUpgrade(ctx, i, group) if err != nil { return err } diff --git a/vault/request_handling.go b/vault/request_handling.go index a1b1e4414..13c33ad7d 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -457,6 +457,8 @@ func (c *Core) handleCancelableRequest(ctx context.Context, ns *namespace.Namesp return nil, logical.CodedError(403, "namespaces feature not enabled") } + var walState = &logical.WALState{} + ctx = logical.IndexStateContext(ctx, walState) var auth *logical.Auth if c.router.LoginPath(ctx, req.Path) { resp, auth, err = c.handleLoginRequest(ctx, req) @@ -564,6 +566,26 @@ func (c *Core) handleCancelableRequest(ctx context.Context, ns *namespace.Namesp } } + if walState.LocalIndex != 0 || walState.ReplicatedIndex != 0 { + walState.ClusterID = c.clusterID.Load() + if walState.LocalIndex == 0 { + if c.perfStandby { + walState.LocalIndex = LastRemoteWAL(c) + } else { + walState.LocalIndex = LastWAL(c) + } + } + if walState.ReplicatedIndex == 0 { + if c.perfStandby { + walState.ReplicatedIndex = LastRemoteUpstreamWAL(c) + } else { + walState.ReplicatedIndex = LastRemoteWAL(c) + } + } + + req.SetResponseState(walState) + } + return } diff --git a/vault/wrapping.go b/vault/wrapping.go index 73e33c595..b2558ac45 100644 --- a/vault/wrapping.go +++ b/vault/wrapping.go @@ -76,6 +76,9 @@ func (c *Core) ensureWrappingKey(ctx context.Context) error { return nil } +// wrapInCubbyhole is invoked when a caller asks for response wrapping. +// On success, return (nil, nil) and mutates resp. On failure, returns +// either a response describing the failure or an error. func (c *Core) wrapInCubbyhole(ctx context.Context, req *logical.Request, resp *logical.Response, auth *logical.Auth) (*logical.Response, error) { if c.perfStandby { return forwardWrapRequest(ctx, c, req, resp, auth) diff --git a/vendor/github.com/hashicorp/vault/api/client.go b/vendor/github.com/hashicorp/vault/api/client.go index d286b7393..52e109b7b 100644 --- a/vendor/github.com/hashicorp/vault/api/client.go +++ b/vendor/github.com/hashicorp/vault/api/client.go @@ -384,6 +384,8 @@ type Client struct { wrappingLookupFunc WrappingLookupFunc mfaCreds []string policyOverride bool + requestCallbacks []RequestCallback + responseCallbacks []ResponseCallback } // NewClient returns a new client for the given configuration. @@ -866,6 +868,10 @@ func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Respon c.modifyLock.RUnlock() + for _, cb := range c.requestCallbacks { + cb(r) + } + if limiter != nil { limiter.Wait(ctx) } @@ -907,7 +913,7 @@ START: } if checkRetry == nil { - checkRetry = retryablehttp.DefaultRetryPolicy + checkRetry = DefaultRetryPolicy } client := &retryablehttp.Client{ @@ -968,9 +974,91 @@ START: goto START } + if result != nil { + for _, cb := range c.responseCallbacks { + cb(result) + } + } if err := result.Error(); err != nil { return result, err } return result, nil } + +type RequestCallback func(*Request) +type ResponseCallback func(*Response) + +// WithRequestCallbacks makes a shallow clone of Client, modifies it to use +// the given callbacks, and returns it. Each of the callbacks will be invoked +// on every outgoing request. A client may be used to issue requests +// concurrently; any locking needed by callbacks invoked concurrently is the +// callback's responsibility. +func (c *Client) WithRequestCallbacks(callbacks ...RequestCallback) *Client { + c2 := *c + c2.modifyLock = sync.RWMutex{} + c2.requestCallbacks = callbacks + return &c2 +} + +// WithResponseCallbacks makes a shallow clone of Client, modifies it to use +// the given callbacks, and returns it. Each of the callbacks will be invoked +// on every received response. A client may be used to issue requests +// concurrently; any locking needed by callbacks invoked concurrently is the +// callback's responsibility. +func (c *Client) WithResponseCallbacks(callbacks ...ResponseCallback) *Client { + c2 := *c + c2.modifyLock = sync.RWMutex{} + c2.responseCallbacks = callbacks + return &c2 +} + +// RecordState returns a response callback that will record the state returned +// by Vault in a response header. +func RecordState(state *string) ResponseCallback { + return func(resp *Response) { + *state = resp.Header.Get("X-Vault-Index") + } +} + +// RequireState returns a request callback that will add a request header to +// specify the state we require of Vault. This state was obtained from a +// response header seen previous, probably captured with RecordState. +func RequireState(states ...string) RequestCallback { + return func(req *Request) { + for _, s := range states { + req.Headers.Add("X-Vault-Index", s) + } + } +} + +// ForwardInconsistent returns a request callback that will add a request +// header which says: if the state required isn't present on the node receiving +// this request, forward it to the active node. This should be used in +// conjunction with RequireState. +func ForwardInconsistent() RequestCallback { + return func(req *Request) { + req.Headers.Set("X-Vault-Inconsistent", "forward-active-node") + } +} + +// ForwardAlways returns a request callback which adds a header telling any +// performance standbys handling the request to forward it to the active node. +// This feature must be enabled in Vault's configuration. +func ForwardAlways() RequestCallback { + return func(req *Request) { + req.Headers.Set("X-Vault-Forward", "active-node") + } +} + +// DefaultRetryPolicy is the default retry policy used by new Client objects. +// It is the same as retryablehttp.DefaultRetryPolicy except that it also retries +// 412 requests, which are returned by Vault when a X-Vault-Index header isn't +// satisfied. +func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) { + retry, err := retryablehttp.DefaultRetryPolicy(ctx, resp, err) + if err != nil || retry { + return retry, err + } + return resp.StatusCode == 412, nil +} diff --git a/vendor/github.com/hashicorp/vault/sdk/logical/request.go b/vendor/github.com/hashicorp/vault/sdk/logical/request.go index e4cd6ad36..1a103a57a 100644 --- a/vendor/github.com/hashicorp/vault/sdk/logical/request.go +++ b/vendor/github.com/hashicorp/vault/sdk/logical/request.go @@ -1,6 +1,7 @@ package logical import ( + "context" "fmt" "net/http" "strings" @@ -54,6 +55,30 @@ const ( ClientTokenFromAuthzHeader ) +type WALState struct { + ClusterID string + LocalIndex uint64 + ReplicatedIndex uint64 +} + +const indexStateCtxKey = "index_state" + +// IndexStateContext returns a context with an added value holding the index +// state that should be populated on writes. +func IndexStateContext(ctx context.Context, state *WALState) context.Context { + return context.WithValue(ctx, indexStateCtxKey, state) +} + +// IndexStateFromContext is a helper to look up if the provided context contains +// an index state pointer. +func IndexStateFromContext(ctx context.Context) *WALState { + s, ok := ctx.Value(indexStateCtxKey).(*WALState) + if !ok { + return nil + } + return s +} + // Request is a struct that stores the parameters and context of a request // being made to Vault. It is used to abstract the details of the higher level // request protocol from the handlers. @@ -179,6 +204,11 @@ type Request struct { // ResponseWriter if set can be used to stream a response value to the http // request that generated this logical.Request object. ResponseWriter *HTTPResponseWriter `json:"-" sentinel:""` + + // responseState is used internally to propagate the state that should appear + // in response headers; it's attached to the request rather than the response + // because not all requests yields non-nil responses. + responseState *WALState } // Clone returns a deep copy of the request by using copystructure @@ -243,6 +273,14 @@ func (r *Request) SetLastRemoteWAL(last uint64) { r.lastRemoteWAL = last } +func (r *Request) ResponseState() *WALState { + return r.responseState +} + +func (r *Request) SetResponseState(w *WALState) { + r.responseState = w +} + func (r *Request) TokenEntry() *TokenEntry { return r.tokenEntry }