Support reading Vault's address from Agent's config file (#6306)

* Support reading Vault's address from Agent's config file

* use consts and switch

* Add tls options to agent config vault block

* Update command/agent/config/config.go

Co-Authored-By: vishalnayak <vishalnayak@users.noreply.github.com>

* remove fmt.Printfs
This commit is contained in:
Vishal Nayak 2019-02-28 17:29:28 -05:00 committed by GitHub
parent 4f35c548fe
commit ac2b499fc9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 171 additions and 25 deletions

View File

@ -25,14 +25,15 @@ import (
"golang.org/x/time/rate" "golang.org/x/time/rate"
) )
const EnvVaultAgentAddress = "VAULT_AGENT_ADDR"
const EnvVaultAddress = "VAULT_ADDR" const EnvVaultAddress = "VAULT_ADDR"
const EnvVaultAgentAddr = "VAULT_AGENT_ADDR"
const EnvVaultCACert = "VAULT_CACERT" const EnvVaultCACert = "VAULT_CACERT"
const EnvVaultCAPath = "VAULT_CAPATH" const EnvVaultCAPath = "VAULT_CAPATH"
const EnvVaultClientCert = "VAULT_CLIENT_CERT" const EnvVaultClientCert = "VAULT_CLIENT_CERT"
const EnvVaultClientKey = "VAULT_CLIENT_KEY" const EnvVaultClientKey = "VAULT_CLIENT_KEY"
const EnvVaultClientTimeout = "VAULT_CLIENT_TIMEOUT" const EnvVaultClientTimeout = "VAULT_CLIENT_TIMEOUT"
const EnvVaultInsecure = "VAULT_SKIP_VERIFY" const EnvVaultSkipVerify = "VAULT_SKIP_VERIFY"
const EnvVaultNamespace = "VAULT_NAMESPACE"
const EnvVaultTLSServerName = "VAULT_TLS_SERVER_NAME" const EnvVaultTLSServerName = "VAULT_TLS_SERVER_NAME"
const EnvVaultWrapTTL = "VAULT_WRAP_TTL" const EnvVaultWrapTTL = "VAULT_WRAP_TTL"
const EnvVaultMaxRetries = "VAULT_MAX_RETRIES" const EnvVaultMaxRetries = "VAULT_MAX_RETRIES"
@ -243,7 +244,7 @@ func (c *Config) ReadEnvironment() error {
if v := os.Getenv(EnvVaultAddress); v != "" { if v := os.Getenv(EnvVaultAddress); v != "" {
envAddress = v envAddress = v
} }
if v := os.Getenv(EnvVaultAgentAddress); v != "" { if v := os.Getenv(EnvVaultAgentAddr); v != "" {
envAgentAddress = v envAgentAddress = v
} }
if v := os.Getenv(EnvVaultMaxRetries); v != "" { if v := os.Getenv(EnvVaultMaxRetries); v != "" {
@ -279,7 +280,7 @@ func (c *Config) ReadEnvironment() error {
} }
envClientTimeout = clientTimeout envClientTimeout = clientTimeout
} }
if v := os.Getenv(EnvVaultInsecure); v != "" { if v := os.Getenv(EnvVaultSkipVerify); v != "" {
var err error var err error
envInsecure, err = strconv.ParseBool(v) envInsecure, err = strconv.ParseBool(v)
if err != nil { if err != nil {

View File

@ -163,19 +163,19 @@ func TestClientEnvSettings(t *testing.T) {
oldCAPath := os.Getenv(EnvVaultCAPath) oldCAPath := os.Getenv(EnvVaultCAPath)
oldClientCert := os.Getenv(EnvVaultClientCert) oldClientCert := os.Getenv(EnvVaultClientCert)
oldClientKey := os.Getenv(EnvVaultClientKey) oldClientKey := os.Getenv(EnvVaultClientKey)
oldSkipVerify := os.Getenv(EnvVaultInsecure) oldSkipVerify := os.Getenv(EnvVaultSkipVerify)
oldMaxRetries := os.Getenv(EnvVaultMaxRetries) oldMaxRetries := os.Getenv(EnvVaultMaxRetries)
os.Setenv(EnvVaultCACert, cwd+"/test-fixtures/keys/cert.pem") os.Setenv(EnvVaultCACert, cwd+"/test-fixtures/keys/cert.pem")
os.Setenv(EnvVaultCAPath, cwd+"/test-fixtures/keys") os.Setenv(EnvVaultCAPath, cwd+"/test-fixtures/keys")
os.Setenv(EnvVaultClientCert, cwd+"/test-fixtures/keys/cert.pem") os.Setenv(EnvVaultClientCert, cwd+"/test-fixtures/keys/cert.pem")
os.Setenv(EnvVaultClientKey, cwd+"/test-fixtures/keys/key.pem") os.Setenv(EnvVaultClientKey, cwd+"/test-fixtures/keys/key.pem")
os.Setenv(EnvVaultInsecure, "true") os.Setenv(EnvVaultSkipVerify, "true")
os.Setenv(EnvVaultMaxRetries, "5") os.Setenv(EnvVaultMaxRetries, "5")
defer os.Setenv(EnvVaultCACert, oldCACert) defer os.Setenv(EnvVaultCACert, oldCACert)
defer os.Setenv(EnvVaultCAPath, oldCAPath) defer os.Setenv(EnvVaultCAPath, oldCAPath)
defer os.Setenv(EnvVaultClientCert, oldClientCert) defer os.Setenv(EnvVaultClientCert, oldClientCert)
defer os.Setenv(EnvVaultClientKey, oldClientKey) defer os.Setenv(EnvVaultClientKey, oldClientKey)
defer os.Setenv(EnvVaultInsecure, oldSkipVerify) defer os.Setenv(EnvVaultSkipVerify, oldSkipVerify)
defer os.Setenv(EnvVaultMaxRetries, oldMaxRetries) defer os.Setenv(EnvVaultMaxRetries, oldMaxRetries)
config := DefaultConfig() config := DefaultConfig()

View File

@ -2,6 +2,7 @@ package command
import ( import (
"context" "context"
"flag"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -206,6 +207,33 @@ func (c *AgentCommand) Run(args []string) int {
return 1 return 1
} }
if config.Vault != nil {
c.setStringFlag(f, config.Vault.Address, &StringVar{
Name: flagNameAddress,
Target: &c.flagAddress,
Default: "https://127.0.0.1:8200",
EnvVar: api.EnvVaultAddress,
})
c.setStringFlag(f, config.Vault.CACert, &StringVar{
Name: flagNameCACert,
Target: &c.flagCACert,
Default: "",
EnvVar: api.EnvVaultCACert,
})
c.setStringFlag(f, config.Vault.CAPath, &StringVar{
Name: flagNameCAPath,
Target: &c.flagCAPath,
Default: "",
EnvVar: api.EnvVaultCAPath,
})
c.setBoolFlag(f, config.Vault.TLSSkipVerify, &BoolVar{
Name: flagNameTLSSkipVerify,
Target: &c.flagTLSSkipVerify,
Default: false,
EnvVar: api.EnvVaultSkipVerify,
})
}
infoKeys := make([]string, 0, 10) infoKeys := make([]string, 0, 10)
info := make(map[string]string) info := make(map[string]string)
info["log level"] = c.flagLogLevel info["log level"] = c.flagLogLevel
@ -235,6 +263,9 @@ func (c *AgentCommand) Run(args []string) int {
return 0 return 0
} }
// Ignore any setting of agent's address. This client is used by the agent
// to reach out to Vault. This should never loop back to agent.
c.flagAgentAddress = ""
client, err := c.Client() client, err := c.Client()
if err != nil { if err != nil {
c.UI.Error(fmt.Sprintf( c.UI.Error(fmt.Sprintf(
@ -472,6 +503,54 @@ func (c *AgentCommand) Run(args []string) int {
return 0 return 0
} }
func (c *AgentCommand) setStringFlag(f *FlagSets, configVal string, fVar *StringVar) {
var isFlagSet bool
f.Visit(func(f *flag.Flag) {
if f.Name == fVar.Name {
isFlagSet = true
}
})
flagEnvValue, flagEnvSet := os.LookupEnv(fVar.EnvVar)
switch {
case isFlagSet:
// Don't do anything as the flag is already set from the command line
case flagEnvSet:
// Use value from env var
*fVar.Target = flagEnvValue
case configVal != "":
// Use value from config
*fVar.Target = configVal
default:
// Use the default value
*fVar.Target = fVar.Default
}
}
func (c *AgentCommand) setBoolFlag(f *FlagSets, configVal bool, fVar *BoolVar) {
var isFlagSet bool
f.Visit(func(f *flag.Flag) {
if f.Name == fVar.Name {
isFlagSet = true
}
})
flagEnvValue, flagEnvSet := os.LookupEnv(fVar.EnvVar)
switch {
case isFlagSet:
// Don't do anything as the flag is already set from the command line
case flagEnvSet:
// Use value from env var
*fVar.Target = flagEnvValue != ""
case configVal == true:
// Use value from config
*fVar.Target = configVal
default:
// Use the default value
*fVar.Target = fVar.Default
}
}
// storePidFile is used to write out our PID to a file if necessary // storePidFile is used to write out our PID to a file if necessary
func (c *AgentCommand) storePidFile(pidPath string) error { func (c *AgentCommand) storePidFile(pidPath string) error {
// Quit fast if no pidfile // Quit fast if no pidfile

View File

@ -23,6 +23,14 @@ type Config struct {
ExitAfterAuth bool `hcl:"exit_after_auth"` ExitAfterAuth bool `hcl:"exit_after_auth"`
PidFile string `hcl:"pid_file"` PidFile string `hcl:"pid_file"`
Cache *Cache `hcl:"cache"` Cache *Cache `hcl:"cache"`
Vault *Vault `hcl:"vault"`
}
type Vault struct {
Address string `hcl:"address"`
CACert string `hcl:"ca_cert"`
CAPath string `hcl:"ca_path"`
TLSSkipVerify bool `hcl:"tls_skip_verify"`
} }
type Cache struct { type Cache struct {
@ -107,9 +115,35 @@ func LoadConfig(path string, logger log.Logger) (*Config, error) {
return nil, errwrap.Wrapf("error parsing 'cache':{{err}}", err) return nil, errwrap.Wrapf("error parsing 'cache':{{err}}", err)
} }
err = parseVault(&result, list)
if err != nil {
return nil, errwrap.Wrapf("error parsing 'vault':{{err}}", err)
}
return &result, nil return &result, nil
} }
func parseVault(result *Config, list *ast.ObjectList) error {
name := "vault"
vaultList := list.Filter(name)
if len(vaultList.Items) > 1 {
return fmt.Errorf("one and only one %q block is required", name)
}
item := vaultList.Items[0]
var v Vault
err := hcl.DecodeObject(&v, item.Val)
if err != nil {
return err
}
result.Vault = &v
return nil
}
func parseCache(result *Config, list *ast.ObjectList) error { func parseCache(result *Config, list *ast.ObjectList) error {
name := "cache" name := "cache"

View File

@ -67,6 +67,12 @@ func TestLoadConfigFile_AgentCache(t *testing.T) {
}, },
}, },
}, },
Vault: &Vault{
Address: "http://127.0.0.1:1111",
CACert: "config_ca_cert",
CAPath: "config_ca_path",
TLSSkipVerify: true,
},
PidFile: "./pidfile", PidFile: "./pidfile",
} }

View File

@ -42,3 +42,10 @@ cache {
tls_cert_file = "/path/to/cacert.pem" tls_cert_file = "/path/to/cacert.pem"
} }
} }
vault {
address = "http://127.0.0.1:1111"
ca_cert = "config_ca_cert"
ca_path = "config_ca_path"
tls_skip_verify = "true"
}

View File

@ -39,3 +39,10 @@ cache {
tls_cert_file = "/path/to/cacert.pem" tls_cert_file = "/path/to/cacert.pem"
} }
} }
vault {
address = "http://127.0.0.1:1111"
ca_cert = "config_ca_cert"
ca_path = "config_ca_path"
tls_skip_verify = "true"
}

