diff --git a/command/agent.go b/command/agent.go index 2d6dcd208..75ae5acdc 100644 --- a/command/agent.go +++ b/command/agent.go @@ -211,7 +211,7 @@ func (c *AgentCommand) Run(args []string) int { c.UI.Info("No auto_auth block found in config, the automatic authentication feature will not be started") } - c.updateConfig(f, config) // This only needs to happen on start-up to aggregate config from flags and env vars + c.applyConfigOverrides(f, config) // This only needs to happen on start-up to aggregate config from flags and env vars c.config = config l, err := c.newLogger() @@ -961,16 +961,16 @@ func (c *AgentCommand) Run(args []string) int { return exitCode } -// updateConfig ensures that the config object accurately reflects the desired +// applyConfigOverrides ensures that the config object accurately reflects the desired // settings as configured by the user. It applies the relevant config setting based // on the precedence (env var overrides file config, cli overrides env var). // It mutates the config object supplied. -func (c *AgentCommand) updateConfig(f *FlagSets, config *agentConfig.Config) { +func (c *AgentCommand) applyConfigOverrides(f *FlagSets, config *agentConfig.Config) { if config.Vault == nil { config.Vault = &agentConfig.Vault{} } - f.updateLogConfig(config.SharedConfig) + f.applyLogConfigOverrides(config.SharedConfig) f.Visit(func(fl *flag.Flag) { if fl.Name == flagNameAgentExitAfterAuth { @@ -1228,16 +1228,6 @@ func (c *AgentCommand) newLogger() (log.InterceptLogger, error) { errors = multierror.Append(errors, err) } - logRotateBytes, err := parseutil.ParseInt(c.config.LogRotateBytes) - if err != nil { - errors = multierror.Append(errors, err) - } - - logRotateMaxFiles, err := parseutil.ParseInt(c.config.LogRotateMaxFiles) - if err != nil { - errors = multierror.Append(errors, err) - } - if errors != nil { return nil, errors } @@ -1248,8 +1238,8 @@ func (c *AgentCommand) newLogger() (log.InterceptLogger, error) { LogFormat: logFormat, LogFilePath: c.config.LogFile, LogRotateDuration: logRotateDuration, - LogRotateBytes: int(logRotateBytes), - LogRotateMaxFiles: int(logRotateMaxFiles), + LogRotateBytes: c.config.LogRotateBytes, + LogRotateMaxFiles: c.config.LogRotateMaxFiles, } l, err := logging.Setup(logCfg, c.logWriter) diff --git a/command/agent_test.go b/command/agent_test.go index 8b918a6e6..e9d737628 100644 --- a/command/agent_test.go +++ b/command/agent_test.go @@ -38,6 +38,8 @@ const ( BasicHclConfig = ` log_file = "TMPDIR/juan.log" log_level="warn" +log_rotate_max_files=2 +log_rotate_bytes=1048576 vault { address = "http://127.0.0.1:8200" retry { @@ -54,6 +56,8 @@ listener "tcp" { BasicHclConfig2 = ` log_file = "TMPDIR/juan.log" log_level="debug" +log_rotate_max_files=-1 +log_rotate_bytes=1048576 vault { address = "http://127.0.0.1:8200" retry { @@ -2110,7 +2114,7 @@ func TestAgent_LogFile_CliOverridesConfig(t *testing.T) { } // Update the config based on the inputs. - cmd.updateConfig(f, cfg) + cmd.applyConfigOverrides(f, cfg) assert.NotEqual(t, "TMPDIR/juan.log", cfg.LogFile) assert.NotEqual(t, "/squiggle/logs.txt", cfg.LogFile) @@ -2127,6 +2131,8 @@ func TestAgent_LogFile_Config(t *testing.T) { // Sanity check that the config value is the current value assert.Equal(t, "TMPDIR/juan.log", cfg.LogFile, "sanity check on log config failed") + assert.Equal(t, 2, cfg.LogRotateMaxFiles) + assert.Equal(t, 1048576, cfg.LogRotateBytes) // Parse the cli flags (but we pass in an empty slice) cmd := &AgentCommand{BaseCommand: &BaseCommand{}} @@ -2136,9 +2142,12 @@ func TestAgent_LogFile_Config(t *testing.T) { t.Fatal(err) } - cmd.updateConfig(f, cfg) + // Should change nothing... + cmd.applyConfigOverrides(f, cfg) assert.Equal(t, "TMPDIR/juan.log", cfg.LogFile, "actual config check") + assert.Equal(t, 2, cfg.LogRotateMaxFiles) + assert.Equal(t, 1048576, cfg.LogRotateBytes) } func TestAgent_Config_NewLogger_Default(t *testing.T) { diff --git a/command/base_flags.go b/command/base_flags.go index 5ec0af3cb..4865d7163 100644 --- a/command/base_flags.go +++ b/command/base_flags.go @@ -249,14 +249,14 @@ func (i *intValue) Set(s string) error { return err } if v >= math.MinInt && v <= math.MaxInt { - *i.target = int(v) + *i.target = v return nil } return fmt.Errorf("Incorrect conversion of a 64-bit integer to a lower bit size. Value %d is not within bounds for int32", v) } -func (i *intValue) Get() interface{} { return int(*i.target) } -func (i *intValue) String() string { return strconv.Itoa(int(*i.target)) } +func (i *intValue) Get() interface{} { return *i.target } +func (i *intValue) String() string { return strconv.Itoa(*i.target) } func (i *intValue) Example() string { return "int" } func (i *intValue) Hidden() bool { return i.hidden } diff --git a/command/log_flags.go b/command/log_flags.go index 8b5e8fef7..c00086740 100644 --- a/command/log_flags.go +++ b/command/log_flags.go @@ -3,7 +3,7 @@ package command import ( "flag" "os" - "strings" + "strconv" "github.com/hashicorp/vault/internalshared/configutil" "github.com/posener/complete" @@ -15,20 +15,18 @@ type logFlags struct { flagLogLevel string flagLogFormat string flagLogFile string - flagLogRotateBytes string + flagLogRotateBytes int flagLogRotateDuration string - flagLogRotateMaxFiles string + flagLogRotateMaxFiles int } -type provider = func(key string) (string, bool) - // valuesProvider has the intention of providing a way to supply a func with a // way to retrieve values for flags and environment variables without having to -// directly call a specific implementation. The reasoning for its existence is -// to facilitate testing. +// directly call a specific implementation. +// The reasoning for its existence is to facilitate testing. type valuesProvider struct { - flagProvider provider - envVarProvider provider + flagProvider func(string) (flag.Value, bool) + envVarProvider func(string) (string, bool) } // addLogFlags will add the set of 'log' related flags to a flag set. @@ -65,7 +63,7 @@ func (f *FlagSet) addLogFlags(l *logFlags) { Usage: "Path to the log file that Vault should use for logging", }) - f.StringVar(&StringVar{ + f.IntVar(&IntVar{ Name: flagNameLogRotateBytes, Target: &l.flagLogRotateBytes, Usage: "Number of bytes that should be written to a log before it needs to be rotated. " + @@ -79,23 +77,34 @@ func (f *FlagSet) addLogFlags(l *logFlags) { "Must be a duration value such as 30s", }) - f.StringVar(&StringVar{ + f.IntVar(&IntVar{ Name: flagNameLogRotateMaxFiles, Target: &l.flagLogRotateMaxFiles, Usage: "The maximum number of older log file archives to keep", }) } -// getValue will attempt to find the flag with the corresponding flag name (key) -// and return the value along with a bool representing whether of not the flag had been found/set. -func (f *FlagSets) getValue(flagName string) (string, bool) { - var result string +// envVarValue attempts to get a named value from the environment variables. +// The value will be returned as a string along with a boolean value indiciating +// to the caller whether the named env var existed. +func envVarValue(key string) (string, bool) { + if key == "" { + return "", false + } + return os.LookupEnv(key) +} + +// flagValue attempts to find the named flag in a set of FlagSets. +// The flag.Value is returned if it was specified, and the boolean value indicates +// to the caller if the flag was specified by the end user. +func (f *FlagSets) flagValue(flagName string) (flag.Value, bool) { + var result flag.Value var isFlagSpecified bool if f != nil { f.Visit(func(fl *flag.Flag) { if fl.Name == flagName { - result = fl.Value.String() + result = fl.Value isFlagSpecified = true } }) @@ -104,51 +113,63 @@ func (f *FlagSets) getValue(flagName string) (string, bool) { return result, isFlagSpecified } -// getAggregatedConfigValue uses the provided keys to check CLI flags and environment +// overrideValue uses the provided keys to check CLI flags and environment // variables for values that may be used to override any specified configuration. -// If nothing can be found in flags/env vars or config, the 'fallback' (default) value will be provided. -func (p *valuesProvider) getAggregatedConfigValue(flagKey, envVarKey, current, fallback string) string { +func (p *valuesProvider) overrideValue(flagKey, envVarKey string) (string, bool) { var result string - current = strings.TrimSpace(current) + found := true flg, flgFound := p.flagProvider(flagKey) env, envFound := p.envVarProvider(envVarKey) switch { case flgFound: - result = flg + result = flg.String() case envFound: - // Use value from env var result = env - case current != "": - // Use value from config - result = current default: - // Use the default value - result = fallback + found = false } - return result + return result, found } -// updateLogConfig will accept a shared config and specifically attempt to update the 'log' related config keys. -// For each 'log' key we aggregate file config/env vars and CLI flags to select the one with the highest precedence. +// applyLogConfigOverrides will accept a shared config and specifically attempt to update the 'log' related config keys. +// For each 'log' key, we aggregate file config, env vars and CLI flags to select the one with the highest precedence. // This method mutates the config object passed into it. -func (f *FlagSets) updateLogConfig(config *configutil.SharedConfig) { +func (f *FlagSets) applyLogConfigOverrides(config *configutil.SharedConfig) { p := &valuesProvider{ - flagProvider: func(key string) (string, bool) { return f.getValue(key) }, - envVarProvider: func(key string) (string, bool) { - if key == "" { - return "", false - } - return os.LookupEnv(key) - }, + flagProvider: f.flagValue, + envVarProvider: envVarValue, } - config.LogLevel = p.getAggregatedConfigValue(flagNameLogLevel, EnvVaultLogLevel, config.LogLevel, "info") - config.LogFormat = p.getAggregatedConfigValue(flagNameLogFormat, EnvVaultLogFormat, config.LogFormat, "") - config.LogFile = p.getAggregatedConfigValue(flagNameLogFile, "", config.LogFile, "") - config.LogRotateDuration = p.getAggregatedConfigValue(flagNameLogRotateDuration, "", config.LogRotateDuration, "") - config.LogRotateBytes = p.getAggregatedConfigValue(flagNameLogRotateBytes, "", config.LogRotateBytes, "") - config.LogRotateMaxFiles = p.getAggregatedConfigValue(flagNameLogRotateMaxFiles, "", config.LogRotateMaxFiles, "") + // Update log level + if val, found := p.overrideValue(flagNameLogLevel, EnvVaultLogLevel); found { + config.LogLevel = val + } + + // Update log format + if val, found := p.overrideValue(flagNameLogFormat, EnvVaultLogFormat); found { + config.LogFormat = val + } + + // Update log file name + if val, found := p.overrideValue(flagNameLogFile, ""); found { + config.LogFile = val + } + + // Update log rotation duration + if val, found := p.overrideValue(flagNameLogRotateDuration, ""); found { + config.LogRotateDuration = val + } + + // Update log max files + if val, found := p.overrideValue(flagNameLogRotateMaxFiles, ""); found { + config.LogRotateMaxFiles, _ = strconv.Atoi(val) + } + + // Update log rotation max bytes + if val, found := p.overrideValue(flagNameLogRotateBytes, ""); found { + config.LogRotateBytes, _ = strconv.Atoi(val) + } } diff --git a/command/log_flags_test.go b/command/log_flags_test.go index d4924f736..78ca51c4d 100644 --- a/command/log_flags_test.go +++ b/command/log_flags_test.go @@ -1,6 +1,7 @@ package command import ( + "flag" "testing" "github.com/stretchr/testify/assert" @@ -10,66 +11,81 @@ func TestLogFlags_ValuesProvider(t *testing.T) { cases := map[string]struct { flagKey string envVarKey string - current string - fallback string - want string + wantValue string + wantFound bool }{ - "only-fallback": { - flagKey: "invalid", - envVarKey: "invalid", - current: "", - fallback: "foo", - want: "foo", - }, - "only-config": { - flagKey: "invalid", - envVarKey: "invalid", - current: "bar", - fallback: "", - want: "bar", - }, "flag-missing": { flagKey: "invalid", envVarKey: "valid-env-var", - current: "my-config-value1", - fallback: "", - want: "envVarValue", + wantValue: "envVarValue", + wantFound: true, }, "envVar-missing": { flagKey: "valid-flag", envVarKey: "invalid", - current: "my-config-value1", - fallback: "", - want: "flagValue", + wantValue: "flagValue", + wantFound: true, }, "all-present": { flagKey: "valid-flag", envVarKey: "valid-env-var", - current: "my-config-value1", - fallback: "foo", - want: "flagValue", + wantValue: "flagValue", + wantFound: true, + }, + "all-missing": { + flagKey: "invalid", + envVarKey: "invalid", + wantValue: "", + wantFound: false, }, } - // Sneaky little fake provider - fakeProvider := func(key string) (string, bool) { - switch key { - case "valid-flag": - return "flagValue", true - case "valid-env-var": - return "envVarValue", true + // Sneaky little fake providers + flagFaker := func(key string) (flag.Value, bool) { + var result fakeFlag + var found bool + + if key == "valid-flag" { + result.Set("flagValue") + found = true } - return "", false + return &result, found + } + + envFaker := func(key string) (string, bool) { + var found bool + var result string + + if key == "valid-env-var" { + result = "envVarValue" + found = true + } + + return result, found } vp := valuesProvider{ - flagProvider: fakeProvider, - envVarProvider: fakeProvider, + flagProvider: flagFaker, + envVarProvider: envFaker, } - for _, tc := range cases { - got := vp.getAggregatedConfigValue(tc.flagKey, tc.envVarKey, tc.current, tc.fallback) - assert.Equal(t, tc.want, got) + for name, tc := range cases { + val, found := vp.overrideValue(tc.flagKey, tc.envVarKey) + assert.Equal(t, tc.wantFound, found, name) + assert.Equal(t, tc.wantValue, val, name) } } + +type fakeFlag struct { + value string +} + +func (v *fakeFlag) String() string { + return v.value +} + +func (v *fakeFlag) Set(raw string) error { + v.value = raw + return nil +} diff --git a/command/server.go b/command/server.go index ed194f07c..eff271e7e 100644 --- a/command/server.go +++ b/command/server.go @@ -433,7 +433,7 @@ func (c *ServerCommand) runRecoveryMode() int { } // Update the 'log' related aspects of shared config based on config/env var/cli - c.Flags().updateLogConfig(config.SharedConfig) + c.Flags().applyLogConfigOverrides(config.SharedConfig) l, err := c.configureLogging(config) if err != nil { c.UI.Error(err.Error()) @@ -1039,7 +1039,7 @@ func (c *ServerCommand) Run(args []string) int { return 1 } - f.updateLogConfig(config.SharedConfig) + f.applyLogConfigOverrides(config.SharedConfig) // Set 'trace' log level for the following 'dev' clusters if c.flagDevThreeNode || c.flagDevFourCluster { @@ -1696,24 +1696,14 @@ func (c *ServerCommand) configureLogging(config *server.Config) (hclog.Intercept return nil, err } - logRotateBytes, err := parseutil.ParseInt(config.LogRotateBytes) - if err != nil { - return nil, err - } - - logRotateMaxFiles, err := parseutil.ParseInt(config.LogRotateMaxFiles) - if err != nil { - return nil, err - } - logCfg := &loghelper.LogConfig{ Name: "vault", LogLevel: logLevel, LogFormat: logFormat, LogFilePath: config.LogFile, LogRotateDuration: logRotateDuration, - LogRotateBytes: int(logRotateBytes), - LogRotateMaxFiles: int(logRotateMaxFiles), + LogRotateBytes: config.LogRotateBytes, + LogRotateMaxFiles: config.LogRotateMaxFiles, } return loghelper.Setup(logCfg, c.logWriter) diff --git a/helper/logging/logger.go b/helper/logging/logger.go index 05ef205ee..e876d54f1 100644 --- a/helper/logging/logger.go +++ b/helper/logging/logger.go @@ -122,6 +122,7 @@ func Setup(config *LogConfig, w io.Writer) (log.InterceptLogger, error) { if config.LogRotateDuration == 0 { config.LogRotateDuration = defaultRotateDuration } + logFile := &LogFile{ fileName: fileName, logPath: dir, diff --git a/internalshared/configutil/config.go b/internalshared/configutil/config.go index dd63239c7..f4c2ec11a 100644 --- a/internalshared/configutil/config.go +++ b/internalshared/configutil/config.go @@ -38,12 +38,14 @@ type SharedConfig struct { // LogFormat specifies the log format. Valid values are "standard" and // "json". The values are case-insenstive. If no log format is specified, // then standard format will be used. - LogFormat string `hcl:"log_format"` - LogLevel string `hcl:"log_level"` - LogFile string `hcl:"log_file"` - LogRotateBytes string `hcl:"log_rotate_bytes"` - LogRotateDuration string `hcl:"log_rotate_duration"` - LogRotateMaxFiles string `hcl:"log_rotate_max_files"` + LogFormat string `hcl:"log_format"` + LogLevel string `hcl:"log_level"` + LogFile string `hcl:"log_file"` + LogRotateDuration string `hcl:"log_rotate_duration"` + LogRotateBytes int `hcl:"log_rotate_bytes"` + LogRotateBytesRaw interface{} `hcl:"log_rotate_bytes"` + LogRotateMaxFiles int `hcl:"log_rotate_max_files"` + LogRotateMaxFilesRaw interface{} `hcl:"log_rotate_max_files"` PidFile string `hcl:"pid_file"` diff --git a/internalshared/configutil/merge.go b/internalshared/configutil/merge.go index 791bd41a7..4bc30e62d 100644 --- a/internalshared/configutil/merge.go +++ b/internalshared/configutil/merge.go @@ -69,13 +69,15 @@ func (c *SharedConfig) Merge(c2 *SharedConfig) *SharedConfig { } result.LogRotateBytes = c.LogRotateBytes - if c2.LogRotateBytes != "" { + if c2.LogRotateBytesRaw != nil { result.LogRotateBytes = c2.LogRotateBytes + result.LogRotateBytesRaw = c2.LogRotateBytesRaw } result.LogRotateMaxFiles = c.LogRotateMaxFiles - if c2.LogRotateMaxFiles != "" { + if c2.LogRotateMaxFilesRaw != nil { result.LogRotateMaxFiles = c2.LogRotateMaxFiles + result.LogRotateMaxFilesRaw = c2.LogRotateMaxFilesRaw } result.LogRotateDuration = c.LogRotateDuration