2018-10-03 16:55:26 +00:00
package pgx
import (
"context"
"crypto/md5"
"crypto/tls"
"crypto/x509"
"encoding/binary"
"encoding/hex"
"fmt"
"io"
"io/ioutil"
"net"
"net/url"
"os"
"os/user"
"path/filepath"
"reflect"
"regexp"
"strconv"
"strings"
"sync"
"time"
"github.com/pkg/errors"
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgproto3"
"github.com/jackc/pgx/pgtype"
)
const (
connStatusUninitialized = iota
connStatusClosed
connStatusIdle
connStatusBusy
)
// minimalConnInfo has just enough static type information to establish the
// connection and retrieve the type data.
var minimalConnInfo * pgtype . ConnInfo
func init ( ) {
minimalConnInfo = pgtype . NewConnInfo ( )
minimalConnInfo . InitializeDataTypes ( map [ string ] pgtype . OID {
"int4" : pgtype . Int4OID ,
"name" : pgtype . NameOID ,
"oid" : pgtype . OIDOID ,
"text" : pgtype . TextOID ,
"varchar" : pgtype . VarcharOID ,
} )
}
// NoticeHandler is a function that can handle notices received from the
// PostgreSQL server. Notices can be received at any time, usually during
// handling of a query response. The *Conn is provided so the handler is aware
// of the origin of the notice, but it must not invoke any query method. Be
// aware that this is distinct from LISTEN/NOTIFY notification.
type NoticeHandler func ( * Conn , * Notice )
// DialFunc is a function that can be used to connect to a PostgreSQL server
type DialFunc func ( network , addr string ) ( net . Conn , error )
// ConnConfig contains all the options used to establish a connection.
type ConnConfig struct {
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
Port uint16 // default: 5432
Database string
User string // default: OS user name
Password string
TLSConfig * tls . Config // config for TLS connection -- nil disables TLS
UseFallbackTLS bool // Try FallbackTLSConfig if connecting with TLSConfig fails. Used for preferring TLS, but allowing unencrypted, or vice-versa
FallbackTLSConfig * tls . Config // config for fallback TLS connection (only used if UseFallBackTLS is true)-- nil disables TLS
Logger Logger
2019-04-13 07:44:06 +00:00
LogLevel int
2018-10-03 16:55:26 +00:00
Dial DialFunc
RuntimeParams map [ string ] string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
OnNotice NoticeHandler // Callback function called when a notice response is received.
CustomConnInfo func ( * Conn ) ( * pgtype . ConnInfo , error ) // Callback function to implement connection strategies for different backends. crate, pgbouncer, pgpool, etc.
2019-04-12 15:51:37 +00:00
CustomCancel func ( * Conn ) error // Callback function used to override cancellation behavior
2018-10-03 16:55:26 +00:00
// PreferSimpleProtocol disables implicit prepared statement usage. By default
// pgx automatically uses the unnamed prepared statement for Query and
// QueryRow. It also uses a prepared statement when Exec has arguments. This
// can improve performance due to being able to use the binary format. It also
// does not rely on client side parameter sanitization. However, it does incur
// two round-trips per query and may be incompatible proxies such as
// PGBouncer. Setting PreferSimpleProtocol causes the simple protocol to be
// used by default. The same functionality can be controlled on a per query
// basis by setting QueryExOptions.SimpleProtocol.
PreferSimpleProtocol bool
}
func ( cc * ConnConfig ) networkAddress ( ) ( network , address string ) {
network = "tcp"
address = fmt . Sprintf ( "%s:%d" , cc . Host , cc . Port )
// See if host is a valid path, if yes connect with a socket
if _ , err := os . Stat ( cc . Host ) ; err == nil {
// For backward compatibility accept socket file paths -- but directories are now preferred
network = "unix"
address = cc . Host
if ! strings . Contains ( address , "/.s.PGSQL." ) {
address = filepath . Join ( address , ".s.PGSQL." ) + strconv . FormatInt ( int64 ( cc . Port ) , 10 )
}
}
return network , address
}
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
// Use ConnPool to manage access to multiple database connections from multiple
// goroutines.
type Conn struct {
conn net . Conn // the underlying TCP or unix domain socket connection
lastActivityTime time . Time // the last time the connection was used
wbuf [ ] byte
pid uint32 // backend pid
secretKey uint32 // key to use to send a cancel query message to the server
RuntimeParams map [ string ] string // parameters that have been reported by the server
config ConnConfig // config used when establishing this connection
txStatus byte
preparedStatements map [ string ] * PreparedStatement
channels map [ string ] struct { }
notifications [ ] * Notification
logger Logger
2019-04-13 07:44:06 +00:00
logLevel int
2018-10-03 16:55:26 +00:00
fp * fastpath
poolResetCount int
preallocatedRows [ ] Rows
onNotice NoticeHandler
mux sync . Mutex
status byte // One of connStatus* constants
causeOfDeath error
2019-04-12 15:51:37 +00:00
pendingReadyForQueryCount int // number of ReadyForQuery messages expected
2018-10-03 16:55:26 +00:00
cancelQueryCompleted chan struct { }
2019-04-12 15:51:37 +00:00
lastStmtSent bool
2018-10-03 16:55:26 +00:00
// context support
ctxInProgress bool
doneChan chan struct { }
closedChan chan error
ConnInfo * pgtype . ConnInfo
frontend * pgproto3 . Frontend
}
// PreparedStatement is a description of a prepared statement
type PreparedStatement struct {
Name string
SQL string
FieldDescriptions [ ] FieldDescription
ParameterOIDs [ ] pgtype . OID
}
// PrepareExOptions is an option struct that can be passed to PrepareEx
type PrepareExOptions struct {
ParameterOIDs [ ] pgtype . OID
}
// Notification is a message received from the PostgreSQL LISTEN/NOTIFY system
type Notification struct {
PID uint32 // backend pid that sent the notification
Channel string // channel from which notification was received
Payload string
}
// CommandTag is the result of an Exec function
type CommandTag string
// RowsAffected returns the number of rows affected. If the CommandTag was not
// for a row affecting command (such as "CREATE TABLE") then it returns 0
func ( ct CommandTag ) RowsAffected ( ) int64 {
s := string ( ct )
index := strings . LastIndex ( s , " " )
if index == - 1 {
return 0
}
n , _ := strconv . ParseInt ( s [ index + 1 : ] , 10 , 64 )
return n
}
// Identifier a PostgreSQL identifier or name. Identifiers can be composed of
// multiple parts such as ["schema", "table"] or ["table", "column"].
type Identifier [ ] string
// Sanitize returns a sanitized string safe for SQL interpolation.
func ( ident Identifier ) Sanitize ( ) string {
parts := make ( [ ] string , len ( ident ) )
for i := range ident {
parts [ i ] = ` " ` + strings . Replace ( ident [ i ] , ` " ` , ` "" ` , - 1 ) + ` " `
}
return strings . Join ( parts , "." )
}
// ErrNoRows occurs when rows are expected but none are returned.
var ErrNoRows = errors . New ( "no rows in result set" )
// ErrDeadConn occurs on an attempt to use a dead connection
var ErrDeadConn = errors . New ( "conn is dead" )
// ErrTLSRefused occurs when the connection attempt requires TLS and the
// PostgreSQL server refuses to use TLS
var ErrTLSRefused = errors . New ( "server refused TLS connection" )
// ErrConnBusy occurs when the connection is busy (for example, in the middle of
// reading query results) and another action is attempted.
var ErrConnBusy = errors . New ( "conn is busy" )
// ErrInvalidLogLevel occurs on attempt to set an invalid log level.
var ErrInvalidLogLevel = errors . New ( "invalid log level" )
// ProtocolError occurs when unexpected data is received from PostgreSQL
type ProtocolError string
func ( e ProtocolError ) Error ( ) string {
return string ( e )
}
// Connect establishes a connection with a PostgreSQL server using config.
// config.Host must be specified. config.User will default to the OS user name.
// Other config fields are optional.
func Connect ( config ConnConfig ) ( c * Conn , err error ) {
return connect ( config , minimalConnInfo )
}
func defaultDialer ( ) * net . Dialer {
return & net . Dialer { KeepAlive : 5 * time . Minute }
}
func connect ( config ConnConfig , connInfo * pgtype . ConnInfo ) ( c * Conn , err error ) {
c = new ( Conn )
c . config = config
c . ConnInfo = connInfo
if c . config . LogLevel != 0 {
c . logLevel = c . config . LogLevel
} else {
// Preserve pre-LogLevel behavior by defaulting to LogLevelDebug
c . logLevel = LogLevelDebug
}
c . logger = c . config . Logger
if c . config . User == "" {
user , err := user . Current ( )
if err != nil {
return nil , err
}
c . config . User = user . Username
if c . shouldLog ( LogLevelDebug ) {
c . log ( LogLevelDebug , "Using default connection config" , map [ string ] interface { } { "User" : c . config . User } )
}
}
if c . config . Port == 0 {
c . config . Port = 5432
if c . shouldLog ( LogLevelDebug ) {
c . log ( LogLevelDebug , "Using default connection config" , map [ string ] interface { } { "Port" : c . config . Port } )
}
}
c . onNotice = config . OnNotice
network , address := c . config . networkAddress ( )
if c . config . Dial == nil {
d := defaultDialer ( )
c . config . Dial = d . Dial
}
if c . shouldLog ( LogLevelInfo ) {
c . log ( LogLevelInfo , "Dialing PostgreSQL server" , map [ string ] interface { } { "network" : network , "address" : address } )
}
err = c . connect ( config , network , address , config . TLSConfig )
if err != nil && config . UseFallbackTLS {
if c . shouldLog ( LogLevelInfo ) {
c . log ( LogLevelInfo , "connect with TLSConfig failed, trying FallbackTLSConfig" , map [ string ] interface { } { "err" : err } )
}
err = c . connect ( config , network , address , config . FallbackTLSConfig )
}
if err != nil {
if c . shouldLog ( LogLevelError ) {
c . log ( LogLevelError , "connect failed" , map [ string ] interface { } { "err" : err } )
}
return nil , err
}
return c , nil
}
func ( c * Conn ) connect ( config ConnConfig , network , address string , tlsConfig * tls . Config ) ( err error ) {
c . conn , err = c . config . Dial ( network , address )
if err != nil {
return err
}
defer func ( ) {
if c != nil && err != nil {
c . conn . Close ( )
c . mux . Lock ( )
c . status = connStatusClosed
c . mux . Unlock ( )
}
} ( )
c . RuntimeParams = make ( map [ string ] string )
c . preparedStatements = make ( map [ string ] * PreparedStatement )
c . channels = make ( map [ string ] struct { } )
c . lastActivityTime = time . Now ( )
2019-04-12 15:51:37 +00:00
c . cancelQueryCompleted = make ( chan struct { } )
close ( c . cancelQueryCompleted )
2018-10-03 16:55:26 +00:00
c . doneChan = make ( chan struct { } )
c . closedChan = make ( chan error )
c . wbuf = make ( [ ] byte , 0 , 1024 )
c . mux . Lock ( )
c . status = connStatusIdle
c . mux . Unlock ( )
if tlsConfig != nil {
if c . shouldLog ( LogLevelDebug ) {
c . log ( LogLevelDebug , "starting TLS handshake" , nil )
}
if err := c . startTLS ( tlsConfig ) ; err != nil {
return err
}
}
c . frontend , err = pgproto3 . NewFrontend ( c . conn , c . conn )
if err != nil {
return err
}
startupMsg := pgproto3 . StartupMessage {
ProtocolVersion : pgproto3 . ProtocolVersionNumber ,
Parameters : make ( map [ string ] string ) ,
}
// Copy default run-time params
for k , v := range config . RuntimeParams {
startupMsg . Parameters [ k ] = v
}
startupMsg . Parameters [ "user" ] = c . config . User
if c . config . Database != "" {
startupMsg . Parameters [ "database" ] = c . config . Database
}
if _ , err := c . conn . Write ( startupMsg . Encode ( nil ) ) ; err != nil {
return err
}
c . pendingReadyForQueryCount = 1
for {
msg , err := c . rxMsg ( )
if err != nil {
return err
}
switch msg := msg . ( type ) {
case * pgproto3 . BackendKeyData :
c . rxBackendKeyData ( msg )
case * pgproto3 . Authentication :
if err = c . rxAuthenticationX ( msg ) ; err != nil {
return err
}
case * pgproto3 . ReadyForQuery :
c . rxReadyForQuery ( msg )
if c . shouldLog ( LogLevelInfo ) {
c . log ( LogLevelInfo , "connection established" , nil )
}
// Replication connections can't execute the queries to
// populate the c.PgTypes and c.pgsqlAfInet
if _ , ok := config . RuntimeParams [ "replication" ] ; ok {
return nil
}
if c . ConnInfo == minimalConnInfo {
err = c . initConnInfo ( )
if err != nil {
return err
}
}
return nil
default :
if err = c . processContextFreeMsg ( msg ) ; err != nil {
return err
}
}
}
}
func initPostgresql ( c * Conn ) ( * pgtype . ConnInfo , error ) {
const (
namedOIDQuery = ` select t . oid ,
case when nsp . nspname in ( ' pg_catalog ' , ' public ' ) then t . typname
else nsp . nspname || '.' || t . typname
end
from pg_type t
left join pg_type base_type on t . typelem = base_type . oid
left join pg_namespace nsp on t . typnamespace = nsp . oid
where (
t . typtype in ( 'b' , 'p' , 'r' , 'e' )
and ( base_type . oid is null or base_type . typtype in ( 'b' , 'p' , 'r' ) )
) `
)
nameOIDs , err := connInfoFromRows ( c . Query ( namedOIDQuery ) )
if err != nil {
return nil , err
}
cinfo := pgtype . NewConnInfo ( )
cinfo . InitializeDataTypes ( nameOIDs )
if err = c . initConnInfoEnumArray ( cinfo ) ; err != nil {
return nil , err
}
if err = c . initConnInfoDomains ( cinfo ) ; err != nil {
return nil , err
}
return cinfo , nil
}
func ( c * Conn ) initConnInfo ( ) ( err error ) {
var (
connInfo * pgtype . ConnInfo
)
if c . config . CustomConnInfo != nil {
if c . ConnInfo , err = c . config . CustomConnInfo ( c ) ; err != nil {
return err
}
return nil
}
if connInfo , err = initPostgresql ( c ) ; err == nil {
c . ConnInfo = connInfo
return err
}
// Check if CrateDB specific approach might still allow us to connect.
if connInfo , err = c . crateDBTypesQuery ( err ) ; err == nil {
c . ConnInfo = connInfo
}
return err
}
// initConnInfoEnumArray introspects for arrays of enums and registers a data type for them.
func ( c * Conn ) initConnInfoEnumArray ( cinfo * pgtype . ConnInfo ) error {
nameOIDs := make ( map [ string ] pgtype . OID , 16 )
rows , err := c . Query ( ` select t . oid , t . typname
from pg_type t
join pg_type base_type on t . typelem = base_type . oid
where t . typtype = 'b'
and base_type . typtype = 'e' ` )
if err != nil {
return err
}
for rows . Next ( ) {
var oid pgtype . OID
var name pgtype . Text
if err := rows . Scan ( & oid , & name ) ; err != nil {
return err
}
nameOIDs [ name . String ] = oid
}
if rows . Err ( ) != nil {
return rows . Err ( )
}
for name , oid := range nameOIDs {
cinfo . RegisterDataType ( pgtype . DataType {
Value : & pgtype . EnumArray { } ,
Name : name ,
OID : oid ,
} )
}
return nil
}
// initConnInfoDomains introspects for domains and registers a data type for them.
func ( c * Conn ) initConnInfoDomains ( cinfo * pgtype . ConnInfo ) error {
type domain struct {
oid pgtype . OID
name pgtype . Text
baseOID pgtype . OID
}
domains := make ( [ ] * domain , 0 , 16 )
rows , err := c . Query ( ` select t . oid , t . typname , t . typbasetype
from pg_type t
join pg_type base_type on t . typbasetype = base_type . oid
where t . typtype = 'd'
and base_type . typtype = 'b' ` )
if err != nil {
return err
}
for rows . Next ( ) {
var d domain
if err := rows . Scan ( & d . oid , & d . name , & d . baseOID ) ; err != nil {
return err
}
domains = append ( domains , & d )
}
if rows . Err ( ) != nil {
return rows . Err ( )
}
for _ , d := range domains {
baseDataType , ok := cinfo . DataTypeForOID ( d . baseOID )
if ok {
cinfo . RegisterDataType ( pgtype . DataType {
Value : reflect . New ( reflect . ValueOf ( baseDataType . Value ) . Elem ( ) . Type ( ) ) . Interface ( ) . ( pgtype . Value ) ,
Name : d . name . String ,
OID : d . oid ,
} )
}
}
return nil
}
// crateDBTypesQuery checks if the given err is likely to be the result of
// CrateDB not implementing the pg_types table correctly. If yes, a CrateDB
// specific query against pg_types is executed and its results are returned. If
// not, the original error is returned.
func ( c * Conn ) crateDBTypesQuery ( err error ) ( * pgtype . ConnInfo , error ) {
// CrateDB 2.1.6 is a database that implements the PostgreSQL wire protocol,
// but not perfectly. In particular, the pg_catalog schema containing the
// pg_type table is not visible by default and the pg_type.typtype column is
// not implemented. Therefor the query above currently returns the following
// error:
//
// pgx.PgError{Severity:"ERROR", Code:"XX000",
// Message:"TableUnknownException: Table 'test.pg_type' unknown",
// Detail:"", Hint:"", Position:0, InternalPosition:0, InternalQuery:"",
// Where:"", SchemaName:"", TableName:"", ColumnName:"", DataTypeName:"",
// ConstraintName:"", File:"Schemas.java", Line:99, Routine:"getTableInfo"}
//
// If CrateDB was to fix the pg_type table visbility in the future, we'd
// still get this error until typtype column is implemented:
//
// pgx.PgError{Severity:"ERROR", Code:"XX000",
// Message:"ColumnUnknownException: Column typtype unknown", Detail:"",
// Hint:"", Position:0, InternalPosition:0, InternalQuery:"", Where:"",
// SchemaName:"", TableName:"", ColumnName:"", DataTypeName:"",
// ConstraintName:"", File:"FullQualifiedNameFieldProvider.java", Line:132,
//
// Additionally CrateDB doesn't implement Postgres error codes [2], and
// instead always returns "XX000" (internal_error). The code below uses all
// of this knowledge as a heuristic to detect CrateDB. If CrateDB is
// detected, a CrateDB specific pg_type query is executed instead.
//
// The heuristic is designed to still work even if CrateDB fixes [2] or
// renames its internal exception names. If both are changed but pg_types
// isn't fixed, this code will need to be changed.
//
// There is also a small chance the heuristic will yield a false positive for
// non-CrateDB databases (e.g. if a real Postgres instance returns a XX000
// error), but hopefully there will be no harm in attempting the alternative
// query in this case.
//
// CrateDB also uses the type varchar for the typname column which required
// adding varchar to the minimalConnInfo init code.
//
// Also see the discussion here [3].
//
// [1] https://crate.io/
// [2] https://github.com/crate/crate/issues/5027
// [3] https://github.com/jackc/pgx/issues/320
if pgErr , ok := err . ( PgError ) ; ok &&
( pgErr . Code == "XX000" ||
strings . Contains ( pgErr . Message , "TableUnknownException" ) ||
strings . Contains ( pgErr . Message , "ColumnUnknownException" ) ) {
var (
nameOIDs map [ string ] pgtype . OID
)
if nameOIDs , err = connInfoFromRows ( c . Query ( ` select oid, typname from pg_catalog.pg_type ` ) ) ; err != nil {
return nil , err
}
cinfo := pgtype . NewConnInfo ( )
cinfo . InitializeDataTypes ( nameOIDs )
return cinfo , err
}
return nil , err
}
// PID returns the backend PID for this connection.
func ( c * Conn ) PID ( ) uint32 {
return c . pid
}
2019-04-12 15:51:37 +00:00
// LocalAddr returns the underlying connection's local address
func ( c * Conn ) LocalAddr ( ) ( net . Addr , error ) {
if ! c . IsAlive ( ) {
return nil , errors . New ( "connection not ready" )
}
return c . conn . LocalAddr ( ) , nil
}
2018-10-03 16:55:26 +00:00
// Close closes a connection. It is safe to call Close on a already closed
// connection.
func ( c * Conn ) Close ( ) ( err error ) {
c . mux . Lock ( )
defer c . mux . Unlock ( )
if c . status < connStatusIdle {
return nil
}
c . status = connStatusClosed
defer func ( ) {
c . conn . Close ( )
c . causeOfDeath = errors . New ( "Closed" )
if c . shouldLog ( LogLevelInfo ) {
c . log ( LogLevelInfo , "closed connection" , nil )
}
} ( )
err = c . conn . SetDeadline ( time . Time { } )
if err != nil && c . shouldLog ( LogLevelWarn ) {
c . log ( LogLevelWarn , "failed to clear deadlines to send close message" , map [ string ] interface { } { "err" : err } )
return err
}
_ , err = c . conn . Write ( [ ] byte { 'X' , 0 , 0 , 0 , 4 } )
if err != nil && c . shouldLog ( LogLevelWarn ) {
c . log ( LogLevelWarn , "failed to send terminate message" , map [ string ] interface { } { "err" : err } )
return err
}
err = c . conn . SetReadDeadline ( time . Now ( ) . Add ( 5 * time . Second ) )
if err != nil && c . shouldLog ( LogLevelWarn ) {
c . log ( LogLevelWarn , "failed to set read deadline to finish closing" , map [ string ] interface { } { "err" : err } )
return err
}
_ , err = c . conn . Read ( make ( [ ] byte , 1 ) )
if err != io . EOF {
return err
}
return nil
}
// Merge returns a new ConnConfig with the attributes of old and other
// combined. When an attribute is set on both, other takes precedence.
//
// As a security precaution, if the other TLSConfig is nil, all old TLS
// attributes will be preserved.
func ( old ConnConfig ) Merge ( other ConnConfig ) ConnConfig {
cc := old
if other . Host != "" {
cc . Host = other . Host
}
if other . Port != 0 {
cc . Port = other . Port
}
if other . Database != "" {
cc . Database = other . Database
}
if other . User != "" {
cc . User = other . User
}
if other . Password != "" {
cc . Password = other . Password
}
if other . TLSConfig != nil {
cc . TLSConfig = other . TLSConfig
cc . UseFallbackTLS = other . UseFallbackTLS
cc . FallbackTLSConfig = other . FallbackTLSConfig
}
if other . Logger != nil {
cc . Logger = other . Logger
}
if other . LogLevel != 0 {
cc . LogLevel = other . LogLevel
}
if other . Dial != nil {
cc . Dial = other . Dial
}
2019-04-13 07:44:06 +00:00
cc . PreferSimpleProtocol = other . PreferSimpleProtocol
2018-10-03 16:55:26 +00:00
cc . RuntimeParams = make ( map [ string ] string )
for k , v := range old . RuntimeParams {
cc . RuntimeParams [ k ] = v
}
for k , v := range other . RuntimeParams {
cc . RuntimeParams [ k ] = v
}
return cc
}
// ParseURI parses a database URI into ConnConfig
//
// Query parameters not used by the connection process are parsed into ConnConfig.RuntimeParams.
func ParseURI ( uri string ) ( ConnConfig , error ) {
var cp ConnConfig
url , err := url . Parse ( uri )
if err != nil {
return cp , err
}
if url . User != nil {
cp . User = url . User . Username ( )
cp . Password , _ = url . User . Password ( )
}
parts := strings . SplitN ( url . Host , ":" , 2 )
cp . Host = parts [ 0 ]
if len ( parts ) == 2 {
p , err := strconv . ParseUint ( parts [ 1 ] , 10 , 16 )
if err != nil {
return cp , err
}
cp . Port = uint16 ( p )
}
cp . Database = strings . TrimLeft ( url . Path , "/" )
if pgtimeout := url . Query ( ) . Get ( "connect_timeout" ) ; pgtimeout != "" {
timeout , err := strconv . ParseInt ( pgtimeout , 10 , 64 )
if err != nil {
return cp , err
}
d := defaultDialer ( )
d . Timeout = time . Duration ( timeout ) * time . Second
cp . Dial = d . Dial
}
tlsArgs := configTLSArgs {
sslCert : url . Query ( ) . Get ( "sslcert" ) ,
sslKey : url . Query ( ) . Get ( "sslkey" ) ,
sslMode : url . Query ( ) . Get ( "sslmode" ) ,
sslRootCert : url . Query ( ) . Get ( "sslrootcert" ) ,
}
err = configTLS ( tlsArgs , & cp )
if err != nil {
return cp , err
}
ignoreKeys := map [ string ] struct { } {
"connect_timeout" : { } ,
"sslcert" : { } ,
"sslkey" : { } ,
"sslmode" : { } ,
"sslrootcert" : { } ,
}
cp . RuntimeParams = make ( map [ string ] string )
for k , v := range url . Query ( ) {
if _ , ok := ignoreKeys [ k ] ; ok {
continue
}
2019-04-12 15:51:37 +00:00
if k == "host" {
cp . Host = v [ 0 ]
continue
}
2018-10-03 16:55:26 +00:00
cp . RuntimeParams [ k ] = v [ 0 ]
}
if cp . Password == "" {
pgpass ( & cp )
}
return cp , nil
}
var dsnRegexp = regexp . MustCompile ( ` ([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+)) ` )
// ParseDSN parses a database DSN (data source name) into a ConnConfig
//
// e.g. ParseDSN("user=username password=password host=1.2.3.4 port=5432 dbname=mydb sslmode=disable")
//
// Any options not used by the connection process are parsed into ConnConfig.RuntimeParams.
//
// e.g. ParseDSN("application_name=pgxtest search_path=admin user=username password=password host=1.2.3.4 dbname=mydb")
//
// ParseDSN tries to match libpq behavior with regard to sslmode. See comments
// for ParseEnvLibpq for more information on the security implications of
// sslmode options.
func ParseDSN ( s string ) ( ConnConfig , error ) {
var cp ConnConfig
m := dsnRegexp . FindAllStringSubmatch ( s , - 1 )
tlsArgs := configTLSArgs { }
cp . RuntimeParams = make ( map [ string ] string )
for _ , b := range m {
switch b [ 1 ] {
case "user" :
cp . User = b [ 2 ]
case "password" :
cp . Password = b [ 2 ]
case "host" :
cp . Host = b [ 2 ]
case "port" :
p , err := strconv . ParseUint ( b [ 2 ] , 10 , 16 )
if err != nil {
return cp , err
}
cp . Port = uint16 ( p )
case "dbname" :
cp . Database = b [ 2 ]
case "sslmode" :
tlsArgs . sslMode = b [ 2 ]
case "sslrootcert" :
tlsArgs . sslRootCert = b [ 2 ]
case "sslcert" :
tlsArgs . sslCert = b [ 2 ]
case "sslkey" :
tlsArgs . sslKey = b [ 2 ]
case "connect_timeout" :
timeout , err := strconv . ParseInt ( b [ 2 ] , 10 , 64 )
if err != nil {
return cp , err
}
d := defaultDialer ( )
d . Timeout = time . Duration ( timeout ) * time . Second
cp . Dial = d . Dial
default :
cp . RuntimeParams [ b [ 1 ] ] = b [ 2 ]
}
}
err := configTLS ( tlsArgs , & cp )
if err != nil {
return cp , err
}
if cp . Password == "" {
pgpass ( & cp )
}
return cp , nil
}
// ParseConnectionString parses either a URI or a DSN connection string.
// see ParseURI and ParseDSN for details.
func ParseConnectionString ( s string ) ( ConnConfig , error ) {
if u , err := url . Parse ( s ) ; err == nil && u . Scheme != "" {
return ParseURI ( s )
}
return ParseDSN ( s )
}
// ParseEnvLibpq parses the environment like libpq does into a ConnConfig
//
// See http://www.postgresql.org/docs/9.4/static/libpq-envars.html for details
// on the meaning of environment variables.
//
// ParseEnvLibpq currently recognizes the following environment variables:
// PGHOST
// PGPORT
// PGDATABASE
// PGUSER
// PGPASSWORD
// PGSSLMODE
// PGSSLCERT
// PGSSLKEY
// PGSSLROOTCERT
// PGAPPNAME
// PGCONNECT_TIMEOUT
//
// Important TLS Security Notes:
// ParseEnvLibpq tries to match libpq behavior with regard to PGSSLMODE. This
// includes defaulting to "prefer" behavior if no environment variable is set.
//
// See http://www.postgresql.org/docs/9.4/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION
// for details on what level of security each sslmode provides.
//
// "verify-ca" mode currently is treated as "verify-full". e.g. It has stronger
// security guarantees than it would with libpq. Do not rely on this behavior as it
// may be possible to match libpq in the future. If you need full security use
// "verify-full".
//
// Several of the PGSSLMODE options (including the default behavior of "prefer")
// will set UseFallbackTLS to true and FallbackTLSConfig to a disabled or
// weakened TLS mode. This means that if ParseEnvLibpq is used, but TLSConfig is
// later set from a different source that UseFallbackTLS MUST be set false to
// avoid the possibility of falling back to weaker or disabled security.
func ParseEnvLibpq ( ) ( ConnConfig , error ) {
var cc ConnConfig
cc . Host = os . Getenv ( "PGHOST" )
if pgport := os . Getenv ( "PGPORT" ) ; pgport != "" {
if port , err := strconv . ParseUint ( pgport , 10 , 16 ) ; err == nil {
cc . Port = uint16 ( port )
} else {
return cc , err
}
}
cc . Database = os . Getenv ( "PGDATABASE" )
cc . User = os . Getenv ( "PGUSER" )
cc . Password = os . Getenv ( "PGPASSWORD" )
if pgtimeout := os . Getenv ( "PGCONNECT_TIMEOUT" ) ; pgtimeout != "" {
if timeout , err := strconv . ParseInt ( pgtimeout , 10 , 64 ) ; err == nil {
d := defaultDialer ( )
d . Timeout = time . Duration ( timeout ) * time . Second
cc . Dial = d . Dial
} else {
return cc , err
}
}
tlsArgs := configTLSArgs {
sslMode : os . Getenv ( "PGSSLMODE" ) ,
sslKey : os . Getenv ( "PGSSLKEY" ) ,
sslCert : os . Getenv ( "PGSSLCERT" ) ,
sslRootCert : os . Getenv ( "PGSSLROOTCERT" ) ,
}
err := configTLS ( tlsArgs , & cc )
if err != nil {
return cc , err
}
cc . RuntimeParams = make ( map [ string ] string )
if appname := os . Getenv ( "PGAPPNAME" ) ; appname != "" {
cc . RuntimeParams [ "application_name" ] = appname
}
if cc . Password == "" {
pgpass ( & cc )
}
return cc , nil
}
type configTLSArgs struct {
sslMode string
sslRootCert string
sslCert string
sslKey string
}
// configTLS uses lib/pq's TLS parameters to reconstruct a coherent tls.Config.
// Inputs are parsed out and provided by ParseDSN() or ParseURI().
func configTLS ( args configTLSArgs , cc * ConnConfig ) error {
// Match libpq default behavior
if args . sslMode == "" {
args . sslMode = "prefer"
}
switch args . sslMode {
case "disable" :
cc . UseFallbackTLS = false
cc . TLSConfig = nil
cc . FallbackTLSConfig = nil
return nil
case "allow" :
cc . UseFallbackTLS = true
cc . FallbackTLSConfig = & tls . Config { InsecureSkipVerify : true }
case "prefer" :
cc . TLSConfig = & tls . Config { InsecureSkipVerify : true }
cc . UseFallbackTLS = true
cc . FallbackTLSConfig = nil
case "require" :
cc . TLSConfig = & tls . Config { InsecureSkipVerify : true }
case "verify-ca" , "verify-full" :
cc . TLSConfig = & tls . Config {
ServerName : cc . Host ,
}
default :
return errors . New ( "sslmode is invalid" )
}
if args . sslRootCert != "" {
caCertPool := x509 . NewCertPool ( )
caPath := args . sslRootCert
caCert , err := ioutil . ReadFile ( caPath )
if err != nil {
return errors . Wrapf ( err , "unable to read CA file %q" , caPath )
}
if ! caCertPool . AppendCertsFromPEM ( caCert ) {
return errors . Wrap ( err , "unable to add CA to cert pool" )
}
cc . TLSConfig . RootCAs = caCertPool
cc . TLSConfig . ClientCAs = caCertPool
}
sslcert := args . sslCert
sslkey := args . sslKey
if ( sslcert != "" && sslkey == "" ) || ( sslcert == "" && sslkey != "" ) {
return fmt . Errorf ( ` both "sslcert" and "sslkey" are required ` )
}
if sslcert != "" && sslkey != "" {
cert , err := tls . LoadX509KeyPair ( sslcert , sslkey )
if err != nil {
return errors . Wrap ( err , "unable to read cert" )
}
cc . TLSConfig . Certificates = [ ] tls . Certificate { cert }
}
return nil
}
// Prepare creates a prepared statement with name and sql. sql can contain placeholders
// for bound parameters. These placeholders are referenced positional as $1, $2, etc.
//
// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same
// name and sql arguments. This allows a code path to Prepare and Query/Exec without
// concern for if the statement has already been prepared.
func ( c * Conn ) Prepare ( name , sql string ) ( ps * PreparedStatement , err error ) {
return c . PrepareEx ( context . Background ( ) , name , sql , nil )
}
// PrepareEx creates a prepared statement with name and sql. sql can contain placeholders
// for bound parameters. These placeholders are referenced positional as $1, $2, etc.
2019-04-12 15:51:37 +00:00
// It differs from Prepare as it allows additional options (such as parameter OIDs) to be passed via struct
2018-10-03 16:55:26 +00:00
//
// PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same
// name and sql arguments. This allows a code path to PrepareEx and Query/Exec without
// concern for if the statement has already been prepared.
func ( c * Conn ) PrepareEx ( ctx context . Context , name , sql string , opts * PrepareExOptions ) ( ps * PreparedStatement , err error ) {
err = c . waitForPreviousCancelQuery ( ctx )
if err != nil {
return nil , err
}
err = c . initContext ( ctx )
if err != nil {
return nil , err
}
ps , err = c . prepareEx ( name , sql , opts )
err = c . termContext ( err )
return ps , err
}
func ( c * Conn ) prepareEx ( name , sql string , opts * PrepareExOptions ) ( ps * PreparedStatement , err error ) {
if name != "" {
if ps , ok := c . preparedStatements [ name ] ; ok && ps . SQL == sql {
return ps , nil
}
}
if err := c . ensureConnectionReadyForQuery ( ) ; err != nil {
return nil , err
}
if c . shouldLog ( LogLevelError ) {
defer func ( ) {
if err != nil {
c . log ( LogLevelError , "prepareEx failed" , map [ string ] interface { } { "err" : err , "name" : name , "sql" : sql } )
}
} ( )
}
if opts == nil {
opts = & PrepareExOptions { }
}
if len ( opts . ParameterOIDs ) > 65535 {
return nil , errors . Errorf ( "Number of PrepareExOptions ParameterOIDs must be between 0 and 65535, received %d" , len ( opts . ParameterOIDs ) )
}
buf := appendParse ( c . wbuf , name , sql , opts . ParameterOIDs )
buf = appendDescribe ( buf , 'S' , name )
buf = appendSync ( buf )
2019-04-13 07:44:06 +00:00
n , err := c . conn . Write ( buf )
2018-10-03 16:55:26 +00:00
if err != nil {
2019-04-13 07:44:06 +00:00
if fatalWriteErr ( n , err ) {
c . die ( err )
}
2018-10-03 16:55:26 +00:00
return nil , err
}
c . pendingReadyForQueryCount ++
ps = & PreparedStatement { Name : name , SQL : sql }
var softErr error
for {
msg , err := c . rxMsg ( )
if err != nil {
return nil , err
}
switch msg := msg . ( type ) {
case * pgproto3 . ParameterDescription :
ps . ParameterOIDs = c . rxParameterDescription ( msg )
if len ( ps . ParameterOIDs ) > 65535 && softErr == nil {
softErr = errors . Errorf ( "PostgreSQL supports maximum of 65535 parameters, received %d" , len ( ps . ParameterOIDs ) )
}
case * pgproto3 . RowDescription :
ps . FieldDescriptions = c . rxRowDescription ( msg )
for i := range ps . FieldDescriptions {
if dt , ok := c . ConnInfo . DataTypeForOID ( ps . FieldDescriptions [ i ] . DataType ) ; ok {
ps . FieldDescriptions [ i ] . DataTypeName = dt . Name
if _ , ok := dt . Value . ( pgtype . BinaryDecoder ) ; ok {
ps . FieldDescriptions [ i ] . FormatCode = BinaryFormatCode
} else {
ps . FieldDescriptions [ i ] . FormatCode = TextFormatCode
}
} else {
2019-04-13 07:44:06 +00:00
return nil , errors . Errorf ( "unknown oid: %d" , ps . FieldDescriptions [ i ] . DataType )
2018-10-03 16:55:26 +00:00
}
}
case * pgproto3 . ReadyForQuery :
c . rxReadyForQuery ( msg )
if softErr == nil {
c . preparedStatements [ name ] = ps
}
return ps , softErr
default :
if e := c . processContextFreeMsg ( msg ) ; e != nil && softErr == nil {
softErr = e
}
}
}
}
// Deallocate released a prepared statement
func ( c * Conn ) Deallocate ( name string ) error {
return c . deallocateContext ( context . Background ( ) , name )
}
// TODO - consider making this public
func ( c * Conn ) deallocateContext ( ctx context . Context , name string ) ( err error ) {
err = c . waitForPreviousCancelQuery ( ctx )
if err != nil {
return err
}
err = c . initContext ( ctx )
if err != nil {
return err
}
defer func ( ) {
err = c . termContext ( err )
} ( )
if err := c . ensureConnectionReadyForQuery ( ) ; err != nil {
return err
}
delete ( c . preparedStatements , name )
// close
buf := c . wbuf
buf = append ( buf , 'C' )
sp := len ( buf )
buf = pgio . AppendInt32 ( buf , - 1 )
buf = append ( buf , 'S' )
buf = append ( buf , name ... )
buf = append ( buf , 0 )
pgio . SetInt32 ( buf [ sp : ] , int32 ( len ( buf [ sp : ] ) ) )
// flush
buf = append ( buf , 'H' )
buf = pgio . AppendInt32 ( buf , 4 )
_ , err = c . conn . Write ( buf )
if err != nil {
c . die ( err )
return err
}
for {
msg , err := c . rxMsg ( )
if err != nil {
return err
}
switch msg . ( type ) {
case * pgproto3 . CloseComplete :
return nil
default :
err = c . processContextFreeMsg ( msg )
if err != nil {
return err
}
}
}
}
// Listen establishes a PostgreSQL listen/notify to channel
func ( c * Conn ) Listen ( channel string ) error {
_ , err := c . Exec ( "listen " + quoteIdentifier ( channel ) )
if err != nil {
return err
}
c . channels [ channel ] = struct { } { }
return nil
}
// Unlisten unsubscribes from a listen channel
func ( c * Conn ) Unlisten ( channel string ) error {
_ , err := c . Exec ( "unlisten " + quoteIdentifier ( channel ) )
if err != nil {
return err
}
delete ( c . channels , channel )
return nil
}
// WaitForNotification waits for a PostgreSQL notification.
func ( c * Conn ) WaitForNotification ( ctx context . Context ) ( notification * Notification , err error ) {
// Return already received notification immediately
if len ( c . notifications ) > 0 {
notification := c . notifications [ 0 ]
c . notifications = c . notifications [ 1 : ]
return notification , nil
}
err = c . waitForPreviousCancelQuery ( ctx )
if err != nil {
return nil , err
}
err = c . initContext ( ctx )
if err != nil {
return nil , err
}
defer func ( ) {
err = c . termContext ( err )
} ( )
if err = c . lock ( ) ; err != nil {
return nil , err
}
defer func ( ) {
if unlockErr := c . unlock ( ) ; unlockErr != nil && err == nil {
err = unlockErr
}
} ( )
if err := c . ensureConnectionReadyForQuery ( ) ; err != nil {
return nil , err
}
for {
msg , err := c . rxMsg ( )
if err != nil {
return nil , err
}
err = c . processContextFreeMsg ( msg )
if err != nil {
return nil , err
}
if len ( c . notifications ) > 0 {
notification := c . notifications [ 0 ]
c . notifications = c . notifications [ 1 : ]
return notification , nil
}
}
}
func ( c * Conn ) IsAlive ( ) bool {
c . mux . Lock ( )
defer c . mux . Unlock ( )
return c . status >= connStatusIdle
}
func ( c * Conn ) CauseOfDeath ( ) error {
c . mux . Lock ( )
defer c . mux . Unlock ( )
return c . causeOfDeath
}
func ( c * Conn ) sendQuery ( sql string , arguments ... interface { } ) ( err error ) {
if ps , present := c . preparedStatements [ sql ] ; present {
return c . sendPreparedQuery ( ps , arguments ... )
}
return c . sendSimpleQuery ( sql , arguments ... )
}
func ( c * Conn ) sendSimpleQuery ( sql string , args ... interface { } ) error {
if err := c . ensureConnectionReadyForQuery ( ) ; err != nil {
return err
}
if len ( args ) == 0 {
buf := appendQuery ( c . wbuf , sql )
_ , err := c . conn . Write ( buf )
if err != nil {
c . die ( err )
return err
}
c . pendingReadyForQueryCount ++
return nil
}
ps , err := c . Prepare ( "" , sql )
if err != nil {
return err
}
return c . sendPreparedQuery ( ps , args ... )
}
func ( c * Conn ) sendPreparedQuery ( ps * PreparedStatement , arguments ... interface { } ) ( err error ) {
if len ( ps . ParameterOIDs ) != len ( arguments ) {
return errors . Errorf ( "Prepared statement \"%v\" requires %d parameters, but %d were provided" , ps . Name , len ( ps . ParameterOIDs ) , len ( arguments ) )
}
if err := c . ensureConnectionReadyForQuery ( ) ; err != nil {
return err
}
resultFormatCodes := make ( [ ] int16 , len ( ps . FieldDescriptions ) )
for i , fd := range ps . FieldDescriptions {
resultFormatCodes [ i ] = fd . FormatCode
}
buf , err := appendBind ( c . wbuf , "" , ps . Name , c . ConnInfo , ps . ParameterOIDs , arguments , resultFormatCodes )
if err != nil {
return err
}
buf = appendExecute ( buf , "" , 0 )
buf = appendSync ( buf )
2019-04-13 07:44:06 +00:00
n , err := c . conn . Write ( buf )
2018-10-03 16:55:26 +00:00
if err != nil {
2019-04-13 07:44:06 +00:00
if fatalWriteErr ( n , err ) {
c . die ( err )
}
2018-10-03 16:55:26 +00:00
return err
}
c . pendingReadyForQueryCount ++
return nil
}
2019-04-13 07:44:06 +00:00
// fatalWriteError takes the response of a net.Conn.Write and determines if it is fatal
func fatalWriteErr ( bytesWritten int , err error ) bool {
// Partial writes break the connection
if bytesWritten > 0 {
return true
}
netErr , is := err . ( net . Error )
return ! ( is && netErr . Timeout ( ) )
}
2018-10-03 16:55:26 +00:00
// Exec executes sql. sql can be either a prepared statement name or an SQL string.
// arguments should be referenced positionally from the sql string as $1, $2, etc.
func ( c * Conn ) Exec ( sql string , arguments ... interface { } ) ( commandTag CommandTag , err error ) {
return c . ExecEx ( context . Background ( ) , sql , nil , arguments ... )
}
// Processes messages that are not exclusive to one context such as
// authentication or query response. The response to these messages is the same
// regardless of when they occur. It also ignores messages that are only
// meaningful in a given context. These messages can occur due to a context
// deadline interrupting message processing. For example, an interrupted query
// may have left DataRow messages on the wire.
func ( c * Conn ) processContextFreeMsg ( msg pgproto3 . BackendMessage ) ( err error ) {
switch msg := msg . ( type ) {
case * pgproto3 . ErrorResponse :
return c . rxErrorResponse ( msg )
case * pgproto3 . NoticeResponse :
c . rxNoticeResponse ( msg )
case * pgproto3 . NotificationResponse :
c . rxNotificationResponse ( msg )
case * pgproto3 . ReadyForQuery :
c . rxReadyForQuery ( msg )
case * pgproto3 . ParameterStatus :
c . rxParameterStatus ( msg )
}
return nil
}
func ( c * Conn ) rxMsg ( ) ( pgproto3 . BackendMessage , error ) {
if ! c . IsAlive ( ) {
return nil , ErrDeadConn
}
msg , err := c . frontend . Receive ( )
if err != nil {
if netErr , ok := err . ( net . Error ) ; ! ( ok && netErr . Timeout ( ) ) {
c . die ( err )
}
return nil , err
}
c . lastActivityTime = time . Now ( )
// fmt.Printf("rxMsg: %#v\n", msg)
return msg , nil
}
func ( c * Conn ) rxAuthenticationX ( msg * pgproto3 . Authentication ) ( err error ) {
switch msg . Type {
case pgproto3 . AuthTypeOk :
case pgproto3 . AuthTypeCleartextPassword :
err = c . txPasswordMessage ( c . config . Password )
case pgproto3 . AuthTypeMD5Password :
digestedPassword := "md5" + hexMD5 ( hexMD5 ( c . config . Password + c . config . User ) + string ( msg . Salt [ : ] ) )
err = c . txPasswordMessage ( digestedPassword )
default :
err = errors . New ( "Received unknown authentication message" )
}
return
}
func hexMD5 ( s string ) string {
hash := md5 . New ( )
io . WriteString ( hash , s )
return hex . EncodeToString ( hash . Sum ( nil ) )
}
func ( c * Conn ) rxParameterStatus ( msg * pgproto3 . ParameterStatus ) {
c . RuntimeParams [ msg . Name ] = msg . Value
}
func ( c * Conn ) rxErrorResponse ( msg * pgproto3 . ErrorResponse ) PgError {
err := PgError {
Severity : msg . Severity ,
Code : msg . Code ,
Message : msg . Message ,
Detail : msg . Detail ,
Hint : msg . Hint ,
Position : msg . Position ,
InternalPosition : msg . InternalPosition ,
InternalQuery : msg . InternalQuery ,
Where : msg . Where ,
SchemaName : msg . SchemaName ,
TableName : msg . TableName ,
ColumnName : msg . ColumnName ,
DataTypeName : msg . DataTypeName ,
ConstraintName : msg . ConstraintName ,
File : msg . File ,
Line : msg . Line ,
Routine : msg . Routine ,
}
if err . Severity == "FATAL" {
c . die ( err )
}
return err
}
func ( c * Conn ) rxNoticeResponse ( msg * pgproto3 . NoticeResponse ) {
if c . onNotice == nil {
return
}
notice := & Notice {
Severity : msg . Severity ,
Code : msg . Code ,
Message : msg . Message ,
Detail : msg . Detail ,
Hint : msg . Hint ,
Position : msg . Position ,
InternalPosition : msg . InternalPosition ,
InternalQuery : msg . InternalQuery ,
Where : msg . Where ,
SchemaName : msg . SchemaName ,
TableName : msg . TableName ,
ColumnName : msg . ColumnName ,
DataTypeName : msg . DataTypeName ,
ConstraintName : msg . ConstraintName ,
File : msg . File ,
Line : msg . Line ,
Routine : msg . Routine ,
}
c . onNotice ( c , notice )
}
func ( c * Conn ) rxBackendKeyData ( msg * pgproto3 . BackendKeyData ) {
c . pid = msg . ProcessID
c . secretKey = msg . SecretKey
}
func ( c * Conn ) rxReadyForQuery ( msg * pgproto3 . ReadyForQuery ) {
c . pendingReadyForQueryCount --
c . txStatus = msg . TxStatus
}
func ( c * Conn ) rxRowDescription ( msg * pgproto3 . RowDescription ) [ ] FieldDescription {
fields := make ( [ ] FieldDescription , len ( msg . Fields ) )
for i := 0 ; i < len ( fields ) ; i ++ {
fields [ i ] . Name = msg . Fields [ i ] . Name
fields [ i ] . Table = pgtype . OID ( msg . Fields [ i ] . TableOID )
fields [ i ] . AttributeNumber = msg . Fields [ i ] . TableAttributeNumber
fields [ i ] . DataType = pgtype . OID ( msg . Fields [ i ] . DataTypeOID )
fields [ i ] . DataTypeSize = msg . Fields [ i ] . DataTypeSize
fields [ i ] . Modifier = msg . Fields [ i ] . TypeModifier
fields [ i ] . FormatCode = msg . Fields [ i ] . Format
}
return fields
}
func ( c * Conn ) rxParameterDescription ( msg * pgproto3 . ParameterDescription ) [ ] pgtype . OID {
parameters := make ( [ ] pgtype . OID , len ( msg . ParameterOIDs ) )
for i := 0 ; i < len ( parameters ) ; i ++ {
parameters [ i ] = pgtype . OID ( msg . ParameterOIDs [ i ] )
}
return parameters
}
func ( c * Conn ) rxNotificationResponse ( msg * pgproto3 . NotificationResponse ) {
n := new ( Notification )
n . PID = msg . PID
n . Channel = msg . Channel
n . Payload = msg . Payload
c . notifications = append ( c . notifications , n )
}
func ( c * Conn ) startTLS ( tlsConfig * tls . Config ) ( err error ) {
err = binary . Write ( c . conn , binary . BigEndian , [ ] int32 { 8 , 80877103 } )
if err != nil {
return
}
response := make ( [ ] byte , 1 )
if _ , err = io . ReadFull ( c . conn , response ) ; err != nil {
return
}
if response [ 0 ] != 'S' {
return ErrTLSRefused
}
c . conn = tls . Client ( c . conn , tlsConfig )
return nil
}
func ( c * Conn ) txPasswordMessage ( password string ) ( err error ) {
buf := c . wbuf
buf = append ( buf , 'p' )
sp := len ( buf )
buf = pgio . AppendInt32 ( buf , - 1 )
buf = append ( buf , password ... )
buf = append ( buf , 0 )
pgio . SetInt32 ( buf [ sp : ] , int32 ( len ( buf [ sp : ] ) ) )
_ , err = c . conn . Write ( buf )
return err
}
func ( c * Conn ) die ( err error ) {
c . mux . Lock ( )
defer c . mux . Unlock ( )
if c . status == connStatusClosed {
return
}
c . status = connStatusClosed
c . causeOfDeath = err
c . conn . Close ( )
}
func ( c * Conn ) lock ( ) error {
c . mux . Lock ( )
defer c . mux . Unlock ( )
if c . status != connStatusIdle {
return ErrConnBusy
}
c . status = connStatusBusy
return nil
}
func ( c * Conn ) unlock ( ) error {
c . mux . Lock ( )
defer c . mux . Unlock ( )
if c . status != connStatusBusy {
return errors . New ( "unlock conn that is not busy" )
}
c . status = connStatusIdle
return nil
}
2019-04-13 07:44:06 +00:00
func ( c * Conn ) shouldLog ( lvl int ) bool {
2018-10-03 16:55:26 +00:00
return c . logger != nil && c . logLevel >= lvl
}
func ( c * Conn ) log ( lvl LogLevel , msg string , data map [ string ] interface { } ) {
if data == nil {
data = map [ string ] interface { } { }
}
if c . pid != 0 {
data [ "pid" ] = c . pid
}
c . logger . Log ( lvl , msg , data )
}
// SetLogger replaces the current logger and returns the previous logger.
func ( c * Conn ) SetLogger ( logger Logger ) Logger {
oldLogger := c . logger
c . logger = logger
return oldLogger
}
// SetLogLevel replaces the current log level and returns the previous log
// level.
2019-04-13 07:44:06 +00:00
func ( c * Conn ) SetLogLevel ( lvl int ) ( int , error ) {
2018-10-03 16:55:26 +00:00
oldLvl := c . logLevel
if lvl < LogLevelNone || lvl > LogLevelTrace {
return oldLvl , ErrInvalidLogLevel
}
c . logLevel = lvl
return lvl , nil
}
func quoteIdentifier ( s string ) string {
return ` " ` + strings . Replace ( s , ` " ` , ` "" ` , - 1 ) + ` " `
}
2019-04-12 15:51:37 +00:00
func doCancel ( c * Conn ) error {
network , address := c . config . networkAddress ( )
cancelConn , err := c . config . Dial ( network , address )
if err != nil {
return err
}
defer cancelConn . Close ( )
// If server doesn't process cancellation request in bounded time then abort.
now := time . Now ( )
err = cancelConn . SetDeadline ( now . Add ( 15 * time . Second ) )
if err != nil {
return err
}
buf := make ( [ ] byte , 16 )
binary . BigEndian . PutUint32 ( buf [ 0 : 4 ] , 16 )
binary . BigEndian . PutUint32 ( buf [ 4 : 8 ] , 80877102 )
binary . BigEndian . PutUint32 ( buf [ 8 : 12 ] , uint32 ( c . pid ) )
binary . BigEndian . PutUint32 ( buf [ 12 : 16 ] , uint32 ( c . secretKey ) )
_ , err = cancelConn . Write ( buf )
if err != nil {
return err
}
_ , err = cancelConn . Read ( buf )
if err != io . EOF {
return errors . Errorf ( "Server failed to close connection after cancel query request: %v %v" , err , buf )
}
return nil
}
2018-10-03 16:55:26 +00:00
// cancelQuery sends a cancel request to the PostgreSQL server. It returns an
// error if unable to deliver the cancel request, but lack of an error does not
// ensure that the query was canceled. As specified in the documentation, there
// is no way to be sure a query was canceled. See
// https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861
func ( c * Conn ) cancelQuery ( ) {
if err := c . conn . SetDeadline ( time . Now ( ) ) ; err != nil {
c . Close ( ) // Close connection if unable to set deadline
return
}
2019-04-12 15:51:37 +00:00
var cancelFn func ( * Conn ) error
completeCh := make ( chan struct { } )
c . mux . Lock ( )
c . cancelQueryCompleted = completeCh
c . mux . Unlock ( )
if c . config . CustomCancel != nil {
cancelFn = c . config . CustomCancel
} else {
cancelFn = doCancel
2018-10-03 16:55:26 +00:00
}
go func ( ) {
2019-04-12 15:51:37 +00:00
defer close ( completeCh )
err := cancelFn ( c )
2018-10-03 16:55:26 +00:00
if err != nil {
c . Close ( ) // Something is very wrong. Terminate the connection.
}
} ( )
}
func ( c * Conn ) Ping ( ctx context . Context ) error {
_ , err := c . ExecEx ( ctx , ";" , nil )
return err
}
func ( c * Conn ) ExecEx ( ctx context . Context , sql string , options * QueryExOptions , arguments ... interface { } ) ( CommandTag , error ) {
2019-04-12 15:51:37 +00:00
c . lastStmtSent = false
2018-10-03 16:55:26 +00:00
err := c . waitForPreviousCancelQuery ( ctx )
if err != nil {
return "" , err
}
if err := c . lock ( ) ; err != nil {
return "" , err
}
defer c . unlock ( )
startTime := time . Now ( )
c . lastActivityTime = startTime
commandTag , err := c . execEx ( ctx , sql , options , arguments ... )
if err != nil {
if c . shouldLog ( LogLevelError ) {
c . log ( LogLevelError , "Exec" , map [ string ] interface { } { "sql" : sql , "args" : logQueryArgs ( arguments ) , "err" : err } )
}
return commandTag , err
}
if c . shouldLog ( LogLevelInfo ) {
endTime := time . Now ( )
c . log ( LogLevelInfo , "Exec" , map [ string ] interface { } { "sql" : sql , "args" : logQueryArgs ( arguments ) , "time" : endTime . Sub ( startTime ) , "commandTag" : commandTag } )
}
return commandTag , err
}
func ( c * Conn ) execEx ( ctx context . Context , sql string , options * QueryExOptions , arguments ... interface { } ) ( commandTag CommandTag , err error ) {
err = c . initContext ( ctx )
if err != nil {
return "" , err
}
defer func ( ) {
err = c . termContext ( err )
} ( )
if ( options == nil && c . config . PreferSimpleProtocol ) || ( options != nil && options . SimpleProtocol ) {
2019-04-12 15:51:37 +00:00
c . lastStmtSent = true
2018-10-03 16:55:26 +00:00
err = c . sanitizeAndSendSimpleQuery ( sql , arguments ... )
if err != nil {
return "" , err
}
} else if options != nil && len ( options . ParameterOIDs ) > 0 {
if err := c . ensureConnectionReadyForQuery ( ) ; err != nil {
return "" , err
}
buf , err := c . buildOneRoundTripExec ( c . wbuf , sql , options , arguments )
if err != nil {
return "" , err
}
buf = appendSync ( buf )
2019-04-12 15:51:37 +00:00
c . lastStmtSent = true
2019-04-13 07:44:06 +00:00
n , err := c . conn . Write ( buf )
if err != nil && fatalWriteErr ( n , err ) {
2018-10-03 16:55:26 +00:00
c . die ( err )
return "" , err
}
c . pendingReadyForQueryCount ++
} else {
if len ( arguments ) > 0 {
ps , ok := c . preparedStatements [ sql ]
if ! ok {
var err error
ps , err = c . prepareEx ( "" , sql , nil )
if err != nil {
return "" , err
}
}
2019-04-12 15:51:37 +00:00
c . lastStmtSent = true
2018-10-03 16:55:26 +00:00
err = c . sendPreparedQuery ( ps , arguments ... )
if err != nil {
return "" , err
}
} else {
2019-04-12 15:51:37 +00:00
c . lastStmtSent = true
2018-10-03 16:55:26 +00:00
if err = c . sendQuery ( sql , arguments ... ) ; err != nil {
return
}
}
}
var softErr error
for {
msg , err := c . rxMsg ( )
if err != nil {
return commandTag , err
}
switch msg := msg . ( type ) {
case * pgproto3 . ReadyForQuery :
c . rxReadyForQuery ( msg )
return commandTag , softErr
case * pgproto3 . CommandComplete :
commandTag = CommandTag ( msg . CommandTag )
default :
if e := c . processContextFreeMsg ( msg ) ; e != nil && softErr == nil {
softErr = e
}
}
}
}
func ( c * Conn ) buildOneRoundTripExec ( buf [ ] byte , sql string , options * QueryExOptions , arguments [ ] interface { } ) ( [ ] byte , error ) {
if len ( arguments ) != len ( options . ParameterOIDs ) {
return nil , errors . Errorf ( "mismatched number of arguments (%d) and options.ParameterOIDs (%d)" , len ( arguments ) , len ( options . ParameterOIDs ) )
}
if len ( options . ParameterOIDs ) > 65535 {
return nil , errors . Errorf ( "Number of QueryExOptions ParameterOIDs must be between 0 and 65535, received %d" , len ( options . ParameterOIDs ) )
}
buf = appendParse ( buf , "" , sql , options . ParameterOIDs )
buf , err := appendBind ( buf , "" , "" , c . ConnInfo , options . ParameterOIDs , arguments , nil )
if err != nil {
return nil , err
}
buf = appendExecute ( buf , "" , 0 )
return buf , nil
}
func ( c * Conn ) initContext ( ctx context . Context ) error {
if c . ctxInProgress {
return errors . New ( "ctx already in progress" )
}
if ctx . Done ( ) == nil {
return nil
}
select {
case <- ctx . Done ( ) :
return ctx . Err ( )
default :
}
c . ctxInProgress = true
go c . contextHandler ( ctx )
return nil
}
func ( c * Conn ) termContext ( opErr error ) error {
if ! c . ctxInProgress {
return opErr
}
var err error
select {
case err = <- c . closedChan :
if opErr == nil {
err = nil
}
case c . doneChan <- struct { } { } :
err = opErr
}
c . ctxInProgress = false
return err
}
func ( c * Conn ) contextHandler ( ctx context . Context ) {
select {
case <- ctx . Done ( ) :
c . cancelQuery ( )
c . closedChan <- ctx . Err ( )
case <- c . doneChan :
}
}
2019-04-12 15:51:37 +00:00
// WaitUntilReady will return when the connection is ready for another query
func ( c * Conn ) WaitUntilReady ( ctx context . Context ) error {
err := c . waitForPreviousCancelQuery ( ctx )
if err != nil {
return err
2018-10-03 16:55:26 +00:00
}
2019-04-12 15:51:37 +00:00
return c . ensureConnectionReadyForQuery ( )
}
2018-10-03 16:55:26 +00:00
2019-04-12 15:51:37 +00:00
func ( c * Conn ) waitForPreviousCancelQuery ( ctx context . Context ) error {
c . mux . Lock ( )
completeCh := c . cancelQueryCompleted
c . mux . Unlock ( )
2018-10-03 16:55:26 +00:00
select {
2019-04-12 15:51:37 +00:00
case <- completeCh :
2018-10-03 16:55:26 +00:00
if err := c . conn . SetDeadline ( time . Time { } ) ; err != nil {
c . Close ( ) // Close connection if unable to disable deadline
return err
}
return nil
case <- ctx . Done ( ) :
return ctx . Err ( )
}
}
func ( c * Conn ) ensureConnectionReadyForQuery ( ) error {
for c . pendingReadyForQueryCount > 0 {
msg , err := c . rxMsg ( )
if err != nil {
return err
}
switch msg := msg . ( type ) {
case * pgproto3 . ErrorResponse :
pgErr := c . rxErrorResponse ( msg )
if pgErr . Severity == "FATAL" {
return pgErr
}
default :
err = c . processContextFreeMsg ( msg )
if err != nil {
return err
}
}
}
return nil
}
func connInfoFromRows ( rows * Rows , err error ) ( map [ string ] pgtype . OID , error ) {
if err != nil {
return nil , err
}
defer rows . Close ( )
nameOIDs := make ( map [ string ] pgtype . OID , 256 )
for rows . Next ( ) {
var oid pgtype . OID
var name pgtype . Text
if err = rows . Scan ( & oid , & name ) ; err != nil {
return nil , err
}
nameOIDs [ name . String ] = oid
}
if err = rows . Err ( ) ; err != nil {
return nil , err
}
return nameOIDs , err
}
2019-04-12 15:51:37 +00:00
// LastStmtSent returns true if the last call to Query(Ex)/Exec(Ex) attempted to
// send the statement over the wire. Each call to a Query(Ex)/Exec(Ex) resets
// the value to false initially until the statement has been sent. This does
// NOT mean that the statement was successful or even received, it just means
// that a write was attempted and therefore it could have been executed. Calls
// to prepare a statement are ignored, only when the prepared statement is
// attempted to be executed will this return true.
func ( c * Conn ) LastStmtSent ( ) bool {
return c . lastStmtSent
}