Validate hostnames when using TLS in Cassandra (#11365)
This commit is contained in:
parent
541ae8636c
commit
4279bc8b34
|
@ -20,13 +20,18 @@ func TestBackend_basic(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cleanup, hostname := cassandra.PrepareTestContainer(t, "latest")
|
copyFromTo := map[string]string{
|
||||||
|
"test-fixtures/cassandra.yaml": "/etc/cassandra/cassandra.yaml",
|
||||||
|
}
|
||||||
|
host, cleanup := cassandra.PrepareTestContainer(t,
|
||||||
|
cassandra.CopyFromTo(copyFromTo),
|
||||||
|
)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
logicaltest.Test(t, logicaltest.TestCase{
|
logicaltest.Test(t, logicaltest.TestCase{
|
||||||
LogicalBackend: b,
|
LogicalBackend: b,
|
||||||
Steps: []logicaltest.TestStep{
|
Steps: []logicaltest.TestStep{
|
||||||
testAccStepConfig(t, hostname),
|
testAccStepConfig(t, host.ConnectionURL()),
|
||||||
testAccStepRole(t),
|
testAccStepRole(t),
|
||||||
testAccStepReadCreds(t, "test"),
|
testAccStepReadCreds(t, "test"),
|
||||||
},
|
},
|
||||||
|
@ -41,13 +46,17 @@ func TestBackend_roleCrud(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cleanup, hostname := cassandra.PrepareTestContainer(t, "latest")
|
copyFromTo := map[string]string{
|
||||||
|
"test-fixtures/cassandra.yaml": "/etc/cassandra/cassandra.yaml",
|
||||||
|
}
|
||||||
|
host, cleanup := cassandra.PrepareTestContainer(t,
|
||||||
|
cassandra.CopyFromTo(copyFromTo))
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
logicaltest.Test(t, logicaltest.TestCase{
|
logicaltest.Test(t, logicaltest.TestCase{
|
||||||
LogicalBackend: b,
|
LogicalBackend: b,
|
||||||
Steps: []logicaltest.TestStep{
|
Steps: []logicaltest.TestStep{
|
||||||
testAccStepConfig(t, hostname),
|
testAccStepConfig(t, host.ConnectionURL()),
|
||||||
testAccStepRole(t),
|
testAccStepRole(t),
|
||||||
testAccStepRoleWithOptions(t),
|
testAccStepRoleWithOptions(t),
|
||||||
testAccStepReadRole(t, "test", testRole),
|
testAccStepReadRole(t, "test", testRole),
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
```release-note:bug
|
||||||
|
secrets/database/cassandra: Fixed issue where hostnames were not being validated when using TLS
|
||||||
|
```
|
2
go.mod
2
go.mod
|
@ -49,7 +49,7 @@ require (
|
||||||
github.com/go-ole/go-ole v1.2.4 // indirect
|
github.com/go-ole/go-ole v1.2.4 // indirect
|
||||||
github.com/go-sql-driver/mysql v1.5.0
|
github.com/go-sql-driver/mysql v1.5.0
|
||||||
github.com/go-test/deep v1.0.7
|
github.com/go-test/deep v1.0.7
|
||||||
github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e
|
github.com/gocql/gocql v0.0.0-20210401103645-80ab1e13e309
|
||||||
github.com/golang/protobuf v1.4.2
|
github.com/golang/protobuf v1.4.2
|
||||||
github.com/google/go-github v17.0.0+incompatible
|
github.com/google/go-github v17.0.0+incompatible
|
||||||
github.com/google/go-metrics-stackdriver v0.2.0
|
github.com/google/go-metrics-stackdriver v0.2.0
|
||||||
|
|
6
go.sum
6
go.sum
|
@ -442,8 +442,8 @@ github.com/gobuffalo/packd v0.1.0/go.mod h1:M2Juc+hhDXf/PnmBANFCqx4DM3wRbgDvnVWe
|
||||||
github.com/gobuffalo/packr/v2 v2.0.9/go.mod h1:emmyGweYTm6Kdper+iywB6YK5YzuKchGtJQZ0Odn4pQ=
|
github.com/gobuffalo/packr/v2 v2.0.9/go.mod h1:emmyGweYTm6Kdper+iywB6YK5YzuKchGtJQZ0Odn4pQ=
|
||||||
github.com/gobuffalo/packr/v2 v2.2.0/go.mod h1:CaAwI0GPIAv+5wKLtv8Afwl+Cm78K/I/VCm/3ptBN+0=
|
github.com/gobuffalo/packr/v2 v2.2.0/go.mod h1:CaAwI0GPIAv+5wKLtv8Afwl+Cm78K/I/VCm/3ptBN+0=
|
||||||
github.com/gobuffalo/syncx v0.0.0-20190224160051-33c29581e754/go.mod h1:HhnNqWY95UYwwW3uSASeV7vtgYkT2t16hJgV3AEPUpw=
|
github.com/gobuffalo/syncx v0.0.0-20190224160051-33c29581e754/go.mod h1:HhnNqWY95UYwwW3uSASeV7vtgYkT2t16hJgV3AEPUpw=
|
||||||
github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e h1:SroDcndcOU9BVAduPf/PXihXoR2ZYTQYLXbupbqxAyQ=
|
github.com/gocql/gocql v0.0.0-20210401103645-80ab1e13e309 h1:8MHuCGYDXh0skFrLumkCMlt9C29hxhqNx39+Haemeqw=
|
||||||
github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e/go.mod h1:DL0ekTmBSTdlNF25Orwt/JMzqIq3EJ4MVa/J/uK64OY=
|
github.com/gocql/gocql v0.0.0-20210401103645-80ab1e13e309/go.mod h1:DL0ekTmBSTdlNF25Orwt/JMzqIq3EJ4MVa/J/uK64OY=
|
||||||
github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4=
|
github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4=
|
||||||
github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||||
github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s=
|
github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s=
|
||||||
|
@ -1084,6 +1084,7 @@ github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6So
|
||||||
github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||||
github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||||
|
github.com/rogpeppe/go-internal v1.6.2 h1:aIihoIOHCiLZHxyoNQ+ABL4NKhFTgKLBdMLyEAh98m0=
|
||||||
github.com/rogpeppe/go-internal v1.6.2/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
github.com/rogpeppe/go-internal v1.6.2/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
||||||
github.com/rs/zerolog v1.4.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
|
github.com/rs/zerolog v1.4.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
|
||||||
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
|
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
|
||||||
|
@ -1631,6 +1632,7 @@ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8
|
||||||
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
|
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
|
||||||
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw=
|
gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw=
|
||||||
|
gopkg.in/errgo.v2 v2.1.0 h1:0vLT13EuvQ0hNvakwLuFZ/jYrLp5F3kcWHXdRggjCE8=
|
||||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||||
gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=
|
gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=
|
||||||
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
|
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
|
||||||
|
|
|
@ -2,9 +2,10 @@ package cassandra
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -12,33 +13,75 @@ import (
|
||||||
"github.com/hashicorp/vault/helper/testhelpers/docker"
|
"github.com/hashicorp/vault/helper/testhelpers/docker"
|
||||||
)
|
)
|
||||||
|
|
||||||
func PrepareTestContainer(t *testing.T, version string) (func(), string) {
|
type containerConfig struct {
|
||||||
|
version string
|
||||||
|
copyFromTo map[string]string
|
||||||
|
sslOpts *gocql.SslOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
type ContainerOpt func(*containerConfig)
|
||||||
|
|
||||||
|
func Version(version string) ContainerOpt {
|
||||||
|
return func(cfg *containerConfig) {
|
||||||
|
cfg.version = version
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func CopyFromTo(copyFromTo map[string]string) ContainerOpt {
|
||||||
|
return func(cfg *containerConfig) {
|
||||||
|
cfg.copyFromTo = copyFromTo
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func SslOpts(sslOpts *gocql.SslOptions) ContainerOpt {
|
||||||
|
return func(cfg *containerConfig) {
|
||||||
|
cfg.sslOpts = sslOpts
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Host struct {
|
||||||
|
Name string
|
||||||
|
Port string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h Host) ConnectionURL() string {
|
||||||
|
return net.JoinHostPort(h.Name, h.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func PrepareTestContainer(t *testing.T, opts ...ContainerOpt) (Host, func()) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
if os.Getenv("CASSANDRA_HOSTS") != "" {
|
if os.Getenv("CASSANDRA_HOSTS") != "" {
|
||||||
return func() {}, os.Getenv("CASSANDRA_HOSTS")
|
host, port, err := net.SplitHostPort(os.Getenv("CASSANDRA_HOSTS"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to split host & port from CASSANDRA_HOSTS (%s): %s", os.Getenv("CASSANDRA_HOSTS"), err)
|
||||||
|
}
|
||||||
|
h := Host{
|
||||||
|
Name: host,
|
||||||
|
Port: port,
|
||||||
|
}
|
||||||
|
return h, func() {}
|
||||||
}
|
}
|
||||||
|
|
||||||
if version == "" {
|
containerCfg := &containerConfig{
|
||||||
version = "3.11"
|
version: "3.11",
|
||||||
}
|
}
|
||||||
|
|
||||||
var copyFromTo map[string]string
|
for _, opt := range opts {
|
||||||
cwd, _ := os.Getwd()
|
opt(containerCfg)
|
||||||
fixturePath := fmt.Sprintf("%s/test-fixtures/", cwd)
|
}
|
||||||
if _, err := os.Stat(fixturePath); err != nil {
|
|
||||||
if !errors.Is(err, os.ErrNotExist) {
|
copyFromTo := map[string]string{}
|
||||||
// If it doesn't exist, no biggie
|
for from, to := range containerCfg.copyFromTo {
|
||||||
t.Fatal(err)
|
absFrom, err := filepath.Abs(from)
|
||||||
}
|
if err != nil {
|
||||||
} else {
|
t.Fatalf("Unable to get absolute path for file %s", from)
|
||||||
copyFromTo = map[string]string{
|
|
||||||
fixturePath: "/etc/cassandra",
|
|
||||||
}
|
}
|
||||||
|
copyFromTo[absFrom] = to
|
||||||
}
|
}
|
||||||
|
|
||||||
runner, err := docker.NewServiceRunner(docker.RunOptions{
|
runner, err := docker.NewServiceRunner(docker.RunOptions{
|
||||||
ImageRepo: "cassandra",
|
ImageRepo: "cassandra",
|
||||||
ImageTag: version,
|
ImageTag: containerCfg.version,
|
||||||
Ports: []string{"9042/tcp"},
|
Ports: []string{"9042/tcp"},
|
||||||
CopyFromTo: copyFromTo,
|
CopyFromTo: copyFromTo,
|
||||||
Env: []string{"CASSANDRA_BROADCAST_ADDRESS=127.0.0.1"},
|
Env: []string{"CASSANDRA_BROADCAST_ADDRESS=127.0.0.1"},
|
||||||
|
@ -58,6 +101,8 @@ func PrepareTestContainer(t *testing.T, version string) (func(), string) {
|
||||||
clusterConfig.ProtoVersion = 4
|
clusterConfig.ProtoVersion = 4
|
||||||
clusterConfig.Port = port
|
clusterConfig.Port = port
|
||||||
|
|
||||||
|
clusterConfig.SslOpts = containerCfg.sslOpts
|
||||||
|
|
||||||
session, err := clusterConfig.CreateSession()
|
session, err := clusterConfig.CreateSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error creating session: %s", err)
|
return nil, fmt.Errorf("error creating session: %s", err)
|
||||||
|
@ -65,19 +110,19 @@ func PrepareTestContainer(t *testing.T, version string) (func(), string) {
|
||||||
defer session.Close()
|
defer session.Close()
|
||||||
|
|
||||||
// Create keyspace
|
// Create keyspace
|
||||||
q := session.Query(`CREATE KEYSPACE "vault" WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };`)
|
query := session.Query(`CREATE KEYSPACE "vault" WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };`)
|
||||||
if err := q.Exec(); err != nil {
|
if err := query.Exec(); err != nil {
|
||||||
t.Fatalf("could not create cassandra keyspace: %v", err)
|
t.Fatalf("could not create cassandra keyspace: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create table
|
// Create table
|
||||||
q = session.Query(`CREATE TABLE "vault"."entries" (
|
query = session.Query(`CREATE TABLE "vault"."entries" (
|
||||||
bucket text,
|
bucket text,
|
||||||
key text,
|
key text,
|
||||||
value blob,
|
value blob,
|
||||||
PRIMARY KEY (bucket, key)
|
PRIMARY KEY (bucket, key)
|
||||||
) WITH CLUSTERING ORDER BY (key ASC);`)
|
) WITH CLUSTERING ORDER BY (key ASC);`)
|
||||||
if err := q.Exec(); err != nil {
|
if err := query.Exec(); err != nil {
|
||||||
t.Fatalf("could not create cassandra table: %v", err)
|
t.Fatalf("could not create cassandra table: %v", err)
|
||||||
}
|
}
|
||||||
return cfg, nil
|
return cfg, nil
|
||||||
|
@ -85,5 +130,14 @@ func PrepareTestContainer(t *testing.T, version string) (func(), string) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Could not start docker cassandra: %s", err)
|
t.Fatalf("Could not start docker cassandra: %s", err)
|
||||||
}
|
}
|
||||||
return svc.Cleanup, svc.Config.Address()
|
|
||||||
|
host, port, err := net.SplitHostPort(svc.Config.Address())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to split host & port from address (%s): %s", svc.Config.Address(), err)
|
||||||
|
}
|
||||||
|
h := Host{
|
||||||
|
Name: host,
|
||||||
|
Port: port,
|
||||||
|
}
|
||||||
|
return h, svc.Cleanup
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,11 +10,10 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/errwrap"
|
|
||||||
log "github.com/hashicorp/go-hclog"
|
|
||||||
|
|
||||||
metrics "github.com/armon/go-metrics"
|
metrics "github.com/armon/go-metrics"
|
||||||
"github.com/gocql/gocql"
|
"github.com/gocql/gocql"
|
||||||
|
"github.com/hashicorp/errwrap"
|
||||||
|
log "github.com/hashicorp/go-hclog"
|
||||||
"github.com/hashicorp/vault/sdk/helper/certutil"
|
"github.com/hashicorp/vault/sdk/helper/certutil"
|
||||||
"github.com/hashicorp/vault/sdk/physical"
|
"github.com/hashicorp/vault/sdk/physical"
|
||||||
)
|
)
|
||||||
|
@ -180,20 +179,18 @@ func setupCassandraTLS(conf map[string]string, cluster *gocql.ClusterConfig) err
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else if pemJSONPath, ok := conf["pem_json_file"]; ok {
|
||||||
if pemJSONPath, ok := conf["pem_json_file"]; ok {
|
pemJSONData, err := ioutil.ReadFile(pemJSONPath)
|
||||||
pemJSONData, err := ioutil.ReadFile(pemJSONPath)
|
if err != nil {
|
||||||
if err != nil {
|
return errwrap.Wrapf(fmt.Sprintf("error reading json bundle from %q: {{err}}", pemJSONPath), err)
|
||||||
return errwrap.Wrapf(fmt.Sprintf("error reading json bundle from %q: {{err}}", pemJSONPath), err)
|
}
|
||||||
}
|
pemJSON, err := certutil.ParsePKIJSON([]byte(pemJSONData))
|
||||||
pemJSON, err := certutil.ParsePKIJSON([]byte(pemJSONData))
|
if err != nil {
|
||||||
if err != nil {
|
return err
|
||||||
return err
|
}
|
||||||
}
|
tlsConfig, err = pemJSON.GetTLSConfig(certutil.TLSClient)
|
||||||
tlsConfig, err = pemJSON.GetTLSConfig(certutil.TLSClient)
|
if err != nil {
|
||||||
if err != nil {
|
return err
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -225,7 +222,8 @@ func setupCassandraTLS(conf map[string]string, cluster *gocql.ClusterConfig) err
|
||||||
}
|
}
|
||||||
|
|
||||||
cluster.SslOpts = &gocql.SslOptions{
|
cluster.SslOpts = &gocql.SslOptions{
|
||||||
Config: tlsConfig.Clone(),
|
Config: tlsConfig,
|
||||||
|
EnableHostVerification: !tlsConfig.InsecureSkipVerify,
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,13 +19,13 @@ func TestCassandraBackend(t *testing.T) {
|
||||||
t.Skip("skipping race test in CI pending https://github.com/gocql/gocql/pull/1474")
|
t.Skip("skipping race test in CI pending https://github.com/gocql/gocql/pull/1474")
|
||||||
}
|
}
|
||||||
|
|
||||||
cleanup, hosts := cassandra.PrepareTestContainer(t, "")
|
host, cleanup := cassandra.PrepareTestContainer(t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
// Run vault tests
|
// Run vault tests
|
||||||
logger := logging.NewVaultLogger(log.Debug)
|
logger := logging.NewVaultLogger(log.Debug)
|
||||||
b, err := NewCassandraBackend(map[string]string{
|
b, err := NewCassandraBackend(map[string]string{
|
||||||
"hosts": hosts,
|
"hosts": host.ConnectionURL(),
|
||||||
"protocol_version": "3",
|
"protocol_version": "3",
|
||||||
}, logger)
|
}, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -3,7 +3,6 @@ package cassandra
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -17,14 +16,16 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func getCassandra(t *testing.T, protocolVersion interface{}) (*Cassandra, func()) {
|
func getCassandra(t *testing.T, protocolVersion interface{}) (*Cassandra, func()) {
|
||||||
cleanup, connURL := cassandra.PrepareTestContainer(t, "latest")
|
host, cleanup := cassandra.PrepareTestContainer(t,
|
||||||
pieces := strings.Split(connURL, ":")
|
cassandra.Version("latest"),
|
||||||
|
cassandra.CopyFromTo(insecureFileMounts),
|
||||||
|
)
|
||||||
|
|
||||||
db := new()
|
db := new()
|
||||||
initReq := dbplugin.InitializeRequest{
|
initReq := dbplugin.InitializeRequest{
|
||||||
Config: map[string]interface{}{
|
Config: map[string]interface{}{
|
||||||
"hosts": connURL,
|
"hosts": host.ConnectionURL(),
|
||||||
"port": pieces[1],
|
"port": host.Port,
|
||||||
"username": "cassandra",
|
"username": "cassandra",
|
||||||
"password": "cassandra",
|
"password": "cassandra",
|
||||||
"protocol_version": protocolVersion,
|
"protocol_version": protocolVersion,
|
||||||
|
@ -34,8 +35,8 @@ func getCassandra(t *testing.T, protocolVersion interface{}) (*Cassandra, func()
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedConfig := map[string]interface{}{
|
expectedConfig := map[string]interface{}{
|
||||||
"hosts": connURL,
|
"hosts": host.ConnectionURL(),
|
||||||
"port": pieces[1],
|
"port": host.Port,
|
||||||
"username": "cassandra",
|
"username": "cassandra",
|
||||||
"password": "cassandra",
|
"password": "cassandra",
|
||||||
"protocol_version": protocolVersion,
|
"protocol_version": protocolVersion,
|
||||||
|
@ -53,7 +54,7 @@ func getCassandra(t *testing.T, protocolVersion interface{}) (*Cassandra, func()
|
||||||
return db, cleanup
|
return db, cleanup
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCassandra_Initialize(t *testing.T) {
|
func TestInitialize(t *testing.T) {
|
||||||
db, cleanup := getCassandra(t, 4)
|
db, cleanup := getCassandra(t, 4)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
|
@ -66,7 +67,7 @@ func TestCassandra_Initialize(t *testing.T) {
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCassandra_CreateUser(t *testing.T) {
|
func TestCreateUser(t *testing.T) {
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
// Config will have the hosts & port added to it during the test
|
// Config will have the hosts & port added to it during the test
|
||||||
config map[string]interface{}
|
config map[string]interface{}
|
||||||
|
@ -126,15 +127,17 @@ func TestCassandra_CreateUser(t *testing.T) {
|
||||||
|
|
||||||
for name, test := range tests {
|
for name, test := range tests {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
cleanup, connURL := cassandra.PrepareTestContainer(t, "latest")
|
host, cleanup := cassandra.PrepareTestContainer(t,
|
||||||
pieces := strings.Split(connURL, ":")
|
cassandra.Version("latest"),
|
||||||
|
cassandra.CopyFromTo(insecureFileMounts),
|
||||||
|
)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
db := new()
|
db := new()
|
||||||
|
|
||||||
config := test.config
|
config := test.config
|
||||||
config["hosts"] = connURL
|
config["hosts"] = host.ConnectionURL()
|
||||||
config["port"] = pieces[1]
|
config["port"] = host.Port
|
||||||
|
|
||||||
initReq := dbplugin.InitializeRequest{
|
initReq := dbplugin.InitializeRequest{
|
||||||
Config: config,
|
Config: config,
|
||||||
|
@ -162,7 +165,7 @@ func TestCassandra_CreateUser(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMyCassandra_UpdateUserPassword(t *testing.T) {
|
func TestUpdateUserPassword(t *testing.T) {
|
||||||
db, cleanup := getCassandra(t, 4)
|
db, cleanup := getCassandra(t, 4)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
|
@ -198,7 +201,7 @@ func TestMyCassandra_UpdateUserPassword(t *testing.T) {
|
||||||
assertCreds(t, db.Hosts, db.Port, createResp.Username, newPassword, 5*time.Second)
|
assertCreds(t, db.Hosts, db.Port, createResp.Username, newPassword, 5*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCassandra_DeleteUser(t *testing.T) {
|
func TestDeleteUser(t *testing.T) {
|
||||||
db, cleanup := getCassandra(t, 4)
|
db, cleanup := getCassandra(t, 4)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
|
|
|
@ -8,14 +8,13 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gocql/gocql"
|
||||||
|
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
|
||||||
"github.com/hashicorp/vault/sdk/database/helper/connutil"
|
"github.com/hashicorp/vault/sdk/database/helper/connutil"
|
||||||
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
||||||
"github.com/hashicorp/vault/sdk/helper/certutil"
|
"github.com/hashicorp/vault/sdk/helper/certutil"
|
||||||
"github.com/hashicorp/vault/sdk/helper/parseutil"
|
"github.com/hashicorp/vault/sdk/helper/parseutil"
|
||||||
"github.com/hashicorp/vault/sdk/helper/tlsutil"
|
"github.com/hashicorp/vault/sdk/helper/tlsutil"
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
|
|
||||||
"github.com/mitchellh/mapstructure"
|
"github.com/mitchellh/mapstructure"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -40,9 +39,7 @@ type cassandraConnectionProducer struct {
|
||||||
|
|
||||||
connectTimeout time.Duration
|
connectTimeout time.Duration
|
||||||
socketKeepAlive time.Duration
|
socketKeepAlive time.Duration
|
||||||
certificate string
|
certBundle *certutil.CertBundle
|
||||||
privateKey string
|
|
||||||
issuingCA string
|
|
||||||
rawConfig map[string]interface{}
|
rawConfig map[string]interface{}
|
||||||
|
|
||||||
Initialized bool
|
Initialized bool
|
||||||
|
@ -99,9 +96,7 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, req dbplug
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error marshaling PEM information: %w", err)
|
return fmt.Errorf("error marshaling PEM information: %w", err)
|
||||||
}
|
}
|
||||||
c.certificate = certBundle.Certificate
|
c.certBundle = certBundle
|
||||||
c.privateKey = certBundle.PrivateKey
|
|
||||||
c.issuingCA = certBundle.IssuingCA
|
|
||||||
c.TLS = true
|
c.TLS = true
|
||||||
|
|
||||||
case len(c.PemBundle) != 0:
|
case len(c.PemBundle) != 0:
|
||||||
|
@ -113,9 +108,11 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, req dbplug
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error marshaling PEM information: %w", err)
|
return fmt.Errorf("error marshaling PEM information: %w", err)
|
||||||
}
|
}
|
||||||
c.certificate = certBundle.Certificate
|
c.certBundle = certBundle
|
||||||
c.privateKey = certBundle.PrivateKey
|
c.TLS = true
|
||||||
c.issuingCA = certBundle.IssuingCA
|
}
|
||||||
|
|
||||||
|
if c.InsecureTLS {
|
||||||
c.TLS = true
|
c.TLS = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -185,49 +182,13 @@ func (c *cassandraConnectionProducer) createSession(ctx context.Context) (*gocql
|
||||||
|
|
||||||
clusterConfig.Timeout = c.connectTimeout
|
clusterConfig.Timeout = c.connectTimeout
|
||||||
clusterConfig.SocketKeepalive = c.socketKeepAlive
|
clusterConfig.SocketKeepalive = c.socketKeepAlive
|
||||||
|
|
||||||
if c.TLS {
|
if c.TLS {
|
||||||
var tlsConfig *tls.Config
|
sslOpts, err := getSslOpts(c.certBundle, c.TLSMinVersion, c.InsecureTLS)
|
||||||
if len(c.certificate) > 0 || len(c.issuingCA) > 0 {
|
if err != nil {
|
||||||
if len(c.certificate) > 0 && len(c.privateKey) == 0 {
|
return nil, err
|
||||||
return nil, fmt.Errorf("found certificate for TLS authentication but no private key")
|
|
||||||
}
|
|
||||||
|
|
||||||
certBundle := &certutil.CertBundle{}
|
|
||||||
if len(c.certificate) > 0 {
|
|
||||||
certBundle.Certificate = c.certificate
|
|
||||||
certBundle.PrivateKey = c.privateKey
|
|
||||||
}
|
|
||||||
if len(c.issuingCA) > 0 {
|
|
||||||
certBundle.IssuingCA = c.issuingCA
|
|
||||||
}
|
|
||||||
|
|
||||||
parsedCertBundle, err := certBundle.ToParsedCertBundle()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse certificate bundle: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
|
|
||||||
if err != nil || tlsConfig == nil {
|
|
||||||
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%w", tlsConfig, err)
|
|
||||||
}
|
|
||||||
tlsConfig.InsecureSkipVerify = c.InsecureTLS
|
|
||||||
|
|
||||||
if c.TLSMinVersion != "" {
|
|
||||||
var ok bool
|
|
||||||
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[c.TLSMinVersion]
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// MinVersion was not being set earlier. Reset it to
|
|
||||||
// zero to gracefully handle upgrades.
|
|
||||||
tlsConfig.MinVersion = 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
clusterConfig.SslOpts = &gocql.SslOptions{
|
|
||||||
Config: tlsConfig,
|
|
||||||
}
|
}
|
||||||
|
clusterConfig.SslOpts = sslOpts
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.LocalDatacenter != "" {
|
if c.LocalDatacenter != "" {
|
||||||
|
@ -269,6 +230,48 @@ func (c *cassandraConnectionProducer) createSession(ctx context.Context) (*gocql
|
||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getSslOpts(certBundle *certutil.CertBundle, minTLSVersion string, insecureSkipVerify bool) (*gocql.SslOptions, error) {
|
||||||
|
tlsConfig := &tls.Config{}
|
||||||
|
if certBundle != nil {
|
||||||
|
if certBundle.Certificate == "" && certBundle.PrivateKey != "" {
|
||||||
|
return nil, fmt.Errorf("found private key for TLS authentication but no certificate")
|
||||||
|
}
|
||||||
|
if certBundle.Certificate != "" && certBundle.PrivateKey == "" {
|
||||||
|
return nil, fmt.Errorf("found certificate for TLS authentication but no private key")
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedCertBundle, err := certBundle.ToParsedCertBundle()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse certificate bundle: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%w", tlsConfig, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig.InsecureSkipVerify = insecureSkipVerify
|
||||||
|
|
||||||
|
if minTLSVersion != "" {
|
||||||
|
var ok bool
|
||||||
|
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[minTLSVersion]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// MinVersion was not being set earlier. Reset it to
|
||||||
|
// zero to gracefully handle upgrades.
|
||||||
|
tlsConfig.MinVersion = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := &gocql.SslOptions{
|
||||||
|
Config: tlsConfig,
|
||||||
|
EnableHostVerification: !insecureSkipVerify,
|
||||||
|
}
|
||||||
|
return opts, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *cassandraConnectionProducer) secretValues() map[string]string {
|
func (c *cassandraConnectionProducer) secretValues() map[string]string {
|
||||||
return map[string]string{
|
return map[string]string{
|
||||||
c.Password: "[password]",
|
c.Password: "[password]",
|
||||||
|
|
|
@ -0,0 +1,95 @@
|
||||||
|
package cassandra
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gocql/gocql"
|
||||||
|
"github.com/hashicorp/vault/helper/testhelpers/cassandra"
|
||||||
|
"github.com/hashicorp/vault/sdk/database/dbplugin/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
insecureFileMounts = map[string]string{
|
||||||
|
"test-fixtures/no_tls/cassandra.yaml": "/etc/cassandra/cassandra.yaml",
|
||||||
|
}
|
||||||
|
secureFileMounts = map[string]string{
|
||||||
|
"test-fixtures/with_tls/cassandra.yaml": "/etc/cassandra/cassandra.yaml",
|
||||||
|
"test-fixtures/with_tls/keystore.jks": "/etc/cassandra/keystore.jks",
|
||||||
|
"test-fixtures/with_tls/.cassandra": "/root/.cassandra/",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTLSConnection(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
config map[string]interface{}
|
||||||
|
expectErr bool
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := map[string]testCase{
|
||||||
|
"tls not specified": {
|
||||||
|
config: map[string]interface{}{},
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
|
"unrecognized certificate": {
|
||||||
|
config: map[string]interface{}{
|
||||||
|
"tls": "true",
|
||||||
|
},
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
|
"insecure TLS": {
|
||||||
|
config: map[string]interface{}{
|
||||||
|
"tls": "true",
|
||||||
|
"insecure_tls": true,
|
||||||
|
},
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, test := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
host, cleanup := cassandra.PrepareTestContainer(t,
|
||||||
|
cassandra.Version("3.11.9"),
|
||||||
|
cassandra.CopyFromTo(secureFileMounts),
|
||||||
|
cassandra.SslOpts(&gocql.SslOptions{
|
||||||
|
Config: &tls.Config{InsecureSkipVerify: true},
|
||||||
|
EnableHostVerification: false,
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Set values that we don't know until the cassandra container is started
|
||||||
|
config := map[string]interface{}{
|
||||||
|
"hosts": host.ConnectionURL(),
|
||||||
|
"port": host.Port,
|
||||||
|
"username": "cassandra",
|
||||||
|
"password": "cassandra",
|
||||||
|
"protocol_version": "3",
|
||||||
|
"connect_timeout": "20s",
|
||||||
|
}
|
||||||
|
// Then add any values specified in the test config. Generally for these tests they shouldn't overlap
|
||||||
|
for k, v := range test.config {
|
||||||
|
config[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
db := new()
|
||||||
|
initReq := dbplugin.InitializeRequest{
|
||||||
|
Config: config,
|
||||||
|
VerifyConnection: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, err := db.Initialize(ctx, initReq)
|
||||||
|
if test.expectErr && err == nil {
|
||||||
|
t.Fatalf("err expected, got nil")
|
||||||
|
}
|
||||||
|
if !test.expectErr && err != nil {
|
||||||
|
t.Fatalf("no error expected, got: %s", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,3 @@
|
||||||
|
[ssl]
|
||||||
|
validate = false
|
||||||
|
version = SSLv23
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,46 @@
|
||||||
|
#!/bin/sh
|
||||||
|
|
||||||
|
################################################################
|
||||||
|
# Usage: ./gencert.sh
|
||||||
|
#
|
||||||
|
# Generates a keystore.jks file that can be used with a
|
||||||
|
# Cassandra server for TLS connections. This does not update
|
||||||
|
# a cassandra config file.
|
||||||
|
################################################################
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
KEYFILE="key.pem"
|
||||||
|
CERTFILE="cert.pem"
|
||||||
|
PKCSFILE="keystore.p12"
|
||||||
|
JKSFILE="keystore.jks"
|
||||||
|
|
||||||
|
HOST="127.0.0.1"
|
||||||
|
NAME="cassandra"
|
||||||
|
ALIAS="cassandra"
|
||||||
|
PASSWORD="cassandra"
|
||||||
|
|
||||||
|
echo "# Generating certificate keypair..."
|
||||||
|
go run /usr/local/go/src/crypto/tls/generate_cert.go --host=${HOST}
|
||||||
|
|
||||||
|
echo "# Creating keystore..."
|
||||||
|
openssl pkcs12 -export -in ${CERTFILE} -inkey ${KEYFILE} -name ${NAME} -password pass:${PASSWORD} > ${PKCSFILE}
|
||||||
|
|
||||||
|
echo "# Creating Java key store"
|
||||||
|
if [ -e "${JKSFILE}" ]; then
|
||||||
|
echo "# Removing old key store"
|
||||||
|
rm ${JKSFILE}
|
||||||
|
fi
|
||||||
|
|
||||||
|
set +e
|
||||||
|
keytool -importkeystore \
|
||||||
|
-srckeystore ${PKCSFILE} \
|
||||||
|
-srcstoretype PKCS12 \
|
||||||
|
-srcstorepass ${PASSWORD} \
|
||||||
|
-destkeystore ${JKSFILE} \
|
||||||
|
-deststorepass ${PASSWORD} \
|
||||||
|
-destkeypass ${PASSWORD} \
|
||||||
|
-alias ${ALIAS}
|
||||||
|
|
||||||
|
echo "# Removing intermediate files"
|
||||||
|
rm ${KEYFILE} ${CERTFILE} ${PKCSFILE}
|
BIN
plugins/database/cassandra/test-fixtures/with_tls/keystore.jks (Stored with Git LFS)
Normal file
BIN
plugins/database/cassandra/test-fixtures/with_tls/keystore.jks (Stored with Git LFS)
Normal file
Binary file not shown.
|
@ -31,8 +31,10 @@ env:
|
||||||
AUTH=false
|
AUTH=false
|
||||||
|
|
||||||
go:
|
go:
|
||||||
- 1.13.x
|
- 1.15.x
|
||||||
- 1.14.x
|
- 1.16.x
|
||||||
|
|
||||||
|
go_import_path: github.com/gocql/gocql
|
||||||
|
|
||||||
install:
|
install:
|
||||||
- ./install_test_deps.sh $TRAVIS_REPO_SLUG
|
- ./install_test_deps.sh $TRAVIS_REPO_SLUG
|
||||||
|
|
|
@ -115,3 +115,8 @@ Pavel Buchinchik <p.buchinchik@gmail.com>
|
||||||
Rintaro Okamura <rintaro.okamura@gmail.com>
|
Rintaro Okamura <rintaro.okamura@gmail.com>
|
||||||
Yura Sokolov <y.sokolov@joom.com>; <funny.falcon@gmail.com>
|
Yura Sokolov <y.sokolov@joom.com>; <funny.falcon@gmail.com>
|
||||||
Jorge Bay <jorgebg@apache.org>
|
Jorge Bay <jorgebg@apache.org>
|
||||||
|
Dmitriy Kozlov <hummerd@mail.ru>
|
||||||
|
Alexey Romanovsky <alexus1024+gocql@gmail.com>
|
||||||
|
Jaume Marhuenda Beltran <jaumemarhuenda@gmail.com>
|
||||||
|
Piotr Dulikowski <piodul@scylladb.com>
|
||||||
|
Árni Dagur <arni@dagur.eu>
|
|
@ -19,8 +19,8 @@ The following matrix shows the versions of Go and Cassandra that are tested with
|
||||||
|
|
||||||
Go/Cassandra | 2.1.x | 2.2.x | 3.x.x
|
Go/Cassandra | 2.1.x | 2.2.x | 3.x.x
|
||||||
-------------| -------| ------| ---------
|
-------------| -------| ------| ---------
|
||||||
1.13 | yes | yes | yes
|
1.15 | yes | yes | yes
|
||||||
1.14 | yes | yes | yes
|
1.16 | yes | yes | yes
|
||||||
|
|
||||||
Gocql has been tested in production against many different versions of Cassandra. Due to limits in our CI setup we only test against the latest 3 major releases, which coincide with the official support from the Apache project.
|
Gocql has been tested in production against many different versions of Cassandra. Due to limits in our CI setup we only test against the latest 3 major releases, which coincide with the official support from the Apache project.
|
||||||
|
|
||||||
|
@ -114,73 +114,7 @@ statement.
|
||||||
Example
|
Example
|
||||||
-------
|
-------
|
||||||
|
|
||||||
```go
|
See [package documentation](https://pkg.go.dev/github.com/gocql/gocql#pkg-examples).
|
||||||
/* Before you execute the program, Launch `cqlsh` and execute:
|
|
||||||
create keyspace example with replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };
|
|
||||||
create table example.tweet(timeline text, id UUID, text text, PRIMARY KEY(id));
|
|
||||||
create index on example.tweet(timeline);
|
|
||||||
*/
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
// connect to the cluster
|
|
||||||
cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
|
|
||||||
cluster.Keyspace = "example"
|
|
||||||
cluster.Consistency = gocql.Quorum
|
|
||||||
session, _ := cluster.CreateSession()
|
|
||||||
defer session.Close()
|
|
||||||
|
|
||||||
// insert a tweet
|
|
||||||
if err := session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
|
|
||||||
"me", gocql.TimeUUID(), "hello world").Exec(); err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var id gocql.UUID
|
|
||||||
var text string
|
|
||||||
|
|
||||||
/* Search for a specific set of records whose 'timeline' column matches
|
|
||||||
* the value 'me'. The secondary index that we created earlier will be
|
|
||||||
* used for optimizing the search */
|
|
||||||
if err := session.Query(`SELECT id, text FROM tweet WHERE timeline = ? LIMIT 1`,
|
|
||||||
"me").Consistency(gocql.One).Scan(&id, &text); err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
fmt.Println("Tweet:", id, text)
|
|
||||||
|
|
||||||
// list all tweets
|
|
||||||
iter := session.Query(`SELECT id, text FROM tweet WHERE timeline = ?`, "me").Iter()
|
|
||||||
for iter.Scan(&id, &text) {
|
|
||||||
fmt.Println("Tweet:", id, text)
|
|
||||||
}
|
|
||||||
if err := iter.Close(); err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
Authentication
|
|
||||||
-------
|
|
||||||
|
|
||||||
```go
|
|
||||||
cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
|
|
||||||
cluster.Authenticator = gocql.PasswordAuthenticator{
|
|
||||||
Username: "user",
|
|
||||||
Password: "password"
|
|
||||||
}
|
|
||||||
cluster.Keyspace = "example"
|
|
||||||
cluster.Consistency = gocql.Quorum
|
|
||||||
session, _ := cluster.CreateSession()
|
|
||||||
defer session.Close()
|
|
||||||
```
|
|
||||||
|
|
||||||
Data Binding
|
Data Binding
|
||||||
------------
|
------------
|
||||||
|
|
|
@ -44,8 +44,8 @@ func approve(authenticator string) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
//JoinHostPort is a utility to return a address string that can be used
|
// JoinHostPort is a utility to return an address string that can be used
|
||||||
//gocql.Conn to form a connection with a host.
|
// by `gocql.Conn` to form a connection with a host.
|
||||||
func JoinHostPort(addr string, port int) string {
|
func JoinHostPort(addr string, port int) string {
|
||||||
addr = strings.TrimSpace(addr)
|
addr = strings.TrimSpace(addr)
|
||||||
if _, _, err := net.SplitHostPort(addr); err != nil {
|
if _, _, err := net.SplitHostPort(addr); err != nil {
|
||||||
|
@ -80,6 +80,19 @@ func (p PasswordAuthenticator) Success(data []byte) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SslOptions configures TLS use.
|
||||||
|
//
|
||||||
|
// Warning: Due to historical reasons, the SslOptions is insecure by default, so you need to set EnableHostVerification
|
||||||
|
// to true if no Config is set. Most users should set SslOptions.Config to a *tls.Config.
|
||||||
|
// SslOptions and Config.InsecureSkipVerify interact as follows:
|
||||||
|
//
|
||||||
|
// Config.InsecureSkipVerify | EnableHostVerification | Result
|
||||||
|
// Config is nil | false | do not verify host
|
||||||
|
// Config is nil | true | verify host
|
||||||
|
// false | false | verify host
|
||||||
|
// true | false | do not verify host
|
||||||
|
// false | true | verify host
|
||||||
|
// true | true | verify host
|
||||||
type SslOptions struct {
|
type SslOptions struct {
|
||||||
*tls.Config
|
*tls.Config
|
||||||
|
|
||||||
|
@ -89,9 +102,12 @@ type SslOptions struct {
|
||||||
CertPath string
|
CertPath string
|
||||||
KeyPath string
|
KeyPath string
|
||||||
CaPath string //optional depending on server config
|
CaPath string //optional depending on server config
|
||||||
// If you want to verify the hostname and server cert (like a wildcard for cass cluster) then you should turn this on
|
// If you want to verify the hostname and server cert (like a wildcard for cass cluster) then you should turn this
|
||||||
// This option is basically the inverse of InSecureSkipVerify
|
// on.
|
||||||
// See InSecureSkipVerify in http://golang.org/pkg/crypto/tls/ for more info
|
// This option is basically the inverse of tls.Config.InsecureSkipVerify.
|
||||||
|
// See InsecureSkipVerify in http://golang.org/pkg/crypto/tls/ for more info.
|
||||||
|
//
|
||||||
|
// See SslOptions documentation to see how EnableHostVerification interacts with the provided tls.Config.
|
||||||
EnableHostVerification bool
|
EnableHostVerification bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -125,7 +141,7 @@ func (fn connErrorHandlerFn) HandleError(conn *Conn, err error, closed bool) {
|
||||||
// which may be serving more queries just fine.
|
// which may be serving more queries just fine.
|
||||||
// Default is 0, should not be changed concurrently with queries.
|
// Default is 0, should not be changed concurrently with queries.
|
||||||
//
|
//
|
||||||
// depreciated
|
// Deprecated.
|
||||||
var TimeoutLimit int64 = 0
|
var TimeoutLimit int64 = 0
|
||||||
|
|
||||||
// Conn is a single connection to a Cassandra node. It can be used to execute
|
// Conn is a single connection to a Cassandra node. It can be used to execute
|
||||||
|
@ -213,14 +229,26 @@ func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg *
|
||||||
dialer = d
|
dialer = d
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := dialer.DialContext(ctx, "tcp", host.HostnameAndPort())
|
addr := host.HostnameAndPort()
|
||||||
|
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if cfg.tlsConfig != nil {
|
if cfg.tlsConfig != nil {
|
||||||
// the TLS config is safe to be reused by connections but it must not
|
// the TLS config is safe to be reused by connections but it must not
|
||||||
// be modified after being used.
|
// be modified after being used.
|
||||||
tconn := tls.Client(conn, cfg.tlsConfig)
|
tlsConfig := cfg.tlsConfig
|
||||||
|
if !tlsConfig.InsecureSkipVerify && tlsConfig.ServerName == "" {
|
||||||
|
colonPos := strings.LastIndex(addr, ":")
|
||||||
|
if colonPos == -1 {
|
||||||
|
colonPos = len(addr)
|
||||||
|
}
|
||||||
|
hostname := addr[:colonPos]
|
||||||
|
// clone config to avoid modifying the shared one.
|
||||||
|
tlsConfig = tlsConfig.Clone()
|
||||||
|
tlsConfig.ServerName = hostname
|
||||||
|
}
|
||||||
|
tconn := tls.Client(conn, tlsConfig)
|
||||||
if err := tconn.Handshake(); err != nil {
|
if err := tconn.Handshake(); err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -845,6 +873,10 @@ func (w *writeCoalescer) writeFlusher(interval time.Duration) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*framer, error) {
|
func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*framer, error) {
|
||||||
|
if ctxErr := ctx.Err(); ctxErr != nil {
|
||||||
|
return nil, ctxErr
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: move tracer onto conn
|
// TODO: move tracer onto conn
|
||||||
stream, ok := c.streams.GetStream()
|
stream, ok := c.streams.GetStream()
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -1173,12 +1205,16 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
|
||||||
}
|
}
|
||||||
|
|
||||||
if x.meta.morePages() && !qry.disableAutoPage {
|
if x.meta.morePages() && !qry.disableAutoPage {
|
||||||
|
newQry := new(Query)
|
||||||
|
*newQry = *qry
|
||||||
|
newQry.pageState = copyBytes(x.meta.pagingState)
|
||||||
|
newQry.metrics = &queryMetrics{m: make(map[string]*hostMetrics)}
|
||||||
|
|
||||||
iter.next = &nextIter{
|
iter.next = &nextIter{
|
||||||
qry: qry,
|
qry: newQry,
|
||||||
pos: int((1 - qry.prefetch) * float64(x.numRows)),
|
pos: int((1 - qry.prefetch) * float64(x.numRows)),
|
||||||
}
|
}
|
||||||
|
|
||||||
iter.next.qry.pageState = copyBytes(x.meta.pagingState)
|
|
||||||
if iter.next.pos < 1 {
|
if iter.next.pos < 1 {
|
||||||
iter.next.pos = 1
|
iter.next.pos = 1
|
||||||
}
|
}
|
||||||
|
@ -1359,10 +1395,11 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) query(ctx context.Context, statement string, values ...interface{}) (iter *Iter) {
|
func (c *Conn) query(ctx context.Context, statement string, values ...interface{}) (iter *Iter) {
|
||||||
q := c.session.Query(statement, values...).Consistency(One)
|
q := c.session.Query(statement, values...).Consistency(One).Trace(nil)
|
||||||
q.trace = nil
|
|
||||||
q.skipPrepare = true
|
q.skipPrepare = true
|
||||||
q.disableSkipMetadata = true
|
q.disableSkipMetadata = true
|
||||||
|
// we want to keep the query on this connection
|
||||||
|
q.conn = c
|
||||||
return c.executeQuery(ctx, q)
|
return c.executeQuery(ctx, q)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,14 +28,31 @@ type SetPartitioner interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
|
func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
|
||||||
|
// Config.InsecureSkipVerify | EnableHostVerification | Result
|
||||||
|
// Config is nil | true | verify host
|
||||||
|
// Config is nil | false | do not verify host
|
||||||
|
// false | false | verify host
|
||||||
|
// true | false | do not verify host
|
||||||
|
// false | true | verify host
|
||||||
|
// true | true | verify host
|
||||||
|
var tlsConfig *tls.Config
|
||||||
if sslOpts.Config == nil {
|
if sslOpts.Config == nil {
|
||||||
sslOpts.Config = &tls.Config{}
|
tlsConfig = &tls.Config{
|
||||||
|
InsecureSkipVerify: !sslOpts.EnableHostVerification,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// use clone to avoid race.
|
||||||
|
tlsConfig = sslOpts.Config.Clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
if tlsConfig.InsecureSkipVerify && sslOpts.EnableHostVerification {
|
||||||
|
tlsConfig.InsecureSkipVerify = false
|
||||||
}
|
}
|
||||||
|
|
||||||
// ca cert is optional
|
// ca cert is optional
|
||||||
if sslOpts.CaPath != "" {
|
if sslOpts.CaPath != "" {
|
||||||
if sslOpts.RootCAs == nil {
|
if tlsConfig.RootCAs == nil {
|
||||||
sslOpts.RootCAs = x509.NewCertPool()
|
tlsConfig.RootCAs = x509.NewCertPool()
|
||||||
}
|
}
|
||||||
|
|
||||||
pem, err := ioutil.ReadFile(sslOpts.CaPath)
|
pem, err := ioutil.ReadFile(sslOpts.CaPath)
|
||||||
|
@ -43,7 +60,7 @@ func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
|
||||||
return nil, fmt.Errorf("connectionpool: unable to open CA certs: %v", err)
|
return nil, fmt.Errorf("connectionpool: unable to open CA certs: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !sslOpts.RootCAs.AppendCertsFromPEM(pem) {
|
if !tlsConfig.RootCAs.AppendCertsFromPEM(pem) {
|
||||||
return nil, errors.New("connectionpool: failed parsing or CA certs")
|
return nil, errors.New("connectionpool: failed parsing or CA certs")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -53,13 +70,10 @@ func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("connectionpool: unable to load X509 key pair: %v", err)
|
return nil, fmt.Errorf("connectionpool: unable to load X509 key pair: %v", err)
|
||||||
}
|
}
|
||||||
sslOpts.Certificates = append(sslOpts.Certificates, mycert)
|
tlsConfig.Certificates = append(tlsConfig.Certificates, mycert)
|
||||||
}
|
}
|
||||||
|
|
||||||
sslOpts.InsecureSkipVerify = !sslOpts.EnableHostVerification
|
return tlsConfig, nil
|
||||||
|
|
||||||
// return clone to avoid race
|
|
||||||
return sslOpts.Config.Clone(), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type policyConnPool struct {
|
type policyConnPool struct {
|
||||||
|
@ -238,12 +252,6 @@ func (p *policyConnPool) removeHost(ip net.IP) {
|
||||||
go pool.Close()
|
go pool.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *policyConnPool) hostUp(host *HostInfo) {
|
|
||||||
// TODO(zariel): have a set of up hosts and down hosts, we can internally
|
|
||||||
// detect down hosts, then try to reconnect to them.
|
|
||||||
p.addHost(host)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *policyConnPool) hostDown(ip net.IP) {
|
func (p *policyConnPool) hostDown(ip net.IP) {
|
||||||
// TODO(zariel): mark host as down so we can try to connect to it later, for
|
// TODO(zariel): mark host as down so we can try to connect to it later, for
|
||||||
// now just treat it has removed.
|
// now just treat it has removed.
|
||||||
|
@ -429,6 +437,8 @@ func (pool *hostConnPool) fill() {
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// notify the session that this node is connected
|
||||||
|
go pool.session.handleNodeUp(pool.host.ConnectAddress(), pool.port)
|
||||||
|
|
||||||
// filled one
|
// filled one
|
||||||
fillCount--
|
fillCount--
|
||||||
|
@ -440,6 +450,11 @@ func (pool *hostConnPool) fill() {
|
||||||
|
|
||||||
// mark the end of filling
|
// mark the end of filling
|
||||||
pool.fillingStopped(err != nil)
|
pool.fillingStopped(err != nil)
|
||||||
|
|
||||||
|
if err == nil && startCount > 0 {
|
||||||
|
// notify the session that this node is connected again
|
||||||
|
go pool.session.handleNodeUp(pool.host.ConnectAddress(), pool.port)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -125,7 +125,7 @@ func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else if len(ips) == 0 {
|
} else if len(ips) == 0 {
|
||||||
return nil, fmt.Errorf("No IP's returned from DNS lookup for %q", addr)
|
return nil, fmt.Errorf("no IP's returned from DNS lookup for %q", addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter to v4 addresses if any present
|
// Filter to v4 addresses if any present
|
||||||
|
@ -177,7 +177,7 @@ func (c *controlConn) shuffleDial(endpoints []*HostInfo) (*Conn, error) {
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
Logger.Printf("gocql: unable to dial control conn %v: %v\n", host.ConnectAddress(), err)
|
Logger.Printf("gocql: unable to dial control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -285,8 +285,6 @@ func (c *controlConn) setupConn(conn *Conn) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
c.conn.Store(ch)
|
c.conn.Store(ch)
|
||||||
c.session.handleNodeUp(host.ConnectAddress(), host.Port(), false)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -452,6 +450,8 @@ func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter
|
||||||
|
|
||||||
for {
|
for {
|
||||||
iter = c.withConn(func(conn *Conn) *Iter {
|
iter = c.withConn(func(conn *Conn) *Iter {
|
||||||
|
// we want to keep the query on the control connection
|
||||||
|
q.conn = conn
|
||||||
return conn.executeQuery(context.TODO(), q)
|
return conn.executeQuery(context.TODO(), q)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,319 @@
|
||||||
|
|
||||||
// Package gocql implements a fast and robust Cassandra driver for the
|
// Package gocql implements a fast and robust Cassandra driver for the
|
||||||
// Go programming language.
|
// Go programming language.
|
||||||
|
//
|
||||||
|
// Connecting to the cluster
|
||||||
|
//
|
||||||
|
// Pass a list of initial node IP addresses to NewCluster to create a new cluster configuration:
|
||||||
|
//
|
||||||
|
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
|
||||||
|
//
|
||||||
|
// Port can be specified as part of the address, the above is equivalent to:
|
||||||
|
//
|
||||||
|
// cluster := gocql.NewCluster("192.168.1.1:9042", "192.168.1.2:9042", "192.168.1.3:9042")
|
||||||
|
//
|
||||||
|
// It is recommended to use the value set in the Cassandra config for broadcast_address or listen_address,
|
||||||
|
// an IP address not a domain name. This is because events from Cassandra will use the configured IP
|
||||||
|
// address, which is used to index connected hosts. If the domain name specified resolves to more than 1 IP address
|
||||||
|
// then the driver may connect multiple times to the same host, and will not mark the node being down or up from events.
|
||||||
|
//
|
||||||
|
// Then you can customize more options (see ClusterConfig):
|
||||||
|
//
|
||||||
|
// cluster.Keyspace = "example"
|
||||||
|
// cluster.Consistency = gocql.Quorum
|
||||||
|
// cluster.ProtoVersion = 4
|
||||||
|
//
|
||||||
|
// The driver tries to automatically detect the protocol version to use if not set, but you might want to set the
|
||||||
|
// protocol version explicitly, as it's not defined which version will be used in certain situations (for example
|
||||||
|
// during upgrade of the cluster when some of the nodes support different set of protocol versions than other nodes).
|
||||||
|
//
|
||||||
|
// When ready, create a session from the configuration. Don't forget to Close the session once you are done with it:
|
||||||
|
//
|
||||||
|
// session, err := cluster.CreateSession()
|
||||||
|
// if err != nil {
|
||||||
|
// return err
|
||||||
|
// }
|
||||||
|
// defer session.Close()
|
||||||
|
//
|
||||||
|
// Authentication
|
||||||
|
//
|
||||||
|
// CQL protocol uses a SASL-based authentication mechanism and so consists of an exchange of server challenges and
|
||||||
|
// client response pairs. The details of the exchanged messages depend on the authenticator used.
|
||||||
|
//
|
||||||
|
// To use authentication, set ClusterConfig.Authenticator or ClusterConfig.AuthProvider.
|
||||||
|
//
|
||||||
|
// PasswordAuthenticator is provided to use for username/password authentication:
|
||||||
|
//
|
||||||
|
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
|
||||||
|
// cluster.Authenticator = gocql.PasswordAuthenticator{
|
||||||
|
// Username: "user",
|
||||||
|
// Password: "password"
|
||||||
|
// }
|
||||||
|
// session, err := cluster.CreateSession()
|
||||||
|
// if err != nil {
|
||||||
|
// return err
|
||||||
|
// }
|
||||||
|
// defer session.Close()
|
||||||
|
//
|
||||||
|
// Transport layer security
|
||||||
|
//
|
||||||
|
// It is possible to secure traffic between the client and server with TLS.
|
||||||
|
//
|
||||||
|
// To use TLS, set the ClusterConfig.SslOpts field. SslOptions embeds *tls.Config so you can set that directly.
|
||||||
|
// There are also helpers to load keys/certificates from files.
|
||||||
|
//
|
||||||
|
// Warning: Due to historical reasons, the SslOptions is insecure by default, so you need to set EnableHostVerification
|
||||||
|
// to true if no Config is set. Most users should set SslOptions.Config to a *tls.Config.
|
||||||
|
// SslOptions and Config.InsecureSkipVerify interact as follows:
|
||||||
|
//
|
||||||
|
// Config.InsecureSkipVerify | EnableHostVerification | Result
|
||||||
|
// Config is nil | false | do not verify host
|
||||||
|
// Config is nil | true | verify host
|
||||||
|
// false | false | verify host
|
||||||
|
// true | false | do not verify host
|
||||||
|
// false | true | verify host
|
||||||
|
// true | true | verify host
|
||||||
|
//
|
||||||
|
// For example:
|
||||||
|
//
|
||||||
|
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
|
||||||
|
// cluster.SslOpts = &gocql.SslOptions{
|
||||||
|
// EnableHostVerification: true,
|
||||||
|
// }
|
||||||
|
// session, err := cluster.CreateSession()
|
||||||
|
// if err != nil {
|
||||||
|
// return err
|
||||||
|
// }
|
||||||
|
// defer session.Close()
|
||||||
|
//
|
||||||
|
// Executing queries
|
||||||
|
//
|
||||||
|
// Create queries with Session.Query. Query values must not be reused between different executions and must not be
|
||||||
|
// modified after starting execution of the query.
|
||||||
|
//
|
||||||
|
// To execute a query without reading results, use Query.Exec:
|
||||||
|
//
|
||||||
|
// err := session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
|
||||||
|
// "me", gocql.TimeUUID(), "hello world").WithContext(ctx).Exec()
|
||||||
|
//
|
||||||
|
// Single row can be read by calling Query.Scan:
|
||||||
|
//
|
||||||
|
// err := session.Query(`SELECT id, text FROM tweet WHERE timeline = ? LIMIT 1`,
|
||||||
|
// "me").WithContext(ctx).Consistency(gocql.One).Scan(&id, &text)
|
||||||
|
//
|
||||||
|
// Multiple rows can be read using Iter.Scanner:
|
||||||
|
//
|
||||||
|
// scanner := session.Query(`SELECT id, text FROM tweet WHERE timeline = ?`,
|
||||||
|
// "me").WithContext(ctx).Iter().Scanner()
|
||||||
|
// for scanner.Next() {
|
||||||
|
// var (
|
||||||
|
// id gocql.UUID
|
||||||
|
// text string
|
||||||
|
// )
|
||||||
|
// err = scanner.Scan(&id, &text)
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
// fmt.Println("Tweet:", id, text)
|
||||||
|
// }
|
||||||
|
// // scanner.Err() closes the iterator, so scanner nor iter should be used afterwards.
|
||||||
|
// if err := scanner.Err(); err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// See Example for complete example.
|
||||||
|
//
|
||||||
|
// Prepared statements
|
||||||
|
//
|
||||||
|
// The driver automatically prepares DML queries (SELECT/INSERT/UPDATE/DELETE/BATCH statements) and maintains a cache
|
||||||
|
// of prepared statements.
|
||||||
|
// CQL protocol does not support preparing other query types.
|
||||||
|
//
|
||||||
|
// When using CQL protocol >= 4, it is possible to use gocql.UnsetValue as the bound value of a column.
|
||||||
|
// This will cause the database to ignore writing the column.
|
||||||
|
// The main advantage is the ability to keep the same prepared statement even when you don't
|
||||||
|
// want to update some fields, where before you needed to make another prepared statement.
|
||||||
|
//
|
||||||
|
// Executing multiple queries concurrently
|
||||||
|
//
|
||||||
|
// Session is safe to use from multiple goroutines, so to execute multiple concurrent queries, just execute them
|
||||||
|
// from several worker goroutines. Gocql provides synchronously-looking API (as recommended for Go APIs) and the queries
|
||||||
|
// are executed asynchronously at the protocol level.
|
||||||
|
//
|
||||||
|
// results := make(chan error, 2)
|
||||||
|
// go func() {
|
||||||
|
// results <- session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
|
||||||
|
// "me", gocql.TimeUUID(), "hello world 1").Exec()
|
||||||
|
// }()
|
||||||
|
// go func() {
|
||||||
|
// results <- session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
|
||||||
|
// "me", gocql.TimeUUID(), "hello world 2").Exec()
|
||||||
|
// }()
|
||||||
|
//
|
||||||
|
// Nulls
|
||||||
|
//
|
||||||
|
// Null values are are unmarshalled as zero value of the type. If you need to distinguish for example between text
|
||||||
|
// column being null and empty string, you can unmarshal into *string variable instead of string.
|
||||||
|
//
|
||||||
|
// var text *string
|
||||||
|
// err := scanner.Scan(&text)
|
||||||
|
// if err != nil {
|
||||||
|
// // handle error
|
||||||
|
// }
|
||||||
|
// if text != nil {
|
||||||
|
// // not null
|
||||||
|
// }
|
||||||
|
// else {
|
||||||
|
// // null
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// See Example_nulls for full example.
|
||||||
|
//
|
||||||
|
// Reusing slices
|
||||||
|
//
|
||||||
|
// The driver reuses backing memory of slices when unmarshalling. This is an optimization so that a buffer does not
|
||||||
|
// need to be allocated for every processed row. However, you need to be careful when storing the slices to other
|
||||||
|
// memory structures.
|
||||||
|
//
|
||||||
|
// scanner := session.Query(`SELECT myints FROM table WHERE pk = ?`, "key").WithContext(ctx).Iter().Scanner()
|
||||||
|
// var myInts []int
|
||||||
|
// for scanner.Next() {
|
||||||
|
// // This scan reuses backing store of myInts for each row.
|
||||||
|
// err = scanner.Scan(&myInts)
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// When you want to save the data for later use, pass a new slice every time. A common pattern is to declare the
|
||||||
|
// slice variable within the scanner loop:
|
||||||
|
//
|
||||||
|
// scanner := session.Query(`SELECT myints FROM table WHERE pk = ?`, "key").WithContext(ctx).Iter().Scanner()
|
||||||
|
// for scanner.Next() {
|
||||||
|
// var myInts []int
|
||||||
|
// // This scan always gets pointer to fresh myInts slice, so does not reuse memory.
|
||||||
|
// err = scanner.Scan(&myInts)
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Paging
|
||||||
|
//
|
||||||
|
// The driver supports paging of results with automatic prefetch, see ClusterConfig.PageSize, Session.SetPrefetch,
|
||||||
|
// Query.PageSize, and Query.Prefetch.
|
||||||
|
//
|
||||||
|
// It is also possible to control the paging manually with Query.PageState (this disables automatic prefetch).
|
||||||
|
// Manual paging is useful if you want to store the page state externally, for example in a URL to allow users
|
||||||
|
// browse pages in a result. You might want to sign/encrypt the paging state when exposing it externally since
|
||||||
|
// it contains data from primary keys.
|
||||||
|
//
|
||||||
|
// Paging state is specific to the CQL protocol version and the exact query used. It is meant as opaque state that
|
||||||
|
// should not be modified. If you send paging state from different query or protocol version, then the behaviour
|
||||||
|
// is not defined (you might get unexpected results or an error from the server). For example, do not send paging state
|
||||||
|
// returned by node using protocol version 3 to a node using protocol version 4. Also, when using protocol version 4,
|
||||||
|
// paging state between Cassandra 2.2 and 3.0 is incompatible (https://issues.apache.org/jira/browse/CASSANDRA-10880).
|
||||||
|
//
|
||||||
|
// The driver does not check whether the paging state is from the same protocol version/statement.
|
||||||
|
// You might want to validate yourself as this could be a problem if you store paging state externally.
|
||||||
|
// For example, if you store paging state in a URL, the URLs might become broken when you upgrade your cluster.
|
||||||
|
//
|
||||||
|
// Call Query.PageState(nil) to fetch just the first page of the query results. Pass the page state returned by
|
||||||
|
// Iter.PageState to Query.PageState of a subsequent query to get the next page. If the length of slice returned
|
||||||
|
// by Iter.PageState is zero, there are no more pages available (or an error occurred).
|
||||||
|
//
|
||||||
|
// Using too low values of PageSize will negatively affect performance, a value below 100 is probably too low.
|
||||||
|
// While Cassandra returns exactly PageSize items (except for last page) in a page currently, the protocol authors
|
||||||
|
// explicitly reserved the right to return smaller or larger amount of items in a page for performance reasons, so don't
|
||||||
|
// rely on the page having the exact count of items.
|
||||||
|
//
|
||||||
|
// See Example_paging for an example of manual paging.
|
||||||
|
//
|
||||||
|
// Dynamic list of columns
|
||||||
|
//
|
||||||
|
// There are certain situations when you don't know the list of columns in advance, mainly when the query is supplied
|
||||||
|
// by the user. Iter.Columns, Iter.RowData, Iter.MapScan and Iter.SliceMap can be used to handle this case.
|
||||||
|
//
|
||||||
|
// See Example_dynamicColumns.
|
||||||
|
//
|
||||||
|
// Batches
|
||||||
|
//
|
||||||
|
// The CQL protocol supports sending batches of DML statements (INSERT/UPDATE/DELETE) and so does gocql.
|
||||||
|
// Use Session.NewBatch to create a new batch and then fill-in details of individual queries.
|
||||||
|
// Then execute the batch with Session.ExecuteBatch.
|
||||||
|
//
|
||||||
|
// Logged batches ensure atomicity, either all or none of the operations in the batch will succeed, but they have
|
||||||
|
// overhead to ensure this property.
|
||||||
|
// Unlogged batches don't have the overhead of logged batches, but don't guarantee atomicity.
|
||||||
|
// Updates of counters are handled specially by Cassandra so batches of counter updates have to use CounterBatch type.
|
||||||
|
// A counter batch can only contain statements to update counters.
|
||||||
|
//
|
||||||
|
// For unlogged batches it is recommended to send only single-partition batches (i.e. all statements in the batch should
|
||||||
|
// involve only a single partition).
|
||||||
|
// Multi-partition batch needs to be split by the coordinator node and re-sent to
|
||||||
|
// correct nodes.
|
||||||
|
// With single-partition batches you can send the batch directly to the node for the partition without incurring the
|
||||||
|
// additional network hop.
|
||||||
|
//
|
||||||
|
// It is also possible to pass entire BEGIN BATCH .. APPLY BATCH statement to Query.Exec.
|
||||||
|
// There are differences how those are executed.
|
||||||
|
// BEGIN BATCH statement passed to Query.Exec is prepared as a whole in a single statement.
|
||||||
|
// Session.ExecuteBatch prepares individual statements in the batch.
|
||||||
|
// If you have variable-length batches using the same statement, using Session.ExecuteBatch is more efficient.
|
||||||
|
//
|
||||||
|
// See Example_batch for an example.
|
||||||
|
//
|
||||||
|
// Lightweight transactions
|
||||||
|
//
|
||||||
|
// Query.ScanCAS or Query.MapScanCAS can be used to execute a single-statement lightweight transaction (an
|
||||||
|
// INSERT/UPDATE .. IF statement) and reading its result. See example for Query.MapScanCAS.
|
||||||
|
//
|
||||||
|
// Multiple-statement lightweight transactions can be executed as a logged batch that contains at least one conditional
|
||||||
|
// statement. All the conditions must return true for the batch to be applied. You can use Session.ExecuteBatchCAS and
|
||||||
|
// Session.MapExecuteBatchCAS when executing the batch to learn about the result of the LWT. See example for
|
||||||
|
// Session.MapExecuteBatchCAS.
|
||||||
|
//
|
||||||
|
// Retries and speculative execution
|
||||||
|
//
|
||||||
|
// Queries can be marked as idempotent. Marking the query as idempotent tells the driver that the query can be executed
|
||||||
|
// multiple times without affecting its result. Non-idempotent queries are not eligible for retrying nor speculative
|
||||||
|
// execution.
|
||||||
|
//
|
||||||
|
// Idempotent queries are retried in case of errors based on the configured RetryPolicy.
|
||||||
|
//
|
||||||
|
// Queries can be retried even before they fail by setting a SpeculativeExecutionPolicy. The policy can
|
||||||
|
// cause the driver to retry on a different node if the query is taking longer than a specified delay even before the
|
||||||
|
// driver receives an error or timeout from the server. When a query is speculatively executed, the original execution
|
||||||
|
// is still executing. The two parallel executions of the query race to return a result, the first received result will
|
||||||
|
// be returned.
|
||||||
|
//
|
||||||
|
// User-defined types
|
||||||
|
//
|
||||||
|
// UDTs can be mapped (un)marshaled from/to map[string]interface{} a Go struct (or a type implementing
|
||||||
|
// UDTUnmarshaler, UDTMarshaler, Unmarshaler or Marshaler interfaces).
|
||||||
|
//
|
||||||
|
// For structs, cql tag can be used to specify the CQL field name to be mapped to a struct field:
|
||||||
|
//
|
||||||
|
// type MyUDT struct {
|
||||||
|
// FieldA int32 `cql:"a"`
|
||||||
|
// FieldB string `cql:"b"`
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// See Example_userDefinedTypesMap, Example_userDefinedTypesStruct, ExampleUDTMarshaler, ExampleUDTUnmarshaler.
|
||||||
|
//
|
||||||
|
// Metrics and tracing
|
||||||
|
//
|
||||||
|
// It is possible to provide observer implementations that could be used to gather metrics:
|
||||||
|
//
|
||||||
|
// - QueryObserver for monitoring individual queries.
|
||||||
|
// - BatchObserver for monitoring batch queries.
|
||||||
|
// - ConnectObserver for monitoring new connections from the driver to the database.
|
||||||
|
// - FrameHeaderObserver for monitoring individual protocol frames.
|
||||||
|
//
|
||||||
|
// CQL protocol also supports tracing of queries. When enabled, the database will write information about
|
||||||
|
// internal events that happened during execution of the query. You can use Query.Trace to request tracing and receive
|
||||||
|
// the session ID that the database used to store the trace information in system_traces.sessions and
|
||||||
|
// system_traces.events tables. NewTraceWriter returns an implementation of Tracer that writes the events to a writer.
|
||||||
|
// Gathering trace information might be essential for debugging and optimizing queries, but writing traces has overhead,
|
||||||
|
// so this feature should not be used on production systems with very high load unless you know what you are doing.
|
||||||
package gocql // import "github.com/gocql/gocql"
|
package gocql // import "github.com/gocql/gocql"
|
||||||
|
|
||||||
// TODO(tux21b): write more docs.
|
|
||||||
|
|
|
@ -164,55 +164,43 @@ func (s *Session) handleNodeEvent(frames []frame) {
|
||||||
|
|
||||||
switch f.change {
|
switch f.change {
|
||||||
case "NEW_NODE":
|
case "NEW_NODE":
|
||||||
s.handleNewNode(f.host, f.port, true)
|
s.handleNewNode(f.host, f.port)
|
||||||
case "REMOVED_NODE":
|
case "REMOVED_NODE":
|
||||||
s.handleRemovedNode(f.host, f.port)
|
s.handleRemovedNode(f.host, f.port)
|
||||||
case "MOVED_NODE":
|
case "MOVED_NODE":
|
||||||
// java-driver handles this, not mentioned in the spec
|
// java-driver handles this, not mentioned in the spec
|
||||||
// TODO(zariel): refresh token map
|
// TODO(zariel): refresh token map
|
||||||
case "UP":
|
case "UP":
|
||||||
s.handleNodeUp(f.host, f.port, true)
|
s.handleNodeUp(f.host, f.port)
|
||||||
case "DOWN":
|
case "DOWN":
|
||||||
s.handleNodeDown(f.host, f.port)
|
s.handleNodeDown(f.host, f.port)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) addNewNode(host *HostInfo) {
|
func (s *Session) addNewNode(ip net.IP, port int) {
|
||||||
if s.cfg.filterHost(host) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
host.setState(NodeUp)
|
|
||||||
s.pool.addHost(host)
|
|
||||||
s.policy.AddHost(host)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) handleNewNode(ip net.IP, port int, waitForBinary bool) {
|
|
||||||
if gocqlDebug {
|
|
||||||
Logger.Printf("gocql: Session.handleNewNode: %s:%d\n", ip.String(), port)
|
|
||||||
}
|
|
||||||
|
|
||||||
ip, port = s.cfg.translateAddressPort(ip, port)
|
|
||||||
|
|
||||||
// Get host info and apply any filters to the host
|
// Get host info and apply any filters to the host
|
||||||
hostInfo, err := s.hostSource.getHostInfo(ip, port)
|
hostInfo, err := s.hostSource.getHostInfo(ip, port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Logger.Printf("gocql: events: unable to fetch host info for (%s:%d): %v\n", ip, port, err)
|
Logger.Printf("gocql: events: unable to fetch host info for (%s:%d): %v\n", ip, port, err)
|
||||||
return
|
return
|
||||||
} else if hostInfo == nil {
|
} else if hostInfo == nil {
|
||||||
// If hostInfo is nil, this host was filtered out by cfg.HostFilter
|
// ignore if it's null because we couldn't find it
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if t := hostInfo.Version().nodeUpDelay(); t > 0 && waitForBinary {
|
if t := hostInfo.Version().nodeUpDelay(); t > 0 {
|
||||||
time.Sleep(t)
|
time.Sleep(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// should this handle token moving?
|
// should this handle token moving?
|
||||||
hostInfo = s.ring.addOrUpdate(hostInfo)
|
hostInfo = s.ring.addOrUpdate(hostInfo)
|
||||||
|
|
||||||
s.addNewNode(hostInfo)
|
if !s.cfg.filterHost(hostInfo) {
|
||||||
|
// we let the pool call handleNodeUp to change the host state
|
||||||
|
s.pool.addHost(hostInfo)
|
||||||
|
s.policy.AddHost(hostInfo)
|
||||||
|
}
|
||||||
|
|
||||||
if s.control != nil && !s.cfg.IgnorePeerAddr {
|
if s.control != nil && !s.cfg.IgnorePeerAddr {
|
||||||
// TODO(zariel): debounce ring refresh
|
// TODO(zariel): debounce ring refresh
|
||||||
|
@ -220,6 +208,22 @@ func (s *Session) handleNewNode(ip net.IP, port int, waitForBinary bool) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Session) handleNewNode(ip net.IP, port int) {
|
||||||
|
if gocqlDebug {
|
||||||
|
Logger.Printf("gocql: Session.handleNewNode: %s:%d\n", ip.String(), port)
|
||||||
|
}
|
||||||
|
|
||||||
|
ip, port = s.cfg.translateAddressPort(ip, port)
|
||||||
|
|
||||||
|
// if we already have the host and it's already up, then do nothing
|
||||||
|
host := s.ring.getHost(ip)
|
||||||
|
if host != nil && host.IsUp() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.addNewNode(ip, port)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Session) handleRemovedNode(ip net.IP, port int) {
|
func (s *Session) handleRemovedNode(ip net.IP, port int) {
|
||||||
if gocqlDebug {
|
if gocqlDebug {
|
||||||
Logger.Printf("gocql: Session.handleRemovedNode: %s:%d\n", ip.String(), port)
|
Logger.Printf("gocql: Session.handleRemovedNode: %s:%d\n", ip.String(), port)
|
||||||
|
@ -232,45 +236,37 @@ func (s *Session) handleRemovedNode(ip net.IP, port int) {
|
||||||
if host == nil {
|
if host == nil {
|
||||||
host = &HostInfo{connectAddress: ip, port: port}
|
host = &HostInfo{connectAddress: ip, port: port}
|
||||||
}
|
}
|
||||||
|
s.ring.removeHost(ip)
|
||||||
if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
host.setState(NodeDown)
|
host.setState(NodeDown)
|
||||||
s.policy.RemoveHost(host)
|
if !s.cfg.filterHost(host) {
|
||||||
s.pool.removeHost(ip)
|
s.policy.RemoveHost(host)
|
||||||
s.ring.removeHost(ip)
|
s.pool.removeHost(ip)
|
||||||
|
}
|
||||||
|
|
||||||
if !s.cfg.IgnorePeerAddr {
|
if !s.cfg.IgnorePeerAddr {
|
||||||
s.hostSource.refreshRing()
|
s.hostSource.refreshRing()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) handleNodeUp(eventIp net.IP, eventPort int, waitForBinary bool) {
|
func (s *Session) handleNodeUp(eventIp net.IP, eventPort int) {
|
||||||
if gocqlDebug {
|
if gocqlDebug {
|
||||||
Logger.Printf("gocql: Session.handleNodeUp: %s:%d\n", eventIp.String(), eventPort)
|
Logger.Printf("gocql: Session.handleNodeUp: %s:%d\n", eventIp.String(), eventPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
ip, _ := s.cfg.translateAddressPort(eventIp, eventPort)
|
ip, port := s.cfg.translateAddressPort(eventIp, eventPort)
|
||||||
|
|
||||||
host := s.ring.getHost(ip)
|
host := s.ring.getHost(ip)
|
||||||
if host == nil {
|
if host == nil {
|
||||||
// TODO(zariel): avoid the need to translate twice in this
|
s.addNewNode(ip, port)
|
||||||
// case
|
|
||||||
s.handleNewNode(eventIp, eventPort, waitForBinary)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
|
host.setState(NodeUp)
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if t := host.Version().nodeUpDelay(); t > 0 && waitForBinary {
|
if !s.cfg.filterHost(host) {
|
||||||
time.Sleep(t)
|
s.policy.HostUp(host)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.addNewNode(host)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) handleNodeDown(ip net.IP, port int) {
|
func (s *Session) handleNodeDown(ip net.IP, port int) {
|
||||||
|
@ -283,11 +279,11 @@ func (s *Session) handleNodeDown(ip net.IP, port int) {
|
||||||
host = &HostInfo{connectAddress: ip, port: port}
|
host = &HostInfo{connectAddress: ip, port: port}
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.cfg.HostFilter != nil && !s.cfg.HostFilter.Accept(host) {
|
host.setState(NodeDown)
|
||||||
|
if s.cfg.filterHost(host) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
host.setState(NodeDown)
|
|
||||||
s.policy.HostDown(host)
|
s.policy.HostDown(host)
|
||||||
s.pool.hostDown(ip)
|
s.pool.hostDown(ip)
|
||||||
}
|
}
|
||||||
|
|
|
@ -311,26 +311,10 @@ var (
|
||||||
|
|
||||||
const maxFrameHeaderSize = 9
|
const maxFrameHeaderSize = 9
|
||||||
|
|
||||||
func writeInt(p []byte, n int32) {
|
|
||||||
p[0] = byte(n >> 24)
|
|
||||||
p[1] = byte(n >> 16)
|
|
||||||
p[2] = byte(n >> 8)
|
|
||||||
p[3] = byte(n)
|
|
||||||
}
|
|
||||||
|
|
||||||
func readInt(p []byte) int32 {
|
func readInt(p []byte) int32 {
|
||||||
return int32(p[0])<<24 | int32(p[1])<<16 | int32(p[2])<<8 | int32(p[3])
|
return int32(p[0])<<24 | int32(p[1])<<16 | int32(p[2])<<8 | int32(p[3])
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeShort(p []byte, n uint16) {
|
|
||||||
p[0] = byte(n >> 8)
|
|
||||||
p[1] = byte(n)
|
|
||||||
}
|
|
||||||
|
|
||||||
func readShort(p []byte) uint16 {
|
|
||||||
return uint16(p[0])<<8 | uint16(p[1])
|
|
||||||
}
|
|
||||||
|
|
||||||
type frameHeader struct {
|
type frameHeader struct {
|
||||||
version protoVersion
|
version protoVersion
|
||||||
flags byte
|
flags byte
|
||||||
|
@ -854,7 +838,7 @@ func (w *writePrepareFrame) writeFrame(f *framer, streamID int) error {
|
||||||
if f.proto > protoVersion4 {
|
if f.proto > protoVersion4 {
|
||||||
flags |= flagWithPreparedKeyspace
|
flags |= flagWithPreparedKeyspace
|
||||||
} else {
|
} else {
|
||||||
panic(fmt.Errorf("The keyspace can only be set with protocol 5 or higher"))
|
panic(fmt.Errorf("the keyspace can only be set with protocol 5 or higher"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if f.proto > protoVersion4 {
|
if f.proto > protoVersion4 {
|
||||||
|
@ -1502,7 +1486,7 @@ func (f *framer) writeQueryParams(opts *queryParams) {
|
||||||
if f.proto > protoVersion4 {
|
if f.proto > protoVersion4 {
|
||||||
flags |= flagWithKeyspace
|
flags |= flagWithKeyspace
|
||||||
} else {
|
} else {
|
||||||
panic(fmt.Errorf("The keyspace can only be set with protocol 5 or higher"))
|
panic(fmt.Errorf("the keyspace can only be set with protocol 5 or higher"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1792,16 +1776,6 @@ func (f *framer) readShort() (n uint16) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *framer) readLong() (n int64) {
|
|
||||||
if len(f.rbuf) < 8 {
|
|
||||||
panic(fmt.Errorf("not enough bytes in buffer to read long require 8 got: %d", len(f.rbuf)))
|
|
||||||
}
|
|
||||||
n = int64(f.rbuf[0])<<56 | int64(f.rbuf[1])<<48 | int64(f.rbuf[2])<<40 | int64(f.rbuf[3])<<32 |
|
|
||||||
int64(f.rbuf[4])<<24 | int64(f.rbuf[5])<<16 | int64(f.rbuf[6])<<8 | int64(f.rbuf[7])
|
|
||||||
f.rbuf = f.rbuf[8:]
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *framer) readString() (s string) {
|
func (f *framer) readString() (s string) {
|
||||||
size := f.readShort()
|
size := f.readShort()
|
||||||
|
|
||||||
|
@ -1915,19 +1889,6 @@ func (f *framer) readConsistency() Consistency {
|
||||||
return Consistency(f.readShort())
|
return Consistency(f.readShort())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *framer) readStringMap() map[string]string {
|
|
||||||
size := f.readShort()
|
|
||||||
m := make(map[string]string, size)
|
|
||||||
|
|
||||||
for i := 0; i < int(size); i++ {
|
|
||||||
k := f.readString()
|
|
||||||
v := f.readString()
|
|
||||||
m[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *framer) readBytesMap() map[string][]byte {
|
func (f *framer) readBytesMap() map[string][]byte {
|
||||||
size := f.readShort()
|
size := f.readShort()
|
||||||
m := make(map[string][]byte, size)
|
m := make(map[string][]byte, size)
|
||||||
|
@ -2037,10 +1998,6 @@ func (f *framer) writeLongString(s string) {
|
||||||
f.wbuf = append(f.wbuf, s...)
|
f.wbuf = append(f.wbuf, s...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *framer) writeUUID(u *UUID) {
|
|
||||||
f.wbuf = append(f.wbuf, u[:]...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *framer) writeStringList(l []string) {
|
func (f *framer) writeStringList(l []string) {
|
||||||
f.writeShort(uint16(len(l)))
|
f.writeShort(uint16(len(l)))
|
||||||
for _, s := range l {
|
for _, s := range l {
|
||||||
|
@ -2073,18 +2030,6 @@ func (f *framer) writeShortBytes(p []byte) {
|
||||||
f.wbuf = append(f.wbuf, p...)
|
f.wbuf = append(f.wbuf, p...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *framer) writeInet(ip net.IP, port int) {
|
|
||||||
f.wbuf = append(f.wbuf,
|
|
||||||
byte(len(ip)),
|
|
||||||
)
|
|
||||||
|
|
||||||
f.wbuf = append(f.wbuf,
|
|
||||||
[]byte(ip)...,
|
|
||||||
)
|
|
||||||
|
|
||||||
f.writeInt(int32(port))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *framer) writeConsistency(cons Consistency) {
|
func (f *framer) writeConsistency(cons Consistency) {
|
||||||
f.writeShort(uint16(cons))
|
f.writeShort(uint16(cons))
|
||||||
}
|
}
|
||||||
|
|
|
@ -270,15 +270,6 @@ func getApacheCassandraType(class string) Type {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func typeCanBeNull(typ TypeInfo) bool {
|
|
||||||
switch typ.(type) {
|
|
||||||
case CollectionType, UDTTypeInfo, TupleTypeInfo:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *RowData) rowMap(m map[string]interface{}) {
|
func (r *RowData) rowMap(m map[string]interface{}) {
|
||||||
for i, column := range r.Columns {
|
for i, column := range r.Columns {
|
||||||
val := dereference(r.Values[i])
|
val := dereference(r.Values[i])
|
||||||
|
@ -372,7 +363,7 @@ func (iter *Iter) SliceMap() ([]map[string]interface{}, error) {
|
||||||
// iter := session.Query(`SELECT * FROM mytable`).Iter()
|
// iter := session.Query(`SELECT * FROM mytable`).Iter()
|
||||||
// for {
|
// for {
|
||||||
// // New map each iteration
|
// // New map each iteration
|
||||||
// row = make(map[string]interface{})
|
// row := make(map[string]interface{})
|
||||||
// if !iter.MapScan(row) {
|
// if !iter.MapScan(row) {
|
||||||
// break
|
// break
|
||||||
// }
|
// }
|
||||||
|
|
|
@ -147,13 +147,6 @@ func (h *HostInfo) Peer() net.IP {
|
||||||
return h.peer
|
return h.peer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *HostInfo) setPeer(peer net.IP) *HostInfo {
|
|
||||||
h.mu.Lock()
|
|
||||||
defer h.mu.Unlock()
|
|
||||||
h.peer = peer
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *HostInfo) invalidConnectAddr() bool {
|
func (h *HostInfo) invalidConnectAddr() bool {
|
||||||
h.mu.RLock()
|
h.mu.RLock()
|
||||||
defer h.mu.RUnlock()
|
defer h.mu.RUnlock()
|
||||||
|
@ -233,13 +226,6 @@ func (h *HostInfo) DataCenter() string {
|
||||||
return dc
|
return dc
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *HostInfo) setDataCenter(dataCenter string) *HostInfo {
|
|
||||||
h.mu.Lock()
|
|
||||||
defer h.mu.Unlock()
|
|
||||||
h.dataCenter = dataCenter
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *HostInfo) Rack() string {
|
func (h *HostInfo) Rack() string {
|
||||||
h.mu.RLock()
|
h.mu.RLock()
|
||||||
rack := h.rack
|
rack := h.rack
|
||||||
|
@ -247,26 +233,12 @@ func (h *HostInfo) Rack() string {
|
||||||
return rack
|
return rack
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *HostInfo) setRack(rack string) *HostInfo {
|
|
||||||
h.mu.Lock()
|
|
||||||
defer h.mu.Unlock()
|
|
||||||
h.rack = rack
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *HostInfo) HostID() string {
|
func (h *HostInfo) HostID() string {
|
||||||
h.mu.RLock()
|
h.mu.RLock()
|
||||||
defer h.mu.RUnlock()
|
defer h.mu.RUnlock()
|
||||||
return h.hostId
|
return h.hostId
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *HostInfo) setHostID(hostID string) *HostInfo {
|
|
||||||
h.mu.Lock()
|
|
||||||
defer h.mu.Unlock()
|
|
||||||
h.hostId = hostID
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *HostInfo) WorkLoad() string {
|
func (h *HostInfo) WorkLoad() string {
|
||||||
h.mu.RLock()
|
h.mu.RLock()
|
||||||
defer h.mu.RUnlock()
|
defer h.mu.RUnlock()
|
||||||
|
@ -303,13 +275,6 @@ func (h *HostInfo) Version() cassVersion {
|
||||||
return h.version
|
return h.version
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *HostInfo) setVersion(major, minor, patch int) *HostInfo {
|
|
||||||
h.mu.Lock()
|
|
||||||
defer h.mu.Unlock()
|
|
||||||
h.version = cassVersion{major, minor, patch}
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *HostInfo) State() nodeState {
|
func (h *HostInfo) State() nodeState {
|
||||||
h.mu.RLock()
|
h.mu.RLock()
|
||||||
defer h.mu.RUnlock()
|
defer h.mu.RUnlock()
|
||||||
|
@ -329,26 +294,12 @@ func (h *HostInfo) Tokens() []string {
|
||||||
return h.tokens
|
return h.tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *HostInfo) setTokens(tokens []string) *HostInfo {
|
|
||||||
h.mu.Lock()
|
|
||||||
defer h.mu.Unlock()
|
|
||||||
h.tokens = tokens
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *HostInfo) Port() int {
|
func (h *HostInfo) Port() int {
|
||||||
h.mu.RLock()
|
h.mu.RLock()
|
||||||
defer h.mu.RUnlock()
|
defer h.mu.RUnlock()
|
||||||
return h.port
|
return h.port
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *HostInfo) setPort(port int) *HostInfo {
|
|
||||||
h.mu.Lock()
|
|
||||||
defer h.mu.Unlock()
|
|
||||||
h.port = port
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *HostInfo) update(from *HostInfo) {
|
func (h *HostInfo) update(from *HostInfo) {
|
||||||
if h == from {
|
if h == from {
|
||||||
return
|
return
|
||||||
|
@ -689,7 +640,7 @@ func (r *ringDescriber) refreshRing() error {
|
||||||
|
|
||||||
// TODO: move this to session
|
// TODO: move this to session
|
||||||
for _, h := range hosts {
|
for _, h := range hosts {
|
||||||
if filter := r.session.cfg.HostFilter; filter != nil && !filter.Accept(h) {
|
if r.session.cfg.filterHost(h) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -8,9 +8,3 @@ git clone https://github.com/pcmanus/ccm.git
|
||||||
pushd ccm
|
pushd ccm
|
||||||
./setup.py install --user
|
./setup.py install --user
|
||||||
popd
|
popd
|
||||||
|
|
||||||
if [ "$1" != "gocql/gocql" ]; then
|
|
||||||
USER=$(echo $1 | cut -f1 -d'/')
|
|
||||||
cd ../..
|
|
||||||
mv ${USER} gocql
|
|
||||||
fi
|
|
||||||
|
|
|
@ -44,6 +44,52 @@ type Unmarshaler interface {
|
||||||
|
|
||||||
// Marshal returns the CQL encoding of the value for the Cassandra
|
// Marshal returns the CQL encoding of the value for the Cassandra
|
||||||
// internal type described by the info parameter.
|
// internal type described by the info parameter.
|
||||||
|
//
|
||||||
|
// nil is serialized as CQL null.
|
||||||
|
// If value implements Marshaler, its MarshalCQL method is called to marshal the data.
|
||||||
|
// If value is a pointer, the pointed-to value is marshaled.
|
||||||
|
//
|
||||||
|
// Supported conversions are as follows, other type combinations may be added in the future:
|
||||||
|
//
|
||||||
|
// CQL type | Go type (value) | Note
|
||||||
|
// varchar, ascii, blob, text | string, []byte |
|
||||||
|
// boolean | bool |
|
||||||
|
// tinyint, smallint, int | integer types |
|
||||||
|
// tinyint, smallint, int | string | formatted as base 10 number
|
||||||
|
// bigint, counter | integer types |
|
||||||
|
// bigint, counter | big.Int |
|
||||||
|
// bigint, counter | string | formatted as base 10 number
|
||||||
|
// float | float32 |
|
||||||
|
// double | float64 |
|
||||||
|
// decimal | inf.Dec |
|
||||||
|
// time | int64 | nanoseconds since start of day
|
||||||
|
// time | time.Duration | duration since start of day
|
||||||
|
// timestamp | int64 | milliseconds since Unix epoch
|
||||||
|
// timestamp | time.Time |
|
||||||
|
// list, set | slice, array |
|
||||||
|
// list, set | map[X]struct{} |
|
||||||
|
// map | map[X]Y |
|
||||||
|
// uuid, timeuuid | gocql.UUID |
|
||||||
|
// uuid, timeuuid | [16]byte | raw UUID bytes
|
||||||
|
// uuid, timeuuid | []byte | raw UUID bytes, length must be 16 bytes
|
||||||
|
// uuid, timeuuid | string | hex representation, see ParseUUID
|
||||||
|
// varint | integer types |
|
||||||
|
// varint | big.Int |
|
||||||
|
// varint | string | value of number in decimal notation
|
||||||
|
// inet | net.IP |
|
||||||
|
// inet | string | IPv4 or IPv6 address string
|
||||||
|
// tuple | slice, array |
|
||||||
|
// tuple | struct | fields are marshaled in order of declaration
|
||||||
|
// user-defined type | gocql.UDTMarshaler | MarshalUDT is called
|
||||||
|
// user-defined type | map[string]interface{} |
|
||||||
|
// user-defined type | struct | struct fields' cql tags are used for column names
|
||||||
|
// date | int64 | milliseconds since Unix epoch to start of day (in UTC)
|
||||||
|
// date | time.Time | start of day (in UTC)
|
||||||
|
// date | string | parsed using "2006-01-02" format
|
||||||
|
// duration | int64 | duration in nanoseconds
|
||||||
|
// duration | time.Duration |
|
||||||
|
// duration | gocql.Duration |
|
||||||
|
// duration | string | parsed with time.ParseDuration
|
||||||
func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
|
func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
|
||||||
if info.Version() < protoVersion1 {
|
if info.Version() < protoVersion1 {
|
||||||
panic("protocol version not set")
|
panic("protocol version not set")
|
||||||
|
@ -118,6 +164,44 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
|
||||||
// Unmarshal parses the CQL encoded data based on the info parameter that
|
// Unmarshal parses the CQL encoded data based on the info parameter that
|
||||||
// describes the Cassandra internal data type and stores the result in the
|
// describes the Cassandra internal data type and stores the result in the
|
||||||
// value pointed by value.
|
// value pointed by value.
|
||||||
|
//
|
||||||
|
// If value implements Unmarshaler, it's UnmarshalCQL method is called to
|
||||||
|
// unmarshal the data.
|
||||||
|
// If value is a pointer to pointer, it is set to nil if the CQL value is
|
||||||
|
// null. Otherwise, nulls are unmarshalled as zero value.
|
||||||
|
//
|
||||||
|
// Supported conversions are as follows, other type combinations may be added in the future:
|
||||||
|
//
|
||||||
|
// CQL type | Go type (value) | Note
|
||||||
|
// varchar, ascii, blob, text | *string |
|
||||||
|
// varchar, ascii, blob, text | *[]byte | non-nil buffer is reused
|
||||||
|
// bool | *bool |
|
||||||
|
// tinyint, smallint, int, bigint, counter | *integer types |
|
||||||
|
// tinyint, smallint, int, bigint, counter | *big.Int |
|
||||||
|
// tinyint, smallint, int, bigint, counter | *string | formatted as base 10 number
|
||||||
|
// float | *float32 |
|
||||||
|
// double | *float64 |
|
||||||
|
// decimal | *inf.Dec |
|
||||||
|
// time | *int64 | nanoseconds since start of day
|
||||||
|
// time | *time.Duration |
|
||||||
|
// timestamp | *int64 | milliseconds since Unix epoch
|
||||||
|
// timestamp | *time.Time |
|
||||||
|
// list, set | *slice, *array |
|
||||||
|
// map | *map[X]Y |
|
||||||
|
// uuid, timeuuid | *string | see UUID.String
|
||||||
|
// uuid, timeuuid | *[]byte | raw UUID bytes
|
||||||
|
// uuid, timeuuid | *gocql.UUID |
|
||||||
|
// timeuuid | *time.Time | timestamp of the UUID
|
||||||
|
// inet | *net.IP |
|
||||||
|
// inet | *string | IPv4 or IPv6 address string
|
||||||
|
// tuple | *slice, *array |
|
||||||
|
// tuple | *struct | struct fields are set in order of declaration
|
||||||
|
// user-defined types | gocql.UDTUnmarshaler | UnmarshalUDT is called
|
||||||
|
// user-defined types | *map[string]interface{} |
|
||||||
|
// user-defined types | *struct | cql tag is used to determine field name
|
||||||
|
// date | *time.Time | time of beginning of the day (in UTC)
|
||||||
|
// date | *string | formatted with 2006-01-02 format
|
||||||
|
// duration | *gocql.Duration |
|
||||||
func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
|
func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
|
||||||
if v, ok := value.(Unmarshaler); ok {
|
if v, ok := value.(Unmarshaler); ok {
|
||||||
return v.UnmarshalCQL(info, data)
|
return v.UnmarshalCQL(info, data)
|
||||||
|
@ -1690,6 +1774,8 @@ func marshalUUID(info TypeInfo, value interface{}) ([]byte, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
case UUID:
|
case UUID:
|
||||||
return val.Bytes(), nil
|
return val.Bytes(), nil
|
||||||
|
case [16]byte:
|
||||||
|
return val[:], nil
|
||||||
case []byte:
|
case []byte:
|
||||||
if len(val) != 16 {
|
if len(val) != 16 {
|
||||||
return nil, marshalErrorf("can not marshal []byte %d bytes long into %s, must be exactly 16 bytes long", len(val), info)
|
return nil, marshalErrorf("can not marshal []byte %d bytes long into %s, must be exactly 16 bytes long", len(val), info)
|
||||||
|
@ -1711,7 +1797,7 @@ func marshalUUID(info TypeInfo, value interface{}) ([]byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error {
|
func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error {
|
||||||
if data == nil || len(data) == 0 {
|
if len(data) == 0 {
|
||||||
switch v := value.(type) {
|
switch v := value.(type) {
|
||||||
case *string:
|
case *string:
|
||||||
*v = ""
|
*v = ""
|
||||||
|
@ -1726,9 +1812,22 @@ func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(data) != 16 {
|
||||||
|
return unmarshalErrorf("unable to parse UUID: UUIDs must be exactly 16 bytes long")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := value.(type) {
|
||||||
|
case *[16]byte:
|
||||||
|
copy((*v)[:], data)
|
||||||
|
return nil
|
||||||
|
case *UUID:
|
||||||
|
copy((*v)[:], data)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
u, err := UUIDFromBytes(data)
|
u, err := UUIDFromBytes(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return unmarshalErrorf("Unable to parse UUID: %s", err)
|
return unmarshalErrorf("unable to parse UUID: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch v := value.(type) {
|
switch v := value.(type) {
|
||||||
|
@ -1738,9 +1837,6 @@ func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error {
|
||||||
case *[]byte:
|
case *[]byte:
|
||||||
*v = u[:]
|
*v = u[:]
|
||||||
return nil
|
return nil
|
||||||
case *UUID:
|
|
||||||
*v = u
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
return unmarshalErrorf("can not unmarshal X %s into %T", info, value)
|
return unmarshalErrorf("can not unmarshal X %s into %T", info, value)
|
||||||
}
|
}
|
||||||
|
@ -1942,7 +2038,7 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error {
|
||||||
for i, elem := range tuple.Elems {
|
for i, elem := range tuple.Elems {
|
||||||
// each element inside data is a [bytes]
|
// each element inside data is a [bytes]
|
||||||
var p []byte
|
var p []byte
|
||||||
if len(data) > 4 {
|
if len(data) >= 4 {
|
||||||
p, data = readBytes(data)
|
p, data = readBytes(data)
|
||||||
}
|
}
|
||||||
err := Unmarshal(elem, p, v[i])
|
err := Unmarshal(elem, p, v[i])
|
||||||
|
@ -1971,7 +2067,7 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error {
|
||||||
|
|
||||||
for i, elem := range tuple.Elems {
|
for i, elem := range tuple.Elems {
|
||||||
var p []byte
|
var p []byte
|
||||||
if len(data) > 4 {
|
if len(data) >= 4 {
|
||||||
p, data = readBytes(data)
|
p, data = readBytes(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1982,7 +2078,11 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error {
|
||||||
|
|
||||||
switch rv.Field(i).Kind() {
|
switch rv.Field(i).Kind() {
|
||||||
case reflect.Ptr:
|
case reflect.Ptr:
|
||||||
rv.Field(i).Set(reflect.ValueOf(v))
|
if p != nil {
|
||||||
|
rv.Field(i).Set(reflect.ValueOf(v))
|
||||||
|
} else {
|
||||||
|
rv.Field(i).Set(reflect.Zero(reflect.TypeOf(v)))
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
rv.Field(i).Set(reflect.ValueOf(v).Elem())
|
rv.Field(i).Set(reflect.ValueOf(v).Elem())
|
||||||
}
|
}
|
||||||
|
@ -2001,7 +2101,7 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error {
|
||||||
|
|
||||||
for i, elem := range tuple.Elems {
|
for i, elem := range tuple.Elems {
|
||||||
var p []byte
|
var p []byte
|
||||||
if len(data) > 4 {
|
if len(data) >= 4 {
|
||||||
p, data = readBytes(data)
|
p, data = readBytes(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2012,7 +2112,11 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error {
|
||||||
|
|
||||||
switch rv.Index(i).Kind() {
|
switch rv.Index(i).Kind() {
|
||||||
case reflect.Ptr:
|
case reflect.Ptr:
|
||||||
rv.Index(i).Set(reflect.ValueOf(v))
|
if p != nil {
|
||||||
|
rv.Index(i).Set(reflect.ValueOf(v))
|
||||||
|
} else {
|
||||||
|
rv.Index(i).Set(reflect.Zero(reflect.TypeOf(v)))
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
rv.Index(i).Set(reflect.ValueOf(v).Elem())
|
rv.Index(i).Set(reflect.ValueOf(v).Elem())
|
||||||
}
|
}
|
||||||
|
@ -2050,7 +2154,7 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) {
|
||||||
case Marshaler:
|
case Marshaler:
|
||||||
return v.MarshalCQL(info)
|
return v.MarshalCQL(info)
|
||||||
case unsetColumn:
|
case unsetColumn:
|
||||||
return nil, unmarshalErrorf("Invalid request: UnsetValue is unsupported for user defined types")
|
return nil, unmarshalErrorf("invalid request: UnsetValue is unsupported for user defined types")
|
||||||
case UDTMarshaler:
|
case UDTMarshaler:
|
||||||
var buf []byte
|
var buf []byte
|
||||||
for _, e := range udt.Elements {
|
for _, e := range udt.Elements {
|
||||||
|
|
|
@ -324,10 +324,10 @@ func compileMetadata(
|
||||||
keyspace.Functions[functions[i].Name] = &functions[i]
|
keyspace.Functions[functions[i].Name] = &functions[i]
|
||||||
}
|
}
|
||||||
keyspace.Aggregates = make(map[string]*AggregateMetadata, len(aggregates))
|
keyspace.Aggregates = make(map[string]*AggregateMetadata, len(aggregates))
|
||||||
for _, aggregate := range aggregates {
|
for i, _ := range aggregates {
|
||||||
aggregate.FinalFunc = *keyspace.Functions[aggregate.finalFunc]
|
aggregates[i].FinalFunc = *keyspace.Functions[aggregates[i].finalFunc]
|
||||||
aggregate.StateFunc = *keyspace.Functions[aggregate.stateFunc]
|
aggregates[i].StateFunc = *keyspace.Functions[aggregates[i].stateFunc]
|
||||||
keyspace.Aggregates[aggregate.Name] = &aggregate
|
keyspace.Aggregates[aggregates[i].Name] = &aggregates[i]
|
||||||
}
|
}
|
||||||
keyspace.Views = make(map[string]*ViewMetadata, len(views))
|
keyspace.Views = make(map[string]*ViewMetadata, len(views))
|
||||||
for i := range views {
|
for i := range views {
|
||||||
|
@ -347,9 +347,9 @@ func compileMetadata(
|
||||||
keyspace.UserTypes[types[i].Name] = &types[i]
|
keyspace.UserTypes[types[i].Name] = &types[i]
|
||||||
}
|
}
|
||||||
keyspace.MaterializedViews = make(map[string]*MaterializedViewMetadata, len(materializedViews))
|
keyspace.MaterializedViews = make(map[string]*MaterializedViewMetadata, len(materializedViews))
|
||||||
for _, materializedView := range materializedViews {
|
for i, _ := range materializedViews {
|
||||||
materializedView.BaseTable = keyspace.Tables[materializedView.baseTableName]
|
materializedViews[i].BaseTable = keyspace.Tables[materializedViews[i].baseTableName]
|
||||||
keyspace.MaterializedViews[materializedView.Name] = &materializedView
|
keyspace.MaterializedViews[materializedViews[i].Name] = &materializedViews[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
// add columns from the schema data
|
// add columns from the schema data
|
||||||
|
@ -559,7 +559,7 @@ func getKeyspaceMetadata(session *Session, keyspaceName string) (*KeyspaceMetada
|
||||||
iter.Scan(&keyspace.DurableWrites, &replication)
|
iter.Scan(&keyspace.DurableWrites, &replication)
|
||||||
err := iter.Close()
|
err := iter.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Error querying keyspace schema: %v", err)
|
return nil, fmt.Errorf("error querying keyspace schema: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
keyspace.StrategyClass = replication["class"]
|
keyspace.StrategyClass = replication["class"]
|
||||||
|
@ -585,13 +585,13 @@ func getKeyspaceMetadata(session *Session, keyspaceName string) (*KeyspaceMetada
|
||||||
iter.Scan(&keyspace.DurableWrites, &keyspace.StrategyClass, &strategyOptionsJSON)
|
iter.Scan(&keyspace.DurableWrites, &keyspace.StrategyClass, &strategyOptionsJSON)
|
||||||
err := iter.Close()
|
err := iter.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Error querying keyspace schema: %v", err)
|
return nil, fmt.Errorf("error querying keyspace schema: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = json.Unmarshal(strategyOptionsJSON, &keyspace.StrategyOptions)
|
err = json.Unmarshal(strategyOptionsJSON, &keyspace.StrategyOptions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf(
|
return nil, fmt.Errorf(
|
||||||
"Invalid JSON value '%s' as strategy_options for in keyspace '%s': %v",
|
"invalid JSON value '%s' as strategy_options for in keyspace '%s': %v",
|
||||||
strategyOptionsJSON, keyspace.Name, err,
|
strategyOptionsJSON, keyspace.Name, err,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -703,7 +703,7 @@ func getTableMetadata(session *Session, keyspaceName string) ([]TableMetadata, e
|
||||||
if err != nil {
|
if err != nil {
|
||||||
iter.Close()
|
iter.Close()
|
||||||
return nil, fmt.Errorf(
|
return nil, fmt.Errorf(
|
||||||
"Invalid JSON value '%s' as key_aliases for in table '%s': %v",
|
"invalid JSON value '%s' as key_aliases for in table '%s': %v",
|
||||||
keyAliasesJSON, table.Name, err,
|
keyAliasesJSON, table.Name, err,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -716,7 +716,7 @@ func getTableMetadata(session *Session, keyspaceName string) ([]TableMetadata, e
|
||||||
if err != nil {
|
if err != nil {
|
||||||
iter.Close()
|
iter.Close()
|
||||||
return nil, fmt.Errorf(
|
return nil, fmt.Errorf(
|
||||||
"Invalid JSON value '%s' as column_aliases for in table '%s': %v",
|
"invalid JSON value '%s' as column_aliases for in table '%s': %v",
|
||||||
columnAliasesJSON, table.Name, err,
|
columnAliasesJSON, table.Name, err,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -728,7 +728,7 @@ func getTableMetadata(session *Session, keyspaceName string) ([]TableMetadata, e
|
||||||
|
|
||||||
err := iter.Close()
|
err := iter.Close()
|
||||||
if err != nil && err != ErrNotFound {
|
if err != nil && err != ErrNotFound {
|
||||||
return nil, fmt.Errorf("Error querying table schema: %v", err)
|
return nil, fmt.Errorf("error querying table schema: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return tables, nil
|
return tables, nil
|
||||||
|
@ -777,7 +777,7 @@ func (s *Session) scanColumnMetadataV1(keyspace string) ([]ColumnMetadata, error
|
||||||
err := json.Unmarshal(indexOptionsJSON, &column.Index.Options)
|
err := json.Unmarshal(indexOptionsJSON, &column.Index.Options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf(
|
return nil, fmt.Errorf(
|
||||||
"Invalid JSON value '%s' as index_options for column '%s' in table '%s': %v",
|
"invalid JSON value '%s' as index_options for column '%s' in table '%s': %v",
|
||||||
indexOptionsJSON,
|
indexOptionsJSON,
|
||||||
column.Name,
|
column.Name,
|
||||||
column.Table,
|
column.Table,
|
||||||
|
@ -837,7 +837,7 @@ func (s *Session) scanColumnMetadataV2(keyspace string) ([]ColumnMetadata, error
|
||||||
err := json.Unmarshal(indexOptionsJSON, &column.Index.Options)
|
err := json.Unmarshal(indexOptionsJSON, &column.Index.Options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf(
|
return nil, fmt.Errorf(
|
||||||
"Invalid JSON value '%s' as index_options for column '%s' in table '%s': %v",
|
"invalid JSON value '%s' as index_options for column '%s' in table '%s': %v",
|
||||||
indexOptionsJSON,
|
indexOptionsJSON,
|
||||||
column.Name,
|
column.Name,
|
||||||
column.Table,
|
column.Table,
|
||||||
|
@ -915,7 +915,7 @@ func getColumnMetadata(session *Session, keyspaceName string) ([]ColumnMetadata,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil && err != ErrNotFound {
|
if err != nil && err != ErrNotFound {
|
||||||
return nil, fmt.Errorf("Error querying column schema: %v", err)
|
return nil, fmt.Errorf("error querying column schema: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return columns, nil
|
return columns, nil
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
// Copyright (c) 2012 The gocql Authors. All rights reserved.
|
// Copyright (c) 2012 The gocql Authors. All rights reserved.
|
||||||
// Use of this source code is governed by a BSD-style
|
// Use of this source code is governed by a BSD-style
|
||||||
// license that can be found in the LICENSE file.
|
// license that can be found in the LICENSE file.
|
||||||
//This file will be the future home for more policies
|
|
||||||
package gocql
|
package gocql
|
||||||
|
|
||||||
|
//This file will be the future home for more policies
|
||||||
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -37,12 +40,6 @@ func (c *cowHostList) get() []*HostInfo {
|
||||||
return *l
|
return *l
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *cowHostList) set(list []*HostInfo) {
|
|
||||||
c.mu.Lock()
|
|
||||||
c.list.Store(&list)
|
|
||||||
c.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// add will add a host if it not already in the list
|
// add will add a host if it not already in the list
|
||||||
func (c *cowHostList) add(host *HostInfo) bool {
|
func (c *cowHostList) add(host *HostInfo) bool {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
|
@ -68,33 +65,6 @@ func (c *cowHostList) add(host *HostInfo) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *cowHostList) update(host *HostInfo) {
|
|
||||||
c.mu.Lock()
|
|
||||||
l := c.get()
|
|
||||||
|
|
||||||
if len(l) == 0 {
|
|
||||||
c.mu.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
found := false
|
|
||||||
newL := make([]*HostInfo, len(l))
|
|
||||||
for i := range l {
|
|
||||||
if host.Equal(l[i]) {
|
|
||||||
newL[i] = host
|
|
||||||
found = true
|
|
||||||
} else {
|
|
||||||
newL[i] = l[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if found {
|
|
||||||
c.list.Store(&newL)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *cowHostList) remove(ip net.IP) bool {
|
func (c *cowHostList) remove(ip net.IP) bool {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
l := c.get()
|
l := c.get()
|
||||||
|
@ -304,7 +274,10 @@ type HostSelectionPolicy interface {
|
||||||
KeyspaceChanged(KeyspaceUpdateEvent)
|
KeyspaceChanged(KeyspaceUpdateEvent)
|
||||||
Init(*Session)
|
Init(*Session)
|
||||||
IsLocal(host *HostInfo) bool
|
IsLocal(host *HostInfo) bool
|
||||||
//Pick returns an iteration function over selected hosts
|
// Pick returns an iteration function over selected hosts.
|
||||||
|
// Multiple attempts of a single query execution won't call the returned NextHost function concurrently,
|
||||||
|
// so it's safe to have internal state without additional synchronization as long as every call to Pick returns
|
||||||
|
// a different instance of NextHost.
|
||||||
Pick(ExecutableQuery) NextHost
|
Pick(ExecutableQuery) NextHost
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -880,6 +853,51 @@ func (d *dcAwareRR) Pick(q ExecutableQuery) NextHost {
|
||||||
return roundRobbin(int(nextStartOffset), d.localHosts.get(), d.remoteHosts.get())
|
return roundRobbin(int(nextStartOffset), d.localHosts.get(), d.remoteHosts.get())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ReadyPolicy defines a policy for when a HostSelectionPolicy can be used. After
|
||||||
|
// each host connects during session initialization, the Ready method will be
|
||||||
|
// called. If you only need a single Host to be up you can wrap a
|
||||||
|
// HostSelectionPolicy policy with SingleHostReadyPolicy.
|
||||||
|
type ReadyPolicy interface {
|
||||||
|
Ready() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// SingleHostReadyPolicy wraps a HostSelectionPolicy and returns Ready after a
|
||||||
|
// single host has been added via HostUp
|
||||||
|
func SingleHostReadyPolicy(p HostSelectionPolicy) *singleHostReadyPolicy {
|
||||||
|
return &singleHostReadyPolicy{
|
||||||
|
HostSelectionPolicy: p,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type singleHostReadyPolicy struct {
|
||||||
|
HostSelectionPolicy
|
||||||
|
ready bool
|
||||||
|
readyMux sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *singleHostReadyPolicy) HostUp(host *HostInfo) {
|
||||||
|
s.HostSelectionPolicy.HostUp(host)
|
||||||
|
|
||||||
|
s.readyMux.Lock()
|
||||||
|
s.ready = true
|
||||||
|
s.readyMux.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *singleHostReadyPolicy) Ready() bool {
|
||||||
|
s.readyMux.Lock()
|
||||||
|
ready := s.ready
|
||||||
|
s.readyMux.Unlock()
|
||||||
|
if !ready {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// in case the wrapped policy is also a ReadyPolicy, defer to that
|
||||||
|
if rdy, ok := s.HostSelectionPolicy.(ReadyPolicy); ok {
|
||||||
|
return rdy.Ready()
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// ConvictionPolicy interface is used by gocql to determine if a host should be
|
// ConvictionPolicy interface is used by gocql to determine if a host should be
|
||||||
// marked as DOWN based on the error and host info
|
// marked as DOWN based on the error and host info
|
||||||
type ConvictionPolicy interface {
|
type ConvictionPolicy interface {
|
||||||
|
|
|
@ -14,18 +14,6 @@ type preparedLRU struct {
|
||||||
lru *lru.Cache
|
lru *lru.Cache
|
||||||
}
|
}
|
||||||
|
|
||||||
// Max adjusts the maximum size of the cache and cleans up the oldest records if
|
|
||||||
// the new max is lower than the previous value. Not concurrency safe.
|
|
||||||
func (p *preparedLRU) max(max int) {
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
|
|
||||||
for p.lru.Len() > max {
|
|
||||||
p.lru.RemoveOldest()
|
|
||||||
}
|
|
||||||
p.lru.MaxEntries = max
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *preparedLRU) clear() {
|
func (p *preparedLRU) clear() {
|
||||||
p.mu.Lock()
|
p.mu.Lock()
|
||||||
defer p.mu.Unlock()
|
defer p.mu.Unlock()
|
||||||
|
|
|
@ -2,6 +2,7 @@ package gocql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -34,14 +35,15 @@ func (q *queryExecutor) attemptQuery(ctx context.Context, qry ExecutableQuery, c
|
||||||
return iter
|
return iter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp SpeculativeExecutionPolicy, results chan *Iter) *Iter {
|
func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp SpeculativeExecutionPolicy,
|
||||||
|
hostIter NextHost, results chan *Iter) *Iter {
|
||||||
ticker := time.NewTicker(sp.Delay())
|
ticker := time.NewTicker(sp.Delay())
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for i := 0; i < sp.Attempts(); i++ {
|
for i := 0; i < sp.Attempts(); i++ {
|
||||||
select {
|
select {
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
go q.run(ctx, qry, results)
|
go q.run(ctx, qry, hostIter, results)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return &Iter{err: ctx.Err()}
|
return &Iter{err: ctx.Err()}
|
||||||
case iter := <-results:
|
case iter := <-results:
|
||||||
|
@ -53,11 +55,23 @@ func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp S
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
|
func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
|
||||||
|
hostIter := q.policy.Pick(qry)
|
||||||
|
|
||||||
// check if the query is not marked as idempotent, if
|
// check if the query is not marked as idempotent, if
|
||||||
// it is, we force the policy to NonSpeculative
|
// it is, we force the policy to NonSpeculative
|
||||||
sp := qry.speculativeExecutionPolicy()
|
sp := qry.speculativeExecutionPolicy()
|
||||||
if !qry.IsIdempotent() || sp.Attempts() == 0 {
|
if !qry.IsIdempotent() || sp.Attempts() == 0 {
|
||||||
return q.do(qry.Context(), qry), nil
|
return q.do(qry.Context(), qry, hostIter), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// When speculative execution is enabled, we could be accessing the host iterator from multiple goroutines below.
|
||||||
|
// To ensure we don't call it concurrently, we wrap the returned NextHost function here to synchronize access to it.
|
||||||
|
var mu sync.Mutex
|
||||||
|
origHostIter := hostIter
|
||||||
|
hostIter = func() SelectedHost {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
return origHostIter()
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(qry.Context())
|
ctx, cancel := context.WithCancel(qry.Context())
|
||||||
|
@ -66,12 +80,12 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
|
||||||
results := make(chan *Iter, 1)
|
results := make(chan *Iter, 1)
|
||||||
|
|
||||||
// Launch the main execution
|
// Launch the main execution
|
||||||
go q.run(ctx, qry, results)
|
go q.run(ctx, qry, hostIter, results)
|
||||||
|
|
||||||
// The speculative executions are launched _in addition_ to the main
|
// The speculative executions are launched _in addition_ to the main
|
||||||
// execution, on a timer. So Speculation{2} would make 3 executions running
|
// execution, on a timer. So Speculation{2} would make 3 executions running
|
||||||
// in total.
|
// in total.
|
||||||
if iter := q.speculate(ctx, qry, sp, results); iter != nil {
|
if iter := q.speculate(ctx, qry, sp, hostIter, results); iter != nil {
|
||||||
return iter, nil
|
return iter, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,8 +97,7 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery) *Iter {
|
func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter NextHost) *Iter {
|
||||||
hostIter := q.policy.Pick(qry)
|
|
||||||
selectedHost := hostIter()
|
selectedHost := hostIter()
|
||||||
rt := qry.retryPolicy()
|
rt := qry.retryPolicy()
|
||||||
|
|
||||||
|
@ -153,9 +166,9 @@ func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery) *Iter {
|
||||||
return &Iter{err: ErrNoConnections}
|
return &Iter{err: ErrNoConnections}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *queryExecutor) run(ctx context.Context, qry ExecutableQuery, results chan<- *Iter) {
|
func (q *queryExecutor) run(ctx context.Context, qry ExecutableQuery, hostIter NextHost, results chan<- *Iter) {
|
||||||
select {
|
select {
|
||||||
case results <- q.do(ctx, qry):
|
case results <- q.do(ctx, qry, hostIter):
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,29 +63,6 @@ func (r *ring) currentHosts() map[string]*HostInfo {
|
||||||
return hosts
|
return hosts
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ring) addHost(host *HostInfo) bool {
|
|
||||||
// TODO(zariel): key all host info by HostID instead of
|
|
||||||
// ip addresses
|
|
||||||
if host.invalidConnectAddr() {
|
|
||||||
panic(fmt.Sprintf("invalid host: %v", host))
|
|
||||||
}
|
|
||||||
ip := host.ConnectAddress().String()
|
|
||||||
|
|
||||||
r.mu.Lock()
|
|
||||||
if r.hosts == nil {
|
|
||||||
r.hosts = make(map[string]*HostInfo)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, ok := r.hosts[ip]
|
|
||||||
if !ok {
|
|
||||||
r.hostList = append(r.hostList, host)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.hosts[ip] = host
|
|
||||||
r.mu.Unlock()
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *ring) addOrUpdate(host *HostInfo) *HostInfo {
|
func (r *ring) addOrUpdate(host *HostInfo) *HostInfo {
|
||||||
if existingHost, ok := r.addHostIfMissing(host); ok {
|
if existingHost, ok := r.addHostIfMissing(host); ok {
|
||||||
existingHost.update(host)
|
existingHost.update(host)
|
||||||
|
|
|
@ -27,7 +27,7 @@ import (
|
||||||
// scenario is to have one global session object to interact with the
|
// scenario is to have one global session object to interact with the
|
||||||
// whole Cassandra cluster.
|
// whole Cassandra cluster.
|
||||||
//
|
//
|
||||||
// This type extends the Node interface by adding a convinient query builder
|
// This type extends the Node interface by adding a convenient query builder
|
||||||
// and automatically sets a default consistency level on all operations
|
// and automatically sets a default consistency level on all operations
|
||||||
// that do not have a consistency level set.
|
// that do not have a consistency level set.
|
||||||
type Session struct {
|
type Session struct {
|
||||||
|
@ -62,7 +62,6 @@ type Session struct {
|
||||||
schemaEvents *eventDebouncer
|
schemaEvents *eventDebouncer
|
||||||
|
|
||||||
// ring metadata
|
// ring metadata
|
||||||
hosts []HostInfo
|
|
||||||
useSystemSchema bool
|
useSystemSchema bool
|
||||||
hasAggregatesAndFunctions bool
|
hasAggregatesAndFunctions bool
|
||||||
|
|
||||||
|
@ -227,18 +226,44 @@ func (s *Session) init() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
hosts = hosts[:0]
|
hosts = hosts[:0]
|
||||||
|
// each host will increment left and decrement it after connecting and once
|
||||||
|
// there's none left, we'll close hostCh
|
||||||
|
var left int64
|
||||||
|
// we will receive up to len(hostMap) of messages so create a buffer so we
|
||||||
|
// don't end up stuck in a goroutine if we stopped listening
|
||||||
|
connectedCh := make(chan struct{}, len(hostMap))
|
||||||
|
// we add one here because we don't want to end up closing hostCh until we're
|
||||||
|
// done looping and the decerement code might be reached before we've looped
|
||||||
|
// again
|
||||||
|
atomic.AddInt64(&left, 1)
|
||||||
for _, host := range hostMap {
|
for _, host := range hostMap {
|
||||||
host = s.ring.addOrUpdate(host)
|
host := s.ring.addOrUpdate(host)
|
||||||
if s.cfg.filterHost(host) {
|
if s.cfg.filterHost(host) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
host.setState(NodeUp)
|
atomic.AddInt64(&left, 1)
|
||||||
s.pool.addHost(host)
|
go func() {
|
||||||
|
s.pool.addHost(host)
|
||||||
|
connectedCh <- struct{}{}
|
||||||
|
|
||||||
|
// if there are no hosts left, then close the hostCh to unblock the loop
|
||||||
|
// below if its still waiting
|
||||||
|
if atomic.AddInt64(&left, -1) == 0 {
|
||||||
|
close(connectedCh)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
hosts = append(hosts, host)
|
hosts = append(hosts, host)
|
||||||
}
|
}
|
||||||
|
// once we're done looping we subtract the one we initially added and check
|
||||||
|
// to see if we should close
|
||||||
|
if atomic.AddInt64(&left, -1) == 0 {
|
||||||
|
close(connectedCh)
|
||||||
|
}
|
||||||
|
|
||||||
|
// before waiting for them to connect, add them all to the policy so we can
|
||||||
|
// utilize efficiencies by calling AddHosts if the policy supports it
|
||||||
type bulkAddHosts interface {
|
type bulkAddHosts interface {
|
||||||
AddHosts([]*HostInfo)
|
AddHosts([]*HostInfo)
|
||||||
}
|
}
|
||||||
|
@ -250,6 +275,15 @@ func (s *Session) init() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
readyPolicy, _ := s.policy.(ReadyPolicy)
|
||||||
|
// now loop over connectedCh until it's closed (meaning we've connected to all)
|
||||||
|
// or until the policy says we're ready
|
||||||
|
for range connectedCh {
|
||||||
|
if readyPolicy != nil && readyPolicy.Ready() {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(zariel): we probably dont need this any more as we verify that we
|
// TODO(zariel): we probably dont need this any more as we verify that we
|
||||||
// can connect to one of the endpoints supplied by using the control conn.
|
// can connect to one of the endpoints supplied by using the control conn.
|
||||||
// See if there are any connections in the pool
|
// See if there are any connections in the pool
|
||||||
|
@ -320,7 +354,8 @@ func (s *Session) reconnectDownedHosts(intv time.Duration) {
|
||||||
if h.IsUp() {
|
if h.IsUp() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.handleNodeUp(h.ConnectAddress(), h.Port(), true)
|
// we let the pool call handleNodeUp to change the host state
|
||||||
|
s.pool.addHost(h)
|
||||||
}
|
}
|
||||||
case <-s.ctx.Done():
|
case <-s.ctx.Done():
|
||||||
return
|
return
|
||||||
|
@ -806,6 +841,7 @@ type Query struct {
|
||||||
trace Tracer
|
trace Tracer
|
||||||
observer QueryObserver
|
observer QueryObserver
|
||||||
session *Session
|
session *Session
|
||||||
|
conn *Conn
|
||||||
rt RetryPolicy
|
rt RetryPolicy
|
||||||
spec SpeculativeExecutionPolicy
|
spec SpeculativeExecutionPolicy
|
||||||
binding func(q *QueryInfo) ([]interface{}, error)
|
binding func(q *QueryInfo) ([]interface{}, error)
|
||||||
|
@ -1094,12 +1130,17 @@ func (q *Query) speculativeExecutionPolicy() SpeculativeExecutionPolicy {
|
||||||
return q.spec
|
return q.spec
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsIdempotent returns whether the query is marked as idempotent.
|
||||||
|
// Non-idempotent query won't be retried.
|
||||||
|
// See "Retries and speculative execution" in package docs for more details.
|
||||||
func (q *Query) IsIdempotent() bool {
|
func (q *Query) IsIdempotent() bool {
|
||||||
return q.idempotent
|
return q.idempotent
|
||||||
}
|
}
|
||||||
|
|
||||||
// Idempotent marks the query as being idempotent or not depending on
|
// Idempotent marks the query as being idempotent or not depending on
|
||||||
// the value.
|
// the value.
|
||||||
|
// Non-idempotent query won't be retried.
|
||||||
|
// See "Retries and speculative execution" in package docs for more details.
|
||||||
func (q *Query) Idempotent(value bool) *Query {
|
func (q *Query) Idempotent(value bool) *Query {
|
||||||
q.idempotent = value
|
q.idempotent = value
|
||||||
return q
|
return q
|
||||||
|
@ -1164,6 +1205,11 @@ func (q *Query) Iter() *Iter {
|
||||||
if isUseStatement(q.stmt) {
|
if isUseStatement(q.stmt) {
|
||||||
return &Iter{err: ErrUseStmt}
|
return &Iter{err: ErrUseStmt}
|
||||||
}
|
}
|
||||||
|
// if the query was specifically run on a connection then re-use that
|
||||||
|
// connection when fetching the next results
|
||||||
|
if q.conn != nil {
|
||||||
|
return q.conn.executeQuery(q.Context(), q)
|
||||||
|
}
|
||||||
return q.session.executeQuery(q)
|
return q.session.executeQuery(q)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1195,6 +1241,10 @@ func (q *Query) Scan(dest ...interface{}) error {
|
||||||
// statement containing an IF clause). If the transaction fails because
|
// statement containing an IF clause). If the transaction fails because
|
||||||
// the existing values did not match, the previous values will be stored
|
// the existing values did not match, the previous values will be stored
|
||||||
// in dest.
|
// in dest.
|
||||||
|
//
|
||||||
|
// As for INSERT .. IF NOT EXISTS, previous values will be returned as if
|
||||||
|
// SELECT * FROM. So using ScanCAS with INSERT is inherently prone to
|
||||||
|
// column mismatching. Use MapScanCAS to capture them safely.
|
||||||
func (q *Query) ScanCAS(dest ...interface{}) (applied bool, err error) {
|
func (q *Query) ScanCAS(dest ...interface{}) (applied bool, err error) {
|
||||||
q.disableSkipMetadata = true
|
q.disableSkipMetadata = true
|
||||||
iter := q.Iter()
|
iter := q.Iter()
|
||||||
|
@ -1423,7 +1473,7 @@ func (iter *Iter) Scan(dest ...interface{}) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
if iter.next != nil && iter.pos >= iter.next.pos {
|
if iter.next != nil && iter.pos >= iter.next.pos {
|
||||||
go iter.next.fetch()
|
iter.next.fetchAsync()
|
||||||
}
|
}
|
||||||
|
|
||||||
// currently only support scanning into an expand tuple, such that its the same
|
// currently only support scanning into an expand tuple, such that its the same
|
||||||
|
@ -1517,16 +1567,31 @@ func (iter *Iter) NumRows() int {
|
||||||
return iter.numRows
|
return iter.numRows
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nextIter holds state for fetching a single page in an iterator.
|
||||||
|
// single page might be attempted multiple times due to retries.
|
||||||
type nextIter struct {
|
type nextIter struct {
|
||||||
qry *Query
|
qry *Query
|
||||||
pos int
|
pos int
|
||||||
once sync.Once
|
oncea sync.Once
|
||||||
next *Iter
|
once sync.Once
|
||||||
|
next *Iter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *nextIter) fetchAsync() {
|
||||||
|
n.oncea.Do(func() {
|
||||||
|
go n.fetch()
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *nextIter) fetch() *Iter {
|
func (n *nextIter) fetch() *Iter {
|
||||||
n.once.Do(func() {
|
n.once.Do(func() {
|
||||||
n.next = n.qry.session.executeQuery(n.qry)
|
// if the query was specifically run on a connection then re-use that
|
||||||
|
// connection when fetching the next results
|
||||||
|
if n.qry.conn != nil {
|
||||||
|
n.next = n.qry.conn.executeQuery(n.qry.Context(), n.qry)
|
||||||
|
} else {
|
||||||
|
n.next = n.qry.session.executeQuery(n.qry)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
return n.next
|
return n.next
|
||||||
}
|
}
|
||||||
|
@ -1536,7 +1601,6 @@ type Batch struct {
|
||||||
Entries []BatchEntry
|
Entries []BatchEntry
|
||||||
Cons Consistency
|
Cons Consistency
|
||||||
routingKey []byte
|
routingKey []byte
|
||||||
routingKeyBuffer []byte
|
|
||||||
CustomPayload map[string][]byte
|
CustomPayload map[string][]byte
|
||||||
rt RetryPolicy
|
rt RetryPolicy
|
||||||
spec SpeculativeExecutionPolicy
|
spec SpeculativeExecutionPolicy
|
||||||
|
@ -1733,7 +1797,7 @@ func (b *Batch) WithTimestamp(timestamp int64) *Batch {
|
||||||
|
|
||||||
func (b *Batch) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) {
|
func (b *Batch) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) {
|
||||||
latency := end.Sub(start)
|
latency := end.Sub(start)
|
||||||
_, metricsForHost := b.metrics.attempt(1, latency, host, b.observer != nil)
|
attempt, metricsForHost := b.metrics.attempt(1, latency, host, b.observer != nil)
|
||||||
|
|
||||||
if b.observer == nil {
|
if b.observer == nil {
|
||||||
return
|
return
|
||||||
|
@ -1753,6 +1817,7 @@ func (b *Batch) attempt(keyspace string, end, start time.Time, iter *Iter, host
|
||||||
Host: host,
|
Host: host,
|
||||||
Metrics: metricsForHost,
|
Metrics: metricsForHost,
|
||||||
Err: iter.err,
|
Err: iter.err,
|
||||||
|
Attempt: attempt,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1968,7 +2033,6 @@ type ObservedQuery struct {
|
||||||
Err error
|
Err error
|
||||||
|
|
||||||
// Attempt is the index of attempt at executing this query.
|
// Attempt is the index of attempt at executing this query.
|
||||||
// An attempt might be either retry or fetching next page of a query.
|
|
||||||
// The first attempt is number zero and any retries have non-zero attempt number.
|
// The first attempt is number zero and any retries have non-zero attempt number.
|
||||||
Attempt int
|
Attempt int
|
||||||
}
|
}
|
||||||
|
@ -1999,6 +2063,10 @@ type ObservedBatch struct {
|
||||||
|
|
||||||
// The metrics per this host
|
// The metrics per this host
|
||||||
Metrics *hostMetrics
|
Metrics *hostMetrics
|
||||||
|
|
||||||
|
// Attempt is the index of attempt at executing this query.
|
||||||
|
// The first attempt is number zero and any retries have non-zero attempt number.
|
||||||
|
Attempt int
|
||||||
}
|
}
|
||||||
|
|
||||||
// BatchObserver is the interface implemented by batch observers / stat collectors.
|
// BatchObserver is the interface implemented by batch observers / stat collectors.
|
||||||
|
|
|
@ -153,7 +153,7 @@ func newTokenRing(partitioner string, hosts []*HostInfo) (*tokenRing, error) {
|
||||||
} else if strings.HasSuffix(partitioner, "RandomPartitioner") {
|
} else if strings.HasSuffix(partitioner, "RandomPartitioner") {
|
||||||
tokenRing.partitioner = randomPartitioner{}
|
tokenRing.partitioner = randomPartitioner{}
|
||||||
} else {
|
} else {
|
||||||
return nil, fmt.Errorf("Unsupported partitioner '%s'", partitioner)
|
return nil, fmt.Errorf("unsupported partitioner '%s'", partitioner)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, host := range hosts {
|
for _, host := range hosts {
|
||||||
|
|
|
@ -46,32 +46,35 @@ type placementStrategy interface {
|
||||||
replicationFactor(dc string) int
|
replicationFactor(dc string) int
|
||||||
}
|
}
|
||||||
|
|
||||||
func getReplicationFactorFromOpts(keyspace string, val interface{}) int {
|
func getReplicationFactorFromOpts(val interface{}) (int, error) {
|
||||||
// TODO: dont really want to panic here, but is better
|
|
||||||
// than spamming
|
|
||||||
switch v := val.(type) {
|
switch v := val.(type) {
|
||||||
case int:
|
case int:
|
||||||
if v <= 0 {
|
if v < 0 {
|
||||||
panic(fmt.Sprintf("invalid replication_factor %d. Is the %q keyspace configured correctly?", v, keyspace))
|
return 0, fmt.Errorf("invalid replication_factor %d", v)
|
||||||
}
|
}
|
||||||
return v
|
return v, nil
|
||||||
case string:
|
case string:
|
||||||
n, err := strconv.Atoi(v)
|
n, err := strconv.Atoi(v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("invalid replication_factor. Is the %q keyspace configured correctly? %v", keyspace, err))
|
return 0, fmt.Errorf("invalid replication_factor %q: %v", v, err)
|
||||||
} else if n <= 0 {
|
} else if n < 0 {
|
||||||
panic(fmt.Sprintf("invalid replication_factor %d. Is the %q keyspace configured correctly?", n, keyspace))
|
return 0, fmt.Errorf("invalid replication_factor %d", n)
|
||||||
}
|
}
|
||||||
return n
|
return n, nil
|
||||||
default:
|
default:
|
||||||
panic(fmt.Sprintf("unkown replication_factor type %T", v))
|
return 0, fmt.Errorf("unknown replication_factor type %T", v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getStrategy(ks *KeyspaceMetadata) placementStrategy {
|
func getStrategy(ks *KeyspaceMetadata) placementStrategy {
|
||||||
switch {
|
switch {
|
||||||
case strings.Contains(ks.StrategyClass, "SimpleStrategy"):
|
case strings.Contains(ks.StrategyClass, "SimpleStrategy"):
|
||||||
return &simpleStrategy{rf: getReplicationFactorFromOpts(ks.Name, ks.StrategyOptions["replication_factor"])}
|
rf, err := getReplicationFactorFromOpts(ks.StrategyOptions["replication_factor"])
|
||||||
|
if err != nil {
|
||||||
|
Logger.Printf("parse rf for keyspace %q: %v", ks.Name, err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &simpleStrategy{rf: rf}
|
||||||
case strings.Contains(ks.StrategyClass, "NetworkTopologyStrategy"):
|
case strings.Contains(ks.StrategyClass, "NetworkTopologyStrategy"):
|
||||||
dcs := make(map[string]int)
|
dcs := make(map[string]int)
|
||||||
for dc, rf := range ks.StrategyOptions {
|
for dc, rf := range ks.StrategyOptions {
|
||||||
|
@ -79,14 +82,21 @@ func getStrategy(ks *KeyspaceMetadata) placementStrategy {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
dcs[dc] = getReplicationFactorFromOpts(ks.Name+":dc="+dc, rf)
|
rf, err := getReplicationFactorFromOpts(rf)
|
||||||
|
if err != nil {
|
||||||
|
Logger.Println("parse rf for keyspace %q, dc %q: %v", err)
|
||||||
|
// skip DC if the rf is invalid/unsupported, so that we can at least work with other working DCs.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dcs[dc] = rf
|
||||||
}
|
}
|
||||||
return &networkTopology{dcs: dcs}
|
return &networkTopology{dcs: dcs}
|
||||||
case strings.Contains(ks.StrategyClass, "LocalStrategy"):
|
case strings.Contains(ks.StrategyClass, "LocalStrategy"):
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
// TODO: handle unknown replicas and just return the primary host for a token
|
Logger.Printf("parse rf for keyspace %q: unsupported strategy class: %v", ks.StrategyClass)
|
||||||
panic(fmt.Sprintf("unsupported strategy class: %v", ks.StrategyClass))
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,11 +2,12 @@
|
||||||
// Use of this source code is governed by a BSD-style
|
// Use of this source code is governed by a BSD-style
|
||||||
// license that can be found in the LICENSE file.
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package gocql
|
||||||
|
|
||||||
// The uuid package can be used to generate and parse universally unique
|
// The uuid package can be used to generate and parse universally unique
|
||||||
// identifiers, a standardized format in the form of a 128 bit number.
|
// identifiers, a standardized format in the form of a 128 bit number.
|
||||||
//
|
//
|
||||||
// http://tools.ietf.org/html/rfc4122
|
// http://tools.ietf.org/html/rfc4122
|
||||||
package gocql
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
|
|
@ -402,7 +402,7 @@ github.com/go-stack/stack
|
||||||
github.com/go-test/deep
|
github.com/go-test/deep
|
||||||
# github.com/go-yaml/yaml v2.1.0+incompatible
|
# github.com/go-yaml/yaml v2.1.0+incompatible
|
||||||
github.com/go-yaml/yaml
|
github.com/go-yaml/yaml
|
||||||
# github.com/gocql/gocql v0.0.0-20200624222514-34081eda590e
|
# github.com/gocql/gocql v0.0.0-20210401103645-80ab1e13e309
|
||||||
## explicit
|
## explicit
|
||||||
github.com/gocql/gocql
|
github.com/gocql/gocql
|
||||||
github.com/gocql/gocql/internal/lru
|
github.com/gocql/gocql/internal/lru
|
||||||
|
|
Loading…
Reference in New Issue