open-consul/vendor/github.com/gocql/gocql/conn.go

1201 lines
28 KiB
Go

// Copyright (c) 2012 The gocql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gocql
import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gocql/gocql/internal/lru"
"github.com/gocql/gocql/internal/streams"
)
var (
approvedAuthenticators = [...]string{
"org.apache.cassandra.auth.PasswordAuthenticator",
"com.instaclustr.cassandra.auth.SharedSecretAuthenticator",
"com.datastax.bdp.cassandra.auth.DseAuthenticator",
}
)
func approve(authenticator string) bool {
for _, s := range approvedAuthenticators {
if authenticator == s {
return true
}
}
return false
}
//JoinHostPort is a utility to return a address string that can be used
//gocql.Conn to form a connection with a host.
func JoinHostPort(addr string, port int) string {
addr = strings.TrimSpace(addr)
if _, _, err := net.SplitHostPort(addr); err != nil {
addr = net.JoinHostPort(addr, strconv.Itoa(port))
}
return addr
}
type Authenticator interface {
Challenge(req []byte) (resp []byte, auth Authenticator, err error)
Success(data []byte) error
}
type PasswordAuthenticator struct {
Username string
Password string
}
func (p PasswordAuthenticator) Challenge(req []byte) ([]byte, Authenticator, error) {
if !approve(string(req)) {
return nil, nil, fmt.Errorf("unexpected authenticator %q", req)
}
resp := make([]byte, 2+len(p.Username)+len(p.Password))
resp[0] = 0
copy(resp[1:], p.Username)
resp[len(p.Username)+1] = 0
copy(resp[2+len(p.Username):], p.Password)
return resp, nil, nil
}
func (p PasswordAuthenticator) Success(data []byte) error {
return nil
}
type SslOptions struct {
*tls.Config
// CertPath and KeyPath are optional depending on server
// config, but both fields must be omitted to avoid using a
// client certificate
CertPath string
KeyPath string
CaPath string //optional depending on server config
// If you want to verify the hostname and server cert (like a wildcard for cass cluster) then you should turn this on
// This option is basically the inverse of InSecureSkipVerify
// See InSecureSkipVerify in http://golang.org/pkg/crypto/tls/ for more info
EnableHostVerification bool
}
type ConnConfig struct {
ProtoVersion int
CQLVersion string
Timeout time.Duration
ConnectTimeout time.Duration
Compressor Compressor
Authenticator Authenticator
Keepalive time.Duration
tlsConfig *tls.Config
}
type ConnErrorHandler interface {
HandleError(conn *Conn, err error, closed bool)
}
type connErrorHandlerFn func(conn *Conn, err error, closed bool)
func (fn connErrorHandlerFn) HandleError(conn *Conn, err error, closed bool) {
fn(conn, err, closed)
}
// If not zero, how many timeouts we will allow to occur before the connection is closed
// and restarted. This is to prevent a single query timeout from killing a connection
// which may be serving more queries just fine.
// Default is 0, should not be changed concurrently with queries.
//
// depreciated
var TimeoutLimit int64 = 0
// Conn is a single connection to a Cassandra node. It can be used to execute
// queries, but users are usually advised to use a more reliable, higher
// level API.
type Conn struct {
conn net.Conn
r *bufio.Reader
timeout time.Duration
cfg *ConnConfig
frameObserver FrameHeaderObserver
headerBuf [maxFrameHeaderSize]byte
streams *streams.IDGenerator
mu sync.RWMutex
calls map[int]*callReq
errorHandler ConnErrorHandler
compressor Compressor
auth Authenticator
addr string
version uint8
currentKeyspace string
host *HostInfo
session *Session
closed int32
quit chan struct{}
timeouts int64
}
// Connect establishes a connection to a Cassandra node.
func (s *Session) dial(host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) {
ip := host.ConnectAddress()
port := host.port
// TODO(zariel): remove these
if len(ip) == 0 || ip.IsUnspecified() {
panic(fmt.Sprintf("host missing connect ip address: %v", ip))
} else if port == 0 {
panic(fmt.Sprintf("host missing port: %v", port))
}
var (
err error
conn net.Conn
)
dialer := &net.Dialer{
Timeout: cfg.ConnectTimeout,
}
// TODO(zariel): handle ipv6 zone
addr := (&net.TCPAddr{IP: ip, Port: port}).String()
if cfg.tlsConfig != nil {
// the TLS config is safe to be reused by connections but it must not
// be modified after being used.
conn, err = tls.DialWithDialer(dialer, "tcp", addr, cfg.tlsConfig)
} else {
conn, err = dialer.Dial("tcp", addr)
}
if err != nil {
return nil, err
}
c := &Conn{
conn: conn,
r: bufio.NewReader(conn),
cfg: cfg,
calls: make(map[int]*callReq),
timeout: cfg.Timeout,
version: uint8(cfg.ProtoVersion),
addr: conn.RemoteAddr().String(),
errorHandler: errorHandler,
compressor: cfg.Compressor,
auth: cfg.Authenticator,
quit: make(chan struct{}),
session: s,
streams: streams.New(cfg.ProtoVersion),
host: host,
frameObserver: s.frameObserver,
}
if cfg.Keepalive > 0 {
c.setKeepalive(cfg.Keepalive)
}
var (
ctx context.Context
cancel func()
)
if cfg.ConnectTimeout > 0 {
ctx, cancel = context.WithTimeout(context.Background(), cfg.ConnectTimeout)
} else {
ctx, cancel = context.WithCancel(context.Background())
}
defer cancel()
frameTicker := make(chan struct{}, 1)
startupErr := make(chan error)
go func() {
for range frameTicker {
err := c.recv()
if err != nil {
select {
case startupErr <- err:
case <-ctx.Done():
}
return
}
}
}()
go func() {
defer close(frameTicker)
err := c.startup(ctx, frameTicker)
select {
case startupErr <- err:
case <-ctx.Done():
}
}()
select {
case err := <-startupErr:
if err != nil {
c.Close()
return nil, err
}
case <-ctx.Done():
c.Close()
return nil, errors.New("gocql: no response to connection startup within timeout")
}
go c.serve()
return c, nil
}
func (c *Conn) Write(p []byte) (int, error) {
if c.timeout > 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.timeout))
}
return c.conn.Write(p)
}
func (c *Conn) Read(p []byte) (n int, err error) {
const maxAttempts = 5
for i := 0; i < maxAttempts; i++ {
var nn int
if c.timeout > 0 {
c.conn.SetReadDeadline(time.Now().Add(c.timeout))
}
nn, err = io.ReadFull(c.r, p[n:])
n += nn
if err == nil {
break
}
if verr, ok := err.(net.Error); !ok || !verr.Temporary() {
break
}
}
return
}
func (c *Conn) startup(ctx context.Context, frameTicker chan struct{}) error {
m := map[string]string{
"CQL_VERSION": c.cfg.CQLVersion,
}
if c.compressor != nil {
m["COMPRESSION"] = c.compressor.Name()
}
select {
case frameTicker <- struct{}{}:
case <-ctx.Done():
return ctx.Err()
}
framer, err := c.exec(ctx, &writeStartupFrame{opts: m}, nil)
if err != nil {
return err
}
frame, err := framer.parseFrame()
if err != nil {
return err
}
switch v := frame.(type) {
case error:
return v
case *readyFrame:
return nil
case *authenticateFrame:
return c.authenticateHandshake(ctx, v, frameTicker)
default:
return NewErrProtocol("Unknown type of response to startup frame: %s", v)
}
}
func (c *Conn) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame, frameTicker chan struct{}) error {
if c.auth == nil {
return fmt.Errorf("authentication required (using %q)", authFrame.class)
}
resp, challenger, err := c.auth.Challenge([]byte(authFrame.class))
if err != nil {
return err
}
req := &writeAuthResponseFrame{data: resp}
for {
select {
case frameTicker <- struct{}{}:
case <-ctx.Done():
return ctx.Err()
}
framer, err := c.exec(ctx, req, nil)
if err != nil {
return err
}
frame, err := framer.parseFrame()
if err != nil {
return err
}
switch v := frame.(type) {
case error:
return v
case *authSuccessFrame:
if challenger != nil {
return challenger.Success(v.data)
}
return nil
case *authChallengeFrame:
resp, challenger, err = challenger.Challenge(v.data)
if err != nil {
return err
}
req = &writeAuthResponseFrame{
data: resp,
}
default:
return fmt.Errorf("unknown frame response during authentication: %v", v)
}
}
}
func (c *Conn) closeWithError(err error) {
if !atomic.CompareAndSwapInt32(&c.closed, 0, 1) {
return
}
// we should attempt to deliver the error back to the caller if it
// exists
if err != nil {
c.mu.RLock()
for _, req := range c.calls {
// we need to send the error to all waiting queries, put the state
// of this conn into not active so that it can not execute any queries.
select {
case req.resp <- err:
case <-req.timeout:
}
}
c.mu.RUnlock()
}
// if error was nil then unblock the quit channel
close(c.quit)
cerr := c.close()
if err != nil {
c.errorHandler.HandleError(c, err, true)
} else if cerr != nil {
// TODO(zariel): is it a good idea to do this?
c.errorHandler.HandleError(c, cerr, true)
}
}
func (c *Conn) close() error {
return c.conn.Close()
}
func (c *Conn) Close() {
c.closeWithError(nil)
}
// Serve starts the stream multiplexer for this connection, which is required
// to execute any queries. This method runs as long as the connection is
// open and is therefore usually called in a separate goroutine.
func (c *Conn) serve() {
var err error
for err == nil {
err = c.recv()
}
c.closeWithError(err)
}
func (c *Conn) discardFrame(head frameHeader) error {
_, err := io.CopyN(ioutil.Discard, c, int64(head.length))
if err != nil {
return err
}
return nil
}
type protocolError struct {
frame frame
}
func (p *protocolError) Error() string {
if err, ok := p.frame.(error); ok {
return err.Error()
}
return fmt.Sprintf("gocql: received unexpected frame on stream %d: %v", p.frame.Header().stream, p.frame)
}
func (c *Conn) recv() error {
// not safe for concurrent reads
// read a full header, ignore timeouts, as this is being ran in a loop
// TODO: TCP level deadlines? or just query level deadlines?
if c.timeout > 0 {
c.conn.SetReadDeadline(time.Time{})
}
headStartTime := time.Now()
// were just reading headers over and over and copy bodies
head, err := readHeader(c.r, c.headerBuf[:])
headEndTime := time.Now()
if err != nil {
return err
}
if c.frameObserver != nil {
c.frameObserver.ObserveFrameHeader(context.Background(), ObservedFrameHeader{
Version: byte(head.version),
Flags: head.flags,
Stream: int16(head.stream),
Opcode: byte(head.op),
Length: int32(head.length),
Start: headStartTime,
End: headEndTime,
})
}
if head.stream > c.streams.NumStreams {
return fmt.Errorf("gocql: frame header stream is beyond call expected bounds: %d", head.stream)
} else if head.stream == -1 {
// TODO: handle cassandra event frames, we shouldnt get any currently
framer := newFramer(c, c, c.compressor, c.version)
if err := framer.readFrame(&head); err != nil {
return err
}
go c.session.handleEvent(framer)
return nil
} else if head.stream <= 0 {
// reserved stream that we dont use, probably due to a protocol error
// or a bug in Cassandra, this should be an error, parse it and return.
framer := newFramer(c, c, c.compressor, c.version)
if err := framer.readFrame(&head); err != nil {
return err
}
frame, err := framer.parseFrame()
if err != nil {
return err
}
return &protocolError{
frame: frame,
}
}
c.mu.RLock()
call, ok := c.calls[head.stream]
c.mu.RUnlock()
if call == nil || call.framer == nil || !ok {
Logger.Printf("gocql: received response for stream which has no handler: header=%v\n", head)
return c.discardFrame(head)
}
err = call.framer.readFrame(&head)
if err != nil {
// only net errors should cause the connection to be closed. Though
// cassandra returning corrupt frames will be returned here as well.
if _, ok := err.(net.Error); ok {
return err
}
}
// we either, return a response to the caller, the caller timedout, or the
// connection has closed. Either way we should never block indefinatly here
select {
case call.resp <- err:
case <-call.timeout:
c.releaseStream(head.stream)
case <-c.quit:
}
return nil
}
func (c *Conn) releaseStream(stream int) {
c.mu.Lock()
call := c.calls[stream]
if call != nil && stream != call.streamID {
panic(fmt.Sprintf("attempt to release streamID with invalid stream: %d -> %+v\n", stream, call))
} else if call == nil {
panic(fmt.Sprintf("releasing a stream not in use: %d", stream))
}
delete(c.calls, stream)
c.mu.Unlock()
if call.timer != nil {
call.timer.Stop()
}
streamPool.Put(call)
c.streams.Clear(stream)
}
func (c *Conn) handleTimeout() {
if TimeoutLimit > 0 && atomic.AddInt64(&c.timeouts, 1) > TimeoutLimit {
c.closeWithError(ErrTooManyTimeouts)
}
}
var (
streamPool = sync.Pool{
New: func() interface{} {
return &callReq{
resp: make(chan error),
}
},
}
)
type callReq struct {
// could use a waitgroup but this allows us to do timeouts on the read/send
resp chan error
framer *framer
timeout chan struct{} // indicates to recv() that a call has timedout
streamID int // current stream in use
timer *time.Timer
}
func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*framer, error) {
// TODO: move tracer onto conn
stream, ok := c.streams.GetStream()
if !ok {
return nil, ErrNoStreams
}
// resp is basically a waiting semaphore protecting the framer
framer := newFramer(c, c, c.compressor, c.version)
call := streamPool.Get().(*callReq)
call.framer = framer
call.timeout = make(chan struct{})
call.streamID = stream
c.mu.Lock()
existingCall := c.calls[stream]
if existingCall == nil {
c.calls[stream] = call
}
c.mu.Unlock()
if existingCall != nil {
return nil, fmt.Errorf("attempting to use stream already in use: %d -> %d", stream, existingCall.streamID)
}
if tracer != nil {
framer.trace()
}
err := req.writeFrame(framer, stream)
if err != nil {
// closeWithError will block waiting for this stream to either receive a response
// or for us to timeout, close the timeout chan here. Im not entirely sure
// but we should not get a response after an error on the write side.
close(call.timeout)
// I think this is the correct thing to do, im not entirely sure. It is not
// ideal as readers might still get some data, but they probably wont.
// Here we need to be careful as the stream is not available and if all
// writes just timeout or fail then the pool might use this connection to
// send a frame on, with all the streams used up and not returned.
c.closeWithError(err)
return nil, err
}
var timeoutCh <-chan time.Time
if c.timeout > 0 {
if call.timer == nil {
call.timer = time.NewTimer(0)
<-call.timer.C
} else {
if !call.timer.Stop() {
select {
case <-call.timer.C:
default:
}
}
}
call.timer.Reset(c.timeout)
timeoutCh = call.timer.C
}
var ctxDone <-chan struct{}
if ctx != nil {
ctxDone = ctx.Done()
}
select {
case err := <-call.resp:
close(call.timeout)
if err != nil {
if !c.Closed() {
// if the connection is closed then we cant release the stream,
// this is because the request is still outstanding and we have
// been handed another error from another stream which caused the
// connection to close.
c.releaseStream(stream)
}
return nil, err
}
case <-timeoutCh:
close(call.timeout)
c.handleTimeout()
return nil, ErrTimeoutNoResponse
case <-ctxDone:
close(call.timeout)
return nil, ctx.Err()
case <-c.quit:
return nil, ErrConnectionClosed
}
// dont release the stream if detect a timeout as another request can reuse
// that stream and get a response for the old request, which we have no
// easy way of detecting.
//
// Ensure that the stream is not released if there are potentially outstanding
// requests on the stream to prevent nil pointer dereferences in recv().
defer c.releaseStream(stream)
if v := framer.header.version.version(); v != c.version {
return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version)
}
return framer, nil
}
type preparedStatment struct {
id []byte
request preparedMetadata
response resultMetadata
}
type inflightPrepare struct {
wg sync.WaitGroup
err error
preparedStatment *preparedStatment
}
func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) (*preparedStatment, error) {
stmtCacheKey := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, stmt)
flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru *lru.Cache) *inflightPrepare {
flight := new(inflightPrepare)
flight.wg.Add(1)
lru.Add(stmtCacheKey, flight)
return flight
})
if ok {
flight.wg.Wait()
return flight.preparedStatment, flight.err
}
prep := &writePrepareFrame{
statement: stmt,
}
framer, err := c.exec(ctx, prep, tracer)
if err != nil {
flight.err = err
flight.wg.Done()
c.session.stmtsLRU.remove(stmtCacheKey)
return nil, err
}
frame, err := framer.parseFrame()
if err != nil {
flight.err = err
flight.wg.Done()
c.session.stmtsLRU.remove(stmtCacheKey)
return nil, err
}
// TODO(zariel): tidy this up, simplify handling of frame parsing so its not duplicated
// everytime we need to parse a frame.
if len(framer.traceID) > 0 && tracer != nil {
tracer.Trace(framer.traceID)
}
switch x := frame.(type) {
case *resultPreparedFrame:
flight.preparedStatment = &preparedStatment{
// defensively copy as we will recycle the underlying buffer after we
// return.
id: copyBytes(x.preparedID),
// the type info's should _not_ have a reference to the framers read buffer,
// therefore we can just copy them directly.
request: x.reqMeta,
response: x.respMeta,
}
case error:
flight.err = x
default:
flight.err = NewErrProtocol("Unknown type in response to prepare frame: %s", x)
}
flight.wg.Done()
if flight.err != nil {
c.session.stmtsLRU.remove(stmtCacheKey)
}
return flight.preparedStatment, flight.err
}
func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error {
if named, ok := value.(*namedValue); ok {
dst.name = named.name
value = named.value
}
if _, ok := value.(unsetColumn); !ok {
val, err := Marshal(typ, value)
if err != nil {
return err
}
dst.value = val
} else {
dst.isUnset = true
}
return nil
}
func (c *Conn) executeQuery(qry *Query) *Iter {
params := queryParams{
consistency: qry.cons,
}
// frame checks that it is not 0
params.serialConsistency = qry.serialCons
params.defaultTimestamp = qry.defaultTimestamp
params.defaultTimestampValue = qry.defaultTimestampValue
if len(qry.pageState) > 0 {
params.pagingState = qry.pageState
}
if qry.pageSize > 0 {
params.pageSize = qry.pageSize
}
var (
frame frameWriter
info *preparedStatment
)
if qry.shouldPrepare() {
// Prepare all DML queries. Other queries can not be prepared.
var err error
info, err = c.prepareStatement(qry.context, qry.stmt, qry.trace)
if err != nil {
return &Iter{err: err}
}
var values []interface{}
if qry.binding == nil {
values = qry.values
} else {
values, err = qry.binding(&QueryInfo{
Id: info.id,
Args: info.request.columns,
Rval: info.response.columns,
PKeyColumns: info.request.pkeyColumns,
})
if err != nil {
return &Iter{err: err}
}
}
if len(values) != info.request.actualColCount {
return &Iter{err: fmt.Errorf("gocql: expected %d values send got %d", info.request.actualColCount, len(values))}
}
params.values = make([]queryValues, len(values))
for i := 0; i < len(values); i++ {
v := &params.values[i]
value := values[i]
typ := info.request.columns[i].TypeInfo
if err := marshalQueryValue(typ, value, v); err != nil {
return &Iter{err: err}
}
}
params.skipMeta = !(c.session.cfg.DisableSkipMetadata || qry.disableSkipMetadata)
frame = &writeExecuteFrame{
preparedID: info.id,
params: params,
}
} else {
frame = &writeQueryFrame{
statement: qry.stmt,
params: params,
}
}
framer, err := c.exec(qry.context, frame, qry.trace)
if err != nil {
return &Iter{err: err}
}
resp, err := framer.parseFrame()
if err != nil {
return &Iter{err: err}
}
if len(framer.traceID) > 0 && qry.trace != nil {
qry.trace.Trace(framer.traceID)
}
switch x := resp.(type) {
case *resultVoidFrame:
return &Iter{framer: framer}
case *resultRowsFrame:
iter := &Iter{
meta: x.meta,
framer: framer,
numRows: x.numRows,
}
if params.skipMeta {
if info != nil {
iter.meta = info.response
iter.meta.pagingState = x.meta.pagingState
} else {
return &Iter{framer: framer, err: errors.New("gocql: did not receive metadata but prepared info is nil")}
}
} else {
iter.meta = x.meta
}
if len(x.meta.pagingState) > 0 && !qry.disableAutoPage {
iter.next = &nextIter{
qry: *qry,
pos: int((1 - qry.prefetch) * float64(x.numRows)),
conn: c,
}
iter.next.qry.pageState = copyBytes(x.meta.pagingState)
if iter.next.pos < 1 {
iter.next.pos = 1
}
}
return iter
case *resultKeyspaceFrame:
return &Iter{framer: framer}
case *schemaChangeKeyspace, *schemaChangeTable, *schemaChangeFunction, *schemaChangeAggregate, *schemaChangeType:
iter := &Iter{framer: framer}
if err := c.awaitSchemaAgreement(); err != nil {
// TODO: should have this behind a flag
Logger.Println(err)
}
// dont return an error from this, might be a good idea to give a warning
// though. The impact of this returning an error would be that the cluster
// is not consistent with regards to its schema.
return iter
case *RequestErrUnprepared:
stmtCacheKey := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, qry.stmt)
if c.session.stmtsLRU.remove(stmtCacheKey) {
return c.executeQuery(qry)
}
return &Iter{err: x, framer: framer}
case error:
return &Iter{err: x, framer: framer}
default:
return &Iter{
err: NewErrProtocol("Unknown type in response to execute query (%T): %s", x, x),
framer: framer,
}
}
}
func (c *Conn) Pick(qry *Query) *Conn {
if c.Closed() {
return nil
}
return c
}
func (c *Conn) Closed() bool {
return atomic.LoadInt32(&c.closed) == 1
}
func (c *Conn) Address() string {
return c.addr
}
func (c *Conn) AvailableStreams() int {
return c.streams.Available()
}
func (c *Conn) UseKeyspace(keyspace string) error {
q := &writeQueryFrame{statement: `USE "` + keyspace + `"`}
q.params.consistency = Any
framer, err := c.exec(context.Background(), q, nil)
if err != nil {
return err
}
resp, err := framer.parseFrame()
if err != nil {
return err
}
switch x := resp.(type) {
case *resultKeyspaceFrame:
case error:
return x
default:
return NewErrProtocol("unknown frame in response to USE: %v", x)
}
c.currentKeyspace = keyspace
return nil
}
func (c *Conn) executeBatch(batch *Batch) *Iter {
if c.version == protoVersion1 {
return &Iter{err: ErrUnsupported}
}
n := len(batch.Entries)
req := &writeBatchFrame{
typ: batch.Type,
statements: make([]batchStatment, n),
consistency: batch.Cons,
serialConsistency: batch.serialCons,
defaultTimestamp: batch.defaultTimestamp,
defaultTimestampValue: batch.defaultTimestampValue,
}
stmts := make(map[string]string, len(batch.Entries))
for i := 0; i < n; i++ {
entry := &batch.Entries[i]
b := &req.statements[i]
if len(entry.Args) > 0 || entry.binding != nil {
info, err := c.prepareStatement(batch.context, entry.Stmt, nil)
if err != nil {
return &Iter{err: err}
}
var values []interface{}
if entry.binding == nil {
values = entry.Args
} else {
values, err = entry.binding(&QueryInfo{
Id: info.id,
Args: info.request.columns,
Rval: info.response.columns,
PKeyColumns: info.request.pkeyColumns,
})
if err != nil {
return &Iter{err: err}
}
}
if len(values) != info.request.actualColCount {
return &Iter{err: fmt.Errorf("gocql: batch statement %d expected %d values send got %d", i, info.request.actualColCount, len(values))}
}
b.preparedID = info.id
stmts[string(info.id)] = entry.Stmt
b.values = make([]queryValues, info.request.actualColCount)
for j := 0; j < info.request.actualColCount; j++ {
v := &b.values[j]
value := values[j]
typ := info.request.columns[j].TypeInfo
if err := marshalQueryValue(typ, value, v); err != nil {
return &Iter{err: err}
}
}
} else {
b.statement = entry.Stmt
}
}
// TODO: should batch support tracing?
framer, err := c.exec(batch.context, req, nil)
if err != nil {
return &Iter{err: err}
}
resp, err := framer.parseFrame()
if err != nil {
return &Iter{err: err, framer: framer}
}
switch x := resp.(type) {
case *resultVoidFrame:
return &Iter{}
case *RequestErrUnprepared:
stmt, found := stmts[string(x.StatementId)]
if found {
key := c.session.stmtsLRU.keyFor(c.addr, c.currentKeyspace, stmt)
c.session.stmtsLRU.remove(key)
}
if found {
return c.executeBatch(batch)
} else {
return &Iter{err: x, framer: framer}
}
case *resultRowsFrame:
iter := &Iter{
meta: x.meta,
framer: framer,
numRows: x.numRows,
}
return iter
case error:
return &Iter{err: x, framer: framer}
default:
return &Iter{err: NewErrProtocol("Unknown type in response to batch statement: %s", x), framer: framer}
}
}
func (c *Conn) setKeepalive(d time.Duration) error {
if tc, ok := c.conn.(*net.TCPConn); ok {
err := tc.SetKeepAlivePeriod(d)
if err != nil {
return err
}
return tc.SetKeepAlive(true)
}
return nil
}
func (c *Conn) query(statement string, values ...interface{}) (iter *Iter) {
q := c.session.Query(statement, values...).Consistency(One)
return c.executeQuery(q)
}
func (c *Conn) awaitSchemaAgreement() (err error) {
const (
peerSchemas = "SELECT schema_version, peer FROM system.peers"
localSchemas = "SELECT schema_version FROM system.local WHERE key='local'"
)
var versions map[string]struct{}
endDeadline := time.Now().Add(c.session.cfg.MaxWaitSchemaAgreement)
for time.Now().Before(endDeadline) {
iter := c.query(peerSchemas)
versions = make(map[string]struct{})
var schemaVersion string
var peer string
for iter.Scan(&schemaVersion, &peer) {
if schemaVersion == "" {
Logger.Printf("skipping peer entry with empty schema_version: peer=%q", peer)
continue
}
versions[schemaVersion] = struct{}{}
schemaVersion = ""
}
if err = iter.Close(); err != nil {
goto cont
}
iter = c.query(localSchemas)
for iter.Scan(&schemaVersion) {
versions[schemaVersion] = struct{}{}
schemaVersion = ""
}
if err = iter.Close(); err != nil {
goto cont
}
if len(versions) <= 1 {
return nil
}
cont:
time.Sleep(200 * time.Millisecond)
}
if err != nil {
return
}
schemas := make([]string, 0, len(versions))
for schema := range versions {
schemas = append(schemas, schema)
}
// not exported
return fmt.Errorf("gocql: cluster schema versions not consistent: %+v", schemas)
}
const localHostInfo = "SELECT * FROM system.local WHERE key='local'"
func (c *Conn) localHostInfo() (*HostInfo, error) {
row, err := c.query(localHostInfo).rowMap()
if err != nil {
return nil, err
}
port := c.conn.RemoteAddr().(*net.TCPAddr).Port
// TODO(zariel): avoid doing this here
host, err := c.session.hostInfoFromMap(row, port)
if err != nil {
return nil, err
}
return c.session.ring.addOrUpdate(host), nil
}
var (
ErrQueryArgLength = errors.New("gocql: query argument length mismatch")
ErrTimeoutNoResponse = errors.New("gocql: no response received from cassandra within timeout period")
ErrTooManyTimeouts = errors.New("gocql: too many query timeouts on the connection")
ErrConnectionClosed = errors.New("gocql: connection closed waiting for response")
ErrNoStreams = errors.New("gocql: no streams available on connection")
)