Validate hostnames when using TLS in Cassandra (#11365)

This commit is contained in:
Michael Golowka 2021-04-16 15:52:35 -06:00 committed by GitHub
parent 541ae8636c
commit 4279bc8b34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 2386 additions and 526 deletions

View File

@ -20,13 +20,18 @@ func TestBackend_basic(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
cleanup, hostname := cassandra.PrepareTestContainer(t, "latest") copyFromTo := map[string]string{
"test-fixtures/cassandra.yaml": "/etc/cassandra/cassandra.yaml",
}
host, cleanup := cassandra.PrepareTestContainer(t,
cassandra.CopyFromTo(copyFromTo),
)
defer cleanup() defer cleanup()
logicaltest.Test(t, logicaltest.TestCase{ logicaltest.Test(t, logicaltest.TestCase{
LogicalBackend: b, LogicalBackend: b,
Steps: []logicaltest.TestStep{ Steps: []logicaltest.TestStep{
testAccStepConfig(t, hostname), testAccStepConfig(t, host.ConnectionURL()),
testAccStepRole(t), testAccStepRole(t),
testAccStepReadCreds(t, "test"), testAccStepReadCreds(t, "test"),
}, },
@ -41,13 +46,17 @@ func TestBackend_roleCrud(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
cleanup, hostname := cassandra.PrepareTestContainer(t, "latest") copyFromTo := map[string]string{
"test-fixtures/cassandra.yaml": "/etc/cassandra/cassandra.yaml",
}
host, cleanup := cassandra.PrepareTestContainer(t,
cassandra.CopyFromTo(copyFromTo))
defer cleanup() defer cleanup()
logicaltest.Test(t, logicaltest.TestCase{ logicaltest.Test(t, logicaltest.TestCase{
LogicalBackend: b, LogicalBackend: b,
Steps: []logicaltest.TestStep{ Steps: []logicaltest.TestStep{
testAccStepConfig(t, hostname), testAccStepConfig(t, host.ConnectionURL()),
testAccStepRole(t), testAccStepRole(t),
testAccStepRoleWithOptions(t), testAccStepRoleWithOptions(t),
testAccStepReadRole(t, "test", testRole), testAccStepReadRole(t, "test", testRole),

3
changelog/11365.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:bug
secrets/database/cassandra: Fixed issue where hostnames were not being validated when using TLS
```

2
go.mod
View File

@ -49,7 +49,7 @@ require (
github.com/go-ole/go-ole v1.2.4 // indirect github.com/go-ole/go-ole v1.2.4 // indirect
github.com/go-sql-driver/mysql v1.5.0 github.com/go-sql-driver/mysql v1.5.0
github.com/go-test/deep v1.0.7 github.com/go-test/deep v1.0.7
github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e github.com/gocql/gocql v0.0.0-20210401103645-80ab1e13e309
github.com/golang/protobuf v1.4.2 github.com/golang/protobuf v1.4.2
github.com/google/go-github v17.0.0+incompatible github.com/google/go-github v17.0.0+incompatible
github.com/google/go-metrics-stackdriver v0.2.0 github.com/google/go-metrics-stackdriver v0.2.0

6
go.sum
View File

@ -442,8 +442,8 @@ github.com/gobuffalo/packd v0.1.0/go.mod h1:M2Juc+hhDXf/PnmBANFCqx4DM3wRbgDvnVWe
github.com/gobuffalo/packr/v2 v2.0.9/go.mod h1:emmyGweYTm6Kdper+iywB6YK5YzuKchGtJQZ0Odn4pQ= github.com/gobuffalo/packr/v2 v2.0.9/go.mod h1:emmyGweYTm6Kdper+iywB6YK5YzuKchGtJQZ0Odn4pQ=
github.com/gobuffalo/packr/v2 v2.2.0/go.mod h1:CaAwI0GPIAv+5wKLtv8Afwl+Cm78K/I/VCm/3ptBN+0= github.com/gobuffalo/packr/v2 v2.2.0/go.mod h1:CaAwI0GPIAv+5wKLtv8Afwl+Cm78K/I/VCm/3ptBN+0=
github.com/gobuffalo/syncx v0.0.0-20190224160051-33c29581e754/go.mod h1:HhnNqWY95UYwwW3uSASeV7vtgYkT2t16hJgV3AEPUpw= github.com/gobuffalo/syncx v0.0.0-20190224160051-33c29581e754/go.mod h1:HhnNqWY95UYwwW3uSASeV7vtgYkT2t16hJgV3AEPUpw=
github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e h1:SroDcndcOU9BVAduPf/PXihXoR2ZYTQYLXbupbqxAyQ= github.com/gocql/gocql v0.0.0-20210401103645-80ab1e13e309 h1:8MHuCGYDXh0skFrLumkCMlt9C29hxhqNx39+Haemeqw=
github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e/go.mod h1:DL0ekTmBSTdlNF25Orwt/JMzqIq3EJ4MVa/J/uK64OY= github.com/gocql/gocql v0.0.0-20210401103645-80ab1e13e309/go.mod h1:DL0ekTmBSTdlNF25Orwt/JMzqIq3EJ4MVa/J/uK64OY=
github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4= github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4=
github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s=
@ -1084,6 +1084,7 @@ github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6So
github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.6.2 h1:aIihoIOHCiLZHxyoNQ+ABL4NKhFTgKLBdMLyEAh98m0=
github.com/rogpeppe/go-internal v1.6.2/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.6.2/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rs/zerolog v1.4.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= github.com/rs/zerolog v1.4.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
@ -1631,6 +1632,7 @@ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw=
gopkg.in/errgo.v2 v2.1.0 h1:0vLT13EuvQ0hNvakwLuFZ/jYrLp5F3kcWHXdRggjCE8=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=

View File

@ -2,9 +2,10 @@ package cassandra
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net"
"os" "os"
"path/filepath"
"testing" "testing"
"time" "time"
@ -12,33 +13,75 @@ import (
"github.com/hashicorp/vault/helper/testhelpers/docker" "github.com/hashicorp/vault/helper/testhelpers/docker"
) )
func PrepareTestContainer(t *testing.T, version string) (func(), string) { type containerConfig struct {
version string
copyFromTo map[string]string
sslOpts *gocql.SslOptions
}
type ContainerOpt func(*containerConfig)
func Version(version string) ContainerOpt {
return func(cfg *containerConfig) {
cfg.version = version
}
}
func CopyFromTo(copyFromTo map[string]string) ContainerOpt {
return func(cfg *containerConfig) {
cfg.copyFromTo = copyFromTo
}
}
func SslOpts(sslOpts *gocql.SslOptions) ContainerOpt {
return func(cfg *containerConfig) {
cfg.sslOpts = sslOpts
}
}
type Host struct {
Name string
Port string
}
func (h Host) ConnectionURL() string {
return net.JoinHostPort(h.Name, h.Port)
}
func PrepareTestContainer(t *testing.T, opts ...ContainerOpt) (Host, func()) {
t.Helper() t.Helper()
if os.Getenv("CASSANDRA_HOSTS") != "" { if os.Getenv("CASSANDRA_HOSTS") != "" {
return func() {}, os.Getenv("CASSANDRA_HOSTS") host, port, err := net.SplitHostPort(os.Getenv("CASSANDRA_HOSTS"))
if err != nil {
t.Fatalf("Failed to split host & port from CASSANDRA_HOSTS (%s): %s", os.Getenv("CASSANDRA_HOSTS"), err)
}
h := Host{
Name: host,
Port: port,
}
return h, func() {}
} }
if version == "" { containerCfg := &containerConfig{
version = "3.11" version: "3.11",
} }
var copyFromTo map[string]string for _, opt := range opts {
cwd, _ := os.Getwd() opt(containerCfg)
fixturePath := fmt.Sprintf("%s/test-fixtures/", cwd)
if _, err := os.Stat(fixturePath); err != nil {
if !errors.Is(err, os.ErrNotExist) {
// If it doesn't exist, no biggie
t.Fatal(err)
} }
} else {
copyFromTo = map[string]string{ copyFromTo := map[string]string{}
fixturePath: "/etc/cassandra", for from, to := range containerCfg.copyFromTo {
absFrom, err := filepath.Abs(from)
if err != nil {
t.Fatalf("Unable to get absolute path for file %s", from)
} }
copyFromTo[absFrom] = to
} }
runner, err := docker.NewServiceRunner(docker.RunOptions{ runner, err := docker.NewServiceRunner(docker.RunOptions{
ImageRepo: "cassandra", ImageRepo: "cassandra",
ImageTag: version, ImageTag: containerCfg.version,
Ports: []string{"9042/tcp"}, Ports: []string{"9042/tcp"},
CopyFromTo: copyFromTo, CopyFromTo: copyFromTo,
Env: []string{"CASSANDRA_BROADCAST_ADDRESS=127.0.0.1"}, Env: []string{"CASSANDRA_BROADCAST_ADDRESS=127.0.0.1"},
@ -58,6 +101,8 @@ func PrepareTestContainer(t *testing.T, version string) (func(), string) {
clusterConfig.ProtoVersion = 4 clusterConfig.ProtoVersion = 4
clusterConfig.Port = port clusterConfig.Port = port
clusterConfig.SslOpts = containerCfg.sslOpts
session, err := clusterConfig.CreateSession() session, err := clusterConfig.CreateSession()
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating session: %s", err) return nil, fmt.Errorf("error creating session: %s", err)
@ -65,19 +110,19 @@ func PrepareTestContainer(t *testing.T, version string) (func(), string) {
defer session.Close() defer session.Close()
// Create keyspace // Create keyspace
q := session.Query(`CREATE KEYSPACE "vault" WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };`) query := session.Query(`CREATE KEYSPACE "vault" WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };`)
if err := q.Exec(); err != nil { if err := query.Exec(); err != nil {
t.Fatalf("could not create cassandra keyspace: %v", err) t.Fatalf("could not create cassandra keyspace: %v", err)
} }
// Create table // Create table
q = session.Query(`CREATE TABLE "vault"."entries" ( query = session.Query(`CREATE TABLE "vault"."entries" (
bucket text, bucket text,
key text, key text,
value blob, value blob,
PRIMARY KEY (bucket, key) PRIMARY KEY (bucket, key)
) WITH CLUSTERING ORDER BY (key ASC);`) ) WITH CLUSTERING ORDER BY (key ASC);`)
if err := q.Exec(); err != nil { if err := query.Exec(); err != nil {
t.Fatalf("could not create cassandra table: %v", err) t.Fatalf("could not create cassandra table: %v", err)
} }
return cfg, nil return cfg, nil
@ -85,5 +130,14 @@ func PrepareTestContainer(t *testing.T, version string) (func(), string) {
if err != nil { if err != nil {
t.Fatalf("Could not start docker cassandra: %s", err) t.Fatalf("Could not start docker cassandra: %s", err)
} }
return svc.Cleanup, svc.Config.Address()
host, port, err := net.SplitHostPort(svc.Config.Address())
if err != nil {
t.Fatalf("Failed to split host & port from address (%s): %s", svc.Config.Address(), err)
}
h := Host{
Name: host,
Port: port,
}
return h, svc.Cleanup
} }

View File

@ -10,11 +10,10 @@ import (
"strings" "strings"
"time" "time"
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
metrics "github.com/armon/go-metrics" metrics "github.com/armon/go-metrics"
"github.com/gocql/gocql" "github.com/gocql/gocql"
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/physical" "github.com/hashicorp/vault/sdk/physical"
) )
@ -180,8 +179,7 @@ func setupCassandraTLS(conf map[string]string, cluster *gocql.ClusterConfig) err
if err != nil { if err != nil {
return err return err
} }
} else { } else if pemJSONPath, ok := conf["pem_json_file"]; ok {
if pemJSONPath, ok := conf["pem_json_file"]; ok {
pemJSONData, err := ioutil.ReadFile(pemJSONPath) pemJSONData, err := ioutil.ReadFile(pemJSONPath)
if err != nil { if err != nil {
return errwrap.Wrapf(fmt.Sprintf("error reading json bundle from %q: {{err}}", pemJSONPath), err) return errwrap.Wrapf(fmt.Sprintf("error reading json bundle from %q: {{err}}", pemJSONPath), err)
@ -195,7 +193,6 @@ func setupCassandraTLS(conf map[string]string, cluster *gocql.ClusterConfig) err
return err return err
} }
} }
}
if tlsSkipVerifyStr, ok := conf["tls_skip_verify"]; ok { if tlsSkipVerifyStr, ok := conf["tls_skip_verify"]; ok {
tlsSkipVerify, err := strconv.Atoi(tlsSkipVerifyStr) tlsSkipVerify, err := strconv.Atoi(tlsSkipVerifyStr)
@ -225,7 +222,8 @@ func setupCassandraTLS(conf map[string]string, cluster *gocql.ClusterConfig) err
} }
cluster.SslOpts = &gocql.SslOptions{ cluster.SslOpts = &gocql.SslOptions{
Config: tlsConfig.Clone(), Config: tlsConfig,
EnableHostVerification: !tlsConfig.InsecureSkipVerify,
} }
return nil return nil
} }

View File

@ -19,13 +19,13 @@ func TestCassandraBackend(t *testing.T) {
t.Skip("skipping race test in CI pending https://github.com/gocql/gocql/pull/1474") t.Skip("skipping race test in CI pending https://github.com/gocql/gocql/pull/1474")
} }
cleanup, hosts := cassandra.PrepareTestContainer(t, "") host, cleanup := cassandra.PrepareTestContainer(t)
defer cleanup() defer cleanup()
// Run vault tests // Run vault tests
logger := logging.NewVaultLogger(log.Debug) logger := logging.NewVaultLogger(log.Debug)
b, err := NewCassandraBackend(map[string]string{ b, err := NewCassandraBackend(map[string]string{
"hosts": hosts, "hosts": host.ConnectionURL(),
"protocol_version": "3", "protocol_version": "3",
}, logger) }, logger)
if err != nil { if err != nil {

View File

@ -3,7 +3,6 @@ package cassandra
import ( import (
"context" "context"
"reflect" "reflect"
"strings"
"testing" "testing"
"time" "time"
@ -17,14 +16,16 @@ import (
) )
func getCassandra(t *testing.T, protocolVersion interface{}) (*Cassandra, func()) { func getCassandra(t *testing.T, protocolVersion interface{}) (*Cassandra, func()) {
cleanup, connURL := cassandra.PrepareTestContainer(t, "latest") host, cleanup := cassandra.PrepareTestContainer(t,
pieces := strings.Split(connURL, ":") cassandra.Version("latest"),
cassandra.CopyFromTo(insecureFileMounts),
)
db := new() db := new()
initReq := dbplugin.InitializeRequest{ initReq := dbplugin.InitializeRequest{
Config: map[string]interface{}{ Config: map[string]interface{}{
"hosts": connURL, "hosts": host.ConnectionURL(),
"port": pieces[1], "port": host.Port,
"username": "cassandra", "username": "cassandra",
"password": "cassandra", "password": "cassandra",
"protocol_version": protocolVersion, "protocol_version": protocolVersion,
@ -34,8 +35,8 @@ func getCassandra(t *testing.T, protocolVersion interface{}) (*Cassandra, func()
} }
expectedConfig := map[string]interface{}{ expectedConfig := map[string]interface{}{
"hosts": connURL, "hosts": host.ConnectionURL(),
"port": pieces[1], "port": host.Port,
"username": "cassandra", "username": "cassandra",
"password": "cassandra", "password": "cassandra",
"protocol_version": protocolVersion, "protocol_version": protocolVersion,
@ -53,7 +54,7 @@ func getCassandra(t *testing.T, protocolVersion interface{}) (*Cassandra, func()
return db, cleanup return db, cleanup
} }
func TestCassandra_Initialize(t *testing.T) { func TestInitialize(t *testing.T) {
db, cleanup := getCassandra(t, 4) db, cleanup := getCassandra(t, 4)
defer cleanup() defer cleanup()
@ -66,7 +67,7 @@ func TestCassandra_Initialize(t *testing.T) {
defer cleanup() defer cleanup()
} }
func TestCassandra_CreateUser(t *testing.T) { func TestCreateUser(t *testing.T) {
type testCase struct { type testCase struct {
// Config will have the hosts & port added to it during the test // Config will have the hosts & port added to it during the test
config map[string]interface{} config map[string]interface{}
@ -126,15 +127,17 @@ func TestCassandra_CreateUser(t *testing.T) {
for name, test := range tests { for name, test := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
cleanup, connURL := cassandra.PrepareTestContainer(t, "latest") host, cleanup := cassandra.PrepareTestContainer(t,
pieces := strings.Split(connURL, ":") cassandra.Version("latest"),
cassandra.CopyFromTo(insecureFileMounts),
)
defer cleanup() defer cleanup()
db := new() db := new()
config := test.config config := test.config
config["hosts"] = connURL config["hosts"] = host.ConnectionURL()
config["port"] = pieces[1] config["port"] = host.Port
initReq := dbplugin.InitializeRequest{ initReq := dbplugin.InitializeRequest{
Config: config, Config: config,
@ -162,7 +165,7 @@ func TestCassandra_CreateUser(t *testing.T) {
} }
} }
func TestMyCassandra_UpdateUserPassword(t *testing.T) { func TestUpdateUserPassword(t *testing.T) {
db, cleanup := getCassandra(t, 4) db, cleanup := getCassandra(t, 4)
defer cleanup() defer cleanup()
@ -198,7 +201,7 @@ func TestMyCassandra_UpdateUserPassword(t *testing.T) {
assertCreds(t, db.Hosts, db.Port, createResp.Username, newPassword, 5*time.Second) assertCreds(t, db.Hosts, db.Port, createResp.Username, newPassword, 5*time.Second)
} }
func TestCassandra_DeleteUser(t *testing.T) { func TestDeleteUser(t *testing.T) {
db, cleanup := getCassandra(t, 4) db, cleanup := getCassandra(t, 4)
defer cleanup() defer cleanup()

View File

@ -8,14 +8,13 @@ import (
"sync" "sync"
"time" "time"
"github.com/gocql/gocql"
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/hashicorp/vault/sdk/database/helper/connutil" "github.com/hashicorp/vault/sdk/database/helper/connutil"
"github.com/hashicorp/vault/sdk/database/helper/dbutil" "github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/parseutil" "github.com/hashicorp/vault/sdk/helper/parseutil"
"github.com/hashicorp/vault/sdk/helper/tlsutil" "github.com/hashicorp/vault/sdk/helper/tlsutil"
"github.com/gocql/gocql"
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
) )
@ -40,9 +39,7 @@ type cassandraConnectionProducer struct {
connectTimeout time.Duration connectTimeout time.Duration
socketKeepAlive time.Duration socketKeepAlive time.Duration
certificate string certBundle *certutil.CertBundle
privateKey string
issuingCA string
rawConfig map[string]interface{} rawConfig map[string]interface{}
Initialized bool Initialized bool
@ -99,9 +96,7 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, req dbplug
if err != nil { if err != nil {
return fmt.Errorf("error marshaling PEM information: %w", err) return fmt.Errorf("error marshaling PEM information: %w", err)
} }
c.certificate = certBundle.Certificate c.certBundle = certBundle
c.privateKey = certBundle.PrivateKey
c.issuingCA = certBundle.IssuingCA
c.TLS = true c.TLS = true
case len(c.PemBundle) != 0: case len(c.PemBundle) != 0:
@ -113,9 +108,11 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, req dbplug
if err != nil { if err != nil {
return fmt.Errorf("error marshaling PEM information: %w", err) return fmt.Errorf("error marshaling PEM information: %w", err)
} }
c.certificate = certBundle.Certificate c.certBundle = certBundle
c.privateKey = certBundle.PrivateKey c.TLS = true
c.issuingCA = certBundle.IssuingCA }
if c.InsecureTLS {
c.TLS = true c.TLS = true
} }
@ -185,49 +182,13 @@ func (c *cassandraConnectionProducer) createSession(ctx context.Context) (*gocql
clusterConfig.Timeout = c.connectTimeout clusterConfig.Timeout = c.connectTimeout
clusterConfig.SocketKeepalive = c.socketKeepAlive clusterConfig.SocketKeepalive = c.socketKeepAlive
if c.TLS { if c.TLS {
var tlsConfig *tls.Config sslOpts, err := getSslOpts(c.certBundle, c.TLSMinVersion, c.InsecureTLS)
if len(c.certificate) > 0 || len(c.issuingCA) > 0 {
if len(c.certificate) > 0 && len(c.privateKey) == 0 {
return nil, fmt.Errorf("found certificate for TLS authentication but no private key")
}
certBundle := &certutil.CertBundle{}
if len(c.certificate) > 0 {
certBundle.Certificate = c.certificate
certBundle.PrivateKey = c.privateKey
}
if len(c.issuingCA) > 0 {
certBundle.IssuingCA = c.issuingCA
}
parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse certificate bundle: %w", err) return nil, err
}
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
if err != nil || tlsConfig == nil {
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%w", tlsConfig, err)
}
tlsConfig.InsecureSkipVerify = c.InsecureTLS
if c.TLSMinVersion != "" {
var ok bool
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[c.TLSMinVersion]
if !ok {
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
}
} else {
// MinVersion was not being set earlier. Reset it to
// zero to gracefully handle upgrades.
tlsConfig.MinVersion = 0
}
}
clusterConfig.SslOpts = &gocql.SslOptions{
Config: tlsConfig,
} }
clusterConfig.SslOpts = sslOpts
} }
if c.LocalDatacenter != "" { if c.LocalDatacenter != "" {
@ -269,6 +230,48 @@ func (c *cassandraConnectionProducer) createSession(ctx context.Context) (*gocql
return session, nil return session, nil
} }
func getSslOpts(certBundle *certutil.CertBundle, minTLSVersion string, insecureSkipVerify bool) (*gocql.SslOptions, error) {
tlsConfig := &tls.Config{}
if certBundle != nil {
if certBundle.Certificate == "" && certBundle.PrivateKey != "" {
return nil, fmt.Errorf("found private key for TLS authentication but no certificate")
}
if certBundle.Certificate != "" && certBundle.PrivateKey == "" {
return nil, fmt.Errorf("found certificate for TLS authentication but no private key")
}
parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil {
return nil, fmt.Errorf("failed to parse certificate bundle: %w", err)
}
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
if err != nil {
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%w", tlsConfig, err)
}
}
tlsConfig.InsecureSkipVerify = insecureSkipVerify
if minTLSVersion != "" {
var ok bool
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[minTLSVersion]
if !ok {
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
}
} else {
// MinVersion was not being set earlier. Reset it to
// zero to gracefully handle upgrades.
tlsConfig.MinVersion = 0
}
opts := &gocql.SslOptions{
Config: tlsConfig,
EnableHostVerification: !insecureSkipVerify,
}
return opts, nil
}
func (c *cassandraConnectionProducer) secretValues() map[string]string { func (c *cassandraConnectionProducer) secretValues() map[string]string {
return map[string]string{ return map[string]string{
c.Password: "[password]", c.Password: "[password]",

View File

@ -0,0 +1,95 @@
package cassandra
import (
"context"
"crypto/tls"
"testing"
"time"
"github.com/gocql/gocql"
"github.com/hashicorp/vault/helper/testhelpers/cassandra"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5"
)
var (
insecureFileMounts = map[string]string{
"test-fixtures/no_tls/cassandra.yaml": "/etc/cassandra/cassandra.yaml",
}
secureFileMounts = map[string]string{
"test-fixtures/with_tls/cassandra.yaml": "/etc/cassandra/cassandra.yaml",
"test-fixtures/with_tls/keystore.jks": "/etc/cassandra/keystore.jks",
"test-fixtures/with_tls/.cassandra": "/root/.cassandra/",
}
)
func TestTLSConnection(t *testing.T) {
type testCase struct {
config map[string]interface{}
expectErr bool
}
tests := map[string]testCase{
"tls not specified": {
config: map[string]interface{}{},
expectErr: true,
},
"unrecognized certificate": {
config: map[string]interface{}{
"tls": "true",
},
expectErr: true,
},
"insecure TLS": {
config: map[string]interface{}{
"tls": "true",
"insecure_tls": true,
},
expectErr: false,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
host, cleanup := cassandra.PrepareTestContainer(t,
cassandra.Version("3.11.9"),
cassandra.CopyFromTo(secureFileMounts),
cassandra.SslOpts(&gocql.SslOptions{
Config: &tls.Config{InsecureSkipVerify: true},
EnableHostVerification: false,
}),
)
defer cleanup()
// Set values that we don't know until the cassandra container is started
config := map[string]interface{}{
"hosts": host.ConnectionURL(),
"port": host.Port,
"username": "cassandra",
"password": "cassandra",
"protocol_version": "3",
"connect_timeout": "20s",
}
// Then add any values specified in the test config. Generally for these tests they shouldn't overlap
for k, v := range test.config {
config[k] = v
}
db := new()
initReq := dbplugin.InitializeRequest{
Config: config,
VerifyConnection: true,
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err := db.Initialize(ctx, initReq)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
})
}
}

View File

@ -0,0 +1,3 @@
[ssl]
validate = false
version = SSLv23

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,46 @@
#!/bin/sh
################################################################
# Usage: ./gencert.sh
#
# Generates a keystore.jks file that can be used with a
# Cassandra server for TLS connections. This does not update
# a cassandra config file.
################################################################
set -e
KEYFILE="key.pem"
CERTFILE="cert.pem"
PKCSFILE="keystore.p12"
JKSFILE="keystore.jks"
HOST="127.0.0.1"
NAME="cassandra"
ALIAS="cassandra"
PASSWORD="cassandra"
echo "# Generating certificate keypair..."
go run /usr/local/go/src/crypto/tls/generate_cert.go --host=${HOST}
echo "# Creating keystore..."
openssl pkcs12 -export -in ${CERTFILE} -inkey ${KEYFILE} -name ${NAME} -password pass:${PASSWORD} > ${PKCSFILE}
echo "# Creating Java key store"
if [ -e "${JKSFILE}" ]; then
echo "# Removing old key store"
rm ${JKSFILE}
fi
set +e
keytool -importkeystore \
-srckeystore ${PKCSFILE} \
-srcstoretype PKCS12 \
-srcstorepass ${PASSWORD} \
-destkeystore ${JKSFILE} \
-deststorepass ${PASSWORD} \
-destkeypass ${PASSWORD} \
-alias ${ALIAS}
echo "# Removing intermediate files"
rm ${KEYFILE} ${CERTFILE} ${PKCSFILE}

Binary file not shown.

View File

@ -31,8 +31,10 @@ env:
AUTH=false AUTH=false
go: go:
- 1.13.x - 1.15.x
- 1.14.x - 1.16.x
go_import_path: github.com/gocql/gocql
install: install:
- ./install_test_deps.sh $TRAVIS_REPO_SLUG - ./install_test_deps.sh $TRAVIS_REPO_SLUG

View File

@ -115,3 +115,8 @@ Pavel Buchinchik <p.buchinchik@gmail.com>
Rintaro Okamura <rintaro.okamura@gmail.com> Rintaro Okamura <rintaro.okamura@gmail.com>
Yura Sokolov <y.sokolov@joom.com>; <funny.falcon@gmail.com> Yura Sokolov <y.sokolov@joom.com>; <funny.falcon@gmail.com>
Jorge Bay <jorgebg@apache.org> Jorge Bay <jorgebg@apache.org>
Dmitriy Kozlov <hummerd@mail.ru>
Alexey Romanovsky <alexus1024+gocql@gmail.com>
Jaume Marhuenda Beltran <jaumemarhuenda@gmail.com>
Piotr Dulikowski <piodul@scylladb.com>
Árni Dagur <arni@dagur.eu>

View File

@ -19,8 +19,8 @@ The following matrix shows the versions of Go and Cassandra that are tested with
Go/Cassandra | 2.1.x | 2.2.x | 3.x.x Go/Cassandra | 2.1.x | 2.2.x | 3.x.x
-------------| -------| ------| --------- -------------| -------| ------| ---------
1.13 | yes | yes | yes 1.15 | yes | yes | yes
1.14 | yes | yes | yes 1.16 | yes | yes | yes
Gocql has been tested in production against many different versions of Cassandra. Due to limits in our CI setup we only test against the latest 3 major releases, which coincide with the official support from the Apache project. Gocql has been tested in production against many different versions of Cassandra. Due to limits in our CI setup we only test against the latest 3 major releases, which coincide with the official support from the Apache project.
@ -114,73 +114,7 @@ statement.
Example Example
------- -------
```go See [package documentation](https://pkg.go.dev/github.com/gocql/gocql#pkg-examples).
/* Before you execute the program, Launch `cqlsh` and execute:
create keyspace example with replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };
create table example.tweet(timeline text, id UUID, text text, PRIMARY KEY(id));
create index on example.tweet(timeline);
*/
package main
import (
"fmt"
"log"
"github.com/gocql/gocql"
)
func main() {
// connect to the cluster
cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
cluster.Keyspace = "example"
cluster.Consistency = gocql.Quorum
session, _ := cluster.CreateSession()
defer session.Close()
// insert a tweet
if err := session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
"me", gocql.TimeUUID(), "hello world").Exec(); err != nil {
log.Fatal(err)
}
var id gocql.UUID
var text string
/* Search for a specific set of records whose 'timeline' column matches
* the value 'me'. The secondary index that we created earlier will be
* used for optimizing the search */
if err := session.Query(`SELECT id, text FROM tweet WHERE timeline = ? LIMIT 1`,
"me").Consistency(gocql.One).Scan(&id, &text); err != nil {
log.Fatal(err)
}
fmt.Println("Tweet:", id, text)
// list all tweets
iter := session.Query(`SELECT id, text FROM tweet WHERE timeline = ?`, "me").Iter()
for iter.Scan(&id, &text) {
fmt.Println("Tweet:", id, text)
}
if err := iter.Close(); err != nil {
log.Fatal(err)
}
}
```
Authentication
-------
```go
cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
cluster.Authenticator = gocql.PasswordAuthenticator{
Username: "user",
Password: "password"
}
cluster.Keyspace = "example"
cluster.Consistency = gocql.Quorum
session, _ := cluster.CreateSession()
defer session.Close()
```
Data Binding Data Binding
------------ ------------

View File

@ -44,8 +44,8 @@ func approve(authenticator string) bool {
return false return false
} }
//JoinHostPort is a utility to return a address string that can be used // JoinHostPort is a utility to return an address string that can be used
//gocql.Conn to form a connection with a host. // by `gocql.Conn` to form a connection with a host.
func JoinHostPort(addr string, port int) string { func JoinHostPort(addr string, port int) string {
addr = strings.TrimSpace(addr) addr = strings.TrimSpace(addr)
if _, _, err := net.SplitHostPort(addr); err != nil { if _, _, err := net.SplitHostPort(addr); err != nil {
@ -80,6 +80,19 @@ func (p PasswordAuthenticator) Success(data []byte) error {
return nil return nil
} }
// SslOptions configures TLS use.
//
// Warning: Due to historical reasons, the SslOptions is insecure by default, so you need to set EnableHostVerification
// to true if no Config is set. Most users should set SslOptions.Config to a *tls.Config.
// SslOptions and Config.InsecureSkipVerify interact as follows:
//
// Config.InsecureSkipVerify | EnableHostVerification | Result
// Config is nil | false | do not verify host
// Config is nil | true | verify host
// false | false | verify host
// true | false | do not verify host
// false | true | verify host
// true | true | verify host
type SslOptions struct { type SslOptions struct {
*tls.Config *tls.Config
@ -89,9 +102,12 @@ type SslOptions struct {
CertPath string CertPath string
KeyPath string KeyPath string
CaPath string //optional depending on server config CaPath string //optional depending on server config
// If you want to verify the hostname and server cert (like a wildcard for cass cluster) then you should turn this on // If you want to verify the hostname and server cert (like a wildcard for cass cluster) then you should turn this
// This option is basically the inverse of InSecureSkipVerify // on.
// See InSecureSkipVerify in http://golang.org/pkg/crypto/tls/ for more info // This option is basically the inverse of tls.Config.InsecureSkipVerify.
// See InsecureSkipVerify in http://golang.org/pkg/crypto/tls/ for more info.
//
// See SslOptions documentation to see how EnableHostVerification interacts with the provided tls.Config.
EnableHostVerification bool EnableHostVerification bool
} }
@ -125,7 +141,7 @@ func (fn connErrorHandlerFn) HandleError(conn *Conn, err error, closed bool) {
// which may be serving more queries just fine. // which may be serving more queries just fine.
// Default is 0, should not be changed concurrently with queries. // Default is 0, should not be changed concurrently with queries.
// //
// depreciated // Deprecated.
var TimeoutLimit int64 = 0 var TimeoutLimit int64 = 0
// Conn is a single connection to a Cassandra node. It can be used to execute // Conn is a single connection to a Cassandra node. It can be used to execute
@ -213,14 +229,26 @@ func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg *
dialer = d dialer = d
} }
conn, err := dialer.DialContext(ctx, "tcp", host.HostnameAndPort()) addr := host.HostnameAndPort()
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if cfg.tlsConfig != nil { if cfg.tlsConfig != nil {
// the TLS config is safe to be reused by connections but it must not // the TLS config is safe to be reused by connections but it must not
// be modified after being used. // be modified after being used.
tconn := tls.Client(conn, cfg.tlsConfig) tlsConfig := cfg.tlsConfig
if !tlsConfig.InsecureSkipVerify && tlsConfig.ServerName == "" {
colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
colonPos = len(addr)
}
hostname := addr[:colonPos]
// clone config to avoid modifying the shared one.
tlsConfig = tlsConfig.Clone()
tlsConfig.ServerName = hostname
}
tconn := tls.Client(conn, tlsConfig)
if err := tconn.Handshake(); err != nil { if err := tconn.Handshake(); err != nil {
conn.Close() conn.Close()
return nil, err return nil, err
@ -845,6 +873,10 @@ func (w *writeCoalescer) writeFlusher(interval time.Duration) {
} }
func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*framer, error) { func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*framer, error) {
if ctxErr := ctx.Err(); ctxErr != nil {
return nil, ctxErr
}
// TODO: move tracer onto conn // TODO: move tracer onto conn
stream, ok := c.streams.GetStream() stream, ok := c.streams.GetStream()
if !ok { if !ok {
@ -1173,12 +1205,16 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
} }
if x.meta.morePages() && !qry.disableAutoPage { if x.meta.morePages() && !qry.disableAutoPage {
newQry := new(Query)
*newQry = *qry
newQry.pageState = copyBytes(x.meta.pagingState)
newQry.metrics = &queryMetrics{m: make(map[string]*hostMetrics)}
iter.next = &nextIter{ iter.next = &nextIter{
qry: qry, qry: newQry,
pos: int((1 - qry.prefetch) * float64(x.numRows)), pos: int((1 - qry.prefetch) * float64(x.numRows)),
} }
iter.next.qry.pageState = copyBytes(x.meta.pagingState)
if iter.next.pos < 1 { if iter.next.pos < 1 {
iter.next.pos = 1 iter.next.pos = 1
} }
@ -1359,10 +1395,11 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter {
} }
func (c *Conn) query(ctx context.Context, statement string, values ...interface{}) (iter *Iter) { func (c *Conn) query(ctx context.Context, statement string, values ...interface{}) (iter *Iter) {
q := c.session.Query(statement, values...).Consistency(One) q := c.session.Query(statement, values...).Consistency(One).Trace(nil)
q.trace = nil
q.skipPrepare = true q.skipPrepare = true
q.disableSkipMetadata = true q.disableSkipMetadata = true
// we want to keep the query on this connection
q.conn = c
return c.executeQuery(ctx, q) return c.executeQuery(ctx, q)
} }

View File

@ -28,14 +28,31 @@ type SetPartitioner interface {
} }
func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) { func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
// Config.InsecureSkipVerify | EnableHostVerification | Result
// Config is nil | true | verify host
// Config is nil | false | do not verify host
// false | false | verify host
// true | false | do not verify host
// false | true | verify host
// true | true | verify host
var tlsConfig *tls.Config
if sslOpts.Config == nil { if sslOpts.Config == nil {
sslOpts.Config = &tls.Config{} tlsConfig = &tls.Config{
InsecureSkipVerify: !sslOpts.EnableHostVerification,
}
} else {
// use clone to avoid race.
tlsConfig = sslOpts.Config.Clone()
}
if tlsConfig.InsecureSkipVerify && sslOpts.EnableHostVerification {
tlsConfig.InsecureSkipVerify = false
} }
// ca cert is optional // ca cert is optional
if sslOpts.CaPath != "" { if sslOpts.CaPath != "" {
if sslOpts.RootCAs == nil { if tlsConfig.RootCAs == nil {
sslOpts.RootCAs = x509.NewCertPool() tlsConfig.RootCAs = x509.NewCertPool()
} }
pem, err := ioutil.ReadFile(sslOpts.CaPath) pem, err := ioutil.ReadFile(sslOpts.CaPath)
@ -43,7 +60,7 @@ func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
return nil, fmt.Errorf("connectionpool: unable to open CA certs: %v", err) return nil, fmt.Errorf("connectionpool: unable to open CA certs: %v", err)
} }
if !sslOpts.RootCAs.AppendCertsFromPEM(pem) { if !tlsConfig.RootCAs.AppendCertsFromPEM(pem) {
return nil, errors.New("connectionpool: failed parsing or CA certs") return nil, errors.New("connectionpool: failed parsing or CA certs")
} }
} }
@ -53,13 +70,10 @@ func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("connectionpool: unable to load X509 key pair: %v", err) return nil, fmt.Errorf("connectionpool: unable to load X509 key pair: %v", err)
} }
sslOpts.Certificates = append(sslOpts.Certificates, mycert) tlsConfig.Certificates = append(tlsConfig.Certificates, mycert)
} }
sslOpts.InsecureSkipVerify = !sslOpts.EnableHostVerification return tlsConfig, nil
// return clone to avoid race
return sslOpts.Config.Clone(), nil
} }
type policyConnPool struct { type policyConnPool struct {
@ -238,12 +252,6 @@ func (p *policyConnPool) removeHost(ip net.IP) {
go pool.Close() go pool.Close()
} }
func (p *policyConnPool) hostUp(host *HostInfo) {
// TODO(zariel): have a set of up hosts and down hosts, we can internally
// detect down hosts, then try to reconnect to them.
p.addHost(host)
}
func (p *policyConnPool) hostDown(ip net.IP) { func (p *policyConnPool) hostDown(ip net.IP) {
// TODO(zariel): mark host as down so we can try to connect to it later, for // TODO(zariel): mark host as down so we can try to connect to it later, for
// now just treat it has removed. // now just treat it has removed.
@ -429,6 +437,8 @@ func (pool *hostConnPool) fill() {
} }
return return
} }
// notify the session that this node is connected
go pool.session.handleNodeUp(pool.host.ConnectAddress(), pool.port)
// filled one // filled one
fillCount-- fillCount--
@ -440,6 +450,11 @@ func (pool *hostConnPool) fill() {
// mark the end of filling // mark the end of filling
pool.fillingStopped(err != nil) pool.fillingStopped(err != nil)
if err == nil && startCount > 0 {
// notify the session that this node is connected again
go pool.session.handleNodeUp(pool.host.ConnectAddress(), pool.port)
}
}() }()
} }

View File

@ -125,7 +125,7 @@ func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} else if len(ips) == 0 { } else if len(ips) == 0 {
return nil, fmt.Errorf("No IP's returned from DNS lookup for %q", addr) return nil, fmt.Errorf("no IP's returned from DNS lookup for %q", addr)
} }
// Filter to v4 addresses if any present // Filter to v4 addresses if any present
@ -177,7 +177,7 @@ func (c *controlConn) shuffleDial(endpoints []*HostInfo) (*Conn, error) {
return conn, nil return conn, nil
} }
Logger.Printf("gocql: unable to dial control conn %v: %v\n", host.ConnectAddress(), err) Logger.Printf("gocql: unable to dial control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
} }
return nil, err return nil, err
@ -285,8 +285,6 @@ func (c *controlConn) setupConn(conn *Conn) error {
} }
c.conn.Store(ch) c.conn.Store(ch)
c.session.handleNodeUp(host.ConnectAddress(), host.Port(), false)
return nil return nil
} }
@ -452,6 +450,8 @@ func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter
for { for {
iter = c.withConn(func(conn *Conn) *Iter { iter = c.withConn(func(conn *Conn) *Iter {
// we want to keep the query on the control connection
q.conn = conn
return conn.executeQuery(context.TODO(), q) return conn.executeQuery(context.TODO(), q)
}) })

317
vendor/github.com/gocql/gocql/doc.go generated vendored
View File

@ -4,6 +4,319 @@
// Package gocql implements a fast and robust Cassandra driver for the // Package gocql implements a fast and robust Cassandra driver for the
// Go programming language. // Go programming language.
//
// Connecting to the cluster
//
// Pass a list of initial node IP addresses to NewCluster to create a new cluster configuration:
//
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
//
// Port can be specified as part of the address, the above is equivalent to:
//
// cluster := gocql.NewCluster("192.168.1.1:9042", "192.168.1.2:9042", "192.168.1.3:9042")
//
// It is recommended to use the value set in the Cassandra config for broadcast_address or listen_address,
// an IP address not a domain name. This is because events from Cassandra will use the configured IP
// address, which is used to index connected hosts. If the domain name specified resolves to more than 1 IP address
// then the driver may connect multiple times to the same host, and will not mark the node being down or up from events.
//
// Then you can customize more options (see ClusterConfig):
//
// cluster.Keyspace = "example"
// cluster.Consistency = gocql.Quorum
// cluster.ProtoVersion = 4
//
// The driver tries to automatically detect the protocol version to use if not set, but you might want to set the
// protocol version explicitly, as it's not defined which version will be used in certain situations (for example
// during upgrade of the cluster when some of the nodes support different set of protocol versions than other nodes).
//
// When ready, create a session from the configuration. Don't forget to Close the session once you are done with it:
//
// session, err := cluster.CreateSession()
// if err != nil {
// return err
// }
// defer session.Close()
//
// Authentication
//
// CQL protocol uses a SASL-based authentication mechanism and so consists of an exchange of server challenges and
// client response pairs. The details of the exchanged messages depend on the authenticator used.
//
// To use authentication, set ClusterConfig.Authenticator or ClusterConfig.AuthProvider.
//
// PasswordAuthenticator is provided to use for username/password authentication:
//
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
// cluster.Authenticator = gocql.PasswordAuthenticator{
// Username: "user",
// Password: "password"
// }
// session, err := cluster.CreateSession()
// if err != nil {
// return err
// }
// defer session.Close()
//
// Transport layer security
//
// It is possible to secure traffic between the client and server with TLS.
//
// To use TLS, set the ClusterConfig.SslOpts field. SslOptions embeds *tls.Config so you can set that directly.
// There are also helpers to load keys/certificates from files.
//
// Warning: Due to historical reasons, the SslOptions is insecure by default, so you need to set EnableHostVerification
// to true if no Config is set. Most users should set SslOptions.Config to a *tls.Config.
// SslOptions and Config.InsecureSkipVerify interact as follows:
//
// Config.InsecureSkipVerify | EnableHostVerification | Result
// Config is nil | false | do not verify host
// Config is nil | true | verify host
// false | false | verify host
// true | false | do not verify host
// false | true | verify host
// true | true | verify host
//
// For example:
//
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
// cluster.SslOpts = &gocql.SslOptions{
// EnableHostVerification: true,
// }
// session, err := cluster.CreateSession()
// if err != nil {
// return err
// }
// defer session.Close()
//
// Executing queries
//
// Create queries with Session.Query. Query values must not be reused between different executions and must not be
// modified after starting execution of the query.
//
// To execute a query without reading results, use Query.Exec:
//
// err := session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
// "me", gocql.TimeUUID(), "hello world").WithContext(ctx).Exec()
//
// Single row can be read by calling Query.Scan:
//
// err := session.Query(`SELECT id, text FROM tweet WHERE timeline = ? LIMIT 1`,
// "me").WithContext(ctx).Consistency(gocql.One).Scan(&id, &text)
//
// Multiple rows can be read using Iter.Scanner:
//
// scanner := session.Query(`SELECT id, text FROM tweet WHERE timeline = ?`,
// "me").WithContext(ctx).Iter().Scanner()
// for scanner.Next() {
// var (
// id gocql.UUID
// text string
// )
// err = scanner.Scan(&id, &text)
// if err != nil {
// log.Fatal(err)
// }
// fmt.Println("Tweet:", id, text)
// }
// // scanner.Err() closes the iterator, so scanner nor iter should be used afterwards.
// if err := scanner.Err(); err != nil {
// log.Fatal(err)
// }
//
// See Example for complete example.
//
// Prepared statements
//
// The driver automatically prepares DML queries (SELECT/INSERT/UPDATE/DELETE/BATCH statements) and maintains a cache
// of prepared statements.
// CQL protocol does not support preparing other query types.
//
// When using CQL protocol >= 4, it is possible to use gocql.UnsetValue as the bound value of a column.
// This will cause the database to ignore writing the column.
// The main advantage is the ability to keep the same prepared statement even when you don't
// want to update some fields, where before you needed to make another prepared statement.
//
// Executing multiple queries concurrently
//
// Session is safe to use from multiple goroutines, so to execute multiple concurrent queries, just execute them
// from several worker goroutines. Gocql provides synchronously-looking API (as recommended for Go APIs) and the queries
// are executed asynchronously at the protocol level.
//
// results := make(chan error, 2)
// go func() {
// results <- session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
// "me", gocql.TimeUUID(), "hello world 1").Exec()
// }()
// go func() {
// results <- session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
// "me", gocql.TimeUUID(), "hello world 2").Exec()
// }()
//
// Nulls
//
// Null values are are unmarshalled as zero value of the type. If you need to distinguish for example between text
// column being null and empty string, you can unmarshal into *string variable instead of string.
//
// var text *string
// err := scanner.Scan(&text)
// if err != nil {
// // handle error
// }
// if text != nil {
// // not null
// }
// else {
// // null
// }
//
// See Example_nulls for full example.
//
// Reusing slices
//
// The driver reuses backing memory of slices when unmarshalling. This is an optimization so that a buffer does not
// need to be allocated for every processed row. However, you need to be careful when storing the slices to other
// memory structures.
//
// scanner := session.Query(`SELECT myints FROM table WHERE pk = ?`, "key").WithContext(ctx).Iter().Scanner()
// var myInts []int
// for scanner.Next() {
// // This scan reuses backing store of myInts for each row.
// err = scanner.Scan(&myInts)
// if err != nil {
// log.Fatal(err)
// }
// }
//
// When you want to save the data for later use, pass a new slice every time. A common pattern is to declare the
// slice variable within the scanner loop:
//
// scanner := session.Query(`SELECT myints FROM table WHERE pk = ?`, "key").WithContext(ctx).Iter().Scanner()
// for scanner.Next() {
// var myInts []int
// // This scan always gets pointer to fresh myInts slice, so does not reuse memory.
// err = scanner.Scan(&myInts)
// if err != nil {
// log.Fatal(err)
// }
// }
//
// Paging
//
// The driver supports paging of results with automatic prefetch, see ClusterConfig.PageSize, Session.SetPrefetch,
// Query.PageSize, and Query.Prefetch.
//
// It is also possible to control the paging manually with Query.PageState (this disables automatic prefetch).
// Manual paging is useful if you want to store the page state externally, for example in a URL to allow users
// browse pages in a result. You might want to sign/encrypt the paging state when exposing it externally since
// it contains data from primary keys.
//
// Paging state is specific to the CQL protocol version and the exact query used. It is meant as opaque state that
// should not be modified. If you send paging state from different query or protocol version, then the behaviour
// is not defined (you might get unexpected results or an error from the server). For example, do not send paging state
// returned by node using protocol version 3 to a node using protocol version 4. Also, when using protocol version 4,
// paging state between Cassandra 2.2 and 3.0 is incompatible (https://issues.apache.org/jira/browse/CASSANDRA-10880).
//
// The driver does not check whether the paging state is from the same protocol version/statement.
// You might want to validate yourself as this could be a problem if you store paging state externally.
// For example, if you store paging state in a URL, the URLs might become broken when you upgrade your cluster.
//
// Call Query.PageState(nil) to fetch just the first page of the query results. Pass the page state returned by
// Iter.PageState to Query.PageState of a subsequent query to get the next page. If the length of slice returned
// by Iter.PageState is zero, there are no more pages available (or an error occurred).
//
// Using too low values of PageSize will negatively affect performance, a value below 100 is probably too low.
// While Cassandra returns exactly PageSize items (except for last page) in a page currently, the protocol authors
// explicitly reserved the right to return smaller or larger amount of items in a page for performance reasons, so don't
// rely on the page having the exact count of items.
//
// See Example_paging for an example of manual paging.
//
// Dynamic list of columns
//
// There are certain situations when you don't know the list of columns in advance, mainly when the query is supplied
// by the user. Iter.Columns, Iter.RowData, Iter.MapScan and Iter.SliceMap can be used to handle this case.
//
// See Example_dynamicColumns.
//
// Batches
//
// The CQL protocol supports sending batches of DML statements (INSERT/UPDATE/DELETE) and so does gocql.
// Use Session.NewBatch to create a new batch and then fill-in details of individual queries.
// Then execute the batch with Session.ExecuteBatch.
//
// Logged batches ensure atomicity, either all or none of the operations in the batch will succeed, but they have
// overhead to ensure this property.
// Unlogged batches don't have the overhead of logged batches, but don't guarantee atomicity.
// Updates of counters are handled specially by Cassandra so batches of counter updates have to use CounterBatch type.
// A counter batch can only contain statements to update counters.
//
// For unlogged batches it is recommended to send only single-partition batches (i.e. all statements in the batch should
// involve only a single partition).
// Multi-partition batch needs to be split by the coordinator node and re-sent to
// correct nodes.
// With single-partition batches you can send the batch directly to the node for the partition without incurring the
// additional network hop.
//
// It is also possible to pass entire BEGIN BATCH .. APPLY BATCH statement to Query.Exec.
// There are differences how those are executed.
// BEGIN BATCH statement passed to Query.Exec is prepared as a whole in a single statement.
// Session.ExecuteBatch prepares individual statements in the batch.
// If you have variable-length batches using the same statement, using Session.ExecuteBatch is more efficient.
//
// See Example_batch for an example.
//
// Lightweight transactions
//
// Query.ScanCAS or Query.MapScanCAS can be used to execute a single-statement lightweight transaction (an
// INSERT/UPDATE .. IF statement) and reading its result. See example for Query.MapScanCAS.
//
// Multiple-statement lightweight transactions can be executed as a logged batch that contains at least one conditional
// statement. All the conditions must return true for the batch to be applied. You can use Session.ExecuteBatchCAS and
// Session.MapExecuteBatchCAS when executing the batch to learn about the result of the LWT. See example for
// Session.MapExecuteBatchCAS.
//
// Retries and speculative execution
//
// Queries can be marked as idempotent. Marking the query as idempotent tells the driver that the query can be executed
// multiple times without affecting its result. Non-idempotent queries are not eligible for retrying nor speculative
// execution.
//
// Idempotent queries are retried in case of errors based on the configured RetryPolicy.
//
// Queries can be retried even before they fail by setting a SpeculativeExecutionPolicy. The policy can
// cause the driver to retry on a different node if the query is taking longer than a specified delay even before the
// driver receives an error or timeout from the server. When a query is speculatively executed, the original execution
// is still executing. The two parallel executions of the query race to return a result, the first received result will
// be returned.
//
// User-defined types
//
// UDTs can be mapped (un)marshaled from/to map[string]interface{} a Go struct (or a type implementing
// UDTUnmarshaler, UDTMarshaler, Unmarshaler or Marshaler interfaces).
//
// For structs, cql tag can be used to specify the CQL field name to be mapped to a struct field:
//
// type MyUDT struct {
// FieldA int32 `cql:"a"`
// FieldB string `cql:"b"`
// }
//
// See Example_userDefinedTypesMap, Example_userDefinedTypesStruct, ExampleUDTMarshaler, ExampleUDTUnmarshaler.
//
// Metrics and tracing
//
// It is possible to provide observer implementations that could be used to gather metrics:
//
// - QueryObserver for monitoring individual queries.
// - BatchObserver for monitoring batch queries.
// - ConnectObserver for monitoring new connections from the driver to the database.
// - FrameHeaderObserver for monitoring individual protocol frames.
//
// CQL protocol also supports tracing of queries. When enabled, the database will write information about
// internal events that happened during execution of the query. You can use Query.Trace to request tracing and receive
// the session ID that the database used to store the trace information in system_traces.sessions and
// system_traces.events tables. NewTraceWriter returns an implementation of Tracer that writes the events to a writer.
// Gathering trace information might be essential for debugging and optimizing queries, but writing traces has overhead,
// so this feature should not be used on production systems with very high load unless you know what you are doing.
package gocql // import "github.com/gocql/gocql" package gocql // import "github.com/gocql/gocql"
// TODO(tux21b): write more docs.

View File

@ -164,55 +164,43 @@ func (s *Session) handleNodeEvent(frames []frame) {
switch f.change { switch f.change {
case "NEW_NODE": case "NEW_NODE":
s.handleNewNode(f.host, f.port, true) s.handleNewNode(f.host, f.port)
case "REMOVED_NODE": case "REMOVED_NODE":
s.handleRemovedNode(f.host, f.port) s.handleRemovedNode(f.host, f.port)
case "MOVED_NODE": case "MOVED_NODE":
// java-driver handles this, not mentioned in the spec // java-driver handles this, not mentioned in the spec
// TODO(zariel): refresh token map // TODO(zariel): refresh token map
case "UP": case "UP":
s.handleNodeUp(f.host, f.port, true) s.handleNodeUp(f.host, f.port)
case "DOWN": case "DOWN":
s.handleNodeDown(f.host, f.port) s.handleNodeDown(f.host, f.port)
} }
} }
} }
func (s *Session) addNewNode(host *HostInfo) { func (s *Session) addNewNode(ip net.IP, port int) {
if s.cfg.filterHost(host) {
return
}
host.setState(NodeUp)
s.pool.addHost(host)
s.policy.AddHost(host)
}
func (s *Session) handleNewNode(ip net.IP, port int, waitForBinary bool) {
if gocqlDebug {
Logger.Printf("gocql: Session.handleNewNode: %s:%d\n", ip.String(), port)
}
ip, port = s.cfg.translateAddressPort(ip, port)
// Get host info and apply any filters to the host // Get host info and apply any filters to the host
hostInfo, err := s.hostSource.getHostInfo(ip, port) hostInfo, err := s.hostSource.getHostInfo(ip, port)
if err != nil { if err != nil {
Logger.Printf("gocql: events: unable to fetch host info for (%s:%d): %v\n", ip, port, err) Logger.Printf("gocql: events: unable to fetch host info for (%s:%d): %v\n", ip, port, err)
return return
} else if hostInfo == nil { } else if hostInfo == nil {
// If hostInfo is nil, this host was filtered out by cfg.HostFilter // ignore if it's null because we couldn't find it
return return
} }
if t := hostInfo.Version().nodeUpDelay(); t > 0 && waitForBinary { if t := hostInfo.Version().nodeUpDelay(); t > 0 {
time.Sleep(t) time.Sleep(t)
} }
// should this handle token moving? // should this handle token moving?
hostInfo = s.ring.addOrUpdate(hostInfo) hostInfo = s.ring.addOrUpdate(hostInfo)
s.addNewNode(hostInfo) if !s.cfg.filterHost(hostInfo) {
// we let the pool call handleNodeUp to change the host state
s.pool.addHost(hostInfo)
s.policy.AddHost(hostInfo)
}
if s.control != nil && !s.cfg.IgnorePeerAddr { if s.control != nil && !s.cfg.IgnorePeerAddr {
// TODO(zariel): debounce ring refresh // TODO(zariel): debounce ring refresh
@ -220,6 +208,22 @@ func (s *Session) handleNewNode(ip net.IP, port int, waitForBinary bool) {
} }
} }
func (s *Session) handleNewNode(ip net.IP, port int) {
if gocqlDebug {
Logger.Printf("gocql: Session.handleNewNode: %s:%d\n", ip.String(), port)
}
ip, port = s.cfg.translateAddressPort(ip, port)
// if we already have the host and it's already up, then do nothing
host := s.ring.getHost(ip)
if host != nil && host.IsUp() {
return
}
s.addNewNode(ip, port)
}
func (s *Session) handleRemovedNode(ip net.IP, port int) { func (s *Session) handleRemovedNode(ip net.IP, port int) {
if gocqlDebug { if gocqlDebug {
Logger.Printf("gocql: Session.handleRemovedNode: %s:%d\n", ip.String(), port) Logger.Printf("gocql: Session.handleRemovedNode: %s:%d\n", ip.String(), port)
@ -232,45 +236,37 @@ func (s *Session) handleRemovedNode(ip net.IP, port int) {
if host == nil { if host == nil {
host = &HostInfo{connectAddress: ip, port: port} host = &HostInfo{connectAddress: ip, port: port}
} }
s.ring.removeHost(ip)
if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
return
}
host.setState(NodeDown) host.setState(NodeDown)
if !s.cfg.filterHost(host) {
s.policy.RemoveHost(host) s.policy.RemoveHost(host)
s.pool.removeHost(ip) s.pool.removeHost(ip)
s.ring.removeHost(ip) }
if !s.cfg.IgnorePeerAddr { if !s.cfg.IgnorePeerAddr {
s.hostSource.refreshRing() s.hostSource.refreshRing()
} }
} }
func (s *Session) handleNodeUp(eventIp net.IP, eventPort int, waitForBinary bool) { func (s *Session) handleNodeUp(eventIp net.IP, eventPort int) {
if gocqlDebug { if gocqlDebug {
Logger.Printf("gocql: Session.handleNodeUp: %s:%d\n", eventIp.String(), eventPort) Logger.Printf("gocql: Session.handleNodeUp: %s:%d\n", eventIp.String(), eventPort)
} }
ip, _ := s.cfg.translateAddressPort(eventIp, eventPort) ip, port := s.cfg.translateAddressPort(eventIp, eventPort)
host := s.ring.getHost(ip) host := s.ring.getHost(ip)
if host == nil { if host == nil {
// TODO(zariel): avoid the need to translate twice in this s.addNewNode(ip, port)
// case
s.handleNewNode(eventIp, eventPort, waitForBinary)
return return
} }
if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) { host.setState(NodeUp)
return
}
if t := host.Version().nodeUpDelay(); t > 0 && waitForBinary { if !s.cfg.filterHost(host) {
time.Sleep(t) s.policy.HostUp(host)
} }
s.addNewNode(host)
} }
func (s *Session) handleNodeDown(ip net.IP, port int) { func (s *Session) handleNodeDown(ip net.IP, port int) {
@ -283,11 +279,11 @@ func (s *Session) handleNodeDown(ip net.IP, port int) {
host = &HostInfo{connectAddress: ip, port: port} host = &HostInfo{connectAddress: ip, port: port}
} }
if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) { host.setState(NodeDown)
if s.cfg.filterHost(host) {
return return
} }
host.setState(NodeDown)
s.policy.HostDown(host) s.policy.HostDown(host)
s.pool.hostDown(ip) s.pool.hostDown(ip)
} }

View File

@ -311,26 +311,10 @@ var (
const maxFrameHeaderSize = 9 const maxFrameHeaderSize = 9
func writeInt(p []byte, n int32) {
p[0] = byte(n >> 24)
p[1] = byte(n >> 16)
p[2] = byte(n >> 8)
p[3] = byte(n)
}
func readInt(p []byte) int32 { func readInt(p []byte) int32 {
return int32(p[0])<<24 | int32(p[1])<<16 | int32(p[2])<<8 | int32(p[3]) return int32(p[0])<<24 | int32(p[1])<<16 | int32(p[2])<<8 | int32(p[3])
} }
func writeShort(p []byte, n uint16) {
p[0] = byte(n >> 8)
p[1] = byte(n)
}
func readShort(p []byte) uint16 {
return uint16(p[0])<<8 | uint16(p[1])
}
type frameHeader struct { type frameHeader struct {
version protoVersion version protoVersion
flags byte flags byte
@ -854,7 +838,7 @@ func (w *writePrepareFrame) writeFrame(f *framer, streamID int) error {
if f.proto > protoVersion4 { if f.proto > protoVersion4 {
flags |= flagWithPreparedKeyspace flags |= flagWithPreparedKeyspace
} else { } else {
panic(fmt.Errorf("The keyspace can only be set with protocol 5 or higher")) panic(fmt.Errorf("the keyspace can only be set with protocol 5 or higher"))
} }
} }
if f.proto > protoVersion4 { if f.proto > protoVersion4 {
@ -1502,7 +1486,7 @@ func (f *framer) writeQueryParams(opts *queryParams) {
if f.proto > protoVersion4 { if f.proto > protoVersion4 {
flags |= flagWithKeyspace flags |= flagWithKeyspace
} else { } else {
panic(fmt.Errorf("The keyspace can only be set with protocol 5 or higher")) panic(fmt.Errorf("the keyspace can only be set with protocol 5 or higher"))
} }
} }
@ -1792,16 +1776,6 @@ func (f *framer) readShort() (n uint16) {
return return
} }
func (f *framer) readLong() (n int64) {
if len(f.rbuf) < 8 {
panic(fmt.Errorf("not enough bytes in buffer to read long require 8 got: %d", len(f.rbuf)))
}
n = int64(f.rbuf[0])<<56 | int64(f.rbuf[1])<<48 | int64(f.rbuf[2])<<40 | int64(f.rbuf[3])<<32 |
int64(f.rbuf[4])<<24 | int64(f.rbuf[5])<<16 | int64(f.rbuf[6])<<8 | int64(f.rbuf[7])
f.rbuf = f.rbuf[8:]
return
}
func (f *framer) readString() (s string) { func (f *framer) readString() (s string) {
size := f.readShort() size := f.readShort()
@ -1915,19 +1889,6 @@ func (f *framer) readConsistency() Consistency {
return Consistency(f.readShort()) return Consistency(f.readShort())
} }
func (f *framer) readStringMap() map[string]string {
size := f.readShort()
m := make(map[string]string, size)
for i := 0; i < int(size); i++ {
k := f.readString()
v := f.readString()
m[k] = v
}
return m
}
func (f *framer) readBytesMap() map[string][]byte { func (f *framer) readBytesMap() map[string][]byte {
size := f.readShort() size := f.readShort()
m := make(map[string][]byte, size) m := make(map[string][]byte, size)
@ -2037,10 +1998,6 @@ func (f *framer) writeLongString(s string) {
f.wbuf = append(f.wbuf, s...) f.wbuf = append(f.wbuf, s...)
} }
func (f *framer) writeUUID(u *UUID) {
f.wbuf = append(f.wbuf, u[:]...)
}
func (f *framer) writeStringList(l []string) { func (f *framer) writeStringList(l []string) {
f.writeShort(uint16(len(l))) f.writeShort(uint16(len(l)))
for _, s := range l { for _, s := range l {
@ -2073,18 +2030,6 @@ func (f *framer) writeShortBytes(p []byte) {
f.wbuf = append(f.wbuf, p...) f.wbuf = append(f.wbuf, p...)
} }
func (f *framer) writeInet(ip net.IP, port int) {
f.wbuf = append(f.wbuf,
byte(len(ip)),
)
f.wbuf = append(f.wbuf,
[]byte(ip)...,
)
f.writeInt(int32(port))
}
func (f *framer) writeConsistency(cons Consistency) { func (f *framer) writeConsistency(cons Consistency) {
f.writeShort(uint16(cons)) f.writeShort(uint16(cons))
} }

View File

@ -270,15 +270,6 @@ func getApacheCassandraType(class string) Type {
} }
} }
func typeCanBeNull(typ TypeInfo) bool {
switch typ.(type) {
case CollectionType, UDTTypeInfo, TupleTypeInfo:
return false
}
return true
}
func (r *RowData) rowMap(m map[string]interface{}) { func (r *RowData) rowMap(m map[string]interface{}) {
for i, column := range r.Columns { for i, column := range r.Columns {
val := dereference(r.Values[i]) val := dereference(r.Values[i])
@ -372,7 +363,7 @@ func (iter *Iter) SliceMap() ([]map[string]interface{}, error) {
// iter := session.Query(`SELECT * FROM mytable`).Iter() // iter := session.Query(`SELECT * FROM mytable`).Iter()
// for { // for {
// // New map each iteration // // New map each iteration
// row = make(map[string]interface{}) // row := make(map[string]interface{})
// if !iter.MapScan(row) { // if !iter.MapScan(row) {
// break // break
// } // }

View File

@ -147,13 +147,6 @@ func (h *HostInfo) Peer() net.IP {
return h.peer return h.peer
} }
func (h *HostInfo) setPeer(peer net.IP) *HostInfo {
h.mu.Lock()
defer h.mu.Unlock()
h.peer = peer
return h
}
func (h *HostInfo) invalidConnectAddr() bool { func (h *HostInfo) invalidConnectAddr() bool {
h.mu.RLock() h.mu.RLock()
defer h.mu.RUnlock() defer h.mu.RUnlock()
@ -233,13 +226,6 @@ func (h *HostInfo) DataCenter() string {
return dc return dc
} }
func (h *HostInfo) setDataCenter(dataCenter string) *HostInfo {
h.mu.Lock()
defer h.mu.Unlock()
h.dataCenter = dataCenter
return h
}
func (h *HostInfo) Rack() string { func (h *HostInfo) Rack() string {
h.mu.RLock() h.mu.RLock()
rack := h.rack rack := h.rack
@ -247,26 +233,12 @@ func (h *HostInfo) Rack() string {
return rack return rack
} }
func (h *HostInfo) setRack(rack string) *HostInfo {
h.mu.Lock()
defer h.mu.Unlock()
h.rack = rack
return h
}
func (h *HostInfo) HostID() string { func (h *HostInfo) HostID() string {
h.mu.RLock() h.mu.RLock()
defer h.mu.RUnlock() defer h.mu.RUnlock()
return h.hostId return h.hostId
} }
func (h *HostInfo) setHostID(hostID string) *HostInfo {
h.mu.Lock()
defer h.mu.Unlock()
h.hostId = hostID
return h
}
func (h *HostInfo) WorkLoad() string { func (h *HostInfo) WorkLoad() string {
h.mu.RLock() h.mu.RLock()
defer h.mu.RUnlock() defer h.mu.RUnlock()
@ -303,13 +275,6 @@ func (h *HostInfo) Version() cassVersion {
return h.version return h.version
} }
func (h *HostInfo) setVersion(major, minor, patch int) *HostInfo {
h.mu.Lock()
defer h.mu.Unlock()
h.version = cassVersion{major, minor, patch}
return h
}
func (h *HostInfo) State() nodeState { func (h *HostInfo) State() nodeState {
h.mu.RLock() h.mu.RLock()
defer h.mu.RUnlock() defer h.mu.RUnlock()
@ -329,26 +294,12 @@ func (h *HostInfo) Tokens() []string {
return h.tokens return h.tokens
} }
func (h *HostInfo) setTokens(tokens []string) *HostInfo {
h.mu.Lock()
defer h.mu.Unlock()
h.tokens = tokens
return h
}
func (h *HostInfo) Port() int { func (h *HostInfo) Port() int {
h.mu.RLock() h.mu.RLock()
defer h.mu.RUnlock() defer h.mu.RUnlock()
return h.port return h.port
} }
func (h *HostInfo) setPort(port int) *HostInfo {
h.mu.Lock()
defer h.mu.Unlock()
h.port = port
return h
}
func (h *HostInfo) update(from *HostInfo) { func (h *HostInfo) update(from *HostInfo) {
if h == from { if h == from {
return return
@ -689,7 +640,7 @@ func (r *ringDescriber) refreshRing() error {
// TODO: move this to session // TODO: move this to session
for _, h := range hosts { for _, h := range hosts {
if filter := r.session.cfg.HostFilter; filter != nil && !filter.Accept(h) { if r.session.cfg.filterHost(h) {
continue continue
} }

View File

@ -8,9 +8,3 @@ git clone https://github.com/pcmanus/ccm.git
pushd ccm pushd ccm
./setup.py install --user ./setup.py install --user
popd popd
if [ "$1" != "gocql/gocql" ]; then
USER=$(echo $1 | cut -f1 -d'/')
cd ../..
mv ${USER} gocql
fi

View File

@ -44,6 +44,52 @@ type Unmarshaler interface {
// Marshal returns the CQL encoding of the value for the Cassandra // Marshal returns the CQL encoding of the value for the Cassandra
// internal type described by the info parameter. // internal type described by the info parameter.
//
// nil is serialized as CQL null.
// If value implements Marshaler, its MarshalCQL method is called to marshal the data.
// If value is a pointer, the pointed-to value is marshaled.
//
// Supported conversions are as follows, other type combinations may be added in the future:
//
// CQL type | Go type (value) | Note
// varchar, ascii, blob, text | string, []byte |
// boolean | bool |
// tinyint, smallint, int | integer types |
// tinyint, smallint, int | string | formatted as base 10 number
// bigint, counter | integer types |
// bigint, counter | big.Int |
// bigint, counter | string | formatted as base 10 number
// float | float32 |
// double | float64 |
// decimal | inf.Dec |
// time | int64 | nanoseconds since start of day
// time | time.Duration | duration since start of day
// timestamp | int64 | milliseconds since Unix epoch
// timestamp | time.Time |
// list, set | slice, array |
// list, set | map[X]struct{} |
// map | map[X]Y |
// uuid, timeuuid | gocql.UUID |
// uuid, timeuuid | [16]byte | raw UUID bytes
// uuid, timeuuid | []byte | raw UUID bytes, length must be 16 bytes
// uuid, timeuuid | string | hex representation, see ParseUUID
// varint | integer types |
// varint | big.Int |
// varint | string | value of number in decimal notation
// inet | net.IP |
// inet | string | IPv4 or IPv6 address string
// tuple | slice, array |
// tuple | struct | fields are marshaled in order of declaration
// user-defined type | gocql.UDTMarshaler | MarshalUDT is called
// user-defined type | map[string]interface{} |
// user-defined type | struct | struct fields' cql tags are used for column names
// date | int64 | milliseconds since Unix epoch to start of day (in UTC)
// date | time.Time | start of day (in UTC)
// date | string | parsed using "2006-01-02" format
// duration | int64 | duration in nanoseconds
// duration | time.Duration |
// duration | gocql.Duration |
// duration | string | parsed with time.ParseDuration
func Marshal(info TypeInfo, value interface{}) ([]byte, error) { func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
if info.Version() < protoVersion1 { if info.Version() < protoVersion1 {
panic("protocol version not set") panic("protocol version not set")
@ -118,6 +164,44 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
// Unmarshal parses the CQL encoded data based on the info parameter that // Unmarshal parses the CQL encoded data based on the info parameter that
// describes the Cassandra internal data type and stores the result in the // describes the Cassandra internal data type and stores the result in the
// value pointed by value. // value pointed by value.
//
// If value implements Unmarshaler, it's UnmarshalCQL method is called to
// unmarshal the data.
// If value is a pointer to pointer, it is set to nil if the CQL value is
// null. Otherwise, nulls are unmarshalled as zero value.
//
// Supported conversions are as follows, other type combinations may be added in the future:
//
// CQL type | Go type (value) | Note
// varchar, ascii, blob, text | *string |
// varchar, ascii, blob, text | *[]byte | non-nil buffer is reused
// bool | *bool |
// tinyint, smallint, int, bigint, counter | *integer types |
// tinyint, smallint, int, bigint, counter | *big.Int |
// tinyint, smallint, int, bigint, counter | *string | formatted as base 10 number
// float | *float32 |
// double | *float64 |
// decimal | *inf.Dec |
// time | *int64 | nanoseconds since start of day
// time | *time.Duration |
// timestamp | *int64 | milliseconds since Unix epoch
// timestamp | *time.Time |
// list, set | *slice, *array |
// map | *map[X]Y |
// uuid, timeuuid | *string | see UUID.String
// uuid, timeuuid | *[]byte | raw UUID bytes
// uuid, timeuuid | *gocql.UUID |
// timeuuid | *time.Time | timestamp of the UUID
// inet | *net.IP |
// inet | *string | IPv4 or IPv6 address string
// tuple | *slice, *array |
// tuple | *struct | struct fields are set in order of declaration
// user-defined types | gocql.UDTUnmarshaler | UnmarshalUDT is called
// user-defined types | *map[string]interface{} |
// user-defined types | *struct | cql tag is used to determine field name
// date | *time.Time | time of beginning of the day (in UTC)
// date | *string | formatted with 2006-01-02 format
// duration | *gocql.Duration |
func Unmarshal(info TypeInfo, data []byte, value interface{}) error { func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
if v, ok := value.(Unmarshaler); ok { if v, ok := value.(Unmarshaler); ok {
return v.UnmarshalCQL(info, data) return v.UnmarshalCQL(info, data)
@ -1690,6 +1774,8 @@ func marshalUUID(info TypeInfo, value interface{}) ([]byte, error) {
return nil, nil return nil, nil
case UUID: case UUID:
return val.Bytes(), nil return val.Bytes(), nil
case [16]byte:
return val[:], nil
case []byte: case []byte:
if len(val) != 16 { if len(val) != 16 {
return nil, marshalErrorf("can not marshal []byte %d bytes long into %s, must be exactly 16 bytes long", len(val), info) return nil, marshalErrorf("can not marshal []byte %d bytes long into %s, must be exactly 16 bytes long", len(val), info)
@ -1711,7 +1797,7 @@ func marshalUUID(info TypeInfo, value interface{}) ([]byte, error) {
} }
func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error { func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error {
if data == nil || len(data) == 0 { if len(data) == 0 {
switch v := value.(type) { switch v := value.(type) {
case *string: case *string:
*v = "" *v = ""
@ -1726,9 +1812,22 @@ func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error {
return nil return nil
} }
if len(data) != 16 {
return unmarshalErrorf("unable to parse UUID: UUIDs must be exactly 16 bytes long")
}
switch v := value.(type) {
case *[16]byte:
copy((*v)[:], data)
return nil
case *UUID:
copy((*v)[:], data)
return nil
}
u, err := UUIDFromBytes(data) u, err := UUIDFromBytes(data)
if err != nil { if err != nil {
return unmarshalErrorf("Unable to parse UUID: %s", err) return unmarshalErrorf("unable to parse UUID: %s", err)
} }
switch v := value.(type) { switch v := value.(type) {
@ -1738,9 +1837,6 @@ func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error {
case *[]byte: case *[]byte:
*v = u[:] *v = u[:]
return nil return nil
case *UUID:
*v = u
return nil
} }
return unmarshalErrorf("can not unmarshal X %s into %T", info, value) return unmarshalErrorf("can not unmarshal X %s into %T", info, value)
} }
@ -1942,7 +2038,7 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error {
for i, elem := range tuple.Elems { for i, elem := range tuple.Elems {
// each element inside data is a [bytes] // each element inside data is a [bytes]
var p []byte var p []byte
if len(data) > 4 { if len(data) >= 4 {
p, data = readBytes(data) p, data = readBytes(data)
} }
err := Unmarshal(elem, p, v[i]) err := Unmarshal(elem, p, v[i])
@ -1971,7 +2067,7 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error {
for i, elem := range tuple.Elems { for i, elem := range tuple.Elems {
var p []byte var p []byte
if len(data) > 4 { if len(data) >= 4 {
p, data = readBytes(data) p, data = readBytes(data)
} }
@ -1982,7 +2078,11 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error {
switch rv.Field(i).Kind() { switch rv.Field(i).Kind() {
case reflect.Ptr: case reflect.Ptr:
if p != nil {
rv.Field(i).Set(reflect.ValueOf(v)) rv.Field(i).Set(reflect.ValueOf(v))
} else {
rv.Field(i).Set(reflect.Zero(reflect.TypeOf(v)))
}
default: default:
rv.Field(i).Set(reflect.ValueOf(v).Elem()) rv.Field(i).Set(reflect.ValueOf(v).Elem())
} }
@ -2001,7 +2101,7 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error {
for i, elem := range tuple.Elems { for i, elem := range tuple.Elems {
var p []byte var p []byte
if len(data) > 4 { if len(data) >= 4 {
p, data = readBytes(data) p, data = readBytes(data)
} }
@ -2012,7 +2112,11 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error {
switch rv.Index(i).Kind() { switch rv.Index(i).Kind() {
case reflect.Ptr: case reflect.Ptr:
if p != nil {
rv.Index(i).Set(reflect.ValueOf(v)) rv.Index(i).Set(reflect.ValueOf(v))
} else {
rv.Index(i).Set(reflect.Zero(reflect.TypeOf(v)))
}
default: default:
rv.Index(i).Set(reflect.ValueOf(v).Elem()) rv.Index(i).Set(reflect.ValueOf(v).Elem())
} }
@ -2050,7 +2154,7 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) {
case Marshaler: case Marshaler:
return v.MarshalCQL(info) return v.MarshalCQL(info)
case unsetColumn: case unsetColumn:
return nil, unmarshalErrorf("Invalid request: UnsetValue is unsupported for user defined types") return nil, unmarshalErrorf("invalid request: UnsetValue is unsupported for user defined types")
case UDTMarshaler: case UDTMarshaler:
var buf []byte var buf []byte
for _, e := range udt.Elements { for _, e := range udt.Elements {

View File

@ -324,10 +324,10 @@ func compileMetadata(
keyspace.Functions[functions[i].Name] = &functions[i] keyspace.Functions[functions[i].Name] = &functions[i]
} }
keyspace.Aggregates = make(map[string]*AggregateMetadata, len(aggregates)) keyspace.Aggregates = make(map[string]*AggregateMetadata, len(aggregates))
for _, aggregate := range aggregates { for i, _ := range aggregates {
aggregate.FinalFunc = *keyspace.Functions[aggregate.finalFunc] aggregates[i].FinalFunc = *keyspace.Functions[aggregates[i].finalFunc]
aggregate.StateFunc = *keyspace.Functions[aggregate.stateFunc] aggregates[i].StateFunc = *keyspace.Functions[aggregates[i].stateFunc]
keyspace.Aggregates[aggregate.Name] = &aggregate keyspace.Aggregates[aggregates[i].Name] = &aggregates[i]
} }
keyspace.Views = make(map[string]*ViewMetadata, len(views)) keyspace.Views = make(map[string]*ViewMetadata, len(views))
for i := range views { for i := range views {
@ -347,9 +347,9 @@ func compileMetadata(
keyspace.UserTypes[types[i].Name] = &types[i] keyspace.UserTypes[types[i].Name] = &types[i]
} }
keyspace.MaterializedViews = make(map[string]*MaterializedViewMetadata, len(materializedViews)) keyspace.MaterializedViews = make(map[string]*MaterializedViewMetadata, len(materializedViews))
for _, materializedView := range materializedViews { for i, _ := range materializedViews {
materializedView.BaseTable = keyspace.Tables[materializedView.baseTableName] materializedViews[i].BaseTable = keyspace.Tables[materializedViews[i].baseTableName]
keyspace.MaterializedViews[materializedView.Name] = &materializedView keyspace.MaterializedViews[materializedViews[i].Name] = &materializedViews[i]
} }
// add columns from the schema data // add columns from the schema data
@ -559,7 +559,7 @@ func getKeyspaceMetadata(session *Session, keyspaceName string) (*KeyspaceMetada
iter.Scan(&keyspace.DurableWrites, &replication) iter.Scan(&keyspace.DurableWrites, &replication)
err := iter.Close() err := iter.Close()
if err != nil { if err != nil {
return nil, fmt.Errorf("Error querying keyspace schema: %v", err) return nil, fmt.Errorf("error querying keyspace schema: %v", err)
} }
keyspace.StrategyClass = replication["class"] keyspace.StrategyClass = replication["class"]
@ -585,13 +585,13 @@ func getKeyspaceMetadata(session *Session, keyspaceName string) (*KeyspaceMetada
iter.Scan(&keyspace.DurableWrites, &keyspace.StrategyClass, &strategyOptionsJSON) iter.Scan(&keyspace.DurableWrites, &keyspace.StrategyClass, &strategyOptionsJSON)
err := iter.Close() err := iter.Close()
if err != nil { if err != nil {
return nil, fmt.Errorf("Error querying keyspace schema: %v", err) return nil, fmt.Errorf("error querying keyspace schema: %v", err)
} }
err = json.Unmarshal(strategyOptionsJSON, &keyspace.StrategyOptions) err = json.Unmarshal(strategyOptionsJSON, &keyspace.StrategyOptions)
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"Invalid JSON value '%s' as strategy_options for in keyspace '%s': %v", "invalid JSON value '%s' as strategy_options for in keyspace '%s': %v",
strategyOptionsJSON, keyspace.Name, err, strategyOptionsJSON, keyspace.Name, err,
) )
} }
@ -703,7 +703,7 @@ func getTableMetadata(session *Session, keyspaceName string) ([]TableMetadata, e
if err != nil { if err != nil {
iter.Close() iter.Close()
return nil, fmt.Errorf( return nil, fmt.Errorf(
"Invalid JSON value '%s' as key_aliases for in table '%s': %v", "invalid JSON value '%s' as key_aliases for in table '%s': %v",
keyAliasesJSON, table.Name, err, keyAliasesJSON, table.Name, err,
) )
} }
@ -716,7 +716,7 @@ func getTableMetadata(session *Session, keyspaceName string) ([]TableMetadata, e
if err != nil { if err != nil {
iter.Close() iter.Close()
return nil, fmt.Errorf( return nil, fmt.Errorf(
"Invalid JSON value '%s' as column_aliases for in table '%s': %v", "invalid JSON value '%s' as column_aliases for in table '%s': %v",
columnAliasesJSON, table.Name, err, columnAliasesJSON, table.Name, err,
) )
} }
@ -728,7 +728,7 @@ func getTableMetadata(session *Session, keyspaceName string) ([]TableMetadata, e
err := iter.Close() err := iter.Close()
if err != nil && err != ErrNotFound { if err != nil && err != ErrNotFound {
return nil, fmt.Errorf("Error querying table schema: %v", err) return nil, fmt.Errorf("error querying table schema: %v", err)
} }
return tables, nil return tables, nil
@ -777,7 +777,7 @@ func (s *Session) scanColumnMetadataV1(keyspace string) ([]ColumnMetadata, error
err := json.Unmarshal(indexOptionsJSON, &column.Index.Options) err := json.Unmarshal(indexOptionsJSON, &column.Index.Options)
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"Invalid JSON value '%s' as index_options for column '%s' in table '%s': %v", "invalid JSON value '%s' as index_options for column '%s' in table '%s': %v",
indexOptionsJSON, indexOptionsJSON,
column.Name, column.Name,
column.Table, column.Table,
@ -837,7 +837,7 @@ func (s *Session) scanColumnMetadataV2(keyspace string) ([]ColumnMetadata, error
err := json.Unmarshal(indexOptionsJSON, &column.Index.Options) err := json.Unmarshal(indexOptionsJSON, &column.Index.Options)
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"Invalid JSON value '%s' as index_options for column '%s' in table '%s': %v", "invalid JSON value '%s' as index_options for column '%s' in table '%s': %v",
indexOptionsJSON, indexOptionsJSON,
column.Name, column.Name,
column.Table, column.Table,
@ -915,7 +915,7 @@ func getColumnMetadata(session *Session, keyspaceName string) ([]ColumnMetadata,
} }
if err != nil && err != ErrNotFound { if err != nil && err != ErrNotFound {
return nil, fmt.Errorf("Error querying column schema: %v", err) return nil, fmt.Errorf("error querying column schema: %v", err)
} }
return columns, nil return columns, nil

View File

@ -1,9 +1,12 @@
// Copyright (c) 2012 The gocql Authors. All rights reserved. // Copyright (c) 2012 The gocql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//This file will be the future home for more policies
package gocql package gocql
//This file will be the future home for more policies
import ( import (
"context" "context"
"errors" "errors"
@ -37,12 +40,6 @@ func (c *cowHostList) get() []*HostInfo {
return *l return *l
} }
func (c *cowHostList) set(list []*HostInfo) {
c.mu.Lock()
c.list.Store(&list)
c.mu.Unlock()
}
// add will add a host if it not already in the list // add will add a host if it not already in the list
func (c *cowHostList) add(host *HostInfo) bool { func (c *cowHostList) add(host *HostInfo) bool {
c.mu.Lock() c.mu.Lock()
@ -68,33 +65,6 @@ func (c *cowHostList) add(host *HostInfo) bool {
return true return true
} }
func (c *cowHostList) update(host *HostInfo) {
c.mu.Lock()
l := c.get()
if len(l) == 0 {
c.mu.Unlock()
return
}
found := false
newL := make([]*HostInfo, len(l))
for i := range l {
if host.Equal(l[i]) {
newL[i] = host
found = true
} else {
newL[i] = l[i]
}
}
if found {
c.list.Store(&newL)
}
c.mu.Unlock()
}
func (c *cowHostList) remove(ip net.IP) bool { func (c *cowHostList) remove(ip net.IP) bool {
c.mu.Lock() c.mu.Lock()
l := c.get() l := c.get()
@ -304,7 +274,10 @@ type HostSelectionPolicy interface {
KeyspaceChanged(KeyspaceUpdateEvent) KeyspaceChanged(KeyspaceUpdateEvent)
Init(*Session) Init(*Session)
IsLocal(host *HostInfo) bool IsLocal(host *HostInfo) bool
//Pick returns an iteration function over selected hosts // Pick returns an iteration function over selected hosts.
// Multiple attempts of a single query execution won't call the returned NextHost function concurrently,
// so it's safe to have internal state without additional synchronization as long as every call to Pick returns
// a different instance of NextHost.
Pick(ExecutableQuery) NextHost Pick(ExecutableQuery) NextHost
} }
@ -880,6 +853,51 @@ func (d *dcAwareRR) Pick(q ExecutableQuery) NextHost {
return roundRobbin(int(nextStartOffset), d.localHosts.get(), d.remoteHosts.get()) return roundRobbin(int(nextStartOffset), d.localHosts.get(), d.remoteHosts.get())
} }
// ReadyPolicy defines a policy for when a HostSelectionPolicy can be used. After
// each host connects during session initialization, the Ready method will be
// called. If you only need a single Host to be up you can wrap a
// HostSelectionPolicy policy with SingleHostReadyPolicy.
type ReadyPolicy interface {
Ready() bool
}
// SingleHostReadyPolicy wraps a HostSelectionPolicy and returns Ready after a
// single host has been added via HostUp
func SingleHostReadyPolicy(p HostSelectionPolicy) *singleHostReadyPolicy {
return &singleHostReadyPolicy{
HostSelectionPolicy: p,
}
}
type singleHostReadyPolicy struct {
HostSelectionPolicy
ready bool
readyMux sync.Mutex
}
func (s *singleHostReadyPolicy) HostUp(host *HostInfo) {
s.HostSelectionPolicy.HostUp(host)
s.readyMux.Lock()
s.ready = true
s.readyMux.Unlock()
}
func (s *singleHostReadyPolicy) Ready() bool {
s.readyMux.Lock()
ready := s.ready
s.readyMux.Unlock()
if !ready {
return false
}
// in case the wrapped policy is also a ReadyPolicy, defer to that
if rdy, ok := s.HostSelectionPolicy.(ReadyPolicy); ok {
return rdy.Ready()
}
return true
}
// ConvictionPolicy interface is used by gocql to determine if a host should be // ConvictionPolicy interface is used by gocql to determine if a host should be
// marked as DOWN based on the error and host info // marked as DOWN based on the error and host info
type ConvictionPolicy interface { type ConvictionPolicy interface {

View File

@ -14,18 +14,6 @@ type preparedLRU struct {
lru *lru.Cache lru *lru.Cache
} }
// Max adjusts the maximum size of the cache and cleans up the oldest records if
// the new max is lower than the previous value. Not concurrency safe.
func (p *preparedLRU) max(max int) {
p.mu.Lock()
defer p.mu.Unlock()
for p.lru.Len() > max {
p.lru.RemoveOldest()
}
p.lru.MaxEntries = max
}
func (p *preparedLRU) clear() { func (p *preparedLRU) clear() {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()

View File

@ -2,6 +2,7 @@ package gocql
import ( import (
"context" "context"
"sync"
"time" "time"
) )
@ -34,14 +35,15 @@ func (q *queryExecutor) attemptQuery(ctx context.Context, qry ExecutableQuery, c
return iter return iter
} }
func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp SpeculativeExecutionPolicy, results chan *Iter) *Iter { func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp SpeculativeExecutionPolicy,
hostIter NextHost, results chan *Iter) *Iter {
ticker := time.NewTicker(sp.Delay()) ticker := time.NewTicker(sp.Delay())
defer ticker.Stop() defer ticker.Stop()
for i := 0; i < sp.Attempts(); i++ { for i := 0; i < sp.Attempts(); i++ {
select { select {
case <-ticker.C: case <-ticker.C:
go q.run(ctx, qry, results) go q.run(ctx, qry, hostIter, results)
case <-ctx.Done(): case <-ctx.Done():
return &Iter{err: ctx.Err()} return &Iter{err: ctx.Err()}
case iter := <-results: case iter := <-results:
@ -53,11 +55,23 @@ func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp S
} }
func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) { func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
hostIter := q.policy.Pick(qry)
// check if the query is not marked as idempotent, if // check if the query is not marked as idempotent, if
// it is, we force the policy to NonSpeculative // it is, we force the policy to NonSpeculative
sp := qry.speculativeExecutionPolicy() sp := qry.speculativeExecutionPolicy()
if !qry.IsIdempotent() || sp.Attempts() == 0 { if !qry.IsIdempotent() || sp.Attempts() == 0 {
return q.do(qry.Context(), qry), nil return q.do(qry.Context(), qry, hostIter), nil
}
// When speculative execution is enabled, we could be accessing the host iterator from multiple goroutines below.
// To ensure we don't call it concurrently, we wrap the returned NextHost function here to synchronize access to it.
var mu sync.Mutex
origHostIter := hostIter
hostIter = func() SelectedHost {
mu.Lock()
defer mu.Unlock()
return origHostIter()
} }
ctx, cancel := context.WithCancel(qry.Context()) ctx, cancel := context.WithCancel(qry.Context())
@ -66,12 +80,12 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
results := make(chan *Iter, 1) results := make(chan *Iter, 1)
// Launch the main execution // Launch the main execution
go q.run(ctx, qry, results) go q.run(ctx, qry, hostIter, results)
// The speculative executions are launched _in addition_ to the main // The speculative executions are launched _in addition_ to the main
// execution, on a timer. So Speculation{2} would make 3 executions running // execution, on a timer. So Speculation{2} would make 3 executions running
// in total. // in total.
if iter := q.speculate(ctx, qry, sp, results); iter != nil { if iter := q.speculate(ctx, qry, sp, hostIter, results); iter != nil {
return iter, nil return iter, nil
} }
@ -83,8 +97,7 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
} }
} }
func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery) *Iter { func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter NextHost) *Iter {
hostIter := q.policy.Pick(qry)
selectedHost := hostIter() selectedHost := hostIter()
rt := qry.retryPolicy() rt := qry.retryPolicy()
@ -153,9 +166,9 @@ func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery) *Iter {
return &Iter{err: ErrNoConnections} return &Iter{err: ErrNoConnections}
} }
func (q *queryExecutor) run(ctx context.Context, qry ExecutableQuery, results chan<- *Iter) { func (q *queryExecutor) run(ctx context.Context, qry ExecutableQuery, hostIter NextHost, results chan<- *Iter) {
select { select {
case results <- q.do(ctx, qry): case results <- q.do(ctx, qry, hostIter):
case <-ctx.Done(): case <-ctx.Done():
} }
} }

View File

@ -63,29 +63,6 @@ func (r *ring) currentHosts() map[string]*HostInfo {
return hosts return hosts
} }
func (r *ring) addHost(host *HostInfo) bool {
// TODO(zariel): key all host info by HostID instead of
// ip addresses
if host.invalidConnectAddr() {
panic(fmt.Sprintf("invalid host: %v", host))
}
ip := host.ConnectAddress().String()
r.mu.Lock()
if r.hosts == nil {
r.hosts = make(map[string]*HostInfo)
}
_, ok := r.hosts[ip]
if !ok {
r.hostList = append(r.hostList, host)
}
r.hosts[ip] = host
r.mu.Unlock()
return ok
}
func (r *ring) addOrUpdate(host *HostInfo) *HostInfo { func (r *ring) addOrUpdate(host *HostInfo) *HostInfo {
if existingHost, ok := r.addHostIfMissing(host); ok { if existingHost, ok := r.addHostIfMissing(host); ok {
existingHost.update(host) existingHost.update(host)

View File

@ -27,7 +27,7 @@ import (
// scenario is to have one global session object to interact with the // scenario is to have one global session object to interact with the
// whole Cassandra cluster. // whole Cassandra cluster.
// //
// This type extends the Node interface by adding a convinient query builder // This type extends the Node interface by adding a convenient query builder
// and automatically sets a default consistency level on all operations // and automatically sets a default consistency level on all operations
// that do not have a consistency level set. // that do not have a consistency level set.
type Session struct { type Session struct {
@ -62,7 +62,6 @@ type Session struct {
schemaEvents *eventDebouncer schemaEvents *eventDebouncer
// ring metadata // ring metadata
hosts []HostInfo
useSystemSchema bool useSystemSchema bool
hasAggregatesAndFunctions bool hasAggregatesAndFunctions bool
@ -227,18 +226,44 @@ func (s *Session) init() error {
} }
hosts = hosts[:0] hosts = hosts[:0]
// each host will increment left and decrement it after connecting and once
// there's none left, we'll close hostCh
var left int64
// we will receive up to len(hostMap) of messages so create a buffer so we
// don't end up stuck in a goroutine if we stopped listening
connectedCh := make(chan struct{}, len(hostMap))
// we add one here because we don't want to end up closing hostCh until we're
// done looping and the decerement code might be reached before we've looped
// again
atomic.AddInt64(&left, 1)
for _, host := range hostMap { for _, host := range hostMap {
host = s.ring.addOrUpdate(host) host := s.ring.addOrUpdate(host)
if s.cfg.filterHost(host) { if s.cfg.filterHost(host) {
continue continue
} }
host.setState(NodeUp) atomic.AddInt64(&left, 1)
go func() {
s.pool.addHost(host) s.pool.addHost(host)
connectedCh <- struct{}{}
// if there are no hosts left, then close the hostCh to unblock the loop
// below if its still waiting
if atomic.AddInt64(&left, -1) == 0 {
close(connectedCh)
}
}()
hosts = append(hosts, host) hosts = append(hosts, host)
} }
// once we're done looping we subtract the one we initially added and check
// to see if we should close
if atomic.AddInt64(&left, -1) == 0 {
close(connectedCh)
}
// before waiting for them to connect, add them all to the policy so we can
// utilize efficiencies by calling AddHosts if the policy supports it
type bulkAddHosts interface { type bulkAddHosts interface {
AddHosts([]*HostInfo) AddHosts([]*HostInfo)
} }
@ -250,6 +275,15 @@ func (s *Session) init() error {
} }
} }
readyPolicy, _ := s.policy.(ReadyPolicy)
// now loop over connectedCh until it's closed (meaning we've connected to all)
// or until the policy says we're ready
for range connectedCh {
if readyPolicy != nil && readyPolicy.Ready() {
break
}
}
// TODO(zariel): we probably dont need this any more as we verify that we // TODO(zariel): we probably dont need this any more as we verify that we
// can connect to one of the endpoints supplied by using the control conn. // can connect to one of the endpoints supplied by using the control conn.
// See if there are any connections in the pool // See if there are any connections in the pool
@ -320,7 +354,8 @@ func (s *Session) reconnectDownedHosts(intv time.Duration) {
if h.IsUp() { if h.IsUp() {
continue continue
} }
s.handleNodeUp(h.ConnectAddress(), h.Port(), true) // we let the pool call handleNodeUp to change the host state
s.pool.addHost(h)
} }
case <-s.ctx.Done(): case <-s.ctx.Done():
return return
@ -806,6 +841,7 @@ type Query struct {
trace Tracer trace Tracer
observer QueryObserver observer QueryObserver
session *Session session *Session
conn *Conn
rt RetryPolicy rt RetryPolicy
spec SpeculativeExecutionPolicy spec SpeculativeExecutionPolicy
binding func(q *QueryInfo) ([]interface{}, error) binding func(q *QueryInfo) ([]interface{}, error)
@ -1094,12 +1130,17 @@ func (q *Query) speculativeExecutionPolicy() SpeculativeExecutionPolicy {
return q.spec return q.spec
} }
// IsIdempotent returns whether the query is marked as idempotent.
// Non-idempotent query won't be retried.
// See "Retries and speculative execution" in package docs for more details.
func (q *Query) IsIdempotent() bool { func (q *Query) IsIdempotent() bool {
return q.idempotent return q.idempotent
} }
// Idempotent marks the query as being idempotent or not depending on // Idempotent marks the query as being idempotent or not depending on
// the value. // the value.
// Non-idempotent query won't be retried.
// See "Retries and speculative execution" in package docs for more details.
func (q *Query) Idempotent(value bool) *Query { func (q *Query) Idempotent(value bool) *Query {
q.idempotent = value q.idempotent = value
return q return q
@ -1164,6 +1205,11 @@ func (q *Query) Iter() *Iter {
if isUseStatement(q.stmt) { if isUseStatement(q.stmt) {
return &Iter{err: ErrUseStmt} return &Iter{err: ErrUseStmt}
} }
// if the query was specifically run on a connection then re-use that
// connection when fetching the next results
if q.conn != nil {
return q.conn.executeQuery(q.Context(), q)
}
return q.session.executeQuery(q) return q.session.executeQuery(q)
} }
@ -1195,6 +1241,10 @@ func (q *Query) Scan(dest ...interface{}) error {
// statement containing an IF clause). If the transaction fails because // statement containing an IF clause). If the transaction fails because
// the existing values did not match, the previous values will be stored // the existing values did not match, the previous values will be stored
// in dest. // in dest.
//
// As for INSERT .. IF NOT EXISTS, previous values will be returned as if
// SELECT * FROM. So using ScanCAS with INSERT is inherently prone to
// column mismatching. Use MapScanCAS to capture them safely.
func (q *Query) ScanCAS(dest ...interface{}) (applied bool, err error) { func (q *Query) ScanCAS(dest ...interface{}) (applied bool, err error) {
q.disableSkipMetadata = true q.disableSkipMetadata = true
iter := q.Iter() iter := q.Iter()
@ -1423,7 +1473,7 @@ func (iter *Iter) Scan(dest ...interface{}) bool {
} }
if iter.next != nil && iter.pos >= iter.next.pos { if iter.next != nil && iter.pos >= iter.next.pos {
go iter.next.fetch() iter.next.fetchAsync()
} }
// currently only support scanning into an expand tuple, such that its the same // currently only support scanning into an expand tuple, such that its the same
@ -1517,16 +1567,31 @@ func (iter *Iter) NumRows() int {
return iter.numRows return iter.numRows
} }
// nextIter holds state for fetching a single page in an iterator.
// single page might be attempted multiple times due to retries.
type nextIter struct { type nextIter struct {
qry *Query qry *Query
pos int pos int
oncea sync.Once
once sync.Once once sync.Once
next *Iter next *Iter
} }
func (n *nextIter) fetchAsync() {
n.oncea.Do(func() {
go n.fetch()
})
}
func (n *nextIter) fetch() *Iter { func (n *nextIter) fetch() *Iter {
n.once.Do(func() { n.once.Do(func() {
// if the query was specifically run on a connection then re-use that
// connection when fetching the next results
if n.qry.conn != nil {
n.next = n.qry.conn.executeQuery(n.qry.Context(), n.qry)
} else {
n.next = n.qry.session.executeQuery(n.qry) n.next = n.qry.session.executeQuery(n.qry)
}
}) })
return n.next return n.next
} }
@ -1536,7 +1601,6 @@ type Batch struct {
Entries []BatchEntry Entries []BatchEntry
Cons Consistency Cons Consistency
routingKey []byte routingKey []byte
routingKeyBuffer []byte
CustomPayload map[string][]byte CustomPayload map[string][]byte
rt RetryPolicy rt RetryPolicy
spec SpeculativeExecutionPolicy spec SpeculativeExecutionPolicy
@ -1733,7 +1797,7 @@ func (b *Batch) WithTimestamp(timestamp int64) *Batch {
func (b *Batch) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) { func (b *Batch) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) {
latency := end.Sub(start) latency := end.Sub(start)
_, metricsForHost := b.metrics.attempt(1, latency, host, b.observer != nil) attempt, metricsForHost := b.metrics.attempt(1, latency, host, b.observer != nil)
if b.observer == nil { if b.observer == nil {
return return
@ -1753,6 +1817,7 @@ func (b *Batch) attempt(keyspace string, end, start time.Time, iter *Iter, host
Host: host, Host: host,
Metrics: metricsForHost, Metrics: metricsForHost,
Err: iter.err, Err: iter.err,
Attempt: attempt,
}) })
} }
@ -1968,7 +2033,6 @@ type ObservedQuery struct {
Err error Err error
// Attempt is the index of attempt at executing this query. // Attempt is the index of attempt at executing this query.
// An attempt might be either retry or fetching next page of a query.
// The first attempt is number zero and any retries have non-zero attempt number. // The first attempt is number zero and any retries have non-zero attempt number.
Attempt int Attempt int
} }
@ -1999,6 +2063,10 @@ type ObservedBatch struct {
// The metrics per this host // The metrics per this host
Metrics *hostMetrics Metrics *hostMetrics
// Attempt is the index of attempt at executing this query.
// The first attempt is number zero and any retries have non-zero attempt number.
Attempt int
} }
// BatchObserver is the interface implemented by batch observers / stat collectors. // BatchObserver is the interface implemented by batch observers / stat collectors.

View File

@ -153,7 +153,7 @@ func newTokenRing(partitioner string, hosts []*HostInfo) (*tokenRing, error) {
} else if strings.HasSuffix(partitioner, "RandomPartitioner") { } else if strings.HasSuffix(partitioner, "RandomPartitioner") {
tokenRing.partitioner = randomPartitioner{} tokenRing.partitioner = randomPartitioner{}
} else { } else {
return nil, fmt.Errorf("Unsupported partitioner '%s'", partitioner) return nil, fmt.Errorf("unsupported partitioner '%s'", partitioner)
} }
for _, host := range hosts { for _, host := range hosts {

View File

@ -46,32 +46,35 @@ type placementStrategy interface {
replicationFactor(dc string) int replicationFactor(dc string) int
} }
func getReplicationFactorFromOpts(keyspace string, val interface{}) int { func getReplicationFactorFromOpts(val interface{}) (int, error) {
// TODO: dont really want to panic here, but is better
// than spamming
switch v := val.(type) { switch v := val.(type) {
case int: case int:
if v <= 0 { if v < 0 {
panic(fmt.Sprintf("invalid replication_factor %d. Is the %q keyspace configured correctly?", v, keyspace)) return 0, fmt.Errorf("invalid replication_factor %d", v)
} }
return v return v, nil
case string: case string:
n, err := strconv.Atoi(v) n, err := strconv.Atoi(v)
if err != nil { if err != nil {
panic(fmt.Sprintf("invalid replication_factor. Is the %q keyspace configured correctly? %v", keyspace, err)) return 0, fmt.Errorf("invalid replication_factor %q: %v", v, err)
} else if n <= 0 { } else if n < 0 {
panic(fmt.Sprintf("invalid replication_factor %d. Is the %q keyspace configured correctly?", n, keyspace)) return 0, fmt.Errorf("invalid replication_factor %d", n)
} }
return n return n, nil
default: default:
panic(fmt.Sprintf("unkown replication_factor type %T", v)) return 0, fmt.Errorf("unknown replication_factor type %T", v)
} }
} }
func getStrategy(ks *KeyspaceMetadata) placementStrategy { func getStrategy(ks *KeyspaceMetadata) placementStrategy {
switch { switch {
case strings.Contains(ks.StrategyClass, "SimpleStrategy"): case strings.Contains(ks.StrategyClass, "SimpleStrategy"):
return &simpleStrategy{rf: getReplicationFactorFromOpts(ks.Name, ks.StrategyOptions["replication_factor"])} rf, err := getReplicationFactorFromOpts(ks.StrategyOptions["replication_factor"])
if err != nil {
Logger.Printf("parse rf for keyspace %q: %v", ks.Name, err)
return nil
}
return &simpleStrategy{rf: rf}
case strings.Contains(ks.StrategyClass, "NetworkTopologyStrategy"): case strings.Contains(ks.StrategyClass, "NetworkTopologyStrategy"):
dcs := make(map[string]int) dcs := make(map[string]int)
for dc, rf := range ks.StrategyOptions { for dc, rf := range ks.StrategyOptions {
@ -79,14 +82,21 @@ func getStrategy(ks *KeyspaceMetadata) placementStrategy {
continue continue
} }
dcs[dc] = getReplicationFactorFromOpts(ks.Name+":dc="+dc, rf) rf, err := getReplicationFactorFromOpts(rf)
if err != nil {
Logger.Println("parse rf for keyspace %q, dc %q: %v", err)
// skip DC if the rf is invalid/unsupported, so that we can at least work with other working DCs.
continue
}
dcs[dc] = rf
} }
return &networkTopology{dcs: dcs} return &networkTopology{dcs: dcs}
case strings.Contains(ks.StrategyClass, "LocalStrategy"): case strings.Contains(ks.StrategyClass, "LocalStrategy"):
return nil return nil
default: default:
// TODO: handle unknown replicas and just return the primary host for a token Logger.Printf("parse rf for keyspace %q: unsupported strategy class: %v", ks.StrategyClass)
panic(fmt.Sprintf("unsupported strategy class: %v", ks.StrategyClass)) return nil
} }
} }

View File

@ -2,11 +2,12 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package gocql
// The uuid package can be used to generate and parse universally unique // The uuid package can be used to generate and parse universally unique
// identifiers, a standardized format in the form of a 128 bit number. // identifiers, a standardized format in the form of a 128 bit number.
// //
// http://tools.ietf.org/html/rfc4122 // http://tools.ietf.org/html/rfc4122
package gocql
import ( import (
"crypto/rand" "crypto/rand"

2
vendor/modules.txt vendored
View File

@ -402,7 +402,7 @@ github.com/go-stack/stack
github.com/go-test/deep github.com/go-test/deep
# github.com/go-yaml/yaml v2.1.0+incompatible # github.com/go-yaml/yaml v2.1.0+incompatible
github.com/go-yaml/yaml github.com/go-yaml/yaml
# github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e # github.com/gocql/gocql v0.0.0-20210401103645-80ab1e13e309
## explicit ## explicit
github.com/gocql/gocql github.com/gocql/gocql
github.com/gocql/gocql/internal/lru github.com/gocql/gocql/internal/lru