Redo API client locking (#4551)

* Redo API client locking

This assigns local values when in critical paths, allowing a single API
client to much more quickly and safely pipeline requests.

Additionally, in order to take that paradigm all the way it changes how
timeouts are set. It now uses a context value set on the request instead
of configuring the timeout in the http client per request, which was
also potentially quite racy.

Trivially tested with
VAULT_CLIENT_TIMEOUT=2 vault write pki/root/generate/internal key_type=rsa key_bits=8192
This commit is contained in:
Jeff Mitchell 2018-05-25 14:38:06 -04:00 committed by GitHub
parent 94ae5d2567
commit 35cb9bc517
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 45 deletions

View File

@ -388,11 +388,12 @@ func (c *Client) SetAddress(addr string) error {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
var err error
if c.addr, err = url.Parse(addr); err != nil {
parsedAddr, err := url.Parse(addr)
if err != nil {
return errwrap.Wrapf("failed to set address: {{err}}", err)
}
c.addr = parsedAddr
return nil
}
@ -411,7 +412,8 @@ func (c *Client) SetLimiter(rateLimit float64, burst int) {
c.modifyLock.RLock()
c.config.modifyLock.Lock()
defer c.config.modifyLock.Unlock()
defer c.modifyLock.RUnlock()
c.modifyLock.RUnlock()
c.config.Limiter = rate.NewLimiter(rate.Limit(rateLimit), burst)
}
@ -544,14 +546,20 @@ func (c *Client) SetPolicyOverride(override bool) {
// doesn't need to be called externally.
func (c *Client) NewRequest(method, requestPath string) *Request {
c.modifyLock.RLock()
defer c.modifyLock.RUnlock()
addr := c.addr
token := c.token
mfaCreds := c.mfaCreds
wrappingLookupFunc := c.wrappingLookupFunc
headers := c.headers
policyOverride := c.policyOverride
c.modifyLock.RUnlock()
// if SRV records exist (see https://tools.ietf.org/html/draft-andrews-http-srv-02), lookup the SRV
// record and take the highest match; this is not designed for high-availability, just discovery
var host string = c.addr.Host
if c.addr.Port() == "" {
var host string = addr.Host
if addr.Port() == "" {
// Internet Draft specifies that the SRV record is ignored if a port is given
_, addrs, err := net.LookupSRV("http", "tcp", c.addr.Hostname())
_, addrs, err := net.LookupSRV("http", "tcp", addr.Hostname())
if err == nil && len(addrs) > 0 {
host = fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port)
}
@ -560,12 +568,12 @@ func (c *Client) NewRequest(method, requestPath string) *Request {
req := &Request{
Method: method,
URL: &url.URL{
User: c.addr.User,
Scheme: c.addr.Scheme,
User: addr.User,
Scheme: addr.Scheme,
Host: host,
Path: path.Join(c.addr.Path, requestPath),
Path: path.Join(addr.Path, requestPath),
},
ClientToken: c.token,
ClientToken: token,
Params: make(map[string][]string),
}
@ -579,21 +587,19 @@ func (c *Client) NewRequest(method, requestPath string) *Request {
lookupPath = requestPath
}
req.MFAHeaderVals = c.mfaCreds
req.MFAHeaderVals = mfaCreds
if c.wrappingLookupFunc != nil {
req.WrapTTL = c.wrappingLookupFunc(method, lookupPath)
if wrappingLookupFunc != nil {
req.WrapTTL = wrappingLookupFunc(method, lookupPath)
} else {
req.WrapTTL = DefaultWrappingLookupFunc(method, lookupPath)
}
if c.config.Timeout != 0 {
c.config.HttpClient.Timeout = c.config.Timeout
}
if c.headers != nil {
req.Headers = c.headers
if headers != nil {
req.Headers = headers
}
req.PolicyOverride = c.policyOverride
req.PolicyOverride = policyOverride
return req
}
@ -602,18 +608,23 @@ func (c *Client) NewRequest(method, requestPath string) *Request {
// a Vault server not configured with this client. This is an advanced operation
// that generally won't need to be called externally.
func (c *Client) RawRequest(r *Request) (*Response, error) {
c.modifyLock.RLock()
c.config.modifyLock.RLock()
defer c.config.modifyLock.RUnlock()
if c.config.Limiter != nil {
c.config.Limiter.Wait(context.Background())
}
token := c.token
c.config.modifyLock.RLock()
limiter := c.config.Limiter
maxRetries := c.config.MaxRetries
backoff := c.config.Backoff
httpClient := c.config.HttpClient
timeout := c.config.Timeout
c.config.modifyLock.RUnlock()
c.modifyLock.RUnlock()
if limiter != nil {
limiter.Wait(context.Background())
}
// Sanity check the token before potentially erroring from the API
idx := strings.IndexFunc(token, func(c rune) bool {
return !unicode.IsPrint(c)
@ -632,16 +643,23 @@ START:
return nil, fmt.Errorf("nil request created")
}
backoff := c.config.Backoff
// Set the timeout, if any
var cancelFunc context.CancelFunc
if timeout != 0 {
var ctx context.Context
ctx, cancelFunc = context.WithTimeout(context.Background(), timeout)
req.Request = req.Request.WithContext(ctx)
}
if backoff == nil {
backoff = retryablehttp.LinearJitterBackoff
}
client := &retryablehttp.Client{
HTTPClient: c.config.HttpClient,
HTTPClient: httpClient,
RetryWaitMin: 1000 * time.Millisecond,
RetryWaitMax: 1500 * time.Millisecond,
RetryMax: c.config.MaxRetries,
RetryMax: maxRetries,
CheckRetry: retryablehttp.DefaultRetryPolicy,
Backoff: backoff,
ErrorHandler: retryablehttp.PassthroughErrorHandler,
@ -649,6 +667,9 @@ START:
var result *Response
resp, err := client.Do(req)
if cancelFunc != nil {
cancelFunc()
}
if resp != nil {
result = &Response{Response: resp}
}

View File

@ -7,7 +7,6 @@ import (
"os"
"strings"
"testing"
"time"
)
func init() {
@ -244,22 +243,10 @@ func TestClientTimeoutSetting(t *testing.T) {
defer os.Setenv(EnvVaultClientTimeout, oldClientTimeout)
config := DefaultConfig()
config.ReadEnvironment()
client, err := NewClient(config)
_, err := NewClient(config)
if err != nil {
t.Fatal(err)
}
_ = client.NewRequest("PUT", "/")
if client.config.HttpClient.Timeout != time.Second*10 {
t.Fatalf("error setting client timeout using env variable")
}
// Setting custom client timeout for a new request
client.SetClientTimeout(time.Second * 20)
_ = client.NewRequest("PUT", "/")
if client.config.HttpClient.Timeout != time.Second*20 {
t.Fatalf("error setting client timeout using SetClientTimeout")
}
}
type roundTripperFunc func(*http.Request) (*http.Response, error)