Various refactoring to clean up code organization

Brought to you by: Dept of 2nd thoughts before pushing enter on `git push`
This commit is contained in:
Sean Chittenden 2016-04-23 19:53:21 -07:00
parent 53f9cea87c
commit 60006f550f
5 changed files with 169 additions and 137 deletions

View File

@ -203,9 +203,6 @@ func (c *ServerCommand) Run(args []string) int {
if envAA := os.Getenv("VAULT_ADVERTISE_ADDR"); envAA != "" { if envAA := os.Getenv("VAULT_ADVERTISE_ADDR"); envAA != "" {
coreConfig.AdvertiseAddr = envAA coreConfig.AdvertiseAddr = envAA
if consulBackend, ok := (backend).(*physical.ConsulBackend); ok {
consulBackend.UpdateAdvertiseAddr(envAA)
}
} }
// Attempt to detect the advertise address, if possible // Attempt to detect the advertise address, if possible
@ -223,9 +220,6 @@ func (c *ServerCommand) Run(args []string) int {
c.Ui.Error("Failed to detect advertise address.") c.Ui.Error("Failed to detect advertise address.")
} else { } else {
coreConfig.AdvertiseAddr = advertise coreConfig.AdvertiseAddr = advertise
if consulBackend, ok := (backend).(*physical.ConsulBackend); ok {
consulBackend.UpdateAdvertiseAddr(advertise)
}
} }
} }
@ -296,6 +290,11 @@ func (c *ServerCommand) Run(args []string) int {
if coreConfig.HAPhysical != nil { if coreConfig.HAPhysical != nil {
sd, ok := coreConfig.HAPhysical.(physical.ServiceDiscovery) sd, ok := coreConfig.HAPhysical.(physical.ServiceDiscovery)
if ok { if ok {
if err := sd.UpdateAdvertiseAddr(coreConfig.AdvertiseAddr); err != nil {
c.Ui.Error(fmt.Sprintf("Error configuring service discovery: %v", err))
return 1
}
if err := sd.RunServiceDiscovery(c.ShutdownCh); err != nil { if err := sd.RunServiceDiscovery(c.ShutdownCh); err != nil {
c.Ui.Error(fmt.Sprintf("Error initializing service discovery: %v", err)) c.Ui.Error(fmt.Sprintf("Error initializing service discovery: %v", err))
return 1 return 1

View File

@ -44,21 +44,23 @@ const (
// prefix within Consul. It is used for most production situations as // prefix within Consul. It is used for most production situations as
// it allows Vault to run on multiple machines in a highly-available manner. // it allows Vault to run on multiple machines in a highly-available manner.
type ConsulBackend struct { type ConsulBackend struct {
path string path string
client *api.Client client *api.Client
kv *api.KV kv *api.KV
permitPool *PermitPool permitPool *PermitPool
serviceLock sync.RWMutex serviceLock sync.RWMutex
service *api.AgentServiceRegistration service *api.AgentServiceRegistration
sealedCheck *api.AgentCheckRegistration sealedCheck *api.AgentCheckRegistration
advertiseAddr string advertiseHost string
consulClientConf *api.Config advertisePort int
serviceName string consulClientConf *api.Config
running bool serviceName string
active bool running bool
sealed bool active bool
checkTimeout time.Duration sealed bool
checkTimer *time.Timer disableRegistration bool
checkTimeout time.Duration
checkTimer *time.Timer
} }
// newConsulBackend constructs a Consul backend using the given API client // newConsulBackend constructs a Consul backend using the given API client
@ -78,6 +80,17 @@ func newConsulBackend(conf map[string]string) (Backend, error) {
path = strings.TrimPrefix(path, "/") path = strings.TrimPrefix(path, "/")
} }
// Allow admins to disable consul integration
disableReg, ok := conf["disable_registration"]
var disableRegistration bool
if ok && disableReg != "" {
b, err := strconv.ParseBool(disableReg)
if err != nil {
return nil, errwrap.Wrapf("failed parsing disable_registration parameter: {{err}}", err)
}
disableRegistration = b
}
// Get the service name to advertise in Consul // Get the service name to advertise in Consul
service, ok := conf["service"] service, ok := conf["service"]
if !ok { if !ok {
@ -141,14 +154,15 @@ func newConsulBackend(conf map[string]string) (Backend, error) {
// Setup the backend // Setup the backend
c := &ConsulBackend{ c := &ConsulBackend{
path: path, path: path,
client: client, client: client,
kv: client.KV(), kv: client.KV(),
permitPool: NewPermitPool(maxParInt), permitPool: NewPermitPool(maxParInt),
consulClientConf: consulConf, consulClientConf: consulConf,
serviceName: service, serviceName: service,
checkTimeout: checkTimeout, checkTimeout: checkTimeout,
checkTimer: time.NewTimer(checkTimeout), checkTimer: time.NewTimer(checkTimeout),
disableRegistration: disableRegistration,
} }
return c, nil return c, nil
} }
@ -160,21 +174,13 @@ func (c *ConsulBackend) UpdateAdvertiseAddr(addr string) error {
return fmt.Errorf("service registration unable to update advertise address, backend already running") return fmt.Errorf("service registration unable to update advertise address, backend already running")
} }
url, err := url.Parse(addr) host, port, err := parseAdvertiseAddr(addr)
if err != nil { if err != nil {
return errwrap.Wrapf(fmt.Sprintf(`updating advertise address failed to parse URL "%v": {{err}}`, addr), err) return errwrap.Wrapf(fmt.Sprintf(`failed to parse advertise address "%v": {{err}}`, addr), err)
} }
_, portStr, err := net.SplitHostPort(url.Host) c.advertiseHost = host
if err != nil { c.advertisePort = int(port)
return errwrap.Wrapf(fmt.Sprintf(`updating advertise address failed to find a host:port in advertise address "%v": {{err}}`, url.Host), err)
}
_, err = strconv.ParseInt(portStr, 10, 0)
if err != nil {
return errwrap.Wrapf(fmt.Sprintf(`updating advertise address failed to parse port "%v": {{err}}`, portStr), err)
}
c.advertiseAddr = addr
return nil return nil
} }
@ -197,10 +203,12 @@ func (c *ConsulBackend) AdvertiseActive(active bool) error {
return nil return nil
} }
c.service.Tags = serviceTags(active) if !c.disableRegistration {
agent := c.client.Agent() c.service.Tags = serviceTags(active)
if err := agent.ServiceRegister(c.service); err != nil { agent := c.client.Agent()
return errwrap.Wrapf("service registration failed: {{err}}", err) if err := agent.ServiceRegister(c.service); err != nil {
return errwrap.Wrapf("service registration failed: {{err}}", err)
}
} }
// Save a cached copy of the active state: no way to query Core // Save a cached copy of the active state: no way to query Core
@ -219,8 +227,10 @@ func (c *ConsulBackend) AdvertiseSealed(sealed bool) error {
return nil return nil
} }
// Push a TTL check immediately to update the state if !c.disableRegistration {
c.runCheck() // Push a TTL check immediately to update the state
c.runCheck()
}
return nil return nil
} }
@ -229,35 +239,22 @@ func (c *ConsulBackend) RunServiceDiscovery(shutdownCh ShutdownChannel) (err err
c.serviceLock.Lock() c.serviceLock.Lock()
defer c.serviceLock.Unlock() defer c.serviceLock.Unlock()
if c.disableRegistration {
return nil
}
if c.running { if c.running {
return fmt.Errorf("service registration routine already running") return fmt.Errorf("service registration routine already running")
} }
url, err := url.Parse(c.advertiseAddr) serviceID := c.serviceID()
if err != nil {
return errwrap.Wrapf(fmt.Sprintf(`service registration failed to parse URL "%v": {{err}}`, c.advertiseAddr), err)
}
host, portStr, err := net.SplitHostPort(url.Host)
if err != nil {
return errwrap.Wrapf(fmt.Sprintf(`service registration failed to find a host:port in advertise address "%v": {{err}}`, url.Host), err)
}
port, err := strconv.ParseInt(portStr, 10, 0)
if err != nil {
return errwrap.Wrapf(fmt.Sprintf(`service registration failed to parse port "%v": {{err}}`, portStr), err)
}
serviceID, err := c.serviceID()
if err != nil {
return err
}
c.service = &api.AgentServiceRegistration{ c.service = &api.AgentServiceRegistration{
ID: serviceID, ID: serviceID,
Name: c.serviceName, Name: c.serviceName,
Tags: serviceTags(c.active), Tags: serviceTags(c.active),
Port: int(port), Port: c.advertisePort,
Address: host, Address: c.advertiseHost,
EnableTagOverride: false, EnableTagOverride: false,
} }
@ -351,22 +348,31 @@ func (c *ConsulBackend) checkID() string {
// serviceID returns the Vault ServiceID for use in Consul. Assume at least // serviceID returns the Vault ServiceID for use in Consul. Assume at least
// a read lock is held. // a read lock is held.
func (c *ConsulBackend) serviceID() (string, error) { func (c *ConsulBackend) serviceID() string {
url, err := url.Parse(c.advertiseAddr) return fmt.Sprintf("%s:%s:%d", c.serviceName, c.advertiseHost, c.advertisePort)
if err != nil { }
return "", errwrap.Wrapf(fmt.Sprintf(`service registration failed to parse URL "%v": {{err}}`, c.advertiseAddr), err)
func parseAdvertiseAddr(addr string) (host string, port int, err error) {
if addr == "" {
return "", -1, fmt.Errorf("advertise address must not be empty")
} }
host, portStr, err := net.SplitHostPort(url.Host) url, err := url.Parse(addr)
if err != nil { if err != nil {
return "", errwrap.Wrapf(fmt.Sprintf(`service registration failed to find a host:port in advertise address "%v": {{err}}`, url.Host), err) return "", -2, errwrap.Wrapf(fmt.Sprintf(`failed to parse advertise URL "%v": {{err}}`, addr), err)
}
port, err := strconv.ParseInt(portStr, 10, 0)
if err != nil {
return "", errwrap.Wrapf(fmt.Sprintf(`service registration failed to parse port "%v": {{err}}`, portStr), err)
} }
return fmt.Sprintf("%s:%s:%d", c.serviceName, host, int(port)), nil var portStr string
host, portStr, err = net.SplitHostPort(url.Host)
if err != nil {
return "", -3, errwrap.Wrapf(fmt.Sprintf(`failed to find a host:port in advertise address "%v": {{err}}`, url.Host), err)
}
portNum, err := strconv.ParseInt(portStr, 10, 0)
if err != nil || portNum < 1 || portNum > 65535 {
return "", -4, errwrap.Wrapf(fmt.Sprintf(`failed to parse valid port "%v": {{err}}`, portStr), err)
}
return host, int(portNum), nil
} }
func setupTLSConfig(conf map[string]string) (*tls.Config, error) { func setupTLSConfig(conf map[string]string) (*tls.Config, error) {

View File

@ -86,67 +86,81 @@ func testConsul_testConsulBackend(t *testing.T) {
func TestConsul_newConsulBackend(t *testing.T) { func TestConsul_newConsulBackend(t *testing.T) {
tests := []struct { tests := []struct {
Name string name string
Config map[string]string consulConfig map[string]string
Fail bool fail bool
checkTimeout time.Duration advertiseAddr string
path string checkTimeout time.Duration
service string path string
address string service string
scheme string address string
token string scheme string
max_parallel int token string
max_parallel int
disableReg bool
}{ }{
{ {
Name: "Valid default config", name: "Valid default config",
Config: map[string]string{}, consulConfig: map[string]string{},
checkTimeout: 5 * time.Second, checkTimeout: 5 * time.Second,
path: "vault", advertiseAddr: "http://127.0.0.1:8200",
service: "vault", path: "vault/",
address: "127.0.0.1", service: "vault",
scheme: "http", address: "127.0.0.1:8500",
token: "", scheme: "http",
max_parallel: 4, token: "",
max_parallel: 4,
disableReg: false,
}, },
{ {
Name: "Valid modified config", name: "Valid modified config",
Config: map[string]string{ consulConfig: map[string]string{
"path": "seaTech/", "path": "seaTech/",
"service": "astronomy", "service": "astronomy",
"check_timeout": "6s", "advertiseAddr": "http://127.0.0.2:8200",
"address": "127.0.0.2", "check_timeout": "6s",
"scheme": "https", "address": "127.0.0.2",
"token": "deadbeef-cafeefac-deadc0de-feedface", "scheme": "https",
"max_parallel": "4", "token": "deadbeef-cafeefac-deadc0de-feedface",
"max_parallel": "4",
"disable_registration": "false",
}, },
checkTimeout: 6 * time.Second, checkTimeout: 6 * time.Second,
path: "seaTech/", path: "seaTech/",
service: "astronomy", service: "astronomy",
address: "127.0.0.2", advertiseAddr: "http://127.0.0.2:8200",
scheme: "https", address: "127.0.0.2",
token: "deadbeef-cafeefac-deadc0de-feedface", scheme: "https",
max_parallel: 4, token: "deadbeef-cafeefac-deadc0de-feedface",
max_parallel: 4,
}, },
{ {
Name: "check timeout too short", name: "check timeout too short",
Fail: true, fail: true,
Config: map[string]string{ consulConfig: map[string]string{
"check_timeout": "99ms", "check_timeout": "99ms",
}, },
}, },
} }
for _, test := range tests { for _, test := range tests {
be, err := newConsulBackend(test.Config) be, err := newConsulBackend(test.consulConfig)
if test.Fail && err == nil { if test.fail {
t.Fatalf("Expected config %s to fail", test.Name) if err == nil {
} else if !test.Fail && err != nil { t.Fatalf(`Expected config "%s" to fail`, test.name)
t.Fatalf("Expected config %s to not fail: %v", test.Name, err) } else {
continue
}
} else if !test.fail && err != nil {
t.Fatalf("Expected config %s to not fail: %v", test.name, err)
} }
c, ok := be.(*ConsulBackend) c, ok := be.(*ConsulBackend)
if !ok { if !ok {
t.Fatalf("Expected ConsulBackend") t.Fatalf("Expected ConsulBackend: %s", test.name)
}
if err := c.UpdateAdvertiseAddr(test.advertiseAddr); err != nil {
t.Fatalf("bad: %v", err)
} }
if test.checkTimeout != c.checkTimeout { if test.checkTimeout != c.checkTimeout {
@ -154,7 +168,7 @@ func TestConsul_newConsulBackend(t *testing.T) {
} }
if test.path != c.path { if test.path != c.path {
t.Errorf("bad: %v != %v", test.path, c.path) t.Errorf("bad: %s %v != %v", test.name, test.path, c.path)
} }
if test.service != c.serviceName { if test.service != c.serviceName {
@ -162,7 +176,7 @@ func TestConsul_newConsulBackend(t *testing.T) {
} }
if test.address != c.consulClientConf.Address { if test.address != c.consulClientConf.Address {
t.Errorf("bad: %v != %v", test.address, c.consulClientConf.Address) t.Errorf("bad: %s %v != %v", test.name, test.address, c.consulClientConf.Address)
} }
if test.scheme != c.consulClientConf.Scheme { if test.scheme != c.consulClientConf.Scheme {
@ -206,14 +220,20 @@ func TestConsul_serviceTags(t *testing.T) {
func TestConsul_UpdateAdvertiseAddr(t *testing.T) { func TestConsul_UpdateAdvertiseAddr(t *testing.T) {
tests := []struct { tests := []struct {
addr string addr string
host string
port int
pass bool pass bool
}{ }{
{ {
addr: "http://127.0.0.1:8200/", addr: "http://127.0.0.1:8200/",
host: "127.0.0.1",
port: 8200,
pass: true, pass: true,
}, },
{ {
addr: "http://127.0.0.1:8200", addr: "http://127.0.0.1:8200",
host: "127.0.0.1",
port: 8200,
pass: true, pass: true,
}, },
{ {
@ -244,8 +264,12 @@ func TestConsul_UpdateAdvertiseAddr(t *testing.T) {
} }
} }
if c.advertiseAddr != test.addr { if c.advertiseHost != test.host {
t.Fatalf("bad: %v != %v", c.advertiseAddr, test.addr) t.Fatalf("bad: %v != %v", c.advertiseHost, test.host)
}
if c.advertisePort != test.port {
t.Fatalf("bad: %v != %v", c.advertisePort, test.port)
} }
} }
} }
@ -330,21 +354,25 @@ func TestConsul_checkID(t *testing.T) {
func TestConsul_serviceID(t *testing.T) { func TestConsul_serviceID(t *testing.T) {
passingTests := []struct { passingTests := []struct {
name string
advertiseAddr string advertiseAddr string
serviceName string serviceName string
expected string expected string
}{ }{
{ {
name: "valid host w/o slash",
advertiseAddr: "http://127.0.0.1:8200", advertiseAddr: "http://127.0.0.1:8200",
serviceName: "sea-tech-astronomy", serviceName: "sea-tech-astronomy",
expected: "sea-tech-astronomy:127.0.0.1:8200", expected: "sea-tech-astronomy:127.0.0.1:8200",
}, },
{ {
name: "valid host w/ slash",
advertiseAddr: "http://127.0.0.1:8200/", advertiseAddr: "http://127.0.0.1:8200/",
serviceName: "sea-tech-astronomy", serviceName: "sea-tech-astronomy",
expected: "sea-tech-astronomy:127.0.0.1:8200", expected: "sea-tech-astronomy:127.0.0.1:8200",
}, },
{ {
name: "valid https host w/ slash",
advertiseAddr: "https://127.0.0.1:8200/", advertiseAddr: "https://127.0.0.1:8200/",
serviceName: "sea-tech-astronomy", serviceName: "sea-tech-astronomy",
expected: "sea-tech-astronomy:127.0.0.1:8200", expected: "sea-tech-astronomy:127.0.0.1:8200",
@ -357,14 +385,10 @@ func TestConsul_serviceID(t *testing.T) {
}) })
if err := c.UpdateAdvertiseAddr(test.advertiseAddr); err != nil { if err := c.UpdateAdvertiseAddr(test.advertiseAddr); err != nil {
t.Fatalf("bad: %v", err) t.Fatalf("bad: %s %v", test.name, err)
}
serviceID, err := c.serviceID()
if err != nil {
t.Fatalf("bad: %v", err)
} }
serviceID := c.serviceID()
if serviceID != test.expected { if serviceID != test.expected {
t.Fatalf("bad: %v != %v", serviceID, test.expected) t.Fatalf("bad: %v != %v", serviceID, test.expected)
} }

View File

@ -43,11 +43,6 @@ type HABackend interface {
type AdvertiseDetect interface { type AdvertiseDetect interface {
// DetectHostAddr is used to detect the host address // DetectHostAddr is used to detect the host address
DetectHostAddr() (string, error) DetectHostAddr() (string, error)
// UpdateAdvertiseAddr allows for a non-Running backend to update the
// advertise address. HABackends may want to present a different
// address that wasn't available when a Backend was created.
UpdateAdvertiseAddr(addr string) error
} }
// ServiceDiscovery is an optional interface that an HABackend can implement. // ServiceDiscovery is an optional interface that an HABackend can implement.
@ -65,6 +60,11 @@ type ServiceDiscovery interface {
// Run executes any background service discovery tasks until the // Run executes any background service discovery tasks until the
// shutdown channel is closed. // shutdown channel is closed.
RunServiceDiscovery(ShutdownChannel) error RunServiceDiscovery(ShutdownChannel) error
// UpdateAdvertiseAddr allows for a non-Running backend to update the
// advertise address. HABackends may want to present a different
// address that wasn't available when a Backend was created.
UpdateAdvertiseAddr(addr string) error
} }
type Lock interface { type Lock interface {

View File

@ -200,6 +200,9 @@ For Consul, the following options are supported:
* `scheme` (optional) - "http" or "https" for talking to Consul. * `scheme` (optional) - "http" or "https" for talking to Consul.
* `disable_registration` (optional) - If true, then Vault will not register
itself with Vault. Defaults to "false".
* `token` (optional) - An access token to use to write data to Consul. * `token` (optional) - An access token to use to write data to Consul.
* `max_parallel` (optional) - The maximum number of connections to Consul; * `max_parallel` (optional) - The maximum number of connections to Consul;