4e7ce6f42b
* Update deps * Change azure dep to match plugin
276 lines
6.7 KiB
Go
276 lines
6.7 KiB
Go
package pgx
|
|
|
|
import (
|
|
"database/sql/driver"
|
|
"fmt"
|
|
"math"
|
|
"reflect"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/pgio"
|
|
"github.com/jackc/pgx/pgtype"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
// PostgreSQL format codes
|
|
const (
|
|
TextFormatCode = 0
|
|
BinaryFormatCode = 1
|
|
)
|
|
|
|
// SerializationError occurs on failure to encode or decode a value
|
|
type SerializationError string
|
|
|
|
func (e SerializationError) Error() string {
|
|
return string(e)
|
|
}
|
|
|
|
func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, error) {
|
|
if arg == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
switch arg := arg.(type) {
|
|
|
|
// https://github.com/jackc/pgx/issues/409 Changed JSON and JSONB to surface
|
|
// []byte to database/sql instead of string. But that caused problems with the
|
|
// simple protocol because the driver.Valuer case got taken before the
|
|
// pgtype.TextEncoder case. And driver.Valuer needed to be first in the usual
|
|
// case because of https://github.com/jackc/pgx/issues/339. So instead we
|
|
// special case JSON and JSONB.
|
|
case *pgtype.JSON:
|
|
buf, err := arg.EncodeText(ci, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if buf == nil {
|
|
return nil, nil
|
|
}
|
|
return string(buf), nil
|
|
case *pgtype.JSONB:
|
|
buf, err := arg.EncodeText(ci, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if buf == nil {
|
|
return nil, nil
|
|
}
|
|
return string(buf), nil
|
|
|
|
case driver.Valuer:
|
|
return callValuerValue(arg)
|
|
case pgtype.TextEncoder:
|
|
buf, err := arg.EncodeText(ci, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if buf == nil {
|
|
return nil, nil
|
|
}
|
|
return string(buf), nil
|
|
case int64:
|
|
return arg, nil
|
|
case float64:
|
|
return arg, nil
|
|
case bool:
|
|
return arg, nil
|
|
case time.Time:
|
|
return arg, nil
|
|
case string:
|
|
return arg, nil
|
|
case []byte:
|
|
return arg, nil
|
|
case int8:
|
|
return int64(arg), nil
|
|
case int16:
|
|
return int64(arg), nil
|
|
case int32:
|
|
return int64(arg), nil
|
|
case int:
|
|
return int64(arg), nil
|
|
case uint8:
|
|
return int64(arg), nil
|
|
case uint16:
|
|
return int64(arg), nil
|
|
case uint32:
|
|
return int64(arg), nil
|
|
case uint64:
|
|
if arg > math.MaxInt64 {
|
|
return nil, errors.Errorf("arg too big for int64: %v", arg)
|
|
}
|
|
return int64(arg), nil
|
|
case uint:
|
|
if uint64(arg) > math.MaxInt64 {
|
|
return nil, errors.Errorf("arg too big for int64: %v", arg)
|
|
}
|
|
return int64(arg), nil
|
|
case float32:
|
|
return float64(arg), nil
|
|
}
|
|
|
|
refVal := reflect.ValueOf(arg)
|
|
|
|
if refVal.Kind() == reflect.Ptr {
|
|
if refVal.IsNil() {
|
|
return nil, nil
|
|
}
|
|
arg = refVal.Elem().Interface()
|
|
return convertSimpleArgument(ci, arg)
|
|
}
|
|
|
|
if strippedArg, ok := stripNamedType(&refVal); ok {
|
|
return convertSimpleArgument(ci, strippedArg)
|
|
}
|
|
return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg))
|
|
}
|
|
|
|
func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid pgtype.OID, arg interface{}) ([]byte, error) {
|
|
if arg == nil {
|
|
return pgio.AppendInt32(buf, -1), nil
|
|
}
|
|
|
|
switch arg := arg.(type) {
|
|
case pgtype.BinaryEncoder:
|
|
sp := len(buf)
|
|
buf = pgio.AppendInt32(buf, -1)
|
|
argBuf, err := arg.EncodeBinary(ci, buf)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if argBuf != nil {
|
|
buf = argBuf
|
|
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
|
|
}
|
|
return buf, nil
|
|
case pgtype.TextEncoder:
|
|
sp := len(buf)
|
|
buf = pgio.AppendInt32(buf, -1)
|
|
argBuf, err := arg.EncodeText(ci, buf)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if argBuf != nil {
|
|
buf = argBuf
|
|
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
|
|
}
|
|
return buf, nil
|
|
case string:
|
|
buf = pgio.AppendInt32(buf, int32(len(arg)))
|
|
buf = append(buf, arg...)
|
|
return buf, nil
|
|
}
|
|
|
|
refVal := reflect.ValueOf(arg)
|
|
|
|
if refVal.Kind() == reflect.Ptr {
|
|
if refVal.IsNil() {
|
|
return pgio.AppendInt32(buf, -1), nil
|
|
}
|
|
arg = refVal.Elem().Interface()
|
|
return encodePreparedStatementArgument(ci, buf, oid, arg)
|
|
}
|
|
|
|
if dt, ok := ci.DataTypeForOID(oid); ok {
|
|
value := dt.Value
|
|
err := value.Set(arg)
|
|
if err != nil {
|
|
{
|
|
if arg, ok := arg.(driver.Valuer); ok {
|
|
v, err := callValuerValue(arg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return encodePreparedStatementArgument(ci, buf, oid, v)
|
|
}
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
|
|
sp := len(buf)
|
|
buf = pgio.AppendInt32(buf, -1)
|
|
var argBuf []byte
|
|
switch valueEncoder := value.(type) {
|
|
case pgtype.BinaryEncoder:
|
|
argBuf, err = valueEncoder.EncodeBinary(ci, buf)
|
|
case pgtype.TextEncoder:
|
|
argBuf, err = valueEncoder.EncodeText(ci, buf)
|
|
default:
|
|
return nil, fmt.Errorf("invalid encode type %v", valueEncoder)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if argBuf != nil {
|
|
buf = argBuf
|
|
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
|
|
}
|
|
return buf, nil
|
|
}
|
|
|
|
if strippedArg, ok := stripNamedType(&refVal); ok {
|
|
return encodePreparedStatementArgument(ci, buf, oid, strippedArg)
|
|
}
|
|
return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
|
|
}
|
|
|
|
// chooseParameterFormatCode determines the correct format code for an
|
|
// argument to a prepared statement. It defaults to TextFormatCode if no
|
|
// determination can be made.
|
|
func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid pgtype.OID, arg interface{}) int16 {
|
|
switch arg.(type) {
|
|
case pgtype.BinaryEncoder:
|
|
return BinaryFormatCode
|
|
case string, *string, pgtype.TextEncoder:
|
|
return TextFormatCode
|
|
}
|
|
|
|
if dt, ok := ci.DataTypeForOID(oid); ok {
|
|
if _, ok := dt.Value.(pgtype.BinaryEncoder); ok {
|
|
return BinaryFormatCode
|
|
}
|
|
}
|
|
|
|
return TextFormatCode
|
|
}
|
|
|
|
func stripNamedType(val *reflect.Value) (interface{}, bool) {
|
|
switch val.Kind() {
|
|
case reflect.Int:
|
|
convVal := int(val.Int())
|
|
return convVal, reflect.TypeOf(convVal) != val.Type()
|
|
case reflect.Int8:
|
|
convVal := int8(val.Int())
|
|
return convVal, reflect.TypeOf(convVal) != val.Type()
|
|
case reflect.Int16:
|
|
convVal := int16(val.Int())
|
|
return convVal, reflect.TypeOf(convVal) != val.Type()
|
|
case reflect.Int32:
|
|
convVal := int32(val.Int())
|
|
return convVal, reflect.TypeOf(convVal) != val.Type()
|
|
case reflect.Int64:
|
|
convVal := int64(val.Int())
|
|
return convVal, reflect.TypeOf(convVal) != val.Type()
|
|
case reflect.Uint:
|
|
convVal := uint(val.Uint())
|
|
return convVal, reflect.TypeOf(convVal) != val.Type()
|
|
case reflect.Uint8:
|
|
convVal := uint8(val.Uint())
|
|
return convVal, reflect.TypeOf(convVal) != val.Type()
|
|
case reflect.Uint16:
|
|
convVal := uint16(val.Uint())
|
|
return convVal, reflect.TypeOf(convVal) != val.Type()
|
|
case reflect.Uint32:
|
|
convVal := uint32(val.Uint())
|
|
return convVal, reflect.TypeOf(convVal) != val.Type()
|
|
case reflect.Uint64:
|
|
convVal := uint64(val.Uint())
|
|
return convVal, reflect.TypeOf(convVal) != val.Type()
|
|
case reflect.String:
|
|
convVal := val.String()
|
|
return convVal, reflect.TypeOf(convVal) != val.Type()
|
|
}
|
|
|
|
return nil, false
|
|
}
|