From 8928b30224ee6c46aedcd74fbdcd91890d129aad Mon Sep 17 00:00:00 2001 From: Marc Boudreau Date: Fri, 19 May 2023 15:45:22 -0400 Subject: [PATCH] Refactor Code Focused on DevTLS Mode into New Function (#20376) * refactor code focused on DevTLS mode into new function * add tests for configureDevTLS function * replace testcase comments with fields in testcase struct --- command/server.go | 129 ++++++++++++++++++++++------------------- command/server_test.go | 63 ++++++++++++++++++++ 2 files changed, 132 insertions(+), 60 deletions(-) diff --git a/command/server.go b/command/server.go index f950cfef9..5d2b144ce 100644 --- a/command/server.go +++ b/command/server.go @@ -930,6 +930,69 @@ func (c *ServerCommand) InitListeners(config *server.Config, disableClustering b return 0, lns, clusterAddrs, nil } +func configureDevTLS(c *ServerCommand) (func(), *server.Config, string, error) { + var devStorageType string + + switch { + case c.flagDevConsul: + devStorageType = "consul" + case c.flagDevHA && c.flagDevTransactional: + devStorageType = "inmem_transactional_ha" + case !c.flagDevHA && c.flagDevTransactional: + devStorageType = "inmem_transactional" + case c.flagDevHA && !c.flagDevTransactional: + devStorageType = "inmem_ha" + default: + devStorageType = "inmem" + } + + var certDir string + var err error + var config *server.Config + var f func() + + if c.flagDevTLS { + if c.flagDevTLSCertDir != "" { + if _, err = os.Stat(c.flagDevTLSCertDir); err != nil { + return nil, nil, "", err + } + + certDir = c.flagDevTLSCertDir + } else { + if certDir, err = os.MkdirTemp("", "vault-tls"); err != nil { + return nil, nil, certDir, err + } + } + config, err = server.DevTLSConfig(devStorageType, certDir) + + f = func() { + if err := os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevCAFilename)); err != nil { + c.UI.Error(err.Error()) + } + + if err := os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevCertFilename)); err != nil { + c.UI.Error(err.Error()) + } + + if err := os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevKeyFilename)); err != nil { + c.UI.Error(err.Error()) + } + + // Only delete temp directories we made. + if c.flagDevTLSCertDir == "" { + if err := os.Remove(certDir); err != nil { + c.UI.Error(err.Error()) + } + } + } + + } else { + config, err = server.DevConfig(devStorageType) + } + + return f, config, certDir, err +} + func (c *ServerCommand) Run(args []string) int { f := c.Flags() @@ -970,68 +1033,11 @@ func (c *ServerCommand) Run(args []string) int { // Load the configuration var config *server.Config - var err error var certDir string if c.flagDev { - var devStorageType string - switch { - case c.flagDevConsul: - devStorageType = "consul" - case c.flagDevHA && c.flagDevTransactional: - devStorageType = "inmem_transactional_ha" - case !c.flagDevHA && c.flagDevTransactional: - devStorageType = "inmem_transactional" - case c.flagDevHA && !c.flagDevTransactional: - devStorageType = "inmem_ha" - default: - devStorageType = "inmem" - } - - if c.flagDevTLS { - if c.flagDevTLSCertDir != "" { - _, err := os.Stat(c.flagDevTLSCertDir) - if err != nil { - c.UI.Error(err.Error()) - return 1 - } - - certDir = c.flagDevTLSCertDir - } else { - certDir, err = os.MkdirTemp("", "vault-tls") - if err != nil { - c.UI.Error(err.Error()) - return 1 - } - } - config, err = server.DevTLSConfig(devStorageType, certDir) - - defer func() { - err := os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevCAFilename)) - if err != nil { - c.UI.Error(err.Error()) - } - - err = os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevCertFilename)) - if err != nil { - c.UI.Error(err.Error()) - } - - err = os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevKeyFilename)) - if err != nil { - c.UI.Error(err.Error()) - } - - // Only delete temp directories we made. - if c.flagDevTLSCertDir == "" { - err = os.Remove(certDir) - if err != nil { - c.UI.Error(err.Error()) - } - } - }() - - } else { - config, err = server.DevConfig(devStorageType) + df, cfg, dir, err := configureDevTLS(c) + if df != nil { + defer df() } if err != nil { @@ -1039,6 +1045,9 @@ func (c *ServerCommand) Run(args []string) int { return 1 } + config = cfg + certDir = dir + if c.flagDevListenAddr != "" { config.Listeners[0].Address = c.flagDevListenAddr } diff --git a/command/server_test.go b/command/server_test.go index 610bc31e0..bfe5b14dd 100644 --- a/command/server_test.go +++ b/command/server_test.go @@ -330,3 +330,66 @@ func TestServer_DevTLS(t *testing.T) { require.Equal(t, 0, retCode, output) require.Contains(t, output, `tls: "enabled"`) } + +// TestConfigureDevTLS verifies the various logic paths that flow through the +// configureDevTLS function. +func TestConfigureDevTLS(t *testing.T) { + testcases := []struct { + ServerCommand *ServerCommand + DeferFuncNotNil bool + ConfigNotNil bool + TLSDisable bool + CertPathEmpty bool + ErrNotNil bool + TestDescription string + }{ + { + ServerCommand: &ServerCommand{ + flagDevTLS: false, + }, + ConfigNotNil: true, + TLSDisable: true, + CertPathEmpty: true, + ErrNotNil: false, + TestDescription: "flagDev is false, nothing will be configured", + }, + { + ServerCommand: &ServerCommand{ + flagDevTLS: true, + flagDevTLSCertDir: "", + }, + DeferFuncNotNil: true, + ConfigNotNil: true, + ErrNotNil: false, + TestDescription: "flagDevTLSCertDir is empty", + }, + { + ServerCommand: &ServerCommand{ + flagDevTLS: true, + flagDevTLSCertDir: "@/#", + }, + CertPathEmpty: true, + ErrNotNil: true, + TestDescription: "flagDevTLSCertDir is set to something invalid", + }, + } + + for _, testcase := range testcases { + fun, cfg, certPath, err := configureDevTLS(testcase.ServerCommand) + if fun != nil { + // If a function is returned, call it right away to clean up + // files created in the temporary directory before anything else has + // a chance to fail this test. + fun() + } + + require.Equal(t, testcase.DeferFuncNotNil, (fun != nil), "test description %s", testcase.TestDescription) + require.Equal(t, testcase.ConfigNotNil, cfg != nil, "test description %s", testcase.TestDescription) + if testcase.ConfigNotNil { + require.True(t, len(cfg.Listeners) > 0, "test description %s", testcase.TestDescription) + require.Equal(t, testcase.TLSDisable, cfg.Listeners[0].TLSDisable, "test description %s", testcase.TestDescription) + } + require.Equal(t, testcase.CertPathEmpty, len(certPath) == 0, "test description %s", testcase.TestDescription) + require.Equal(t, testcase.ErrNotNil, (err != nil), "test description %s", testcase.TestDescription) + } +}