diff --git a/api/client.go b/api/client.go index eeab4f32b..f46cda942 100644 --- a/api/client.go +++ b/api/client.go @@ -25,14 +25,15 @@ import ( "golang.org/x/time/rate" ) -const EnvVaultAgentAddress = "VAULT_AGENT_ADDR" const EnvVaultAddress = "VAULT_ADDR" +const EnvVaultAgentAddr = "VAULT_AGENT_ADDR" const EnvVaultCACert = "VAULT_CACERT" const EnvVaultCAPath = "VAULT_CAPATH" const EnvVaultClientCert = "VAULT_CLIENT_CERT" const EnvVaultClientKey = "VAULT_CLIENT_KEY" const EnvVaultClientTimeout = "VAULT_CLIENT_TIMEOUT" -const EnvVaultInsecure = "VAULT_SKIP_VERIFY" +const EnvVaultSkipVerify = "VAULT_SKIP_VERIFY" +const EnvVaultNamespace = "VAULT_NAMESPACE" const EnvVaultTLSServerName = "VAULT_TLS_SERVER_NAME" const EnvVaultWrapTTL = "VAULT_WRAP_TTL" const EnvVaultMaxRetries = "VAULT_MAX_RETRIES" @@ -243,7 +244,7 @@ func (c *Config) ReadEnvironment() error { if v := os.Getenv(EnvVaultAddress); v != "" { envAddress = v } - if v := os.Getenv(EnvVaultAgentAddress); v != "" { + if v := os.Getenv(EnvVaultAgentAddr); v != "" { envAgentAddress = v } if v := os.Getenv(EnvVaultMaxRetries); v != "" { @@ -279,7 +280,7 @@ func (c *Config) ReadEnvironment() error { } envClientTimeout = clientTimeout } - if v := os.Getenv(EnvVaultInsecure); v != "" { + if v := os.Getenv(EnvVaultSkipVerify); v != "" { var err error envInsecure, err = strconv.ParseBool(v) if err != nil { diff --git a/api/client_test.go b/api/client_test.go index 5678478ea..13fdd6e9d 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -163,19 +163,19 @@ func TestClientEnvSettings(t *testing.T) { oldCAPath := os.Getenv(EnvVaultCAPath) oldClientCert := os.Getenv(EnvVaultClientCert) oldClientKey := os.Getenv(EnvVaultClientKey) - oldSkipVerify := os.Getenv(EnvVaultInsecure) + oldSkipVerify := os.Getenv(EnvVaultSkipVerify) oldMaxRetries := os.Getenv(EnvVaultMaxRetries) os.Setenv(EnvVaultCACert, cwd+"/test-fixtures/keys/cert.pem") os.Setenv(EnvVaultCAPath, cwd+"/test-fixtures/keys") os.Setenv(EnvVaultClientCert, cwd+"/test-fixtures/keys/cert.pem") os.Setenv(EnvVaultClientKey, cwd+"/test-fixtures/keys/key.pem") - os.Setenv(EnvVaultInsecure, "true") + os.Setenv(EnvVaultSkipVerify, "true") os.Setenv(EnvVaultMaxRetries, "5") defer os.Setenv(EnvVaultCACert, oldCACert) defer os.Setenv(EnvVaultCAPath, oldCAPath) defer os.Setenv(EnvVaultClientCert, oldClientCert) defer os.Setenv(EnvVaultClientKey, oldClientKey) - defer os.Setenv(EnvVaultInsecure, oldSkipVerify) + defer os.Setenv(EnvVaultSkipVerify, oldSkipVerify) defer os.Setenv(EnvVaultMaxRetries, oldMaxRetries) config := DefaultConfig() diff --git a/command/agent.go b/command/agent.go index f3b40fcb3..21b9811be 100644 --- a/command/agent.go +++ b/command/agent.go @@ -2,6 +2,7 @@ package command import ( "context" + "flag" "fmt" "io" "net" @@ -206,6 +207,33 @@ func (c *AgentCommand) Run(args []string) int { return 1 } + if config.Vault != nil { + c.setStringFlag(f, config.Vault.Address, &StringVar{ + Name: flagNameAddress, + Target: &c.flagAddress, + Default: "https://127.0.0.1:8200", + EnvVar: api.EnvVaultAddress, + }) + c.setStringFlag(f, config.Vault.CACert, &StringVar{ + Name: flagNameCACert, + Target: &c.flagCACert, + Default: "", + EnvVar: api.EnvVaultCACert, + }) + c.setStringFlag(f, config.Vault.CAPath, &StringVar{ + Name: flagNameCAPath, + Target: &c.flagCAPath, + Default: "", + EnvVar: api.EnvVaultCAPath, + }) + c.setBoolFlag(f, config.Vault.TLSSkipVerify, &BoolVar{ + Name: flagNameTLSSkipVerify, + Target: &c.flagTLSSkipVerify, + Default: false, + EnvVar: api.EnvVaultSkipVerify, + }) + } + infoKeys := make([]string, 0, 10) info := make(map[string]string) info["log level"] = c.flagLogLevel @@ -235,6 +263,9 @@ func (c *AgentCommand) Run(args []string) int { return 0 } + // Ignore any setting of agent's address. This client is used by the agent + // to reach out to Vault. This should never loop back to agent. + c.flagAgentAddress = "" client, err := c.Client() if err != nil { c.UI.Error(fmt.Sprintf( @@ -472,6 +503,54 @@ func (c *AgentCommand) Run(args []string) int { return 0 } +func (c *AgentCommand) setStringFlag(f *FlagSets, configVal string, fVar *StringVar) { + var isFlagSet bool + f.Visit(func(f *flag.Flag) { + if f.Name == fVar.Name { + isFlagSet = true + } + }) + + flagEnvValue, flagEnvSet := os.LookupEnv(fVar.EnvVar) + switch { + case isFlagSet: + // Don't do anything as the flag is already set from the command line + case flagEnvSet: + // Use value from env var + *fVar.Target = flagEnvValue + case configVal != "": + // Use value from config + *fVar.Target = configVal + default: + // Use the default value + *fVar.Target = fVar.Default + } +} + +func (c *AgentCommand) setBoolFlag(f *FlagSets, configVal bool, fVar *BoolVar) { + var isFlagSet bool + f.Visit(func(f *flag.Flag) { + if f.Name == fVar.Name { + isFlagSet = true + } + }) + + flagEnvValue, flagEnvSet := os.LookupEnv(fVar.EnvVar) + switch { + case isFlagSet: + // Don't do anything as the flag is already set from the command line + case flagEnvSet: + // Use value from env var + *fVar.Target = flagEnvValue != "" + case configVal == true: + // Use value from config + *fVar.Target = configVal + default: + // Use the default value + *fVar.Target = fVar.Default + } +} + // storePidFile is used to write out our PID to a file if necessary func (c *AgentCommand) storePidFile(pidPath string) error { // Quit fast if no pidfile diff --git a/command/agent/config/config.go b/command/agent/config/config.go index 2c6ffcc23..d6d9875e3 100644 --- a/command/agent/config/config.go +++ b/command/agent/config/config.go @@ -23,6 +23,14 @@ type Config struct { ExitAfterAuth bool `hcl:"exit_after_auth"` PidFile string `hcl:"pid_file"` Cache *Cache `hcl:"cache"` + Vault *Vault `hcl:"vault"` +} + +type Vault struct { + Address string `hcl:"address"` + CACert string `hcl:"ca_cert"` + CAPath string `hcl:"ca_path"` + TLSSkipVerify bool `hcl:"tls_skip_verify"` } type Cache struct { @@ -107,9 +115,35 @@ func LoadConfig(path string, logger log.Logger) (*Config, error) { return nil, errwrap.Wrapf("error parsing 'cache':{{err}}", err) } + err = parseVault(&result, list) + if err != nil { + return nil, errwrap.Wrapf("error parsing 'vault':{{err}}", err) + } + return &result, nil } +func parseVault(result *Config, list *ast.ObjectList) error { + name := "vault" + + vaultList := list.Filter(name) + if len(vaultList.Items) > 1 { + return fmt.Errorf("one and only one %q block is required", name) + } + + item := vaultList.Items[0] + + var v Vault + err := hcl.DecodeObject(&v, item.Val) + if err != nil { + return err + } + + result.Vault = &v + + return nil +} + func parseCache(result *Config, list *ast.ObjectList) error { name := "cache" diff --git a/command/agent/config/config_test.go b/command/agent/config/config_test.go index 49621b50c..eeec7bf37 100644 --- a/command/agent/config/config_test.go +++ b/command/agent/config/config_test.go @@ -67,6 +67,12 @@ func TestLoadConfigFile_AgentCache(t *testing.T) { }, }, }, + Vault: &Vault{ + Address: "http://127.0.0.1:1111", + CACert: "config_ca_cert", + CAPath: "config_ca_path", + TLSSkipVerify: true, + }, PidFile: "./pidfile", } diff --git a/command/agent/config/test-fixtures/config-cache-embedded-type.hcl b/command/agent/config/test-fixtures/config-cache-embedded-type.hcl index 3079b29d7..01c466e93 100644 --- a/command/agent/config/test-fixtures/config-cache-embedded-type.hcl +++ b/command/agent/config/test-fixtures/config-cache-embedded-type.hcl @@ -42,3 +42,10 @@ cache { tls_cert_file = "/path/to/cacert.pem" } } + +vault { + address = "http://127.0.0.1:1111" + ca_cert = "config_ca_cert" + ca_path = "config_ca_path" + tls_skip_verify = "true" +} diff --git a/command/agent/config/test-fixtures/config-cache.hcl b/command/agent/config/test-fixtures/config-cache.hcl index f2ae5cb38..132980811 100644 --- a/command/agent/config/test-fixtures/config-cache.hcl +++ b/command/agent/config/test-fixtures/config-cache.hcl @@ -39,3 +39,10 @@ cache { tls_cert_file = "/path/to/cacert.pem" } } + +vault { + address = "http://127.0.0.1:1111" + ca_cert = "config_ca_cert" + ca_path = "config_ca_path" + tls_skip_verify = "true" +} diff --git a/command/agent_test.go b/command/agent_test.go index 7bcc32bc3..0d8de9305 100644 --- a/command/agent_test.go +++ b/command/agent_test.go @@ -188,15 +188,15 @@ cache { } }() - originalVaultAgentAddress := os.Getenv(api.EnvVaultAgentAddress) + originalVaultAgentAddress := os.Getenv(api.EnvVaultAgentAddr) // Create a client that talks to the agent - os.Setenv(api.EnvVaultAgentAddress, socketf) + os.Setenv(api.EnvVaultAgentAddr, socketf) testClient, err := api.NewClient(api.DefaultConfig()) if err != nil { t.Fatal(err) } - os.Setenv(api.EnvVaultAgentAddress, originalVaultAgentAddress) + os.Setenv(api.EnvVaultAgentAddr, originalVaultAgentAddress) // Start the agent go cmd.Run([]string{"-config", conf}) diff --git a/command/base.go b/command/base.go index 144e16435..edef8e761 100644 --- a/command/base.go +++ b/command/base.go @@ -211,9 +211,9 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets { f := set.NewFlagSet("HTTP Options") addrStringVar := &StringVar{ - Name: "address", + Name: flagNameAddress, Target: &c.flagAddress, - EnvVar: "VAULT_ADDR", + EnvVar: api.EnvVaultAddress, Completion: complete.PredictAnything, Usage: "Address of the Vault server.", } @@ -227,17 +227,17 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets { agentAddrStringVar := &StringVar{ Name: "agent-address", Target: &c.flagAgentAddress, - EnvVar: "VAULT_AGENT_ADDR", + EnvVar: api.EnvVaultAgentAddr, Completion: complete.PredictAnything, Usage: "Address of the Agent.", } f.StringVar(agentAddrStringVar) f.StringVar(&StringVar{ - Name: "ca-cert", + Name: flagNameCACert, Target: &c.flagCACert, Default: "", - EnvVar: "VAULT_CACERT", + EnvVar: api.EnvVaultCACert, Completion: complete.PredictFiles("*"), Usage: "Path on the local disk to a single PEM-encoded CA " + "certificate to verify the Vault server's SSL certificate. This " + @@ -245,10 +245,10 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets { }) f.StringVar(&StringVar{ - Name: "ca-path", + Name: flagNameCAPath, Target: &c.flagCAPath, Default: "", - EnvVar: "VAULT_CAPATH", + EnvVar: api.EnvVaultCAPath, Completion: complete.PredictDirs("*"), Usage: "Path on the local disk to a directory of PEM-encoded CA " + "certificates to verify the Vault server's SSL certificate.", @@ -258,7 +258,7 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets { Name: "client-cert", Target: &c.flagClientCert, Default: "", - EnvVar: "VAULT_CLIENT_CERT", + EnvVar: api.EnvVaultClientCert, Completion: complete.PredictFiles("*"), Usage: "Path on the local disk to a single PEM-encoded CA " + "certificate to use for TLS authentication to the Vault server. If " + @@ -269,7 +269,7 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets { Name: "client-key", Target: &c.flagClientKey, Default: "", - EnvVar: "VAULT_CLIENT_KEY", + EnvVar: api.EnvVaultClientKey, Completion: complete.PredictFiles("*"), Usage: "Path on the local disk to a single PEM-encoded private key " + "matching the client certificate from -client-cert.", @@ -279,7 +279,7 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets { Name: "namespace", Target: &c.flagNamespace, Default: notSetValue, // this can never be a real value - EnvVar: "VAULT_NAMESPACE", + EnvVar: api.EnvVaultNamespace, Completion: complete.PredictAnything, Usage: "The namespace to use for the command. Setting this is not " + "necessary but allows using relative paths. -ns can be used as " + @@ -299,17 +299,17 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets { Name: "tls-server-name", Target: &c.flagTLSServerName, Default: "", - EnvVar: "VAULT_TLS_SERVER_NAME", + EnvVar: api.EnvVaultTLSServerName, Completion: complete.PredictAnything, Usage: "Name to use as the SNI host when connecting to the Vault " + "server via TLS.", }) f.BoolVar(&BoolVar{ - Name: "tls-skip-verify", + Name: flagNameTLSSkipVerify, Target: &c.flagTLSSkipVerify, Default: false, - EnvVar: "VAULT_SKIP_VERIFY", + EnvVar: api.EnvVaultSkipVerify, Usage: "Disable verification of TLS certificates. Using this option " + "is highly discouraged and decreases the security of data " + "transmissions to and from the Vault server.", @@ -327,7 +327,7 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets { Name: "wrap-ttl", Target: &c.flagWrapTTL, Default: 0, - EnvVar: "VAULT_WRAP_TTL", + EnvVar: api.EnvVaultWrapTTL, Completion: complete.PredictAnything, Usage: "Wraps the response in a cubbyhole token with the requested " + "TTL. The response is available via the \"vault unwrap\" command. " + diff --git a/command/base_flags.go b/command/base_flags.go index 472317037..0cdb01f8d 100644 --- a/command/base_flags.go +++ b/command/base_flags.go @@ -743,7 +743,7 @@ func (f *FlagSet) VarFlag(i *VarFlag) { } // Var is a lower-level API for adding something to the flags. It should be used -// wtih caution, since it bypasses all validation. Consider VarFlag instead. +// with caution, since it bypasses all validation. Consider VarFlag instead. func (f *FlagSet) Var(value flag.Value, name, usage string) { f.mainSet.Var(value, name, usage) f.flagSet.Var(value, name, usage) diff --git a/command/commands.go b/command/commands.go index e009d5733..9c67ddd12 100644 --- a/command/commands.go +++ b/command/commands.go @@ -66,6 +66,18 @@ const ( // EnvVaultFormat is the output format EnvVaultFormat = `VAULT_FORMAT` + // flagNameAddress is the flag used in the base command to read in the + // address of the Vault server. + flagNameAddress = "address" + // flagnameCACert is the flag used in the base command to read in the CA + // cert. + flagNameCACert = "ca-cert" + // flagnameCAPath is the flag used in the base command to read in the CA + // cert path. + flagNameCAPath = "ca-path" + // flagNameTLSSkipVerify is the flag used in the base command to read in + // the option to ignore TLS certificate verification. + flagNameTLSSkipVerify = "tls-skip-verify" // flagNameAuditNonHMACRequestKeys is the flag name used for auth/secrets enable flagNameAuditNonHMACRequestKeys = "audit-non-hmac-request-keys" // flagNameAuditNonHMACResponseKeys is the flag name used for auth/secrets enable