Rejig where the reload functions live
This commit is contained in:
parent
4a505bfa3e
commit
85315ff188
|
@ -5,7 +5,6 @@ import (
|
|||
|
||||
auditFile "github.com/hashicorp/vault/builtin/audit/file"
|
||||
auditSyslog "github.com/hashicorp/vault/builtin/audit/syslog"
|
||||
"github.com/hashicorp/vault/command/server"
|
||||
"github.com/hashicorp/vault/version"
|
||||
|
||||
credAppId "github.com/hashicorp/vault/builtin/credential/app-id"
|
||||
|
@ -87,9 +86,8 @@ func Commands(metaPtr *meta.Meta) map[string]cli.CommandFactory {
|
|||
"ssh": ssh.Factory,
|
||||
"rabbitmq": rabbitmq.Factory,
|
||||
},
|
||||
ShutdownCh: command.MakeShutdownCh(),
|
||||
SighupCh: command.MakeSighupCh(),
|
||||
ReloadFuncs: map[string][]server.ReloadFunc{},
|
||||
ShutdownCh: command.MakeShutdownCh(),
|
||||
SighupCh: command.MakeSighupCh(),
|
||||
}, nil
|
||||
},
|
||||
|
||||
|
|
|
@ -54,7 +54,8 @@ type ServerCommand struct {
|
|||
|
||||
logger log.Logger
|
||||
|
||||
ReloadFuncs map[string][]server.ReloadFunc
|
||||
reloadFuncsLock *sync.RWMutex
|
||||
reloadFuncs *map[string][]vault.ReloadFunc
|
||||
}
|
||||
|
||||
func (c *ServerCommand) Run(args []string) int {
|
||||
|
@ -338,6 +339,10 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
}
|
||||
}
|
||||
|
||||
// Copy the reload funcs pointers back
|
||||
c.reloadFuncs = coreConfig.ReloadFuncs
|
||||
c.reloadFuncsLock = coreConfig.ReloadFuncsLock
|
||||
|
||||
// Compile server information for output later
|
||||
info["backend"] = config.Backend.Type
|
||||
info["log level"] = logLevel
|
||||
|
@ -374,6 +379,7 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
clusterAddrs := []*net.TCPAddr{}
|
||||
|
||||
// Initialize the listeners
|
||||
c.reloadFuncsLock.Lock()
|
||||
lns := make([]net.Listener, 0, len(config.Listeners))
|
||||
for i, lnConfig := range config.Listeners {
|
||||
if lnConfig.Type == "atlas" {
|
||||
|
@ -396,9 +402,9 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
lns = append(lns, ln)
|
||||
|
||||
if reloadFunc != nil {
|
||||
relSlice := c.ReloadFuncs["listener|"+lnConfig.Type]
|
||||
relSlice := (*c.reloadFuncs)["listener|"+lnConfig.Type]
|
||||
relSlice = append(relSlice, reloadFunc)
|
||||
c.ReloadFuncs["listener|"+lnConfig.Type] = relSlice
|
||||
(*c.reloadFuncs)["listener|"+lnConfig.Type] = relSlice
|
||||
}
|
||||
|
||||
if !disableClustering && lnConfig.Type == "tcp" {
|
||||
|
@ -440,6 +446,7 @@ func (c *ServerCommand) Run(args []string) int {
|
|||
"%s (%s)", lnConfig.Type, strings.Join(propsList, ", "))
|
||||
|
||||
}
|
||||
c.reloadFuncsLock.Unlock()
|
||||
if !disableClustering {
|
||||
if c.logger.IsTrace() {
|
||||
c.logger.Trace("cluster listener addresses synthesized", "cluster_addresses", clusterAddrs)
|
||||
|
@ -855,11 +862,14 @@ func (c *ServerCommand) Reload(configPath []string) error {
|
|||
return retErr
|
||||
}
|
||||
|
||||
c.reloadFuncsLock.RLock()
|
||||
defer c.reloadFuncsLock.RUnlock()
|
||||
|
||||
var reloadErrors *multierror.Error
|
||||
// Call reload on the listeners. This will call each listener with each
|
||||
// config block, but they verify the address.
|
||||
for _, lnConfig := range config.Listeners {
|
||||
for _, relFunc := range c.ReloadFuncs["listener|"+lnConfig.Type] {
|
||||
for _, relFunc := range (*c.reloadFuncs)["listener|"+lnConfig.Type] {
|
||||
if err := relFunc(lnConfig.Config); err != nil {
|
||||
retErr := fmt.Errorf("Error encountered reloading configuration: %s", err)
|
||||
reloadErrors = multierror.Append(retErr)
|
||||
|
|
|
@ -17,9 +17,6 @@ import (
|
|||
"github.com/hashicorp/hcl/hcl/ast"
|
||||
)
|
||||
|
||||
// ReloadFunc are functions that are called when a reload is requested.
|
||||
type ReloadFunc func(map[string]string) error
|
||||
|
||||
// Config is the configuration for the vault server.
|
||||
type Config struct {
|
||||
Listeners []*Listener `hcl:"-"`
|
||||
|
|
|
@ -12,10 +12,11 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/hashicorp/vault/helper/tlsutil"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
||||
// ListenerFactory is the factory function to create a listener.
|
||||
type ListenerFactory func(map[string]string, io.Writer) (net.Listener, map[string]string, ReloadFunc, error)
|
||||
type ListenerFactory func(map[string]string, io.Writer) (net.Listener, map[string]string, vault.ReloadFunc, error)
|
||||
|
||||
// BuiltinListeners is the list of built-in listener types.
|
||||
var BuiltinListeners = map[string]ListenerFactory{
|
||||
|
@ -25,7 +26,7 @@ var BuiltinListeners = map[string]ListenerFactory{
|
|||
|
||||
// NewListener creates a new listener of the given type with the given
|
||||
// configuration. The type is looked up in the BuiltinListeners map.
|
||||
func NewListener(t string, config map[string]string, logger io.Writer) (net.Listener, map[string]string, ReloadFunc, error) {
|
||||
func NewListener(t string, config map[string]string, logger io.Writer) (net.Listener, map[string]string, vault.ReloadFunc, error) {
|
||||
f, ok := BuiltinListeners[t]
|
||||
if !ok {
|
||||
return nil, nil, nil, fmt.Errorf("unknown listener type: %s", t)
|
||||
|
@ -37,7 +38,7 @@ func NewListener(t string, config map[string]string, logger io.Writer) (net.List
|
|||
func listenerWrapTLS(
|
||||
ln net.Listener,
|
||||
props map[string]string,
|
||||
config map[string]string) (net.Listener, map[string]string, ReloadFunc, error) {
|
||||
config map[string]string) (net.Listener, map[string]string, vault.ReloadFunc, error) {
|
||||
props["tls"] = "disabled"
|
||||
|
||||
if v, ok := config["tls_disable"]; ok {
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"net"
|
||||
|
||||
"github.com/hashicorp/scada-client/scada"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
"github.com/hashicorp/vault/version"
|
||||
)
|
||||
|
||||
|
@ -26,7 +27,7 @@ func (s *SCADAListener) Addr() net.Addr {
|
|||
return s.ln.Addr()
|
||||
}
|
||||
|
||||
func atlasListenerFactory(config map[string]string, logger io.Writer) (net.Listener, map[string]string, ReloadFunc, error) {
|
||||
func atlasListenerFactory(config map[string]string, logger io.Writer) (net.Listener, map[string]string, vault.ReloadFunc, error) {
|
||||
scadaConfig := &scada.Config{
|
||||
Service: "vault",
|
||||
Version: version.GetVersion().VersionNumber(),
|
||||
|
|
|
@ -4,9 +4,11 @@ import (
|
|||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
||||
func tcpListenerFactory(config map[string]string, _ io.Writer) (net.Listener, map[string]string, ReloadFunc, error) {
|
||||
func tcpListenerFactory(config map[string]string, _ io.Writer) (net.Listener, map[string]string, vault.ReloadFunc, error) {
|
||||
addr, ok := config["address"]
|
||||
if !ok {
|
||||
addr = "127.0.0.1:8200"
|
||||
|
|
|
@ -14,7 +14,6 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/command/server"
|
||||
"github.com/hashicorp/vault/meta"
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
@ -183,9 +182,8 @@ func TestServer_ReloadListener(t *testing.T) {
|
|||
Meta: meta.Meta{
|
||||
Ui: ui,
|
||||
},
|
||||
ShutdownCh: MakeShutdownCh(),
|
||||
SighupCh: MakeSighupCh(),
|
||||
ReloadFuncs: map[string][]server.ReloadFunc{},
|
||||
ShutdownCh: MakeShutdownCh(),
|
||||
SighupCh: MakeSighupCh(),
|
||||
}
|
||||
|
||||
finished := false
|
||||
|
|
|
@ -89,6 +89,9 @@ var (
|
|||
manualStepDownSleepPeriod = 10 * time.Second
|
||||
)
|
||||
|
||||
// ReloadFunc are functions that are called when a reload is requested.
|
||||
type ReloadFunc func(map[string]string) error
|
||||
|
||||
// NonFatalError is an error that can be returned during NewCore that should be
|
||||
// displayed but not cause a program exit
|
||||
type NonFatalError struct {
|
||||
|
@ -242,6 +245,12 @@ type Core struct {
|
|||
// cachingDisabled indicates whether caches are disabled
|
||||
cachingDisabled bool
|
||||
|
||||
// reloadFuncs is a map containing reload functions
|
||||
reloadFuncs map[string][]ReloadFunc
|
||||
|
||||
// reloadFuncsLock controlls access to the funcs
|
||||
reloadFuncsLock sync.RWMutex
|
||||
|
||||
//
|
||||
// Cluster information
|
||||
//
|
||||
|
@ -322,6 +331,9 @@ type CoreConfig struct {
|
|||
MaxLeaseTTL time.Duration `json:"max_lease_ttl" structs:"max_lease_ttl" mapstructure:"max_lease_ttl"`
|
||||
|
||||
ClusterName string `json:"cluster_name" structs:"cluster_name" mapstructure:"cluster_name"`
|
||||
|
||||
ReloadFuncs *map[string][]ReloadFunc
|
||||
ReloadFuncsLock *sync.RWMutex
|
||||
}
|
||||
|
||||
// NewCore is used to construct a new core
|
||||
|
@ -415,6 +427,14 @@ func NewCore(conf *CoreConfig) (*Core, error) {
|
|||
c.ha = conf.HAPhysical
|
||||
}
|
||||
|
||||
// We create the funcs here, then populate the given config with it so that
|
||||
// the caller can share state
|
||||
conf.ReloadFuncsLock = &c.reloadFuncsLock
|
||||
c.reloadFuncsLock.Lock()
|
||||
c.reloadFuncs = make(map[string][]ReloadFunc)
|
||||
c.reloadFuncsLock.Unlock()
|
||||
conf.ReloadFuncs = &c.reloadFuncs
|
||||
|
||||
// Setup the backends
|
||||
logicalBackends := make(map[string]logical.Factory)
|
||||
for k, f := range conf.LogicalBackends {
|
||||
|
|
Loading…
Reference in a new issue