Refactor Code Focused on DevTLS Mode into New Function (#20376)
* refactor code focused on DevTLS mode into new function * add tests for configureDevTLS function * replace testcase comments with fields in testcase struct
This commit is contained in:
parent
4330265469
commit
8928b30224
|
@ -930,6 +930,69 @@ func (c *ServerCommand) InitListeners(config *server.Config, disableClustering b
|
||||||
return 0, lns, clusterAddrs, nil
|
return 0, lns, clusterAddrs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func configureDevTLS(c *ServerCommand) (func(), *server.Config, string, error) {
|
||||||
|
var devStorageType string
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case c.flagDevConsul:
|
||||||
|
devStorageType = "consul"
|
||||||
|
case c.flagDevHA && c.flagDevTransactional:
|
||||||
|
devStorageType = "inmem_transactional_ha"
|
||||||
|
case !c.flagDevHA && c.flagDevTransactional:
|
||||||
|
devStorageType = "inmem_transactional"
|
||||||
|
case c.flagDevHA && !c.flagDevTransactional:
|
||||||
|
devStorageType = "inmem_ha"
|
||||||
|
default:
|
||||||
|
devStorageType = "inmem"
|
||||||
|
}
|
||||||
|
|
||||||
|
var certDir string
|
||||||
|
var err error
|
||||||
|
var config *server.Config
|
||||||
|
var f func()
|
||||||
|
|
||||||
|
if c.flagDevTLS {
|
||||||
|
if c.flagDevTLSCertDir != "" {
|
||||||
|
if _, err = os.Stat(c.flagDevTLSCertDir); err != nil {
|
||||||
|
return nil, nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
certDir = c.flagDevTLSCertDir
|
||||||
|
} else {
|
||||||
|
if certDir, err = os.MkdirTemp("", "vault-tls"); err != nil {
|
||||||
|
return nil, nil, certDir, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config, err = server.DevTLSConfig(devStorageType, certDir)
|
||||||
|
|
||||||
|
f = func() {
|
||||||
|
if err := os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevCAFilename)); err != nil {
|
||||||
|
c.UI.Error(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevCertFilename)); err != nil {
|
||||||
|
c.UI.Error(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevKeyFilename)); err != nil {
|
||||||
|
c.UI.Error(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only delete temp directories we made.
|
||||||
|
if c.flagDevTLSCertDir == "" {
|
||||||
|
if err := os.Remove(certDir); err != nil {
|
||||||
|
c.UI.Error(err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
config, err = server.DevConfig(devStorageType)
|
||||||
|
}
|
||||||
|
|
||||||
|
return f, config, certDir, err
|
||||||
|
}
|
||||||
|
|
||||||
func (c *ServerCommand) Run(args []string) int {
|
func (c *ServerCommand) Run(args []string) int {
|
||||||
f := c.Flags()
|
f := c.Flags()
|
||||||
|
|
||||||
|
@ -970,68 +1033,11 @@ func (c *ServerCommand) Run(args []string) int {
|
||||||
|
|
||||||
// Load the configuration
|
// Load the configuration
|
||||||
var config *server.Config
|
var config *server.Config
|
||||||
var err error
|
|
||||||
var certDir string
|
var certDir string
|
||||||
if c.flagDev {
|
if c.flagDev {
|
||||||
var devStorageType string
|
df, cfg, dir, err := configureDevTLS(c)
|
||||||
switch {
|
if df != nil {
|
||||||
case c.flagDevConsul:
|
defer df()
|
||||||
devStorageType = "consul"
|
|
||||||
case c.flagDevHA && c.flagDevTransactional:
|
|
||||||
devStorageType = "inmem_transactional_ha"
|
|
||||||
case !c.flagDevHA && c.flagDevTransactional:
|
|
||||||
devStorageType = "inmem_transactional"
|
|
||||||
case c.flagDevHA && !c.flagDevTransactional:
|
|
||||||
devStorageType = "inmem_ha"
|
|
||||||
default:
|
|
||||||
devStorageType = "inmem"
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.flagDevTLS {
|
|
||||||
if c.flagDevTLSCertDir != "" {
|
|
||||||
_, err := os.Stat(c.flagDevTLSCertDir)
|
|
||||||
if err != nil {
|
|
||||||
c.UI.Error(err.Error())
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
certDir = c.flagDevTLSCertDir
|
|
||||||
} else {
|
|
||||||
certDir, err = os.MkdirTemp("", "vault-tls")
|
|
||||||
if err != nil {
|
|
||||||
c.UI.Error(err.Error())
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
config, err = server.DevTLSConfig(devStorageType, certDir)
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err := os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevCAFilename))
|
|
||||||
if err != nil {
|
|
||||||
c.UI.Error(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
err = os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevCertFilename))
|
|
||||||
if err != nil {
|
|
||||||
c.UI.Error(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
err = os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevKeyFilename))
|
|
||||||
if err != nil {
|
|
||||||
c.UI.Error(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only delete temp directories we made.
|
|
||||||
if c.flagDevTLSCertDir == "" {
|
|
||||||
err = os.Remove(certDir)
|
|
||||||
if err != nil {
|
|
||||||
c.UI.Error(err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
} else {
|
|
||||||
config, err = server.DevConfig(devStorageType)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1039,6 +1045,9 @@ func (c *ServerCommand) Run(args []string) int {
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
config = cfg
|
||||||
|
certDir = dir
|
||||||
|
|
||||||
if c.flagDevListenAddr != "" {
|
if c.flagDevListenAddr != "" {
|
||||||
config.Listeners[0].Address = c.flagDevListenAddr
|
config.Listeners[0].Address = c.flagDevListenAddr
|
||||||
}
|
}
|
||||||
|
|
|
@ -330,3 +330,66 @@ func TestServer_DevTLS(t *testing.T) {
|
||||||
require.Equal(t, 0, retCode, output)
|
require.Equal(t, 0, retCode, output)
|
||||||
require.Contains(t, output, `tls: "enabled"`)
|
require.Contains(t, output, `tls: "enabled"`)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestConfigureDevTLS verifies the various logic paths that flow through the
|
||||||
|
// configureDevTLS function.
|
||||||
|
func TestConfigureDevTLS(t *testing.T) {
|
||||||
|
testcases := []struct {
|
||||||
|
ServerCommand *ServerCommand
|
||||||
|
DeferFuncNotNil bool
|
||||||
|
ConfigNotNil bool
|
||||||
|
TLSDisable bool
|
||||||
|
CertPathEmpty bool
|
||||||
|
ErrNotNil bool
|
||||||
|
TestDescription string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
ServerCommand: &ServerCommand{
|
||||||
|
flagDevTLS: false,
|
||||||
|
},
|
||||||
|
ConfigNotNil: true,
|
||||||
|
TLSDisable: true,
|
||||||
|
CertPathEmpty: true,
|
||||||
|
ErrNotNil: false,
|
||||||
|
TestDescription: "flagDev is false, nothing will be configured",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ServerCommand: &ServerCommand{
|
||||||
|
flagDevTLS: true,
|
||||||
|
flagDevTLSCertDir: "",
|
||||||
|
},
|
||||||
|
DeferFuncNotNil: true,
|
||||||
|
ConfigNotNil: true,
|
||||||
|
ErrNotNil: false,
|
||||||
|
TestDescription: "flagDevTLSCertDir is empty",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ServerCommand: &ServerCommand{
|
||||||
|
flagDevTLS: true,
|
||||||
|
flagDevTLSCertDir: "@/#",
|
||||||
|
},
|
||||||
|
CertPathEmpty: true,
|
||||||
|
ErrNotNil: true,
|
||||||
|
TestDescription: "flagDevTLSCertDir is set to something invalid",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testcase := range testcases {
|
||||||
|
fun, cfg, certPath, err := configureDevTLS(testcase.ServerCommand)
|
||||||
|
if fun != nil {
|
||||||
|
// If a function is returned, call it right away to clean up
|
||||||
|
// files created in the temporary directory before anything else has
|
||||||
|
// a chance to fail this test.
|
||||||
|
fun()
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, testcase.DeferFuncNotNil, (fun != nil), "test description %s", testcase.TestDescription)
|
||||||
|
require.Equal(t, testcase.ConfigNotNil, cfg != nil, "test description %s", testcase.TestDescription)
|
||||||
|
if testcase.ConfigNotNil {
|
||||||
|
require.True(t, len(cfg.Listeners) > 0, "test description %s", testcase.TestDescription)
|
||||||
|
require.Equal(t, testcase.TLSDisable, cfg.Listeners[0].TLSDisable, "test description %s", testcase.TestDescription)
|
||||||
|
}
|
||||||
|
require.Equal(t, testcase.CertPathEmpty, len(certPath) == 0, "test description %s", testcase.TestDescription)
|
||||||
|
require.Equal(t, testcase.ErrNotNil, (err != nil), "test description %s", testcase.TestDescription)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue