From b38eeec96a95e3d4f617bfe55df6e9734e8c3dda Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 2 Feb 2017 15:44:56 -0800 Subject: [PATCH] Add write deadline and a Reload function --- builtin/audit/socket/backend.go | 69 ++++++++++++++++++++++++++++----- 1 file changed, 60 insertions(+), 9 deletions(-) diff --git a/builtin/audit/socket/backend.go b/builtin/audit/socket/backend.go index 1b910bb76..cf109efe4 100644 --- a/builtin/audit/socket/backend.go +++ b/builtin/audit/socket/backend.go @@ -5,8 +5,11 @@ import ( "fmt" "net" "strconv" + "sync" + "time" "github.com/hashicorp/vault/audit" + "github.com/hashicorp/vault/helper/duration" "github.com/hashicorp/vault/logical" ) @@ -20,9 +23,18 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) { return nil, fmt.Errorf("address is required") } - socket_type, ok := conf.Config["socket_type"] + socketType, ok := conf.Config["socket_type"] if !ok { - socket_type = "tcp" + socketType = "tcp" + } + + writeDeadline, ok := conf.Config["write_deadline"] + if !ok { + writeDeadline = "2s" + } + writeDuration, err := duration.ParseDurationSecond(writeDeadline) + if err != nil { + return nil, err } format, ok := conf.Config["format"] @@ -55,7 +67,7 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) { logRaw = b } - conn, err := net.Dial(socket_type, address) + conn, err := net.Dial(socketType, address) if err != nil { return nil, err } @@ -67,6 +79,9 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) { Salt: conf.Salt, HMACAccessor: hmacAccessor, }, + writeDuration: writeDuration, + address: address, + socketType: socketType, } switch format { @@ -85,6 +100,12 @@ type Backend struct { formatter audit.AuditFormatter formatConfig audit.FormatterConfig + + writeDuration time.Duration + address string + socketType string + + sync.Mutex } func (b *Backend) GetHash(data string) string { @@ -97,20 +118,50 @@ func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr return err } - b.connection.Write(buf.Bytes()) - return nil + b.Lock() + + b.connection.SetDeadline(time.Now().Add(b.writeDuration)) + _, err := b.connection.Write(buf.Bytes()) + + b.Unlock() + if err != nil { + b.Reload() + } + + return err } func (b *Backend) LogResponse(auth *logical.Auth, req *logical.Request, - resp *logical.Response, err error) error { + resp *logical.Response, outerErr error) error { var buf bytes.Buffer - if err := b.formatter.FormatResponse(&buf, b.formatConfig, auth, req, resp, err); err != nil { + if err := b.formatter.FormatResponse(&buf, b.formatConfig, auth, req, resp, outerErr); err != nil { return err } - b.connection.Write(buf.Bytes()) - return nil + + b.Lock() + + b.connection.SetDeadline(time.Now().Add(b.writeDuration)) + _, err := b.connection.Write(buf.Bytes()) + + b.Unlock() + if err != nil { + b.Reload() + } + + return err } func (b *Backend) Reload() error { + b.Lock() + defer b.Unlock() + + conn, err := net.Dial(b.socketType, b.address) + if err != nil { + return err + } + + b.connection.Close() + b.connection = conn + return nil }