From e899e2adfa9638e064d8e4205aa3202cd831534f Mon Sep 17 00:00:00 2001 From: Ben Ash <32777270+benashz@users.noreply.github.com> Date: Mon, 19 Jul 2021 17:15:31 -0400 Subject: [PATCH] Add ability to optionally clone an api.Client's headers (#12117) --- api/client.go | 29 +++++++++ api/client_test.go | 143 +++++++++++++++++++++++++++++--------------- changelog/12117.txt | 3 + 3 files changed, 126 insertions(+), 49 deletions(-) create mode 100644 changelog/12117.txt diff --git a/api/client.go b/api/client.go index b7282dbaf..870dee776 100644 --- a/api/client.go +++ b/api/client.go @@ -125,6 +125,9 @@ type Config struct { // SRVLookup enables the client to lookup the host through DNS SRV lookup SRVLookup bool + + // CloneHeaders ensures that the source client's headers are copied to its clone. + CloneHeaders bool } // TLSConfig contains the parameters needed to configure TLS on the HTTP client @@ -504,6 +507,7 @@ func (c *Client) CloneConfig() *Config { newConfig.Limiter = c.config.Limiter newConfig.OutputCurlString = c.config.OutputCurlString newConfig.SRVLookup = c.config.SRVLookup + newConfig.CloneHeaders = c.config.CloneHeaders // we specifically want a _copy_ of the client here, not a pointer to the original one newClient := *c.config.HttpClient @@ -809,6 +813,26 @@ func (c *Client) SetLogger(logger retryablehttp.LeveledLogger) { c.config.Logger = logger } +// SetCloneHeaders to allow headers to be copied whenever the client is cloned. +func (c *Client) SetCloneHeaders(cloneHeaders bool) { + c.modifyLock.Lock() + defer c.modifyLock.Unlock() + c.config.modifyLock.Lock() + defer c.config.modifyLock.Unlock() + + c.config.CloneHeaders = cloneHeaders +} + +// CloneHeaders gets the configured CloneHeaders value. +func (c *Client) CloneHeaders() bool { + c.modifyLock.RLock() + defer c.modifyLock.RUnlock() + c.config.modifyLock.RLock() + defer c.config.modifyLock.RUnlock() + + return c.config.CloneHeaders +} + // Clone creates a new client with the same configuration. Note that the same // underlying http.Client is used; modifying the client from more than one // goroutine at once may not be safe, so modify the client as needed and then @@ -839,12 +863,17 @@ func (c *Client) Clone() (*Client, error) { OutputCurlString: config.OutputCurlString, AgentAddress: config.AgentAddress, SRVLookup: config.SRVLookup, + CloneHeaders: config.CloneHeaders, } client, err := NewClient(newConfig) if err != nil { return nil, err } + if config.CloneHeaders { + client.SetHeaders(c.Headers().Clone()) + } + return client, nil } diff --git a/api/client_test.go b/api/client_test.go index 4fe356e1b..474db04e1 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -9,6 +9,7 @@ import ( "net/http" "net/url" "os" + "reflect" "strings" "testing" "time" @@ -409,63 +410,107 @@ func TestClientNonTransportRoundTripper(t *testing.T) { } func TestClone(t *testing.T) { - client1, err := NewClient(DefaultConfig()) - if err != nil { - t.Fatalf("NewClient failed: %v", err) + type fields struct { + } + tests := []struct { + name string + config *Config + headers *http.Header + }{ + { + name: "default", + config: DefaultConfig(), + }, + { + name: "cloneHeaders", + config: &Config{ + CloneHeaders: true, + }, + headers: &http.Header{ + "X-foo": []string{"bar"}, + "X-baz": []string{"qux"}, + }, + }, } - // Set all of the things that we provide setter methods for, which modify config values - err = client1.SetAddress("http://example.com:8080") - if err != nil { - t.Fatalf("SetAddress failed: %v", err) - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client1, err := NewClient(tt.config) + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } - clientTimeout := time.Until(time.Now().AddDate(0, 0, 1)) - client1.SetClientTimeout(clientTimeout) + // Set all of the things that we provide setter methods for, which modify config values + err = client1.SetAddress("http://example.com:8080") + if err != nil { + t.Fatalf("SetAddress failed: %v", err) + } - checkRetry := func(ctx context.Context, resp *http.Response, err error) (bool, error) { - return true, nil - } - client1.SetCheckRetry(checkRetry) + clientTimeout := time.Until(time.Now().AddDate(0, 0, 1)) + client1.SetClientTimeout(clientTimeout) - client1.SetLogger(hclog.NewNullLogger()) + checkRetry := func(ctx context.Context, resp *http.Response, err error) (bool, error) { + return true, nil + } + client1.SetCheckRetry(checkRetry) - client1.SetLimiter(5.0, 10) - client1.SetMaxRetries(5) - client1.SetOutputCurlString(true) - client1.SetSRVLookup(true) + client1.SetLogger(hclog.NewNullLogger()) - client2, err := client1.Clone() - if err != nil { - t.Fatalf("Clone failed: %v", err) - } + client1.SetLimiter(5.0, 10) + client1.SetMaxRetries(5) + client1.SetOutputCurlString(true) + client1.SetSRVLookup(true) - if client1.Address() != client2.Address() { - t.Fatalf("addresses don't match: %v vs %v", client1.Address(), client2.Address()) - } - if client1.ClientTimeout() != client2.ClientTimeout() { - t.Fatalf("timeouts don't match: %v vs %v", client1.ClientTimeout(), client2.ClientTimeout()) - } - if client1.CheckRetry() != nil && client2.CheckRetry() == nil { - t.Fatal("checkRetry functions don't match. client2 is nil.") - } - if (client1.Limiter() != nil && client2.Limiter() == nil) || (client1.Limiter() == nil && client2.Limiter() != nil) { - t.Fatalf("limiters don't match: %v vs %v", client1.Limiter(), client2.Limiter()) - } - if client1.Limiter().Limit() != client2.Limiter().Limit() { - t.Fatalf("limiter limits don't match: %v vs %v", client1.Limiter().Limit(), client2.Limiter().Limit()) - } - if client1.Limiter().Burst() != client2.Limiter().Burst() { - t.Fatalf("limiter bursts don't match: %v vs %v", client1.Limiter().Burst(), client2.Limiter().Burst()) - } - if client1.MaxRetries() != client2.MaxRetries() { - t.Fatalf("maxRetries don't match: %v vs %v", client1.MaxRetries(), client2.MaxRetries()) - } - if client1.OutputCurlString() != client2.OutputCurlString() { - t.Fatalf("outputCurlString doesn't match: %v vs %v", client1.OutputCurlString(), client2.OutputCurlString()) - } - if client1.SRVLookup() != client2.SRVLookup() { - t.Fatalf("SRVLookup doesn't match: %v vs %v", client1.SRVLookup(), client2.SRVLookup()) + if tt.headers != nil { + client1.SetHeaders(*tt.headers) + } + + client2, err := client1.Clone() + if err != nil { + t.Fatalf("Clone failed: %v", err) + } + + if client1.Address() != client2.Address() { + t.Fatalf("addresses don't match: %v vs %v", client1.Address(), client2.Address()) + } + if client1.ClientTimeout() != client2.ClientTimeout() { + t.Fatalf("timeouts don't match: %v vs %v", client1.ClientTimeout(), client2.ClientTimeout()) + } + if client1.CheckRetry() != nil && client2.CheckRetry() == nil { + t.Fatal("checkRetry functions don't match. client2 is nil.") + } + if (client1.Limiter() != nil && client2.Limiter() == nil) || (client1.Limiter() == nil && client2.Limiter() != nil) { + t.Fatalf("limiters don't match: %v vs %v", client1.Limiter(), client2.Limiter()) + } + if client1.Limiter().Limit() != client2.Limiter().Limit() { + t.Fatalf("limiter limits don't match: %v vs %v", client1.Limiter().Limit(), client2.Limiter().Limit()) + } + if client1.Limiter().Burst() != client2.Limiter().Burst() { + t.Fatalf("limiter bursts don't match: %v vs %v", client1.Limiter().Burst(), client2.Limiter().Burst()) + } + if client1.MaxRetries() != client2.MaxRetries() { + t.Fatalf("maxRetries don't match: %v vs %v", client1.MaxRetries(), client2.MaxRetries()) + } + if client1.OutputCurlString() != client2.OutputCurlString() { + t.Fatalf("outputCurlString doesn't match: %v vs %v", client1.OutputCurlString(), client2.OutputCurlString()) + } + if client1.SRVLookup() != client2.SRVLookup() { + t.Fatalf("SRVLookup doesn't match: %v vs %v", client1.SRVLookup(), client2.SRVLookup()) + } + if tt.config.CloneHeaders { + if !reflect.DeepEqual(client1.Headers(), client2.Headers()) { + t.Fatalf("Headers() don't match: %v vs %v", client1.Headers(), client2.Headers()) + } + if client1.config.CloneHeaders != client2.config.CloneHeaders { + t.Fatalf("config.CloneHeaders doesn't match: %v vs %v", client1.config.CloneHeaders, client2.config.CloneHeaders) + } + if tt.headers != nil { + if !reflect.DeepEqual(*tt.headers, client2.Headers()) { + t.Fatalf("expected headers %v, actual %v", *tt.headers, client2.Headers()) + } + } + } + }) } } diff --git a/changelog/12117.txt b/changelog/12117.txt new file mode 100644 index 000000000..081a587e6 --- /dev/null +++ b/changelog/12117.txt @@ -0,0 +1,3 @@ +```release-note:improvement +api: Allow cloning `api.Client` HTTP headers via `api.Config.CloneHeaders` or `api.Client.SetCloneHeaders`. +```