This commit is contained in:
Alex Dadgar 2017-02-03 16:28:27 -08:00
parent e6863e4c01
commit 94263b9648
12 changed files with 869 additions and 107 deletions

View file

@ -2,6 +2,7 @@ package iradix
import (
"bytes"
"strings"
"github.com/hashicorp/golang-lru/simplelru"
)
@ -11,7 +12,9 @@ const (
// cache used per transaction. This is used to cache the updates
// to the nodes near the root, while the leaves do not need to be
// cached. This is important for very large transactions to prevent
// the modified cache from growing to be enormous.
// the modified cache from growing to be enormous. This is also used
// to set the max size of the mutation notify maps since those should
// also be bounded in a similar way.
defaultModifiedCache = 8192
)
@ -27,7 +30,11 @@ type Tree struct {
// New returns an empty Tree
func New() *Tree {
t := &Tree{root: &Node{}}
t := &Tree{
root: &Node{
mutateCh: make(chan struct{}),
},
}
return t
}
@ -40,75 +47,148 @@ func (t *Tree) Len() int {
// atomically and returns a new tree when committed. A transaction
// is not thread safe, and should only be used by a single goroutine.
type Txn struct {
// root is the modified root for the transaction.
root *Node
// snap is a snapshot of the root node for use if we have to run the
// slow notify algorithm.
snap *Node
// size tracks the size of the tree as it is modified during the
// transaction.
size int
modified *simplelru.LRU
// writable is a cache of writable nodes that have been created during
// the course of the transaction. This allows us to re-use the same
// nodes for further writes and avoid unnecessary copies of nodes that
// have never been exposed outside the transaction. This will only hold
// up to defaultModifiedCache number of entries.
writable *simplelru.LRU
// trackChannels is used to hold channels that need to be notified to
// signal mutation of the tree. This will only hold up to
// defaultModifiedCache number of entries, after which we will set the
// trackOverflow flag, which will cause us to use a more expensive
// algorithm to perform the notifications. Mutation tracking is only
// performed if trackMutate is true.
trackChannels map[*chan struct{}]struct{}
trackOverflow bool
trackMutate bool
}
// Txn starts a new transaction that can be used to mutate the tree
func (t *Tree) Txn() *Txn {
txn := &Txn{
root: t.root,
snap: t.root,
size: t.size,
}
return txn
}
// writeNode returns a node to be modified, if the current
// node as already been modified during the course of
// the transaction, it is used in-place.
func (t *Txn) writeNode(n *Node) *Node {
// Ensure the modified set exists
if t.modified == nil {
// TrackMutate can be used to toggle if mutations are tracked. If this is enabled
// then notifications will be issued for affected internal nodes and leaves when
// the transaction is committed.
func (t *Txn) TrackMutate(track bool) {
t.trackMutate = track
}
// trackChannel safely attempts to track the given mutation channel, setting the
// overflow flag if we can no longer track any more. This limits the amount of
// state that will accumulate during a transaction and we have a slower algorithm
// to switch to if we overflow.
func (t *Txn) trackChannel(ch *chan struct{}) {
// In overflow, make sure we don't store any more objects.
if t.trackOverflow {
return
}
// Create the map on the fly when we need it.
if t.trackChannels == nil {
t.trackChannels = make(map[*chan struct{}]struct{})
}
// If this would overflow the state we reject it and set the flag (since
// we aren't tracking everything that's required any longer).
if len(t.trackChannels) >= defaultModifiedCache {
t.trackOverflow = true
return
}
// Otherwise we are good to track it.
t.trackChannels[ch] = struct{}{}
}
// writeNode returns a node to be modified, if the current node has already been
// modified during the course of the transaction, it is used in-place. Set
// forLeafUpdate to true if you are getting a write node to update the leaf,
// which will set leaf mutation tracking appropriately as well.
func (t *Txn) writeNode(n *Node, forLeafUpdate bool) *Node {
// Ensure the writable set exists.
if t.writable == nil {
lru, err := simplelru.NewLRU(defaultModifiedCache, nil)
if err != nil {
panic(err)
}
t.modified = lru
t.writable = lru
}
// If this node has already been modified, we can
// continue to use it during this transaction.
if _, ok := t.modified.Get(n); ok {
// If this node has already been modified, we can continue to use it
// during this transaction. If a node gets kicked out of cache then we
// *may* notify for its mutation if we end up copying the node again,
// but we don't make any guarantees about notifying for intermediate
// mutations that were never exposed outside of a transaction.
if _, ok := t.writable.Get(n); ok {
return n
}
// Copy the existing node
nc := new(Node)
// Mark this node as being mutated.
if t.trackMutate {
t.trackChannel(&(n.mutateCh))
}
// Mark its leaf as being mutated, if appropriate.
if t.trackMutate && forLeafUpdate && n.leaf != nil {
t.trackChannel(&(n.leaf.mutateCh))
}
// Copy the existing node.
nc := &Node{
mutateCh: make(chan struct{}),
leaf: n.leaf,
}
if n.prefix != nil {
nc.prefix = make([]byte, len(n.prefix))
copy(nc.prefix, n.prefix)
}
if n.leaf != nil {
nc.leaf = new(leafNode)
*nc.leaf = *n.leaf
}
if len(n.edges) != 0 {
nc.edges = make([]edge, len(n.edges))
copy(nc.edges, n.edges)
}
// Mark this node as modified
t.modified.Add(nc, nil)
// Mark this node as writable.
t.writable.Add(nc, nil)
return nc
}
// insert does a recursive insertion
func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface{}, bool) {
// Handle key exhaution
// Handle key exhaustion
if len(search) == 0 {
nc := t.writeNode(n)
var oldVal interface{}
didUpdate := false
if n.isLeaf() {
old := nc.leaf.val
nc.leaf.val = v
return nc, old, true
} else {
oldVal = n.leaf.val
didUpdate = true
}
nc := t.writeNode(n, true)
nc.leaf = &leafNode{
mutateCh: make(chan struct{}),
key: k,
val: v,
}
return nc, nil, false
}
return nc, oldVal, didUpdate
}
// Look for the edge
@ -119,14 +199,16 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface
e := edge{
label: search[0],
node: &Node{
mutateCh: make(chan struct{}),
leaf: &leafNode{
mutateCh: make(chan struct{}),
key: k,
val: v,
},
prefix: search,
},
}
nc := t.writeNode(n)
nc := t.writeNode(n, false)
nc.addEdge(e)
return nc, nil, false
}
@ -137,7 +219,7 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface
search = search[commonPrefix:]
newChild, oldVal, didUpdate := t.insert(child, k, search, v)
if newChild != nil {
nc := t.writeNode(n)
nc := t.writeNode(n, false)
nc.edges[idx].node = newChild
return nc, oldVal, didUpdate
}
@ -145,8 +227,9 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface
}
// Split the node
nc := t.writeNode(n)
nc := t.writeNode(n, false)
splitNode := &Node{
mutateCh: make(chan struct{}),
prefix: search[:commonPrefix],
}
nc.replaceEdge(edge{
@ -155,7 +238,7 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface
})
// Restore the existing child node
modChild := t.writeNode(child)
modChild := t.writeNode(child, false)
splitNode.addEdge(edge{
label: modChild.prefix[commonPrefix],
node: modChild,
@ -164,6 +247,7 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface
// Create a new leaf node
leaf := &leafNode{
mutateCh: make(chan struct{}),
key: k,
val: v,
}
@ -179,6 +263,7 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface
splitNode.addEdge(edge{
label: search[0],
node: &Node{
mutateCh: make(chan struct{}),
leaf: leaf,
prefix: search,
},
@ -188,14 +273,14 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface
// delete does a recursive deletion
func (t *Txn) delete(parent, n *Node, search []byte) (*Node, *leafNode) {
// Check for key exhaution
// Check for key exhaustion
if len(search) == 0 {
if !n.isLeaf() {
return nil, nil
}
// Remove the leaf node
nc := t.writeNode(n)
nc := t.writeNode(n, true)
nc.leaf = nil
// Check if this node should be merged
@ -219,8 +304,11 @@ func (t *Txn) delete(parent, n *Node, search []byte) (*Node, *leafNode) {
return nil, nil
}
// Copy this node
nc := t.writeNode(n)
// Copy this node. WATCH OUT - it's safe to pass "false" here because we
// will only ADD a leaf via nc.mergeChilde() if there isn't one due to
// the !nc.isLeaf() check in the logic just below. This is pretty subtle,
// so be careful if you change any of the logic here.
nc := t.writeNode(n, false)
// Delete the edge if the node has no edges
if newChild.leaf == nil && len(newChild.edges) == 0 {
@ -274,10 +362,109 @@ func (t *Txn) Get(k []byte) (interface{}, bool) {
return t.root.Get(k)
}
// Commit is used to finalize the transaction and return a new tree
// GetWatch is used to lookup a specific key, returning
// the watch channel, value and if it was found
func (t *Txn) GetWatch(k []byte) (<-chan struct{}, interface{}, bool) {
return t.root.GetWatch(k)
}
// Commit is used to finalize the transaction and return a new tree. If mutation
// tracking is turned on then notifications will also be issued.
func (t *Txn) Commit() *Tree {
t.modified = nil
return &Tree{t.root, t.size}
nt := t.commit()
if t.trackMutate {
t.notify()
}
return nt
}
// commit is an internal helper for Commit(), useful for unit tests.
func (t *Txn) commit() *Tree {
nt := &Tree{t.root, t.size}
t.writable = nil
return nt
}
// slowNotify does a complete comparison of the before and after trees in order
// to trigger notifications. This doesn't require any additional state but it
// is very expensive to compute.
func (t *Txn) slowNotify() {
snapIter := t.snap.rawIterator()
rootIter := t.root.rawIterator()
for snapIter.Front() != nil || rootIter.Front() != nil {
// If we've exhausted the nodes in the old snapshot, we know
// there's nothing remaining to notify.
if snapIter.Front() == nil {
return
}
snapElem := snapIter.Front()
// If we've exhausted the nodes in the new root, we know we need
// to invalidate everything that remains in the old snapshot. We
// know from the loop condition there's something in the old
// snapshot.
if rootIter.Front() == nil {
close(snapElem.mutateCh)
if snapElem.isLeaf() {
close(snapElem.leaf.mutateCh)
}
snapIter.Next()
continue
}
// Do one string compare so we can check the various conditions
// below without repeating the compare.
cmp := strings.Compare(snapIter.Path(), rootIter.Path())
// If the snapshot is behind the root, then we must have deleted
// this node during the transaction.
if cmp < 0 {
close(snapElem.mutateCh)
if snapElem.isLeaf() {
close(snapElem.leaf.mutateCh)
}
snapIter.Next()
continue
}
// If the snapshot is ahead of the root, then we must have added
// this node during the transaction.
if cmp > 0 {
rootIter.Next()
continue
}
// If we have the same path, then we need to see if we mutated a
// node and possibly the leaf.
rootElem := rootIter.Front()
if snapElem != rootElem {
close(snapElem.mutateCh)
if snapElem.leaf != nil && (snapElem.leaf != rootElem.leaf) {
close(snapElem.leaf.mutateCh)
}
}
snapIter.Next()
rootIter.Next()
}
}
// notify is used along with TrackMutate to trigger notifications. This should
// only be done once a transaction is committed.
func (t *Txn) notify() {
// If we've overflowed the tracking state we can't use it in any way and
// need to do a full tree compare.
if t.trackOverflow {
t.slowNotify()
} else {
for ch := range t.trackChannels {
close(*ch)
}
}
// Clean up the tracking state so that a re-notify is safe (will trigger
// the else clause above which will be a no-op).
t.trackChannels = nil
t.trackOverflow = false
}
// Insert is used to add or update a given key. The return provides

View file

@ -9,11 +9,13 @@ type Iterator struct {
stack []edges
}
// SeekPrefix is used to seek the iterator to a given prefix
func (i *Iterator) SeekPrefix(prefix []byte) {
// SeekPrefixWatch is used to seek the iterator to a given prefix
// and returns the watch channel of the finest granularity
func (i *Iterator) SeekPrefixWatch(prefix []byte) (watch <-chan struct{}) {
// Wipe the stack
i.stack = nil
n := i.node
watch = n.mutateCh
search := prefix
for {
// Check for key exhaution
@ -29,6 +31,9 @@ func (i *Iterator) SeekPrefix(prefix []byte) {
return
}
// Update to the finest granularity as the search makes progress
watch = n.mutateCh
// Consume the search prefix
if bytes.HasPrefix(search, n.prefix) {
search = search[len(n.prefix):]
@ -43,6 +48,11 @@ func (i *Iterator) SeekPrefix(prefix []byte) {
}
}
// SeekPrefix is used to seek the iterator to a given prefix
func (i *Iterator) SeekPrefix(prefix []byte) {
i.SeekPrefixWatch(prefix)
}
// Next returns the next node in order
func (i *Iterator) Next() ([]byte, interface{}, bool) {
// Initialize our stack if needed

View file

@ -12,6 +12,7 @@ type WalkFn func(k []byte, v interface{}) bool
// leafNode is used to represent a value
type leafNode struct {
mutateCh chan struct{}
key []byte
val interface{}
}
@ -24,6 +25,9 @@ type edge struct {
// Node is an immutable node in the radix tree
type Node struct {
// mutateCh is closed if this node is modified
mutateCh chan struct{}
// leaf is used to store possible leaf
leaf *leafNode
@ -105,13 +109,14 @@ func (n *Node) mergeChild() {
}
}
func (n *Node) Get(k []byte) (interface{}, bool) {
func (n *Node) GetWatch(k []byte) (<-chan struct{}, interface{}, bool) {
search := k
watch := n.mutateCh
for {
// Check for key exhaution
// Check for key exhaustion
if len(search) == 0 {
if n.isLeaf() {
return n.leaf.val, true
return n.leaf.mutateCh, n.leaf.val, true
}
break
}
@ -122,6 +127,9 @@ func (n *Node) Get(k []byte) (interface{}, bool) {
break
}
// Update to the finest granularity as the search makes progress
watch = n.mutateCh
// Consume the search prefix
if bytes.HasPrefix(search, n.prefix) {
search = search[len(n.prefix):]
@ -129,7 +137,12 @@ func (n *Node) Get(k []byte) (interface{}, bool) {
break
}
}
return nil, false
return watch, nil, false
}
func (n *Node) Get(k []byte) (interface{}, bool) {
_, val, ok := n.GetWatch(k)
return val, ok
}
// LongestPrefix is like Get, but instead of an
@ -204,6 +217,14 @@ func (n *Node) Iterator() *Iterator {
return &Iterator{node: n}
}
// rawIterator is used to return a raw iterator at the given node to walk the
// tree.
func (n *Node) rawIterator() *rawIterator {
iter := &rawIterator{node: n}
iter.Next()
return iter
}
// Walk is used to walk the tree
func (n *Node) Walk(fn WalkFn) {
recursiveWalk(n, fn)

View file

@ -0,0 +1,78 @@
package iradix
// rawIterator visits each of the nodes in the tree, even the ones that are not
// leaves. It keeps track of the effective path (what a leaf at a given node
// would be called), which is useful for comparing trees.
type rawIterator struct {
// node is the starting node in the tree for the iterator.
node *Node
// stack keeps track of edges in the frontier.
stack []rawStackEntry
// pos is the current position of the iterator.
pos *Node
// path is the effective path of the current iterator position,
// regardless of whether the current node is a leaf.
path string
}
// rawStackEntry is used to keep track of the cumulative common path as well as
// its associated edges in the frontier.
type rawStackEntry struct {
path string
edges edges
}
// Front returns the current node that has been iterated to.
func (i *rawIterator) Front() *Node {
return i.pos
}
// Path returns the effective path of the current node, even if it's not actually
// a leaf.
func (i *rawIterator) Path() string {
return i.path
}
// Next advances the iterator to the next node.
func (i *rawIterator) Next() {
// Initialize our stack if needed.
if i.stack == nil && i.node != nil {
i.stack = []rawStackEntry{
rawStackEntry{
edges: edges{
edge{node: i.node},
},
},
}
}
for len(i.stack) > 0 {
// Inspect the last element of the stack.
n := len(i.stack)
last := i.stack[n-1]
elem := last.edges[0].node
// Update the stack.
if len(last.edges) > 1 {
i.stack[n-1].edges = last.edges[1:]
} else {
i.stack = i.stack[:n-1]
}
// Push the edges onto the frontier.
if len(elem.edges) > 0 {
path := last.path + string(elem.prefix)
i.stack = append(i.stack, rawStackEntry{path, elem.edges})
}
i.pos = elem
i.path = last.path + string(elem.prefix)
return
}
i.pos = nil
i.path = ""
}

View file

@ -19,7 +19,7 @@ The database provides the following:
* Rich Indexing - Tables can support any number of indexes, which can be simple like
a single field index, or more advanced compound field indexes. Certain types like
UUID can be efficiently compressed from strings into byte indexes for reduces
UUID can be efficiently compressed from strings into byte indexes for reduced
storage requirements.
For the underlying immutable radix trees, see [go-immutable-radix](https://github.com/hashicorp/go-immutable-radix).

View file

@ -9,15 +9,27 @@ import (
// Indexer is an interface used for defining indexes
type Indexer interface {
// FromObject is used to extract an index value from an
// object or to indicate that the index value is missing.
FromObject(raw interface{}) (bool, []byte, error)
// ExactFromArgs is used to build an exact index lookup
// based on arguments
FromArgs(args ...interface{}) ([]byte, error)
}
// SingleIndexer is an interface used for defining indexes
// generating a single entry per object
type SingleIndexer interface {
// FromObject is used to extract an index value from an
// object or to indicate that the index value is missing.
FromObject(raw interface{}) (bool, []byte, error)
}
// MultiIndexer is an interface used for defining indexes
// generating multiple entries per object
type MultiIndexer interface {
// FromObject is used to extract index values from an
// object or to indicate that the index value is missing.
FromObject(raw interface{}) (bool, [][]byte, error)
}
// PrefixIndexer can optionally be implemented for any
// indexes that support prefix based iteration. This may
// not apply to all indexes.
@ -88,6 +100,155 @@ func (s *StringFieldIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) {
return val, nil
}
// StringSliceFieldIndex is used to extract a field from an object
// using reflection and builds an index on that field.
type StringSliceFieldIndex struct {
Field string
Lowercase bool
}
func (s *StringSliceFieldIndex) FromObject(obj interface{}) (bool, [][]byte, error) {
v := reflect.ValueOf(obj)
v = reflect.Indirect(v) // Dereference the pointer if any
fv := v.FieldByName(s.Field)
if !fv.IsValid() {
return false, nil,
fmt.Errorf("field '%s' for %#v is invalid", s.Field, obj)
}
if fv.Kind() != reflect.Slice || fv.Type().Elem().Kind() != reflect.String {
return false, nil, fmt.Errorf("field '%s' is not a string slice", s.Field)
}
length := fv.Len()
vals := make([][]byte, 0, length)
for i := 0; i < fv.Len(); i++ {
val := fv.Index(i).String()
if val == "" {
continue
}
if s.Lowercase {
val = strings.ToLower(val)
}
// Add the null character as a terminator
val += "\x00"
vals = append(vals, []byte(val))
}
if len(vals) == 0 {
return false, nil, nil
}
return true, vals, nil
}
func (s *StringSliceFieldIndex) FromArgs(args ...interface{}) ([]byte, error) {
if len(args) != 1 {
return nil, fmt.Errorf("must provide only a single argument")
}
arg, ok := args[0].(string)
if !ok {
return nil, fmt.Errorf("argument must be a string: %#v", args[0])
}
if s.Lowercase {
arg = strings.ToLower(arg)
}
// Add the null character as a terminator
arg += "\x00"
return []byte(arg), nil
}
func (s *StringSliceFieldIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) {
val, err := s.FromArgs(args...)
if err != nil {
return nil, err
}
// Strip the null terminator, the rest is a prefix
n := len(val)
if n > 0 {
return val[:n-1], nil
}
return val, nil
}
// StringMapFieldIndex is used to extract a field of type map[string]string
// from an object using reflection and builds an index on that field.
type StringMapFieldIndex struct {
Field string
Lowercase bool
}
var MapType = reflect.MapOf(reflect.TypeOf(""), reflect.TypeOf("")).Kind()
func (s *StringMapFieldIndex) FromObject(obj interface{}) (bool, [][]byte, error) {
v := reflect.ValueOf(obj)
v = reflect.Indirect(v) // Dereference the pointer if any
fv := v.FieldByName(s.Field)
if !fv.IsValid() {
return false, nil, fmt.Errorf("field '%s' for %#v is invalid", s.Field, obj)
}
if fv.Kind() != MapType {
return false, nil, fmt.Errorf("field '%s' is not a map[string]string", s.Field)
}
length := fv.Len()
vals := make([][]byte, 0, length)
for _, key := range fv.MapKeys() {
k := key.String()
if k == "" {
continue
}
val := fv.MapIndex(key).String()
if s.Lowercase {
k = strings.ToLower(k)
val = strings.ToLower(val)
}
// Add the null character as a terminator
k += "\x00" + val + "\x00"
vals = append(vals, []byte(k))
}
if len(vals) == 0 {
return false, nil, nil
}
return true, vals, nil
}
func (s *StringMapFieldIndex) FromArgs(args ...interface{}) ([]byte, error) {
if len(args) > 2 || len(args) == 0 {
return nil, fmt.Errorf("must provide one or two arguments")
}
key, ok := args[0].(string)
if !ok {
return nil, fmt.Errorf("argument must be a string: %#v", args[0])
}
if s.Lowercase {
key = strings.ToLower(key)
}
// Add the null character as a terminator
key += "\x00"
if len(args) == 2 {
val, ok := args[1].(string)
if !ok {
return nil, fmt.Errorf("argument must be a string: %#v", args[1])
}
if s.Lowercase {
val = strings.ToLower(val)
}
// Add the null character as a terminator
key += val + "\x00"
}
return []byte(key), nil
}
// UUIDFieldIndex is used to extract a field from an object
// using reflection and builds an index on that field by treating
// it as a UUID. This is an optimization to using a StringFieldIndex
@ -270,7 +431,11 @@ type CompoundIndex struct {
func (c *CompoundIndex) FromObject(raw interface{}) (bool, []byte, error) {
var out []byte
for i, idx := range c.Indexes {
for i, idxRaw := range c.Indexes {
idx, ok := idxRaw.(SingleIndexer)
if !ok {
return false, nil, fmt.Errorf("sub-index %d error: %s", i, "sub-index must be a SingleIndexer")
}
ok, val, err := idx.FromObject(raw)
if err != nil {
return false, nil, fmt.Errorf("sub-index %d error: %v", i, err)

View file

@ -15,6 +15,7 @@ import (
type MemDB struct {
schema *DBSchema
root unsafe.Pointer // *iradix.Tree underneath
primary bool
// There can only be a single writter at once
writer sync.Mutex
@ -31,6 +32,7 @@ func NewMemDB(schema *DBSchema) (*MemDB, error) {
db := &MemDB{
schema: schema,
root: unsafe.Pointer(iradix.New()),
primary: true,
}
if err := db.initialize(); err != nil {
return nil, err
@ -65,6 +67,7 @@ func (db *MemDB) Snapshot() *MemDB {
clone := &MemDB{
schema: db.schema,
root: unsafe.Pointer(db.getRoot()),
primary: false,
}
return clone
}

View file

@ -38,7 +38,7 @@ func (s *TableSchema) Validate() error {
return fmt.Errorf("missing table name")
}
if len(s.Indexes) == 0 {
return fmt.Errorf("missing table schemas for '%s'", s.Name)
return fmt.Errorf("missing table indexes for '%s'", s.Name)
}
if _, ok := s.Indexes["id"]; !ok {
return fmt.Errorf("must have id index")
@ -46,6 +46,9 @@ func (s *TableSchema) Validate() error {
if !s.Indexes["id"].Unique {
return fmt.Errorf("id index must be unique")
}
if _, ok := s.Indexes["id"].Indexer.(SingleIndexer); !ok {
return fmt.Errorf("id index must be a SingleIndexer")
}
for name, index := range s.Indexes {
if name != index.Name {
return fmt.Errorf("index name mis-match for '%s'", name)
@ -72,5 +75,11 @@ func (s *IndexSchema) Validate() error {
if s.Indexer == nil {
return fmt.Errorf("missing index function for '%s'", s.Name)
}
switch s.Indexer.(type) {
case SingleIndexer:
case MultiIndexer:
default:
return fmt.Errorf("indexer for '%s' must be a SingleIndexer or MultiIndexer", s.Name)
}
return nil
}

View file

@ -70,6 +70,11 @@ func (txn *Txn) writableIndex(table, index string) *iradix.Txn {
raw, _ := txn.rootTxn.Get(path)
indexTxn := raw.(*iradix.Tree).Txn()
// If we are the primary DB, enable mutation tracking. Snapshots should
// not notify, otherwise we will trigger watches on the primary DB when
// the writes will not be visible.
indexTxn.TrackMutate(txn.db.primary)
// Keep this open for the duration of the txn
txn.modified[key] = indexTxn
return indexTxn
@ -148,7 +153,8 @@ func (txn *Txn) Insert(table string, obj interface{}) error {
// Get the primary ID of the object
idSchema := tableSchema.Indexes[id]
ok, idVal, err := idSchema.Indexer.FromObject(obj)
idIndexer := idSchema.Indexer.(SingleIndexer)
ok, idVal, err := idIndexer.FromObject(obj)
if err != nil {
return fmt.Errorf("failed to build primary index: %v", err)
}
@ -167,7 +173,19 @@ func (txn *Txn) Insert(table string, obj interface{}) error {
indexTxn := txn.writableIndex(table, name)
// Determine the new index value
ok, val, err := indexSchema.Indexer.FromObject(obj)
var (
ok bool
vals [][]byte
err error
)
switch indexer := indexSchema.Indexer.(type) {
case SingleIndexer:
var val []byte
ok, val, err = indexer.FromObject(obj)
vals = [][]byte{val}
case MultiIndexer:
ok, vals, err = indexer.FromObject(obj)
}
if err != nil {
return fmt.Errorf("failed to build index '%s': %v", name, err)
}
@ -176,16 +194,31 @@ func (txn *Txn) Insert(table string, obj interface{}) error {
// This is done by appending the primary key which must
// be unique anyways.
if ok && !indexSchema.Unique {
val = append(val, idVal...)
for i := range vals {
vals[i] = append(vals[i], idVal...)
}
}
// Handle the update by deleting from the index first
if update {
okExist, valExist, err := indexSchema.Indexer.FromObject(existing)
var (
okExist bool
valsExist [][]byte
err error
)
switch indexer := indexSchema.Indexer.(type) {
case SingleIndexer:
var valExist []byte
okExist, valExist, err = indexer.FromObject(existing)
valsExist = [][]byte{valExist}
case MultiIndexer:
okExist, valsExist, err = indexer.FromObject(existing)
}
if err != nil {
return fmt.Errorf("failed to build index '%s': %v", name, err)
}
if okExist {
for i, valExist := range valsExist {
// Handle non-unique index by computing a unique index.
// This is done by appending the primary key which must
// be unique anyways.
@ -196,11 +229,12 @@ func (txn *Txn) Insert(table string, obj interface{}) error {
// If we are writing to the same index with the same value,
// we can avoid the delete as the insert will overwrite the
// value anyways.
if !bytes.Equal(valExist, val) {
if i >= len(vals) || !bytes.Equal(valExist, vals[i]) {
indexTxn.Delete(valExist)
}
}
}
}
// If there is no index value, either this is an error or an expected
// case and we can skip updating
@ -213,8 +247,10 @@ func (txn *Txn) Insert(table string, obj interface{}) error {
}
// Update the value of the index
for _, val := range vals {
indexTxn.Insert(val, obj)
}
}
return nil
}
@ -233,7 +269,8 @@ func (txn *Txn) Delete(table string, obj interface{}) error {
// Get the primary ID of the object
idSchema := tableSchema.Indexes[id]
ok, idVal, err := idSchema.Indexer.FromObject(obj)
idIndexer := idSchema.Indexer.(SingleIndexer)
ok, idVal, err := idIndexer.FromObject(obj)
if err != nil {
return fmt.Errorf("failed to build primary index: %v", err)
}
@ -253,7 +290,19 @@ func (txn *Txn) Delete(table string, obj interface{}) error {
indexTxn := txn.writableIndex(table, name)
// Handle the update by deleting from the index first
ok, val, err := indexSchema.Indexer.FromObject(existing)
var (
ok bool
vals [][]byte
err error
)
switch indexer := indexSchema.Indexer.(type) {
case SingleIndexer:
var val []byte
ok, val, err = indexer.FromObject(existing)
vals = [][]byte{val}
case MultiIndexer:
ok, vals, err = indexer.FromObject(existing)
}
if err != nil {
return fmt.Errorf("failed to build index '%s': %v", name, err)
}
@ -261,12 +310,14 @@ func (txn *Txn) Delete(table string, obj interface{}) error {
// Handle non-unique index by computing a unique index.
// This is done by appending the primary key which must
// be unique anyways.
for _, val := range vals {
if !indexSchema.Unique {
val = append(val, idVal...)
}
indexTxn.Delete(val)
}
}
}
return nil
}
@ -306,13 +357,13 @@ func (txn *Txn) DeleteAll(table, index string, args ...interface{}) (int, error)
return num, nil
}
// First is used to return the first matching object for
// the given constraints on the index
func (txn *Txn) First(table, index string, args ...interface{}) (interface{}, error) {
// FirstWatch is used to return the first matching object for
// the given constraints on the index along with the watch channel
func (txn *Txn) FirstWatch(table, index string, args ...interface{}) (<-chan struct{}, interface{}, error) {
// Get the index value
indexSchema, val, err := txn.getIndexValue(table, index, args...)
if err != nil {
return nil, err
return nil, nil, err
}
// Get the index itself
@ -320,18 +371,25 @@ func (txn *Txn) First(table, index string, args ...interface{}) (interface{}, er
// Do an exact lookup
if indexSchema.Unique && val != nil && indexSchema.Name == index {
obj, ok := indexTxn.Get(val)
watch, obj, ok := indexTxn.GetWatch(val)
if !ok {
return nil, nil
return watch, nil, nil
}
return obj, nil
return watch, obj, nil
}
// Handle non-unique index by using an iterator and getting the first value
iter := indexTxn.Root().Iterator()
iter.SeekPrefix(val)
watch := iter.SeekPrefixWatch(val)
_, value, _ := iter.Next()
return value, nil
return watch, value, nil
}
// First is used to return the first matching object for
// the given constraints on the index
func (txn *Txn) First(table, index string, args ...interface{}) (interface{}, error) {
_, val, err := txn.FirstWatch(table, index, args...)
return val, err
}
// LongestPrefix is used to fetch the longest prefix match for the given
@ -422,6 +480,7 @@ func (txn *Txn) getIndexValue(table, index string, args ...interface{}) (*IndexS
// ResultIterator is used to iterate over a list of results
// from a Get query on a table.
type ResultIterator interface {
WatchCh() <-chan struct{}
Next() interface{}
}
@ -442,11 +501,12 @@ func (txn *Txn) Get(table, index string, args ...interface{}) (ResultIterator, e
indexIter := indexRoot.Iterator()
// Seek the iterator to the appropriate sub-set
indexIter.SeekPrefix(val)
watchCh := indexIter.SeekPrefixWatch(val)
// Create an iterator
iter := &radixIterator{
iter: indexIter,
watchCh: watchCh,
}
return iter, nil
}
@ -460,10 +520,15 @@ func (txn *Txn) Defer(fn func()) {
}
// radixIterator is used to wrap an underlying iradix iterator.
// This is much mroe efficient than a sliceIterator as we are not
// This is much more efficient than a sliceIterator as we are not
// materializing the entire view.
type radixIterator struct {
iter *iradix.Iterator
watchCh <-chan struct{}
}
func (r *radixIterator) WatchCh() <-chan struct{} {
return r.watchCh
}
func (r *radixIterator) Next() interface{} {

108
vendor/github.com/hashicorp/go-memdb/watch.go generated vendored Normal file
View file

@ -0,0 +1,108 @@
package memdb
import "time"
// WatchSet is a collection of watch channels.
type WatchSet map[<-chan struct{}]struct{}
// NewWatchSet constructs a new watch set.
func NewWatchSet() WatchSet {
return make(map[<-chan struct{}]struct{})
}
// Add appends a watchCh to the WatchSet if non-nil.
func (w WatchSet) Add(watchCh <-chan struct{}) {
if w == nil {
return
}
if _, ok := w[watchCh]; !ok {
w[watchCh] = struct{}{}
}
}
// AddWithLimit appends a watchCh to the WatchSet if non-nil, and if the given
// softLimit hasn't been exceeded. Otherwise, it will watch the given alternate
// channel. It's expected that the altCh will be the same on many calls to this
// function, so you will exceed the soft limit a little bit if you hit this, but
// not by much.
//
// This is useful if you want to track individual items up to some limit, after
// which you watch a higher-level channel (usually a channel from start start of
// an iterator higher up in the radix tree) that will watch a superset of items.
func (w WatchSet) AddWithLimit(softLimit int, watchCh <-chan struct{}, altCh <-chan struct{}) {
// This is safe for a nil WatchSet so we don't need to check that here.
if len(w) < softLimit {
w.Add(watchCh)
} else {
w.Add(altCh)
}
}
// Watch is used to wait for either the watch set to trigger or a timeout.
// Returns true on timeout.
func (w WatchSet) Watch(timeoutCh <-chan time.Time) bool {
if w == nil {
return false
}
if n := len(w); n <= aFew {
idx := 0
chunk := make([]<-chan struct{}, aFew)
for watchCh := range w {
chunk[idx] = watchCh
idx++
}
return watchFew(chunk, timeoutCh)
} else {
return w.watchMany(timeoutCh)
}
}
// watchMany is used if there are many watchers.
func (w WatchSet) watchMany(timeoutCh <-chan time.Time) bool {
// Make a fake timeout channel we can feed into watchFew to cancel all
// the blocking goroutines.
doneCh := make(chan time.Time)
defer close(doneCh)
// Set up a goroutine for each watcher.
triggerCh := make(chan struct{}, 1)
watcher := func(chunk []<-chan struct{}) {
if timeout := watchFew(chunk, doneCh); !timeout {
select {
case triggerCh <- struct{}{}:
default:
}
}
}
// Apportion the watch channels into chunks we can feed into the
// watchFew helper.
idx := 0
chunk := make([]<-chan struct{}, aFew)
for watchCh := range w {
subIdx := idx % aFew
chunk[subIdx] = watchCh
idx++
// Fire off this chunk and start a fresh one.
if idx%aFew == 0 {
go watcher(chunk)
chunk = make([]<-chan struct{}, aFew)
}
}
// Make sure to watch any residual channels in the last chunk.
if idx%aFew != 0 {
go watcher(chunk)
}
// Wait for a channel to trigger or timeout.
select {
case <-triggerCh:
return false
case <-timeoutCh:
return true
}
}

116
vendor/github.com/hashicorp/go-memdb/watch_few.go generated vendored Normal file
View file

@ -0,0 +1,116 @@
//go:generate sh -c "go run watch-gen/main.go >watch_few.go"
package memdb
import(
"time"
)
// aFew gives how many watchers this function is wired to support. You must
// always pass a full slice of this length, but unused channels can be nil.
const aFew = 32
// watchFew is used if there are only a few watchers as a performance
// optimization.
func watchFew(ch []<-chan struct{}, timeoutCh <-chan time.Time) bool {
select {
case <-ch[0]:
return false
case <-ch[1]:
return false
case <-ch[2]:
return false
case <-ch[3]:
return false
case <-ch[4]:
return false
case <-ch[5]:
return false
case <-ch[6]:
return false
case <-ch[7]:
return false
case <-ch[8]:
return false
case <-ch[9]:
return false
case <-ch[10]:
return false
case <-ch[11]:
return false
case <-ch[12]:
return false
case <-ch[13]:
return false
case <-ch[14]:
return false
case <-ch[15]:
return false
case <-ch[16]:
return false
case <-ch[17]:
return false
case <-ch[18]:
return false
case <-ch[19]:
return false
case <-ch[20]:
return false
case <-ch[21]:
return false
case <-ch[22]:
return false
case <-ch[23]:
return false
case <-ch[24]:
return false
case <-ch[25]:
return false
case <-ch[26]:
return false
case <-ch[27]:
return false
case <-ch[28]:
return false
case <-ch[29]:
return false
case <-ch[30]:
return false
case <-ch[31]:
return false
case <-timeoutCh:
return true
}
}

12
vendor/vendor.json vendored
View file

@ -672,16 +672,16 @@
"revision": "3142ddc1d627a166970ddd301bc09cb510c74edc"
},
{
"checksumSHA1": "qmE9mO0WW6ALLpUU81rXDyspP5M=",
"checksumSHA1": "jPxyofQxI1PRPq6LPc6VlcRn5fI=",
"path": "github.com/hashicorp/go-immutable-radix",
"revision": "afc5a0dbb18abdf82c277a7bc01533e81fa1d6b8",
"revisionTime": "2016-06-09T02:05:29Z"
"revision": "76b5f4e390910df355bfb9b16b41899538594a05",
"revisionTime": "2017-01-13T02:29:29Z"
},
{
"checksumSHA1": "/V57CyN7x2NUlHoOzVL5GgGXX84=",
"checksumSHA1": "K8Fsgt1llTXP0EwqdBzvSGdKOKc=",
"path": "github.com/hashicorp/go-memdb",
"revision": "98f52f52d7a476958fa9da671354d270c50661a7",
"revisionTime": "2016-03-01T23:01:42Z"
"revision": "c01f56b44823e8ba697e23c18d12dca984b85aca",
"revisionTime": "2017-01-23T15:32:28Z"
},
{
"path": "github.com/hashicorp/go-msgpack/codec",