367 lines
9.7 KiB
Go
367 lines
9.7 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package cassandra
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
metrics "github.com/armon/go-metrics"
|
|
"github.com/gocql/gocql"
|
|
log "github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/vault/sdk/helper/certutil"
|
|
"github.com/hashicorp/vault/sdk/physical"
|
|
)
|
|
|
|
// CassandraBackend is a physical backend that stores data in Cassandra.
|
|
type CassandraBackend struct {
|
|
sess *gocql.Session
|
|
table string
|
|
|
|
logger log.Logger
|
|
}
|
|
|
|
// Verify CassandraBackend satisfies the correct interfaces
|
|
var _ physical.Backend = (*CassandraBackend)(nil)
|
|
|
|
// NewCassandraBackend constructs a Cassandra backend using a pre-existing
|
|
// keyspace and table.
|
|
func NewCassandraBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
|
|
splitArray := func(v string) []string {
|
|
return strings.FieldsFunc(v, func(r rune) bool {
|
|
return r == ','
|
|
})
|
|
}
|
|
|
|
var (
|
|
hosts = splitArray(conf["hosts"])
|
|
port = 9042
|
|
explicitPort = false
|
|
keyspace = conf["keyspace"]
|
|
table = conf["table"]
|
|
consistency = gocql.LocalQuorum
|
|
)
|
|
|
|
if len(hosts) == 0 {
|
|
hosts = []string{"localhost"}
|
|
}
|
|
for i, hp := range hosts {
|
|
h, ps, err := net.SplitHostPort(hp)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
p, err := strconv.Atoi(ps)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if explicitPort && p != port {
|
|
return nil, fmt.Errorf("all hosts must have the same port")
|
|
}
|
|
hosts[i], port = h, p
|
|
explicitPort = true
|
|
}
|
|
|
|
if keyspace == "" {
|
|
keyspace = "vault"
|
|
}
|
|
if table == "" {
|
|
table = "entries"
|
|
}
|
|
if cs, ok := conf["consistency"]; ok {
|
|
switch cs {
|
|
case "ANY":
|
|
consistency = gocql.Any
|
|
case "ONE":
|
|
consistency = gocql.One
|
|
case "TWO":
|
|
consistency = gocql.Two
|
|
case "THREE":
|
|
consistency = gocql.Three
|
|
case "QUORUM":
|
|
consistency = gocql.Quorum
|
|
case "ALL":
|
|
consistency = gocql.All
|
|
case "LOCAL_QUORUM":
|
|
consistency = gocql.LocalQuorum
|
|
case "EACH_QUORUM":
|
|
consistency = gocql.EachQuorum
|
|
case "LOCAL_ONE":
|
|
consistency = gocql.LocalOne
|
|
default:
|
|
return nil, fmt.Errorf("'consistency' must be one of {ANY, ONE, TWO, THREE, QUORUM, ALL, LOCAL_QUORUM, EACH_QUORUM, LOCAL_ONE}")
|
|
}
|
|
}
|
|
|
|
connectStart := time.Now()
|
|
cluster := gocql.NewCluster(hosts...)
|
|
cluster.Port = port
|
|
cluster.Keyspace = keyspace
|
|
|
|
if retryCountStr, ok := conf["simple_retry_policy_retries"]; ok {
|
|
retryCount, err := strconv.Atoi(retryCountStr)
|
|
if err != nil || retryCount <= 0 {
|
|
return nil, fmt.Errorf("'simple_retry_policy_retries' must be a positive integer")
|
|
}
|
|
cluster.RetryPolicy = &gocql.SimpleRetryPolicy{NumRetries: retryCount}
|
|
}
|
|
|
|
cluster.ProtoVersion = 2
|
|
if protoVersionStr, ok := conf["protocol_version"]; ok {
|
|
protoVersion, err := strconv.Atoi(protoVersionStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("'protocol_version' must be an integer")
|
|
}
|
|
cluster.ProtoVersion = protoVersion
|
|
}
|
|
|
|
if username, ok := conf["username"]; ok {
|
|
if cluster.ProtoVersion < 2 {
|
|
return nil, fmt.Errorf("authentication is not supported with protocol version < 2")
|
|
}
|
|
authenticator := gocql.PasswordAuthenticator{Username: username}
|
|
if password, ok := conf["password"]; ok {
|
|
authenticator.Password = password
|
|
}
|
|
cluster.Authenticator = authenticator
|
|
}
|
|
|
|
if initialConnectionTimeoutStr, ok := conf["initial_connection_timeout"]; ok {
|
|
initialConnectionTimeout, err := strconv.Atoi(initialConnectionTimeoutStr)
|
|
if err != nil || initialConnectionTimeout <= 0 {
|
|
return nil, fmt.Errorf("'initial_connection_timeout' must be a positive integer")
|
|
}
|
|
cluster.ConnectTimeout = time.Duration(initialConnectionTimeout) * time.Second
|
|
}
|
|
|
|
if connTimeoutStr, ok := conf["connection_timeout"]; ok {
|
|
connectionTimeout, err := strconv.Atoi(connTimeoutStr)
|
|
if err != nil || connectionTimeout <= 0 {
|
|
return nil, fmt.Errorf("'connection_timeout' must be a positive integer")
|
|
}
|
|
cluster.Timeout = time.Duration(connectionTimeout) * time.Second
|
|
}
|
|
|
|
if err := setupCassandraTLS(conf, cluster); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sess, err := cluster.CreateSession()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
metrics.MeasureSince([]string{"cassandra", "connect"}, connectStart)
|
|
sess.SetConsistency(consistency)
|
|
|
|
impl := &CassandraBackend{
|
|
sess: sess,
|
|
table: table,
|
|
logger: logger,
|
|
}
|
|
return impl, nil
|
|
}
|
|
|
|
func setupCassandraTLS(conf map[string]string, cluster *gocql.ClusterConfig) error {
|
|
tlsOnStr, ok := conf["tls"]
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
tlsOn, err := strconv.Atoi(tlsOnStr)
|
|
if err != nil {
|
|
return fmt.Errorf("'tls' must be an integer (0 or 1)")
|
|
}
|
|
|
|
if tlsOn == 0 {
|
|
return nil
|
|
}
|
|
|
|
tlsConfig := &tls.Config{}
|
|
if pemBundlePath, ok := conf["pem_bundle_file"]; ok {
|
|
pemBundleData, err := ioutil.ReadFile(pemBundlePath)
|
|
if err != nil {
|
|
return fmt.Errorf("error reading pem bundle from %q: %w", pemBundlePath, err)
|
|
}
|
|
pemBundle, err := certutil.ParsePEMBundle(string(pemBundleData))
|
|
if err != nil {
|
|
return fmt.Errorf("error parsing 'pem_bundle': %w", err)
|
|
}
|
|
tlsConfig, err = pemBundle.GetTLSConfig(certutil.TLSClient)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else if pemJSONPath, ok := conf["pem_json_file"]; ok {
|
|
pemJSONData, err := ioutil.ReadFile(pemJSONPath)
|
|
if err != nil {
|
|
return fmt.Errorf("error reading json bundle from %q: %w", pemJSONPath, err)
|
|
}
|
|
pemJSON, err := certutil.ParsePKIJSON([]byte(pemJSONData))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
tlsConfig, err = pemJSON.GetTLSConfig(certutil.TLSClient)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if tlsSkipVerifyStr, ok := conf["tls_skip_verify"]; ok {
|
|
tlsSkipVerify, err := strconv.Atoi(tlsSkipVerifyStr)
|
|
if err != nil {
|
|
return fmt.Errorf("'tls_skip_verify' must be an integer (0 or 1)")
|
|
}
|
|
if tlsSkipVerify == 0 {
|
|
tlsConfig.InsecureSkipVerify = false
|
|
} else {
|
|
tlsConfig.InsecureSkipVerify = true
|
|
}
|
|
}
|
|
|
|
if tlsMinVersion, ok := conf["tls_min_version"]; ok {
|
|
switch tlsMinVersion {
|
|
case "tls10":
|
|
tlsConfig.MinVersion = tls.VersionTLS10
|
|
case "tls11":
|
|
tlsConfig.MinVersion = tls.VersionTLS11
|
|
case "tls12":
|
|
tlsConfig.MinVersion = tls.VersionTLS12
|
|
case "tls13":
|
|
tlsConfig.MinVersion = tls.VersionTLS13
|
|
default:
|
|
return fmt.Errorf("'tls_min_version' must be one of `tls10`, `tls11`, `tls12` or `tls13`")
|
|
}
|
|
}
|
|
|
|
cluster.SslOpts = &gocql.SslOptions{
|
|
Config: tlsConfig,
|
|
EnableHostVerification: !tlsConfig.InsecureSkipVerify,
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// bucketName sanitises a bucket name for Cassandra
|
|
func (c *CassandraBackend) bucketName(name string) string {
|
|
if name == "" {
|
|
name = "."
|
|
}
|
|
return strings.TrimRight(name, "/")
|
|
}
|
|
|
|
// bucket returns all the prefix buckets the key should be stored at
|
|
func (c *CassandraBackend) buckets(key string) []string {
|
|
vals := append([]string{""}, physical.Prefixes(key)...)
|
|
for i, v := range vals {
|
|
vals[i] = c.bucketName(v)
|
|
}
|
|
return vals
|
|
}
|
|
|
|
// bucket returns the most specific bucket for the key
|
|
func (c *CassandraBackend) bucket(key string) string {
|
|
bs := c.buckets(key)
|
|
return bs[len(bs)-1]
|
|
}
|
|
|
|
// Put is used to insert or update an entry
|
|
func (c *CassandraBackend) Put(ctx context.Context, entry *physical.Entry) error {
|
|
defer metrics.MeasureSince([]string{"cassandra", "put"}, time.Now())
|
|
|
|
// Execute inserts to each key prefix simultaneously
|
|
stmt := fmt.Sprintf(`INSERT INTO "%s" (bucket, key, value) VALUES (?, ?, ?)`, c.table)
|
|
buckets := c.buckets(entry.Key)
|
|
results := make(chan error, len(buckets))
|
|
for i, _bucket := range buckets {
|
|
go func(i int, bucket string) {
|
|
var value []byte
|
|
if i == len(buckets)-1 {
|
|
// Only store the full value if this is the leaf bucket where the entry will actually be read
|
|
// otherwise this write is just to allow for list operations
|
|
value = entry.Value
|
|
}
|
|
results <- c.sess.Query(stmt, bucket, entry.Key, value).Exec()
|
|
}(i, _bucket)
|
|
}
|
|
for i := 0; i < len(buckets); i++ {
|
|
if err := <-results; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Get is used to fetch an entry
|
|
func (c *CassandraBackend) Get(ctx context.Context, key string) (*physical.Entry, error) {
|
|
defer metrics.MeasureSince([]string{"cassandra", "get"}, time.Now())
|
|
|
|
v := []byte(nil)
|
|
stmt := fmt.Sprintf(`SELECT value FROM "%s" WHERE bucket = ? AND key = ? LIMIT 1`, c.table)
|
|
q := c.sess.Query(stmt, c.bucket(key), key)
|
|
if err := q.Scan(&v); err != nil {
|
|
if err == gocql.ErrNotFound {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
return &physical.Entry{
|
|
Key: key,
|
|
Value: v,
|
|
}, nil
|
|
}
|
|
|
|
// Delete is used to permanently delete an entry
|
|
func (c *CassandraBackend) Delete(ctx context.Context, key string) error {
|
|
defer metrics.MeasureSince([]string{"cassandra", "delete"}, time.Now())
|
|
|
|
stmt := fmt.Sprintf(`DELETE FROM "%s" WHERE bucket = ? AND key = ?`, c.table)
|
|
buckets := c.buckets(key)
|
|
results := make(chan error, len(buckets))
|
|
|
|
for _, bucket := range buckets {
|
|
go func(bucket string) {
|
|
results <- c.sess.Query(stmt, bucket, key).Exec()
|
|
}(bucket)
|
|
}
|
|
|
|
for i := 0; i < len(buckets); i++ {
|
|
if err := <-results; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// List is used ot list all the keys under a given
|
|
// prefix, up to the next prefix.
|
|
func (c *CassandraBackend) List(ctx context.Context, prefix string) ([]string, error) {
|
|
defer metrics.MeasureSince([]string{"cassandra", "list"}, time.Now())
|
|
|
|
stmt := fmt.Sprintf(`SELECT key FROM "%s" WHERE bucket = ?`, c.table)
|
|
q := c.sess.Query(stmt, c.bucketName(prefix))
|
|
iter := q.Iter()
|
|
k, keys := "", []string{}
|
|
for iter.Scan(&k) {
|
|
// Only return the next "component" (with a trailing slash if it has children)
|
|
k = strings.TrimPrefix(k, prefix)
|
|
if parts := strings.SplitN(k, "/", 2); len(parts) > 1 {
|
|
k = parts[0] + "/"
|
|
} else {
|
|
k = parts[0]
|
|
}
|
|
|
|
// Deduplicate; this works because the keys are sorted
|
|
if len(keys) > 0 && keys[len(keys)-1] == k {
|
|
continue
|
|
}
|
|
keys = append(keys, k)
|
|
}
|
|
return keys, iter.Close()
|
|
}
|