diff --git a/api/client.go b/api/client.go index 957ba5d82..ee8bdd051 100644 --- a/api/client.go +++ b/api/client.go @@ -347,8 +347,6 @@ func (c *Config) ReadEnvironment() error { } if v := os.Getenv(EnvVaultAgentAddr); v != "" { envAgentAddress = v - } else if v := os.Getenv(EnvVaultAgentAddress); v != "" { - envAgentAddress = v } if v := os.Getenv(EnvVaultMaxRetries); v != "" { maxRetries, err := strconv.ParseUint(v, 10, 32) @@ -392,12 +390,6 @@ func (c *Config) ReadEnvironment() error { if err != nil { return fmt.Errorf("could not parse VAULT_SKIP_VERIFY") } - } else if v := os.Getenv(EnvVaultInsecure); v != "" { - var err error - envInsecure, err = strconv.ParseBool(v) - if err != nil { - return fmt.Errorf("could not parse VAULT_INSECURE") - } } if v := os.Getenv(EnvVaultSRVLookup); v != "" { var err error @@ -470,6 +462,51 @@ func (c *Config) ReadEnvironment() error { return nil } +// ParseAddress transforms the provided address into a url.URL and handles +// the case of Unix domain sockets by setting the DialContext in the +// configuration's HttpClient.Transport. This function must be called with +// c.modifyLock held for write access. +func (c *Config) ParseAddress(address string) (*url.URL, error) { + u, err := url.Parse(address) + if err != nil { + return nil, err + } + + c.Address = address + + if strings.HasPrefix(address, "unix://") { + // When the address begins with unix://, always change the transport's + // DialContext (to match previous behaviour) + socket := strings.TrimPrefix(address, "unix://") + + if transport, ok := c.HttpClient.Transport.(*http.Transport); ok { + transport.DialContext = func(context.Context, string, string) (net.Conn, error) { + return net.Dial("unix", socket) + } + + // Since the address points to a unix domain socket, the scheme in the + // *URL would be set to `unix`. The *URL in the client is expected to + // be pointing to the protocol used in the application layer and not to + // the transport layer. Hence, setting the fields accordingly. + u.Scheme = "http" + u.Host = socket + u.Path = "" + } else { + return nil, fmt.Errorf("attempting to specify unix:// address with non-transport transport") + } + } else if strings.HasPrefix(c.Address, "unix://") { + // When the address being set does not begin with unix:// but the previous + // address in the Config did, change the transport's DialContext back to + // use the default configuration that cleanhttp uses. + + if transport, ok := c.HttpClient.Transport.(*http.Transport); ok { + transport.DialContext = cleanhttp.DefaultPooledTransport().DialContext + } + } + + return u, nil +} + func parseRateLimit(val string) (rate float64, burst int, err error) { _, err = fmt.Sscanf(val, "%f:%d", &rate, &burst) if err != nil { @@ -542,27 +579,11 @@ func NewClient(c *Config) (*Client, error) { address = c.AgentAddress } - u, err := url.Parse(address) + u, err := c.ParseAddress(address) if err != nil { return nil, err } - if strings.HasPrefix(address, "unix://") { - socket := strings.TrimPrefix(address, "unix://") - transport := c.HttpClient.Transport.(*http.Transport) - transport.DialContext = func(context.Context, string, string) (net.Conn, error) { - return net.Dial("unix", socket) - } - - // Since the address points to a unix domain socket, the scheme in the - // *URL would be set to `unix`. The *URL in the client is expected to - // be pointing to the protocol used in the application layer and not to - // the transport layer. Hence, setting the fields accordingly. - u.Scheme = "http" - u.Host = socket - u.Path = "" - } - client := &Client{ addr: u, config: c, @@ -621,14 +642,11 @@ func (c *Client) SetAddress(addr string) error { c.modifyLock.Lock() defer c.modifyLock.Unlock() - parsedAddr, err := url.Parse(addr) + parsedAddr, err := c.config.ParseAddress(addr) if err != nil { return errwrap.Wrapf("failed to set address: {{err}}", err) } - c.config.modifyLock.Lock() - c.config.Address = addr - c.config.modifyLock.Unlock() c.addr = parsedAddr return nil } diff --git a/api/client_test.go b/api/client_test.go index 74f1b9354..ca165e010 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -69,17 +69,63 @@ func TestClientNilConfig(t *testing.T) { } } +func TestClientDefaultHttpClient_unixSocket(t *testing.T) { + os.Setenv("VAULT_AGENT_ADDR", "unix:///var/run/vault.sock") + defer os.Setenv("VAULT_AGENT_ADDR", "") + + client, err := NewClient(nil) + if err != nil { + t.Fatal(err) + } + if client == nil { + t.Fatal("expected a non-nil client") + } + if client.addr.Scheme != "http" { + t.Fatalf("bad: %s", client.addr.Scheme) + } + if client.addr.Host != "/var/run/vault.sock" { + t.Fatalf("bad: %s", client.addr.Host) + } +} + func TestClientSetAddress(t *testing.T) { client, err := NewClient(nil) if err != nil { t.Fatal(err) } + // Start with TCP address using HTTP if err := client.SetAddress("http://172.168.2.1:8300"); err != nil { t.Fatal(err) } if client.addr.Host != "172.168.2.1:8300" { t.Fatalf("bad: expected: '172.168.2.1:8300' actual: %q", client.addr.Host) } + // Test switching to Unix Socket address from TCP address + if err := client.SetAddress("unix:///var/run/vault.sock"); err != nil { + t.Fatal(err) + } + if client.addr.Scheme != "http" { + t.Fatalf("bad: expected: 'http' actual: %q", client.addr.Scheme) + } + if client.addr.Host != "/var/run/vault.sock" { + t.Fatalf("bad: expected: '/var/run/vault.sock' actual: %q", client.addr.Host) + } + if client.addr.Path != "" { + t.Fatalf("bad: expected '' actual: %q", client.addr.Path) + } + if client.config.HttpClient.Transport.(*http.Transport).DialContext == nil { + t.Fatal("bad: expected DialContext to not be nil") + } + // Test switching to TCP address from Unix Socket address + if err := client.SetAddress("http://172.168.2.1:8300"); err != nil { + t.Fatal(err) + } + if client.addr.Host != "172.168.2.1:8300" { + t.Fatalf("bad: expected: '172.168.2.1:8300' actual: %q", client.addr.Host) + } + if client.addr.Scheme != "http" { + t.Fatalf("bad: expected: 'http' actual: %q", client.addr.Scheme) + } } func TestClientToken(t *testing.T) { @@ -426,6 +472,20 @@ func TestClientNonTransportRoundTripper(t *testing.T) { } } +func TestClientNonTransportRoundTripperUnixAddress(t *testing.T) { + client := &http.Client{ + Transport: roundTripperFunc(http.DefaultTransport.RoundTrip), + } + + _, err := NewClient(&Config{ + HttpClient: client, + Address: "unix:///var/run/vault.sock", + }) + if err == nil { + t.Fatal("bad: expected error got nil") + } +} + func TestClone(t *testing.T) { type fields struct{} tests := []struct { @@ -1284,3 +1344,25 @@ func TestVaultProxy(t *testing.T) { }) } } + +func TestParseAddressWithUnixSocket(t *testing.T) { + address := "unix:///var/run/vault.sock" + config := DefaultConfig() + + u, err := config.ParseAddress(address) + if err != nil { + t.Fatal("Error not expected") + } + if u.Scheme != "http" { + t.Fatal("Scheme not changed to http") + } + if u.Host != "/var/run/vault.sock" { + t.Fatal("Host not changed to socket name") + } + if u.Path != "" { + t.Fatal("Path expected to be blank") + } + if config.HttpClient.Transport.(*http.Transport).DialContext == nil { + t.Fatal("DialContext function not set in config.HttpClient.Transport") + } +} diff --git a/changelog/11904.txt b/changelog/11904.txt new file mode 100644 index 000000000..584aeae8d --- /dev/null +++ b/changelog/11904.txt @@ -0,0 +1,3 @@ +```release-note:bug +api: properly handle switching to/from unix domain socket when changing client address +```