Fix SRV Lookups (#8520)

* Pin HTTP Host header for all client requests
* Drop port map scheme
* Add SRV Lookup environment var
* Lookup SRV records only when env var is specified
* Add docs

Co-Authored-By: Michel Vocks <michelvocks@gmail.com>
This commit is contained in:
Daniel Spangenberg 2020-03-11 14:22:58 +01:00 committed by GitHub
parent 0b09580c36
commit 8007845ba4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 60 additions and 19 deletions

View File

@ -32,6 +32,7 @@ const EnvVaultCAPath = "VAULT_CAPATH"
const EnvVaultClientCert = "VAULT_CLIENT_CERT"
const EnvVaultClientKey = "VAULT_CLIENT_KEY"
const EnvVaultClientTimeout = "VAULT_CLIENT_TIMEOUT"
const EnvVaultSRVLookup = "VAULT_SRV_LOOKUP"
const EnvVaultSkipVerify = "VAULT_SKIP_VERIFY"
const EnvVaultNamespace = "VAULT_NAMESPACE"
const EnvVaultTLSServerName = "VAULT_TLS_SERVER_NAME"
@ -105,6 +106,9 @@ type Config struct {
// Note: It is not thread-safe to set this and make concurrent requests
// with the same client. Cloning a client will not clone this value.
OutputCurlString bool
// SRVLookup enables the client to lookup the host through DNS SRV lookup
SRVLookup bool
}
// TLSConfig contains the parameters needed to configure TLS on the HTTP client
@ -245,6 +249,7 @@ func (c *Config) ReadEnvironment() error {
var envInsecure bool
var envTLSServerName string
var envMaxRetries *uint64
var envSRVLookup bool
var limit *rate.Limiter
// Parse the environment variables
@ -302,6 +307,13 @@ func (c *Config) ReadEnvironment() error {
return fmt.Errorf("could not parse VAULT_INSECURE")
}
}
if v := os.Getenv(EnvVaultSRVLookup); v != "" {
var err error
envSRVLookup, err = strconv.ParseBool(v)
if err != nil {
return fmt.Errorf("could not parse %s", EnvVaultSRVLookup)
}
}
if v := os.Getenv(EnvVaultTLSServerName); v != "" {
envTLSServerName = v
@ -320,6 +332,7 @@ func (c *Config) ReadEnvironment() error {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
c.SRVLookup = envSRVLookup
c.Limiter = limit
if err := c.ConfigureTLS(t); err != nil {
@ -686,12 +699,6 @@ func (c *Client) SetPolicyOverride(override bool) {
c.policyOverride = override
}
// portMap defines the standard port map
var portMap = map[string]string{
"http": "80",
"https": "443",
}
// NewRequest creates a new raw request object to query the Vault server
// configured for this client. This is an advanced method and generally
// doesn't need to be called externally.
@ -704,20 +711,14 @@ func (c *Client) NewRequest(method, requestPath string) *Request {
policyOverride := c.policyOverride
c.modifyLock.RUnlock()
var host = addr.Host
// if SRV records exist (see https://tools.ietf.org/html/draft-andrews-http-srv-02), lookup the SRV
// record and take the highest match; this is not designed for high-availability, just discovery
var host string = addr.Host
if addr.Port() == "" {
// Avoid lookup of SRV record if scheme is known
port, ok := portMap[addr.Scheme]
if ok {
host = net.JoinHostPort(host, port)
} else {
// Internet Draft specifies that the SRV record is ignored if a port is given
_, addrs, err := net.LookupSRV("http", "tcp", addr.Hostname())
if err == nil && len(addrs) > 0 {
host = fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port)
}
// Internet Draft specifies that the SRV record is ignored if a port is given
if addr.Port() == "" && c.config.SRVLookup {
_, addrs, err := net.LookupSRV("http", "tcp", addr.Hostname())
if err == nil && len(addrs) > 0 {
host = fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port)
}
}
@ -729,6 +730,7 @@ func (c *Client) NewRequest(method, requestPath string) *Request {
Host: host,
Path: path.Join(addr.Path, requestPath),
},
Host: addr.Host,
ClientToken: token,
Params: make(map[string][]string),
}

View File

@ -97,6 +97,37 @@ func TestClientToken(t *testing.T) {
}
}
func TestClientHostHeader(t *testing.T) {
handler := func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte(req.Host))
}
config, ln := testHTTPServer(t, http.HandlerFunc(handler))
defer ln.Close()
config.Address = strings.ReplaceAll(config.Address, "127.0.0.1", "localhost")
client, err := NewClient(config)
if err != nil {
t.Fatalf("err: %s", err)
}
// Set the token manually
client.SetToken("foo")
resp, err := client.RawRequest(client.NewRequest("PUT", "/"))
if err != nil {
t.Fatal(err)
}
// Copy the response
var buf bytes.Buffer
io.Copy(&buf, resp.Body)
// Verify we got the response from the primary
if buf.String() != strings.ReplaceAll(config.Address, "http://", "") {
t.Fatalf("Bad address: %s", buf.String())
}
}
func TestClientBadToken(t *testing.T) {
handler := func(w http.ResponseWriter, req *http.Request) {}

View File

@ -18,6 +18,7 @@ import (
type Request struct {
Method string
URL *url.URL
Host string
Params url.Values
Headers http.Header
ClientToken string
@ -115,7 +116,7 @@ func (r *Request) toRetryableHTTP() (*retryablehttp.Request, error) {
req.URL.User = r.URL.User
req.URL.Scheme = r.URL.Scheme
req.URL.Host = r.URL.Host
req.Host = r.URL.Host
req.Host = r.Host
if r.Headers != nil {
for header, vals := range r.Headers {

View File

@ -288,6 +288,13 @@ this environment variable is most useful when using the Go
The namespace to use for the command. Setting this is not necessary
but allows using relative paths.
### `VAULT_SRV_LOOKUP`
Enables the client to lookup the host through DNS SRV look up as described in this
[draft](https://tools.ietf.org/html/draft-andrews-http-srv-02).
This is not designed for high-availability, just discovery.
The draft specifies that the SRV record lookup is ignored if a port is given.
### `VAULT_MFA`
**ENTERPRISE ONLY**