Validate hostnames when using TLS in Cassandra (#11365)

This commit is contained in:
Michael Golowka 2021-04-16 15:52:35 -06:00 committed by GitHub
parent 541ae8636c
commit 4279bc8b34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 2386 additions and 526 deletions

View File

@ -20,13 +20,18 @@ func TestBackend_basic(t *testing.T) {
t.Fatal(err)
}
cleanup, hostname := cassandra.PrepareTestContainer(t, "latest")
copyFromTo := map[string]string{
"test-fixtures/cassandra.yaml": "/etc/cassandra/cassandra.yaml",
}
host, cleanup := cassandra.PrepareTestContainer(t,
cassandra.CopyFromTo(copyFromTo),
)
defer cleanup()
logicaltest.Test(t, logicaltest.TestCase{
LogicalBackend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t, hostname),
testAccStepConfig(t, host.ConnectionURL()),
testAccStepRole(t),
testAccStepReadCreds(t, "test"),
},
@ -41,13 +46,17 @@ func TestBackend_roleCrud(t *testing.T) {
t.Fatal(err)
}
cleanup, hostname := cassandra.PrepareTestContainer(t, "latest")
copyFromTo := map[string]string{
"test-fixtures/cassandra.yaml": "/etc/cassandra/cassandra.yaml",
}
host, cleanup := cassandra.PrepareTestContainer(t,
cassandra.CopyFromTo(copyFromTo))
defer cleanup()
logicaltest.Test(t, logicaltest.TestCase{
LogicalBackend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t, hostname),
testAccStepConfig(t, host.ConnectionURL()),
testAccStepRole(t),
testAccStepRoleWithOptions(t),
testAccStepReadRole(t, "test", testRole),

3
changelog/11365.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:bug
secrets/database/cassandra: Fixed issue where hostnames were not being validated when using TLS
```

2
go.mod
View File

@ -49,7 +49,7 @@ require (
github.com/go-ole/go-ole v1.2.4 // indirect
github.com/go-sql-driver/mysql v1.5.0
github.com/go-test/deep v1.0.7
github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e
github.com/gocql/gocql v0.0.0-20210401103645-80ab1e13e309
github.com/golang/protobuf v1.4.2
github.com/google/go-github v17.0.0+incompatible
github.com/google/go-metrics-stackdriver v0.2.0

6
go.sum
View File

@ -442,8 +442,8 @@ github.com/gobuffalo/packd v0.1.0/go.mod h1:M2Juc+hhDXf/PnmBANFCqx4DM3wRbgDvnVWe
github.com/gobuffalo/packr/v2 v2.0.9/go.mod h1:emmyGweYTm6Kdper+iywB6YK5YzuKchGtJQZ0Odn4pQ=
github.com/gobuffalo/packr/v2 v2.2.0/go.mod h1:CaAwI0GPIAv+5wKLtv8Afwl+Cm78K/I/VCm/3ptBN+0=
github.com/gobuffalo/syncx v0.0.0-20190224160051-33c29581e754/go.mod h1:HhnNqWY95UYwwW3uSASeV7vtgYkT2t16hJgV3AEPUpw=
github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e h1:SroDcndcOU9BVAduPf/PXihXoR2ZYTQYLXbupbqxAyQ=
github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e/go.mod h1:DL0ekTmBSTdlNF25Orwt/JMzqIq3EJ4MVa/J/uK64OY=
github.com/gocql/gocql v0.0.0-20210401103645-80ab1e13e309 h1:8MHuCGYDXh0skFrLumkCMlt9C29hxhqNx39+Haemeqw=
github.com/gocql/gocql v0.0.0-20210401103645-80ab1e13e309/go.mod h1:DL0ekTmBSTdlNF25Orwt/JMzqIq3EJ4MVa/J/uK64OY=
github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4=
github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s=
@ -1084,6 +1084,7 @@ github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6So
github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.6.2 h1:aIihoIOHCiLZHxyoNQ+ABL4NKhFTgKLBdMLyEAh98m0=
github.com/rogpeppe/go-internal v1.6.2/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rs/zerolog v1.4.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
@ -1631,6 +1632,7 @@ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw=
gopkg.in/errgo.v2 v2.1.0 h1:0vLT13EuvQ0hNvakwLuFZ/jYrLp5F3kcWHXdRggjCE8=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=

View File

@ -2,9 +2,10 @@ package cassandra
import (
"context"
"errors"
"fmt"
"net"
"os"
"path/filepath"
"testing"
"time"
@ -12,33 +13,75 @@ import (
"github.com/hashicorp/vault/helper/testhelpers/docker"
)
func PrepareTestContainer(t *testing.T, version string) (func(), string) {
type containerConfig struct {
version string
copyFromTo map[string]string
sslOpts *gocql.SslOptions
}
type ContainerOpt func(*containerConfig)
func Version(version string) ContainerOpt {
return func(cfg *containerConfig) {
cfg.version = version
}
}
func CopyFromTo(copyFromTo map[string]string) ContainerOpt {
return func(cfg *containerConfig) {
cfg.copyFromTo = copyFromTo
}
}
func SslOpts(sslOpts *gocql.SslOptions) ContainerOpt {
return func(cfg *containerConfig) {
cfg.sslOpts = sslOpts
}
}
type Host struct {
Name string
Port string
}
func (h Host) ConnectionURL() string {
return net.JoinHostPort(h.Name, h.Port)
}
func PrepareTestContainer(t *testing.T, opts ...ContainerOpt) (Host, func()) {
t.Helper()
if os.Getenv("CASSANDRA_HOSTS") != "" {
return func() {}, os.Getenv("CASSANDRA_HOSTS")
host, port, err := net.SplitHostPort(os.Getenv("CASSANDRA_HOSTS"))
if err != nil {
t.Fatalf("Failed to split host & port from CASSANDRA_HOSTS (%s): %s", os.Getenv("CASSANDRA_HOSTS"), err)
}
h := Host{
Name: host,
Port: port,
}
return h, func() {}
}
if version == "" {
version = "3.11"
containerCfg := &containerConfig{
version: "3.11",
}
var copyFromTo map[string]string
cwd, _ := os.Getwd()
fixturePath := fmt.Sprintf("%s/test-fixtures/", cwd)
if _, err := os.Stat(fixturePath); err != nil {
if !errors.Is(err, os.ErrNotExist) {
// If it doesn't exist, no biggie
t.Fatal(err)
}
} else {
copyFromTo = map[string]string{
fixturePath: "/etc/cassandra",
for _, opt := range opts {
opt(containerCfg)
}
copyFromTo := map[string]string{}
for from, to := range containerCfg.copyFromTo {
absFrom, err := filepath.Abs(from)
if err != nil {
t.Fatalf("Unable to get absolute path for file %s", from)
}
copyFromTo[absFrom] = to
}
runner, err := docker.NewServiceRunner(docker.RunOptions{
ImageRepo: "cassandra",
ImageTag: version,
ImageTag: containerCfg.version,
Ports: []string{"9042/tcp"},
CopyFromTo: copyFromTo,
Env: []string{"CASSANDRA_BROADCAST_ADDRESS=127.0.0.1"},
@ -58,6 +101,8 @@ func PrepareTestContainer(t *testing.T, version string) (func(), string) {
clusterConfig.ProtoVersion = 4
clusterConfig.Port = port
clusterConfig.SslOpts = containerCfg.sslOpts
session, err := clusterConfig.CreateSession()
if err != nil {
return nil, fmt.Errorf("error creating session: %s", err)
@ -65,19 +110,19 @@ func PrepareTestContainer(t *testing.T, version string) (func(), string) {
defer session.Close()
// Create keyspace
q := session.Query(`CREATE KEYSPACE "vault" WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };`)
if err := q.Exec(); err != nil {
query := session.Query(`CREATE KEYSPACE "vault" WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };`)
if err := query.Exec(); err != nil {
t.Fatalf("could not create cassandra keyspace: %v", err)
}
// Create table
q = session.Query(`CREATE TABLE "vault"."entries" (
query = session.Query(`CREATE TABLE "vault"."entries" (
bucket text,
key text,
value blob,
PRIMARY KEY (bucket, key)
) WITH CLUSTERING ORDER BY (key ASC);`)
if err := q.Exec(); err != nil {
if err := query.Exec(); err != nil {
t.Fatalf("could not create cassandra table: %v", err)
}
return cfg, nil
@ -85,5 +130,14 @@ func PrepareTestContainer(t *testing.T, version string) (func(), string) {
if err != nil {
t.Fatalf("Could not start docker cassandra: %s", err)
}
return svc.Cleanup, svc.Config.Address()
host, port, err := net.SplitHostPort(svc.Config.Address())
if err != nil {
t.Fatalf("Failed to split host & port from address (%s): %s", svc.Config.Address(), err)
}
h := Host{
Name: host,
Port: port,
}
return h, svc.Cleanup
}

View File

@ -10,11 +10,10 @@ import (
"strings"
"time"
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
metrics "github.com/armon/go-metrics"
"github.com/gocql/gocql"
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/physical"
)
@ -180,20 +179,18 @@ func setupCassandraTLS(conf map[string]string, cluster *gocql.ClusterConfig) err
if err != nil {
return err
}
} else {
if pemJSONPath, ok := conf["pem_json_file"]; ok {
pemJSONData, err := ioutil.ReadFile(pemJSONPath)
if err != nil {
return errwrap.Wrapf(fmt.Sprintf("error reading json bundle from %q: {{err}}", pemJSONPath), err)
}
pemJSON, err := certutil.ParsePKIJSON([]byte(pemJSONData))
if err != nil {
return err
}
tlsConfig, err = pemJSON.GetTLSConfig(certutil.TLSClient)
if err != nil {
return err
}
} else if pemJSONPath, ok := conf["pem_json_file"]; ok {
pemJSONData, err := ioutil.ReadFile(pemJSONPath)
if err != nil {
return errwrap.Wrapf(fmt.Sprintf("error reading json bundle from %q: {{err}}", pemJSONPath), err)
}
pemJSON, err := certutil.ParsePKIJSON([]byte(pemJSONData))
if err != nil {
return err
}
tlsConfig, err = pemJSON.GetTLSConfig(certutil.TLSClient)
if err != nil {
return err
}
}
@ -225,7 +222,8 @@ func setupCassandraTLS(conf map[string]string, cluster *gocql.ClusterConfig) err
}
cluster.SslOpts = &gocql.SslOptions{
Config: tlsConfig.Clone(),
Config: tlsConfig,
EnableHostVerification: !tlsConfig.InsecureSkipVerify,
}
return nil
}

View File

@ -19,13 +19,13 @@ func TestCassandraBackend(t *testing.T) {
t.Skip("skipping race test in CI pending https://github.com/gocql/gocql/pull/1474")
}
cleanup, hosts := cassandra.PrepareTestContainer(t, "")
host, cleanup := cassandra.PrepareTestContainer(t)
defer cleanup()
// Run vault tests
logger := logging.NewVaultLogger(log.Debug)
b, err := NewCassandraBackend(map[string]string{
"hosts": hosts,
"hosts": host.ConnectionURL(),
"protocol_version": "3",
}, logger)
if err != nil {

View File

@ -3,7 +3,6 @@ package cassandra
import (
"context"
"reflect"
"strings"
"testing"
"time"
@ -17,14 +16,16 @@ import (
)
func getCassandra(t *testing.T, protocolVersion interface{}) (*Cassandra, func()) {
cleanup, connURL := cassandra.PrepareTestContainer(t, "latest")
pieces := strings.Split(connURL, ":")
host, cleanup := cassandra.PrepareTestContainer(t,
cassandra.Version("latest"),
cassandra.CopyFromTo(insecureFileMounts),
)
db := new()
initReq := dbplugin.InitializeRequest{
Config: map[string]interface{}{
"hosts": connURL,
"port": pieces[1],
"hosts": host.ConnectionURL(),
"port": host.Port,
"username": "cassandra",
"password": "cassandra",
"protocol_version": protocolVersion,
@ -34,8 +35,8 @@ func getCassandra(t *testing.T, protocolVersion interface{}) (*Cassandra, func()
}
expectedConfig := map[string]interface{}{
"hosts": connURL,
"port": pieces[1],
"hosts": host.ConnectionURL(),
"port": host.Port,
"username": "cassandra",
"password": "cassandra",
"protocol_version": protocolVersion,
@ -53,7 +54,7 @@ func getCassandra(t *testing.T, protocolVersion interface{}) (*Cassandra, func()
return db, cleanup
}
func TestCassandra_Initialize(t *testing.T) {
func TestInitialize(t *testing.T) {
db, cleanup := getCassandra(t, 4)
defer cleanup()
@ -66,7 +67,7 @@ func TestCassandra_Initialize(t *testing.T) {
defer cleanup()
}
func TestCassandra_CreateUser(t *testing.T) {
func TestCreateUser(t *testing.T) {
type testCase struct {
// Config will have the hosts & port added to it during the test
config map[string]interface{}
@ -126,15 +127,17 @@ func TestCassandra_CreateUser(t *testing.T) {
for name, test := range tests {
t.Run(name, func(t *testing.T) {
cleanup, connURL := cassandra.PrepareTestContainer(t, "latest")
pieces := strings.Split(connURL, ":")
host, cleanup := cassandra.PrepareTestContainer(t,
cassandra.Version("latest"),
cassandra.CopyFromTo(insecureFileMounts),
)
defer cleanup()
db := new()
config := test.config
config["hosts"] = connURL
config["port"] = pieces[1]
config["hosts"] = host.ConnectionURL()
config["port"] = host.Port
initReq := dbplugin.InitializeRequest{
Config: config,
@ -162,7 +165,7 @@ func TestCassandra_CreateUser(t *testing.T) {
}
}
func TestMyCassandra_UpdateUserPassword(t *testing.T) {
func TestUpdateUserPassword(t *testing.T) {
db, cleanup := getCassandra(t, 4)
defer cleanup()
@ -198,7 +201,7 @@ func TestMyCassandra_UpdateUserPassword(t *testing.T) {
assertCreds(t, db.Hosts, db.Port, createResp.Username, newPassword, 5*time.Second)
}
func TestCassandra_DeleteUser(t *testing.T) {
func TestDeleteUser(t *testing.T) {
db, cleanup := getCassandra(t, 4)
defer cleanup()

View File

@ -8,14 +8,13 @@ import (
"sync"
"time"
"github.com/gocql/gocql"
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/hashicorp/vault/sdk/database/helper/connutil"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/sdk/helper/certutil"
"github.com/hashicorp/vault/sdk/helper/parseutil"
"github.com/hashicorp/vault/sdk/helper/tlsutil"
"github.com/gocql/gocql"
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/mitchellh/mapstructure"
)
@ -40,9 +39,7 @@ type cassandraConnectionProducer struct {
connectTimeout time.Duration
socketKeepAlive time.Duration
certificate string
privateKey string
issuingCA string
certBundle *certutil.CertBundle
rawConfig map[string]interface{}
Initialized bool
@ -99,9 +96,7 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, req dbplug
if err != nil {
return fmt.Errorf("error marshaling PEM information: %w", err)
}
c.certificate = certBundle.Certificate
c.privateKey = certBundle.PrivateKey
c.issuingCA = certBundle.IssuingCA
c.certBundle = certBundle
c.TLS = true
case len(c.PemBundle) != 0:
@ -113,9 +108,11 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, req dbplug
if err != nil {
return fmt.Errorf("error marshaling PEM information: %w", err)
}
c.certificate = certBundle.Certificate
c.privateKey = certBundle.PrivateKey
c.issuingCA = certBundle.IssuingCA
c.certBundle = certBundle
c.TLS = true
}
if c.InsecureTLS {
c.TLS = true
}
@ -185,49 +182,13 @@ func (c *cassandraConnectionProducer) createSession(ctx context.Context) (*gocql
clusterConfig.Timeout = c.connectTimeout
clusterConfig.SocketKeepalive = c.socketKeepAlive
if c.TLS {
var tlsConfig *tls.Config
if len(c.certificate) > 0 || len(c.issuingCA) > 0 {
if len(c.certificate) > 0 && len(c.privateKey) == 0 {
return nil, fmt.Errorf("found certificate for TLS authentication but no private key")
}
certBundle := &certutil.CertBundle{}
if len(c.certificate) > 0 {
certBundle.Certificate = c.certificate
certBundle.PrivateKey = c.privateKey
}
if len(c.issuingCA) > 0 {
certBundle.IssuingCA = c.issuingCA
}
parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil {
return nil, fmt.Errorf("failed to parse certificate bundle: %w", err)
}
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
if err != nil || tlsConfig == nil {
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%w", tlsConfig, err)
}
tlsConfig.InsecureSkipVerify = c.InsecureTLS
if c.TLSMinVersion != "" {
var ok bool
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[c.TLSMinVersion]
if !ok {
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
}
} else {
// MinVersion was not being set earlier. Reset it to
// zero to gracefully handle upgrades.
tlsConfig.MinVersion = 0
}
}
clusterConfig.SslOpts = &gocql.SslOptions{
Config: tlsConfig,
sslOpts, err := getSslOpts(c.certBundle, c.TLSMinVersion, c.InsecureTLS)
if err != nil {
return nil, err
}
clusterConfig.SslOpts = sslOpts
}
if c.LocalDatacenter != "" {
@ -269,6 +230,48 @@ func (c *cassandraConnectionProducer) createSession(ctx context.Context) (*gocql
return session, nil
}
func getSslOpts(certBundle *certutil.CertBundle, minTLSVersion string, insecureSkipVerify bool) (*gocql.SslOptions, error) {
tlsConfig := &tls.Config{}
if certBundle != nil {
if certBundle.Certificate == "" && certBundle.PrivateKey != "" {
return nil, fmt.Errorf("found private key for TLS authentication but no certificate")
}
if certBundle.Certificate != "" && certBundle.PrivateKey == "" {
return nil, fmt.Errorf("found certificate for TLS authentication but no private key")
}
parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil {
return nil, fmt.Errorf("failed to parse certificate bundle: %w", err)
}
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
if err != nil {
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%w", tlsConfig, err)
}
}
tlsConfig.InsecureSkipVerify = insecureSkipVerify
if minTLSVersion != "" {
var ok bool
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[minTLSVersion]
if !ok {
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
}
} else {
// MinVersion was not being set earlier. Reset it to
// zero to gracefully handle upgrades.
tlsConfig.MinVersion = 0
}
opts := &gocql.SslOptions{
Config: tlsConfig,
EnableHostVerification: !insecureSkipVerify,
}
return opts, nil
}
func (c *cassandraConnectionProducer) secretValues() map[string]string {
return map[string]string{
c.Password: "[password]",

View File

@ -0,0 +1,95 @@
package cassandra
import (
"context"
"crypto/tls"
"testing"
"time"
"github.com/gocql/gocql"
"github.com/hashicorp/vault/helper/testhelpers/cassandra"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5"
)
var (
insecureFileMounts = map[string]string{
"test-fixtures/no_tls/cassandra.yaml": "/etc/cassandra/cassandra.yaml",
}
secureFileMounts = map[string]string{
"test-fixtures/with_tls/cassandra.yaml": "/etc/cassandra/cassandra.yaml",
"test-fixtures/with_tls/keystore.jks": "/etc/cassandra/keystore.jks",
"test-fixtures/with_tls/.cassandra": "/root/.cassandra/",
}
)
func TestTLSConnection(t *testing.T) {
type testCase struct {
config map[string]interface{}
expectErr bool
}
tests := map[string]testCase{
"tls not specified": {
config: map[string]interface{}{},
expectErr: true,
},
"unrecognized certificate": {
config: map[string]interface{}{
"tls": "true",
},
expectErr: true,
},
"insecure TLS": {
config: map[string]interface{}{
"tls": "true",
"insecure_tls": true,
},
expectErr: false,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
host, cleanup := cassandra.PrepareTestContainer(t,
cassandra.Version("3.11.9"),
cassandra.CopyFromTo(secureFileMounts),
cassandra.SslOpts(&gocql.SslOptions{
Config: &tls.Config{InsecureSkipVerify: true},
EnableHostVerification: false,
}),
)
defer cleanup()
// Set values that we don't know until the cassandra container is started
config := map[string]interface{}{
"hosts": host.ConnectionURL(),
"port": host.Port,
"username": "cassandra",
"password": "cassandra",
"protocol_version": "3",
"connect_timeout": "20s",
}
// Then add any values specified in the test config. Generally for these tests they shouldn't overlap
for k, v := range test.config {
config[k] = v
}
db := new()
initReq := dbplugin.InitializeRequest{
Config: config,
VerifyConnection: true,
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err := db.Initialize(ctx, initReq)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
})
}
}

View File

@ -0,0 +1,3 @@
[ssl]
validate = false
version = SSLv23

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,46 @@
#!/bin/sh
################################################################
# Usage: ./gencert.sh
#
# Generates a keystore.jks file that can be used with a
# Cassandra server for TLS connections. This does not update
# a cassandra config file.
################################################################
set -e
KEYFILE="key.pem"
CERTFILE="cert.pem"
PKCSFILE="keystore.p12"
JKSFILE="keystore.jks"
HOST="127.0.0.1"
NAME="cassandra"
ALIAS="cassandra"
PASSWORD="cassandra"
echo "# Generating certificate keypair..."
go run /usr/local/go/src/crypto/tls/generate_cert.go --host=${HOST}
echo "# Creating keystore..."
openssl pkcs12 -export -in ${CERTFILE} -inkey ${KEYFILE} -name ${NAME} -password pass:${PASSWORD} > ${PKCSFILE}
echo "# Creating Java key store"
if [ -e "${JKSFILE}" ]; then
echo "# Removing old key store"
rm ${JKSFILE}
fi
set +e
keytool -importkeystore \
-srckeystore ${PKCSFILE} \
-srcstoretype PKCS12 \
-srcstorepass ${PASSWORD} \
-destkeystore ${JKSFILE} \
-deststorepass ${PASSWORD} \
-destkeypass ${PASSWORD} \
-alias ${ALIAS}
echo "# Removing intermediate files"
rm ${KEYFILE} ${CERTFILE} ${PKCSFILE}

Binary file not shown.

View File

@ -31,8 +31,10 @@ env:
AUTH=false
go:
- 1.13.x
- 1.14.x
- 1.15.x
- 1.16.x
go_import_path: github.com/gocql/gocql
install:
- ./install_test_deps.sh $TRAVIS_REPO_SLUG

View File

@ -115,3 +115,8 @@ Pavel Buchinchik <p.buchinchik@gmail.com>
Rintaro Okamura <rintaro.okamura@gmail.com>
Yura Sokolov <y.sokolov@joom.com>; <funny.falcon@gmail.com>
Jorge Bay <jorgebg@apache.org>
Dmitriy Kozlov <hummerd@mail.ru>
Alexey Romanovsky <alexus1024+gocql@gmail.com>
Jaume Marhuenda Beltran <jaumemarhuenda@gmail.com>
Piotr Dulikowski <piodul@scylladb.com>
Árni Dagur <arni@dagur.eu>

View File

@ -19,8 +19,8 @@ The following matrix shows the versions of Go and Cassandra that are tested with
Go/Cassandra | 2.1.x | 2.2.x | 3.x.x
-------------| -------| ------| ---------
1.13 | yes | yes | yes
1.14 | yes | yes | yes
1.15 | yes | yes | yes
1.16 | yes | yes | yes
Gocql has been tested in production against many different versions of Cassandra. Due to limits in our CI setup we only test against the latest 3 major releases, which coincide with the official support from the Apache project.
@ -114,73 +114,7 @@ statement.
Example
-------
```go
/* Before you execute the program, Launch `cqlsh` and execute:
create keyspace example with replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };
create table example.tweet(timeline text, id UUID, text text, PRIMARY KEY(id));
create index on example.tweet(timeline);
*/
package main
import (
"fmt"
"log"
"github.com/gocql/gocql"
)
func main() {
// connect to the cluster
cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
cluster.Keyspace = "example"
cluster.Consistency = gocql.Quorum
session, _ := cluster.CreateSession()
defer session.Close()
// insert a tweet
if err := session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
"me", gocql.TimeUUID(), "hello world").Exec(); err != nil {
log.Fatal(err)
}
var id gocql.UUID
var text string
/* Search for a specific set of records whose 'timeline' column matches
* the value 'me'. The secondary index that we created earlier will be
* used for optimizing the search */
if err := session.Query(`SELECT id, text FROM tweet WHERE timeline = ? LIMIT 1`,
"me").Consistency(gocql.One).Scan(&id, &text); err != nil {
log.Fatal(err)
}
fmt.Println("Tweet:", id, text)
// list all tweets
iter := session.Query(`SELECT id, text FROM tweet WHERE timeline = ?`, "me").Iter()
for iter.Scan(&id, &text) {
fmt.Println("Tweet:", id, text)
}
if err := iter.Close(); err != nil {
log.Fatal(err)
}
}
```
Authentication
-------
```go
cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
cluster.Authenticator = gocql.PasswordAuthenticator{
Username: "user",
Password: "password"
}
cluster.Keyspace = "example"
cluster.Consistency = gocql.Quorum
session, _ := cluster.CreateSession()
defer session.Close()
```
See [package documentation](https://pkg.go.dev/github.com/gocql/gocql#pkg-examples).
Data Binding
------------

View File

@ -44,8 +44,8 @@ func approve(authenticator string) bool {
return false
}
//JoinHostPort is a utility to return a address string that can be used
//gocql.Conn to form a connection with a host.
// JoinHostPort is a utility to return an address string that can be used
// by `gocql.Conn` to form a connection with a host.
func JoinHostPort(addr string, port int) string {
addr = strings.TrimSpace(addr)
if _, _, err := net.SplitHostPort(addr); err != nil {
@ -80,6 +80,19 @@ func (p PasswordAuthenticator) Success(data []byte) error {
return nil
}
// SslOptions configures TLS use.
//
// Warning: Due to historical reasons, the SslOptions is insecure by default, so you need to set EnableHostVerification
// to true if no Config is set. Most users should set SslOptions.Config to a *tls.Config.
// SslOptions and Config.InsecureSkipVerify interact as follows:
//
// Config.InsecureSkipVerify | EnableHostVerification | Result
// Config is nil | false | do not verify host
// Config is nil | true | verify host
// false | false | verify host
// true | false | do not verify host
// false | true | verify host
// true | true | verify host
type SslOptions struct {
*tls.Config
@ -89,9 +102,12 @@ type SslOptions struct {
CertPath string
KeyPath string
CaPath string //optional depending on server config
// If you want to verify the hostname and server cert (like a wildcard for cass cluster) then you should turn this on
// This option is basically the inverse of InSecureSkipVerify
// See InSecureSkipVerify in http://golang.org/pkg/crypto/tls/ for more info
// If you want to verify the hostname and server cert (like a wildcard for cass cluster) then you should turn this
// on.
// This option is basically the inverse of tls.Config.InsecureSkipVerify.
// See InsecureSkipVerify in http://golang.org/pkg/crypto/tls/ for more info.
//
// See SslOptions documentation to see how EnableHostVerification interacts with the provided tls.Config.
EnableHostVerification bool
}
@ -125,7 +141,7 @@ func (fn connErrorHandlerFn) HandleError(conn *Conn, err error, closed bool) {
// which may be serving more queries just fine.
// Default is 0, should not be changed concurrently with queries.
//
// depreciated
// Deprecated.
var TimeoutLimit int64 = 0
// Conn is a single connection to a Cassandra node. It can be used to execute
@ -213,14 +229,26 @@ func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg *
dialer = d
}
conn, err := dialer.DialContext(ctx, "tcp", host.HostnameAndPort())
addr := host.HostnameAndPort()
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, err
}
if cfg.tlsConfig != nil {
// the TLS config is safe to be reused by connections but it must not
// be modified after being used.
tconn := tls.Client(conn, cfg.tlsConfig)
tlsConfig := cfg.tlsConfig
if !tlsConfig.InsecureSkipVerify && tlsConfig.ServerName == "" {
colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
colonPos = len(addr)
}
hostname := addr[:colonPos]
// clone config to avoid modifying the shared one.
tlsConfig = tlsConfig.Clone()
tlsConfig.ServerName = hostname
}
tconn := tls.Client(conn, tlsConfig)
if err := tconn.Handshake(); err != nil {
conn.Close()
return nil, err
@ -845,6 +873,10 @@ func (w *writeCoalescer) writeFlusher(interval time.Duration) {
}
func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*framer, error) {
if ctxErr := ctx.Err(); ctxErr != nil {
return nil, ctxErr
}
// TODO: move tracer onto conn
stream, ok := c.streams.GetStream()
if !ok {
@ -1173,12 +1205,16 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
}
if x.meta.morePages() && !qry.disableAutoPage {
newQry := new(Query)
*newQry = *qry
newQry.pageState = copyBytes(x.meta.pagingState)
newQry.metrics = &queryMetrics{m: make(map[string]*hostMetrics)}
iter.next = &nextIter{
qry: qry,
qry: newQry,
pos: int((1 - qry.prefetch) * float64(x.numRows)),
}
iter.next.qry.pageState = copyBytes(x.meta.pagingState)
if iter.next.pos < 1 {
iter.next.pos = 1
}
@ -1359,10 +1395,11 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter {
}
func (c *Conn) query(ctx context.Context, statement string, values ...interface{}) (iter *Iter) {
q := c.session.Query(statement, values...).Consistency(One)
q.trace = nil
q := c.session.Query(statement, values...).Consistency(One).Trace(nil)
q.skipPrepare = true
q.disableSkipMetadata = true
// we want to keep the query on this connection
q.conn = c
return c.executeQuery(ctx, q)
}

View File

@ -28,14 +28,31 @@ type SetPartitioner interface {
}
func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
// Config.InsecureSkipVerify | EnableHostVerification | Result
// Config is nil | true | verify host
// Config is nil | false | do not verify host
// false | false | verify host
// true | false | do not verify host
// false | true | verify host
// true | true | verify host
var tlsConfig *tls.Config
if sslOpts.Config == nil {
sslOpts.Config = &tls.Config{}
tlsConfig = &tls.Config{
InsecureSkipVerify: !sslOpts.EnableHostVerification,
}
} else {
// use clone to avoid race.
tlsConfig = sslOpts.Config.Clone()
}
if tlsConfig.InsecureSkipVerify && sslOpts.EnableHostVerification {
tlsConfig.InsecureSkipVerify = false
}
// ca cert is optional
if sslOpts.CaPath != "" {
if sslOpts.RootCAs == nil {
sslOpts.RootCAs = x509.NewCertPool()
if tlsConfig.RootCAs == nil {
tlsConfig.RootCAs = x509.NewCertPool()
}
pem, err := ioutil.ReadFile(sslOpts.CaPath)
@ -43,7 +60,7 @@ func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
return nil, fmt.Errorf("connectionpool: unable to open CA certs: %v", err)
}
if !sslOpts.RootCAs.AppendCertsFromPEM(pem) {
if !tlsConfig.RootCAs.AppendCertsFromPEM(pem) {
return nil, errors.New("connectionpool: failed parsing or CA certs")
}
}
@ -53,13 +70,10 @@ func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
if err != nil {
return nil, fmt.Errorf("connectionpool: unable to load X509 key pair: %v", err)
}
sslOpts.Certificates = append(sslOpts.Certificates, mycert)
tlsConfig.Certificates = append(tlsConfig.Certificates, mycert)
}
sslOpts.InsecureSkipVerify = !sslOpts.EnableHostVerification
// return clone to avoid race
return sslOpts.Config.Clone(), nil
return tlsConfig, nil
}
type policyConnPool struct {
@ -238,12 +252,6 @@ func (p *policyConnPool) removeHost(ip net.IP) {
go pool.Close()
}
func (p *policyConnPool) hostUp(host *HostInfo) {
// TODO(zariel): have a set of up hosts and down hosts, we can internally
// detect down hosts, then try to reconnect to them.
p.addHost(host)
}
func (p *policyConnPool) hostDown(ip net.IP) {
// TODO(zariel): mark host as down so we can try to connect to it later, for
// now just treat it has removed.
@ -429,6 +437,8 @@ func (pool *hostConnPool) fill() {
}
return
}
// notify the session that this node is connected
go pool.session.handleNodeUp(pool.host.ConnectAddress(), pool.port)
// filled one
fillCount--
@ -440,6 +450,11 @@ func (pool *hostConnPool) fill() {
// mark the end of filling
pool.fillingStopped(err != nil)
if err == nil && startCount > 0 {
// notify the session that this node is connected again
go pool.session.handleNodeUp(pool.host.ConnectAddress(), pool.port)
}
}()
}

View File

@ -125,7 +125,7 @@ func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) {
if err != nil {
return nil, err
} else if len(ips) == 0 {
return nil, fmt.Errorf("No IP's returned from DNS lookup for %q", addr)
return nil, fmt.Errorf("no IP's returned from DNS lookup for %q", addr)
}
// Filter to v4 addresses if any present
@ -177,7 +177,7 @@ func (c *controlConn) shuffleDial(endpoints []*HostInfo) (*Conn, error) {
return conn, nil
}
Logger.Printf("gocql: unable to dial control conn %v: %v\n", host.ConnectAddress(), err)
Logger.Printf("gocql: unable to dial control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
}
return nil, err
@ -285,8 +285,6 @@ func (c *controlConn) setupConn(conn *Conn) error {
}
c.conn.Store(ch)
c.session.handleNodeUp(host.ConnectAddress(), host.Port(), false)
return nil
}
@ -452,6 +450,8 @@ func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter
for {
iter = c.withConn(func(conn *Conn) *Iter {
// we want to keep the query on the control connection
q.conn = conn
return conn.executeQuery(context.TODO(), q)
})

317
vendor/github.com/gocql/gocql/doc.go generated vendored
View File

@ -4,6 +4,319 @@
// Package gocql implements a fast and robust Cassandra driver for the
// Go programming language.
//
// Connecting to the cluster
//
// Pass a list of initial node IP addresses to NewCluster to create a new cluster configuration:
//
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
//
// Port can be specified as part of the address, the above is equivalent to:
//
// cluster := gocql.NewCluster("192.168.1.1:9042", "192.168.1.2:9042", "192.168.1.3:9042")
//
// It is recommended to use the value set in the Cassandra config for broadcast_address or listen_address,
// an IP address not a domain name. This is because events from Cassandra will use the configured IP
// address, which is used to index connected hosts. If the domain name specified resolves to more than 1 IP address
// then the driver may connect multiple times to the same host, and will not mark the node being down or up from events.
//
// Then you can customize more options (see ClusterConfig):
//
// cluster.Keyspace = "example"
// cluster.Consistency = gocql.Quorum
// cluster.ProtoVersion = 4
//
// The driver tries to automatically detect the protocol version to use if not set, but you might want to set the
// protocol version explicitly, as it's not defined which version will be used in certain situations (for example
// during upgrade of the cluster when some of the nodes support different set of protocol versions than other nodes).
//
// When ready, create a session from the configuration. Don't forget to Close the session once you are done with it:
//
// session, err := cluster.CreateSession()
// if err != nil {
// return err
// }
// defer session.Close()
//
// Authentication
//
// CQL protocol uses a SASL-based authentication mechanism and so consists of an exchange of server challenges and
// client response pairs. The details of the exchanged messages depend on the authenticator used.
//
// To use authentication, set ClusterConfig.Authenticator or ClusterConfig.AuthProvider.
//
// PasswordAuthenticator is provided to use for username/password authentication:
//
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
// cluster.Authenticator = gocql.PasswordAuthenticator{
// Username: "user",
// Password: "password"
// }
// session, err := cluster.CreateSession()
// if err != nil {
// return err
// }
// defer session.Close()
//
// Transport layer security
//
// It is possible to secure traffic between the client and server with TLS.
//
// To use TLS, set the ClusterConfig.SslOpts field. SslOptions embeds *tls.Config so you can set that directly.
// There are also helpers to load keys/certificates from files.
//
// Warning: Due to historical reasons, the SslOptions is insecure by default, so you need to set EnableHostVerification
// to true if no Config is set. Most users should set SslOptions.Config to a *tls.Config.
// SslOptions and Config.InsecureSkipVerify interact as follows:
//
// Config.InsecureSkipVerify | EnableHostVerification | Result
// Config is nil | false | do not verify host
// Config is nil | true | verify host
// false | false | verify host
// true | false | do not verify host
// false | true | verify host
// true | true | verify host
//
// For example:
//
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
// cluster.SslOpts = &gocql.SslOptions{
// EnableHostVerification: true,
// }
// session, err := cluster.CreateSession()
// if err != nil {
// return err
// }
// defer session.Close()
//
// Executing queries
//
// Create queries with Session.Query. Query values must not be reused between different executions and must not be
// modified after starting execution of the query.
//
// To execute a query without reading results, use Query.Exec:
//
// err := session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
// "me", gocql.TimeUUID(), "hello world").WithContext(ctx).Exec()
//
// Single row can be read by calling Query.Scan:
//
// err := session.Query(`SELECT id, text FROM tweet WHERE timeline = ? LIMIT 1`,
// "me").WithContext(ctx).Consistency(gocql.One).Scan(&id, &text)
//
// Multiple rows can be read using Iter.Scanner:
//
// scanner := session.Query(`SELECT id, text FROM tweet WHERE timeline = ?`,
// "me").WithContext(ctx).Iter().Scanner()
// for scanner.Next() {
// var (
// id gocql.UUID
// text string
// )
// err = scanner.Scan(&id, &text)
// if err != nil {
// log.Fatal(err)
// }
// fmt.Println("Tweet:", id, text)
// }
// // scanner.Err() closes the iterator, so scanner nor iter should be used afterwards.
// if err := scanner.Err(); err != nil {
// log.Fatal(err)
// }
//
// See Example for complete example.
//
// Prepared statements
//
// The driver automatically prepares DML queries (SELECT/INSERT/UPDATE/DELETE/BATCH statements) and maintains a cache
// of prepared statements.
// CQL protocol does not support preparing other query types.
//
// When using CQL protocol >= 4, it is possible to use gocql.UnsetValue as the bound value of a column.
// This will cause the database to ignore writing the column.
// The main advantage is the ability to keep the same prepared statement even when you don't
// want to update some fields, where before you needed to make another prepared statement.
//
// Executing multiple queries concurrently
//
// Session is safe to use from multiple goroutines, so to execute multiple concurrent queries, just execute them
// from several worker goroutines. Gocql provides synchronously-looking API (as recommended for Go APIs) and the queries
// are executed asynchronously at the protocol level.
//
// results := make(chan error, 2)
// go func() {
// results <- session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
// "me", gocql.TimeUUID(), "hello world 1").Exec()
// }()
// go func() {
// results <- session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
// "me", gocql.TimeUUID(), "hello world 2").Exec()
// }()
//
// Nulls
//
// Null values are are unmarshalled as zero value of the type. If you need to distinguish for example between text
// column being null and empty string, you can unmarshal into *string variable instead of string.
//
// var text *string
// err := scanner.Scan(&text)
// if err != nil {
// // handle error
// }
// if text != nil {
// // not null
// }
// else {
// // null
// }
//
// See Example_nulls for full example.
//
// Reusing slices
//
// The driver reuses backing memory of slices when unmarshalling. This is an optimization so that a buffer does not
// need to be allocated for every processed row. However, you need to be careful when storing the slices to other
// memory structures.
//
// scanner := session.Query(`SELECT myints FROM table WHERE pk = ?`, "key").WithContext(ctx).Iter().Scanner()
// var myInts []int
// for scanner.Next() {
// // This scan reuses backing store of myInts for each row.
// err = scanner.Scan(&myInts)
// if err != nil {
// log.Fatal(err)
// }
// }
//
// When you want to save the data for later use, pass a new slice every time. A common pattern is to declare the
// slice variable within the scanner loop:
//
// scanner := session.Query(`SELECT myints FROM table WHERE pk = ?`, "key").WithContext(ctx).Iter().Scanner()
// for scanner.Next() {
// var myInts []int
// // This scan always gets pointer to fresh myInts slice, so does not reuse memory.
// err = scanner.Scan(&myInts)
// if err != nil {
// log.Fatal(err)
// }
// }
//
// Paging
//
// The driver supports paging of results with automatic prefetch, see ClusterConfig.PageSize, Session.SetPrefetch,
// Query.PageSize, and Query.Prefetch.
//
// It is also possible to control the paging manually with Query.PageState (this disables automatic prefetch).
// Manual paging is useful if you want to store the page state externally, for example in a URL to allow users
// browse pages in a result. You might want to sign/encrypt the paging state when exposing it externally since
// it contains data from primary keys.
//
// Paging state is specific to the CQL protocol version and the exact query used. It is meant as opaque state that
// should not be modified. If you send paging state from different query or protocol version, then the behaviour
// is not defined (you might get unexpected results or an error from the server). For example, do not send paging state
// returned by node using protocol version 3 to a node using protocol version 4. Also, when using protocol version 4,
// paging state between Cassandra 2.2 and 3.0 is incompatible (https://issues.apache.org/jira/browse/CASSANDRA-10880).
//
// The driver does not check whether the paging state is from the same protocol version/statement.
// You might want to validate yourself as this could be a problem if you store paging state externally.
// For example, if you store paging state in a URL, the URLs might become broken when you upgrade your cluster.
//
// Call Query.PageState(nil) to fetch just the first page of the query results. Pass the page state returned by
// Iter.PageState to Query.PageState of a subsequent query to get the next page. If the length of slice returned
// by Iter.PageState is zero, there are no more pages available (or an error occurred).
//
// Using too low values of PageSize will negatively affect performance, a value below 100 is probably too low.
// While Cassandra returns exactly PageSize items (except for last page) in a page currently, the protocol authors
// explicitly reserved the right to return smaller or larger amount of items in a page for performance reasons, so don't
// rely on the page having the exact count of items.
//
// See Example_paging for an example of manual paging.
//
// Dynamic list of columns
//
// There are certain situations when you don't know the list of columns in advance, mainly when the query is supplied
// by the user. Iter.Columns, Iter.RowData, Iter.MapScan and Iter.SliceMap can be used to handle this case.
//
// See Example_dynamicColumns.
//
// Batches
//
// The CQL protocol supports sending batches of DML statements (INSERT/UPDATE/DELETE) and so does gocql.
// Use Session.NewBatch to create a new batch and then fill-in details of individual queries.
// Then execute the batch with Session.ExecuteBatch.
//
// Logged batches ensure atomicity, either all or none of the operations in the batch will succeed, but they have
// overhead to ensure this property.
// Unlogged batches don't have the overhead of logged batches, but don't guarantee atomicity.
// Updates of counters are handled specially by Cassandra so batches of counter updates have to use CounterBatch type.
// A counter batch can only contain statements to update counters.
//
// For unlogged batches it is recommended to send only single-partition batches (i.e. all statements in the batch should
// involve only a single partition).
// Multi-partition batch needs to be split by the coordinator node and re-sent to
// correct nodes.
// With single-partition batches you can send the batch directly to the node for the partition without incurring the
// additional network hop.
//
// It is also possible to pass entire BEGIN BATCH .. APPLY BATCH statement to Query.Exec.
// There are differences how those are executed.
// BEGIN BATCH statement passed to Query.Exec is prepared as a whole in a single statement.
// Session.ExecuteBatch prepares individual statements in the batch.
// If you have variable-length batches using the same statement, using Session.ExecuteBatch is more efficient.
//
// See Example_batch for an example.
//
// Lightweight transactions
//
// Query.ScanCAS or Query.MapScanCAS can be used to execute a single-statement lightweight transaction (an
// INSERT/UPDATE .. IF statement) and reading its result. See example for Query.MapScanCAS.
//
// Multiple-statement lightweight transactions can be executed as a logged batch that contains at least one conditional
// statement. All the conditions must return true for the batch to be applied. You can use Session.ExecuteBatchCAS and
// Session.MapExecuteBatchCAS when executing the batch to learn about the result of the LWT. See example for
// Session.MapExecuteBatchCAS.
//
// Retries and speculative execution
//
// Queries can be marked as idempotent. Marking the query as idempotent tells the driver that the query can be executed
// multiple times without affecting its result. Non-idempotent queries are not eligible for retrying nor speculative
// execution.
//
// Idempotent queries are retried in case of errors based on the configured RetryPolicy.
//
// Queries can be retried even before they fail by setting a SpeculativeExecutionPolicy. The policy can
// cause the driver to retry on a different node if the query is taking longer than a specified delay even before the
// driver receives an error or timeout from the server. When a query is speculatively executed, the original execution
// is still executing. The two parallel executions of the query race to return a result, the first received result will
// be returned.
//
// User-defined types
//
// UDTs can be mapped (un)marshaled from/to map[string]interface{} a Go struct (or a type implementing
// UDTUnmarshaler, UDTMarshaler, Unmarshaler or Marshaler interfaces).
//
// For structs, cql tag can be used to specify the CQL field name to be mapped to a struct field:
//
// type MyUDT struct {
// FieldA int32 `cql:"a"`
// FieldB string `cql:"b"`
// }
//
// See Example_userDefinedTypesMap, Example_userDefinedTypesStruct, ExampleUDTMarshaler, ExampleUDTUnmarshaler.
//
// Metrics and tracing
//
// It is possible to provide observer implementations that could be used to gather metrics:
//
// - QueryObserver for monitoring individual queries.
// - BatchObserver for monitoring batch queries.
// - ConnectObserver for monitoring new connections from the driver to the database.
// - FrameHeaderObserver for monitoring individual protocol frames.
//
// CQL protocol also supports tracing of queries. When enabled, the database will write information about
// internal events that happened during execution of the query. You can use Query.Trace to request tracing and receive
// the session ID that the database used to store the trace information in system_traces.sessions and
// system_traces.events tables. NewTraceWriter returns an implementation of Tracer that writes the events to a writer.
// Gathering trace information might be essential for debugging and optimizing queries, but writing traces has overhead,
// so this feature should not be used on production systems with very high load unless you know what you are doing.
package gocql // import "github.com/gocql/gocql"
// TODO(tux21b): write more docs.

View File

@ -164,55 +164,43 @@ func (s *Session) handleNodeEvent(frames []frame) {
switch f.change {
case "NEW_NODE":
s.handleNewNode(f.host, f.port, true)
s.handleNewNode(f.host, f.port)
case "REMOVED_NODE":
s.handleRemovedNode(f.host, f.port)
case "MOVED_NODE":
// java-driver handles this, not mentioned in the spec
// TODO(zariel): refresh token map
case "UP":
s.handleNodeUp(f.host, f.port, true)
s.handleNodeUp(f.host, f.port)
case "DOWN":
s.handleNodeDown(f.host, f.port)
}
}
}
func (s *Session) addNewNode(host *HostInfo) {
if s.cfg.filterHost(host) {
return
}
host.setState(NodeUp)
s.pool.addHost(host)
s.policy.AddHost(host)
}
func (s *Session) handleNewNode(ip net.IP, port int, waitForBinary bool) {
if gocqlDebug {
Logger.Printf("gocql: Session.handleNewNode: %s:%d\n", ip.String(), port)
}
ip, port = s.cfg.translateAddressPort(ip, port)
func (s *Session) addNewNode(ip net.IP, port int) {
// Get host info and apply any filters to the host
hostInfo, err := s.hostSource.getHostInfo(ip, port)
if err != nil {
Logger.Printf("gocql: events: unable to fetch host info for (%s:%d): %v\n", ip, port, err)
return
} else if hostInfo == nil {
// If hostInfo is nil, this host was filtered out by cfg.HostFilter
// ignore if it's null because we couldn't find it
return
}
if t := hostInfo.Version().nodeUpDelay(); t > 0 && waitForBinary {
if t := hostInfo.Version().nodeUpDelay(); t > 0 {
time.Sleep(t)
}
// should this handle token moving?
hostInfo = s.ring.addOrUpdate(hostInfo)
s.addNewNode(hostInfo)
if !s.cfg.filterHost(hostInfo) {
// we let the pool call handleNodeUp to change the host state
s.pool.addHost(hostInfo)
s.policy.AddHost(hostInfo)
}
if s.control != nil && !s.cfg.IgnorePeerAddr {
// TODO(zariel): debounce ring refresh
@ -220,6 +208,22 @@ func (s *Session) handleNewNode(ip net.IP, port int, waitForBinary bool) {
}
}
func (s *Session) handleNewNode(ip net.IP, port int) {
if gocqlDebug {
Logger.Printf("gocql: Session.handleNewNode: %s:%d\n", ip.String(), port)
}
ip, port = s.cfg.translateAddressPort(ip, port)
// if we already have the host and it's already up, then do nothing
host := s.ring.getHost(ip)
if host != nil && host.IsUp() {
return
}
s.addNewNode(ip, port)
}
func (s *Session) handleRemovedNode(ip net.IP, port int) {
if gocqlDebug {
Logger.Printf("gocql: Session.handleRemovedNode: %s:%d\n", ip.String(), port)
@ -232,45 +236,37 @@ func (s *Session) handleRemovedNode(ip net.IP, port int) {
if host == nil {
host = &HostInfo{connectAddress: ip, port: port}
}
if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
return
}
s.ring.removeHost(ip)
host.setState(NodeDown)
s.policy.RemoveHost(host)
s.pool.removeHost(ip)
s.ring.removeHost(ip)
if !s.cfg.filterHost(host) {
s.policy.RemoveHost(host)
s.pool.removeHost(ip)
}
if !s.cfg.IgnorePeerAddr {
s.hostSource.refreshRing()
}
}
func (s *Session) handleNodeUp(eventIp net.IP, eventPort int, waitForBinary bool) {
func (s *Session) handleNodeUp(eventIp net.IP, eventPort int) {
if gocqlDebug {
Logger.Printf("gocql: Session.handleNodeUp: %s:%d\n", eventIp.String(), eventPort)
}
ip, _ := s.cfg.translateAddressPort(eventIp, eventPort)
ip, port := s.cfg.translateAddressPort(eventIp, eventPort)
host := s.ring.getHost(ip)
if host == nil {
// TODO(zariel): avoid the need to translate twice in this
// case
s.handleNewNode(eventIp, eventPort, waitForBinary)
s.addNewNode(ip, port)
return
}
if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
return
}
host.setState(NodeUp)
if t := host.Version().nodeUpDelay(); t > 0 && waitForBinary {
time.Sleep(t)
if !s.cfg.filterHost(host) {
s.policy.HostUp(host)
}
s.addNewNode(host)
}
func (s *Session) handleNodeDown(ip net.IP, port int) {
@ -283,11 +279,11 @@ func (s *Session) handleNodeDown(ip net.IP, port int) {
host = &HostInfo{connectAddress: ip, port: port}
}
if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
host.setState(NodeDown)
if s.cfg.filterHost(host) {
return
}
host.setState(NodeDown)
s.policy.HostDown(host)
s.pool.hostDown(ip)
}

View File

@ -311,26 +311,10 @@ var (
const maxFrameHeaderSize = 9
func writeInt(p []byte, n int32) {
p[0] = byte(n >> 24)
p[1] = byte(n >> 16)
p[2] = byte(n >> 8)
p[3] = byte(n)
}
func readInt(p []byte) int32 {
return int32(p[0])<<24 | int32(p[1])<<16 | int32(p[2])<<8 | int32(p[3])
}
func writeShort(p []byte, n uint16) {
p[0] = byte(n >> 8)
p[1] = byte(n)
}
func readShort(p []byte) uint16 {
return uint16(p[0])<<8 | uint16(p[1])
}
type frameHeader struct {
version protoVersion
flags byte
@ -854,7 +838,7 @@ func (w *writePrepareFrame) writeFrame(f *framer, streamID int) error {
if f.proto > protoVersion4 {
flags |= flagWithPreparedKeyspace
} else {
panic(fmt.Errorf("The keyspace can only be set with protocol 5 or higher"))
panic(fmt.Errorf("the keyspace can only be set with protocol 5 or higher"))
}
}
if f.proto > protoVersion4 {
@ -1502,7 +1486,7 @@ func (f *framer) writeQueryParams(opts *queryParams) {
if f.proto > protoVersion4 {
flags |= flagWithKeyspace
} else {
panic(fmt.Errorf("The keyspace can only be set with protocol 5 or higher"))
panic(fmt.Errorf("the keyspace can only be set with protocol 5 or higher"))
}
}
@ -1792,16 +1776,6 @@ func (f *framer) readShort() (n uint16) {
return
}
func (f *framer) readLong() (n int64) {
if len(f.rbuf) < 8 {
panic(fmt.Errorf("not enough bytes in buffer to read long require 8 got: %d", len(f.rbuf)))
}
n = int64(f.rbuf[0])<<56 | int64(f.rbuf[1])<<48 | int64(f.rbuf[2])<<40 | int64(f.rbuf[3])<<32 |
int64(f.rbuf[4])<<24 | int64(f.rbuf[5])<<16 | int64(f.rbuf[6])<<8 | int64(f.rbuf[7])
f.rbuf = f.rbuf[8:]
return
}
func (f *framer) readString() (s string) {
size := f.readShort()
@ -1915,19 +1889,6 @@ func (f *framer) readConsistency() Consistency {
return Consistency(f.readShort())
}
func (f *framer) readStringMap() map[string]string {
size := f.readShort()
m := make(map[string]string, size)
for i := 0; i < int(size); i++ {
k := f.readString()
v := f.readString()
m[k] = v
}
return m
}
func (f *framer) readBytesMap() map[string][]byte {
size := f.readShort()
m := make(map[string][]byte, size)
@ -2037,10 +1998,6 @@ func (f *framer) writeLongString(s string) {
f.wbuf = append(f.wbuf, s...)
}
func (f *framer) writeUUID(u *UUID) {
f.wbuf = append(f.wbuf, u[:]...)
}
func (f *framer) writeStringList(l []string) {
f.writeShort(uint16(len(l)))
for _, s := range l {
@ -2073,18 +2030,6 @@ func (f *framer) writeShortBytes(p []byte) {
f.wbuf = append(f.wbuf, p...)
}
func (f *framer) writeInet(ip net.IP, port int) {
f.wbuf = append(f.wbuf,
byte(len(ip)),
)
f.wbuf = append(f.wbuf,
[]byte(ip)...,
)
f.writeInt(int32(port))
}
func (f *framer) writeConsistency(cons Consistency) {
f.writeShort(uint16(cons))
}

View File

@ -270,15 +270,6 @@ func getApacheCassandraType(class string) Type {
}
}
func typeCanBeNull(typ TypeInfo) bool {
switch typ.(type) {
case CollectionType, UDTTypeInfo, TupleTypeInfo:
return false
}
return true
}
func (r *RowData) rowMap(m map[string]interface{}) {
for i, column := range r.Columns {
val := dereference(r.Values[i])
@ -372,7 +363,7 @@ func (iter *Iter) SliceMap() ([]map[string]interface{}, error) {
// iter := session.Query(`SELECT * FROM mytable`).Iter()
// for {
// // New map each iteration
// row = make(map[string]interface{})
// row := make(map[string]interface{})
// if !iter.MapScan(row) {
// break
// }

View File

@ -147,13 +147,6 @@ func (h *HostInfo) Peer() net.IP {
return h.peer
}
func (h *HostInfo) setPeer(peer net.IP) *HostInfo {
h.mu.Lock()
defer h.mu.Unlock()
h.peer = peer
return h
}
func (h *HostInfo) invalidConnectAddr() bool {
h.mu.RLock()
defer h.mu.RUnlock()
@ -233,13 +226,6 @@ func (h *HostInfo) DataCenter() string {
return dc
}
func (h *HostInfo) setDataCenter(dataCenter string) *HostInfo {
h.mu.Lock()
defer h.mu.Unlock()
h.dataCenter = dataCenter
return h
}
func (h *HostInfo) Rack() string {
h.mu.RLock()
rack := h.rack
@ -247,26 +233,12 @@ func (h *HostInfo) Rack() string {
return rack
}
func (h *HostInfo) setRack(rack string) *HostInfo {
h.mu.Lock()
defer h.mu.Unlock()
h.rack = rack
return h
}
func (h *HostInfo) HostID() string {
h.mu.RLock()
defer h.mu.RUnlock()
return h.hostId
}
func (h *HostInfo) setHostID(hostID string) *HostInfo {
h.mu.Lock()
defer h.mu.Unlock()
h.hostId = hostID
return h
}
func (h *HostInfo) WorkLoad() string {
h.mu.RLock()
defer h.mu.RUnlock()
@ -303,13 +275,6 @@ func (h *HostInfo) Version() cassVersion {
return h.version
}
func (h *HostInfo) setVersion(major, minor, patch int) *HostInfo {
h.mu.Lock()
defer h.mu.Unlock()
h.version = cassVersion{major, minor, patch}
return h
}
func (h *HostInfo) State() nodeState {
h.mu.RLock()
defer h.mu.RUnlock()
@ -329,26 +294,12 @@ func (h *HostInfo) Tokens() []string {
return h.tokens
}
func (h *HostInfo) setTokens(tokens []string) *HostInfo {
h.mu.Lock()
defer h.mu.Unlock()
h.tokens = tokens
return h
}
func (h *HostInfo) Port() int {
h.mu.RLock()
defer h.mu.RUnlock()
return h.port
}
func (h *HostInfo) setPort(port int) *HostInfo {
h.mu.Lock()
defer h.mu.Unlock()
h.port = port
return h
}
func (h *HostInfo) update(from *HostInfo) {
if h == from {
return
@ -689,7 +640,7 @@ func (r *ringDescriber) refreshRing() error {
// TODO: move this to session
for _, h := range hosts {
if filter := r.session.cfg.HostFilter; filter != nil && !filter.Accept(h) {
if r.session.cfg.filterHost(h) {
continue
}

View File

@ -8,9 +8,3 @@ git clone https://github.com/pcmanus/ccm.git
pushd ccm
./setup.py install --user
popd
if [ "$1" != "gocql/gocql" ]; then
USER=$(echo $1 | cut -f1 -d'/')
cd ../..
mv ${USER} gocql
fi

View File

@ -44,6 +44,52 @@ type Unmarshaler interface {
// Marshal returns the CQL encoding of the value for the Cassandra
// internal type described by the info parameter.
//
// nil is serialized as CQL null.
// If value implements Marshaler, its MarshalCQL method is called to marshal the data.
// If value is a pointer, the pointed-to value is marshaled.
//
// Supported conversions are as follows, other type combinations may be added in the future:
//
// CQL type | Go type (value) | Note
// varchar, ascii, blob, text | string, []byte |
// boolean | bool |
// tinyint, smallint, int | integer types |
// tinyint, smallint, int | string | formatted as base 10 number
// bigint, counter | integer types |
// bigint, counter | big.Int |
// bigint, counter | string | formatted as base 10 number
// float | float32 |
// double | float64 |
// decimal | inf.Dec |
// time | int64 | nanoseconds since start of day
// time | time.Duration | duration since start of day
// timestamp | int64 | milliseconds since Unix epoch
// timestamp | time.Time |
// list, set | slice, array |
// list, set | map[X]struct{} |
// map | map[X]Y |
// uuid, timeuuid | gocql.UUID |
// uuid, timeuuid | [16]byte | raw UUID bytes
// uuid, timeuuid | []byte | raw UUID bytes, length must be 16 bytes
// uuid, timeuuid | string | hex representation, see ParseUUID
// varint | integer types |
// varint | big.Int |
// varint | string | value of number in decimal notation
// inet | net.IP |
// inet | string | IPv4 or IPv6 address string
// tuple | slice, array |
// tuple | struct | fields are marshaled in order of declaration
// user-defined type | gocql.UDTMarshaler | MarshalUDT is called
// user-defined type | map[string]interface{} |
// user-defined type | struct | struct fields' cql tags are used for column names
// date | int64 | milliseconds since Unix epoch to start of day (in UTC)
// date | time.Time | start of day (in UTC)
// date | string | parsed using "2006-01-02" format
// duration | int64 | duration in nanoseconds
// duration | time.Duration |
// duration | gocql.Duration |
// duration | string | parsed with time.ParseDuration
func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
if info.Version() < protoVersion1 {
panic("protocol version not set")
@ -118,6 +164,44 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
// Unmarshal parses the CQL encoded data based on the info parameter that
// describes the Cassandra internal data type and stores the result in the
// value pointed by value.
//
// If value implements Unmarshaler, it's UnmarshalCQL method is called to
// unmarshal the data.
// If value is a pointer to pointer, it is set to nil if the CQL value is
// null. Otherwise, nulls are unmarshalled as zero value.
//
// Supported conversions are as follows, other type combinations may be added in the future:
//
// CQL type | Go type (value) | Note
// varchar, ascii, blob, text | *string |
// varchar, ascii, blob, text | *[]byte | non-nil buffer is reused
// bool | *bool |
// tinyint, smallint, int, bigint, counter | *integer types |
// tinyint, smallint, int, bigint, counter | *big.Int |
// tinyint, smallint, int, bigint, counter | *string | formatted as base 10 number
// float | *float32 |
// double | *float64 |
// decimal | *inf.Dec |
// time | *int64 | nanoseconds since start of day
// time | *time.Duration |
// timestamp | *int64 | milliseconds since Unix epoch
// timestamp | *time.Time |
// list, set | *slice, *array |
// map | *map[X]Y |
// uuid, timeuuid | *string | see UUID.String
// uuid, timeuuid | *[]byte | raw UUID bytes
// uuid, timeuuid | *gocql.UUID |
// timeuuid | *time.Time | timestamp of the UUID
// inet | *net.IP |
// inet | *string | IPv4 or IPv6 address string
// tuple | *slice, *array |
// tuple | *struct | struct fields are set in order of declaration
// user-defined types | gocql.UDTUnmarshaler | UnmarshalUDT is called
// user-defined types | *map[string]interface{} |
// user-defined types | *struct | cql tag is used to determine field name
// date | *time.Time | time of beginning of the day (in UTC)
// date | *string | formatted with 2006-01-02 format
// duration | *gocql.Duration |
func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
if v, ok := value.(Unmarshaler); ok {
return v.UnmarshalCQL(info, data)
@ -1690,6 +1774,8 @@ func marshalUUID(info TypeInfo, value interface{}) ([]byte, error) {
return nil, nil
case UUID:
return val.Bytes(), nil
case [16]byte:
return val[:], nil
case []byte:
if len(val) != 16 {
return nil, marshalErrorf("can not marshal []byte %d bytes long into %s, must be exactly 16 bytes long", len(val), info)
@ -1711,7 +1797,7 @@ func marshalUUID(info TypeInfo, value interface{}) ([]byte, error) {
}
func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error {
if data == nil || len(data) == 0 {
if len(data) == 0 {
switch v := value.(type) {
case *string:
*v = ""
@ -1726,9 +1812,22 @@ func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error {
return nil
}
if len(data) != 16 {
return unmarshalErrorf("unable to parse UUID: UUIDs must be exactly 16 bytes long")
}
switch v := value.(type) {
case *[16]byte:
copy((*v)[:], data)
return nil
case *UUID:
copy((*v)[:], data)
return nil
}
u, err := UUIDFromBytes(data)
if err != nil {
return unmarshalErrorf("Unable to parse UUID: %s", err)
return unmarshalErrorf("unable to parse UUID: %s", err)
}
switch v := value.(type) {
@ -1738,9 +1837,6 @@ func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error {
case *[]byte:
*v = u[:]
return nil
case *UUID:
*v = u
return nil
}
return unmarshalErrorf("can not unmarshal X %s into %T", info, value)
}
@ -1942,7 +2038,7 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error {
for i, elem := range tuple.Elems {
// each element inside data is a [bytes]
var p []byte
if len(data) > 4 {
if len(data) >= 4 {
p, data = readBytes(data)
}
err := Unmarshal(elem, p, v[i])
@ -1971,7 +2067,7 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error {
for i, elem := range tuple.Elems {
var p []byte
if len(data) > 4 {
if len(data) >= 4 {
p, data = readBytes(data)
}
@ -1982,7 +2078,11 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error {
switch rv.Field(i).Kind() {
case reflect.Ptr:
rv.Field(i).Set(reflect.ValueOf(v))
if p != nil {
rv.Field(i).Set(reflect.ValueOf(v))
} else {
rv.Field(i).Set(reflect.Zero(reflect.TypeOf(v)))
}
default:
rv.Field(i).Set(reflect.ValueOf(v).Elem())
}
@ -2001,7 +2101,7 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error {
for i, elem := range tuple.Elems {
var p []byte
if len(data) > 4 {
if len(data) >= 4 {
p, data = readBytes(data)
}
@ -2012,7 +2112,11 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error {
switch rv.Index(i).Kind() {
case reflect.Ptr:
rv.Index(i).Set(reflect.ValueOf(v))
if p != nil {
rv.Index(i).Set(reflect.ValueOf(v))
} else {
rv.Index(i).Set(reflect.Zero(reflect.TypeOf(v)))
}
default:
rv.Index(i).Set(reflect.ValueOf(v).Elem())
}
@ -2050,7 +2154,7 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) {
case Marshaler:
return v.MarshalCQL(info)
case unsetColumn:
return nil, unmarshalErrorf("Invalid request: UnsetValue is unsupported for user defined types")
return nil, unmarshalErrorf("invalid request: UnsetValue is unsupported for user defined types")
case UDTMarshaler:
var buf []byte
for _, e := range udt.Elements {

View File

@ -324,10 +324,10 @@ func compileMetadata(
keyspace.Functions[functions[i].Name] = &functions[i]
}
keyspace.Aggregates = make(map[string]*AggregateMetadata, len(aggregates))
for _, aggregate := range aggregates {
aggregate.FinalFunc = *keyspace.Functions[aggregate.finalFunc]
aggregate.StateFunc = *keyspace.Functions[aggregate.stateFunc]
keyspace.Aggregates[aggregate.Name] = &aggregate
for i, _ := range aggregates {
aggregates[i].FinalFunc = *keyspace.Functions[aggregates[i].finalFunc]
aggregates[i].StateFunc = *keyspace.Functions[aggregates[i].stateFunc]
keyspace.Aggregates[aggregates[i].Name] = &aggregates[i]
}
keyspace.Views = make(map[string]*ViewMetadata, len(views))
for i := range views {
@ -347,9 +347,9 @@ func compileMetadata(
keyspace.UserTypes[types[i].Name] = &types[i]
}
keyspace.MaterializedViews = make(map[string]*MaterializedViewMetadata, len(materializedViews))
for _, materializedView := range materializedViews {
materializedView.BaseTable = keyspace.Tables[materializedView.baseTableName]
keyspace.MaterializedViews[materializedView.Name] = &materializedView
for i, _ := range materializedViews {
materializedViews[i].BaseTable = keyspace.Tables[materializedViews[i].baseTableName]
keyspace.MaterializedViews[materializedViews[i].Name] = &materializedViews[i]
}
// add columns from the schema data
@ -559,7 +559,7 @@ func getKeyspaceMetadata(session *Session, keyspaceName string) (*KeyspaceMetada
iter.Scan(&keyspace.DurableWrites, &replication)
err := iter.Close()
if err != nil {
return nil, fmt.Errorf("Error querying keyspace schema: %v", err)
return nil, fmt.Errorf("error querying keyspace schema: %v", err)
}
keyspace.StrategyClass = replication["class"]
@ -585,13 +585,13 @@ func getKeyspaceMetadata(session *Session, keyspaceName string) (*KeyspaceMetada
iter.Scan(&keyspace.DurableWrites, &keyspace.StrategyClass, &strategyOptionsJSON)
err := iter.Close()
if err != nil {
return nil, fmt.Errorf("Error querying keyspace schema: %v", err)
return nil, fmt.Errorf("error querying keyspace schema: %v", err)
}
err = json.Unmarshal(strategyOptionsJSON, &keyspace.StrategyOptions)
if err != nil {
return nil, fmt.Errorf(
"Invalid JSON value '%s' as strategy_options for in keyspace '%s': %v",
"invalid JSON value '%s' as strategy_options for in keyspace '%s': %v",
strategyOptionsJSON, keyspace.Name, err,
)
}
@ -703,7 +703,7 @@ func getTableMetadata(session *Session, keyspaceName string) ([]TableMetadata, e
if err != nil {
iter.Close()
return nil, fmt.Errorf(
"Invalid JSON value '%s' as key_aliases for in table '%s': %v",
"invalid JSON value '%s' as key_aliases for in table '%s': %v",
keyAliasesJSON, table.Name, err,
)
}
@ -716,7 +716,7 @@ func getTableMetadata(session *Session, keyspaceName string) ([]TableMetadata, e
if err != nil {
iter.Close()
return nil, fmt.Errorf(
"Invalid JSON value '%s' as column_aliases for in table '%s': %v",
"invalid JSON value '%s' as column_aliases for in table '%s': %v",
columnAliasesJSON, table.Name, err,
)
}
@ -728,7 +728,7 @@ func getTableMetadata(session *Session, keyspaceName string) ([]TableMetadata, e
err := iter.Close()
if err != nil && err != ErrNotFound {
return nil, fmt.Errorf("Error querying table schema: %v", err)
return nil, fmt.Errorf("error querying table schema: %v", err)
}
return tables, nil
@ -777,7 +777,7 @@ func (s *Session) scanColumnMetadataV1(keyspace string) ([]ColumnMetadata, error
err := json.Unmarshal(indexOptionsJSON, &column.Index.Options)
if err != nil {
return nil, fmt.Errorf(
"Invalid JSON value '%s' as index_options for column '%s' in table '%s': %v",
"invalid JSON value '%s' as index_options for column '%s' in table '%s': %v",
indexOptionsJSON,
column.Name,
column.Table,
@ -837,7 +837,7 @@ func (s *Session) scanColumnMetadataV2(keyspace string) ([]ColumnMetadata, error
err := json.Unmarshal(indexOptionsJSON, &column.Index.Options)
if err != nil {
return nil, fmt.Errorf(
"Invalid JSON value '%s' as index_options for column '%s' in table '%s': %v",
"invalid JSON value '%s' as index_options for column '%s' in table '%s': %v",
indexOptionsJSON,
column.Name,
column.Table,
@ -915,7 +915,7 @@ func getColumnMetadata(session *Session, keyspaceName string) ([]ColumnMetadata,
}
if err != nil && err != ErrNotFound {
return nil, fmt.Errorf("Error querying column schema: %v", err)
return nil, fmt.Errorf("error querying column schema: %v", err)
}
return columns, nil

View File

@ -1,9 +1,12 @@
// Copyright (c) 2012 The gocql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//This file will be the future home for more policies
package gocql
//This file will be the future home for more policies
import (
"context"
"errors"
@ -37,12 +40,6 @@ func (c *cowHostList) get() []*HostInfo {
return *l
}
func (c *cowHostList) set(list []*HostInfo) {
c.mu.Lock()
c.list.Store(&list)
c.mu.Unlock()
}
// add will add a host if it not already in the list
func (c *cowHostList) add(host *HostInfo) bool {
c.mu.Lock()
@ -68,33 +65,6 @@ func (c *cowHostList) add(host *HostInfo) bool {
return true
}
func (c *cowHostList) update(host *HostInfo) {
c.mu.Lock()
l := c.get()
if len(l) == 0 {
c.mu.Unlock()
return
}
found := false
newL := make([]*HostInfo, len(l))
for i := range l {
if host.Equal(l[i]) {
newL[i] = host
found = true
} else {
newL[i] = l[i]
}
}
if found {
c.list.Store(&newL)
}
c.mu.Unlock()
}
func (c *cowHostList) remove(ip net.IP) bool {
c.mu.Lock()
l := c.get()
@ -304,7 +274,10 @@ type HostSelectionPolicy interface {
KeyspaceChanged(KeyspaceUpdateEvent)
Init(*Session)
IsLocal(host *HostInfo) bool
//Pick returns an iteration function over selected hosts
// Pick returns an iteration function over selected hosts.
// Multiple attempts of a single query execution won't call the returned NextHost function concurrently,
// so it's safe to have internal state without additional synchronization as long as every call to Pick returns
// a different instance of NextHost.
Pick(ExecutableQuery) NextHost
}
@ -880,6 +853,51 @@ func (d *dcAwareRR) Pick(q ExecutableQuery) NextHost {
return roundRobbin(int(nextStartOffset), d.localHosts.get(), d.remoteHosts.get())
}
// ReadyPolicy defines a policy for when a HostSelectionPolicy can be used. After
// each host connects during session initialization, the Ready method will be
// called. If you only need a single Host to be up you can wrap a
// HostSelectionPolicy policy with SingleHostReadyPolicy.
type ReadyPolicy interface {
Ready() bool
}
// SingleHostReadyPolicy wraps a HostSelectionPolicy and returns Ready after a
// single host has been added via HostUp
func SingleHostReadyPolicy(p HostSelectionPolicy) *singleHostReadyPolicy {
return &singleHostReadyPolicy{
HostSelectionPolicy: p,
}
}
type singleHostReadyPolicy struct {
HostSelectionPolicy
ready bool
readyMux sync.Mutex
}
func (s *singleHostReadyPolicy) HostUp(host *HostInfo) {
s.HostSelectionPolicy.HostUp(host)
s.readyMux.Lock()
s.ready = true
s.readyMux.Unlock()
}
func (s *singleHostReadyPolicy) Ready() bool {
s.readyMux.Lock()
ready := s.ready
s.readyMux.Unlock()
if !ready {
return false
}
// in case the wrapped policy is also a ReadyPolicy, defer to that
if rdy, ok := s.HostSelectionPolicy.(ReadyPolicy); ok {
return rdy.Ready()
}
return true
}
// ConvictionPolicy interface is used by gocql to determine if a host should be
// marked as DOWN based on the error and host info
type ConvictionPolicy interface {

View File

@ -14,18 +14,6 @@ type preparedLRU struct {
lru *lru.Cache
}
// Max adjusts the maximum size of the cache and cleans up the oldest records if
// the new max is lower than the previous value. Not concurrency safe.
func (p *preparedLRU) max(max int) {
p.mu.Lock()
defer p.mu.Unlock()
for p.lru.Len() > max {
p.lru.RemoveOldest()
}
p.lru.MaxEntries = max
}
func (p *preparedLRU) clear() {
p.mu.Lock()
defer p.mu.Unlock()

View File

@ -2,6 +2,7 @@ package gocql
import (
"context"
"sync"
"time"
)
@ -34,14 +35,15 @@ func (q *queryExecutor) attemptQuery(ctx context.Context, qry ExecutableQuery, c
return iter
}
func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp SpeculativeExecutionPolicy, results chan *Iter) *Iter {
func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp SpeculativeExecutionPolicy,
hostIter NextHost, results chan *Iter) *Iter {
ticker := time.NewTicker(sp.Delay())
defer ticker.Stop()
for i := 0; i < sp.Attempts(); i++ {
select {
case <-ticker.C:
go q.run(ctx, qry, results)
go q.run(ctx, qry, hostIter, results)
case <-ctx.Done():
return &Iter{err: ctx.Err()}
case iter := <-results:
@ -53,11 +55,23 @@ func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp S
}
func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
hostIter := q.policy.Pick(qry)
// check if the query is not marked as idempotent, if
// it is, we force the policy to NonSpeculative
sp := qry.speculativeExecutionPolicy()
if !qry.IsIdempotent() || sp.Attempts() == 0 {
return q.do(qry.Context(), qry), nil
return q.do(qry.Context(), qry, hostIter), nil
}
// When speculative execution is enabled, we could be accessing the host iterator from multiple goroutines below.
// To ensure we don't call it concurrently, we wrap the returned NextHost function here to synchronize access to it.
var mu sync.Mutex
origHostIter := hostIter
hostIter = func() SelectedHost {
mu.Lock()
defer mu.Unlock()
return origHostIter()
}
ctx, cancel := context.WithCancel(qry.Context())
@ -66,12 +80,12 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
results := make(chan *Iter, 1)
// Launch the main execution
go q.run(ctx, qry, results)
go q.run(ctx, qry, hostIter, results)
// The speculative executions are launched _in addition_ to the main
// execution, on a timer. So Speculation{2} would make 3 executions running
// in total.
if iter := q.speculate(ctx, qry, sp, results); iter != nil {
if iter := q.speculate(ctx, qry, sp, hostIter, results); iter != nil {
return iter, nil
}
@ -83,8 +97,7 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
}
}
func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery) *Iter {
hostIter := q.policy.Pick(qry)
func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter NextHost) *Iter {
selectedHost := hostIter()
rt := qry.retryPolicy()
@ -153,9 +166,9 @@ func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery) *Iter {
return &Iter{err: ErrNoConnections}
}
func (q *queryExecutor) run(ctx context.Context, qry ExecutableQuery, results chan<- *Iter) {
func (q *queryExecutor) run(ctx context.Context, qry ExecutableQuery, hostIter NextHost, results chan<- *Iter) {
select {
case results <- q.do(ctx, qry):
case results <- q.do(ctx, qry, hostIter):
case <-ctx.Done():
}
}

View File

@ -63,29 +63,6 @@ func (r *ring) currentHosts() map[string]*HostInfo {
return hosts
}
func (r *ring) addHost(host *HostInfo) bool {
// TODO(zariel): key all host info by HostID instead of
// ip addresses
if host.invalidConnectAddr() {
panic(fmt.Sprintf("invalid host: %v", host))
}
ip := host.ConnectAddress().String()
r.mu.Lock()
if r.hosts == nil {
r.hosts = make(map[string]*HostInfo)
}
_, ok := r.hosts[ip]
if !ok {
r.hostList = append(r.hostList, host)
}
r.hosts[ip] = host
r.mu.Unlock()
return ok
}
func (r *ring) addOrUpdate(host *HostInfo) *HostInfo {
if existingHost, ok := r.addHostIfMissing(host); ok {
existingHost.update(host)

View File

@ -27,7 +27,7 @@ import (
// scenario is to have one global session object to interact with the
// whole Cassandra cluster.
//
// This type extends the Node interface by adding a convinient query builder
// This type extends the Node interface by adding a convenient query builder
// and automatically sets a default consistency level on all operations
// that do not have a consistency level set.
type Session struct {
@ -62,7 +62,6 @@ type Session struct {
schemaEvents *eventDebouncer
// ring metadata
hosts []HostInfo
useSystemSchema bool
hasAggregatesAndFunctions bool
@ -227,18 +226,44 @@ func (s *Session) init() error {
}
hosts = hosts[:0]
// each host will increment left and decrement it after connecting and once
// there's none left, we'll close hostCh
var left int64
// we will receive up to len(hostMap) of messages so create a buffer so we
// don't end up stuck in a goroutine if we stopped listening
connectedCh := make(chan struct{}, len(hostMap))
// we add one here because we don't want to end up closing hostCh until we're
// done looping and the decerement code might be reached before we've looped
// again
atomic.AddInt64(&left, 1)
for _, host := range hostMap {
host = s.ring.addOrUpdate(host)
host := s.ring.addOrUpdate(host)
if s.cfg.filterHost(host) {
continue
}
host.setState(NodeUp)
s.pool.addHost(host)
atomic.AddInt64(&left, 1)
go func() {
s.pool.addHost(host)
connectedCh <- struct{}{}
// if there are no hosts left, then close the hostCh to unblock the loop
// below if its still waiting
if atomic.AddInt64(&left, -1) == 0 {
close(connectedCh)
}
}()
hosts = append(hosts, host)
}
// once we're done looping we subtract the one we initially added and check
// to see if we should close
if atomic.AddInt64(&left, -1) == 0 {
close(connectedCh)
}
// before waiting for them to connect, add them all to the policy so we can
// utilize efficiencies by calling AddHosts if the policy supports it
type bulkAddHosts interface {
AddHosts([]*HostInfo)
}
@ -250,6 +275,15 @@ func (s *Session) init() error {
}
}
readyPolicy, _ := s.policy.(ReadyPolicy)
// now loop over connectedCh until it's closed (meaning we've connected to all)
// or until the policy says we're ready
for range connectedCh {
if readyPolicy != nil && readyPolicy.Ready() {
break
}
}
// TODO(zariel): we probably dont need this any more as we verify that we
// can connect to one of the endpoints supplied by using the control conn.
// See if there are any connections in the pool
@ -320,7 +354,8 @@ func (s *Session) reconnectDownedHosts(intv time.Duration) {
if h.IsUp() {
continue
}
s.handleNodeUp(h.ConnectAddress(), h.Port(), true)
// we let the pool call handleNodeUp to change the host state
s.pool.addHost(h)
}
case <-s.ctx.Done():
return
@ -806,6 +841,7 @@ type Query struct {
trace Tracer
observer QueryObserver
session *Session
conn *Conn
rt RetryPolicy
spec SpeculativeExecutionPolicy
binding func(q *QueryInfo) ([]interface{}, error)
@ -1094,12 +1130,17 @@ func (q *Query) speculativeExecutionPolicy() SpeculativeExecutionPolicy {
return q.spec
}
// IsIdempotent returns whether the query is marked as idempotent.
// Non-idempotent query won't be retried.
// See "Retries and speculative execution" in package docs for more details.
func (q *Query) IsIdempotent() bool {
return q.idempotent
}
// Idempotent marks the query as being idempotent or not depending on
// the value.
// Non-idempotent query won't be retried.
// See "Retries and speculative execution" in package docs for more details.
func (q *Query) Idempotent(value bool) *Query {
q.idempotent = value
return q
@ -1164,6 +1205,11 @@ func (q *Query) Iter() *Iter {
if isUseStatement(q.stmt) {
return &Iter{err: ErrUseStmt}
}
// if the query was specifically run on a connection then re-use that
// connection when fetching the next results
if q.conn != nil {
return q.conn.executeQuery(q.Context(), q)
}
return q.session.executeQuery(q)
}
@ -1195,6 +1241,10 @@ func (q *Query) Scan(dest ...interface{}) error {
// statement containing an IF clause). If the transaction fails because
// the existing values did not match, the previous values will be stored
// in dest.
//
// As for INSERT .. IF NOT EXISTS, previous values will be returned as if
// SELECT * FROM. So using ScanCAS with INSERT is inherently prone to
// column mismatching. Use MapScanCAS to capture them safely.
func (q *Query) ScanCAS(dest ...interface{}) (applied bool, err error) {
q.disableSkipMetadata = true
iter := q.Iter()
@ -1423,7 +1473,7 @@ func (iter *Iter) Scan(dest ...interface{}) bool {
}
if iter.next != nil && iter.pos >= iter.next.pos {
go iter.next.fetch()
iter.next.fetchAsync()
}
// currently only support scanning into an expand tuple, such that its the same
@ -1517,16 +1567,31 @@ func (iter *Iter) NumRows() int {
return iter.numRows
}
// nextIter holds state for fetching a single page in an iterator.
// single page might be attempted multiple times due to retries.
type nextIter struct {
qry *Query
pos int
once sync.Once
next *Iter
qry *Query
pos int
oncea sync.Once
once sync.Once
next *Iter
}
func (n *nextIter) fetchAsync() {
n.oncea.Do(func() {
go n.fetch()
})
}
func (n *nextIter) fetch() *Iter {
n.once.Do(func() {
n.next = n.qry.session.executeQuery(n.qry)
// if the query was specifically run on a connection then re-use that
// connection when fetching the next results
if n.qry.conn != nil {
n.next = n.qry.conn.executeQuery(n.qry.Context(), n.qry)
} else {
n.next = n.qry.session.executeQuery(n.qry)
}
})
return n.next
}
@ -1536,7 +1601,6 @@ type Batch struct {
Entries []BatchEntry
Cons Consistency
routingKey []byte
routingKeyBuffer []byte
CustomPayload map[string][]byte
rt RetryPolicy
spec SpeculativeExecutionPolicy
@ -1733,7 +1797,7 @@ func (b *Batch) WithTimestamp(timestamp int64) *Batch {
func (b *Batch) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) {
latency := end.Sub(start)
_, metricsForHost := b.metrics.attempt(1, latency, host, b.observer != nil)
attempt, metricsForHost := b.metrics.attempt(1, latency, host, b.observer != nil)
if b.observer == nil {
return
@ -1753,6 +1817,7 @@ func (b *Batch) attempt(keyspace string, end, start time.Time, iter *Iter, host
Host: host,
Metrics: metricsForHost,
Err: iter.err,
Attempt: attempt,
})
}
@ -1968,7 +2033,6 @@ type ObservedQuery struct {
Err error
// Attempt is the index of attempt at executing this query.
// An attempt might be either retry or fetching next page of a query.
// The first attempt is number zero and any retries have non-zero attempt number.
Attempt int
}
@ -1999,6 +2063,10 @@ type ObservedBatch struct {
// The metrics per this host
Metrics *hostMetrics
// Attempt is the index of attempt at executing this query.
// The first attempt is number zero and any retries have non-zero attempt number.
Attempt int
}
// BatchObserver is the interface implemented by batch observers / stat collectors.

View File

@ -153,7 +153,7 @@ func newTokenRing(partitioner string, hosts []*HostInfo) (*tokenRing, error) {
} else if strings.HasSuffix(partitioner, "RandomPartitioner") {
tokenRing.partitioner = randomPartitioner{}
} else {
return nil, fmt.Errorf("Unsupported partitioner '%s'", partitioner)
return nil, fmt.Errorf("unsupported partitioner '%s'", partitioner)
}
for _, host := range hosts {

View File

@ -46,32 +46,35 @@ type placementStrategy interface {
replicationFactor(dc string) int
}
func getReplicationFactorFromOpts(keyspace string, val interface{}) int {
// TODO: dont really want to panic here, but is better
// than spamming
func getReplicationFactorFromOpts(val interface{}) (int, error) {
switch v := val.(type) {
case int:
if v <= 0 {
panic(fmt.Sprintf("invalid replication_factor %d. Is the %q keyspace configured correctly?", v, keyspace))
if v < 0 {
return 0, fmt.Errorf("invalid replication_factor %d", v)
}
return v
return v, nil
case string:
n, err := strconv.Atoi(v)
if err != nil {
panic(fmt.Sprintf("invalid replication_factor. Is the %q keyspace configured correctly? %v", keyspace, err))
} else if n <= 0 {
panic(fmt.Sprintf("invalid replication_factor %d. Is the %q keyspace configured correctly?", n, keyspace))
return 0, fmt.Errorf("invalid replication_factor %q: %v", v, err)
} else if n < 0 {
return 0, fmt.Errorf("invalid replication_factor %d", n)
}
return n
return n, nil
default:
panic(fmt.Sprintf("unkown replication_factor type %T", v))
return 0, fmt.Errorf("unknown replication_factor type %T", v)
}
}
func getStrategy(ks *KeyspaceMetadata) placementStrategy {
switch {
case strings.Contains(ks.StrategyClass, "SimpleStrategy"):
return &simpleStrategy{rf: getReplicationFactorFromOpts(ks.Name, ks.StrategyOptions["replication_factor"])}
rf, err := getReplicationFactorFromOpts(ks.StrategyOptions["replication_factor"])
if err != nil {
Logger.Printf("parse rf for keyspace %q: %v", ks.Name, err)
return nil
}
return &simpleStrategy{rf: rf}
case strings.Contains(ks.StrategyClass, "NetworkTopologyStrategy"):
dcs := make(map[string]int)
for dc, rf := range ks.StrategyOptions {
@ -79,14 +82,21 @@ func getStrategy(ks *KeyspaceMetadata) placementStrategy {
continue
}
dcs[dc] = getReplicationFactorFromOpts(ks.Name+":dc="+dc, rf)
rf, err := getReplicationFactorFromOpts(rf)
if err != nil {
Logger.Println("parse rf for keyspace %q, dc %q: %v", err)
// skip DC if the rf is invalid/unsupported, so that we can at least work with other working DCs.
continue
}
dcs[dc] = rf
}
return &networkTopology{dcs: dcs}
case strings.Contains(ks.StrategyClass, "LocalStrategy"):
return nil
default:
// TODO: handle unknown replicas and just return the primary host for a token
panic(fmt.Sprintf("unsupported strategy class: %v", ks.StrategyClass))
Logger.Printf("parse rf for keyspace %q: unsupported strategy class: %v", ks.StrategyClass)
return nil
}
}

View File

@ -2,11 +2,12 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gocql
// The uuid package can be used to generate and parse universally unique
// identifiers, a standardized format in the form of a 128 bit number.
//
// http://tools.ietf.org/html/rfc4122
package gocql
import (
"crypto/rand"

2
vendor/modules.txt vendored
View File

@ -402,7 +402,7 @@ github.com/go-stack/stack
github.com/go-test/deep
# github.com/go-yaml/yaml v2.1.0+incompatible
github.com/go-yaml/yaml
# github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e
# github.com/gocql/gocql v0.0.0-20210401103645-80ab1e13e309
## explicit
github.com/gocql/gocql
github.com/gocql/gocql/internal/lru