open-vault/physical/cassandra.go

297 lines
7.5 KiB
Go

package physical
import (
"crypto/tls"
"fmt"
"io/ioutil"
"strconv"
"strings"
"time"
log "github.com/mgutz/logxi/v1"
"github.com/armon/go-metrics"
"github.com/gocql/gocql"
"github.com/hashicorp/vault/helper/certutil"
)
// CassandraBackend is a physical backend that stores data in Cassandra.
type CassandraBackend struct {
sess *gocql.Session
table string
logger log.Logger
}
// newCassandraBackend constructs a Cassandra backend using a pre-existing
// keyspace and table.
func newCassandraBackend(conf map[string]string, logger log.Logger) (Backend, error) {
splitArray := func(v string) []string {
return strings.FieldsFunc(v, func(r rune) bool {
return r == ','
})
}
var (
hosts = splitArray(conf["hosts"])
keyspace = conf["keyspace"]
table = conf["table"]
consistency = gocql.LocalQuorum
)
if len(hosts) == 0 {
hosts = []string{"localhost"}
}
if keyspace == "" {
keyspace = "vault"
}
if table == "" {
table = "entries"
}
if cs, ok := conf["consistency"]; ok {
switch cs {
case "ANY":
consistency = gocql.Any
case "ONE":
consistency = gocql.One
case "TWO":
consistency = gocql.Two
case "THREE":
consistency = gocql.Three
case "QUORUM":
consistency = gocql.Quorum
case "ALL":
consistency = gocql.All
case "LOCAL_QUORUM":
consistency = gocql.LocalQuorum
case "EACH_QUORUM":
consistency = gocql.EachQuorum
case "LOCAL_ONE":
consistency = gocql.LocalOne
default:
return nil, fmt.Errorf("'consistency' must be one of {ANY, ONE, TWO, THREE, QUORUM, ALL, LOCAL_QUORUM, EACH_QUORUM, LOCAL_ONE}")
}
}
connectStart := time.Now()
cluster := gocql.NewCluster(hosts...)
cluster.Keyspace = keyspace
cluster.ProtoVersion = 2
if protoVersionStr, ok := conf["protocol_version"]; ok {
protoVersion, err := strconv.Atoi(protoVersionStr)
if err != nil {
return nil, fmt.Errorf("'protocol_version' must be an integer")
}
cluster.ProtoVersion = protoVersion
}
if username, ok := conf["username"]; ok {
if cluster.ProtoVersion < 2 {
return nil, fmt.Errorf("Authentication is not supported with protocol version < 2")
}
authenticator := gocql.PasswordAuthenticator{Username: username}
if password, ok := conf["password"]; ok {
authenticator.Password = password
}
cluster.Authenticator = authenticator
}
if connTimeoutStr, ok := conf["connection_timeout"]; ok {
connectionTimeout, err := strconv.Atoi(connTimeoutStr)
if err != nil {
return nil, fmt.Errorf("'connection_timeout' must be an integer")
}
cluster.Timeout = time.Duration(connectionTimeout) * time.Second
}
if err := setupCassandraTLS(conf, cluster); err != nil {
return nil, err
}
sess, err := cluster.CreateSession()
if err != nil {
return nil, err
}
metrics.MeasureSince([]string{"cassandra", "connect"}, connectStart)
sess.SetConsistency(consistency)
impl := &CassandraBackend{
sess: sess,
table: table,
logger: logger}
return impl, nil
}
func setupCassandraTLS(conf map[string]string, cluster *gocql.ClusterConfig) error {
tlsOnStr, ok := conf["tls"]
if !ok {
return nil
}
tlsOn, err := strconv.Atoi(tlsOnStr)
if err != nil {
return fmt.Errorf("'tls' must be an integer (0 or 1)")
}
if tlsOn == 0 {
return nil
}
var tlsConfig = &tls.Config{}
if pemBundlePath, ok := conf["pem_bundle"]; ok {
pemBundleData, err := ioutil.ReadFile(pemBundlePath)
if err != nil {
return fmt.Errorf("Error reading pem bundle from %s: %v", pemBundlePath, err)
}
pemBundle, err := certutil.ParsePEMBundle(string(pemBundleData))
if err != nil {
return fmt.Errorf("Error parsing 'pem_bundle': %v", err)
}
tlsConfig, err = pemBundle.GetTLSConfig(certutil.TLSClient)
if err != nil {
return err
}
} else {
if pemJSONStr, ok := conf["pem_json"]; ok {
pemJSON, err := certutil.ParsePKIJSON([]byte(pemJSONStr))
if err != nil {
return err
}
tlsConfig, err = pemJSON.GetTLSConfig(certutil.TLSClient)
if err != nil {
return err
}
}
}
if tlsSkipVerifyStr, ok := conf["tls_skip_verify"]; ok {
tlsSkipVerify, err := strconv.Atoi(tlsSkipVerifyStr)
if err != nil {
return fmt.Errorf("'tls_skip_verify' must be an integer (0 or 1)")
}
if tlsSkipVerify == 0 {
tlsConfig.InsecureSkipVerify = false
} else {
tlsConfig.InsecureSkipVerify = true
}
}
if tlsMinVersion, ok := conf["tls_min_version"]; ok {
switch tlsMinVersion {
case "tls10":
tlsConfig.MinVersion = tls.VersionTLS10
case "tls11":
tlsConfig.MinVersion = tls.VersionTLS11
case "tls12":
tlsConfig.MinVersion = tls.VersionTLS12
default:
return fmt.Errorf("'tls_min_version' must be one of `tls10`, `tls11` or `tls12`")
}
}
cluster.SslOpts = &gocql.SslOptions{
Config: *tlsConfig.Clone(),
}
return nil
}
// bucketName sanitises a bucket name for Cassandra
func (c *CassandraBackend) bucketName(name string) string {
if name == "" {
name = "."
}
return strings.TrimRight(name, "/")
}
// bucket returns all the prefix buckets the key should be stored at
func (c *CassandraBackend) buckets(key string) []string {
vals := append([]string{""}, prefixes(key)...)
for i, v := range vals {
vals[i] = c.bucketName(v)
}
return vals
}
// bucket returns the most specific bucket for the key
func (c *CassandraBackend) bucket(key string) string {
bs := c.buckets(key)
return bs[len(bs)-1]
}
// Put is used to insert or update an entry
func (c *CassandraBackend) Put(entry *Entry) error {
defer metrics.MeasureSince([]string{"cassandra", "put"}, time.Now())
stmt := fmt.Sprintf(`INSERT INTO "%s" (bucket, key, value) VALUES (?, ?, ?)`, c.table)
batch := gocql.NewBatch(gocql.LoggedBatch)
for _, bucket := range c.buckets(entry.Key) {
batch.Entries = append(batch.Entries, gocql.BatchEntry{
Stmt: stmt,
Args: []interface{}{bucket, entry.Key, entry.Value}})
}
return c.sess.ExecuteBatch(batch)
}
// Get is used to fetch an entry
func (c *CassandraBackend) Get(key string) (*Entry, error) {
defer metrics.MeasureSince([]string{"cassandra", "get"}, time.Now())
v := []byte(nil)
stmt := fmt.Sprintf(`SELECT value FROM "%s" WHERE bucket = ? AND key = ? LIMIT 1`, c.table)
q := c.sess.Query(stmt, c.bucket(key), key)
if err := q.Scan(&v); err != nil {
if err == gocql.ErrNotFound {
return nil, nil
}
return nil, err
}
return &Entry{
Key: key,
Value: v,
}, nil
}
// Delete is used to permanently delete an entry
func (c *CassandraBackend) Delete(key string) error {
defer metrics.MeasureSince([]string{"cassandra", "delete"}, time.Now())
stmt := fmt.Sprintf(`DELETE FROM "%s" WHERE bucket = ? AND key = ?`, c.table)
batch := gocql.NewBatch(gocql.LoggedBatch)
for _, bucket := range c.buckets(key) {
batch.Entries = append(batch.Entries, gocql.BatchEntry{
Stmt: stmt,
Args: []interface{}{bucket, key}})
}
return c.sess.ExecuteBatch(batch)
}
// List is used ot list all the keys under a given
// prefix, up to the next prefix.
func (c *CassandraBackend) List(prefix string) ([]string, error) {
defer metrics.MeasureSince([]string{"cassandra", "list"}, time.Now())
stmt := fmt.Sprintf(`SELECT key FROM "%s" WHERE bucket = ?`, c.table)
q := c.sess.Query(stmt, c.bucketName(prefix))
iter := q.Iter()
k, keys := "", []string{}
for iter.Scan(&k) {
// Only return the next "component" (with a trailing slash if it has children)
k = strings.TrimPrefix(k, prefix)
if parts := strings.SplitN(k, "/", 2); len(parts) > 1 {
k = parts[0] + "/"
} else {
k = parts[0]
}
// Deduplicate; this works because the keys are sorted
if len(keys) > 0 && keys[len(keys)-1] == k {
continue
}
keys = append(keys, k)
}
return keys, iter.Close()
}