Merge pull request #5722 from hashicorp/f-nomad-exec-escape-try2

escapingio: handle stalled readers
This commit is contained in:
Mahmood Ali 2019-05-17 14:52:54 -04:00 committed by GitHub
commit ccac5ad3e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 337 additions and 143 deletions

View File

@ -265,6 +265,11 @@ func (l *AllocExecCommand) execImpl(client *api.Client, alloc *api.Allocation, t
stdin = escapingio.NewReader(stdin, escapeChar[0], func(c byte) bool {
switch c {
case '.':
// need to restore tty state so error reporting here
// gets emitted at beginning of line
outCleanup()
inCleanup()
stderr.Write([]byte("\nConnection closed\n"))
cancelFn()
return true
@ -272,7 +277,6 @@ func (l *AllocExecCommand) execImpl(client *api.Client, alloc *api.Allocation, t
return false
}
})
}
}

View File

@ -1,6 +1,7 @@
package escapingio
import (
"bufio"
"io"
)
@ -22,12 +23,16 @@ type Handler func(c byte) bool
//
// Appearances of `~` when not preceded by a new line are propagated unmodified.
func NewReader(r io.Reader, c byte, h Handler) io.Reader {
return &reader{
pr, pw := io.Pipe()
reader := &reader{
impl: r,
escapeChar: c,
state: sLookEscapeChar,
handler: h,
pr: pr,
pw: pw,
}
go reader.pipe()
return reader
}
// lookState represents the state of reader for what character of `\n~.` sequence
@ -52,112 +57,115 @@ type reader struct {
escapeChar uint8
handler Handler
state lookState
// unread is a buffered character for next read if not-nil
unread *byte
// buffers
pw *io.PipeWriter
pr *io.PipeReader
}
func (r *reader) Read(buf []byte) (int, error) {
return r.pr.Read(buf)
}
func (r *reader) pipe() {
rb := make([]byte, 4096)
bw := bufio.NewWriter(r.pw)
state := sLookEscapeChar
for {
n, err := r.impl.Read(rb)
if n > 0 {
state = r.processBuf(bw, rb, n, state)
bw.Flush()
if state == sLookChar {
// terminated with ~ - let's read one more character
n, err = r.impl.Read(rb[:1])
if n == 1 {
state = sLookNewLine
if rb[0] == r.escapeChar {
// only emit escape character once
bw.WriteByte(rb[0])
bw.Flush()
} else if r.handler(rb[0]) {
// skip if handled
} else {
bw.WriteByte(r.escapeChar)
bw.WriteByte(rb[0])
bw.Flush()
}
}
}
}
if err != nil {
// write ~ if it's the last thing
if state == sLookChar {
bw.WriteByte(r.escapeChar)
}
bw.Flush()
r.pw.CloseWithError(err)
break
}
}
}
// processBuf process buffer and emits all output to writer
// if the last part of buffer is a new line followed by sequnce, it writes
// all output until the new line and returns sLookChar
func (r *reader) processBuf(bw io.Writer, buf []byte, n int, s lookState) lookState {
i := 0
wi := 0
START:
var n int
var err error
if r.unread != nil {
// try to return the unread character immediately
// without trying to block for another read
buf[0] = *r.unread
n = 1
r.unread = nil
} else {
n, err = r.impl.Read(buf)
}
// when we get to the end, check if we have any unprocessed \n~
if n == 0 && err != nil {
if r.state == sLookChar && err != nil {
buf[0] = r.escapeChar
n = 1
if s == sLookEscapeChar && buf[i] == r.escapeChar {
if i+1 >= n {
// buf terminates with ~ - write all before
bw.Write(buf[wi:i])
return sLookChar
}
return n, err
}
// inspect the state at beginning of read
if r.state == sLookChar {
r.state = sLookNewLine
// escape character hasn't been emitted yet
if buf[0] == r.escapeChar {
// earlier ~ was swallowed already, so leave this as is
} else if handled := r.handler(buf[0]); handled {
// need to drop a single letter
copy(buf, buf[1:n])
n--
nc := buf[i+1]
if nc == r.escapeChar {
// skip one escape char
bw.Write(buf[wi:i])
i++
wi = i
} else if r.handler(nc) {
// skip both characters
bw.Write(buf[wi:i])
i = i + 2
wi = i
} else {
// we need to re-introduce ~ with rest of body
// but be mindful if reintroducing ~ causes buffer to overflow
if n == len(buf) {
// in which case, save it for next read
c := buf[n-1]
r.unread = &c
copy(buf[1:], buf[:n])
buf[0] = r.escapeChar
} else {
copy(buf[1:], buf[:n])
buf[0] = r.escapeChar
n++
}
i = i + 2
// need to write everything keep going
}
}
n = r.processBuffer(buf, n)
if n == 0 && err == nil {
goto START
}
return n, err
}
// handles escaped character inside body of read buf.
func (r *reader) processBuffer(buf []byte, read int) int {
b := 0
for b < read {
c := buf[b]
if r.state == sLookEscapeChar && r.escapeChar == c {
r.state = sLookEscapeChar
// are we at the end of read; wait for next read
if b == read-1 {
read--
r.state = sLookChar
return read
}
// otherwise peek at next
nc := buf[b+1]
if nc == r.escapeChar {
// repeated ~, only emit one - skip one character
copy(buf[b:], buf[b+1:read])
read--
b++
continue
} else if handled := r.handler(nc); handled {
// need to drop both ~ and letter
copy(buf[b:], buf[b+2:read])
read -= 2
continue
} else {
// need to pass output unmodified with ~ and letter
}
} else if c == '\n' || c == '\r' {
r.state = sLookEscapeChar
} else {
r.state = sLookNewLine
// search until we get \n~, or buf terminates
for {
if i >= n {
// got to end without new line, write and return
bw.Write(buf[wi:n])
return sLookNewLine
}
b++
}
return read
if buf[i] == '\n' || buf[i] == '\r' {
// buf terminated at new line
if i+1 >= n {
bw.Write(buf[wi:n])
return sLookEscapeChar
}
// peek to see escape character go back to START if so
if buf[i+1] == r.escapeChar {
s = sLookEscapeChar
i++
goto START
}
}
i++
}
}

View File

@ -2,17 +2,21 @@ package escapingio
import (
"bytes"
"errors"
"fmt"
"io"
"math/rand"
"reflect"
"regexp"
"strings"
"sync"
"testing"
"testing/iotest"
"testing/quick"
"time"
"unicode"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -42,12 +46,11 @@ func TestEscapingReader_Static(t *testing.T) {
for _, c := range cases {
t.Run("sanity check naive implementation", func(t *testing.T) {
foundEscaped := ""
h := testHandler(&foundEscaped)
h := &testHandler{}
processed := naiveEscapeCharacters(c.input, '~', h)
processed := naiveEscapeCharacters(c.input, '~', h.handler)
require.Equal(t, c.expected, processed)
require.Equal(t, c.escaped, foundEscaped)
require.Equal(t, c.escaped, h.escaped())
})
t.Run("chunks at a time: "+c.input, func(t *testing.T) {
@ -55,16 +58,15 @@ func TestEscapingReader_Static(t *testing.T) {
input := strings.NewReader(c.input)
foundEscaped := ""
h := testHandler(&foundEscaped)
h := &testHandler{}
filter := NewReader(input, '~', h)
filter := NewReader(input, '~', h.handler)
_, err := io.Copy(&found, filter)
require.NoError(t, err)
require.Equal(t, c.expected, found.String())
require.Equal(t, c.escaped, foundEscaped)
require.Equal(t, c.escaped, h.escaped())
})
t.Run("1 byte at a time: "+c.input, func(t *testing.T) {
@ -72,19 +74,173 @@ func TestEscapingReader_Static(t *testing.T) {
input := iotest.OneByteReader(strings.NewReader(c.input))
foundEscaped := ""
h := testHandler(&foundEscaped)
h := &testHandler{}
filter := NewReader(input, '~', h)
filter := NewReader(input, '~', h.handler)
_, err := io.Copy(&found, filter)
require.NoError(t, err)
require.Equal(t, c.expected, found.String())
require.Equal(t, c.escaped, foundEscaped)
require.Equal(t, c.escaped, h.escaped())
})
t.Run("without reading: "+c.input, func(t *testing.T) {
input := strings.NewReader(c.input)
h := &testHandler{}
filter := NewReader(input, '~', h.handler)
// don't read to mimic a stalled reader
_ = filter
assertEventually(t, func() (bool, error) {
escaped := h.escaped()
if c.escaped == escaped {
return true, nil
}
return false, fmt.Errorf("expected %v but found %v", c.escaped, escaped)
})
})
}
}
// TestEscapingReader_EmitsPartialReads should emit partial results
// if next character is not read
func TestEscapingReader_FlushesPartialReads(t *testing.T) {
pr, pw := io.Pipe()
h := &testHandler{}
filter := NewReader(pr, '~', h.handler)
var lock sync.Mutex
var read bytes.Buffer
// helper for asserting reads
requireRead := func(expected *bytes.Buffer) {
readSoFar := ""
start := time.Now()
for time.Since(start) < 2*time.Second {
lock.Lock()
readSoFar = read.String()
lock.Unlock()
if readSoFar == expected.String() {
break
}
time.Sleep(50 * time.Millisecond)
}
require.Equal(t, expected.String(), readSoFar, "timed out without output")
}
var rerr error
var wg sync.WaitGroup
wg.Add(1)
// goroutine for reading partial data
go func() {
defer wg.Done()
buf := make([]byte, 1024)
for {
n, err := filter.Read(buf)
lock.Lock()
read.Write(buf[:n])
lock.Unlock()
if err != nil {
rerr = err
break
}
}
}()
expected := &bytes.Buffer{}
// test basic start and no new lines
pw.Write([]byte("first data"))
expected.WriteString("first data")
requireRead(expected)
require.Equal(t, "", h.escaped())
// test ~. appearing in middle of line but stop at new line
pw.Write([]byte("~.inmiddleappears\n"))
expected.WriteString("~.inmiddleappears\n")
requireRead(expected)
require.Equal(t, "", h.escaped())
// from here on we test \n~ at boundary
// ~~ after new line; and stop at \n~
pw.Write([]byte("~~second line\n~"))
expected.WriteString("~second line\n")
requireRead(expected)
require.Equal(t, "", h.escaped())
// . to be skipped; stop at \n~ again
pw.Write([]byte(".third line\n~"))
expected.WriteString("third line\n")
requireRead(expected)
require.Equal(t, ".", h.escaped())
// q to be emitted; stop at \n
pw.Write([]byte("qfourth line\n"))
expected.WriteString("~qfourth line\n")
requireRead(expected)
require.Equal(t, ".q", h.escaped())
// ~. to be skipped; stop at \n~
pw.Write([]byte("~.fifth line\n~"))
expected.WriteString("fifth line\n")
requireRead(expected)
require.Equal(t, ".q.", h.escaped())
// ~ alone after \n~ - should be emitted
pw.Write([]byte("~"))
expected.WriteString("~")
requireRead(expected)
require.Equal(t, ".q.", h.escaped())
// rest of line ending with \n~
pw.Write([]byte("rest of line\n~"))
expected.WriteString("rest of line\n")
requireRead(expected)
require.Equal(t, ".q.", h.escaped())
// m alone after \n~ - should be emitted with ~
pw.Write([]byte("m"))
expected.WriteString("~m")
requireRead(expected)
require.Equal(t, ".q.m", h.escaped())
// rest of line and end with \n
pw.Write([]byte("onemore line\n"))
expected.WriteString("onemore line\n")
requireRead(expected)
require.Equal(t, ".q.m", h.escaped())
// ~q to be emitted stop at \n~; last charcater
pw.Write([]byte("~qlast line\n~"))
expected.WriteString("~qlast line\n")
requireRead(expected)
require.Equal(t, ".q.mq", h.escaped())
// last ~ gets emitted and we preserve error
eerr := errors.New("my custom error")
pw.CloseWithError(eerr)
expected.WriteString("~")
requireRead(expected)
require.Equal(t, ".q.mq", h.escaped())
wg.Wait()
require.Error(t, rerr)
require.Equal(t, eerr, rerr)
}
func TestEscapingReader_Generated_EquivalentToNaive(t *testing.T) {
f := func(v readingInput) bool {
return checkEquivalenceToNaive(t, string(v))
@ -95,40 +251,52 @@ func TestEscapingReader_Generated_EquivalentToNaive(t *testing.T) {
}))
}
// testHandler returns a handler that stores all basic ascii letters in result
// reference. We avoid complicated unicode characters that may cross
// byte boundary
func testHandler(result *string) Handler {
return func(c byte) bool {
rc := rune(c)
simple := unicode.IsLetter(rc) ||
unicode.IsDigit(rc) ||
unicode.IsPunct(rc) ||
unicode.IsSymbol(rc)
// testHandler is a conveneient struct for finding "escaped" ascii letters
// in escaping reader.
// We avoid complicated unicode characters that may cross byte boundary
type testHandler struct {
l sync.Mutex
result string
}
if simple {
*result += string([]byte{c})
}
return c == '.'
// handler is method to be passed to escaping io reader
func (t *testHandler) handler(c byte) bool {
rc := rune(c)
simple := unicode.IsLetter(rc) ||
unicode.IsDigit(rc) ||
unicode.IsPunct(rc) ||
unicode.IsSymbol(rc)
if simple {
t.l.Lock()
t.result += string([]byte{c})
t.l.Unlock()
}
return c == '.'
}
// escaped returns all seen escaped characters so far
func (t *testHandler) escaped() string {
t.l.Lock()
defer t.l.Unlock()
return t.result
}
// checkEquivalence returns true if parsing input with naive implementation
// is equivalent to our reader
func checkEquivalenceToNaive(t *testing.T, input string) bool {
nfe := ""
nh := testHandler(&nfe)
expected := naiveEscapeCharacters(input, '~', nh)
nh := &testHandler{}
expected := naiveEscapeCharacters(input, '~', nh.handler)
foundEscaped := ""
h := testHandler(&foundEscaped)
foundH := &testHandler{}
var inputReader io.Reader = bytes.NewBufferString(input)
inputReader = &arbtiraryReader{
buf: inputReader.(*bytes.Buffer),
maxReadOnce: 10,
}
filter := NewReader(inputReader, '~', h)
filter := NewReader(inputReader, '~', foundH.handler)
var found bytes.Buffer
_, err := io.Copy(&found, filter)
if err != nil {
@ -136,11 +304,11 @@ func checkEquivalenceToNaive(t *testing.T, input string) bool {
return false
}
if nfe == foundEscaped && expected == found.String() {
if nh.escaped() == foundH.escaped() && expected == found.String() {
return true
}
t.Logf("escaped differed=%v expected=%v found=%v", nfe != foundEscaped, nfe, foundEscaped)
t.Logf("escaped differed=%v expected=%v found=%v", nh.escaped() != foundH.escaped(), nh.escaped(), foundH.escaped())
t.Logf("read differed=%v expected=%s found=%v", expected != found.String(), expected, found.String())
return false
@ -159,15 +327,13 @@ func TestEscapingReader_Generated_EquivalentToReadOnce(t *testing.T) {
// checkEquivalenceToReadOnce returns true if parsing input in a single
// read matches multiple reads
func checkEquivalenceToReadOnce(t *testing.T, input string) bool {
nfe := ""
nh := &testHandler{}
var expected bytes.Buffer
// getting expected value from read all at once
{
h := testHandler(&nfe)
buf := make([]byte, len(input)+5)
inputReader := NewReader(bytes.NewBufferString(input), '~', h)
inputReader := NewReader(bytes.NewBufferString(input), '~', nh.handler)
_, err := io.CopyBuffer(&expected, inputReader, buf)
if err != nil {
t.Logf("unexpected error while reading: %v", err)
@ -175,18 +341,16 @@ func checkEquivalenceToReadOnce(t *testing.T, input string) bool {
}
}
foundEscaped := ""
foundH := &testHandler{}
var found bytes.Buffer
// getting found by using arbitrary reader
{
h := testHandler(&foundEscaped)
inputReader := &arbtiraryReader{
buf: bytes.NewBufferString(input),
maxReadOnce: 10,
}
filter := NewReader(inputReader, '~', h)
filter := NewReader(inputReader, '~', foundH.handler)
_, err := io.Copy(&found, filter)
if err != nil {
t.Logf("unexpected error while reading: %v", err)
@ -194,11 +358,11 @@ func checkEquivalenceToReadOnce(t *testing.T, input string) bool {
}
}
if nfe == foundEscaped && expected.String() == found.String() {
if nh.escaped() == foundH.escaped() && expected.String() == found.String() {
return true
}
t.Logf("escaped differed=%v expected=%v found=%v", nfe != foundEscaped, nfe, foundEscaped)
t.Logf("escaped differed=%v expected=%v found=%v", nh.escaped() != foundH.escaped(), nh.escaped(), foundH.escaped())
t.Logf("read differed=%v expected=%s found=%v", expected.String() != found.String(), expected.String(), found.String())
return false
@ -309,3 +473,21 @@ func (r *arbtiraryReader) Read(buf []byte) (int, error) {
return r.buf.Read(buf[:l])
}
func assertEventually(t *testing.T, testFn func() (bool, error)) {
start := time.Now()
var err error
var b bool
for {
if time.Since(start) > 2*time.Second {
assert.Fail(t, "timed out", "error: %v", err)
}
b, err = testFn()
if b {
return
}
time.Sleep(50 * time.Millisecond)
}
}