escapingio: handle stalled readers

Handle stalled readers (e.g. network write got stalled), by having
escaping io have a buffer so it looks for escaped characters in the
stream.

This simplifies the implementation considerably, as we can look
for new lines followed by escaped characters directly.

Also, we add a test to ensure that any partial results are flushed to
readers.
This commit is contained in:
Mahmood Ali 2019-05-17 10:23:40 -04:00
parent 5bd946d790
commit 4013847ada
2 changed files with 282 additions and 98 deletions

View file

@ -1,6 +1,7 @@
package escapingio
import (
"bufio"
"io"
)
@ -22,12 +23,16 @@ type Handler func(c byte) bool
//
// Appearances of `~` when not preceded by a new line are propagated unmodified.
func NewReader(r io.Reader, c byte, h Handler) io.Reader {
return &reader{
pr, pw := io.Pipe()
reader := &reader{
impl: r,
escapeChar: c,
state: sLookEscapeChar,
handler: h,
pr: pr,
pw: pw,
}
go reader.pipe()
return reader
}
// lookState represents the state of reader for what character of `\n~.` sequence
@ -52,112 +57,115 @@ type reader struct {
escapeChar uint8
handler Handler
state lookState
// unread is a buffered character for next read if not-nil
unread *byte
// buffers
pw *io.PipeWriter
pr *io.PipeReader
}
func (r *reader) Read(buf []byte) (int, error) {
return r.pr.Read(buf)
}
func (r *reader) pipe() {
rb := make([]byte, 4096)
bw := bufio.NewWriter(r.pw)
state := sLookEscapeChar
for {
n, err := r.impl.Read(rb)
if n > 0 {
state = r.processBuf(bw, rb, n, state)
bw.Flush()
if state == sLookChar {
// terminated with ~ - let's read one more character
n, err = r.impl.Read(rb[:1])
if n == 1 {
state = sLookNewLine
if rb[0] == r.escapeChar {
// only emit escape character once
bw.WriteByte(rb[0])
bw.Flush()
} else if r.handler(rb[0]) {
// skip if handled
} else {
bw.WriteByte(r.escapeChar)
bw.WriteByte(rb[0])
bw.Flush()
}
}
}
}
if err != nil {
// write ~ if it's the last thing
if state == sLookChar {
bw.WriteByte(r.escapeChar)
}
bw.Flush()
r.pw.CloseWithError(err)
break
}
}
}
// processBuf process buffer and emits all output to writer
// if the last part of buffer is a new line followed by sequnce, it writes
// all output until the new line and returns sLookChar
func (r *reader) processBuf(bw io.Writer, buf []byte, n int, s lookState) lookState {
i := 0
wi := 0
START:
var n int
var err error
if r.unread != nil {
// try to return the unread character immediately
// without trying to block for another read
buf[0] = *r.unread
n = 1
r.unread = nil
} else {
n, err = r.impl.Read(buf)
}
// when we get to the end, check if we have any unprocessed \n~
if n == 0 && err != nil {
if r.state == sLookChar && err != nil {
buf[0] = r.escapeChar
n = 1
if s == sLookEscapeChar && buf[i] == r.escapeChar {
if i+1 >= n {
// buf terminates with ~ - write all before
bw.Write(buf[wi:i])
return sLookChar
}
return n, err
}
// inspect the state at beginning of read
if r.state == sLookChar {
r.state = sLookNewLine
// escape character hasn't been emitted yet
if buf[0] == r.escapeChar {
// earlier ~ was swallowed already, so leave this as is
} else if handled := r.handler(buf[0]); handled {
// need to drop a single letter
copy(buf, buf[1:n])
n--
nc := buf[i+1]
if nc == r.escapeChar {
// skip one escape char
bw.Write(buf[wi:i])
i++
wi = i
} else if r.handler(nc) {
// skip both characters
bw.Write(buf[wi:i])
i = i + 2
wi = i
} else {
// we need to re-introduce ~ with rest of body
// but be mindful if reintroducing ~ causes buffer to overflow
if n == len(buf) {
// in which case, save it for next read
c := buf[n-1]
r.unread = &c
copy(buf[1:], buf[:n])
buf[0] = r.escapeChar
} else {
copy(buf[1:], buf[:n])
buf[0] = r.escapeChar
n++
}
i = i + 2
// need to write everything keep going
}
}
n = r.processBuffer(buf, n)
if n == 0 && err == nil {
goto START
}
return n, err
}
// handles escaped character inside body of read buf.
func (r *reader) processBuffer(buf []byte, read int) int {
b := 0
for b < read {
c := buf[b]
if r.state == sLookEscapeChar && r.escapeChar == c {
r.state = sLookEscapeChar
// are we at the end of read; wait for next read
if b == read-1 {
read--
r.state = sLookChar
return read
}
// otherwise peek at next
nc := buf[b+1]
if nc == r.escapeChar {
// repeated ~, only emit one - skip one character
copy(buf[b:], buf[b+1:read])
read--
b++
continue
} else if handled := r.handler(nc); handled {
// need to drop both ~ and letter
copy(buf[b:], buf[b+2:read])
read -= 2
continue
} else {
// need to pass output unmodified with ~ and letter
}
} else if c == '\n' || c == '\r' {
r.state = sLookEscapeChar
} else {
r.state = sLookNewLine
// search until we get \n~, or buf terminates
for {
if i >= n {
// got to end without new line, write and return
bw.Write(buf[wi:n])
return sLookNewLine
}
b++
}
return read
if buf[i] == '\n' || buf[i] == '\r' {
// buf terminated at new line
if i+1 >= n {
bw.Write(buf[wi:n])
return sLookEscapeChar
}
// peek to see escape character go back to START if so
if buf[i+1] == r.escapeChar {
s = sLookEscapeChar
i++
goto START
}
}
i++
}
}

View file

@ -2,6 +2,7 @@ package escapingio
import (
"bytes"
"errors"
"fmt"
"io"
"math/rand"
@ -12,8 +13,10 @@ import (
"testing"
"testing/iotest"
"testing/quick"
"time"
"unicode"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -80,9 +83,164 @@ func TestEscapingReader_Static(t *testing.T) {
require.Equal(t, c.expected, found.String())
require.Equal(t, c.escaped, h.escaped())
})
t.Run("without reading: "+c.input, func(t *testing.T) {
input := strings.NewReader(c.input)
h := &testHandler{}
filter := NewReader(input, '~', h.handler)
// don't read to mimic a stalled reader
_ = filter
assertEventually(t, func() (bool, error) {
escaped := h.escaped()
if c.escaped == escaped {
return true, nil
}
return false, fmt.Errorf("expected %v but found %v", c.escaped, escaped)
})
})
}
}
// TestEscapingReader_EmitsPartialReads should emit partial results
// if next character is not read
func TestEscapingReader_FlushesPartialReads(t *testing.T) {
pr, pw := io.Pipe()
h := &testHandler{}
filter := NewReader(pr, '~', h.handler)
var lock sync.Mutex
var read bytes.Buffer
// helper for asserting reads
requireRead := func(expected *bytes.Buffer) {
readSoFar := ""
start := time.Now()
for time.Since(start) < 2*time.Second {
lock.Lock()
readSoFar = read.String()
lock.Unlock()
if readSoFar == expected.String() {
break
}
time.Sleep(50 * time.Millisecond)
}
require.Equal(t, expected.String(), readSoFar, "timed out without output")
}
var rerr error
var wg sync.WaitGroup
wg.Add(1)
// goroutine for reading partial data
go func() {
defer wg.Done()
buf := make([]byte, 1024)
for {
n, err := filter.Read(buf)
lock.Lock()
read.Write(buf[:n])
lock.Unlock()
if err != nil {
rerr = err
break
}
}
}()
expected := &bytes.Buffer{}
// test basic start and no new lines
pw.Write([]byte("first data"))
expected.WriteString("first data")
requireRead(expected)
require.Equal(t, "", h.escaped())
// test ~. appearing in middle of line but stop at new line
pw.Write([]byte("~.inmiddleappears\n"))
expected.WriteString("~.inmiddleappears\n")
requireRead(expected)
require.Equal(t, "", h.escaped())
// from here on we test \n~ at boundary
// ~~ after new line; and stop at \n~
pw.Write([]byte("~~second line\n~"))
expected.WriteString("~second line\n")
requireRead(expected)
require.Equal(t, "", h.escaped())
// . to be skipped; stop at \n~ again
pw.Write([]byte(".third line\n~"))
expected.WriteString("third line\n")
requireRead(expected)
require.Equal(t, ".", h.escaped())
// q to be emitted; stop at \n
pw.Write([]byte("qfourth line\n"))
expected.WriteString("~qfourth line\n")
requireRead(expected)
require.Equal(t, ".q", h.escaped())
// ~. to be skipped; stop at \n~
pw.Write([]byte("~.fifth line\n~"))
expected.WriteString("fifth line\n")
requireRead(expected)
require.Equal(t, ".q.", h.escaped())
// ~ alone after \n~ - should be emitted
pw.Write([]byte("~"))
expected.WriteString("~")
requireRead(expected)
require.Equal(t, ".q.", h.escaped())
// rest of line ending with \n~
pw.Write([]byte("rest of line\n~"))
expected.WriteString("rest of line\n")
requireRead(expected)
require.Equal(t, ".q.", h.escaped())
// m alone after \n~ - should be emitted with ~
pw.Write([]byte("m"))
expected.WriteString("~m")
requireRead(expected)
require.Equal(t, ".q.m", h.escaped())
// rest of line and end with \n
pw.Write([]byte("onemore line\n"))
expected.WriteString("onemore line\n")
requireRead(expected)
require.Equal(t, ".q.m", h.escaped())
// ~q to be emitted stop at \n~; last charcater
pw.Write([]byte("~qlast line\n~"))
expected.WriteString("~qlast line\n")
requireRead(expected)
require.Equal(t, ".q.mq", h.escaped())
// last ~ gets emitted and we preserve error
eerr := errors.New("my custom error")
pw.CloseWithError(eerr)
expected.WriteString("~")
requireRead(expected)
require.Equal(t, ".q.mq", h.escaped())
wg.Wait()
require.Error(t, rerr)
require.Equal(t, eerr, rerr)
}
func TestEscapingReader_Generated_EquivalentToNaive(t *testing.T) {
f := func(v readingInput) bool {
return checkEquivalenceToNaive(t, string(v))
@ -315,3 +473,21 @@ func (r *arbtiraryReader) Read(buf []byte) (int, error) {
return r.buf.Read(buf[:l])
}
func assertEventually(t *testing.T, testFn func() (bool, error)) {
start := time.Now()
var err error
var b bool
for {
if time.Since(start) > 2*time.Second {
assert.Fail(t, "timed out", "error: %v", err)
}
b, err = testFn()
if b {
return
}
time.Sleep(50 * time.Millisecond)
}
}