Refactor Code Focused on DevTLS Mode into New Function (#20376)
* refactor code focused on DevTLS mode into new function * add tests for configureDevTLS function * replace testcase comments with fields in testcase struct
This commit is contained in:
parent
4330265469
commit
8928b30224
|
@ -930,6 +930,69 @@ func (c *ServerCommand) InitListeners(config *server.Config, disableClustering b
|
|||
return 0, lns, clusterAddrs, nil
|
||||
}
|
||||
|
||||
func configureDevTLS(c *ServerCommand) (func(), *server.Config, string, error) {
|
||||
var devStorageType string
|
||||
|
||||
switch {
|
||||
case c.flagDevConsul:
|
||||
devStorageType = "consul"
|
||||
case c.flagDevHA && c.flagDevTransactional:
|
||||
devStorageType = "inmem_transactional_ha"
|
||||
case !c.flagDevHA && c.flagDevTransactional:
|
||||
devStorageType = "inmem_transactional"
|
||||
case c.flagDevHA && !c.flagDevTransactional:
|
||||
devStorageType = "inmem_ha"
|
||||
default:
|
||||
devStorageType = "inmem"
|
||||
}
|
||||
|
||||
var certDir string
|
||||
var err error
|
||||
var config *server.Config
|
||||
var f func()
|
||||
|
||||
if c.flagDevTLS {
|
||||
if c.flagDevTLSCertDir != "" {
|
||||
if _, err = os.Stat(c.flagDevTLSCertDir); err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
|
||||
certDir = c.flagDevTLSCertDir
|
||||
} else {
|
||||
if certDir, err = os.MkdirTemp("", "vault-tls"); err != nil {
|
||||
return nil, nil, certDir, err
|
||||
}
|
||||
}
|
||||
config, err = server.DevTLSConfig(devStorageType, certDir)
|
||||
|
||||
f = func() {
|
||||
if err := os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevCAFilename)); err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
}
|
||||
|
||||
if err := os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevCertFilename)); err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
}
|
||||
|
||||
if err := os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevKeyFilename)); err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
}
|
||||
|
||||
// Only delete temp directories we made.
|
||||
if c.flagDevTLSCertDir == "" {
|
||||
if err := os.Remove(certDir); err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
config, err = server.DevConfig(devStorageType)
|
||||
}
|
||||
|
||||
return f, config, certDir, err
|
||||
}
|
||||
|
||||
func (c *ServerCommand) Run(args []string) int {
|
||||
f := c.Flags()
|
||||
|
||||
|
@ -970,68 +1033,11 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
|
||||
// Load the configuration
|
||||
var config *server.Config
|
||||
var err error
|
||||
var certDir string
|
||||
if c.flagDev {
|
||||
var devStorageType string
|
||||
switch {
|
||||
case c.flagDevConsul:
|
||||
devStorageType = "consul"
|
||||
case c.flagDevHA && c.flagDevTransactional:
|
||||
devStorageType = "inmem_transactional_ha"
|
||||
case !c.flagDevHA && c.flagDevTransactional:
|
||||
devStorageType = "inmem_transactional"
|
||||
case c.flagDevHA && !c.flagDevTransactional:
|
||||
devStorageType = "inmem_ha"
|
||||
default:
|
||||
devStorageType = "inmem"
|
||||
}
|
||||
|
||||
if c.flagDevTLS {
|
||||
if c.flagDevTLSCertDir != "" {
|
||||
_, err := os.Stat(c.flagDevTLSCertDir)
|
||||
if err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
return 1
|
||||
}
|
||||
|
||||
certDir = c.flagDevTLSCertDir
|
||||
} else {
|
||||
certDir, err = os.MkdirTemp("", "vault-tls")
|
||||
if err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
return 1
|
||||
}
|
||||
}
|
||||
config, err = server.DevTLSConfig(devStorageType, certDir)
|
||||
|
||||
defer func() {
|
||||
err := os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevCAFilename))
|
||||
if err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
}
|
||||
|
||||
err = os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevCertFilename))
|
||||
if err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
}
|
||||
|
||||
err = os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevKeyFilename))
|
||||
if err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
}
|
||||
|
||||
// Only delete temp directories we made.
|
||||
if c.flagDevTLSCertDir == "" {
|
||||
err = os.Remove(certDir)
|
||||
if err != nil {
|
||||
c.UI.Error(err.Error())
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
} else {
|
||||
config, err = server.DevConfig(devStorageType)
|
||||
df, cfg, dir, err := configureDevTLS(c)
|
||||
if df != nil {
|
||||
defer df()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
@ -1039,6 +1045,9 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
return 1
|
||||
}
|
||||
|
||||
config = cfg
|
||||
certDir = dir
|
||||
|
||||
if c.flagDevListenAddr != "" {
|
||||
config.Listeners[0].Address = c.flagDevListenAddr
|
||||
}
|
||||
|
|
|
@ -330,3 +330,66 @@ func TestServer_DevTLS(t *testing.T) {
|
|||
require.Equal(t, 0, retCode, output)
|
||||
require.Contains(t, output, `tls: "enabled"`)
|
||||
}
|
||||
|
||||
// TestConfigureDevTLS verifies the various logic paths that flow through the
|
||||
// configureDevTLS function.
|
||||
func TestConfigureDevTLS(t *testing.T) {
|
||||
testcases := []struct {
|
||||
ServerCommand *ServerCommand
|
||||
DeferFuncNotNil bool
|
||||
ConfigNotNil bool
|
||||
TLSDisable bool
|
||||
CertPathEmpty bool
|
||||
ErrNotNil bool
|
||||
TestDescription string
|
||||
}{
|
||||
{
|
||||
ServerCommand: &ServerCommand{
|
||||
flagDevTLS: false,
|
||||
},
|
||||
ConfigNotNil: true,
|
||||
TLSDisable: true,
|
||||
CertPathEmpty: true,
|
||||
ErrNotNil: false,
|
||||
TestDescription: "flagDev is false, nothing will be configured",
|
||||
},
|
||||
{
|
||||
ServerCommand: &ServerCommand{
|
||||
flagDevTLS: true,
|
||||
flagDevTLSCertDir: "",
|
||||
},
|
||||
DeferFuncNotNil: true,
|
||||
ConfigNotNil: true,
|
||||
ErrNotNil: false,
|
||||
TestDescription: "flagDevTLSCertDir is empty",
|
||||
},
|
||||
{
|
||||
ServerCommand: &ServerCommand{
|
||||
flagDevTLS: true,
|
||||
flagDevTLSCertDir: "@/#",
|
||||
},
|
||||
CertPathEmpty: true,
|
||||
ErrNotNil: true,
|
||||
TestDescription: "flagDevTLSCertDir is set to something invalid",
|
||||
},
|
||||
}
|
||||
|
||||
for _, testcase := range testcases {
|
||||
fun, cfg, certPath, err := configureDevTLS(testcase.ServerCommand)
|
||||
if fun != nil {
|
||||
// If a function is returned, call it right away to clean up
|
||||
// files created in the temporary directory before anything else has
|
||||
// a chance to fail this test.
|
||||
fun()
|
||||
}
|
||||
|
||||
require.Equal(t, testcase.DeferFuncNotNil, (fun != nil), "test description %s", testcase.TestDescription)
|
||||
require.Equal(t, testcase.ConfigNotNil, cfg != nil, "test description %s", testcase.TestDescription)
|
||||
if testcase.ConfigNotNil {
|
||||
require.True(t, len(cfg.Listeners) > 0, "test description %s", testcase.TestDescription)
|
||||
require.Equal(t, testcase.TLSDisable, cfg.Listeners[0].TLSDisable, "test description %s", testcase.TestDescription)
|
||||
}
|
||||
require.Equal(t, testcase.CertPathEmpty, len(certPath) == 0, "test description %s", testcase.TestDescription)
|
||||
require.Equal(t, testcase.ErrNotNil, (err != nil), "test description %s", testcase.TestDescription)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue