open-vault/vendor/github.com/jackc/pgx/pgproto3/startup_message.go

98 lines
2.1 KiB
Go
Raw Normal View History

2018-10-03 16:55:26 +00:00
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"github.com/jackc/pgx/pgio"
"github.com/pkg/errors"
)
const (
ProtocolVersionNumber = 196608 // 3.0
sslRequestNumber = 80877103
)
type StartupMessage struct {
ProtocolVersion uint32
Parameters map[string]string
}
func (*StartupMessage) Frontend() {}
func (dst *StartupMessage) Decode(src []byte) error {
if len(src) < 4 {
return errors.Errorf("startup message too short")
}
dst.ProtocolVersion = binary.BigEndian.Uint32(src)
rp := 4
if dst.ProtocolVersion == sslRequestNumber {
return errors.Errorf("can't handle ssl connection request")
}
if dst.ProtocolVersion != ProtocolVersionNumber {
return errors.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion)
}
dst.Parameters = make(map[string]string)
for {
idx := bytes.IndexByte(src[rp:], 0)
if idx < 0 {
return &invalidMessageFormatErr{messageType: "StartupMesage"}
}
key := string(src[rp : rp+idx])
rp += idx + 1
idx = bytes.IndexByte(src[rp:], 0)
if idx < 0 {
return &invalidMessageFormatErr{messageType: "StartupMesage"}
}
value := string(src[rp : rp+idx])
rp += idx + 1
dst.Parameters[key] = value
if len(src[rp:]) == 1 {
if src[rp] != 0 {
return errors.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp])
}
break
}
}
return nil
}
func (src *StartupMessage) Encode(dst []byte) []byte {
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, src.ProtocolVersion)
for k, v := range src.Parameters {
dst = append(dst, k...)
dst = append(dst, 0)
dst = append(dst, v...)
dst = append(dst, 0)
}
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
func (src *StartupMessage) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ProtocolVersion uint32
Parameters map[string]string
}{
Type: "StartupMessage",
ProtocolVersion: src.ProtocolVersion,
Parameters: src.Parameters,
})
}