Fixes race on StreamFramer Destroy

This PR:
* Fixes a race in which StreamFramer could panic while closing.
* Simplifies the logic of the StreamFramer
* Removes a potentially leaked goroutine
This commit is contained in:
Alex Dadgar 2016-11-17 20:14:47 -08:00
parent c8ddd98e3b
commit 7ead95c333

View file

@ -227,16 +227,18 @@ func (s *StreamFrame) IsHeartbeat() bool {
// StreamFramer is used to buffer and send frames as well as heartbeat. // StreamFramer is used to buffer and send frames as well as heartbeat.
type StreamFramer struct { type StreamFramer struct {
out io.WriteCloser out io.WriteCloser
enc *codec.Encoder enc *codec.Encoder
frameSize int encLock sync.Mutex
heartbeat *time.Ticker
flusher *time.Ticker frameSize int
heartbeat *time.Ticker
flusher *time.Ticker
shutdownCh chan struct{} shutdownCh chan struct{}
exitCh chan struct{} exitCh chan struct{}
outbound chan *StreamFrame
// The mutex protects everything below // The mutex protects everything below
l sync.Mutex l sync.Mutex
@ -266,7 +268,6 @@ func NewStreamFramer(out io.WriteCloser, heartbeatRate, batchWindow time.Duratio
frameSize: frameSize, frameSize: frameSize,
heartbeat: heartbeat, heartbeat: heartbeat,
flusher: flusher, flusher: flusher,
outbound: make(chan *StreamFrame),
data: bytes.NewBuffer(make([]byte, 0, 2*frameSize)), data: bytes.NewBuffer(make([]byte, 0, 2*frameSize)),
shutdownCh: make(chan struct{}), shutdownCh: make(chan struct{}),
exitCh: make(chan struct{}), exitCh: make(chan struct{}),
@ -279,10 +280,11 @@ func (s *StreamFramer) Destroy() {
close(s.shutdownCh) close(s.shutdownCh)
s.heartbeat.Stop() s.heartbeat.Stop()
s.flusher.Stop() s.flusher.Stop()
running := s.running
s.l.Unlock() s.l.Unlock()
// Ensure things were flushed // Ensure things were flushed
if s.running { if running {
<-s.exitCh <-s.exitCh
} }
s.out.Close() s.out.Close()
@ -309,90 +311,60 @@ func (s *StreamFramer) ExitCh() <-chan struct{} {
// run is the internal run method. It exits if Destroy is called or an error // run is the internal run method. It exits if Destroy is called or an error
// occurs, in which case the exit channel is closed. // occurs, in which case the exit channel is closed.
func (s *StreamFramer) run() { func (s *StreamFramer) run() {
// Store any error and mark it as not running
var err error var err error
defer func() { defer func() {
close(s.exitCh) close(s.exitCh)
s.l.Lock() s.l.Lock()
close(s.outbound)
s.Err = err
s.running = false s.running = false
s.Err = err
s.l.Unlock() s.l.Unlock()
}() }()
// Start a heartbeat/flusher go-routine. This is done seprately to avoid blocking
// the outbound channel.
go func() {
for {
select {
case <-s.exitCh:
return
case <-s.shutdownCh:
return
case <-s.flusher.C:
// Skip if there is nothing to flush
s.l.Lock()
if s.f == nil {
s.l.Unlock()
continue
}
// Read the data for the frame, and send it
s.f.Data = s.readData()
select {
case s.outbound <- s.f:
s.f = nil
case <-s.exitCh:
}
s.l.Unlock()
case <-s.heartbeat.C:
// Send a heartbeat frame
s.l.Lock()
select {
case s.outbound <- &StreamFrame{}:
default:
}
s.l.Unlock()
}
}
}()
OUTER: OUTER:
for { for {
select { select {
case <-s.shutdownCh: case <-s.shutdownCh:
break OUTER break OUTER
case o := <-s.outbound: case <-s.flusher.C:
// Send the frame // Skip if there is nothing to flush
if err = s.enc.Encode(o); err != nil { s.l.Lock()
return if s.f == nil {
s.l.Unlock()
continue
} }
}
}
// Flush any existing frames // Read the data for the frame, and send it
FLUSH: s.f.Data = s.readData()
for { err = s.send(s.f)
select { s.f = nil
case o := <-s.outbound: s.l.Unlock()
// Send the frame and then clear the current working frame if err != nil {
if err = s.enc.Encode(o); err != nil { return
}
case <-s.heartbeat.C:
// Send a heartbeat frame
if err = s.send(&StreamFrame{}); err != nil {
return return
} }
default:
break FLUSH
} }
} }
s.l.Lock() s.l.Lock()
if s.f != nil { if s.f != nil {
s.f.Data = s.readData() s.f.Data = s.readData()
s.enc.Encode(s.f) err = s.send(s.f)
s.f = nil
} }
s.l.Unlock() s.l.Unlock()
} }
// send takes a StreamFrame, encodes and sends it
func (s *StreamFramer) send(f *StreamFrame) error {
s.encLock.Lock()
defer s.encLock.Unlock()
return s.enc.Encode(f)
}
// readData is a helper which reads the buffered data returning up to the frame // readData is a helper which reads the buffered data returning up to the frame
// size of data. Must be called with the lock held. The returned value is // size of data. Must be called with the lock held. The returned value is
// invalid on the next read or write into the StreamFramer buffer // invalid on the next read or write into the StreamFramer buffer
@ -424,6 +396,7 @@ func (s *StreamFramer) Send(file, fileEvent string, data []byte, offset int64) e
if s.Err != nil { if s.Err != nil {
return s.Err return s.Err
} }
return fmt.Errorf("StreamFramer not running") return fmt.Errorf("StreamFramer not running")
} }
@ -435,8 +408,12 @@ func (s *StreamFramer) Send(file, fileEvent string, data []byte, offset int64) e
select { select {
case <-s.exitCh: case <-s.exitCh:
return nil return nil
case s.outbound <- &f: default:
s.f = nil }
err := s.send(&f)
s.f = nil
if err != nil {
return err
} }
} }
@ -457,11 +434,16 @@ func (s *StreamFramer) Send(file, fileEvent string, data []byte, offset int64) e
select { select {
case <-s.exitCh: case <-s.exitCh:
return nil return nil
case s.outbound <- &StreamFrame{ default:
}
f := &StreamFrame{
Offset: s.f.Offset, Offset: s.f.Offset,
File: s.f.File, File: s.f.File,
FileEvent: s.f.FileEvent, FileEvent: s.f.FileEvent,
}: }
if err := s.send(f); err != nil {
return err
} }
} }
@ -472,12 +454,17 @@ func (s *StreamFramer) Send(file, fileEvent string, data []byte, offset int64) e
select { select {
case <-s.exitCh: case <-s.exitCh:
return nil return nil
case s.outbound <- &StreamFrame{ default:
}
f := &StreamFrame{
Offset: s.f.Offset, Offset: s.f.Offset,
File: s.f.File, File: s.f.File,
FileEvent: s.f.FileEvent, FileEvent: s.f.FileEvent,
Data: d, Data: d,
}: }
if err := s.send(f); err != nil {
return err
} }
} }
@ -866,6 +853,10 @@ func blockUntilNextLog(fs allocdir.AllocDirFS, t *tomb.Tomb, logPath, task, logT
scanCh := time.Tick(nextLogCheckRate) scanCh := time.Tick(nextLogCheckRate)
for { for {
select { select {
case <-t.Dead():
next <- fmt.Errorf("shutdown triggered")
close(next)
return
case err := <-eofCancelCh: case err := <-eofCancelCh:
next <- err next <- err
close(next) close(next)