27bb03bbc0
* adding copyright header * fix fmt and a test
264 lines
6.8 KiB
Go
264 lines
6.8 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package cassandra
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/hashicorp/vault/sdk/helper/template"
|
|
|
|
"github.com/gocql/gocql"
|
|
multierror "github.com/hashicorp/go-multierror"
|
|
"github.com/hashicorp/go-secure-stdlib/strutil"
|
|
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
|
|
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
|
)
|
|
|
|
const (
|
|
defaultUserCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;`
|
|
defaultUserDeletionCQL = `DROP USER '{{username}}';`
|
|
defaultChangePasswordCQL = `ALTER USER '{{username}}' WITH PASSWORD '{{password}}';`
|
|
cassandraTypeName = "cassandra"
|
|
|
|
defaultUserNameTemplate = `{{ printf "v_%s_%s_%s_%s" (.DisplayName | truncate 15) (.RoleName | truncate 15) (random 20) (unix_time) | truncate 100 | replace "-" "_" | lowercase }}`
|
|
)
|
|
|
|
var _ dbplugin.Database = &Cassandra{}
|
|
|
|
// Cassandra is an implementation of Database interface
|
|
type Cassandra struct {
|
|
*cassandraConnectionProducer
|
|
|
|
usernameProducer template.StringTemplate
|
|
}
|
|
|
|
// New returns a new Cassandra instance
|
|
func New() (interface{}, error) {
|
|
db := new()
|
|
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues)
|
|
|
|
return dbType, nil
|
|
}
|
|
|
|
func new() *Cassandra {
|
|
connProducer := &cassandraConnectionProducer{}
|
|
connProducer.Type = cassandraTypeName
|
|
|
|
return &Cassandra{
|
|
cassandraConnectionProducer: connProducer,
|
|
}
|
|
}
|
|
|
|
// Type returns the TypeName for this backend
|
|
func (c *Cassandra) Type() (string, error) {
|
|
return cassandraTypeName, nil
|
|
}
|
|
|
|
func (c *Cassandra) getConnection(ctx context.Context) (*gocql.Session, error) {
|
|
session, err := c.Connection(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return session.(*gocql.Session), nil
|
|
}
|
|
|
|
func (c *Cassandra) Initialize(ctx context.Context, req dbplugin.InitializeRequest) (dbplugin.InitializeResponse, error) {
|
|
usernameTemplate, err := strutil.GetString(req.Config, "username_template")
|
|
if err != nil {
|
|
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve username_template: %w", err)
|
|
}
|
|
if usernameTemplate == "" {
|
|
usernameTemplate = defaultUserNameTemplate
|
|
}
|
|
|
|
up, err := template.NewTemplate(template.Template(usernameTemplate))
|
|
if err != nil {
|
|
return dbplugin.InitializeResponse{}, fmt.Errorf("unable to initialize username template: %w", err)
|
|
}
|
|
c.usernameProducer = up
|
|
|
|
_, err = c.usernameProducer.Generate(dbplugin.UsernameMetadata{})
|
|
if err != nil {
|
|
return dbplugin.InitializeResponse{}, fmt.Errorf("invalid username template: %w", err)
|
|
}
|
|
|
|
err = c.cassandraConnectionProducer.Initialize(ctx, req)
|
|
if err != nil {
|
|
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to initialize: %w", err)
|
|
}
|
|
|
|
resp := dbplugin.InitializeResponse{
|
|
Config: req.Config,
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
// NewUser generates the username/password on the underlying Cassandra secret backend as instructed by
|
|
// the statements provided.
|
|
func (c *Cassandra) NewUser(ctx context.Context, req dbplugin.NewUserRequest) (dbplugin.NewUserResponse, error) {
|
|
c.Lock()
|
|
defer c.Unlock()
|
|
|
|
session, err := c.getConnection(ctx)
|
|
if err != nil {
|
|
return dbplugin.NewUserResponse{}, err
|
|
}
|
|
|
|
creationCQL := req.Statements.Commands
|
|
if len(creationCQL) == 0 {
|
|
creationCQL = []string{defaultUserCreationCQL}
|
|
}
|
|
|
|
rollbackCQL := req.RollbackStatements.Commands
|
|
if len(rollbackCQL) == 0 {
|
|
rollbackCQL = []string{defaultUserDeletionCQL}
|
|
}
|
|
|
|
username, err := c.usernameProducer.Generate(req.UsernameConfig)
|
|
if err != nil {
|
|
return dbplugin.NewUserResponse{}, err
|
|
}
|
|
|
|
for _, stmt := range creationCQL {
|
|
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
|
|
query = strings.TrimSpace(query)
|
|
if len(query) == 0 {
|
|
continue
|
|
}
|
|
|
|
m := map[string]string{
|
|
"username": username,
|
|
"password": req.Password,
|
|
}
|
|
err = session.
|
|
Query(dbutil.QueryHelper(query, m)).
|
|
WithContext(ctx).
|
|
Exec()
|
|
if err != nil {
|
|
rollbackErr := rollbackUser(ctx, session, username, rollbackCQL)
|
|
if rollbackErr != nil {
|
|
err = multierror.Append(err, rollbackErr)
|
|
}
|
|
return dbplugin.NewUserResponse{}, err
|
|
}
|
|
}
|
|
}
|
|
|
|
resp := dbplugin.NewUserResponse{
|
|
Username: username,
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
func rollbackUser(ctx context.Context, session *gocql.Session, username string, rollbackCQL []string) error {
|
|
for _, stmt := range rollbackCQL {
|
|
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
|
|
query = strings.TrimSpace(query)
|
|
if len(query) == 0 {
|
|
continue
|
|
}
|
|
|
|
m := map[string]string{
|
|
"username": username,
|
|
}
|
|
err := session.
|
|
Query(dbutil.QueryHelper(query, m)).
|
|
WithContext(ctx).
|
|
Exec()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to roll back user %s: %w", username, err)
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *Cassandra) UpdateUser(ctx context.Context, req dbplugin.UpdateUserRequest) (dbplugin.UpdateUserResponse, error) {
|
|
if req.Password == nil && req.Expiration == nil {
|
|
return dbplugin.UpdateUserResponse{}, fmt.Errorf("no changes requested")
|
|
}
|
|
|
|
if req.Password != nil {
|
|
err := c.changeUserPassword(ctx, req.Username, req.Password)
|
|
return dbplugin.UpdateUserResponse{}, err
|
|
}
|
|
// Expiration is no-op
|
|
return dbplugin.UpdateUserResponse{}, nil
|
|
}
|
|
|
|
func (c *Cassandra) changeUserPassword(ctx context.Context, username string, changePass *dbplugin.ChangePassword) error {
|
|
session, err := c.getConnection(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
rotateCQL := changePass.Statements.Commands
|
|
if len(rotateCQL) == 0 {
|
|
rotateCQL = []string{defaultChangePasswordCQL}
|
|
}
|
|
|
|
var result *multierror.Error
|
|
for _, stmt := range rotateCQL {
|
|
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
|
|
query = strings.TrimSpace(query)
|
|
if len(query) == 0 {
|
|
continue
|
|
}
|
|
|
|
m := map[string]string{
|
|
"username": username,
|
|
"password": changePass.NewPassword,
|
|
}
|
|
err := session.
|
|
Query(dbutil.QueryHelper(query, m)).
|
|
WithContext(ctx).
|
|
Exec()
|
|
result = multierror.Append(result, err)
|
|
}
|
|
}
|
|
|
|
return result.ErrorOrNil()
|
|
}
|
|
|
|
// DeleteUser attempts to drop the specified user.
|
|
func (c *Cassandra) DeleteUser(ctx context.Context, req dbplugin.DeleteUserRequest) (dbplugin.DeleteUserResponse, error) {
|
|
c.Lock()
|
|
defer c.Unlock()
|
|
|
|
session, err := c.getConnection(ctx)
|
|
if err != nil {
|
|
return dbplugin.DeleteUserResponse{}, err
|
|
}
|
|
|
|
revocationCQL := req.Statements.Commands
|
|
if len(revocationCQL) == 0 {
|
|
revocationCQL = []string{defaultUserDeletionCQL}
|
|
}
|
|
|
|
var result *multierror.Error
|
|
for _, stmt := range revocationCQL {
|
|
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
|
|
query = strings.TrimSpace(query)
|
|
if len(query) == 0 {
|
|
continue
|
|
}
|
|
|
|
m := map[string]string{
|
|
"username": req.Username,
|
|
}
|
|
err := session.
|
|
Query(dbutil.QueryHelper(query, m)).
|
|
WithContext(ctx).
|
|
Exec()
|
|
|
|
result = multierror.Append(result, err)
|
|
}
|
|
}
|
|
|
|
return dbplugin.DeleteUserResponse{}, result.ErrorOrNil()
|
|
}
|