Handle record updates

This commit is contained in:
Armon Dadgar 2014-01-07 21:35:44 -08:00
parent ed87aae1cf
commit 49bc2c4075
2 changed files with 165 additions and 3 deletions

View file

@ -211,6 +211,7 @@ func (t *MDBTable) objIndexKeys(obj interface{}) (map[string][]byte, error) {
// Insert is used to insert or update an object // Insert is used to insert or update an object
func (t *MDBTable) Insert(obj interface{}) error { func (t *MDBTable) Insert(obj interface{}) error {
var n int
// Construct the indexes keys // Construct the indexes keys
indexes, err := t.objIndexKeys(obj) indexes, err := t.objIndexKeys(obj)
if err != nil { if err != nil {
@ -227,8 +228,23 @@ func (t *MDBTable) Insert(obj interface{}) error {
} }
defer tx.Abort() defer tx.Abort()
// TODO: Handle updates // Scan and check if this primary key already exists
primaryDbi := tx.dbis[t.Indexes["id"].dbiName]
_, err = tx.tx.Get(primaryDbi, indexes["id"])
if err == mdb.NotFound {
goto AFTER_DELETE
}
// Delete the existing row{
n, err = t.deleteWithIndex(tx, t.Indexes["id"], indexes["id"])
if err != nil {
return err
}
if n != 1 {
return fmt.Errorf("unexpected number of updates: %d", n)
}
AFTER_DELETE:
// Insert with a new row ID // Insert with a new row ID
rowId := t.nextRowID() rowId := t.nextRowID()
encRowId := uint64ToBytes(rowId) encRowId := uint64ToBytes(rowId)
@ -311,6 +327,19 @@ func (t *MDBTable) Delete(index string, parts ...string) (num int, err error) {
} }
defer tx.Abort() defer tx.Abort()
// Delete with the index
num, err = t.deleteWithIndex(tx, idx, key)
if err != nil {
return 0, err
}
// Attempt a commit
return num, tx.Commit()
}
// deleteWithIndex deletes all associated rows while scanning
// a given index for a key prefix.
func (t *MDBTable) deleteWithIndex(tx *MDBTxn, idx *MDBIndex, key []byte) (num int, err error) {
// Handle an error while deleting // Handle an error while deleting
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@ -332,7 +361,7 @@ func (t *MDBTable) Delete(index string, parts ...string) (num int, err error) {
// Delete the indexes we are not iterating // Delete the indexes we are not iterating
for name, otherIdx := range t.Indexes { for name, otherIdx := range t.Indexes {
if name == index { if name == idx.name {
continue continue
} }
dbi := tx.dbis[otherIdx.dbiName] dbi := tx.dbis[otherIdx.dbiName]
@ -355,7 +384,7 @@ func (t *MDBTable) Delete(index string, parts ...string) (num int, err error) {
} }
// Return the deleted count // Return the deleted count
return num, tx.Commit() return num, nil
} }
// Initializes an index and returns a potential error // Initializes an index and returns a potential error

View file

@ -374,3 +374,136 @@ func TestMDBTableDelete(t *testing.T) {
t.Fatalf("expect 0 results: %#v", res) t.Fatalf("expect 0 results: %#v", res)
} }
} }
func TestMDBTableUpdate(t *testing.T) {
dir, env := testMDBEnv(t)
defer os.RemoveAll(dir)
defer env.Close()
table := &MDBTable{
Env: env,
Name: "test",
Indexes: map[string]*MDBIndex{
"id": &MDBIndex{
Unique: true,
Fields: []string{"Key"},
},
"name": &MDBIndex{
Fields: []string{"First", "Last"},
},
"country": &MDBIndex{
Fields: []string{"Country"},
},
},
Encoder: MockEncoder,
Decoder: MockDecoder,
}
if err := table.Init(); err != nil {
t.Fatalf("err: %v", err)
}
objs := []*MockData{
&MockData{
Key: "1",
First: "Kevin",
Last: "Smith",
Country: "USA",
},
&MockData{
Key: "2",
First: "Kevin",
Last: "Wang",
Country: "USA",
},
&MockData{
Key: "3",
First: "Bernardo",
Last: "Torres",
Country: "Mexico",
},
&MockData{
Key: "1",
First: "Roger",
Last: "Rodrigez",
Country: "Mexico",
},
&MockData{
Key: "2",
First: "Anna",
Last: "Smith",
Country: "UK",
},
&MockData{
Key: "3",
First: "Ahmad",
Last: "Badari",
Country: "Iran",
},
}
// Insert and update some mock objects
for _, obj := range objs {
if err := table.Insert(obj); err != nil {
t.Fatalf("err: %v", err)
}
}
// Verify with some gets
res, err := table.Get("id", "1")
if err != nil {
t.Fatalf("err: %v", err)
}
if len(res) != 1 {
t.Fatalf("expect 1 result: %#v", res)
}
if !reflect.DeepEqual(res[0], objs[3]) {
t.Fatalf("bad: %#v", res[0])
}
res, err = table.Get("name", "Kevin")
if err != nil {
t.Fatalf("err: %v", err)
}
if len(res) != 0 {
t.Fatalf("expect 0 result: %#v", res)
}
res, err = table.Get("name", "Ahmad")
if err != nil {
t.Fatalf("err: %v", err)
}
if len(res) != 1 {
t.Fatalf("expect 1 result: %#v", res)
}
if !reflect.DeepEqual(res[0], objs[5]) {
t.Fatalf("bad: %#v", res[0])
}
res, err = table.Get("country", "Mexico")
if err != nil {
t.Fatalf("err: %v", err)
}
if len(res) != 1 {
t.Fatalf("expect 1 result: %#v", res)
}
if !reflect.DeepEqual(res[0], objs[3]) {
t.Fatalf("bad: %#v", res[0])
}
res, err = table.Get("id")
if err != nil {
t.Fatalf("err: %v", err)
}
if len(res) != 3 {
t.Fatalf("expect 3 result: %#v", res)
}
if !reflect.DeepEqual(res[0], objs[3]) {
t.Fatalf("bad: %#v", res[0])
}
if !reflect.DeepEqual(res[1], objs[4]) {
t.Fatalf("bad: %#v", res[1])
}
if !reflect.DeepEqual(res[2], objs[5]) {
t.Fatalf("bad: %#v", res[2])
}
}