open-consul/agent/consul/state/session.go

453 lines
12 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package state
import (
"fmt"
"reflect"
"strings"
"time"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/structs"
)
const (
tableSessions = "sessions"
tableSessionChecks = "session_checks"
indexNodeCheck = "node_check"
)
func indexFromSession(e *structs.Session) ([]byte, error) {
v := strings.ToLower(e.ID)
if v == "" {
return nil, errMissingValueForIndex
}
var b indexBuilder
b.String(v)
return b.Bytes(), nil
}
// sessionsTableSchema returns a new table schema used for storing session
// information.
func sessionsTableSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: tableSessions,
Indexes: map[string]*memdb.IndexSchema{
indexID: {
Name: indexID,
AllowMissing: false,
Unique: true,
Indexer: sessionIndexer(),
},
indexNode: {
Name: indexNode,
AllowMissing: false,
Unique: false,
Indexer: nodeSessionsIndexer(),
},
},
}
}
// sessionChecksTableSchema returns a new table schema used for storing session
// checks.
func sessionChecksTableSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: tableSessionChecks,
Indexes: map[string]*memdb.IndexSchema{
indexID: {
Name: indexID,
AllowMissing: false,
Unique: true,
Indexer: idCheckIndexer(),
},
indexNodeCheck: {
Name: indexNodeCheck,
AllowMissing: false,
Unique: false,
Indexer: nodeChecksIndexer(),
},
indexSession: {
Name: indexSession,
AllowMissing: false,
Unique: false,
Indexer: sessionCheckIndexer(),
},
},
}
}
// indexNodeFromSession creates an index key from *structs.Session
func indexNodeFromSession(e *structs.Session) ([]byte, error) {
v := strings.ToLower(e.Node)
if v == "" {
return nil, errMissingValueForIndex
}
var b indexBuilder
b.String(v)
return b.Bytes(), nil
}
// indexFromNodeCheckIDSession creates an index key from sessionCheck
func indexFromNodeCheckIDSession(e *sessionCheck) ([]byte, error) {
var b indexBuilder
v := strings.ToLower(e.Node)
if v == "" {
return nil, errMissingValueForIndex
}
b.String(v)
v = strings.ToLower(string(e.CheckID.ID))
if v == "" {
return nil, errMissingValueForIndex
}
b.String(v)
v = strings.ToLower(e.Session)
if v == "" {
return nil, errMissingValueForIndex
}
b.String(v)
return b.Bytes(), nil
}
// indexSessionCheckFromSession creates an index key from sessionCheck
func indexSessionCheckFromSession(e *sessionCheck) ([]byte, error) {
var b indexBuilder
v := strings.ToLower(e.Session)
if v == "" {
return nil, errMissingValueForIndex
}
b.String(v)
return b.Bytes(), nil
}
type CheckIDIndex struct {
}
func (index *CheckIDIndex) FromObject(obj interface{}) (bool, []byte, error) {
v := reflect.ValueOf(obj)
v = reflect.Indirect(v) // Dereference the pointer if any
fv := v.FieldByName("CheckID")
isPtr := fv.Kind() == reflect.Ptr
fv = reflect.Indirect(fv)
if !isPtr && !fv.IsValid() || !fv.CanInterface() {
return false, nil,
fmt.Errorf("field 'EnterpriseMeta' for %#v is invalid %v ", obj, isPtr)
}
checkID, ok := fv.Interface().(structs.CheckID)
if !ok {
return false, nil, fmt.Errorf("Field 'EnterpriseMeta' is not of type structs.EnterpriseMeta")
}
// Enforce lowercase and add null character as terminator
id := strings.ToLower(string(checkID.ID)) + "\x00"
return true, []byte(id), nil
}
func (index *CheckIDIndex) FromArgs(args ...interface{}) ([]byte, error) {
if len(args) != 1 {
return nil, fmt.Errorf("must provide only a single argument")
}
arg, ok := args[0].(string)
if !ok {
return nil, fmt.Errorf("argument must be a string: %#v", args[0])
}
arg = strings.ToLower(arg)
// Add the null character as a terminator
arg += "\x00"
return []byte(arg), nil
}
func (index *CheckIDIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) {
val, err := index.FromArgs(args...)
if err != nil {
return nil, err
}
// Strip the null terminator, the rest is a prefix
n := len(val)
if n > 0 {
return val[:n-1], nil
}
return val, nil
}
// Sessions is used to pull the full list of sessions for use during snapshots.
func (s *Snapshot) Sessions() (memdb.ResultIterator, error) {
iter, err := s.tx.Get(tableSessions, indexID)
if err != nil {
return nil, err
}
return iter, nil
}
// Session is used when restoring from a snapshot. For general inserts, use
// SessionCreate.
func (s *Restore) Session(sess *structs.Session) error {
if err := insertSessionTxn(s.tx, sess, sess.ModifyIndex, true, true); err != nil {
return fmt.Errorf("failed inserting session: %s", err)
}
return nil
}
// SessionCreate is used to register a new session in the state store.
func (s *Store) SessionCreate(idx uint64, sess *structs.Session) error {
tx := s.db.WriteTxn(idx)
defer tx.Abort()
// This code is technically able to (incorrectly) update an existing
// session but we never do that in practice. The upstream endpoint code
// always adds a unique ID when doing a create operation so we never hit
// an existing session again. It isn't worth the overhead to verify
// that here, but it's worth noting that we should never do this in the
// future.
// Call the session creation
if err := sessionCreateTxn(tx, idx, sess); err != nil {
return err
}
return tx.Commit()
}
// sessionCreateTxn is the inner method used for creating session entries in
// an open transaction. Any health checks registered with the session will be
// checked for failing status. Returns any error encountered.
func sessionCreateTxn(tx WriteTxn, idx uint64, sess *structs.Session) error {
// Check that we have a session ID
if sess.ID == "" {
return ErrMissingSessionID
}
// Verify the session behavior is valid
switch sess.Behavior {
case "":
// Release by default to preserve backwards compatibility
sess.Behavior = structs.SessionKeysRelease
case structs.SessionKeysRelease:
case structs.SessionKeysDelete:
default:
return fmt.Errorf("Invalid session behavior: %s", sess.Behavior)
}
// Assign the indexes. ModifyIndex likely will not be used but
// we set it here anyways for sanity.
sess.CreateIndex = idx
sess.ModifyIndex = idx
// Check that the node exists
node, err := tx.First(tableNodes, indexID, Query{Value: sess.Node, EnterpriseMeta: *structs.DefaultEnterpriseMetaInPartition(sess.PartitionOrDefault())})
if err != nil {
return fmt.Errorf("failed node lookup: %s", err)
}
if node == nil {
return ErrMissingNode
}
// Verify that all session checks exist
if err := validateSessionChecksTxn(tx, sess); err != nil {
return err
}
// Insert the session
if err := insertSessionTxn(tx, sess, idx, false, false); err != nil {
return fmt.Errorf("failed inserting session: %s", err)
}
return nil
}
// SessionGet is used to retrieve an active session from the state store.
func (s *Store) SessionGet(ws memdb.WatchSet,
sessionID string, entMeta *acl.EnterpriseMeta) (uint64, *structs.Session, error) {
tx := s.db.Txn(false)
defer tx.Abort()
idx := maxIndexTxnSessions(tx, entMeta)
// Look up the session by its ID
if entMeta == nil {
entMeta = structs.DefaultEnterpriseMetaInDefaultPartition()
}
watchCh, session, err := tx.FirstWatch(tableSessions, indexID, Query{Value: sessionID, EnterpriseMeta: *entMeta})
if err != nil {
return 0, nil, fmt.Errorf("failed session lookup: %s", err)
}
ws.Add(watchCh)
if session != nil {
return idx, session.(*structs.Session), nil
}
return idx, nil, nil
}
// NodeSessions returns a set of active sessions associated
// with the given node ID. The returned index is the highest
// index seen from the result set.
func (s *Store) NodeSessions(ws memdb.WatchSet, nodeID string, entMeta *acl.EnterpriseMeta) (uint64, structs.Sessions, error) {
tx := s.db.Txn(false)
defer tx.Abort()
// Get the table index.
idx := maxIndexTxnSessions(tx, entMeta)
// Get all of the sessions which belong to the node
result, err := nodeSessionsTxn(tx, ws, nodeID, entMeta)
if err != nil {
return 0, nil, err
}
return idx, result, nil
}
// SessionDestroy is used to remove an active session. This will
// implicitly invalidate the session and invoke the specified
// session destroy behavior.
func (s *Store) SessionDestroy(idx uint64, sessionID string, entMeta *acl.EnterpriseMeta) error {
tx := s.db.WriteTxn(idx)
defer tx.Abort()
// Call the session deletion.
if err := s.deleteSessionTxn(tx, idx, sessionID, entMeta); err != nil {
return err
}
return tx.Commit()
}
// deleteSessionTxn is the inner method, which is used to do the actual
// session deletion and handle session invalidation, etc.
func (s *Store) deleteSessionTxn(tx WriteTxn, idx uint64, sessionID string, entMeta *acl.EnterpriseMeta) error {
// Look up the session.
if entMeta == nil {
entMeta = structs.DefaultEnterpriseMetaInDefaultPartition()
}
sess, err := tx.First(tableSessions, indexID, Query{Value: sessionID, EnterpriseMeta: *entMeta})
if err != nil {
return fmt.Errorf("failed session lookup: %s", err)
}
if sess == nil {
return nil
}
// Delete the session and write the new index.
session := sess.(*structs.Session)
if err := sessionDeleteWithSession(tx, session, idx); err != nil {
return fmt.Errorf("failed deleting session: %v", err)
}
// Enforce the max lock delay.
delay := session.LockDelay
if delay > structs.MaxLockDelay {
delay = structs.MaxLockDelay
}
// Snag the current now time so that all the expirations get calculated
// the same way.
now := time.Now()
// Get an iterator over all of the keys with the given session.
entries, err := tx.Get(tableKVs, indexSession, sessionID)
if err != nil {
return fmt.Errorf("failed kvs lookup: %s", err)
}
var kvs []interface{}
for entry := entries.Next(); entry != nil; entry = entries.Next() {
kvs = append(kvs, entry)
}
// Invalidate any held locks.
switch session.Behavior {
case structs.SessionKeysRelease:
for _, obj := range kvs {
// Note that we clone here since we are modifying the
// returned object and want to make sure our set op
// respects the transaction we are in.
e := obj.(*structs.DirEntry).Clone()
e.Session = ""
if err := kvsSetTxn(tx, idx, e, true); err != nil {
return fmt.Errorf("failed kvs update: %s", err)
}
// Apply the lock delay if present.
if delay > 0 {
s.lockDelay.SetExpiration(e.Key, now, delay, entMeta)
}
}
case structs.SessionKeysDelete:
for _, obj := range kvs {
e := obj.(*structs.DirEntry)
if err := s.kvsDeleteTxn(tx, idx, e.Key, entMeta); err != nil {
return fmt.Errorf("failed kvs delete: %s", err)
}
// Apply the lock delay if present.
if delay > 0 {
s.lockDelay.SetExpiration(e.Key, now, delay, entMeta)
}
}
default:
return fmt.Errorf("unknown session behavior %#v", session.Behavior)
}
if entMeta == nil {
entMeta = structs.DefaultEnterpriseMetaInDefaultPartition()
}
// Delete any check mappings.
mappings, err := tx.Get(tableSessionChecks, indexSession, Query{Value: sessionID, EnterpriseMeta: *entMeta})
if err != nil {
return fmt.Errorf("failed session checks lookup: %s", err)
}
{
var objs []interface{}
for mapping := mappings.Next(); mapping != nil; mapping = mappings.Next() {
objs = append(objs, mapping)
}
// Do the delete in a separate loop so we don't trash the iterator.
for _, obj := range objs {
if err := tx.Delete(tableSessionChecks, obj); err != nil {
return fmt.Errorf("failed deleting session check: %s", err)
}
}
}
// Delete any prepared queries.
queries, err := tx.Get("prepared-queries", "session", sessionID)
if err != nil {
return fmt.Errorf("failed prepared query lookup: %s", err)
}
{
var ids []string
for wrapped := queries.Next(); wrapped != nil; wrapped = queries.Next() {
ids = append(ids, toPreparedQuery(wrapped).ID)
}
// Do the delete in a separate loop so we don't trash the iterator.
for _, id := range ids {
if err := preparedQueryDeleteTxn(tx, idx, id); err != nil {
return fmt.Errorf("failed prepared query delete: %s", err)
}
}
}
return nil
}