diff --git a/api/client.go b/api/client.go index b6c497d0b..475899ee4 100644 --- a/api/client.go +++ b/api/client.go @@ -139,6 +139,9 @@ type Config struct { // its clone. CloneHeaders bool + // CloneToken from parent. + CloneToken bool + // ReadYourWrites ensures isolated read-after-write semantics by // providing discovered cluster replication states in each request. // The shared state is automatically propagated to all Client clones. @@ -547,6 +550,7 @@ func (c *Client) CloneConfig() *Config { newConfig.OutputCurlString = c.config.OutputCurlString newConfig.SRVLookup = c.config.SRVLookup newConfig.CloneHeaders = c.config.CloneHeaders + newConfig.CloneToken = c.config.CloneToken newConfig.ReadYourWrites = c.config.ReadYourWrites // we specifically want a _copy_ of the client here, not a pointer to the original one @@ -873,6 +877,26 @@ func (c *Client) CloneHeaders() bool { return c.config.CloneHeaders } +// SetCloneToken from parent +func (c *Client) SetCloneToken(cloneToken bool) { + c.modifyLock.Lock() + defer c.modifyLock.Unlock() + c.config.modifyLock.Lock() + defer c.config.modifyLock.Unlock() + + c.config.CloneToken = cloneToken +} + +// CloneToken gets the configured CloneToken value. +func (c *Client) CloneToken() bool { + c.modifyLock.RLock() + defer c.modifyLock.RUnlock() + c.config.modifyLock.RLock() + defer c.config.modifyLock.RUnlock() + + return c.config.CloneToken +} + // SetReadYourWrites to prevent reading stale cluster replication state. func (c *Client) SetReadYourWrites(preventStaleReads bool) { c.modifyLock.Lock() @@ -932,6 +956,7 @@ func (c *Client) Clone() (*Client, error) { AgentAddress: config.AgentAddress, SRVLookup: config.SRVLookup, CloneHeaders: config.CloneHeaders, + CloneToken: config.CloneToken, ReadYourWrites: config.ReadYourWrites, } client, err := NewClient(newConfig) @@ -943,6 +968,10 @@ func (c *Client) Clone() (*Client, error) { client.SetHeaders(c.Headers().Clone()) } + if config.CloneToken { + client.SetToken(c.token) + } + client.replicationStateStore = c.replicationStateStore return client, nil diff --git a/api/client_test.go b/api/client_test.go index 89d238c3e..30c6cad9d 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -420,6 +420,7 @@ func TestClone(t *testing.T) { name string config *Config headers *http.Header + token string }{ { name: "default", @@ -441,91 +442,119 @@ func TestClone(t *testing.T) { ReadYourWrites: true, }, }, + { + name: "cloneToken", + config: &Config{ + CloneToken: true, + }, + token: "cloneToken", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - client1, err := NewClient(tt.config) + parent, err := NewClient(tt.config) if err != nil { t.Fatalf("NewClient failed: %v", err) } // Set all of the things that we provide setter methods for, which modify config values - err = client1.SetAddress("http://example.com:8080") + err = parent.SetAddress("http://example.com:8080") if err != nil { t.Fatalf("SetAddress failed: %v", err) } clientTimeout := time.Until(time.Now().AddDate(0, 0, 1)) - client1.SetClientTimeout(clientTimeout) + parent.SetClientTimeout(clientTimeout) checkRetry := func(ctx context.Context, resp *http.Response, err error) (bool, error) { return true, nil } - client1.SetCheckRetry(checkRetry) + parent.SetCheckRetry(checkRetry) - client1.SetLogger(hclog.NewNullLogger()) + parent.SetLogger(hclog.NewNullLogger()) - client1.SetLimiter(5.0, 10) - client1.SetMaxRetries(5) - client1.SetOutputCurlString(true) - client1.SetSRVLookup(true) + parent.SetLimiter(5.0, 10) + parent.SetMaxRetries(5) + parent.SetOutputCurlString(true) + parent.SetSRVLookup(true) if tt.headers != nil { - client1.SetHeaders(*tt.headers) + parent.SetHeaders(*tt.headers) } - client2, err := client1.Clone() + if tt.token != "" { + parent.SetToken(tt.token) + } + + clone, err := parent.Clone() if err != nil { t.Fatalf("Clone failed: %v", err) } - if client1.Address() != client2.Address() { - t.Fatalf("addresses don't match: %v vs %v", client1.Address(), client2.Address()) + if parent.Address() != clone.Address() { + t.Fatalf("addresses don't match: %v vs %v", parent.Address(), clone.Address()) } - if client1.ClientTimeout() != client2.ClientTimeout() { - t.Fatalf("timeouts don't match: %v vs %v", client1.ClientTimeout(), client2.ClientTimeout()) + if parent.ClientTimeout() != clone.ClientTimeout() { + t.Fatalf("timeouts don't match: %v vs %v", parent.ClientTimeout(), clone.ClientTimeout()) } - if client1.CheckRetry() != nil && client2.CheckRetry() == nil { - t.Fatal("checkRetry functions don't match. client2 is nil.") + if parent.CheckRetry() != nil && clone.CheckRetry() == nil { + t.Fatal("checkRetry functions don't match. clone is nil.") } - if (client1.Limiter() != nil && client2.Limiter() == nil) || (client1.Limiter() == nil && client2.Limiter() != nil) { - t.Fatalf("limiters don't match: %v vs %v", client1.Limiter(), client2.Limiter()) + if (parent.Limiter() != nil && clone.Limiter() == nil) || (parent.Limiter() == nil && clone.Limiter() != nil) { + t.Fatalf("limiters don't match: %v vs %v", parent.Limiter(), clone.Limiter()) } - if client1.Limiter().Limit() != client2.Limiter().Limit() { - t.Fatalf("limiter limits don't match: %v vs %v", client1.Limiter().Limit(), client2.Limiter().Limit()) + if parent.Limiter().Limit() != clone.Limiter().Limit() { + t.Fatalf("limiter limits don't match: %v vs %v", parent.Limiter().Limit(), clone.Limiter().Limit()) } - if client1.Limiter().Burst() != client2.Limiter().Burst() { - t.Fatalf("limiter bursts don't match: %v vs %v", client1.Limiter().Burst(), client2.Limiter().Burst()) + if parent.Limiter().Burst() != clone.Limiter().Burst() { + t.Fatalf("limiter bursts don't match: %v vs %v", parent.Limiter().Burst(), clone.Limiter().Burst()) } - if client1.MaxRetries() != client2.MaxRetries() { - t.Fatalf("maxRetries don't match: %v vs %v", client1.MaxRetries(), client2.MaxRetries()) + if parent.MaxRetries() != clone.MaxRetries() { + t.Fatalf("maxRetries don't match: %v vs %v", parent.MaxRetries(), clone.MaxRetries()) } - if client1.OutputCurlString() != client2.OutputCurlString() { - t.Fatalf("outputCurlString doesn't match: %v vs %v", client1.OutputCurlString(), client2.OutputCurlString()) + if parent.OutputCurlString() != clone.OutputCurlString() { + t.Fatalf("outputCurlString doesn't match: %v vs %v", parent.OutputCurlString(), clone.OutputCurlString()) } - if client1.SRVLookup() != client2.SRVLookup() { - t.Fatalf("SRVLookup doesn't match: %v vs %v", client1.SRVLookup(), client2.SRVLookup()) + if parent.SRVLookup() != clone.SRVLookup() { + t.Fatalf("SRVLookup doesn't match: %v vs %v", parent.SRVLookup(), clone.SRVLookup()) } if tt.config.CloneHeaders { - if !reflect.DeepEqual(client1.Headers(), client2.Headers()) { - t.Fatalf("Headers() don't match: %v vs %v", client1.Headers(), client2.Headers()) + if !reflect.DeepEqual(parent.Headers(), clone.Headers()) { + t.Fatalf("Headers() don't match: %v vs %v", parent.Headers(), clone.Headers()) } - if client1.config.CloneHeaders != client2.config.CloneHeaders { - t.Fatalf("config.CloneHeaders doesn't match: %v vs %v", client1.config.CloneHeaders, client2.config.CloneHeaders) + if parent.config.CloneHeaders != clone.config.CloneHeaders { + t.Fatalf("config.CloneHeaders doesn't match: %v vs %v", parent.config.CloneHeaders, clone.config.CloneHeaders) } if tt.headers != nil { - if !reflect.DeepEqual(*tt.headers, client2.Headers()) { - t.Fatalf("expected headers %v, actual %v", *tt.headers, client2.Headers()) + if !reflect.DeepEqual(*tt.headers, clone.Headers()) { + t.Fatalf("expected headers %v, actual %v", *tt.headers, clone.Headers()) } } } - if tt.config.ReadYourWrites && client1.replicationStateStore == nil { + if tt.config.ReadYourWrites && parent.replicationStateStore == nil { t.Fatalf("replicationStateStore is nil") } - if !reflect.DeepEqual(client1.replicationStateStore, client2.replicationStateStore) { - t.Fatalf("expected replicationStateStore %v, actual %v", client1.replicationStateStore, - client2.replicationStateStore) + if tt.config.CloneToken { + if tt.token == "" { + t.Fatalf("test requires a non-empty token") + } + if parent.config.CloneToken != clone.config.CloneToken { + t.Fatalf("config.CloneToken doesn't match: %v vs %v", parent.config.CloneToken, clone.config.CloneToken) + } + if parent.token != clone.token { + t.Fatalf("tokens do not match: %v vs %v", parent.token, clone.token) + } + } else { + // assumes `VAULT_TOKEN` is unset or has an empty value. + expected := "" + if clone.token != expected { + t.Fatalf("expected clone's token %q, actual %q", expected, clone.token) + } + } + if !reflect.DeepEqual(parent.replicationStateStore, clone.replicationStateStore) { + t.Fatalf("expected replicationStateStore %v, actual %v", parent.replicationStateStore, + clone.replicationStateStore) } }) } @@ -1052,3 +1081,45 @@ func TestClient_SetReadYourWrites(t *testing.T) { }) } } + +func TestClient_SetCloneToken(t *testing.T) { + tests := []struct { + name string + calls []bool + }{ + { + name: "false", + calls: []bool{false}, + }, + { + name: "true", + calls: []bool{true}, + }, + { + name: "multi", + calls: []bool{true, false, true}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Client{ + config: &Config{}, + } + + var expected bool + for _, v := range tt.calls { + actual := c.CloneToken() + if expected != actual { + t.Fatalf("expected %v, actual %v", expected, actual) + } + + expected = v + c.SetCloneToken(expected) + actual = c.CloneToken() + if actual != expected { + t.Fatalf("SetCloneToken(): expected %v, actual %v", expected, actual) + } + } + }) + } +} diff --git a/changelog/13515.txt b/changelog/13515.txt new file mode 100644 index 000000000..7c34dc9e8 --- /dev/null +++ b/changelog/13515.txt @@ -0,0 +1,3 @@ +```release-note:improvement +api: Allow cloning `api.Client` tokens via `api.Config.CloneToken` or `api.Client.SetCloneToken()`. +```