Add a escaping reader that mimics ssh behavior
Adds an escaping reading that mimics ssh handling of input escape sequences. The reader parses chunks to look for \n~
This commit is contained in:
parent
84cc5ddd57
commit
b02852ef62
|
@ -0,0 +1,163 @@
|
|||
package escapingio
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
// Handler is a callback for handling an escaped char. Reader would skip
|
||||
// the escape char and passed char if returns true; otherwise, it preserves them
|
||||
// in output
|
||||
type Handler func(c byte) bool
|
||||
|
||||
// NewReader returns a reader that escapes the c character (following new lines),
|
||||
// in the same manner OpenSSH handling, which defaults to `~`.
|
||||
//
|
||||
// For illustrative purposes, we use `~` in documentation as a shorthand for escaping character.
|
||||
//
|
||||
// If following a new line, reader sees:
|
||||
// * `~~`, only one is emitted
|
||||
// * `~.` (or any character), the handler is invoked with the character.
|
||||
// If handler returns true, `~.` will be skipped; otherwise, it's propagated.
|
||||
// * `~` and it's the last character in stream, it's propagated
|
||||
//
|
||||
// Appearances of `~` when not followed by a new line is propagated unmodified.
|
||||
func NewReader(r io.Reader, c byte, h Handler) io.Reader {
|
||||
return &reader{
|
||||
impl: r,
|
||||
escapeChar: c,
|
||||
state: sLookEscapeChar,
|
||||
handler: h,
|
||||
}
|
||||
}
|
||||
|
||||
// lookState represents the state of reader for what character of `\n~.` sequence
|
||||
// reader is looking for
|
||||
type lookState int
|
||||
|
||||
const (
|
||||
// sLookNewLine indicates that reader is looking for new line
|
||||
sLookNewLine lookState = iota
|
||||
|
||||
// sLookEscapeChar indicates that reader is looking for ~
|
||||
sLookEscapeChar
|
||||
|
||||
// sLookChar indicates that reader just read `~` is waiting for next character
|
||||
// before acting
|
||||
sLookChar
|
||||
)
|
||||
|
||||
// to ease comments, i'll assume escape character to be `~`
|
||||
type reader struct {
|
||||
impl io.Reader
|
||||
escapeChar uint8
|
||||
handler Handler
|
||||
|
||||
state lookState
|
||||
|
||||
// unread is a buffered character for next read if not-nil
|
||||
unread *byte
|
||||
}
|
||||
|
||||
func (r *reader) Read(buf []byte) (int, error) {
|
||||
START:
|
||||
var n int
|
||||
var err error
|
||||
|
||||
if r.unread != nil {
|
||||
// try to return the unread character immediately
|
||||
// without trying to block for another read
|
||||
buf[0] = *r.unread
|
||||
n = 1
|
||||
r.unread = nil
|
||||
} else {
|
||||
n, err = r.impl.Read(buf)
|
||||
}
|
||||
|
||||
// when we get to the end, check if we have any unprocessed \n~
|
||||
if n == 0 && err != nil {
|
||||
if r.state == sLookChar && err != nil {
|
||||
buf[0] = r.escapeChar
|
||||
n = 1
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// inspect the state at beginning of read
|
||||
if r.state == sLookChar {
|
||||
r.state = sLookNewLine
|
||||
|
||||
// escape character hasn't been emitted yet
|
||||
if buf[0] == r.escapeChar {
|
||||
// earlier ~ was sallowed already, so leave this as is
|
||||
} else if handled := r.handler(buf[0]); handled {
|
||||
// need to drop a single letter
|
||||
copy(buf, buf[1:n])
|
||||
n--
|
||||
} else {
|
||||
// we need to re-introduce ~ with rest of body
|
||||
// but be mindful if reintroducing ~ causes buffer to overflow
|
||||
if n == len(buf) {
|
||||
// in which case, save it for next read
|
||||
c := buf[n-1]
|
||||
r.unread = &c
|
||||
copy(buf[1:], buf[:n])
|
||||
buf[0] = r.escapeChar
|
||||
} else {
|
||||
copy(buf[1:], buf[:n])
|
||||
buf[0] = r.escapeChar
|
||||
n++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
n = r.processBuffer(buf, n)
|
||||
if n == 0 && err == nil {
|
||||
goto START
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// handles escaped character inside body of read buf.
|
||||
func (r *reader) processBuffer(buf []byte, read int) int {
|
||||
b := 0
|
||||
|
||||
for b < read {
|
||||
|
||||
c := buf[b]
|
||||
if r.state == sLookEscapeChar && r.escapeChar == c {
|
||||
r.state = sLookEscapeChar
|
||||
|
||||
// are we at the end of read; wait for next read
|
||||
if b == read-1 {
|
||||
read--
|
||||
r.state = sLookChar
|
||||
return read
|
||||
}
|
||||
|
||||
// otherwise peek at next
|
||||
nc := buf[b+1]
|
||||
if nc == r.escapeChar {
|
||||
// repeated ~, only emit one - skip one character
|
||||
copy(buf[b:], buf[b+1:read])
|
||||
read--
|
||||
b++
|
||||
continue
|
||||
} else if handled := r.handler(nc); handled {
|
||||
// need to drop both ~ and letter
|
||||
copy(buf[b:], buf[b+2:read])
|
||||
read -= 2
|
||||
continue
|
||||
} else {
|
||||
// need to pass output unmodified with ~ and letter
|
||||
}
|
||||
} else if c == '\n' || c == '\r' {
|
||||
r.state = sLookEscapeChar
|
||||
} else {
|
||||
r.state = sLookNewLine
|
||||
}
|
||||
b++
|
||||
}
|
||||
|
||||
return read
|
||||
}
|
|
@ -0,0 +1,319 @@
|
|||
package escapingio
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/iotest"
|
||||
"testing/quick"
|
||||
"unicode"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEscapingReader_Static(t *testing.T) {
|
||||
cases := []struct {
|
||||
input string
|
||||
expected string
|
||||
escaped string
|
||||
}{
|
||||
{"hello", "hello", ""},
|
||||
{"he\nllo", "he\nllo", ""},
|
||||
{"he~.lo", "he~.lo", ""},
|
||||
{"he\n~.rest", "he\nrest", "."},
|
||||
{"he\n~.r\n~.est", "he\nr\nest", ".."},
|
||||
{"he\n~~r\n~~est", "he\n~r\n~est", ""},
|
||||
{"he\n~~r\n~.est", "he\n~r\nest", "."},
|
||||
{"he\nr~~est", "he\nr~~est", ""},
|
||||
{"he\nr\n~qest", "he\nr\n~qest", "q"},
|
||||
{"he\nr\r~qe\r~.st", "he\nr\r~qe\rst", "q."},
|
||||
{"~q", "~q", "q"},
|
||||
{"~.", "", "."},
|
||||
{"m~.", "m~.", ""},
|
||||
{"\n~.", "\n", "."},
|
||||
{"~", "~", ""},
|
||||
{"\r~.", "\r", "."},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run("sanity check naive implementation", func(t *testing.T) {
|
||||
foundEscaped := ""
|
||||
h := testHandler(&foundEscaped)
|
||||
|
||||
processed := naiveEscapeCharacters(c.input, '~', h)
|
||||
require.Equal(t, c.expected, processed)
|
||||
require.Equal(t, c.escaped, foundEscaped)
|
||||
})
|
||||
|
||||
t.Run("chunks at a time: "+c.input, func(t *testing.T) {
|
||||
var found bytes.Buffer
|
||||
|
||||
input := strings.NewReader(c.input)
|
||||
|
||||
foundEscaped := ""
|
||||
h := testHandler(&foundEscaped)
|
||||
|
||||
filter := NewReader(input, '~', h)
|
||||
|
||||
_, err := io.Copy(&found, filter)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, c.expected, found.String())
|
||||
require.Equal(t, c.escaped, foundEscaped)
|
||||
})
|
||||
|
||||
t.Run("1 byte at a time: "+c.input, func(t *testing.T) {
|
||||
var found bytes.Buffer
|
||||
|
||||
input := iotest.OneByteReader(strings.NewReader(c.input))
|
||||
|
||||
foundEscaped := ""
|
||||
h := testHandler(&foundEscaped)
|
||||
|
||||
filter := NewReader(input, '~', h)
|
||||
_, err := io.Copy(&found, filter)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, c.expected, found.String())
|
||||
require.Equal(t, c.escaped, foundEscaped)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEscapingReader_Generated_EquivalentToNaive(t *testing.T) {
|
||||
called := 0
|
||||
f := func(v readingInput) bool {
|
||||
called++
|
||||
return checkEquivalenceToNaive(t, string(v))
|
||||
}
|
||||
|
||||
require.NoError(t, quick.Check(f, &quick.Config{
|
||||
MaxCountScale: 200,
|
||||
}))
|
||||
|
||||
fmt.Println("CALLED ", called)
|
||||
}
|
||||
|
||||
// testHandler returns a handler that stores all basic ascii letters in result
|
||||
// reference. We avoid complicated unicode characters that may cross
|
||||
// byte boundary
|
||||
func testHandler(result *string) Handler {
|
||||
return func(c byte) bool {
|
||||
rc := rune(c)
|
||||
simple := unicode.IsLetter(rc) ||
|
||||
unicode.IsDigit(rc) ||
|
||||
unicode.IsPunct(rc) ||
|
||||
unicode.IsSymbol(rc)
|
||||
|
||||
if simple {
|
||||
*result += string([]byte{c})
|
||||
}
|
||||
return c == '.'
|
||||
}
|
||||
}
|
||||
|
||||
// checkEquivalence returns true if parsing input with naive implementation
|
||||
// is equivalent to our reader
|
||||
func checkEquivalenceToNaive(t *testing.T, input string) bool {
|
||||
nfe := ""
|
||||
nh := testHandler(&nfe)
|
||||
expected := naiveEscapeCharacters(input, '~', nh)
|
||||
|
||||
foundEscaped := ""
|
||||
h := testHandler(&foundEscaped)
|
||||
|
||||
var inputReader io.Reader = bytes.NewBufferString(input)
|
||||
inputReader = &arbtiraryReader{
|
||||
buf: inputReader.(*bytes.Buffer),
|
||||
maxReadOnce: 10,
|
||||
}
|
||||
filter := NewReader(inputReader, '~', h)
|
||||
var found bytes.Buffer
|
||||
_, err := io.Copy(&found, filter)
|
||||
if err != nil {
|
||||
t.Logf("unexpected error while reading: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if nfe == foundEscaped && expected == found.String() {
|
||||
return true
|
||||
}
|
||||
|
||||
t.Logf("escaped differed=%v expected=%v found=%v", nfe != foundEscaped, nfe, foundEscaped)
|
||||
t.Logf("read differed=%v expected=%s found=%v", expected != found.String(), expected, found.String())
|
||||
return false
|
||||
|
||||
}
|
||||
|
||||
func TestEscapingReader_Generated_EquivalentToReadOnce(t *testing.T) {
|
||||
called := 0
|
||||
f := func(v readingInput) bool {
|
||||
called++
|
||||
return checkEquivalenceToNaive(t, string(v))
|
||||
}
|
||||
|
||||
require.NoError(t, quick.Check(f, &quick.Config{
|
||||
MaxCountScale: 200,
|
||||
}))
|
||||
|
||||
fmt.Println("CALLED ", called)
|
||||
}
|
||||
|
||||
// checkEquivalenceToReadOnce returns true if parsing input in a single
|
||||
// read matches multiple reads
|
||||
func checkEquivalenceToReadOnce(t *testing.T, input string) bool {
|
||||
nfe := ""
|
||||
var expected bytes.Buffer
|
||||
|
||||
// getting expected value from read all at once
|
||||
{
|
||||
h := testHandler(&nfe)
|
||||
|
||||
buf := make([]byte, len(input)+5)
|
||||
inputReader := NewReader(bytes.NewBufferString(input), '~', h)
|
||||
_, err := io.CopyBuffer(&expected, inputReader, buf)
|
||||
if err != nil {
|
||||
t.Logf("unexpected error while reading: %v", err)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
foundEscaped := ""
|
||||
var found bytes.Buffer
|
||||
|
||||
// getting found by using arbitrary reader
|
||||
{
|
||||
h := testHandler(&foundEscaped)
|
||||
|
||||
inputReader := &arbtiraryReader{
|
||||
buf: bytes.NewBufferString(input),
|
||||
maxReadOnce: 10,
|
||||
}
|
||||
filter := NewReader(inputReader, '~', h)
|
||||
_, err := io.Copy(&found, filter)
|
||||
if err != nil {
|
||||
t.Logf("unexpected error while reading: %v", err)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if nfe == foundEscaped && expected.String() == found.String() {
|
||||
return true
|
||||
}
|
||||
|
||||
t.Logf("escaped differed=%v expected=%v found=%v", nfe != foundEscaped, nfe, foundEscaped)
|
||||
t.Logf("read differed=%v expected=%s found=%v", expected.String() != found.String(), expected.String(), found.String())
|
||||
return false
|
||||
|
||||
}
|
||||
|
||||
// readingInput is a string with some quick generation capability to
|
||||
// inject some \n, \n~., \n~q in text
|
||||
type readingInput string
|
||||
|
||||
func (i readingInput) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||
v, ok := quick.Value(reflect.TypeOf(""), rand)
|
||||
if !ok {
|
||||
panic("couldn't generate a string")
|
||||
}
|
||||
|
||||
// inject some terminals
|
||||
var b bytes.Buffer
|
||||
injectProbabilistically := func() {
|
||||
p := rand.Float32()
|
||||
if p < 0.05 {
|
||||
b.WriteString("\n~.")
|
||||
} else if p < 0.10 {
|
||||
b.WriteString("\n~q")
|
||||
} else if p < 0.15 {
|
||||
b.WriteString("\n")
|
||||
} else if p < 0.2 {
|
||||
b.WriteString("~")
|
||||
} else if p < 0.25 {
|
||||
b.WriteString("~~")
|
||||
}
|
||||
}
|
||||
|
||||
for _, c := range v.String() {
|
||||
injectProbabilistically()
|
||||
b.WriteRune(c)
|
||||
}
|
||||
|
||||
injectProbabilistically()
|
||||
|
||||
return reflect.ValueOf(readingInput(b.String()))
|
||||
}
|
||||
|
||||
// naiveEscapeCharacters is a simplified implementation that operates
|
||||
// on entire unchunked string. Uses regexp implementation.
|
||||
//
|
||||
// It differs from the other implementation in handling unicode characters
|
||||
// proceeding `\n~`
|
||||
func naiveEscapeCharacters(input string, escapeChar byte, h Handler) string {
|
||||
reg := regexp.MustCompile(fmt.Sprintf("(\n|\r)%c.", escapeChar))
|
||||
|
||||
// check first appearances
|
||||
if len(input) > 1 && input[0] == escapeChar {
|
||||
if input[1] == escapeChar {
|
||||
input = input[1:]
|
||||
} else if h(input[1]) {
|
||||
input = input[2:]
|
||||
} else {
|
||||
// we are good
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return reg.ReplaceAllStringFunc(input, func(match string) string {
|
||||
if len(match) != 3 {
|
||||
panic(fmt.Errorf("match isn't 3 characters: %s", match))
|
||||
}
|
||||
|
||||
c := match[2]
|
||||
|
||||
// ignore some unicode partial codes
|
||||
ltr := ('a' <= c && c <= 'z') ||
|
||||
('A' <= c && c <= 'Z') ||
|
||||
('0' <= c && c <= '9') ||
|
||||
(c == '~' || c == '.' || c == escapeChar)
|
||||
|
||||
if c == escapeChar {
|
||||
return match[:2]
|
||||
} else if ltr && h(c) {
|
||||
return match[:1]
|
||||
} else {
|
||||
return match
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// arbitraryReader is a reader that reads arbitrary length at a time
|
||||
// to simulate input being read in chunks.
|
||||
type arbtiraryReader struct {
|
||||
buf *bytes.Buffer
|
||||
maxReadOnce int
|
||||
}
|
||||
|
||||
func (r *arbtiraryReader) Read(buf []byte) (int, error) {
|
||||
l := r.buf.Len()
|
||||
if l == 0 || l == 1 {
|
||||
return r.buf.Read(buf)
|
||||
}
|
||||
|
||||
if l > r.maxReadOnce {
|
||||
l = r.maxReadOnce
|
||||
}
|
||||
if l != 1 {
|
||||
l = rand.Intn(l-1) + 1
|
||||
}
|
||||
if l > len(buf) {
|
||||
l = len(buf)
|
||||
}
|
||||
|
||||
return r.buf.Read(buf[:l])
|
||||
}
|
Loading…
Reference in New Issue