View File

@ -188,15 +188,15 @@ cache {
} }
}() }()
originalVaultAgentAddress := os.Getenv(api.EnvVaultAgentAddress) originalVaultAgentAddress := os.Getenv(api.EnvVaultAgentAddr)
// Create a client that talks to the agent // Create a client that talks to the agent
os.Setenv(api.EnvVaultAgentAddress, socketf) os.Setenv(api.EnvVaultAgentAddr, socketf)
testClient, err := api.NewClient(api.DefaultConfig()) testClient, err := api.NewClient(api.DefaultConfig())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
os.Setenv(api.EnvVaultAgentAddress, originalVaultAgentAddress) os.Setenv(api.EnvVaultAgentAddr, originalVaultAgentAddress)
// Start the agent // Start the agent
go cmd.Run([]string{"-config", conf}) go cmd.Run([]string{"-config", conf})

View File

@ -211,9 +211,9 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
f := set.NewFlagSet("HTTP Options") f := set.NewFlagSet("HTTP Options")
addrStringVar := &StringVar{ addrStringVar := &StringVar{
Name: "address", Name: flagNameAddress,
Target: &c.flagAddress, Target: &c.flagAddress,
EnvVar: "VAULT_ADDR", EnvVar: api.EnvVaultAddress,
Completion: complete.PredictAnything, Completion: complete.PredictAnything,
Usage: "Address of the Vault server.", Usage: "Address of the Vault server.",
} }
@ -227,17 +227,17 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
agentAddrStringVar := &StringVar{ agentAddrStringVar := &StringVar{
Name: "agent-address", Name: "agent-address",
Target: &c.flagAgentAddress, Target: &c.flagAgentAddress,
EnvVar: "VAULT_AGENT_ADDR", EnvVar: api.EnvVaultAgentAddr,
Completion: complete.PredictAnything, Completion: complete.PredictAnything,
Usage: "Address of the Agent.", Usage: "Address of the Agent.",
} }
f.StringVar(agentAddrStringVar) f.StringVar(agentAddrStringVar)
f.StringVar(&StringVar{ f.StringVar(&StringVar{
Name: "ca-cert", Name: flagNameCACert,
Target: &c.flagCACert, Target: &c.flagCACert,
Default: "", Default: "",
EnvVar: "VAULT_CACERT", EnvVar: api.EnvVaultCACert,
Completion: complete.PredictFiles("*"), Completion: complete.PredictFiles("*"),
Usage: "Path on the local disk to a single PEM-encoded CA " + Usage: "Path on the local disk to a single PEM-encoded CA " +
"certificate to verify the Vault server's SSL certificate. This " + "certificate to verify the Vault server's SSL certificate. This " +
@ -245,10 +245,10 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
}) })
f.StringVar(&StringVar{ f.StringVar(&StringVar{
Name: "ca-path", Name: flagNameCAPath,
Target: &c.flagCAPath, Target: &c.flagCAPath,
Default: "", Default: "",
EnvVar: "VAULT_CAPATH", EnvVar: api.EnvVaultCAPath,
Completion: complete.PredictDirs("*"), Completion: complete.PredictDirs("*"),
Usage: "Path on the local disk to a directory of PEM-encoded CA " + Usage: "Path on the local disk to a directory of PEM-encoded CA " +
"certificates to verify the Vault server's SSL certificate.", "certificates to verify the Vault server's SSL certificate.",
@ -258,7 +258,7 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
Name: "client-cert", Name: "client-cert",
Target: &c.flagClientCert, Target: &c.flagClientCert,
Default: "", Default: "",
EnvVar: "VAULT_CLIENT_CERT", EnvVar: api.EnvVaultClientCert,
Completion: complete.PredictFiles("*"), Completion: complete.PredictFiles("*"),
Usage: "Path on the local disk to a single PEM-encoded CA " + Usage: "Path on the local disk to a single PEM-encoded CA " +
"certificate to use for TLS authentication to the Vault server. If " + "certificate to use for TLS authentication to the Vault server. If " +
@ -269,7 +269,7 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
Name: "client-key", Name: "client-key",
Target: &c.flagClientKey, Target: &c.flagClientKey,
Default: "", Default: "",
EnvVar: "VAULT_CLIENT_KEY", EnvVar: api.EnvVaultClientKey,
Completion: complete.PredictFiles("*"), Completion: complete.PredictFiles("*"),
Usage: "Path on the local disk to a single PEM-encoded private key " + Usage: "Path on the local disk to a single PEM-encoded private key " +
"matching the client certificate from -client-cert.", "matching the client certificate from -client-cert.",
@ -279,7 +279,7 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
Name: "namespace", Name: "namespace",
Target: &c.flagNamespace, Target: &c.flagNamespace,
Default: notSetValue, // this can never be a real value Default: notSetValue, // this can never be a real value
EnvVar: "VAULT_NAMESPACE", EnvVar: api.EnvVaultNamespace,
Completion: complete.PredictAnything, Completion: complete.PredictAnything,
Usage: "The namespace to use for the command. Setting this is not " + Usage: "The namespace to use for the command. Setting this is not " +
"necessary but allows using relative paths. -ns can be used as " + "necessary but allows using relative paths. -ns can be used as " +
@ -299,17 +299,17 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
Name: "tls-server-name", Name: "tls-server-name",
Target: &c.flagTLSServerName, Target: &c.flagTLSServerName,
Default: "", Default: "",
EnvVar: "VAULT_TLS_SERVER_NAME", EnvVar: api.EnvVaultTLSServerName,
Completion: complete.PredictAnything, Completion: complete.PredictAnything,
Usage: "Name to use as the SNI host when connecting to the Vault " + Usage: "Name to use as the SNI host when connecting to the Vault " +
"server via TLS.", "server via TLS.",
}) })
f.BoolVar(&BoolVar{ f.BoolVar(&BoolVar{
Name: "tls-skip-verify", Name: flagNameTLSSkipVerify,
Target: &c.flagTLSSkipVerify, Target: &c.flagTLSSkipVerify,
Default: false, Default: false,
EnvVar: "VAULT_SKIP_VERIFY", EnvVar: api.EnvVaultSkipVerify,
Usage: "Disable verification of TLS certificates. Using this option " + Usage: "Disable verification of TLS certificates. Using this option " +
"is highly discouraged and decreases the security of data " + "is highly discouraged and decreases the security of data " +
"transmissions to and from the Vault server.", "transmissions to and from the Vault server.",
@ -327,7 +327,7 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
Name: "wrap-ttl", Name: "wrap-ttl",
Target: &c.flagWrapTTL, Target: &c.flagWrapTTL,
Default: 0, Default: 0,
EnvVar: "VAULT_WRAP_TTL", EnvVar: api.EnvVaultWrapTTL,
Completion: complete.PredictAnything, Completion: complete.PredictAnything,
Usage: "Wraps the response in a cubbyhole token with the requested " + Usage: "Wraps the response in a cubbyhole token with the requested " +
"TTL. The response is available via the \"vault unwrap\" command. " + "TTL. The response is available via the \"vault unwrap\" command. " +

View File

@ -743,7 +743,7 @@ func (f *FlagSet) VarFlag(i *VarFlag) {
} }
// Var is a lower-level API for adding something to the flags. It should be used // Var is a lower-level API for adding something to the flags. It should be used
// wtih caution, since it bypasses all validation. Consider VarFlag instead. // with caution, since it bypasses all validation. Consider VarFlag instead.
func (f *FlagSet) Var(value flag.Value, name, usage string) { func (f *FlagSet) Var(value flag.Value, name, usage string) {
f.mainSet.Var(value, name, usage) f.mainSet.Var(value, name, usage)
f.flagSet.Var(value, name, usage) f.flagSet.Var(value, name, usage)

View File

@ -66,6 +66,18 @@ const (
// EnvVaultFormat is the output format // EnvVaultFormat is the output format
EnvVaultFormat = `VAULT_FORMAT` EnvVaultFormat = `VAULT_FORMAT`
// flagNameAddress is the flag used in the base command to read in the
// address of the Vault server.
flagNameAddress = "address"
// flagnameCACert is the flag used in the base command to read in the CA
// cert.
flagNameCACert = "ca-cert"
// flagnameCAPath is the flag used in the base command to read in the CA
// cert path.
flagNameCAPath = "ca-path"
// flagNameTLSSkipVerify is the flag used in the base command to read in
// the option to ignore TLS certificate verification.
flagNameTLSSkipVerify = "tls-skip-verify"
// flagNameAuditNonHMACRequestKeys is the flag name used for auth/secrets enable // flagNameAuditNonHMACRequestKeys is the flag name used for auth/secrets enable
flagNameAuditNonHMACRequestKeys = "audit-non-hmac-request-keys" flagNameAuditNonHMACRequestKeys = "audit-non-hmac-request-keys"
// flagNameAuditNonHMACResponseKeys is the flag name used for auth/secrets enable // flagNameAuditNonHMACResponseKeys is the flag name used for auth/secrets enable