package state import ( "fmt" "reflect" "strings" "time" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/go-memdb" ) // sessionsTableSchema returns a new table schema used for storing session // information. func sessionsTableSchema() *memdb.TableSchema { return &memdb.TableSchema{ Name: "sessions", Indexes: map[string]*memdb.IndexSchema{ "id": { Name: "id", AllowMissing: false, Unique: true, Indexer: sessionIndexer(), }, "node": { Name: "node", AllowMissing: false, Unique: false, Indexer: nodeSessionsIndexer(), }, }, } } // sessionChecksTableSchema returns a new table schema used for storing session // checks. func sessionChecksTableSchema() *memdb.TableSchema { return &memdb.TableSchema{ Name: "session_checks", Indexes: map[string]*memdb.IndexSchema{ "id": { Name: "id", AllowMissing: false, Unique: true, Indexer: &memdb.CompoundIndex{ Indexes: []memdb.Indexer{ &memdb.StringFieldIndex{ Field: "Node", Lowercase: true, }, &CheckIDIndex{}, &memdb.UUIDFieldIndex{ Field: "Session", }, }, }, }, "node_check": { Name: "node_check", AllowMissing: false, Unique: false, Indexer: nodeChecksIndexer(), }, "session": { Name: "session", AllowMissing: false, Unique: false, Indexer: &memdb.UUIDFieldIndex{ Field: "Session", }, }, }, } } type CheckIDIndex struct { } func (index *CheckIDIndex) FromObject(obj interface{}) (bool, []byte, error) { v := reflect.ValueOf(obj) v = reflect.Indirect(v) // Dereference the pointer if any fv := v.FieldByName("CheckID") isPtr := fv.Kind() == reflect.Ptr fv = reflect.Indirect(fv) if !isPtr && !fv.IsValid() || !fv.CanInterface() { return false, nil, fmt.Errorf("field 'EnterpriseMeta' for %#v is invalid %v ", obj, isPtr) } checkID, ok := fv.Interface().(structs.CheckID) if !ok { return false, nil, fmt.Errorf("Field 'EnterpriseMeta' is not of type structs.EnterpriseMeta") } // Enforce lowercase and add null character as terminator id := strings.ToLower(string(checkID.ID)) + "\x00" return true, []byte(id), nil } func (index *CheckIDIndex) FromArgs(args ...interface{}) ([]byte, error) { if len(args) != 1 { return nil, fmt.Errorf("must provide only a single argument") } arg, ok := args[0].(string) if !ok { return nil, fmt.Errorf("argument must be a string: %#v", args[0]) } arg = strings.ToLower(arg) // Add the null character as a terminator arg += "\x00" return []byte(arg), nil } func (index *CheckIDIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) { val, err := index.FromArgs(args...) if err != nil { return nil, err } // Strip the null terminator, the rest is a prefix n := len(val) if n > 0 { return val[:n-1], nil } return val, nil } func init() { registerSchema(sessionsTableSchema) registerSchema(sessionChecksTableSchema) } // Sessions is used to pull the full list of sessions for use during snapshots. func (s *Snapshot) Sessions() (memdb.ResultIterator, error) { iter, err := s.tx.Get("sessions", "id") if err != nil { return nil, err } return iter, nil } // Session is used when restoring from a snapshot. For general inserts, use // SessionCreate. func (s *Restore) Session(sess *structs.Session) error { if err := insertSessionTxn(s.tx, sess, sess.ModifyIndex, true); err != nil { return fmt.Errorf("failed inserting session: %s", err) } return nil } // SessionCreate is used to register a new session in the state store. func (s *Store) SessionCreate(idx uint64, sess *structs.Session) error { tx := s.db.WriteTxn(idx) defer tx.Abort() // This code is technically able to (incorrectly) update an existing // session but we never do that in practice. The upstream endpoint code // always adds a unique ID when doing a create operation so we never hit // an existing session again. It isn't worth the overhead to verify // that here, but it's worth noting that we should never do this in the // future. // Call the session creation if err := sessionCreateTxn(tx, idx, sess); err != nil { return err } return tx.Commit() } // sessionCreateTxn is the inner method used for creating session entries in // an open transaction. Any health checks registered with the session will be // checked for failing status. Returns any error encountered. func sessionCreateTxn(tx *txn, idx uint64, sess *structs.Session) error { // Check that we have a session ID if sess.ID == "" { return ErrMissingSessionID } // Verify the session behavior is valid switch sess.Behavior { case "": // Release by default to preserve backwards compatibility sess.Behavior = structs.SessionKeysRelease case structs.SessionKeysRelease: case structs.SessionKeysDelete: default: return fmt.Errorf("Invalid session behavior: %s", sess.Behavior) } // Assign the indexes. ModifyIndex likely will not be used but // we set it here anyways for sanity. sess.CreateIndex = idx sess.ModifyIndex = idx // Check that the node exists node, err := tx.First("nodes", "id", sess.Node) if err != nil { return fmt.Errorf("failed node lookup: %s", err) } if node == nil { return ErrMissingNode } // Verify that all session checks exist if err := validateSessionChecksTxn(tx, sess); err != nil { return err } // Insert the session if err := insertSessionTxn(tx, sess, idx, false); err != nil { return fmt.Errorf("failed inserting session: %s", err) } return nil } // SessionGet is used to retrieve an active session from the state store. func (s *Store) SessionGet(ws memdb.WatchSet, sessionID string, entMeta *structs.EnterpriseMeta) (uint64, *structs.Session, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. idx := sessionMaxIndex(tx, entMeta) // Look up the session by its ID watchCh, session, err := firstWatchWithTxn(tx, "sessions", "id", sessionID, entMeta) if err != nil { return 0, nil, fmt.Errorf("failed session lookup: %s", err) } ws.Add(watchCh) if session != nil { return idx, session.(*structs.Session), nil } return idx, nil, nil } // SessionList returns a slice containing all of the active sessions. func (s *Store) SessionList(ws memdb.WatchSet, entMeta *structs.EnterpriseMeta) (uint64, structs.Sessions, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. idx := sessionMaxIndex(tx, entMeta) // Query all of the active sessions. sessions, err := getWithTxn(tx, "sessions", "id_prefix", "", entMeta) if err != nil { return 0, nil, fmt.Errorf("failed session lookup: %s", err) } ws.Add(sessions.WatchCh()) // Go over the sessions and create a slice of them. var result structs.Sessions for session := sessions.Next(); session != nil; session = sessions.Next() { result = append(result, session.(*structs.Session)) } return idx, result, nil } // NodeSessions returns a set of active sessions associated // with the given node ID. The returned index is the highest // index seen from the result set. func (s *Store) NodeSessions(ws memdb.WatchSet, nodeID string, entMeta *structs.EnterpriseMeta) (uint64, structs.Sessions, error) { tx := s.db.Txn(false) defer tx.Abort() // Get the table index. idx := sessionMaxIndex(tx, entMeta) // Get all of the sessions which belong to the node result, err := nodeSessionsTxn(tx, ws, nodeID, entMeta) if err != nil { return 0, nil, err } return idx, result, nil } // SessionDestroy is used to remove an active session. This will // implicitly invalidate the session and invoke the specified // session destroy behavior. func (s *Store) SessionDestroy(idx uint64, sessionID string, entMeta *structs.EnterpriseMeta) error { tx := s.db.WriteTxn(idx) defer tx.Abort() // Call the session deletion. if err := s.deleteSessionTxn(tx, idx, sessionID, entMeta); err != nil { return err } return tx.Commit() } // deleteSessionTxn is the inner method, which is used to do the actual // session deletion and handle session invalidation, etc. func (s *Store) deleteSessionTxn(tx *txn, idx uint64, sessionID string, entMeta *structs.EnterpriseMeta) error { // Look up the session. sess, err := firstWithTxn(tx, "sessions", "id", sessionID, entMeta) if err != nil { return fmt.Errorf("failed session lookup: %s", err) } if sess == nil { return nil } // Delete the session and write the new index. session := sess.(*structs.Session) if err := sessionDeleteWithSession(tx, session, idx); err != nil { return fmt.Errorf("failed deleting session: %v", err) } // Enforce the max lock delay. delay := session.LockDelay if delay > structs.MaxLockDelay { delay = structs.MaxLockDelay } // Snag the current now time so that all the expirations get calculated // the same way. now := time.Now() // Get an iterator over all of the keys with the given session. entries, err := tx.Get("kvs", "session", sessionID) if err != nil { return fmt.Errorf("failed kvs lookup: %s", err) } var kvs []interface{} for entry := entries.Next(); entry != nil; entry = entries.Next() { kvs = append(kvs, entry) } // Invalidate any held locks. switch session.Behavior { case structs.SessionKeysRelease: for _, obj := range kvs { // Note that we clone here since we are modifying the // returned object and want to make sure our set op // respects the transaction we are in. e := obj.(*structs.DirEntry).Clone() e.Session = "" if err := kvsSetTxn(tx, idx, e, true); err != nil { return fmt.Errorf("failed kvs update: %s", err) } // Apply the lock delay if present. if delay > 0 { s.lockDelay.SetExpiration(e.Key, now, delay, entMeta) } } case structs.SessionKeysDelete: for _, obj := range kvs { e := obj.(*structs.DirEntry) if err := s.kvsDeleteTxn(tx, idx, e.Key, entMeta); err != nil { return fmt.Errorf("failed kvs delete: %s", err) } // Apply the lock delay if present. if delay > 0 { s.lockDelay.SetExpiration(e.Key, now, delay, entMeta) } } default: return fmt.Errorf("unknown session behavior %#v", session.Behavior) } // Delete any check mappings. mappings, err := tx.Get("session_checks", "session", sessionID) if err != nil { return fmt.Errorf("failed session checks lookup: %s", err) } { var objs []interface{} for mapping := mappings.Next(); mapping != nil; mapping = mappings.Next() { objs = append(objs, mapping) } // Do the delete in a separate loop so we don't trash the iterator. for _, obj := range objs { if err := tx.Delete("session_checks", obj); err != nil { return fmt.Errorf("failed deleting session check: %s", err) } } } // Delete any prepared queries. queries, err := tx.Get("prepared-queries", "session", sessionID) if err != nil { return fmt.Errorf("failed prepared query lookup: %s", err) } { var ids []string for wrapped := queries.Next(); wrapped != nil; wrapped = queries.Next() { ids = append(ids, toPreparedQuery(wrapped).ID) } // Do the delete in a separate loop so we don't trash the iterator. for _, id := range ids { if err := preparedQueryDeleteTxn(tx, idx, id); err != nil { return fmt.Errorf("failed prepared query delete: %s", err) } } } return nil }