475 lines
13 KiB
Go
475 lines
13 KiB
Go
package proto
|
|
|
|
import (
|
|
"fmt"
|
|
"github.com/inconshreveable/muxado/proto/frame"
|
|
"io"
|
|
"net"
|
|
"reflect"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
defaultWindowSize = 0x10000 // 64KB
|
|
defaultAcceptQueueDepth = 100
|
|
MinExtensionType = 0xFFFFFFFF - 0x100 // 512 extensions
|
|
)
|
|
|
|
// private interface for Sessions to call Streams
|
|
type stream interface {
|
|
IStream
|
|
handleStreamData(*frame.RStreamData)
|
|
handleStreamWndInc(*frame.RStreamWndInc)
|
|
handleStreamRst(*frame.RStreamRst)
|
|
closeWith(error)
|
|
}
|
|
|
|
// for extensions
|
|
type ExtAccept func() (IStream, error)
|
|
type Extension interface {
|
|
Start(ISession, ExtAccept) frame.StreamType
|
|
}
|
|
|
|
type deadReason struct {
|
|
errorCode frame.ErrorCode
|
|
err error
|
|
remoteDebug []byte
|
|
}
|
|
|
|
// factory function that creates new streams
|
|
type streamFactory func(id frame.StreamId, priority frame.StreamPriority, streamType frame.StreamType, finLocal bool, finRemote bool, windowSize uint32, sess session) stream
|
|
|
|
// checks the parity of a stream id (local vs remote, client vs server)
|
|
type parityFn func(frame.StreamId) bool
|
|
|
|
// state for each half of the session (remote and local)
|
|
type halfState struct {
|
|
goneAway int32 // true if that half of the stream has gone away
|
|
lastId uint32 // last id used/seen from one half of the session
|
|
}
|
|
|
|
// Session implements a simple streaming session manager. It has the following characteristics:
|
|
//
|
|
// - When closing the Session, it does not linger, all pending write operations will fail immediately.
|
|
// - It completely ignores stream priority when processing and writing frames
|
|
// - It offers no customization of settings like window size/ping time
|
|
type Session struct {
|
|
conn net.Conn // connection the transport is running over
|
|
transport frame.Transport // transport
|
|
streams StreamMap // all active streams
|
|
local halfState // client state
|
|
remote halfState // server state
|
|
syn *frame.WStreamSyn // STREAM_SYN frame for opens
|
|
wr sync.Mutex // synchronization when writing frames
|
|
accept chan stream // new streams opened by the remote
|
|
diebit int32 // true if we're dying
|
|
remoteDebug []byte // debugging data sent in the remote's GoAway frame
|
|
defaultWindowSize uint32 // window size when creating new streams
|
|
newStream streamFactory // factory function to make new streams
|
|
dead chan deadReason // dead
|
|
isLocal parityFn // determines if a stream id is local or remote
|
|
exts map[frame.StreamType]chan stream // map of extension stream type -> accept channel for the extension
|
|
}
|
|
|
|
func NewSession(conn net.Conn, newStream streamFactory, isClient bool, exts []Extension) ISession {
|
|
sess := &Session{
|
|
conn: conn,
|
|
transport: frame.NewBasicTransport(conn),
|
|
streams: NewConcurrentStreamMap(),
|
|
local: halfState{lastId: 0},
|
|
remote: halfState{lastId: 0},
|
|
syn: frame.NewWStreamSyn(),
|
|
diebit: 0,
|
|
defaultWindowSize: defaultWindowSize,
|
|
accept: make(chan stream, defaultAcceptQueueDepth),
|
|
newStream: newStream,
|
|
dead: make(chan deadReason, 1), // don't block die() if there is no Wait call
|
|
exts: make(map[frame.StreamType]chan stream),
|
|
}
|
|
|
|
if isClient {
|
|
sess.isLocal = sess.isClient
|
|
sess.local.lastId += 1
|
|
} else {
|
|
sess.isLocal = sess.isServer
|
|
sess.remote.lastId += 1
|
|
}
|
|
|
|
for _, ext := range exts {
|
|
sess.startExtension(ext)
|
|
}
|
|
|
|
go sess.reader()
|
|
|
|
return sess
|
|
}
|
|
|
|
////////////////////////////////
|
|
// public interface
|
|
////////////////////////////////
|
|
|
|
func (s *Session) Open() (IStream, error) {
|
|
return s.OpenStream(0, 0, false)
|
|
}
|
|
|
|
func (s *Session) OpenStream(priority frame.StreamPriority, streamType frame.StreamType, fin bool) (ret IStream, err error) {
|
|
// check if the remote has gone away
|
|
if atomic.LoadInt32(&s.remote.goneAway) == 1 {
|
|
return nil, fmt.Errorf("Failed to create stream, remote has gone away.")
|
|
}
|
|
|
|
// this lock prevents the following race:
|
|
// goroutine1 goroutine2
|
|
// - inc stream id
|
|
// - inc stream id
|
|
// - send streamsyn
|
|
// - send streamsyn
|
|
s.wr.Lock()
|
|
|
|
// get the next id we can use
|
|
nextId := frame.StreamId(atomic.AddUint32(&s.local.lastId, 2))
|
|
|
|
// make the stream
|
|
str := s.newStream(nextId, priority, streamType, fin, false, s.defaultWindowSize, s)
|
|
|
|
// add to to the stream map
|
|
s.streams.Set(nextId, str)
|
|
|
|
// write the frame
|
|
if err = s.syn.Set(nextId, priority, streamType, fin); err != nil {
|
|
s.wr.Unlock()
|
|
s.die(frame.InternalError, err)
|
|
return
|
|
}
|
|
|
|
if err = s.transport.WriteFrame(s.syn); err != nil {
|
|
s.wr.Unlock()
|
|
s.die(frame.InternalError, err)
|
|
return
|
|
}
|
|
|
|
s.wr.Unlock()
|
|
return str, nil
|
|
}
|
|
|
|
func (s *Session) Accept() (str IStream, err error) {
|
|
var ok bool
|
|
if str, ok = <-s.accept; !ok {
|
|
return nil, fmt.Errorf("Session closed")
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (s *Session) Kill() error {
|
|
return s.transport.Close()
|
|
}
|
|
|
|
func (s *Session) Close() error {
|
|
return s.die(frame.NoError, fmt.Errorf("Session Close()"))
|
|
}
|
|
|
|
func (s *Session) GoAway(errorCode frame.ErrorCode, debug []byte) (err error) {
|
|
if !atomic.CompareAndSwapInt32(&s.local.goneAway, 0, 1) {
|
|
return fmt.Errorf("Already sent GoAway!")
|
|
}
|
|
|
|
s.wr.Lock()
|
|
f := frame.NewWGoAway()
|
|
remoteId := frame.StreamId(atomic.LoadUint32(&s.remote.lastId))
|
|
if err = f.Set(remoteId, errorCode, debug); err != nil {
|
|
s.wr.Unlock()
|
|
s.die(frame.InternalError, err)
|
|
return
|
|
}
|
|
|
|
if err = s.transport.WriteFrame(f); err != nil {
|
|
s.wr.Unlock()
|
|
s.die(frame.InternalError, err)
|
|
return
|
|
}
|
|
|
|
s.wr.Unlock()
|
|
return
|
|
}
|
|
|
|
func (s *Session) LocalAddr() net.Addr {
|
|
return s.conn.LocalAddr()
|
|
}
|
|
|
|
func (s *Session) RemoteAddr() net.Addr {
|
|
return s.conn.RemoteAddr()
|
|
}
|
|
|
|
func (s *Session) Wait() (frame.ErrorCode, error, []byte) {
|
|
reason := <-s.dead
|
|
return reason.errorCode, reason.err, reason.remoteDebug
|
|
}
|
|
|
|
////////////////////////////////
|
|
// private interface for streams
|
|
////////////////////////////////
|
|
|
|
// removeStream removes a stream from this session's stream registry
|
|
//
|
|
// It does not error if the stream is not present
|
|
func (s *Session) removeStream(id frame.StreamId) {
|
|
s.streams.Delete(id)
|
|
return
|
|
}
|
|
|
|
// writeFrame writes the given frame to the transport and returns the error from the write operation
|
|
func (s *Session) writeFrame(f frame.WFrame, dl time.Time) (err error) {
|
|
s.wr.Lock()
|
|
s.conn.SetWriteDeadline(dl)
|
|
err = s.transport.WriteFrame(f)
|
|
s.wr.Unlock()
|
|
return
|
|
}
|
|
|
|
// die closes the session cleanly with the given error and protocol error code
|
|
func (s *Session) die(errorCode frame.ErrorCode, err error) error {
|
|
// only one shutdown ever happens
|
|
if !atomic.CompareAndSwapInt32(&s.diebit, 0, 1) {
|
|
return fmt.Errorf("Shutdown already in progress")
|
|
}
|
|
|
|
// send a go away frame
|
|
s.GoAway(errorCode, []byte(err.Error()))
|
|
|
|
// now we're safe to stop accepting incoming connections
|
|
close(s.accept)
|
|
|
|
// we cleaned up as best as possible, close the transport
|
|
s.transport.Close()
|
|
|
|
// notify all of the streams that we're closing
|
|
s.streams.Each(func(id frame.StreamId, str stream) {
|
|
str.closeWith(fmt.Errorf("Session closed"))
|
|
})
|
|
|
|
s.dead <- deadReason{errorCode, err, s.remoteDebug}
|
|
|
|
return nil
|
|
}
|
|
|
|
////////////////////////////////
|
|
// internal methods
|
|
////////////////////////////////
|
|
|
|
// reader() reads frames from the underlying transport and handles passes them to handleFrame
|
|
func (s *Session) reader() {
|
|
defer s.recoverPanic("reader()")
|
|
|
|
// close all of the extension accept channels when we're done
|
|
// we do this here instead of in die() since otherwise it wouldn't
|
|
// be safe to access s.exts
|
|
defer func() {
|
|
for _, extAccept := range s.exts {
|
|
close(extAccept)
|
|
}
|
|
}()
|
|
|
|
for {
|
|
f, err := s.transport.ReadFrame()
|
|
if err != nil {
|
|
// if we fail to read a frame, terminate the session
|
|
_, ok := err.(*frame.FramingError)
|
|
if ok {
|
|
s.die(frame.ProtocolError, err)
|
|
} else {
|
|
s.die(frame.InternalError, err)
|
|
}
|
|
return
|
|
}
|
|
|
|
s.handleFrame(f)
|
|
}
|
|
}
|
|
|
|
func (s *Session) handleFrame(rf frame.RFrame) {
|
|
switch f := rf.(type) {
|
|
case *frame.RStreamSyn:
|
|
// if we're going away, refuse new streams
|
|
if atomic.LoadInt32(&s.local.goneAway) == 1 {
|
|
rstF := frame.NewWStreamRst()
|
|
rstF.Set(f.StreamId(), frame.RefusedStream)
|
|
go s.writeFrame(rstF, time.Time{})
|
|
return
|
|
}
|
|
|
|
if f.StreamId() <= frame.StreamId(atomic.LoadUint32(&s.remote.lastId)) {
|
|
s.die(frame.ProtocolError, fmt.Errorf("Stream id %d is less than last remote id.", f.StreamId()))
|
|
return
|
|
}
|
|
|
|
if s.isLocal(f.StreamId()) {
|
|
s.die(frame.ProtocolError, fmt.Errorf("Stream id has wrong parity for remote endpoint: %d", f.StreamId()))
|
|
return
|
|
}
|
|
|
|
// update last remote id
|
|
atomic.StoreUint32(&s.remote.lastId, uint32(f.StreamId()))
|
|
|
|
// make the new stream
|
|
str := s.newStream(f.StreamId(), f.StreamPriority(), f.StreamType(), false, f.Fin(), s.defaultWindowSize, s)
|
|
|
|
// add it to the stream map
|
|
s.streams.Set(f.StreamId(), str)
|
|
|
|
// check if this is an extension stream
|
|
if f.StreamType() >= MinExtensionType {
|
|
extAccept, ok := s.exts[f.StreamType()]
|
|
if !ok {
|
|
// Extension type of stream not registered
|
|
fRst := frame.NewWStreamRst()
|
|
if err := fRst.Set(f.StreamId(), frame.StreamClosed); err != nil {
|
|
s.die(frame.InternalError, err)
|
|
}
|
|
|
|
s.wr.Lock()
|
|
defer s.wr.Unlock()
|
|
s.transport.WriteFrame(fRst)
|
|
} else {
|
|
extAccept <- str
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// put the new stream on the accept channel
|
|
s.accept <- str
|
|
|
|
case *frame.RStreamData:
|
|
if str := s.getStream(f.StreamId()); str != nil {
|
|
str.handleStreamData(f)
|
|
} else {
|
|
// if we get a data frame on a non-existent connection, we still
|
|
// need to read out the frame body so that the stream stays in a
|
|
// good state. read the payload into a throwaway buffer
|
|
discard := make([]byte, f.Length())
|
|
io.ReadFull(f.Reader(), discard)
|
|
|
|
// DATA frames on closed connections are just stream-level errors
|
|
fRst := frame.NewWStreamRst()
|
|
if err := fRst.Set(f.StreamId(), frame.StreamClosed); err != nil {
|
|
s.die(frame.InternalError, err)
|
|
}
|
|
|
|
s.wr.Lock()
|
|
defer s.wr.Unlock()
|
|
s.transport.WriteFrame(fRst)
|
|
return
|
|
}
|
|
|
|
case *frame.RStreamRst:
|
|
// delegate to the stream to handle these frames
|
|
if str := s.getStream(f.StreamId()); str != nil {
|
|
str.handleStreamRst(f)
|
|
}
|
|
case *frame.RStreamWndInc:
|
|
// delegate to the stream to handle these frames
|
|
if str := s.getStream(f.StreamId()); str != nil {
|
|
str.handleStreamWndInc(f)
|
|
}
|
|
|
|
case *frame.RGoAway:
|
|
atomic.StoreInt32(&s.remote.goneAway, 1)
|
|
s.remoteDebug = f.Debug()
|
|
|
|
lastId := f.LastStreamId()
|
|
s.streams.Each(func(id frame.StreamId, str stream) {
|
|
// close all streams that we opened above the last handled id
|
|
if s.isLocal(str.Id()) && str.Id() > lastId {
|
|
str.closeWith(fmt.Errorf("Remote is going away"))
|
|
}
|
|
})
|
|
|
|
default:
|
|
s.die(frame.ProtocolError, fmt.Errorf("Unrecognized frame type: %v", reflect.TypeOf(f)))
|
|
return
|
|
}
|
|
}
|
|
|
|
func (s *Session) recoverPanic(prefix string) {
|
|
if r := recover(); r != nil {
|
|
s.die(frame.InternalError, fmt.Errorf("%s panic: %v", prefix, r))
|
|
}
|
|
}
|
|
|
|
func (s *Session) getStream(id frame.StreamId) (str stream) {
|
|
// decide if this id is in the "idle" state (i.e. greater than any we've seen for that parity)
|
|
var lastId *uint32
|
|
if s.isLocal(id) {
|
|
lastId = &s.local.lastId
|
|
} else {
|
|
lastId = &s.remote.lastId
|
|
}
|
|
|
|
if uint32(id) > atomic.LoadUint32(lastId) {
|
|
s.die(frame.ProtocolError, fmt.Errorf("%d is an invalid, unassigned stream id", id))
|
|
}
|
|
|
|
// find the stream in the stream map
|
|
var ok bool
|
|
if str, ok = s.streams.Get(id); !ok {
|
|
return nil
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// check if a stream id is for a client stream. client streams are odd
|
|
func (s *Session) isClient(id frame.StreamId) bool {
|
|
return uint32(id)&1 == 1
|
|
}
|
|
|
|
func (s *Session) isServer(id frame.StreamId) bool {
|
|
return !s.isClient(id)
|
|
}
|
|
|
|
//////////////////////////////////////////////
|
|
// session extensions
|
|
//////////////////////////////////////////////
|
|
func (s *Session) startExtension(ext Extension) {
|
|
accept := make(chan stream)
|
|
extAccept := func() (IStream, error) {
|
|
s, ok := <-accept
|
|
if !ok {
|
|
return nil, fmt.Errorf("Failed to accept connection, shutting down")
|
|
}
|
|
|
|
return s, nil
|
|
}
|
|
|
|
extType := ext.Start(s, extAccept)
|
|
s.exts[extType] = accept
|
|
}
|
|
|
|
//////////////////////////////////////////////
|
|
// net adaptors
|
|
//////////////////////////////////////////////
|
|
func (s *Session) NetDial(_, _ string) (net.Conn, error) {
|
|
str, err := s.Open()
|
|
return net.Conn(str), err
|
|
}
|
|
|
|
func (s *Session) NetListener() net.Listener {
|
|
return &netListenerAdaptor{s}
|
|
}
|
|
|
|
type netListenerAdaptor struct {
|
|
*Session
|
|
}
|
|
|
|
func (a *netListenerAdaptor) Addr() net.Addr {
|
|
return a.LocalAddr()
|
|
}
|
|
|
|
func (a *netListenerAdaptor) Accept() (net.Conn, error) {
|
|
str, err := a.Session.Accept()
|
|
return net.Conn(str), err
|
|
}
|