Validate hostnames when using TLS in Cassandra (#11365)
This commit is contained in:
parent
541ae8636c
commit
4279bc8b34
|
@ -20,13 +20,18 @@ func TestBackend_basic(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cleanup, hostname := cassandra.PrepareTestContainer(t, "latest")
|
||||
copyFromTo := map[string]string{
|
||||
"test-fixtures/cassandra.yaml": "/etc/cassandra/cassandra.yaml",
|
||||
}
|
||||
host, cleanup := cassandra.PrepareTestContainer(t,
|
||||
cassandra.CopyFromTo(copyFromTo),
|
||||
)
|
||||
defer cleanup()
|
||||
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
LogicalBackend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(t, hostname),
|
||||
testAccStepConfig(t, host.ConnectionURL()),
|
||||
testAccStepRole(t),
|
||||
testAccStepReadCreds(t, "test"),
|
||||
},
|
||||
|
@ -41,13 +46,17 @@ func TestBackend_roleCrud(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cleanup, hostname := cassandra.PrepareTestContainer(t, "latest")
|
||||
copyFromTo := map[string]string{
|
||||
"test-fixtures/cassandra.yaml": "/etc/cassandra/cassandra.yaml",
|
||||
}
|
||||
host, cleanup := cassandra.PrepareTestContainer(t,
|
||||
cassandra.CopyFromTo(copyFromTo))
|
||||
defer cleanup()
|
||||
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
LogicalBackend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(t, hostname),
|
||||
testAccStepConfig(t, host.ConnectionURL()),
|
||||
testAccStepRole(t),
|
||||
testAccStepRoleWithOptions(t),
|
||||
testAccStepReadRole(t, "test", testRole),
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
```release-note:bug
|
||||
secrets/database/cassandra: Fixed issue where hostnames were not being validated when using TLS
|
||||
```
|
2
go.mod
2
go.mod
|
@ -49,7 +49,7 @@ require (
|
|||
github.com/go-ole/go-ole v1.2.4 // indirect
|
||||
github.com/go-sql-driver/mysql v1.5.0
|
||||
github.com/go-test/deep v1.0.7
|
||||
github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e
|
||||
github.com/gocql/gocql v0.0.0-20210401103645-80ab1e13e309
|
||||
github.com/golang/protobuf v1.4.2
|
||||
github.com/google/go-github v17.0.0+incompatible
|
||||
github.com/google/go-metrics-stackdriver v0.2.0
|
||||
|
|
6
go.sum
6
go.sum
|
@ -442,8 +442,8 @@ github.com/gobuffalo/packd v0.1.0/go.mod h1:M2Juc+hhDXf/PnmBANFCqx4DM3wRbgDvnVWe
|
|||
github.com/gobuffalo/packr/v2 v2.0.9/go.mod h1:emmyGweYTm6Kdper+iywB6YK5YzuKchGtJQZ0Odn4pQ=
|
||||
github.com/gobuffalo/packr/v2 v2.2.0/go.mod h1:CaAwI0GPIAv+5wKLtv8Afwl+Cm78K/I/VCm/3ptBN+0=
|
||||
github.com/gobuffalo/syncx v0.0.0-20190224160051-33c29581e754/go.mod h1:HhnNqWY95UYwwW3uSASeV7vtgYkT2t16hJgV3AEPUpw=
|
||||
github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e h1:SroDcndcOU9BVAduPf/PXihXoR2ZYTQYLXbupbqxAyQ=
|
||||
github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e/go.mod h1:DL0ekTmBSTdlNF25Orwt/JMzqIq3EJ4MVa/J/uK64OY=
|
||||
github.com/gocql/gocql v0.0.0-20210401103645-80ab1e13e309 h1:8MHuCGYDXh0skFrLumkCMlt9C29hxhqNx39+Haemeqw=
|
||||
github.com/gocql/gocql v0.0.0-20210401103645-80ab1e13e309/go.mod h1:DL0ekTmBSTdlNF25Orwt/JMzqIq3EJ4MVa/J/uK64OY=
|
||||
github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4=
|
||||
github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s=
|
||||
|
@ -1084,6 +1084,7 @@ github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6So
|
|||
github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rogpeppe/go-internal v1.6.2 h1:aIihoIOHCiLZHxyoNQ+ABL4NKhFTgKLBdMLyEAh98m0=
|
||||
github.com/rogpeppe/go-internal v1.6.2/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
||||
github.com/rs/zerolog v1.4.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
|
||||
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
|
||||
|
@ -1631,6 +1632,7 @@ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8
|
|||
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
|
||||
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw=
|
||||
gopkg.in/errgo.v2 v2.1.0 h1:0vLT13EuvQ0hNvakwLuFZ/jYrLp5F3kcWHXdRggjCE8=
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=
|
||||
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
|
||||
|
|
|
@ -2,9 +2,10 @@ package cassandra
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -12,33 +13,75 @@ import (
|
|||
"github.com/hashicorp/vault/helper/testhelpers/docker"
|
||||
)
|
||||
|
||||
func PrepareTestContainer(t *testing.T, version string) (func(), string) {
|
||||
type containerConfig struct {
|
||||
version string
|
||||
copyFromTo map[string]string
|
||||
sslOpts *gocql.SslOptions
|
||||
}
|
||||
|
||||
type ContainerOpt func(*containerConfig)
|
||||
|
||||
func Version(version string) ContainerOpt {
|
||||
return func(cfg *containerConfig) {
|
||||
cfg.version = version
|
||||
}
|
||||
}
|
||||
|
||||
func CopyFromTo(copyFromTo map[string]string) ContainerOpt {
|
||||
return func(cfg *containerConfig) {
|
||||
cfg.copyFromTo = copyFromTo
|
||||
}
|
||||
}
|
||||
|
||||
func SslOpts(sslOpts *gocql.SslOptions) ContainerOpt {
|
||||
return func(cfg *containerConfig) {
|
||||
cfg.sslOpts = sslOpts
|
||||
}
|
||||
}
|
||||
|
||||
type Host struct {
|
||||
Name string
|
||||
Port string
|
||||
}
|
||||
|
||||
func (h Host) ConnectionURL() string {
|
||||
return net.JoinHostPort(h.Name, h.Port)
|
||||
}
|
||||
|
||||
func PrepareTestContainer(t *testing.T, opts ...ContainerOpt) (Host, func()) {
|
||||
t.Helper()
|
||||
if os.Getenv("CASSANDRA_HOSTS") != "" {
|
||||
return func() {}, os.Getenv("CASSANDRA_HOSTS")
|
||||
host, port, err := net.SplitHostPort(os.Getenv("CASSANDRA_HOSTS"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to split host & port from CASSANDRA_HOSTS (%s): %s", os.Getenv("CASSANDRA_HOSTS"), err)
|
||||
}
|
||||
h := Host{
|
||||
Name: host,
|
||||
Port: port,
|
||||
}
|
||||
return h, func() {}
|
||||
}
|
||||
|
||||
if version == "" {
|
||||
version = "3.11"
|
||||
containerCfg := &containerConfig{
|
||||
version: "3.11",
|
||||
}
|
||||
|
||||
var copyFromTo map[string]string
|
||||
cwd, _ := os.Getwd()
|
||||
fixturePath := fmt.Sprintf("%s/test-fixtures/", cwd)
|
||||
if _, err := os.Stat(fixturePath); err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
// If it doesn't exist, no biggie
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
copyFromTo = map[string]string{
|
||||
fixturePath: "/etc/cassandra",
|
||||
for _, opt := range opts {
|
||||
opt(containerCfg)
|
||||
}
|
||||
|
||||
copyFromTo := map[string]string{}
|
||||
for from, to := range containerCfg.copyFromTo {
|
||||
absFrom, err := filepath.Abs(from)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to get absolute path for file %s", from)
|
||||
}
|
||||
copyFromTo[absFrom] = to
|
||||
}
|
||||
|
||||
runner, err := docker.NewServiceRunner(docker.RunOptions{
|
||||
ImageRepo: "cassandra",
|
||||
ImageTag: version,
|
||||
ImageTag: containerCfg.version,
|
||||
Ports: []string{"9042/tcp"},
|
||||
CopyFromTo: copyFromTo,
|
||||
Env: []string{"CASSANDRA_BROADCAST_ADDRESS=127.0.0.1"},
|
||||
|
@ -58,6 +101,8 @@ func PrepareTestContainer(t *testing.T, version string) (func(), string) {
|
|||
clusterConfig.ProtoVersion = 4
|
||||
clusterConfig.Port = port
|
||||
|
||||
clusterConfig.SslOpts = containerCfg.sslOpts
|
||||
|
||||
session, err := clusterConfig.CreateSession()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating session: %s", err)
|
||||
|
@ -65,19 +110,19 @@ func PrepareTestContainer(t *testing.T, version string) (func(), string) {
|
|||
defer session.Close()
|
||||
|
||||
// Create keyspace
|
||||
q := session.Query(`CREATE KEYSPACE "vault" WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };`)
|
||||
if err := q.Exec(); err != nil {
|
||||
query := session.Query(`CREATE KEYSPACE "vault" WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };`)
|
||||
if err := query.Exec(); err != nil {
|
||||
t.Fatalf("could not create cassandra keyspace: %v", err)
|
||||
}
|
||||
|
||||
// Create table
|
||||
q = session.Query(`CREATE TABLE "vault"."entries" (
|
||||
query = session.Query(`CREATE TABLE "vault"."entries" (
|
||||
bucket text,
|
||||
key text,
|
||||
value blob,
|
||||
PRIMARY KEY (bucket, key)
|
||||
) WITH CLUSTERING ORDER BY (key ASC);`)
|
||||
if err := q.Exec(); err != nil {
|
||||
if err := query.Exec(); err != nil {
|
||||
t.Fatalf("could not create cassandra table: %v", err)
|
||||
}
|
||||
return cfg, nil
|
||||
|
@ -85,5 +130,14 @@ func PrepareTestContainer(t *testing.T, version string) (func(), string) {
|
|||
if err != nil {
|
||||
t.Fatalf("Could not start docker cassandra: %s", err)
|
||||
}
|
||||
return svc.Cleanup, svc.Config.Address()
|
||||
|
||||
host, port, err := net.SplitHostPort(svc.Config.Address())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to split host & port from address (%s): %s", svc.Config.Address(), err)
|
||||
}
|
||||
h := Host{
|
||||
Name: host,
|
||||
Port: port,
|
||||
}
|
||||
return h, svc.Cleanup
|
||||
}
|
||||
|
|
|
@ -10,11 +10,10 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
|
||||
metrics "github.com/armon/go-metrics"
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/hashicorp/errwrap"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/sdk/helper/certutil"
|
||||
"github.com/hashicorp/vault/sdk/physical"
|
||||
)
|
||||
|
@ -180,20 +179,18 @@ func setupCassandraTLS(conf map[string]string, cluster *gocql.ClusterConfig) err
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if pemJSONPath, ok := conf["pem_json_file"]; ok {
|
||||
pemJSONData, err := ioutil.ReadFile(pemJSONPath)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("error reading json bundle from %q: {{err}}", pemJSONPath), err)
|
||||
}
|
||||
pemJSON, err := certutil.ParsePKIJSON([]byte(pemJSONData))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tlsConfig, err = pemJSON.GetTLSConfig(certutil.TLSClient)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if pemJSONPath, ok := conf["pem_json_file"]; ok {
|
||||
pemJSONData, err := ioutil.ReadFile(pemJSONPath)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("error reading json bundle from %q: {{err}}", pemJSONPath), err)
|
||||
}
|
||||
pemJSON, err := certutil.ParsePKIJSON([]byte(pemJSONData))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tlsConfig, err = pemJSON.GetTLSConfig(certutil.TLSClient)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -225,7 +222,8 @@ func setupCassandraTLS(conf map[string]string, cluster *gocql.ClusterConfig) err
|
|||
}
|
||||
|
||||
cluster.SslOpts = &gocql.SslOptions{
|
||||
Config: tlsConfig.Clone(),
|
||||
Config: tlsConfig,
|
||||
EnableHostVerification: !tlsConfig.InsecureSkipVerify,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -19,13 +19,13 @@ func TestCassandraBackend(t *testing.T) {
|
|||
t.Skip("skipping race test in CI pending https://github.com/gocql/gocql/pull/1474")
|
||||
}
|
||||
|
||||
cleanup, hosts := cassandra.PrepareTestContainer(t, "")
|
||||
host, cleanup := cassandra.PrepareTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
// Run vault tests
|
||||
logger := logging.NewVaultLogger(log.Debug)
|
||||
b, err := NewCassandraBackend(map[string]string{
|
||||
"hosts": hosts,
|
||||
"hosts": host.ConnectionURL(),
|
||||
"protocol_version": "3",
|
||||
}, logger)
|
||||
if err != nil {
|
||||
|
|
|
@ -3,7 +3,6 @@ package cassandra
|
|||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -17,14 +16,16 @@ import (
|
|||
)
|
||||
|
||||
func getCassandra(t *testing.T, protocolVersion interface{}) (*Cassandra, func()) {
|
||||
cleanup, connURL := cassandra.PrepareTestContainer(t, "latest")
|
||||
pieces := strings.Split(connURL, ":")
|
||||
host, cleanup := cassandra.PrepareTestContainer(t,
|
||||
cassandra.Version("latest"),
|
||||
cassandra.CopyFromTo(insecureFileMounts),
|
||||
)
|
||||
|
||||
db := new()
|
||||
initReq := dbplugin.InitializeRequest{
|
||||
Config: map[string]interface{}{
|
||||
"hosts": connURL,
|
||||
"port": pieces[1],
|
||||
"hosts": host.ConnectionURL(),
|
||||
"port": host.Port,
|
||||
"username": "cassandra",
|
||||
"password": "cassandra",
|
||||
"protocol_version": protocolVersion,
|
||||
|
@ -34,8 +35,8 @@ func getCassandra(t *testing.T, protocolVersion interface{}) (*Cassandra, func()
|
|||
}
|
||||
|
||||
expectedConfig := map[string]interface{}{
|
||||
"hosts": connURL,
|
||||
"port": pieces[1],
|
||||
"hosts": host.ConnectionURL(),
|
||||
"port": host.Port,
|
||||
"username": "cassandra",
|
||||
"password": "cassandra",
|
||||
"protocol_version": protocolVersion,
|
||||
|
@ -53,7 +54,7 @@ func getCassandra(t *testing.T, protocolVersion interface{}) (*Cassandra, func()
|
|||
return db, cleanup
|
||||
}
|
||||
|
||||
func TestCassandra_Initialize(t *testing.T) {
|
||||
func TestInitialize(t *testing.T) {
|
||||
db, cleanup := getCassandra(t, 4)
|
||||
defer cleanup()
|
||||
|
||||
|
@ -66,7 +67,7 @@ func TestCassandra_Initialize(t *testing.T) {
|
|||
defer cleanup()
|
||||
}
|
||||
|
||||
func TestCassandra_CreateUser(t *testing.T) {
|
||||
func TestCreateUser(t *testing.T) {
|
||||
type testCase struct {
|
||||
// Config will have the hosts & port added to it during the test
|
||||
config map[string]interface{}
|
||||
|
@ -126,15 +127,17 @@ func TestCassandra_CreateUser(t *testing.T) {
|
|||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
cleanup, connURL := cassandra.PrepareTestContainer(t, "latest")
|
||||
pieces := strings.Split(connURL, ":")
|
||||
host, cleanup := cassandra.PrepareTestContainer(t,
|
||||
cassandra.Version("latest"),
|
||||
cassandra.CopyFromTo(insecureFileMounts),
|
||||
)
|
||||
defer cleanup()
|
||||
|
||||
db := new()
|
||||
|
||||
config := test.config
|
||||
config["hosts"] = connURL
|
||||
config["port"] = pieces[1]
|
||||
config["hosts"] = host.ConnectionURL()
|
||||
config["port"] = host.Port
|
||||
|
||||
initReq := dbplugin.InitializeRequest{
|
||||
Config: config,
|
||||
|
@ -162,7 +165,7 @@ func TestCassandra_CreateUser(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestMyCassandra_UpdateUserPassword(t *testing.T) {
|
||||
func TestUpdateUserPassword(t *testing.T) {
|
||||
db, cleanup := getCassandra(t, 4)
|
||||
defer cleanup()
|
||||
|
||||
|
@ -198,7 +201,7 @@ func TestMyCassandra_UpdateUserPassword(t *testing.T) {
|
|||
assertCreds(t, db.Hosts, db.Port, createResp.Username, newPassword, 5*time.Second)
|
||||
}
|
||||
|
||||
func TestCassandra_DeleteUser(t *testing.T) {
|
||||
func TestDeleteUser(t *testing.T) {
|
||||
db, cleanup := getCassandra(t, 4)
|
||||
defer cleanup()
|
||||
|
||||
|
|
|
@ -8,14 +8,13 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
|
||||
"github.com/hashicorp/vault/sdk/database/helper/connutil"
|
||||
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/certutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/parseutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/tlsutil"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
|
@ -40,9 +39,7 @@ type cassandraConnectionProducer struct {
|
|||
|
||||
connectTimeout time.Duration
|
||||
socketKeepAlive time.Duration
|
||||
certificate string
|
||||
privateKey string
|
||||
issuingCA string
|
||||
certBundle *certutil.CertBundle
|
||||
rawConfig map[string]interface{}
|
||||
|
||||
Initialized bool
|
||||
|
@ -99,9 +96,7 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, req dbplug
|
|||
if err != nil {
|
||||
return fmt.Errorf("error marshaling PEM information: %w", err)
|
||||
}
|
||||
c.certificate = certBundle.Certificate
|
||||
c.privateKey = certBundle.PrivateKey
|
||||
c.issuingCA = certBundle.IssuingCA
|
||||
c.certBundle = certBundle
|
||||
c.TLS = true
|
||||
|
||||
case len(c.PemBundle) != 0:
|
||||
|
@ -113,9 +108,11 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, req dbplug
|
|||
if err != nil {
|
||||
return fmt.Errorf("error marshaling PEM information: %w", err)
|
||||
}
|
||||
c.certificate = certBundle.Certificate
|
||||
c.privateKey = certBundle.PrivateKey
|
||||
c.issuingCA = certBundle.IssuingCA
|
||||
c.certBundle = certBundle
|
||||
c.TLS = true
|
||||
}
|
||||
|
||||
if c.InsecureTLS {
|
||||
c.TLS = true
|
||||
}
|
||||
|
||||
|
@ -185,49 +182,13 @@ func (c *cassandraConnectionProducer) createSession(ctx context.Context) (*gocql
|
|||
|
||||
clusterConfig.Timeout = c.connectTimeout
|
||||
clusterConfig.SocketKeepalive = c.socketKeepAlive
|
||||
|
||||
if c.TLS {
|
||||
var tlsConfig *tls.Config
|
||||
if len(c.certificate) > 0 || len(c.issuingCA) > 0 {
|
||||
if len(c.certificate) > 0 && len(c.privateKey) == 0 {
|
||||
return nil, fmt.Errorf("found certificate for TLS authentication but no private key")
|
||||
}
|
||||
|
||||
certBundle := &certutil.CertBundle{}
|
||||
if len(c.certificate) > 0 {
|
||||
certBundle.Certificate = c.certificate
|
||||
certBundle.PrivateKey = c.privateKey
|
||||
}
|
||||
if len(c.issuingCA) > 0 {
|
||||
certBundle.IssuingCA = c.issuingCA
|
||||
}
|
||||
|
||||
parsedCertBundle, err := certBundle.ToParsedCertBundle()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse certificate bundle: %w", err)
|
||||
}
|
||||
|
||||
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
|
||||
if err != nil || tlsConfig == nil {
|
||||
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%w", tlsConfig, err)
|
||||
}
|
||||
tlsConfig.InsecureSkipVerify = c.InsecureTLS
|
||||
|
||||
if c.TLSMinVersion != "" {
|
||||
var ok bool
|
||||
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[c.TLSMinVersion]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
|
||||
}
|
||||
} else {
|
||||
// MinVersion was not being set earlier. Reset it to
|
||||
// zero to gracefully handle upgrades.
|
||||
tlsConfig.MinVersion = 0
|
||||
}
|
||||
}
|
||||
|
||||
clusterConfig.SslOpts = &gocql.SslOptions{
|
||||
Config: tlsConfig,
|
||||
sslOpts, err := getSslOpts(c.certBundle, c.TLSMinVersion, c.InsecureTLS)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clusterConfig.SslOpts = sslOpts
|
||||
}
|
||||
|
||||
if c.LocalDatacenter != "" {
|
||||
|
@ -269,6 +230,48 @@ func (c *cassandraConnectionProducer) createSession(ctx context.Context) (*gocql
|
|||
return session, nil
|
||||
}
|
||||
|
||||
func getSslOpts(certBundle *certutil.CertBundle, minTLSVersion string, insecureSkipVerify bool) (*gocql.SslOptions, error) {
|
||||
tlsConfig := &tls.Config{}
|
||||
if certBundle != nil {
|
||||
if certBundle.Certificate == "" && certBundle.PrivateKey != "" {
|
||||
return nil, fmt.Errorf("found private key for TLS authentication but no certificate")
|
||||
}
|
||||
if certBundle.Certificate != "" && certBundle.PrivateKey == "" {
|
||||
return nil, fmt.Errorf("found certificate for TLS authentication but no private key")
|
||||
}
|
||||
|
||||
parsedCertBundle, err := certBundle.ToParsedCertBundle()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse certificate bundle: %w", err)
|
||||
}
|
||||
|
||||
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%w", tlsConfig, err)
|
||||
}
|
||||
}
|
||||
|
||||
tlsConfig.InsecureSkipVerify = insecureSkipVerify
|
||||
|
||||
if minTLSVersion != "" {
|
||||
var ok bool
|
||||
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[minTLSVersion]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
|
||||
}
|
||||
} else {
|
||||
// MinVersion was not being set earlier. Reset it to
|
||||
// zero to gracefully handle upgrades.
|
||||
tlsConfig.MinVersion = 0
|
||||
}
|
||||
|
||||
opts := &gocql.SslOptions{
|
||||
Config: tlsConfig,
|
||||
EnableHostVerification: !insecureSkipVerify,
|
||||
}
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func (c *cassandraConnectionProducer) secretValues() map[string]string {
|
||||
return map[string]string{
|
||||
c.Password: "[password]",
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
package cassandra
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/hashicorp/vault/helper/testhelpers/cassandra"
|
||||
"github.com/hashicorp/vault/sdk/database/dbplugin/v5"
|
||||
)
|
||||
|
||||
var (
|
||||
insecureFileMounts = map[string]string{
|
||||
"test-fixtures/no_tls/cassandra.yaml": "/etc/cassandra/cassandra.yaml",
|
||||
}
|
||||
secureFileMounts = map[string]string{
|
||||
"test-fixtures/with_tls/cassandra.yaml": "/etc/cassandra/cassandra.yaml",
|
||||
"test-fixtures/with_tls/keystore.jks": "/etc/cassandra/keystore.jks",
|
||||
"test-fixtures/with_tls/.cassandra": "/root/.cassandra/",
|
||||
}
|
||||
)
|
||||
|
||||
func TestTLSConnection(t *testing.T) {
|
||||
type testCase struct {
|
||||
config map[string]interface{}
|
||||
expectErr bool
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"tls not specified": {
|
||||
config: map[string]interface{}{},
|
||||
expectErr: true,
|
||||
},
|
||||
"unrecognized certificate": {
|
||||
config: map[string]interface{}{
|
||||
"tls": "true",
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
"insecure TLS": {
|
||||
config: map[string]interface{}{
|
||||
"tls": "true",
|
||||
"insecure_tls": true,
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
host, cleanup := cassandra.PrepareTestContainer(t,
|
||||
cassandra.Version("3.11.9"),
|
||||
cassandra.CopyFromTo(secureFileMounts),
|
||||
cassandra.SslOpts(&gocql.SslOptions{
|
||||
Config: &tls.Config{InsecureSkipVerify: true},
|
||||
EnableHostVerification: false,
|
||||
}),
|
||||
)
|
||||
defer cleanup()
|
||||
|
||||
// Set values that we don't know until the cassandra container is started
|
||||
config := map[string]interface{}{
|
||||
"hosts": host.ConnectionURL(),
|
||||
"port": host.Port,
|
||||
"username": "cassandra",
|
||||
"password": "cassandra",
|
||||
"protocol_version": "3",
|
||||
"connect_timeout": "20s",
|
||||
}
|
||||
// Then add any values specified in the test config. Generally for these tests they shouldn't overlap
|
||||
for k, v := range test.config {
|
||||
config[k] = v
|
||||
}
|
||||
|
||||
db := new()
|
||||
initReq := dbplugin.InitializeRequest{
|
||||
Config: config,
|
||||
VerifyConnection: true,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err := db.Initialize(ctx, initReq)
|
||||
if test.expectErr && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Fatalf("no error expected, got: %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
[ssl]
|
||||
validate = false
|
||||
version = SSLv23
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,46 @@
|
|||
#!/bin/sh
|
||||
|
||||
################################################################
|
||||
# Usage: ./gencert.sh
|
||||
#
|
||||
# Generates a keystore.jks file that can be used with a
|
||||
# Cassandra server for TLS connections. This does not update
|
||||
# a cassandra config file.
|
||||
################################################################
|
||||
|
||||
set -e
|
||||
|
||||
KEYFILE="key.pem"
|
||||
CERTFILE="cert.pem"
|
||||
PKCSFILE="keystore.p12"
|
||||
JKSFILE="keystore.jks"
|
||||
|
||||
HOST="127.0.0.1"
|
||||
NAME="cassandra"
|
||||
ALIAS="cassandra"
|
||||
PASSWORD="cassandra"
|
||||
|
||||
echo "# Generating certificate keypair..."
|
||||
go run /usr/local/go/src/crypto/tls/generate_cert.go --host=${HOST}
|
||||
|
||||
echo "# Creating keystore..."
|
||||
openssl pkcs12 -export -in ${CERTFILE} -inkey ${KEYFILE} -name ${NAME} -password pass:${PASSWORD} > ${PKCSFILE}
|
||||
|
||||
echo "# Creating Java key store"
|
||||
if [ -e "${JKSFILE}" ]; then
|
||||
echo "# Removing old key store"
|
||||
rm ${JKSFILE}
|
||||
fi
|
||||
|
||||
set +e
|
||||
keytool -importkeystore \
|
||||
-srckeystore ${PKCSFILE} \
|
||||
-srcstoretype PKCS12 \
|
||||
-srcstorepass ${PASSWORD} \
|
||||
-destkeystore ${JKSFILE} \
|
||||
-deststorepass ${PASSWORD} \
|
||||
-destkeypass ${PASSWORD} \
|
||||
-alias ${ALIAS}
|
||||
|
||||
echo "# Removing intermediate files"
|
||||
rm ${KEYFILE} ${CERTFILE} ${PKCSFILE}
|
BIN
plugins/database/cassandra/test-fixtures/with_tls/keystore.jks (Stored with Git LFS)
Normal file
BIN
plugins/database/cassandra/test-fixtures/with_tls/keystore.jks (Stored with Git LFS)
Normal file
Binary file not shown.
|
@ -31,8 +31,10 @@ env:
|
|||
AUTH=false
|
||||
|
||||
go:
|
||||
- 1.13.x
|
||||
- 1.14.x
|
||||
- 1.15.x
|
||||
- 1.16.x
|
||||
|
||||
go_import_path: github.com/gocql/gocql
|
||||
|
||||
install:
|
||||
- ./install_test_deps.sh $TRAVIS_REPO_SLUG
|
||||
|
|
|
@ -115,3 +115,8 @@ Pavel Buchinchik <p.buchinchik@gmail.com>
|
|||
Rintaro Okamura <rintaro.okamura@gmail.com>
|
||||
Yura Sokolov <y.sokolov@joom.com>; <funny.falcon@gmail.com>
|
||||
Jorge Bay <jorgebg@apache.org>
|
||||
Dmitriy Kozlov <hummerd@mail.ru>
|
||||
Alexey Romanovsky <alexus1024+gocql@gmail.com>
|
||||
Jaume Marhuenda Beltran <jaumemarhuenda@gmail.com>
|
||||
Piotr Dulikowski <piodul@scylladb.com>
|
||||
Árni Dagur <arni@dagur.eu>
|
|
@ -19,8 +19,8 @@ The following matrix shows the versions of Go and Cassandra that are tested with
|
|||
|
||||
Go/Cassandra | 2.1.x | 2.2.x | 3.x.x
|
||||
-------------| -------| ------| ---------
|
||||
1.13 | yes | yes | yes
|
||||
1.14 | yes | yes | yes
|
||||
1.15 | yes | yes | yes
|
||||
1.16 | yes | yes | yes
|
||||
|
||||
Gocql has been tested in production against many different versions of Cassandra. Due to limits in our CI setup we only test against the latest 3 major releases, which coincide with the official support from the Apache project.
|
||||
|
||||
|
@ -114,73 +114,7 @@ statement.
|
|||
Example
|
||||
-------
|
||||
|
||||
```go
|
||||
/* Before you execute the program, Launch `cqlsh` and execute:
|
||||
create keyspace example with replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };
|
||||
create table example.tweet(timeline text, id UUID, text text, PRIMARY KEY(id));
|
||||
create index on example.tweet(timeline);
|
||||
*/
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// connect to the cluster
|
||||
cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
|
||||
cluster.Keyspace = "example"
|
||||
cluster.Consistency = gocql.Quorum
|
||||
session, _ := cluster.CreateSession()
|
||||
defer session.Close()
|
||||
|
||||
// insert a tweet
|
||||
if err := session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
|
||||
"me", gocql.TimeUUID(), "hello world").Exec(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
var id gocql.UUID
|
||||
var text string
|
||||
|
||||
/* Search for a specific set of records whose 'timeline' column matches
|
||||
* the value 'me'. The secondary index that we created earlier will be
|
||||
* used for optimizing the search */
|
||||
if err := session.Query(`SELECT id, text FROM tweet WHERE timeline = ? LIMIT 1`,
|
||||
"me").Consistency(gocql.One).Scan(&id, &text); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println("Tweet:", id, text)
|
||||
|
||||
// list all tweets
|
||||
iter := session.Query(`SELECT id, text FROM tweet WHERE timeline = ?`, "me").Iter()
|
||||
for iter.Scan(&id, &text) {
|
||||
fmt.Println("Tweet:", id, text)
|
||||
}
|
||||
if err := iter.Close(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
Authentication
|
||||
-------
|
||||
|
||||
```go
|
||||
cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
|
||||
cluster.Authenticator = gocql.PasswordAuthenticator{
|
||||
Username: "user",
|
||||
Password: "password"
|
||||
}
|
||||
cluster.Keyspace = "example"
|
||||
cluster.Consistency = gocql.Quorum
|
||||
session, _ := cluster.CreateSession()
|
||||
defer session.Close()
|
||||
```
|
||||
See [package documentation](https://pkg.go.dev/github.com/gocql/gocql#pkg-examples).
|
||||
|
||||
Data Binding
|
||||
------------
|
||||
|
|
|
@ -44,8 +44,8 @@ func approve(authenticator string) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
//JoinHostPort is a utility to return a address string that can be used
|
||||
//gocql.Conn to form a connection with a host.
|
||||
// JoinHostPort is a utility to return an address string that can be used
|
||||
// by `gocql.Conn` to form a connection with a host.
|
||||
func JoinHostPort(addr string, port int) string {
|
||||
addr = strings.TrimSpace(addr)
|
||||
if _, _, err := net.SplitHostPort(addr); err != nil {
|
||||
|
@ -80,6 +80,19 @@ func (p PasswordAuthenticator) Success(data []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// SslOptions configures TLS use.
|
||||
//
|
||||
// Warning: Due to historical reasons, the SslOptions is insecure by default, so you need to set EnableHostVerification
|
||||
// to true if no Config is set. Most users should set SslOptions.Config to a *tls.Config.
|
||||
// SslOptions and Config.InsecureSkipVerify interact as follows:
|
||||
//
|
||||
// Config.InsecureSkipVerify | EnableHostVerification | Result
|
||||
// Config is nil | false | do not verify host
|
||||
// Config is nil | true | verify host
|
||||
// false | false | verify host
|
||||
// true | false | do not verify host
|
||||
// false | true | verify host
|
||||
// true | true | verify host
|
||||
type SslOptions struct {
|
||||
*tls.Config
|
||||
|
||||
|
@ -89,9 +102,12 @@ type SslOptions struct {
|
|||
CertPath string
|
||||
KeyPath string
|
||||
CaPath string //optional depending on server config
|
||||
// If you want to verify the hostname and server cert (like a wildcard for cass cluster) then you should turn this on
|
||||
// This option is basically the inverse of InSecureSkipVerify
|
||||
// See InSecureSkipVerify in http://golang.org/pkg/crypto/tls/ for more info
|
||||
// If you want to verify the hostname and server cert (like a wildcard for cass cluster) then you should turn this
|
||||
// on.
|
||||
// This option is basically the inverse of tls.Config.InsecureSkipVerify.
|
||||
// See InsecureSkipVerify in http://golang.org/pkg/crypto/tls/ for more info.
|
||||
//
|
||||
// See SslOptions documentation to see how EnableHostVerification interacts with the provided tls.Config.
|
||||
EnableHostVerification bool
|
||||
}
|
||||
|
||||
|
@ -125,7 +141,7 @@ func (fn connErrorHandlerFn) HandleError(conn *Conn, err error, closed bool) {
|
|||
// which may be serving more queries just fine.
|
||||
// Default is 0, should not be changed concurrently with queries.
|
||||
//
|
||||
// depreciated
|
||||
// Deprecated.
|
||||
var TimeoutLimit int64 = 0
|
||||
|
||||
// Conn is a single connection to a Cassandra node. It can be used to execute
|
||||
|
@ -213,14 +229,26 @@ func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg *
|
|||
dialer = d
|
||||
}
|
||||
|
||||
conn, err := dialer.DialContext(ctx, "tcp", host.HostnameAndPort())
|
||||
addr := host.HostnameAndPort()
|
||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg.tlsConfig != nil {
|
||||
// the TLS config is safe to be reused by connections but it must not
|
||||
// be modified after being used.
|
||||
tconn := tls.Client(conn, cfg.tlsConfig)
|
||||
tlsConfig := cfg.tlsConfig
|
||||
if !tlsConfig.InsecureSkipVerify && tlsConfig.ServerName == "" {
|
||||
colonPos := strings.LastIndex(addr, ":")
|
||||
if colonPos == -1 {
|
||||
colonPos = len(addr)
|
||||
}
|
||||
hostname := addr[:colonPos]
|
||||
// clone config to avoid modifying the shared one.
|
||||
tlsConfig = tlsConfig.Clone()
|
||||
tlsConfig.ServerName = hostname
|
||||
}
|
||||
tconn := tls.Client(conn, tlsConfig)
|
||||
if err := tconn.Handshake(); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
|
@ -845,6 +873,10 @@ func (w *writeCoalescer) writeFlusher(interval time.Duration) {
|
|||
}
|
||||
|
||||
func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*framer, error) {
|
||||
if ctxErr := ctx.Err(); ctxErr != nil {
|
||||
return nil, ctxErr
|
||||
}
|
||||
|
||||
// TODO: move tracer onto conn
|
||||
stream, ok := c.streams.GetStream()
|
||||
if !ok {
|
||||
|
@ -1173,12 +1205,16 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
|
|||
}
|
||||
|
||||
if x.meta.morePages() && !qry.disableAutoPage {
|
||||
newQry := new(Query)
|
||||
*newQry = *qry
|
||||
newQry.pageState = copyBytes(x.meta.pagingState)
|
||||
newQry.metrics = &queryMetrics{m: make(map[string]*hostMetrics)}
|
||||
|
||||
iter.next = &nextIter{
|
||||
qry: qry,
|
||||
qry: newQry,
|
||||
pos: int((1 - qry.prefetch) * float64(x.numRows)),
|
||||
}
|
||||
|
||||
iter.next.qry.pageState = copyBytes(x.meta.pagingState)
|
||||
if iter.next.pos < 1 {
|
||||
iter.next.pos = 1
|
||||
}
|
||||
|
@ -1359,10 +1395,11 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter {
|
|||
}
|
||||
|
||||
func (c *Conn) query(ctx context.Context, statement string, values ...interface{}) (iter *Iter) {
|
||||
q := c.session.Query(statement, values...).Consistency(One)
|
||||
q.trace = nil
|
||||
q := c.session.Query(statement, values...).Consistency(One).Trace(nil)
|
||||
q.skipPrepare = true
|
||||
q.disableSkipMetadata = true
|
||||
// we want to keep the query on this connection
|
||||
q.conn = c
|
||||
return c.executeQuery(ctx, q)
|
||||
}
|
||||
|
||||
|
|
|
@ -28,14 +28,31 @@ type SetPartitioner interface {
|
|||
}
|
||||
|
||||
func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
|
||||
// Config.InsecureSkipVerify | EnableHostVerification | Result
|
||||
// Config is nil | true | verify host
|
||||
// Config is nil | false | do not verify host
|
||||
// false | false | verify host
|
||||
// true | false | do not verify host
|
||||
// false | true | verify host
|
||||
// true | true | verify host
|
||||
var tlsConfig *tls.Config
|
||||
if sslOpts.Config == nil {
|
||||
sslOpts.Config = &tls.Config{}
|
||||
tlsConfig = &tls.Config{
|
||||
InsecureSkipVerify: !sslOpts.EnableHostVerification,
|
||||
}
|
||||
} else {
|
||||
// use clone to avoid race.
|
||||
tlsConfig = sslOpts.Config.Clone()
|
||||
}
|
||||
|
||||
if tlsConfig.InsecureSkipVerify && sslOpts.EnableHostVerification {
|
||||
tlsConfig.InsecureSkipVerify = false
|
||||
}
|
||||
|
||||
// ca cert is optional
|
||||
if sslOpts.CaPath != "" {
|
||||
if sslOpts.RootCAs == nil {
|
||||
sslOpts.RootCAs = x509.NewCertPool()
|
||||
if tlsConfig.RootCAs == nil {
|
||||
tlsConfig.RootCAs = x509.NewCertPool()
|
||||
}
|
||||
|
||||
pem, err := ioutil.ReadFile(sslOpts.CaPath)
|
||||
|
@ -43,7 +60,7 @@ func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
|
|||
return nil, fmt.Errorf("connectionpool: unable to open CA certs: %v", err)
|
||||
}
|
||||
|
||||
if !sslOpts.RootCAs.AppendCertsFromPEM(pem) {
|
||||
if !tlsConfig.RootCAs.AppendCertsFromPEM(pem) {
|
||||
return nil, errors.New("connectionpool: failed parsing or CA certs")
|
||||
}
|
||||
}
|
||||
|
@ -53,13 +70,10 @@ func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("connectionpool: unable to load X509 key pair: %v", err)
|
||||
}
|
||||
sslOpts.Certificates = append(sslOpts.Certificates, mycert)
|
||||
tlsConfig.Certificates = append(tlsConfig.Certificates, mycert)
|
||||
}
|
||||
|
||||
sslOpts.InsecureSkipVerify = !sslOpts.EnableHostVerification
|
||||
|
||||
// return clone to avoid race
|
||||
return sslOpts.Config.Clone(), nil
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
type policyConnPool struct {
|
||||
|
@ -238,12 +252,6 @@ func (p *policyConnPool) removeHost(ip net.IP) {
|
|||
go pool.Close()
|
||||
}
|
||||
|
||||
func (p *policyConnPool) hostUp(host *HostInfo) {
|
||||
// TODO(zariel): have a set of up hosts and down hosts, we can internally
|
||||
// detect down hosts, then try to reconnect to them.
|
||||
p.addHost(host)
|
||||
}
|
||||
|
||||
func (p *policyConnPool) hostDown(ip net.IP) {
|
||||
// TODO(zariel): mark host as down so we can try to connect to it later, for
|
||||
// now just treat it has removed.
|
||||
|
@ -429,6 +437,8 @@ func (pool *hostConnPool) fill() {
|
|||
}
|
||||
return
|
||||
}
|
||||
// notify the session that this node is connected
|
||||
go pool.session.handleNodeUp(pool.host.ConnectAddress(), pool.port)
|
||||
|
||||
// filled one
|
||||
fillCount--
|
||||
|
@ -440,6 +450,11 @@ func (pool *hostConnPool) fill() {
|
|||
|
||||
// mark the end of filling
|
||||
pool.fillingStopped(err != nil)
|
||||
|
||||
if err == nil && startCount > 0 {
|
||||
// notify the session that this node is connected again
|
||||
go pool.session.handleNodeUp(pool.host.ConnectAddress(), pool.port)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
|
|
|
@ -125,7 +125,7 @@ func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
} else if len(ips) == 0 {
|
||||
return nil, fmt.Errorf("No IP's returned from DNS lookup for %q", addr)
|
||||
return nil, fmt.Errorf("no IP's returned from DNS lookup for %q", addr)
|
||||
}
|
||||
|
||||
// Filter to v4 addresses if any present
|
||||
|
@ -177,7 +177,7 @@ func (c *controlConn) shuffleDial(endpoints []*HostInfo) (*Conn, error) {
|
|||
return conn, nil
|
||||
}
|
||||
|
||||
Logger.Printf("gocql: unable to dial control conn %v: %v\n", host.ConnectAddress(), err)
|
||||
Logger.Printf("gocql: unable to dial control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
|
||||
}
|
||||
|
||||
return nil, err
|
||||
|
@ -285,8 +285,6 @@ func (c *controlConn) setupConn(conn *Conn) error {
|
|||
}
|
||||
|
||||
c.conn.Store(ch)
|
||||
c.session.handleNodeUp(host.ConnectAddress(), host.Port(), false)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -452,6 +450,8 @@ func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter
|
|||
|
||||
for {
|
||||
iter = c.withConn(func(conn *Conn) *Iter {
|
||||
// we want to keep the query on the control connection
|
||||
q.conn = conn
|
||||
return conn.executeQuery(context.TODO(), q)
|
||||
})
|
||||
|
||||
|
|
|
@ -4,6 +4,319 @@
|
|||
|
||||
// Package gocql implements a fast and robust Cassandra driver for the
|
||||
// Go programming language.
|
||||
//
|
||||
// Connecting to the cluster
|
||||
//
|
||||
// Pass a list of initial node IP addresses to NewCluster to create a new cluster configuration:
|
||||
//
|
||||
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
|
||||
//
|
||||
// Port can be specified as part of the address, the above is equivalent to:
|
||||
//
|
||||
// cluster := gocql.NewCluster("192.168.1.1:9042", "192.168.1.2:9042", "192.168.1.3:9042")
|
||||
//
|
||||
// It is recommended to use the value set in the Cassandra config for broadcast_address or listen_address,
|
||||
// an IP address not a domain name. This is because events from Cassandra will use the configured IP
|
||||
// address, which is used to index connected hosts. If the domain name specified resolves to more than 1 IP address
|
||||
// then the driver may connect multiple times to the same host, and will not mark the node being down or up from events.
|
||||
//
|
||||
// Then you can customize more options (see ClusterConfig):
|
||||
//
|
||||
// cluster.Keyspace = "example"
|
||||
// cluster.Consistency = gocql.Quorum
|
||||
// cluster.ProtoVersion = 4
|
||||
//
|
||||
// The driver tries to automatically detect the protocol version to use if not set, but you might want to set the
|
||||
// protocol version explicitly, as it's not defined which version will be used in certain situations (for example
|
||||
// during upgrade of the cluster when some of the nodes support different set of protocol versions than other nodes).
|
||||
//
|
||||
// When ready, create a session from the configuration. Don't forget to Close the session once you are done with it:
|
||||
//
|
||||
// session, err := cluster.CreateSession()
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// defer session.Close()
|
||||
//
|
||||
// Authentication
|
||||
//
|
||||
// CQL protocol uses a SASL-based authentication mechanism and so consists of an exchange of server challenges and
|
||||
// client response pairs. The details of the exchanged messages depend on the authenticator used.
|
||||
//
|
||||
// To use authentication, set ClusterConfig.Authenticator or ClusterConfig.AuthProvider.
|
||||
//
|
||||
// PasswordAuthenticator is provided to use for username/password authentication:
|
||||
//
|
||||
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
|
||||
// cluster.Authenticator = gocql.PasswordAuthenticator{
|
||||
// Username: "user",
|
||||
// Password: "password"
|
||||
// }
|
||||
// session, err := cluster.CreateSession()
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// defer session.Close()
|
||||
//
|
||||
// Transport layer security
|
||||
//
|
||||
// It is possible to secure traffic between the client and server with TLS.
|
||||
//
|
||||
// To use TLS, set the ClusterConfig.SslOpts field. SslOptions embeds *tls.Config so you can set that directly.
|
||||
// There are also helpers to load keys/certificates from files.
|
||||
//
|
||||
// Warning: Due to historical reasons, the SslOptions is insecure by default, so you need to set EnableHostVerification
|
||||
// to true if no Config is set. Most users should set SslOptions.Config to a *tls.Config.
|
||||
// SslOptions and Config.InsecureSkipVerify interact as follows:
|
||||
//
|
||||
// Config.InsecureSkipVerify | EnableHostVerification | Result
|
||||
// Config is nil | false | do not verify host
|
||||
// Config is nil | true | verify host
|
||||
// false | false | verify host
|
||||
// true | false | do not verify host
|
||||
// false | true | verify host
|
||||
// true | true | verify host
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
|
||||
// cluster.SslOpts = &gocql.SslOptions{
|
||||
// EnableHostVerification: true,
|
||||
// }
|
||||
// session, err := cluster.CreateSession()
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// defer session.Close()
|
||||
//
|
||||
// Executing queries
|
||||
//
|
||||
// Create queries with Session.Query. Query values must not be reused between different executions and must not be
|
||||
// modified after starting execution of the query.
|
||||
//
|
||||
// To execute a query without reading results, use Query.Exec:
|
||||
//
|
||||
// err := session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
|
||||
// "me", gocql.TimeUUID(), "hello world").WithContext(ctx).Exec()
|
||||
//
|
||||
// Single row can be read by calling Query.Scan:
|
||||
//
|
||||
// err := session.Query(`SELECT id, text FROM tweet WHERE timeline = ? LIMIT 1`,
|
||||
// "me").WithContext(ctx).Consistency(gocql.One).Scan(&id, &text)
|
||||
//
|
||||
// Multiple rows can be read using Iter.Scanner:
|
||||
//
|
||||
// scanner := session.Query(`SELECT id, text FROM tweet WHERE timeline = ?`,
|
||||
// "me").WithContext(ctx).Iter().Scanner()
|
||||
// for scanner.Next() {
|
||||
// var (
|
||||
// id gocql.UUID
|
||||
// text string
|
||||
// )
|
||||
// err = scanner.Scan(&id, &text)
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
// fmt.Println("Tweet:", id, text)
|
||||
// }
|
||||
// // scanner.Err() closes the iterator, so scanner nor iter should be used afterwards.
|
||||
// if err := scanner.Err(); err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// See Example for complete example.
|
||||
//
|
||||
// Prepared statements
|
||||
//
|
||||
// The driver automatically prepares DML queries (SELECT/INSERT/UPDATE/DELETE/BATCH statements) and maintains a cache
|
||||
// of prepared statements.
|
||||
// CQL protocol does not support preparing other query types.
|
||||
//
|
||||
// When using CQL protocol >= 4, it is possible to use gocql.UnsetValue as the bound value of a column.
|
||||
// This will cause the database to ignore writing the column.
|
||||
// The main advantage is the ability to keep the same prepared statement even when you don't
|
||||
// want to update some fields, where before you needed to make another prepared statement.
|
||||
//
|
||||
// Executing multiple queries concurrently
|
||||
//
|
||||
// Session is safe to use from multiple goroutines, so to execute multiple concurrent queries, just execute them
|
||||
// from several worker goroutines. Gocql provides synchronously-looking API (as recommended for Go APIs) and the queries
|
||||
// are executed asynchronously at the protocol level.
|
||||
//
|
||||
// results := make(chan error, 2)
|
||||
// go func() {
|
||||
// results <- session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
|
||||
// "me", gocql.TimeUUID(), "hello world 1").Exec()
|
||||
// }()
|
||||
// go func() {
|
||||
// results <- session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
|
||||
// "me", gocql.TimeUUID(), "hello world 2").Exec()
|
||||
// }()
|
||||
//
|
||||
// Nulls
|
||||
//
|
||||
// Null values are are unmarshalled as zero value of the type. If you need to distinguish for example between text
|
||||
// column being null and empty string, you can unmarshal into *string variable instead of string.
|
||||
//
|
||||
// var text *string
|
||||
// err := scanner.Scan(&text)
|
||||
// if err != nil {
|
||||
// // handle error
|
||||
// }
|
||||
// if text != nil {
|
||||
// // not null
|
||||
// }
|
||||
// else {
|
||||
// // null
|
||||
// }
|
||||
//
|
||||
// See Example_nulls for full example.
|
||||
//
|
||||
// Reusing slices
|
||||
//
|
||||
// The driver reuses backing memory of slices when unmarshalling. This is an optimization so that a buffer does not
|
||||
// need to be allocated for every processed row. However, you need to be careful when storing the slices to other
|
||||
// memory structures.
|
||||
//
|
||||
// scanner := session.Query(`SELECT myints FROM table WHERE pk = ?`, "key").WithContext(ctx).Iter().Scanner()
|
||||
// var myInts []int
|
||||
// for scanner.Next() {
|
||||
// // This scan reuses backing store of myInts for each row.
|
||||
// err = scanner.Scan(&myInts)
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// When you want to save the data for later use, pass a new slice every time. A common pattern is to declare the
|
||||
// slice variable within the scanner loop:
|
||||
//
|
||||
// scanner := session.Query(`SELECT myints FROM table WHERE pk = ?`, "key").WithContext(ctx).Iter().Scanner()
|
||||
// for scanner.Next() {
|
||||
// var myInts []int
|
||||
// // This scan always gets pointer to fresh myInts slice, so does not reuse memory.
|
||||
// err = scanner.Scan(&myInts)
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// Paging
|
||||
//
|
||||
// The driver supports paging of results with automatic prefetch, see ClusterConfig.PageSize, Session.SetPrefetch,
|
||||
// Query.PageSize, and Query.Prefetch.
|
||||
//
|
||||
// It is also possible to control the paging manually with Query.PageState (this disables automatic prefetch).
|
||||
// Manual paging is useful if you want to store the page state externally, for example in a URL to allow users
|
||||
// browse pages in a result. You might want to sign/encrypt the paging state when exposing it externally since
|
||||
// it contains data from primary keys.
|
||||
//
|
||||
// Paging state is specific to the CQL protocol version and the exact query used. It is meant as opaque state that
|
||||
// should not be modified. If you send paging state from different query or protocol version, then the behaviour
|
||||
// is not defined (you might get unexpected results or an error from the server). For example, do not send paging state
|
||||
// returned by node using protocol version 3 to a node using protocol version 4. Also, when using protocol version 4,
|
||||
// paging state between Cassandra 2.2 and 3.0 is incompatible (https://issues.apache.org/jira/browse/CASSANDRA-10880).
|
||||
//
|
||||
// The driver does not check whether the paging state is from the same protocol version/statement.
|
||||
// You might want to validate yourself as this could be a problem if you store paging state externally.
|
||||
// For example, if you store paging state in a URL, the URLs might become broken when you upgrade your cluster.
|
||||
//
|
||||
// Call Query.PageState(nil) to fetch just the first page of the query results. Pass the page state returned by
|
||||
// Iter.PageState to Query.PageState of a subsequent query to get the next page. If the length of slice returned
|
||||
// by Iter.PageState is zero, there are no more pages available (or an error occurred).
|
||||
//
|
||||
// Using too low values of PageSize will negatively affect performance, a value below 100 is probably too low.
|
||||
// While Cassandra returns exactly PageSize items (except for last page) in a page currently, the protocol authors
|
||||
// explicitly reserved the right to return smaller or larger amount of items in a page for performance reasons, so don't
|
||||
// rely on the page having the exact count of items.
|
||||
//
|
||||
// See Example_paging for an example of manual paging.
|
||||
//
|
||||
// Dynamic list of columns
|
||||
//
|
||||
// There are certain situations when you don't know the list of columns in advance, mainly when the query is supplied
|
||||
// by the user. Iter.Columns, Iter.RowData, Iter.MapScan and Iter.SliceMap can be used to handle this case.
|
||||
//
|
||||
// See Example_dynamicColumns.
|
||||
//
|
||||
// Batches
|
||||
//
|
||||
// The CQL protocol supports sending batches of DML statements (INSERT/UPDATE/DELETE) and so does gocql.
|
||||
// Use Session.NewBatch to create a new batch and then fill-in details of individual queries.
|
||||
// Then execute the batch with Session.ExecuteBatch.
|
||||
//
|
||||
// Logged batches ensure atomicity, either all or none of the operations in the batch will succeed, but they have
|
||||
// overhead to ensure this property.
|
||||
// Unlogged batches don't have the overhead of logged batches, but don't guarantee atomicity.
|
||||
// Updates of counters are handled specially by Cassandra so batches of counter updates have to use CounterBatch type.
|
||||
// A counter batch can only contain statements to update counters.
|
||||
//
|
||||
// For unlogged batches it is recommended to send only single-partition batches (i.e. all statements in the batch should
|
||||
// involve only a single partition).
|
||||
// Multi-partition batch needs to be split by the coordinator node and re-sent to
|
||||
// correct nodes.
|
||||
// With single-partition batches you can send the batch directly to the node for the partition without incurring the
|
||||
// additional network hop.
|
||||
//
|
||||
// It is also possible to pass entire BEGIN BATCH .. APPLY BATCH statement to Query.Exec.
|
||||
// There are differences how those are executed.
|
||||
// BEGIN BATCH statement passed to Query.Exec is prepared as a whole in a single statement.
|
||||
// Session.ExecuteBatch prepares individual statements in the batch.
|
||||
// If you have variable-length batches using the same statement, using Session.ExecuteBatch is more efficient.
|
||||
//
|
||||
// See Example_batch for an example.
|
||||
//
|
||||
// Lightweight transactions
|
||||
//
|
||||
// Query.ScanCAS or Query.MapScanCAS can be used to execute a single-statement lightweight transaction (an
|
||||
// INSERT/UPDATE .. IF statement) and reading its result. See example for Query.MapScanCAS.
|
||||
//
|
||||
// Multiple-statement lightweight transactions can be executed as a logged batch that contains at least one conditional
|
||||
// statement. All the conditions must return true for the batch to be applied. You can use Session.ExecuteBatchCAS and
|
||||
// Session.MapExecuteBatchCAS when executing the batch to learn about the result of the LWT. See example for
|
||||
// Session.MapExecuteBatchCAS.
|
||||
//
|
||||
// Retries and speculative execution
|
||||
//
|
||||
// Queries can be marked as idempotent. Marking the query as idempotent tells the driver that the query can be executed
|
||||
// multiple times without affecting its result. Non-idempotent queries are not eligible for retrying nor speculative
|
||||
// execution.
|
||||
//
|
||||
// Idempotent queries are retried in case of errors based on the configured RetryPolicy.
|
||||
//
|
||||
// Queries can be retried even before they fail by setting a SpeculativeExecutionPolicy. The policy can
|
||||
// cause the driver to retry on a different node if the query is taking longer than a specified delay even before the
|
||||
// driver receives an error or timeout from the server. When a query is speculatively executed, the original execution
|
||||
// is still executing. The two parallel executions of the query race to return a result, the first received result will
|
||||
// be returned.
|
||||
//
|
||||
// User-defined types
|
||||
//
|
||||
// UDTs can be mapped (un)marshaled from/to map[string]interface{} a Go struct (or a type implementing
|
||||
// UDTUnmarshaler, UDTMarshaler, Unmarshaler or Marshaler interfaces).
|
||||
//
|
||||
// For structs, cql tag can be used to specify the CQL field name to be mapped to a struct field:
|
||||
//
|
||||
// type MyUDT struct {
|
||||
// FieldA int32 `cql:"a"`
|
||||
// FieldB string `cql:"b"`
|
||||
// }
|
||||
//
|
||||
// See Example_userDefinedTypesMap, Example_userDefinedTypesStruct, ExampleUDTMarshaler, ExampleUDTUnmarshaler.
|
||||
//
|
||||
// Metrics and tracing
|
||||
//
|
||||
// It is possible to provide observer implementations that could be used to gather metrics:
|
||||
//
|
||||
// - QueryObserver for monitoring individual queries.
|
||||
// - BatchObserver for monitoring batch queries.
|
||||
// - ConnectObserver for monitoring new connections from the driver to the database.
|
||||
// - FrameHeaderObserver for monitoring individual protocol frames.
|
||||
//
|
||||
// CQL protocol also supports tracing of queries. When enabled, the database will write information about
|
||||
// internal events that happened during execution of the query. You can use Query.Trace to request tracing and receive
|
||||
// the session ID that the database used to store the trace information in system_traces.sessions and
|
||||
// system_traces.events tables. NewTraceWriter returns an implementation of Tracer that writes the events to a writer.
|
||||
// Gathering trace information might be essential for debugging and optimizing queries, but writing traces has overhead,
|
||||
// so this feature should not be used on production systems with very high load unless you know what you are doing.
|
||||
package gocql // import "github.com/gocql/gocql"
|
||||
|
||||
// TODO(tux21b): write more docs.
|
||||
|
|
|
@ -164,55 +164,43 @@ func (s *Session) handleNodeEvent(frames []frame) {
|
|||
|
||||
switch f.change {
|
||||
case "NEW_NODE":
|
||||
s.handleNewNode(f.host, f.port, true)
|
||||
s.handleNewNode(f.host, f.port)
|
||||
case "REMOVED_NODE":
|
||||
s.handleRemovedNode(f.host, f.port)
|
||||
case "MOVED_NODE":
|
||||
// java-driver handles this, not mentioned in the spec
|
||||
// TODO(zariel): refresh token map
|
||||
case "UP":
|
||||
s.handleNodeUp(f.host, f.port, true)
|
||||
s.handleNodeUp(f.host, f.port)
|
||||
case "DOWN":
|
||||
s.handleNodeDown(f.host, f.port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) addNewNode(host *HostInfo) {
|
||||
if s.cfg.filterHost(host) {
|
||||
return
|
||||
}
|
||||
|
||||
host.setState(NodeUp)
|
||||
s.pool.addHost(host)
|
||||
s.policy.AddHost(host)
|
||||
}
|
||||
|
||||
func (s *Session) handleNewNode(ip net.IP, port int, waitForBinary bool) {
|
||||
if gocqlDebug {
|
||||
Logger.Printf("gocql: Session.handleNewNode: %s:%d\n", ip.String(), port)
|
||||
}
|
||||
|
||||
ip, port = s.cfg.translateAddressPort(ip, port)
|
||||
|
||||
func (s *Session) addNewNode(ip net.IP, port int) {
|
||||
// Get host info and apply any filters to the host
|
||||
hostInfo, err := s.hostSource.getHostInfo(ip, port)
|
||||
if err != nil {
|
||||
Logger.Printf("gocql: events: unable to fetch host info for (%s:%d): %v\n", ip, port, err)
|
||||
return
|
||||
} else if hostInfo == nil {
|
||||
// If hostInfo is nil, this host was filtered out by cfg.HostFilter
|
||||
// ignore if it's null because we couldn't find it
|
||||
return
|
||||
}
|
||||
|
||||
if t := hostInfo.Version().nodeUpDelay(); t > 0 && waitForBinary {
|
||||
if t := hostInfo.Version().nodeUpDelay(); t > 0 {
|
||||
time.Sleep(t)
|
||||
}
|
||||
|
||||
// should this handle token moving?
|
||||
hostInfo = s.ring.addOrUpdate(hostInfo)
|
||||
|
||||
s.addNewNode(hostInfo)
|
||||
if !s.cfg.filterHost(hostInfo) {
|
||||
// we let the pool call handleNodeUp to change the host state
|
||||
s.pool.addHost(hostInfo)
|
||||
s.policy.AddHost(hostInfo)
|
||||
}
|
||||
|
||||
if s.control != nil && !s.cfg.IgnorePeerAddr {
|
||||
// TODO(zariel): debounce ring refresh
|
||||
|
@ -220,6 +208,22 @@ func (s *Session) handleNewNode(ip net.IP, port int, waitForBinary bool) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Session) handleNewNode(ip net.IP, port int) {
|
||||
if gocqlDebug {
|
||||
Logger.Printf("gocql: Session.handleNewNode: %s:%d\n", ip.String(), port)
|
||||
}
|
||||
|
||||
ip, port = s.cfg.translateAddressPort(ip, port)
|
||||
|
||||
// if we already have the host and it's already up, then do nothing
|
||||
host := s.ring.getHost(ip)
|
||||
if host != nil && host.IsUp() {
|
||||
return
|
||||
}
|
||||
|
||||
s.addNewNode(ip, port)
|
||||
}
|
||||
|
||||
func (s *Session) handleRemovedNode(ip net.IP, port int) {
|
||||
if gocqlDebug {
|
||||
Logger.Printf("gocql: Session.handleRemovedNode: %s:%d\n", ip.String(), port)
|
||||
|
@ -232,45 +236,37 @@ func (s *Session) handleRemovedNode(ip net.IP, port int) {
|
|||
if host == nil {
|
||||
host = &HostInfo{connectAddress: ip, port: port}
|
||||
}
|
||||
|
||||
if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
|
||||
return
|
||||
}
|
||||
s.ring.removeHost(ip)
|
||||
|
||||
host.setState(NodeDown)
|
||||
s.policy.RemoveHost(host)
|
||||
s.pool.removeHost(ip)
|
||||
s.ring.removeHost(ip)
|
||||
if !s.cfg.filterHost(host) {
|
||||
s.policy.RemoveHost(host)
|
||||
s.pool.removeHost(ip)
|
||||
}
|
||||
|
||||
if !s.cfg.IgnorePeerAddr {
|
||||
s.hostSource.refreshRing()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) handleNodeUp(eventIp net.IP, eventPort int, waitForBinary bool) {
|
||||
func (s *Session) handleNodeUp(eventIp net.IP, eventPort int) {
|
||||
if gocqlDebug {
|
||||
Logger.Printf("gocql: Session.handleNodeUp: %s:%d\n", eventIp.String(), eventPort)
|
||||
}
|
||||
|
||||
ip, _ := s.cfg.translateAddressPort(eventIp, eventPort)
|
||||
ip, port := s.cfg.translateAddressPort(eventIp, eventPort)
|
||||
|
||||
host := s.ring.getHost(ip)
|
||||
if host == nil {
|
||||
// TODO(zariel): avoid the need to translate twice in this
|
||||
// case
|
||||
s.handleNewNode(eventIp, eventPort, waitForBinary)
|
||||
s.addNewNode(ip, port)
|
||||
return
|
||||
}
|
||||
|
||||
if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
|
||||
return
|
||||
}
|
||||
host.setState(NodeUp)
|
||||
|
||||
if t := host.Version().nodeUpDelay(); t > 0 && waitForBinary {
|
||||
time.Sleep(t)
|
||||
if !s.cfg.filterHost(host) {
|
||||
s.policy.HostUp(host)
|
||||
}
|
||||
|
||||
s.addNewNode(host)
|
||||
}
|
||||
|
||||
func (s *Session) handleNodeDown(ip net.IP, port int) {
|
||||
|
@ -283,11 +279,11 @@ func (s *Session) handleNodeDown(ip net.IP, port int) {
|
|||
host = &HostInfo{connectAddress: ip, port: port}
|
||||
}
|
||||
|
||||
if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
|
||||
host.setState(NodeDown)
|
||||
if s.cfg.filterHost(host) {
|
||||
return
|
||||
}
|
||||
|
||||
host.setState(NodeDown)
|
||||
s.policy.HostDown(host)
|
||||
s.pool.hostDown(ip)
|
||||
}
|
||||
|
|
|
@ -311,26 +311,10 @@ var (
|
|||
|
||||
const maxFrameHeaderSize = 9
|
||||
|
||||
func writeInt(p []byte, n int32) {
|
||||
p[0] = byte(n >> 24)
|
||||
p[1] = byte(n >> 16)
|
||||
p[2] = byte(n >> 8)
|
||||
p[3] = byte(n)
|
||||
}
|
||||
|
||||
func readInt(p []byte) int32 {
|
||||
return int32(p[0])<<24 | int32(p[1])<<16 | int32(p[2])<<8 | int32(p[3])
|
||||
}
|
||||
|
||||
func writeShort(p []byte, n uint16) {
|
||||
p[0] = byte(n >> 8)
|
||||
p[1] = byte(n)
|
||||
}
|
||||
|
||||
func readShort(p []byte) uint16 {
|
||||
return uint16(p[0])<<8 | uint16(p[1])
|
||||
}
|
||||
|
||||
type frameHeader struct {
|
||||
version protoVersion
|
||||
flags byte
|
||||
|
@ -854,7 +838,7 @@ func (w *writePrepareFrame) writeFrame(f *framer, streamID int) error {
|
|||
if f.proto > protoVersion4 {
|
||||
flags |= flagWithPreparedKeyspace
|
||||
} else {
|
||||
panic(fmt.Errorf("The keyspace can only be set with protocol 5 or higher"))
|
||||
panic(fmt.Errorf("the keyspace can only be set with protocol 5 or higher"))
|
||||
}
|
||||
}
|
||||
if f.proto > protoVersion4 {
|
||||
|
@ -1502,7 +1486,7 @@ func (f *framer) writeQueryParams(opts *queryParams) {
|
|||
if f.proto > protoVersion4 {
|
||||
flags |= flagWithKeyspace
|
||||
} else {
|
||||
panic(fmt.Errorf("The keyspace can only be set with protocol 5 or higher"))
|
||||
panic(fmt.Errorf("the keyspace can only be set with protocol 5 or higher"))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1792,16 +1776,6 @@ func (f *framer) readShort() (n uint16) {
|
|||
return
|
||||
}
|
||||
|
||||
func (f *framer) readLong() (n int64) {
|
||||
if len(f.rbuf) < 8 {
|
||||
panic(fmt.Errorf("not enough bytes in buffer to read long require 8 got: %d", len(f.rbuf)))
|
||||
}
|
||||
n = int64(f.rbuf[0])<<56 | int64(f.rbuf[1])<<48 | int64(f.rbuf[2])<<40 | int64(f.rbuf[3])<<32 |
|
||||
int64(f.rbuf[4])<<24 | int64(f.rbuf[5])<<16 | int64(f.rbuf[6])<<8 | int64(f.rbuf[7])
|
||||
f.rbuf = f.rbuf[8:]
|
||||
return
|
||||
}
|
||||
|
||||
func (f *framer) readString() (s string) {
|
||||
size := f.readShort()
|
||||
|
||||
|
@ -1915,19 +1889,6 @@ func (f *framer) readConsistency() Consistency {
|
|||
return Consistency(f.readShort())
|
||||
}
|
||||
|
||||
func (f *framer) readStringMap() map[string]string {
|
||||
size := f.readShort()
|
||||
m := make(map[string]string, size)
|
||||
|
||||
for i := 0; i < int(size); i++ {
|
||||
k := f.readString()
|
||||
v := f.readString()
|
||||
m[k] = v
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (f *framer) readBytesMap() map[string][]byte {
|
||||
size := f.readShort()
|
||||
m := make(map[string][]byte, size)
|
||||
|
@ -2037,10 +1998,6 @@ func (f *framer) writeLongString(s string) {
|
|||
f.wbuf = append(f.wbuf, s...)
|
||||
}
|
||||
|
||||
func (f *framer) writeUUID(u *UUID) {
|
||||
f.wbuf = append(f.wbuf, u[:]...)
|
||||
}
|
||||
|
||||
func (f *framer) writeStringList(l []string) {
|
||||
f.writeShort(uint16(len(l)))
|
||||
for _, s := range l {
|
||||
|
@ -2073,18 +2030,6 @@ func (f *framer) writeShortBytes(p []byte) {
|
|||
f.wbuf = append(f.wbuf, p...)
|
||||
}
|
||||
|
||||
func (f *framer) writeInet(ip net.IP, port int) {
|
||||
f.wbuf = append(f.wbuf,
|
||||
byte(len(ip)),
|
||||
)
|
||||
|
||||
f.wbuf = append(f.wbuf,
|
||||
[]byte(ip)...,
|
||||
)
|
||||
|
||||
f.writeInt(int32(port))
|
||||
}
|
||||
|
||||
func (f *framer) writeConsistency(cons Consistency) {
|
||||
f.writeShort(uint16(cons))
|
||||
}
|
||||
|
|
|
@ -270,15 +270,6 @@ func getApacheCassandraType(class string) Type {
|
|||
}
|
||||
}
|
||||
|
||||
func typeCanBeNull(typ TypeInfo) bool {
|
||||
switch typ.(type) {
|
||||
case CollectionType, UDTTypeInfo, TupleTypeInfo:
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *RowData) rowMap(m map[string]interface{}) {
|
||||
for i, column := range r.Columns {
|
||||
val := dereference(r.Values[i])
|
||||
|
@ -372,7 +363,7 @@ func (iter *Iter) SliceMap() ([]map[string]interface{}, error) {
|
|||
// iter := session.Query(`SELECT * FROM mytable`).Iter()
|
||||
// for {
|
||||
// // New map each iteration
|
||||
// row = make(map[string]interface{})
|
||||
// row := make(map[string]interface{})
|
||||
// if !iter.MapScan(row) {
|
||||
// break
|
||||
// }
|
||||
|
|
|
@ -147,13 +147,6 @@ func (h *HostInfo) Peer() net.IP {
|
|||
return h.peer
|
||||
}
|
||||
|
||||
func (h *HostInfo) setPeer(peer net.IP) *HostInfo {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.peer = peer
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *HostInfo) invalidConnectAddr() bool {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
@ -233,13 +226,6 @@ func (h *HostInfo) DataCenter() string {
|
|||
return dc
|
||||
}
|
||||
|
||||
func (h *HostInfo) setDataCenter(dataCenter string) *HostInfo {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.dataCenter = dataCenter
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *HostInfo) Rack() string {
|
||||
h.mu.RLock()
|
||||
rack := h.rack
|
||||
|
@ -247,26 +233,12 @@ func (h *HostInfo) Rack() string {
|
|||
return rack
|
||||
}
|
||||
|
||||
func (h *HostInfo) setRack(rack string) *HostInfo {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.rack = rack
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *HostInfo) HostID() string {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.hostId
|
||||
}
|
||||
|
||||
func (h *HostInfo) setHostID(hostID string) *HostInfo {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.hostId = hostID
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *HostInfo) WorkLoad() string {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
@ -303,13 +275,6 @@ func (h *HostInfo) Version() cassVersion {
|
|||
return h.version
|
||||
}
|
||||
|
||||
func (h *HostInfo) setVersion(major, minor, patch int) *HostInfo {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.version = cassVersion{major, minor, patch}
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *HostInfo) State() nodeState {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
@ -329,26 +294,12 @@ func (h *HostInfo) Tokens() []string {
|
|||
return h.tokens
|
||||
}
|
||||
|
||||
func (h *HostInfo) setTokens(tokens []string) *HostInfo {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.tokens = tokens
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *HostInfo) Port() int {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.port
|
||||
}
|
||||
|
||||
func (h *HostInfo) setPort(port int) *HostInfo {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.port = port
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *HostInfo) update(from *HostInfo) {
|
||||
if h == from {
|
||||
return
|
||||
|
@ -689,7 +640,7 @@ func (r *ringDescriber) refreshRing() error {
|
|||
|
||||
// TODO: move this to session
|
||||
for _, h := range hosts {
|
||||
if filter := r.session.cfg.HostFilter; filter != nil && !filter.Accept(h) {
|
||||
if r.session.cfg.filterHost(h) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
|
|
@ -8,9 +8,3 @@ git clone https://github.com/pcmanus/ccm.git
|
|||
pushd ccm
|
||||
./setup.py install --user
|
||||
popd
|
||||
|
||||
if [ "$1" != "gocql/gocql" ]; then
|
||||
USER=$(echo $1 | cut -f1 -d'/')
|
||||
cd ../..
|
||||
mv ${USER} gocql
|
||||
fi
|
||||
|
|
|
@ -44,6 +44,52 @@ type Unmarshaler interface {
|
|||
|
||||
// Marshal returns the CQL encoding of the value for the Cassandra
|
||||
// internal type described by the info parameter.
|
||||
//
|
||||
// nil is serialized as CQL null.
|
||||
// If value implements Marshaler, its MarshalCQL method is called to marshal the data.
|
||||
// If value is a pointer, the pointed-to value is marshaled.
|
||||
//
|
||||
// Supported conversions are as follows, other type combinations may be added in the future:
|
||||
//
|
||||
// CQL type | Go type (value) | Note
|
||||
// varchar, ascii, blob, text | string, []byte |
|
||||
// boolean | bool |
|
||||
// tinyint, smallint, int | integer types |
|
||||
// tinyint, smallint, int | string | formatted as base 10 number
|
||||
// bigint, counter | integer types |
|
||||
// bigint, counter | big.Int |
|
||||
// bigint, counter | string | formatted as base 10 number
|
||||
// float | float32 |
|
||||
// double | float64 |
|
||||
// decimal | inf.Dec |
|
||||
// time | int64 | nanoseconds since start of day
|
||||
// time | time.Duration | duration since start of day
|
||||
// timestamp | int64 | milliseconds since Unix epoch
|
||||
// timestamp | time.Time |
|
||||
// list, set | slice, array |
|
||||
// list, set | map[X]struct{} |
|
||||
// map | map[X]Y |
|
||||
// uuid, timeuuid | gocql.UUID |
|
||||
// uuid, timeuuid | [16]byte | raw UUID bytes
|
||||
// uuid, timeuuid | []byte | raw UUID bytes, length must be 16 bytes
|
||||
// uuid, timeuuid | string | hex representation, see ParseUUID
|
||||
// varint | integer types |
|
||||
// varint | big.Int |
|
||||
// varint | string | value of number in decimal notation
|
||||
// inet | net.IP |
|
||||
// inet | string | IPv4 or IPv6 address string
|
||||
// tuple | slice, array |
|
||||
// tuple | struct | fields are marshaled in order of declaration
|
||||
// user-defined type | gocql.UDTMarshaler | MarshalUDT is called
|
||||
// user-defined type | map[string]interface{} |
|
||||
// user-defined type | struct | struct fields' cql tags are used for column names
|
||||
// date | int64 | milliseconds since Unix epoch to start of day (in UTC)
|
||||
// date | time.Time | start of day (in UTC)
|
||||
// date | string | parsed using "2006-01-02" format
|
||||
// duration | int64 | duration in nanoseconds
|
||||
// duration | time.Duration |
|
||||
// duration | gocql.Duration |
|
||||
// duration | string | parsed with time.ParseDuration
|
||||
func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
|
||||
if info.Version() < protoVersion1 {
|
||||
panic("protocol version not set")
|
||||
|
@ -118,6 +164,44 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
|
|||
// Unmarshal parses the CQL encoded data based on the info parameter that
|
||||
// describes the Cassandra internal data type and stores the result in the
|
||||
// value pointed by value.
|
||||
//
|
||||
// If value implements Unmarshaler, it's UnmarshalCQL method is called to
|
||||
// unmarshal the data.
|
||||
// If value is a pointer to pointer, it is set to nil if the CQL value is
|
||||
// null. Otherwise, nulls are unmarshalled as zero value.
|
||||
//
|
||||
// Supported conversions are as follows, other type combinations may be added in the future:
|
||||
//
|
||||
// CQL type | Go type (value) | Note
|
||||
// varchar, ascii, blob, text | *string |
|
||||
// varchar, ascii, blob, text | *[]byte | non-nil buffer is reused
|
||||
// bool | *bool |
|
||||
// tinyint, smallint, int, bigint, counter | *integer types |
|
||||
// tinyint, smallint, int, bigint, counter | *big.Int |
|
||||
// tinyint, smallint, int, bigint, counter | *string | formatted as base 10 number
|
||||
// float | *float32 |
|
||||
// double | *float64 |
|
||||
// decimal | *inf.Dec |
|
||||
// time | *int64 | nanoseconds since start of day
|
||||
// time | *time.Duration |
|
||||
// timestamp | *int64 | milliseconds since Unix epoch
|
||||
// timestamp | *time.Time |
|
||||
// list, set | *slice, *array |
|
||||
// map | *map[X]Y |
|
||||
// uuid, timeuuid | *string | see UUID.String
|
||||
// uuid, timeuuid | *[]byte | raw UUID bytes
|
||||
// uuid, timeuuid | *gocql.UUID |
|
||||
// timeuuid | *time.Time | timestamp of the UUID
|
||||
// inet | *net.IP |
|
||||
// inet | *string | IPv4 or IPv6 address string
|
||||
// tuple | *slice, *array |
|
||||
// tuple | *struct | struct fields are set in order of declaration
|
||||
// user-defined types | gocql.UDTUnmarshaler | UnmarshalUDT is called
|
||||
// user-defined types | *map[string]interface{} |
|
||||
// user-defined types | *struct | cql tag is used to determine field name
|
||||
// date | *time.Time | time of beginning of the day (in UTC)
|
||||
// date | *string | formatted with 2006-01-02 format
|
||||
// duration | *gocql.Duration |
|
||||
func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
|
||||
if v, ok := value.(Unmarshaler); ok {
|
||||
return v.UnmarshalCQL(info, data)
|
||||
|
@ -1690,6 +1774,8 @@ func marshalUUID(info TypeInfo, value interface{}) ([]byte, error) {
|
|||
return nil, nil
|
||||
case UUID:
|
||||
return val.Bytes(), nil
|
||||
case [16]byte:
|
||||
return val[:], nil
|
||||
case []byte:
|
||||
if len(val) != 16 {
|
||||
return nil, marshalErrorf("can not marshal []byte %d bytes long into %s, must be exactly 16 bytes long", len(val), info)
|
||||
|
@ -1711,7 +1797,7 @@ func marshalUUID(info TypeInfo, value interface{}) ([]byte, error) {
|
|||
}
|
||||
|
||||
func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error {
|
||||
if data == nil || len(data) == 0 {
|
||||
if len(data) == 0 {
|
||||
switch v := value.(type) {
|
||||
case *string:
|
||||
*v = ""
|
||||
|
@ -1726,9 +1812,22 @@ func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
if len(data) != 16 {
|
||||
return unmarshalErrorf("unable to parse UUID: UUIDs must be exactly 16 bytes long")
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case *[16]byte:
|
||||
copy((*v)[:], data)
|
||||
return nil
|
||||
case *UUID:
|
||||
copy((*v)[:], data)
|
||||
return nil
|
||||
}
|
||||
|
||||
u, err := UUIDFromBytes(data)
|
||||
if err != nil {
|
||||
return unmarshalErrorf("Unable to parse UUID: %s", err)
|
||||
return unmarshalErrorf("unable to parse UUID: %s", err)
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
|
@ -1738,9 +1837,6 @@ func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error {
|
|||
case *[]byte:
|
||||
*v = u[:]
|
||||
return nil
|
||||
case *UUID:
|
||||
*v = u
|
||||
return nil
|
||||
}
|
||||
return unmarshalErrorf("can not unmarshal X %s into %T", info, value)
|
||||
}
|
||||
|
@ -1942,7 +2038,7 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error {
|
|||
for i, elem := range tuple.Elems {
|
||||
// each element inside data is a [bytes]
|
||||
var p []byte
|
||||
if len(data) > 4 {
|
||||
if len(data) >= 4 {
|
||||
p, data = readBytes(data)
|
||||
}
|
||||
err := Unmarshal(elem, p, v[i])
|
||||
|
@ -1971,7 +2067,7 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error {
|
|||
|
||||
for i, elem := range tuple.Elems {
|
||||
var p []byte
|
||||
if len(data) > 4 {
|
||||
if len(data) >= 4 {
|
||||
p, data = readBytes(data)
|
||||
}
|
||||
|
||||
|
@ -1982,7 +2078,11 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error {
|
|||
|
||||
switch rv.Field(i).Kind() {
|
||||
case reflect.Ptr:
|
||||
rv.Field(i).Set(reflect.ValueOf(v))
|
||||
if p != nil {
|
||||
rv.Field(i).Set(reflect.ValueOf(v))
|
||||
} else {
|
||||
rv.Field(i).Set(reflect.Zero(reflect.TypeOf(v)))
|
||||
}
|
||||
default:
|
||||
rv.Field(i).Set(reflect.ValueOf(v).Elem())
|
||||
}
|
||||
|
@ -2001,7 +2101,7 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error {
|
|||
|
||||
for i, elem := range tuple.Elems {
|
||||
var p []byte
|
||||
if len(data) > 4 {
|
||||
if len(data) >= 4 {
|
||||
p, data = readBytes(data)
|
||||
}
|
||||
|
||||
|
@ -2012,7 +2112,11 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error {
|
|||
|
||||
switch rv.Index(i).Kind() {
|
||||
case reflect.Ptr:
|
||||
rv.Index(i).Set(reflect.ValueOf(v))
|
||||
if p != nil {
|
||||
rv.Index(i).Set(reflect.ValueOf(v))
|
||||
} else {
|
||||
rv.Index(i).Set(reflect.Zero(reflect.TypeOf(v)))
|
||||
}
|
||||
default:
|
||||
rv.Index(i).Set(reflect.ValueOf(v).Elem())
|
||||
}
|
||||
|
@ -2050,7 +2154,7 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) {
|
|||
case Marshaler:
|
||||
return v.MarshalCQL(info)
|
||||
case unsetColumn:
|
||||
return nil, unmarshalErrorf("Invalid request: UnsetValue is unsupported for user defined types")
|
||||
return nil, unmarshalErrorf("invalid request: UnsetValue is unsupported for user defined types")
|
||||
case UDTMarshaler:
|
||||
var buf []byte
|
||||
for _, e := range udt.Elements {
|
||||
|
|
|
@ -324,10 +324,10 @@ func compileMetadata(
|
|||
keyspace.Functions[functions[i].Name] = &functions[i]
|
||||
}
|
||||
keyspace.Aggregates = make(map[string]*AggregateMetadata, len(aggregates))
|
||||
for _, aggregate := range aggregates {
|
||||
aggregate.FinalFunc = *keyspace.Functions[aggregate.finalFunc]
|
||||
aggregate.StateFunc = *keyspace.Functions[aggregate.stateFunc]
|
||||
keyspace.Aggregates[aggregate.Name] = &aggregate
|
||||
for i, _ := range aggregates {
|
||||
aggregates[i].FinalFunc = *keyspace.Functions[aggregates[i].finalFunc]
|
||||
aggregates[i].StateFunc = *keyspace.Functions[aggregates[i].stateFunc]
|
||||
keyspace.Aggregates[aggregates[i].Name] = &aggregates[i]
|
||||
}
|
||||
keyspace.Views = make(map[string]*ViewMetadata, len(views))
|
||||
for i := range views {
|
||||
|
@ -347,9 +347,9 @@ func compileMetadata(
|
|||
keyspace.UserTypes[types[i].Name] = &types[i]
|
||||
}
|
||||
keyspace.MaterializedViews = make(map[string]*MaterializedViewMetadata, len(materializedViews))
|
||||
for _, materializedView := range materializedViews {
|
||||
materializedView.BaseTable = keyspace.Tables[materializedView.baseTableName]
|
||||
keyspace.MaterializedViews[materializedView.Name] = &materializedView
|
||||
for i, _ := range materializedViews {
|
||||
materializedViews[i].BaseTable = keyspace.Tables[materializedViews[i].baseTableName]
|
||||
keyspace.MaterializedViews[materializedViews[i].Name] = &materializedViews[i]
|
||||
}
|
||||
|
||||
// add columns from the schema data
|
||||
|
@ -559,7 +559,7 @@ func getKeyspaceMetadata(session *Session, keyspaceName string) (*KeyspaceMetada
|
|||
iter.Scan(&keyspace.DurableWrites, &replication)
|
||||
err := iter.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error querying keyspace schema: %v", err)
|
||||
return nil, fmt.Errorf("error querying keyspace schema: %v", err)
|
||||
}
|
||||
|
||||
keyspace.StrategyClass = replication["class"]
|
||||
|
@ -585,13 +585,13 @@ func getKeyspaceMetadata(session *Session, keyspaceName string) (*KeyspaceMetada
|
|||
iter.Scan(&keyspace.DurableWrites, &keyspace.StrategyClass, &strategyOptionsJSON)
|
||||
err := iter.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error querying keyspace schema: %v", err)
|
||||
return nil, fmt.Errorf("error querying keyspace schema: %v", err)
|
||||
}
|
||||
|
||||
err = json.Unmarshal(strategyOptionsJSON, &keyspace.StrategyOptions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"Invalid JSON value '%s' as strategy_options for in keyspace '%s': %v",
|
||||
"invalid JSON value '%s' as strategy_options for in keyspace '%s': %v",
|
||||
strategyOptionsJSON, keyspace.Name, err,
|
||||
)
|
||||
}
|
||||
|
@ -703,7 +703,7 @@ func getTableMetadata(session *Session, keyspaceName string) ([]TableMetadata, e
|
|||
if err != nil {
|
||||
iter.Close()
|
||||
return nil, fmt.Errorf(
|
||||
"Invalid JSON value '%s' as key_aliases for in table '%s': %v",
|
||||
"invalid JSON value '%s' as key_aliases for in table '%s': %v",
|
||||
keyAliasesJSON, table.Name, err,
|
||||
)
|
||||
}
|
||||
|
@ -716,7 +716,7 @@ func getTableMetadata(session *Session, keyspaceName string) ([]TableMetadata, e
|
|||
if err != nil {
|
||||
iter.Close()
|
||||
return nil, fmt.Errorf(
|
||||
"Invalid JSON value '%s' as column_aliases for in table '%s': %v",
|
||||
"invalid JSON value '%s' as column_aliases for in table '%s': %v",
|
||||
columnAliasesJSON, table.Name, err,
|
||||
)
|
||||
}
|
||||
|
@ -728,7 +728,7 @@ func getTableMetadata(session *Session, keyspaceName string) ([]TableMetadata, e
|
|||
|
||||
err := iter.Close()
|
||||
if err != nil && err != ErrNotFound {
|
||||
return nil, fmt.Errorf("Error querying table schema: %v", err)
|
||||
return nil, fmt.Errorf("error querying table schema: %v", err)
|
||||
}
|
||||
|
||||
return tables, nil
|
||||
|
@ -777,7 +777,7 @@ func (s *Session) scanColumnMetadataV1(keyspace string) ([]ColumnMetadata, error
|
|||
err := json.Unmarshal(indexOptionsJSON, &column.Index.Options)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"Invalid JSON value '%s' as index_options for column '%s' in table '%s': %v",
|
||||
"invalid JSON value '%s' as index_options for column '%s' in table '%s': %v",
|
||||
indexOptionsJSON,
|
||||
column.Name,
|
||||
column.Table,
|
||||
|
@ -837,7 +837,7 @@ func (s *Session) scanColumnMetadataV2(keyspace string) ([]ColumnMetadata, error
|
|||
err := json.Unmarshal(indexOptionsJSON, &column.Index.Options)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"Invalid JSON value '%s' as index_options for column '%s' in table '%s': %v",
|
||||
"invalid JSON value '%s' as index_options for column '%s' in table '%s': %v",
|
||||
indexOptionsJSON,
|
||||
column.Name,
|
||||
column.Table,
|
||||
|
@ -915,7 +915,7 @@ func getColumnMetadata(session *Session, keyspaceName string) ([]ColumnMetadata,
|
|||
}
|
||||
|
||||
if err != nil && err != ErrNotFound {
|
||||
return nil, fmt.Errorf("Error querying column schema: %v", err)
|
||||
return nil, fmt.Errorf("error querying column schema: %v", err)
|
||||
}
|
||||
|
||||
return columns, nil
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
// Copyright (c) 2012 The gocql Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
//This file will be the future home for more policies
|
||||
|
||||
package gocql
|
||||
|
||||
//This file will be the future home for more policies
|
||||
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
@ -37,12 +40,6 @@ func (c *cowHostList) get() []*HostInfo {
|
|||
return *l
|
||||
}
|
||||
|
||||
func (c *cowHostList) set(list []*HostInfo) {
|
||||
c.mu.Lock()
|
||||
c.list.Store(&list)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// add will add a host if it not already in the list
|
||||
func (c *cowHostList) add(host *HostInfo) bool {
|
||||
c.mu.Lock()
|
||||
|
@ -68,33 +65,6 @@ func (c *cowHostList) add(host *HostInfo) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func (c *cowHostList) update(host *HostInfo) {
|
||||
c.mu.Lock()
|
||||
l := c.get()
|
||||
|
||||
if len(l) == 0 {
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
found := false
|
||||
newL := make([]*HostInfo, len(l))
|
||||
for i := range l {
|
||||
if host.Equal(l[i]) {
|
||||
newL[i] = host
|
||||
found = true
|
||||
} else {
|
||||
newL[i] = l[i]
|
||||
}
|
||||
}
|
||||
|
||||
if found {
|
||||
c.list.Store(&newL)
|
||||
}
|
||||
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *cowHostList) remove(ip net.IP) bool {
|
||||
c.mu.Lock()
|
||||
l := c.get()
|
||||
|
@ -304,7 +274,10 @@ type HostSelectionPolicy interface {
|
|||
KeyspaceChanged(KeyspaceUpdateEvent)
|
||||
Init(*Session)
|
||||
IsLocal(host *HostInfo) bool
|
||||
//Pick returns an iteration function over selected hosts
|
||||
// Pick returns an iteration function over selected hosts.
|
||||
// Multiple attempts of a single query execution won't call the returned NextHost function concurrently,
|
||||
// so it's safe to have internal state without additional synchronization as long as every call to Pick returns
|
||||
// a different instance of NextHost.
|
||||
Pick(ExecutableQuery) NextHost
|
||||
}
|
||||
|
||||
|
@ -880,6 +853,51 @@ func (d *dcAwareRR) Pick(q ExecutableQuery) NextHost {
|
|||
return roundRobbin(int(nextStartOffset), d.localHosts.get(), d.remoteHosts.get())
|
||||
}
|
||||
|
||||
// ReadyPolicy defines a policy for when a HostSelectionPolicy can be used. After
|
||||
// each host connects during session initialization, the Ready method will be
|
||||
// called. If you only need a single Host to be up you can wrap a
|
||||
// HostSelectionPolicy policy with SingleHostReadyPolicy.
|
||||
type ReadyPolicy interface {
|
||||
Ready() bool
|
||||
}
|
||||
|
||||
// SingleHostReadyPolicy wraps a HostSelectionPolicy and returns Ready after a
|
||||
// single host has been added via HostUp
|
||||
func SingleHostReadyPolicy(p HostSelectionPolicy) *singleHostReadyPolicy {
|
||||
return &singleHostReadyPolicy{
|
||||
HostSelectionPolicy: p,
|
||||
}
|
||||
}
|
||||
|
||||
type singleHostReadyPolicy struct {
|
||||
HostSelectionPolicy
|
||||
ready bool
|
||||
readyMux sync.Mutex
|
||||
}
|
||||
|
||||
func (s *singleHostReadyPolicy) HostUp(host *HostInfo) {
|
||||
s.HostSelectionPolicy.HostUp(host)
|
||||
|
||||
s.readyMux.Lock()
|
||||
s.ready = true
|
||||
s.readyMux.Unlock()
|
||||
}
|
||||
|
||||
func (s *singleHostReadyPolicy) Ready() bool {
|
||||
s.readyMux.Lock()
|
||||
ready := s.ready
|
||||
s.readyMux.Unlock()
|
||||
if !ready {
|
||||
return false
|
||||
}
|
||||
|
||||
// in case the wrapped policy is also a ReadyPolicy, defer to that
|
||||
if rdy, ok := s.HostSelectionPolicy.(ReadyPolicy); ok {
|
||||
return rdy.Ready()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// ConvictionPolicy interface is used by gocql to determine if a host should be
|
||||
// marked as DOWN based on the error and host info
|
||||
type ConvictionPolicy interface {
|
||||
|
|
|
@ -14,18 +14,6 @@ type preparedLRU struct {
|
|||
lru *lru.Cache
|
||||
}
|
||||
|
||||
// Max adjusts the maximum size of the cache and cleans up the oldest records if
|
||||
// the new max is lower than the previous value. Not concurrency safe.
|
||||
func (p *preparedLRU) max(max int) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
for p.lru.Len() > max {
|
||||
p.lru.RemoveOldest()
|
||||
}
|
||||
p.lru.MaxEntries = max
|
||||
}
|
||||
|
||||
func (p *preparedLRU) clear() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
|
|
@ -2,6 +2,7 @@ package gocql
|
|||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -34,14 +35,15 @@ func (q *queryExecutor) attemptQuery(ctx context.Context, qry ExecutableQuery, c
|
|||
return iter
|
||||
}
|
||||
|
||||
func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp SpeculativeExecutionPolicy, results chan *Iter) *Iter {
|
||||
func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp SpeculativeExecutionPolicy,
|
||||
hostIter NextHost, results chan *Iter) *Iter {
|
||||
ticker := time.NewTicker(sp.Delay())
|
||||
defer ticker.Stop()
|
||||
|
||||
for i := 0; i < sp.Attempts(); i++ {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
go q.run(ctx, qry, results)
|
||||
go q.run(ctx, qry, hostIter, results)
|
||||
case <-ctx.Done():
|
||||
return &Iter{err: ctx.Err()}
|
||||
case iter := <-results:
|
||||
|
@ -53,11 +55,23 @@ func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp S
|
|||
}
|
||||
|
||||
func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
|
||||
hostIter := q.policy.Pick(qry)
|
||||
|
||||
// check if the query is not marked as idempotent, if
|
||||
// it is, we force the policy to NonSpeculative
|
||||
sp := qry.speculativeExecutionPolicy()
|
||||
if !qry.IsIdempotent() || sp.Attempts() == 0 {
|
||||
return q.do(qry.Context(), qry), nil
|
||||
return q.do(qry.Context(), qry, hostIter), nil
|
||||
}
|
||||
|
||||
// When speculative execution is enabled, we could be accessing the host iterator from multiple goroutines below.
|
||||
// To ensure we don't call it concurrently, we wrap the returned NextHost function here to synchronize access to it.
|
||||
var mu sync.Mutex
|
||||
origHostIter := hostIter
|
||||
hostIter = func() SelectedHost {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return origHostIter()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(qry.Context())
|
||||
|
@ -66,12 +80,12 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
|
|||
results := make(chan *Iter, 1)
|
||||
|
||||
// Launch the main execution
|
||||
go q.run(ctx, qry, results)
|
||||
go q.run(ctx, qry, hostIter, results)
|
||||
|
||||
// The speculative executions are launched _in addition_ to the main
|
||||
// execution, on a timer. So Speculation{2} would make 3 executions running
|
||||
// in total.
|
||||
if iter := q.speculate(ctx, qry, sp, results); iter != nil {
|
||||
if iter := q.speculate(ctx, qry, sp, hostIter, results); iter != nil {
|
||||
return iter, nil
|
||||
}
|
||||
|
||||
|
@ -83,8 +97,7 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery) *Iter {
|
||||
hostIter := q.policy.Pick(qry)
|
||||
func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter NextHost) *Iter {
|
||||
selectedHost := hostIter()
|
||||
rt := qry.retryPolicy()
|
||||
|
||||
|
@ -153,9 +166,9 @@ func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery) *Iter {
|
|||
return &Iter{err: ErrNoConnections}
|
||||
}
|
||||
|
||||
func (q *queryExecutor) run(ctx context.Context, qry ExecutableQuery, results chan<- *Iter) {
|
||||
func (q *queryExecutor) run(ctx context.Context, qry ExecutableQuery, hostIter NextHost, results chan<- *Iter) {
|
||||
select {
|
||||
case results <- q.do(ctx, qry):
|
||||
case results <- q.do(ctx, qry, hostIter):
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
|
|
|
@ -63,29 +63,6 @@ func (r *ring) currentHosts() map[string]*HostInfo {
|
|||
return hosts
|
||||
}
|
||||
|
||||
func (r *ring) addHost(host *HostInfo) bool {
|
||||
// TODO(zariel): key all host info by HostID instead of
|
||||
// ip addresses
|
||||
if host.invalidConnectAddr() {
|
||||
panic(fmt.Sprintf("invalid host: %v", host))
|
||||
}
|
||||
ip := host.ConnectAddress().String()
|
||||
|
||||
r.mu.Lock()
|
||||
if r.hosts == nil {
|
||||
r.hosts = make(map[string]*HostInfo)
|
||||
}
|
||||
|
||||
_, ok := r.hosts[ip]
|
||||
if !ok {
|
||||
r.hostList = append(r.hostList, host)
|
||||
}
|
||||
|
||||
r.hosts[ip] = host
|
||||
r.mu.Unlock()
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r *ring) addOrUpdate(host *HostInfo) *HostInfo {
|
||||
if existingHost, ok := r.addHostIfMissing(host); ok {
|
||||
existingHost.update(host)
|
||||
|
|
|
@ -27,7 +27,7 @@ import (
|
|||
// scenario is to have one global session object to interact with the
|
||||
// whole Cassandra cluster.
|
||||
//
|
||||
// This type extends the Node interface by adding a convinient query builder
|
||||
// This type extends the Node interface by adding a convenient query builder
|
||||
// and automatically sets a default consistency level on all operations
|
||||
// that do not have a consistency level set.
|
||||
type Session struct {
|
||||
|
@ -62,7 +62,6 @@ type Session struct {
|
|||
schemaEvents *eventDebouncer
|
||||
|
||||
// ring metadata
|
||||
hosts []HostInfo
|
||||
useSystemSchema bool
|
||||
hasAggregatesAndFunctions bool
|
||||
|
||||
|
@ -227,18 +226,44 @@ func (s *Session) init() error {
|
|||
}
|
||||
|
||||
hosts = hosts[:0]
|
||||
// each host will increment left and decrement it after connecting and once
|
||||
// there's none left, we'll close hostCh
|
||||
var left int64
|
||||
// we will receive up to len(hostMap) of messages so create a buffer so we
|
||||
// don't end up stuck in a goroutine if we stopped listening
|
||||
connectedCh := make(chan struct{}, len(hostMap))
|
||||
// we add one here because we don't want to end up closing hostCh until we're
|
||||
// done looping and the decerement code might be reached before we've looped
|
||||
// again
|
||||
atomic.AddInt64(&left, 1)
|
||||
for _, host := range hostMap {
|
||||
host = s.ring.addOrUpdate(host)
|
||||
host := s.ring.addOrUpdate(host)
|
||||
if s.cfg.filterHost(host) {
|
||||
continue
|
||||
}
|
||||
|
||||
host.setState(NodeUp)
|
||||
s.pool.addHost(host)
|
||||
atomic.AddInt64(&left, 1)
|
||||
go func() {
|
||||
s.pool.addHost(host)
|
||||
connectedCh <- struct{}{}
|
||||
|
||||
// if there are no hosts left, then close the hostCh to unblock the loop
|
||||
// below if its still waiting
|
||||
if atomic.AddInt64(&left, -1) == 0 {
|
||||
close(connectedCh)
|
||||
}
|
||||
}()
|
||||
|
||||
hosts = append(hosts, host)
|
||||
}
|
||||
// once we're done looping we subtract the one we initially added and check
|
||||
// to see if we should close
|
||||
if atomic.AddInt64(&left, -1) == 0 {
|
||||
close(connectedCh)
|
||||
}
|
||||
|
||||
// before waiting for them to connect, add them all to the policy so we can
|
||||
// utilize efficiencies by calling AddHosts if the policy supports it
|
||||
type bulkAddHosts interface {
|
||||
AddHosts([]*HostInfo)
|
||||
}
|
||||
|
@ -250,6 +275,15 @@ func (s *Session) init() error {
|
|||
}
|
||||
}
|
||||
|
||||
readyPolicy, _ := s.policy.(ReadyPolicy)
|
||||
// now loop over connectedCh until it's closed (meaning we've connected to all)
|
||||
// or until the policy says we're ready
|
||||
for range connectedCh {
|
||||
if readyPolicy != nil && readyPolicy.Ready() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(zariel): we probably dont need this any more as we verify that we
|
||||
// can connect to one of the endpoints supplied by using the control conn.
|
||||
// See if there are any connections in the pool
|
||||
|
@ -320,7 +354,8 @@ func (s *Session) reconnectDownedHosts(intv time.Duration) {
|
|||
if h.IsUp() {
|
||||
continue
|
||||
}
|
||||
s.handleNodeUp(h.ConnectAddress(), h.Port(), true)
|
||||
// we let the pool call handleNodeUp to change the host state
|
||||
s.pool.addHost(h)
|
||||
}
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
|
@ -806,6 +841,7 @@ type Query struct {
|
|||
trace Tracer
|
||||
observer QueryObserver
|
||||
session *Session
|
||||
conn *Conn
|
||||
rt RetryPolicy
|
||||
spec SpeculativeExecutionPolicy
|
||||
binding func(q *QueryInfo) ([]interface{}, error)
|
||||
|
@ -1094,12 +1130,17 @@ func (q *Query) speculativeExecutionPolicy() SpeculativeExecutionPolicy {
|
|||
return q.spec
|
||||
}
|
||||
|
||||
// IsIdempotent returns whether the query is marked as idempotent.
|
||||
// Non-idempotent query won't be retried.
|
||||
// See "Retries and speculative execution" in package docs for more details.
|
||||
func (q *Query) IsIdempotent() bool {
|
||||
return q.idempotent
|
||||
}
|
||||
|
||||
// Idempotent marks the query as being idempotent or not depending on
|
||||
// the value.
|
||||
// Non-idempotent query won't be retried.
|
||||
// See "Retries and speculative execution" in package docs for more details.
|
||||
func (q *Query) Idempotent(value bool) *Query {
|
||||
q.idempotent = value
|
||||
return q
|
||||
|
@ -1164,6 +1205,11 @@ func (q *Query) Iter() *Iter {
|
|||
if isUseStatement(q.stmt) {
|
||||
return &Iter{err: ErrUseStmt}
|
||||
}
|
||||
// if the query was specifically run on a connection then re-use that
|
||||
// connection when fetching the next results
|
||||
if q.conn != nil {
|
||||
return q.conn.executeQuery(q.Context(), q)
|
||||
}
|
||||
return q.session.executeQuery(q)
|
||||
}
|
||||
|
||||
|
@ -1195,6 +1241,10 @@ func (q *Query) Scan(dest ...interface{}) error {
|
|||
// statement containing an IF clause). If the transaction fails because
|
||||
// the existing values did not match, the previous values will be stored
|
||||
// in dest.
|
||||
//
|
||||
// As for INSERT .. IF NOT EXISTS, previous values will be returned as if
|
||||
// SELECT * FROM. So using ScanCAS with INSERT is inherently prone to
|
||||
// column mismatching. Use MapScanCAS to capture them safely.
|
||||
func (q *Query) ScanCAS(dest ...interface{}) (applied bool, err error) {
|
||||
q.disableSkipMetadata = true
|
||||
iter := q.Iter()
|
||||
|
@ -1423,7 +1473,7 @@ func (iter *Iter) Scan(dest ...interface{}) bool {
|
|||
}
|
||||
|
||||
if iter.next != nil && iter.pos >= iter.next.pos {
|
||||
go iter.next.fetch()
|
||||
iter.next.fetchAsync()
|
||||
}
|
||||
|
||||
// currently only support scanning into an expand tuple, such that its the same
|
||||
|
@ -1517,16 +1567,31 @@ func (iter *Iter) NumRows() int {
|
|||
return iter.numRows
|
||||
}
|
||||
|
||||
// nextIter holds state for fetching a single page in an iterator.
|
||||
// single page might be attempted multiple times due to retries.
|
||||
type nextIter struct {
|
||||
qry *Query
|
||||
pos int
|
||||
once sync.Once
|
||||
next *Iter
|
||||
qry *Query
|
||||
pos int
|
||||
oncea sync.Once
|
||||
once sync.Once
|
||||
next *Iter
|
||||
}
|
||||
|
||||
func (n *nextIter) fetchAsync() {
|
||||
n.oncea.Do(func() {
|
||||
go n.fetch()
|
||||
})
|
||||
}
|
||||
|
||||
func (n *nextIter) fetch() *Iter {
|
||||
n.once.Do(func() {
|
||||
n.next = n.qry.session.executeQuery(n.qry)
|
||||
// if the query was specifically run on a connection then re-use that
|
||||
// connection when fetching the next results
|
||||
if n.qry.conn != nil {
|
||||
n.next = n.qry.conn.executeQuery(n.qry.Context(), n.qry)
|
||||
} else {
|
||||
n.next = n.qry.session.executeQuery(n.qry)
|
||||
}
|
||||
})
|
||||
return n.next
|
||||
}
|
||||
|
@ -1536,7 +1601,6 @@ type Batch struct {
|
|||
Entries []BatchEntry
|
||||
Cons Consistency
|
||||
routingKey []byte
|
||||
routingKeyBuffer []byte
|
||||
CustomPayload map[string][]byte
|
||||
rt RetryPolicy
|
||||
spec SpeculativeExecutionPolicy
|
||||
|
@ -1733,7 +1797,7 @@ func (b *Batch) WithTimestamp(timestamp int64) *Batch {
|
|||
|
||||
func (b *Batch) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) {
|
||||
latency := end.Sub(start)
|
||||
_, metricsForHost := b.metrics.attempt(1, latency, host, b.observer != nil)
|
||||
attempt, metricsForHost := b.metrics.attempt(1, latency, host, b.observer != nil)
|
||||
|
||||
if b.observer == nil {
|
||||
return
|
||||
|
@ -1753,6 +1817,7 @@ func (b *Batch) attempt(keyspace string, end, start time.Time, iter *Iter, host
|
|||
Host: host,
|
||||
Metrics: metricsForHost,
|
||||
Err: iter.err,
|
||||
Attempt: attempt,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -1968,7 +2033,6 @@ type ObservedQuery struct {
|
|||
Err error
|
||||
|
||||
// Attempt is the index of attempt at executing this query.
|
||||
// An attempt might be either retry or fetching next page of a query.
|
||||
// The first attempt is number zero and any retries have non-zero attempt number.
|
||||
Attempt int
|
||||
}
|
||||
|
@ -1999,6 +2063,10 @@ type ObservedBatch struct {
|
|||
|
||||
// The metrics per this host
|
||||
Metrics *hostMetrics
|
||||
|
||||
// Attempt is the index of attempt at executing this query.
|
||||
// The first attempt is number zero and any retries have non-zero attempt number.
|
||||
Attempt int
|
||||
}
|
||||
|
||||
// BatchObserver is the interface implemented by batch observers / stat collectors.
|
||||
|
|
|
@ -153,7 +153,7 @@ func newTokenRing(partitioner string, hosts []*HostInfo) (*tokenRing, error) {
|
|||
} else if strings.HasSuffix(partitioner, "RandomPartitioner") {
|
||||
tokenRing.partitioner = randomPartitioner{}
|
||||
} else {
|
||||
return nil, fmt.Errorf("Unsupported partitioner '%s'", partitioner)
|
||||
return nil, fmt.Errorf("unsupported partitioner '%s'", partitioner)
|
||||
}
|
||||
|
||||
for _, host := range hosts {
|
||||
|
|
|
@ -46,32 +46,35 @@ type placementStrategy interface {
|
|||
replicationFactor(dc string) int
|
||||
}
|
||||
|
||||
func getReplicationFactorFromOpts(keyspace string, val interface{}) int {
|
||||
// TODO: dont really want to panic here, but is better
|
||||
// than spamming
|
||||
func getReplicationFactorFromOpts(val interface{}) (int, error) {
|
||||
switch v := val.(type) {
|
||||
case int:
|
||||
if v <= 0 {
|
||||
panic(fmt.Sprintf("invalid replication_factor %d. Is the %q keyspace configured correctly?", v, keyspace))
|
||||
if v < 0 {
|
||||
return 0, fmt.Errorf("invalid replication_factor %d", v)
|
||||
}
|
||||
return v
|
||||
return v, nil
|
||||
case string:
|
||||
n, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("invalid replication_factor. Is the %q keyspace configured correctly? %v", keyspace, err))
|
||||
} else if n <= 0 {
|
||||
panic(fmt.Sprintf("invalid replication_factor %d. Is the %q keyspace configured correctly?", n, keyspace))
|
||||
return 0, fmt.Errorf("invalid replication_factor %q: %v", v, err)
|
||||
} else if n < 0 {
|
||||
return 0, fmt.Errorf("invalid replication_factor %d", n)
|
||||
}
|
||||
return n
|
||||
return n, nil
|
||||
default:
|
||||
panic(fmt.Sprintf("unkown replication_factor type %T", v))
|
||||
return 0, fmt.Errorf("unknown replication_factor type %T", v)
|
||||
}
|
||||
}
|
||||
|
||||
func getStrategy(ks *KeyspaceMetadata) placementStrategy {
|
||||
switch {
|
||||
case strings.Contains(ks.StrategyClass, "SimpleStrategy"):
|
||||
return &simpleStrategy{rf: getReplicationFactorFromOpts(ks.Name, ks.StrategyOptions["replication_factor"])}
|
||||
rf, err := getReplicationFactorFromOpts(ks.StrategyOptions["replication_factor"])
|
||||
if err != nil {
|
||||
Logger.Printf("parse rf for keyspace %q: %v", ks.Name, err)
|
||||
return nil
|
||||
}
|
||||
return &simpleStrategy{rf: rf}
|
||||
case strings.Contains(ks.StrategyClass, "NetworkTopologyStrategy"):
|
||||
dcs := make(map[string]int)
|
||||
for dc, rf := range ks.StrategyOptions {
|
||||
|
@ -79,14 +82,21 @@ func getStrategy(ks *KeyspaceMetadata) placementStrategy {
|
|||
continue
|
||||
}
|
||||
|
||||
dcs[dc] = getReplicationFactorFromOpts(ks.Name+":dc="+dc, rf)
|
||||
rf, err := getReplicationFactorFromOpts(rf)
|
||||
if err != nil {
|
||||
Logger.Println("parse rf for keyspace %q, dc %q: %v", err)
|
||||
// skip DC if the rf is invalid/unsupported, so that we can at least work with other working DCs.
|
||||
continue
|
||||
}
|
||||
|
||||
dcs[dc] = rf
|
||||
}
|
||||
return &networkTopology{dcs: dcs}
|
||||
case strings.Contains(ks.StrategyClass, "LocalStrategy"):
|
||||
return nil
|
||||
default:
|
||||
// TODO: handle unknown replicas and just return the primary host for a token
|
||||
panic(fmt.Sprintf("unsupported strategy class: %v", ks.StrategyClass))
|
||||
Logger.Printf("parse rf for keyspace %q: unsupported strategy class: %v", ks.StrategyClass)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,11 +2,12 @@
|
|||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package gocql
|
||||
|
||||
// The uuid package can be used to generate and parse universally unique
|
||||
// identifiers, a standardized format in the form of a 128 bit number.
|
||||
//
|
||||
// http://tools.ietf.org/html/rfc4122
|
||||
package gocql
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
|
|
|
@ -402,7 +402,7 @@ github.com/go-stack/stack
|
|||
github.com/go-test/deep
|
||||
# github.com/go-yaml/yaml v2.1.0+incompatible
|
||||
github.com/go-yaml/yaml
|
||||
# github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e
|
||||
# github.com/gocql/gocql v0.0.0-20210401103645-80ab1e13e309
|
||||
## explicit
|
||||
github.com/gocql/gocql
|
||||
github.com/gocql/gocql/internal/lru
|
||||
|
|
Loading…
Reference in New Issue