only detect advertise address if derived value is any (#3506)

* only detect advertise address if derived value is any

* determine detect function only when advertise addr is any
This commit is contained in:
Frank Schröder 2017-09-27 21:59:47 +02:00 committed by James Phillips
parent d677999258
commit beb803f0d9
2 changed files with 52 additions and 21 deletions

View File

@ -323,27 +323,27 @@ func (b *Builder) Build() (rt RuntimeConfig, err error) {
} }
bindAddr := bindAddrs[0].(*net.IPAddr) bindAddr := bindAddrs[0].(*net.IPAddr)
advertiseAddr := b.makeIPAddr(b.expandFirstIP("advertise_addr", c.AdvertiseAddrLAN), bindAddr)
if ipaddr.IsAny(advertiseAddr) {
var addrtyp string var addrtyp string
var detect func() ([]*net.IPAddr, error) var detect func() ([]*net.IPAddr, error)
switch { switch {
case ipaddr.IsAnyV4(c.BindAddr) || ipaddr.IsAnyV4(bindAddr): case ipaddr.IsAnyV4(advertiseAddr):
addrtyp = "private IPv4" addrtyp = "private IPv4"
detect = b.GetPrivateIPv4 detect = b.GetPrivateIPv4
if detect == nil { if detect == nil {
detect = ipaddr.GetPrivateIPv4 detect = ipaddr.GetPrivateIPv4
}
case ipaddr.IsAnyV6(advertiseAddr):
addrtyp = "public IPv6"
detect = b.GetPublicIPv6
if detect == nil {
detect = ipaddr.GetPublicIPv6
}
} }
case ipaddr.IsAnyV6(c.BindAddr) || ipaddr.IsAnyV6(bindAddr):
addrtyp = "public IPv6"
detect = b.GetPublicIPv6
if detect == nil {
detect = ipaddr.GetPublicIPv6
}
}
advertiseAddr := bindAddr
if detect != nil {
advertiseAddrs, err := detect() advertiseAddrs, err := detect()
if err != nil { if err != nil {
return RuntimeConfig{}, fmt.Errorf("Error detecting %s address: %s", addrtyp, err) return RuntimeConfig{}, fmt.Errorf("Error detecting %s address: %s", addrtyp, err)

View File

@ -653,6 +653,31 @@ func TestConfigFlagsAndEdgecases(t *testing.T) {
return []*net.IPAddr{ipAddr("dead:beef::1")}, nil return []*net.IPAddr{ipAddr("dead:beef::1")}, nil
}, },
}, },
{
desc: "bind addr any and advertise set should not detect",
flags: []string{`-data-dir=` + dataDir},
json: []string{`{ "bind_addr":"0.0.0.0", "advertise_addr": "1.2.3.4" }`},
hcl: []string{`bind_addr = "0.0.0.0" advertise_addr = "1.2.3.4"`},
patch: func(rt *RuntimeConfig) {
rt.AdvertiseAddrLAN = ipAddr("1.2.3.4")
rt.AdvertiseAddrWAN = ipAddr("1.2.3.4")
rt.BindAddr = ipAddr("0.0.0.0")
rt.RPCAdvertiseAddr = tcpAddr("1.2.3.4:8300")
rt.RPCBindAddr = tcpAddr("0.0.0.0:8300")
rt.SerfAdvertiseAddrLAN = tcpAddr("1.2.3.4:8301")
rt.SerfAdvertiseAddrWAN = tcpAddr("1.2.3.4:8302")
rt.SerfBindAddrLAN = tcpAddr("0.0.0.0:8301")
rt.SerfBindAddrWAN = tcpAddr("0.0.0.0:8302")
rt.TaggedAddresses = map[string]string{
"lan": "1.2.3.4",
"wan": "1.2.3.4",
}
rt.DataDir = dataDir
},
privatev4: func() ([]*net.IPAddr, error) {
return nil, fmt.Errorf("should not detect advertise_addr")
},
},
{ {
desc: "client addr and ports == 0", desc: "client addr and ports == 0",
flags: []string{`-data-dir=` + dataDir}, flags: []string{`-data-dir=` + dataDir},
@ -1764,8 +1789,14 @@ func testConfig(t *testing.T, tests []configTest, dataDir string) {
return []*net.IPAddr{ipAddr("10.0.0.1")}, nil return []*net.IPAddr{ipAddr("10.0.0.1")}, nil
} }
} }
publicv6 := tt.publicv6
if publicv6 == nil {
publicv6 = func() ([]*net.IPAddr, error) {
return []*net.IPAddr{ipAddr("dead:beef::1")}, nil
}
}
b.GetPrivateIPv4 = privatev4 b.GetPrivateIPv4 = privatev4
b.GetPublicIPv6 = tt.publicv6 b.GetPublicIPv6 = publicv6
// read the source fragements // read the source fragements
for i, data := range srcs { for i, data := range srcs {
@ -1824,8 +1855,8 @@ func testConfig(t *testing.T, tests []configTest, dataDir string) {
t.Fatal(err) t.Fatal(err)
} }
x.Hostname = b.Hostname x.Hostname = b.Hostname
x.GetPrivateIPv4 = b.GetPrivateIPv4 x.GetPrivateIPv4 = func() ([]*net.IPAddr, error) { return []*net.IPAddr{ipAddr("10.0.0.1")}, nil }
x.GetPublicIPv6 = b.GetPublicIPv6 x.GetPublicIPv6 = func() ([]*net.IPAddr, error) { return []*net.IPAddr{ipAddr("dead:beef::1")}, nil }
patchedRT, err := x.Build() patchedRT, err := x.Build()
if err != nil { if err != nil {
t.Fatalf("build default failed: %s", err) t.Fatalf("build default failed: %s", err)