Send a test message before committing a new audit device. (#10520)

* Send a test message before committing a new audit device.
Also, lower timeout on connection attempts in socket device.
* added changelog
* go mod vendor (picked up some unrelated changes.)
* Skip audit device check in integration test.
Co-authored-by: swayne275 <swayne@hashicorp.com>
This commit is contained in:
Mark Gritter 2020-12-16 16:00:32 -06:00 committed by GitHub
parent 5ac1c93c4a
commit 8c67bed7ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 174 additions and 2 deletions

View File

@ -24,6 +24,12 @@ type Backend interface {
// a possibility.
LogResponse(context.Context, *logical.LogInput) error
// LogTestMessage is used to check an audit backend before adding it
// permanently. It should attempt to synchronously log the given test
// message, WITHOUT using the normal Salt (which would require a storage
// operation on creation, which is currently disallowed.)
LogTestMessage(context.Context, *logical.LogInput, map[string]string) error
// GetHash is used to return the given data with the backend's hash,
// so that a caller can determine if a value in the audit log matches
// an expected plaintext value

View File

@ -434,3 +434,25 @@ func parseVaultTokenFromJWT(token string) *string {
return &claims.ID
}
// Create a formatter not backed by a persistent salt.
func NewTemporaryFormatter(format, prefix string) *AuditFormatter {
temporarySalt := func(ctx context.Context) (*salt.Salt, error) {
return salt.NewNonpersistentSalt(), nil
}
ret := &AuditFormatter{}
switch format {
case "jsonx":
ret.AuditFormatWriter = &JSONxFormatWriter{
Prefix: prefix,
SaltFunc: temporarySalt,
}
default:
ret.AuditFormatWriter = &JSONFormatWriter{
Prefix: prefix,
SaltFunc: temporarySalt,
}
}
return ret
}

View File

@ -258,6 +258,24 @@ func (b *Backend) LogResponse(ctx context.Context, in *logical.LogInput) error {
return b.log(ctx, buf, writer)
}
func (b *Backend) LogTestMessage(ctx context.Context, in *logical.LogInput, config map[string]string) error {
var writer io.Writer
switch b.path {
case "stdout":
writer = os.Stdout
case "discard":
return nil
}
var buf bytes.Buffer
temporaryFormatter := audit.NewTemporaryFormatter(config["format"], config["prefix"])
if err := temporaryFormatter.FormatRequest(ctx, &buf, b.formatConfig, in); err != nil {
return err
}
return b.log(ctx, &buf, writer)
}
// The file lock must be held before calling this
func (b *Backend) open() error {
if b.f != nil {

View File

@ -177,6 +177,30 @@ func (b *Backend) LogResponse(ctx context.Context, in *logical.LogInput) error {
return err
}
func (b *Backend) LogTestMessage(ctx context.Context, in *logical.LogInput, config map[string]string) error {
var buf bytes.Buffer
temporaryFormatter := audit.NewTemporaryFormatter(config["format"], config["prefix"])
if err := temporaryFormatter.FormatRequest(ctx, &buf, b.formatConfig, in); err != nil {
return err
}
b.Lock()
defer b.Unlock()
err := b.write(ctx, buf.Bytes())
if err != nil {
rErr := b.reconnect(ctx)
if rErr != nil {
err = multierror.Append(err, rErr)
} else {
// Try once more after reconnecting
err = b.write(ctx, buf.Bytes())
}
}
return err
}
func (b *Backend) write(ctx context.Context, buf []byte) error {
if b.connection == nil {
if err := b.reconnect(ctx); err != nil {
@ -203,8 +227,11 @@ func (b *Backend) reconnect(ctx context.Context) error {
b.connection = nil
}
timeoutContext, cancel := context.WithTimeout(ctx, b.writeDuration)
defer cancel()
dialer := net.Dialer{}
conn, err := dialer.DialContext(ctx, b.socketType, b.address)
conn, err := dialer.DialContext(timeoutContext, b.socketType, b.address)
if err != nil {
return err
}

View File

@ -140,6 +140,18 @@ func (b *Backend) LogResponse(ctx context.Context, in *logical.LogInput) error {
return err
}
func (b *Backend) LogTestMessage(ctx context.Context, in *logical.LogInput, config map[string]string) error {
var buf bytes.Buffer
temporaryFormatter := audit.NewTemporaryFormatter(config["format"], config["prefix"])
if err := temporaryFormatter.FormatRequest(ctx, &buf, b.formatConfig, in); err != nil {
return err
}
// Send to syslog
_, err := b.logger.Write(buf.Bytes())
return err
}
func (b *Backend) Reload(_ context.Context) error {
return nil
}

3
changelog/10520.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:improvement
core: Check audit device with a test message before adding it.
```

View File

@ -189,7 +189,8 @@ func TestAuditEnableCommand_Run(t *testing.T) {
case "file":
args = append(args, "file_path=discard")
case "socket":
args = append(args, "address=127.0.0.1:8888")
args = append(args, "address=127.0.0.1:8888",
"skip_test=true")
case "syslog":
if _, exists := os.LookupEnv("WSLENV"); exists {
t.Log("skipping syslog test on WSL")

View File

@ -115,6 +115,23 @@ func NewSalt(ctx context.Context, view logical.Storage, config *Config) (*Salt,
return s, nil
}
// NewNonpersistentSalt creates a new salt with default configuration and no storage usage.
func NewNonpersistentSalt() *Salt {
// Setup the configuration
config := &Config{}
config.Location = ""
config.HashFunc = SHA256Hash
config.HMAC = sha256.New
config.HMACType = "hmac-sha256"
s := &Salt{
config: config,
}
s.salt, _ = uuid.GenerateUUID()
s.generated = true
return s
}
// SaltID is used to apply a salt and hash function to an ID to make sure
// it is not reversible
func (s *Salt) SaltID(id string) string {

View File

@ -40,6 +40,24 @@ var (
errLoadAuditFailed = errors.New("failed to setup audit table")
)
func (c *Core) generateAuditTestProbe() (*logical.LogInput, error) {
requestId, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
return &logical.LogInput{
Type: "request",
Auth: nil,
Request: &logical.Request{
ID: requestId,
Operation: "update",
Path: "sys/audit/test",
},
Response: nil,
OuterErr: nil,
}, nil
}
// enableAudit is used to enable a new audit backend
func (c *Core) enableAudit(ctx context.Context, entry *MountEntry, updateStorage bool) error {
// Ensure we end the path in a slash
@ -103,6 +121,20 @@ func (c *Core) enableAudit(ctx context.Context, entry *MountEntry, updateStorage
return fmt.Errorf("nil audit backend of type %q returned from factory", entry.Type)
}
if entry.Options["skip_test"] != "true" {
// Test the new audit device and report failure if it doesn't work.
testProbe, err := c.generateAuditTestProbe()
if err != nil {
return err
}
err = backend.LogTestMessage(ctx, testProbe, entry.Options)
if err != nil {
c.logger.Error("new audit backend failed test", "path", entry.Path, "type", entry.Type, "error", err)
return fmt.Errorf("audit backend failed test message: %w", err)
}
}
newTable := c.audit.shallowClone()
newTable.Entries = append(newTable.Entries, entry)

View File

@ -561,6 +561,19 @@ func (n *noopAudit) LogResponse(ctx context.Context, in *logical.LogInput) error
return nil
}
func (n *noopAudit) LogTestMessage(ctx context.Context, in *logical.LogInput, config map[string]string) error {
n.l.Lock()
defer n.l.Unlock()
var w bytes.Buffer
tempFormatter := audit.NewTemporaryFormatter(config["format"], config["prefix"])
err := tempFormatter.FormatResponse(ctx, &w, audit.FormatterConfig{}, in)
if err != nil {
return err
}
n.records = append(n.records, w.Bytes())
return nil
}
func (n *noopAudit) Reload(_ context.Context) error {
return nil
}
@ -2127,6 +2140,10 @@ func (n *NoopAudit) LogResponse(ctx context.Context, in *logical.LogInput) error
return n.RespErr
}
func (n *NoopAudit) LogTestMessage(ctx context.Context, in *logical.LogInput, options map[string]string) error {
return nil
}
func (n *NoopAudit) Salt(ctx context.Context) (*salt.Salt, error) {
n.saltMutex.RLock()
if n.salt != nil {

View File

@ -115,6 +115,23 @@ func NewSalt(ctx context.Context, view logical.Storage, config *Config) (*Salt,
return s, nil
}
// NewNonpersistentSalt creates a new salt with default configuration and no storage usage.
func NewNonpersistentSalt() *Salt {
// Setup the configuration
config := &Config{}
config.Location = ""
config.HashFunc = SHA256Hash
config.HMAC = sha256.New
config.HMACType = "hmac-sha256"
s := &Salt{
config: config,
}
s.salt, _ = uuid.GenerateUUID()
s.generated = true
return s
}
// SaltID is used to apply a salt and hash function to an ID to make sure
// it is not reversible
func (s *Salt) SaltID(id string) string {