From 9f8991e0ccb8bdda5a18d4fce2766e90ac949600 Mon Sep 17 00:00:00 2001 From: Matt Keeler Date: Mon, 16 Jul 2018 16:30:15 -0400 Subject: [PATCH] Fix issue with choosing a client addr that is 0.0.0.0 or :: --- agent/agent.go | 36 ++++---------- agent/config/runtime.go | 87 +++++++++++++++++++++++++++++---- agent/config/runtime_test.go | 94 ++++++++++++++++++++++++++++++++++++ 3 files changed, 180 insertions(+), 37 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 13c55b156..c77565840 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -710,25 +710,14 @@ func (a *Agent) reloadWatches(cfg *config.RuntimeConfig) error { watchPlans = append(watchPlans, wp) } - // Determine the primary http(s) endpoint. - var netaddr net.Addr - https := false - if len(cfg.HTTPAddrs) > 0 { - netaddr = cfg.HTTPAddrs[0] - } else { - netaddr = cfg.HTTPSAddrs[0] - https = true - } - addr := netaddr.String() - if netaddr.Network() == "unix" { - addr = "unix://" + addr - https = false - } else if https { - addr = "https://" + addr - } - // Fire off a goroutine for each new watch plan. for _, wp := range watchPlans { + config, err := a.config.APIConfig(true) + if err != nil { + a.logger.Printf("[ERR] agent: Failed to run watch: %v", err) + continue + } + a.watchPlans = append(a.watchPlans, wp) go func(wp *watch.Plan) { if h, ok := wp.Exempt["handler"]; ok { @@ -741,16 +730,9 @@ func (a *Agent) reloadWatches(cfg *config.RuntimeConfig) error { } wp.LogOutput = a.LogOutput - config := api.DefaultConfig() - if https { - if a.config.CAPath != "" { - config.TLSConfig.CAPath = a.config.CAPath - } - if a.config.CAFile != "" { - config.TLSConfig.CAFile = a.config.CAFile - } - // use the original address without the https:// prefix - config.TLSConfig.Address = netaddr.String() + addr := config.Address + if config.Scheme == "https" { + addr = "https://" + addr } if err := wp.RunWithConfig(addr, config); err != nil { diff --git a/agent/config/runtime.go b/agent/config/runtime.go index 2914c38bf..d63c7e20d 100644 --- a/agent/config/runtime.go +++ b/agent/config/runtime.go @@ -1193,7 +1193,7 @@ func (c *RuntimeConfig) IncomingHTTPSConfig() (*tls.Config, error) { func (c *RuntimeConfig) apiAddresses(maxPerType int) (unixAddrs, httpAddrs, httpsAddrs []string) { if len(c.HTTPSAddrs) > 0 { for i, addr := range c.HTTPSAddrs { - if i < maxPerType { + if maxPerType < 1 || i < maxPerType { httpsAddrs = append(httpsAddrs, addr.String()) } else { break @@ -1206,12 +1206,12 @@ func (c *RuntimeConfig) apiAddresses(maxPerType int) (unixAddrs, httpAddrs, http for _, addr := range c.HTTPAddrs { switch addr.(type) { case *net.UnixAddr: - if unix_count < maxPerType { + if maxPerType < 1 || unix_count < maxPerType { unixAddrs = append(unixAddrs, addr.String()) unix_count += 1 } default: - if http_count < maxPerType { + if maxPerType < 1 || http_count < maxPerType { httpAddrs = append(httpAddrs, addr.String()) http_count += 1 } @@ -1222,28 +1222,95 @@ func (c *RuntimeConfig) apiAddresses(maxPerType int) (unixAddrs, httpAddrs, http return } +func (c *RuntimeConfig) ClientAddress() (unixAddr, httpAddr, httpsAddr string) { + unixAddrs, httpAddrs, httpsAddrs := c.apiAddresses(0) + + if len(unixAddrs) > 0 { + unixAddr = "unix://" + unixAddrs[0] + } + + http_any := "" + if len(httpAddrs) > 0 { + for _, addr := range httpAddrs { + host, port, err := net.SplitHostPort(addr) + if err != nil { + continue + } + + if host == "0.0.0.0" || host == "::" { + if http_any == "" { + if host == "0.0.0.0" { + http_any = net.JoinHostPort("127.0.0.1", port) + } else { + http_any = net.JoinHostPort("::1", port) + } + } + continue + } + + httpAddr = addr + break + } + + if httpAddr == "" && http_any != "" { + httpAddr = http_any + } + } + + https_any := "" + if len(httpsAddrs) > 0 { + for _, addr := range httpsAddrs { + host, port, err := net.SplitHostPort(addr) + if err != nil { + continue + } + + if host == "0.0.0.0" || host == "::" { + if https_any == "" { + if host == "0.0.0.0" { + https_any = net.JoinHostPort("127.0.0.1", port) + } else { + https_any = net.JoinHostPort("::1", port) + } + } + continue + } + + httpsAddr = addr + break + } + + if httpsAddr == "" && https_any != "" { + httpsAddr = https_any + } + } + + return +} + func (c *RuntimeConfig) APIConfig(includeClientCerts bool) (*api.Config, error) { cfg := &api.Config{ Datacenter: c.Datacenter, TLSConfig: api.TLSConfig{InsecureSkipVerify: !c.VerifyOutgoing}, } - unixAddrs, httpAddrs, httpsAddrs := c.apiAddresses(1) + unixAddr, httpAddr, httpsAddr := c.ClientAddress() - if len(httpsAddrs) > 0 { - cfg.Address = httpsAddrs[0] + if httpsAddr != "" { + cfg.Address = httpsAddr cfg.Scheme = "https" cfg.TLSConfig.CAFile = c.CAFile cfg.TLSConfig.CAPath = c.CAPath + cfg.TLSConfig.Address = httpsAddr if includeClientCerts { cfg.TLSConfig.CertFile = c.CertFile cfg.TLSConfig.KeyFile = c.KeyFile } - } else if len(httpAddrs) > 0 { - cfg.Address = httpAddrs[0] + } else if httpAddr != "" { + cfg.Address = httpAddr cfg.Scheme = "http" - } else if len(unixAddrs) > 0 { - cfg.Address = "unix://" + unixAddrs[0] + } else if unixAddr != "" { + cfg.Address = unixAddr // this should be ignored - however we are still talking http over a unix socket // so it makes sense to set it like this cfg.Scheme = "http" diff --git a/agent/config/runtime_test.go b/agent/config/runtime_test.go index 4df0bd5a0..9326e8ff6 100644 --- a/agent/config/runtime_test.go +++ b/agent/config/runtime_test.go @@ -4608,6 +4608,100 @@ func TestRuntime_APIConfigUNIX(t *testing.T) { require.Equal(t, "", cfg.TLSConfig.KeyFile) } +func TestRuntime_APIConfigANYAddrV4(t *testing.T) { + rt := RuntimeConfig{ + HTTPAddrs: []net.Addr{ + &net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: 5678}, + }, + Datacenter: "dc-test", + } + + cfg, err := rt.APIConfig(false) + require.NoError(t, err) + require.Equal(t, rt.Datacenter, cfg.Datacenter) + require.Equal(t, "127.0.0.1:5678", cfg.Address) + require.Equal(t, "http", cfg.Scheme) + require.Equal(t, "", cfg.TLSConfig.CAFile) + require.Equal(t, "", cfg.TLSConfig.CAPath) + require.Equal(t, "", cfg.TLSConfig.CertFile) + require.Equal(t, "", cfg.TLSConfig.KeyFile) +} + +func TestRuntime_APIConfigANYAddrV6(t *testing.T) { + rt := RuntimeConfig{ + HTTPAddrs: []net.Addr{ + &net.TCPAddr{IP: net.ParseIP("::"), Port: 5678}, + }, + Datacenter: "dc-test", + } + + cfg, err := rt.APIConfig(false) + require.NoError(t, err) + require.Equal(t, rt.Datacenter, cfg.Datacenter) + require.Equal(t, "[::1]:5678", cfg.Address) + require.Equal(t, "http", cfg.Scheme) + require.Equal(t, "", cfg.TLSConfig.CAFile) + require.Equal(t, "", cfg.TLSConfig.CAPath) + require.Equal(t, "", cfg.TLSConfig.CertFile) + require.Equal(t, "", cfg.TLSConfig.KeyFile) +} + +func TestRuntime_ClientAddress(t *testing.T) { + rt := RuntimeConfig{ + HTTPAddrs: []net.Addr{ + &net.TCPAddr{IP: net.ParseIP("::"), Port: 5678}, + &net.TCPAddr{IP: net.ParseIP("198.18.0.1"), Port: 5679}, + &net.UnixAddr{Name: "/var/run/foo", Net: "unix"}, + }, + HTTPSAddrs: []net.Addr{ + &net.TCPAddr{IP: net.ParseIP("::"), Port: 5688}, + &net.TCPAddr{IP: net.ParseIP("198.18.0.1"), Port: 5689}, + }, + } + + unix, http, https := rt.ClientAddress() + + require.Equal(t, "unix:///var/run/foo", unix) + require.Equal(t, "198.18.0.1:5679", http) + require.Equal(t, "198.18.0.1:5689", https) +} + +func TestRuntime_ClientAddressAnyV4(t *testing.T) { + rt := RuntimeConfig{ + HTTPAddrs: []net.Addr{ + &net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: 5678}, + &net.UnixAddr{Name: "/var/run/foo", Net: "unix"}, + }, + HTTPSAddrs: []net.Addr{ + &net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: 5688}, + }, + } + + unix, http, https := rt.ClientAddress() + + require.Equal(t, "unix:///var/run/foo", unix) + require.Equal(t, "127.0.0.1:5678", http) + require.Equal(t, "127.0.0.1:5688", https) +} + +func TestRuntime_ClientAddressAnyV6(t *testing.T) { + rt := RuntimeConfig{ + HTTPAddrs: []net.Addr{ + &net.TCPAddr{IP: net.ParseIP("::"), Port: 5678}, + &net.UnixAddr{Name: "/var/run/foo", Net: "unix"}, + }, + HTTPSAddrs: []net.Addr{ + &net.TCPAddr{IP: net.ParseIP("::"), Port: 5688}, + }, + } + + unix, http, https := rt.ClientAddress() + + require.Equal(t, "unix:///var/run/foo", unix) + require.Equal(t, "[::1]:5678", http) + require.Equal(t, "[::1]:5688", https) +} + func splitIPPort(hostport string) (net.IP, int) { h, p, err := net.SplitHostPort(hostport) if err != nil {