open-nomad/helper/gated-writer/writer_test.go
2023-04-10 15:36:59 +00:00

130 lines
2.1 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package gatedwriter
import (
"bytes"
"io"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestWriter_impl(t *testing.T) {
var _ io.Writer = new(Writer)
}
type slowTestWriter struct {
buf *bytes.Buffer
called chan struct{}
callCount int
}
func (w *slowTestWriter) Write(p []byte) (int, error) {
if w.callCount == 0 {
defer close(w.called)
}
w.callCount++
time.Sleep(time.Millisecond)
return w.buf.Write(p)
}
func TestWriter_WithSlowWriter(t *testing.T) {
buf := new(bytes.Buffer)
called := make(chan struct{})
w := &slowTestWriter{
buf: buf,
called: called,
}
writer := &Writer{Writer: w}
writer.Write([]byte("foo\n"))
writer.Write([]byte("bar\n"))
writer.Write([]byte("baz\n"))
flushed := make(chan struct{})
go func() {
writer.Flush()
close(flushed)
}()
// wait for the flush to call Write on slowTestWriter
<-called
// write to the now-flushing writer, which is no longer buffering
writer.Write([]byte("quux\n"))
// wait for the flush to finish to assert
<-flushed
require.Equal(t, "foo\nbar\nbaz\nquux\n", buf.String())
}
func TestWriter(t *testing.T) {
buf := new(bytes.Buffer)
w := &Writer{Writer: buf}
w.Write([]byte("foo\n"))
w.Write([]byte("bar\n"))
if buf.String() != "" {
t.Fatalf("bad: %s", buf.String())
}
w.Flush()
if buf.String() != "foo\nbar\n" {
t.Fatalf("bad: %s", buf.String())
}
w.Write([]byte("baz\n"))
if buf.String() != "foo\nbar\nbaz\n" {
t.Fatalf("bad: %s", buf.String())
}
}
func TestWriter_WithMultipleWriters(t *testing.T) {
buf := new(bytes.Buffer)
writer := &Writer{Writer: buf}
strs := []string{
"foo\n",
"bar\n",
"baz\n",
"quux\n",
}
waitCh := make(chan struct{})
wg := &sync.WaitGroup{}
for _, str := range strs {
str := str
wg.Add(1)
go func() {
defer wg.Done()
<-waitCh
writer.Write([]byte(str))
}()
}
// synchronize calls to Write() as closely as possible
close(waitCh)
wg.Wait()
writer.Flush()
require.Equal(t, strings.Count(buf.String(), "\n"), len(strs))
}