OSS parts of the new client controlled consistency feature (#10974)

This commit is contained in:
Nick Cabatoff 2021-02-24 06:58:10 -05:00 committed by GitHub
parent 5502d43f6e
commit c1ddfbb538
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 754 additions and 44 deletions

View File

@ -384,6 +384,8 @@ type Client struct {
wrappingLookupFunc WrappingLookupFunc
mfaCreds []string
policyOverride bool
requestCallbacks []RequestCallback
responseCallbacks []ResponseCallback
}
// NewClient returns a new client for the given configuration.
@ -866,6 +868,10 @@ func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Respon
c.modifyLock.RUnlock()
for _, cb := range c.requestCallbacks {
cb(r)
}
if limiter != nil {
limiter.Wait(ctx)
}
@ -907,7 +913,7 @@ START:
}
if checkRetry == nil {
checkRetry = retryablehttp.DefaultRetryPolicy
checkRetry = DefaultRetryPolicy
}
client := &retryablehttp.Client{
@ -968,9 +974,91 @@ START:
goto START
}
if result != nil {
for _, cb := range c.responseCallbacks {
cb(result)
}
}
if err := result.Error(); err != nil {
return result, err
}
return result, nil
}
type RequestCallback func(*Request)
type ResponseCallback func(*Response)
// WithRequestCallbacks makes a shallow clone of Client, modifies it to use
// the given callbacks, and returns it. Each of the callbacks will be invoked
// on every outgoing request. A client may be used to issue requests
// concurrently; any locking needed by callbacks invoked concurrently is the
// callback's responsibility.
func (c *Client) WithRequestCallbacks(callbacks ...RequestCallback) *Client {
c2 := *c
c2.modifyLock = sync.RWMutex{}
c2.requestCallbacks = callbacks
return &c2
}
// WithResponseCallbacks makes a shallow clone of Client, modifies it to use
// the given callbacks, and returns it. Each of the callbacks will be invoked
// on every received response. A client may be used to issue requests
// concurrently; any locking needed by callbacks invoked concurrently is the
// callback's responsibility.
func (c *Client) WithResponseCallbacks(callbacks ...ResponseCallback) *Client {
c2 := *c
c2.modifyLock = sync.RWMutex{}
c2.responseCallbacks = callbacks
return &c2
}
// RecordState returns a response callback that will record the state returned
// by Vault in a response header.
func RecordState(state *string) ResponseCallback {
return func(resp *Response) {
*state = resp.Header.Get("X-Vault-Index")
}
}
// RequireState returns a request callback that will add a request header to
// specify the state we require of Vault. This state was obtained from a
// response header seen previous, probably captured with RecordState.
func RequireState(states ...string) RequestCallback {
return func(req *Request) {
for _, s := range states {
req.Headers.Add("X-Vault-Index", s)
}
}
}
// ForwardInconsistent returns a request callback that will add a request
// header which says: if the state required isn't present on the node receiving
// this request, forward it to the active node. This should be used in
// conjunction with RequireState.
func ForwardInconsistent() RequestCallback {
return func(req *Request) {
req.Headers.Set("X-Vault-Inconsistent", "forward-active-node")
}
}
// ForwardAlways returns a request callback which adds a header telling any
// performance standbys handling the request to forward it to the active node.
// This feature must be enabled in Vault's configuration.
func ForwardAlways() RequestCallback {
return func(req *Request) {
req.Headers.Set("X-Vault-Forward", "active-node")
}
}
// DefaultRetryPolicy is the default retry policy used by new Client objects.
// It is the same as retryablehttp.DefaultRetryPolicy except that it also retries
// 412 requests, which are returned by Vault when a X-Vault-Index header isn't
// satisfied.
func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) {
retry, err := retryablehttp.DefaultRetryPolicy(ctx, resp, err)
if err != nil || retry {
return retry, err
}
return resp.StatusCode == 412, nil
}

View File

