diff --git a/lib/eof.go b/lib/eof.go index f77844fd6..e29d15110 100644 --- a/lib/eof.go +++ b/lib/eof.go @@ -1,7 +1,10 @@ package lib import ( + "errors" + "fmt" "io" + "net/rpc" "strings" "github.com/hashicorp/yamux" @@ -13,7 +16,7 @@ var yamuxSessionShutdown = yamux.ErrSessionShutdown.Error() // IsErrEOF returns true if we get an EOF error from the socket itself, or // an EOF equivalent error from yamux. func IsErrEOF(err error) bool { - if err == io.EOF { + if errors.Is(err, io.EOF) { return true } @@ -23,5 +26,10 @@ func IsErrEOF(err error) bool { return true } + var serverError rpc.ServerError + if errors.As(err, &serverError) { + return strings.HasSuffix(err.Error(), fmt.Sprintf(": %s", io.EOF.Error())) + } + return false } diff --git a/lib/eof_test.go b/lib/eof_test.go new file mode 100644 index 000000000..38106ae99 --- /dev/null +++ b/lib/eof_test.go @@ -0,0 +1,31 @@ +package lib + +import ( + "fmt" + "io" + "net/rpc" + "testing" + + "github.com/hashicorp/yamux" + "github.com/stretchr/testify/require" +) + +func TestErrIsEOF(t *testing.T) { + var tests = []struct { + name string + err error + }{ + {name: "EOF", err: io.EOF}, + {name: "Wrapped EOF", err: fmt.Errorf("test: %w", io.EOF)}, + {name: "yamuxStreamClosed", err: yamux.ErrStreamClosed}, + {name: "yamuxSessionShutdown", err: yamux.ErrSessionShutdown}, + {name: "ServerError(___: EOF)", err: rpc.ServerError(fmt.Sprintf("rpc error: %s", io.EOF.Error()))}, + {name: "Wrapped ServerError(___: EOF)", err: fmt.Errorf("rpc error: %w", rpc.ServerError(fmt.Sprintf("rpc error: %s", io.EOF.Error())))}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.True(t, IsErrEOF(tt.err)) + }) + } +}