escapingio: handle stalled readers
Handle stalled readers (e.g. network write got stalled), by having escaping io have a buffer so it looks for escaped characters in the stream. This simplifies the implementation considerably, as we can look for new lines followed by escaped characters directly. Also, we add a test to ensure that any partial results are flushed to readers.
This commit is contained in:
parent
5bd946d790
commit
4013847ada
|
@ -1,6 +1,7 @@
|
|||
package escapingio
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
)
|
||||
|
||||
|
@ -22,12 +23,16 @@ type Handler func(c byte) bool
|
|||
//
|
||||
// Appearances of `~` when not preceded by a new line are propagated unmodified.
|
||||
func NewReader(r io.Reader, c byte, h Handler) io.Reader {
|
||||
return &reader{
|
||||
pr, pw := io.Pipe()
|
||||
reader := &reader{
|
||||
impl: r,
|
||||
escapeChar: c,
|
||||
state: sLookEscapeChar,
|
||||
handler: h,
|
||||
pr: pr,
|
||||
pw: pw,
|
||||
}
|
||||
go reader.pipe()
|
||||
return reader
|
||||
}
|
||||
|
||||
// lookState represents the state of reader for what character of `\n~.` sequence
|
||||
|
@ -52,112 +57,115 @@ type reader struct {
|
|||
escapeChar uint8
|
||||
handler Handler
|
||||
|
||||
state lookState
|
||||
|
||||
// unread is a buffered character for next read if not-nil
|
||||
unread *byte
|
||||
// buffers
|
||||
pw *io.PipeWriter
|
||||
pr *io.PipeReader
|
||||
}
|
||||
|
||||
func (r *reader) Read(buf []byte) (int, error) {
|
||||
return r.pr.Read(buf)
|
||||
}
|
||||
|
||||
func (r *reader) pipe() {
|
||||
rb := make([]byte, 4096)
|
||||
bw := bufio.NewWriter(r.pw)
|
||||
|
||||
state := sLookEscapeChar
|
||||
|
||||
for {
|
||||
n, err := r.impl.Read(rb)
|
||||
|
||||
if n > 0 {
|
||||
state = r.processBuf(bw, rb, n, state)
|
||||
bw.Flush()
|
||||
if state == sLookChar {
|
||||
// terminated with ~ - let's read one more character
|
||||
n, err = r.impl.Read(rb[:1])
|
||||
if n == 1 {
|
||||
state = sLookNewLine
|
||||
if rb[0] == r.escapeChar {
|
||||
// only emit escape character once
|
||||
bw.WriteByte(rb[0])
|
||||
bw.Flush()
|
||||
} else if r.handler(rb[0]) {
|
||||
// skip if handled
|
||||
} else {
|
||||
bw.WriteByte(r.escapeChar)
|
||||
bw.WriteByte(rb[0])
|
||||
bw.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// write ~ if it's the last thing
|
||||
if state == sLookChar {
|
||||
bw.WriteByte(r.escapeChar)
|
||||
}
|
||||
bw.Flush()
|
||||
r.pw.CloseWithError(err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processBuf process buffer and emits all output to writer
|
||||
// if the last part of buffer is a new line followed by sequnce, it writes
|
||||
// all output until the new line and returns sLookChar
|
||||
func (r *reader) processBuf(bw io.Writer, buf []byte, n int, s lookState) lookState {
|
||||
i := 0
|
||||
|
||||
wi := 0
|
||||
|
||||
START:
|
||||
var n int
|
||||
var err error
|
||||
|
||||
if r.unread != nil {
|
||||
// try to return the unread character immediately
|
||||
// without trying to block for another read
|
||||
buf[0] = *r.unread
|
||||
n = 1
|
||||
r.unread = nil
|
||||
} else {
|
||||
n, err = r.impl.Read(buf)
|
||||
}
|
||||
|
||||
// when we get to the end, check if we have any unprocessed \n~
|
||||
if n == 0 && err != nil {
|
||||
if r.state == sLookChar && err != nil {
|
||||
buf[0] = r.escapeChar
|
||||
n = 1
|
||||
if s == sLookEscapeChar && buf[i] == r.escapeChar {
|
||||
if i+1 >= n {
|
||||
// buf terminates with ~ - write all before
|
||||
bw.Write(buf[wi:i])
|
||||
return sLookChar
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// inspect the state at beginning of read
|
||||
if r.state == sLookChar {
|
||||
r.state = sLookNewLine
|
||||
|
||||
// escape character hasn't been emitted yet
|
||||
if buf[0] == r.escapeChar {
|
||||
// earlier ~ was swallowed already, so leave this as is
|
||||
} else if handled := r.handler(buf[0]); handled {
|
||||
// need to drop a single letter
|
||||
copy(buf, buf[1:n])
|
||||
n--
|
||||
nc := buf[i+1]
|
||||
if nc == r.escapeChar {
|
||||
// skip one escape char
|
||||
bw.Write(buf[wi:i])
|
||||
i++
|
||||
wi = i
|
||||
} else if r.handler(nc) {
|
||||
// skip both characters
|
||||
bw.Write(buf[wi:i])
|
||||
i = i + 2
|
||||
wi = i
|
||||
} else {
|
||||
// we need to re-introduce ~ with rest of body
|
||||
// but be mindful if reintroducing ~ causes buffer to overflow
|
||||
if n == len(buf) {
|
||||
// in which case, save it for next read
|
||||
c := buf[n-1]
|
||||
r.unread = &c
|
||||
copy(buf[1:], buf[:n])
|
||||
buf[0] = r.escapeChar
|
||||
} else {
|
||||
copy(buf[1:], buf[:n])
|
||||
buf[0] = r.escapeChar
|
||||
n++
|
||||
}
|
||||
i = i + 2
|
||||
// need to write everything keep going
|
||||
}
|
||||
}
|
||||
|
||||
n = r.processBuffer(buf, n)
|
||||
if n == 0 && err == nil {
|
||||
goto START
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// handles escaped character inside body of read buf.
|
||||
func (r *reader) processBuffer(buf []byte, read int) int {
|
||||
b := 0
|
||||
|
||||
for b < read {
|
||||
|
||||
c := buf[b]
|
||||
if r.state == sLookEscapeChar && r.escapeChar == c {
|
||||
r.state = sLookEscapeChar
|
||||
|
||||
// are we at the end of read; wait for next read
|
||||
if b == read-1 {
|
||||
read--
|
||||
r.state = sLookChar
|
||||
return read
|
||||
}
|
||||
|
||||
// otherwise peek at next
|
||||
nc := buf[b+1]
|
||||
if nc == r.escapeChar {
|
||||
// repeated ~, only emit one - skip one character
|
||||
copy(buf[b:], buf[b+1:read])
|
||||
read--
|
||||
b++
|
||||
continue
|
||||
} else if handled := r.handler(nc); handled {
|
||||
// need to drop both ~ and letter
|
||||
copy(buf[b:], buf[b+2:read])
|
||||
read -= 2
|
||||
continue
|
||||
} else {
|
||||
// need to pass output unmodified with ~ and letter
|
||||
}
|
||||
} else if c == '\n' || c == '\r' {
|
||||
r.state = sLookEscapeChar
|
||||
} else {
|
||||
r.state = sLookNewLine
|
||||
// search until we get \n~, or buf terminates
|
||||
for {
|
||||
if i >= n {
|
||||
// got to end without new line, write and return
|
||||
bw.Write(buf[wi:n])
|
||||
return sLookNewLine
|
||||
}
|
||||
b++
|
||||
}
|
||||
|
||||
return read
|
||||
if buf[i] == '\n' || buf[i] == '\r' {
|
||||
// buf terminated at new line
|
||||
if i+1 >= n {
|
||||
bw.Write(buf[wi:n])
|
||||
return sLookEscapeChar
|
||||
}
|
||||
|
||||
// peek to see escape character go back to START if so
|
||||
if buf[i+1] == r.escapeChar {
|
||||
s = sLookEscapeChar
|
||||
i++
|
||||
goto START
|
||||
}
|
||||
}
|
||||
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package escapingio
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
|
@ -12,8 +13,10 @@ import (
|
|||
"testing"
|
||||
"testing/iotest"
|
||||
"testing/quick"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
|
@ -80,9 +83,164 @@ func TestEscapingReader_Static(t *testing.T) {
|
|||
require.Equal(t, c.expected, found.String())
|
||||
require.Equal(t, c.escaped, h.escaped())
|
||||
})
|
||||
|
||||
t.Run("without reading: "+c.input, func(t *testing.T) {
|
||||
input := strings.NewReader(c.input)
|
||||
|
||||
h := &testHandler{}
|
||||
|
||||
filter := NewReader(input, '~', h.handler)
|
||||
|
||||
// don't read to mimic a stalled reader
|
||||
_ = filter
|
||||
|
||||
assertEventually(t, func() (bool, error) {
|
||||
escaped := h.escaped()
|
||||
if c.escaped == escaped {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, fmt.Errorf("expected %v but found %v", c.escaped, escaped)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEscapingReader_EmitsPartialReads should emit partial results
|
||||
// if next character is not read
|
||||
func TestEscapingReader_FlushesPartialReads(t *testing.T) {
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
h := &testHandler{}
|
||||
filter := NewReader(pr, '~', h.handler)
|
||||
|
||||
var lock sync.Mutex
|
||||
var read bytes.Buffer
|
||||
|
||||
// helper for asserting reads
|
||||
requireRead := func(expected *bytes.Buffer) {
|
||||
readSoFar := ""
|
||||
|
||||
start := time.Now()
|
||||
for time.Since(start) < 2*time.Second {
|
||||
lock.Lock()
|
||||
readSoFar = read.String()
|
||||
lock.Unlock()
|
||||
|
||||
if readSoFar == expected.String() {
|
||||
break
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
require.Equal(t, expected.String(), readSoFar, "timed out without output")
|
||||
}
|
||||
|
||||
var rerr error
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
// goroutine for reading partial data
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
for {
|
||||
n, err := filter.Read(buf)
|
||||
lock.Lock()
|
||||
read.Write(buf[:n])
|
||||
lock.Unlock()
|
||||
|
||||
if err != nil {
|
||||
rerr = err
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
expected := &bytes.Buffer{}
|
||||
|
||||
// test basic start and no new lines
|
||||
pw.Write([]byte("first data"))
|
||||
expected.WriteString("first data")
|
||||
requireRead(expected)
|
||||
require.Equal(t, "", h.escaped())
|
||||
|
||||
// test ~. appearing in middle of line but stop at new line
|
||||
pw.Write([]byte("~.inmiddleappears\n"))
|
||||
expected.WriteString("~.inmiddleappears\n")
|
||||
requireRead(expected)
|
||||
require.Equal(t, "", h.escaped())
|
||||
|
||||
// from here on we test \n~ at boundary
|
||||
|
||||
// ~~ after new line; and stop at \n~
|
||||
pw.Write([]byte("~~second line\n~"))
|
||||
expected.WriteString("~second line\n")
|
||||
requireRead(expected)
|
||||
require.Equal(t, "", h.escaped())
|
||||
|
||||
// . to be skipped; stop at \n~ again
|
||||
pw.Write([]byte(".third line\n~"))
|
||||
expected.WriteString("third line\n")
|
||||
requireRead(expected)
|
||||
require.Equal(t, ".", h.escaped())
|
||||
|
||||
// q to be emitted; stop at \n
|
||||
pw.Write([]byte("qfourth line\n"))
|
||||
expected.WriteString("~qfourth line\n")
|
||||
requireRead(expected)
|
||||
require.Equal(t, ".q", h.escaped())
|
||||
|
||||
// ~. to be skipped; stop at \n~
|
||||
pw.Write([]byte("~.fifth line\n~"))
|
||||
expected.WriteString("fifth line\n")
|
||||
requireRead(expected)
|
||||
require.Equal(t, ".q.", h.escaped())
|
||||
|
||||
// ~ alone after \n~ - should be emitted
|
||||
pw.Write([]byte("~"))
|
||||
expected.WriteString("~")
|
||||
requireRead(expected)
|
||||
require.Equal(t, ".q.", h.escaped())
|
||||
|
||||
// rest of line ending with \n~
|
||||
pw.Write([]byte("rest of line\n~"))
|
||||
expected.WriteString("rest of line\n")
|
||||
requireRead(expected)
|
||||
require.Equal(t, ".q.", h.escaped())
|
||||
|
||||
// m alone after \n~ - should be emitted with ~
|
||||
pw.Write([]byte("m"))
|
||||
expected.WriteString("~m")
|
||||
requireRead(expected)
|
||||
require.Equal(t, ".q.m", h.escaped())
|
||||
|
||||
// rest of line and end with \n
|
||||
pw.Write([]byte("onemore line\n"))
|
||||
expected.WriteString("onemore line\n")
|
||||
requireRead(expected)
|
||||
require.Equal(t, ".q.m", h.escaped())
|
||||
|
||||
// ~q to be emitted stop at \n~; last charcater
|
||||
pw.Write([]byte("~qlast line\n~"))
|
||||
expected.WriteString("~qlast line\n")
|
||||
requireRead(expected)
|
||||
require.Equal(t, ".q.mq", h.escaped())
|
||||
|
||||
// last ~ gets emitted and we preserve error
|
||||
eerr := errors.New("my custom error")
|
||||
pw.CloseWithError(eerr)
|
||||
expected.WriteString("~")
|
||||
requireRead(expected)
|
||||
require.Equal(t, ".q.mq", h.escaped())
|
||||
|
||||
wg.Wait()
|
||||
require.Error(t, rerr)
|
||||
require.Equal(t, eerr, rerr)
|
||||
}
|
||||
|
||||
func TestEscapingReader_Generated_EquivalentToNaive(t *testing.T) {
|
||||
f := func(v readingInput) bool {
|
||||
return checkEquivalenceToNaive(t, string(v))
|
||||
|
@ -315,3 +473,21 @@ func (r *arbtiraryReader) Read(buf []byte) (int, error) {
|
|||
|
||||
return r.buf.Read(buf[:l])
|
||||
}
|
||||
|
||||
func assertEventually(t *testing.T, testFn func() (bool, error)) {
|
||||
start := time.Now()
|
||||
var err error
|
||||
var b bool
|
||||
for {
|
||||
if time.Since(start) > 2*time.Second {
|
||||
assert.Fail(t, "timed out", "error: %v", err)
|
||||
}
|
||||
|
||||
b, err = testFn()
|
||||
if b {
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue