auto-encrypt: Fix port resolution and fallback to default port (#6205)

Auto-encrypt meant to fallback to the default port when it wasn't provided, but it hadn't been because of an issue with the error handling. We were checking against an incomplete error value:
"missing port in address" vs "address $HOST: missing port in address"

Additionally, all RPCs to AutoEncrypt.Sign were using a.config.ServerPort, so those were updated to use ports resolved by resolveAddrs, if they are available.
This commit is contained in:
Freddy 2019-07-24 16:49:37 -07:00 committed by GitHub
parent 33f51db661
commit 7dbbe7e55a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 110 additions and 11 deletions

View File

@ -4,6 +4,7 @@ import (
"fmt"
"log"
"net"
"strconv"
"strings"
"time"
@ -18,7 +19,7 @@ const (
retryJitterWindow = 30 * time.Second
)
func (c *Client) RequestAutoEncryptCerts(servers []string, port int, token string, interruptCh chan struct{}) (*structs.SignedResponse, string, error) {
func (c *Client) RequestAutoEncryptCerts(servers []string, defaultPort int, token string, interruptCh chan struct{}) (*structs.SignedResponse, string, error) {
errFn := func(err error) (*structs.SignedResponse, string, error) {
return nil, "", err
}
@ -81,11 +82,12 @@ func (c *Client) RequestAutoEncryptCerts(servers []string, port int, token strin
// Translate host to net.TCPAddr to make life easier for
// RPCInsecure.
for _, s := range servers {
ips, err := resolveAddr(s, c.logger)
ips, port, err := resolveAddr(s, defaultPort, c.logger)
if err != nil {
c.logger.Printf("[WARN] agent: AutoEncrypt resolveAddr failed: %v", err)
continue
}
for _, ip := range ips {
addr := net.TCPAddr{IP: ip, Port: port}
@ -112,16 +114,29 @@ func (c *Client) RequestAutoEncryptCerts(servers []string, port int, token strin
}
}
// resolveAddr is used to resolve the address into an address,
// port, and error. If no port is given, use the default
func resolveAddr(rawHost string, logger *log.Logger) ([]net.IP, error) {
host, _, err := net.SplitHostPort(rawHost)
if err != nil && err.Error() != "missing port in address" {
return nil, err
// resolveAddr is used to resolve the host into IPs, port, and error.
// If no port is given, use the default
func resolveAddr(rawHost string, defaultPort int, logger *log.Logger) ([]net.IP, int, error) {
host, splitPort, err := net.SplitHostPort(rawHost)
if err != nil && err.Error() != fmt.Sprintf("address %s: missing port in address", rawHost) {
return nil, defaultPort, err
}
// SplitHostPort returns empty host and splitPort on missingPort err,
// so those are set to defaults
var port int
if err != nil {
host = rawHost
port = defaultPort
} else {
port, err = strconv.Atoi(splitPort)
if err != nil {
port = defaultPort
}
}
if ip := net.ParseIP(host); ip != nil {
return []net.IP{ip}, nil
return []net.IP{ip}, port, nil
}
// First try TCP so we have the best chance for the largest list of
@ -130,13 +145,17 @@ func resolveAddr(rawHost string, logger *log.Logger) ([]net.IP, error) {
if ips, err := tcpLookupIP(host, logger); err != nil {
logger.Printf("[DEBUG] agent: TCP-first lookup failed for '%s', falling back to UDP: %s", host, err)
} else if len(ips) > 0 {
return ips, nil
return ips, port, nil
}
// If TCP didn't yield anything then use the normal Go resolver which
// will try UDP, then might possibly try TCP again if the UDP response
// indicates it was truncated.
return net.LookupIP(host)
ips, err := net.LookupIP(host)
if err != nil {
return nil, port, err
}
return ips, port, nil
}
// tcpLookupIP is a helper to initiate a TCP-based DNS lookup for the given host.

View File

@ -0,0 +1,80 @@
package consul
import (
"github.com/stretchr/testify/require"
"log"
"net"
"os"
"testing"
)
func TestAutoEncrypt_resolveAddr(t *testing.T) {
type args struct {
rawHost string
defaultPort int
logger *log.Logger
}
tests := []struct {
name string
args args
ips []net.IP
port int
wantErr bool
}{
{
name: "host without port",
args: args{
"127.0.0.1",
8300,
log.New(os.Stderr, "", log.LstdFlags),
},
ips: []net.IP{net.IPv4(127, 0, 0, 1)},
port: 8300,
wantErr: false,
},
{
name: "host with port",
args: args{
"127.0.0.1:1234",
8300,
log.New(os.Stderr, "", log.LstdFlags),
},
ips: []net.IP{net.IPv4(127, 0, 0, 1)},
port: 1234,
wantErr: false,
},
{
name: "host with broken port",
args: args{
"127.0.0.1:xyz",
8300,
log.New(os.Stderr, "", log.LstdFlags),
},
ips: []net.IP{net.IPv4(127, 0, 0, 1)},
port: 8300,
wantErr: false,
},
{
name: "not an address",
args: args{
"abc",
8300,
log.New(os.Stderr, "", log.LstdFlags),
},
ips: nil,
port: 8300,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ips, port, err := resolveAddr(tt.args.rawHost, tt.args.defaultPort, tt.args.logger)
if (err != nil) != tt.wantErr {
t.Errorf("resolveAddr error: %v, wantErr: %v", err, tt.wantErr)
return
}
require.Equal(t, tt.ips, ips)
require.Equal(t, tt.port, port)
})
}
}