From 49bc2c407599d411dd08258d89b1d07fb38ae702 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Tue, 7 Jan 2014 21:35:44 -0800 Subject: [PATCH] Handle record updates --- consul/mdb_table.go | 35 ++++++++++- consul/mdb_table_test.go | 133 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 165 insertions(+), 3 deletions(-) diff --git a/consul/mdb_table.go b/consul/mdb_table.go index 61c5f6f93..fb03fc6ce 100644 --- a/consul/mdb_table.go +++ b/consul/mdb_table.go @@ -211,6 +211,7 @@ func (t *MDBTable) objIndexKeys(obj interface{}) (map[string][]byte, error) { // Insert is used to insert or update an object func (t *MDBTable) Insert(obj interface{}) error { + var n int // Construct the indexes keys indexes, err := t.objIndexKeys(obj) if err != nil { @@ -227,8 +228,23 @@ func (t *MDBTable) Insert(obj interface{}) error { } defer tx.Abort() - // TODO: Handle updates + // Scan and check if this primary key already exists + primaryDbi := tx.dbis[t.Indexes["id"].dbiName] + _, err = tx.tx.Get(primaryDbi, indexes["id"]) + if err == mdb.NotFound { + goto AFTER_DELETE + } + // Delete the existing row{ + n, err = t.deleteWithIndex(tx, t.Indexes["id"], indexes["id"]) + if err != nil { + return err + } + if n != 1 { + return fmt.Errorf("unexpected number of updates: %d", n) + } + +AFTER_DELETE: // Insert with a new row ID rowId := t.nextRowID() encRowId := uint64ToBytes(rowId) @@ -311,6 +327,19 @@ func (t *MDBTable) Delete(index string, parts ...string) (num int, err error) { } defer tx.Abort() + // Delete with the index + num, err = t.deleteWithIndex(tx, idx, key) + if err != nil { + return 0, err + } + + // Attempt a commit + return num, tx.Commit() +} + +// deleteWithIndex deletes all associated rows while scanning +// a given index for a key prefix. +func (t *MDBTable) deleteWithIndex(tx *MDBTxn, idx *MDBIndex, key []byte) (num int, err error) { // Handle an error while deleting defer func() { if r := recover(); r != nil { @@ -332,7 +361,7 @@ func (t *MDBTable) Delete(index string, parts ...string) (num int, err error) { // Delete the indexes we are not iterating for name, otherIdx := range t.Indexes { - if name == index { + if name == idx.name { continue } dbi := tx.dbis[otherIdx.dbiName] @@ -355,7 +384,7 @@ func (t *MDBTable) Delete(index string, parts ...string) (num int, err error) { } // Return the deleted count - return num, tx.Commit() + return num, nil } // Initializes an index and returns a potential error diff --git a/consul/mdb_table_test.go b/consul/mdb_table_test.go index 15fd5b356..2b53ed7be 100644 --- a/consul/mdb_table_test.go +++ b/consul/mdb_table_test.go @@ -374,3 +374,136 @@ func TestMDBTableDelete(t *testing.T) { t.Fatalf("expect 0 results: %#v", res) } } + +func TestMDBTableUpdate(t *testing.T) { + dir, env := testMDBEnv(t) + defer os.RemoveAll(dir) + defer env.Close() + + table := &MDBTable{ + Env: env, + Name: "test", + Indexes: map[string]*MDBIndex{ + "id": &MDBIndex{ + Unique: true, + Fields: []string{"Key"}, + }, + "name": &MDBIndex{ + Fields: []string{"First", "Last"}, + }, + "country": &MDBIndex{ + Fields: []string{"Country"}, + }, + }, + Encoder: MockEncoder, + Decoder: MockDecoder, + } + if err := table.Init(); err != nil { + t.Fatalf("err: %v", err) + } + + objs := []*MockData{ + &MockData{ + Key: "1", + First: "Kevin", + Last: "Smith", + Country: "USA", + }, + &MockData{ + Key: "2", + First: "Kevin", + Last: "Wang", + Country: "USA", + }, + &MockData{ + Key: "3", + First: "Bernardo", + Last: "Torres", + Country: "Mexico", + }, + &MockData{ + Key: "1", + First: "Roger", + Last: "Rodrigez", + Country: "Mexico", + }, + &MockData{ + Key: "2", + First: "Anna", + Last: "Smith", + Country: "UK", + }, + &MockData{ + Key: "3", + First: "Ahmad", + Last: "Badari", + Country: "Iran", + }, + } + + // Insert and update some mock objects + for _, obj := range objs { + if err := table.Insert(obj); err != nil { + t.Fatalf("err: %v", err) + } + } + + // Verify with some gets + res, err := table.Get("id", "1") + if err != nil { + t.Fatalf("err: %v", err) + } + if len(res) != 1 { + t.Fatalf("expect 1 result: %#v", res) + } + if !reflect.DeepEqual(res[0], objs[3]) { + t.Fatalf("bad: %#v", res[0]) + } + + res, err = table.Get("name", "Kevin") + if err != nil { + t.Fatalf("err: %v", err) + } + if len(res) != 0 { + t.Fatalf("expect 0 result: %#v", res) + } + + res, err = table.Get("name", "Ahmad") + if err != nil { + t.Fatalf("err: %v", err) + } + if len(res) != 1 { + t.Fatalf("expect 1 result: %#v", res) + } + if !reflect.DeepEqual(res[0], objs[5]) { + t.Fatalf("bad: %#v", res[0]) + } + + res, err = table.Get("country", "Mexico") + if err != nil { + t.Fatalf("err: %v", err) + } + if len(res) != 1 { + t.Fatalf("expect 1 result: %#v", res) + } + if !reflect.DeepEqual(res[0], objs[3]) { + t.Fatalf("bad: %#v", res[0]) + } + + res, err = table.Get("id") + if err != nil { + t.Fatalf("err: %v", err) + } + if len(res) != 3 { + t.Fatalf("expect 3 result: %#v", res) + } + if !reflect.DeepEqual(res[0], objs[3]) { + t.Fatalf("bad: %#v", res[0]) + } + if !reflect.DeepEqual(res[1], objs[4]) { + t.Fatalf("bad: %#v", res[1]) + } + if !reflect.DeepEqual(res[2], objs[5]) { + t.Fatalf("bad: %#v", res[2]) + } +}