diff --git a/builtin/audit/socket/backend.go b/builtin/audit/socket/backend.go index 9472a5526..24d600e7b 100644 --- a/builtin/audit/socket/backend.go +++ b/builtin/audit/socket/backend.go @@ -122,13 +122,14 @@ func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr b.Lock() defer b.Unlock() - b.connection.SetDeadline(time.Now().Add(b.writeDuration)) - _, err := b.connection.Write(buf.Bytes()) - + err := b.write(buf.Bytes()) if err != nil { rErr := b.reconnect() if rErr != nil { err = multierror.Append(err, rErr) + } else { + // Try once more after reconnecting + err = b.write(buf.Bytes()) } } @@ -145,19 +146,34 @@ func (b *Backend) LogResponse(auth *logical.Auth, req *logical.Request, b.Lock() defer b.Unlock() - b.connection.SetDeadline(time.Now().Add(b.writeDuration)) - _, err := b.connection.Write(buf.Bytes()) - + err := b.write(buf.Bytes()) if err != nil { rErr := b.reconnect() if rErr != nil { err = multierror.Append(err, rErr) + } else { + // Try once more after reconnecting + err = b.write(buf.Bytes()) } } return err } +func (b *Backend) write(buf []byte) error { + err := b.connection.SetWriteDeadline(time.Now().Add(b.writeDuration)) + if err != nil { + return err + } + + _, err = b.connection.Write(buf) + if err != nil { + return err + } + + return err +} + func (b *Backend) reconnect() error { conn, err := net.Dial(b.socketType, b.address) if err != nil {