diff --git a/consul/mdb_table.go b/consul/mdb_table.go index 23856fc5f..37eb52503 100644 --- a/consul/mdb_table.go +++ b/consul/mdb_table.go @@ -389,10 +389,32 @@ func (t *MDBTable) GetTxn(tx *MDBTxn, index string, parts ...string) ([]interfac // Accumulate the results var results []interface{} - err = idx.iterate(tx, key, func(encRowId, res []byte) bool { + err = idx.iterate(tx, key, func(encRowId, res []byte) (bool, bool) { obj := t.Decoder(res) results = append(results, obj) - return false + return false, false + }) + + return results, err +} + +// GetTxnLimit is like GetTxn limits the maximum number of +// rows it will return +func (t *MDBTable) GetTxnLimit(tx *MDBTxn, limit int, index string, parts ...string) ([]interface{}, error) { + // Get the associated index + idx, key, err := t.getIndex(index, parts) + if err != nil { + return nil, err + } + + // Accumulate the results + var results []interface{} + num := 0 + err = idx.iterate(tx, key, func(encRowId, res []byte) (bool, bool) { + num++ + obj := t.Decoder(res) + results = append(results, obj) + return false, num == limit }) return results, err @@ -412,10 +434,10 @@ func (t *MDBTable) StreamTxn(stream chan<- interface{}, tx *MDBTxn, index string } // Stream the results - err = idx.iterate(tx, key, func(encRowId, res []byte) bool { + err = idx.iterate(tx, key, func(encRowId, res []byte) (bool, bool) { obj := t.Decoder(res) stream <- obj - return false + return false, false }) return err @@ -508,7 +530,7 @@ func (t *MDBTable) innerDeleteWithIndex(tx *MDBTxn, idx *MDBIndex, key []byte) ( }() // Delete everything as we iterate - err = idx.iterate(tx, key, func(encRowId, res []byte) bool { + err = idx.iterate(tx, key, func(encRowId, res []byte) (bool, bool) { // Get the object obj := t.Decoder(res) @@ -542,7 +564,7 @@ func (t *MDBTable) innerDeleteWithIndex(tx *MDBTxn, idx *MDBIndex, key []byte) ( // Delete the object num++ - return true + return true, false }) if err != nil { return 0, err @@ -644,7 +666,7 @@ func (i *MDBIndex) keyFromParts(parts ...string) []byte { // and invoking the cb with each row. We dereference the rowid, // and only return the object row func (i *MDBIndex) iterate(tx *MDBTxn, prefix []byte, - cb func(encRowId, res []byte) bool) error { + cb func(encRowId, res []byte) (bool, bool)) error { table := tx.dbis[i.table.Name] // If virtual, use the correct DBI @@ -667,8 +689,9 @@ func (i *MDBIndex) iterate(tx *MDBTxn, prefix []byte, var key, encRowId, objBytes []byte first := true + shouldStop := false shouldDelete := false - for { + for !shouldStop { if first && len(prefix) > 0 { first = false key, encRowId, err = cursor.Get(prefix, mdb.SET_RANGE) @@ -708,7 +731,8 @@ func (i *MDBIndex) iterate(tx *MDBTxn, prefix []byte, } // Invoke the cb - if shouldDelete = cb(encRowId, objBytes); shouldDelete { + shouldDelete, shouldStop = cb(encRowId, objBytes) + if shouldDelete { if err := cursor.Del(0); err != nil { return fmt.Errorf("delete failed: %v", err) } diff --git a/consul/mdb_table_test.go b/consul/mdb_table_test.go index e70c9131f..73e4001d1 100644 --- a/consul/mdb_table_test.go +++ b/consul/mdb_table_test.go @@ -2,12 +2,13 @@ package consul import ( "bytes" - "github.com/armon/gomdb" - "github.com/hashicorp/go-msgpack/codec" "io/ioutil" "os" "reflect" "testing" + + "github.com/armon/gomdb" + "github.com/hashicorp/go-msgpack/codec" ) type MockData struct { @@ -970,3 +971,78 @@ func TestMDBTableStream(t *testing.T) { t.Fatalf("bad index: %d", idx) } } + +func TestMDBTableGetTxnLimit(t *testing.T) { + dir, env := testMDBEnv(t) + defer os.RemoveAll(dir) + defer env.Close() + + table := &MDBTable{ + Env: env, + Name: "test", + Indexes: map[string]*MDBIndex{ + "id": &MDBIndex{ + Unique: true, + Fields: []string{"Key"}, + }, + "name": &MDBIndex{ + Fields: []string{"First", "Last"}, + }, + "country": &MDBIndex{ + Fields: []string{"Country"}, + }, + }, + Encoder: MockEncoder, + Decoder: MockDecoder, + } + if err := table.Init(); err != nil { + t.Fatalf("err: %v", err) + } + + objs := []*MockData{ + &MockData{ + Key: "1", + First: "Kevin", + Last: "Smith", + Country: "USA", + }, + &MockData{ + Key: "2", + First: "Kevin", + Last: "Wang", + Country: "USA", + }, + &MockData{ + Key: "3", + First: "Bernardo", + Last: "Torres", + Country: "Mexico", + }, + } + + // Insert some mock objects + for idx, obj := range objs { + if err := table.Insert(obj); err != nil { + t.Fatalf("err: %v", err) + } + if err := table.SetLastIndex(uint64(idx + 1)); err != nil { + t.Fatalf("err: %v", err) + } + } + + // Start a readonly txn + tx, err := table.StartTxn(true, nil) + if err != nil { + panic(err) + } + defer tx.Abort() + + // Verify with some gets + res, err := table.GetTxnLimit(tx, 2, "id") + if err != nil { + t.Fatalf("err: %v", err) + } + if len(res) != 2 { + t.Fatalf("expect 2 result: %#v", res) + } +}