Redo API client locking (#4551)
* Redo API client locking This assigns local values when in critical paths, allowing a single API client to much more quickly and safely pipeline requests. Additionally, in order to take that paradigm all the way it changes how timeouts are set. It now uses a context value set on the request instead of configuring the timeout in the http client per request, which was also potentially quite racy. Trivially tested with VAULT_CLIENT_TIMEOUT=2 vault write pki/root/generate/internal key_type=rsa key_bits=8192
This commit is contained in:
parent
94ae5d2567
commit
35cb9bc517
|
@ -388,11 +388,12 @@ func (c *Client) SetAddress(addr string) error {
|
|||
c.modifyLock.Lock()
|
||||
defer c.modifyLock.Unlock()
|
||||
|
||||
var err error
|
||||
if c.addr, err = url.Parse(addr); err != nil {
|
||||
parsedAddr, err := url.Parse(addr)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("failed to set address: {{err}}", err)
|
||||
}
|
||||
|
||||
c.addr = parsedAddr
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -411,7 +412,8 @@ func (c *Client) SetLimiter(rateLimit float64, burst int) {
|
|||
c.modifyLock.RLock()
|
||||
c.config.modifyLock.Lock()
|
||||
defer c.config.modifyLock.Unlock()
|
||||
defer c.modifyLock.RUnlock()
|
||||
c.modifyLock.RUnlock()
|
||||
|
||||
c.config.Limiter = rate.NewLimiter(rate.Limit(rateLimit), burst)
|
||||
}
|
||||
|
||||
|
@ -544,14 +546,20 @@ func (c *Client) SetPolicyOverride(override bool) {
|
|||
// doesn't need to be called externally.
|
||||
func (c *Client) NewRequest(method, requestPath string) *Request {
|
||||
c.modifyLock.RLock()
|
||||
defer c.modifyLock.RUnlock()
|
||||
addr := c.addr
|
||||
token := c.token
|
||||
mfaCreds := c.mfaCreds
|
||||
wrappingLookupFunc := c.wrappingLookupFunc
|
||||
headers := c.headers
|
||||
policyOverride := c.policyOverride
|
||||
c.modifyLock.RUnlock()
|
||||
|
||||
// if SRV records exist (see https://tools.ietf.org/html/draft-andrews-http-srv-02), lookup the SRV
|
||||
// record and take the highest match; this is not designed for high-availability, just discovery
|
||||
var host string = c.addr.Host
|
||||
if c.addr.Port() == "" {
|
||||
var host string = addr.Host
|
||||
if addr.Port() == "" {
|
||||
// Internet Draft specifies that the SRV record is ignored if a port is given
|
||||
_, addrs, err := net.LookupSRV("http", "tcp", c.addr.Hostname())
|
||||
_, addrs, err := net.LookupSRV("http", "tcp", addr.Hostname())
|
||||
if err == nil && len(addrs) > 0 {
|
||||
host = fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port)
|
||||
}
|
||||
|
@ -560,12 +568,12 @@ func (c *Client) NewRequest(method, requestPath string) *Request {
|
|||
req := &Request{
|
||||
Method: method,
|
||||
URL: &url.URL{
|
||||
User: c.addr.User,
|
||||
Scheme: c.addr.Scheme,
|
||||
User: addr.User,
|
||||
Scheme: addr.Scheme,
|
||||
Host: host,
|
||||
Path: path.Join(c.addr.Path, requestPath),
|
||||
Path: path.Join(addr.Path, requestPath),
|
||||
},
|
||||
ClientToken: c.token,
|
||||
ClientToken: token,
|
||||
Params: make(map[string][]string),
|
||||
}
|
||||
|
||||
|
@ -579,21 +587,19 @@ func (c *Client) NewRequest(method, requestPath string) *Request {
|
|||
lookupPath = requestPath
|
||||
}
|
||||
|
||||
req.MFAHeaderVals = c.mfaCreds
|
||||
req.MFAHeaderVals = mfaCreds
|
||||
|
||||
if c.wrappingLookupFunc != nil {
|
||||
req.WrapTTL = c.wrappingLookupFunc(method, lookupPath)
|
||||
if wrappingLookupFunc != nil {
|
||||
req.WrapTTL = wrappingLookupFunc(method, lookupPath)
|
||||
} else {
|
||||
req.WrapTTL = DefaultWrappingLookupFunc(method, lookupPath)
|
||||
}
|
||||
if c.config.Timeout != 0 {
|
||||
c.config.HttpClient.Timeout = c.config.Timeout
|
||||
}
|
||||
if c.headers != nil {
|
||||
req.Headers = c.headers
|
||||
|
||||
if headers != nil {
|
||||
req.Headers = headers
|
||||
}
|
||||
|
||||
req.PolicyOverride = c.policyOverride
|
||||
req.PolicyOverride = policyOverride
|
||||
|
||||
return req
|
||||
}
|
||||
|
@ -602,18 +608,23 @@ func (c *Client) NewRequest(method, requestPath string) *Request {
|
|||
// a Vault server not configured with this client. This is an advanced operation
|
||||
// that generally won't need to be called externally.
|
||||
func (c *Client) RawRequest(r *Request) (*Response, error) {
|
||||
|
||||
c.modifyLock.RLock()
|
||||
c.config.modifyLock.RLock()
|
||||
defer c.config.modifyLock.RUnlock()
|
||||
|
||||
if c.config.Limiter != nil {
|
||||
c.config.Limiter.Wait(context.Background())
|
||||
}
|
||||
|
||||
token := c.token
|
||||
|
||||
c.config.modifyLock.RLock()
|
||||
limiter := c.config.Limiter
|
||||
maxRetries := c.config.MaxRetries
|
||||
backoff := c.config.Backoff
|
||||
httpClient := c.config.HttpClient
|
||||
timeout := c.config.Timeout
|
||||
c.config.modifyLock.RUnlock()
|
||||
|
||||
c.modifyLock.RUnlock()
|
||||
|
||||
if limiter != nil {
|
||||
limiter.Wait(context.Background())
|
||||
}
|
||||
|
||||
// Sanity check the token before potentially erroring from the API
|
||||
idx := strings.IndexFunc(token, func(c rune) bool {
|
||||
return !unicode.IsPrint(c)
|
||||
|
@ -632,16 +643,23 @@ START:
|
|||
return nil, fmt.Errorf("nil request created")
|
||||
}
|
||||
|
||||
backoff := c.config.Backoff
|
||||
// Set the timeout, if any
|
||||
var cancelFunc context.CancelFunc
|
||||
if timeout != 0 {
|
||||
var ctx context.Context
|
||||
ctx, cancelFunc = context.WithTimeout(context.Background(), timeout)
|
||||
req.Request = req.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
if backoff == nil {
|
||||
backoff = retryablehttp.LinearJitterBackoff
|
||||
}
|
||||
|
||||
client := &retryablehttp.Client{
|
||||
HTTPClient: c.config.HttpClient,
|
||||
HTTPClient: httpClient,
|
||||
RetryWaitMin: 1000 * time.Millisecond,
|
||||
RetryWaitMax: 1500 * time.Millisecond,
|
||||
RetryMax: c.config.MaxRetries,
|
||||
RetryMax: maxRetries,
|
||||
CheckRetry: retryablehttp.DefaultRetryPolicy,
|
||||
Backoff: backoff,
|
||||
ErrorHandler: retryablehttp.PassthroughErrorHandler,
|
||||
|
@ -649,6 +667,9 @@ START:
|
|||
|
||||
var result *Response
|
||||
resp, err := client.Do(req)
|
||||
if cancelFunc != nil {
|
||||
cancelFunc()
|
||||
}
|
||||
if resp != nil {
|
||||
result = &Response{Response: resp}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
@ -244,22 +243,10 @@ func TestClientTimeoutSetting(t *testing.T) {
|
|||
defer os.Setenv(EnvVaultClientTimeout, oldClientTimeout)
|
||||
config := DefaultConfig()
|
||||
config.ReadEnvironment()
|
||||
client, err := NewClient(config)
|
||||
_, err := NewClient(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_ = client.NewRequest("PUT", "/")
|
||||
if client.config.HttpClient.Timeout != time.Second*10 {
|
||||
t.Fatalf("error setting client timeout using env variable")
|
||||
}
|
||||
|
||||
// Setting custom client timeout for a new request
|
||||
client.SetClientTimeout(time.Second * 20)
|
||||
_ = client.NewRequest("PUT", "/")
|
||||
if client.config.HttpClient.Timeout != time.Second*20 {
|
||||
t.Fatalf("error setting client timeout using SetClientTimeout")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
|
Loading…
Reference in New Issue