OSS parts of the new client controlled consistency feature (#10974)
This commit is contained in:
parent
5502d43f6e
commit
c1ddfbb538
|
@ -384,6 +384,8 @@ type Client struct {
|
|||
wrappingLookupFunc WrappingLookupFunc
|
||||
mfaCreds []string
|
||||
policyOverride bool
|
||||
requestCallbacks []RequestCallback
|
||||
responseCallbacks []ResponseCallback
|
||||
}
|
||||
|
||||
// NewClient returns a new client for the given configuration.
|
||||
|
@ -866,6 +868,10 @@ func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Respon
|
|||
|
||||
c.modifyLock.RUnlock()
|
||||
|
||||
for _, cb := range c.requestCallbacks {
|
||||
cb(r)
|
||||
}
|
||||
|
||||
if limiter != nil {
|
||||
limiter.Wait(ctx)
|
||||
}
|
||||
|
@ -907,7 +913,7 @@ START:
|
|||
}
|
||||
|
||||
if checkRetry == nil {
|
||||
checkRetry = retryablehttp.DefaultRetryPolicy
|
||||
checkRetry = DefaultRetryPolicy
|
||||
}
|
||||
|
||||
client := &retryablehttp.Client{
|
||||
|
@ -968,9 +974,91 @@ START:
|
|||
goto START
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
for _, cb := range c.responseCallbacks {
|
||||
cb(result)
|
||||
}
|
||||
}
|
||||
if err := result.Error(); err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type RequestCallback func(*Request)
|
||||
type ResponseCallback func(*Response)
|
||||
|
||||
// WithRequestCallbacks makes a shallow clone of Client, modifies it to use
|
||||
// the given callbacks, and returns it. Each of the callbacks will be invoked
|
||||
// on every outgoing request. A client may be used to issue requests
|
||||
// concurrently; any locking needed by callbacks invoked concurrently is the
|
||||
// callback's responsibility.
|
||||
func (c *Client) WithRequestCallbacks(callbacks ...RequestCallback) *Client {
|
||||
c2 := *c
|
||||
c2.modifyLock = sync.RWMutex{}
|
||||
c2.requestCallbacks = callbacks
|
||||
return &c2
|
||||
}
|
||||
|
||||
// WithResponseCallbacks makes a shallow clone of Client, modifies it to use
|
||||
// the given callbacks, and returns it. Each of the callbacks will be invoked
|
||||
// on every received response. A client may be used to issue requests
|
||||
// concurrently; any locking needed by callbacks invoked concurrently is the
|
||||
// callback's responsibility.
|
||||
func (c *Client) WithResponseCallbacks(callbacks ...ResponseCallback) *Client {
|
||||
c2 := *c
|
||||
c2.modifyLock = sync.RWMutex{}
|
||||
c2.responseCallbacks = callbacks
|
||||
return &c2
|
||||
}
|
||||
|
||||
// RecordState returns a response callback that will record the state returned
|
||||
// by Vault in a response header.
|
||||
func RecordState(state *string) ResponseCallback {
|
||||
return func(resp *Response) {
|
||||
*state = resp.Header.Get("X-Vault-Index")
|
||||
}
|
||||
}
|
||||
|
||||
// RequireState returns a request callback that will add a request header to
|
||||
// specify the state we require of Vault. This state was obtained from a
|
||||
// response header seen previous, probably captured with RecordState.
|
||||
func RequireState(states ...string) RequestCallback {
|
||||
return func(req *Request) {
|
||||
for _, s := range states {
|
||||
req.Headers.Add("X-Vault-Index", s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ForwardInconsistent returns a request callback that will add a request
|
||||
// header which says: if the state required isn't present on the node receiving
|
||||
// this request, forward it to the active node. This should be used in
|
||||
// conjunction with RequireState.
|
||||
func ForwardInconsistent() RequestCallback {
|
||||
return func(req *Request) {
|
||||
req.Headers.Set("X-Vault-Inconsistent", "forward-active-node")
|
||||
}
|
||||
}
|
||||
|
||||
// ForwardAlways returns a request callback which adds a header telling any
|
||||
// performance standbys handling the request to forward it to the active node.
|
||||
// This feature must be enabled in Vault's configuration.
|
||||
func ForwardAlways() RequestCallback {
|
||||
return func(req *Request) {
|
||||
req.Headers.Set("X-Vault-Forward", "active-node")
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultRetryPolicy is the default retry policy used by new Client objects.
|
||||
// It is the same as retryablehttp.DefaultRetryPolicy except that it also retries
|
||||
// 412 requests, which are returned by Vault when a X-Vault-Index header isn't
|
||||
// satisfied.
|
||||
func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) {
|
||||
retry, err := retryablehttp.DefaultRetryPolicy(ctx, resp, err)
|
||||
if err != nil || retry {
|
||||
return retry, err
|
||||
}
|
||||
return resp.StatusCode == 412, nil
|
||||
}
|
||||
|
|
|
@ -416,6 +416,30 @@ func (c *AgentCommand) Run(args []string) int {
|
|||
}
|
||||
}
|
||||
|
||||
enforceConsistency := cache.EnforceConsistencyNever
|
||||
whenInconsistent := cache.WhenInconsistentFail
|
||||
if config.Cache != nil {
|
||||
switch config.Cache.EnforceConsistency {
|
||||
case "always":
|
||||
enforceConsistency = cache.EnforceConsistencyAlways
|
||||
case "never", "":
|
||||
default:
|
||||
c.UI.Error(fmt.Sprintf("Unknown cache setting for enforce_consistency: %q", config.Cache.EnforceConsistency))
|
||||
return 1
|
||||
}
|
||||
|
||||
switch config.Cache.WhenInconsistent {
|
||||
case "retry":
|
||||
whenInconsistent = cache.WhenInconsistentRetry
|
||||
case "forward":
|
||||
whenInconsistent = cache.WhenInconsistentForward
|
||||
case "fail", "":
|
||||
default:
|
||||
c.UI.Error(fmt.Sprintf("Unknown cache setting for when_inconsistent: %q", config.Cache.WhenInconsistent))
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
// Warn if cache _and_ cert auto-auth is enabled but certificates were not
|
||||
// provided in the auto_auth.method["cert"].config stanza.
|
||||
if config.Cache != nil && (config.AutoAuth != nil && config.AutoAuth.Method != nil && config.AutoAuth.Method.Type == "cert") {
|
||||
|
@ -437,20 +461,16 @@ func (c *AgentCommand) Run(args []string) int {
|
|||
c.UI.Output("==> Vault agent started! Log data will stream in below:\n")
|
||||
}
|
||||
|
||||
// Inform any tests that the server is ready
|
||||
select {
|
||||
case c.startedCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
// Parse agent listener configurations
|
||||
if config.Cache != nil && len(config.Listeners) != 0 {
|
||||
cacheLogger := c.logger.Named("cache")
|
||||
|
||||
// Create the API proxier
|
||||
apiProxy, err := cache.NewAPIProxy(&cache.APIProxyConfig{
|
||||
Client: client,
|
||||
Logger: cacheLogger.Named("apiproxy"),
|
||||
Client: client,
|
||||
Logger: cacheLogger.Named("apiproxy"),
|
||||
EnforceConsistency: enforceConsistency,
|
||||
WhenInconsistentAction: whenInconsistent,
|
||||
})
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error creating API proxy: %v", err))
|
||||
|
@ -547,6 +567,12 @@ func (c *AgentCommand) Run(args []string) int {
|
|||
defer c.cleanupGuard.Do(listenerCloseFunc)
|
||||
}
|
||||
|
||||
// Inform any tests that the server is ready
|
||||
select {
|
||||
case c.startedCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
// Listen for signals
|
||||
// TODO: implement support for SIGHUP reloading of configuration
|
||||
// signal.Notify(c.signalCh)
|
||||
|
|
|
@ -3,21 +3,49 @@ package cache
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
hclog "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/go-retryablehttp"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/sdk/helper/strutil"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
||||
type EnforceConsistency int
|
||||
|
||||
const (
|
||||
EnforceConsistencyNever EnforceConsistency = iota
|
||||
EnforceConsistencyAlways
|
||||
)
|
||||
|
||||
type WhenInconsistentAction int
|
||||
|
||||
const (
|
||||
WhenInconsistentFail WhenInconsistentAction = iota
|
||||
WhenInconsistentRetry
|
||||
WhenInconsistentForward
|
||||
)
|
||||
|
||||
// APIProxy is an implementation of the proxier interface that is used to
|
||||
// forward the request to Vault and get the response.
|
||||
type APIProxy struct {
|
||||
client *api.Client
|
||||
logger hclog.Logger
|
||||
client *api.Client
|
||||
logger hclog.Logger
|
||||
enforceConsistency EnforceConsistency
|
||||
whenInconsistentAction WhenInconsistentAction
|
||||
l sync.RWMutex
|
||||
lastIndexStates []string
|
||||
}
|
||||
|
||||
var _ Proxier = &APIProxy{}
|
||||
|
||||
type APIProxyConfig struct {
|
||||
Client *api.Client
|
||||
Logger hclog.Logger
|
||||
Client *api.Client
|
||||
Logger hclog.Logger
|
||||
EnforceConsistency EnforceConsistency
|
||||
WhenInconsistentAction WhenInconsistentAction
|
||||
}
|
||||
|
||||
func NewAPIProxy(config *APIProxyConfig) (Proxier, error) {
|
||||
|
@ -25,11 +53,65 @@ func NewAPIProxy(config *APIProxyConfig) (Proxier, error) {
|
|||
return nil, fmt.Errorf("nil API client")
|
||||
}
|
||||
return &APIProxy{
|
||||
client: config.Client,
|
||||
logger: config.Logger,
|
||||
client: config.Client,
|
||||
logger: config.Logger,
|
||||
enforceConsistency: config.EnforceConsistency,
|
||||
whenInconsistentAction: config.WhenInconsistentAction,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// compareStates returns 1 if s1 is newer or identical, -1 if s1 is older, and 0
|
||||
// if neither s1 or s2 is strictly greater. An error is returned if s1 or s2
|
||||
// are invalid or from different clusters.
|
||||
func compareStates(s1, s2 string) (int, error) {
|
||||
w1, err := vault.ParseRequiredState(s1, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
w2, err := vault.ParseRequiredState(s2, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if w1.ClusterID != w2.ClusterID {
|
||||
return 0, fmt.Errorf("don't know how to compare states with different ClusterIDs")
|
||||
}
|
||||
|
||||
switch {
|
||||
case w1.LocalIndex >= w2.LocalIndex && w1.ReplicatedIndex >= w2.ReplicatedIndex:
|
||||
return 1, nil
|
||||
// We've already handled the case where both are equal above, so really we're
|
||||
// asking here if one or both are lesser.
|
||||
case w1.LocalIndex <= w2.LocalIndex && w1.ReplicatedIndex <= w2.ReplicatedIndex:
|
||||
return -1, nil
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func mergeStates(old []string, new string) []string {
|
||||
if len(old) == 0 || len(old) > 2 {
|
||||
return []string{new}
|
||||
}
|
||||
|
||||
var ret []string
|
||||
for _, o := range old {
|
||||
c, err := compareStates(o, new)
|
||||
if err != nil {
|
||||
return []string{new}
|
||||
}
|
||||
switch c {
|
||||
case 1:
|
||||
ret = append(ret, o)
|
||||
case -1:
|
||||
ret = append(ret, new)
|
||||
case 0:
|
||||
ret = append(ret, o, new)
|
||||
}
|
||||
}
|
||||
return strutil.RemoveDuplicates(ret, false)
|
||||
}
|
||||
|
||||
func (ap *APIProxy) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) {
|
||||
client, err := ap.client.Clone()
|
||||
if err != nil {
|
||||
|
@ -51,6 +133,34 @@ func (ap *APIProxy) Send(ctx context.Context, req *SendRequest) (*SendResponse,
|
|||
fwReq.Params = query
|
||||
}
|
||||
|
||||
var newState string
|
||||
manageState := ap.enforceConsistency == EnforceConsistencyAlways &&
|
||||
req.Request.Header.Get(http.VaultIndexHeaderName) == "" &&
|
||||
req.Request.Header.Get(http.VaultForwardHeaderName) == "" &&
|
||||
req.Request.Header.Get(http.VaultInconsistentHeaderName) == ""
|
||||
|
||||
if manageState {
|
||||
client = client.WithResponseCallbacks(api.RecordState(&newState))
|
||||
ap.l.RLock()
|
||||
lastStates := ap.lastIndexStates
|
||||
ap.l.RUnlock()
|
||||
if len(lastStates) != 0 {
|
||||
client = client.WithRequestCallbacks(api.RequireState(lastStates...))
|
||||
switch ap.whenInconsistentAction {
|
||||
case WhenInconsistentFail:
|
||||
// In this mode we want to delegate handling of inconsistency
|
||||
// failures to the external client talking to Agent.
|
||||
client.SetCheckRetry(retryablehttp.DefaultRetryPolicy)
|
||||
case WhenInconsistentRetry:
|
||||
// In this mode we want to handle retries due to inconsistency
|
||||
// internally. This is the default api.Client behaviour so
|
||||
// we needn't do anything.
|
||||
case WhenInconsistentForward:
|
||||
fwReq.Headers.Set(http.VaultInconsistentHeaderName, http.VaultInconsistentForward)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Make the request to Vault and get the response
|
||||
ap.logger.Info("forwarding request", "method", req.Request.Method, "path", req.Request.URL.Path)
|
||||
|
||||
|
@ -60,6 +170,20 @@ func (ap *APIProxy) Send(ctx context.Context, req *SendRequest) (*SendResponse,
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if newState != "" {
|
||||
ap.l.Lock()
|
||||
// We want to be using the "newest" states seen, but newer isn't well
|
||||
// defined here. There can be two states S1 and S2 which aren't strictly ordered:
|
||||
// S1 could have a newer localindex and S2 could have a newer replicatedindex. So
|
||||
// we need to merge them. But we can't merge them because we wouldn't be able to
|
||||
// "sign" the resulting header because we don't have access to the HMAC key that
|
||||
// Vault uses to do so. So instead we compare any of the 0-2 saved states
|
||||
// we have to the new header, keeping the newest 1-2 of these, and sending
|
||||
// them to Vault to evaluate.
|
||||
ap.lastIndexStates = mergeStates(ap.lastIndexStates, newState)
|
||||
ap.l.Unlock()
|
||||
}
|
||||
|
||||
// Before error checking from the request call, we'd want to initialize a SendResponse to
|
||||
// potentially return
|
||||
sendResponse, newErr := NewSendResponse(resp, nil)
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"github.com/go-test/deep"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
|
@ -93,3 +95,85 @@ func TestAPIProxy_queryParams(t *testing.T) {
|
|||
t.Fatalf("exptected standby to return 200, got: %v", resp.Response.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeStates(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
old []string
|
||||
new string
|
||||
expected []string
|
||||
}
|
||||
|
||||
var testCases = []testCase{
|
||||
{
|
||||
name: "empty-old",
|
||||
old: nil,
|
||||
new: "v1:cid:1:0:",
|
||||
expected: []string{"v1:cid:1:0:"},
|
||||
},
|
||||
{
|
||||
name: "old-smaller",
|
||||
old: []string{"v1:cid:1:0:"},
|
||||
new: "v1:cid:2:0:",
|
||||
expected: []string{"v1:cid:2:0:"},
|
||||
},
|
||||
{
|
||||
name: "old-bigger",
|
||||
old: []string{"v1:cid:2:0:"},
|
||||
new: "v1:cid:1:0:",
|
||||
expected: []string{"v1:cid:2:0:"},
|
||||
},
|
||||
{
|
||||
name: "mixed-single",
|
||||
old: []string{"v1:cid:1:0:"},
|
||||
new: "v1:cid:0:1:",
|
||||
expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
|
||||
},
|
||||
{
|
||||
name: "mixed-single-alt",
|
||||
old: []string{"v1:cid:0:1:"},
|
||||
new: "v1:cid:1:0:",
|
||||
expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
|
||||
},
|
||||
{
|
||||
name: "mixed-double",
|
||||
old: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
|
||||
new: "v1:cid:2:0:",
|
||||
expected: []string{"v1:cid:0:1:", "v1:cid:2:0:"},
|
||||
},
|
||||
{
|
||||
name: "newer-both",
|
||||
old: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
|
||||
new: "v1:cid:2:1:",
|
||||
expected: []string{"v1:cid:2:1:"},
|
||||
},
|
||||
}
|
||||
|
||||
b64enc := func(ss []string) []string {
|
||||
var ret []string
|
||||
for _, s := range ss {
|
||||
ret = append(ret, base64.StdEncoding.EncodeToString([]byte(s)))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
b64dec := func(ss []string) []string {
|
||||
var ret []string
|
||||
for _, s := range ss {
|
||||
d, err := base64.StdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ret = append(ret, string(d))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
out := b64dec(mergeStates(b64enc(tc.old), base64.StdEncoding.EncodeToString([]byte(tc.new))))
|
||||
if diff := deep.Equal(out, tc.expected); len(diff) != 0 {
|
||||
t.Errorf("got=%v, expected=%v, diff=%v", out, tc.expected, diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -48,6 +48,8 @@ type Cache struct {
|
|||
UseAutoAuthTokenRaw interface{} `hcl:"use_auto_auth_token"`
|
||||
UseAutoAuthToken bool `hcl:"-"`
|
||||
ForceAutoAuthToken bool `hcl:"-"`
|
||||
EnforceConsistency string `hcl:"enforce_consistency"`
|
||||
WhenInconsistent string `hcl:"when_inconsistent"`
|
||||
}
|
||||
|
||||
// AutoAuth is the configured authentication method and sinks
|
||||
|
|
|
@ -742,3 +742,32 @@ func TestLoadConfigFile_Vault_Retry_Empty(t *testing.T) {
|
|||
t.Fatal(diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigFile_EnforceConsistency(t *testing.T) {
|
||||
config, err := LoadConfig("./test-fixtures/config-consistency.hcl")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := &Config{
|
||||
SharedConfig: &configutil.SharedConfig{
|
||||
Listeners: []*configutil.Listener{
|
||||
{
|
||||
Type: "tcp",
|
||||
Address: "127.0.0.1:8300",
|
||||
TLSDisable: true,
|
||||
},
|
||||
},
|
||||
PidFile: "",
|
||||
},
|
||||
Cache: &Cache{
|
||||
EnforceConsistency: "always",
|
||||
WhenInconsistent: "retry",
|
||||
},
|
||||
}
|
||||
|
||||
config.Listeners[0].RawConfig = nil
|
||||
if diff := deep.Equal(config, expected); diff != nil {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
cache {
|
||||
enforce_consistency = "always"
|
||||
when_inconsistent = "retry"
|
||||
}
|
||||
|
||||
listener "tcp" {
|
||||
address = "127.0.0.1:8300"
|
||||
tls_disable = true
|
||||
}
|
|
@ -41,7 +41,7 @@ func (s *StoragePacker) View() logical.Storage {
|
|||
}
|
||||
|
||||
// GetBucket returns a bucket for a given key
|
||||
func (s *StoragePacker) GetBucket(key string) (*Bucket, error) {
|
||||
func (s *StoragePacker) GetBucket(ctx context.Context, key string) (*Bucket, error) {
|
||||
if key == "" {
|
||||
return nil, fmt.Errorf("missing bucket key")
|
||||
}
|
||||
|
@ -51,7 +51,7 @@ func (s *StoragePacker) GetBucket(key string) (*Bucket, error) {
|
|||
defer lock.RUnlock()
|
||||
|
||||
// Read from storage
|
||||
storageEntry, err := s.view.Get(context.Background(), key)
|
||||
storageEntry, err := s.view.Get(ctx, key)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("failed to read packed storage entry: {{err}}", err)
|
||||
}
|
||||
|
@ -126,8 +126,8 @@ func (s *StoragePacker) BucketKey(itemID string) string {
|
|||
}
|
||||
|
||||
// DeleteItem removes the item from the respective bucket
|
||||
func (s *StoragePacker) DeleteItem(_ context.Context, itemID string) error {
|
||||
return s.DeleteMultipleItems(context.Background(), nil, []string{itemID})
|
||||
func (s *StoragePacker) DeleteItem(ctx context.Context, itemID string) error {
|
||||
return s.DeleteMultipleItems(ctx, nil, []string{itemID})
|
||||
}
|
||||
|
||||
func (s *StoragePacker) DeleteMultipleItems(ctx context.Context, logger hclog.Logger, itemIDs []string) error {
|
||||
|
@ -171,7 +171,7 @@ func (s *StoragePacker) DeleteMultipleItems(ctx context.Context, logger hclog.Lo
|
|||
idx := 0
|
||||
for bucketKey, itemsToRemove := range byBucket {
|
||||
// Read bucket from storage
|
||||
storageEntry, err := s.view.Get(context.Background(), bucketKey)
|
||||
storageEntry, err := s.view.Get(ctx, bucketKey)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("failed to read packed storage value: {{err}}", err)
|
||||
}
|
||||
|
@ -278,7 +278,7 @@ func (s *StoragePacker) GetItem(itemID string) (*Item, error) {
|
|||
bucketKey := s.BucketKey(itemID)
|
||||
|
||||
// Fetch the bucket entry
|
||||
bucket, err := s.GetBucket(bucketKey)
|
||||
bucket, err := s.GetBucket(context.Background(), bucketKey)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("failed to read packed storage item: {{err}}", err)
|
||||
}
|
||||
|
@ -297,7 +297,7 @@ func (s *StoragePacker) GetItem(itemID string) (*Item, error) {
|
|||
}
|
||||
|
||||
// PutItem stores the given item in its respective bucket
|
||||
func (s *StoragePacker) PutItem(_ context.Context, item *Item) error {
|
||||
func (s *StoragePacker) PutItem(ctx context.Context, item *Item) error {
|
||||
defer metrics.MeasureSince([]string{"storage_packer", "put_item"}, time.Now())
|
||||
|
||||
if item == nil {
|
||||
|
@ -323,7 +323,7 @@ func (s *StoragePacker) PutItem(_ context.Context, item *Item) error {
|
|||
defer lock.Unlock()
|
||||
|
||||
// Check if there is an existing bucket for a given key
|
||||
storageEntry, err := s.view.Get(context.Background(), bucketKey)
|
||||
storageEntry, err := s.view.Get(ctx, bucketKey)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("failed to read packed storage bucket entry: {{err}}", err)
|
||||
}
|
||||
|
@ -354,7 +354,7 @@ func (s *StoragePacker) PutItem(_ context.Context, item *Item) error {
|
|||
}
|
||||
}
|
||||
|
||||
return s.putBucket(context.Background(), bucket)
|
||||
return s.putBucket(ctx, bucket)
|
||||
}
|
||||
|
||||
// NewStoragePacker creates a new storage packer for a given view
|
||||
|
|
|
@ -188,7 +188,6 @@ func FileBackendSetup(conf *vault.CoreConfig, opts *vault.TestClusterOptions) {
|
|||
}
|
||||
|
||||
func RaftBackendSetup(conf *vault.CoreConfig, opts *vault.TestClusterOptions) {
|
||||
conf.DisablePerformanceStandby = true
|
||||
opts.KeepStandbysSealed = true
|
||||
opts.PhysicalFactory = MakeRaftBackend
|
||||
opts.SetupFunc = func(t testing.T, c *vault.TestCluster) {
|
||||
|
|
|
@ -57,6 +57,12 @@ const (
|
|||
// soft-mandatory Sentinel policies.
|
||||
PolicyOverrideHeaderName = "X-Vault-Policy-Override"
|
||||
|
||||
VaultIndexHeaderName = "X-Vault-Index"
|
||||
VaultInconsistentHeaderName = "X-Vault-Inconsistent"
|
||||
VaultForwardHeaderName = "X-Vault-Forward"
|
||||
VaultInconsistentForward = "forward-active-node"
|
||||
VaultInconsistentFail = "fail"
|
||||
|
||||
// DefaultMaxRequestSize is the default maximum accepted request size. This
|
||||
// is to prevent a denial of service attack where no Content-Length is
|
||||
// provided and the server is fed ever more data until it exhausts memory.
|
||||
|
@ -666,12 +672,52 @@ func parseFormRequest(r *http.Request) (map[string]interface{}, error) {
|
|||
return data, nil
|
||||
}
|
||||
|
||||
// forwardBasedOnHeaders returns true if the request headers specify that
|
||||
// we should forward to the active node - either unconditionally or because
|
||||
// a specified state isn't present locally.
|
||||
func forwardBasedOnHeaders(core *vault.Core, r *http.Request) (bool, error) {
|
||||
rawForward := r.Header.Get(VaultForwardHeaderName)
|
||||
if rawForward != "" {
|
||||
if !core.AllowForwardingViaHeader() {
|
||||
return false, fmt.Errorf("forwarding via header %s disabled in configuration", VaultForwardHeaderName)
|
||||
}
|
||||
if rawForward == "active-node" {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
rawInconsistent := r.Header.Get(VaultInconsistentHeaderName)
|
||||
if rawInconsistent == "" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
switch rawInconsistent {
|
||||
case VaultInconsistentForward:
|
||||
if !core.AllowForwardingViaHeader() {
|
||||
return false, fmt.Errorf("forwarding via header %s=%s disabled in configuration",
|
||||
VaultInconsistentHeaderName, VaultInconsistentForward)
|
||||
}
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return core.MissingRequiredState(r.Header.Values(VaultIndexHeaderName)), nil
|
||||
}
|
||||
|
||||
// handleRequestForwarding determines whether to forward a request or not,
|
||||
// falling back on the older behavior of redirecting the client
|
||||
func handleRequestForwarding(core *vault.Core, handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// If we are a performance standby we can handle the request.
|
||||
if core.PerfStandby() {
|
||||
// Note if the client requested forwarding
|
||||
shouldForward, err := forwardBasedOnHeaders(core, r)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
||||
// If we are a performance standby we can maybe handle the request.
|
||||
if core.PerfStandby() && !shouldForward {
|
||||
ns, err := namespace.FromContext(r.Context())
|
||||
if err != nil {
|
||||
respondError(w, http.StatusBadRequest, err)
|
||||
|
|
|
@ -219,6 +219,12 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques
|
|||
if err != nil || status != 0 {
|
||||
return nil, nil, status, err
|
||||
}
|
||||
|
||||
rawRequired := r.Header.Values(VaultIndexHeaderName)
|
||||
if len(rawRequired) != 0 && core.MissingRequiredState(rawRequired) {
|
||||
return nil, nil, http.StatusPreconditionFailed, fmt.Errorf("required index state not present")
|
||||
}
|
||||
|
||||
req, err = requestAuth(core, r, req)
|
||||
if err != nil {
|
||||
if errwrap.Contains(err, logical.ErrPermissionDenied.Error()) {
|
||||
|
@ -453,13 +459,13 @@ func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForw
|
|||
return
|
||||
default:
|
||||
// Build and return the proper response if everything is fine.
|
||||
respondLogical(w, r, req, resp, injectDataIntoTopLevel)
|
||||
respondLogical(core, w, r, req, resp, injectDataIntoTopLevel)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func respondLogical(w http.ResponseWriter, r *http.Request, req *logical.Request, resp *logical.Response, injectDataIntoTopLevel bool) {
|
||||
func respondLogical(core *vault.Core, w http.ResponseWriter, r *http.Request, req *logical.Request, resp *logical.Response, injectDataIntoTopLevel bool) {
|
||||
var httpResp *logical.HTTPResponse
|
||||
var ret interface{}
|
||||
|
||||
|
@ -509,6 +515,8 @@ func respondLogical(w http.ResponseWriter, r *http.Request, req *logical.Request
|
|||
}
|
||||
}
|
||||
|
||||
adjustResponse(core, w, req)
|
||||
|
||||
// Respond
|
||||
respondOk(w, ret)
|
||||
return
|
||||
|
|
|
@ -369,7 +369,7 @@ func TestLogical_RespondWithStatusCode(t *testing.T) {
|
|||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
respondLogical(w, nil, nil, resp404, false)
|
||||
respondLogical(nil, w, nil, nil, resp404, false)
|
||||
|
||||
if w.Code != 404 {
|
||||
t.Fatalf("Bad Status code: %d", w.Code)
|
||||
|
|
|
@ -28,6 +28,8 @@ var (
|
|||
additionalRoutes = func(mux *http.ServeMux, core *vault.Core) {}
|
||||
|
||||
nonVotersAllowed = false
|
||||
|
||||
adjustResponse = func(core *vault.Core, w http.ResponseWriter, req *logical.Request) {}
|
||||
)
|
||||
|
||||
func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler {
|
||||
|
|
|
@ -77,8 +77,8 @@ type FSM struct {
|
|||
logger log.Logger
|
||||
noopRestore bool
|
||||
|
||||
// applyDelay is used to simulate a slow apply in tests
|
||||
applyDelay time.Duration
|
||||
// applyCallback is used to control the pace of applies in tests
|
||||
applyCallback func()
|
||||
|
||||
db *bolt.DB
|
||||
|
||||
|
@ -131,8 +131,12 @@ func (f *FSM) getDB() *bolt.DB {
|
|||
// SetFSMDelay adds a delay to the FSM apply. This is used in tests to simulate
|
||||
// a slow apply.
|
||||
func (r *RaftBackend) SetFSMDelay(delay time.Duration) {
|
||||
r.SetFSMApplyCallback(func() { time.Sleep(delay) })
|
||||
}
|
||||
|
||||
func (r *RaftBackend) SetFSMApplyCallback(f func()) {
|
||||
r.fsm.l.Lock()
|
||||
r.fsm.applyDelay = delay
|
||||
r.fsm.applyCallback = f
|
||||
r.fsm.l.Unlock()
|
||||
}
|
||||
|
||||
|
@ -469,8 +473,8 @@ func (f *FSM) ApplyBatch(logs []*raft.Log) []interface{} {
|
|||
f.l.RLock()
|
||||
defer f.l.RUnlock()
|
||||
|
||||
if f.applyDelay > 0 {
|
||||
time.Sleep(f.applyDelay)
|
||||
if f.applyCallback != nil {
|
||||
f.applyCallback()
|
||||
}
|
||||
|
||||
err = f.db.Update(func(tx *bolt.Tx) error {
|
||||
|
|
|
@ -263,6 +263,16 @@ func NewRaftBackend(conf map[string]string, logger log.Logger) (physical.Backend
|
|||
return nil, fmt.Errorf("failed to create fsm: %v", err)
|
||||
}
|
||||
|
||||
if delayRaw, ok := conf["apply_delay"]; ok {
|
||||
delay, err := time.ParseDuration(delayRaw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("apply_delay does not parse as a duration: %w", err)
|
||||
}
|
||||
fsm.applyCallback = func() {
|
||||
time.Sleep(delay)
|
||||
}
|
||||
}
|
||||
|
||||
// Build an all in-memory setup for dev mode, otherwise prepare a full
|
||||
// disk-based setup.
|
||||
var log raft.LogStore
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package logical
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
@ -54,6 +55,30 @@ const (
|
|||
ClientTokenFromAuthzHeader
|
||||
)
|
||||
|
||||
type WALState struct {
|
||||
ClusterID string
|
||||
LocalIndex uint64
|
||||
ReplicatedIndex uint64
|
||||
}
|
||||
|
||||
const indexStateCtxKey = "index_state"
|
||||
|
||||
// IndexStateContext returns a context with an added value holding the index
|
||||
// state that should be populated on writes.
|
||||
func IndexStateContext(ctx context.Context, state *WALState) context.Context {
|
||||
return context.WithValue(ctx, indexStateCtxKey, state)
|
||||
}
|
||||
|
||||
// IndexStateFromContext is a helper to look up if the provided context contains
|
||||
// an index state pointer.
|
||||
func IndexStateFromContext(ctx context.Context) *WALState {
|
||||
s, ok := ctx.Value(indexStateCtxKey).(*WALState)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Request is a struct that stores the parameters and context of a request
|
||||
// being made to Vault. It is used to abstract the details of the higher level
|
||||
// request protocol from the handlers.
|
||||
|
@ -179,6 +204,11 @@ type Request struct {
|
|||
// ResponseWriter if set can be used to stream a response value to the http
|
||||
// request that generated this logical.Request object.
|
||||
ResponseWriter *HTTPResponseWriter `json:"-" sentinel:""`
|
||||
|
||||
// responseState is used internally to propagate the state that should appear
|
||||
// in response headers; it's attached to the request rather than the response
|
||||
// because not all requests yields non-nil responses.
|
||||
responseState *WALState
|
||||
}
|
||||
|
||||
// Clone returns a deep copy of the request by using copystructure
|
||||
|
@ -243,6 +273,14 @@ func (r *Request) SetLastRemoteWAL(last uint64) {
|
|||
r.lastRemoteWAL = last
|
||||
}
|
||||
|
||||
func (r *Request) ResponseState() *WALState {
|
||||
return r.responseState
|
||||
}
|
||||
|
||||
func (r *Request) SetResponseState(w *WALState) {
|
||||
r.responseState = w
|
||||
}
|
||||
|
||||
func (r *Request) TokenEntry() *TokenEntry {
|
||||
return r.tokenEntry
|
||||
}
|
||||
|
|
|
@ -280,6 +280,18 @@ func (c *Core) setupCluster(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
|
||||
c.clusterID.Store(cluster.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Core) loadCluster(ctx context.Context) error {
|
||||
cluster, err := c.Cluster(ctx)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to get cluster details", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
c.clusterID.Store(cluster.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -3,10 +3,14 @@ package vault
|
|||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -16,6 +20,8 @@ import (
|
|||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
@ -48,6 +54,7 @@ import (
|
|||
"github.com/hashicorp/vault/vault/quotas"
|
||||
vaultseal "github.com/hashicorp/vault/vault/seal"
|
||||
"github.com/patrickmn/go-cache"
|
||||
uberAtomic "go.uber.org/atomic"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
|
@ -72,6 +79,8 @@ const (
|
|||
// clusters that they need to perform a rekey operation synchronously; this
|
||||
// isn't keyring-canary to avoid ignoring it when ignoring core/keyring
|
||||
coreKeyringCanaryPath = "core/canary-keyring"
|
||||
|
||||
indexHeaderHMACKeyPath = "core/index-header-hmac-key"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -108,6 +117,7 @@ var (
|
|||
PerformanceMerkleRoot = merkleRootImpl
|
||||
DRMerkleRoot = merkleRootImpl
|
||||
LastRemoteWAL = lastRemoteWALImpl
|
||||
LastRemoteUpstreamWAL = lastRemoteUpstreamWALImpl
|
||||
WaitUntilWALShipped = waitUntilWALShippedImpl
|
||||
)
|
||||
|
||||
|
@ -391,6 +401,8 @@ type Core struct {
|
|||
//
|
||||
// Name
|
||||
clusterName string
|
||||
// ID
|
||||
clusterID uberAtomic.String
|
||||
// Specific cipher suites to use for clustering, if any
|
||||
clusterCipherSuites []uint16
|
||||
// Used to modify cluster parameters
|
||||
|
@ -546,6 +558,8 @@ type Core struct {
|
|||
|
||||
// number of workers to use for lease revocation in the expiration manager
|
||||
numExpirationWorkers int
|
||||
|
||||
IndexHeaderHMACKey uberAtomic.Value
|
||||
}
|
||||
|
||||
// CoreConfig is used to parameterize a core
|
||||
|
@ -2201,6 +2215,10 @@ func lastRemoteWALImpl(c *Core) uint64 {
|
|||
return 0
|
||||
}
|
||||
|
||||
func lastRemoteUpstreamWALImpl(c *Core) uint64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *Core) PhysicalSealConfigs(ctx context.Context) (*SealConfig, *SealConfig, error) {
|
||||
pe, err := c.physical.Get(ctx, barrierSealConfigPath)
|
||||
if err != nil {
|
||||
|
@ -2542,6 +2560,10 @@ func (c *Core) IsDRSecondary() bool {
|
|||
return c.ReplicationState().HasState(consts.ReplicationDRSecondary)
|
||||
}
|
||||
|
||||
func (c *Core) IsPerfSecondary() bool {
|
||||
return c.ReplicationState().HasState(consts.ReplicationPerformanceSecondary)
|
||||
}
|
||||
|
||||
func (c *Core) AddLogger(logger log.Logger) {
|
||||
c.allLoggersLock.Lock()
|
||||
defer c.allLoggersLock.Unlock()
|
||||
|
@ -2705,3 +2727,48 @@ func (c *Core) KeyRotateGracePeriod() time.Duration {
|
|||
func (c *Core) SetKeyRotateGracePeriod(t time.Duration) {
|
||||
atomic.StoreInt64(c.keyRotateGracePeriod, int64(t))
|
||||
}
|
||||
|
||||
func ParseRequiredState(raw string, hmacKey []byte) (*logical.WALState, error) {
|
||||
cooked, err := base64.StdEncoding.DecodeString(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := string(cooked)
|
||||
|
||||
lastIndex := strings.LastIndexByte(s, ':')
|
||||
if lastIndex == -1 {
|
||||
return nil, fmt.Errorf("invalid state header format")
|
||||
}
|
||||
state, stateHMACRaw := s[:lastIndex], s[lastIndex+1:]
|
||||
stateHMAC, err := hex.DecodeString(stateHMACRaw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid state header HMAC: %v, %w", stateHMACRaw, err)
|
||||
}
|
||||
|
||||
if len(hmacKey) != 0 {
|
||||
hm := hmac.New(sha256.New, hmacKey)
|
||||
hm.Write([]byte(state))
|
||||
if !hmac.Equal(hm.Sum(nil), stateHMAC) {
|
||||
return nil, fmt.Errorf("invalid state header HMAC (mismatch)")
|
||||
}
|
||||
}
|
||||
|
||||
pieces := strings.Split(state, ":")
|
||||
if len(pieces) != 4 || pieces[0] != "v1" || pieces[1] == "" {
|
||||
return nil, fmt.Errorf("invalid state header format")
|
||||
}
|
||||
localIndex, err := strconv.ParseUint(pieces[2], 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid state header format")
|
||||
}
|
||||
replicatedIndex, err := strconv.ParseUint(pieces[3], 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid state header format")
|
||||
}
|
||||
|
||||
return &logical.WALState{
|
||||
ClusterID: pieces[1],
|
||||
LocalIndex: localIndex,
|
||||
ReplicatedIndex: replicatedIndex,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -158,3 +158,11 @@ func (c *Core) quotasHandleLeases(ctx context.Context, action quotas.LeaseAction
|
|||
func (c *Core) namespaceByPath(path string) *namespace.Namespace {
|
||||
return namespace.RootNamespace
|
||||
}
|
||||
|
||||
func (c *Core) AllowForwardingViaHeader() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Core) MissingRequiredState(raw []string) bool {
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -52,6 +52,7 @@ func (e extendedSystemViewImpl) ForwardGenericRequest(ctx context.Context, req *
|
|||
// Forward the request if allowed
|
||||
if couldForward(e.core) {
|
||||
ctx = namespace.ContextWithNamespace(ctx, e.mountEntry.Namespace())
|
||||
ctx = logical.IndexStateContext(ctx, &logical.WALState{})
|
||||
ctx = context.WithValue(ctx, ctxKeyForwardedRequestMountAccessor{}, e.mountEntry.Accessor)
|
||||
return forward(ctx, e.core, req)
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ const (
|
|||
|
||||
var (
|
||||
caseSensitivityKey = "casesensitivity"
|
||||
sendGroupUpgrade = func(*IdentityStore, *identity.Group) (bool, error) { return false, nil }
|
||||
sendGroupUpgrade = func(context.Context, *IdentityStore, *identity.Group) (bool, error) { return false, nil }
|
||||
parseExtraEntityFromBucket = func(context.Context, *IdentityStore, *identity.Entity) (bool, error) { return false, nil }
|
||||
addExtraEntityDataToResponse = func(*identity.Entity, map[string]interface{}) {}
|
||||
)
|
||||
|
@ -210,7 +210,7 @@ func (i *IdentityStore) Invalidate(ctx context.Context, key string) {
|
|||
}
|
||||
|
||||
// Get the storage bucket entry
|
||||
bucket, err := i.entityPacker.GetBucket(key)
|
||||
bucket, err := i.entityPacker.GetBucket(ctx, key)
|
||||
if err != nil {
|
||||
i.logger.Error("failed to refresh entities", "key", key, "error", err)
|
||||
return
|
||||
|
@ -273,7 +273,7 @@ func (i *IdentityStore) Invalidate(ctx context.Context, key string) {
|
|||
}
|
||||
|
||||
// Get the storage bucket entry
|
||||
bucket, err := i.groupPacker.GetBucket(key)
|
||||
bucket, err := i.groupPacker.GetBucket(ctx, key)
|
||||
if err != nil {
|
||||
i.logger.Error("failed to refresh group", "key", key, "error", err)
|
||||
return
|
||||
|
|
|
@ -569,6 +569,8 @@ func (i *IdentityStore) handleEntityBatchDelete() framework.OperationFunc {
|
|||
}
|
||||
}
|
||||
|
||||
// handleEntityDeleteCommon deletes an entity by removing it from groups of
|
||||
// which it's a member and then, if update is true, deleting the entity itself.
|
||||
func (i *IdentityStore) handleEntityDeleteCommon(ctx context.Context, txn *memdb.Txn, entity *identity.Entity, update bool) error {
|
||||
ns, err := namespace.FromContext(ctx)
|
||||
if err != nil {
|
||||
|
|
|
@ -89,7 +89,7 @@ func (i *IdentityStore) loadGroups(ctx context.Context) error {
|
|||
i.logger.Debug("groups collected", "num_existing", len(existing))
|
||||
|
||||
for _, key := range existing {
|
||||
bucket, err := i.groupPacker.GetBucket(groupBucketsPrefix + key)
|
||||
bucket, err := i.groupPacker.GetBucket(ctx, groupBucketsPrefix+key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -124,7 +124,7 @@ func (i *IdentityStore) loadGroups(ctx context.Context) error {
|
|||
}
|
||||
continue
|
||||
}
|
||||
nsCtx := namespace.ContextWithNamespace(context.Background(), ns)
|
||||
nsCtx := namespace.ContextWithNamespace(ctx, ns)
|
||||
|
||||
// Ensure that there are no groups with duplicate names
|
||||
groupByName, err := i.MemDBGroupByName(nsCtx, group.Name, false)
|
||||
|
@ -212,7 +212,7 @@ func (i *IdentityStore) loadEntities(ctx context.Context) error {
|
|||
return
|
||||
}
|
||||
|
||||
bucket, err := i.entityPacker.GetBucket(storagepacker.StoragePackerBucketsPrefix + key)
|
||||
bucket, err := i.entityPacker.GetBucket(ctx, storagepacker.StoragePackerBucketsPrefix+key)
|
||||
if err != nil {
|
||||
errs <- err
|
||||
continue
|
||||
|
@ -292,7 +292,7 @@ func (i *IdentityStore) loadEntities(ctx context.Context) error {
|
|||
}
|
||||
continue
|
||||
}
|
||||
nsCtx := namespace.ContextWithNamespace(context.Background(), ns)
|
||||
nsCtx := namespace.ContextWithNamespace(ctx, ns)
|
||||
|
||||
// Ensure that there are no entities with duplicate names
|
||||
entityByName, err := i.MemDBEntityByName(nsCtx, entity.Name, false)
|
||||
|
@ -1437,7 +1437,7 @@ func (i *IdentityStore) UpsertGroupInTxn(ctx context.Context, txn *memdb.Txn, gr
|
|||
Message: groupAsAny,
|
||||
}
|
||||
|
||||
sent, err := sendGroupUpgrade(i, group)
|
||||
sent, err := sendGroupUpgrade(ctx, i, group)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -457,6 +457,8 @@ func (c *Core) handleCancelableRequest(ctx context.Context, ns *namespace.Namesp
|
|||
return nil, logical.CodedError(403, "namespaces feature not enabled")
|
||||
}
|
||||
|
||||
var walState = &logical.WALState{}
|
||||
ctx = logical.IndexStateContext(ctx, walState)
|
||||
var auth *logical.Auth
|
||||
if c.router.LoginPath(ctx, req.Path) {
|
||||
resp, auth, err = c.handleLoginRequest(ctx, req)
|
||||
|
@ -564,6 +566,26 @@ func (c *Core) handleCancelableRequest(ctx context.Context, ns *namespace.Namesp
|
|||
}
|
||||
}
|
||||
|
||||
if walState.LocalIndex != 0 || walState.ReplicatedIndex != 0 {
|
||||
walState.ClusterID = c.clusterID.Load()
|
||||
if walState.LocalIndex == 0 {
|
||||
if c.perfStandby {
|
||||
walState.LocalIndex = LastRemoteWAL(c)
|
||||
} else {
|
||||
walState.LocalIndex = LastWAL(c)
|
||||
}
|
||||
}
|
||||
if walState.ReplicatedIndex == 0 {
|
||||
if c.perfStandby {
|
||||
walState.ReplicatedIndex = LastRemoteUpstreamWAL(c)
|
||||
} else {
|
||||
walState.ReplicatedIndex = LastRemoteWAL(c)
|
||||
}
|
||||
}
|
||||
|
||||
req.SetResponseState(walState)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -76,6 +76,9 @@ func (c *Core) ensureWrappingKey(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// wrapInCubbyhole is invoked when a caller asks for response wrapping.
|
||||
// On success, return (nil, nil) and mutates resp. On failure, returns
|
||||
// either a response describing the failure or an error.
|
||||
func (c *Core) wrapInCubbyhole(ctx context.Context, req *logical.Request, resp *logical.Response, auth *logical.Auth) (*logical.Response, error) {
|
||||
if c.perfStandby {
|
||||
return forwardWrapRequest(ctx, c, req, resp, auth)
|
||||
|
|
|
@ -384,6 +384,8 @@ type Client struct {
|
|||
wrappingLookupFunc WrappingLookupFunc
|
||||
mfaCreds []string
|
||||
policyOverride bool
|
||||
requestCallbacks []RequestCallback
|
||||
responseCallbacks []ResponseCallback
|
||||
}
|
||||
|
||||
// NewClient returns a new client for the given configuration.
|
||||
|
@ -866,6 +868,10 @@ func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Respon
|
|||
|
||||
c.modifyLock.RUnlock()
|
||||
|
||||
for _, cb := range c.requestCallbacks {
|
||||
cb(r)
|
||||
}
|
||||
|
||||
if limiter != nil {
|
||||
limiter.Wait(ctx)
|
||||
}
|
||||
|
@ -907,7 +913,7 @@ START:
|
|||
}
|
||||
|
||||
if checkRetry == nil {
|
||||
checkRetry = retryablehttp.DefaultRetryPolicy
|
||||
checkRetry = DefaultRetryPolicy
|
||||
}
|
||||
|
||||
client := &retryablehttp.Client{
|
||||
|
@ -968,9 +974,91 @@ START:
|
|||
goto START
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
for _, cb := range c.responseCallbacks {
|
||||
cb(result)
|
||||
}
|
||||
}
|
||||
if err := result.Error(); err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type RequestCallback func(*Request)
|
||||
type ResponseCallback func(*Response)
|
||||
|
||||
// WithRequestCallbacks makes a shallow clone of Client, modifies it to use
|
||||
// the given callbacks, and returns it. Each of the callbacks will be invoked
|
||||
// on every outgoing request. A client may be used to issue requests
|
||||
// concurrently; any locking needed by callbacks invoked concurrently is the
|
||||
// callback's responsibility.
|
||||
func (c *Client) WithRequestCallbacks(callbacks ...RequestCallback) *Client {
|
||||
c2 := *c
|
||||
c2.modifyLock = sync.RWMutex{}
|
||||
c2.requestCallbacks = callbacks
|
||||
return &c2
|
||||
}
|
||||
|
||||
// WithResponseCallbacks makes a shallow clone of Client, modifies it to use
|
||||
// the given callbacks, and returns it. Each of the callbacks will be invoked
|
||||
// on every received response. A client may be used to issue requests
|
||||
// concurrently; any locking needed by callbacks invoked concurrently is the
|
||||
// callback's responsibility.
|
||||
func (c *Client) WithResponseCallbacks(callbacks ...ResponseCallback) *Client {
|
||||
c2 := *c
|
||||
c2.modifyLock = sync.RWMutex{}
|
||||
c2.responseCallbacks = callbacks
|
||||
return &c2
|
||||
}
|
||||
|
||||
// RecordState returns a response callback that will record the state returned
|
||||
// by Vault in a response header.
|
||||
func RecordState(state *string) ResponseCallback {
|
||||
return func(resp *Response) {
|
||||
*state = resp.Header.Get("X-Vault-Index")
|
||||
}
|
||||
}
|
||||
|
||||
// RequireState returns a request callback that will add a request header to
|
||||
// specify the state we require of Vault. This state was obtained from a
|
||||
// response header seen previous, probably captured with RecordState.
|
||||
func RequireState(states ...string) RequestCallback {
|
||||
return func(req *Request) {
|
||||
for _, s := range states {
|
||||
req.Headers.Add("X-Vault-Index", s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ForwardInconsistent returns a request callback that will add a request
|
||||
// header which says: if the state required isn't present on the node receiving
|
||||
// this request, forward it to the active node. This should be used in
|
||||
// conjunction with RequireState.
|
||||
func ForwardInconsistent() RequestCallback {
|
||||
return func(req *Request) {
|
||||
req.Headers.Set("X-Vault-Inconsistent", "forward-active-node")
|
||||
}
|
||||
}
|
||||
|
||||
// ForwardAlways returns a request callback which adds a header telling any
|
||||
// performance standbys handling the request to forward it to the active node.
|
||||
// This feature must be enabled in Vault's configuration.
|
||||
func ForwardAlways() RequestCallback {
|
||||
return func(req *Request) {
|
||||
req.Headers.Set("X-Vault-Forward", "active-node")
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultRetryPolicy is the default retry policy used by new Client objects.
|
||||
// It is the same as retryablehttp.DefaultRetryPolicy except that it also retries
|
||||
// 412 requests, which are returned by Vault when a X-Vault-Index header isn't
|
||||
// satisfied.
|
||||
func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) {
|
||||
retry, err := retryablehttp.DefaultRetryPolicy(ctx, resp, err)
|
||||
if err != nil || retry {
|
||||
return retry, err
|
||||
}
|
||||
return resp.StatusCode == 412, nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package logical
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
@ -54,6 +55,30 @@ const (
|
|||
ClientTokenFromAuthzHeader
|
||||
)
|
||||
|
||||
type WALState struct {
|
||||
ClusterID string
|
||||
LocalIndex uint64
|
||||
ReplicatedIndex uint64
|
||||
}
|
||||
|
||||
const indexStateCtxKey = "index_state"
|
||||
|
||||
// IndexStateContext returns a context with an added value holding the index
|
||||
// state that should be populated on writes.
|
||||
func IndexStateContext(ctx context.Context, state *WALState) context.Context {
|
||||
return context.WithValue(ctx, indexStateCtxKey, state)
|
||||
}
|
||||
|
||||
// IndexStateFromContext is a helper to look up if the provided context contains
|
||||
// an index state pointer.
|
||||
func IndexStateFromContext(ctx context.Context) *WALState {
|
||||
s, ok := ctx.Value(indexStateCtxKey).(*WALState)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Request is a struct that stores the parameters and context of a request
|
||||
// being made to Vault. It is used to abstract the details of the higher level
|
||||
// request protocol from the handlers.
|
||||
|
@ -179,6 +204,11 @@ type Request struct {
|
|||
// ResponseWriter if set can be used to stream a response value to the http
|
||||
// request that generated this logical.Request object.
|
||||
ResponseWriter *HTTPResponseWriter `json:"-" sentinel:""`
|
||||
|
||||
// responseState is used internally to propagate the state that should appear
|
||||
// in response headers; it's attached to the request rather than the response
|
||||
// because not all requests yields non-nil responses.
|
||||
responseState *WALState
|
||||
}
|
||||
|
||||
// Clone returns a deep copy of the request by using copystructure
|
||||
|
@ -243,6 +273,14 @@ func (r *Request) SetLastRemoteWAL(last uint64) {
|
|||
r.lastRemoteWAL = last
|
||||
}
|
||||
|
||||
func (r *Request) ResponseState() *WALState {
|
||||
return r.responseState
|
||||
}
|
||||
|
||||
func (r *Request) SetResponseState(w *WALState) {
|
||||
r.responseState = w
|
||||
}
|
||||
|
||||
func (r *Request) TokenEntry() *TokenEntry {
|
||||
return r.tokenEntry
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue