diff --git a/CHANGELOG.md b/CHANGELOG.md index f77d8ebe6..3dd8c2b21 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -67,6 +67,8 @@ IMPROVEMENTS: the `transit` backend are now disabled as well [GH-1346] * credential/cert: Renewal requests are rejected if the set of policies has changed since the token was issued [GH-477] + * credential/cert: Check CRLs for specific non-CA certs configured in the + backend [GH-1404] * credential/ldap: If `groupdn` is not configured, skip searching LDAP and only return policies for local groups, plus a warning [GH-1283] * credential/userpass: Add list support for users [GH-911] diff --git a/builtin/credential/cert/backend_test.go b/builtin/credential/cert/backend_test.go index 795ff3ff7..79e7abf77 100644 --- a/builtin/credential/cert/backend_test.go +++ b/builtin/credential/cert/backend_test.go @@ -110,6 +110,80 @@ func failOnError(t *testing.T, resp *logical.Response, err error) { } } +func TestBackend_RegisteredNonCA_CRL(t *testing.T) { + config := logical.TestBackendConfig() + storage := &logical.InmemStorage{} + config.StorageView = storage + + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + nonCACert, err := ioutil.ReadFile(testCertPath1) + if err != nil { + t.Fatal(err) + } + + // Register the Non-CA certificate of the client key pair + certData := map[string]interface{}{ + "certificate": nonCACert, + "policies": "abc", + "display_name": "cert1", + "ttl": 10000, + } + certReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "certs/cert1", + Storage: storage, + Data: certData, + } + + resp, err := b.HandleRequest(certReq) + failOnError(t, resp, err) + + // Connection state is presenting the client Non-CA cert and its key. + // This is exactly what is registered at the backend. + connState := connectionState(t, serverCAPath, serverCertPath, serverKeyPath, testCertPath1, testKeyPath1) + loginReq := &logical.Request{ + Operation: logical.UpdateOperation, + Storage: storage, + Path: "login", + Connection: &logical.Connection{ + ConnState: &connState, + }, + } + // Login should succeed. + resp, err = b.HandleRequest(loginReq) + failOnError(t, resp, err) + + // Register a CRL containing the issued client certificate used above. + issuedCRL, err := ioutil.ReadFile(testIssuedCertCRL) + if err != nil { + t.Fatal(err) + } + crlData := map[string]interface{}{ + "crl": issuedCRL, + } + crlReq := &logical.Request{ + Operation: logical.UpdateOperation, + Storage: storage, + Path: "crls/issuedcrl", + Data: crlData, + } + resp, err = b.HandleRequest(crlReq) + failOnError(t, resp, err) + + // Attempt login with the same connection state but with the CRL registered + resp, err = b.HandleRequest(loginReq) + if err != nil { + t.Fatal(err) + } + if resp == nil || !resp.IsError() { + t.Fatalf("expected failure due to revoked certificate") + } +} + func TestBackend_CRLs(t *testing.T) { config := logical.TestBackendConfig() storage := &logical.InmemStorage{} diff --git a/builtin/credential/cert/path_login.go b/builtin/credential/cert/path_login.go index 00a1dabc1..f37540e80 100644 --- a/builtin/credential/cert/path_login.go +++ b/builtin/credential/cert/path_login.go @@ -149,7 +149,7 @@ func (b *backend) verifyCredentials(req *logical.Request) (*ParsedCert, *logical // with the backend. if len(trustedNonCAs) != 0 { policy := b.matchNonCAPolicy(connState.PeerCertificates[0], trustedNonCAs) - if policy != nil { + if policy != nil && !b.checkForChainInCRLs(policy.Certificates) { return policy, nil, nil } } @@ -245,18 +245,21 @@ func (b *backend) loadTrustedCerts(store logical.Storage) (pool *x509.CertPool, return } -func (b *backend) checkForValidChain(store logical.Storage, chains [][]*x509.Certificate) bool { - var badChain bool - for _, chain := range chains { - badChain = false - for _, cert := range chain { - badCRLs := b.findSerialInCRLs(cert.SerialNumber) - if len(badCRLs) != 0 { - badChain = true - break - } +func (b *backend) checkForChainInCRLs(chain []*x509.Certificate) bool { + badChain := false + for _, cert := range chain { + badCRLs := b.findSerialInCRLs(cert.SerialNumber) + if len(badCRLs) != 0 { + badChain = true + break } - if !badChain { + } + return badChain +} + +func (b *backend) checkForValidChain(store logical.Storage, chains [][]*x509.Certificate) bool { + for _, chain := range chains { + if !b.checkForChainInCRLs(chain) { return true } } diff --git a/command/server.go b/command/server.go index 47ec3e0bf..8c3764b09 100644 --- a/command/server.go +++ b/command/server.go @@ -298,7 +298,21 @@ func (c *ServerCommand) Run(args []string) int { if coreConfig.HAPhysical != nil { sd, ok := coreConfig.HAPhysical.(physical.ServiceDiscovery) if ok { - if err := sd.RunServiceDiscovery(c.ShutdownCh, coreConfig.AdvertiseAddr); err != nil { + activeFunc := func() bool { + if isLeader, _, err := core.Leader(); err != nil { + return isLeader + } + return false + } + + sealedFunc := func() bool { + if sealed, err := core.Sealed(); err != nil { + return sealed + } + return true + } + + if err := sd.RunServiceDiscovery(c.ShutdownCh, coreConfig.AdvertiseAddr, activeFunc, sealedFunc); err != nil { c.Ui.Error(fmt.Sprintf("Error initializing service discovery: %v", err)) return 1 } diff --git a/physical/consul.go b/physical/consul.go index c5aacf9ac..11dfaf048 100644 --- a/physical/consul.go +++ b/physical/consul.go @@ -30,22 +30,24 @@ const ( // be executed too close to the TTL check timeout checkMinBuffer = 100 * time.Millisecond + // consulRetryInterval specifies the retry duration to use when an + // API call to the Consul agent fails. + consulRetryInterval = 1 * time.Second + // defaultCheckTimeout changes the timeout of TTL checks defaultCheckTimeout = 5 * time.Second - // defaultCheckInterval specifies the default interval used to send - // checks - defaultCheckInterval = 4 * time.Second - // defaultServiceName is the default Consul service name used when // advertising a Vault instance. defaultServiceName = "vault" - // registrationRetryInterval specifies the retry duration to use when - // a registration to the Consul agent fails. - registrationRetryInterval = 1 * time.Second + // reconcileTimeout is how often Vault should query Consul to detect + // and fix any state drift. + reconcileTimeout = 60 * time.Second ) +type notifyEvent struct{} + // ConsulBackend is a physical backend that stores data at specific // prefix within Consul. It is used for most production situations as // it allows Vault to run on multiple machines in a highly-available manner. @@ -56,19 +58,14 @@ type ConsulBackend struct { kv *api.KV permitPool *PermitPool serviceLock sync.RWMutex - service *api.AgentServiceRegistration - sealedCheck *api.AgentCheckRegistration - registrationLock int64 advertiseHost string advertisePort int64 - consulClientConf *api.Config serviceName string - running bool - active bool - unsealed bool disableRegistration bool checkTimeout time.Duration - checkTimer *time.Timer + + notifyActiveCh chan notifyEvent + notifySealedCh chan notifyEvent } // newConsulBackend constructs a Consul backend using the given API client @@ -79,6 +76,7 @@ func newConsulBackend(conf map[string]string, logger *log.Logger) (Backend, erro if !ok { path = "vault/" } + logger.Printf("[DEBUG]: consul: config path set to %v", path) // Ensure path is suffixed but not prefixed if !strings.HasSuffix(path, "/") { @@ -100,12 +98,14 @@ func newConsulBackend(conf map[string]string, logger *log.Logger) (Backend, erro } disableRegistration = b } + logger.Printf("[DEBUG]: consul: config disable_registration set to %v", disableRegistration) // Get the service name to advertise in Consul service, ok := conf["service"] if !ok { service = defaultServiceName } + logger.Printf("[DEBUG]: consul: config service set to %s", service) checkTimeout := defaultCheckTimeout checkTimeoutStr, ok := conf["check_timeout"] @@ -121,6 +121,7 @@ func newConsulBackend(conf map[string]string, logger *log.Logger) (Backend, erro } checkTimeout = d + logger.Printf("[DEBUG]: consul: config check_timeout set to %v", d) } // Configure the client @@ -128,12 +129,15 @@ func newConsulBackend(conf map[string]string, logger *log.Logger) (Backend, erro if addr, ok := conf["address"]; ok { consulConf.Address = addr + logger.Printf("[DEBUG]: consul: config address set to %d", addr) } if scheme, ok := conf["scheme"]; ok { consulConf.Scheme = scheme + logger.Printf("[DEBUG]: consul: config scheme set to %d", scheme) } if token, ok := conf["token"]; ok { consulConf.Token = token + logger.Printf("[DEBUG]: consul: config token set") } if consulConf.Scheme == "https" { @@ -146,6 +150,7 @@ func newConsulBackend(conf map[string]string, logger *log.Logger) (Backend, erro transport.MaxIdleConnsPerHost = 4 transport.TLSClientConfig = tlsClientConfig consulConf.HttpClient.Transport = transport + logger.Printf("[DEBUG]: consul: configured TLS") } client, err := api.NewClient(consulConf) @@ -170,235 +175,13 @@ func newConsulBackend(conf map[string]string, logger *log.Logger) (Backend, erro client: client, kv: client.KV(), permitPool: NewPermitPool(maxParInt), - consulClientConf: consulConf, serviceName: service, checkTimeout: checkTimeout, - checkTimer: time.NewTimer(checkTimeout), disableRegistration: disableRegistration, } return c, nil } -// serviceTags returns all of the relevant tags for Consul. -func serviceTags(active bool) []string { - activeTag := "standby" - if active { - activeTag = "active" - } - return []string{activeTag} -} - -func (c *ConsulBackend) AdvertiseActive(active bool) error { - c.serviceLock.Lock() - defer c.serviceLock.Unlock() - - // Vault is still bootstrapping - if c.service == nil { - return nil - } - - // Save a cached copy of the active state: no way to query Core - c.active = active - - // Ensure serial registration to the Consul agent. Allow for - // concurrent calls to update active status while a single task - // attempts, until successful, to update the Consul Agent. - if !c.disableRegistration && atomic.CompareAndSwapInt64(&c.registrationLock, 0, 1) { - defer atomic.CompareAndSwapInt64(&c.registrationLock, 1, 0) - - // Retry agent registration until successful - for { - c.service.Tags = serviceTags(c.active) - agent := c.client.Agent() - err := agent.ServiceRegister(c.service) - if err == nil { - // Success - return nil - } - - c.logger.Printf("[WARN] consul: service registration failed: %v", err) - c.serviceLock.Unlock() - time.Sleep(registrationRetryInterval) - c.serviceLock.Lock() - - if !c.running { - // Shutting down - return err - } - } - } - - // Successful concurrent update to active state - return nil -} - -func (c *ConsulBackend) AdvertiseSealed(sealed bool) error { - c.serviceLock.Lock() - defer c.serviceLock.Unlock() - c.unsealed = !sealed - - // Vault is still bootstrapping - if c.service == nil { - return nil - } - - if !c.disableRegistration { - // Push a TTL check immediately to update the state - c.runCheck() - } - - return nil -} - -func (c *ConsulBackend) RunServiceDiscovery(shutdownCh ShutdownChannel, advertiseAddr string) (err error) { - c.serviceLock.Lock() - defer c.serviceLock.Unlock() - - if c.disableRegistration { - return nil - } - - if err := c.setAdvertiseAddr(advertiseAddr); err != nil { - return err - } - - serviceID := c.serviceID() - - c.service = &api.AgentServiceRegistration{ - ID: serviceID, - Name: c.serviceName, - Tags: serviceTags(c.active), - Port: int(c.advertisePort), - Address: c.advertiseHost, - EnableTagOverride: false, - } - - checkStatus := api.HealthCritical - if c.unsealed { - checkStatus = api.HealthPassing - } - - c.sealedCheck = &api.AgentCheckRegistration{ - ID: c.checkID(), - Name: "Vault Sealed Status", - Notes: "Vault service is healthy when Vault is in an unsealed status and can become an active Vault server", - ServiceID: serviceID, - AgentServiceCheck: api.AgentServiceCheck{ - TTL: c.checkTimeout.String(), - Status: checkStatus, - }, - } - - agent := c.client.Agent() - if err := agent.ServiceRegister(c.service); err != nil { - return errwrap.Wrapf("service registration failed: {{err}}", err) - } - - if err := agent.CheckRegister(c.sealedCheck); err != nil { - return errwrap.Wrapf("service registration check registration failed: {{err}}", err) - } - - go c.checkRunner(shutdownCh) - c.running = true - - // Deregister upon shutdown - go func() { - shutdown: - for { - select { - case <-shutdownCh: - c.logger.Printf("[INFO]: consul: Shutting down consul backend") - break shutdown - } - } - - if err := agent.ServiceDeregister(serviceID); err != nil { - c.logger.Printf("[WARN]: consul: service deregistration failed: {{err}}", err) - } - c.running = false - }() - - return nil -} - -// checkRunner periodically runs TTL checks -func (c *ConsulBackend) checkRunner(shutdownCh ShutdownChannel) { - defer c.checkTimer.Stop() - - for { - select { - case <-c.checkTimer.C: - go func() { - c.serviceLock.Lock() - defer c.serviceLock.Unlock() - c.runCheck() - }() - case <-shutdownCh: - return - } - } -} - -// runCheck immediately pushes a TTL check. Assumes c.serviceLock is held -// exclusively. -func (c *ConsulBackend) runCheck() { - // Reset timer before calling run check in order to not slide the - // window of the next check. - c.checkTimer.Reset(lib.DurationMinusBuffer(c.checkTimeout, checkMinBuffer, checkJitterFactor)) - - // Run a TTL check - agent := c.client.Agent() - if c.unsealed { - agent.PassTTL(c.checkID(), "Vault Unsealed") - } else { - agent.FailTTL(c.checkID(), "Vault Sealed") - } -} - -// checkID returns the ID used for a Consul Check. Assume at least a read -// lock is held. -func (c *ConsulBackend) checkID() string { - return "vault-sealed-check" -} - -// serviceID returns the Vault ServiceID for use in Consul. Assume at least -// a read lock is held. -func (c *ConsulBackend) serviceID() string { - return fmt.Sprintf("%s:%s:%d", c.serviceName, c.advertiseHost, c.advertisePort) -} - -func (c *ConsulBackend) setAdvertiseAddr(addr string) (err error) { - if addr == "" { - return fmt.Errorf("advertise address must not be empty") - } - - url, err := url.Parse(addr) - if err != nil { - return errwrap.Wrapf(fmt.Sprintf(`failed to parse advertise URL "%v": {{err}}`, addr), err) - } - - var portStr string - c.advertiseHost, portStr, err = net.SplitHostPort(url.Host) - if err != nil { - if url.Scheme == "http" { - portStr = "80" - } else if url.Scheme == "https" { - portStr = "443" - } else if url.Scheme == "unix" { - portStr = "-1" - c.advertiseHost = url.Path - } else { - return errwrap.Wrapf(fmt.Sprintf(`failed to find a host:port in advertise address "%v": {{err}}`, url.Host), err) - } - } - c.advertisePort, err = strconv.ParseInt(portStr, 10, 0) - if err != nil || c.advertisePort < -1 || c.advertisePort > 65535 { - return errwrap.Wrapf(fmt.Sprintf(`failed to parse valid port "%v": {{err}}`, portStr), err) - } - - return nil -} - func setupTLSConfig(conf map[string]string) (*tls.Config, error) { serverName := strings.Split(conf["address"], ":") @@ -577,3 +360,274 @@ func (c *ConsulLock) Value() (bool, string, error) { value := string(pair.Value) return held, value, nil } + +func (c *ConsulBackend) NotifyActiveStateChange() error { + select { + case c.notifyActiveCh <- notifyEvent{}: + default: + // NOTE: If this occurs Vault's active status could be out of + // sync with Consul until reconcileTimer expires. + c.logger.Printf("[WARN]: consul: Concurrent state change notify dropped") + } + + return nil +} + +func (c *ConsulBackend) NotifySealedStateChange() error { + select { + case c.notifySealedCh <- notifyEvent{}: + default: + // NOTE: If this occurs Vault's sealed status could be out of + // sync with Consul until checkTimer expires. + c.logger.Printf("[WARN]: consul: Concurrent sealed state change notify dropped") + } + + return nil +} + +func (c *ConsulBackend) checkDuration() time.Duration { + return lib.DurationMinusBuffer(c.checkTimeout, checkMinBuffer, checkJitterFactor) +} + +func (c *ConsulBackend) RunServiceDiscovery(shutdownCh ShutdownChannel, advertiseAddr string, activeFunc activeFunction, sealedFunc sealedFunction) (err error) { + if err := c.setAdvertiseAddr(advertiseAddr); err != nil { + return err + } + + go c.runEventDemuxer(shutdownCh, advertiseAddr, activeFunc, sealedFunc) + + return nil +} + +func (c *ConsulBackend) runEventDemuxer(shutdownCh ShutdownChannel, advertiseAddr string, activeFunc activeFunction, sealedFunc sealedFunction) { + // Fire the reconcileTimer immediately upon starting the event demuxer + reconcileTimer := time.NewTimer(0) + defer reconcileTimer.Stop() + + // Schedule the first check. Consul TTL checks are passing by + // default, checkTimer does not need to be run immediately. + checkTimer := time.NewTimer(c.checkDuration()) + defer checkTimer.Stop() + + // Use a reactor pattern to handle and dispatch events to singleton + // goroutine handlers for execution. It is not acceptable to drop + // inbound events from Notify*(). + // + // goroutines are dispatched if the demuxer can acquire a lock (via + // an atomic CAS incr) on the handler. Handlers are responsible for + // deregistering themselves (atomic CAS decr). Handlers and the + // demuxer share a lock to synchronize information at the beginning + // and end of a handler's life (or after a handler wakes up from + // sleeping during a back-off/retry). + var shutdown bool + var checkLock int64 + var registeredServiceID string + var serviceRegLock int64 +shutdown: + for { + select { + case <-c.notifyActiveCh: + // Run reconcile immediately upon active state change notification + reconcileTimer.Reset(0) + case <-c.notifySealedCh: + // Run check timer immediately upon a seal state change notification + checkTimer.Reset(0) + case <-reconcileTimer.C: + // Unconditionally rearm the reconcileTimer + reconcileTimer.Reset(reconcileTimeout - lib.RandomStagger(reconcileTimeout/checkJitterFactor)) + + // Abort if service discovery is disabled or a + // reconcile handler is already active + if !c.disableRegistration && atomic.CompareAndSwapInt64(&serviceRegLock, 0, 1) { + // Enter handler with serviceRegLock held + go func() { + defer atomic.CompareAndSwapInt64(&serviceRegLock, 1, 0) + for !shutdown { + serviceID, err := c.reconcileConsul(registeredServiceID, activeFunc, sealedFunc) + if err != nil { + c.logger.Printf("[WARN]: consul: reconcile unable to talk with Consul backend: %v", err) + time.Sleep(consulRetryInterval) + continue + } + + c.serviceLock.Lock() + defer c.serviceLock.Unlock() + + registeredServiceID = serviceID + return + } + }() + } + case <-checkTimer.C: + checkTimer.Reset(c.checkDuration()) + // Abort if service discovery is disabled or a + // reconcile handler is active + if !c.disableRegistration && atomic.CompareAndSwapInt64(&checkLock, 0, 1) { + // Enter handler with serviceRegLock held + go func() { + defer atomic.CompareAndSwapInt64(&checkLock, 1, 0) + for !shutdown { + unsealed := sealedFunc() + if err := c.runCheck(unsealed); err != nil { + c.logger.Printf("[WARN]: consul: check unable to talk with Consul backend: %v", err) + time.Sleep(consulRetryInterval) + continue + } + return + } + }() + } + case <-shutdownCh: + c.logger.Printf("[INFO]: consul: Shutting down consul backend") + shutdown = true + break shutdown + } + } + + c.serviceLock.RLock() + defer c.serviceLock.RUnlock() + if err := c.client.Agent().ServiceDeregister(registeredServiceID); err != nil { + c.logger.Printf("[WARN]: consul: service deregistration failed: %v", err) + } +} + +// checkID returns the ID used for a Consul Check. Assume at least a read +// lock is held. +func (c *ConsulBackend) checkID() string { + return "vault-sealed-check" +} + +// reconcileConsul queries the state of Vault Core and Consul and fixes up +// Consul's state according to what's in Vault. reconcileConsul is called +// without any locks held and can be run concurrently, therefore no changes +// to ConsulBackend can be made in this method (i.e. wtb const receiver for +// compiler enforced safety). +func (c *ConsulBackend) reconcileConsul(registeredServiceID string, activeFunc activeFunction, sealedFunc sealedFunction) (serviceID string, err error) { + // Query vault Core for its current state + active := activeFunc() + sealed := sealedFunc() + + agent := c.client.Agent() + + // Get the current state of Vault from Consul + var currentVaultService *api.AgentService + if services, err := agent.Services(); err == nil { + if service, ok := services[c.serviceName]; ok { + currentVaultService = service + } + } + + serviceID = c.serviceID() + tags := serviceTags(active) + + var reregister bool + switch { + case currentVaultService == nil, + registeredServiceID == "": + reregister = true + default: + switch { + case len(currentVaultService.Tags) != 1, + currentVaultService.Tags[0] != tags[0]: + reregister = true + } + } + + if !reregister { + return "", nil + } + + service := &api.AgentServiceRegistration{ + ID: serviceID, + Name: c.serviceName, + Tags: tags, + Port: int(c.advertisePort), + Address: c.advertiseHost, + EnableTagOverride: false, + } + + checkStatus := api.HealthCritical + if !sealed { + checkStatus = api.HealthPassing + } + + sealedCheck := &api.AgentCheckRegistration{ + ID: c.checkID(), + Name: "Vault Sealed Status", + Notes: "Vault service is healthy when Vault is in an unsealed status and can become an active Vault server", + ServiceID: serviceID, + AgentServiceCheck: api.AgentServiceCheck{ + TTL: c.checkTimeout.String(), + Status: checkStatus, + }, + } + + if err := agent.ServiceRegister(service); err != nil { + return "", errwrap.Wrapf(`service registration failed: {{err}}`, err) + } + + if err := agent.CheckRegister(sealedCheck); err != nil { + return serviceID, errwrap.Wrapf(`service check registration failed: {{err}}`, err) + } + + return serviceID, nil +} + +// runCheck immediately pushes a TTL check. Assumes c.serviceLock is held +// exclusively. +func (c *ConsulBackend) runCheck(unsealed bool) error { + // Run a TTL check + agent := c.client.Agent() + if unsealed { + return agent.PassTTL(c.checkID(), "Vault Unsealed") + } else { + return agent.FailTTL(c.checkID(), "Vault Sealed") + } +} + +// serviceID returns the Vault ServiceID for use in Consul. Assume at least +// a read lock is held. +func (c *ConsulBackend) serviceID() string { + return fmt.Sprintf("%s:%s:%d", c.serviceName, c.advertiseHost, c.advertisePort) +} + +// serviceTags returns all of the relevant tags for Consul. +func serviceTags(active bool) []string { + activeTag := "standby" + if active { + activeTag = "active" + } + return []string{activeTag} +} + +func (c *ConsulBackend) setAdvertiseAddr(addr string) (err error) { + if addr == "" { + return fmt.Errorf("advertise address must not be empty") + } + + url, err := url.Parse(addr) + if err != nil { + return errwrap.Wrapf(fmt.Sprintf(`failed to parse advertise URL "%v": {{err}}`, addr), err) + } + + var portStr string + c.advertiseHost, portStr, err = net.SplitHostPort(url.Host) + if err != nil { + if url.Scheme == "http" { + portStr = "80" + } else if url.Scheme == "https" { + portStr = "443" + } else if url.Scheme == "unix" { + portStr = "-1" + c.advertiseHost = url.Path + } else { + return errwrap.Wrapf(fmt.Sprintf(`failed to find a host:port in advertise address "%v": {{err}}`, url.Host), err) + } + } + c.advertisePort, err = strconv.ParseInt(portStr, 10, 0) + if err != nil || c.advertisePort < -1 || c.advertisePort > 65535 { + return errwrap.Wrapf(fmt.Sprintf(`failed to parse valid port "%v": {{err}}`, portStr), err) + } + + return nil +} diff --git a/physical/consul_test.go b/physical/consul_test.go index e7d4c3cbc..32284e771 100644 --- a/physical/consul_test.go +++ b/physical/consul_test.go @@ -3,6 +3,7 @@ package physical import ( "fmt" "log" + "math/rand" "os" "reflect" "testing" @@ -39,28 +40,6 @@ func testConsulBackendConfig(t *testing.T, conf *consulConf) *ConsulBackend { t.Fatalf("Expected ConsulBackend") } - c.consulClientConf = api.DefaultConfig() - - c.service = &api.AgentServiceRegistration{ - ID: c.serviceID(), - Name: c.serviceName, - Tags: serviceTags(c.active), - Port: 8200, - Address: testHostIP(), - EnableTagOverride: false, - } - - c.sealedCheck = &api.AgentCheckRegistration{ - ID: c.checkID(), - Name: "Vault Sealed Status", - Notes: "Vault service is healthy when Vault is in an unsealed status and can become an active Vault server", - ServiceID: c.serviceID(), - AgentServiceCheck: api.AgentServiceCheck{ - TTL: c.checkTimeout.String(), - Status: api.HealthPassing, - }, - } - return c } @@ -69,21 +48,27 @@ func testConsul_testConsulBackend(t *testing.T) { if c == nil { t.Fatalf("bad") } +} - if c.active != false { - t.Fatalf("bad") +func testActiveFunc(activePct float64) activeFunction { + return func() bool { + var active bool + standbyProb := rand.Float64() + if standbyProb > activePct { + active = true + } + return active } +} - if c.unsealed != false { - t.Fatalf("bad") - } - - if c.service == nil { - t.Fatalf("bad") - } - - if c.sealedCheck == nil { - t.Fatalf("bad") +func testSealedFunc(sealedPct float64) sealedFunction { + return func() bool { + var sealed bool + unsealedProb := rand.Float64() + if unsealedProb > sealedPct { + sealed = true + } + return sealed } } @@ -165,8 +150,15 @@ func TestConsul_newConsulBackend(t *testing.T) { } c.disableRegistration = true + if c.disableRegistration == false { + addr := os.Getenv("CONSUL_HTTP_ADDR") + if addr == "" { + continue + } + } + var shutdownCh ShutdownChannel - if err := c.RunServiceDiscovery(shutdownCh, test.advertiseAddr); err != nil { + if err := c.RunServiceDiscovery(shutdownCh, test.advertiseAddr, testActiveFunc(0.5), testSealedFunc(0.5)); err != nil { t.Fatalf("bad: %v", err) } @@ -182,18 +174,6 @@ func TestConsul_newConsulBackend(t *testing.T) { t.Errorf("bad: %v != %v", test.service, c.serviceName) } - if test.address != c.consulClientConf.Address { - t.Errorf("bad: %s %v != %v", test.name, test.address, c.consulClientConf.Address) - } - - if test.scheme != c.consulClientConf.Scheme { - t.Errorf("bad: %v != %v", test.scheme, c.consulClientConf.Scheme) - } - - if test.token != c.consulClientConf.Token { - t.Errorf("bad: %v != %v", test.token, c.consulClientConf.Token) - } - // FIXME(sean@): Unable to test max_parallel // if test.max_parallel != cap(c.permitPool) { // t.Errorf("bad: %v != %v", test.max_parallel, cap(c.permitPool)) @@ -289,7 +269,7 @@ func TestConsul_setAdvertiseAddr(t *testing.T) { } } -func TestConsul_AdvertiseActive(t *testing.T) { +func TestConsul_NotifyActiveStateChange(t *testing.T) { addr := os.Getenv("CONSUL_HTTP_ADDR") if addr == "" { t.Skipf("No consul process running, skipping test") @@ -297,32 +277,12 @@ func TestConsul_AdvertiseActive(t *testing.T) { c := testConsulBackend(t) - if c.active != false { - t.Fatalf("bad") - } - - if err := c.AdvertiseActive(true); err != nil { - t.Fatalf("bad: %v", err) - } - - if err := c.AdvertiseActive(true); err != nil { - t.Fatalf("bad: %v", err) - } - - if err := c.AdvertiseActive(false); err != nil { - t.Fatalf("bad: %v", err) - } - - if err := c.AdvertiseActive(false); err != nil { - t.Fatalf("bad: %v", err) - } - - if err := c.AdvertiseActive(true); err != nil { + if err := c.NotifyActiveStateChange(); err != nil { t.Fatalf("bad: %v", err) } } -func TestConsul_AdvertiseSealed(t *testing.T) { +func TestConsul_NotifySealedStateChange(t *testing.T) { addr := os.Getenv("CONSUL_HTTP_ADDR") if addr == "" { t.Skipf("No consul process running, skipping test") @@ -330,44 +290,9 @@ func TestConsul_AdvertiseSealed(t *testing.T) { c := testConsulBackend(t) - if c.unsealed == true { - t.Fatalf("bad") - } - - if err := c.AdvertiseSealed(true); err != nil { + if err := c.NotifySealedStateChange(); err != nil { t.Fatalf("bad: %v", err) } - if c.unsealed == true { - t.Fatalf("bad") - } - - if err := c.AdvertiseSealed(true); err != nil { - t.Fatalf("bad: %v", err) - } - if c.unsealed == true { - t.Fatalf("bad") - } - - if err := c.AdvertiseSealed(false); err != nil { - t.Fatalf("bad: %v", err) - } - if c.unsealed == false { - t.Fatalf("bad") - } - - if err := c.AdvertiseSealed(false); err != nil { - t.Fatalf("bad: %v", err) - } - if c.unsealed == false { - t.Fatalf("bad") - } - - if err := c.AdvertiseSealed(true); err != nil { - t.Fatalf("bad: %v", err) - } - if c.unsealed == true { - t.Fatalf("bad") - } } func TestConsul_checkID(t *testing.T) { diff --git a/physical/physical.go b/physical/physical.go index 5b98815b7..51b331685 100644 --- a/physical/physical.go +++ b/physical/physical.go @@ -48,21 +48,27 @@ type AdvertiseDetect interface { DetectHostAddr() (string, error) } +// Callback signatures for RunServiceDiscovery +type activeFunction func() bool +type sealedFunction func() bool + // ServiceDiscovery is an optional interface that an HABackend can implement. // If they do, the state of a backend is advertised to the service discovery // network. type ServiceDiscovery interface { - // AdvertiseActive is used to reflect whether or not a backend is in - // an active or standby state. - AdvertiseActive(bool) error + // NotifyActiveStateChange is used by Core to notify a backend + // capable of ServiceDiscovery that this Vault instance has changed + // its status to active or standby. + NotifyActiveStateChange() error - // AdvertiseSealed is used to reflect whether or not a backend is in - // a sealed state or not. - AdvertiseSealed(bool) error + // NotifySealedStateChange is used by Core to notify a backend + // capable of ServiceDiscovery that Vault has changed its Sealed + // status to sealed or unsealed. + NotifySealedStateChange() error // Run executes any background service discovery tasks until the // shutdown channel is closed. - RunServiceDiscovery(shutdownCh ShutdownChannel, advertiseAddr string) error + RunServiceDiscovery(shutdownCh ShutdownChannel, advertiseAddr string, activeFunc activeFunction, sealedFunc sealedFunction) error } type Lock interface { diff --git a/vault/core.go b/vault/core.go index 8f544de49..02dc252b4 100644 --- a/vault/core.go +++ b/vault/core.go @@ -489,7 +489,7 @@ func (c *Core) Standby() (bool, error) { } // Leader is used to get the current active leader -func (c *Core) Leader() (bool, string, error) { +func (c *Core) Leader() (isLeader bool, leaderAddr string, err error) { c.stateLock.RLock() defer c.stateLock.RUnlock() // Check if HA enabled @@ -650,11 +650,9 @@ func (c *Core) Unseal(key []byte) (bool, error) { if c.ha != nil { sd, ok := c.ha.(physical.ServiceDiscovery) if ok { - go func() { - if err := sd.AdvertiseSealed(false); err != nil { - c.logger.Printf("[WARN] core: failed to advertise unsealed status: %v", err) - } - }() + if err := sd.NotifySealedStateChange(); err != nil { + c.logger.Printf("[WARN] core: failed to notify unsealed status: %v", err) + } } } return true, nil @@ -822,11 +820,9 @@ func (c *Core) sealInternal() error { if c.ha != nil { sd, ok := c.ha.(physical.ServiceDiscovery) if ok { - go func() { - if err := sd.AdvertiseSealed(true); err != nil { - c.logger.Printf("[WARN] core: failed to advertise sealed status: %v", err) - } - }() + if err := sd.NotifySealedStateChange(); err != nil { + c.logger.Printf("[WARN] core: failed to notify sealed status: %v", err) + } } } @@ -1146,11 +1142,9 @@ func (c *Core) advertiseLeader(uuid string, leaderLostCh <-chan struct{}) error sd, ok := c.ha.(physical.ServiceDiscovery) if ok { - go func() { - if err := sd.AdvertiseActive(true); err != nil { - c.logger.Printf("[WARN] core: failed to advertise active status: %v", err) - } - }() + if err := sd.NotifyActiveStateChange(); err != nil { + c.logger.Printf("[WARN] core: failed to notify active status: %v", err) + } } return nil } @@ -1182,11 +1176,9 @@ func (c *Core) clearLeader(uuid string) error { // Advertise ourselves as a standby sd, ok := c.ha.(physical.ServiceDiscovery) if ok { - go func() { - if err := sd.AdvertiseActive(false); err != nil { - c.logger.Printf("[WARN] core: failed to advertise standby status: %v", err) - } - }() + if err := sd.NotifyActiveStateChange(); err != nil { + c.logger.Printf("[WARN] core: failed to notify standby status: %v", err) + } } return err diff --git a/vault/logical_system_helpers.go b/vault/logical_system_helpers.go index a3fb94579..bcda207f4 100644 --- a/vault/logical_system_helpers.go +++ b/vault/logical_system_helpers.go @@ -44,7 +44,7 @@ func (b *SystemBackend) tuneMountTTLs(path string, meConfig *MountConfig, newDef int(newDefault.Seconds()), int(b.Core.maxLeaseTTL.Seconds())) } } else { - if meConfig.MaxLeaseTTL < *newDefault { + if newMax == nil && *newDefault > meConfig.MaxLeaseTTL { return fmt.Errorf("new backend default lease TTL of %d greater than backend max lease TTL of %d", int(newDefault.Seconds()), int(meConfig.MaxLeaseTTL.Seconds())) }