From 8c67bed7ae9ed36d3fe606841d182bd19aaca53f Mon Sep 17 00:00:00 2001 From: Mark Gritter Date: Wed, 16 Dec 2020 16:00:32 -0600 Subject: [PATCH] Send a test message before committing a new audit device. (#10520) * Send a test message before committing a new audit device. Also, lower timeout on connection attempts in socket device. * added changelog * go mod vendor (picked up some unrelated changes.) * Skip audit device check in integration test. Co-authored-by: swayne275 --- audit/audit.go | 6 ++++ audit/format.go | 22 +++++++++++++ builtin/audit/file/backend.go | 18 +++++++++++ builtin/audit/socket/backend.go | 29 ++++++++++++++++- builtin/audit/syslog/backend.go | 12 +++++++ changelog/10520.txt | 3 ++ command/audit_enable_test.go | 3 +- sdk/helper/salt/salt.go | 17 ++++++++++ vault/audit.go | 32 +++++++++++++++++++ vault/testing.go | 17 ++++++++++ .../hashicorp/vault/sdk/helper/salt/salt.go | 17 ++++++++++ 11 files changed, 174 insertions(+), 2 deletions(-) create mode 100644 changelog/10520.txt diff --git a/audit/audit.go b/audit/audit.go index 6f1e7208e..5641b449a 100644 --- a/audit/audit.go +++ b/audit/audit.go @@ -24,6 +24,12 @@ type Backend interface { // a possibility. LogResponse(context.Context, *logical.LogInput) error + // LogTestMessage is used to check an audit backend before adding it + // permanently. It should attempt to synchronously log the given test + // message, WITHOUT using the normal Salt (which would require a storage + // operation on creation, which is currently disallowed.) + LogTestMessage(context.Context, *logical.LogInput, map[string]string) error + // GetHash is used to return the given data with the backend's hash, // so that a caller can determine if a value in the audit log matches // an expected plaintext value diff --git a/audit/format.go b/audit/format.go index 89d0934ac..37092e2cc 100644 --- a/audit/format.go +++ b/audit/format.go @@ -434,3 +434,25 @@ func parseVaultTokenFromJWT(token string) *string { return &claims.ID } + +// Create a formatter not backed by a persistent salt. +func NewTemporaryFormatter(format, prefix string) *AuditFormatter { + temporarySalt := func(ctx context.Context) (*salt.Salt, error) { + return salt.NewNonpersistentSalt(), nil + } + ret := &AuditFormatter{} + + switch format { + case "jsonx": + ret.AuditFormatWriter = &JSONxFormatWriter{ + Prefix: prefix, + SaltFunc: temporarySalt, + } + default: + ret.AuditFormatWriter = &JSONFormatWriter{ + Prefix: prefix, + SaltFunc: temporarySalt, + } + } + return ret +} diff --git a/builtin/audit/file/backend.go b/builtin/audit/file/backend.go index ebe38a00c..5cb1d9b60 100644 --- a/builtin/audit/file/backend.go +++ b/builtin/audit/file/backend.go @@ -258,6 +258,24 @@ func (b *Backend) LogResponse(ctx context.Context, in *logical.LogInput) error { return b.log(ctx, buf, writer) } +func (b *Backend) LogTestMessage(ctx context.Context, in *logical.LogInput, config map[string]string) error { + var writer io.Writer + switch b.path { + case "stdout": + writer = os.Stdout + case "discard": + return nil + } + + var buf bytes.Buffer + temporaryFormatter := audit.NewTemporaryFormatter(config["format"], config["prefix"]) + if err := temporaryFormatter.FormatRequest(ctx, &buf, b.formatConfig, in); err != nil { + return err + } + + return b.log(ctx, &buf, writer) +} + // The file lock must be held before calling this func (b *Backend) open() error { if b.f != nil { diff --git a/builtin/audit/socket/backend.go b/builtin/audit/socket/backend.go index ddec20cc4..2aef3a539 100644 --- a/builtin/audit/socket/backend.go +++ b/builtin/audit/socket/backend.go @@ -177,6 +177,30 @@ func (b *Backend) LogResponse(ctx context.Context, in *logical.LogInput) error { return err } +func (b *Backend) LogTestMessage(ctx context.Context, in *logical.LogInput, config map[string]string) error { + var buf bytes.Buffer + temporaryFormatter := audit.NewTemporaryFormatter(config["format"], config["prefix"]) + if err := temporaryFormatter.FormatRequest(ctx, &buf, b.formatConfig, in); err != nil { + return err + } + + b.Lock() + defer b.Unlock() + + err := b.write(ctx, buf.Bytes()) + if err != nil { + rErr := b.reconnect(ctx) + if rErr != nil { + err = multierror.Append(err, rErr) + } else { + // Try once more after reconnecting + err = b.write(ctx, buf.Bytes()) + } + } + + return err +} + func (b *Backend) write(ctx context.Context, buf []byte) error { if b.connection == nil { if err := b.reconnect(ctx); err != nil { @@ -203,8 +227,11 @@ func (b *Backend) reconnect(ctx context.Context) error { b.connection = nil } + timeoutContext, cancel := context.WithTimeout(ctx, b.writeDuration) + defer cancel() + dialer := net.Dialer{} - conn, err := dialer.DialContext(ctx, b.socketType, b.address) + conn, err := dialer.DialContext(timeoutContext, b.socketType, b.address) if err != nil { return err } diff --git a/builtin/audit/syslog/backend.go b/builtin/audit/syslog/backend.go index ee3eb78f9..9c7b775b8 100644 --- a/builtin/audit/syslog/backend.go +++ b/builtin/audit/syslog/backend.go @@ -140,6 +140,18 @@ func (b *Backend) LogResponse(ctx context.Context, in *logical.LogInput) error { return err } +func (b *Backend) LogTestMessage(ctx context.Context, in *logical.LogInput, config map[string]string) error { + var buf bytes.Buffer + temporaryFormatter := audit.NewTemporaryFormatter(config["format"], config["prefix"]) + if err := temporaryFormatter.FormatRequest(ctx, &buf, b.formatConfig, in); err != nil { + return err + } + + // Send to syslog + _, err := b.logger.Write(buf.Bytes()) + return err +} + func (b *Backend) Reload(_ context.Context) error { return nil } diff --git a/changelog/10520.txt b/changelog/10520.txt new file mode 100644 index 000000000..a8caf7732 --- /dev/null +++ b/changelog/10520.txt @@ -0,0 +1,3 @@ +```release-note:improvement +core: Check audit device with a test message before adding it. +``` diff --git a/command/audit_enable_test.go b/command/audit_enable_test.go index cef209f26..1f55703c2 100644 --- a/command/audit_enable_test.go +++ b/command/audit_enable_test.go @@ -189,7 +189,8 @@ func TestAuditEnableCommand_Run(t *testing.T) { case "file": args = append(args, "file_path=discard") case "socket": - args = append(args, "address=127.0.0.1:8888") + args = append(args, "address=127.0.0.1:8888", + "skip_test=true") case "syslog": if _, exists := os.LookupEnv("WSLENV"); exists { t.Log("skipping syslog test on WSL") diff --git a/sdk/helper/salt/salt.go b/sdk/helper/salt/salt.go index e9b7b6e98..50e0cad90 100644 --- a/sdk/helper/salt/salt.go +++ b/sdk/helper/salt/salt.go @@ -115,6 +115,23 @@ func NewSalt(ctx context.Context, view logical.Storage, config *Config) (*Salt, return s, nil } +// NewNonpersistentSalt creates a new salt with default configuration and no storage usage. +func NewNonpersistentSalt() *Salt { + // Setup the configuration + config := &Config{} + config.Location = "" + config.HashFunc = SHA256Hash + config.HMAC = sha256.New + config.HMACType = "hmac-sha256" + + s := &Salt{ + config: config, + } + s.salt, _ = uuid.GenerateUUID() + s.generated = true + return s +} + // SaltID is used to apply a salt and hash function to an ID to make sure // it is not reversible func (s *Salt) SaltID(id string) string { diff --git a/vault/audit.go b/vault/audit.go index 5231397cf..acf98a388 100644 --- a/vault/audit.go +++ b/vault/audit.go @@ -40,6 +40,24 @@ var ( errLoadAuditFailed = errors.New("failed to setup audit table") ) +func (c *Core) generateAuditTestProbe() (*logical.LogInput, error) { + requestId, err := uuid.GenerateUUID() + if err != nil { + return nil, err + } + return &logical.LogInput{ + Type: "request", + Auth: nil, + Request: &logical.Request{ + ID: requestId, + Operation: "update", + Path: "sys/audit/test", + }, + Response: nil, + OuterErr: nil, + }, nil +} + // enableAudit is used to enable a new audit backend func (c *Core) enableAudit(ctx context.Context, entry *MountEntry, updateStorage bool) error { // Ensure we end the path in a slash @@ -103,6 +121,20 @@ func (c *Core) enableAudit(ctx context.Context, entry *MountEntry, updateStorage return fmt.Errorf("nil audit backend of type %q returned from factory", entry.Type) } + if entry.Options["skip_test"] != "true" { + // Test the new audit device and report failure if it doesn't work. + testProbe, err := c.generateAuditTestProbe() + if err != nil { + return err + } + err = backend.LogTestMessage(ctx, testProbe, entry.Options) + if err != nil { + c.logger.Error("new audit backend failed test", "path", entry.Path, "type", entry.Type, "error", err) + return fmt.Errorf("audit backend failed test message: %w", err) + + } + } + newTable := c.audit.shallowClone() newTable.Entries = append(newTable.Entries, entry) diff --git a/vault/testing.go b/vault/testing.go index 35620cc9c..e8188dfa0 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -561,6 +561,19 @@ func (n *noopAudit) LogResponse(ctx context.Context, in *logical.LogInput) error return nil } +func (n *noopAudit) LogTestMessage(ctx context.Context, in *logical.LogInput, config map[string]string) error { + n.l.Lock() + defer n.l.Unlock() + var w bytes.Buffer + tempFormatter := audit.NewTemporaryFormatter(config["format"], config["prefix"]) + err := tempFormatter.FormatResponse(ctx, &w, audit.FormatterConfig{}, in) + if err != nil { + return err + } + n.records = append(n.records, w.Bytes()) + return nil +} + func (n *noopAudit) Reload(_ context.Context) error { return nil } @@ -2127,6 +2140,10 @@ func (n *NoopAudit) LogResponse(ctx context.Context, in *logical.LogInput) error return n.RespErr } +func (n *NoopAudit) LogTestMessage(ctx context.Context, in *logical.LogInput, options map[string]string) error { + return nil +} + func (n *NoopAudit) Salt(ctx context.Context) (*salt.Salt, error) { n.saltMutex.RLock() if n.salt != nil { diff --git a/vendor/github.com/hashicorp/vault/sdk/helper/salt/salt.go b/vendor/github.com/hashicorp/vault/sdk/helper/salt/salt.go index e9b7b6e98..50e0cad90 100644 --- a/vendor/github.com/hashicorp/vault/sdk/helper/salt/salt.go +++ b/vendor/github.com/hashicorp/vault/sdk/helper/salt/salt.go @@ -115,6 +115,23 @@ func NewSalt(ctx context.Context, view logical.Storage, config *Config) (*Salt, return s, nil } +// NewNonpersistentSalt creates a new salt with default configuration and no storage usage. +func NewNonpersistentSalt() *Salt { + // Setup the configuration + config := &Config{} + config.Location = "" + config.HashFunc = SHA256Hash + config.HMAC = sha256.New + config.HMACType = "hmac-sha256" + + s := &Salt{ + config: config, + } + s.salt, _ = uuid.GenerateUUID() + s.generated = true + return s +} + // SaltID is used to apply a salt and hash function to an ID to make sure // it is not reversible func (s *Salt) SaltID(id string) string {