open-vault/physical/consul.go
2016-04-25 20:10:55 -07:00

580 lines
14 KiB
Go

package physical
import (
"fmt"
"io/ioutil"
"log"
"net"
"net/url"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"crypto/tls"
"crypto/x509"
"github.com/armon/go-metrics"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-cleanhttp"
)
const (
// checkJitterFactor specifies the jitter factor used to stagger checks
checkJitterFactor = 16
// checkMinBuffer specifies provides a guarantee that a check will not
// be executed too close to the TTL check timeout
checkMinBuffer = 100 * time.Millisecond
// defaultCheckTimeout changes the timeout of TTL checks
defaultCheckTimeout = 5 * time.Second
// defaultCheckInterval specifies the default interval used to send
// checks
defaultCheckInterval = 4 * time.Second
// defaultServiceName is the default Consul service name used when
// advertising a Vault instance.
defaultServiceName = "vault"
// registrationRetryInterval specifies the retry duration to use when
// a registration to the Consul agent fails.
registrationRetryInterval = 1 * time.Second
)
// ConsulBackend is a physical backend that stores data at specific
// prefix within Consul. It is used for most production situations as
// it allows Vault to run on multiple machines in a highly-available manner.
type ConsulBackend struct {
path string
logger *log.Logger
client *api.Client
kv *api.KV
permitPool *PermitPool
serviceLock sync.RWMutex
service *api.AgentServiceRegistration
sealedCheck *api.AgentCheckRegistration
registrationLock int64
advertiseHost string
advertisePort int64
consulClientConf *api.Config
serviceName string
running bool
active bool
unsealed bool
disableRegistration bool
checkTimeout time.Duration
checkTimer *time.Timer
}
// newConsulBackend constructs a Consul backend using the given API client
// and the prefix in the KV store.
func newConsulBackend(conf map[string]string, logger *log.Logger) (Backend, error) {
// Get the path in Consul
path, ok := conf["path"]
if !ok {
path = "vault/"
}
// Ensure path is suffixed but not prefixed
if !strings.HasSuffix(path, "/") {
logger.Printf("[WARN]: consul: appending trailing forward slash to path")
path += "/"
}
if strings.HasPrefix(path, "/") {
logger.Printf("[WARN]: consul: trimming path of its forward slash")
path = strings.TrimPrefix(path, "/")
}
// Allow admins to disable consul integration
disableReg, ok := conf["disable_registration"]
var disableRegistration bool
if ok && disableReg != "" {
b, err := strconv.ParseBool(disableReg)
if err != nil {
return nil, errwrap.Wrapf("failed parsing disable_registration parameter: {{err}}", err)
}
disableRegistration = b
}
// Get the service name to advertise in Consul
service, ok := conf["service"]
if !ok {
service = defaultServiceName
}
checkTimeout := defaultCheckTimeout
checkTimeoutStr, ok := conf["check_timeout"]
if ok {
d, err := time.ParseDuration(checkTimeoutStr)
if err != nil {
return nil, err
}
min, _ := lib.DurationMinusBufferDomain(d, checkMinBuffer, checkJitterFactor)
if min < checkMinBuffer {
return nil, fmt.Errorf("Consul check_timeout must be greater than %v", min)
}
checkTimeout = d
}
// Configure the client
consulConf := api.DefaultConfig()
if addr, ok := conf["address"]; ok {
consulConf.Address = addr
}
if scheme, ok := conf["scheme"]; ok {
consulConf.Scheme = scheme
}
if token, ok := conf["token"]; ok {
consulConf.Token = token
}
if consulConf.Scheme == "https" {
tlsClientConfig, err := setupTLSConfig(conf)
if err != nil {
return nil, err
}
transport := cleanhttp.DefaultPooledTransport()
transport.MaxIdleConnsPerHost = 4
transport.TLSClientConfig = tlsClientConfig
consulConf.HttpClient.Transport = transport
}
client, err := api.NewClient(consulConf)
if err != nil {
return nil, errwrap.Wrapf("client setup failed: {{err}}", err)
}
maxParStr, ok := conf["max_parallel"]
var maxParInt int
if ok {
maxParInt, err = strconv.Atoi(maxParStr)
if err != nil {
return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err)
}
logger.Printf("[DEBUG]: consul: max_parallel set to %d", maxParInt)
}
// Setup the backend
c := &ConsulBackend{
path: path,
logger: logger,
client: client,
kv: client.KV(),
permitPool: NewPermitPool(maxParInt),
consulClientConf: consulConf,
serviceName: service,
checkTimeout: checkTimeout,
checkTimer: time.NewTimer(checkTimeout),
disableRegistration: disableRegistration,
}
return c, nil
}
// serviceTags returns all of the relevant tags for Consul.
func serviceTags(active bool) []string {
activeTag := "standby"
if active {
activeTag = "active"
}
return []string{activeTag}
}
func (c *ConsulBackend) AdvertiseActive(active bool) error {
c.serviceLock.Lock()
defer c.serviceLock.Unlock()
// Vault is still bootstrapping
if c.service == nil {
return nil
}
// Save a cached copy of the active state: no way to query Core
c.active = active
// Ensure serial registration to the Consul agent. Allow for
// concurrent calls to update active status while a single task
// attempts, until successful, to update the Consul Agent.
if !c.disableRegistration && atomic.CompareAndSwapInt64(&c.registrationLock, 0, 1) {
defer atomic.CompareAndSwapInt64(&c.registrationLock, 1, 0)
// Retry agent registration until successful
for {
c.service.Tags = serviceTags(c.active)
agent := c.client.Agent()
err := agent.ServiceRegister(c.service)
if err == nil {
// Success
return nil
}
c.logger.Printf("[WARN] consul: service registration failed: %v", err)
c.serviceLock.Unlock()
time.Sleep(registrationRetryInterval)
c.serviceLock.Lock()
if !c.running {
// Shutting down
return err
}
}
}
// Successful concurrent update to active state
return nil
}
func (c *ConsulBackend) AdvertiseSealed(sealed bool) error {
c.serviceLock.Lock()
defer c.serviceLock.Unlock()
c.unsealed = !sealed
// Vault is still bootstrapping
if c.service == nil {
return nil
}
if !c.disableRegistration {
// Push a TTL check immediately to update the state
c.runCheck()
}
return nil
}
func (c *ConsulBackend) RunServiceDiscovery(shutdownCh ShutdownChannel, advertiseAddr string) (err error) {
c.serviceLock.Lock()
defer c.serviceLock.Unlock()
if c.disableRegistration {
return nil
}
if err := c.setAdvertiseAddr(advertiseAddr); err != nil {
return err
}
serviceID := c.serviceID()
c.service = &api.AgentServiceRegistration{
ID: serviceID,
Name: c.serviceName,
Tags: serviceTags(c.active),
Port: int(c.advertisePort),
Address: c.advertiseHost,
EnableTagOverride: false,
}
checkStatus := api.HealthCritical
if c.unsealed {
checkStatus = api.HealthPassing
}
c.sealedCheck = &api.AgentCheckRegistration{
ID: c.checkID(),
Name: "Vault Sealed Status",
Notes: "Vault service is healthy when Vault is in an unsealed status and can become an active Vault server",
ServiceID: serviceID,
AgentServiceCheck: api.AgentServiceCheck{
TTL: c.checkTimeout.String(),
Status: checkStatus,
},
}
agent := c.client.Agent()
if err := agent.ServiceRegister(c.service); err != nil {
return errwrap.Wrapf("service registration failed: {{err}}", err)
}
if err := agent.CheckRegister(c.sealedCheck); err != nil {
return errwrap.Wrapf("service registration check registration failed: {{err}}", err)
}
go c.checkRunner(shutdownCh)
c.running = true
// Deregister upon shutdown
go func() {
shutdown:
for {
select {
case <-shutdownCh:
c.logger.Printf("[INFO]: consul: Shutting down consul backend")
break shutdown
}
}
if err := agent.ServiceDeregister(serviceID); err != nil {
c.logger.Printf("[WARN]: consul: service deregistration failed: {{err}}", err)
}
c.running = false
}()
return nil
}
// checkRunner periodically runs TTL checks
func (c *ConsulBackend) checkRunner(shutdownCh ShutdownChannel) {
defer c.checkTimer.Stop()
for {
select {
case <-c.checkTimer.C:
go func() {
c.serviceLock.Lock()
defer c.serviceLock.Unlock()
c.runCheck()
}()
case <-shutdownCh:
return
}
}
}
// runCheck immediately pushes a TTL check. Assumes c.serviceLock is held
// exclusively.
func (c *ConsulBackend) runCheck() {
// Reset timer before calling run check in order to not slide the
// window of the next check.
c.checkTimer.Reset(lib.DurationMinusBuffer(c.checkTimeout, checkMinBuffer, checkJitterFactor))
// Run a TTL check
agent := c.client.Agent()
if c.unsealed {
agent.PassTTL(c.checkID(), "Vault Unsealed")
} else {
agent.FailTTL(c.checkID(), "Vault Sealed")
}
}
// checkID returns the ID used for a Consul Check. Assume at least a read
// lock is held.
func (c *ConsulBackend) checkID() string {
return "vault-sealed-check"
}
// serviceID returns the Vault ServiceID for use in Consul. Assume at least
// a read lock is held.
func (c *ConsulBackend) serviceID() string {
return fmt.Sprintf("%s:%s:%d", c.serviceName, c.advertiseHost, c.advertisePort)
}
func (c *ConsulBackend) setAdvertiseAddr(addr string) (err error) {
if addr == "" {
return fmt.Errorf("advertise address must not be empty")
}
url, err := url.Parse(addr)
if err != nil {
return errwrap.Wrapf(fmt.Sprintf(`failed to parse advertise URL "%v": {{err}}`, addr), err)
}
var portStr string
c.advertiseHost, portStr, err = net.SplitHostPort(url.Host)
if err != nil {
if url.Scheme == "http" {
portStr = "80"
} else if url.Scheme == "https" {
portStr = "443"
} else if url.Scheme == "unix" {
portStr = "-1"
c.advertiseHost = url.Path
} else {
return errwrap.Wrapf(fmt.Sprintf(`failed to find a host:port in advertise address "%v": {{err}}`, url.Host), err)
}
}
c.advertisePort, err = strconv.ParseInt(portStr, 10, 0)
if err != nil || c.advertisePort < -1 || c.advertisePort > 65535 {
return errwrap.Wrapf(fmt.Sprintf(`failed to parse valid port "%v": {{err}}`, portStr), err)
}
return nil
}
func setupTLSConfig(conf map[string]string) (*tls.Config, error) {
serverName := strings.Split(conf["address"], ":")
insecureSkipVerify := false
if _, ok := conf["tls_skip_verify"]; ok {
insecureSkipVerify = true
}
tlsClientConfig := &tls.Config{
InsecureSkipVerify: insecureSkipVerify,
ServerName: serverName[0],
}
_, okCert := conf["tls_cert_file"]
_, okKey := conf["tls_key_file"]
if okCert && okKey {
tlsCert, err := tls.LoadX509KeyPair(conf["tls_cert_file"], conf["tls_key_file"])
if err != nil {
return nil, fmt.Errorf("client tls setup failed: %v", err)
}
tlsClientConfig.Certificates = []tls.Certificate{tlsCert}
}
if tlsCaFile, ok := conf["tls_ca_file"]; ok {
caPool := x509.NewCertPool()
data, err := ioutil.ReadFile(tlsCaFile)
if err != nil {
return nil, fmt.Errorf("failed to read CA file: %v", err)
}
if !caPool.AppendCertsFromPEM(data) {
return nil, fmt.Errorf("failed to parse CA certificate")
}
tlsClientConfig.RootCAs = caPool
}
return tlsClientConfig, nil
}
// Put is used to insert or update an entry
func (c *ConsulBackend) Put(entry *Entry) error {
defer metrics.MeasureSince([]string{"consul", "put"}, time.Now())
pair := &api.KVPair{
Key: c.path + entry.Key,
Value: entry.Value,
}
c.permitPool.Acquire()
defer c.permitPool.Release()
_, err := c.kv.Put(pair, nil)
return err
}
// Get is used to fetch an entry
func (c *ConsulBackend) Get(key string) (*Entry, error) {
defer metrics.MeasureSince([]string{"consul", "get"}, time.Now())
c.permitPool.Acquire()
defer c.permitPool.Release()
pair, _, err := c.kv.Get(c.path+key, nil)
if err != nil {
return nil, err
}
if pair == nil {
return nil, nil
}
ent := &Entry{
Key: key,
Value: pair.Value,
}
return ent, nil
}
// Delete is used to permanently delete an entry
func (c *ConsulBackend) Delete(key string) error {
defer metrics.MeasureSince([]string{"consul", "delete"}, time.Now())
c.permitPool.Acquire()
defer c.permitPool.Release()
_, err := c.kv.Delete(c.path+key, nil)
return err
}
// List is used to list all the keys under a given
// prefix, up to the next prefix.
func (c *ConsulBackend) List(prefix string) ([]string, error) {
defer metrics.MeasureSince([]string{"consul", "list"}, time.Now())
scan := c.path + prefix
// The TrimPrefix call below will not work correctly if we have "//" at the
// end. This can happen in cases where you are e.g. listing the root of a
// prefix in a logical backend via "/" instead of ""
if strings.HasSuffix(scan, "//") {
scan = scan[:len(scan)-1]
}
c.permitPool.Acquire()
defer c.permitPool.Release()
out, _, err := c.kv.Keys(scan, "/", nil)
for idx, val := range out {
out[idx] = strings.TrimPrefix(val, scan)
}
return out, err
}
// Lock is used for mutual exclusion based on the given key.
func (c *ConsulBackend) LockWith(key, value string) (Lock, error) {
// Create the lock
opts := &api.LockOptions{
Key: c.path + key,
Value: []byte(value),
SessionName: "Vault Lock",
MonitorRetries: 5,
}
lock, err := c.client.LockOpts(opts)
if err != nil {
return nil, fmt.Errorf("failed to create lock: %v", err)
}
cl := &ConsulLock{
client: c.client,
key: c.path + key,
lock: lock,
}
return cl, nil
}
// DetectHostAddr is used to detect the host address by asking the Consul agent
func (c *ConsulBackend) DetectHostAddr() (string, error) {
agent := c.client.Agent()
self, err := agent.Self()
if err != nil {
return "", err
}
addr, ok := self["Member"]["Addr"].(string)
if !ok {
return "", fmt.Errorf("Unable to convert an address to string")
}
return addr, nil
}
// ConsulLock is used to provide the Lock interface backed by Consul
type ConsulLock struct {
client *api.Client
key string
lock *api.Lock
}
func (c *ConsulLock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
return c.lock.Lock(stopCh)
}
func (c *ConsulLock) Unlock() error {
return c.lock.Unlock()
}
func (c *ConsulLock) Value() (bool, string, error) {
kv := c.client.KV()
pair, _, err := kv.Get(c.key, nil)
if err != nil {
return false, "", err
}
if pair == nil {
return false, "", nil
}
held := pair.Session != ""
value := string(pair.Value)
return held, value, nil
}