396 lines
10 KiB
Go
396 lines
10 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
//go:build !race && !hsm && !fips_140_3
|
|
|
|
// NOTE: we can't use this with HSM. We can't set testing mode on and it's not
|
|
// safe to use env vars since that provides an attack vector in the real world.
|
|
//
|
|
// The server tests have a go-metrics/exp manager race condition :(.
|
|
|
|
package command
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/hashicorp/vault/sdk/physical"
|
|
physInmem "github.com/hashicorp/vault/sdk/physical/inmem"
|
|
"github.com/mitchellh/cli"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func init() {
|
|
if signed := os.Getenv("VAULT_LICENSE_CI"); signed != "" {
|
|
os.Setenv(EnvVaultLicense, signed)
|
|
}
|
|
}
|
|
|
|
func testBaseHCL(tb testing.TB, listenerExtras string) string {
|
|
tb.Helper()
|
|
|
|
return strings.TrimSpace(fmt.Sprintf(`
|
|
disable_mlock = true
|
|
listener "tcp" {
|
|
address = "127.0.0.1:%d"
|
|
tls_disable = "true"
|
|
%s
|
|
}
|
|
`, 0, listenerExtras))
|
|
}
|
|
|
|
const (
|
|
goodListenerTimeouts = `http_read_header_timeout = 12
|
|
http_read_timeout = "34s"
|
|
http_write_timeout = "56m"
|
|
http_idle_timeout = "78h"`
|
|
|
|
badListenerReadHeaderTimeout = `http_read_header_timeout = "12km"`
|
|
badListenerReadTimeout = `http_read_timeout = "34日"`
|
|
badListenerWriteTimeout = `http_write_timeout = "56lbs"`
|
|
badListenerIdleTimeout = `http_idle_timeout = "78gophers"`
|
|
|
|
inmemHCL = `
|
|
backend "inmem_ha" {
|
|
advertise_addr = "http://127.0.0.1:8200"
|
|
}
|
|
`
|
|
haInmemHCL = `
|
|
ha_backend "inmem_ha" {
|
|
redirect_addr = "http://127.0.0.1:8200"
|
|
}
|
|
`
|
|
|
|
badHAInmemHCL = `
|
|
ha_backend "inmem" {}
|
|
`
|
|
|
|
reloadHCL = `
|
|
backend "inmem" {}
|
|
disable_mlock = true
|
|
listener "tcp" {
|
|
address = "127.0.0.1:8203"
|
|
tls_cert_file = "TMPDIR/reload_cert.pem"
|
|
tls_key_file = "TMPDIR/reload_key.pem"
|
|
}
|
|
`
|
|
cloudHCL = `
|
|
cloud {
|
|
resource_id = "organization/bc58b3d0-2eab-4ab8-abf4-f61d3c9975ff/project/1c78e888-2142-4000-8918-f933bbbc7690/hashicorp.example.resource/example"
|
|
client_id = "J2TtcSYOyPUkPV2z0mSyDtvitxLVjJmu"
|
|
client_secret = "N9JtHZyOnHrIvJZs82pqa54vd4jnkyU3xCcqhFXuQKJZZuxqxxbP1xCfBZVB82vY"
|
|
}
|
|
`
|
|
)
|
|
|
|
func testServerCommand(tb testing.TB) (*cli.MockUi, *ServerCommand) {
|
|
tb.Helper()
|
|
|
|
ui := cli.NewMockUi()
|
|
return ui, &ServerCommand{
|
|
BaseCommand: &BaseCommand{
|
|
UI: ui,
|
|
},
|
|
ShutdownCh: MakeShutdownCh(),
|
|
SighupCh: MakeSighupCh(),
|
|
SigUSR2Ch: MakeSigUSR2Ch(),
|
|
PhysicalBackends: map[string]physical.Factory{
|
|
"inmem": physInmem.NewInmem,
|
|
"inmem_ha": physInmem.NewInmemHA,
|
|
},
|
|
|
|
// These prevent us from random sleep guessing...
|
|
startedCh: make(chan struct{}, 5),
|
|
reloadedCh: make(chan struct{}, 5),
|
|
licenseReloadedCh: make(chan error),
|
|
}
|
|
}
|
|
|
|
func TestServer_ReloadListener(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
wd, _ := os.Getwd()
|
|
wd += "/server/test-fixtures/reload/"
|
|
|
|
td, err := ioutil.TempDir("", "vault-test-")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer os.RemoveAll(td)
|
|
|
|
wg := &sync.WaitGroup{}
|
|
// Setup initial certs
|
|
inBytes, _ := ioutil.ReadFile(wd + "reload_foo.pem")
|
|
ioutil.WriteFile(td+"/reload_cert.pem", inBytes, 0o777)
|
|
inBytes, _ = ioutil.ReadFile(wd + "reload_foo.key")
|
|
ioutil.WriteFile(td+"/reload_key.pem", inBytes, 0o777)
|
|
|
|
relhcl := strings.ReplaceAll(reloadHCL, "TMPDIR", td)
|
|
ioutil.WriteFile(td+"/reload.hcl", []byte(relhcl), 0o777)
|
|
|
|
inBytes, _ = ioutil.ReadFile(wd + "reload_ca.pem")
|
|
certPool := x509.NewCertPool()
|
|
ok := certPool.AppendCertsFromPEM(inBytes)
|
|
if !ok {
|
|
t.Fatal("not ok when appending CA cert")
|
|
}
|
|
|
|
ui, cmd := testServerCommand(t)
|
|
_ = ui
|
|
|
|
wg.Add(1)
|
|
args := []string{"-config", td + "/reload.hcl"}
|
|
go func() {
|
|
if code := cmd.Run(args); code != 0 {
|
|
output := ui.ErrorWriter.String() + ui.OutputWriter.String()
|
|
t.Errorf("got a non-zero exit status: %s", output)
|
|
}
|
|
wg.Done()
|
|
}()
|
|
|
|
testCertificateName := func(cn string) error {
|
|
conn, err := tls.Dial("tcp", "127.0.0.1:8203", &tls.Config{
|
|
RootCAs: certPool,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer conn.Close()
|
|
if err = conn.Handshake(); err != nil {
|
|
return err
|
|
}
|
|
servName := conn.ConnectionState().PeerCertificates[0].Subject.CommonName
|
|
if servName != cn {
|
|
return fmt.Errorf("expected %s, got %s", cn, servName)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
select {
|
|
case <-cmd.startedCh:
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
|
|
if err := testCertificateName("foo.example.com"); err != nil {
|
|
t.Fatalf("certificate name didn't check out: %s", err)
|
|
}
|
|
|
|
relhcl = strings.ReplaceAll(reloadHCL, "TMPDIR", td)
|
|
inBytes, _ = ioutil.ReadFile(wd + "reload_bar.pem")
|
|
ioutil.WriteFile(td+"/reload_cert.pem", inBytes, 0o777)
|
|
inBytes, _ = ioutil.ReadFile(wd + "reload_bar.key")
|
|
ioutil.WriteFile(td+"/reload_key.pem", inBytes, 0o777)
|
|
ioutil.WriteFile(td+"/reload.hcl", []byte(relhcl), 0o777)
|
|
|
|
cmd.SighupCh <- struct{}{}
|
|
select {
|
|
case <-cmd.reloadedCh:
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
|
|
if err := testCertificateName("bar.example.com"); err != nil {
|
|
t.Fatalf("certificate name didn't check out: %s", err)
|
|
}
|
|
|
|
cmd.ShutdownCh <- struct{}{}
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestServer(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
cases := []struct {
|
|
name string
|
|
contents string
|
|
exp string
|
|
code int
|
|
args []string
|
|
}{
|
|
{
|
|
"common_ha",
|
|
testBaseHCL(t, "") + inmemHCL,
|
|
"(HA available)",
|
|
0,
|
|
[]string{"-test-verify-only"},
|
|
},
|
|
{
|
|
"separate_ha",
|
|
testBaseHCL(t, "") + inmemHCL + haInmemHCL,
|
|
"HA Storage:",
|
|
0,
|
|
[]string{"-test-verify-only"},
|
|
},
|
|
{
|
|
"bad_separate_ha",
|
|
testBaseHCL(t, "") + inmemHCL + badHAInmemHCL,
|
|
"Specified HA storage does not support HA",
|
|
1,
|
|
[]string{"-test-verify-only"},
|
|
},
|
|
{
|
|
"good_listener_timeout_config",
|
|
testBaseHCL(t, goodListenerTimeouts) + inmemHCL,
|
|
"",
|
|
0,
|
|
[]string{"-test-server-config"},
|
|
},
|
|
{
|
|
"bad_listener_read_header_timeout_config",
|
|
testBaseHCL(t, badListenerReadHeaderTimeout) + inmemHCL,
|
|
"unknown unit \"km\" in duration \"12km\"",
|
|
1,
|
|
[]string{"-test-server-config"},
|
|
},
|
|
{
|
|
"bad_listener_read_timeout_config",
|
|
testBaseHCL(t, badListenerReadTimeout) + inmemHCL,
|
|
"unknown unit \"\\xe6\\x97\\xa5\" in duration",
|
|
1,
|
|
[]string{"-test-server-config"},
|
|
},
|
|
{
|
|
"bad_listener_write_timeout_config",
|
|
testBaseHCL(t, badListenerWriteTimeout) + inmemHCL,
|
|
"unknown unit \"lbs\" in duration \"56lbs\"",
|
|
1,
|
|
[]string{"-test-server-config"},
|
|
},
|
|
{
|
|
"bad_listener_idle_timeout_config",
|
|
testBaseHCL(t, badListenerIdleTimeout) + inmemHCL,
|
|
"unknown unit \"gophers\" in duration \"78gophers\"",
|
|
1,
|
|
[]string{"-test-server-config"},
|
|
},
|
|
{
|
|
"environment_variables_logged",
|
|
testBaseHCL(t, "") + inmemHCL,
|
|
"Environment Variables",
|
|
0,
|
|
[]string{"-test-verify-only"},
|
|
},
|
|
{
|
|
"cloud_config",
|
|
testBaseHCL(t, "") + inmemHCL + cloudHCL,
|
|
"HCP Organization: bc58b3d0-2eab-4ab8-abf4-f61d3c9975ff",
|
|
0,
|
|
[]string{"-test-verify-only"},
|
|
},
|
|
{
|
|
"recovery_mode",
|
|
testBaseHCL(t, "") + inmemHCL,
|
|
"",
|
|
0,
|
|
[]string{"-test-verify-only", "-recovery"},
|
|
},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
tc := tc
|
|
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ui, cmd := testServerCommand(t)
|
|
|
|
f, err := os.CreateTemp(t.TempDir(), "")
|
|
require.NoErrorf(t, err, "error creating temp dir: %v", err)
|
|
|
|
_, err = f.WriteString(tc.contents)
|
|
require.NoErrorf(t, err, "cannot write temp file contents")
|
|
|
|
err = f.Close()
|
|
require.NoErrorf(t, err, "unable to close temp file")
|
|
|
|
args := append(tc.args, "-config", f.Name())
|
|
code := cmd.Run(args)
|
|
output := ui.ErrorWriter.String() + ui.OutputWriter.String()
|
|
require.Equal(t, tc.code, code, "expected %d to be %d: %s", code, tc.code, output)
|
|
require.Contains(t, output, tc.exp, "expected %q to contain %q", output, tc.exp)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestServer_DevTLS verifies that a vault server starts up correctly with the -dev-tls flag
|
|
func TestServer_DevTLS(t *testing.T) {
|
|
ui, cmd := testServerCommand(t)
|
|
args := []string{"-dev-tls", "-dev-listen-address=127.0.0.1:0", "-test-server-config"}
|
|
retCode := cmd.Run(args)
|
|
output := ui.ErrorWriter.String() + ui.OutputWriter.String()
|
|
require.Equal(t, 0, retCode, output)
|
|
require.Contains(t, output, `tls: "enabled"`)
|
|
}
|
|
|
|
// TestConfigureDevTLS verifies the various logic paths that flow through the
|
|
// configureDevTLS function.
|
|
func TestConfigureDevTLS(t *testing.T) {
|
|
testcases := []struct {
|
|
ServerCommand *ServerCommand
|
|
DeferFuncNotNil bool
|
|
ConfigNotNil bool
|
|
TLSDisable bool
|
|
CertPathEmpty bool
|
|
ErrNotNil bool
|
|
TestDescription string
|
|
}{
|
|
{
|
|
ServerCommand: &ServerCommand{
|
|
flagDevTLS: false,
|
|
},
|
|
ConfigNotNil: true,
|
|
TLSDisable: true,
|
|
CertPathEmpty: true,
|
|
ErrNotNil: false,
|
|
TestDescription: "flagDev is false, nothing will be configured",
|
|
},
|
|
{
|
|
ServerCommand: &ServerCommand{
|
|
flagDevTLS: true,
|
|
flagDevTLSCertDir: "",
|
|
},
|
|
DeferFuncNotNil: true,
|
|
ConfigNotNil: true,
|
|
ErrNotNil: false,
|
|
TestDescription: "flagDevTLSCertDir is empty",
|
|
},
|
|
{
|
|
ServerCommand: &ServerCommand{
|
|
flagDevTLS: true,
|
|
flagDevTLSCertDir: "@/#",
|
|
},
|
|
CertPathEmpty: true,
|
|
ErrNotNil: true,
|
|
TestDescription: "flagDevTLSCertDir is set to something invalid",
|
|
},
|
|
}
|
|
|
|
for _, testcase := range testcases {
|
|
fun, cfg, certPath, err := configureDevTLS(testcase.ServerCommand)
|
|
if fun != nil {
|
|
// If a function is returned, call it right away to clean up
|
|
// files created in the temporary directory before anything else has
|
|
// a chance to fail this test.
|
|
fun()
|
|
}
|
|
|
|
require.Equal(t, testcase.DeferFuncNotNil, (fun != nil), "test description %s", testcase.TestDescription)
|
|
require.Equal(t, testcase.ConfigNotNil, cfg != nil, "test description %s", testcase.TestDescription)
|
|
if testcase.ConfigNotNil {
|
|
require.True(t, len(cfg.Listeners) > 0, "test description %s", testcase.TestDescription)
|
|
require.Equal(t, testcase.TLSDisable, cfg.Listeners[0].TLSDisable, "test description %s", testcase.TestDescription)
|
|
}
|
|
require.Equal(t, testcase.CertPathEmpty, len(certPath) == 0, "test description %s", testcase.TestDescription)
|
|
require.Equal(t, testcase.ErrNotNil, (err != nil), "test description %s", testcase.TestDescription)
|
|
}
|
|
}
|