Refactor Code Focused on DevTLS Mode into New Function (#20376)

* refactor code focused on DevTLS mode into new function

* add tests for configureDevTLS function

* replace testcase comments with fields in testcase struct
This commit is contained in:
Marc Boudreau 2023-05-19 15:45:22 -04:00 committed by GitHub
parent 4330265469
commit 8928b30224
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 132 additions and 60 deletions

View file

@ -930,6 +930,69 @@ func (c *ServerCommand) InitListeners(config *server.Config, disableClustering b
return 0, lns, clusterAddrs, nil return 0, lns, clusterAddrs, nil
} }
func configureDevTLS(c *ServerCommand) (func(), *server.Config, string, error) {
var devStorageType string
switch {
case c.flagDevConsul:
devStorageType = "consul"
case c.flagDevHA && c.flagDevTransactional:
devStorageType = "inmem_transactional_ha"
case !c.flagDevHA && c.flagDevTransactional:
devStorageType = "inmem_transactional"
case c.flagDevHA && !c.flagDevTransactional:
devStorageType = "inmem_ha"
default:
devStorageType = "inmem"
}
var certDir string
var err error
var config *server.Config
var f func()
if c.flagDevTLS {
if c.flagDevTLSCertDir != "" {
if _, err = os.Stat(c.flagDevTLSCertDir); err != nil {
return nil, nil, "", err
}
certDir = c.flagDevTLSCertDir
} else {
if certDir, err = os.MkdirTemp("", "vault-tls"); err != nil {
return nil, nil, certDir, err
}
}
config, err = server.DevTLSConfig(devStorageType, certDir)
f = func() {
if err := os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevCAFilename)); err != nil {
c.UI.Error(err.Error())
}
if err := os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevCertFilename)); err != nil {
c.UI.Error(err.Error())
}
if err := os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevKeyFilename)); err != nil {
c.UI.Error(err.Error())
}
// Only delete temp directories we made.
if c.flagDevTLSCertDir == "" {
if err := os.Remove(certDir); err != nil {
c.UI.Error(err.Error())
}
}
}
} else {
config, err = server.DevConfig(devStorageType)
}
return f, config, certDir, err
}
func (c *ServerCommand) Run(args []string) int { func (c *ServerCommand) Run(args []string) int {
f := c.Flags() f := c.Flags()
@ -970,68 +1033,11 @@ func (c *ServerCommand) Run(args []string) int {
// Load the configuration // Load the configuration
var config *server.Config var config *server.Config
var err error
var certDir string var certDir string
if c.flagDev { if c.flagDev {
var devStorageType string df, cfg, dir, err := configureDevTLS(c)
switch { if df != nil {
case c.flagDevConsul: defer df()
devStorageType = "consul"
case c.flagDevHA && c.flagDevTransactional:
devStorageType = "inmem_transactional_ha"
case !c.flagDevHA && c.flagDevTransactional:
devStorageType = "inmem_transactional"
case c.flagDevHA && !c.flagDevTransactional:
devStorageType = "inmem_ha"
default:
devStorageType = "inmem"
}
if c.flagDevTLS {
if c.flagDevTLSCertDir != "" {
_, err := os.Stat(c.flagDevTLSCertDir)
if err != nil {
c.UI.Error(err.Error())
return 1
}
certDir = c.flagDevTLSCertDir
} else {
certDir, err = os.MkdirTemp("", "vault-tls")
if err != nil {
c.UI.Error(err.Error())
return 1
}
}
config, err = server.DevTLSConfig(devStorageType, certDir)
defer func() {
err := os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevCAFilename))
if err != nil {
c.UI.Error(err.Error())
}
err = os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevCertFilename))
if err != nil {
c.UI.Error(err.Error())
}
err = os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevKeyFilename))
if err != nil {
c.UI.Error(err.Error())
}
// Only delete temp directories we made.
if c.flagDevTLSCertDir == "" {
err = os.Remove(certDir)
if err != nil {
c.UI.Error(err.Error())
}
}
}()
} else {
config, err = server.DevConfig(devStorageType)
} }
if err != nil { if err != nil {
@ -1039,6 +1045,9 @@ func (c *ServerCommand) Run(args []string) int {
return 1 return 1
} }
config = cfg
certDir = dir
if c.flagDevListenAddr != "" { if c.flagDevListenAddr != "" {
config.Listeners[0].Address = c.flagDevListenAddr config.Listeners[0].Address = c.flagDevListenAddr
} }

View file

@ -330,3 +330,66 @@ func TestServer_DevTLS(t *testing.T) {
require.Equal(t, 0, retCode, output) require.Equal(t, 0, retCode, output)
require.Contains(t, output, `tls: "enabled"`) require.Contains(t, output, `tls: "enabled"`)
} }
// TestConfigureDevTLS verifies the various logic paths that flow through the
// configureDevTLS function.
func TestConfigureDevTLS(t *testing.T) {
testcases := []struct {
ServerCommand *ServerCommand
DeferFuncNotNil bool
ConfigNotNil bool
TLSDisable bool
CertPathEmpty bool
ErrNotNil bool
TestDescription string
}{
{
ServerCommand: &ServerCommand{
flagDevTLS: false,
},
ConfigNotNil: true,
TLSDisable: true,
CertPathEmpty: true,
ErrNotNil: false,
TestDescription: "flagDev is false, nothing will be configured",
},
{
ServerCommand: &ServerCommand{
flagDevTLS: true,
flagDevTLSCertDir: "",
},
DeferFuncNotNil: true,
ConfigNotNil: true,
ErrNotNil: false,
TestDescription: "flagDevTLSCertDir is empty",
},
{
ServerCommand: &ServerCommand{
flagDevTLS: true,
flagDevTLSCertDir: "@/#",
},
CertPathEmpty: true,
ErrNotNil: true,
TestDescription: "flagDevTLSCertDir is set to something invalid",
},
}
for _, testcase := range testcases {
fun, cfg, certPath, err := configureDevTLS(testcase.ServerCommand)
if fun != nil {
// If a function is returned, call it right away to clean up
// files created in the temporary directory before anything else has
// a chance to fail this test.
fun()
}
require.Equal(t, testcase.DeferFuncNotNil, (fun != nil), "test description %s", testcase.TestDescription)
require.Equal(t, testcase.ConfigNotNil, cfg != nil, "test description %s", testcase.TestDescription)
if testcase.ConfigNotNil {
require.True(t, len(cfg.Listeners) > 0, "test description %s", testcase.TestDescription)
require.Equal(t, testcase.TLSDisable, cfg.Listeners[0].TLSDisable, "test description %s", testcase.TestDescription)
}
require.Equal(t, testcase.CertPathEmpty, len(certPath) == 0, "test description %s", testcase.TestDescription)
require.Equal(t, testcase.ErrNotNil, (err != nil), "test description %s", testcase.TestDescription)
}
}