340 lines
7.1 KiB
Go
340 lines
7.1 KiB
Go
package pgx
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
|
|
"github.com/jackc/pgx/pgio"
|
|
"github.com/jackc/pgx/pgproto3"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
// CopyFromRows returns a CopyFromSource interface over the provided rows slice
|
|
// making it usable by *Conn.CopyFrom.
|
|
func CopyFromRows(rows [][]interface{}) CopyFromSource {
|
|
return ©FromRows{rows: rows, idx: -1}
|
|
}
|
|
|
|
type copyFromRows struct {
|
|
rows [][]interface{}
|
|
idx int
|
|
}
|
|
|
|
func (ctr *copyFromRows) Next() bool {
|
|
ctr.idx++
|
|
return ctr.idx < len(ctr.rows)
|
|
}
|
|
|
|
func (ctr *copyFromRows) Values() ([]interface{}, error) {
|
|
return ctr.rows[ctr.idx], nil
|
|
}
|
|
|
|
func (ctr *copyFromRows) Err() error {
|
|
return nil
|
|
}
|
|
|
|
// CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data.
|
|
type CopyFromSource interface {
|
|
// Next returns true if there is another row and makes the next row data
|
|
// available to Values(). When there are no more rows available or an error
|
|
// has occurred it returns false.
|
|
Next() bool
|
|
|
|
// Values returns the values for the current row.
|
|
Values() ([]interface{}, error)
|
|
|
|
// Err returns any error that has been encountered by the CopyFromSource. If
|
|
// this is not nil *Conn.CopyFrom will abort the copy.
|
|
Err() error
|
|
}
|
|
|
|
type copyFrom struct {
|
|
conn *Conn
|
|
tableName Identifier
|
|
columnNames []string
|
|
rowSrc CopyFromSource
|
|
readerErrChan chan error
|
|
}
|
|
|
|
func (ct *copyFrom) readUntilReadyForQuery() {
|
|
for {
|
|
msg, err := ct.conn.rxMsg()
|
|
if err != nil {
|
|
ct.readerErrChan <- err
|
|
close(ct.readerErrChan)
|
|
return
|
|
}
|
|
|
|
switch msg := msg.(type) {
|
|
case *pgproto3.ReadyForQuery:
|
|
ct.conn.rxReadyForQuery(msg)
|
|
close(ct.readerErrChan)
|
|
return
|
|
case *pgproto3.CommandComplete:
|
|
case *pgproto3.ErrorResponse:
|
|
ct.readerErrChan <- ct.conn.rxErrorResponse(msg)
|
|
default:
|
|
err = ct.conn.processContextFreeMsg(msg)
|
|
if err != nil {
|
|
ct.readerErrChan <- ct.conn.processContextFreeMsg(msg)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (ct *copyFrom) waitForReaderDone() error {
|
|
var err error
|
|
for err = range ct.readerErrChan {
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (ct *copyFrom) run() (int, error) {
|
|
quotedTableName := ct.tableName.Sanitize()
|
|
cbuf := &bytes.Buffer{}
|
|
for i, cn := range ct.columnNames {
|
|
if i != 0 {
|
|
cbuf.WriteString(", ")
|
|
}
|
|
cbuf.WriteString(quoteIdentifier(cn))
|
|
}
|
|
quotedColumnNames := cbuf.String()
|
|
|
|
ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
err = ct.conn.sendSimpleQuery(fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
err = ct.conn.readUntilCopyInResponse()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
panicked := true
|
|
|
|
go ct.readUntilReadyForQuery()
|
|
defer ct.waitForReaderDone()
|
|
defer func() {
|
|
if panicked {
|
|
ct.conn.die(errors.New("panic while in copy from"))
|
|
}
|
|
}()
|
|
|
|
buf := ct.conn.wbuf
|
|
buf = append(buf, copyData)
|
|
sp := len(buf)
|
|
buf = pgio.AppendInt32(buf, -1)
|
|
|
|
buf = append(buf, "PGCOPY\n\377\r\n\000"...)
|
|
buf = pgio.AppendInt32(buf, 0)
|
|
buf = pgio.AppendInt32(buf, 0)
|
|
|
|
var sentCount int
|
|
|
|
moreRows := true
|
|
for moreRows {
|
|
select {
|
|
case err = <-ct.readerErrChan:
|
|
panicked = false
|
|
return 0, err
|
|
default:
|
|
}
|
|
|
|
var addedRows int
|
|
var err error
|
|
moreRows, buf, addedRows, err = ct.buildCopyBuf(buf, ps)
|
|
if err != nil {
|
|
panicked = false
|
|
ct.cancelCopyIn()
|
|
return 0, err
|
|
}
|
|
sentCount += addedRows
|
|
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
|
|
|
_, err = ct.conn.conn.Write(buf)
|
|
if err != nil {
|
|
panicked = false
|
|
ct.conn.die(err)
|
|
return 0, err
|
|
}
|
|
|
|
// Directly manipulate wbuf to reset to reuse the same buffer
|
|
buf = buf[0:5]
|
|
|
|
}
|
|
|
|
if ct.rowSrc.Err() != nil {
|
|
panicked = false
|
|
ct.cancelCopyIn()
|
|
return 0, ct.rowSrc.Err()
|
|
}
|
|
|
|
buf = pgio.AppendInt16(buf, -1) // terminate the copy stream
|
|
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
|
|
|
buf = append(buf, copyDone)
|
|
buf = pgio.AppendInt32(buf, 4)
|
|
|
|
_, err = ct.conn.conn.Write(buf)
|
|
if err != nil {
|
|
panicked = false
|
|
ct.conn.die(err)
|
|
return 0, err
|
|
}
|
|
|
|
err = ct.waitForReaderDone()
|
|
if err != nil {
|
|
panicked = false
|
|
return 0, err
|
|
}
|
|
|
|
panicked = false
|
|
return sentCount, nil
|
|
}
|
|
|
|
func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byte, int, error) {
|
|
var rowCount int
|
|
|
|
for ct.rowSrc.Next() {
|
|
values, err := ct.rowSrc.Values()
|
|
if err != nil {
|
|
return false, nil, 0, err
|
|
}
|
|
if len(values) != len(ct.columnNames) {
|
|
return false, nil, 0, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
|
|
}
|
|
|
|
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
|
|
for i, val := range values {
|
|
buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val)
|
|
if err != nil {
|
|
return false, nil, 0, err
|
|
}
|
|
}
|
|
|
|
rowCount++
|
|
|
|
if len(buf) > 65536 {
|
|
return true, buf, rowCount, nil
|
|
}
|
|
}
|
|
|
|
return false, buf, rowCount, nil
|
|
}
|
|
|
|
func (c *Conn) readUntilCopyInResponse() error {
|
|
for {
|
|
msg, err := c.rxMsg()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
switch msg := msg.(type) {
|
|
case *pgproto3.CopyInResponse:
|
|
return nil
|
|
default:
|
|
err = c.processContextFreeMsg(msg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (ct *copyFrom) cancelCopyIn() error {
|
|
buf := ct.conn.wbuf
|
|
buf = append(buf, copyFail)
|
|
sp := len(buf)
|
|
buf = pgio.AppendInt32(buf, -1)
|
|
buf = append(buf, "client error: abort"...)
|
|
buf = append(buf, 0)
|
|
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
|
|
|
|
_, err := ct.conn.conn.Write(buf)
|
|
if err != nil {
|
|
ct.conn.die(err)
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion.
|
|
// It returns the number of rows copied and an error.
|
|
//
|
|
// CopyFrom requires all values use the binary format. Almost all types
|
|
// implemented by pgx use the binary format by default. Types implementing
|
|
// Encoder can only be used if they encode to the binary format.
|
|
func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) {
|
|
ct := ©From{
|
|
conn: c,
|
|
tableName: tableName,
|
|
columnNames: columnNames,
|
|
rowSrc: rowSrc,
|
|
readerErrChan: make(chan error),
|
|
}
|
|
|
|
return ct.run()
|
|
}
|
|
|
|
// CopyFromReader uses the PostgreSQL textual format of the copy protocol
|
|
func (c *Conn) CopyFromReader(r io.Reader, sql string) error {
|
|
if err := c.sendSimpleQuery(sql); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := c.readUntilCopyInResponse(); err != nil {
|
|
return err
|
|
}
|
|
buf := c.wbuf
|
|
|
|
buf = append(buf, copyData)
|
|
sp := len(buf)
|
|
for {
|
|
n, err := r.Read(buf[5:cap(buf)])
|
|
if err == io.EOF && n == 0 {
|
|
break
|
|
}
|
|
buf = buf[0 : n+5]
|
|
pgio.SetInt32(buf[sp:], int32(n+4))
|
|
|
|
if _, err := c.conn.Write(buf); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
buf = buf[:0]
|
|
buf = append(buf, copyDone)
|
|
buf = pgio.AppendInt32(buf, 4)
|
|
|
|
if _, err := c.conn.Write(buf); err != nil {
|
|
return err
|
|
}
|
|
|
|
for {
|
|
msg, err := c.rxMsg()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
switch msg := msg.(type) {
|
|
case *pgproto3.ReadyForQuery:
|
|
c.rxReadyForQuery(msg)
|
|
return nil
|
|
case *pgproto3.CommandComplete:
|
|
case *pgproto3.ErrorResponse:
|
|
return c.rxErrorResponse(msg)
|
|
default:
|
|
return c.processContextFreeMsg(msg)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|