Improving Handling of Unix Domain Socket Addresses (#11904)

* Removed redundant checks for same env var in ReadEnvironment, extracted Unix domain socket logic to function, and made use of this logic in SetAddress.  Adjusted unit tests to verify proper Unix domain socket handling.

* Adding case to revert from Unix domain socket dial function back to TCP

* Adding changelog file

* Only adjust DialContext if RoundTripper is an http.Transport

* Switching from read lock to normal lock

* only reset transport DialContext when setting different address type

* made ParseAddress a method on Config

* Adding additional tests to cover transitions to/from TCP to Unix

* Moved Config type method ParseAddress closer to type's other methods.

* make release note more end-user focused

* adopt review feedback to add comment about holding a lock
This commit is contained in:
Marc Boudreau 2022-06-21 18:16:58 -04:00 committed by GitHub
parent d1971a9f19
commit 03d75a7b60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 132 additions and 29 deletions

View File

@ -347,8 +347,6 @@ func (c *Config) ReadEnvironment() error {
} }
if v := os.Getenv(EnvVaultAgentAddr); v != "" { if v := os.Getenv(EnvVaultAgentAddr); v != "" {
envAgentAddress = v envAgentAddress = v
} else if v := os.Getenv(EnvVaultAgentAddress); v != "" {
envAgentAddress = v
} }
if v := os.Getenv(EnvVaultMaxRetries); v != "" { if v := os.Getenv(EnvVaultMaxRetries); v != "" {
maxRetries, err := strconv.ParseUint(v, 10, 32) maxRetries, err := strconv.ParseUint(v, 10, 32)
@ -392,12 +390,6 @@ func (c *Config) ReadEnvironment() error {
if err != nil { if err != nil {
return fmt.Errorf("could not parse VAULT_SKIP_VERIFY") return fmt.Errorf("could not parse VAULT_SKIP_VERIFY")
} }
} else if v := os.Getenv(EnvVaultInsecure); v != "" {
var err error
envInsecure, err = strconv.ParseBool(v)
if err != nil {
return fmt.Errorf("could not parse VAULT_INSECURE")
}
} }
if v := os.Getenv(EnvVaultSRVLookup); v != "" { if v := os.Getenv(EnvVaultSRVLookup); v != "" {
var err error var err error
@ -470,6 +462,51 @@ func (c *Config) ReadEnvironment() error {
return nil return nil
} }
// ParseAddress transforms the provided address into a url.URL and handles
// the case of Unix domain sockets by setting the DialContext in the
// configuration's HttpClient.Transport. This function must be called with
// c.modifyLock held for write access.
func (c *Config) ParseAddress(address string) (*url.URL, error) {
u, err := url.Parse(address)
if err != nil {
return nil, err
}
c.Address = address
if strings.HasPrefix(address, "unix://") {
// When the address begins with unix://, always change the transport's
// DialContext (to match previous behaviour)
socket := strings.TrimPrefix(address, "unix://")
if transport, ok := c.HttpClient.Transport.(*http.Transport); ok {
transport.DialContext = func(context.Context, string, string) (net.Conn, error) {
return net.Dial("unix", socket)
}
// Since the address points to a unix domain socket, the scheme in the
// *URL would be set to `unix`. The *URL in the client is expected to
// be pointing to the protocol used in the application layer and not to
// the transport layer. Hence, setting the fields accordingly.
u.Scheme = "http"
u.Host = socket
u.Path = ""
} else {
return nil, fmt.Errorf("attempting to specify unix:// address with non-transport transport")
}
} else if strings.HasPrefix(c.Address, "unix://") {
// When the address being set does not begin with unix:// but the previous
// address in the Config did, change the transport's DialContext back to
// use the default configuration that cleanhttp uses.
if transport, ok := c.HttpClient.Transport.(*http.Transport); ok {
transport.DialContext = cleanhttp.DefaultPooledTransport().DialContext
}
}
return u, nil
}
func parseRateLimit(val string) (rate float64, burst int, err error) { func parseRateLimit(val string) (rate float64, burst int, err error) {
_, err = fmt.Sscanf(val, "%f:%d", &rate, &burst) _, err = fmt.Sscanf(val, "%f:%d", &rate, &burst)
if err != nil { if err != nil {
@ -542,27 +579,11 @@ func NewClient(c *Config) (*Client, error) {
address = c.AgentAddress address = c.AgentAddress
} }
u, err := url.Parse(address) u, err := c.ParseAddress(address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if strings.HasPrefix(address, "unix://") {
socket := strings.TrimPrefix(address, "unix://")
transport := c.HttpClient.Transport.(*http.Transport)
transport.DialContext = func(context.Context, string, string) (net.Conn, error) {
return net.Dial("unix", socket)
}
// Since the address points to a unix domain socket, the scheme in the
// *URL would be set to `unix`. The *URL in the client is expected to
// be pointing to the protocol used in the application layer and not to
// the transport layer. Hence, setting the fields accordingly.
u.Scheme = "http"
u.Host = socket
u.Path = ""
}
client := &Client{ client := &Client{
addr: u, addr: u,
config: c, config: c,
@ -621,14 +642,11 @@ func (c *Client) SetAddress(addr string) error {
c.modifyLock.Lock() c.modifyLock.Lock()
defer c.modifyLock.Unlock() defer c.modifyLock.Unlock()
parsedAddr, err := url.Parse(addr) parsedAddr, err := c.config.ParseAddress(addr)
if err != nil { if err != nil {
return errwrap.Wrapf("failed to set address: {{err}}", err) return errwrap.Wrapf("failed to set address: {{err}}", err)
} }
c.config.modifyLock.Lock()
c.config.Address = addr
c.config.modifyLock.Unlock()
c.addr = parsedAddr c.addr = parsedAddr
return nil return nil
} }

View File

@ -69,17 +69,63 @@ func TestClientNilConfig(t *testing.T) {
} }
} }
func TestClientDefaultHttpClient_unixSocket(t *testing.T) {
os.Setenv("VAULT_AGENT_ADDR", "unix:///var/run/vault.sock")
defer os.Setenv("VAULT_AGENT_ADDR", "")
client, err := NewClient(nil)
if err != nil {
t.Fatal(err)
}
if client == nil {
t.Fatal("expected a non-nil client")
}
if client.addr.Scheme != "http" {
t.Fatalf("bad: %s", client.addr.Scheme)
}
if client.addr.Host != "/var/run/vault.sock" {
t.Fatalf("bad: %s", client.addr.Host)
}
}
func TestClientSetAddress(t *testing.T) { func TestClientSetAddress(t *testing.T) {
client, err := NewClient(nil) client, err := NewClient(nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Start with TCP address using HTTP
if err := client.SetAddress("http://172.168.2.1:8300"); err != nil { if err := client.SetAddress("http://172.168.2.1:8300"); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if client.addr.Host != "172.168.2.1:8300" { if client.addr.Host != "172.168.2.1:8300" {
t.Fatalf("bad: expected: '172.168.2.1:8300' actual: %q", client.addr.Host) t.Fatalf("bad: expected: '172.168.2.1:8300' actual: %q", client.addr.Host)
} }
// Test switching to Unix Socket address from TCP address
if err := client.SetAddress("unix:///var/run/vault.sock"); err != nil {
t.Fatal(err)
}
if client.addr.Scheme != "http" {
t.Fatalf("bad: expected: 'http' actual: %q", client.addr.Scheme)
}
if client.addr.Host != "/var/run/vault.sock" {
t.Fatalf("bad: expected: '/var/run/vault.sock' actual: %q", client.addr.Host)
}
if client.addr.Path != "" {
t.Fatalf("bad: expected '' actual: %q", client.addr.Path)
}
if client.config.HttpClient.Transport.(*http.Transport).DialContext == nil {
t.Fatal("bad: expected DialContext to not be nil")
}
// Test switching to TCP address from Unix Socket address
if err := client.SetAddress("http://172.168.2.1:8300"); err != nil {
t.Fatal(err)
}
if client.addr.Host != "172.168.2.1:8300" {
t.Fatalf("bad: expected: '172.168.2.1:8300' actual: %q", client.addr.Host)
}
if client.addr.Scheme != "http" {
t.Fatalf("bad: expected: 'http' actual: %q", client.addr.Scheme)
}
} }
func TestClientToken(t *testing.T) { func TestClientToken(t *testing.T) {
@ -426,6 +472,20 @@ func TestClientNonTransportRoundTripper(t *testing.T) {
} }
} }
func TestClientNonTransportRoundTripperUnixAddress(t *testing.T) {
client := &http.Client{
Transport: roundTripperFunc(http.DefaultTransport.RoundTrip),
}
_, err := NewClient(&Config{
HttpClient: client,
Address: "unix:///var/run/vault.sock",
})
if err == nil {
t.Fatal("bad: expected error got nil")
}
}
func TestClone(t *testing.T) { func TestClone(t *testing.T) {
type fields struct{} type fields struct{}
tests := []struct { tests := []struct {
@ -1284,3 +1344,25 @@ func TestVaultProxy(t *testing.T) {
}) })
} }
} }
func TestParseAddressWithUnixSocket(t *testing.T) {
address := "unix:///var/run/vault.sock"
config := DefaultConfig()
u, err := config.ParseAddress(address)
if err != nil {
t.Fatal("Error not expected")
}
if u.Scheme != "http" {
t.Fatal("Scheme not changed to http")
}
if u.Host != "/var/run/vault.sock" {
t.Fatal("Host not changed to socket name")
}
if u.Path != "" {
t.Fatal("Path expected to be blank")
}
if config.HttpClient.Transport.(*http.Transport).DialContext == nil {
t.Fatal("DialContext function not set in config.HttpClient.Transport")
}
}

3
changelog/11904.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:bug
api: properly handle switching to/from unix domain socket when changing client address
```