Improving Handling of Unix Domain Socket Addresses (#11904)
* Removed redundant checks for same env var in ReadEnvironment, extracted Unix domain socket logic to function, and made use of this logic in SetAddress. Adjusted unit tests to verify proper Unix domain socket handling. * Adding case to revert from Unix domain socket dial function back to TCP * Adding changelog file * Only adjust DialContext if RoundTripper is an http.Transport * Switching from read lock to normal lock * only reset transport DialContext when setting different address type * made ParseAddress a method on Config * Adding additional tests to cover transitions to/from TCP to Unix * Moved Config type method ParseAddress closer to type's other methods. * make release note more end-user focused * adopt review feedback to add comment about holding a lock
This commit is contained in:
parent
d1971a9f19
commit
03d75a7b60
|
@ -347,8 +347,6 @@ func (c *Config) ReadEnvironment() error {
|
|||
}
|
||||
if v := os.Getenv(EnvVaultAgentAddr); v != "" {
|
||||
envAgentAddress = v
|
||||
} else if v := os.Getenv(EnvVaultAgentAddress); v != "" {
|
||||
envAgentAddress = v
|
||||
}
|
||||
if v := os.Getenv(EnvVaultMaxRetries); v != "" {
|
||||
maxRetries, err := strconv.ParseUint(v, 10, 32)
|
||||
|
@ -392,12 +390,6 @@ func (c *Config) ReadEnvironment() error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("could not parse VAULT_SKIP_VERIFY")
|
||||
}
|
||||
} else if v := os.Getenv(EnvVaultInsecure); v != "" {
|
||||
var err error
|
||||
envInsecure, err = strconv.ParseBool(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not parse VAULT_INSECURE")
|
||||
}
|
||||
}
|
||||
if v := os.Getenv(EnvVaultSRVLookup); v != "" {
|
||||
var err error
|
||||
|
@ -470,6 +462,51 @@ func (c *Config) ReadEnvironment() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// ParseAddress transforms the provided address into a url.URL and handles
|
||||
// the case of Unix domain sockets by setting the DialContext in the
|
||||
// configuration's HttpClient.Transport. This function must be called with
|
||||
// c.modifyLock held for write access.
|
||||
func (c *Config) ParseAddress(address string) (*url.URL, error) {
|
||||
u, err := url.Parse(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.Address = address
|
||||
|
||||
if strings.HasPrefix(address, "unix://") {
|
||||
// When the address begins with unix://, always change the transport's
|
||||
// DialContext (to match previous behaviour)
|
||||
socket := strings.TrimPrefix(address, "unix://")
|
||||
|
||||
if transport, ok := c.HttpClient.Transport.(*http.Transport); ok {
|
||||
transport.DialContext = func(context.Context, string, string) (net.Conn, error) {
|
||||
return net.Dial("unix", socket)
|
||||
}
|
||||
|
||||
// Since the address points to a unix domain socket, the scheme in the
|
||||
// *URL would be set to `unix`. The *URL in the client is expected to
|
||||
// be pointing to the protocol used in the application layer and not to
|
||||
// the transport layer. Hence, setting the fields accordingly.
|
||||
u.Scheme = "http"
|
||||
u.Host = socket
|
||||
u.Path = ""
|
||||
} else {
|
||||
return nil, fmt.Errorf("attempting to specify unix:// address with non-transport transport")
|
||||
}
|
||||
} else if strings.HasPrefix(c.Address, "unix://") {
|
||||
// When the address being set does not begin with unix:// but the previous
|
||||
// address in the Config did, change the transport's DialContext back to
|
||||
// use the default configuration that cleanhttp uses.
|
||||
|
||||
if transport, ok := c.HttpClient.Transport.(*http.Transport); ok {
|
||||
transport.DialContext = cleanhttp.DefaultPooledTransport().DialContext
|
||||
}
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func parseRateLimit(val string) (rate float64, burst int, err error) {
|
||||
_, err = fmt.Sscanf(val, "%f:%d", &rate, &burst)
|
||||
if err != nil {
|
||||
|
@ -542,27 +579,11 @@ func NewClient(c *Config) (*Client, error) {
|
|||
address = c.AgentAddress
|
||||
}
|
||||
|
||||
u, err := url.Parse(address)
|
||||
u, err := c.ParseAddress(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if strings.HasPrefix(address, "unix://") {
|
||||
socket := strings.TrimPrefix(address, "unix://")
|
||||
transport := c.HttpClient.Transport.(*http.Transport)
|
||||
transport.DialContext = func(context.Context, string, string) (net.Conn, error) {
|
||||
return net.Dial("unix", socket)
|
||||
}
|
||||
|
||||
// Since the address points to a unix domain socket, the scheme in the
|
||||
// *URL would be set to `unix`. The *URL in the client is expected to
|
||||
// be pointing to the protocol used in the application layer and not to
|
||||
// the transport layer. Hence, setting the fields accordingly.
|
||||
u.Scheme = "http"
|
||||
u.Host = socket
|
||||
u.Path = ""
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
addr: u,
|
||||
config: c,
|
||||
|
@ -621,14 +642,11 @@ func (c *Client) SetAddress(addr string) error {
|
|||
c.modifyLock.Lock()
|
||||
defer c.modifyLock.Unlock()
|
||||
|
||||
parsedAddr, err := url.Parse(addr)
|
||||
parsedAddr, err := c.config.ParseAddress(addr)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("failed to set address: {{err}}", err)
|
||||
}
|
||||
|
||||
c.config.modifyLock.Lock()
|
||||
c.config.Address = addr
|
||||
c.config.modifyLock.Unlock()
|
||||
c.addr = parsedAddr
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -69,17 +69,63 @@ func TestClientNilConfig(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestClientDefaultHttpClient_unixSocket(t *testing.T) {
|
||||
os.Setenv("VAULT_AGENT_ADDR", "unix:///var/run/vault.sock")
|
||||
defer os.Setenv("VAULT_AGENT_ADDR", "")
|
||||
|
||||
client, err := NewClient(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if client == nil {
|
||||
t.Fatal("expected a non-nil client")
|
||||
}
|
||||
if client.addr.Scheme != "http" {
|
||||
t.Fatalf("bad: %s", client.addr.Scheme)
|
||||
}
|
||||
if client.addr.Host != "/var/run/vault.sock" {
|
||||
t.Fatalf("bad: %s", client.addr.Host)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientSetAddress(t *testing.T) {
|
||||
client, err := NewClient(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Start with TCP address using HTTP
|
||||
if err := client.SetAddress("http://172.168.2.1:8300"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if client.addr.Host != "172.168.2.1:8300" {
|
||||
t.Fatalf("bad: expected: '172.168.2.1:8300' actual: %q", client.addr.Host)
|
||||
}
|
||||
// Test switching to Unix Socket address from TCP address
|
||||
if err := client.SetAddress("unix:///var/run/vault.sock"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if client.addr.Scheme != "http" {
|
||||
t.Fatalf("bad: expected: 'http' actual: %q", client.addr.Scheme)
|
||||
}
|
||||
if client.addr.Host != "/var/run/vault.sock" {
|
||||
t.Fatalf("bad: expected: '/var/run/vault.sock' actual: %q", client.addr.Host)
|
||||
}
|
||||
if client.addr.Path != "" {
|
||||
t.Fatalf("bad: expected '' actual: %q", client.addr.Path)
|
||||
}
|
||||
if client.config.HttpClient.Transport.(*http.Transport).DialContext == nil {
|
||||
t.Fatal("bad: expected DialContext to not be nil")
|
||||
}
|
||||
// Test switching to TCP address from Unix Socket address
|
||||
if err := client.SetAddress("http://172.168.2.1:8300"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if client.addr.Host != "172.168.2.1:8300" {
|
||||
t.Fatalf("bad: expected: '172.168.2.1:8300' actual: %q", client.addr.Host)
|
||||
}
|
||||
if client.addr.Scheme != "http" {
|
||||
t.Fatalf("bad: expected: 'http' actual: %q", client.addr.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientToken(t *testing.T) {
|
||||
|
@ -426,6 +472,20 @@ func TestClientNonTransportRoundTripper(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestClientNonTransportRoundTripperUnixAddress(t *testing.T) {
|
||||
client := &http.Client{
|
||||
Transport: roundTripperFunc(http.DefaultTransport.RoundTrip),
|
||||
}
|
||||
|
||||
_, err := NewClient(&Config{
|
||||
HttpClient: client,
|
||||
Address: "unix:///var/run/vault.sock",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("bad: expected error got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClone(t *testing.T) {
|
||||
type fields struct{}
|
||||
tests := []struct {
|
||||
|
@ -1284,3 +1344,25 @@ func TestVaultProxy(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAddressWithUnixSocket(t *testing.T) {
|
||||
address := "unix:///var/run/vault.sock"
|
||||
config := DefaultConfig()
|
||||
|
||||
u, err := config.ParseAddress(address)
|
||||
if err != nil {
|
||||
t.Fatal("Error not expected")
|
||||
}
|
||||
if u.Scheme != "http" {
|
||||
t.Fatal("Scheme not changed to http")
|
||||
}
|
||||
if u.Host != "/var/run/vault.sock" {
|
||||
t.Fatal("Host not changed to socket name")
|
||||
}
|
||||
if u.Path != "" {
|
||||
t.Fatal("Path expected to be blank")
|
||||
}
|
||||
if config.HttpClient.Transport.(*http.Transport).DialContext == nil {
|
||||
t.Fatal("DialContext function not set in config.HttpClient.Transport")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
```release-note:bug
|
||||
api: properly handle switching to/from unix domain socket when changing client address
|
||||
```
|
Loading…
Reference in New Issue