package proto import ( "fmt" "github.com/inconshreveable/muxado/proto/frame" "io" "net" "reflect" "sync" "sync/atomic" "time" ) const ( defaultWindowSize = 0x10000 // 64KB defaultAcceptQueueDepth = 100 MinExtensionType = 0xFFFFFFFF - 0x100 // 512 extensions ) // private interface for Sessions to call Streams type stream interface { IStream handleStreamData(*frame.RStreamData) handleStreamWndInc(*frame.RStreamWndInc) handleStreamRst(*frame.RStreamRst) closeWith(error) } // for extensions type ExtAccept func() (IStream, error) type Extension interface { Start(ISession, ExtAccept) frame.StreamType } type deadReason struct { errorCode frame.ErrorCode err error remoteDebug []byte } // factory function that creates new streams type streamFactory func(id frame.StreamId, priority frame.StreamPriority, streamType frame.StreamType, finLocal bool, finRemote bool, windowSize uint32, sess session) stream // checks the parity of a stream id (local vs remote, client vs server) type parityFn func(frame.StreamId) bool // state for each half of the session (remote and local) type halfState struct { goneAway int32 // true if that half of the stream has gone away lastId uint32 // last id used/seen from one half of the session } // Session implements a simple streaming session manager. It has the following characteristics: // // - When closing the Session, it does not linger, all pending write operations will fail immediately. // - It completely ignores stream priority when processing and writing frames // - It offers no customization of settings like window size/ping time type Session struct { conn net.Conn // connection the transport is running over transport frame.Transport // transport streams StreamMap // all active streams local halfState // client state remote halfState // server state syn *frame.WStreamSyn // STREAM_SYN frame for opens wr sync.Mutex // synchronization when writing frames accept chan stream // new streams opened by the remote diebit int32 // true if we're dying remoteDebug []byte // debugging data sent in the remote's GoAway frame defaultWindowSize uint32 // window size when creating new streams newStream streamFactory // factory function to make new streams dead chan deadReason // dead isLocal parityFn // determines if a stream id is local or remote exts map[frame.StreamType]chan stream // map of extension stream type -> accept channel for the extension } func NewSession(conn net.Conn, newStream streamFactory, isClient bool, exts []Extension) ISession { sess := &Session{ conn: conn, transport: frame.NewBasicTransport(conn), streams: NewConcurrentStreamMap(), local: halfState{lastId: 0}, remote: halfState{lastId: 0}, syn: frame.NewWStreamSyn(), diebit: 0, defaultWindowSize: defaultWindowSize, accept: make(chan stream, defaultAcceptQueueDepth), newStream: newStream, dead: make(chan deadReason, 1), // don't block die() if there is no Wait call exts: make(map[frame.StreamType]chan stream), } if isClient { sess.isLocal = sess.isClient sess.local.lastId += 1 } else { sess.isLocal = sess.isServer sess.remote.lastId += 1 } for _, ext := range exts { sess.startExtension(ext) } go sess.reader() return sess } //////////////////////////////// // public interface //////////////////////////////// func (s *Session) Open() (IStream, error) { return s.OpenStream(0, 0, false) } func (s *Session) OpenStream(priority frame.StreamPriority, streamType frame.StreamType, fin bool) (ret IStream, err error) { // check if the remote has gone away if atomic.LoadInt32(&s.remote.goneAway) == 1 { return nil, fmt.Errorf("Failed to create stream, remote has gone away.") } // this lock prevents the following race: // goroutine1 goroutine2 // - inc stream id // - inc stream id // - send streamsyn // - send streamsyn s.wr.Lock() // get the next id we can use nextId := frame.StreamId(atomic.AddUint32(&s.local.lastId, 2)) // make the stream str := s.newStream(nextId, priority, streamType, fin, false, s.defaultWindowSize, s) // add to to the stream map s.streams.Set(nextId, str) // write the frame if err = s.syn.Set(nextId, priority, streamType, fin); err != nil { s.wr.Unlock() s.die(frame.InternalError, err) return } if err = s.transport.WriteFrame(s.syn); err != nil { s.wr.Unlock() s.die(frame.InternalError, err) return } s.wr.Unlock() return str, nil } func (s *Session) Accept() (str IStream, err error) { var ok bool if str, ok = <-s.accept; !ok { return nil, fmt.Errorf("Session closed") } return } func (s *Session) Kill() error { return s.transport.Close() } func (s *Session) Close() error { return s.die(frame.NoError, fmt.Errorf("Session Close()")) } func (s *Session) GoAway(errorCode frame.ErrorCode, debug []byte) (err error) { if !atomic.CompareAndSwapInt32(&s.local.goneAway, 0, 1) { return fmt.Errorf("Already sent GoAway!") } s.wr.Lock() f := frame.NewWGoAway() remoteId := frame.StreamId(atomic.LoadUint32(&s.remote.lastId)) if err = f.Set(remoteId, errorCode, debug); err != nil { s.wr.Unlock() s.die(frame.InternalError, err) return } if err = s.transport.WriteFrame(f); err != nil { s.wr.Unlock() s.die(frame.InternalError, err) return } s.wr.Unlock() return } func (s *Session) LocalAddr() net.Addr { return s.conn.LocalAddr() } func (s *Session) RemoteAddr() net.Addr { return s.conn.RemoteAddr() } func (s *Session) Wait() (frame.ErrorCode, error, []byte) { reason := <-s.dead return reason.errorCode, reason.err, reason.remoteDebug } //////////////////////////////// // private interface for streams //////////////////////////////// // removeStream removes a stream from this session's stream registry // // It does not error if the stream is not present func (s *Session) removeStream(id frame.StreamId) { s.streams.Delete(id) return } // writeFrame writes the given frame to the transport and returns the error from the write operation func (s *Session) writeFrame(f frame.WFrame, dl time.Time) (err error) { s.wr.Lock() s.conn.SetWriteDeadline(dl) err = s.transport.WriteFrame(f) s.wr.Unlock() return } // die closes the session cleanly with the given error and protocol error code func (s *Session) die(errorCode frame.ErrorCode, err error) error { // only one shutdown ever happens if !atomic.CompareAndSwapInt32(&s.diebit, 0, 1) { return fmt.Errorf("Shutdown already in progress") } // send a go away frame s.GoAway(errorCode, []byte(err.Error())) // now we're safe to stop accepting incoming connections close(s.accept) // we cleaned up as best as possible, close the transport s.transport.Close() // notify all of the streams that we're closing s.streams.Each(func(id frame.StreamId, str stream) { str.closeWith(fmt.Errorf("Session closed")) }) s.dead <- deadReason{errorCode, err, s.remoteDebug} return nil } //////////////////////////////// // internal methods //////////////////////////////// // reader() reads frames from the underlying transport and handles passes them to handleFrame func (s *Session) reader() { defer s.recoverPanic("reader()") // close all of the extension accept channels when we're done // we do this here instead of in die() since otherwise it wouldn't // be safe to access s.exts defer func() { for _, extAccept := range s.exts { close(extAccept) } }() for { f, err := s.transport.ReadFrame() if err != nil { // if we fail to read a frame, terminate the session _, ok := err.(*frame.FramingError) if ok { s.die(frame.ProtocolError, err) } else { s.die(frame.InternalError, err) } return } s.handleFrame(f) } } func (s *Session) handleFrame(rf frame.RFrame) { switch f := rf.(type) { case *frame.RStreamSyn: // if we're going away, refuse new streams if atomic.LoadInt32(&s.local.goneAway) == 1 { rstF := frame.NewWStreamRst() rstF.Set(f.StreamId(), frame.RefusedStream) go s.writeFrame(rstF, time.Time{}) return } if f.StreamId() <= frame.StreamId(atomic.LoadUint32(&s.remote.lastId)) { s.die(frame.ProtocolError, fmt.Errorf("Stream id %d is less than last remote id.", f.StreamId())) return } if s.isLocal(f.StreamId()) { s.die(frame.ProtocolError, fmt.Errorf("Stream id has wrong parity for remote endpoint: %d", f.StreamId())) return } // update last remote id atomic.StoreUint32(&s.remote.lastId, uint32(f.StreamId())) // make the new stream str := s.newStream(f.StreamId(), f.StreamPriority(), f.StreamType(), false, f.Fin(), s.defaultWindowSize, s) // add it to the stream map s.streams.Set(f.StreamId(), str) // check if this is an extension stream if f.StreamType() >= MinExtensionType { extAccept, ok := s.exts[f.StreamType()] if !ok { // Extension type of stream not registered fRst := frame.NewWStreamRst() if err := fRst.Set(f.StreamId(), frame.StreamClosed); err != nil { s.die(frame.InternalError, err) } s.wr.Lock() defer s.wr.Unlock() s.transport.WriteFrame(fRst) } else { extAccept <- str } return } // put the new stream on the accept channel s.accept <- str case *frame.RStreamData: if str := s.getStream(f.StreamId()); str != nil { str.handleStreamData(f) } else { // if we get a data frame on a non-existent connection, we still // need to read out the frame body so that the stream stays in a // good state. read the payload into a throwaway buffer discard := make([]byte, f.Length()) io.ReadFull(f.Reader(), discard) // DATA frames on closed connections are just stream-level errors fRst := frame.NewWStreamRst() if err := fRst.Set(f.StreamId(), frame.StreamClosed); err != nil { s.die(frame.InternalError, err) } s.wr.Lock() defer s.wr.Unlock() s.transport.WriteFrame(fRst) return } case *frame.RStreamRst: // delegate to the stream to handle these frames if str := s.getStream(f.StreamId()); str != nil { str.handleStreamRst(f) } case *frame.RStreamWndInc: // delegate to the stream to handle these frames if str := s.getStream(f.StreamId()); str != nil { str.handleStreamWndInc(f) } case *frame.RGoAway: atomic.StoreInt32(&s.remote.goneAway, 1) s.remoteDebug = f.Debug() lastId := f.LastStreamId() s.streams.Each(func(id frame.StreamId, str stream) { // close all streams that we opened above the last handled id if s.isLocal(str.Id()) && str.Id() > lastId { str.closeWith(fmt.Errorf("Remote is going away")) } }) default: s.die(frame.ProtocolError, fmt.Errorf("Unrecognized frame type: %v", reflect.TypeOf(f))) return } } func (s *Session) recoverPanic(prefix string) { if r := recover(); r != nil { s.die(frame.InternalError, fmt.Errorf("%s panic: %v", prefix, r)) } } func (s *Session) getStream(id frame.StreamId) (str stream) { // decide if this id is in the "idle" state (i.e. greater than any we've seen for that parity) var lastId *uint32 if s.isLocal(id) { lastId = &s.local.lastId } else { lastId = &s.remote.lastId } if uint32(id) > atomic.LoadUint32(lastId) { s.die(frame.ProtocolError, fmt.Errorf("%d is an invalid, unassigned stream id", id)) } // find the stream in the stream map var ok bool if str, ok = s.streams.Get(id); !ok { return nil } return } // check if a stream id is for a client stream. client streams are odd func (s *Session) isClient(id frame.StreamId) bool { return uint32(id)&1 == 1 } func (s *Session) isServer(id frame.StreamId) bool { return !s.isClient(id) } ////////////////////////////////////////////// // session extensions ////////////////////////////////////////////// func (s *Session) startExtension(ext Extension) { accept := make(chan stream) extAccept := func() (IStream, error) { s, ok := <-accept if !ok { return nil, fmt.Errorf("Failed to accept connection, shutting down") } return s, nil } extType := ext.Start(s, extAccept) s.exts[extType] = accept } ////////////////////////////////////////////// // net adaptors ////////////////////////////////////////////// func (s *Session) NetDial(_, _ string) (net.Conn, error) { str, err := s.Open() return net.Conn(str), err } func (s *Session) NetListener() net.Listener { return &netListenerAdaptor{s} } type netListenerAdaptor struct { *Session } func (a *netListenerAdaptor) Addr() net.Addr { return a.LocalAddr() } func (a *netListenerAdaptor) Accept() (net.Conn, error) { str, err := a.Session.Accept() return net.Conn(str), err }