e5bee669e4
Co-authored-by: Hamid Ghaf <83242695+hghaf099@users.noreply.github.com>
1275 lines
32 KiB
Go
1275 lines
32 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package command
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"reflect"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/hashicorp/go-hclog"
|
|
vaultjwt "github.com/hashicorp/vault-plugin-auth-jwt"
|
|
logicalKv "github.com/hashicorp/vault-plugin-secrets-kv"
|
|
"github.com/hashicorp/vault/api"
|
|
credAppRole "github.com/hashicorp/vault/builtin/credential/approle"
|
|
"github.com/hashicorp/vault/command/agent"
|
|
proxyConfig "github.com/hashicorp/vault/command/proxy/config"
|
|
"github.com/hashicorp/vault/helper/useragent"
|
|
vaulthttp "github.com/hashicorp/vault/http"
|
|
"github.com/hashicorp/vault/sdk/helper/logging"
|
|
"github.com/hashicorp/vault/sdk/logical"
|
|
"github.com/hashicorp/vault/vault"
|
|
"github.com/mitchellh/cli"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func testProxyCommand(tb testing.TB, logger hclog.Logger) (*cli.MockUi, *ProxyCommand) {
|
|
tb.Helper()
|
|
|
|
ui := cli.NewMockUi()
|
|
return ui, &ProxyCommand{
|
|
BaseCommand: &BaseCommand{
|
|
UI: ui,
|
|
},
|
|
ShutdownCh: MakeShutdownCh(),
|
|
SighupCh: MakeSighupCh(),
|
|
logger: logger,
|
|
startedCh: make(chan struct{}, 5),
|
|
reloadedCh: make(chan struct{}, 5),
|
|
}
|
|
}
|
|
|
|
// TestProxy_ExitAfterAuth tests the exit_after_auth flag, provided both
|
|
// as config and via -exit-after-auth.
|
|
func TestProxy_ExitAfterAuth(t *testing.T) {
|
|
t.Run("via_config", func(t *testing.T) {
|
|
testProxyExitAfterAuth(t, false)
|
|
})
|
|
|
|
t.Run("via_flag", func(t *testing.T) {
|
|
testProxyExitAfterAuth(t, true)
|
|
})
|
|
}
|
|
|
|
func testProxyExitAfterAuth(t *testing.T, viaFlag bool) {
|
|
logger := logging.NewVaultLogger(hclog.Trace)
|
|
coreConfig := &vault.CoreConfig{
|
|
Logger: logger,
|
|
CredentialBackends: map[string]logical.Factory{
|
|
"jwt": vaultjwt.Factory,
|
|
},
|
|
}
|
|
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
|
|
HandlerFunc: vaulthttp.Handler,
|
|
})
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
|
|
vault.TestWaitActive(t, cluster.Cores[0].Core)
|
|
client := cluster.Cores[0].Client
|
|
|
|
// Setup Vault
|
|
err := client.Sys().EnableAuthWithOptions("jwt", &api.EnableAuthOptions{
|
|
Type: "jwt",
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, err = client.Logical().Write("auth/jwt/config", map[string]interface{}{
|
|
"bound_issuer": "https://team-vault.auth0.com/",
|
|
"jwt_validation_pubkeys": agent.TestECDSAPubKey,
|
|
"jwt_supported_algs": "ES256",
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, err = client.Logical().Write("auth/jwt/role/test", map[string]interface{}{
|
|
"role_type": "jwt",
|
|
"bound_subject": "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
|
|
"bound_audiences": "https://vault.plugin.auth.jwt.test",
|
|
"user_claim": "https://vault/user",
|
|
"groups_claim": "https://vault/groups",
|
|
"policies": "test",
|
|
"period": "3s",
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
dir := t.TempDir()
|
|
inf, err := os.CreateTemp(dir, "auth.jwt.test.")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
in := inf.Name()
|
|
inf.Close()
|
|
// We remove these files in this test since we don't need the files, we just need
|
|
// a non-conflicting file name for the config.
|
|
os.Remove(in)
|
|
t.Logf("input: %s", in)
|
|
|
|
sink1f, err := os.CreateTemp(dir, "sink1.jwt.test.")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
sink1 := sink1f.Name()
|
|
sink1f.Close()
|
|
os.Remove(sink1)
|
|
t.Logf("sink1: %s", sink1)
|
|
|
|
sink2f, err := os.CreateTemp(dir, "sink2.jwt.test.")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
sink2 := sink2f.Name()
|
|
sink2f.Close()
|
|
os.Remove(sink2)
|
|
t.Logf("sink2: %s", sink2)
|
|
|
|
conff, err := os.CreateTemp(dir, "conf.jwt.test.")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
conf := conff.Name()
|
|
conff.Close()
|
|
os.Remove(conf)
|
|
t.Logf("config: %s", conf)
|
|
|
|
jwtToken, _ := agent.GetTestJWT(t)
|
|
if err := os.WriteFile(in, []byte(jwtToken), 0o600); err != nil {
|
|
t.Fatal(err)
|
|
} else {
|
|
logger.Trace("wrote test jwt", "path", in)
|
|
}
|
|
|
|
exitAfterAuthTemplText := "exit_after_auth = true"
|
|
if viaFlag {
|
|
exitAfterAuthTemplText = ""
|
|
}
|
|
|
|
config := `
|
|
%s
|
|
|
|
auto_auth {
|
|
method {
|
|
type = "jwt"
|
|
config = {
|
|
role = "test"
|
|
path = "%s"
|
|
}
|
|
}
|
|
|
|
sink {
|
|
type = "file"
|
|
config = {
|
|
path = "%s"
|
|
}
|
|
}
|
|
|
|
sink "file" {
|
|
config = {
|
|
path = "%s"
|
|
}
|
|
}
|
|
}
|
|
`
|
|
|
|
config = fmt.Sprintf(config, exitAfterAuthTemplText, in, sink1, sink2)
|
|
if err := os.WriteFile(conf, []byte(config), 0o600); err != nil {
|
|
t.Fatal(err)
|
|
} else {
|
|
logger.Trace("wrote test config", "path", conf)
|
|
}
|
|
|
|
doneCh := make(chan struct{})
|
|
go func() {
|
|
ui, cmd := testProxyCommand(t, logger)
|
|
cmd.client = client
|
|
|
|
args := []string{"-config", conf}
|
|
if viaFlag {
|
|
args = append(args, "-exit-after-auth")
|
|
}
|
|
|
|
code := cmd.Run(args)
|
|
if code != 0 {
|
|
t.Errorf("expected %d to be %d", code, 0)
|
|
t.Logf("output from proxy:\n%s", ui.OutputWriter.String())
|
|
t.Logf("error from proxy:\n%s", ui.ErrorWriter.String())
|
|
}
|
|
close(doneCh)
|
|
}()
|
|
|
|
select {
|
|
case <-doneCh:
|
|
break
|
|
case <-time.After(1 * time.Minute):
|
|
t.Fatal("timeout reached while waiting for proxy to exit")
|
|
}
|
|
|
|
sink1Bytes, err := os.ReadFile(sink1)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(sink1Bytes) == 0 {
|
|
t.Fatal("got no output from sink 1")
|
|
}
|
|
|
|
sink2Bytes, err := os.ReadFile(sink2)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(sink2Bytes) == 0 {
|
|
t.Fatal("got no output from sink 2")
|
|
}
|
|
|
|
if string(sink1Bytes) != string(sink2Bytes) {
|
|
t.Fatal("sink 1/2 values don't match")
|
|
}
|
|
}
|
|
|
|
// TestProxy_AutoAuth_UserAgent tests that the User-Agent sent
|
|
// to Vault by Vault Proxy is correct when performing Auto-Auth.
|
|
// Uses the custom handler userAgentHandler (defined above) so
|
|
// that Vault validates the User-Agent on requests sent by Proxy.
|
|
func TestProxy_AutoAuth_UserAgent(t *testing.T) {
|
|
logger := logging.NewVaultLogger(hclog.Trace)
|
|
var h userAgentHandler
|
|
cluster := vault.NewTestCluster(t, &vault.CoreConfig{
|
|
Logger: logger,
|
|
CredentialBackends: map[string]logical.Factory{
|
|
"approle": credAppRole.Factory,
|
|
},
|
|
}, &vault.TestClusterOptions{
|
|
NumCores: 1,
|
|
HandlerFunc: vaulthttp.HandlerFunc(
|
|
func(properties *vault.HandlerProperties) http.Handler {
|
|
h.props = properties
|
|
h.userAgentToCheckFor = useragent.ProxyAutoAuthString()
|
|
h.requestMethodToCheck = "PUT"
|
|
h.pathToCheck = "auth/approle/login"
|
|
h.t = t
|
|
return &h
|
|
}),
|
|
})
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
|
|
serverClient := cluster.Cores[0].Client
|
|
|
|
// Enable the approle auth method
|
|
req := serverClient.NewRequest("POST", "/v1/sys/auth/approle")
|
|
req.BodyBytes = []byte(`{
|
|
"type": "approle"
|
|
}`)
|
|
request(t, serverClient, req, 204)
|
|
|
|
// Create a named role
|
|
req = serverClient.NewRequest("PUT", "/v1/auth/approle/role/test-role")
|
|
req.BodyBytes = []byte(`{
|
|
"secret_id_num_uses": "10",
|
|
"secret_id_ttl": "1m",
|
|
"token_max_ttl": "1m",
|
|
"token_num_uses": "10",
|
|
"token_ttl": "1m",
|
|
"policies": "default"
|
|
}`)
|
|
request(t, serverClient, req, 204)
|
|
|
|
// Fetch the RoleID of the named role
|
|
req = serverClient.NewRequest("GET", "/v1/auth/approle/role/test-role/role-id")
|
|
body := request(t, serverClient, req, 200)
|
|
data := body["data"].(map[string]interface{})
|
|
roleID := data["role_id"].(string)
|
|
|
|
// Get a SecretID issued against the named role
|
|
req = serverClient.NewRequest("PUT", "/v1/auth/approle/role/test-role/secret-id")
|
|
body = request(t, serverClient, req, 200)
|
|
data = body["data"].(map[string]interface{})
|
|
secretID := data["secret_id"].(string)
|
|
|
|
// Write the RoleID and SecretID to temp files
|
|
roleIDPath := makeTempFile(t, "role_id.txt", roleID+"\n")
|
|
secretIDPath := makeTempFile(t, "secret_id.txt", secretID+"\n")
|
|
defer os.Remove(roleIDPath)
|
|
defer os.Remove(secretIDPath)
|
|
|
|
sinkf, err := os.CreateTemp("", "sink.test.")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
sink := sinkf.Name()
|
|
sinkf.Close()
|
|
os.Remove(sink)
|
|
|
|
autoAuthConfig := fmt.Sprintf(`
|
|
auto_auth {
|
|
method "approle" {
|
|
mount_path = "auth/approle"
|
|
config = {
|
|
role_id_file_path = "%s"
|
|
secret_id_file_path = "%s"
|
|
}
|
|
}
|
|
|
|
sink "file" {
|
|
config = {
|
|
path = "%s"
|
|
}
|
|
}
|
|
}`, roleIDPath, secretIDPath, sink)
|
|
|
|
listenAddr := generateListenerAddress(t)
|
|
listenConfig := fmt.Sprintf(`
|
|
listener "tcp" {
|
|
address = "%s"
|
|
tls_disable = true
|
|
}
|
|
`, listenAddr)
|
|
|
|
config := fmt.Sprintf(`
|
|
vault {
|
|
address = "%s"
|
|
tls_skip_verify = true
|
|
}
|
|
api_proxy {
|
|
use_auto_auth_token = true
|
|
}
|
|
%s
|
|
%s
|
|
`, serverClient.Address(), listenConfig, autoAuthConfig)
|
|
configPath := makeTempFile(t, "config.hcl", config)
|
|
defer os.Remove(configPath)
|
|
|
|
// Unset the environment variable so that proxy picks up the right test
|
|
// cluster address
|
|
defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress))
|
|
os.Unsetenv(api.EnvVaultAddress)
|
|
|
|
// Start proxy
|
|
_, cmd := testProxyCommand(t, logger)
|
|
cmd.startedCh = make(chan struct{})
|
|
|
|
wg := &sync.WaitGroup{}
|
|
wg.Add(1)
|
|
go func() {
|
|
cmd.Run([]string{"-config", configPath})
|
|
wg.Done()
|
|
}()
|
|
|
|
select {
|
|
case <-cmd.startedCh:
|
|
case <-time.After(5 * time.Second):
|
|
t.Errorf("timeout")
|
|
}
|
|
|
|
// Validate that the auto-auth token has been correctly attained
|
|
// and works for LookupSelf
|
|
conf := api.DefaultConfig()
|
|
conf.Address = "http://" + listenAddr
|
|
proxyClient, err := api.NewClient(conf)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
|
|
proxyClient.SetToken("")
|
|
err = proxyClient.SetAddress("http://" + listenAddr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Wait for the token to be sent to syncs and be available to be used
|
|
time.Sleep(5 * time.Second)
|
|
|
|
req = proxyClient.NewRequest("GET", "/v1/auth/token/lookup-self")
|
|
body = request(t, proxyClient, req, 200)
|
|
|
|
close(cmd.ShutdownCh)
|
|
wg.Wait()
|
|
}
|
|
|
|
// TestProxy_APIProxyWithoutCache_UserAgent tests that the User-Agent sent
|
|
// to Vault by Vault Proxy is correct using the API proxy without
|
|
// the cache configured. Uses the custom handler
|
|
// userAgentHandler struct defined in this test package, so that Vault validates the
|
|
// User-Agent on requests sent by Proxy.
|
|
func TestProxy_APIProxyWithoutCache_UserAgent(t *testing.T) {
|
|
logger := logging.NewVaultLogger(hclog.Trace)
|
|
userAgentForProxiedClient := "proxied-client"
|
|
var h userAgentHandler
|
|
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
|
|
NumCores: 1,
|
|
HandlerFunc: vaulthttp.HandlerFunc(
|
|
func(properties *vault.HandlerProperties) http.Handler {
|
|
h.props = properties
|
|
h.userAgentToCheckFor = useragent.ProxyStringWithProxiedUserAgent(userAgentForProxiedClient)
|
|
h.pathToCheck = "/v1/auth/token/lookup-self"
|
|
h.requestMethodToCheck = "GET"
|
|
h.t = t
|
|
return &h
|
|
}),
|
|
})
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
|
|
serverClient := cluster.Cores[0].Client
|
|
|
|
// Unset the environment variable so that proxy picks up the right test
|
|
// cluster address
|
|
defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress))
|
|
os.Unsetenv(api.EnvVaultAddress)
|
|
|
|
listenAddr := generateListenerAddress(t)
|
|
listenConfig := fmt.Sprintf(`
|
|
listener "tcp" {
|
|
address = "%s"
|
|
tls_disable = true
|
|
}
|
|
`, listenAddr)
|
|
|
|
config := fmt.Sprintf(`
|
|
vault {
|
|
address = "%s"
|
|
tls_skip_verify = true
|
|
}
|
|
%s
|
|
`, serverClient.Address(), listenConfig)
|
|
configPath := makeTempFile(t, "config.hcl", config)
|
|
defer os.Remove(configPath)
|
|
|
|
// Start the proxy
|
|
_, cmd := testProxyCommand(t, logger)
|
|
cmd.startedCh = make(chan struct{})
|
|
|
|
wg := &sync.WaitGroup{}
|
|
wg.Add(1)
|
|
go func() {
|
|
cmd.Run([]string{"-config", configPath})
|
|
wg.Done()
|
|
}()
|
|
|
|
select {
|
|
case <-cmd.startedCh:
|
|
case <-time.After(5 * time.Second):
|
|
t.Errorf("timeout")
|
|
}
|
|
|
|
proxyClient, err := api.NewClient(api.DefaultConfig())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
proxyClient.AddHeader("User-Agent", userAgentForProxiedClient)
|
|
proxyClient.SetToken(serverClient.Token())
|
|
proxyClient.SetMaxRetries(0)
|
|
err = proxyClient.SetAddress("http://" + listenAddr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, err = proxyClient.Auth().Token().LookupSelf()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
close(cmd.ShutdownCh)
|
|
wg.Wait()
|
|
}
|
|
|
|
// TestProxy_APIProxyWithCache_UserAgent tests that the User-Agent sent
|
|
// to Vault by Vault Proxy is correct using the API proxy with
|
|
// the cache configured. Uses the custom handler
|
|
// userAgentHandler struct defined in this test package, so that Vault validates the
|
|
// User-Agent on requests sent by Proxy.
|
|
func TestProxy_APIProxyWithCache_UserAgent(t *testing.T) {
|
|
logger := logging.NewVaultLogger(hclog.Trace)
|
|
userAgentForProxiedClient := "proxied-client"
|
|
var h userAgentHandler
|
|
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
|
|
NumCores: 1,
|
|
HandlerFunc: vaulthttp.HandlerFunc(
|
|
func(properties *vault.HandlerProperties) http.Handler {
|
|
h.props = properties
|
|
h.userAgentToCheckFor = useragent.ProxyStringWithProxiedUserAgent(userAgentForProxiedClient)
|
|
h.pathToCheck = "/v1/auth/token/lookup-self"
|
|
h.requestMethodToCheck = "GET"
|
|
h.t = t
|
|
return &h
|
|
}),
|
|
})
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
|
|
serverClient := cluster.Cores[0].Client
|
|
|
|
// Unset the environment variable so that proxy picks up the right test
|
|
// cluster address
|
|
defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress))
|
|
os.Unsetenv(api.EnvVaultAddress)
|
|
|
|
listenAddr := generateListenerAddress(t)
|
|
listenConfig := fmt.Sprintf(`
|
|
listener "tcp" {
|
|
address = "%s"
|
|
tls_disable = true
|
|
}
|
|
`, listenAddr)
|
|
|
|
cacheConfig := `
|
|
cache {
|
|
}`
|
|
|
|
config := fmt.Sprintf(`
|
|
vault {
|
|
address = "%s"
|
|
tls_skip_verify = true
|
|
}
|
|
%s
|
|
%s
|
|
`, serverClient.Address(), listenConfig, cacheConfig)
|
|
configPath := makeTempFile(t, "config.hcl", config)
|
|
defer os.Remove(configPath)
|
|
|
|
// Start the proxy
|
|
_, cmd := testProxyCommand(t, logger)
|
|
cmd.startedCh = make(chan struct{})
|
|
|
|
wg := &sync.WaitGroup{}
|
|
wg.Add(1)
|
|
go func() {
|
|
cmd.Run([]string{"-config", configPath})
|
|
wg.Done()
|
|
}()
|
|
|
|
select {
|
|
case <-cmd.startedCh:
|
|
case <-time.After(5 * time.Second):
|
|
t.Errorf("timeout")
|
|
}
|
|
|
|
proxyClient, err := api.NewClient(api.DefaultConfig())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
proxyClient.AddHeader("User-Agent", userAgentForProxiedClient)
|
|
proxyClient.SetToken(serverClient.Token())
|
|
proxyClient.SetMaxRetries(0)
|
|
err = proxyClient.SetAddress("http://" + listenAddr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, err = proxyClient.Auth().Token().LookupSelf()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
close(cmd.ShutdownCh)
|
|
wg.Wait()
|
|
}
|
|
|
|
// TestProxy_Cache_DynamicSecret Tests that the cache successfully caches a dynamic secret
|
|
// going through the Proxy,
|
|
func TestProxy_Cache_DynamicSecret(t *testing.T) {
|
|
logger := logging.NewVaultLogger(hclog.Trace)
|
|
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
|
|
HandlerFunc: vaulthttp.Handler,
|
|
})
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
|
|
serverClient := cluster.Cores[0].Client
|
|
|
|
// Unset the environment variable so that proxy picks up the right test
|
|
// cluster address
|
|
defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress))
|
|
os.Unsetenv(api.EnvVaultAddress)
|
|
|
|
cacheConfig := `
|
|
cache {
|
|
}
|
|
`
|
|
listenAddr := generateListenerAddress(t)
|
|
listenConfig := fmt.Sprintf(`
|
|
listener "tcp" {
|
|
address = "%s"
|
|
tls_disable = true
|
|
}
|
|
`, listenAddr)
|
|
|
|
config := fmt.Sprintf(`
|
|
vault {
|
|
address = "%s"
|
|
tls_skip_verify = true
|
|
}
|
|
%s
|
|
%s
|
|
`, serverClient.Address(), cacheConfig, listenConfig)
|
|
configPath := makeTempFile(t, "config.hcl", config)
|
|
defer os.Remove(configPath)
|
|
|
|
// Start proxy
|
|
_, cmd := testProxyCommand(t, logger)
|
|
cmd.startedCh = make(chan struct{})
|
|
|
|
wg := &sync.WaitGroup{}
|
|
wg.Add(1)
|
|
go func() {
|
|
cmd.Run([]string{"-config", configPath})
|
|
wg.Done()
|
|
}()
|
|
|
|
select {
|
|
case <-cmd.startedCh:
|
|
case <-time.After(5 * time.Second):
|
|
t.Errorf("timeout")
|
|
}
|
|
|
|
proxyClient, err := api.NewClient(api.DefaultConfig())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
proxyClient.SetToken(serverClient.Token())
|
|
proxyClient.SetMaxRetries(0)
|
|
err = proxyClient.SetAddress("http://" + listenAddr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
renewable := true
|
|
tokenCreateRequest := &api.TokenCreateRequest{
|
|
Policies: []string{"default"},
|
|
TTL: "30m",
|
|
Renewable: &renewable,
|
|
}
|
|
|
|
// This was the simplest test I could find to trigger the caching behaviour,
|
|
// i.e. the most concise I could make the test that I can tell
|
|
// creating an orphan token returns Auth, is renewable, and isn't a token
|
|
// that's managed elsewhere (since it's an orphan)
|
|
secret, err := proxyClient.Auth().Token().CreateOrphan(tokenCreateRequest)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if secret == nil || secret.Auth == nil {
|
|
t.Fatalf("secret not as expected: %v", secret)
|
|
}
|
|
|
|
token := secret.Auth.ClientToken
|
|
|
|
secret, err = proxyClient.Auth().Token().CreateOrphan(tokenCreateRequest)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if secret == nil || secret.Auth == nil {
|
|
t.Fatalf("secret not as expected: %v", secret)
|
|
}
|
|
|
|
token2 := secret.Auth.ClientToken
|
|
|
|
if token != token2 {
|
|
t.Fatalf("token create response not cached when it should have been, as tokens differ")
|
|
}
|
|
|
|
close(cmd.ShutdownCh)
|
|
wg.Wait()
|
|
}
|
|
|
|
// TestProxy_ApiProxy_Retry Tests the retry functionalities of Vault Proxy's API Proxy
|
|
func TestProxy_ApiProxy_Retry(t *testing.T) {
|
|
//----------------------------------------------------
|
|
// Start the server and proxy
|
|
//----------------------------------------------------
|
|
logger := logging.NewVaultLogger(hclog.Trace)
|
|
var h handler
|
|
cluster := vault.NewTestCluster(t,
|
|
&vault.CoreConfig{
|
|
Logger: logger,
|
|
CredentialBackends: map[string]logical.Factory{
|
|
"approle": credAppRole.Factory,
|
|
},
|
|
LogicalBackends: map[string]logical.Factory{
|
|
"kv": logicalKv.Factory,
|
|
},
|
|
},
|
|
&vault.TestClusterOptions{
|
|
NumCores: 1,
|
|
HandlerFunc: vaulthttp.HandlerFunc(func(properties *vault.HandlerProperties) http.Handler {
|
|
h.props = properties
|
|
h.t = t
|
|
return &h
|
|
}),
|
|
})
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
|
|
vault.TestWaitActive(t, cluster.Cores[0].Core)
|
|
serverClient := cluster.Cores[0].Client
|
|
|
|
// Unset the environment variable so that proxy picks up the right test
|
|
// cluster address
|
|
defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress))
|
|
os.Unsetenv(api.EnvVaultAddress)
|
|
|
|
_, err := serverClient.Logical().Write("secret/foo", map[string]interface{}{
|
|
"bar": "baz",
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
intRef := func(i int) *int {
|
|
return &i
|
|
}
|
|
// start test cases here
|
|
testCases := map[string]struct {
|
|
retries *int
|
|
expectError bool
|
|
}{
|
|
"none": {
|
|
retries: intRef(-1),
|
|
expectError: true,
|
|
},
|
|
"one": {
|
|
retries: intRef(1),
|
|
expectError: true,
|
|
},
|
|
"two": {
|
|
retries: intRef(2),
|
|
expectError: false,
|
|
},
|
|
"missing": {
|
|
retries: nil,
|
|
expectError: false,
|
|
},
|
|
"default": {
|
|
retries: intRef(0),
|
|
expectError: false,
|
|
},
|
|
}
|
|
|
|
for tcname, tc := range testCases {
|
|
t.Run(tcname, func(t *testing.T) {
|
|
h.failCount = 2
|
|
|
|
cacheConfig := `
|
|
cache {
|
|
}
|
|
`
|
|
listenAddr := generateListenerAddress(t)
|
|
listenConfig := fmt.Sprintf(`
|
|
listener "tcp" {
|
|
address = "%s"
|
|
tls_disable = true
|
|
}
|
|
`, listenAddr)
|
|
|
|
var retryConf string
|
|
if tc.retries != nil {
|
|
retryConf = fmt.Sprintf("retry { num_retries = %d }", *tc.retries)
|
|
}
|
|
|
|
config := fmt.Sprintf(`
|
|
vault {
|
|
address = "%s"
|
|
%s
|
|
tls_skip_verify = true
|
|
}
|
|
%s
|
|
%s
|
|
`, serverClient.Address(), retryConf, cacheConfig, listenConfig)
|
|
configPath := makeTempFile(t, "config.hcl", config)
|
|
defer os.Remove(configPath)
|
|
|
|
_, cmd := testProxyCommand(t, logger)
|
|
cmd.startedCh = make(chan struct{})
|
|
|
|
wg := &sync.WaitGroup{}
|
|
wg.Add(1)
|
|
go func() {
|
|
cmd.Run([]string{"-config", configPath})
|
|
wg.Done()
|
|
}()
|
|
|
|
select {
|
|
case <-cmd.startedCh:
|
|
case <-time.After(5 * time.Second):
|
|
t.Errorf("timeout")
|
|
}
|
|
|
|
client, err := api.NewClient(api.DefaultConfig())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
client.SetToken(serverClient.Token())
|
|
client.SetMaxRetries(0)
|
|
err = client.SetAddress("http://" + listenAddr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
secret, err := client.Logical().Read("secret/foo")
|
|
switch {
|
|
case (err != nil || secret == nil) && tc.expectError:
|
|
case (err == nil || secret != nil) && !tc.expectError:
|
|
default:
|
|
t.Fatalf("%s expectError=%v error=%v secret=%v", tcname, tc.expectError, err, secret)
|
|
}
|
|
if secret != nil && secret.Data["foo"] != nil {
|
|
val := secret.Data["foo"].(map[string]interface{})
|
|
if !reflect.DeepEqual(val, map[string]interface{}{"bar": "baz"}) {
|
|
t.Fatalf("expected key 'foo' to yield bar=baz, got: %v", val)
|
|
}
|
|
}
|
|
time.Sleep(time.Second)
|
|
|
|
close(cmd.ShutdownCh)
|
|
wg.Wait()
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestProxy_Metrics tests that metrics are being properly reported.
|
|
func TestProxy_Metrics(t *testing.T) {
|
|
// Start a vault server
|
|
logger := logging.NewVaultLogger(hclog.Trace)
|
|
cluster := vault.NewTestCluster(t,
|
|
&vault.CoreConfig{
|
|
Logger: logger,
|
|
},
|
|
&vault.TestClusterOptions{
|
|
HandlerFunc: vaulthttp.Handler,
|
|
})
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
vault.TestWaitActive(t, cluster.Cores[0].Core)
|
|
serverClient := cluster.Cores[0].Client
|
|
|
|
// Create a config file
|
|
listenAddr := generateListenerAddress(t)
|
|
config := fmt.Sprintf(`
|
|
cache {}
|
|
|
|
listener "tcp" {
|
|
address = "%s"
|
|
tls_disable = true
|
|
}
|
|
`, listenAddr)
|
|
configPath := makeTempFile(t, "config.hcl", config)
|
|
defer os.Remove(configPath)
|
|
|
|
ui, cmd := testProxyCommand(t, logger)
|
|
cmd.client = serverClient
|
|
cmd.startedCh = make(chan struct{})
|
|
|
|
wg := &sync.WaitGroup{}
|
|
wg.Add(1)
|
|
go func() {
|
|
code := cmd.Run([]string{"-config", configPath})
|
|
if code != 0 {
|
|
t.Errorf("non-zero return code when running proxy: %d", code)
|
|
t.Logf("STDOUT from proxy:\n%s", ui.OutputWriter.String())
|
|
t.Logf("STDERR from proxy:\n%s", ui.ErrorWriter.String())
|
|
}
|
|
wg.Done()
|
|
}()
|
|
|
|
select {
|
|
case <-cmd.startedCh:
|
|
case <-time.After(5 * time.Second):
|
|
t.Errorf("timeout")
|
|
}
|
|
|
|
// defer proxy shutdown
|
|
defer func() {
|
|
cmd.ShutdownCh <- struct{}{}
|
|
wg.Wait()
|
|
}()
|
|
|
|
conf := api.DefaultConfig()
|
|
conf.Address = "http://" + listenAddr
|
|
proxyClient, err := api.NewClient(conf)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
|
|
req := proxyClient.NewRequest("GET", "/proxy/v1/metrics")
|
|
body := request(t, proxyClient, req, 200)
|
|
keys := []string{}
|
|
for k := range body {
|
|
keys = append(keys, k)
|
|
}
|
|
require.ElementsMatch(t, keys, []string{
|
|
"Counters",
|
|
"Samples",
|
|
"Timestamp",
|
|
"Gauges",
|
|
"Points",
|
|
})
|
|
}
|
|
|
|
// TestProxy_QuitAPI Tests the /proxy/v1/quit API that can be enabled for the proxy.
|
|
func TestProxy_QuitAPI(t *testing.T) {
|
|
logger := logging.NewVaultLogger(hclog.Error)
|
|
cluster := vault.NewTestCluster(t,
|
|
&vault.CoreConfig{
|
|
Logger: logger,
|
|
CredentialBackends: map[string]logical.Factory{
|
|
"approle": credAppRole.Factory,
|
|
},
|
|
LogicalBackends: map[string]logical.Factory{
|
|
"kv": logicalKv.Factory,
|
|
},
|
|
},
|
|
&vault.TestClusterOptions{
|
|
NumCores: 1,
|
|
})
|
|
cluster.Start()
|
|
defer cluster.Cleanup()
|
|
|
|
vault.TestWaitActive(t, cluster.Cores[0].Core)
|
|
serverClient := cluster.Cores[0].Client
|
|
|
|
// Unset the environment variable so that proxy picks up the right test
|
|
// cluster address
|
|
defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress))
|
|
err := os.Unsetenv(api.EnvVaultAddress)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
listenAddr := generateListenerAddress(t)
|
|
listenAddr2 := generateListenerAddress(t)
|
|
config := fmt.Sprintf(`
|
|
vault {
|
|
address = "%s"
|
|
tls_skip_verify = true
|
|
}
|
|
|
|
listener "tcp" {
|
|
address = "%s"
|
|
tls_disable = true
|
|
}
|
|
|
|
listener "tcp" {
|
|
address = "%s"
|
|
tls_disable = true
|
|
proxy_api {
|
|
enable_quit = true
|
|
}
|
|
}
|
|
|
|
cache {}
|
|
`, serverClient.Address(), listenAddr, listenAddr2)
|
|
|
|
configPath := makeTempFile(t, "config.hcl", config)
|
|
defer os.Remove(configPath)
|
|
|
|
_, cmd := testProxyCommand(t, logger)
|
|
cmd.startedCh = make(chan struct{})
|
|
|
|
wg := &sync.WaitGroup{}
|
|
wg.Add(1)
|
|
go func() {
|
|
cmd.Run([]string{"-config", configPath})
|
|
wg.Done()
|
|
}()
|
|
|
|
select {
|
|
case <-cmd.startedCh:
|
|
case <-time.After(5 * time.Second):
|
|
t.Errorf("timeout")
|
|
}
|
|
client, err := api.NewClient(api.DefaultConfig())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
client.SetToken(serverClient.Token())
|
|
client.SetMaxRetries(0)
|
|
err = client.SetAddress("http://" + listenAddr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// First try on listener 1 where the API should be disabled.
|
|
resp, err := client.RawRequest(client.NewRequest(http.MethodPost, "/proxy/v1/quit"))
|
|
if err == nil {
|
|
t.Fatalf("expected error")
|
|
}
|
|
if resp != nil && resp.StatusCode != http.StatusNotFound {
|
|
t.Fatalf("expected %d but got: %d", http.StatusNotFound, resp.StatusCode)
|
|
}
|
|
|
|
// Now try on listener 2 where the quit API should be enabled.
|
|
err = client.SetAddress("http://" + listenAddr2)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, err = client.RawRequest(client.NewRequest(http.MethodPost, "/proxy/v1/quit"))
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %s", err)
|
|
}
|
|
|
|
select {
|
|
case <-cmd.ShutdownCh:
|
|
case <-time.After(5 * time.Second):
|
|
t.Errorf("timeout")
|
|
}
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
// TestProxy_LogFile_CliOverridesConfig tests that the CLI values
|
|
// override the config for log files
|
|
func TestProxy_LogFile_CliOverridesConfig(t *testing.T) {
|
|
// Create basic config
|
|
configFile := populateTempFile(t, "proxy-config.hcl", BasicHclConfig)
|
|
cfg, err := proxyConfig.LoadConfigFile(configFile.Name())
|
|
if err != nil {
|
|
t.Fatal("Cannot load config to test update/merge", err)
|
|
}
|
|
|
|
// Sanity check that the config value is the current value
|
|
assert.Equal(t, "TMPDIR/juan.log", cfg.LogFile)
|
|
|
|
// Initialize the command and parse any flags
|
|
cmd := &ProxyCommand{BaseCommand: &BaseCommand{}}
|
|
f := cmd.Flags()
|
|
// Simulate the flag being specified
|
|
err = f.Parse([]string{"-log-file=/foo/bar/test.log"})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Update the config based on the inputs.
|
|
cmd.applyConfigOverrides(f, cfg)
|
|
|
|
assert.NotEqual(t, "TMPDIR/juan.log", cfg.LogFile)
|
|
assert.NotEqual(t, "/squiggle/logs.txt", cfg.LogFile)
|
|
assert.Equal(t, "/foo/bar/test.log", cfg.LogFile)
|
|
}
|
|
|
|
// TestProxy_LogFile_Config tests log file config when loaded from config
|
|
func TestProxy_LogFile_Config(t *testing.T) {
|
|
configFile := populateTempFile(t, "proxy-config.hcl", BasicHclConfig)
|
|
|
|
cfg, err := proxyConfig.LoadConfigFile(configFile.Name())
|
|
if err != nil {
|
|
t.Fatal("Cannot load config to test update/merge", err)
|
|
}
|
|
|
|
// Sanity check that the config value is the current value
|
|
assert.Equal(t, "TMPDIR/juan.log", cfg.LogFile, "sanity check on log config failed")
|
|
assert.Equal(t, 2, cfg.LogRotateMaxFiles)
|
|
assert.Equal(t, 1048576, cfg.LogRotateBytes)
|
|
|
|
// Parse the cli flags (but we pass in an empty slice)
|
|
cmd := &ProxyCommand{BaseCommand: &BaseCommand{}}
|
|
f := cmd.Flags()
|
|
err = f.Parse([]string{})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Should change nothing...
|
|
cmd.applyConfigOverrides(f, cfg)
|
|
|
|
assert.Equal(t, "TMPDIR/juan.log", cfg.LogFile, "actual config check")
|
|
assert.Equal(t, 2, cfg.LogRotateMaxFiles)
|
|
assert.Equal(t, 1048576, cfg.LogRotateBytes)
|
|
}
|
|
|
|
// TestProxy_Config_NewLogger_Default Tests defaults for log level and
|
|
// specifically cmd.newLogger()
|
|
func TestProxy_Config_NewLogger_Default(t *testing.T) {
|
|
cmd := &ProxyCommand{BaseCommand: &BaseCommand{}}
|
|
cmd.config = proxyConfig.NewConfig()
|
|
logger, err := cmd.newLogger()
|
|
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, logger)
|
|
assert.Equal(t, hclog.Info.String(), logger.GetLevel().String())
|
|
}
|
|
|
|
// TestProxy_Config_ReloadLogLevel Tests reloading updates the log
|
|
// level as expected.
|
|
func TestProxy_Config_ReloadLogLevel(t *testing.T) {
|
|
cmd := &ProxyCommand{BaseCommand: &BaseCommand{}}
|
|
var err error
|
|
tempDir := t.TempDir()
|
|
|
|
// Load an initial config
|
|
hcl := strings.ReplaceAll(BasicHclConfig, "TMPDIR", tempDir)
|
|
configFile := populateTempFile(t, "proxy-config.hcl", hcl)
|
|
cmd.config, err = proxyConfig.LoadConfigFile(configFile.Name())
|
|
if err != nil {
|
|
t.Fatal("Cannot load config to test update/merge", err)
|
|
}
|
|
|
|
// Tweak the loaded config to make sure we can put log files into a temp dir
|
|
// and systemd log attempts work fine, this would usually happen during Run.
|
|
cmd.logWriter = os.Stdout
|
|
cmd.logger, err = cmd.newLogger()
|
|
if err != nil {
|
|
t.Fatal("logger required for systemd log messages", err)
|
|
}
|
|
|
|
// Sanity check
|
|
assert.Equal(t, "warn", cmd.config.LogLevel)
|
|
|
|
// Load a new config
|
|
hcl = strings.ReplaceAll(BasicHclConfig2, "TMPDIR", tempDir)
|
|
configFile = populateTempFile(t, "proxy-config.hcl", hcl)
|
|
err = cmd.reloadConfig([]string{configFile.Name()})
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "debug", cmd.config.LogLevel)
|
|
}
|
|
|
|
// TestProxy_Config_ReloadTls Tests that the TLS certs for the listener are
|
|
// correctly reloaded.
|
|
func TestProxy_Config_ReloadTls(t *testing.T) {
|
|
var wg sync.WaitGroup
|
|
wd, err := os.Getwd()
|
|
if err != nil {
|
|
t.Fatal("unable to get current working directory")
|
|
}
|
|
workingDir := filepath.Join(wd, "/proxy/test-fixtures/reload")
|
|
fooCert := "reload_foo.pem"
|
|
fooKey := "reload_foo.key"
|
|
|
|
barCert := "reload_bar.pem"
|
|
barKey := "reload_bar.key"
|
|
|
|
reloadCert := "reload_cert.pem"
|
|
reloadKey := "reload_key.pem"
|
|
caPem := "reload_ca.pem"
|
|
|
|
tempDir := t.TempDir()
|
|
|
|
// Set up initial 'foo' certs
|
|
inBytes, err := os.ReadFile(filepath.Join(workingDir, fooCert))
|
|
if err != nil {
|
|
t.Fatal("unable to read cert required for test", fooCert, err)
|
|
}
|
|
err = os.WriteFile(filepath.Join(tempDir, reloadCert), inBytes, 0o777)
|
|
if err != nil {
|
|
t.Fatal("unable to write temp cert required for test", reloadCert, err)
|
|
}
|
|
|
|
inBytes, err = os.ReadFile(filepath.Join(workingDir, fooKey))
|
|
if err != nil {
|
|
t.Fatal("unable to read cert key required for test", fooKey, err)
|
|
}
|
|
err = os.WriteFile(filepath.Join(tempDir, reloadKey), inBytes, 0o777)
|
|
if err != nil {
|
|
t.Fatal("unable to write temp cert key required for test", reloadKey, err)
|
|
}
|
|
|
|
inBytes, err = os.ReadFile(filepath.Join(workingDir, caPem))
|
|
if err != nil {
|
|
t.Fatal("unable to read CA pem required for test", caPem, err)
|
|
}
|
|
certPool := x509.NewCertPool()
|
|
ok := certPool.AppendCertsFromPEM(inBytes)
|
|
if !ok {
|
|
t.Fatal("not ok when appending CA cert")
|
|
}
|
|
|
|
replacedHcl := strings.ReplaceAll(BasicHclConfig, "TMPDIR", tempDir)
|
|
configFile := populateTempFile(t, "proxy-config.hcl", replacedHcl)
|
|
|
|
// Set up Proxy
|
|
logger := logging.NewVaultLogger(hclog.Trace)
|
|
ui, cmd := testProxyCommand(t, logger)
|
|
|
|
var output string
|
|
var code int
|
|
wg.Add(1)
|
|
args := []string{"-config", configFile.Name()}
|
|
go func() {
|
|
if code = cmd.Run(args); code != 0 {
|
|
output = ui.ErrorWriter.String() + ui.OutputWriter.String()
|
|
}
|
|
wg.Done()
|
|
}()
|
|
|
|
testCertificateName := func(cn string) error {
|
|
conn, err := tls.Dial("tcp", "127.0.0.1:8100", &tls.Config{
|
|
RootCAs: certPool,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer conn.Close()
|
|
if err = conn.Handshake(); err != nil {
|
|
return err
|
|
}
|
|
servName := conn.ConnectionState().PeerCertificates[0].Subject.CommonName
|
|
if servName != cn {
|
|
return fmt.Errorf("expected %s, got %s", cn, servName)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Start
|
|
select {
|
|
case <-cmd.startedCh:
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
|
|
if err := testCertificateName("foo.example.com"); err != nil {
|
|
t.Fatalf("certificate name didn't check out: %s", err)
|
|
}
|
|
|
|
// Swap out certs
|
|
inBytes, err = os.ReadFile(filepath.Join(workingDir, barCert))
|
|
if err != nil {
|
|
t.Fatal("unable to read cert required for test", barCert, err)
|
|
}
|
|
err = os.WriteFile(filepath.Join(tempDir, reloadCert), inBytes, 0o777)
|
|
if err != nil {
|
|
t.Fatal("unable to write temp cert required for test", reloadCert, err)
|
|
}
|
|
|
|
inBytes, err = os.ReadFile(filepath.Join(workingDir, barKey))
|
|
if err != nil {
|
|
t.Fatal("unable to read cert key required for test", barKey, err)
|
|
}
|
|
err = os.WriteFile(filepath.Join(tempDir, reloadKey), inBytes, 0o777)
|
|
if err != nil {
|
|
t.Fatal("unable to write temp cert key required for test", reloadKey, err)
|
|
}
|
|
|
|
// Reload
|
|
cmd.SighupCh <- struct{}{}
|
|
select {
|
|
case <-cmd.reloadedCh:
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
|
|
if err := testCertificateName("bar.example.com"); err != nil {
|
|
t.Fatalf("certificate name didn't check out: %s", err)
|
|
}
|
|
|
|
// Shut down
|
|
cmd.ShutdownCh <- struct{}{}
|
|
wg.Wait()
|
|
|
|
if code != 0 {
|
|
t.Fatalf("got a non-zero exit status: %d, stdout/stderr: %s", code, output)
|
|
}
|
|
}
|