From d648e6511e254fee02ba2a8b42652d748e20673e Mon Sep 17 00:00:00 2001 From: Conor Mongey Date: Thu, 7 Jan 2021 23:48:53 +0000 Subject: [PATCH] Move header methods from config to client --- api/api.go | 49 ++++++++++++++++++++++++++++++++++++++++++------- api/api_test.go | 26 ++++++++++++++++++++------ 2 files changed, 62 insertions(+), 13 deletions(-) diff --git a/api/api.go b/api/api.go index 04882fe51..dbecaa5af 100644 --- a/api/api.go +++ b/api/api.go @@ -14,6 +14,7 @@ import ( "os" "strconv" "strings" + "sync" "time" "github.com/hashicorp/go-cleanhttp" @@ -314,8 +315,6 @@ type Config struct { Namespace string TLSConfig TLSConfig - - Header http.Header } // TLSConfig is used to generate a TLSClientConfig that's useful for talking to @@ -550,9 +549,48 @@ func (c *Config) GenerateEnv() []string { // Client provides a client to the Consul API type Client struct { + modifyLock sync.RWMutex + headers http.Header + config Config } +// Headers gets the current set of headers used for requests. This returns a +// copy; to modify it call AddHeader or SetHeaders. +func (c *Client) Headers() http.Header { + c.modifyLock.RLock() + defer c.modifyLock.RUnlock() + + if c.headers == nil { + return nil + } + + ret := make(http.Header) + for k, v := range c.headers { + for _, val := range v { + ret[k] = append(ret[k], val) + } + } + + return ret +} + +// AddHeader allows a single header key/value pair to be added +// in a race-safe fashion. +func (c *Client) AddHeader(key, value string) { + c.modifyLock.Lock() + defer c.modifyLock.Unlock() + c.headers.Add(key, value) +} + +// SetHeaders clears all previous headers and uses only the given +// ones going forward. +func (c *Client) SetHeaders(headers http.Header) { + c.modifyLock.Lock() + defer c.modifyLock.Unlock() + c.headers = headers +} + // NewClient returns a new client func NewClient(config *Config) (*Client, error) { // bootstrap the config @@ -642,7 +680,7 @@ func NewClient(config *Config) (*Client, error) { config.Token = defConfig.Token } - return &Client{config: *config}, nil + return &Client{config: *config, headers: make(http.Header)}, nil } // NewHttpClient returns an http client configured with the given Transport and TLS @@ -855,12 +893,9 @@ func (c *Client) newRequest(method, path string) *request { Path: path, }, params: make(map[string][]string), - header: make(http.Header), + header: c.Headers(), } - if c.config.Header != nil { - r.header = c.config.Header - } if c.config.Datacenter != "" { r.params.Set("dc", c.config.Datacenter) } diff --git a/api/api_test.go b/api/api_test.go index 45ec02730..7b593b22e 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -810,17 +810,31 @@ func TestAPI_SetWriteOptions(t *testing.T) { func TestAPI_Headers(t *testing.T) { t.Parallel() - c, s := makeClientWithConfig(t, func(c *Config) { - c.Header = http.Header{ - "Hello": []string{"World"}, - } - }, nil) + c, s := makeClient(t) defer s.Stop() + if len(c.Headers()) != 0 { + t.Fatalf("expected headers to be empty: %v", c.Headers()) + } + + c.AddHeader("Hello", "World") r := c.newRequest("GET", "/v1/kv/foo") if r.header.Get("Hello") != "World" { - t.Fatalf("bad: %v", r.header) + t.Fatalf("Hello header not set : %v", r.header) + } + + c.SetHeaders(http.Header{ + "Auth": []string{"Token"}, + }) + + r = c.newRequest("GET", "/v1/kv/foo") + if r.header.Get("Hello") != "" { + t.Fatalf("Hello header should not be set: %v", r.header) + } + + if r.header.Get("Auth") != "Token" { + t.Fatalf("Auth header not set: %v", r.header) } }