open-vault/vendor/github.com/jackc/pgx/conn.go
Jeff Mitchell 9ebc57581d
Switch to go modules (#6585)
* Switch to go modules

* Make fmt
2019-04-13 03:44:06 -04:00

1998 lines
51 KiB
Go

package pgx
import (
"context"
"crypto/md5"
"crypto/tls"
"crypto/x509"
"encoding/binary"
"encoding/hex"
"fmt"
"io"
"io/ioutil"
"net"
"net/url"
"os"
"os/user"
"path/filepath"
"reflect"
"regexp"
"strconv"
"strings"
"sync"
"time"
"github.com/pkg/errors"
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgproto3"
"github.com/jackc/pgx/pgtype"
)
const (
connStatusUninitialized = iota
connStatusClosed
connStatusIdle
connStatusBusy
)
// minimalConnInfo has just enough static type information to establish the
// connection and retrieve the type data.
var minimalConnInfo *pgtype.ConnInfo
func init() {
minimalConnInfo = pgtype.NewConnInfo()
minimalConnInfo.InitializeDataTypes(map[string]pgtype.OID{
"int4": pgtype.Int4OID,
"name": pgtype.NameOID,
"oid": pgtype.OIDOID,
"text": pgtype.TextOID,
"varchar": pgtype.VarcharOID,
})
}
// NoticeHandler is a function that can handle notices received from the
// PostgreSQL server. Notices can be received at any time, usually during
// handling of a query response. The *Conn is provided so the handler is aware
// of the origin of the notice, but it must not invoke any query method. Be
// aware that this is distinct from LISTEN/NOTIFY notification.
type NoticeHandler func(*Conn, *Notice)
// DialFunc is a function that can be used to connect to a PostgreSQL server
type DialFunc func(network, addr string) (net.Conn, error)
// ConnConfig contains all the options used to establish a connection.
type ConnConfig struct {
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
Port uint16 // default: 5432
Database string
User string // default: OS user name
Password string
TLSConfig *tls.Config // config for TLS connection -- nil disables TLS
UseFallbackTLS bool // Try FallbackTLSConfig if connecting with TLSConfig fails. Used for preferring TLS, but allowing unencrypted, or vice-versa
FallbackTLSConfig *tls.Config // config for fallback TLS connection (only used if UseFallBackTLS is true)-- nil disables TLS
Logger Logger
LogLevel int
Dial DialFunc
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
OnNotice NoticeHandler // Callback function called when a notice response is received.
CustomConnInfo func(*Conn) (*pgtype.ConnInfo, error) // Callback function to implement connection strategies for different backends. crate, pgbouncer, pgpool, etc.
CustomCancel func(*Conn) error // Callback function used to override cancellation behavior
// PreferSimpleProtocol disables implicit prepared statement usage. By default
// pgx automatically uses the unnamed prepared statement for Query and
// QueryRow. It also uses a prepared statement when Exec has arguments. This
// can improve performance due to being able to use the binary format. It also
// does not rely on client side parameter sanitization. However, it does incur
// two round-trips per query and may be incompatible proxies such as
// PGBouncer. Setting PreferSimpleProtocol causes the simple protocol to be
// used by default. The same functionality can be controlled on a per query
// basis by setting QueryExOptions.SimpleProtocol.
PreferSimpleProtocol bool
}
func (cc *ConnConfig) networkAddress() (network, address string) {
network = "tcp"
address = fmt.Sprintf("%s:%d", cc.Host, cc.Port)
// See if host is a valid path, if yes connect with a socket
if _, err := os.Stat(cc.Host); err == nil {
// For backward compatibility accept socket file paths -- but directories are now preferred
network = "unix"
address = cc.Host
if !strings.Contains(address, "/.s.PGSQL.") {
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10)
}
}
return network, address
}
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
// Use ConnPool to manage access to multiple database connections from multiple
// goroutines.
type Conn struct {
conn net.Conn // the underlying TCP or unix domain socket connection
lastActivityTime time.Time // the last time the connection was used
wbuf []byte
pid uint32 // backend pid
secretKey uint32 // key to use to send a cancel query message to the server
RuntimeParams map[string]string // parameters that have been reported by the server
config ConnConfig // config used when establishing this connection
txStatus byte
preparedStatements map[string]*PreparedStatement
channels map[string]struct{}
notifications []*Notification
logger Logger
logLevel int
fp *fastpath
poolResetCount int
preallocatedRows []Rows
onNotice NoticeHandler
mux sync.Mutex
status byte // One of connStatus* constants
causeOfDeath error
pendingReadyForQueryCount int // number of ReadyForQuery messages expected
cancelQueryCompleted chan struct{}
lastStmtSent bool
// context support
ctxInProgress bool
doneChan chan struct{}
closedChan chan error
ConnInfo *pgtype.ConnInfo
frontend *pgproto3.Frontend
}
// PreparedStatement is a description of a prepared statement
type PreparedStatement struct {
Name string
SQL string
FieldDescriptions []FieldDescription
ParameterOIDs []pgtype.OID
}
// PrepareExOptions is an option struct that can be passed to PrepareEx
type PrepareExOptions struct {
ParameterOIDs []pgtype.OID
}
// Notification is a message received from the PostgreSQL LISTEN/NOTIFY system
type Notification struct {
PID uint32 // backend pid that sent the notification
Channel string // channel from which notification was received
Payload string
}
// CommandTag is the result of an Exec function
type CommandTag string
// RowsAffected returns the number of rows affected. If the CommandTag was not
// for a row affecting command (such as "CREATE TABLE") then it returns 0
func (ct CommandTag) RowsAffected() int64 {
s := string(ct)
index := strings.LastIndex(s, " ")
if index == -1 {
return 0
}
n, _ := strconv.ParseInt(s[index+1:], 10, 64)
return n
}
// Identifier a PostgreSQL identifier or name. Identifiers can be composed of
// multiple parts such as ["schema", "table"] or ["table", "column"].
type Identifier []string
// Sanitize returns a sanitized string safe for SQL interpolation.
func (ident Identifier) Sanitize() string {
parts := make([]string, len(ident))
for i := range ident {
parts[i] = `"` + strings.Replace(ident[i], `"`, `""`, -1) + `"`
}
return strings.Join(parts, ".")
}
// ErrNoRows occurs when rows are expected but none are returned.
var ErrNoRows = errors.New("no rows in result set")
// ErrDeadConn occurs on an attempt to use a dead connection
var ErrDeadConn = errors.New("conn is dead")
// ErrTLSRefused occurs when the connection attempt requires TLS and the
// PostgreSQL server refuses to use TLS
var ErrTLSRefused = errors.New("server refused TLS connection")
// ErrConnBusy occurs when the connection is busy (for example, in the middle of
// reading query results) and another action is attempted.
var ErrConnBusy = errors.New("conn is busy")
// ErrInvalidLogLevel occurs on attempt to set an invalid log level.
var ErrInvalidLogLevel = errors.New("invalid log level")
// ProtocolError occurs when unexpected data is received from PostgreSQL
type ProtocolError string
func (e ProtocolError) Error() string {
return string(e)
}
// Connect establishes a connection with a PostgreSQL server using config.
// config.Host must be specified. config.User will default to the OS user name.
// Other config fields are optional.
func Connect(config ConnConfig) (c *Conn, err error) {
return connect(config, minimalConnInfo)
}
func defaultDialer() *net.Dialer {
return &net.Dialer{KeepAlive: 5 * time.Minute}
}
func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) {
c = new(Conn)
c.config = config
c.ConnInfo = connInfo
if c.config.LogLevel != 0 {
c.logLevel = c.config.LogLevel
} else {
// Preserve pre-LogLevel behavior by defaulting to LogLevelDebug
c.logLevel = LogLevelDebug
}
c.logger = c.config.Logger
if c.config.User == "" {
user, err := user.Current()
if err != nil {
return nil, err
}
c.config.User = user.Username
if c.shouldLog(LogLevelDebug) {
c.log(LogLevelDebug, "Using default connection config", map[string]interface{}{"User": c.config.User})
}
}
if c.config.Port == 0 {
c.config.Port = 5432
if c.shouldLog(LogLevelDebug) {
c.log(LogLevelDebug, "Using default connection config", map[string]interface{}{"Port": c.config.Port})
}
}
c.onNotice = config.OnNotice
network, address := c.config.networkAddress()
if c.config.Dial == nil {
d := defaultDialer()
c.config.Dial = d.Dial
}
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"network": network, "address": address})
}
err = c.connect(config, network, address, config.TLSConfig)
if err != nil && config.UseFallbackTLS {
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err})
}
err = c.connect(config, network, address, config.FallbackTLSConfig)
}
if err != nil {
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err})
}
return nil, err
}
return c, nil
}
func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) {
c.conn, err = c.config.Dial(network, address)
if err != nil {
return err
}
defer func() {
if c != nil && err != nil {
c.conn.Close()
c.mux.Lock()
c.status = connStatusClosed
c.mux.Unlock()
}
}()
c.RuntimeParams = make(map[string]string)
c.preparedStatements = make(map[string]*PreparedStatement)
c.channels = make(map[string]struct{})
c.lastActivityTime = time.Now()
c.cancelQueryCompleted = make(chan struct{})
close(c.cancelQueryCompleted)
c.doneChan = make(chan struct{})
c.closedChan = make(chan error)
c.wbuf = make([]byte, 0, 1024)
c.mux.Lock()
c.status = connStatusIdle
c.mux.Unlock()
if tlsConfig != nil {
if c.shouldLog(LogLevelDebug) {
c.log(LogLevelDebug, "starting TLS handshake", nil)
}
if err := c.startTLS(tlsConfig); err != nil {
return err
}
}
c.frontend, err = pgproto3.NewFrontend(c.conn, c.conn)
if err != nil {
return err
}
startupMsg := pgproto3.StartupMessage{
ProtocolVersion: pgproto3.ProtocolVersionNumber,
Parameters: make(map[string]string),
}
// Copy default run-time params
for k, v := range config.RuntimeParams {
startupMsg.Parameters[k] = v
}
startupMsg.Parameters["user"] = c.config.User
if c.config.Database != "" {
startupMsg.Parameters["database"] = c.config.Database
}
if _, err := c.conn.Write(startupMsg.Encode(nil)); err != nil {
return err
}
c.pendingReadyForQueryCount = 1
for {
msg, err := c.rxMsg()
if err != nil {
return err
}
switch msg := msg.(type) {
case *pgproto3.BackendKeyData:
c.rxBackendKeyData(msg)
case *pgproto3.Authentication:
if err = c.rxAuthenticationX(msg); err != nil {
return err
}
case *pgproto3.ReadyForQuery:
c.rxReadyForQuery(msg)
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "connection established", nil)
}
// Replication connections can't execute the queries to
// populate the c.PgTypes and c.pgsqlAfInet
if _, ok := config.RuntimeParams["replication"]; ok {
return nil
}
if c.ConnInfo == minimalConnInfo {
err = c.initConnInfo()
if err != nil {
return err
}
}
return nil
default:
if err = c.processContextFreeMsg(msg); err != nil {
return err
}
}
}
}
func initPostgresql(c *Conn) (*pgtype.ConnInfo, error) {
const (
namedOIDQuery = `select t.oid,
case when nsp.nspname in ('pg_catalog', 'public') then t.typname
else nsp.nspname||'.'||t.typname
end
from pg_type t
left join pg_type base_type on t.typelem=base_type.oid
left join pg_namespace nsp on t.typnamespace=nsp.oid
where (
t.typtype in('b', 'p', 'r', 'e')
and (base_type.oid is null or base_type.typtype in('b', 'p', 'r'))
)`
)
nameOIDs, err := connInfoFromRows(c.Query(namedOIDQuery))
if err != nil {
return nil, err
}
cinfo := pgtype.NewConnInfo()
cinfo.InitializeDataTypes(nameOIDs)
if err = c.initConnInfoEnumArray(cinfo); err != nil {
return nil, err
}
if err = c.initConnInfoDomains(cinfo); err != nil {
return nil, err
}
return cinfo, nil
}
func (c *Conn) initConnInfo() (err error) {
var (
connInfo *pgtype.ConnInfo
)
if c.config.CustomConnInfo != nil {
if c.ConnInfo, err = c.config.CustomConnInfo(c); err != nil {
return err
}
return nil
}
if connInfo, err = initPostgresql(c); err == nil {
c.ConnInfo = connInfo
return err
}
// Check if CrateDB specific approach might still allow us to connect.
if connInfo, err = c.crateDBTypesQuery(err); err == nil {
c.ConnInfo = connInfo
}
return err
}
// initConnInfoEnumArray introspects for arrays of enums and registers a data type for them.
func (c *Conn) initConnInfoEnumArray(cinfo *pgtype.ConnInfo) error {
nameOIDs := make(map[string]pgtype.OID, 16)
rows, err := c.Query(`select t.oid, t.typname
from pg_type t
join pg_type base_type on t.typelem=base_type.oid
where t.typtype = 'b'
and base_type.typtype = 'e'`)
if err != nil {
return err
}
for rows.Next() {
var oid pgtype.OID
var name pgtype.Text
if err := rows.Scan(&oid, &name); err != nil {
return err
}
nameOIDs[name.String] = oid
}
if rows.Err() != nil {
return rows.Err()
}
for name, oid := range nameOIDs {
cinfo.RegisterDataType(pgtype.DataType{
Value: &pgtype.EnumArray{},
Name: name,
OID: oid,
})
}
return nil
}
// initConnInfoDomains introspects for domains and registers a data type for them.
func (c *Conn) initConnInfoDomains(cinfo *pgtype.ConnInfo) error {
type domain struct {
oid pgtype.OID
name pgtype.Text
baseOID pgtype.OID
}
domains := make([]*domain, 0, 16)
rows, err := c.Query(`select t.oid, t.typname, t.typbasetype
from pg_type t
join pg_type base_type on t.typbasetype=base_type.oid
where t.typtype = 'd'
and base_type.typtype = 'b'`)
if err != nil {
return err
}
for rows.Next() {
var d domain
if err := rows.Scan(&d.oid, &d.name, &d.baseOID); err != nil {
return err
}
domains = append(domains, &d)
}
if rows.Err() != nil {
return rows.Err()
}
for _, d := range domains {
baseDataType, ok := cinfo.DataTypeForOID(d.baseOID)
if ok {
cinfo.RegisterDataType(pgtype.DataType{
Value: reflect.New(reflect.ValueOf(baseDataType.Value).Elem().Type()).Interface().(pgtype.Value),
Name: d.name.String,
OID: d.oid,
})
}
}
return nil
}
// crateDBTypesQuery checks if the given err is likely to be the result of
// CrateDB not implementing the pg_types table correctly. If yes, a CrateDB
// specific query against pg_types is executed and its results are returned. If
// not, the original error is returned.
func (c *Conn) crateDBTypesQuery(err error) (*pgtype.ConnInfo, error) {
// CrateDB 2.1.6 is a database that implements the PostgreSQL wire protocol,
// but not perfectly. In particular, the pg_catalog schema containing the
// pg_type table is not visible by default and the pg_type.typtype column is
// not implemented. Therefor the query above currently returns the following
// error:
//
// pgx.PgError{Severity:"ERROR", Code:"XX000",
// Message:"TableUnknownException: Table 'test.pg_type' unknown",
// Detail:"", Hint:"", Position:0, InternalPosition:0, InternalQuery:"",
// Where:"", SchemaName:"", TableName:"", ColumnName:"", DataTypeName:"",
// ConstraintName:"", File:"Schemas.java", Line:99, Routine:"getTableInfo"}
//
// If CrateDB was to fix the pg_type table visbility in the future, we'd
// still get this error until typtype column is implemented:
//
// pgx.PgError{Severity:"ERROR", Code:"XX000",
// Message:"ColumnUnknownException: Column typtype unknown", Detail:"",
// Hint:"", Position:0, InternalPosition:0, InternalQuery:"", Where:"",
// SchemaName:"", TableName:"", ColumnName:"", DataTypeName:"",
// ConstraintName:"", File:"FullQualifiedNameFieldProvider.java", Line:132,
//
// Additionally CrateDB doesn't implement Postgres error codes [2], and
// instead always returns "XX000" (internal_error). The code below uses all
// of this knowledge as a heuristic to detect CrateDB. If CrateDB is
// detected, a CrateDB specific pg_type query is executed instead.
//
// The heuristic is designed to still work even if CrateDB fixes [2] or
// renames its internal exception names. If both are changed but pg_types
// isn't fixed, this code will need to be changed.
//
// There is also a small chance the heuristic will yield a false positive for
// non-CrateDB databases (e.g. if a real Postgres instance returns a XX000
// error), but hopefully there will be no harm in attempting the alternative
// query in this case.
//
// CrateDB also uses the type varchar for the typname column which required
// adding varchar to the minimalConnInfo init code.
//
// Also see the discussion here [3].
//
// [1] https://crate.io/
// [2] https://github.com/crate/crate/issues/5027
// [3] https://github.com/jackc/pgx/issues/320
if pgErr, ok := err.(PgError); ok &&
(pgErr.Code == "XX000" ||
strings.Contains(pgErr.Message, "TableUnknownException") ||
strings.Contains(pgErr.Message, "ColumnUnknownException")) {
var (
nameOIDs map[string]pgtype.OID
)
if nameOIDs, err = connInfoFromRows(c.Query(`select oid, typname from pg_catalog.pg_type`)); err != nil {
return nil, err
}
cinfo := pgtype.NewConnInfo()
cinfo.InitializeDataTypes(nameOIDs)
return cinfo, err
}
return nil, err
}
// PID returns the backend PID for this connection.
func (c *Conn) PID() uint32 {
return c.pid
}
// LocalAddr returns the underlying connection's local address
func (c *Conn) LocalAddr() (net.Addr, error) {
if !c.IsAlive() {
return nil, errors.New("connection not ready")
}
return c.conn.LocalAddr(), nil
}
// Close closes a connection. It is safe to call Close on a already closed
// connection.
func (c *Conn) Close() (err error) {
c.mux.Lock()
defer c.mux.Unlock()
if c.status < connStatusIdle {
return nil
}
c.status = connStatusClosed
defer func() {
c.conn.Close()
c.causeOfDeath = errors.New("Closed")
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "closed connection", nil)
}
}()
err = c.conn.SetDeadline(time.Time{})
if err != nil && c.shouldLog(LogLevelWarn) {
c.log(LogLevelWarn, "failed to clear deadlines to send close message", map[string]interface{}{"err": err})
return err
}
_, err = c.conn.Write([]byte{'X', 0, 0, 0, 4})
if err != nil && c.shouldLog(LogLevelWarn) {
c.log(LogLevelWarn, "failed to send terminate message", map[string]interface{}{"err": err})
return err
}
err = c.conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if err != nil && c.shouldLog(LogLevelWarn) {
c.log(LogLevelWarn, "failed to set read deadline to finish closing", map[string]interface{}{"err": err})
return err
}
_, err = c.conn.Read(make([]byte, 1))
if err != io.EOF {
return err
}
return nil
}
// Merge returns a new ConnConfig with the attributes of old and other
// combined. When an attribute is set on both, other takes precedence.
//
// As a security precaution, if the other TLSConfig is nil, all old TLS
// attributes will be preserved.
func (old ConnConfig) Merge(other ConnConfig) ConnConfig {
cc := old
if other.Host != "" {
cc.Host = other.Host
}
if other.Port != 0 {
cc.Port = other.Port
}
if other.Database != "" {
cc.Database = other.Database
}
if other.User != "" {
cc.User = other.User
}
if other.Password != "" {
cc.Password = other.Password
}
if other.TLSConfig != nil {
cc.TLSConfig = other.TLSConfig
cc.UseFallbackTLS = other.UseFallbackTLS
cc.FallbackTLSConfig = other.FallbackTLSConfig
}
if other.Logger != nil {
cc.Logger = other.Logger
}
if other.LogLevel != 0 {
cc.LogLevel = other.LogLevel
}
if other.Dial != nil {
cc.Dial = other.Dial
}
cc.PreferSimpleProtocol = other.PreferSimpleProtocol
cc.RuntimeParams = make(map[string]string)
for k, v := range old.RuntimeParams {
cc.RuntimeParams[k] = v
}
for k, v := range other.RuntimeParams {
cc.RuntimeParams[k] = v
}
return cc
}
// ParseURI parses a database URI into ConnConfig
//
// Query parameters not used by the connection process are parsed into ConnConfig.RuntimeParams.
func ParseURI(uri string) (ConnConfig, error) {
var cp ConnConfig
url, err := url.Parse(uri)
if err != nil {
return cp, err
}
if url.User != nil {
cp.User = url.User.Username()
cp.Password, _ = url.User.Password()
}
parts := strings.SplitN(url.Host, ":", 2)
cp.Host = parts[0]
if len(parts) == 2 {
p, err := strconv.ParseUint(parts[1], 10, 16)
if err != nil {
return cp, err
}
cp.Port = uint16(p)
}
cp.Database = strings.TrimLeft(url.Path, "/")
if pgtimeout := url.Query().Get("connect_timeout"); pgtimeout != "" {
timeout, err := strconv.ParseInt(pgtimeout, 10, 64)
if err != nil {
return cp, err
}
d := defaultDialer()
d.Timeout = time.Duration(timeout) * time.Second
cp.Dial = d.Dial
}
tlsArgs := configTLSArgs{
sslCert: url.Query().Get("sslcert"),
sslKey: url.Query().Get("sslkey"),
sslMode: url.Query().Get("sslmode"),
sslRootCert: url.Query().Get("sslrootcert"),
}
err = configTLS(tlsArgs, &cp)
if err != nil {
return cp, err
}
ignoreKeys := map[string]struct{}{
"connect_timeout": {},
"sslcert": {},
"sslkey": {},
"sslmode": {},
"sslrootcert": {},
}
cp.RuntimeParams = make(map[string]string)
for k, v := range url.Query() {
if _, ok := ignoreKeys[k]; ok {
continue
}
if k == "host" {
cp.Host = v[0]
continue
}
cp.RuntimeParams[k] = v[0]
}
if cp.Password == "" {
pgpass(&cp)
}
return cp, nil
}
var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`)
// ParseDSN parses a database DSN (data source name) into a ConnConfig
//
// e.g. ParseDSN("user=username password=password host=1.2.3.4 port=5432 dbname=mydb sslmode=disable")
//
// Any options not used by the connection process are parsed into ConnConfig.RuntimeParams.
//
// e.g. ParseDSN("application_name=pgxtest search_path=admin user=username password=password host=1.2.3.4 dbname=mydb")
//
// ParseDSN tries to match libpq behavior with regard to sslmode. See comments
// for ParseEnvLibpq for more information on the security implications of
// sslmode options.
func ParseDSN(s string) (ConnConfig, error) {
var cp ConnConfig
m := dsnRegexp.FindAllStringSubmatch(s, -1)
tlsArgs := configTLSArgs{}
cp.RuntimeParams = make(map[string]string)
for _, b := range m {
switch b[1] {
case "user":
cp.User = b[2]
case "password":
cp.Password = b[2]
case "host":
cp.Host = b[2]
case "port":
p, err := strconv.ParseUint(b[2], 10, 16)
if err != nil {
return cp, err
}
cp.Port = uint16(p)
case "dbname":
cp.Database = b[2]
case "sslmode":
tlsArgs.sslMode = b[2]
case "sslrootcert":
tlsArgs.sslRootCert = b[2]
case "sslcert":
tlsArgs.sslCert = b[2]
case "sslkey":
tlsArgs.sslKey = b[2]
case "connect_timeout":
timeout, err := strconv.ParseInt(b[2], 10, 64)
if err != nil {
return cp, err
}
d := defaultDialer()
d.Timeout = time.Duration(timeout) * time.Second
cp.Dial = d.Dial
default:
cp.RuntimeParams[b[1]] = b[2]
}
}
err := configTLS(tlsArgs, &cp)
if err != nil {
return cp, err
}
if cp.Password == "" {
pgpass(&cp)
}
return cp, nil
}
// ParseConnectionString parses either a URI or a DSN connection string.
// see ParseURI and ParseDSN for details.
func ParseConnectionString(s string) (ConnConfig, error) {
if u, err := url.Parse(s); err == nil && u.Scheme != "" {
return ParseURI(s)
}
return ParseDSN(s)
}
// ParseEnvLibpq parses the environment like libpq does into a ConnConfig
//
// See http://www.postgresql.org/docs/9.4/static/libpq-envars.html for details
// on the meaning of environment variables.
//
// ParseEnvLibpq currently recognizes the following environment variables:
// PGHOST
// PGPORT
// PGDATABASE
// PGUSER
// PGPASSWORD
// PGSSLMODE
// PGSSLCERT
// PGSSLKEY
// PGSSLROOTCERT
// PGAPPNAME
// PGCONNECT_TIMEOUT
//
// Important TLS Security Notes:
// ParseEnvLibpq tries to match libpq behavior with regard to PGSSLMODE. This
// includes defaulting to "prefer" behavior if no environment variable is set.
//
// See http://www.postgresql.org/docs/9.4/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION
// for details on what level of security each sslmode provides.
//
// "verify-ca" mode currently is treated as "verify-full". e.g. It has stronger
// security guarantees than it would with libpq. Do not rely on this behavior as it
// may be possible to match libpq in the future. If you need full security use
// "verify-full".
//
// Several of the PGSSLMODE options (including the default behavior of "prefer")
// will set UseFallbackTLS to true and FallbackTLSConfig to a disabled or
// weakened TLS mode. This means that if ParseEnvLibpq is used, but TLSConfig is
// later set from a different source that UseFallbackTLS MUST be set false to
// avoid the possibility of falling back to weaker or disabled security.
func ParseEnvLibpq() (ConnConfig, error) {
var cc ConnConfig
cc.Host = os.Getenv("PGHOST")
if pgport := os.Getenv("PGPORT"); pgport != "" {
if port, err := strconv.ParseUint(pgport, 10, 16); err == nil {
cc.Port = uint16(port)
} else {
return cc, err
}
}
cc.Database = os.Getenv("PGDATABASE")
cc.User = os.Getenv("PGUSER")
cc.Password = os.Getenv("PGPASSWORD")
if pgtimeout := os.Getenv("PGCONNECT_TIMEOUT"); pgtimeout != "" {
if timeout, err := strconv.ParseInt(pgtimeout, 10, 64); err == nil {
d := defaultDialer()
d.Timeout = time.Duration(timeout) * time.Second
cc.Dial = d.Dial
} else {
return cc, err
}
}
tlsArgs := configTLSArgs{
sslMode: os.Getenv("PGSSLMODE"),
sslKey: os.Getenv("PGSSLKEY"),
sslCert: os.Getenv("PGSSLCERT"),
sslRootCert: os.Getenv("PGSSLROOTCERT"),
}
err := configTLS(tlsArgs, &cc)
if err != nil {
return cc, err
}
cc.RuntimeParams = make(map[string]string)
if appname := os.Getenv("PGAPPNAME"); appname != "" {
cc.RuntimeParams["application_name"] = appname
}
if cc.Password == "" {
pgpass(&cc)
}
return cc, nil
}
type configTLSArgs struct {
sslMode string
sslRootCert string
sslCert string
sslKey string
}
// configTLS uses lib/pq's TLS parameters to reconstruct a coherent tls.Config.
// Inputs are parsed out and provided by ParseDSN() or ParseURI().
func configTLS(args configTLSArgs, cc *ConnConfig) error {
// Match libpq default behavior
if args.sslMode == "" {
args.sslMode = "prefer"
}
switch args.sslMode {
case "disable":
cc.UseFallbackTLS = false
cc.TLSConfig = nil
cc.FallbackTLSConfig = nil
return nil
case "allow":
cc.UseFallbackTLS = true
cc.FallbackTLSConfig = &tls.Config{InsecureSkipVerify: true}
case "prefer":
cc.TLSConfig = &tls.Config{InsecureSkipVerify: true}
cc.UseFallbackTLS = true
cc.FallbackTLSConfig = nil
case "require":
cc.TLSConfig = &tls.Config{InsecureSkipVerify: true}
case "verify-ca", "verify-full":
cc.TLSConfig = &tls.Config{
ServerName: cc.Host,
}
default:
return errors.New("sslmode is invalid")
}
if args.sslRootCert != "" {
caCertPool := x509.NewCertPool()
caPath := args.sslRootCert
caCert, err := ioutil.ReadFile(caPath)
if err != nil {
return errors.Wrapf(err, "unable to read CA file %q", caPath)
}
if !caCertPool.AppendCertsFromPEM(caCert) {
return errors.Wrap(err, "unable to add CA to cert pool")
}
cc.TLSConfig.RootCAs = caCertPool
cc.TLSConfig.ClientCAs = caCertPool
}
sslcert := args.sslCert
sslkey := args.sslKey
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
return fmt.Errorf(`both "sslcert" and "sslkey" are required`)
}
if sslcert != "" && sslkey != "" {
cert, err := tls.LoadX509KeyPair(sslcert, sslkey)
if err != nil {
return errors.Wrap(err, "unable to read cert")
}
cc.TLSConfig.Certificates = []tls.Certificate{cert}
}
return nil
}
// Prepare creates a prepared statement with name and sql. sql can contain placeholders
// for bound parameters. These placeholders are referenced positional as $1, $2, etc.
//
// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same
// name and sql arguments. This allows a code path to Prepare and Query/Exec without
// concern for if the statement has already been prepared.
func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
return c.PrepareEx(context.Background(), name, sql, nil)
}
// PrepareEx creates a prepared statement with name and sql. sql can contain placeholders
// for bound parameters. These placeholders are referenced positional as $1, $2, etc.
// It differs from Prepare as it allows additional options (such as parameter OIDs) to be passed via struct
//
// PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same
// name and sql arguments. This allows a code path to PrepareEx and Query/Exec without
// concern for if the statement has already been prepared.
func (c *Conn) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
err = c.waitForPreviousCancelQuery(ctx)
if err != nil {
return nil, err
}
err = c.initContext(ctx)
if err != nil {
return nil, err
}
ps, err = c.prepareEx(name, sql, opts)
err = c.termContext(err)
return ps, err
}
func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
if name != "" {
if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql {
return ps, nil
}
}
if err := c.ensureConnectionReadyForQuery(); err != nil {
return nil, err
}
if c.shouldLog(LogLevelError) {
defer func() {
if err != nil {
c.log(LogLevelError, "prepareEx failed", map[string]interface{}{"err": err, "name": name, "sql": sql})
}
}()
}
if opts == nil {
opts = &PrepareExOptions{}
}
if len(opts.ParameterOIDs) > 65535 {
return nil, errors.Errorf("Number of PrepareExOptions ParameterOIDs must be between 0 and 65535, received %d", len(opts.ParameterOIDs))
}
buf := appendParse(c.wbuf, name, sql, opts.ParameterOIDs)
buf = appendDescribe(buf, 'S', name)
buf = appendSync(buf)
n, err := c.conn.Write(buf)
if err != nil {
if fatalWriteErr(n, err) {
c.die(err)
}
return nil, err
}
c.pendingReadyForQueryCount++
ps = &PreparedStatement{Name: name, SQL: sql}
var softErr error
for {
msg, err := c.rxMsg()
if err != nil {
return nil, err
}
switch msg := msg.(type) {
case *pgproto3.ParameterDescription:
ps.ParameterOIDs = c.rxParameterDescription(msg)
if len(ps.ParameterOIDs) > 65535 && softErr == nil {
softErr = errors.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOIDs))
}
case *pgproto3.RowDescription:
ps.FieldDescriptions = c.rxRowDescription(msg)
for i := range ps.FieldDescriptions {
if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok {
ps.FieldDescriptions[i].DataTypeName = dt.Name
if _, ok := dt.Value.(pgtype.BinaryDecoder); ok {
ps.FieldDescriptions[i].FormatCode = BinaryFormatCode
} else {
ps.FieldDescriptions[i].FormatCode = TextFormatCode
}
} else {
return nil, errors.Errorf("unknown oid: %d", ps.FieldDescriptions[i].DataType)
}
}
case *pgproto3.ReadyForQuery:
c.rxReadyForQuery(msg)
if softErr == nil {
c.preparedStatements[name] = ps
}
return ps, softErr
default:
if e := c.processContextFreeMsg(msg); e != nil && softErr == nil {
softErr = e
}
}
}
}
// Deallocate released a prepared statement
func (c *Conn) Deallocate(name string) error {
return c.deallocateContext(context.Background(), name)
}
// TODO - consider making this public
func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) {
err = c.waitForPreviousCancelQuery(ctx)
if err != nil {
return err
}
err = c.initContext(ctx)
if err != nil {
return err
}
defer func() {
err = c.termContext(err)
}()
if err := c.ensureConnectionReadyForQuery(); err != nil {
return err
}
delete(c.preparedStatements, name)
// close
buf := c.wbuf
buf = append(buf, 'C')
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, 'S')
buf = append(buf, name...)
buf = append(buf, 0)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
// flush
buf = append(buf, 'H')
buf = pgio.AppendInt32(buf, 4)
_, err = c.conn.Write(buf)
if err != nil {
c.die(err)
return err
}
for {
msg, err := c.rxMsg()
if err != nil {
return err
}
switch msg.(type) {
case *pgproto3.CloseComplete:
return nil
default:
err = c.processContextFreeMsg(msg)
if err != nil {
return err
}
}
}
}
// Listen establishes a PostgreSQL listen/notify to channel
func (c *Conn) Listen(channel string) error {
_, err := c.Exec("listen " + quoteIdentifier(channel))
if err != nil {
return err
}
c.channels[channel] = struct{}{}
return nil
}
// Unlisten unsubscribes from a listen channel
func (c *Conn) Unlisten(channel string) error {
_, err := c.Exec("unlisten " + quoteIdentifier(channel))
if err != nil {
return err
}
delete(c.channels, channel)
return nil
}
// WaitForNotification waits for a PostgreSQL notification.
func (c *Conn) WaitForNotification(ctx context.Context) (notification *Notification, err error) {
// Return already received notification immediately
if len(c.notifications) > 0 {
notification := c.notifications[0]
c.notifications = c.notifications[1:]
return notification, nil
}
err = c.waitForPreviousCancelQuery(ctx)
if err != nil {
return nil, err
}
err = c.initContext(ctx)
if err != nil {
return nil, err
}
defer func() {
err = c.termContext(err)
}()
if err = c.lock(); err != nil {
return nil, err
}
defer func() {
if unlockErr := c.unlock(); unlockErr != nil && err == nil {
err = unlockErr
}
}()
if err := c.ensureConnectionReadyForQuery(); err != nil {
return nil, err
}
for {
msg, err := c.rxMsg()
if err != nil {
return nil, err
}
err = c.processContextFreeMsg(msg)
if err != nil {
return nil, err
}
if len(c.notifications) > 0 {
notification := c.notifications[0]
c.notifications = c.notifications[1:]
return notification, nil
}
}
}
func (c *Conn) IsAlive() bool {
c.mux.Lock()
defer c.mux.Unlock()
return c.status >= connStatusIdle
}
func (c *Conn) CauseOfDeath() error {
c.mux.Lock()
defer c.mux.Unlock()
return c.causeOfDeath
}
func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) {
if ps, present := c.preparedStatements[sql]; present {
return c.sendPreparedQuery(ps, arguments...)
}
return c.sendSimpleQuery(sql, arguments...)
}
func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error {
if err := c.ensureConnectionReadyForQuery(); err != nil {
return err
}
if len(args) == 0 {
buf := appendQuery(c.wbuf, sql)
_, err := c.conn.Write(buf)
if err != nil {
c.die(err)
return err
}
c.pendingReadyForQueryCount++
return nil
}
ps, err := c.Prepare("", sql)
if err != nil {
return err
}
return c.sendPreparedQuery(ps, args...)
}
func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}) (err error) {
if len(ps.ParameterOIDs) != len(arguments) {
return errors.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOIDs), len(arguments))
}
if err := c.ensureConnectionReadyForQuery(); err != nil {
return err
}
resultFormatCodes := make([]int16, len(ps.FieldDescriptions))
for i, fd := range ps.FieldDescriptions {
resultFormatCodes[i] = fd.FormatCode
}
buf, err := appendBind(c.wbuf, "", ps.Name, c.ConnInfo, ps.ParameterOIDs, arguments, resultFormatCodes)
if err != nil {
return err
}
buf = appendExecute(buf, "", 0)
buf = appendSync(buf)
n, err := c.conn.Write(buf)
if err != nil {
if fatalWriteErr(n, err) {
c.die(err)
}
return err
}
c.pendingReadyForQueryCount++
return nil
}
// fatalWriteError takes the response of a net.Conn.Write and determines if it is fatal
func fatalWriteErr(bytesWritten int, err error) bool {
// Partial writes break the connection
if bytesWritten > 0 {
return true
}
netErr, is := err.(net.Error)
return !(is && netErr.Timeout())
}
// Exec executes sql. sql can be either a prepared statement name or an SQL string.
// arguments should be referenced positionally from the sql string as $1, $2, etc.
func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
return c.ExecEx(context.Background(), sql, nil, arguments...)
}
// Processes messages that are not exclusive to one context such as
// authentication or query response. The response to these messages is the same
// regardless of when they occur. It also ignores messages that are only
// meaningful in a given context. These messages can occur due to a context
// deadline interrupting message processing. For example, an interrupted query
// may have left DataRow messages on the wire.
func (c *Conn) processContextFreeMsg(msg pgproto3.BackendMessage) (err error) {
switch msg := msg.(type) {
case *pgproto3.ErrorResponse:
return c.rxErrorResponse(msg)
case *pgproto3.NoticeResponse:
c.rxNoticeResponse(msg)
case *pgproto3.NotificationResponse:
c.rxNotificationResponse(msg)
case *pgproto3.ReadyForQuery:
c.rxReadyForQuery(msg)
case *pgproto3.ParameterStatus:
c.rxParameterStatus(msg)
}
return nil
}
func (c *Conn) rxMsg() (pgproto3.BackendMessage, error) {
if !c.IsAlive() {
return nil, ErrDeadConn
}
msg, err := c.frontend.Receive()
if err != nil {
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
c.die(err)
}
return nil, err
}
c.lastActivityTime = time.Now()
// fmt.Printf("rxMsg: %#v\n", msg)
return msg, nil
}
func (c *Conn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) {
switch msg.Type {
case pgproto3.AuthTypeOk:
case pgproto3.AuthTypeCleartextPassword:
err = c.txPasswordMessage(c.config.Password)
case pgproto3.AuthTypeMD5Password:
digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+string(msg.Salt[:]))
err = c.txPasswordMessage(digestedPassword)
default:
err = errors.New("Received unknown authentication message")
}
return
}
func hexMD5(s string) string {
hash := md5.New()
io.WriteString(hash, s)
return hex.EncodeToString(hash.Sum(nil))
}
func (c *Conn) rxParameterStatus(msg *pgproto3.ParameterStatus) {
c.RuntimeParams[msg.Name] = msg.Value
}
func (c *Conn) rxErrorResponse(msg *pgproto3.ErrorResponse) PgError {
err := PgError{
Severity: msg.Severity,
Code: msg.Code,
Message: msg.Message,
Detail: msg.Detail,
Hint: msg.Hint,
Position: msg.Position,
InternalPosition: msg.InternalPosition,
InternalQuery: msg.InternalQuery,
Where: msg.Where,
SchemaName: msg.SchemaName,
TableName: msg.TableName,
ColumnName: msg.ColumnName,
DataTypeName: msg.DataTypeName,
ConstraintName: msg.ConstraintName,
File: msg.File,
Line: msg.Line,
Routine: msg.Routine,
}
if err.Severity == "FATAL" {
c.die(err)
}
return err
}
func (c *Conn) rxNoticeResponse(msg *pgproto3.NoticeResponse) {
if c.onNotice == nil {
return
}
notice := &Notice{
Severity: msg.Severity,
Code: msg.Code,
Message: msg.Message,
Detail: msg.Detail,
Hint: msg.Hint,
Position: msg.Position,
InternalPosition: msg.InternalPosition,
InternalQuery: msg.InternalQuery,
Where: msg.Where,
SchemaName: msg.SchemaName,
TableName: msg.TableName,
ColumnName: msg.ColumnName,
DataTypeName: msg.DataTypeName,
ConstraintName: msg.ConstraintName,
File: msg.File,
Line: msg.Line,
Routine: msg.Routine,
}
c.onNotice(c, notice)
}
func (c *Conn) rxBackendKeyData(msg *pgproto3.BackendKeyData) {
c.pid = msg.ProcessID
c.secretKey = msg.SecretKey
}
func (c *Conn) rxReadyForQuery(msg *pgproto3.ReadyForQuery) {
c.pendingReadyForQueryCount--
c.txStatus = msg.TxStatus
}
func (c *Conn) rxRowDescription(msg *pgproto3.RowDescription) []FieldDescription {
fields := make([]FieldDescription, len(msg.Fields))
for i := 0; i < len(fields); i++ {
fields[i].Name = msg.Fields[i].Name
fields[i].Table = pgtype.OID(msg.Fields[i].TableOID)
fields[i].AttributeNumber = msg.Fields[i].TableAttributeNumber
fields[i].DataType = pgtype.OID(msg.Fields[i].DataTypeOID)
fields[i].DataTypeSize = msg.Fields[i].DataTypeSize
fields[i].Modifier = msg.Fields[i].TypeModifier
fields[i].FormatCode = msg.Fields[i].Format
}
return fields
}
func (c *Conn) rxParameterDescription(msg *pgproto3.ParameterDescription) []pgtype.OID {
parameters := make([]pgtype.OID, len(msg.ParameterOIDs))
for i := 0; i < len(parameters); i++ {
parameters[i] = pgtype.OID(msg.ParameterOIDs[i])
}
return parameters
}
func (c *Conn) rxNotificationResponse(msg *pgproto3.NotificationResponse) {
n := new(Notification)
n.PID = msg.PID
n.Channel = msg.Channel
n.Payload = msg.Payload
c.notifications = append(c.notifications, n)
}
func (c *Conn) startTLS(tlsConfig *tls.Config) (err error) {
err = binary.Write(c.conn, binary.BigEndian, []int32{8, 80877103})
if err != nil {
return
}
response := make([]byte, 1)
if _, err = io.ReadFull(c.conn, response); err != nil {
return
}
if response[0] != 'S' {
return ErrTLSRefused
}
c.conn = tls.Client(c.conn, tlsConfig)
return nil
}
func (c *Conn) txPasswordMessage(password string) (err error) {
buf := c.wbuf
buf = append(buf, 'p')
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
buf = append(buf, password...)
buf = append(buf, 0)
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
_, err = c.conn.Write(buf)
return err
}
func (c *Conn) die(err error) {
c.mux.Lock()
defer c.mux.Unlock()
if c.status == connStatusClosed {
return
}
c.status = connStatusClosed
c.causeOfDeath = err
c.conn.Close()
}
func (c *Conn) lock() error {
c.mux.Lock()
defer c.mux.Unlock()
if c.status != connStatusIdle {
return ErrConnBusy
}
c.status = connStatusBusy
return nil
}
func (c *Conn) unlock() error {
c.mux.Lock()
defer c.mux.Unlock()
if c.status != connStatusBusy {
return errors.New("unlock conn that is not busy")
}
c.status = connStatusIdle
return nil
}
func (c *Conn) shouldLog(lvl int) bool {
return c.logger != nil && c.logLevel >= lvl
}
func (c *Conn) log(lvl LogLevel, msg string, data map[string]interface{}) {
if data == nil {
data = map[string]interface{}{}
}
if c.pid != 0 {
data["pid"] = c.pid
}
c.logger.Log(lvl, msg, data)
}
// SetLogger replaces the current logger and returns the previous logger.
func (c *Conn) SetLogger(logger Logger) Logger {
oldLogger := c.logger
c.logger = logger
return oldLogger
}
// SetLogLevel replaces the current log level and returns the previous log
// level.
func (c *Conn) SetLogLevel(lvl int) (int, error) {
oldLvl := c.logLevel
if lvl < LogLevelNone || lvl > LogLevelTrace {
return oldLvl, ErrInvalidLogLevel
}
c.logLevel = lvl
return lvl, nil
}
func quoteIdentifier(s string) string {
return `"` + strings.Replace(s, `"`, `""`, -1) + `"`
}
func doCancel(c *Conn) error {
network, address := c.config.networkAddress()
cancelConn, err := c.config.Dial(network, address)
if err != nil {
return err
}
defer cancelConn.Close()
// If server doesn't process cancellation request in bounded time then abort.
now := time.Now()
err = cancelConn.SetDeadline(now.Add(15 * time.Second))
if err != nil {
return err
}
buf := make([]byte, 16)
binary.BigEndian.PutUint32(buf[0:4], 16)
binary.BigEndian.PutUint32(buf[4:8], 80877102)
binary.BigEndian.PutUint32(buf[8:12], uint32(c.pid))
binary.BigEndian.PutUint32(buf[12:16], uint32(c.secretKey))
_, err = cancelConn.Write(buf)
if err != nil {
return err
}
_, err = cancelConn.Read(buf)
if err != io.EOF {
return errors.Errorf("Server failed to close connection after cancel query request: %v %v", err, buf)
}
return nil
}
// cancelQuery sends a cancel request to the PostgreSQL server. It returns an
// error if unable to deliver the cancel request, but lack of an error does not
// ensure that the query was canceled. As specified in the documentation, there
// is no way to be sure a query was canceled. See
// https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861
func (c *Conn) cancelQuery() {
if err := c.conn.SetDeadline(time.Now()); err != nil {
c.Close() // Close connection if unable to set deadline
return
}
var cancelFn func(*Conn) error
completeCh := make(chan struct{})
c.mux.Lock()
c.cancelQueryCompleted = completeCh
c.mux.Unlock()
if c.config.CustomCancel != nil {
cancelFn = c.config.CustomCancel
} else {
cancelFn = doCancel
}
go func() {
defer close(completeCh)
err := cancelFn(c)
if err != nil {
c.Close() // Something is very wrong. Terminate the connection.
}
}()
}
func (c *Conn) Ping(ctx context.Context) error {
_, err := c.ExecEx(ctx, ";", nil)
return err
}
func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (CommandTag, error) {
c.lastStmtSent = false
err := c.waitForPreviousCancelQuery(ctx)
if err != nil {
return "", err
}
if err := c.lock(); err != nil {
return "", err
}
defer c.unlock()
startTime := time.Now()
c.lastActivityTime = startTime
commandTag, err := c.execEx(ctx, sql, options, arguments...)
if err != nil {
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err})
}
return commandTag, err
}
if c.shouldLog(LogLevelInfo) {
endTime := time.Now()
c.log(LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag})
}
return commandTag, err
}
func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) {
err = c.initContext(ctx)
if err != nil {
return "", err
}
defer func() {
err = c.termContext(err)
}()
if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) {
c.lastStmtSent = true
err = c.sanitizeAndSendSimpleQuery(sql, arguments...)
if err != nil {
return "", err
}
} else if options != nil && len(options.ParameterOIDs) > 0 {
if err := c.ensureConnectionReadyForQuery(); err != nil {
return "", err
}
buf, err := c.buildOneRoundTripExec(c.wbuf, sql, options, arguments)
if err != nil {
return "", err
}
buf = appendSync(buf)
c.lastStmtSent = true
n, err := c.conn.Write(buf)
if err != nil && fatalWriteErr(n, err) {
c.die(err)
return "", err
}
c.pendingReadyForQueryCount++
} else {
if len(arguments) > 0 {
ps, ok := c.preparedStatements[sql]
if !ok {
var err error
ps, err = c.prepareEx("", sql, nil)
if err != nil {
return "", err
}
}
c.lastStmtSent = true
err = c.sendPreparedQuery(ps, arguments...)
if err != nil {
return "", err
}
} else {
c.lastStmtSent = true
if err = c.sendQuery(sql, arguments...); err != nil {
return
}
}
}
var softErr error
for {
msg, err := c.rxMsg()
if err != nil {
return commandTag, err
}
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
c.rxReadyForQuery(msg)
return commandTag, softErr
case *pgproto3.CommandComplete:
commandTag = CommandTag(msg.CommandTag)
default:
if e := c.processContextFreeMsg(msg); e != nil && softErr == nil {
softErr = e
}
}
}
}
func (c *Conn) buildOneRoundTripExec(buf []byte, sql string, options *QueryExOptions, arguments []interface{}) ([]byte, error) {
if len(arguments) != len(options.ParameterOIDs) {
return nil, errors.Errorf("mismatched number of arguments (%d) and options.ParameterOIDs (%d)", len(arguments), len(options.ParameterOIDs))
}
if len(options.ParameterOIDs) > 65535 {
return nil, errors.Errorf("Number of QueryExOptions ParameterOIDs must be between 0 and 65535, received %d", len(options.ParameterOIDs))
}
buf = appendParse(buf, "", sql, options.ParameterOIDs)
buf, err := appendBind(buf, "", "", c.ConnInfo, options.ParameterOIDs, arguments, nil)
if err != nil {
return nil, err
}
buf = appendExecute(buf, "", 0)
return buf, nil
}
func (c *Conn) initContext(ctx context.Context) error {
if c.ctxInProgress {
return errors.New("ctx already in progress")
}
if ctx.Done() == nil {
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
default:
}
c.ctxInProgress = true
go c.contextHandler(ctx)
return nil
}
func (c *Conn) termContext(opErr error) error {
if !c.ctxInProgress {
return opErr
}
var err error
select {
case err = <-c.closedChan:
if opErr == nil {
err = nil
}
case c.doneChan <- struct{}{}:
err = opErr
}
c.ctxInProgress = false
return err
}
func (c *Conn) contextHandler(ctx context.Context) {
select {
case <-ctx.Done():
c.cancelQuery()
c.closedChan <- ctx.Err()
case <-c.doneChan:
}
}
// WaitUntilReady will return when the connection is ready for another query
func (c *Conn) WaitUntilReady(ctx context.Context) error {
err := c.waitForPreviousCancelQuery(ctx)
if err != nil {
return err
}
return c.ensureConnectionReadyForQuery()
}
func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error {
c.mux.Lock()
completeCh := c.cancelQueryCompleted
c.mux.Unlock()
select {
case <-completeCh:
if err := c.conn.SetDeadline(time.Time{}); err != nil {
c.Close() // Close connection if unable to disable deadline
return err
}
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (c *Conn) ensureConnectionReadyForQuery() error {
for c.pendingReadyForQueryCount > 0 {
msg, err := c.rxMsg()
if err != nil {
return err
}
switch msg := msg.(type) {
case *pgproto3.ErrorResponse:
pgErr := c.rxErrorResponse(msg)
if pgErr.Severity == "FATAL" {
return pgErr
}
default:
err = c.processContextFreeMsg(msg)
if err != nil {
return err
}
}
}
return nil
}
func connInfoFromRows(rows *Rows, err error) (map[string]pgtype.OID, error) {
if err != nil {
return nil, err
}
defer rows.Close()
nameOIDs := make(map[string]pgtype.OID, 256)
for rows.Next() {
var oid pgtype.OID
var name pgtype.Text
if err = rows.Scan(&oid, &name); err != nil {
return nil, err
}
nameOIDs[name.String] = oid
}
if err = rows.Err(); err != nil {
return nil, err
}
return nameOIDs, err
}
// LastStmtSent returns true if the last call to Query(Ex)/Exec(Ex) attempted to
// send the statement over the wire. Each call to a Query(Ex)/Exec(Ex) resets
// the value to false initially until the statement has been sent. This does
// NOT mean that the statement was successful or even received, it just means
// that a write was attempted and therefore it could have been executed. Calls
// to prepare a statement are ignored, only when the prepared statement is
// attempted to be executed will this return true.
func (c *Conn) LastStmtSent() bool {
return c.lastStmtSent
}