Observer pattern for service registration interface (#8123)
* use observer pattern for service discovery * update perf standby method * fix test * revert usersTags to being called serviceTags * use previous consul code * vault isnt a performance standby before starting * log err * changes from feedback * add Run method to interface * changes from feedback * fix core test * update example
This commit is contained in:
parent
36f0c05744
commit
759f9b38f7
|
@ -161,7 +161,7 @@ var (
|
|||
}
|
||||
|
||||
serviceRegistrations = map[string]sr.Factory{
|
||||
"consul": csr.NewConsulServiceRegistration,
|
||||
"consul": csr.NewServiceRegistration,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
@ -919,11 +919,22 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
return 1
|
||||
}
|
||||
|
||||
if config.Storage.Type == "raft" {
|
||||
// Do any custom configuration needed per backend
|
||||
switch config.Storage.Type {
|
||||
case "consul":
|
||||
if config.ServiceRegistration == nil {
|
||||
// If Consul is configured for storage and service registration is unconfigured,
|
||||
// use Consul for service registration without requiring additional configuration.
|
||||
// This maintains backward-compatibility.
|
||||
config.ServiceRegistration = &server.ServiceRegistration{
|
||||
Type: "consul",
|
||||
Config: config.Storage.Config,
|
||||
}
|
||||
}
|
||||
case "raft":
|
||||
if envCA := os.Getenv("VAULT_CLUSTER_ADDR"); envCA != "" {
|
||||
config.ClusterAddr = envCA
|
||||
}
|
||||
|
||||
if len(config.ClusterAddr) == 0 {
|
||||
c.UI.Error("Cluster address must be set when using raft storage")
|
||||
return 1
|
||||
|
@ -943,6 +954,9 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
return 1
|
||||
}
|
||||
|
||||
// Instantiate the wait group
|
||||
c.WaitGroup = &sync.WaitGroup{}
|
||||
|
||||
// Initialize the Service Discovery, if there is one
|
||||
var configSR sr.ServiceRegistration
|
||||
if config.ServiceRegistration != nil {
|
||||
|
@ -954,11 +968,25 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
|
||||
namedSDLogger := c.logger.Named("service_registration." + config.ServiceRegistration.Type)
|
||||
allLoggers = append(allLoggers, namedSDLogger)
|
||||
configSR, err = sdFactory(config.ServiceRegistration.Config, namedSDLogger)
|
||||
|
||||
// Since we haven't even begun starting Vault's core yet,
|
||||
// we know that Vault is in its pre-running state.
|
||||
state := sr.State{
|
||||
VaultVersion: version.GetVersion().VersionNumber(),
|
||||
IsInitialized: false,
|
||||
IsSealed: true,
|
||||
IsActive: false,
|
||||
IsPerformanceStandby: false,
|
||||
}
|
||||
configSR, err = sdFactory(config.ServiceRegistration.Config, namedSDLogger, state, config.Storage.RedirectAddr)
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error initializing service_registration of type %s: %s", config.ServiceRegistration.Type, err))
|
||||
return 1
|
||||
}
|
||||
if err := configSR.Run(c.ShutdownCh, c.WaitGroup); err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error running service_registration of type %s: %s", config.ServiceRegistration.Type, err))
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
infoKeys := make([]string, 0, 10)
|
||||
|
@ -1514,26 +1542,7 @@ CLUSTER_SYNTHESIS_COMPLETE:
|
|||
}
|
||||
}
|
||||
|
||||
// Perform service discovery registrations and initialization of
|
||||
// HTTP server after the verifyOnly check.
|
||||
|
||||
// Instantiate the wait group
|
||||
c.WaitGroup = &sync.WaitGroup{}
|
||||
|
||||
// If service discovery is available, run service discovery
|
||||
if disc := coreConfig.GetServiceRegistration(); disc != nil {
|
||||
activeFunc := func() bool {
|
||||
if isLeader, _, _, err := core.Leader(); err == nil {
|
||||
return isLeader
|
||||
}
|
||||
return false
|
||||
}
|
||||
if err := disc.RunServiceRegistration(c.WaitGroup, c.ShutdownCh, coreConfig.RedirectAddr, activeFunc, core.Sealed, core.PerfStandby); err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error initializing service discovery: %v", err))
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
// Perform initialization of HTTP server after the verifyOnly check.
|
||||
// If we're in Dev mode, then initialize the core
|
||||
if c.flagDev && !c.flagDevSkipInit {
|
||||
init, err := c.enableDev(core, coreConfig)
|
||||
|
|
|
@ -4,20 +4,21 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
|
||||
metrics "github.com/armon/go-metrics"
|
||||
"github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/hashicorp/errwrap"
|
||||
multierror "github.com/hashicorp/go-multierror"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/helper/parseutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/tlsutil"
|
||||
"github.com/hashicorp/vault/sdk/physical"
|
||||
sr "github.com/hashicorp/vault/serviceregistration"
|
||||
csr "github.com/hashicorp/vault/serviceregistration/consul"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -35,14 +36,12 @@ var _ physical.Backend = (*ConsulBackend)(nil)
|
|||
var _ physical.HABackend = (*ConsulBackend)(nil)
|
||||
var _ physical.Lock = (*ConsulLock)(nil)
|
||||
var _ physical.Transactional = (*ConsulBackend)(nil)
|
||||
var _ sr.ServiceRegistration = (*ConsulBackend)(nil)
|
||||
|
||||
// ConsulBackend is a physical backend that stores data at specific
|
||||
// prefix within Consul. It is used for most production situations as
|
||||
// it allows Vault to run on multiple machines in a highly-available manner.
|
||||
type ConsulBackend struct {
|
||||
*csr.ConsulServiceRegistration
|
||||
|
||||
client *api.Client
|
||||
path string
|
||||
kv *api.KV
|
||||
permitPool *physical.PermitPool
|
||||
|
@ -55,15 +54,6 @@ type ConsulBackend struct {
|
|||
// NewConsulBackend constructs a Consul backend using the given API client
|
||||
// and the prefix in the KV store.
|
||||
func NewConsulBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
|
||||
|
||||
// Create the ConsulServiceRegistration struct that we will embed in the
|
||||
// ConsulBackend
|
||||
sreg, err := csr.NewConsulServiceRegistration(conf, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
csreg := sreg.(*csr.ConsulServiceRegistration)
|
||||
|
||||
// Get the path in Consul
|
||||
path, ok := conf["path"]
|
||||
if !ok {
|
||||
|
@ -112,7 +102,7 @@ func NewConsulBackend(conf map[string]string, logger log.Logger) (physical.Backe
|
|||
maxParStr, ok := conf["max_parallel"]
|
||||
var maxParInt int
|
||||
if ok {
|
||||
maxParInt, err = strconv.Atoi(maxParStr)
|
||||
maxParInt, err := strconv.Atoi(maxParStr)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err)
|
||||
}
|
||||
|
@ -132,12 +122,68 @@ func NewConsulBackend(conf map[string]string, logger log.Logger) (physical.Backe
|
|||
consistencyMode = consistencyModeDefault
|
||||
}
|
||||
|
||||
// Configure the client
|
||||
consulConf := api.DefaultConfig()
|
||||
// Set MaxIdleConnsPerHost to the number of processes used in expiration.Restore
|
||||
consulConf.Transport.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount
|
||||
|
||||
if addr, ok := conf["address"]; ok {
|
||||
consulConf.Address = addr
|
||||
if logger.IsDebug() {
|
||||
logger.Debug("config address set", "address", addr)
|
||||
}
|
||||
|
||||
// Copied from the Consul API module; set the Scheme based on
|
||||
// the protocol field if address looks ike a URL.
|
||||
// This can enable the TLS configuration below.
|
||||
parts := strings.SplitN(addr, "://", 2)
|
||||
if len(parts) == 2 {
|
||||
if parts[0] == "http" || parts[0] == "https" {
|
||||
consulConf.Scheme = parts[0]
|
||||
consulConf.Address = parts[1]
|
||||
if logger.IsDebug() {
|
||||
logger.Debug("config address parsed", "scheme", parts[0])
|
||||
logger.Debug("config scheme parsed", "address", parts[1])
|
||||
}
|
||||
} // allow "unix:" or whatever else consul supports in the future
|
||||
}
|
||||
}
|
||||
if scheme, ok := conf["scheme"]; ok {
|
||||
consulConf.Scheme = scheme
|
||||
if logger.IsDebug() {
|
||||
logger.Debug("config scheme set", "scheme", scheme)
|
||||
}
|
||||
}
|
||||
if token, ok := conf["token"]; ok {
|
||||
consulConf.Token = token
|
||||
logger.Debug("config token set")
|
||||
}
|
||||
|
||||
if consulConf.Scheme == "https" {
|
||||
// Use the parsed Address instead of the raw conf['address']
|
||||
tlsClientConfig, err := tlsutil.SetupTLSConfig(conf, consulConf.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
consulConf.Transport.TLSClientConfig = tlsClientConfig
|
||||
if err := http2.ConfigureTransport(consulConf.Transport); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logger.Debug("configured TLS")
|
||||
}
|
||||
|
||||
consulConf.HttpClient = &http.Client{Transport: consulConf.Transport}
|
||||
client, err := api.NewClient(consulConf)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("client setup failed: {{err}}", err)
|
||||
}
|
||||
|
||||
// Setup the backend
|
||||
c := &ConsulBackend{
|
||||
ConsulServiceRegistration: csreg,
|
||||
|
||||
path: path,
|
||||
kv: csreg.Client.KV(),
|
||||
client: client,
|
||||
kv: client.KV(),
|
||||
permitPool: physical.NewPermitPool(maxParInt),
|
||||
consistencyMode: consistencyMode,
|
||||
|
||||
|
@ -302,12 +348,12 @@ func (c *ConsulBackend) LockWith(key, value string) (physical.Lock, error) {
|
|||
SessionTTL: c.sessionTTL,
|
||||
LockWaitTime: c.lockWaitTime,
|
||||
}
|
||||
lock, err := c.Client.LockOpts(opts)
|
||||
lock, err := c.client.LockOpts(opts)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("failed to create lock: {{err}}", err)
|
||||
}
|
||||
cl := &ConsulLock{
|
||||
client: c.Client,
|
||||
client: c.client,
|
||||
key: c.path + key,
|
||||
lock: lock,
|
||||
consistencyMode: c.consistencyMode,
|
||||
|
@ -323,7 +369,7 @@ func (c *ConsulBackend) HAEnabled() bool {
|
|||
|
||||
// DetectHostAddr is used to detect the host address by asking the Consul agent
|
||||
func (c *ConsulBackend) DetectHostAddr() (string, error) {
|
||||
agent := c.Client.Agent()
|
||||
agent := c.client.Agent()
|
||||
self, err := agent.Self()
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -16,75 +15,8 @@ import (
|
|||
"github.com/hashicorp/vault/helper/testhelpers/consul"
|
||||
"github.com/hashicorp/vault/sdk/helper/logging"
|
||||
"github.com/hashicorp/vault/sdk/physical"
|
||||
sr "github.com/hashicorp/vault/serviceregistration"
|
||||
)
|
||||
|
||||
type consulConf map[string]string
|
||||
|
||||
var (
|
||||
addrCount int = 0
|
||||
)
|
||||
|
||||
func testConsulBackend(t *testing.T) *ConsulBackend {
|
||||
return testConsulBackendConfig(t, &consulConf{})
|
||||
}
|
||||
|
||||
func testConsulBackendConfig(t *testing.T, conf *consulConf) *ConsulBackend {
|
||||
logger := logging.NewVaultLogger(log.Debug)
|
||||
|
||||
be, err := NewConsulBackend(*conf, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected Consul to initialize: %v", err)
|
||||
}
|
||||
|
||||
c, ok := be.(*ConsulBackend)
|
||||
if !ok {
|
||||
t.Fatalf("Expected ConsulBackend")
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func testConsul_testConsulBackend(t *testing.T) {
|
||||
c := testConsulBackend(t)
|
||||
if c == nil {
|
||||
t.Fatalf("bad")
|
||||
}
|
||||
}
|
||||
|
||||
func testActiveFunc(activePct float64) sr.ActiveFunction {
|
||||
return func() bool {
|
||||
var active bool
|
||||
standbyProb := rand.Float64()
|
||||
if standbyProb > activePct {
|
||||
active = true
|
||||
}
|
||||
return active
|
||||
}
|
||||
}
|
||||
|
||||
func testSealedFunc(sealedPct float64) sr.SealedFunction {
|
||||
return func() bool {
|
||||
var sealed bool
|
||||
unsealedProb := rand.Float64()
|
||||
if unsealedProb > sealedPct {
|
||||
sealed = true
|
||||
}
|
||||
return sealed
|
||||
}
|
||||
}
|
||||
|
||||
func testPerformanceStandbyFunc(perfPct float64) sr.PerformanceStandbyFunction {
|
||||
return func() bool {
|
||||
var ps bool
|
||||
unsealedProb := rand.Float64()
|
||||
if unsealedProb > perfPct {
|
||||
ps = true
|
||||
}
|
||||
return ps
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsul_newConsulBackend(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -175,13 +107,6 @@ func TestConsul_newConsulBackend(t *testing.T) {
|
|||
disableReg: false,
|
||||
consistencyMode: "default",
|
||||
},
|
||||
{
|
||||
name: "check timeout too short",
|
||||
fail: true,
|
||||
consulConfig: map[string]string{
|
||||
"check_timeout": "99ms",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
|
@ -203,12 +128,6 @@ func TestConsul_newConsulBackend(t *testing.T) {
|
|||
t.Fatalf("Expected ConsulBackend: %s", test.name)
|
||||
}
|
||||
|
||||
var shutdownCh sr.ShutdownChannel
|
||||
waitGroup := &sync.WaitGroup{}
|
||||
if err := c.RunServiceRegistration(waitGroup, shutdownCh, test.redirectAddr, testActiveFunc(0.5), testSealedFunc(0.5), testPerformanceStandbyFunc(0.5)); err != nil {
|
||||
t.Fatalf("bad: %v", err)
|
||||
}
|
||||
|
||||
if test.path != c.path {
|
||||
t.Errorf("bad: %s %v != %v", test.name, test.path, c.path)
|
||||
}
|
||||
|
@ -219,7 +138,7 @@ func TestConsul_newConsulBackend(t *testing.T) {
|
|||
|
||||
// The configuration stored in the Consul "client" object is not exported, so
|
||||
// we either have to skip validating it, or add a method to export it, or use reflection.
|
||||
consulConfig := reflect.Indirect(reflect.ValueOf(c.Client)).FieldByName("config")
|
||||
consulConfig := reflect.Indirect(reflect.ValueOf(c.client)).FieldByName("config")
|
||||
consulConfigScheme := consulConfig.FieldByName("Scheme").String()
|
||||
consulConfigAddress := consulConfig.FieldByName("Address").String()
|
||||
|
||||
|
@ -238,22 +157,6 @@ func TestConsul_newConsulBackend(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestConsul_NotifyActiveStateChange(t *testing.T) {
|
||||
c := testConsulBackend(t)
|
||||
|
||||
if err := c.NotifyActiveStateChange(); err != nil {
|
||||
t.Fatalf("bad: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsul_NotifySealedStateChange(t *testing.T) {
|
||||
c := testConsulBackend(t)
|
||||
|
||||
if err := c.NotifySealedStateChange(); err != nil {
|
||||
t.Fatalf("bad: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsulBackend(t *testing.T) {
|
||||
consulToken := os.Getenv("CONSUL_HTTP_TOKEN")
|
||||
addr := os.Getenv("CONSUL_HTTP_ADDR")
|
||||
|
|
|
@ -5,7 +5,12 @@ import (
|
|||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/vault/sdk/helper/parseutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/strutil"
|
||||
)
|
||||
|
||||
|
@ -110,3 +115,72 @@ func ClientTLSConfig(caCert []byte, clientCert []byte, clientKey []byte) (*tls.C
|
|||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
func SetupTLSConfig(conf map[string]string, address string) (*tls.Config, error) {
|
||||
serverName, _, err := net.SplitHostPort(address)
|
||||
switch {
|
||||
case err == nil:
|
||||
case strings.Contains(err.Error(), "missing port"):
|
||||
serverName = conf["address"]
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
insecureSkipVerify := false
|
||||
tlsSkipVerify := conf["tls_skip_verify"]
|
||||
|
||||
if tlsSkipVerify != "" {
|
||||
b, err := parseutil.ParseBool(tlsSkipVerify)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("failed parsing tls_skip_verify parameter: {{err}}", err)
|
||||
}
|
||||
insecureSkipVerify = b
|
||||
}
|
||||
|
||||
tlsMinVersionStr, ok := conf["tls_min_version"]
|
||||
if !ok {
|
||||
// Set the default value
|
||||
tlsMinVersionStr = "tls12"
|
||||
}
|
||||
|
||||
tlsMinVersion, ok := TLSLookup[tlsMinVersionStr]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'tls_min_version'")
|
||||
}
|
||||
|
||||
tlsClientConfig := &tls.Config{
|
||||
MinVersion: tlsMinVersion,
|
||||
InsecureSkipVerify: insecureSkipVerify,
|
||||
ServerName: serverName,
|
||||
}
|
||||
|
||||
_, okCert := conf["tls_cert_file"]
|
||||
_, okKey := conf["tls_key_file"]
|
||||
|
||||
if okCert && okKey {
|
||||
tlsCert, err := tls.LoadX509KeyPair(conf["tls_cert_file"], conf["tls_key_file"])
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("client tls setup failed: {{err}}", err)
|
||||
}
|
||||
|
||||
tlsClientConfig.Certificates = []tls.Certificate{tlsCert}
|
||||
} else if okCert || okKey {
|
||||
return nil, fmt.Errorf("both tls_cert_file and tls_key_file must be provided")
|
||||
}
|
||||
|
||||
if tlsCaFile, ok := conf["tls_ca_file"]; ok {
|
||||
caPool := x509.NewCertPool()
|
||||
|
||||
data, err := ioutil.ReadFile(tlsCaFile)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("failed to read CA file: {{err}}", err)
|
||||
}
|
||||
|
||||
if !caPool.AppendCertsFromPEM(data) {
|
||||
return nil, fmt.Errorf("failed to parse CA certificate")
|
||||
}
|
||||
|
||||
tlsClientConfig.RootCAs = caPool
|
||||
}
|
||||
return tlsClientConfig, nil
|
||||
}
|
||||
|
|
|
@ -1,11 +1,8 @@
|
|||
package consul
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -52,17 +49,13 @@ const (
|
|||
reconcileTimeout = 60 * time.Second
|
||||
)
|
||||
|
||||
type notifyEvent struct{}
|
||||
|
||||
var _ sr.ServiceRegistration = (*ConsulServiceRegistration)(nil)
|
||||
|
||||
var (
|
||||
hostnameRegex = regexp.MustCompile(`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`)
|
||||
)
|
||||
|
||||
// ConsulServiceRegistration is a ServiceRegistration that advertises the state of
|
||||
// serviceRegistration is a ServiceRegistration that advertises the state of
|
||||
// Vault to Consul.
|
||||
type ConsulServiceRegistration struct {
|
||||
type serviceRegistration struct {
|
||||
Client *api.Client
|
||||
|
||||
logger log.Logger
|
||||
|
@ -74,14 +67,18 @@ type ConsulServiceRegistration struct {
|
|||
serviceAddress *string
|
||||
disableRegistration bool
|
||||
checkTimeout time.Duration
|
||||
redirectAddr string
|
||||
|
||||
notifyActiveCh chan notifyEvent
|
||||
notifySealedCh chan notifyEvent
|
||||
notifyPerfStandbyCh chan notifyEvent
|
||||
notifyActiveCh chan bool
|
||||
notifySealedCh chan bool
|
||||
notifyPerfStandbyCh chan bool
|
||||
|
||||
stateLock sync.RWMutex
|
||||
isActive, isSealed, isPerfStandby bool
|
||||
}
|
||||
|
||||
// NewConsulServiceRegistration constructs a Consul-based ServiceRegistration.
|
||||
func NewConsulServiceRegistration(conf map[string]string, logger log.Logger) (sr.ServiceRegistration, error) {
|
||||
func NewServiceRegistration(conf map[string]string, logger log.Logger, state sr.State, redirectAddr string) (sr.ServiceRegistration, error) {
|
||||
|
||||
// Allow admins to disable consul integration
|
||||
disableReg, ok := conf["disable_registration"]
|
||||
|
@ -183,7 +180,7 @@ func NewConsulServiceRegistration(conf map[string]string, logger log.Logger) (sr
|
|||
|
||||
if consulConf.Scheme == "https" {
|
||||
// Use the parsed Address instead of the raw conf['address']
|
||||
tlsClientConfig, err := setupTLSConfig(conf, consulConf.Address)
|
||||
tlsClientConfig, err := tlsutil.SetupTLSConfig(conf, consulConf.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -202,7 +199,7 @@ func NewConsulServiceRegistration(conf map[string]string, logger log.Logger) (sr
|
|||
}
|
||||
|
||||
// Setup the backend
|
||||
c := &ConsulServiceRegistration{
|
||||
c := &serviceRegistration{
|
||||
Client: client,
|
||||
|
||||
logger: logger,
|
||||
|
@ -211,87 +208,33 @@ func NewConsulServiceRegistration(conf map[string]string, logger log.Logger) (sr
|
|||
serviceAddress: serviceAddr,
|
||||
checkTimeout: checkTimeout,
|
||||
disableRegistration: disableRegistration,
|
||||
redirectAddr: redirectAddr,
|
||||
|
||||
notifyActiveCh: make(chan notifyEvent),
|
||||
notifySealedCh: make(chan notifyEvent),
|
||||
notifyPerfStandbyCh: make(chan notifyEvent),
|
||||
notifyActiveCh: make(chan bool),
|
||||
notifySealedCh: make(chan bool),
|
||||
notifyPerfStandbyCh: make(chan bool),
|
||||
|
||||
isActive: state.IsActive,
|
||||
isSealed: state.IsSealed,
|
||||
isPerfStandby: state.IsPerformanceStandby,
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func setupTLSConfig(conf map[string]string, address string) (*tls.Config, error) {
|
||||
serverName, _, err := net.SplitHostPort(address)
|
||||
switch {
|
||||
case err == nil:
|
||||
case strings.Contains(err.Error(), "missing port"):
|
||||
serverName = conf["address"]
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
insecureSkipVerify := false
|
||||
tlsSkipVerify, ok := conf["tls_skip_verify"]
|
||||
|
||||
if ok && tlsSkipVerify != "" {
|
||||
b, err := parseutil.ParseBool(tlsSkipVerify)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("failed parsing tls_skip_verify parameter: {{err}}", err)
|
||||
func (c *serviceRegistration) Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup) error {
|
||||
go func() {
|
||||
if err := c.runServiceRegistration(wait, shutdownCh, c.redirectAddr); err != nil {
|
||||
if c.logger.IsError() {
|
||||
c.logger.Error(fmt.Sprintf("error running service registration: %s", err))
|
||||
}
|
||||
}
|
||||
insecureSkipVerify = b
|
||||
}
|
||||
|
||||
tlsMinVersionStr, ok := conf["tls_min_version"]
|
||||
if !ok {
|
||||
// Set the default value
|
||||
tlsMinVersionStr = "tls12"
|
||||
}
|
||||
|
||||
tlsMinVersion, ok := tlsutil.TLSLookup[tlsMinVersionStr]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'tls_min_version'")
|
||||
}
|
||||
|
||||
tlsClientConfig := &tls.Config{
|
||||
MinVersion: tlsMinVersion,
|
||||
InsecureSkipVerify: insecureSkipVerify,
|
||||
ServerName: serverName,
|
||||
}
|
||||
|
||||
_, okCert := conf["tls_cert_file"]
|
||||
_, okKey := conf["tls_key_file"]
|
||||
|
||||
if okCert && okKey {
|
||||
tlsCert, err := tls.LoadX509KeyPair(conf["tls_cert_file"], conf["tls_key_file"])
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("client tls setup failed: {{err}}", err)
|
||||
}
|
||||
|
||||
tlsClientConfig.Certificates = []tls.Certificate{tlsCert}
|
||||
} else if okCert || okKey {
|
||||
return nil, fmt.Errorf("both tls_cert_file and tls_key_file must be provided")
|
||||
}
|
||||
|
||||
if tlsCaFile, ok := conf["tls_ca_file"]; ok {
|
||||
caPool := x509.NewCertPool()
|
||||
|
||||
data, err := ioutil.ReadFile(tlsCaFile)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("failed to read CA file: {{err}}", err)
|
||||
}
|
||||
|
||||
if !caPool.AppendCertsFromPEM(data) {
|
||||
return nil, fmt.Errorf("failed to parse CA certificate")
|
||||
}
|
||||
|
||||
tlsClientConfig.RootCAs = caPool
|
||||
}
|
||||
|
||||
return tlsClientConfig, nil
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ConsulServiceRegistration) NotifyActiveStateChange() error {
|
||||
func (c *serviceRegistration) NotifyActiveStateChange(isActive bool) error {
|
||||
select {
|
||||
case c.notifyActiveCh <- notifyEvent{}:
|
||||
case c.notifyActiveCh <- isActive:
|
||||
default:
|
||||
// NOTE: If this occurs Vault's active status could be out of
|
||||
// sync with Consul until reconcileTimer expires.
|
||||
|
@ -301,9 +244,9 @@ func (c *ConsulServiceRegistration) NotifyActiveStateChange() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *ConsulServiceRegistration) NotifyPerformanceStandbyStateChange() error {
|
||||
func (c *serviceRegistration) NotifyPerformanceStandbyStateChange(isStandby bool) error {
|
||||
select {
|
||||
case c.notifyPerfStandbyCh <- notifyEvent{}:
|
||||
case c.notifyPerfStandbyCh <- isStandby:
|
||||
default:
|
||||
// NOTE: If this occurs Vault's active status could be out of
|
||||
// sync with Consul until reconcileTimer expires.
|
||||
|
@ -313,9 +256,9 @@ func (c *ConsulServiceRegistration) NotifyPerformanceStandbyStateChange() error
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *ConsulServiceRegistration) NotifySealedStateChange() error {
|
||||
func (c *serviceRegistration) NotifySealedStateChange(isSealed bool) error {
|
||||
select {
|
||||
case c.notifySealedCh <- notifyEvent{}:
|
||||
case c.notifySealedCh <- isSealed:
|
||||
default:
|
||||
// NOTE: If this occurs Vault's sealed status could be out of
|
||||
// sync with Consul until checkTimer expires.
|
||||
|
@ -325,11 +268,18 @@ func (c *ConsulServiceRegistration) NotifySealedStateChange() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *ConsulServiceRegistration) checkDuration() time.Duration {
|
||||
func (c *serviceRegistration) NotifyInitializedStateChange(isInitialized bool) error {
|
||||
// This is not implemented because to date, Consul service registration has
|
||||
// never reported out on whether Vault was initialized. We may someday want to
|
||||
// do this, but it has not yet been requested.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *serviceRegistration) checkDuration() time.Duration {
|
||||
return durationMinusBuffer(c.checkTimeout, checkMinBuffer, checkJitterFactor)
|
||||
}
|
||||
|
||||
func (c *ConsulServiceRegistration) RunServiceRegistration(waitGroup *sync.WaitGroup, shutdownCh sr.ShutdownChannel, redirectAddr string, activeFunc sr.ActiveFunction, sealedFunc sr.SealedFunction, perfStandbyFunc sr.PerformanceStandbyFunction) (err error) {
|
||||
func (c *serviceRegistration) runServiceRegistration(waitGroup *sync.WaitGroup, shutdownCh <-chan struct{}, redirectAddr string) (err error) {
|
||||
if err := c.setRedirectAddr(redirectAddr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -337,12 +287,12 @@ func (c *ConsulServiceRegistration) RunServiceRegistration(waitGroup *sync.WaitG
|
|||
// 'server' command will wait for the below goroutine to complete
|
||||
waitGroup.Add(1)
|
||||
|
||||
go c.runEventDemuxer(waitGroup, shutdownCh, redirectAddr, activeFunc, sealedFunc, perfStandbyFunc)
|
||||
go c.runEventDemuxer(waitGroup, shutdownCh, redirectAddr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ConsulServiceRegistration) runEventDemuxer(waitGroup *sync.WaitGroup, shutdownCh sr.ShutdownChannel, redirectAddr string, activeFunc sr.ActiveFunction, sealedFunc sr.SealedFunction, perfStandbyFunc sr.PerformanceStandbyFunction) {
|
||||
func (c *serviceRegistration) runEventDemuxer(waitGroup *sync.WaitGroup, shutdownCh <-chan struct{}, redirectAddr string) {
|
||||
// This defer statement should be executed last. So push it first.
|
||||
defer waitGroup.Done()
|
||||
|
||||
|
@ -372,13 +322,25 @@ func (c *ConsulServiceRegistration) runEventDemuxer(waitGroup *sync.WaitGroup, s
|
|||
|
||||
for !shutdown {
|
||||
select {
|
||||
case <-c.notifyActiveCh:
|
||||
case isActive := <-c.notifyActiveCh:
|
||||
c.stateLock.Lock()
|
||||
c.isActive = isActive
|
||||
c.stateLock.Unlock()
|
||||
|
||||
// Run reconcile immediately upon active state change notification
|
||||
reconcileTimer.Reset(0)
|
||||
case <-c.notifySealedCh:
|
||||
case isSealed := <-c.notifySealedCh:
|
||||
c.stateLock.Lock()
|
||||
c.isSealed = isSealed
|
||||
c.stateLock.Unlock()
|
||||
|
||||
// Run check timer immediately upon a seal state change notification
|
||||
checkTimer.Reset(0)
|
||||
case <-c.notifyPerfStandbyCh:
|
||||
case isStandby := <-c.notifyPerfStandbyCh:
|
||||
c.stateLock.Lock()
|
||||
c.isPerfStandby = isStandby
|
||||
c.stateLock.Unlock()
|
||||
|
||||
// Run check timer immediately upon a seal state change notification
|
||||
checkTimer.Reset(0)
|
||||
case <-reconcileTimer.C:
|
||||
|
@ -392,7 +354,7 @@ func (c *ConsulServiceRegistration) runEventDemuxer(waitGroup *sync.WaitGroup, s
|
|||
go func() {
|
||||
defer atomic.CompareAndSwapInt32(serviceRegLock, 1, 0)
|
||||
for !shutdown {
|
||||
serviceID, err := c.reconcileConsul(registeredServiceID, activeFunc, sealedFunc, perfStandbyFunc)
|
||||
serviceID, err := c.reconcileConsul(registeredServiceID)
|
||||
if err != nil {
|
||||
if c.logger.IsWarn() {
|
||||
c.logger.Warn("reconcile unable to talk with Consul backend", "error", err)
|
||||
|
@ -418,7 +380,9 @@ func (c *ConsulServiceRegistration) runEventDemuxer(waitGroup *sync.WaitGroup, s
|
|||
go func() {
|
||||
defer atomic.CompareAndSwapInt32(checkLock, 1, 0)
|
||||
for !shutdown {
|
||||
sealed := sealedFunc()
|
||||
c.stateLock.RLock()
|
||||
sealed := c.isSealed
|
||||
c.stateLock.RUnlock()
|
||||
if err := c.runCheck(sealed); err != nil {
|
||||
if c.logger.IsWarn() {
|
||||
c.logger.Warn("check unable to talk with Consul backend", "error", err)
|
||||
|
@ -447,26 +411,28 @@ func (c *ConsulServiceRegistration) runEventDemuxer(waitGroup *sync.WaitGroup, s
|
|||
|
||||
// checkID returns the ID used for a Consul Check. Assume at least a read
|
||||
// lock is held.
|
||||
func (c *ConsulServiceRegistration) checkID() string {
|
||||
func (c *serviceRegistration) checkID() string {
|
||||
return fmt.Sprintf("%s:vault-sealed-check", c.serviceID())
|
||||
}
|
||||
|
||||
// serviceID returns the Vault ServiceID for use in Consul. Assume at least
|
||||
// a read lock is held.
|
||||
func (c *ConsulServiceRegistration) serviceID() string {
|
||||
func (c *serviceRegistration) serviceID() string {
|
||||
return fmt.Sprintf("%s:%s:%d", c.serviceName, c.redirectHost, c.redirectPort)
|
||||
}
|
||||
|
||||
// reconcileConsul queries the state of Vault Core and Consul and fixes up
|
||||
// Consul's state according to what's in Vault. reconcileConsul is called
|
||||
// without any locks held and can be run concurrently, therefore no changes
|
||||
// to ConsulServiceRegistration can be made in this method (i.e. wtb const receiver for
|
||||
// to serviceRegistration can be made in this method (i.e. wtb const receiver for
|
||||
// compiler enforced safety).
|
||||
func (c *ConsulServiceRegistration) reconcileConsul(registeredServiceID string, activeFunc sr.ActiveFunction, sealedFunc sr.SealedFunction, perfStandbyFunc sr.PerformanceStandbyFunction) (serviceID string, err error) {
|
||||
func (c *serviceRegistration) reconcileConsul(registeredServiceID string) (serviceID string, err error) {
|
||||
// Query vault Core for its current state
|
||||
active := activeFunc()
|
||||
sealed := sealedFunc()
|
||||
perfStandby := perfStandbyFunc()
|
||||
c.stateLock.RLock()
|
||||
active := c.isActive
|
||||
sealed := c.isSealed
|
||||
perfStandby := c.isPerfStandby
|
||||
c.stateLock.RUnlock()
|
||||
|
||||
agent := c.Client.Agent()
|
||||
catalog := c.Client.Catalog()
|
||||
|
@ -550,7 +516,7 @@ func (c *ConsulServiceRegistration) reconcileConsul(registeredServiceID string,
|
|||
}
|
||||
|
||||
// runCheck immediately pushes a TTL check.
|
||||
func (c *ConsulServiceRegistration) runCheck(sealed bool) error {
|
||||
func (c *serviceRegistration) runCheck(sealed bool) error {
|
||||
// Run a TTL check
|
||||
agent := c.Client.Agent()
|
||||
if !sealed {
|
||||
|
@ -561,7 +527,7 @@ func (c *ConsulServiceRegistration) runCheck(sealed bool) error {
|
|||
}
|
||||
|
||||
// fetchServiceTags returns all of the relevant tags for Consul.
|
||||
func (c *ConsulServiceRegistration) fetchServiceTags(active bool, perfStandby bool) []string {
|
||||
func (c *serviceRegistration) fetchServiceTags(active bool, perfStandby bool) []string {
|
||||
activeTag := "standby"
|
||||
if active {
|
||||
activeTag = "active"
|
||||
|
@ -576,7 +542,7 @@ func (c *ConsulServiceRegistration) fetchServiceTags(active bool, perfStandby bo
|
|||
return result
|
||||
}
|
||||
|
||||
func (c *ConsulServiceRegistration) setRedirectAddr(addr string) (err error) {
|
||||
func (c *serviceRegistration) setRedirectAddr(addr string) (err error) {
|
||||
if addr == "" {
|
||||
return fmt.Errorf("redirect address must not be empty")
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package consul
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"os"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
@ -22,59 +21,33 @@ import (
|
|||
|
||||
type consulConf map[string]string
|
||||
|
||||
func testConsulServiceRegistration(t *testing.T) *ConsulServiceRegistration {
|
||||
func testConsulServiceRegistration(t *testing.T) *serviceRegistration {
|
||||
return testConsulServiceRegistrationConfig(t, &consulConf{})
|
||||
}
|
||||
|
||||
func testConsulServiceRegistrationConfig(t *testing.T, conf *consulConf) *ConsulServiceRegistration {
|
||||
func testConsulServiceRegistrationConfig(t *testing.T, conf *consulConf) *serviceRegistration {
|
||||
logger := logging.NewVaultLogger(log.Debug)
|
||||
|
||||
be, err := NewConsulServiceRegistration(*conf, logger)
|
||||
shutdownCh := make(chan struct{})
|
||||
defer func() {
|
||||
close(shutdownCh)
|
||||
}()
|
||||
be, err := NewServiceRegistration(*conf, logger, sr.State{}, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Expected Consul to initialize: %v", err)
|
||||
}
|
||||
if err := be.Run(shutdownCh, &sync.WaitGroup{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c, ok := be.(*ConsulServiceRegistration)
|
||||
c, ok := be.(*serviceRegistration)
|
||||
if !ok {
|
||||
t.Fatalf("Expected ConsulServiceRegistration")
|
||||
t.Fatalf("Expected serviceRegistration")
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func testActiveFunc(activePct float64) sr.ActiveFunction {
|
||||
return func() bool {
|
||||
var active bool
|
||||
standbyProb := rand.Float64()
|
||||
if standbyProb > activePct {
|
||||
active = true
|
||||
}
|
||||
return active
|
||||
}
|
||||
}
|
||||
|
||||
func testSealedFunc(sealedPct float64) sr.SealedFunction {
|
||||
return func() bool {
|
||||
var sealed bool
|
||||
unsealedProb := rand.Float64()
|
||||
if unsealedProb > sealedPct {
|
||||
sealed = true
|
||||
}
|
||||
return sealed
|
||||
}
|
||||
}
|
||||
|
||||
func testPerformanceStandbyFunc(perfPct float64) sr.PerformanceStandbyFunction {
|
||||
return func() bool {
|
||||
var ps bool
|
||||
unsealedProb := rand.Float64()
|
||||
if unsealedProb > perfPct {
|
||||
ps = true
|
||||
}
|
||||
return ps
|
||||
}
|
||||
}
|
||||
|
||||
// TestConsul_ServiceRegistration tests whether consul ServiceRegistration works
|
||||
func TestConsul_ServiceRegistration(t *testing.T) {
|
||||
|
||||
|
@ -110,15 +83,24 @@ func TestConsul_ServiceRegistration(t *testing.T) {
|
|||
return nil
|
||||
}
|
||||
|
||||
shutdownCh := make(chan struct{})
|
||||
defer func() {
|
||||
close(shutdownCh)
|
||||
}()
|
||||
const redirectAddr = "http://127.0.0.1:8200"
|
||||
|
||||
// Create a ServiceRegistration that points to our consul instance
|
||||
logger := logging.NewVaultLogger(log.Trace)
|
||||
sd, err := NewConsulServiceRegistration(map[string]string{
|
||||
sd, err := NewServiceRegistration(map[string]string{
|
||||
"address": addr,
|
||||
"token": token,
|
||||
}, logger)
|
||||
}, logger, sr.State{}, redirectAddr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := sd.Run(shutdownCh, &sync.WaitGroup{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create the core
|
||||
inm, err := inmem.NewInmemHA(nil, logger)
|
||||
|
@ -129,7 +111,6 @@ func TestConsul_ServiceRegistration(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
const redirectAddr = "http://127.0.0.1:8200"
|
||||
core, err := vault.NewCore(&vault.CoreConfig{
|
||||
ServiceRegistration: sd,
|
||||
Physical: inm,
|
||||
|
@ -152,21 +133,6 @@ func TestConsul_ServiceRegistration(t *testing.T) {
|
|||
t.Fatal(diff)
|
||||
}
|
||||
|
||||
// Run service discovery on the core
|
||||
wg := &sync.WaitGroup{}
|
||||
var shutdown chan struct{}
|
||||
activeFunc := func() bool {
|
||||
if isLeader, _, _, err := core.Leader(); err == nil {
|
||||
return isLeader
|
||||
}
|
||||
return false
|
||||
}
|
||||
err = sd.RunServiceRegistration(
|
||||
wg, shutdown, redirectAddr, activeFunc, core.Sealed, core.PerfStandby)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Vault should soon be registered with Consul in standby mode
|
||||
services = transitionFrom(t, map[string][]string{
|
||||
"consul": []string{},
|
||||
|
@ -220,12 +186,20 @@ func TestConsul_ServiceTags(t *testing.T) {
|
|||
}
|
||||
logger := logging.NewVaultLogger(log.Debug)
|
||||
|
||||
be, err := NewConsulServiceRegistration(consulConfig, logger)
|
||||
shutdownCh := make(chan struct{})
|
||||
defer func() {
|
||||
close(shutdownCh)
|
||||
}()
|
||||
|
||||
be, err := NewServiceRegistration(consulConfig, logger, sr.State{}, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := be.Run(shutdownCh, &sync.WaitGroup{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c, ok := be.(*ConsulServiceRegistration)
|
||||
c, ok := be.(*serviceRegistration)
|
||||
if !ok {
|
||||
t.Fatalf("failed to create physical Consul backend")
|
||||
}
|
||||
|
@ -273,14 +247,18 @@ func TestConsul_ServiceAddress(t *testing.T) {
|
|||
}
|
||||
|
||||
for _, test := range tests {
|
||||
shutdownCh := make(chan struct{})
|
||||
logger := logging.NewVaultLogger(log.Debug)
|
||||
|
||||
be, err := NewConsulServiceRegistration(test.consulConfig, logger)
|
||||
be, err := NewServiceRegistration(test.consulConfig, logger, sr.State{}, "")
|
||||
if err != nil {
|
||||
t.Fatalf("expected Consul to initialize: %v", err)
|
||||
}
|
||||
if err := be.Run(shutdownCh, &sync.WaitGroup{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c, ok := be.(*ConsulServiceRegistration)
|
||||
c, ok := be.(*serviceRegistration)
|
||||
if !ok {
|
||||
t.Fatalf("Expected ConsulServiceRegistration")
|
||||
}
|
||||
|
@ -294,6 +272,7 @@ func TestConsul_ServiceAddress(t *testing.T) {
|
|||
t.Fatalf("did not expect service address to be nil")
|
||||
}
|
||||
}
|
||||
close(shutdownCh)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -397,9 +376,10 @@ func TestConsul_newConsulServiceRegistration(t *testing.T) {
|
|||
}
|
||||
|
||||
for _, test := range tests {
|
||||
shutdownCh := make(chan struct{})
|
||||
logger := logging.NewVaultLogger(log.Debug)
|
||||
|
||||
be, err := NewConsulServiceRegistration(test.consulConfig, logger)
|
||||
be, err := NewServiceRegistration(test.consulConfig, logger, sr.State{}, "")
|
||||
if test.fail {
|
||||
if err == nil {
|
||||
t.Fatalf(`Expected config "%s" to fail`, test.name)
|
||||
|
@ -409,8 +389,11 @@ func TestConsul_newConsulServiceRegistration(t *testing.T) {
|
|||
} else if !test.fail && err != nil {
|
||||
t.Fatalf("Expected config %s to not fail: %v", test.name, err)
|
||||
}
|
||||
if err := be.Run(shutdownCh, &sync.WaitGroup{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c, ok := be.(*ConsulServiceRegistration)
|
||||
c, ok := be.(*serviceRegistration)
|
||||
if !ok {
|
||||
t.Fatalf("Expected ConsulServiceRegistration: %s", test.name)
|
||||
}
|
||||
|
@ -423,12 +406,6 @@ func TestConsul_newConsulServiceRegistration(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
var shutdownCh sr.ShutdownChannel
|
||||
waitGroup := &sync.WaitGroup{}
|
||||
if err := c.RunServiceRegistration(waitGroup, shutdownCh, test.redirectAddr, testActiveFunc(0.5), testSealedFunc(0.5), testPerformanceStandbyFunc(0.5)); err != nil {
|
||||
t.Fatalf("bad: %v", err)
|
||||
}
|
||||
|
||||
if test.checkTimeout != c.checkTimeout {
|
||||
t.Errorf("bad: %v != %v", test.checkTimeout, c.checkTimeout)
|
||||
}
|
||||
|
@ -455,6 +432,7 @@ func TestConsul_newConsulServiceRegistration(t *testing.T) {
|
|||
// if test.max_parallel != cap(c.permitPool) {
|
||||
// t.Errorf("bad: %v != %v", test.max_parallel, cap(c.permitPool))
|
||||
// }
|
||||
close(shutdownCh)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -602,9 +580,10 @@ func TestConsul_serviceID(t *testing.T) {
|
|||
logger := logging.NewVaultLogger(log.Debug)
|
||||
|
||||
for _, test := range tests {
|
||||
be, err := NewConsulServiceRegistration(consulConf{
|
||||
shutdownCh := make(chan struct{})
|
||||
be, err := NewServiceRegistration(consulConf{
|
||||
"service": test.serviceName,
|
||||
}, logger)
|
||||
}, logger, sr.State{}, "")
|
||||
if !test.valid {
|
||||
if err == nil {
|
||||
t.Fatalf("expected an error initializing for name %q", test.serviceName)
|
||||
|
@ -614,10 +593,13 @@ func TestConsul_serviceID(t *testing.T) {
|
|||
if test.valid && err != nil {
|
||||
t.Fatalf("expected Consul to initialize: %v", err)
|
||||
}
|
||||
if err := be.Run(shutdownCh, &sync.WaitGroup{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c, ok := be.(*ConsulServiceRegistration)
|
||||
c, ok := be.(*serviceRegistration)
|
||||
if !ok {
|
||||
t.Fatalf("Expected ConsulServiceRegistration")
|
||||
t.Fatalf("Expected serviceRegistration")
|
||||
}
|
||||
|
||||
if err := c.setRedirectAddr(test.redirectAddr); err != nil {
|
||||
|
|
|
@ -1,40 +1,93 @@
|
|||
package serviceregistration
|
||||
|
||||
/*
|
||||
ServiceRegistration is an interface that can be fulfilled to use
|
||||
varying applications for service discovery, regardless of the physical
|
||||
back-end used.
|
||||
|
||||
Service registration implements notifications for changes in _dynamic_
|
||||
properties regarding Vault's health. Vault's version is the only static
|
||||
property given in state for now, but if there's a need for more in the future,
|
||||
we could add them on.
|
||||
*/
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
)
|
||||
|
||||
type State struct {
|
||||
VaultVersion string
|
||||
IsInitialized, IsSealed, IsActive, IsPerformanceStandby bool
|
||||
}
|
||||
|
||||
// Factory is the factory function to create a ServiceRegistration.
|
||||
type Factory func(config map[string]string, logger log.Logger) (ServiceRegistration, error)
|
||||
// The config is the key/value pairs set _inside_ the service registration config stanza.
|
||||
// The state is the initial state.
|
||||
// The redirectAddr is Vault core's RedirectAddr.
|
||||
type Factory func(config map[string]string, logger log.Logger, state State, redirectAddr string) (ServiceRegistration, error)
|
||||
|
||||
// ServiceRegistration is an interface that advertises the state of Vault to a
|
||||
// service discovery network.
|
||||
type ServiceRegistration interface {
|
||||
// Run provides a shutdownCh and wait WaitGroup. The shutdownCh
|
||||
// is for monitoring when a shutdown occurs and initiating any actions needed
|
||||
// to leave service registration in a final state. When finished, signalling
|
||||
// that with wait means that Vault will wait until complete.
|
||||
// Run is called just after Factory instantiation so can be relied upon
|
||||
// for controlling shutdown behavior.
|
||||
// Here is an example of its intended use:
|
||||
// func Run(shutdownCh <-chan struct{}, wait sync.WaitGroup) error {
|
||||
//
|
||||
// // Run shutdown code in a goroutine so Run doesn't block.
|
||||
// go func(){
|
||||
// // Since we are going to want Vault to wait to shutdown
|
||||
// // until after we do cleanup...
|
||||
// wait.Add(1)
|
||||
//
|
||||
// // Ensure that when this ends, no matter how it ends,
|
||||
// // we don't cause Vault to hang on shutdown.
|
||||
// defer wait.Done()
|
||||
//
|
||||
// // Now wait until we're actually receiving a shutdown.
|
||||
// <-shutdownCh
|
||||
//
|
||||
// // Now do whatever we need to clean up.
|
||||
// if err := someService.SetFinalState(); err != nil {
|
||||
// // Log it at error level.
|
||||
// }
|
||||
// }()
|
||||
// return nil
|
||||
// }
|
||||
Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup) error
|
||||
|
||||
// NotifyActiveStateChange is used by Core to notify that this Vault
|
||||
// instance has changed its status to active or standby.
|
||||
NotifyActiveStateChange() error
|
||||
// instance has changed its status on whether it's active or is
|
||||
// a standby.
|
||||
// If errors are returned, Vault only logs a warning, so it is
|
||||
// the implementation's responsibility to retry updating state
|
||||
// in the face of errors.
|
||||
NotifyActiveStateChange(isActive bool) error
|
||||
|
||||
// NotifySealedStateChange is used by Core to notify that Vault has changed
|
||||
// its Sealed status to sealed or unsealed.
|
||||
NotifySealedStateChange() error
|
||||
// If errors are returned, Vault only logs a warning, so it is
|
||||
// the implementation's responsibility to retry updating state
|
||||
// in the face of errors.
|
||||
NotifySealedStateChange(isSealed bool) error
|
||||
|
||||
// NotifyPerformanceStandbyStateChange is used by Core to notify that this
|
||||
// Vault instance has changed it status to performance standby or standby.
|
||||
NotifyPerformanceStandbyStateChange() error
|
||||
// Vault instance has changed its performance standby status.
|
||||
// If errors are returned, Vault only logs a warning, so it is
|
||||
// the implementation's responsibility to retry updating state
|
||||
// in the face of errors.
|
||||
NotifyPerformanceStandbyStateChange(isStandby bool) error
|
||||
|
||||
// Run executes any background service discovery tasks until the
|
||||
// shutdown channel is closed.
|
||||
RunServiceRegistration(
|
||||
waitGroup *sync.WaitGroup, shutdownCh ShutdownChannel, redirectAddr string,
|
||||
activeFunc ActiveFunction, sealedFunc SealedFunction, perfStandbyFunc PerformanceStandbyFunction) error
|
||||
// NotifyInitializedStateChange is used by Core to notify that the core is
|
||||
// initialized.
|
||||
// If errors are returned, Vault only logs a warning, so it is
|
||||
// the implementation's responsibility to retry updating state
|
||||
// in the face of errors.
|
||||
NotifyInitializedStateChange(isInitialized bool) error
|
||||
}
|
||||
|
||||
// Callback signatures for RunServiceRegistration
|
||||
type ActiveFunction func() bool
|
||||
type SealedFunction func() bool
|
||||
type PerformanceStandbyFunction func() bool
|
||||
|
||||
// ShutdownChannel is the shutdown signal for RunServiceRegistration
|
||||
type ShutdownChannel chan struct{}
|
||||
|
|
|
@ -1418,7 +1418,7 @@ func (c *Core) unsealInternal(ctx context.Context, masterKey []byte) (bool, erro
|
|||
}
|
||||
|
||||
if c.serviceRegistration != nil {
|
||||
if err := c.serviceRegistration.NotifySealedStateChange(); err != nil {
|
||||
if err := c.serviceRegistration.NotifySealedStateChange(false); err != nil {
|
||||
if c.logger.IsWarn() {
|
||||
c.logger.Warn("failed to notify unsealed status", "error", err)
|
||||
}
|
||||
|
@ -1719,7 +1719,7 @@ func (c *Core) sealInternalWithOptions(grabStateLock, keepHALock, shutdownRaft b
|
|||
}
|
||||
|
||||
if c.serviceRegistration != nil {
|
||||
if err := c.serviceRegistration.NotifySealedStateChange(); err != nil {
|
||||
if err := c.serviceRegistration.NotifySealedStateChange(true); err != nil {
|
||||
if c.logger.IsWarn() {
|
||||
c.logger.Warn("failed to notify sealed status", "error", err)
|
||||
}
|
||||
|
|
|
@ -18,7 +18,6 @@ import (
|
|||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/sdk/physical"
|
||||
"github.com/hashicorp/vault/sdk/physical/inmem"
|
||||
sr "github.com/hashicorp/vault/serviceregistration"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -2454,30 +2453,32 @@ type mockServiceRegistration struct {
|
|||
notifyActiveCount int
|
||||
notifySealedCount int
|
||||
notifyPerfCount int
|
||||
notifyInitCount int
|
||||
runDiscoveryCount int
|
||||
}
|
||||
|
||||
func (m *mockServiceRegistration) NotifyActiveStateChange() error {
|
||||
func (m *mockServiceRegistration) Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup) error {
|
||||
m.runDiscoveryCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockServiceRegistration) NotifyActiveStateChange(isActive bool) error {
|
||||
m.notifyActiveCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockServiceRegistration) NotifySealedStateChange() error {
|
||||
func (m *mockServiceRegistration) NotifySealedStateChange(isSealed bool) error {
|
||||
m.notifySealedCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockServiceRegistration) NotifyPerformanceStandbyStateChange() error {
|
||||
func (m *mockServiceRegistration) NotifyPerformanceStandbyStateChange(isStandby bool) error {
|
||||
m.notifyPerfCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockServiceRegistration) RunServiceRegistration(
|
||||
_ *sync.WaitGroup, _ sr.ShutdownChannel, _ string,
|
||||
_ sr.ActiveFunction, _ sr.SealedFunction,
|
||||
_ sr.PerformanceStandbyFunction) error {
|
||||
|
||||
m.runDiscoveryCount++
|
||||
func (m *mockServiceRegistration) NotifyInitializedStateChange(isInitialized bool) error {
|
||||
m.notifyInitCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -2514,21 +2515,6 @@ func TestCore_ServiceRegistration(t *testing.T) {
|
|||
t.Fatal(diff)
|
||||
}
|
||||
|
||||
// Run service discovery on the core
|
||||
wg := &sync.WaitGroup{}
|
||||
var shutdown chan struct{}
|
||||
activeFunc := func() bool {
|
||||
if isLeader, _, _, err := core.Leader(); err == nil {
|
||||
return isLeader
|
||||
}
|
||||
return false
|
||||
}
|
||||
err = sr.RunServiceRegistration(
|
||||
wg, shutdown, redirectAddr, activeFunc, core.Sealed, core.PerfStandby)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Vault should be registered
|
||||
if diff := deep.Equal(sr, &mockServiceRegistration{
|
||||
runDiscoveryCount: 1,
|
||||
|
@ -2555,6 +2541,7 @@ func TestCore_ServiceRegistration(t *testing.T) {
|
|||
runDiscoveryCount: 1,
|
||||
notifyActiveCount: 1,
|
||||
notifySealedCount: 1,
|
||||
notifyInitCount: 1,
|
||||
}); diff != nil {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
|
|
|
@ -925,7 +925,7 @@ func (c *Core) advertiseLeader(ctx context.Context, uuid string, leaderLostCh <-
|
|||
}
|
||||
|
||||
if c.serviceRegistration != nil {
|
||||
if err := c.serviceRegistration.NotifyActiveStateChange(); err != nil {
|
||||
if err := c.serviceRegistration.NotifyActiveStateChange(true); err != nil {
|
||||
if c.logger.IsWarn() {
|
||||
c.logger.Warn("failed to notify active status", "error", err)
|
||||
}
|
||||
|
@ -960,7 +960,7 @@ func (c *Core) clearLeader(uuid string) error {
|
|||
|
||||
// Advertise ourselves as a standby
|
||||
if c.serviceRegistration != nil {
|
||||
if err := c.serviceRegistration.NotifyActiveStateChange(); err != nil {
|
||||
if err := c.serviceRegistration.NotifyActiveStateChange(false); err != nil {
|
||||
if c.logger.IsWarn() {
|
||||
c.logger.Warn("failed to notify standby status", "error", err)
|
||||
}
|
||||
|
|
|
@ -396,6 +396,14 @@ func (c *Core) Initialize(ctx context.Context, initParams *InitParams) (*InitRes
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if c.serviceRegistration != nil {
|
||||
if err := c.serviceRegistration.NotifyInitializedStateChange(true); err != nil {
|
||||
if c.logger.IsWarn() {
|
||||
c.logger.Warn("notification of initialization failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -5,7 +5,12 @@ import (
|
|||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/vault/sdk/helper/parseutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/strutil"
|
||||
)
|
||||
|
||||
|
@ -110,3 +115,72 @@ func ClientTLSConfig(caCert []byte, clientCert []byte, clientKey []byte) (*tls.C
|
|||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
func SetupTLSConfig(conf map[string]string, address string) (*tls.Config, error) {
|
||||
serverName, _, err := net.SplitHostPort(address)
|
||||
switch {
|
||||
case err == nil:
|
||||
case strings.Contains(err.Error(), "missing port"):
|
||||
serverName = conf["address"]
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
insecureSkipVerify := false
|
||||
tlsSkipVerify, ok := conf["tls_skip_verify"]
|
||||
|
||||
if ok && tlsSkipVerify != "" {
|
||||
b, err := parseutil.ParseBool(tlsSkipVerify)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("failed parsing tls_skip_verify parameter: {{err}}", err)
|
||||
}
|
||||
insecureSkipVerify = b
|
||||
}
|
||||
|
||||
tlsMinVersionStr, ok := conf["tls_min_version"]
|
||||
if !ok {
|
||||
// Set the default value
|
||||
tlsMinVersionStr = "tls12"
|
||||
}
|
||||
|
||||
tlsMinVersion, ok := TLSLookup[tlsMinVersionStr]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'tls_min_version'")
|
||||
}
|
||||
|
||||
tlsClientConfig := &tls.Config{
|
||||
MinVersion: tlsMinVersion,
|
||||
InsecureSkipVerify: insecureSkipVerify,
|
||||
ServerName: serverName,
|
||||
}
|
||||
|
||||
_, okCert := conf["tls_cert_file"]
|
||||
_, okKey := conf["tls_key_file"]
|
||||
|
||||
if okCert && okKey {
|
||||
tlsCert, err := tls.LoadX509KeyPair(conf["tls_cert_file"], conf["tls_key_file"])
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("client tls setup failed: {{err}}", err)
|
||||
}
|
||||
|
||||
tlsClientConfig.Certificates = []tls.Certificate{tlsCert}
|
||||
} else if okCert || okKey {
|
||||
return nil, fmt.Errorf("both tls_cert_file and tls_key_file must be provided")
|
||||
}
|
||||
|
||||
if tlsCaFile, ok := conf["tls_ca_file"]; ok {
|
||||
caPool := x509.NewCertPool()
|
||||
|
||||
data, err := ioutil.ReadFile(tlsCaFile)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("failed to read CA file: {{err}}", err)
|
||||
}
|
||||
|
||||
if !caPool.AppendCertsFromPEM(data) {
|
||||
return nil, fmt.Errorf("failed to parse CA certificate")
|
||||
}
|
||||
|
||||
tlsClientConfig.RootCAs = caPool
|
||||
}
|
||||
return tlsClientConfig, nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue