diff --git a/audit/audit.go b/audit/audit.go index 3f1d5edfc..dffa8eee5 100644 --- a/audit/audit.go +++ b/audit/audit.go @@ -26,6 +26,9 @@ type Backend interface { // so that a caller can determine if a value in the audit log matches // an expected plaintext value GetHash(string) string + + // Reload is called on SIGHUP for supporting backends. + Reload() error } type BackendConfig struct { diff --git a/builtin/audit/file/backend.go b/builtin/audit/file/backend.go index 80e654630..afdc8e019 100644 --- a/builtin/audit/file/backend.go +++ b/builtin/audit/file/backend.go @@ -91,8 +91,8 @@ type Backend struct { formatter audit.AuditFormatter formatConfig audit.FormatterConfig - once sync.Once - f *os.File + fileLock sync.RWMutex + f *os.File } func (b *Backend) GetHash(data string) string { @@ -100,6 +100,9 @@ func (b *Backend) GetHash(data string) string { } func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr error) error { + b.fileLock.Lock() + defer b.fileLock.Unlock() + if err := b.open(); err != nil { return err } @@ -112,6 +115,10 @@ func (b *Backend) LogResponse( req *logical.Request, resp *logical.Response, err error) error { + + b.fileLock.Lock() + defer b.fileLock.Unlock() + if err := b.open(); err != nil { return err } @@ -119,6 +126,7 @@ func (b *Backend) LogResponse( return b.formatter.FormatResponse(b.f, b.formatConfig, auth, req, resp, err) } +// The file lock must be held before calling this func (b *Backend) open() error { if b.f != nil { return nil @@ -135,3 +143,22 @@ func (b *Backend) open() error { return nil } + +func (b *Backend) Reload() error { + b.fileLock.Lock() + defer b.fileLock.Unlock() + + if b.f == nil { + return b.open() + } + + err := b.f.Close() + // Set to nil here so that even if we error out, on the next access open() + // will be tried + b.f = nil + if err != nil { + return err + } + + return b.open() +} diff --git a/builtin/audit/syslog/backend.go b/builtin/audit/syslog/backend.go index bde7ca764..2d594af8e 100644 --- a/builtin/audit/syslog/backend.go +++ b/builtin/audit/syslog/backend.go @@ -116,3 +116,7 @@ func (b *Backend) LogResponse(auth *logical.Auth, req *logical.Request, _, err = b.logger.Write(buf.Bytes()) return err } + +func (b *Backend) Reload() error { + return nil +} diff --git a/command/server.go b/command/server.go index 3c35d40e6..2663a05c6 100644 --- a/command/server.go +++ b/command/server.go @@ -69,7 +69,7 @@ func (c *ServerCommand) Run(args []string) int { flags.StringVar(&logLevel, "log-level", "info", "") flags.BoolVar(&verifyOnly, "verify-only", false, "") flags.BoolVar(&devHA, "dev-ha", false, "") - flags.Usage = func() { c.Ui.Error(c.Help()) } + flags.Usage = func() { c.Ui.Output(c.Help()) } flags.Var((*sliceflag.StringFlag)(&configPath), "config", "config") if err := flags.Parse(args); err != nil { return 1 @@ -128,11 +128,11 @@ func (c *ServerCommand) Run(args []string) int { if !dev { switch { case len(configPath) == 0: - c.Ui.Error("At least one config path must be specified with -config") + c.Ui.Output("At least one config path must be specified with -config") flags.Usage() return 1 case devRootTokenID != "": - c.Ui.Error("Root token ID can only be specified with -dev") + c.Ui.Output("Root token ID can only be specified with -dev") flags.Usage() return 1 } @@ -149,7 +149,7 @@ func (c *ServerCommand) Run(args []string) int { for _, path := range configPath { current, err := server.LoadConfig(path, c.logger) if err != nil { - c.Ui.Error(fmt.Sprintf( + c.Ui.Output(fmt.Sprintf( "Error loading configuration from %s: %s", path, err)) return 1 } @@ -163,13 +163,13 @@ func (c *ServerCommand) Run(args []string) int { // Ensure at least one config was found. if config == nil { - c.Ui.Error("No configuration files found.") + c.Ui.Output("No configuration files found.") return 1 } // Ensure that a backend is provided if config.Backend == nil { - c.Ui.Error("A physical backend must be specified") + c.Ui.Output("A physical backend must be specified") return 1 } @@ -183,7 +183,7 @@ func (c *ServerCommand) Run(args []string) int { } if err := c.setupTelemetry(config); err != nil { - c.Ui.Error(fmt.Sprintf("Error initializing telemetry: %s", err)) + c.Ui.Output(fmt.Sprintf("Error initializing telemetry: %s", err)) return 1 } @@ -191,7 +191,7 @@ func (c *ServerCommand) Run(args []string) int { backend, err := physical.NewBackend( config.Backend.Type, c.logger, config.Backend.Config) if err != nil { - c.Ui.Error(fmt.Sprintf( + c.Ui.Output(fmt.Sprintf( "Error initializing backend of type %s: %s", config.Backend.Type, err)) return 1 @@ -206,7 +206,7 @@ func (c *ServerCommand) Run(args []string) int { defer func() { err = seal.Finalize() if err != nil { - c.Ui.Error(fmt.Sprintf("Error finalizing seals: %v", err)) + c.Ui.Output(fmt.Sprintf("Error finalizing seals: %v", err)) } }() @@ -235,19 +235,19 @@ func (c *ServerCommand) Run(args []string) int { habackend, err := physical.NewBackend( config.HABackend.Type, c.logger, config.HABackend.Config) if err != nil { - c.Ui.Error(fmt.Sprintf( + c.Ui.Output(fmt.Sprintf( "Error initializing backend of type %s: %s", config.HABackend.Type, err)) return 1 } if coreConfig.HAPhysical, ok = habackend.(physical.HABackend); !ok { - c.Ui.Error("Specified HA backend does not support HA") + c.Ui.Output("Specified HA backend does not support HA") return 1 } if !coreConfig.HAPhysical.HAEnabled() { - c.Ui.Error("Specified HA backend has HA support disabled; please consult documentation") + c.Ui.Output("Specified HA backend has HA support disabled; please consult documentation") return 1 } @@ -282,9 +282,9 @@ func (c *ServerCommand) Run(args []string) int { if ok && coreConfig.RedirectAddr == "" { redirect, err := c.detectRedirect(detect, config) if err != nil { - c.Ui.Error(fmt.Sprintf("Error detecting redirect address: %s", err)) + c.Ui.Output(fmt.Sprintf("Error detecting redirect address: %s", err)) } else if redirect == "" { - c.Ui.Error("Failed to detect redirect address.") + c.Ui.Output("Failed to detect redirect address.") } else { coreConfig.RedirectAddr = redirect } @@ -299,7 +299,7 @@ func (c *ServerCommand) Run(args []string) int { } else if coreConfig.ClusterAddr == "" && coreConfig.RedirectAddr != "" { u, err := url.ParseRequestURI(coreConfig.RedirectAddr) if err != nil { - c.Ui.Error(fmt.Sprintf("Error parsing redirect address %s: %v", coreConfig.RedirectAddr, err)) + c.Ui.Output(fmt.Sprintf("Error parsing redirect address %s: %v", coreConfig.RedirectAddr, err)) return 1 } host, port, err := net.SplitHostPort(u.Host) @@ -311,7 +311,7 @@ func (c *ServerCommand) Run(args []string) int { nPort = 443 } if nPortErr != nil { - c.Ui.Error(fmt.Sprintf("Cannot parse %s as a numeric port: %v", port, nPortErr)) + c.Ui.Output(fmt.Sprintf("Cannot parse %s as a numeric port: %v", port, nPortErr)) return 1 } u.Host = net.JoinHostPort(host, strconv.Itoa(nPort+1)) @@ -323,7 +323,7 @@ func (c *ServerCommand) Run(args []string) int { // Force https as we'll always be TLS-secured u, err := url.ParseRequestURI(coreConfig.ClusterAddr) if err != nil { - c.Ui.Error(fmt.Sprintf("Error parsing cluster address %s: %v", coreConfig.RedirectAddr, err)) + c.Ui.Output(fmt.Sprintf("Error parsing cluster address %s: %v", coreConfig.RedirectAddr, err)) return 1 } u.Scheme = "https" @@ -334,7 +334,7 @@ func (c *ServerCommand) Run(args []string) int { core, newCoreError := vault.NewCore(coreConfig) if newCoreError != nil { if !errwrap.ContainsType(newCoreError, new(vault.NonFatalError)) { - c.Ui.Error(fmt.Sprintf("Error initializing core: %s", newCoreError)) + c.Ui.Output(fmt.Sprintf("Error initializing core: %s", newCoreError)) return 1 } } @@ -384,7 +384,7 @@ func (c *ServerCommand) Run(args []string) int { for i, lnConfig := range config.Listeners { if lnConfig.Type == "atlas" { if config.ClusterName == "" { - c.Ui.Error("cluster_name is not set in the config and is a required value") + c.Ui.Output("cluster_name is not set in the config and is a required value") return 1 } @@ -393,7 +393,7 @@ func (c *ServerCommand) Run(args []string) int { ln, props, reloadFunc, err := server.NewListener(lnConfig.Type, lnConfig.Config, logGate) if err != nil { - c.Ui.Error(fmt.Sprintf( + c.Ui.Output(fmt.Sprintf( "Error initializing listener of type %s: %s", lnConfig.Type, err)) return 1 @@ -413,7 +413,7 @@ func (c *ServerCommand) Run(args []string) int { if addr, ok = lnConfig.Config["cluster_address"]; ok { tcpAddr, err := net.ResolveTCPAddr("tcp", addr) if err != nil { - c.Ui.Error(fmt.Sprintf( + c.Ui.Output(fmt.Sprintf( "Error resolving cluster_address: %s", err)) return 1 @@ -422,7 +422,7 @@ func (c *ServerCommand) Run(args []string) int { } else { tcpAddr, ok := ln.Addr().(*net.TCPAddr) if !ok { - c.Ui.Error("Failed to parse tcp listener") + c.Ui.Output("Failed to parse tcp listener") return 1 } clusterAddrs = append(clusterAddrs, &net.TCPAddr{ @@ -505,7 +505,7 @@ func (c *ServerCommand) Run(args []string) int { } if err := sd.RunServiceDiscovery(c.WaitGroup, c.ShutdownCh, coreConfig.RedirectAddr, activeFunc, sealedFunc); err != nil { - c.Ui.Error(fmt.Sprintf("Error initializing service discovery: %v", err)) + c.Ui.Output(fmt.Sprintf("Error initializing service discovery: %v", err)) return 1 } } @@ -522,7 +522,7 @@ func (c *ServerCommand) Run(args []string) int { if dev { init, err := c.enableDev(core, devRootTokenID) if err != nil { - c.Ui.Error(fmt.Sprintf( + c.Ui.Output(fmt.Sprintf( "Error initializing dev mode: %s", err)) return 1 } @@ -577,13 +577,13 @@ func (c *ServerCommand) Run(args []string) int { case <-c.ShutdownCh: c.Ui.Output("==> Vault shutdown triggered") if err := core.Shutdown(); err != nil { - c.Ui.Error(fmt.Sprintf("Error with core shutdown: %s", err)) + c.Ui.Output(fmt.Sprintf("Error with core shutdown: %s", err)) } shutdownTriggered = true case <-c.SighupCh: c.Ui.Output("==> Vault reload triggered") if err := c.Reload(configPath); err != nil { - c.Ui.Error(fmt.Sprintf("Error(s) were encountered during reload: %s", err)) + c.Ui.Output(fmt.Sprintf("Error(s) were encountered during reload: %s", err)) } } } @@ -838,14 +838,18 @@ func (c *ServerCommand) setupTelemetry(config *server.Config) error { } func (c *ServerCommand) Reload(configPath []string) error { + c.reloadFuncsLock.RLock() + defer c.reloadFuncsLock.RUnlock() + + var reloadErrors *multierror.Error + // Read the new config var config *server.Config for _, path := range configPath { current, err := server.LoadConfig(path, c.logger) if err != nil { - retErr := fmt.Errorf("Error loading configuration from %s: %s", path, err) - c.Ui.Error(retErr.Error()) - return retErr + reloadErrors = multierror.Append(reloadErrors, fmt.Errorf("Error loading configuration from %s: %s", path, err)) + goto audit } if config == nil { @@ -857,22 +861,32 @@ func (c *ServerCommand) Reload(configPath []string) error { // Ensure at least one config was found. if config == nil { - retErr := fmt.Errorf("No configuration files found") - c.Ui.Error(retErr.Error()) - return retErr + reloadErrors = multierror.Append(reloadErrors, fmt.Errorf("No configuration files found")) + goto audit } - c.reloadFuncsLock.RLock() - defer c.reloadFuncsLock.RUnlock() - - var reloadErrors *multierror.Error // Call reload on the listeners. This will call each listener with each // config block, but they verify the address. for _, lnConfig := range config.Listeners { for _, relFunc := range (*c.reloadFuncs)["listener|"+lnConfig.Type] { if err := relFunc(lnConfig.Config); err != nil { - retErr := fmt.Errorf("Error encountered reloading configuration: %s", err) - reloadErrors = multierror.Append(retErr) + reloadErrors = multierror.Append(reloadErrors, fmt.Errorf("Error encountered reloading configuration: %s", err)) + goto audit + } + } + } + +audit: + // file audit reload funcs + for k, relFuncs := range *c.reloadFuncs { + if !strings.HasPrefix(k, "audit_file|") { + continue + } + for _, relFunc := range relFuncs { + if relFunc != nil { + if err := relFunc(nil); err != nil { + reloadErrors = multierror.Append(reloadErrors, fmt.Errorf("Error encountered reloading file audit backend at path %s: %v", strings.TrimPrefix(k, "audit_file|"), err)) + } } } } diff --git a/vault/audit.go b/vault/audit.go index db47a9c8b..33921a173 100644 --- a/vault/audit.go +++ b/vault/audit.go @@ -77,7 +77,7 @@ func (c *Core) enableAudit(entry *MountEntry) error { view := NewBarrierView(c.barrier, auditBarrierPrefix+entry.UUID+"/") // Lookup the new backend - backend, err := c.newAuditBackend(entry.Type, view, entry.Options) + backend, err := c.newAuditBackend(entry, view, entry.Options) if err != nil { return err } @@ -110,13 +110,15 @@ func (c *Core) disableAudit(path string) (bool, error) { defer c.auditLock.Unlock() newTable := c.audit.shallowClone() - found := newTable.remove(path) + entry := newTable.remove(path) // Ensure there was a match - if !found { + if entry == nil { return false, fmt.Errorf("no matching backend") } + c.removeAuditReloadFunc(entry) + // Update the audit table if err := c.persistAudit(newTable); err != nil { return true, errors.New("failed to update audit table") @@ -235,7 +237,7 @@ func (c *Core) setupAudits() error { view := NewBarrierView(c.barrier, auditBarrierPrefix+entry.UUID+"/") // Initialize the backend - audit, err := c.newAuditBackend(entry.Type, view, entry.Options) + audit, err := c.newAuditBackend(entry, view, entry.Options) if err != nil { c.logger.Error("core: failed to create audit entry", "path", entry.Path, "error", err) return errLoadAuditFailed @@ -254,16 +256,38 @@ func (c *Core) teardownAudits() error { c.auditLock.Lock() defer c.auditLock.Unlock() + for _, entry := range c.audit.Entries { + c.removeAuditReloadFunc(entry) + } + c.audit = nil c.auditBroker = nil return nil } +// removeAuditReloadFunc removes the reload func from the working set. The +// audit lock needs to be held before calling this. +func (c *Core) removeAuditReloadFunc(entry *MountEntry) { + switch entry.Type { + case "file": + key := "audit_file|" + entry.Path + c.reloadFuncsLock.Lock() + + if c.logger.IsDebug() { + c.logger.Debug("audit: removing reload function", "path", entry.Path) + } + + delete(c.reloadFuncs, key) + + c.reloadFuncsLock.Unlock() + } +} + // newAuditBackend is used to create and configure a new audit backend by name -func (c *Core) newAuditBackend(t string, view logical.Storage, conf map[string]string) (audit.Backend, error) { - f, ok := c.auditBackends[t] +func (c *Core) newAuditBackend(entry *MountEntry, view logical.Storage, conf map[string]string) (audit.Backend, error) { + f, ok := c.auditBackends[entry.Type] if !ok { - return nil, fmt.Errorf("unknown backend type: %s", t) + return nil, fmt.Errorf("unknown backend type: %s", entry.Type) } salter, err := salt.NewSalt(view, &salt.Config{ HMAC: sha256.New, @@ -272,10 +296,36 @@ func (c *Core) newAuditBackend(t string, view logical.Storage, conf map[string]s if err != nil { return nil, fmt.Errorf("core: unable to generate salt: %v", err) } - return f(&audit.BackendConfig{ + + be, err := f(&audit.BackendConfig{ Salt: salter, Config: conf, }) + if err != nil { + return nil, err + } + + switch entry.Type { + case "file": + key := "audit_file|" + entry.Path + + c.reloadFuncsLock.Lock() + + if c.logger.IsDebug() { + c.logger.Debug("audit: adding reload function", "path", entry.Path) + } + + c.reloadFuncs[key] = append(c.reloadFuncs[key], func(map[string]string) error { + if c.logger.IsInfo() { + c.logger.Info("audit: reloading file audit backend", "path", entry.Path) + } + return be.Reload() + }) + + c.reloadFuncsLock.Unlock() + } + + return be, err } // defaultAuditTable creates a default audit table @@ -294,7 +344,7 @@ type backendEntry struct { // AuditBroker is used to provide a single ingest interface to auditable // events given that multiple backends may be configured. type AuditBroker struct { - l sync.RWMutex + sync.RWMutex backends map[string]backendEntry logger log.Logger } @@ -310,8 +360,8 @@ func NewAuditBroker(log log.Logger) *AuditBroker { // Register is used to add new audit backend to the broker func (a *AuditBroker) Register(name string, b audit.Backend, v *BarrierView) { - a.l.Lock() - defer a.l.Unlock() + a.Lock() + defer a.Unlock() a.backends[name] = backendEntry{ backend: b, view: v, @@ -320,23 +370,23 @@ func (a *AuditBroker) Register(name string, b audit.Backend, v *BarrierView) { // Deregister is used to remove an audit backend from the broker func (a *AuditBroker) Deregister(name string) { - a.l.Lock() - defer a.l.Unlock() + a.Lock() + defer a.Unlock() delete(a.backends, name) } // IsRegistered is used to check if a given audit backend is registered func (a *AuditBroker) IsRegistered(name string) bool { - a.l.RLock() - defer a.l.RUnlock() + a.RLock() + defer a.RUnlock() _, ok := a.backends[name] return ok } // GetHash returns a hash using the salt of the given backend func (a *AuditBroker) GetHash(name string, input string) (string, error) { - a.l.RLock() - defer a.l.RUnlock() + a.RLock() + defer a.RUnlock() be, ok := a.backends[name] if !ok { return "", fmt.Errorf("unknown audit backend %s", name) @@ -349,8 +399,8 @@ func (a *AuditBroker) GetHash(name string, input string) (string, error) { // log the given request and that *at least one* succeeds. func (a *AuditBroker) LogRequest(auth *logical.Auth, req *logical.Request, outerErr error) (retErr error) { defer metrics.MeasureSince([]string{"audit", "log_request"}, time.Now()) - a.l.RLock() - defer a.l.RUnlock() + a.RLock() + defer a.RUnlock() defer func() { if r := recover(); r != nil { a.logger.Error("audit: panic during logging", "request_path", req.Path, "error", r) @@ -389,8 +439,8 @@ func (a *AuditBroker) LogRequest(auth *logical.Auth, req *logical.Request, outer func (a *AuditBroker) LogResponse(auth *logical.Auth, req *logical.Request, resp *logical.Response, err error) (reterr error) { defer metrics.MeasureSince([]string{"audit", "log_response"}, time.Now()) - a.l.RLock() - defer a.l.RUnlock() + a.RLock() + defer a.RUnlock() defer func() { if r := recover(); r != nil { a.logger.Error("audit: panic during logging", "request_path", req.Path, "error", r) diff --git a/vault/audit_test.go b/vault/audit_test.go index 129f5b9c8..57c455d1e 100644 --- a/vault/audit_test.go +++ b/vault/audit_test.go @@ -49,6 +49,10 @@ func (n *NoopAudit) GetHash(data string) string { return n.Config.Salt.GetIdentifiedHMAC(data) } +func (n *NoopAudit) Reload() error { + return nil +} + func TestCore_EnableAudit(t *testing.T) { c, key, _ := TestCoreUnsealed(t) c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) { diff --git a/vault/mount.go b/vault/mount.go index 727757143..c19366c1e 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -101,17 +101,18 @@ func (t *MountTable) setTaint(path string, value bool) bool { return false } -// remove is used to remove a given path entry -func (t *MountTable) remove(path string) bool { +// remove is used to remove a given path entry; returns the entry that was +// removed +func (t *MountTable) remove(path string) *MountEntry { n := len(t.Entries) for i := 0; i < n; i++ { - if t.Entries[i].Path == path { + if entry := t.Entries[i]; entry.Path == path { t.Entries[i], t.Entries[n-1] = t.Entries[n-1], nil t.Entries = t.Entries[:n-1] - return true + return entry } } - return false + return nil } // MountEntry is used to represent a mount table entry diff --git a/vault/testing.go b/vault/testing.go index e3a5903ba..f721da0e8 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -346,6 +346,10 @@ func (n *noopAudit) LogResponse(a *logical.Auth, r *logical.Request, re *logical return nil } +func (n *noopAudit) Reload() error { + return nil +} + type rawHTTP struct{} func (n *rawHTTP) HandleRequest(req *logical.Request) (*logical.Response, error) {