b747b27afd
registerSchema creates some indirection which is not necessary in this case. newDBSchema can call each of the tables. Enterprise tables can be added from the existing withEnterpriseSchema shim.
410 lines
11 KiB
Go
410 lines
11 KiB
Go
package state
|
|
|
|
import (
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/hashicorp/go-memdb"
|
|
|
|
"github.com/hashicorp/consul/agent/structs"
|
|
)
|
|
|
|
// sessionsTableSchema returns a new table schema used for storing session
|
|
// information.
|
|
func sessionsTableSchema() *memdb.TableSchema {
|
|
return &memdb.TableSchema{
|
|
Name: "sessions",
|
|
Indexes: map[string]*memdb.IndexSchema{
|
|
"id": {
|
|
Name: "id",
|
|
AllowMissing: false,
|
|
Unique: true,
|
|
Indexer: sessionIndexer(),
|
|
},
|
|
"node": {
|
|
Name: "node",
|
|
AllowMissing: false,
|
|
Unique: false,
|
|
Indexer: nodeSessionsIndexer(),
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
// sessionChecksTableSchema returns a new table schema used for storing session
|
|
// checks.
|
|
func sessionChecksTableSchema() *memdb.TableSchema {
|
|
return &memdb.TableSchema{
|
|
Name: "session_checks",
|
|
Indexes: map[string]*memdb.IndexSchema{
|
|
"id": {
|
|
Name: "id",
|
|
AllowMissing: false,
|
|
Unique: true,
|
|
Indexer: &memdb.CompoundIndex{
|
|
Indexes: []memdb.Indexer{
|
|
&memdb.StringFieldIndex{
|
|
Field: "Node",
|
|
Lowercase: true,
|
|
},
|
|
&CheckIDIndex{},
|
|
&memdb.UUIDFieldIndex{
|
|
Field: "Session",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
"node_check": {
|
|
Name: "node_check",
|
|
AllowMissing: false,
|
|
Unique: false,
|
|
Indexer: nodeChecksIndexer(),
|
|
},
|
|
"session": {
|
|
Name: "session",
|
|
AllowMissing: false,
|
|
Unique: false,
|
|
Indexer: &memdb.UUIDFieldIndex{
|
|
Field: "Session",
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
type CheckIDIndex struct {
|
|
}
|
|
|
|
func (index *CheckIDIndex) FromObject(obj interface{}) (bool, []byte, error) {
|
|
v := reflect.ValueOf(obj)
|
|
v = reflect.Indirect(v) // Dereference the pointer if any
|
|
|
|
fv := v.FieldByName("CheckID")
|
|
isPtr := fv.Kind() == reflect.Ptr
|
|
fv = reflect.Indirect(fv)
|
|
if !isPtr && !fv.IsValid() || !fv.CanInterface() {
|
|
return false, nil,
|
|
fmt.Errorf("field 'EnterpriseMeta' for %#v is invalid %v ", obj, isPtr)
|
|
}
|
|
|
|
checkID, ok := fv.Interface().(structs.CheckID)
|
|
if !ok {
|
|
return false, nil, fmt.Errorf("Field 'EnterpriseMeta' is not of type structs.EnterpriseMeta")
|
|
}
|
|
|
|
// Enforce lowercase and add null character as terminator
|
|
id := strings.ToLower(string(checkID.ID)) + "\x00"
|
|
|
|
return true, []byte(id), nil
|
|
}
|
|
|
|
func (index *CheckIDIndex) FromArgs(args ...interface{}) ([]byte, error) {
|
|
if len(args) != 1 {
|
|
return nil, fmt.Errorf("must provide only a single argument")
|
|
}
|
|
arg, ok := args[0].(string)
|
|
if !ok {
|
|
return nil, fmt.Errorf("argument must be a string: %#v", args[0])
|
|
}
|
|
|
|
arg = strings.ToLower(arg)
|
|
|
|
// Add the null character as a terminator
|
|
arg += "\x00"
|
|
return []byte(arg), nil
|
|
}
|
|
|
|
func (index *CheckIDIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) {
|
|
val, err := index.FromArgs(args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Strip the null terminator, the rest is a prefix
|
|
n := len(val)
|
|
if n > 0 {
|
|
return val[:n-1], nil
|
|
}
|
|
return val, nil
|
|
}
|
|
|
|
// Sessions is used to pull the full list of sessions for use during snapshots.
|
|
func (s *Snapshot) Sessions() (memdb.ResultIterator, error) {
|
|
iter, err := s.tx.Get("sessions", "id")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return iter, nil
|
|
}
|
|
|
|
// Session is used when restoring from a snapshot. For general inserts, use
|
|
// SessionCreate.
|
|
func (s *Restore) Session(sess *structs.Session) error {
|
|
if err := insertSessionTxn(s.tx, sess, sess.ModifyIndex, true, true); err != nil {
|
|
return fmt.Errorf("failed inserting session: %s", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SessionCreate is used to register a new session in the state store.
|
|
func (s *Store) SessionCreate(idx uint64, sess *structs.Session) error {
|
|
tx := s.db.WriteTxn(idx)
|
|
defer tx.Abort()
|
|
|
|
// This code is technically able to (incorrectly) update an existing
|
|
// session but we never do that in practice. The upstream endpoint code
|
|
// always adds a unique ID when doing a create operation so we never hit
|
|
// an existing session again. It isn't worth the overhead to verify
|
|
// that here, but it's worth noting that we should never do this in the
|
|
// future.
|
|
|
|
// Call the session creation
|
|
if err := sessionCreateTxn(tx, idx, sess); err != nil {
|
|
return err
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
// sessionCreateTxn is the inner method used for creating session entries in
|
|
// an open transaction. Any health checks registered with the session will be
|
|
// checked for failing status. Returns any error encountered.
|
|
func sessionCreateTxn(tx *txn, idx uint64, sess *structs.Session) error {
|
|
// Check that we have a session ID
|
|
if sess.ID == "" {
|
|
return ErrMissingSessionID
|
|
}
|
|
|
|
// Verify the session behavior is valid
|
|
switch sess.Behavior {
|
|
case "":
|
|
// Release by default to preserve backwards compatibility
|
|
sess.Behavior = structs.SessionKeysRelease
|
|
case structs.SessionKeysRelease:
|
|
case structs.SessionKeysDelete:
|
|
default:
|
|
return fmt.Errorf("Invalid session behavior: %s", sess.Behavior)
|
|
}
|
|
|
|
// Assign the indexes. ModifyIndex likely will not be used but
|
|
// we set it here anyways for sanity.
|
|
sess.CreateIndex = idx
|
|
sess.ModifyIndex = idx
|
|
|
|
// Check that the node exists
|
|
node, err := tx.First("nodes", "id", sess.Node)
|
|
if err != nil {
|
|
return fmt.Errorf("failed node lookup: %s", err)
|
|
}
|
|
if node == nil {
|
|
return ErrMissingNode
|
|
}
|
|
|
|
// Verify that all session checks exist
|
|
if err := validateSessionChecksTxn(tx, sess); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Insert the session
|
|
if err := insertSessionTxn(tx, sess, idx, false, false); err != nil {
|
|
return fmt.Errorf("failed inserting session: %s", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SessionGet is used to retrieve an active session from the state store.
|
|
func (s *Store) SessionGet(ws memdb.WatchSet,
|
|
sessionID string, entMeta *structs.EnterpriseMeta) (uint64, *structs.Session, error) {
|
|
|
|
tx := s.db.Txn(false)
|
|
defer tx.Abort()
|
|
|
|
// Get the table index.
|
|
idx := sessionMaxIndex(tx, entMeta)
|
|
|
|
// Look up the session by its ID
|
|
watchCh, session, err := firstWatchWithTxn(tx, "sessions", "id", sessionID, entMeta)
|
|
if err != nil {
|
|
return 0, nil, fmt.Errorf("failed session lookup: %s", err)
|
|
}
|
|
ws.Add(watchCh)
|
|
|
|
if session != nil {
|
|
return idx, session.(*structs.Session), nil
|
|
}
|
|
return idx, nil, nil
|
|
}
|
|
|
|
// SessionList returns a slice containing all of the active sessions.
|
|
func (s *Store) SessionList(ws memdb.WatchSet, entMeta *structs.EnterpriseMeta) (uint64, structs.Sessions, error) {
|
|
tx := s.db.Txn(false)
|
|
defer tx.Abort()
|
|
|
|
// Get the table index.
|
|
idx := sessionMaxIndex(tx, entMeta)
|
|
|
|
// Query all of the active sessions.
|
|
sessions, err := getWithTxn(tx, "sessions", "id_prefix", "", entMeta)
|
|
if err != nil {
|
|
return 0, nil, fmt.Errorf("failed session lookup: %s", err)
|
|
}
|
|
ws.Add(sessions.WatchCh())
|
|
|
|
// Go over the sessions and create a slice of them.
|
|
var result structs.Sessions
|
|
for session := sessions.Next(); session != nil; session = sessions.Next() {
|
|
result = append(result, session.(*structs.Session))
|
|
}
|
|
return idx, result, nil
|
|
}
|
|
|
|
// NodeSessions returns a set of active sessions associated
|
|
// with the given node ID. The returned index is the highest
|
|
// index seen from the result set.
|
|
func (s *Store) NodeSessions(ws memdb.WatchSet, nodeID string, entMeta *structs.EnterpriseMeta) (uint64, structs.Sessions, error) {
|
|
tx := s.db.Txn(false)
|
|
defer tx.Abort()
|
|
|
|
// Get the table index.
|
|
idx := sessionMaxIndex(tx, entMeta)
|
|
|
|
// Get all of the sessions which belong to the node
|
|
result, err := nodeSessionsTxn(tx, ws, nodeID, entMeta)
|
|
if err != nil {
|
|
return 0, nil, err
|
|
}
|
|
return idx, result, nil
|
|
}
|
|
|
|
// SessionDestroy is used to remove an active session. This will
|
|
// implicitly invalidate the session and invoke the specified
|
|
// session destroy behavior.
|
|
func (s *Store) SessionDestroy(idx uint64, sessionID string, entMeta *structs.EnterpriseMeta) error {
|
|
tx := s.db.WriteTxn(idx)
|
|
defer tx.Abort()
|
|
|
|
// Call the session deletion.
|
|
if err := s.deleteSessionTxn(tx, idx, sessionID, entMeta); err != nil {
|
|
return err
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
// deleteSessionTxn is the inner method, which is used to do the actual
|
|
// session deletion and handle session invalidation, etc.
|
|
func (s *Store) deleteSessionTxn(tx WriteTxn, idx uint64, sessionID string, entMeta *structs.EnterpriseMeta) error {
|
|
// Look up the session.
|
|
sess, err := firstWithTxn(tx, "sessions", "id", sessionID, entMeta)
|
|
if err != nil {
|
|
return fmt.Errorf("failed session lookup: %s", err)
|
|
}
|
|
if sess == nil {
|
|
return nil
|
|
}
|
|
|
|
// Delete the session and write the new index.
|
|
session := sess.(*structs.Session)
|
|
if err := sessionDeleteWithSession(tx, session, idx); err != nil {
|
|
return fmt.Errorf("failed deleting session: %v", err)
|
|
}
|
|
|
|
// Enforce the max lock delay.
|
|
delay := session.LockDelay
|
|
if delay > structs.MaxLockDelay {
|
|
delay = structs.MaxLockDelay
|
|
}
|
|
|
|
// Snag the current now time so that all the expirations get calculated
|
|
// the same way.
|
|
now := time.Now()
|
|
|
|
// Get an iterator over all of the keys with the given session.
|
|
entries, err := tx.Get("kvs", "session", sessionID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed kvs lookup: %s", err)
|
|
}
|
|
var kvs []interface{}
|
|
for entry := entries.Next(); entry != nil; entry = entries.Next() {
|
|
kvs = append(kvs, entry)
|
|
}
|
|
|
|
// Invalidate any held locks.
|
|
switch session.Behavior {
|
|
case structs.SessionKeysRelease:
|
|
for _, obj := range kvs {
|
|
// Note that we clone here since we are modifying the
|
|
// returned object and want to make sure our set op
|
|
// respects the transaction we are in.
|
|
e := obj.(*structs.DirEntry).Clone()
|
|
e.Session = ""
|
|
if err := kvsSetTxn(tx, idx, e, true); err != nil {
|
|
return fmt.Errorf("failed kvs update: %s", err)
|
|
}
|
|
|
|
// Apply the lock delay if present.
|
|
if delay > 0 {
|
|
s.lockDelay.SetExpiration(e.Key, now, delay, entMeta)
|
|
}
|
|
}
|
|
case structs.SessionKeysDelete:
|
|
for _, obj := range kvs {
|
|
e := obj.(*structs.DirEntry)
|
|
if err := s.kvsDeleteTxn(tx, idx, e.Key, entMeta); err != nil {
|
|
return fmt.Errorf("failed kvs delete: %s", err)
|
|
}
|
|
|
|
// Apply the lock delay if present.
|
|
if delay > 0 {
|
|
s.lockDelay.SetExpiration(e.Key, now, delay, entMeta)
|
|
}
|
|
}
|
|
default:
|
|
return fmt.Errorf("unknown session behavior %#v", session.Behavior)
|
|
}
|
|
|
|
// Delete any check mappings.
|
|
mappings, err := tx.Get("session_checks", "session", sessionID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed session checks lookup: %s", err)
|
|
}
|
|
{
|
|
var objs []interface{}
|
|
for mapping := mappings.Next(); mapping != nil; mapping = mappings.Next() {
|
|
objs = append(objs, mapping)
|
|
}
|
|
|
|
// Do the delete in a separate loop so we don't trash the iterator.
|
|
for _, obj := range objs {
|
|
if err := tx.Delete("session_checks", obj); err != nil {
|
|
return fmt.Errorf("failed deleting session check: %s", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Delete any prepared queries.
|
|
queries, err := tx.Get("prepared-queries", "session", sessionID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed prepared query lookup: %s", err)
|
|
}
|
|
{
|
|
var ids []string
|
|
for wrapped := queries.Next(); wrapped != nil; wrapped = queries.Next() {
|
|
ids = append(ids, toPreparedQuery(wrapped).ID)
|
|
}
|
|
|
|
// Do the delete in a separate loop so we don't trash the iterator.
|
|
for _, id := range ids {
|
|
if err := preparedQueryDeleteTxn(tx, idx, id); err != nil {
|
|
return fmt.Errorf("failed prepared query delete: %s", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|