Rejig where the reload functions live

This commit is contained in:
Jeff Mitchell 2016-09-30 00:06:40 -04:00
parent 4a505bfa3e
commit 85315ff188
8 changed files with 47 additions and 20 deletions

View file

@ -5,7 +5,6 @@ import (
auditFile "github.com/hashicorp/vault/builtin/audit/file"
auditSyslog "github.com/hashicorp/vault/builtin/audit/syslog"
"github.com/hashicorp/vault/command/server"
"github.com/hashicorp/vault/version"
credAppId "github.com/hashicorp/vault/builtin/credential/app-id"
@ -87,9 +86,8 @@ func Commands(metaPtr *meta.Meta) map[string]cli.CommandFactory {
"ssh": ssh.Factory,
"rabbitmq": rabbitmq.Factory,
},
ShutdownCh: command.MakeShutdownCh(),
SighupCh: command.MakeSighupCh(),
ReloadFuncs: map[string][]server.ReloadFunc{},
ShutdownCh: command.MakeShutdownCh(),
SighupCh: command.MakeSighupCh(),
}, nil
},

View file

@ -54,7 +54,8 @@ type ServerCommand struct {
logger log.Logger
ReloadFuncs map[string][]server.ReloadFunc
reloadFuncsLock *sync.RWMutex
reloadFuncs *map[string][]vault.ReloadFunc
}
func (c *ServerCommand) Run(args []string) int {
@ -338,6 +339,10 @@ func (c *ServerCommand) Run(args []string) int {
}
}
// Copy the reload funcs pointers back
c.reloadFuncs = coreConfig.ReloadFuncs
c.reloadFuncsLock = coreConfig.ReloadFuncsLock
// Compile server information for output later
info["backend"] = config.Backend.Type
info["log level"] = logLevel
@ -374,6 +379,7 @@ func (c *ServerCommand) Run(args []string) int {
clusterAddrs := []*net.TCPAddr{}
// Initialize the listeners
c.reloadFuncsLock.Lock()
lns := make([]net.Listener, 0, len(config.Listeners))
for i, lnConfig := range config.Listeners {
if lnConfig.Type == "atlas" {
@ -396,9 +402,9 @@ func (c *ServerCommand) Run(args []string) int {
lns = append(lns, ln)
if reloadFunc != nil {
relSlice := c.ReloadFuncs["listener|"+lnConfig.Type]
relSlice := (*c.reloadFuncs)["listener|"+lnConfig.Type]
relSlice = append(relSlice, reloadFunc)
c.ReloadFuncs["listener|"+lnConfig.Type] = relSlice
(*c.reloadFuncs)["listener|"+lnConfig.Type] = relSlice
}
if !disableClustering && lnConfig.Type == "tcp" {
@ -440,6 +446,7 @@ func (c *ServerCommand) Run(args []string) int {
"%s (%s)", lnConfig.Type, strings.Join(propsList, ", "))
}
c.reloadFuncsLock.Unlock()
if !disableClustering {
if c.logger.IsTrace() {
c.logger.Trace("cluster listener addresses synthesized", "cluster_addresses", clusterAddrs)
@ -855,11 +862,14 @@ func (c *ServerCommand) Reload(configPath []string) error {
return retErr
}
c.reloadFuncsLock.RLock()
defer c.reloadFuncsLock.RUnlock()
var reloadErrors *multierror.Error
// Call reload on the listeners. This will call each listener with each
// config block, but they verify the address.
for _, lnConfig := range config.Listeners {
for _, relFunc := range c.ReloadFuncs["listener|"+lnConfig.Type] {
for _, relFunc := range (*c.reloadFuncs)["listener|"+lnConfig.Type] {
if err := relFunc(lnConfig.Config); err != nil {
retErr := fmt.Errorf("Error encountered reloading configuration: %s", err)
reloadErrors = multierror.Append(retErr)

View file

@ -17,9 +17,6 @@ import (
"github.com/hashicorp/hcl/hcl/ast"
)
// ReloadFunc are functions that are called when a reload is requested.
type ReloadFunc func(map[string]string) error
// Config is the configuration for the vault server.
type Config struct {
Listeners []*Listener `hcl:"-"`

View file

@ -12,10 +12,11 @@ import (
"sync"
"github.com/hashicorp/vault/helper/tlsutil"
"github.com/hashicorp/vault/vault"
)
// ListenerFactory is the factory function to create a listener.
type ListenerFactory func(map[string]string, io.Writer) (net.Listener, map[string]string, ReloadFunc, error)
type ListenerFactory func(map[string]string, io.Writer) (net.Listener, map[string]string, vault.ReloadFunc, error)
// BuiltinListeners is the list of built-in listener types.
var BuiltinListeners = map[string]ListenerFactory{
@ -25,7 +26,7 @@ var BuiltinListeners = map[string]ListenerFactory{
// NewListener creates a new listener of the given type with the given
// configuration. The type is looked up in the BuiltinListeners map.
func NewListener(t string, config map[string]string, logger io.Writer) (net.Listener, map[string]string, ReloadFunc, error) {
func NewListener(t string, config map[string]string, logger io.Writer) (net.Listener, map[string]string, vault.ReloadFunc, error) {
f, ok := BuiltinListeners[t]
if !ok {
return nil, nil, nil, fmt.Errorf("unknown listener type: %s", t)
@ -37,7 +38,7 @@ func NewListener(t string, config map[string]string, logger io.Writer) (net.List
func listenerWrapTLS(
ln net.Listener,
props map[string]string,
config map[string]string) (net.Listener, map[string]string, ReloadFunc, error) {
config map[string]string) (net.Listener, map[string]string, vault.ReloadFunc, error) {
props["tls"] = "disabled"
if v, ok := config["tls_disable"]; ok {

View file

@ -5,6 +5,7 @@ import (
"net"
"github.com/hashicorp/scada-client/scada"
"github.com/hashicorp/vault/vault"
"github.com/hashicorp/vault/version"
)
@ -26,7 +27,7 @@ func (s *SCADAListener) Addr() net.Addr {
return s.ln.Addr()
}
func atlasListenerFactory(config map[string]string, logger io.Writer) (net.Listener, map[string]string, ReloadFunc, error) {
func atlasListenerFactory(config map[string]string, logger io.Writer) (net.Listener, map[string]string, vault.ReloadFunc, error) {
scadaConfig := &scada.Config{
Service: "vault",
Version: version.GetVersion().VersionNumber(),

View file

@ -4,9 +4,11 @@ import (
"io"
"net"
"time"
"github.com/hashicorp/vault/vault"
)
func tcpListenerFactory(config map[string]string, _ io.Writer) (net.Listener, map[string]string, ReloadFunc, error) {
func tcpListenerFactory(config map[string]string, _ io.Writer) (net.Listener, map[string]string, vault.ReloadFunc, error) {
addr, ok := config["address"]
if !ok {
addr = "127.0.0.1:8200"

View file

@ -14,7 +14,6 @@ import (
"testing"
"time"
"github.com/hashicorp/vault/command/server"
"github.com/hashicorp/vault/meta"
"github.com/mitchellh/cli"
)
@ -183,9 +182,8 @@ func TestServer_ReloadListener(t *testing.T) {
Meta: meta.Meta{
Ui: ui,
},
ShutdownCh: MakeShutdownCh(),
SighupCh: MakeSighupCh(),
ReloadFuncs: map[string][]server.ReloadFunc{},
ShutdownCh: MakeShutdownCh(),
SighupCh: MakeSighupCh(),
}
finished := false

View file

@ -89,6 +89,9 @@ var (
manualStepDownSleepPeriod = 10 * time.Second
)
// ReloadFunc are functions that are called when a reload is requested.
type ReloadFunc func(map[string]string) error
// NonFatalError is an error that can be returned during NewCore that should be
// displayed but not cause a program exit
type NonFatalError struct {
@ -242,6 +245,12 @@ type Core struct {
// cachingDisabled indicates whether caches are disabled
cachingDisabled bool
// reloadFuncs is a map containing reload functions
reloadFuncs map[string][]ReloadFunc
// reloadFuncsLock controlls access to the funcs
reloadFuncsLock sync.RWMutex
//
// Cluster information
//
@ -322,6 +331,9 @@ type CoreConfig struct {
MaxLeaseTTL time.Duration `json:"max_lease_ttl" structs:"max_lease_ttl" mapstructure:"max_lease_ttl"`
ClusterName string `json:"cluster_name" structs:"cluster_name" mapstructure:"cluster_name"`
ReloadFuncs *map[string][]ReloadFunc
ReloadFuncsLock *sync.RWMutex
}
// NewCore is used to construct a new core
@ -415,6 +427,14 @@ func NewCore(conf *CoreConfig) (*Core, error) {
c.ha = conf.HAPhysical
}
// We create the funcs here, then populate the given config with it so that
// the caller can share state
conf.ReloadFuncsLock = &c.reloadFuncsLock
c.reloadFuncsLock.Lock()
c.reloadFuncs = make(map[string][]ReloadFunc)
c.reloadFuncsLock.Unlock()
conf.ReloadFuncs = &c.reloadFuncs
// Setup the backends
logicalBackends := make(map[string]logical.Factory)
for k, f := range conf.LogicalBackends {