diff --git a/changelog/12508.txt b/changelog/12508.txt new file mode 100644 index 000000000..52c9e7c83 --- /dev/null +++ b/changelog/12508.txt @@ -0,0 +1,3 @@ +```release-note:improvement +cli: add new http option : -header which enable sending arbitrary headers with the cli +``` diff --git a/command/base.go b/command/base.go index d80c2b566..558ec4993 100644 --- a/command/base.go +++ b/command/base.go @@ -58,6 +58,8 @@ type BaseCommand struct { flagMFA []string + flagHeader map[string]string + tokenHelper token.TokenHelper client *api.Client @@ -154,6 +156,23 @@ func (c *BaseCommand) Client() (*api.Client, error) { client.SetPolicyOverride(c.flagPolicyOverride) } + if c.flagHeader != nil { + + var forbiddenHeaders []string + for key, val := range c.flagHeader { + + if strings.HasPrefix(key, "X-Vault-") { + forbiddenHeaders = append(forbiddenHeaders, key) + continue + } + client.AddHeader(key, val) + } + + if len(forbiddenHeaders) > 0 { + return nil, fmt.Errorf("failed to setup Headers[%s]: Header starting by 'X-Vault-' are for internal usage only", strings.Join(forbiddenHeaders, ", ")) + } + } + c.client = client return client, nil @@ -365,6 +384,15 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets { Usage: "Key to unlock a namespace API lock.", }) + f.StringMapVar(&StringMapVar{ + Name: "header", + Target: &c.flagHeader, + Completion: complete.PredictAnything, + Usage: "Key-value pair provided as key=value to provide http header added to any request done by the CLI." + + "Trying to add headers starting with 'X-Vault-' is forbidden and will make the command fail " + + "This can be specified multiple times.", + }) + } if bit&(FlagSetOutputField|FlagSetOutputFormat) != 0 { diff --git a/command/base_test.go b/command/base_test.go new file mode 100644 index 000000000..b3f75e0eb --- /dev/null +++ b/command/base_test.go @@ -0,0 +1,69 @@ +package command + +import ( + "net/http" + "reflect" + "testing" +) + +func getDefaultCliHeaders(t *testing.T) http.Header { + bc := &BaseCommand{} + cli, err := bc.Client() + if err != nil { + t.Fatal(err) + } + return cli.Headers() +} + +func TestClient_FlagHeader(t *testing.T) { + defaultHeaders := getDefaultCliHeaders(t) + + cases := []struct { + Input map[string]string + Valid bool + }{ + { + map[string]string{}, + true, + }, + { + map[string]string{"foo": "bar", "header2": "value2"}, + true, + }, + { + map[string]string{"X-Vault-foo": "bar", "header2": "value2"}, + false, + }, + } + + for _, tc := range cases { + expectedHeaders := defaultHeaders.Clone() + for key, val := range tc.Input { + expectedHeaders.Add(key, val) + } + + bc := &BaseCommand{flagHeader: tc.Input} + cli, err := bc.Client() + + if err == nil && !tc.Valid { + t.Errorf("No error for input[%#v], but not valid", tc.Input) + continue + } + + if err != nil { + if tc.Valid { + t.Errorf("Error[%v] with input[%#v], but valid", err, tc.Input) + } + continue + } + + if cli == nil { + t.Error("client should not be nil") + } + + actualHeaders := cli.Headers() + if !reflect.DeepEqual(expectedHeaders, actualHeaders) { + t.Errorf("expected [%#v] but got [%#v]", expectedHeaders, actualHeaders) + } + } +}