command/envoy: Refactor flag parsing/validation (#7504)
This commit is contained in:
parent
02cacf8128
commit
2569b2c6dd
|
@ -11,6 +11,7 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/mitchellh/cli"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
|
||||
"github.com/hashicorp/consul/agent/structs"
|
||||
|
@ -19,9 +20,6 @@ import (
|
|||
proxyCmd "github.com/hashicorp/consul/command/connect/proxy"
|
||||
"github.com/hashicorp/consul/command/flags"
|
||||
"github.com/hashicorp/consul/ipaddr"
|
||||
"github.com/hashicorp/go-sockaddr/template"
|
||||
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
||||
func New(ui cli.Ui) *cmd {
|
||||
|
@ -60,10 +58,10 @@ type cmd struct {
|
|||
|
||||
// mesh gateway registration information
|
||||
register bool
|
||||
address string
|
||||
wanAddress string
|
||||
lanAddress ServiceAddressValue
|
||||
wanAddress ServiceAddressValue
|
||||
deregAfterCritical string
|
||||
bindAddresses map[string]string
|
||||
bindAddresses ServiceAddressMapValue
|
||||
exposeServers bool
|
||||
|
||||
meshGatewaySvcName string
|
||||
|
@ -120,13 +118,13 @@ func (c *cmd) init() {
|
|||
c.flags.BoolVar(&c.register, "register", false,
|
||||
"Register a new Mesh Gateway service before configuring and starting Envoy")
|
||||
|
||||
c.flags.StringVar(&c.address, "address", "",
|
||||
c.flags.Var(&c.lanAddress, "address",
|
||||
"LAN address to advertise in the Mesh Gateway service registration")
|
||||
|
||||
c.flags.StringVar(&c.wanAddress, "wan-address", "",
|
||||
c.flags.Var(&c.wanAddress, "wan-address",
|
||||
"WAN address to advertise in the Mesh Gateway service registration")
|
||||
|
||||
c.flags.Var((*flags.FlagMapValue)(&c.bindAddresses), "bind-address", "Bind "+
|
||||
c.flags.Var(&c.bindAddresses, "bind-address", "Bind "+
|
||||
"address to use instead of the default binding rules given as `<name>=<ip>:<port>` "+
|
||||
"pairs. This flag may be specified multiple times to add multiple bind addresses.")
|
||||
|
||||
|
@ -145,38 +143,6 @@ func (c *cmd) init() {
|
|||
c.help = flags.Usage(help, c.flags)
|
||||
}
|
||||
|
||||
const (
|
||||
DefaultMeshGatewayPort int = 443
|
||||
)
|
||||
|
||||
func parseAddress(addrStr string) (string, int, error) {
|
||||
if addrStr == "" {
|
||||
// defaulting the port to 443
|
||||
return "", DefaultMeshGatewayPort, nil
|
||||
}
|
||||
|
||||
x, err := template.Parse(addrStr)
|
||||
if err != nil {
|
||||
return "", DefaultMeshGatewayPort, fmt.Errorf("Error parsing address %q: %v", addrStr, err)
|
||||
}
|
||||
|
||||
addr, portStr, err := net.SplitHostPort(x)
|
||||
if err != nil {
|
||||
return "", DefaultMeshGatewayPort, fmt.Errorf("Error parsing address %q: %v", x, err)
|
||||
}
|
||||
|
||||
port := DefaultMeshGatewayPort
|
||||
|
||||
if portStr != "" {
|
||||
port, err = strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return "", DefaultMeshGatewayPort, fmt.Errorf("Error parsing port %q: %v", portStr, err)
|
||||
}
|
||||
}
|
||||
|
||||
return addr, port, nil
|
||||
}
|
||||
|
||||
// canBindInternal is here mainly so we can unit test this with a constant net.Addr list
|
||||
func canBindInternal(addr string, ifAddrs []net.Addr) bool {
|
||||
if addr == "" {
|
||||
|
@ -206,13 +172,13 @@ func canBindInternal(addr string, ifAddrs []net.Addr) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func canBind(addr string) bool {
|
||||
func canBind(addr api.ServiceAddress) bool {
|
||||
ifAddrs, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return canBindInternal(addr, ifAddrs)
|
||||
return canBindInternal(addr.Address, ifAddrs)
|
||||
}
|
||||
|
||||
func (c *cmd) Run(args []string) int {
|
||||
|
@ -246,30 +212,18 @@ func (c *cmd) Run(args []string) int {
|
|||
return 1
|
||||
}
|
||||
|
||||
lanAddr, lanPort, err := parseAddress(c.address)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Failed to parse the -address parameter: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
taggedAddrs := make(map[string]api.ServiceAddress)
|
||||
|
||||
if lanAddr != "" {
|
||||
taggedAddrs[structs.TaggedAddressLAN] = api.ServiceAddress{Address: lanAddr, Port: lanPort}
|
||||
lanAddr := c.lanAddress.Value()
|
||||
if lanAddr.Address != "" {
|
||||
taggedAddrs[structs.TaggedAddressLAN] = lanAddr
|
||||
}
|
||||
|
||||
wanAddr := ""
|
||||
wanPort := lanPort
|
||||
if c.wanAddress != "" {
|
||||
wanAddr, wanPort, err = parseAddress(c.wanAddress)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Failed to parse the -wan-address parameter: %v", err))
|
||||
return 1
|
||||
}
|
||||
taggedAddrs[structs.TaggedAddressWAN] = api.ServiceAddress{Address: wanAddr, Port: wanPort}
|
||||
wanAddr := c.wanAddress.Value()
|
||||
if wanAddr.Address != "" {
|
||||
taggedAddrs[structs.TaggedAddressWAN] = wanAddr
|
||||
}
|
||||
|
||||
tcpCheckAddr := lanAddr
|
||||
tcpCheckAddr := lanAddr.Address
|
||||
if tcpCheckAddr == "" {
|
||||
// fallback to localhost as the gateway has to reside in the same network namespace
|
||||
// as the agent
|
||||
|
@ -278,24 +232,12 @@ func (c *cmd) Run(args []string) int {
|
|||
|
||||
var proxyConf *api.AgentServiceConnectProxyConfig
|
||||
|
||||
if len(c.bindAddresses) > 0 {
|
||||
if len(c.bindAddresses.value) > 0 {
|
||||
// override all default binding rules and just bind to the user-supplied addresses
|
||||
bindAddresses := make(map[string]api.ServiceAddress)
|
||||
|
||||
for addrName, addrStr := range c.bindAddresses {
|
||||
addr, port, err := parseAddress(addrStr)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Failed to parse the bind address: %s=%s: %v", addrName, addrStr, err))
|
||||
return 1
|
||||
}
|
||||
|
||||
bindAddresses[addrName] = api.ServiceAddress{Address: addr, Port: port}
|
||||
}
|
||||
|
||||
proxyConf = &api.AgentServiceConnectProxyConfig{
|
||||
Config: map[string]interface{}{
|
||||
"envoy_mesh_gateway_no_default_bind": true,
|
||||
"envoy_mesh_gateway_bind_addresses": bindAddresses,
|
||||
"envoy_mesh_gateway_bind_addresses": c.bindAddresses.value,
|
||||
},
|
||||
}
|
||||
} else if canBind(lanAddr) && canBind(wanAddr) {
|
||||
|
@ -307,8 +249,8 @@ func (c *cmd) Run(args []string) int {
|
|||
"envoy_mesh_gateway_bind_tagged_addresses": true,
|
||||
},
|
||||
}
|
||||
} else if !canBind(lanAddr) && lanAddr != "" {
|
||||
c.UI.Error(fmt.Sprintf("The LAN address %q will not be bindable. Either set a bindable address or override the bind addresses with -bind-address", lanAddr))
|
||||
} else if !canBind(lanAddr) && lanAddr.Address != "" {
|
||||
c.UI.Error(fmt.Sprintf("The LAN address %q will not be bindable. Either set a bindable address or override the bind addresses with -bind-address", lanAddr.Address))
|
||||
return 1
|
||||
}
|
||||
|
||||
|
@ -320,14 +262,14 @@ func (c *cmd) Run(args []string) int {
|
|||
svc := api.AgentServiceRegistration{
|
||||
Kind: api.ServiceKindMeshGateway,
|
||||
Name: c.meshGatewaySvcName,
|
||||
Address: lanAddr,
|
||||
Port: lanPort,
|
||||
Address: lanAddr.Address,
|
||||
Port: lanAddr.Port,
|
||||
Meta: meta,
|
||||
TaggedAddresses: taggedAddrs,
|
||||
Proxy: proxyConf,
|
||||
Check: &api.AgentServiceCheck{
|
||||
Name: "Mesh Gateway Listening",
|
||||
TCP: ipaddr.FormatAddressPort(tcpCheckAddr, lanPort),
|
||||
TCP: ipaddr.FormatAddressPort(tcpCheckAddr, lanAddr.Port),
|
||||
Interval: "10s",
|
||||
DeregisterCriticalServiceAfter: c.deregAfterCritical,
|
||||
},
|
||||
|
|
95
command/connect/envoy/flags.go
Normal file
95
command/connect/envoy/flags.go
Normal file
|
@ -0,0 +1,95 @@
|
|||
package envoy
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/go-sockaddr/template"
|
||||
)
|
||||
|
||||
const defaultMeshGatewayPort int = 443
|
||||
|
||||
// ServiceAddressValue implements a flag.Value that may be used to parse an
|
||||
// addr:port string into an api.ServiceAddress.
|
||||
type ServiceAddressValue struct {
|
||||
value api.ServiceAddress
|
||||
}
|
||||
|
||||
func (s *ServiceAddressValue) String() string {
|
||||
if s == nil {
|
||||
return fmt.Sprintf(":%d", defaultMeshGatewayPort)
|
||||
}
|
||||
return fmt.Sprintf("%v:%d", s.value.Address, s.value.Port)
|
||||
}
|
||||
|
||||
func (s *ServiceAddressValue) Value() api.ServiceAddress {
|
||||
if s == nil || s.value.Port == 0 && s.value.Address == "" {
|
||||
return api.ServiceAddress{Port: defaultMeshGatewayPort}
|
||||
}
|
||||
return s.value
|
||||
}
|
||||
|
||||
func (s *ServiceAddressValue) Set(raw string) error {
|
||||
var err error
|
||||
s.value, err = parseAddress(raw)
|
||||
return err
|
||||
}
|
||||
|
||||
func parseAddress(raw string) (api.ServiceAddress, error) {
|
||||
result := api.ServiceAddress{}
|
||||
x, err := template.Parse(raw)
|
||||
if err != nil {
|
||||
return result, fmt.Errorf("Error parsing address %q: %v", raw, err)
|
||||
}
|
||||
|
||||
addr, portStr, err := net.SplitHostPort(x)
|
||||
if err != nil {
|
||||
return result, fmt.Errorf("Error parsing address %q: %v", x, err)
|
||||
}
|
||||
|
||||
port := defaultMeshGatewayPort
|
||||
if portStr != "" {
|
||||
port, err = strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return result, fmt.Errorf("Error parsing port %q: %v", portStr, err)
|
||||
}
|
||||
}
|
||||
|
||||
result.Address = addr
|
||||
result.Port = port
|
||||
return result, nil
|
||||
}
|
||||
|
||||
var _ flag.Value = (*ServiceAddressValue)(nil)
|
||||
|
||||
type ServiceAddressMapValue struct {
|
||||
value map[string]api.ServiceAddress
|
||||
}
|
||||
|
||||
func (s *ServiceAddressMapValue) String() string {
|
||||
buf := new(strings.Builder)
|
||||
for k, v := range s.value {
|
||||
buf.WriteString(fmt.Sprintf("%v=%v:%d,", k, v.Address, v.Port))
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func (s *ServiceAddressMapValue) Set(raw string) error {
|
||||
if s.value == nil {
|
||||
s.value = make(map[string]api.ServiceAddress)
|
||||
}
|
||||
idx := strings.Index(raw, "=")
|
||||
if idx == -1 {
|
||||
return fmt.Errorf(`Missing "=" in argument: %s`, raw)
|
||||
}
|
||||
key, value := raw[0:idx], raw[idx+1:]
|
||||
var err error
|
||||
s.value[key], err = parseAddress(value)
|
||||
return err
|
||||
}
|
||||
|
||||
var _ flag.Value = (*ServiceAddressMapValue)(nil)
|
80
command/connect/envoy/flags_test.go
Normal file
80
command/connect/envoy/flags_test.go
Normal file
|
@ -0,0 +1,80 @@
|
|||
package envoy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestServiceAddressValue_Value(t *testing.T) {
|
||||
t.Run("nil receiver", func(t *testing.T) {
|
||||
var addr *ServiceAddressValue
|
||||
require.Equal(t, addr.Value(), api.ServiceAddress{Port: defaultMeshGatewayPort})
|
||||
})
|
||||
|
||||
t.Run("default value", func(t *testing.T) {
|
||||
addr := &ServiceAddressValue{}
|
||||
require.Equal(t, addr.Value(), api.ServiceAddress{Port: defaultMeshGatewayPort})
|
||||
})
|
||||
|
||||
t.Run("set value", func(t *testing.T) {
|
||||
addr := &ServiceAddressValue{}
|
||||
require.NoError(t, addr.Set("localhost:3333"))
|
||||
require.Equal(t, addr.Value(), api.ServiceAddress{
|
||||
Address: "localhost",
|
||||
Port: 3333,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestServiceAddressValue_Set(t *testing.T) {
|
||||
var testcases = []struct {
|
||||
name string
|
||||
input string
|
||||
expectedErr string
|
||||
expectedValue api.ServiceAddress
|
||||
}{
|
||||
{
|
||||
name: "default port",
|
||||
input: "8.8.8.8:",
|
||||
expectedValue: api.ServiceAddress{
|
||||
Address: "8.8.8.8",
|
||||
Port: defaultMeshGatewayPort,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid address",
|
||||
input: "8.8.8.8:1234",
|
||||
expectedValue: api.ServiceAddress{Address: "8.8.8.8", Port: 1234},
|
||||
},
|
||||
{
|
||||
name: "invalid address",
|
||||
input: "not-an-address",
|
||||
expectedErr: "missing port in address",
|
||||
},
|
||||
{
|
||||
name: "invalid port",
|
||||
input: "localhost:notaport",
|
||||
expectedErr: `Error parsing port "notaport"`,
|
||||
},
|
||||
{
|
||||
name: "invalid address format",
|
||||
input: "too:many:colons",
|
||||
expectedErr: "address too:many:colons: too many colons",
|
||||
},
|
||||
}
|
||||
for _, tc := range testcases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
addr := &ServiceAddressValue{}
|
||||
err := addr.Set(tc.input)
|
||||
if tc.expectedErr != "" {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tc.expectedErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.Equal(t, addr.Value(), tc.expectedValue)
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue