diff --git a/api/client.go b/api/client.go index 7c1798105..c6843348e 100644 --- a/api/client.go +++ b/api/client.go @@ -33,29 +33,30 @@ import ( ) const ( - EnvVaultAddress = "VAULT_ADDR" - EnvVaultAgentAddr = "VAULT_AGENT_ADDR" - EnvVaultCACert = "VAULT_CACERT" - EnvVaultCACertBytes = "VAULT_CACERT_BYTES" - EnvVaultCAPath = "VAULT_CAPATH" - EnvVaultClientCert = "VAULT_CLIENT_CERT" - EnvVaultClientKey = "VAULT_CLIENT_KEY" - EnvVaultClientTimeout = "VAULT_CLIENT_TIMEOUT" - EnvVaultSRVLookup = "VAULT_SRV_LOOKUP" - EnvVaultSkipVerify = "VAULT_SKIP_VERIFY" - EnvVaultNamespace = "VAULT_NAMESPACE" - EnvVaultTLSServerName = "VAULT_TLS_SERVER_NAME" - EnvVaultWrapTTL = "VAULT_WRAP_TTL" - EnvVaultMaxRetries = "VAULT_MAX_RETRIES" - EnvVaultToken = "VAULT_TOKEN" - EnvVaultMFA = "VAULT_MFA" - EnvRateLimit = "VAULT_RATE_LIMIT" - EnvHTTPProxy = "VAULT_HTTP_PROXY" - EnvVaultProxyAddr = "VAULT_PROXY_ADDR" - HeaderIndex = "X-Vault-Index" - HeaderForward = "X-Vault-Forward" - HeaderInconsistent = "X-Vault-Inconsistent" - TLSErrorString = "This error usually means that the server is running with TLS disabled\n" + + EnvVaultAddress = "VAULT_ADDR" + EnvVaultAgentAddr = "VAULT_AGENT_ADDR" + EnvVaultCACert = "VAULT_CACERT" + EnvVaultCACertBytes = "VAULT_CACERT_BYTES" + EnvVaultCAPath = "VAULT_CAPATH" + EnvVaultClientCert = "VAULT_CLIENT_CERT" + EnvVaultClientKey = "VAULT_CLIENT_KEY" + EnvVaultClientTimeout = "VAULT_CLIENT_TIMEOUT" + EnvVaultSRVLookup = "VAULT_SRV_LOOKUP" + EnvVaultSkipVerify = "VAULT_SKIP_VERIFY" + EnvVaultNamespace = "VAULT_NAMESPACE" + EnvVaultTLSServerName = "VAULT_TLS_SERVER_NAME" + EnvVaultWrapTTL = "VAULT_WRAP_TTL" + EnvVaultMaxRetries = "VAULT_MAX_RETRIES" + EnvVaultToken = "VAULT_TOKEN" + EnvVaultMFA = "VAULT_MFA" + EnvRateLimit = "VAULT_RATE_LIMIT" + EnvHTTPProxy = "VAULT_HTTP_PROXY" + EnvVaultProxyAddr = "VAULT_PROXY_ADDR" + EnvVaultDisableRedirects = "VAULT_DISABLE_REDIRECTS" + HeaderIndex = "X-Vault-Index" + HeaderForward = "X-Vault-Forward" + HeaderInconsistent = "X-Vault-Inconsistent" + TLSErrorString = "This error usually means that the server is running with TLS disabled\n" + "but the client is configured to use TLS. Please either enable TLS\n" + "on the server or run the client with -address set to an address\n" + "that uses the http protocol:\n\n" + @@ -176,6 +177,16 @@ type Config struct { // since there will be a performance penalty paid upon each request. // This feature requires Enterprise server-side. ReadYourWrites bool + + // DisableRedirects when set to true, will prevent the client from + // automatically following a (single) redirect response to its initial + // request. This behavior may be desirable if using Vault CLI on the server + // side. + // + // Note: Disabling redirect following behavior could cause issues with + // commands such as 'vault operator raft snapshot' as this redirects to the + // primary node. + DisableRedirects bool } // TLSConfig contains the parameters needed to configure TLS on the HTTP client @@ -340,6 +351,7 @@ func (c *Config) ReadEnvironment() error { var envSRVLookup bool var limit *rate.Limiter var envVaultProxy string + var envVaultDisableRedirects bool // Parse the environment variables if v := os.Getenv(EnvVaultAddress); v != "" { @@ -388,7 +400,7 @@ func (c *Config) ReadEnvironment() error { var err error envInsecure, err = strconv.ParseBool(v) if err != nil { - return fmt.Errorf("could not parse VAULT_SKIP_VERIFY") + return fmt.Errorf("could not parse %s", EnvVaultSkipVerify) } } if v := os.Getenv(EnvVaultSRVLookup); v != "" { @@ -412,6 +424,16 @@ func (c *Config) ReadEnvironment() error { envVaultProxy = v } + if v := os.Getenv(EnvVaultDisableRedirects); v != "" { + var err error + envVaultDisableRedirects, err = strconv.ParseBool(v) + if err != nil { + return fmt.Errorf("could not parse %s", EnvVaultDisableRedirects) + } + + c.DisableRedirects = envVaultDisableRedirects + } + // Configure the HTTP clients TLS configuration. t := &TLSConfig{ CACert: envCACert, @@ -1270,6 +1292,7 @@ func (c *Client) rawRequestWithContext(ctx context.Context, r *Request) (*Respon outputCurlString := c.config.OutputCurlString outputPolicy := c.config.OutputPolicy logger := c.config.Logger + disableRedirects := c.config.DisableRedirects c.config.modifyLock.RUnlock() c.modifyLock.RUnlock() @@ -1363,8 +1386,8 @@ START: return result, err } - // Check for a redirect, only allowing for a single redirect - if (resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307) && redirectCount == 0 { + // Check for a redirect, only allowing for a single redirect (if redirects aren't disabled) + if (resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307) && redirectCount == 0 && !disableRedirects { // Parse the updated location respLoc, err := resp.Location() if err != nil { @@ -1423,6 +1446,7 @@ func (c *Client) httpRequestWithContext(ctx context.Context, r *Request) (*Respo httpClient := c.config.HttpClient outputCurlString := c.config.OutputCurlString outputPolicy := c.config.OutputPolicy + disableRedirects := c.config.DisableRedirects // add headers if c.headers != nil { @@ -1495,8 +1519,8 @@ func (c *Client) httpRequestWithContext(ctx context.Context, r *Request) (*Respo return result, err } - // Check for a redirect, only allowing for a single redirect - if resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307 { + // Check for a redirect, only allowing for a single redirect, if redirects aren't disabled + if (resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307) && !disableRedirects { // Parse the updated location respLoc, err := resp.Location() if err != nil { diff --git a/api/client_test.go b/api/client_test.go index 2305d42fe..844dcadd9 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -209,6 +209,67 @@ func TestClientBadToken(t *testing.T) { } } +func TestClientDisableRedirects(t *testing.T) { + tests := map[string]struct { + statusCode int + expectedNumReqs int + disableRedirects bool + }{ + "Disabled redirects: Moved permanently": {statusCode: 301, expectedNumReqs: 1, disableRedirects: true}, + "Disabled redirects: Found": {statusCode: 302, expectedNumReqs: 1, disableRedirects: true}, + "Disabled redirects: Temporary Redirect": {statusCode: 307, expectedNumReqs: 1, disableRedirects: true}, + "Enable redirects: Moved permanently": {statusCode: 301, expectedNumReqs: 2, disableRedirects: false}, + } + + for name, tc := range tests { + test := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + numReqs := 0 + var config *Config + + respFunc := func(w http.ResponseWriter, req *http.Request) { + // Track how many requests the server has handled + numReqs++ + // Send back the relevant status code and generate a location + w.Header().Set("Location", fmt.Sprintf(config.Address+"/reqs/%v", numReqs)) + w.WriteHeader(test.statusCode) + } + + config, ln := testHTTPServer(t, http.HandlerFunc(respFunc)) + config.DisableRedirects = test.disableRedirects + defer ln.Close() + + client, err := NewClient(config) + if err != nil { + t.Fatalf("%s: error %v", name, err) + } + + req := client.NewRequest("GET", "/") + resp, err := client.rawRequestWithContext(context.Background(), req) + if err != nil { + t.Fatalf("%s: error %v", name, err) + } + + if numReqs != test.expectedNumReqs { + t.Fatalf("%s: expected %v request(s) but got %v", name, test.expectedNumReqs, numReqs) + } + + if resp.StatusCode != test.statusCode { + t.Fatalf("%s: expected status code %v got %v", name, test.statusCode, resp.StatusCode) + } + + location, err := resp.Location() + if err != nil { + t.Fatalf("%s error %v", name, err) + } + if req.URL.String() == location.String() { + t.Fatalf("%s: expected request URL %v to be different from redirect URL %v", name, req.URL, resp.Request.URL) + } + }) + } +} + func TestClientRedirect(t *testing.T) { primary := func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("test")) @@ -320,6 +381,7 @@ func TestClientEnvSettings(t *testing.T) { oldClientKey := os.Getenv(EnvVaultClientKey) oldSkipVerify := os.Getenv(EnvVaultSkipVerify) oldMaxRetries := os.Getenv(EnvVaultMaxRetries) + oldDisableRedirects := os.Getenv(EnvVaultDisableRedirects) os.Setenv(EnvVaultCACert, cwd+"/test-fixtures/keys/cert.pem") os.Setenv(EnvVaultCACertBytes, string(caCertBytes)) @@ -328,6 +390,7 @@ func TestClientEnvSettings(t *testing.T) { os.Setenv(EnvVaultClientKey, cwd+"/test-fixtures/keys/key.pem") os.Setenv(EnvVaultSkipVerify, "true") os.Setenv(EnvVaultMaxRetries, "5") + os.Setenv(EnvVaultDisableRedirects, "true") defer func() { os.Setenv(EnvVaultCACert, oldCACert) @@ -337,6 +400,7 @@ func TestClientEnvSettings(t *testing.T) { os.Setenv(EnvVaultClientKey, oldClientKey) os.Setenv(EnvVaultSkipVerify, oldSkipVerify) os.Setenv(EnvVaultMaxRetries, oldMaxRetries) + os.Setenv(EnvVaultDisableRedirects, oldDisableRedirects) }() config := DefaultConfig() @@ -354,6 +418,9 @@ func TestClientEnvSettings(t *testing.T) { if tlsConfig.InsecureSkipVerify != true { t.Fatalf("bad: %v", tlsConfig.InsecureSkipVerify) } + if config.DisableRedirects != true { + t.Fatalf("bad: expected disable redirects to be true: %v", config.DisableRedirects) + } } func TestClientDeprecatedEnvSettings(t *testing.T) { diff --git a/changelog/17352.txt b/changelog/17352.txt new file mode 100644 index 000000000..a5b5ca5ea --- /dev/null +++ b/changelog/17352.txt @@ -0,0 +1,3 @@ +```release-note:improvement +api: Support VAULT_DISABLE_REDIRECTS environment variable (and --disable-redirects flag) to disable default client behavior and prevent the client following any redirection responses. +``` \ No newline at end of file diff --git a/command/base.go b/command/base.go index c4dd042b1..3aa5bb749 100644 --- a/command/base.go +++ b/command/base.go @@ -40,19 +40,20 @@ type BaseCommand struct { flags *FlagSets flagsOnce sync.Once - flagAddress string - flagAgentAddress string - flagCACert string - flagCAPath string - flagClientCert string - flagClientKey string - flagNamespace string - flagNS string - flagPolicyOverride bool - flagTLSServerName string - flagTLSSkipVerify bool - flagWrapTTL time.Duration - flagUnlockKey string + flagAddress string + flagAgentAddress string + flagCACert string + flagCAPath string + flagClientCert string + flagClientKey string + flagNamespace string + flagNS string + flagPolicyOverride bool + flagTLSServerName string + flagTLSSkipVerify bool + flagDisableRedirects bool + flagWrapTTL time.Duration + flagUnlockKey string flagFormat string flagField string @@ -427,6 +428,15 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets { "transmissions to and from the Vault server.", }) + f.BoolVar(&BoolVar{ + Name: flagNameDisableRedirects, + Target: &c.flagDisableRedirects, + Default: false, + EnvVar: api.EnvVaultDisableRedirects, + Usage: "Disable the default client behavior, which honors a single " + + "redirect response from a request", + }) + f.BoolVar(&BoolVar{ Name: "policy-override", Target: &c.flagPolicyOverride, diff --git a/command/commands.go b/command/commands.go index d4ce6b6ca..82b4919e0 100644 --- a/command/commands.go +++ b/command/commands.go @@ -126,6 +126,8 @@ const ( flagNameAllowedManagedKeys = "allowed-managed-keys" // flagNamePluginVersion selects what version of a plugin should be used. flagNamePluginVersion = "plugin-version" + // flagNameDisableRedirects is used to prevent the client from honoring a single redirect as a response to a request + flagNameDisableRedirects = "disable-redirects" ) var ( diff --git a/website/content/docs/commands/index.mdx b/website/content/docs/commands/index.mdx index e476a6876..1a5282c13 100644 --- a/website/content/docs/commands/index.mdx +++ b/website/content/docs/commands/index.mdx @@ -419,6 +419,12 @@ All requests will resolve the specified proxy; there is no way to exclude ~> Note: If both `VAULT_HTTP_PROXY` and `VAULT_PROXY_ADDR` environment variables are supplied, `VAULT_PROXY_ADDR` will be prioritized and preferred. +### `VAULT_DISABLE_REDIRECTS` + +Prevents the Vault client from following redirects. By default, the Vault client will automatically follow a single redirect. + +~> **Note:** Disabling redirect following behavior could cause issues with commands such as 'vault operator raft snapshot' as this command redirects the request to the cluster's primary node. + ## Flags There are different CLI flags that are available depending on subcommands. Some