From 04b365495d53b32afbbcdadf32f935c5291d5c09 Mon Sep 17 00:00:00 2001 From: James Phillips Date: Fri, 25 Sep 2015 12:01:46 -0700 Subject: [PATCH] Completes state store for KV, sessions, tombstones, and nodes/services/checks (needs tests and integration). --- consul/acl_endpoint.go | 4 +- consul/fsm.go | 137 +++--- consul/rpc.go | 10 +- consul/state/delay.go | 54 +++ consul/state/graveyard.go | 103 +++++ consul/state/schema.go | 41 +- consul/state/state_store.go | 721 +++++++++++++++++++++++++------ consul/state/state_store_test.go | 14 +- consul/state/watch.go | 131 +++++- consul/structs/structs.go | 16 + 10 files changed, 990 insertions(+), 241 deletions(-) create mode 100644 consul/state/delay.go create mode 100644 consul/state/graveyard.go diff --git a/consul/acl_endpoint.go b/consul/acl_endpoint.go index 51fd1272c..0cebc9e88 100644 --- a/consul/acl_endpoint.go +++ b/consul/acl_endpoint.go @@ -123,7 +123,7 @@ func (a *ACL) Get(args *structs.ACLSpecificRequest, state := a.srv.fsm.StateNew() return a.srv.blockingRPCNew(&args.QueryOptions, &reply.QueryMeta, - state.GetWatchManager("acls"), + state.GetTableWatch("acls"), func() error { acl, err := state.ACLGet(args.ACL) if acl != nil { @@ -194,7 +194,7 @@ func (a *ACL) List(args *structs.DCSpecificRequest, state := a.srv.fsm.StateNew() return a.srv.blockingRPCNew(&args.QueryOptions, &reply.QueryMeta, - state.GetWatchManager("acls"), + state.GetTableWatch("acls"), func() error { var err error reply.Index, reply.ACLs, err = state.ACLList() diff --git a/consul/fsm.go b/consul/fsm.go index 1e0fb73f8..e8ff2d2dc 100644 --- a/consul/fsm.go +++ b/consul/fsm.go @@ -136,9 +136,9 @@ func (c *consulFSM) decodeRegister(buf []byte, index uint64) interface{} { } func (c *consulFSM) applyRegister(req *structs.RegisterRequest, index uint64) interface{} { - defer metrics.MeasureSince([]string{"consul", "fsm", "register"}, time.Now()) // Apply all updates in a single transaction - if err := c.state.EnsureRegistration(index, req); err != nil { + defer metrics.MeasureSince([]string{"consul", "fsm", "register"}, time.Now()) + if err := c.stateNew.EnsureRegistration(index, req); err != nil { c.logger.Printf("[INFO] consul.fsm: EnsureRegistration failed: %v", err) return err } @@ -154,17 +154,17 @@ func (c *consulFSM) applyDeregister(buf []byte, index uint64) interface{} { // Either remove the service entry or the whole node if req.ServiceID != "" { - if err := c.state.DeleteNodeService(index, req.Node, req.ServiceID); err != nil { + if err := c.stateNew.DeleteService(index, req.Node, req.ServiceID); err != nil { c.logger.Printf("[INFO] consul.fsm: DeleteNodeService failed: %v", err) return err } } else if req.CheckID != "" { - if err := c.state.DeleteNodeCheck(index, req.Node, req.CheckID); err != nil { + if err := c.stateNew.DeleteCheck(index, req.Node, req.CheckID); err != nil { c.logger.Printf("[INFO] consul.fsm: DeleteNodeCheck failed: %v", err) return err } } else { - if err := c.state.DeleteNode(index, req.Node); err != nil { + if err := c.stateNew.DeleteNode(index, req.Node); err != nil { c.logger.Printf("[INFO] consul.fsm: DeleteNode failed: %v", err) return err } @@ -180,34 +180,34 @@ func (c *consulFSM) applyKVSOperation(buf []byte, index uint64) interface{} { defer metrics.MeasureSince([]string{"consul", "fsm", "kvs", string(req.Op)}, time.Now()) switch req.Op { case structs.KVSSet: - return c.state.KVSSet(index, &req.DirEnt) + return c.stateNew.KVSSet(index, &req.DirEnt) case structs.KVSDelete: - return c.state.KVSDelete(index, req.DirEnt.Key) + return c.stateNew.KVSDelete(index, req.DirEnt.Key) case structs.KVSDeleteCAS: - act, err := c.state.KVSDeleteCheckAndSet(index, req.DirEnt.Key, req.DirEnt.ModifyIndex) + act, err := c.stateNew.KVSDeleteCAS(index, req.DirEnt.ModifyIndex, req.DirEnt.Key) if err != nil { return err } else { return act } case structs.KVSDeleteTree: - return c.state.KVSDeleteTree(index, req.DirEnt.Key) + return c.stateNew.KVSDeleteTree(index, req.DirEnt.Key) case structs.KVSCAS: - act, err := c.state.KVSCheckAndSet(index, &req.DirEnt) + act, err := c.stateNew.KVSSetCAS(index, &req.DirEnt) if err != nil { return err } else { return act } case structs.KVSLock: - act, err := c.state.KVSLock(index, &req.DirEnt) + act, err := c.stateNew.KVSLock(index, &req.DirEnt) if err != nil { return err } else { return act } case structs.KVSUnlock: - act, err := c.state.KVSUnlock(index, &req.DirEnt) + act, err := c.stateNew.KVSUnlock(index, &req.DirEnt) if err != nil { return err } else { @@ -228,13 +228,13 @@ func (c *consulFSM) applySessionOperation(buf []byte, index uint64) interface{} defer metrics.MeasureSince([]string{"consul", "fsm", "session", string(req.Op)}, time.Now()) switch req.Op { case structs.SessionCreate: - if err := c.state.SessionCreate(index, &req.Session); err != nil { + if err := c.stateNew.SessionCreate(index, &req.Session); err != nil { return err } else { return req.Session.ID } case structs.SessionDestroy: - return c.state.SessionDestroy(index, req.Session.ID) + return c.stateNew.SessionDestroy(index, req.Session.ID) default: c.logger.Printf("[WARN] consul.fsm: Invalid Session operation '%s'", req.Op) return fmt.Errorf("Invalid Session operation '%s'", req.Op) @@ -270,7 +270,7 @@ func (c *consulFSM) applyTombstoneOperation(buf []byte, index uint64) interface{ defer metrics.MeasureSince([]string{"consul", "fsm", "tombstone", string(req.Op)}, time.Now()) switch req.Op { case structs.TombstoneReap: - return c.state.ReapTombstones(req.ReapIndex) + return c.stateNew.ReapTombstones(req.ReapIndex) default: c.logger.Printf("[WARN] consul.fsm: Invalid Tombstone operation '%s'", req.Op) return fmt.Errorf("Invalid Tombstone operation '%s'", req.Op) @@ -300,12 +300,12 @@ func (c *consulFSM) Restore(old io.ReadCloser) error { } // Create a new state store - state, err := NewStateStorePath(c.gc, tmpPath, c.logOutput) + store, err := NewStateStorePath(c.gc, tmpPath, c.logOutput) if err != nil { return err } c.state.Close() - c.state = state + c.state = store // Create a decoder dec := codec.NewDecoder(old, msgpackHandle) @@ -341,7 +341,7 @@ func (c *consulFSM) Restore(old io.ReadCloser) error { if err := dec.Decode(&req); err != nil { return err } - if err := c.state.KVSRestore(&req); err != nil { + if err := c.stateNew.KVSRestore(&req); err != nil { return err } @@ -350,7 +350,7 @@ func (c *consulFSM) Restore(old io.ReadCloser) error { if err := dec.Decode(&req); err != nil { return err } - if err := c.state.SessionRestore(&req); err != nil { + if err := c.stateNew.SessionRestore(&req); err != nil { return err } @@ -368,7 +368,15 @@ func (c *consulFSM) Restore(old io.ReadCloser) error { if err := dec.Decode(&req); err != nil { return err } - if err := c.state.TombstoneRestore(&req); err != nil { + + // For historical reasons, these are serialized in the + // snapshots as KV entries. We want to keep the snapshot + // format compatible with pre-0.6 versions for now. + stone := &state.Tombstone{ + Key: req.Key, + Index: req.ModifyIndex, + } + if err := c.stateNew.TombstoneRestore(stone); err != nil { return err } @@ -387,7 +395,7 @@ func (s *consulSnapshot) Persist(sink raft.SnapshotSink) error { // Write the header header := snapshotHeader{ - LastIndex: s.state.LastIndex(), + LastIndex: s.stateNew.LastIndex(), } if err := encoder.Encode(&header); err != nil { sink.Cancel() @@ -423,8 +431,12 @@ func (s *consulSnapshot) Persist(sink raft.SnapshotSink) error { func (s *consulSnapshot) persistNodes(sink raft.SnapshotSink, encoder *codec.Encoder) error { + // Get all the nodes - nodes := s.state.Nodes() + nodes, err := s.stateNew.NodeDump() + if err != nil { + return err + } // Register each node var req structs.RegisterRequest @@ -441,8 +453,11 @@ func (s *consulSnapshot) persistNodes(sink raft.SnapshotSink, } // Register each service this node has - services := s.state.NodeServices(nodes[i].Node) - for _, srv := range services.Services { + services, err := s.stateNew.ServiceDump(nodes[i].Node) + if err != nil { + return err + } + for _, srv := range services { req.Service = srv sink.Write([]byte{byte(structs.RegisterRequestType)}) if err := encoder.Encode(&req); err != nil { @@ -452,7 +467,10 @@ func (s *consulSnapshot) persistNodes(sink raft.SnapshotSink, // Register each check this node has req.Service = nil - checks := s.state.NodeChecks(nodes[i].Node) + checks, err := s.stateNew.CheckDump(nodes[i].Node) + if err != nil { + return err + } for _, check := range checks { req.Check = check sink.Write([]byte{byte(structs.RegisterRequestType)}) @@ -466,7 +484,7 @@ func (s *consulSnapshot) persistNodes(sink raft.SnapshotSink, func (s *consulSnapshot) persistSessions(sink raft.SnapshotSink, encoder *codec.Encoder) error { - sessions, err := s.state.SessionList() + sessions, err := s.stateNew.SessionDump() if err != nil { return err } @@ -482,7 +500,7 @@ func (s *consulSnapshot) persistSessions(sink raft.SnapshotSink, func (s *consulSnapshot) persistACLs(sink raft.SnapshotSink, encoder *codec.Encoder) error { - acls, err := s.stateNew.ACLList() + acls, err := s.stateNew.ACLDump() if err != nil { return err } @@ -498,58 +516,47 @@ func (s *consulSnapshot) persistACLs(sink raft.SnapshotSink, func (s *consulSnapshot) persistKV(sink raft.SnapshotSink, encoder *codec.Encoder) error { - streamCh := make(chan interface{}, 256) - errorCh := make(chan error) - go func() { - if err := s.state.KVSDump(streamCh); err != nil { - errorCh <- err - } - }() + entries, err := s.stateNew.KVSDump() + if err != nil { + return err + } - for { - select { - case raw := <-streamCh: - if raw == nil { - return nil - } - sink.Write([]byte{byte(structs.KVSRequestType)}) - if err := encoder.Encode(raw); err != nil { - return err - } - - case err := <-errorCh: + for _, e := range entries { + sink.Write([]byte{byte(structs.KVSRequestType)}) + if err := encoder.Encode(e); err != nil { return err } } + return nil } func (s *consulSnapshot) persistTombstones(sink raft.SnapshotSink, encoder *codec.Encoder) error { - streamCh := make(chan interface{}, 256) - errorCh := make(chan error) - go func() { - if err := s.state.TombstoneDump(streamCh); err != nil { - errorCh <- err + stones, err := s.stateNew.TombstoneDump() + if err != nil { + return err + } + + for _, s := range stones { + sink.Write([]byte{byte(structs.TombstoneRequestType)}) + + // For historical reasons, these are serialized in the snapshots + // as KV entries. We want to keep the snapshot format compatible + // with pre-0.6 versions for now. + fake := &structs.DirEntry{ + Key: s.Key, + RaftIndex: structs.RaftIndex{ + ModifyIndex: s.Index, + }, } - }() - - for { - select { - case raw := <-streamCh: - if raw == nil { - return nil - } - sink.Write([]byte{byte(structs.TombstoneRequestType)}) - if err := encoder.Encode(raw); err != nil { - return err - } - - case err := <-errorCh: + if err := encoder.Encode(fake); err != nil { return err } } + return nil } func (s *consulSnapshot) Release() { s.state.Close() + s.stateNew.Close() } diff --git a/consul/rpc.go b/consul/rpc.go index 6860ddaea..c4c912ee4 100644 --- a/consul/rpc.go +++ b/consul/rpc.go @@ -400,7 +400,7 @@ RUN_QUERY: // TODO(slackpad) func (s *Server) blockingRPCNew(queryOpts *structs.QueryOptions, queryMeta *structs.QueryMeta, - watch state.WatchManager, run func() error) error { + watch state.Watch, run func() error) error { var timeout *time.Timer var notifyCh chan struct{} @@ -409,9 +409,9 @@ func (s *Server) blockingRPCNew(queryOpts *structs.QueryOptions, queryMeta *stru goto RUN_QUERY } - // Make sure a watch manager was given if we were asked to block. + // Make sure a watch was given if we were asked to block. if watch == nil { - panic("no watch manager given for blocking query") + panic("no watch given for blocking query") } // Restrict the max query time, and ensure there is always one. @@ -433,13 +433,13 @@ func (s *Server) blockingRPCNew(queryOpts *structs.QueryOptions, queryMeta *stru // Ensure we tear down any watches on return. defer func() { timeout.Stop() - watch.Stop(notifyCh) + watch.Clear(notifyCh) }() REGISTER_NOTIFY: // Register the notification channel. This may be done multiple times if // we haven't reached the target wait index. - watch.Start(notifyCh) + watch.Wait(notifyCh) RUN_QUERY: // Update the query metadata. diff --git a/consul/state/delay.go b/consul/state/delay.go new file mode 100644 index 000000000..206fe4da6 --- /dev/null +++ b/consul/state/delay.go @@ -0,0 +1,54 @@ +package state + +import ( + "sync" + "time" +) + +// Delay is used to mark certain locks as unacquirable. When a lock is +// forcefully released (failing health check, destroyed session, etc.), it is +// subject to the LockDelay impossed by the session. This prevents another +// session from acquiring the lock for some period of time as a protection +// against split-brains. This is inspired by the lock-delay in Chubby. Because +// this relies on wall-time, we cannot assume all peers perceive time as flowing +// uniformly. This means KVSLock MUST ignore lockDelay, since the lockDelay may +// have expired on the leader, but not on the follower. Rejecting the lock could +// result in inconsistencies in the FSMs due to the rate time progresses. Instead, +// only the opinion of the leader is respected, and the Raft log is never +// questioned. +type Delay struct { + // delay has the set of active delay expiration times, organized by key. + delay map[string]time.Time + + // lock protects the delay map. + lock sync.RWMutex +} + +// NewDelay returns a new delay manager. +func NewDelay() *Delay { + return &Delay{delay: make(map[string]time.Time)} +} + +// GetExpiration returns the expiration time of a key lock delay. This must be +// checked on the leader node, and not in KVSLock due to the variability of +// clocks. +func (d *Delay) GetExpiration(key string) time.Time { + d.lock.RLock() + expires := d.delay[key] + d.lock.RUnlock() + return expires +} + +// SetExpiration sets the expiration time for the lock delay to the given +// delay from the given now time. +func (d *Delay) SetExpiration(key string, now time.Time, delay time.Duration) { + d.lock.Lock() + defer d.lock.Unlock() + + d.delay[key] = now.Add(delay) + time.AfterFunc(delay, func() { + d.lock.Lock() + delete(d.delay, key) + d.lock.Unlock() + }) +} diff --git a/consul/state/graveyard.go b/consul/state/graveyard.go new file mode 100644 index 000000000..a500159cd --- /dev/null +++ b/consul/state/graveyard.go @@ -0,0 +1,103 @@ +package state + +import ( + "fmt" + + "github.com/hashicorp/go-memdb" +) + +// tombstone is the internal type used to track tombstones. +type Tombstone struct { + Key string + Index uint64 +} + +// Graveyard manages a set of tombstones for a table. This is just used for +// KVS right now but we've broken it out for other table types later. +type Graveyard struct { + Table string +} + +// NewGraveyard returns a new graveyard. +func NewGraveyard(table string) *Graveyard { + return &Graveyard{Table: "tombstones_" + table} +} + +// InsertTxn adds a new tombstone. +func (g *Graveyard) InsertTxn(tx *memdb.Txn, context string, idx uint64) error { + stone := &Tombstone{Key: context, Index: idx} + if err := tx.Insert(g.Table, stone); err != nil { + return fmt.Errorf("failed inserting tombstone: %s", err) + } + + if err := tx.Insert("index", &IndexEntry{g.Table, idx}); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + return nil +} + +// GetMaxIndexTxn returns the highest index tombstone whose key matches the +// given context, using a prefix match. +func (g *Graveyard) GetMaxIndexTxn(tx *memdb.Txn, context string) (uint64, error) { + stones, err := tx.Get(g.Table, "id", context) + if err != nil { + return 0, fmt.Errorf("failed querying tombstones: %s", err) + } + + var lindex uint64 + for stone := stones.Next(); stone != nil; stone = stones.Next() { + r := stone.(*Tombstone) + if r.Index > lindex { + lindex = r.Index + } + } + return lindex, nil +} + +// DumpTxn returns all the tombstones. +func (g *Graveyard) DumpTxn(tx *memdb.Txn) ([]*Tombstone, error) { + stones, err := tx.Get(g.Table, "id", "") + if err != nil { + return nil, fmt.Errorf("failed querying tombstones: %s", err) + } + + var dump []*Tombstone + for stone := stones.Next(); stone != nil; stone = stones.Next() { + dump = append(dump, stone.(*Tombstone)) + } + return dump, nil +} + +// RestoreTxn is used when restoring from a snapshot. For general inserts, use +// InsertTxn. +func (g *Graveyard) RestoreTxn(tx *memdb.Txn, stone *Tombstone) error { + if err := tx.Insert(g.Table, stone); err != nil { + return fmt.Errorf("failed inserting tombstone: %s", err) + } + + if err := indexUpdateMaxTxn(tx, stone.Index, g.Table); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + return nil +} + +// ReapTxn cleans out all tombstones whose index values are less than or equal +// to the given idx. This prevents unbounded storage growth of the tombstones. +func (g *Graveyard) ReapTxn(tx *memdb.Txn, idx uint64) error { + // This does a full table scan since we currently can't index on a + // numeric value. Since this is all in-memory and done infrequently + // this pretty reasonable. + stones, err := tx.Get(g.Table, "id", "") + if err != nil { + return fmt.Errorf("failed querying tombstones: %s", err) + } + + for stone := stones.Next(); stone != nil; stone = stones.Next() { + if stone.(*Tombstone).Index <= idx { + if err := tx.Delete(g.Table, stone); err != nil { + return fmt.Errorf("failed deleting tombstone: %s", err) + } + } + } + return nil +} diff --git a/consul/state/schema.go b/consul/state/schema.go index 81faccb41..cd8d32b68 100644 --- a/consul/state/schema.go +++ b/consul/state/schema.go @@ -25,7 +25,7 @@ func stateStoreSchema() *memdb.DBSchema { servicesTableSchema, checksTableSchema, kvsTableSchema, - tombstonesTableSchema, + func() *memdb.TableSchema { return tombstonesTableSchema("kvs") }, sessionsTableSchema, sessionChecksTableSchema, aclsTableSchema, @@ -177,7 +177,6 @@ func checksTableSchema() *memdb.TableSchema { Lowercase: true, }, }, - // TODO(slackpad): This one is new, where is it used? "node_service": &memdb.IndexSchema{ Name: "node_service", AllowMissing: true, @@ -231,11 +230,11 @@ func kvsTableSchema() *memdb.TableSchema { } // tombstonesTableSchema returns a new table schema used for -// storing tombstones during kvs delete operations to prevent -// the index from sliding backwards. -func tombstonesTableSchema() *memdb.TableSchema { +// storing tombstones during the given table's delete operations +// to prevent the index from sliding backwards. +func tombstonesTableSchema(table string) *memdb.TableSchema { return &memdb.TableSchema{ - Name: "tombstones", + Name: "tombstones_" + table, Indexes: map[string]*memdb.IndexSchema{ "id": &memdb.IndexSchema{ Name: "id", @@ -305,21 +304,10 @@ func sessionChecksTableSchema() *memdb.TableSchema { }, }, }, - // TODO(slackpad): Where did these come from? - "session": &memdb.IndexSchema{ - Name: "session", + "node_check": &memdb.IndexSchema{ + Name: "node_check", AllowMissing: false, - Unique: true, - Indexer: &memdb.StringFieldIndex{ - Field: "Session", - Lowercase: false, - }, - }, - // TODO(slackpad): Should this be called node_session? - "node": &memdb.IndexSchema{ - Name: "node", - AllowMissing: false, - Unique: true, + Unique: false, Indexer: &memdb.CompoundIndex{ Indexes: []memdb.Indexer{ &memdb.StringFieldIndex{ @@ -327,12 +315,21 @@ func sessionChecksTableSchema() *memdb.TableSchema { Lowercase: true, }, &memdb.StringFieldIndex{ - Field: "Session", - Lowercase: false, + Field: "CheckID", + Lowercase: true, }, }, }, }, + "session": &memdb.IndexSchema{ + Name: "session", + AllowMissing: false, + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "Session", + Lowercase: false, + }, + }, }, } } diff --git a/consul/state/state_store.go b/consul/state/state_store.go index 7de24c13f..b1a607e2e 100644 --- a/consul/state/state_store.go +++ b/consul/state/state_store.go @@ -6,6 +6,7 @@ import ( "io" "log" "strings" + "time" "github.com/hashicorp/consul/consul/structs" "github.com/hashicorp/go-memdb" @@ -34,15 +35,27 @@ var ( // pairs and more. The DB is entirely in-memory and is constructed // from the Raft log through the FSM. type StateStore struct { - logger *log.Logger // TODO(slackpad) - Delete if unused! - schema *memdb.DBSchema - db *memdb.MemDB - watches map[string]WatchManager + logger *log.Logger // TODO(slackpad) - Delete if unused! + schema *memdb.DBSchema + db *memdb.MemDB + + // tableWatches holds all the full table watches, indexed by table name. + tableWatches map[string]*FullTableWatch + + // kvsWatch holds the special prefix watch for the key value store. + kvsWatch *PrefixWatch + + // kvsGraveyard manages tombstones for the key value store. + kvsGraveyard *Graveyard + + // lockDelay holds expiration times for locks associated with keys. + lockDelay *Delay } // StateSnapshot is used to provide a point-in-time snapshot. It // works by starting a read transaction against the whole state store. type StateSnapshot struct { + store *StateStore tx *memdb.Txn lastIndex uint64 } @@ -72,18 +85,25 @@ func NewStateStore(logOutput io.Writer) (*StateStore, error) { return nil, fmt.Errorf("Failed setting up state store: %s", err) } - // Build up the watch managers. - watches, err := newWatchManagers(schema) - if err != nil { - return nil, fmt.Errorf("Failed to build watch managers: %s", err) + // Build up the all-table watches. + tableWatches := make(map[string]*FullTableWatch) + for table, _ := range schema.Tables { + if table == "kvs" { + continue + } + + tableWatches[table] = NewFullTableWatch() } // Create and return the state store. s := &StateStore{ - logger: log.New(logOutput, "", log.LstdFlags), - schema: schema, - db: db, - watches: watches, + logger: log.New(logOutput, "", log.LstdFlags), + schema: schema, + db: db, + tableWatches: tableWatches, + kvsWatch: NewPrefixWatch(), + kvsGraveyard: NewGraveyard("kvs"), + lockDelay: NewDelay(), } return s, nil } @@ -98,7 +118,7 @@ func (s *StateStore) Snapshot() *StateSnapshot { } idx := maxIndexTxn(tx, tables...) - return &StateSnapshot{tx, idx} + return &StateSnapshot{s, tx, idx} } // LastIndex returns that last index that affects the snapshotted data. @@ -111,8 +131,92 @@ func (s *StateSnapshot) Close() { s.tx.Abort() } -// ACLList is used to pull all the ACLs from the snapshot. -func (s *StateSnapshot) ACLList() ([]*structs.ACL, error) { +// NodeDump is used to pull the full list of nodes for use during snapshots. +func (s *StateSnapshot) NodeDump() (structs.Nodes, error) { + nodes, err := s.tx.Get("nodes", "id") + if err != nil { + return nil, fmt.Errorf("failed node lookup: %s", err) + } + + var dump structs.Nodes + for node := nodes.Next(); node != nil; node = nodes.Next() { + dump = append(dump, node.(*structs.Node)) + } + return dump, nil +} + +// ServiceDump is used to pull the full list of services for a given node for use +// during snapshots. +func (s *StateSnapshot) ServiceDump(node string) ([]*structs.NodeService, error) { + services, err := s.tx.Get("services", "node", node) + if err != nil { + return nil, fmt.Errorf("failed service lookup: %s", err) + } + + var dump []*structs.NodeService + for service := services.Next(); service != nil; service = services.Next() { + s := service.(*structs.ServiceNode) + dump = append(dump, &structs.NodeService{ + ID: s.ServiceID, + Service: s.ServiceName, + Tags: s.ServiceTags, + Address: s.ServiceAddress, + Port: s.ServicePort, + }) + } + return dump, nil +} + +// CheckDump is used to pull the full list of checks for a given node for use +// during snapshots. +func (s *StateSnapshot) CheckDump(node string) (structs.HealthChecks, error) { + checks, err := s.tx.Get("checks", "node", node) + if err != nil { + return nil, fmt.Errorf("failed check lookup: %s", err) + } + + var dump structs.HealthChecks + for check := checks.Next(); check != nil; check = checks.Next() { + dump = append(dump, check.(*structs.HealthCheck)) + } + return dump, nil +} + +// KVSDump is used to pull the full list of KVS entries for use during snapshots. +func (s *StateSnapshot) KVSDump() (structs.DirEntries, error) { + entries, err := s.tx.Get("kvs", "id_prefix") + if err != nil { + return nil, fmt.Errorf("failed kvs lookup: %s", err) + } + + var dump structs.DirEntries + for entry := entries.Next(); entry != nil; entry = entries.Next() { + dump = append(dump, entry.(*structs.DirEntry)) + } + return dump, nil +} + +// TombstoneDump is used to pull all the tombstones from the graveyard. +func (s *StateSnapshot) TombstoneDump() ([]*Tombstone, error) { + return s.store.kvsGraveyard.DumpTxn(s.tx) +} + +// SessionDump is used to pull the full list of sessions for use during snapshots. +func (s *StateSnapshot) SessionDump() (structs.Sessions, error) { + sessions, err := s.tx.Get("sessions", "id") + if err != nil { + return nil, fmt.Errorf("failed session lookup: %s", err) + } + + var dump structs.Sessions + for session := sessions.Next(); session != nil; session = sessions.Next() { + dump = append(dump, session.(*structs.Session)) + } + return dump, nil +} + +// ACLDump is used to pull all the ACLs from the snapshot. +func (s *StateSnapshot) ACLDump() (structs.ACLs, error) { _, ret, err := aclListTxn(s.tx) return ret, err } @@ -132,7 +236,7 @@ func maxIndexTxn(tx *memdb.Txn, tables ...string) uint64 { for _, table := range tables { ti, err := tx.First("index", "id", table) if err != nil { - panic(fmt.Sprintf("unknown index: %s", table)) + panic(fmt.Sprintf("unknown index: %s err: %s", table, err)) } if idx, ok := ti.(*IndexEntry); ok && idx.Value > lindex { lindex = idx.Value @@ -144,21 +248,12 @@ func maxIndexTxn(tx *memdb.Txn, tables ...string) uint64 { // indexUpdateMaxTxn is used when restoring entries and sets the table's index to // the given idx only if it's greater than the current index. func indexUpdateMaxTxn(tx *memdb.Txn, idx uint64, table string) error { - raw, err := tx.First("index", "id", table) + ti, err := tx.First("index", "id", table) if err != nil { return fmt.Errorf("failed to retrieve existing index: %s", err) } - if raw == nil { - return fmt.Errorf("missing index for table %s", table) - } - - entry, ok := raw.(*IndexEntry) - if !ok { - return fmt.Errorf("unexpected index type for table %s", table) - } - - if idx > entry.Value { + if cur, ok := ti.(*IndexEntry); ok && idx > cur.Value { if err := tx.Insert("index", &IndexEntry{table, idx}); err != nil { return fmt.Errorf("failed updating index %s", err) } @@ -167,16 +262,69 @@ func indexUpdateMaxTxn(tx *memdb.Txn, idx uint64, table string) error { return nil } -// getWatchManager returns a watch manager for the given set of tables. The -// order of the tables is not important. -func (s *StateStore) GetWatchManager(tables ...string) WatchManager { - if len(tables) == 1 { - if manager, ok := s.watches[tables[0]]; ok { - return manager +// ReapTombstones is used to delete all the tombstones with an index +// less than or equal to the given index. This is used to prevent +// unbounded storage growth of the tombstones. +func (s *StateStore) ReapTombstones(index uint64) error { + tx := s.db.Txn(true) + defer tx.Abort() + + if err := s.kvsGraveyard.ReapTxn(tx, index); err != nil { + return fmt.Errorf("failed to reap kvs tombstones: %s", err) + } + + tx.Commit() + return nil +} + +// GetTableWatch returns a watch for the given table. +func (s *StateStore) GetTableWatch(table string) Watch { + if watch, ok := s.tableWatches[table]; ok { + return watch + } + + panic(fmt.Sprintf("Unknown watch for table %#s", table)) +} + +// GetKVSWatch returns a watch for the given prefix in the key value store. +func (s *StateStore) GetKVSWatch(prefix string) Watch { + return s.kvsWatch.GetSubwatch(prefix) +} + +// EnsureRegistration is used to make sure a node, service, and check +// registration is performed within a single transaction to avoid race +// conditions on state updates. +func (s *StateStore) EnsureRegistration(idx uint64, req *structs.RegisterRequest) error { + tx := s.db.Txn(true) + defer tx.Abort() + + // Add the node. + node := &structs.Node{Node: req.Node, Address: req.Address} + if err := s.ensureNodeTxn(tx, idx, node); err != nil { + return fmt.Errorf("failed inserting node: %s", err) + } + + // Add the service, if any. + if req.Service != nil { + if err := s.ensureServiceTxn(tx, idx, req.Node, req.Service); err != nil { + return fmt.Errorf("failed inserting service: %s", err) } } - panic(fmt.Sprintf("Unknown watch manager(s): %v", tables)) + // Add the checks, if any. + if req.Check != nil { + if err := s.ensureCheckTxn(tx, idx, req.Check); err != nil { + return fmt.Errorf("failed inserting check: %s", err) + } + } + for _, check := range req.Checks { + if err := s.ensureCheckTxn(tx, idx, check); err != nil { + return fmt.Errorf("failed inserting check: %s", err) + } + } + + tx.Commit() + return nil } // EnsureNode is used to upsert node registration or modification. @@ -185,7 +333,7 @@ func (s *StateStore) EnsureNode(idx uint64, node *structs.Node) error { defer tx.Abort() // Call the node upsert - if err := ensureNodeTxn(tx, idx, node); err != nil { + if err := s.ensureNodeTxn(tx, idx, node); err != nil { return err } @@ -196,7 +344,7 @@ func (s *StateStore) EnsureNode(idx uint64, node *structs.Node) error { // ensureNodeTxn is the inner function called to actually create a node // registration or modify an existing one in the state store. It allows // passing in a memdb transaction so it may be part of a larger txn. -func ensureNodeTxn(tx *memdb.Txn, idx uint64, node *structs.Node) error { +func (s *StateStore) ensureNodeTxn(tx *memdb.Txn, idx uint64, node *structs.Node) error { // Check for an existing node existing, err := tx.First("nodes", "id", node.Node) if err != nil { @@ -219,6 +367,8 @@ func ensureNodeTxn(tx *memdb.Txn, idx uint64, node *structs.Node) error { if err := tx.Insert("index", &IndexEntry{"nodes", idx}); err != nil { return fmt.Errorf("failed updating index: %s", err) } + + tx.Defer(func() { s.tableWatches["nodes"].Notify() }) return nil } @@ -269,7 +419,7 @@ func (s *StateStore) DeleteNode(idx uint64, nodeID string) error { defer tx.Abort() // Call the node deletion. - if err := deleteNodeTxn(tx, idx, nodeID); err != nil { + if err := s.deleteNodeTxn(tx, idx, nodeID); err != nil { return err } @@ -279,8 +429,8 @@ func (s *StateStore) DeleteNode(idx uint64, nodeID string) error { // deleteNodeTxn is the inner method used for removing a node from // the store within a given transaction. -func deleteNodeTxn(tx *memdb.Txn, idx uint64, nodeID string) error { - // Look up the node +func (s *StateStore) deleteNodeTxn(tx *memdb.Txn, idx uint64, nodeID string) error { + // Look up the node. node, err := tx.First("nodes", "id", nodeID) if err != nil { return fmt.Errorf("node lookup failed: %s", err) @@ -289,31 +439,36 @@ func deleteNodeTxn(tx *memdb.Txn, idx uint64, nodeID string) error { return nil } - // Delete all services associated with the node and update the service index + // Use a watch manager since the inner functions can perform multiple + // ops per table. + watches := NewDumbWatchManager(s.tableWatches) + watches.Arm("nodes") + + // Delete all services associated with the node and update the service index. services, err := tx.Get("services", "node", nodeID) if err != nil { return fmt.Errorf("failed service lookup: %s", err) } for service := services.Next(); service != nil; service = services.Next() { svc := service.(*structs.ServiceNode) - if err := deleteServiceTxn(tx, idx, nodeID, svc.ServiceID); err != nil { + if err := s.deleteServiceTxn(tx, idx, watches, nodeID, svc.ServiceID); err != nil { return err } } - // Delete all checks associated with the node and update the check index + // Delete all checks associated with the node and update the check index. checks, err := tx.Get("checks", "node", nodeID) if err != nil { return fmt.Errorf("failed check lookup: %s", err) } for check := checks.Next(); check != nil; check = checks.Next() { - chk := check.(*structs.HealthCheck) - if err := deleteCheckTxn(tx, idx, nodeID, chk.CheckID); err != nil { + hc := check.(*structs.HealthCheck) + if err := s.deleteCheckTxn(tx, idx, watches, nodeID, hc.CheckID); err != nil { return err } } - // Delete the node and update the index + // Delete the node and update the index. if err := tx.Delete("nodes", node); err != nil { return fmt.Errorf("failed deleting node: %s", err) } @@ -321,8 +476,19 @@ func deleteNodeTxn(tx *memdb.Txn, idx uint64, nodeID string) error { return fmt.Errorf("failed updating index: %s", err) } - // TODO: session invalidation - // TODO: watch trigger + // Invalidate any sessions for this node. + sessions, err := tx.Get("sessions", "node", nodeID) + if err != nil { + return fmt.Errorf("failed session lookup: %s", err) + } + for sess := sessions.Next(); sess != nil; sess = sessions.Next() { + session := sess.(*structs.Session).ID + if err := s.deleteSessionTxn(tx, idx, watches, session); err != nil { + return fmt.Errorf("failed session delete: %s", err) + } + } + + tx.Defer(func() { watches.Notify() }) return nil } @@ -332,7 +498,7 @@ func (s *StateStore) EnsureService(idx uint64, node string, svc *structs.NodeSer defer tx.Abort() // Call the service registration upsert - if err := ensureServiceTxn(tx, idx, node, svc); err != nil { + if err := s.ensureServiceTxn(tx, idx, node, svc); err != nil { return err } @@ -342,7 +508,7 @@ func (s *StateStore) EnsureService(idx uint64, node string, svc *structs.NodeSer // ensureServiceTxn is used to upsert a service registration within an // existing memdb transaction. -func ensureServiceTxn(tx *memdb.Txn, idx uint64, node string, svc *structs.NodeService) error { +func (s *StateStore) ensureServiceTxn(tx *memdb.Txn, idx uint64, node string, svc *structs.NodeService) error { // Check for existing service existing, err := tx.First("services", "id", node, svc.Service) if err != nil { @@ -384,6 +550,8 @@ func ensureServiceTxn(tx *memdb.Txn, idx uint64, node string, svc *structs.NodeS if err := tx.Insert("index", &IndexEntry{"services", idx}); err != nil { return fmt.Errorf("failed updating index: %s", err) } + + tx.Defer(func() { s.tableWatches["services"].Notify() }) return nil } @@ -448,17 +616,19 @@ func (s *StateStore) DeleteService(idx uint64, nodeID, serviceID string) error { defer tx.Abort() // Call the service deletion - if err := deleteServiceTxn(tx, idx, nodeID, serviceID); err != nil { + watches := NewDumbWatchManager(s.tableWatches) + if err := s.deleteServiceTxn(tx, idx, watches, nodeID, serviceID); err != nil { return err } + tx.Defer(func() { watches.Notify() }) tx.Commit() return nil } // deleteServiceTxn is the inner method called to remove a service // registration within an existing transaction. -func deleteServiceTxn(tx *memdb.Txn, idx uint64, nodeID, serviceID string) error { +func (s *StateStore) deleteServiceTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, nodeID, serviceID string) error { // Look up the service service, err := tx.First("services", "id", nodeID, serviceID) if err != nil { @@ -477,6 +647,7 @@ func deleteServiceTxn(tx *memdb.Txn, idx uint64, nodeID, serviceID string) error if err := tx.Delete("checks", check); err != nil { return fmt.Errorf("failed deleting service check: %s", err) } + watches.Arm("checks") } if err := tx.Insert("index", &IndexEntry{"checks", idx}); err != nil { return fmt.Errorf("failed updating index: %s", err) @@ -490,8 +661,7 @@ func deleteServiceTxn(tx *memdb.Txn, idx uint64, nodeID, serviceID string) error return fmt.Errorf("failed updating index: %s", err) } - // TODO: session invalidation - // TODO: watch trigger + watches.Arm("services") return nil } @@ -501,7 +671,7 @@ func (s *StateStore) EnsureCheck(idx uint64, hc *structs.HealthCheck) error { defer tx.Abort() // Call the check registration - if err := ensureCheckTxn(tx, idx, hc); err != nil { + if err := s.ensureCheckTxn(tx, idx, hc); err != nil { return err } @@ -512,7 +682,7 @@ func (s *StateStore) EnsureCheck(idx uint64, hc *structs.HealthCheck) error { // ensureCheckTransaction is used as the inner method to handle inserting // a health check into the state store. It ensures safety against inserting // checks with no matching node or service. -func ensureCheckTxn(tx *memdb.Txn, idx uint64, hc *structs.HealthCheck) error { +func (s *StateStore) ensureCheckTxn(tx *memdb.Txn, idx uint64, hc *structs.HealthCheck) error { // Check if we have an existing health check existing, err := tx.First("checks", "id", hc.Node, hc.CheckID) if err != nil { @@ -557,9 +727,24 @@ func ensureCheckTxn(tx *memdb.Txn, idx uint64, hc *structs.HealthCheck) error { hc.ServiceName = service.(*structs.ServiceNode).ServiceName } - // TODO: invalidate sessions if status == critical + // Delete any sessions for this check if the health is critical. + if hc.Status == structs.HealthCritical { + mappings, err := tx.Get("session_checks", "node_check", hc.Node, hc.CheckID) + if err != nil { + return fmt.Errorf("failed session checks lookup: %s", err) + } - // Persist the check registration in the db + watches := NewDumbWatchManager(s.tableWatches) + for mapping := mappings.Next(); mapping != nil; mapping = mappings.Next() { + session := mapping.(*sessionCheck).Session + if err := s.deleteSessionTxn(tx, idx, watches, session); err != nil { + return fmt.Errorf("failed deleting session: %s", err) + } + } + tx.Defer(func() { watches.Notify() }) + } + + // Persist the check registration in the db. if err := tx.Insert("checks", hc); err != nil { return fmt.Errorf("failed inserting service: %s", err) } @@ -567,8 +752,7 @@ func ensureCheckTxn(tx *memdb.Txn, idx uint64, hc *structs.HealthCheck) error { return fmt.Errorf("failed updating index: %s", err) } - // TODO: trigger watches - + tx.Defer(func() { s.tableWatches["checks"].Notify() }) return nil } @@ -615,12 +799,12 @@ func (s *StateStore) parseChecks(iter memdb.ResultIterator, err error) (uint64, // Track the highest index along the way. var results structs.HealthChecks var lindex uint64 - for hc := iter.Next(); hc != nil; hc = iter.Next() { - check := hc.(*structs.HealthCheck) - if check.ModifyIndex > lindex { - lindex = check.ModifyIndex + for check := iter.Next(); check != nil; check = iter.Next() { + hc := check.(*structs.HealthCheck) + if hc.ModifyIndex > lindex { + lindex = hc.ModifyIndex } - results = append(results, check) + results = append(results, hc) } return lindex, results, nil } @@ -631,36 +815,49 @@ func (s *StateStore) DeleteCheck(idx uint64, node, id string) error { defer tx.Abort() // Call the check deletion - if err := deleteCheckTxn(tx, idx, node, id); err != nil { + watches := NewDumbWatchManager(s.tableWatches) + if err := s.deleteCheckTxn(tx, idx, watches, node, id); err != nil { return err } + tx.Defer(func() { watches.Notify() }) tx.Commit() return nil } // deleteCheckTxn is the inner method used to call a health // check deletion within an existing transaction. -func deleteCheckTxn(tx *memdb.Txn, idx uint64, node, id string) error { - // Try to retrieve the existing health check - check, err := tx.First("checks", "id", node, id) +func (s *StateStore) deleteCheckTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, node, id string) error { + // Try to retrieve the existing health check. + hc, err := tx.First("checks", "id", node, id) if err != nil { return fmt.Errorf("check lookup failed: %s", err) } - if check == nil { + if hc == nil { return nil } - // Delete the check from the DB and update the index - if err := tx.Delete("checks", check); err != nil { + // Delete the check from the DB and update the index. + if err := tx.Delete("checks", hc); err != nil { return fmt.Errorf("failed removing check: %s", err) } if err := tx.Insert("index", &IndexEntry{"checks", idx}); err != nil { return fmt.Errorf("failed updating index: %s", err) } - // TODO: invalidate sessions - // TODO: watch triggers + // Delete any sessions for this check. + mappings, err := tx.Get("session_checks", "node_check", node, id) + if err != nil { + return fmt.Errorf("failed session checks lookup: %s", err) + } + for mapping := mappings.Next(); mapping != nil; mapping = mappings.Next() { + session := mapping.(*sessionCheck).Session + if err := s.deleteSessionTxn(tx, idx, watches, session); err != nil { + return fmt.Errorf("failed deleting session: %s", err) + } + } + + watches.Arm("checks") return nil } @@ -816,11 +1013,11 @@ func (s *StateStore) parseNodes( return 0, nil, fmt.Errorf("failed node lookup: %s", err) } for check := checks.Next(); check != nil; check = checks.Next() { - chk := check.(*structs.HealthCheck) - if chk.ModifyIndex > lindex { - lindex = chk.ModifyIndex + hc := check.(*structs.HealthCheck) + if hc.ModifyIndex > lindex { + lindex = hc.ModifyIndex } - dump.Checks = append(dump.Checks, chk) + dump.Checks = append(dump.Checks, hc) } // Add the result to the slice @@ -833,13 +1030,19 @@ func (s *StateStore) parseNodes( func (s *StateStore) KVSSet(idx uint64, entry *structs.DirEntry) error { tx := s.db.Txn(true) defer tx.Abort() - return kvsSetTxn(tx, idx, entry) + + // Perform the actual set. + if err := s.kvsSetTxn(tx, idx, entry); err != nil { + return err + } + + tx.Commit() + return nil } // kvsSetTxn is used to insert or update a key/value pair in the state // store. It is the inner method used and handles only the actual storage. -func kvsSetTxn(tx *memdb.Txn, idx uint64, entry *structs.DirEntry) error { - +func (s *StateStore) kvsSetTxn(tx *memdb.Txn, idx uint64, entry *structs.DirEntry) error { // Retrieve an existing KV pair existing, err := tx.First("kvs", "id", entry.Key) if err != nil { @@ -863,7 +1066,7 @@ func kvsSetTxn(tx *memdb.Txn, idx uint64, entry *structs.DirEntry) error { return fmt.Errorf("failed updating index: %s", err) } - tx.Commit() + tx.Defer(func() { s.kvsWatch.Notify(entry.Key, false) }) return nil } @@ -905,6 +1108,15 @@ func (s *StateStore) KVSList(prefix string) (uint64, []string, error) { lindex = e.ModifyIndex } } + + // Check for the highest index in the graveyard. + gindex, err := s.kvsGraveyard.GetMaxIndexTxn(tx, prefix) + if err != nil { + return 0, nil, fmt.Errorf("failed graveyard lookup: %s", err) + } + if gindex > lindex { + lindex = gindex + } return lindex, keys, nil } @@ -956,6 +1168,15 @@ func (s *StateStore) KVSListKeys(prefix, sep string) (uint64, []string, error) { keys = append(keys, e.Key) } } + + // Check for the highest index in the graveyard. + gindex, err := s.kvsGraveyard.GetMaxIndexTxn(tx, prefix) + if err != nil { + return 0, nil, fmt.Errorf("failed graveyard lookup: %s", err) + } + if gindex > lindex { + lindex = gindex + } return lindex, keys, nil } @@ -966,7 +1187,7 @@ func (s *StateStore) KVSDelete(idx uint64, key string) error { defer tx.Abort() // Perform the actual delete - if err := kvsDeleteTxn(tx, idx, key); err != nil { + if err := s.kvsDeleteTxn(tx, idx, key); err != nil { return err } @@ -976,8 +1197,8 @@ func (s *StateStore) KVSDelete(idx uint64, key string) error { // kvsDeleteTxn is the inner method used to perform the actual deletion // of a key/value pair within an existing transaction. -func kvsDeleteTxn(tx *memdb.Txn, idx uint64, key string) error { - // Look up the entry in the state store +func (s *StateStore) kvsDeleteTxn(tx *memdb.Txn, idx uint64, key string) error { + // Look up the entry in the state store. entry, err := tx.First("kvs", "id", key) if err != nil { return fmt.Errorf("failed kvs lookup: %s", err) @@ -986,13 +1207,20 @@ func kvsDeleteTxn(tx *memdb.Txn, idx uint64, key string) error { return nil } - // Delete the entry and update the index + // Create a tombstone. + if err := s.kvsGraveyard.InsertTxn(tx, key, idx); err != nil { + return fmt.Errorf("failed adding to graveyard: %s", err) + } + + // Delete the entry and update the index. if err := tx.Delete("kvs", entry); err != nil { return fmt.Errorf("failed deleting kvs entry: %s", err) } if err := tx.Insert("index", &IndexEntry{"kvs", idx}); err != nil { return fmt.Errorf("failed updating index: %s", err) } + + tx.Defer(func() { s.kvsWatch.Notify(key, false) }) return nil } @@ -1004,7 +1232,7 @@ func (s *StateStore) KVSDeleteCAS(idx, cidx uint64, key string) (bool, error) { tx := s.db.Txn(true) defer tx.Abort() - // Retrieve the existing kvs entry, if any exists + // Retrieve the existing kvs entry, if any exists. entry, err := tx.First("kvs", "id", key) if err != nil { return false, fmt.Errorf("failed kvs lookup: %s", err) @@ -1018,8 +1246,8 @@ func (s *StateStore) KVSDeleteCAS(idx, cidx uint64, key string) (bool, error) { return entry == nil, nil } - // Call the actual deletion if the above passed - if err := kvsDeleteTxn(tx, idx, key); err != nil { + // Call the actual deletion if the above passed. + if err := s.kvsDeleteTxn(tx, idx, key); err != nil { return false, err } @@ -1035,7 +1263,7 @@ func (s *StateStore) KVSSetCAS(idx uint64, entry *structs.DirEntry) (bool, error tx := s.db.Txn(true) defer tx.Abort() - // Retrieve the existing entry + // Retrieve the existing entry. existing, err := tx.First("kvs", "id", entry.Key) if err != nil { return false, fmt.Errorf("failed kvs lookup: %s", err) @@ -1055,7 +1283,12 @@ func (s *StateStore) KVSSetCAS(idx uint64, entry *structs.DirEntry) (bool, error } // If we made it this far, we should perform the set. - return true, kvsSetTxn(tx, idx, entry) + if err := s.kvsSetTxn(tx, idx, entry); err != nil { + return false, err + } + + tx.Commit() + return true, nil } // KVSDeleteTree is used to do a recursive delete on a key prefix @@ -1065,18 +1298,23 @@ func (s *StateStore) KVSDeleteTree(idx uint64, prefix string) error { tx := s.db.Txn(true) defer tx.Abort() - // Get an iterator over all of the keys with the given prefix + // Get an iterator over all of the keys with the given prefix. entries, err := tx.Get("kvs", "id_prefix", prefix) if err != nil { return fmt.Errorf("failed kvs lookup: %s", err) } // Go over all of the keys and remove them. We call the delete - // directly so that we only update the index once. + // directly so that we only update the index once. We also add + // tombstones as we go. var modified bool for entry := entries.Next(); entry != nil; entry = entries.Next() { - err := tx.Delete("kvs", entry.(*structs.DirEntry)) - if err != nil { + e := entry.(*structs.DirEntry) + if err := s.kvsGraveyard.InsertTxn(tx, e.Key, idx); err != nil { + return fmt.Errorf("failed adding to graveyard: %s", err) + } + + if err := tx.Delete("kvs", e); err != nil { return fmt.Errorf("failed deleting kvs entry: %s", err) } modified = true @@ -1084,6 +1322,7 @@ func (s *StateStore) KVSDeleteTree(idx uint64, prefix string) error { // Update the index if modified { + tx.Defer(func() { s.kvsWatch.Notify(prefix, true) }) if err := tx.Insert("index", &IndexEntry{"kvs", idx}); err != nil { return fmt.Errorf("failed updating index: %s", err) } @@ -1093,13 +1332,146 @@ func (s *StateStore) KVSDeleteTree(idx uint64, prefix string) error { return nil } +// KVSLockDelay returns the expiration time for any lock delay associated with +// the given key. +func (s *StateStore) KVSLockDelay(key string) time.Time { + return s.lockDelay.GetExpiration(key) +} + +// KVSLock is similar to KVSSet but only performs the set if the lock can be +// acquired. +func (s *StateStore) KVSLock(idx uint64, entry *structs.DirEntry) (bool, error) { + tx := s.db.Txn(true) + defer tx.Abort() + + // Verify that a session is present. + if entry.Session == "" { + return false, fmt.Errorf("Missing session") + } + + // Verify that the session exists. + sess, err := tx.First("sessions", "id", entry.Session) + if err != nil { + return false, fmt.Errorf("failed session lookup: %s", err) + } + if sess == nil { + return false, fmt.Errorf("Invalid session %#v", entry.Session) + } + + // Retrieve the existing entry. + existing, err := tx.First("kvs", "id", entry.Key) + if err != nil { + return false, fmt.Errorf("failed kvs lookup: %s", err) + } + + // Set up the entry, using the existing entry if present. + if existing != nil { + // Bail if there's already a lock on this entry. + e := existing.(*structs.DirEntry) + if e.Session != "" { + return false, nil + } + + entry.CreateIndex = e.CreateIndex + entry.LockIndex = e.LockIndex + 1 + } else { + entry.CreateIndex = idx + entry.LockIndex = 1 + } + entry.ModifyIndex = idx + + // If we made it this far, we should perform the set. + if err := s.kvsSetTxn(tx, idx, entry); err != nil { + return false, err + } + + tx.Commit() + return true, nil +} + +// KVSUnlock is similar to KVSSet but only performs the set if the lock can be +// unlocked (the key must already exist and be locked). +func (s *StateStore) KVSUnlock(idx uint64, entry *structs.DirEntry) (bool, error) { + tx := s.db.Txn(true) + defer tx.Abort() + + // Verify that a session is present. + if entry.Session == "" { + return false, fmt.Errorf("Missing session") + } + + // Retrieve the existing entry. + existing, err := tx.First("kvs", "id", entry.Key) + if err != nil { + return false, fmt.Errorf("failed kvs lookup: %s", err) + } + + // Bail if there's no existing key. + if existing == nil { + return false, nil + } + + // Make sure the given session is the lock holder. + e := existing.(*structs.DirEntry) + if e.Session != entry.Session { + return false, nil + } + + // Clear the lock and update the entry. + entry.Session = "" + entry.LockIndex = e.LockIndex + entry.CreateIndex = e.CreateIndex + entry.ModifyIndex = idx + + // If we made it this far, we should perform the set. + if err := s.kvsSetTxn(tx, idx, entry); err != nil { + return false, err + } + + tx.Commit() + return true, nil +} + +// KVSRestore is used when restoring from a snapshot. Use KVSSet for general +// inserts. +func (s *StateStore) KVSRestore(entry *structs.DirEntry) error { + tx := s.db.Txn(true) + defer tx.Abort() + + if err := tx.Insert("kvs", entry); err != nil { + return fmt.Errorf("failed inserting kvs entry: %s", err) + } + + if err := indexUpdateMaxTxn(tx, entry.ModifyIndex, "kvs"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Defer(func() { s.kvsWatch.Notify(entry.Key, false) }) + tx.Commit() + return nil +} + +// Tombstone is used when restoring from a snapshot. For general inserts, use +// Graveyard.InsertTxn. +func (s *StateStore) TombstoneRestore(stone *Tombstone) error { + tx := s.db.Txn(true) + defer tx.Abort() + + if err := s.kvsGraveyard.RestoreTxn(tx, stone); err != nil { + return fmt.Errorf("failed restoring tombstone: %s", err) + } + + tx.Commit() + return nil +} + // SessionCreate is used to register a new session in the state store. func (s *StateStore) SessionCreate(idx uint64, sess *structs.Session) error { tx := s.db.Txn(true) defer tx.Abort() // Call the session creation - if err := sessionCreateTxn(tx, idx, sess); err != nil { + if err := s.sessionCreateTxn(tx, idx, sess); err != nil { return err } @@ -1110,7 +1482,7 @@ func (s *StateStore) SessionCreate(idx uint64, sess *structs.Session) error { // sessionCreateTxn is the inner method used for creating session entries in // an open transaction. Any health checks registered with the session will be // checked for failing status. Returns any error encountered. -func sessionCreateTxn(tx *memdb.Txn, idx uint64, sess *structs.Session) error { +func (s *StateStore) sessionCreateTxn(tx *memdb.Txn, idx uint64, sess *structs.Session) error { // Check that we have a session ID if sess.ID == "" { return ErrMissingSessionID @@ -1165,12 +1537,12 @@ func sessionCreateTxn(tx *memdb.Txn, idx uint64, sess *structs.Session) error { // Insert the check mappings for _, checkID := range sess.Checks { - check := &sessionCheck{ + mapping := &sessionCheck{ Node: sess.Node, CheckID: checkID, Session: sess.ID, } - if err := tx.Insert("session_checks", check); err != nil { + if err := tx.Insert("session_checks", mapping); err != nil { return fmt.Errorf("failed inserting session check mapping: %s", err) } } @@ -1179,6 +1551,8 @@ func sessionCreateTxn(tx *memdb.Txn, idx uint64, sess *structs.Session) error { if err := tx.Insert("index", &IndexEntry{"sessions", idx}); err != nil { return fmt.Errorf("failed updating index: %s", err) } + + tx.Defer(func() { s.tableWatches["sessions"].Notify() }) return nil } @@ -1199,18 +1573,18 @@ func (s *StateStore) GetSession(sessionID string) (*structs.Session, error) { } // SessionList returns a slice containing all of the active sessions. -func (s *StateStore) SessionList() (uint64, []*structs.Session, error) { +func (s *StateStore) SessionList() (uint64, structs.Sessions, error) { tx := s.db.Txn(false) defer tx.Abort() - // Query all of the active sessions + // Query all of the active sessions. sessions, err := tx.Get("sessions", "id") if err != nil { return 0, nil, fmt.Errorf("failed session lookup: %s", err) } - // Go over the sessions and create a slice of them - var result []*structs.Session + // Go over the sessions and create a slice of them. + var result structs.Sessions var lindex uint64 for session := sessions.Next(); session != nil; session = sessions.Next() { sess := session.(*structs.Session) @@ -1227,7 +1601,7 @@ func (s *StateStore) SessionList() (uint64, []*structs.Session, error) { // NodeSessions returns a set of active sessions associated // with the given node ID. The returned index is the highest // index seen from the result set. -func (s *StateStore) NodeSessions(nodeID string) (uint64, []*structs.Session, error) { +func (s *StateStore) NodeSessions(nodeID string) (uint64, structs.Sessions, error) { tx := s.db.Txn(false) defer tx.Abort() @@ -1238,7 +1612,7 @@ func (s *StateStore) NodeSessions(nodeID string) (uint64, []*structs.Session, er } // Go over all of the sessions and return them as a slice - var result []*structs.Session + var result structs.Sessions var lindex uint64 for session := sessions.Next(); session != nil; session = sessions.Next() { sess := session.(*structs.Session) @@ -1259,19 +1633,21 @@ func (s *StateStore) SessionDestroy(idx uint64, sessionID string) error { tx := s.db.Txn(true) defer tx.Abort() - // Call the session deletion - if err := sessionDestroyTxn(tx, idx, sessionID); err != nil { + // Call the session deletion. + watches := NewDumbWatchManager(s.tableWatches) + if err := s.deleteSessionTxn(tx, idx, watches, sessionID); err != nil { return err } + tx.Defer(func() { watches.Notify() }) tx.Commit() return nil } -// sessionDestroyTxn is the inner method, which is used to do the actual +// deleteSessionTxn is the inner method, which is used to do the actual // session deletion and handle session invalidation, watch triggers, etc. -func sessionDestroyTxn(tx *memdb.Txn, idx uint64, sessionID string) error { - // Look up the session +func (s *StateStore) deleteSessionTxn(tx *memdb.Txn, idx uint64, watches *DumbWatchManager, sessionID string) error { + // Look up the session. sess, err := tx.First("sessions", "id", sessionID) if err != nil { return fmt.Errorf("failed session lookup: %s", err) @@ -1280,7 +1656,7 @@ func sessionDestroyTxn(tx *memdb.Txn, idx uint64, sessionID string) error { return nil } - // Delete the session and write the new index + // Delete the session and write the new index. if err := tx.Delete("sessions", sess); err != nil { return fmt.Errorf("failed deleting session: %s", err) } @@ -1288,8 +1664,104 @@ func sessionDestroyTxn(tx *memdb.Txn, idx uint64, sessionID string) error { return fmt.Errorf("failed updating index: %s", err) } - // TODO: invalidate session + // Enforce the max lock delay. + session := sess.(*structs.Session) + delay := session.LockDelay + if delay > structs.MaxLockDelay { + delay = structs.MaxLockDelay + } + // Snag the current now time so that all the expirations get calculated + // the same way. + now := time.Now() + + // Get an iterator over all of the keys with the given session. + entries, err := tx.Get("kvs", "session", sessionID) + if err != nil { + return fmt.Errorf("failed kvs lookup: %s", err) + } + + // TODO (slackpad) The operations below use the common inner functions + // that look up the entry by key; we could optimize this by splitting + // the inner functions and passing the entry we already have, though it + // makes the code a little more complex. + + // Invalidate any held locks. + switch session.Behavior { + case structs.SessionKeysRelease: + for entry := entries.Next(); entry != nil; entry = entries.Next() { + e := entry.(*structs.DirEntry).Clone() + e.Session = "" + if err := s.kvsSetTxn(tx, idx, e); err != nil { + return fmt.Errorf("failed kvs update: %s", err) + } + + // Apply the lock delay if present. + if delay > 0 { + s.lockDelay.SetExpiration(e.Key, now, delay) + } + } + case structs.SessionKeysDelete: + for entry := entries.Next(); entry != nil; entry = entries.Next() { + e := entry.(*structs.DirEntry) + if err := s.kvsDeleteTxn(tx, idx, e.Key); err != nil { + return fmt.Errorf("failed kvs delete: %s", err) + } + + // Apply the lock delay if present. + if delay > 0 { + s.lockDelay.SetExpiration(e.Key, now, delay) + } + } + default: + return fmt.Errorf("unknown session behavior %#v", session.Behavior) + } + + // Delete any check mappings. + mappings, err := tx.Get("session_checks", "session", sessionID) + if err != nil { + return fmt.Errorf("failed session checks lookup: %s", err) + } + for mapping := mappings.Next(); mapping != nil; mapping = mappings.Next() { + if err := tx.Delete("session_checks", mapping); err != nil { + return fmt.Errorf("failed deleting session check: %s", err) + } + } + + watches.Arm("sessions") + return nil +} + +// SessionRestore is used when restoring from a snapshot. For general inserts, +// use SessionCreate. +func (s *StateStore) SessionRestore(sess *structs.Session) error { + tx := s.db.Txn(false) + defer tx.Abort() + + // Insert the session + if err := tx.Insert("sessions", sess); err != nil { + return fmt.Errorf("failed inserting session: %s", err) + } + + // Insert the check mappings + for _, checkID := range sess.Checks { + mapping := &sessionCheck{ + Node: sess.Node, + CheckID: checkID, + Session: sess.ID, + } + if err := tx.Insert("session_checks", mapping); err != nil { + return fmt.Errorf("failed inserting session check mapping: %s", err) + } + } + + // Update the index + if err := indexUpdateMaxTxn(tx, sess.ModifyIndex, "sessions"); err != nil { + return fmt.Errorf("failed updating index: %s", err) + } + + tx.Defer(func() { s.tableWatches["sessions"].Notify() }) + tx.Commit() return nil } @@ -1299,18 +1771,17 @@ func (s *StateStore) ACLSet(idx uint64, acl *structs.ACL) error { defer tx.Abort() // Call set on the ACL - if err := aclSetTxn(tx, idx, acl); err != nil { + if err := s.aclSetTxn(tx, idx, acl); err != nil { return err } - tx.Defer(func() { s.GetWatchManager("acls").Notify() }) tx.Commit() return nil } // aclSetTxn is the inner method used to insert an ACL rule with the // proper indexes into the state store. -func aclSetTxn(tx *memdb.Txn, idx uint64, acl *structs.ACL) error { +func (s *StateStore) aclSetTxn(tx *memdb.Txn, idx uint64, acl *structs.ACL) error { // Check that the ID is set if acl.ID == "" { return ErrMissingACLID @@ -1338,6 +1809,8 @@ func aclSetTxn(tx *memdb.Txn, idx uint64, acl *structs.ACL) error { if err := tx.Insert("index", &IndexEntry{"acls", idx}); err != nil { return fmt.Errorf("failed updating index: %s", err) } + + tx.Defer(func() { s.tableWatches["acls"].Notify() }) return nil } @@ -1358,14 +1831,15 @@ func (s *StateStore) ACLGet(aclID string) (*structs.ACL, error) { } // ACLList is used to list out all of the ACLs in the state store. -func (s *StateStore) ACLList() (uint64, []*structs.ACL, error) { +func (s *StateStore) ACLList() (uint64, structs.ACLs, error) { tx := s.db.Txn(false) defer tx.Abort() return aclListTxn(tx) } -// aclListTxn is used to list out all of the ACLs in the state store. -func aclListTxn(tx *memdb.Txn) (uint64, []*structs.ACL, error) { +// aclListTxn is used to list out all of the ACLs in the state store. This is a +// function vs. a method so it can be called from the snapshotter. +func aclListTxn(tx *memdb.Txn) (uint64, structs.ACLs, error) { // Query all of the ACLs in the state store acls, err := tx.Get("acls", "id") if err != nil { @@ -1373,7 +1847,7 @@ func aclListTxn(tx *memdb.Txn) (uint64, []*structs.ACL, error) { } // Go over all of the ACLs and build the response - var result []*structs.ACL + var result structs.ACLs var lindex uint64 for acl := acls.Next(); acl != nil; acl = acls.Next() { a := acl.(*structs.ACL) @@ -1394,18 +1868,17 @@ func (s *StateStore) ACLDelete(idx uint64, aclID string) error { defer tx.Abort() // Call the ACL delete - if err := aclDeleteTxn(tx, idx, aclID); err != nil { + if err := s.aclDeleteTxn(tx, idx, aclID); err != nil { return err } - tx.Defer(func() { s.GetWatchManager("acls").Notify() }) tx.Commit() return nil } // aclDeleteTxn is used to delete an ACL from the state store within // an existing transaction. -func aclDeleteTxn(tx *memdb.Txn, idx uint64, aclID string) error { +func (s *StateStore) aclDeleteTxn(tx *memdb.Txn, idx uint64, aclID string) error { // Look up the existing ACL acl, err := tx.First("acls", "id", aclID) if err != nil { @@ -1422,6 +1895,8 @@ func aclDeleteTxn(tx *memdb.Txn, idx uint64, aclID string) error { if err := tx.Insert("index", &IndexEntry{"acls", idx}); err != nil { return fmt.Errorf("failed updating index: %s", err) } + + tx.Defer(func() { s.tableWatches["acls"].Notify() }) return nil } @@ -1436,10 +1911,10 @@ func (s *StateStore) ACLRestore(acl *structs.ACL) error { } if err := indexUpdateMaxTxn(tx, acl.ModifyIndex, "acls"); err != nil { - return err + return fmt.Errorf("failed updating index: %s", err) } - tx.Defer(func() { s.GetWatchManager("acls").Notify() }) + tx.Defer(func() { s.tableWatches["acls"].Notify() }) tx.Commit() return nil } diff --git a/consul/state/state_store_test.go b/consul/state/state_store_test.go index 2eea4fb7a..e944e9bfd 100644 --- a/consul/state/state_store_test.go +++ b/consul/state/state_store_test.go @@ -1140,7 +1140,7 @@ func TestStateStore_KVSDelete(t *testing.T) { t.Fatalf("bad index: %d", idx) } - // Deleting a nonexistent key should be idempotent and note return an + // Deleting a nonexistent key should be idempotent and not return an // error if err := s.KVSDelete(4, "foo"); err != nil { t.Fatalf("err: %s", err) @@ -1519,7 +1519,7 @@ func TestStateStore_SessionList(t *testing.T) { testRegisterNode(t, s, 3, "node3") // Create some sessions in the state store - sessions := []*structs.Session{ + sessions := structs.Sessions{ &structs.Session{ ID: "session1", Node: "node1", @@ -1569,7 +1569,7 @@ func TestStateStore_NodeSessions(t *testing.T) { testRegisterNode(t, s, 2, "node2") // Register some sessions with the nodes - sessions1 := []*structs.Session{ + sessions1 := structs.Sessions{ &structs.Session{ ID: "session1", Node: "node1", @@ -1758,7 +1758,7 @@ func TestStateStore_ACLList(t *testing.T) { } // Insert some ACLs - acls := []*structs.ACL{ + acls := structs.ACLs{ &structs.ACL{ ID: "acl1", Type: structs.ACLTypeClient, @@ -1839,7 +1839,7 @@ func TestStateStore_ACL_Watches(t *testing.T) { s := testStateStore(t) ch := make(chan struct{}) - s.GetWatchManager("acls").Start(ch) + s.GetTableWatch("acls").Wait(ch) go func() { if err := s.ACLSet(1, &structs.ACL{ID: "acl1"}); err != nil { t.Fatalf("err: %s", err) @@ -1851,7 +1851,7 @@ func TestStateStore_ACL_Watches(t *testing.T) { t.Fatalf("watch was not notified") } - s.GetWatchManager("acls").Start(ch) + s.GetTableWatch("acls").Wait(ch) go func() { if err := s.ACLDelete(2, "acl1"); err != nil { t.Fatalf("err: %s", err) @@ -1863,7 +1863,7 @@ func TestStateStore_ACL_Watches(t *testing.T) { t.Fatalf("watch was not notified") } - s.GetWatchManager("acls").Start(ch) + s.GetTableWatch("acls").Wait(ch) go func() { if err := s.ACLRestore(&structs.ACL{ID: "acl1"}); err != nil { t.Fatalf("err: %s", err) diff --git a/consul/state/watch.go b/consul/state/watch.go index 304c42ea2..04b071f26 100644 --- a/consul/state/watch.go +++ b/consul/state/watch.go @@ -1,35 +1,132 @@ package state import ( - "github.com/hashicorp/go-memdb" + "sync" + + "github.com/armon/go-radix" ) -type WatchManager interface { - Start(notifyCh chan struct{}) - Stop(notifyCh chan struct{}) - Notify() +// Watch is the external interface that's common to all the different flavors. +type Watch interface { + // Wait registers the given channel and calls it back when the watch + // fires. + Wait(notifyCh chan struct{}) + + // Clear deregisters the given channel. + Clear(notifyCh chan struct{}) } +// FullTableWatch implements a single notify group for a table. type FullTableWatch struct { - notify NotifyGroup + group NotifyGroup } -func (w *FullTableWatch) Start(notifyCh chan struct{}) { - w.notify.Wait(notifyCh) +// NewFullTableWatch returns a new full table watch. +func NewFullTableWatch() *FullTableWatch { + return &FullTableWatch{} } -func (w *FullTableWatch) Stop(notifyCh chan struct{}) { - w.notify.Clear(notifyCh) +// See Watch. +func (w *FullTableWatch) Wait(notifyCh chan struct{}) { + w.group.Wait(notifyCh) } +// See Watch. +func (w *FullTableWatch) Clear(notifyCh chan struct{}) { + w.group.Clear(notifyCh) +} + +// Notify wakes up all the watchers registered for this table. func (w *FullTableWatch) Notify() { - w.notify.Notify() + w.group.Notify() } -func newWatchManagers(schema *memdb.DBSchema) (map[string]WatchManager, error) { - watches := make(map[string]WatchManager) - for table, _ := range schema.Tables { - watches[table] = &FullTableWatch{} - } - return watches, nil +// DumbWatchManager is a wrapper that allows nested code to arm full table +// watches multiple times but fire them only once. +type DumbWatchManager struct { + tableWatches map[string]*FullTableWatch + armed map[string]bool +} + +// NewDumbWatchManager returns a new dumb watch manager. +func NewDumbWatchManager(tableWatches map[string]*FullTableWatch) *DumbWatchManager { + return &DumbWatchManager{ + tableWatches: tableWatches, + armed: make(map[string]bool), + } +} + +// Arm arms the given table's watch. +func (d *DumbWatchManager) Arm(table string) { + if _, ok := d.armed[table]; !ok { + d.armed[table] = true + } +} + +// Notify fires watches for all the armed tables. +func (d *DumbWatchManager) Notify() { + for table, _ := range d.armed { + d.tableWatches[table].Notify() + } +} + +// PrefixWatch maintains a notify group for each prefix, allowing for much more +// fine-grained watches. +type PrefixWatch struct { + // watches has the set of notify groups, organized by prefix. + watches *radix.Tree + + // lock protects the watches tree. + lock sync.Mutex +} + +// NewPrefixWatch returns a new prefix watch. +func NewPrefixWatch() *PrefixWatch { + return &PrefixWatch{watches: radix.New()} +} + +// GetSubwatch returns the notify group for the given prefix. +func (w *PrefixWatch) GetSubwatch(prefix string) *NotifyGroup { + w.lock.Lock() + defer w.lock.Unlock() + + if raw, ok := w.watches.Get(prefix); ok { + return raw.(*NotifyGroup) + } + + group := &NotifyGroup{} + w.watches.Insert(prefix, group) + return group +} + +// Notify wakes up all the watchers associated with the given prefix. If subtree +// is true then we will also notify all the tree under the prefix, such as when +// a key is being deleted. +func (w *PrefixWatch) Notify(prefix string, subtree bool) { + w.lock.Lock() + defer w.lock.Unlock() + + var cleanup []string + fn := func(k string, v interface{}) bool { + group := v.(*NotifyGroup) + group.Notify() + if k != "" { + cleanup = append(cleanup, k) + } + return false + } + + // Invoke any watcher on the path downward to the key. + w.watches.WalkPath(prefix, fn) + + // If the entire prefix may be affected (e.g. delete tree), + // invoke the entire prefix. + if subtree { + w.watches.WalkPrefix(prefix, fn) + } + + // Delete the old notify groups. + for i := len(cleanup) - 1; i >= 0; i-- { + w.watches.Delete(cleanup[i]) + } } diff --git a/consul/structs/structs.go b/consul/structs/structs.go index cfc455e0f..185caa680 100644 --- a/consul/structs/structs.go +++ b/consul/structs/structs.go @@ -356,6 +356,22 @@ type DirEntry struct { RaftIndex } + +// Returns a clone of the given directory entry. +func (d *DirEntry) Clone() *DirEntry { + return &DirEntry{ + LockIndex: d.LockIndex, + Key: d.Key, + Flags: d.Flags, + Value: d.Value, + Session: d.Session, + RaftIndex: RaftIndex{ + CreateIndex: d.CreateIndex, + ModifyIndex: d.ModifyIndex, + }, + } +} + type DirEntries []*DirEntry type KVSOp string