@ -416,6 +416,30 @@ func (c *AgentCommand) Run(args []string) int {
}
}
enforceConsistency := cache.EnforceConsistencyNever
whenInconsistent := cache.WhenInconsistentFail
if config.Cache != nil {
switch config.Cache.EnforceConsistency {
case "always":
enforceConsistency = cache.EnforceConsistencyAlways
case "never", "":
default:
c.UI.Error(fmt.Sprintf("Unknown cache setting for enforce_consistency: %q", config.Cache.EnforceConsistency))
return 1
}
switch config.Cache.WhenInconsistent {
case "retry":
whenInconsistent = cache.WhenInconsistentRetry
case "forward":
whenInconsistent = cache.WhenInconsistentForward
case "fail", "":
default:
c.UI.Error(fmt.Sprintf("Unknown cache setting for when_inconsistent: %q", config.Cache.WhenInconsistent))
return 1
}
}
// Warn if cache _and_ cert auto-auth is enabled but certificates were not
// provided in the auto_auth.method["cert"].config stanza.
if config.Cache != nil && (config.AutoAuth != nil && config.AutoAuth.Method != nil && config.AutoAuth.Method.Type == "cert") {
@ -437,20 +461,16 @@ func (c *AgentCommand) Run(args []string) int {
c.UI.Output("==> Vault agent started! Log data will stream in below:\n")
}
// Inform any tests that the server is ready
select {
case c.startedCh <- struct{}{}:
default:
}
// Parse agent listener configurations
if config.Cache != nil && len(config.Listeners) != 0 {
cacheLogger := c.logger.Named("cache")
// Create the API proxier
apiProxy, err := cache.NewAPIProxy(&cache.APIProxyConfig{
Client: client,
Logger: cacheLogger.Named("apiproxy"),
Client: client,
Logger: cacheLogger.Named("apiproxy"),
EnforceConsistency: enforceConsistency,
WhenInconsistentAction: whenInconsistent,
})
if err != nil {
c.UI.Error(fmt.Sprintf("Error creating API proxy: %v", err))
@ -547,6 +567,12 @@ func (c *AgentCommand) Run(args []string) int {
defer c.cleanupGuard.Do(listenerCloseFunc)
}
// Inform any tests that the server is ready
select {
case c.startedCh <- struct{}{}:
default:
}
// Listen for signals
// TODO: implement support for SIGHUP reloading of configuration
// signal.Notify(c.signalCh)

View File

@ -3,21 +3,49 @@ package cache
import (
"context"
"fmt"
"sync"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-retryablehttp"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/hashicorp/vault/vault"
)
type EnforceConsistency int
const (
EnforceConsistencyNever EnforceConsistency = iota
EnforceConsistencyAlways
)
type WhenInconsistentAction int
const (
WhenInconsistentFail WhenInconsistentAction = iota
WhenInconsistentRetry
WhenInconsistentForward
)
// APIProxy is an implementation of the proxier interface that is used to
// forward the request to Vault and get the response.
type APIProxy struct {
client *api.Client
logger hclog.Logger
client *api.Client
logger hclog.Logger
enforceConsistency EnforceConsistency
whenInconsistentAction WhenInconsistentAction
l sync.RWMutex
lastIndexStates []string
}
var _ Proxier = &APIProxy{}
type APIProxyConfig struct {
Client *api.Client
Logger hclog.Logger
Client *api.Client
Logger hclog.Logger
EnforceConsistency EnforceConsistency
WhenInconsistentAction WhenInconsistentAction
}
func NewAPIProxy(config *APIProxyConfig) (Proxier, error) {
@ -25,11 +53,65 @@ func NewAPIProxy(config *APIProxyConfig) (Proxier, error) {
return nil, fmt.Errorf("nil API client")
}
return &APIProxy{
client: config.Client,
logger: config.Logger,
client: config.Client,
logger: config.Logger,
enforceConsistency: config.EnforceConsistency,
whenInconsistentAction: config.WhenInconsistentAction,
}, nil
}
// compareStates returns 1 if s1 is newer or identical, -1 if s1 is older, and 0
// if neither s1 or s2 is strictly greater. An error is returned if s1 or s2
// are invalid or from different clusters.
func compareStates(s1, s2 string) (int, error) {
w1, err := vault.ParseRequiredState(s1, nil)
if err != nil {
return 0, err
}
w2, err := vault.ParseRequiredState(s2, nil)
if err != nil {
return 0, err
}
if w1.ClusterID != w2.ClusterID {
return 0, fmt.Errorf("don't know how to compare states with different ClusterIDs")
}
switch {
case w1.LocalIndex >= w2.LocalIndex && w1.ReplicatedIndex >= w2.ReplicatedIndex:
return 1, nil
// We've already handled the case where both are equal above, so really we're
// asking here if one or both are lesser.
case w1.LocalIndex <= w2.LocalIndex && w1.ReplicatedIndex <= w2.ReplicatedIndex:
return -1, nil
}
return 0, nil
}
func mergeStates(old []string, new string) []string {
if len(old) == 0 || len(old) > 2 {
return []string{new}
}
var ret []string
for _, o := range old {
c, err := compareStates(o, new)
if err != nil {
return []string{new}
}
switch c {
case 1:
ret = append(ret, o)
case -1:
ret = append(ret, new)
case 0:
ret = append(ret, o, new)
}
}
return strutil.RemoveDuplicates(ret, false)
}
func (ap *APIProxy) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) {
client, err := ap.client.Clone()
if err != nil {
@ -51,6 +133,34 @@ func (ap *APIProxy) Send(ctx context.Context, req *SendRequest) (*SendResponse,
fwReq.Params = query
}
var newState string
manageState := ap.enforceConsistency == EnforceConsistencyAlways &&
req.Request.Header.Get(http.VaultIndexHeaderName) == "" &&
req.Request.Header.Get(http.VaultForwardHeaderName) == "" &&
req.Request.Header.Get(http.VaultInconsistentHeaderName) == ""
if manageState {
client = client.WithResponseCallbacks(api.RecordState(&newState))
ap.l.RLock()
lastStates := ap.lastIndexStates
ap.l.RUnlock()
if len(lastStates) != 0 {
client = client.WithRequestCallbacks(api.RequireState(lastStates...))
switch ap.whenInconsistentAction {
case WhenInconsistentFail:
// In this mode we want to delegate handling of inconsistency
// failures to the external client talking to Agent.
client.SetCheckRetry(retryablehttp.DefaultRetryPolicy)
case WhenInconsistentRetry:
// In this mode we want to handle retries due to inconsistency
// internally. This is the default api.Client behaviour so
// we needn't do anything.
case WhenInconsistentForward:
fwReq.Headers.Set(http.VaultInconsistentHeaderName, http.VaultInconsistentForward)
}
}
}
// Make the request to Vault and get the response
ap.logger.Info("forwarding request", "method", req.Request.Method, "path", req.Request.URL.Path)
@ -60,6 +170,20 @@ func (ap *APIProxy) Send(ctx context.Context, req *SendRequest) (*SendResponse,
return nil, err
}
if newState != "" {
ap.l.Lock()
// We want to be using the "newest" states seen, but newer isn't well
// defined here. There can be two states S1 and S2 which aren't strictly ordered:
// S1 could have a newer localindex and S2 could have a newer replicatedindex. So
// we need to merge them. But we can't merge them because we wouldn't be able to
// "sign" the resulting header because we don't have access to the HMAC key that
// Vault uses to do so. So instead we compare any of the 0-2 saved states
// we have to the new header, keeping the newest 1-2 of these, and sending
// them to Vault to evaluate.
ap.lastIndexStates = mergeStates(ap.lastIndexStates, newState)
ap.l.Unlock()
}
// Before error checking from the request call, we'd want to initialize a SendResponse to
// potentially return
sendResponse, newErr := NewSendResponse(resp, nil)

View File

@ -1,6 +1,8 @@
package cache
import (
"encoding/base64"
"github.com/go-test/deep"
"net/http"
"testing"
@ -93,3 +95,85 @@ func TestAPIProxy_queryParams(t *testing.T) {
t.Fatalf("exptected standby to return 200, got: %v", resp.Response.StatusCode)
}
}
func TestMergeStates(t *testing.T) {
type testCase struct {
name string
old []string
new string
expected []string
}
var testCases = []testCase{
{
name: "empty-old",
old: nil,
new: "v1:cid:1:0:",
expected: []string{"v1:cid:1:0:"},
},
{
name: "old-smaller",
old: []string{"v1:cid:1:0:"},
new: "v1:cid:2:0:",
expected: []string{"v1:cid:2:0:"},
},
{
name: "old-bigger",
old: []string{"v1:cid:2:0:"},
new: "v1:cid:1:0:",
expected: []string{"v1:cid:2:0:"},
},
{
name: "mixed-single",
old: []string{"v1:cid:1:0:"},
new: "v1:cid:0:1:",
expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
},
{
name: "mixed-single-alt",
old: []string{"v1:cid:0:1:"},
new: "v1:cid:1:0:",
expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
},
{
name: "mixed-double",
old: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
new: "v1:cid:2:0:",
expected: []string{"v1:cid:0:1:", "v1:cid:2:0:"},
},
{
name: "newer-both",
old: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
new: "v1:cid:2:1:",
expected: []string{"v1:cid:2:1:"},
},
}
b64enc := func(ss []string) []string {
var ret []string
for _, s := range ss {
ret = append(ret, base64.StdEncoding.EncodeToString([]byte(s)))
}
return ret
}
b64dec := func(ss []string) []string {
var ret []string
for _, s := range ss {
d, err := base64.StdEncoding.DecodeString(s)
if err != nil {
t.Fatal(err)
}
ret = append(ret, string(d))
}
return ret
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
out := b64dec(mergeStates(b64enc(tc.old), base64.StdEncoding.EncodeToString([]byte(tc.new))))
if diff := deep.Equal(out, tc.expected); len(diff) != 0 {
t.Errorf("got=%v, expected=%v, diff=%v", out, tc.expected, diff)
}
})
}
}

View File

@ -48,6 +48,8 @@ type Cache struct {
UseAutoAuthTokenRaw interface{} `hcl:"use_auto_auth_token"`
UseAutoAuthToken bool `hcl:"-"`
ForceAutoAuthToken bool `hcl:"-"`
EnforceConsistency string `hcl:"enforce_consistency"`
WhenInconsistent string `hcl:"when_inconsistent"`
}
// AutoAuth is the configured authentication method and sinks

View File

@ -742,3 +742,32 @@ func TestLoadConfigFile_Vault_Retry_Empty(t *testing.T) {
t.Fatal(diff)
}
}
func TestLoadConfigFile_EnforceConsistency(t *testing.T) {
config, err := LoadConfig("./test-fixtures/config-consistency.hcl")
if err != nil {
t.Fatal(err)
}
expected := &Config{
SharedConfig: &configutil.SharedConfig{
Listeners: []*configutil.Listener{
{
Type: "tcp",
Address: "127.0.0.1:8300",
TLSDisable: true,
},
},
PidFile: "",
},
Cache: &Cache{
EnforceConsistency: "always",
WhenInconsistent: "retry",
},
}
config.Listeners[0].RawConfig = nil
if diff := deep.Equal(config, expected); diff != nil {
t.Fatal(diff)
}
}

View File

@ -0,0 +1,9 @@
cache {
enforce_consistency = "always"
when_inconsistent = "retry"
}
listener "tcp" {
address = "127.0.0.1:8300"
tls_disable = true
}

View File

@ -41,7 +41,7 @@ func (s *StoragePacker) View() logical.Storage {
}
// GetBucket returns a bucket for a given key
func (s *StoragePacker) GetBucket(key string) (*Bucket, error) {
func (s *StoragePacker) GetBucket(ctx context.Context, key string) (*Bucket, error) {
if key == "" {
return nil, fmt.Errorf("missing bucket key")
}
@ -51,7 +51,7 @@ func (s *StoragePacker) GetBucket(key string) (*Bucket, error) {
defer lock.RUnlock()
// Read from storage
storageEntry, err := s.view.Get(context.Background(), key)
storageEntry, err := s.view.Get(ctx, key)
if err != nil {
return nil, errwrap.Wrapf("failed to read packed storage entry: {{err}}", err)
}
@ -126,8 +126,8 @@ func (s *StoragePacker) BucketKey(itemID string) string {
}
// DeleteItem removes the item from the respective bucket
func (s *StoragePacker) DeleteItem(_ context.Context, itemID string) error {
return s.DeleteMultipleItems(context.Background(), nil, []string{itemID})
func (s *StoragePacker) DeleteItem(ctx context.Context, itemID string) error {
return s.DeleteMultipleItems(ctx, nil, []string{itemID})
}
func (s *StoragePacker) DeleteMultipleItems(ctx context.Context, logger hclog.Logger, itemIDs []string) error {
@ -171,7 +171,7 @@ func (s *StoragePacker) DeleteMultipleItems(ctx context.Context, logger hclog.Lo
idx := 0
for bucketKey, itemsToRemove := range byBucket {
// Read bucket from storage
storageEntry, err := s.view.Get(context.Background(), bucketKey)
storageEntry, err := s.view.Get(ctx, bucketKey)
if err != nil {
return errwrap.Wrapf("failed to read packed storage value: {{err}}", err)
}
@ -278,7 +278,7 @@ func (s *StoragePacker) GetItem(itemID string) (*Item, error) {
bucketKey := s.BucketKey(itemID)
// Fetch the bucket entry
bucket, err := s.GetBucket(bucketKey)
bucket, err := s.GetBucket(context.Background(), bucketKey)
if err != nil {
return nil, errwrap.Wrapf("failed to read packed storage item: {{err}}", err)
}
@ -297,7 +297,7 @@ func (s *StoragePacker) GetItem(itemID string) (*Item, error) {
}
// PutItem stores the given item in its respective bucket
func (s *StoragePacker) PutItem(_ context.Context, item *Item) error {
func (s *StoragePacker) PutItem(ctx context.Context, item *Item) error {
defer metrics.MeasureSince([]string{"storage_packer", "put_item"}, time.Now())
if item == nil {
@ -323,7 +323,7 @@ func (s *StoragePacker) PutItem(_ context.Context, item *Item) error {
defer lock.Unlock()
// Check if there is an existing bucket for a given key
storageEntry, err := s.view.Get(context.Background(), bucketKey)
storageEntry, err := s.view.Get(ctx, bucketKey)
if err != nil {
return errwrap.Wrapf("failed to read packed storage bucket entry: {{err}}", err)
}
@ -354,7 +354,7 @@ func (s *StoragePacker) PutItem(_ context.Context, item *Item) error {
}
}
return s.putBucket(context.Background(), bucket)
return s.putBucket(ctx, bucket)
}
// NewStoragePacker creates a new storage packer for a given view

View File

@ -188,7 +188,6 @@ func FileBackendSetup(conf *vault.CoreConfig, opts *vault.TestClusterOptions) {
}
func RaftBackendSetup(conf *vault.CoreConfig, opts *vault.TestClusterOptions) {
conf.DisablePerformanceStandby = true
opts.KeepStandbysSealed = true
opts.PhysicalFactory = MakeRaftBackend
opts.SetupFunc = func(t testing.T, c *vault.TestCluster) {

View File

@ -57,6 +57,12 @@ const (
// soft-mandatory Sentinel policies.
PolicyOverrideHeaderName = "X-Vault-Policy-Override"
VaultIndexHeaderName = "X-Vault-Index"
VaultInconsistentHeaderName = "X-Vault-Inconsistent"
VaultForwardHeaderName = "X-Vault-Forward"
VaultInconsistentForward = "forward-active-node"
VaultInconsistentFail = "fail"
// DefaultMaxRequestSize is the default maximum accepted request size. This
// is to prevent a denial of service attack where no Content-Length is
// provided and the server is fed ever more data until it exhausts memory.
@ -666,12 +672,52 @@ func parseFormRequest(r *http.Request) (map[string]interface{}, error) {
return data, nil
}
// forwardBasedOnHeaders returns true if the request headers specify that
// we should forward to the active node - either unconditionally or because
// a specified state isn't present locally.
func forwardBasedOnHeaders(core *vault.Core, r *http.Request) (bool, error) {
rawForward := r.Header.Get(VaultForwardHeaderName)
if rawForward != "" {
if !core.AllowForwardingViaHeader() {
return false, fmt.Errorf("forwarding via header %s disabled in configuration", VaultForwardHeaderName)
}
if rawForward == "active-node" {
return true, nil
}
return false, nil
}
rawInconsistent := r.Header.Get(VaultInconsistentHeaderName)
if rawInconsistent == "" {
return false, nil
}
switch rawInconsistent {
case VaultInconsistentForward:
if !core.AllowForwardingViaHeader() {
return false, fmt.Errorf("forwarding via header %s=%s disabled in configuration",
VaultInconsistentHeaderName, VaultInconsistentForward)
}
default:
return false, nil
}
return core.MissingRequiredState(r.Header.Values(VaultIndexHeaderName)), nil
}
// handleRequestForwarding determines whether to forward a request or not,
// falling back on the older behavior of redirecting the client
func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// If we are a performance standby we can handle the request.
if core.PerfStandby() {
// Note if the client requested forwarding
shouldForward, err := forwardBasedOnHeaders(core, r)
if err != nil {
respondError(w, http.StatusBadRequest, err)
return
}
// If we are a performance standby we can maybe handle the request.
if core.PerfStandby() && !shouldForward {
ns, err := namespace.FromContext(r.Context())
if err != nil {
respondError(w, http.StatusBadRequest, err)

View File

@ -219,6 +219,12 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques
if err != nil || status != 0 {
return nil, nil, status, err
}
rawRequired := r.Header.Values(VaultIndexHeaderName)
if len(rawRequired) != 0 && core.MissingRequiredState(rawRequired) {
return nil, nil, http.StatusPreconditionFailed, fmt.Errorf("required index state not present")
}
req, err = requestAuth(core, r, req)
if err != nil {
if errwrap.Contains(err, logical.ErrPermissionDenied.Error()) {
@ -453,13 +459,13 @@ func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForw
return
default:
// Build and return the proper response if everything is fine.
respondLogical(w, r, req, resp, injectDataIntoTopLevel)
respondLogical(core, w, r, req, resp, injectDataIntoTopLevel)
return
}
})
}
func respondLogical(w http.ResponseWriter, r *http.Request, req *logical.Request, resp *logical.Response, injectDataIntoTopLevel bool) {
func respondLogical(core *vault.Core, w http.ResponseWriter, r *http.Request, req *logical.Request, resp *logical.Response, injectDataIntoTopLevel bool) {
var httpResp *logical.HTTPResponse
var ret interface{}
@ -509,6 +515,8 @@ func respondLogical(w http.ResponseWriter, r *http.Request, req *logical.Request
}
}
adjustResponse(core, w, req)
// Respond
respondOk(w, ret)
return

View File

@ -369,7 +369,7 @@ func TestLogical_RespondWithStatusCode(t *testing.T) {
}
w := httptest.NewRecorder()
respondLogical(w, nil, nil, resp404, false)
respondLogical(nil, w, nil, nil, resp404, false)
if w.Code != 404 {
t.Fatalf("Bad Status code: %d", w.Code)

View File

@ -28,6 +28,8 @@ var (
additionalRoutes = func(mux *http.ServeMux, core *vault.Core) {}
nonVotersAllowed = false
adjustResponse = func(core *vault.Core, w http.ResponseWriter, req *logical.Request) {}
)
func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler {

View File

@ -77,8 +77,8 @@ type FSM struct {
logger log.Logger
noopRestore bool
// applyDelay is used to simulate a slow apply in tests
applyDelay time.Duration
// applyCallback is used to control the pace of applies in tests
applyCallback func()
db *bolt.DB
@ -131,8 +131,12 @@ func (f *FSM) getDB() *bolt.DB {
// SetFSMDelay adds a delay to the FSM apply. This is used in tests to simulate
// a slow apply.
func (r *RaftBackend) SetFSMDelay(delay time.Duration) {
r.SetFSMApplyCallback(func() { time.Sleep(delay) })
}
func (r *RaftBackend) SetFSMApplyCallback(f func()) {
r.fsm.l.Lock()
r.fsm.applyDelay = delay
r.fsm.applyCallback = f
r.fsm.l.Unlock()
}
@ -469,8 +473,8 @@ func (f *FSM) ApplyBatch(logs []*raft.Log) []interface{} {
f.l.RLock()
defer f.l.RUnlock()
if f.applyDelay > 0 {
time.Sleep(f.applyDelay)
if f.applyCallback != nil {
f.applyCallback()
}
err = f.db.Update(func(tx *bolt.Tx) error {

View File

@ -263,6 +263,16 @@ func NewRaftBackend(conf map[string]string, logger log.Logger) (physical.Backend
return nil, fmt.Errorf("failed to create fsm: %v", err)
}
if delayRaw, ok := conf["apply_delay"]; ok {
delay, err := time.ParseDuration(delayRaw)
if err != nil {
return nil, fmt.Errorf("apply_delay does not parse as a duration: %w", err)
}
fsm.applyCallback = func() {
time.Sleep(delay)
}
}
// Build an all in-memory setup for dev mode, otherwise prepare a full
// disk-based setup.
var log raft.LogStore

View File

@ -1,6 +1,7 @@
package logical
import (
"context"
"fmt"
"net/http"
"strings"
@ -54,6 +55,30 @@ const (
ClientTokenFromAuthzHeader
)
type WALState struct {
ClusterID string
LocalIndex uint64
ReplicatedIndex uint64
}
const indexStateCtxKey = "index_state"
// IndexStateContext returns a context with an added value holding the index
// state that should be populated on writes.
func IndexStateContext(ctx context.Context, state *WALState) context.Context {
return context.WithValue(ctx, indexStateCtxKey, state)
}
// IndexStateFromContext is a helper to look up if the provided context contains
// an index state pointer.
func IndexStateFromContext(ctx context.Context) *WALState {
s, ok := ctx.Value(indexStateCtxKey).(*WALState)
if !ok {
return nil
}
return s
}
// Request is a struct that stores the parameters and context of a request
// being made to Vault. It is used to abstract the details of the higher level
// request protocol from the handlers.
@ -179,6 +204,11 @@ type Request struct {
// ResponseWriter if set can be used to stream a response value to the http
// request that generated this logical.Request object.
ResponseWriter *HTTPResponseWriter `json:"-" sentinel:""`
// responseState is used internally to propagate the state that should appear
// in response headers; it's attached to the request rather than the response
// because not all requests yields non-nil responses.
responseState *WALState
}
// Clone returns a deep copy of the request by using copystructure
@ -243,6 +273,14 @@ func (r *Request) SetLastRemoteWAL(last uint64) {
r.lastRemoteWAL = last
}
func (r *Request) ResponseState() *WALState {
return r.responseState
}
func (r *Request) SetResponseState(w *WALState) {
r.responseState = w
}
func (r *Request) TokenEntry() *TokenEntry {
return r.tokenEntry
}

View File

@ -280,6 +280,18 @@ func (c *Core) setupCluster(ctx context.Context) error {
}
}
c.clusterID.Store(cluster.ID)
return nil
}
func (c *Core) loadCluster(ctx context.Context) error {
cluster, err := c.Cluster(ctx)
if err != nil {
c.logger.Error("failed to get cluster details", "error", err)
return err
}
c.clusterID.Store(cluster.ID)
return nil
}

View File

@ -3,10 +3,14 @@ package vault
import (
"context"
"crypto/ecdsa"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
@ -16,6 +20,8 @@ import (
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
@ -48,6 +54,7 @@ import (
"github.com/hashicorp/vault/vault/quotas"
vaultseal "github.com/hashicorp/vault/vault/seal"
"github.com/patrickmn/go-cache"
uberAtomic "go.uber.org/atomic"
"google.golang.org/grpc"
)
@ -72,6 +79,8 @@ const (
// clusters that they need to perform a rekey operation synchronously; this
// isn't keyring-canary to avoid ignoring it when ignoring core/keyring
coreKeyringCanaryPath = "core/canary-keyring"
indexHeaderHMACKeyPath = "core/index-header-hmac-key"
)
var (
@ -108,6 +117,7 @@ var (
PerformanceMerkleRoot = merkleRootImpl
DRMerkleRoot = merkleRootImpl
LastRemoteWAL = lastRemoteWALImpl
LastRemoteUpstreamWAL = lastRemoteUpstreamWALImpl
WaitUntilWALShipped = waitUntilWALShippedImpl
)
@ -391,6 +401,8 @@ type Core struct {
//
// Name
clusterName string
// ID
clusterID uberAtomic.String
// Specific cipher suites to use for clustering, if any
clusterCipherSuites []uint16
// Used to modify cluster parameters
@ -546,6 +558,8 @@ type Core struct {
// number of workers to use for lease revocation in the expiration manager
numExpirationWorkers int
IndexHeaderHMACKey uberAtomic.Value
}
// CoreConfig is used to parameterize a core
@ -2201,6 +2215,10 @@ func lastRemoteWALImpl(c *Core) uint64 {
return 0
}
func lastRemoteUpstreamWALImpl(c *Core) uint64 {
return 0
}
func (c *Core) PhysicalSealConfigs(ctx context.Context) (*SealConfig, *SealConfig, error) {
pe, err := c.physical.Get(ctx, barrierSealConfigPath)
if err != nil {
@ -2542,6 +2560,10 @@ func (c *Core) IsDRSecondary() bool {
return c.ReplicationState().HasState(consts.ReplicationDRSecondary)
}
func (c *Core) IsPerfSecondary() bool {
return c.ReplicationState().HasState(consts.ReplicationPerformanceSecondary)
}
func (c *Core) AddLogger(logger log.Logger) {
c.allLoggersLock.Lock()
defer c.allLoggersLock.Unlock()
@ -2705,3 +2727,48 @@ func (c *Core) KeyRotateGracePeriod() time.Duration {
func (c *Core) SetKeyRotateGracePeriod(t time.Duration) {
atomic.StoreInt64(c.keyRotateGracePeriod, int64(t))
}
func ParseRequiredState(raw string, hmacKey []byte) (*logical.WALState, error) {
cooked, err := base64.StdEncoding.DecodeString(raw)
if err != nil {
return nil, err
}
s := string(cooked)
lastIndex := strings.LastIndexByte(s, ':')
if lastIndex == -1 {
return nil, fmt.Errorf("invalid state header format")
}
state, stateHMACRaw := s[:lastIndex], s[lastIndex+1:]
stateHMAC, err := hex.DecodeString(stateHMACRaw)
if err != nil {
return nil, fmt.Errorf("invalid state header HMAC: %v, %w", stateHMACRaw, err)
}
if len(hmacKey) != 0 {
hm := hmac.New(sha256.New, hmacKey)
hm.Write([]byte(state))
if !hmac.Equal(hm.Sum(nil), stateHMAC) {
return nil, fmt.Errorf("invalid state header HMAC (mismatch)")
}
}
pieces := strings.Split(state, ":")
if len(pieces) != 4 || pieces[0] != "v1" || pieces[1] == "" {
return nil, fmt.Errorf("invalid state header format")
}
localIndex, err := strconv.ParseUint(pieces[2], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid state header format")
}
replicatedIndex, err := strconv.ParseUint(pieces[3], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid state header format")
}
return &logical.WALState{
ClusterID: pieces[1],
LocalIndex: localIndex,
ReplicatedIndex: replicatedIndex,
}, nil
}

View File

@ -158,3 +158,11 @@ func (c *Core) quotasHandleLeases(ctx context.Context, action quotas.LeaseAction
func (c *Core) namespaceByPath(path string) *namespace.Namespace {
return namespace.RootNamespace
}
func (c *Core) AllowForwardingViaHeader() bool {
return false
}
func (c *Core) MissingRequiredState(raw []string) bool {
return false
}

View File

@ -52,6 +52,7 @@ func (e extendedSystemViewImpl) ForwardGenericRequest(ctx context.Context, req *
// Forward the request if allowed
if couldForward(e.core) {
ctx = namespace.ContextWithNamespace(ctx, e.mountEntry.Namespace())
ctx = logical.IndexStateContext(ctx, &logical.WALState{})
ctx = context.WithValue(ctx, ctxKeyForwardedRequestMountAccessor{}, e.mountEntry.Accessor)
return forward(ctx, e.core, req)
}

View File

@ -27,7 +27,7 @@ const (
var (
caseSensitivityKey = "casesensitivity"
sendGroupUpgrade = func(*IdentityStore, *identity.Group) (bool, error) { return false, nil }
sendGroupUpgrade = func(context.Context, *IdentityStore, *identity.Group) (bool, error) { return false, nil }
parseExtraEntityFromBucket = func(context.Context, *IdentityStore, *identity.Entity) (bool, error) { return false, nil }
addExtraEntityDataToResponse = func(*identity.Entity, map[string]interface{}) {}
)
@ -210,7 +210,7 @@ func (i *IdentityStore) Invalidate(ctx context.Context, key string) {
}
// Get the storage bucket entry
bucket, err := i.entityPacker.GetBucket(key)
bucket, err := i.entityPacker.GetBucket(ctx, key)
if err != nil {
i.logger.Error("failed to refresh entities", "key", key, "error", err)
return
@ -273,7 +273,7 @@ func (i *IdentityStore) Invalidate(ctx context.Context, key string) {
}
// Get the storage bucket entry
bucket, err := i.groupPacker.GetBucket(key)
bucket, err := i.groupPacker.GetBucket(ctx, key)
if err != nil {
i.logger.Error("failed to refresh group", "key", key, "error", err)
return

View File

@ -569,6 +569,8 @@ func (i *IdentityStore) handleEntityBatchDelete() framework.OperationFunc {
}
}
// handleEntityDeleteCommon deletes an entity by removing it from groups of
// which it's a member and then, if update is true, deleting the entity itself.
func (i *IdentityStore) handleEntityDeleteCommon(ctx context.Context, txn *memdb.Txn, entity *identity.Entity, update bool) error {
ns, err := namespace.FromContext(ctx)
if err != nil {

View File

@ -89,7 +89,7 @@ func (i *IdentityStore) loadGroups(ctx context.Context) error {
i.logger.Debug("groups collected", "num_existing", len(existing))
for _, key := range existing {
bucket, err := i.groupPacker.GetBucket(groupBucketsPrefix + key)
bucket, err := i.groupPacker.GetBucket(ctx, groupBucketsPrefix+key)
if err != nil {
return err
}
@ -124,7 +124,7 @@ func (i *IdentityStore) loadGroups(ctx context.Context) error {
}
continue
}
nsCtx := namespace.ContextWithNamespace(context.Background(), ns)
nsCtx := namespace.ContextWithNamespace(ctx, ns)
// Ensure that there are no groups with duplicate names
groupByName, err := i.MemDBGroupByName(nsCtx, group.Name, false)
@ -212,7 +212,7 @@ func (i *IdentityStore) loadEntities(ctx context.Context) error {
return
}
bucket, err := i.entityPacker.GetBucket(storagepacker.StoragePackerBucketsPrefix + key)
bucket, err := i.entityPacker.GetBucket(ctx, storagepacker.StoragePackerBucketsPrefix+key)
if err != nil {
errs <- err
continue
@ -292,7 +292,7 @@ func (i *IdentityStore) loadEntities(ctx context.Context) error {
}
continue
}
nsCtx := namespace.ContextWithNamespace(context.Background(), ns)
nsCtx := namespace.ContextWithNamespace(ctx, ns)
// Ensure that there are no entities with duplicate names
entityByName, err := i.MemDBEntityByName(nsCtx, entity.Name, false)
@ -1437,7 +1437,7 @@ func (i *IdentityStore) UpsertGroupInTxn(ctx context.Context, txn *memdb.Txn, gr
Message: groupAsAny,
}
sent, err := sendGroupUpgrade(i, group)
sent, err := sendGroupUpgrade(ctx, i, group)
if err != nil {
return err
}

View File

@ -457,6 +457,8 @@ func (c *Core) handleCancelableRequest(ctx context.Context, ns *namespace.Namesp
return nil, logical.CodedError(403, "namespaces feature not enabled")
}
var walState = &logical.WALState{}
ctx = logical.IndexStateContext(ctx, walState)
var auth *logical.Auth
if c.router.LoginPath(ctx, req.Path) {
resp, auth, err = c.handleLoginRequest(ctx, req)
@ -564,6 +566,26 @@ func (c *Core) handleCancelableRequest(ctx context.Context, ns *namespace.Namesp
}
}
if walState.LocalIndex != 0 || walState.ReplicatedIndex != 0 {
walState.ClusterID = c.clusterID.Load()
if walState.LocalIndex == 0 {
if c.perfStandby {
walState.LocalIndex = LastRemoteWAL(c)
} else {
walState.LocalIndex = LastWAL(c)
}
}
if walState.ReplicatedIndex == 0 {
if c.perfStandby {
walState.ReplicatedIndex = LastRemoteUpstreamWAL(c)
} else {
walState.ReplicatedIndex = LastRemoteWAL(c)
}
}
req.SetResponseState(walState)
}
return
}

View File

@ -76,6 +76,9 @@ func (c *Core) ensureWrappingKey(ctx context.Context) error {
return nil
}
// wrapInCubbyhole is invoked when a caller asks for response wrapping.
// On success, return (nil, nil) and mutates resp. On failure, returns
// either a response describing the failure or an error.
func (c *Core) wrapInCubbyhole(ctx context.Context, req *logical.Request, resp *logical.Response, auth *logical.Auth) (*logical.Response, error) {
if c.perfStandby {
return forwardWrapRequest(ctx, c, req, resp, auth)

View File

@ -384,6 +384,8 @@ type Client struct {
wrappingLookupFunc WrappingLookupFunc
mfaCreds []string
policyOverride bool
requestCallbacks []RequestCallback
responseCallbacks []ResponseCallback
}
// NewClient returns a new client for the given configuration.
@ -866,6 +868,10 @@ func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Respon
c.modifyLock.RUnlock()
for _, cb := range c.requestCallbacks {
cb(r)
}
if limiter != nil {
limiter.Wait(ctx)
}
@ -907,7 +913,7 @@ START:
}
if checkRetry == nil {
checkRetry = retryablehttp.DefaultRetryPolicy
checkRetry = DefaultRetryPolicy
}
client := &retryablehttp.Client{
@ -968,9 +974,91 @@ START:
goto START
}
if result != nil {
for _, cb := range c.responseCallbacks {
cb(result)
}
}
if err := result.Error(); err != nil {
return result, err
}
return result, nil
}
type RequestCallback func(*Request)
type ResponseCallback func(*Response)
// WithRequestCallbacks makes a shallow clone of Client, modifies it to use
// the given callbacks, and returns it. Each of the callbacks will be invoked
// on every outgoing request. A client may be used to issue requests
// concurrently; any locking needed by callbacks invoked concurrently is the
// callback's responsibility.
func (c *Client) WithRequestCallbacks(callbacks ...RequestCallback) *Client {
c2 := *c
c2.modifyLock = sync.RWMutex{}
c2.requestCallbacks = callbacks
return &c2
}
// WithResponseCallbacks makes a shallow clone of Client, modifies it to use
// the given callbacks, and returns it. Each of the callbacks will be invoked
// on every received response. A client may be used to issue requests
// concurrently; any locking needed by callbacks invoked concurrently is the
// callback's responsibility.
func (c *Client) WithResponseCallbacks(callbacks ...ResponseCallback) *Client {
c2 := *c
c2.modifyLock = sync.RWMutex{}
c2.responseCallbacks = callbacks
return &c2
}
// RecordState returns a response callback that will record the state returned
// by Vault in a response header.
func RecordState(state *string) ResponseCallback {
return func(resp *Response) {
*state = resp.Header.Get("X-Vault-Index")
}
}
// RequireState returns a request callback that will add a request header to
// specify the state we require of Vault. This state was obtained from a
// response header seen previous, probably captured with RecordState.
func RequireState(states ...string) RequestCallback {
return func(req *Request) {
for _, s := range states {
req.Headers.Add("X-Vault-Index", s)
}
}
}
// ForwardInconsistent returns a request callback that will add a request
// header which says: if the state required isn't present on the node receiving
// this request, forward it to the active node. This should be used in
// conjunction with RequireState.
func ForwardInconsistent() RequestCallback {
return func(req *Request) {
req.Headers.Set("X-Vault-Inconsistent", "forward-active-node")
}
}
// ForwardAlways returns a request callback which adds a header telling any
// performance standbys handling the request to forward it to the active node.
// This feature must be enabled in Vault's configuration.
func ForwardAlways() RequestCallback {
return func(req *Request) {
req.Headers.Set("X-Vault-Forward", "active-node")
}
}
// DefaultRetryPolicy is the default retry policy used by new Client objects.
// It is the same as retryablehttp.DefaultRetryPolicy except that it also retries
// 412 requests, which are returned by Vault when a X-Vault-Index header isn't
// satisfied.
func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) {
retry, err := retryablehttp.DefaultRetryPolicy(ctx, resp, err)
if err != nil || retry {
return retry, err
}
return resp.StatusCode == 412, nil
}

View File

@ -1,6 +1,7 @@
package logical
import (
"context"
"fmt"
"net/http"
"strings"
@ -54,6 +55,30 @@ const (
ClientTokenFromAuthzHeader
)
type WALState struct {
ClusterID string
LocalIndex uint64
ReplicatedIndex uint64
}
const indexStateCtxKey = "index_state"
// IndexStateContext returns a context with an added value holding the index
// state that should be populated on writes.
func IndexStateContext(ctx context.Context, state *WALState) context.Context {
return context.WithValue(ctx, indexStateCtxKey, state)
}
// IndexStateFromContext is a helper to look up if the provided context contains
// an index state pointer.
func IndexStateFromContext(ctx context.Context) *WALState {
s, ok := ctx.Value(indexStateCtxKey).(*WALState)
if !ok {
return nil
}
return s
}
// Request is a struct that stores the parameters and context of a request
// being made to Vault. It is used to abstract the details of the higher level
// request protocol from the handlers.
@ -179,6 +204,11 @@ type Request struct {
// ResponseWriter if set can be used to stream a response value to the http
// request that generated this logical.Request object.
ResponseWriter *HTTPResponseWriter `json:"-" sentinel:""`
// responseState is used internally to propagate the state that should appear
// in response headers; it's attached to the request rather than the response
// because not all requests yields non-nil responses.
responseState *WALState
}
// Clone returns a deep copy of the request by using copystructure
@ -243,6 +273,14 @@ func (r *Request) SetLastRemoteWAL(last uint64) {
r.lastRemoteWAL = last
}
func (r *Request) ResponseState() *WALState {
return r.responseState
}
func (r *Request) SetResponseState(w *WALState) {
r.responseState = w
}
func (r *Request) TokenEntry() *TokenEntry {
return r.tokenEntry
}