open-vault/vendor/github.com/denisenkom/go-mssqldb/buf.go

256 lines
4.9 KiB
Go
Raw Normal View History

2016-03-03 15:16:59 +00:00
package mssql
import (
"encoding/binary"
2016-12-01 19:46:57 +00:00
"errors"
2017-01-27 01:16:19 +00:00
"io"
2016-03-03 15:16:59 +00:00
)
2017-01-04 21:47:38 +00:00
type packetType uint8
2016-03-03 15:16:59 +00:00
type header struct {
2017-01-04 21:47:38 +00:00
PacketType packetType
2016-03-03 15:16:59 +00:00
Status uint8
Size uint16
Spid uint16
PacketNo uint8
Pad uint8
}
2017-01-27 01:16:19 +00:00
// tdsBuffer reads and writes TDS packets of data to the transport.
2017-03-31 00:03:13 +00:00
// The write and read buffers are separate to make sending attn signals
2017-01-27 01:16:19 +00:00
// possible without locks. Currently attn signals are only sent during
// reads, not writes.
2016-03-03 15:16:59 +00:00
type tdsBuffer struct {
2017-01-27 01:16:19 +00:00
transport io.ReadWriteCloser
// Write fields.
wbuf []byte
wpos uint16
// Read fields.
rbuf []byte
rpos uint16
rsize uint16
2016-03-03 15:16:59 +00:00
final bool
2017-01-04 21:47:38 +00:00
packet_type packetType
2017-01-27 01:16:19 +00:00
// afterFirst is assigned to right after tdsBuffer is created and
// before the first use. It is executed after the first packet is
2017-03-31 00:03:13 +00:00
// written and then removed.
2017-01-27 01:16:19 +00:00
afterFirst func()
2016-03-03 15:16:59 +00:00
}
2017-04-17 15:17:06 +00:00
func newTdsBuffer(bufsize uint16, transport io.ReadWriteCloser) *tdsBuffer {
2016-03-03 15:16:59 +00:00
w := new(tdsBuffer)
2017-01-27 01:16:19 +00:00
w.wbuf = make([]byte, bufsize)
w.rbuf = make([]byte, bufsize)
w.wpos = 0
w.rpos = 8
2016-03-03 15:16:59 +00:00
w.transport = transport
return w
}
2017-01-27 01:16:19 +00:00
func (rw *tdsBuffer) ResizeBuffer(packetsizei int) {
if len(rw.rbuf) != packetsizei {
newbuf := make([]byte, packetsizei)
copy(newbuf, rw.rbuf)
rw.rbuf = newbuf
}
if len(rw.wbuf) != packetsizei {
newbuf := make([]byte, packetsizei)
copy(newbuf, rw.wbuf)
rw.wbuf = newbuf
}
}
func (w *tdsBuffer) PackageSize() uint32 {
return uint32(len(w.wbuf))
}
2016-03-03 15:16:59 +00:00
func (w *tdsBuffer) flush() (err error) {
2016-12-01 19:46:57 +00:00
// writing packet size
2017-01-27 01:16:19 +00:00
binary.BigEndian.PutUint16(w.wbuf[2:], w.wpos)
2016-12-01 19:46:57 +00:00
// writing packet into underlying transport
2017-01-27 01:16:19 +00:00
if _, err = w.transport.Write(w.wbuf[:w.wpos]); err != nil {
2016-03-03 15:16:59 +00:00
return err
}
2016-12-01 19:46:57 +00:00
// execute afterFirst hook if it is set
2016-03-03 15:16:59 +00:00
if w.afterFirst != nil {
w.afterFirst()
w.afterFirst = nil
}
2016-12-01 19:46:57 +00:00
2017-01-27 01:16:19 +00:00
w.wpos = 8
2016-12-01 19:46:57 +00:00
// packet number
2017-01-27 01:16:19 +00:00
w.wbuf[6] += 1
2016-03-03 15:16:59 +00:00
return nil
}
2016-12-01 19:46:57 +00:00
func (w *tdsBuffer) Write(p []byte) (total int, err error) {
total = 0
2016-03-03 15:16:59 +00:00
for {
2017-01-27 01:16:19 +00:00
copied := copy(w.wbuf[w.wpos:], p)
w.wpos += uint16(copied)
2016-03-03 15:16:59 +00:00
total += copied
if copied == len(p) {
break
}
if err = w.flush(); err != nil {
2016-12-01 19:46:57 +00:00
return
2016-03-03 15:16:59 +00:00
}
p = p[copied:]
}
2016-12-01 19:46:57 +00:00
return
2016-03-03 15:16:59 +00:00
}
func (w *tdsBuffer) WriteByte(b byte) error {
2017-01-27 01:16:19 +00:00
if int(w.wpos) == len(w.wbuf) {
2016-03-03 15:16:59 +00:00
if err := w.flush(); err != nil {
return err
}
}
2017-01-27 01:16:19 +00:00
w.wbuf[w.wpos] = b
w.wpos += 1
2016-03-03 15:16:59 +00:00
return nil
}
2017-01-04 21:47:38 +00:00
func (w *tdsBuffer) BeginPacket(packet_type packetType) {
2017-01-27 01:16:19 +00:00
w.wbuf[0] = byte(packet_type)
w.wbuf[1] = 0 // packet is incomplete
w.wbuf[4] = 0 // spid
w.wbuf[5] = 0
w.wbuf[6] = 1 // packet id
w.wbuf[7] = 0 // window
w.wpos = 8
2016-03-03 15:16:59 +00:00
}
2016-12-01 19:46:57 +00:00
func (w *tdsBuffer) FinishPacket() error {
2017-01-27 01:16:19 +00:00
w.wbuf[1] = 1 // this is last packet
2016-12-01 19:46:57 +00:00
return w.flush()
2016-03-03 15:16:59 +00:00
}
func (r *tdsBuffer) readNextPacket() error {
header := header{}
var err error
err = binary.Read(r.transport, binary.BigEndian, &header)
if err != nil {
2017-09-05 22:06:47 +00:00
return err
2016-03-03 15:16:59 +00:00
}
offset := uint16(binary.Size(header))
2017-01-27 01:16:19 +00:00
if int(header.Size) > len(r.rbuf) {
2016-12-01 19:46:57 +00:00
return errors.New("Invalid packet size, it is longer than buffer size")
}
if int(offset) > int(header.Size) {
return errors.New("Invalid packet size, it is shorter than header size")
}
2017-01-27 01:16:19 +00:00
_, err = io.ReadFull(r.transport, r.rbuf[offset:header.Size])
2016-03-03 15:16:59 +00:00
if err != nil {
2017-09-05 22:06:47 +00:00
return err
2016-03-03 15:16:59 +00:00
}
2017-01-27 01:16:19 +00:00
r.rpos = offset
r.rsize = header.Size
2016-03-03 15:16:59 +00:00
r.final = header.Status != 0
r.packet_type = header.PacketType
return nil
}
2017-01-04 21:47:38 +00:00
func (r *tdsBuffer) BeginRead() (packetType, error) {
2016-03-03 15:16:59 +00:00
err := r.readNextPacket()
if err != nil {
return 0, err
}
return r.packet_type, nil
}
func (r *tdsBuffer) ReadByte() (res byte, err error) {
2017-01-27 01:16:19 +00:00
if r.rpos == r.rsize {
2016-03-03 15:16:59 +00:00
if r.final {
return 0, io.EOF
}
err = r.readNextPacket()
if err != nil {
return 0, err
}
}
2017-01-27 01:16:19 +00:00
res = r.rbuf[r.rpos]
r.rpos++
2016-03-03 15:16:59 +00:00
return res, nil
}
func (r *tdsBuffer) byte() byte {
b, err := r.ReadByte()
if err != nil {
badStreamPanic(err)
}
return b
}
func (r *tdsBuffer) ReadFull(buf []byte) {
_, err := io.ReadFull(r, buf[:])
if err != nil {
2017-09-05 22:06:47 +00:00
badStreamPanic(err)
2016-03-03 15:16:59 +00:00
}
}
func (r *tdsBuffer) uint64() uint64 {
var buf [8]byte
r.ReadFull(buf[:])
return binary.LittleEndian.Uint64(buf[:])
}
func (r *tdsBuffer) int32() int32 {
return int32(r.uint32())
}
func (r *tdsBuffer) uint32() uint32 {
var buf [4]byte
r.ReadFull(buf[:])
return binary.LittleEndian.Uint32(buf[:])
}
func (r *tdsBuffer) uint16() uint16 {
var buf [2]byte
r.ReadFull(buf[:])
return binary.LittleEndian.Uint16(buf[:])
}
func (r *tdsBuffer) BVarChar() string {
l := int(r.byte())
return r.readUcs2(l)
}
func (r *tdsBuffer) UsVarChar() string {
l := int(r.uint16())
return r.readUcs2(l)
}
func (r *tdsBuffer) readUcs2(numchars int) string {
b := make([]byte, numchars*2)
r.ReadFull(b)
res, err := ucs22str(b)
if err != nil {
badStreamPanic(err)
}
return res
}
2016-12-01 19:46:57 +00:00
func (r *tdsBuffer) Read(buf []byte) (copied int, err error) {
copied = 0
err = nil
2017-01-27 01:16:19 +00:00
if r.rpos == r.rsize {
2016-03-03 15:16:59 +00:00
if r.final {
return 0, io.EOF
}
err = r.readNextPacket()
if err != nil {
2016-12-01 19:46:57 +00:00
return
2016-03-03 15:16:59 +00:00
}
}
2017-01-27 01:16:19 +00:00
copied = copy(buf, r.rbuf[r.rpos:r.rsize])
r.rpos += uint16(copied)
2016-12-01 19:46:57 +00:00
return
2016-03-03 15:16:59 +00:00
}