Support reading Vault's address from Agent's config file (#6306)

* Support reading Vault's address from Agent's config file

* use consts and switch

* Add tls options to agent config vault block

* Update command/agent/config/config.go

Co-Authored-By: vishalnayak <vishalnayak@users.noreply.github.com>

* remove fmt.Printfs
This commit is contained in:
Vishal Nayak 2019-02-28 17:29:28 -05:00 committed by GitHub
parent 4f35c548fe
commit ac2b499fc9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 171 additions and 25 deletions

View File

@ -25,14 +25,15 @@ import (
"golang.org/x/time/rate"
)
const EnvVaultAgentAddress = "VAULT_AGENT_ADDR"
const EnvVaultAddress = "VAULT_ADDR"
const EnvVaultAgentAddr = "VAULT_AGENT_ADDR"
const EnvVaultCACert = "VAULT_CACERT"
const EnvVaultCAPath = "VAULT_CAPATH"
const EnvVaultClientCert = "VAULT_CLIENT_CERT"
const EnvVaultClientKey = "VAULT_CLIENT_KEY"
const EnvVaultClientTimeout = "VAULT_CLIENT_TIMEOUT"
const EnvVaultInsecure = "VAULT_SKIP_VERIFY"
const EnvVaultSkipVerify = "VAULT_SKIP_VERIFY"
const EnvVaultNamespace = "VAULT_NAMESPACE"
const EnvVaultTLSServerName = "VAULT_TLS_SERVER_NAME"
const EnvVaultWrapTTL = "VAULT_WRAP_TTL"
const EnvVaultMaxRetries = "VAULT_MAX_RETRIES"
@ -243,7 +244,7 @@ func (c *Config) ReadEnvironment() error {
if v := os.Getenv(EnvVaultAddress); v != "" {
envAddress = v
}
if v := os.Getenv(EnvVaultAgentAddress); v != "" {
if v := os.Getenv(EnvVaultAgentAddr); v != "" {
envAgentAddress = v
}
if v := os.Getenv(EnvVaultMaxRetries); v != "" {
@ -279,7 +280,7 @@ func (c *Config) ReadEnvironment() error {
}
envClientTimeout = clientTimeout
}
if v := os.Getenv(EnvVaultInsecure); v != "" {
if v := os.Getenv(EnvVaultSkipVerify); v != "" {
var err error
envInsecure, err = strconv.ParseBool(v)
if err != nil {

View File

@ -163,19 +163,19 @@ func TestClientEnvSettings(t *testing.T) {
oldCAPath := os.Getenv(EnvVaultCAPath)
oldClientCert := os.Getenv(EnvVaultClientCert)
oldClientKey := os.Getenv(EnvVaultClientKey)
oldSkipVerify := os.Getenv(EnvVaultInsecure)
oldSkipVerify := os.Getenv(EnvVaultSkipVerify)
oldMaxRetries := os.Getenv(EnvVaultMaxRetries)
os.Setenv(EnvVaultCACert, cwd+"/test-fixtures/keys/cert.pem")
os.Setenv(EnvVaultCAPath, cwd+"/test-fixtures/keys")
os.Setenv(EnvVaultClientCert, cwd+"/test-fixtures/keys/cert.pem")
os.Setenv(EnvVaultClientKey, cwd+"/test-fixtures/keys/key.pem")
os.Setenv(EnvVaultInsecure, "true")
os.Setenv(EnvVaultSkipVerify, "true")
os.Setenv(EnvVaultMaxRetries, "5")
defer os.Setenv(EnvVaultCACert, oldCACert)
defer os.Setenv(EnvVaultCAPath, oldCAPath)
defer os.Setenv(EnvVaultClientCert, oldClientCert)
defer os.Setenv(EnvVaultClientKey, oldClientKey)
defer os.Setenv(EnvVaultInsecure, oldSkipVerify)
defer os.Setenv(EnvVaultSkipVerify, oldSkipVerify)
defer os.Setenv(EnvVaultMaxRetries, oldMaxRetries)
config := DefaultConfig()

View File

@ -2,6 +2,7 @@ package command
import (
"context"
"flag"
"fmt"
"io"
"net"
@ -206,6 +207,33 @@ func (c *AgentCommand) Run(args []string) int {
return 1
}
if config.Vault != nil {
c.setStringFlag(f, config.Vault.Address, &StringVar{
Name: flagNameAddress,
Target: &c.flagAddress,
Default: "https://127.0.0.1:8200",
EnvVar: api.EnvVaultAddress,
})
c.setStringFlag(f, config.Vault.CACert, &StringVar{
Name: flagNameCACert,
Target: &c.flagCACert,
Default: "",
EnvVar: api.EnvVaultCACert,
})
c.setStringFlag(f, config.Vault.CAPath, &StringVar{
Name: flagNameCAPath,
Target: &c.flagCAPath,
Default: "",
EnvVar: api.EnvVaultCAPath,
})
c.setBoolFlag(f, config.Vault.TLSSkipVerify, &BoolVar{
Name: flagNameTLSSkipVerify,
Target: &c.flagTLSSkipVerify,
Default: false,
EnvVar: api.EnvVaultSkipVerify,
})
}
infoKeys := make([]string, 0, 10)
info := make(map[string]string)
info["log level"] = c.flagLogLevel
@ -235,6 +263,9 @@ func (c *AgentCommand) Run(args []string) int {
return 0
}
// Ignore any setting of agent's address. This client is used by the agent
// to reach out to Vault. This should never loop back to agent.
c.flagAgentAddress = ""
client, err := c.Client()
if err != nil {
c.UI.Error(fmt.Sprintf(
@ -472,6 +503,54 @@ func (c *AgentCommand) Run(args []string) int {
return 0
}
func (c *AgentCommand) setStringFlag(f *FlagSets, configVal string, fVar *StringVar) {
var isFlagSet bool
f.Visit(func(f *flag.Flag) {
if f.Name == fVar.Name {
isFlagSet = true
}
})
flagEnvValue, flagEnvSet := os.LookupEnv(fVar.EnvVar)
switch {
case isFlagSet:
// Don't do anything as the flag is already set from the command line
case flagEnvSet:
// Use value from env var
*fVar.Target = flagEnvValue
case configVal != "":
// Use value from config
*fVar.Target = configVal
default:
// Use the default value
*fVar.Target = fVar.Default
}
}
func (c *AgentCommand) setBoolFlag(f *FlagSets, configVal bool, fVar *BoolVar) {
var isFlagSet bool
f.Visit(func(f *flag.Flag) {
if f.Name == fVar.Name {
isFlagSet = true
}
})
flagEnvValue, flagEnvSet := os.LookupEnv(fVar.EnvVar)
switch {
case isFlagSet:
// Don't do anything as the flag is already set from the command line
case flagEnvSet:
// Use value from env var
*fVar.Target = flagEnvValue != ""
case configVal == true:
// Use value from config
*fVar.Target = configVal
default:
// Use the default value
*fVar.Target = fVar.Default
}
}
// storePidFile is used to write out our PID to a file if necessary
func (c *AgentCommand) storePidFile(pidPath string) error {
// Quit fast if no pidfile

View File

@ -23,6 +23,14 @@ type Config struct {
ExitAfterAuth bool `hcl:"exit_after_auth"`
PidFile string `hcl:"pid_file"`
Cache *Cache `hcl:"cache"`
Vault *Vault `hcl:"vault"`
}
type Vault struct {
Address string `hcl:"address"`
CACert string `hcl:"ca_cert"`
CAPath string `hcl:"ca_path"`
TLSSkipVerify bool `hcl:"tls_skip_verify"`
}
type Cache struct {
@ -107,9 +115,35 @@ func LoadConfig(path string, logger log.Logger) (*Config, error) {
return nil, errwrap.Wrapf("error parsing 'cache':{{err}}", err)
}
err = parseVault(&result, list)
if err != nil {
return nil, errwrap.Wrapf("error parsing 'vault':{{err}}", err)
}
return &result, nil
}
func parseVault(result *Config, list *ast.ObjectList) error {
name := "vault"
vaultList := list.Filter(name)
if len(vaultList.Items) > 1 {
return fmt.Errorf("one and only one %q block is required", name)
}
item := vaultList.Items[0]
var v Vault
err := hcl.DecodeObject(&v, item.Val)
if err != nil {
return err
}
result.Vault = &v
return nil
}
func parseCache(result *Config, list *ast.ObjectList) error {
name := "cache"

View File

@ -67,6 +67,12 @@ func TestLoadConfigFile_AgentCache(t *testing.T) {
},
},
},
Vault: &Vault{
Address: "http://127.0.0.1:1111",
CACert: "config_ca_cert",
CAPath: "config_ca_path",
TLSSkipVerify: true,
},
PidFile: "./pidfile",
}

View File

@ -42,3 +42,10 @@ cache {
tls_cert_file = "/path/to/cacert.pem"
}
}
vault {
address = "http://127.0.0.1:1111"
ca_cert = "config_ca_cert"
ca_path = "config_ca_path"
tls_skip_verify = "true"
}

View File

@ -39,3 +39,10 @@ cache {
tls_cert_file = "/path/to/cacert.pem"
}
}
vault {
address = "http://127.0.0.1:1111"
ca_cert = "config_ca_cert"
ca_path = "config_ca_path"
tls_skip_verify = "true"
}

View File

@ -188,15 +188,15 @@ cache {
}
}()
originalVaultAgentAddress := os.Getenv(api.EnvVaultAgentAddress)
originalVaultAgentAddress := os.Getenv(api.EnvVaultAgentAddr)
// Create a client that talks to the agent
os.Setenv(api.EnvVaultAgentAddress, socketf)
os.Setenv(api.EnvVaultAgentAddr, socketf)
testClient, err := api.NewClient(api.DefaultConfig())
if err != nil {
t.Fatal(err)
}
os.Setenv(api.EnvVaultAgentAddress, originalVaultAgentAddress)
os.Setenv(api.EnvVaultAgentAddr, originalVaultAgentAddress)
// Start the agent
go cmd.Run([]string{"-config", conf})

View File

@ -211,9 +211,9 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
f := set.NewFlagSet("HTTP Options")
addrStringVar := &StringVar{
Name: "address",
Name: flagNameAddress,
Target: &c.flagAddress,
EnvVar: "VAULT_ADDR",
EnvVar: api.EnvVaultAddress,
Completion: complete.PredictAnything,
Usage: "Address of the Vault server.",
}
@ -227,17 +227,17 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
agentAddrStringVar := &StringVar{
Name: "agent-address",
Target: &c.flagAgentAddress,
EnvVar: "VAULT_AGENT_ADDR",
EnvVar: api.EnvVaultAgentAddr,
Completion: complete.PredictAnything,
Usage: "Address of the Agent.",
}
f.StringVar(agentAddrStringVar)
f.StringVar(&StringVar{
Name: "ca-cert",
Name: flagNameCACert,
Target: &c.flagCACert,
Default: "",
EnvVar: "VAULT_CACERT",
EnvVar: api.EnvVaultCACert,
Completion: complete.PredictFiles("*"),
Usage: "Path on the local disk to a single PEM-encoded CA " +
"certificate to verify the Vault server's SSL certificate. This " +
@ -245,10 +245,10 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
})
f.StringVar(&StringVar{
Name: "ca-path",
Name: flagNameCAPath,
Target: &c.flagCAPath,
Default: "",
EnvVar: "VAULT_CAPATH",
EnvVar: api.EnvVaultCAPath,
Completion: complete.PredictDirs("*"),
Usage: "Path on the local disk to a directory of PEM-encoded CA " +
"certificates to verify the Vault server's SSL certificate.",
@ -258,7 +258,7 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
Name: "client-cert",
Target: &c.flagClientCert,
Default: "",
EnvVar: "VAULT_CLIENT_CERT",
EnvVar: api.EnvVaultClientCert,
Completion: complete.PredictFiles("*"),
Usage: "Path on the local disk to a single PEM-encoded CA " +
"certificate to use for TLS authentication to the Vault server. If " +
@ -269,7 +269,7 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
Name: "client-key",
Target: &c.flagClientKey,
Default: "",
EnvVar: "VAULT_CLIENT_KEY",
EnvVar: api.EnvVaultClientKey,
Completion: complete.PredictFiles("*"),
Usage: "Path on the local disk to a single PEM-encoded private key " +
"matching the client certificate from -client-cert.",
@ -279,7 +279,7 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
Name: "namespace",
Target: &c.flagNamespace,
Default: notSetValue, // this can never be a real value
EnvVar: "VAULT_NAMESPACE",
EnvVar: api.EnvVaultNamespace,
Completion: complete.PredictAnything,
Usage: "The namespace to use for the command. Setting this is not " +
"necessary but allows using relative paths. -ns can be used as " +
@ -299,17 +299,17 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
Name: "tls-server-name",
Target: &c.flagTLSServerName,
Default: "",
EnvVar: "VAULT_TLS_SERVER_NAME",
EnvVar: api.EnvVaultTLSServerName,
Completion: complete.PredictAnything,
Usage: "Name to use as the SNI host when connecting to the Vault " +
"server via TLS.",
})
f.BoolVar(&BoolVar{
Name: "tls-skip-verify",
Name: flagNameTLSSkipVerify,
Target: &c.flagTLSSkipVerify,
Default: false,
EnvVar: "VAULT_SKIP_VERIFY",
EnvVar: api.EnvVaultSkipVerify,
Usage: "Disable verification of TLS certificates. Using this option " +
"is highly discouraged and decreases the security of data " +
"transmissions to and from the Vault server.",
@ -327,7 +327,7 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
Name: "wrap-ttl",
Target: &c.flagWrapTTL,
Default: 0,
EnvVar: "VAULT_WRAP_TTL",
EnvVar: api.EnvVaultWrapTTL,
Completion: complete.PredictAnything,
Usage: "Wraps the response in a cubbyhole token with the requested " +
"TTL. The response is available via the \"vault unwrap\" command. " +

View File

@ -743,7 +743,7 @@ func (f *FlagSet) VarFlag(i *VarFlag) {
}
// Var is a lower-level API for adding something to the flags. It should be used
// wtih caution, since it bypasses all validation. Consider VarFlag instead.
// with caution, since it bypasses all validation. Consider VarFlag instead.
func (f *FlagSet) Var(value flag.Value, name, usage string) {
f.mainSet.Var(value, name, usage)
f.flagSet.Var(value, name, usage)

View File

@ -66,6 +66,18 @@ const (
// EnvVaultFormat is the output format
EnvVaultFormat = `VAULT_FORMAT`
// flagNameAddress is the flag used in the base command to read in the
// address of the Vault server.
flagNameAddress = "address"
// flagnameCACert is the flag used in the base command to read in the CA
// cert.
flagNameCACert = "ca-cert"
// flagnameCAPath is the flag used in the base command to read in the CA
// cert path.
flagNameCAPath = "ca-path"
// flagNameTLSSkipVerify is the flag used in the base command to read in
// the option to ignore TLS certificate verification.
flagNameTLSSkipVerify = "tls-skip-verify"
// flagNameAuditNonHMACRequestKeys is the flag name used for auth/secrets enable
flagNameAuditNonHMACRequestKeys = "audit-non-hmac-request-keys"
// flagNameAuditNonHMACResponseKeys is the flag name used for auth/secrets enable