// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package command import ( "crypto/tls" "crypto/x509" "fmt" "net/http" "os" "path/filepath" "reflect" "strings" "sync" "testing" "time" "github.com/hashicorp/go-hclog" vaultjwt "github.com/hashicorp/vault-plugin-auth-jwt" logicalKv "github.com/hashicorp/vault-plugin-secrets-kv" "github.com/hashicorp/vault/api" credAppRole "github.com/hashicorp/vault/builtin/credential/approle" "github.com/hashicorp/vault/command/agent" proxyConfig "github.com/hashicorp/vault/command/proxy/config" "github.com/hashicorp/vault/helper/useragent" vaulthttp "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/sdk/helper/logging" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func testProxyCommand(tb testing.TB, logger hclog.Logger) (*cli.MockUi, *ProxyCommand) { tb.Helper() ui := cli.NewMockUi() return ui, &ProxyCommand{ BaseCommand: &BaseCommand{ UI: ui, }, ShutdownCh: MakeShutdownCh(), SighupCh: MakeSighupCh(), logger: logger, startedCh: make(chan struct{}, 5), reloadedCh: make(chan struct{}, 5), } } // TestProxy_ExitAfterAuth tests the exit_after_auth flag, provided both // as config and via -exit-after-auth. func TestProxy_ExitAfterAuth(t *testing.T) { t.Run("via_config", func(t *testing.T) { testProxyExitAfterAuth(t, false) }) t.Run("via_flag", func(t *testing.T) { testProxyExitAfterAuth(t, true) }) } func testProxyExitAfterAuth(t *testing.T, viaFlag bool) { logger := logging.NewVaultLogger(hclog.Trace) coreConfig := &vault.CoreConfig{ Logger: logger, CredentialBackends: map[string]logical.Factory{ "jwt": vaultjwt.Factory, }, } cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ HandlerFunc: vaulthttp.Handler, }) cluster.Start() defer cluster.Cleanup() vault.TestWaitActive(t, cluster.Cores[0].Core) client := cluster.Cores[0].Client // Setup Vault err := client.Sys().EnableAuthWithOptions("jwt", &api.EnableAuthOptions{ Type: "jwt", }) if err != nil { t.Fatal(err) } _, err = client.Logical().Write("auth/jwt/config", map[string]interface{}{ "bound_issuer": "https://team-vault.auth0.com/", "jwt_validation_pubkeys": agent.TestECDSAPubKey, "jwt_supported_algs": "ES256", }) if err != nil { t.Fatal(err) } _, err = client.Logical().Write("auth/jwt/role/test", map[string]interface{}{ "role_type": "jwt", "bound_subject": "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients", "bound_audiences": "https://vault.plugin.auth.jwt.test", "user_claim": "https://vault/user", "groups_claim": "https://vault/groups", "policies": "test", "period": "3s", }) if err != nil { t.Fatal(err) } dir := t.TempDir() inf, err := os.CreateTemp(dir, "auth.jwt.test.") if err != nil { t.Fatal(err) } in := inf.Name() inf.Close() // We remove these files in this test since we don't need the files, we just need // a non-conflicting file name for the config. os.Remove(in) t.Logf("input: %s", in) sink1f, err := os.CreateTemp(dir, "sink1.jwt.test.") if err != nil { t.Fatal(err) } sink1 := sink1f.Name() sink1f.Close() os.Remove(sink1) t.Logf("sink1: %s", sink1) sink2f, err := os.CreateTemp(dir, "sink2.jwt.test.") if err != nil { t.Fatal(err) } sink2 := sink2f.Name() sink2f.Close() os.Remove(sink2) t.Logf("sink2: %s", sink2) conff, err := os.CreateTemp(dir, "conf.jwt.test.") if err != nil { t.Fatal(err) } conf := conff.Name() conff.Close() os.Remove(conf) t.Logf("config: %s", conf) jwtToken, _ := agent.GetTestJWT(t) if err := os.WriteFile(in, []byte(jwtToken), 0o600); err != nil { t.Fatal(err) } else { logger.Trace("wrote test jwt", "path", in) } exitAfterAuthTemplText := "exit_after_auth = true" if viaFlag { exitAfterAuthTemplText = "" } config := ` %s auto_auth { method { type = "jwt" config = { role = "test" path = "%s" } } sink { type = "file" config = { path = "%s" } } sink "file" { config = { path = "%s" } } } ` config = fmt.Sprintf(config, exitAfterAuthTemplText, in, sink1, sink2) if err := os.WriteFile(conf, []byte(config), 0o600); err != nil { t.Fatal(err) } else { logger.Trace("wrote test config", "path", conf) } doneCh := make(chan struct{}) go func() { ui, cmd := testProxyCommand(t, logger) cmd.client = client args := []string{"-config", conf} if viaFlag { args = append(args, "-exit-after-auth") } code := cmd.Run(args) if code != 0 { t.Errorf("expected %d to be %d", code, 0) t.Logf("output from proxy:\n%s", ui.OutputWriter.String()) t.Logf("error from proxy:\n%s", ui.ErrorWriter.String()) } close(doneCh) }() select { case <-doneCh: break case <-time.After(1 * time.Minute): t.Fatal("timeout reached while waiting for proxy to exit") } sink1Bytes, err := os.ReadFile(sink1) if err != nil { t.Fatal(err) } if len(sink1Bytes) == 0 { t.Fatal("got no output from sink 1") } sink2Bytes, err := os.ReadFile(sink2) if err != nil { t.Fatal(err) } if len(sink2Bytes) == 0 { t.Fatal("got no output from sink 2") } if string(sink1Bytes) != string(sink2Bytes) { t.Fatal("sink 1/2 values don't match") } } // TestProxy_AutoAuth_UserAgent tests that the User-Agent sent // to Vault by Vault Proxy is correct when performing Auto-Auth. // Uses the custom handler userAgentHandler (defined above) so // that Vault validates the User-Agent on requests sent by Proxy. func TestProxy_AutoAuth_UserAgent(t *testing.T) { logger := logging.NewVaultLogger(hclog.Trace) var h userAgentHandler cluster := vault.NewTestCluster(t, &vault.CoreConfig{ Logger: logger, CredentialBackends: map[string]logical.Factory{ "approle": credAppRole.Factory, }, }, &vault.TestClusterOptions{ NumCores: 1, HandlerFunc: vaulthttp.HandlerFunc( func(properties *vault.HandlerProperties) http.Handler { h.props = properties h.userAgentToCheckFor = useragent.ProxyAutoAuthString() h.requestMethodToCheck = "PUT" h.pathToCheck = "auth/approle/login" h.t = t return &h }), }) cluster.Start() defer cluster.Cleanup() serverClient := cluster.Cores[0].Client // Enable the approle auth method req := serverClient.NewRequest("POST", "/v1/sys/auth/approle") req.BodyBytes = []byte(`{ "type": "approle" }`) request(t, serverClient, req, 204) // Create a named role req = serverClient.NewRequest("PUT", "/v1/auth/approle/role/test-role") req.BodyBytes = []byte(`{ "secret_id_num_uses": "10", "secret_id_ttl": "1m", "token_max_ttl": "1m", "token_num_uses": "10", "token_ttl": "1m", "policies": "default" }`) request(t, serverClient, req, 204) // Fetch the RoleID of the named role req = serverClient.NewRequest("GET", "/v1/auth/approle/role/test-role/role-id") body := request(t, serverClient, req, 200) data := body["data"].(map[string]interface{}) roleID := data["role_id"].(string) // Get a SecretID issued against the named role req = serverClient.NewRequest("PUT", "/v1/auth/approle/role/test-role/secret-id") body = request(t, serverClient, req, 200) data = body["data"].(map[string]interface{}) secretID := data["secret_id"].(string) // Write the RoleID and SecretID to temp files roleIDPath := makeTempFile(t, "role_id.txt", roleID+"\n") secretIDPath := makeTempFile(t, "secret_id.txt", secretID+"\n") defer os.Remove(roleIDPath) defer os.Remove(secretIDPath) sinkf, err := os.CreateTemp("", "sink.test.") if err != nil { t.Fatal(err) } sink := sinkf.Name() sinkf.Close() os.Remove(sink) autoAuthConfig := fmt.Sprintf(` auto_auth { method "approle" { mount_path = "auth/approle" config = { role_id_file_path = "%s" secret_id_file_path = "%s" } } sink "file" { config = { path = "%s" } } }`, roleIDPath, secretIDPath, sink) listenAddr := generateListenerAddress(t) listenConfig := fmt.Sprintf(` listener "tcp" { address = "%s" tls_disable = true } `, listenAddr) config := fmt.Sprintf(` vault { address = "%s" tls_skip_verify = true } api_proxy { use_auto_auth_token = true } %s %s `, serverClient.Address(), listenConfig, autoAuthConfig) configPath := makeTempFile(t, "config.hcl", config) defer os.Remove(configPath) // Unset the environment variable so that proxy picks up the right test // cluster address defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress)) os.Unsetenv(api.EnvVaultAddress) // Start proxy _, cmd := testProxyCommand(t, logger) cmd.startedCh = make(chan struct{}) wg := &sync.WaitGroup{} wg.Add(1) go func() { cmd.Run([]string{"-config", configPath}) wg.Done() }() select { case <-cmd.startedCh: case <-time.After(5 * time.Second): t.Errorf("timeout") } // Validate that the auto-auth token has been correctly attained // and works for LookupSelf conf := api.DefaultConfig() conf.Address = "http://" + listenAddr proxyClient, err := api.NewClient(conf) if err != nil { t.Fatalf("err: %s", err) } proxyClient.SetToken("") err = proxyClient.SetAddress("http://" + listenAddr) if err != nil { t.Fatal(err) } // Wait for the token to be sent to syncs and be available to be used time.Sleep(5 * time.Second) req = proxyClient.NewRequest("GET", "/v1/auth/token/lookup-self") body = request(t, proxyClient, req, 200) close(cmd.ShutdownCh) wg.Wait() } // TestProxy_APIProxyWithoutCache_UserAgent tests that the User-Agent sent // to Vault by Vault Proxy is correct using the API proxy without // the cache configured. Uses the custom handler // userAgentHandler struct defined in this test package, so that Vault validates the // User-Agent on requests sent by Proxy. func TestProxy_APIProxyWithoutCache_UserAgent(t *testing.T) { logger := logging.NewVaultLogger(hclog.Trace) userAgentForProxiedClient := "proxied-client" var h userAgentHandler cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ NumCores: 1, HandlerFunc: vaulthttp.HandlerFunc( func(properties *vault.HandlerProperties) http.Handler { h.props = properties h.userAgentToCheckFor = useragent.ProxyStringWithProxiedUserAgent(userAgentForProxiedClient) h.pathToCheck = "/v1/auth/token/lookup-self" h.requestMethodToCheck = "GET" h.t = t return &h }), }) cluster.Start() defer cluster.Cleanup() serverClient := cluster.Cores[0].Client // Unset the environment variable so that proxy picks up the right test // cluster address defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress)) os.Unsetenv(api.EnvVaultAddress) listenAddr := generateListenerAddress(t) listenConfig := fmt.Sprintf(` listener "tcp" { address = "%s" tls_disable = true } `, listenAddr) config := fmt.Sprintf(` vault { address = "%s" tls_skip_verify = true } %s `, serverClient.Address(), listenConfig) configPath := makeTempFile(t, "config.hcl", config) defer os.Remove(configPath) // Start the proxy _, cmd := testProxyCommand(t, logger) cmd.startedCh = make(chan struct{}) wg := &sync.WaitGroup{} wg.Add(1) go func() { cmd.Run([]string{"-config", configPath}) wg.Done() }() select { case <-cmd.startedCh: case <-time.After(5 * time.Second): t.Errorf("timeout") } proxyClient, err := api.NewClient(api.DefaultConfig()) if err != nil { t.Fatal(err) } proxyClient.AddHeader("User-Agent", userAgentForProxiedClient) proxyClient.SetToken(serverClient.Token()) proxyClient.SetMaxRetries(0) err = proxyClient.SetAddress("http://" + listenAddr) if err != nil { t.Fatal(err) } _, err = proxyClient.Auth().Token().LookupSelf() if err != nil { t.Fatal(err) } close(cmd.ShutdownCh) wg.Wait() } // TestProxy_APIProxyWithCache_UserAgent tests that the User-Agent sent // to Vault by Vault Proxy is correct using the API proxy with // the cache configured. Uses the custom handler // userAgentHandler struct defined in this test package, so that Vault validates the // User-Agent on requests sent by Proxy. func TestProxy_APIProxyWithCache_UserAgent(t *testing.T) { logger := logging.NewVaultLogger(hclog.Trace) userAgentForProxiedClient := "proxied-client" var h userAgentHandler cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ NumCores: 1, HandlerFunc: vaulthttp.HandlerFunc( func(properties *vault.HandlerProperties) http.Handler { h.props = properties h.userAgentToCheckFor = useragent.ProxyStringWithProxiedUserAgent(userAgentForProxiedClient) h.pathToCheck = "/v1/auth/token/lookup-self" h.requestMethodToCheck = "GET" h.t = t return &h }), }) cluster.Start() defer cluster.Cleanup() serverClient := cluster.Cores[0].Client // Unset the environment variable so that proxy picks up the right test // cluster address defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress)) os.Unsetenv(api.EnvVaultAddress) listenAddr := generateListenerAddress(t) listenConfig := fmt.Sprintf(` listener "tcp" { address = "%s" tls_disable = true } `, listenAddr) cacheConfig := ` cache { }` config := fmt.Sprintf(` vault { address = "%s" tls_skip_verify = true } %s %s `, serverClient.Address(), listenConfig, cacheConfig) configPath := makeTempFile(t, "config.hcl", config) defer os.Remove(configPath) // Start the proxy _, cmd := testProxyCommand(t, logger) cmd.startedCh = make(chan struct{}) wg := &sync.WaitGroup{} wg.Add(1) go func() { cmd.Run([]string{"-config", configPath}) wg.Done() }() select { case <-cmd.startedCh: case <-time.After(5 * time.Second): t.Errorf("timeout") } proxyClient, err := api.NewClient(api.DefaultConfig()) if err != nil { t.Fatal(err) } proxyClient.AddHeader("User-Agent", userAgentForProxiedClient) proxyClient.SetToken(serverClient.Token()) proxyClient.SetMaxRetries(0) err = proxyClient.SetAddress("http://" + listenAddr) if err != nil { t.Fatal(err) } _, err = proxyClient.Auth().Token().LookupSelf() if err != nil { t.Fatal(err) } close(cmd.ShutdownCh) wg.Wait() } // TestProxy_Cache_DynamicSecret Tests that the cache successfully caches a dynamic secret // going through the Proxy, func TestProxy_Cache_DynamicSecret(t *testing.T) { logger := logging.NewVaultLogger(hclog.Trace) cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ HandlerFunc: vaulthttp.Handler, }) cluster.Start() defer cluster.Cleanup() serverClient := cluster.Cores[0].Client // Unset the environment variable so that proxy picks up the right test // cluster address defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress)) os.Unsetenv(api.EnvVaultAddress) cacheConfig := ` cache { } ` listenAddr := generateListenerAddress(t) listenConfig := fmt.Sprintf(` listener "tcp" { address = "%s" tls_disable = true } `, listenAddr) config := fmt.Sprintf(` vault { address = "%s" tls_skip_verify = true } %s %s `, serverClient.Address(), cacheConfig, listenConfig) configPath := makeTempFile(t, "config.hcl", config) defer os.Remove(configPath) // Start proxy _, cmd := testProxyCommand(t, logger) cmd.startedCh = make(chan struct{}) wg := &sync.WaitGroup{} wg.Add(1) go func() { cmd.Run([]string{"-config", configPath}) wg.Done() }() select { case <-cmd.startedCh: case <-time.After(5 * time.Second): t.Errorf("timeout") } proxyClient, err := api.NewClient(api.DefaultConfig()) if err != nil { t.Fatal(err) } proxyClient.SetToken(serverClient.Token()) proxyClient.SetMaxRetries(0) err = proxyClient.SetAddress("http://" + listenAddr) if err != nil { t.Fatal(err) } renewable := true tokenCreateRequest := &api.TokenCreateRequest{ Policies: []string{"default"}, TTL: "30m", Renewable: &renewable, } // This was the simplest test I could find to trigger the caching behaviour, // i.e. the most concise I could make the test that I can tell // creating an orphan token returns Auth, is renewable, and isn't a token // that's managed elsewhere (since it's an orphan) secret, err := proxyClient.Auth().Token().CreateOrphan(tokenCreateRequest) if err != nil { t.Fatal(err) } if secret == nil || secret.Auth == nil { t.Fatalf("secret not as expected: %v", secret) } token := secret.Auth.ClientToken secret, err = proxyClient.Auth().Token().CreateOrphan(tokenCreateRequest) if err != nil { t.Fatal(err) } if secret == nil || secret.Auth == nil { t.Fatalf("secret not as expected: %v", secret) } token2 := secret.Auth.ClientToken if token != token2 { t.Fatalf("token create response not cached when it should have been, as tokens differ") } close(cmd.ShutdownCh) wg.Wait() } // TestProxy_ApiProxy_Retry Tests the retry functionalities of Vault Proxy's API Proxy func TestProxy_ApiProxy_Retry(t *testing.T) { //---------------------------------------------------- // Start the server and proxy //---------------------------------------------------- logger := logging.NewVaultLogger(hclog.Trace) var h handler cluster := vault.NewTestCluster(t, &vault.CoreConfig{ Logger: logger, CredentialBackends: map[string]logical.Factory{ "approle": credAppRole.Factory, }, LogicalBackends: map[string]logical.Factory{ "kv": logicalKv.Factory, }, }, &vault.TestClusterOptions{ NumCores: 1, HandlerFunc: vaulthttp.HandlerFunc(func(properties *vault.HandlerProperties) http.Handler { h.props = properties h.t = t return &h }), }) cluster.Start() defer cluster.Cleanup() vault.TestWaitActive(t, cluster.Cores[0].Core) serverClient := cluster.Cores[0].Client // Unset the environment variable so that proxy picks up the right test // cluster address defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress)) os.Unsetenv(api.EnvVaultAddress) _, err := serverClient.Logical().Write("secret/foo", map[string]interface{}{ "bar": "baz", }) if err != nil { t.Fatal(err) } intRef := func(i int) *int { return &i } // start test cases here testCases := map[string]struct { retries *int expectError bool }{ "none": { retries: intRef(-1), expectError: true, }, "one": { retries: intRef(1), expectError: true, }, "two": { retries: intRef(2), expectError: false, }, "missing": { retries: nil, expectError: false, }, "default": { retries: intRef(0), expectError: false, }, } for tcname, tc := range testCases { t.Run(tcname, func(t *testing.T) { h.failCount = 2 cacheConfig := ` cache { } ` listenAddr := generateListenerAddress(t) listenConfig := fmt.Sprintf(` listener "tcp" { address = "%s" tls_disable = true } `, listenAddr) var retryConf string if tc.retries != nil { retryConf = fmt.Sprintf("retry { num_retries = %d }", *tc.retries) } config := fmt.Sprintf(` vault { address = "%s" %s tls_skip_verify = true } %s %s `, serverClient.Address(), retryConf, cacheConfig, listenConfig) configPath := makeTempFile(t, "config.hcl", config) defer os.Remove(configPath) _, cmd := testProxyCommand(t, logger) cmd.startedCh = make(chan struct{}) wg := &sync.WaitGroup{} wg.Add(1) go func() { cmd.Run([]string{"-config", configPath}) wg.Done() }() select { case <-cmd.startedCh: case <-time.After(5 * time.Second): t.Errorf("timeout") } client, err := api.NewClient(api.DefaultConfig()) if err != nil { t.Fatal(err) } client.SetToken(serverClient.Token()) client.SetMaxRetries(0) err = client.SetAddress("http://" + listenAddr) if err != nil { t.Fatal(err) } secret, err := client.Logical().Read("secret/foo") switch { case (err != nil || secret == nil) && tc.expectError: case (err == nil || secret != nil) && !tc.expectError: default: t.Fatalf("%s expectError=%v error=%v secret=%v", tcname, tc.expectError, err, secret) } if secret != nil && secret.Data["foo"] != nil { val := secret.Data["foo"].(map[string]interface{}) if !reflect.DeepEqual(val, map[string]interface{}{"bar": "baz"}) { t.Fatalf("expected key 'foo' to yield bar=baz, got: %v", val) } } time.Sleep(time.Second) close(cmd.ShutdownCh) wg.Wait() }) } } // TestProxy_Metrics tests that metrics are being properly reported. func TestProxy_Metrics(t *testing.T) { // Start a vault server logger := logging.NewVaultLogger(hclog.Trace) cluster := vault.NewTestCluster(t, &vault.CoreConfig{ Logger: logger, }, &vault.TestClusterOptions{ HandlerFunc: vaulthttp.Handler, }) cluster.Start() defer cluster.Cleanup() vault.TestWaitActive(t, cluster.Cores[0].Core) serverClient := cluster.Cores[0].Client // Create a config file listenAddr := generateListenerAddress(t) config := fmt.Sprintf(` cache {} listener "tcp" { address = "%s" tls_disable = true } `, listenAddr) configPath := makeTempFile(t, "config.hcl", config) defer os.Remove(configPath) ui, cmd := testProxyCommand(t, logger) cmd.client = serverClient cmd.startedCh = make(chan struct{}) wg := &sync.WaitGroup{} wg.Add(1) go func() { code := cmd.Run([]string{"-config", configPath}) if code != 0 { t.Errorf("non-zero return code when running proxy: %d", code) t.Logf("STDOUT from proxy:\n%s", ui.OutputWriter.String()) t.Logf("STDERR from proxy:\n%s", ui.ErrorWriter.String()) } wg.Done() }() select { case <-cmd.startedCh: case <-time.After(5 * time.Second): t.Errorf("timeout") } // defer proxy shutdown defer func() { cmd.ShutdownCh <- struct{}{} wg.Wait() }() conf := api.DefaultConfig() conf.Address = "http://" + listenAddr proxyClient, err := api.NewClient(conf) if err != nil { t.Fatalf("err: %s", err) } req := proxyClient.NewRequest("GET", "/proxy/v1/metrics") body := request(t, proxyClient, req, 200) keys := []string{} for k := range body { keys = append(keys, k) } require.ElementsMatch(t, keys, []string{ "Counters", "Samples", "Timestamp", "Gauges", "Points", }) } // TestProxy_QuitAPI Tests the /proxy/v1/quit API that can be enabled for the proxy. func TestProxy_QuitAPI(t *testing.T) { logger := logging.NewVaultLogger(hclog.Error) cluster := vault.NewTestCluster(t, &vault.CoreConfig{ Logger: logger, CredentialBackends: map[string]logical.Factory{ "approle": credAppRole.Factory, }, LogicalBackends: map[string]logical.Factory{ "kv": logicalKv.Factory, }, }, &vault.TestClusterOptions{ NumCores: 1, }) cluster.Start() defer cluster.Cleanup() vault.TestWaitActive(t, cluster.Cores[0].Core) serverClient := cluster.Cores[0].Client // Unset the environment variable so that proxy picks up the right test // cluster address defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress)) err := os.Unsetenv(api.EnvVaultAddress) if err != nil { t.Fatal(err) } listenAddr := generateListenerAddress(t) listenAddr2 := generateListenerAddress(t) config := fmt.Sprintf(` vault { address = "%s" tls_skip_verify = true } listener "tcp" { address = "%s" tls_disable = true } listener "tcp" { address = "%s" tls_disable = true proxy_api { enable_quit = true } } cache {} `, serverClient.Address(), listenAddr, listenAddr2) configPath := makeTempFile(t, "config.hcl", config) defer os.Remove(configPath) _, cmd := testProxyCommand(t, logger) cmd.startedCh = make(chan struct{}) wg := &sync.WaitGroup{} wg.Add(1) go func() { cmd.Run([]string{"-config", configPath}) wg.Done() }() select { case <-cmd.startedCh: case <-time.After(5 * time.Second): t.Errorf("timeout") } client, err := api.NewClient(api.DefaultConfig()) if err != nil { t.Fatal(err) } client.SetToken(serverClient.Token()) client.SetMaxRetries(0) err = client.SetAddress("http://" + listenAddr) if err != nil { t.Fatal(err) } // First try on listener 1 where the API should be disabled. resp, err := client.RawRequest(client.NewRequest(http.MethodPost, "/proxy/v1/quit")) if err == nil { t.Fatalf("expected error") } if resp != nil && resp.StatusCode != http.StatusNotFound { t.Fatalf("expected %d but got: %d", http.StatusNotFound, resp.StatusCode) } // Now try on listener 2 where the quit API should be enabled. err = client.SetAddress("http://" + listenAddr2) if err != nil { t.Fatal(err) } _, err = client.RawRequest(client.NewRequest(http.MethodPost, "/proxy/v1/quit")) if err != nil { t.Fatalf("unexpected error: %s", err) } select { case <-cmd.ShutdownCh: case <-time.After(5 * time.Second): t.Errorf("timeout") } wg.Wait() } // TestProxy_LogFile_CliOverridesConfig tests that the CLI values // override the config for log files func TestProxy_LogFile_CliOverridesConfig(t *testing.T) { // Create basic config configFile := populateTempFile(t, "proxy-config.hcl", BasicHclConfig) cfg, err := proxyConfig.LoadConfigFile(configFile.Name()) if err != nil { t.Fatal("Cannot load config to test update/merge", err) } // Sanity check that the config value is the current value assert.Equal(t, "TMPDIR/juan.log", cfg.LogFile) // Initialize the command and parse any flags cmd := &ProxyCommand{BaseCommand: &BaseCommand{}} f := cmd.Flags() // Simulate the flag being specified err = f.Parse([]string{"-log-file=/foo/bar/test.log"}) if err != nil { t.Fatal(err) } // Update the config based on the inputs. cmd.applyConfigOverrides(f, cfg) assert.NotEqual(t, "TMPDIR/juan.log", cfg.LogFile) assert.NotEqual(t, "/squiggle/logs.txt", cfg.LogFile) assert.Equal(t, "/foo/bar/test.log", cfg.LogFile) } // TestProxy_LogFile_Config tests log file config when loaded from config func TestProxy_LogFile_Config(t *testing.T) { configFile := populateTempFile(t, "proxy-config.hcl", BasicHclConfig) cfg, err := proxyConfig.LoadConfigFile(configFile.Name()) if err != nil { t.Fatal("Cannot load config to test update/merge", err) } // Sanity check that the config value is the current value assert.Equal(t, "TMPDIR/juan.log", cfg.LogFile, "sanity check on log config failed") assert.Equal(t, 2, cfg.LogRotateMaxFiles) assert.Equal(t, 1048576, cfg.LogRotateBytes) // Parse the cli flags (but we pass in an empty slice) cmd := &ProxyCommand{BaseCommand: &BaseCommand{}} f := cmd.Flags() err = f.Parse([]string{}) if err != nil { t.Fatal(err) } // Should change nothing... cmd.applyConfigOverrides(f, cfg) assert.Equal(t, "TMPDIR/juan.log", cfg.LogFile, "actual config check") assert.Equal(t, 2, cfg.LogRotateMaxFiles) assert.Equal(t, 1048576, cfg.LogRotateBytes) } // TestProxy_Config_NewLogger_Default Tests defaults for log level and // specifically cmd.newLogger() func TestProxy_Config_NewLogger_Default(t *testing.T) { cmd := &ProxyCommand{BaseCommand: &BaseCommand{}} cmd.config = proxyConfig.NewConfig() logger, err := cmd.newLogger() assert.NoError(t, err) assert.NotNil(t, logger) assert.Equal(t, hclog.Info.String(), logger.GetLevel().String()) } // TestProxy_Config_ReloadLogLevel Tests reloading updates the log // level as expected. func TestProxy_Config_ReloadLogLevel(t *testing.T) { cmd := &ProxyCommand{BaseCommand: &BaseCommand{}} var err error tempDir := t.TempDir() // Load an initial config hcl := strings.ReplaceAll(BasicHclConfig, "TMPDIR", tempDir) configFile := populateTempFile(t, "proxy-config.hcl", hcl) cmd.config, err = proxyConfig.LoadConfigFile(configFile.Name()) if err != nil { t.Fatal("Cannot load config to test update/merge", err) } // Tweak the loaded config to make sure we can put log files into a temp dir // and systemd log attempts work fine, this would usually happen during Run. cmd.logWriter = os.Stdout cmd.logger, err = cmd.newLogger() if err != nil { t.Fatal("logger required for systemd log messages", err) } // Sanity check assert.Equal(t, "warn", cmd.config.LogLevel) // Load a new config hcl = strings.ReplaceAll(BasicHclConfig2, "TMPDIR", tempDir) configFile = populateTempFile(t, "proxy-config.hcl", hcl) err = cmd.reloadConfig([]string{configFile.Name()}) assert.NoError(t, err) assert.Equal(t, "debug", cmd.config.LogLevel) } // TestProxy_Config_ReloadTls Tests that the TLS certs for the listener are // correctly reloaded. func TestProxy_Config_ReloadTls(t *testing.T) { var wg sync.WaitGroup wd, err := os.Getwd() if err != nil { t.Fatal("unable to get current working directory") } workingDir := filepath.Join(wd, "/proxy/test-fixtures/reload") fooCert := "reload_foo.pem" fooKey := "reload_foo.key" barCert := "reload_bar.pem" barKey := "reload_bar.key" reloadCert := "reload_cert.pem" reloadKey := "reload_key.pem" caPem := "reload_ca.pem" tempDir := t.TempDir() // Set up initial 'foo' certs inBytes, err := os.ReadFile(filepath.Join(workingDir, fooCert)) if err != nil { t.Fatal("unable to read cert required for test", fooCert, err) } err = os.WriteFile(filepath.Join(tempDir, reloadCert), inBytes, 0o777) if err != nil { t.Fatal("unable to write temp cert required for test", reloadCert, err) } inBytes, err = os.ReadFile(filepath.Join(workingDir, fooKey)) if err != nil { t.Fatal("unable to read cert key required for test", fooKey, err) } err = os.WriteFile(filepath.Join(tempDir, reloadKey), inBytes, 0o777) if err != nil { t.Fatal("unable to write temp cert key required for test", reloadKey, err) } inBytes, err = os.ReadFile(filepath.Join(workingDir, caPem)) if err != nil { t.Fatal("unable to read CA pem required for test", caPem, err) } certPool := x509.NewCertPool() ok := certPool.AppendCertsFromPEM(inBytes) if !ok { t.Fatal("not ok when appending CA cert") } replacedHcl := strings.ReplaceAll(BasicHclConfig, "TMPDIR", tempDir) configFile := populateTempFile(t, "proxy-config.hcl", replacedHcl) // Set up Proxy logger := logging.NewVaultLogger(hclog.Trace) ui, cmd := testProxyCommand(t, logger) var output string var code int wg.Add(1) args := []string{"-config", configFile.Name()} go func() { if code = cmd.Run(args); code != 0 { output = ui.ErrorWriter.String() + ui.OutputWriter.String() } wg.Done() }() testCertificateName := func(cn string) error { conn, err := tls.Dial("tcp", "127.0.0.1:8100", &tls.Config{ RootCAs: certPool, }) if err != nil { return err } defer conn.Close() if err = conn.Handshake(); err != nil { return err } servName := conn.ConnectionState().PeerCertificates[0].Subject.CommonName if servName != cn { return fmt.Errorf("expected %s, got %s", cn, servName) } return nil } // Start select { case <-cmd.startedCh: case <-time.After(5 * time.Second): t.Fatalf("timeout") } if err := testCertificateName("foo.example.com"); err != nil { t.Fatalf("certificate name didn't check out: %s", err) } // Swap out certs inBytes, err = os.ReadFile(filepath.Join(workingDir, barCert)) if err != nil { t.Fatal("unable to read cert required for test", barCert, err) } err = os.WriteFile(filepath.Join(tempDir, reloadCert), inBytes, 0o777) if err != nil { t.Fatal("unable to write temp cert required for test", reloadCert, err) } inBytes, err = os.ReadFile(filepath.Join(workingDir, barKey)) if err != nil { t.Fatal("unable to read cert key required for test", barKey, err) } err = os.WriteFile(filepath.Join(tempDir, reloadKey), inBytes, 0o777) if err != nil { t.Fatal("unable to write temp cert key required for test", reloadKey, err) } // Reload cmd.SighupCh <- struct{}{} select { case <-cmd.reloadedCh: case <-time.After(5 * time.Second): t.Fatalf("timeout") } if err := testCertificateName("bar.example.com"); err != nil { t.Fatalf("certificate name didn't check out: %s", err) } // Shut down cmd.ShutdownCh <- struct{}{} wg.Wait() if code != 0 { t.Fatalf("got a non-zero exit status: %d, stdout/stderr: %s", code, output) } }