open-vault/physical/consul.go

633 lines
17 KiB
Go

package physical
import (
"fmt"
"io/ioutil"
"log"
"net"
"net/url"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"crypto/tls"
"crypto/x509"
"github.com/armon/go-metrics"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-cleanhttp"
)
const (
// checkJitterFactor specifies the jitter factor used to stagger checks
checkJitterFactor = 16
// checkMinBuffer specifies provides a guarantee that a check will not
// be executed too close to the TTL check timeout
checkMinBuffer = 100 * time.Millisecond
// consulRetryInterval specifies the retry duration to use when an
// API call to the Consul agent fails.
consulRetryInterval = 1 * time.Second
// defaultCheckTimeout changes the timeout of TTL checks
defaultCheckTimeout = 5 * time.Second
// defaultServiceName is the default Consul service name used when
// advertising a Vault instance.
defaultServiceName = "vault"
// reconcileTimeout is how often Vault should query Consul to detect
// and fix any state drift.
reconcileTimeout = 60 * time.Second
)
type notifyEvent struct{}
// ConsulBackend is a physical backend that stores data at specific
// prefix within Consul. It is used for most production situations as
// it allows Vault to run on multiple machines in a highly-available manner.
type ConsulBackend struct {
path string
logger *log.Logger
client *api.Client
kv *api.KV
permitPool *PermitPool
serviceLock sync.RWMutex
advertiseHost string
advertisePort int64
serviceName string
disableRegistration bool
checkTimeout time.Duration
notifyActiveCh chan notifyEvent
notifySealedCh chan notifyEvent
}
// newConsulBackend constructs a Consul backend using the given API client
// and the prefix in the KV store.
func newConsulBackend(conf map[string]string, logger *log.Logger) (Backend, error) {
// Get the path in Consul
path, ok := conf["path"]
if !ok {
path = "vault/"
}
logger.Printf("[DEBUG]: consul: config path set to %v", path)
// Ensure path is suffixed but not prefixed
if !strings.HasSuffix(path, "/") {
logger.Printf("[WARN]: consul: appending trailing forward slash to path")
path += "/"
}
if strings.HasPrefix(path, "/") {
logger.Printf("[WARN]: consul: trimming path of its forward slash")
path = strings.TrimPrefix(path, "/")
}
// Allow admins to disable consul integration
disableReg, ok := conf["disable_registration"]
var disableRegistration bool
if ok && disableReg != "" {
b, err := strconv.ParseBool(disableReg)
if err != nil {
return nil, errwrap.Wrapf("failed parsing disable_registration parameter: {{err}}", err)
}
disableRegistration = b
}
logger.Printf("[DEBUG]: consul: config disable_registration set to %v", disableRegistration)
// Get the service name to advertise in Consul
service, ok := conf["service"]
if !ok {
service = defaultServiceName
}
logger.Printf("[DEBUG]: consul: config service set to %s", service)
checkTimeout := defaultCheckTimeout
checkTimeoutStr, ok := conf["check_timeout"]
if ok {
d, err := time.ParseDuration(checkTimeoutStr)
if err != nil {
return nil, err
}
min, _ := lib.DurationMinusBufferDomain(d, checkMinBuffer, checkJitterFactor)
if min < checkMinBuffer {
return nil, fmt.Errorf("Consul check_timeout must be greater than %v", min)
}
checkTimeout = d
logger.Printf("[DEBUG]: consul: config check_timeout set to %v", d)
}
// Configure the client
consulConf := api.DefaultConfig()
if addr, ok := conf["address"]; ok {
consulConf.Address = addr
logger.Printf("[DEBUG]: consul: config address set to %d", addr)
}
if scheme, ok := conf["scheme"]; ok {
consulConf.Scheme = scheme
logger.Printf("[DEBUG]: consul: config scheme set to %d", scheme)
}
if token, ok := conf["token"]; ok {
consulConf.Token = token
logger.Printf("[DEBUG]: consul: config token set")
}
if consulConf.Scheme == "https" {
tlsClientConfig, err := setupTLSConfig(conf)
if err != nil {
return nil, err
}
transport := cleanhttp.DefaultPooledTransport()
transport.MaxIdleConnsPerHost = 4
transport.TLSClientConfig = tlsClientConfig
consulConf.HttpClient.Transport = transport
logger.Printf("[DEBUG]: consul: configured TLS")
}
client, err := api.NewClient(consulConf)
if err != nil {
return nil, errwrap.Wrapf("client setup failed: {{err}}", err)
}
maxParStr, ok := conf["max_parallel"]
var maxParInt int
if ok {
maxParInt, err = strconv.Atoi(maxParStr)
if err != nil {
return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err)
}
logger.Printf("[DEBUG]: consul: max_parallel set to %d", maxParInt)
}
// Setup the backend
c := &ConsulBackend{
path: path,
logger: logger,
client: client,
kv: client.KV(),
permitPool: NewPermitPool(maxParInt),
serviceName: service,
checkTimeout: checkTimeout,
disableRegistration: disableRegistration,
}
return c, nil
}
func setupTLSConfig(conf map[string]string) (*tls.Config, error) {
serverName := strings.Split(conf["address"], ":")
insecureSkipVerify := false
if _, ok := conf["tls_skip_verify"]; ok {
insecureSkipVerify = true
}
tlsClientConfig := &tls.Config{
InsecureSkipVerify: insecureSkipVerify,
ServerName: serverName[0],
}
_, okCert := conf["tls_cert_file"]
_, okKey := conf["tls_key_file"]
if okCert && okKey {
tlsCert, err := tls.LoadX509KeyPair(conf["tls_cert_file"], conf["tls_key_file"])
if err != nil {
return nil, fmt.Errorf("client tls setup failed: %v", err)
}
tlsClientConfig.Certificates = []tls.Certificate{tlsCert}
}
if tlsCaFile, ok := conf["tls_ca_file"]; ok {
caPool := x509.NewCertPool()
data, err := ioutil.ReadFile(tlsCaFile)
if err != nil {
return nil, fmt.Errorf("failed to read CA file: %v", err)
}
if !caPool.AppendCertsFromPEM(data) {
return nil, fmt.Errorf("failed to parse CA certificate")
}
tlsClientConfig.RootCAs = caPool
}
return tlsClientConfig, nil
}
// Put is used to insert or update an entry
func (c *ConsulBackend) Put(entry *Entry) error {
defer metrics.MeasureSince([]string{"consul", "put"}, time.Now())
pair := &api.KVPair{
Key: c.path + entry.Key,
Value: entry.Value,
}
c.permitPool.Acquire()
defer c.permitPool.Release()
_, err := c.kv.Put(pair, nil)
return err
}
// Get is used to fetch an entry
func (c *ConsulBackend) Get(key string) (*Entry, error) {
defer metrics.MeasureSince([]string{"consul", "get"}, time.Now())
c.permitPool.Acquire()
defer c.permitPool.Release()
pair, _, err := c.kv.Get(c.path+key, nil)
if err != nil {
return nil, err
}
if pair == nil {
return nil, nil
}
ent := &Entry{
Key: key,
Value: pair.Value,
}
return ent, nil
}
// Delete is used to permanently delete an entry
func (c *ConsulBackend) Delete(key string) error {
defer metrics.MeasureSince([]string{"consul", "delete"}, time.Now())
c.permitPool.Acquire()
defer c.permitPool.Release()
_, err := c.kv.Delete(c.path+key, nil)
return err
}
// List is used to list all the keys under a given
// prefix, up to the next prefix.
func (c *ConsulBackend) List(prefix string) ([]string, error) {
defer metrics.MeasureSince([]string{"consul", "list"}, time.Now())
scan := c.path + prefix
// The TrimPrefix call below will not work correctly if we have "//" at the
// end. This can happen in cases where you are e.g. listing the root of a
// prefix in a logical backend via "/" instead of ""
if strings.HasSuffix(scan, "//") {
scan = scan[:len(scan)-1]
}
c.permitPool.Acquire()
defer c.permitPool.Release()
out, _, err := c.kv.Keys(scan, "/", nil)
for idx, val := range out {
out[idx] = strings.TrimPrefix(val, scan)
}
return out, err
}
// Lock is used for mutual exclusion based on the given key.
func (c *ConsulBackend) LockWith(key, value string) (Lock, error) {
// Create the lock
opts := &api.LockOptions{
Key: c.path + key,
Value: []byte(value),
SessionName: "Vault Lock",
MonitorRetries: 5,
}
lock, err := c.client.LockOpts(opts)
if err != nil {
return nil, fmt.Errorf("failed to create lock: %v", err)
}
cl := &ConsulLock{
client: c.client,
key: c.path + key,
lock: lock,
}
return cl, nil
}
// DetectHostAddr is used to detect the host address by asking the Consul agent
func (c *ConsulBackend) DetectHostAddr() (string, error) {
agent := c.client.Agent()
self, err := agent.Self()
if err != nil {
return "", err
}
addr, ok := self["Member"]["Addr"].(string)
if !ok {
return "", fmt.Errorf("Unable to convert an address to string")
}
return addr, nil
}
// ConsulLock is used to provide the Lock interface backed by Consul
type ConsulLock struct {
client *api.Client
key string
lock *api.Lock
}
func (c *ConsulLock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
return c.lock.Lock(stopCh)
}
func (c *ConsulLock) Unlock() error {
return c.lock.Unlock()
}
func (c *ConsulLock) Value() (bool, string, error) {
kv := c.client.KV()
pair, _, err := kv.Get(c.key, nil)
if err != nil {
return false, "", err
}
if pair == nil {
return false, "", nil
}
held := pair.Session != ""
value := string(pair.Value)
return held, value, nil
}
func (c *ConsulBackend) NotifyActiveStateChange() error {
select {
case c.notifyActiveCh <- notifyEvent{}:
default:
// NOTE: If this occurs Vault's active status could be out of
// sync with Consul until reconcileTimer expires.
c.logger.Printf("[WARN]: consul: Concurrent state change notify dropped")
}
return nil
}
func (c *ConsulBackend) NotifySealedStateChange() error {
select {
case c.notifySealedCh <- notifyEvent{}:
default:
// NOTE: If this occurs Vault's sealed status could be out of
// sync with Consul until checkTimer expires.
c.logger.Printf("[WARN]: consul: Concurrent sealed state change notify dropped")
}
return nil
}
func (c *ConsulBackend) checkDuration() time.Duration {
return lib.DurationMinusBuffer(c.checkTimeout, checkMinBuffer, checkJitterFactor)
}
func (c *ConsulBackend) RunServiceDiscovery(shutdownCh ShutdownChannel, advertiseAddr string, activeFunc activeFunction, sealedFunc sealedFunction) (err error) {
if err := c.setAdvertiseAddr(advertiseAddr); err != nil {
return err
}
go c.runEventDemuxer(shutdownCh, advertiseAddr, activeFunc, sealedFunc)
return nil
}
func (c *ConsulBackend) runEventDemuxer(shutdownCh ShutdownChannel, advertiseAddr string, activeFunc activeFunction, sealedFunc sealedFunction) {
// Fire the reconcileTimer immediately upon starting the event demuxer
reconcileTimer := time.NewTimer(0)
defer reconcileTimer.Stop()
// Schedule the first check. Consul TTL checks are passing by
// default, checkTimer does not need to be run immediately.
checkTimer := time.NewTimer(c.checkDuration())
defer checkTimer.Stop()
// Use a reactor pattern to handle and dispatch events to singleton
// goroutine handlers for execution. It is not acceptable to drop
// inbound events from Notify*().
//
// goroutines are dispatched if the demuxer can acquire a lock (via
// an atomic CAS incr) on the handler. Handlers are responsible for
// deregistering themselves (atomic CAS decr). Handlers and the
// demuxer share a lock to synchronize information at the beginning
// and end of a handler's life (or after a handler wakes up from
// sleeping during a back-off/retry).
var shutdown bool
var checkLock int64
var registeredServiceID string
var serviceRegLock int64
shutdown:
for {
select {
case <-c.notifyActiveCh:
// Run reconcile immediately upon active state change notification
reconcileTimer.Reset(0)
case <-c.notifySealedCh:
// Run check timer immediately upon a seal state change notification
checkTimer.Reset(0)
case <-reconcileTimer.C:
// Unconditionally rearm the reconcileTimer
reconcileTimer.Reset(reconcileTimeout - lib.RandomStagger(reconcileTimeout/checkJitterFactor))
// Abort if service discovery is disabled or a
// reconcile handler is already active
if !c.disableRegistration && atomic.CompareAndSwapInt64(&serviceRegLock, 0, 1) {
// Enter handler with serviceRegLock held
go func() {
defer atomic.CompareAndSwapInt64(&serviceRegLock, 1, 0)
for !shutdown {
serviceID, err := c.reconcileConsul(registeredServiceID, activeFunc, sealedFunc)
if err != nil {
c.logger.Printf("[WARN]: consul: reconcile unable to talk with Consul backend: %v", err)
time.Sleep(consulRetryInterval)
continue
}
c.serviceLock.Lock()
defer c.serviceLock.Unlock()
registeredServiceID = serviceID
return
}
}()
}
case <-checkTimer.C:
checkTimer.Reset(c.checkDuration())
// Abort if service discovery is disabled or a
// reconcile handler is active
if !c.disableRegistration && atomic.CompareAndSwapInt64(&checkLock, 0, 1) {
// Enter handler with checkLock held
go func() {
defer atomic.CompareAndSwapInt64(&checkLock, 1, 0)
for !shutdown {
sealed := sealedFunc()
if err := c.runCheck(sealed); err != nil {
c.logger.Printf("[WARN]: consul: check unable to talk with Consul backend: %v", err)
time.Sleep(consulRetryInterval)
continue
}
return
}
}()
}
case <-shutdownCh:
c.logger.Printf("[INFO]: consul: Shutting down consul backend")
shutdown = true
break shutdown
}
}
c.serviceLock.RLock()
defer c.serviceLock.RUnlock()
if err := c.client.Agent().ServiceDeregister(registeredServiceID); err != nil {
c.logger.Printf("[WARN]: consul: service deregistration failed: %v", err)
}
}
// checkID returns the ID used for a Consul Check. Assume at least a read
// lock is held.
func (c *ConsulBackend) checkID() string {
return "vault-sealed-check"
}
// reconcileConsul queries the state of Vault Core and Consul and fixes up
// Consul's state according to what's in Vault. reconcileConsul is called
// without any locks held and can be run concurrently, therefore no changes
// to ConsulBackend can be made in this method (i.e. wtb const receiver for
// compiler enforced safety).
func (c *ConsulBackend) reconcileConsul(registeredServiceID string, activeFunc activeFunction, sealedFunc sealedFunction) (serviceID string, err error) {
// Query vault Core for its current state
active := activeFunc()
sealed := sealedFunc()
agent := c.client.Agent()
// Get the current state of Vault from Consul
var currentVaultService *api.AgentService
if services, err := agent.Services(); err == nil {
if service, ok := services[c.serviceName]; ok {
currentVaultService = service
}
}
serviceID = c.serviceID()
tags := serviceTags(active)
var reregister bool
switch {
case currentVaultService == nil,
registeredServiceID == "":
reregister = true
default:
switch {
case len(currentVaultService.Tags) != 1,
currentVaultService.Tags[0] != tags[0]:
reregister = true
}
}
if !reregister {
return "", nil
}
service := &api.AgentServiceRegistration{
ID: serviceID,
Name: c.serviceName,
Tags: tags,
Port: int(c.advertisePort),
Address: c.advertiseHost,
EnableTagOverride: false,
}
checkStatus := api.HealthCritical
if !sealed {
checkStatus = api.HealthPassing
}
sealedCheck := &api.AgentCheckRegistration{
ID: c.checkID(),
Name: "Vault Sealed Status",
Notes: "Vault service is healthy when Vault is in an unsealed status and can become an active Vault server",
ServiceID: serviceID,
AgentServiceCheck: api.AgentServiceCheck{
TTL: c.checkTimeout.String(),
Status: checkStatus,
},
}
if err := agent.ServiceRegister(service); err != nil {
return "", errwrap.Wrapf(`service registration failed: {{err}}`, err)
}
if err := agent.CheckRegister(sealedCheck); err != nil {
return serviceID, errwrap.Wrapf(`service check registration failed: {{err}}`, err)
}
return serviceID, nil
}
// runCheck immediately pushes a TTL check.
func (c *ConsulBackend) runCheck(sealed bool) error {
// Run a TTL check
agent := c.client.Agent()
if !sealed {
return agent.PassTTL(c.checkID(), "Vault Unsealed")
} else {
return agent.FailTTL(c.checkID(), "Vault Sealed")
}
}
// serviceID returns the Vault ServiceID for use in Consul. Assume at least
// a read lock is held.
func (c *ConsulBackend) serviceID() string {
return fmt.Sprintf("%s:%s:%d", c.serviceName, c.advertiseHost, c.advertisePort)
}
// serviceTags returns all of the relevant tags for Consul.
func serviceTags(active bool) []string {
activeTag := "standby"
if active {
activeTag = "active"
}
return []string{activeTag}
}
func (c *ConsulBackend) setAdvertiseAddr(addr string) (err error) {
if addr == "" {
return fmt.Errorf("advertise address must not be empty")
}
url, err := url.Parse(addr)
if err != nil {
return errwrap.Wrapf(fmt.Sprintf(`failed to parse advertise URL "%v": {{err}}`, addr), err)
}
var portStr string
c.advertiseHost, portStr, err = net.SplitHostPort(url.Host)
if err != nil {
if url.Scheme == "http" {
portStr = "80"
} else if url.Scheme == "https" {
portStr = "443"
} else if url.Scheme == "unix" {
portStr = "-1"
c.advertiseHost = url.Path
} else {
return errwrap.Wrapf(fmt.Sprintf(`failed to find a host:port in advertise address "%v": {{err}}`, url.Host), err)
}
}
c.advertisePort, err = strconv.ParseInt(portStr, 10, 0)
if err != nil || c.advertisePort < -1 || c.advertisePort > 65535 {
return errwrap.Wrapf(fmt.Sprintf(`failed to parse valid port "%v": {{err}}`, portStr), err)
}
return nil
}