memdb: Adding UUID indexer

This commit is contained in:
Armon Dadgar 2015-06-16 14:02:00 -07:00
parent ff6945a28b
commit 064b8c49d1
2 changed files with 191 additions and 1 deletions

View file

@ -1,6 +1,7 @@
package memdb
import (
"encoding/hex"
"fmt"
"reflect"
"strings"
@ -47,3 +48,90 @@ func (s *StringFieldIndex) FromArgs(args ...interface{}) ([]byte, error) {
}
return []byte(arg), nil
}
// UUIDFieldIndex is used to extract a field from an object
// using reflection and builds an index on that field by treating
// it as a UUID. This is an optimization to using a StringFieldIndex
// as the UUID can be more compactly represented in byte form.
type UUIDFieldIndex struct {
Field string
}
func (u *UUIDFieldIndex) FromObject(obj interface{}) (bool, []byte, error) {
v := reflect.ValueOf(obj)
v = reflect.Indirect(v) // Derefence the pointer if any
fv := v.FieldByName(u.Field)
if !fv.IsValid() {
return false, nil,
fmt.Errorf("field '%s' for %#v is invalid", u.Field, obj)
}
val := fv.String()
if val == "" {
return false, nil, nil
}
buf, err := u.parseString(val)
return true, buf, err
}
func (u *UUIDFieldIndex) FromArgs(args ...interface{}) ([]byte, error) {
if len(args) != 1 {
return nil, fmt.Errorf("must provide only a single argument")
}
switch arg := args[0].(type) {
case string:
return u.parseString(arg)
case []byte:
if len(arg) != 16 {
return nil, fmt.Errorf("byte slice must be 16 characters")
}
return arg, nil
default:
return nil,
fmt.Errorf("argument must be a string or byte slice: %#v", args[0])
}
}
func (u *UUIDFieldIndex) parseString(s string) ([]byte, error) {
// Verify the length
if len(s) != 36 {
return nil, fmt.Errorf("UUID must be 36 characters")
}
// Decode each of the parts
part1, err := hex.DecodeString(s[0:8])
if err != nil {
return nil, fmt.Errorf("Invalid UUID: %v", err)
}
part2, err := hex.DecodeString(s[9:13])
if err != nil {
return nil, fmt.Errorf("Invalid UUID: %v", err)
}
part3, err := hex.DecodeString(s[14:18])
if err != nil {
return nil, fmt.Errorf("Invalid UUID: %v", err)
}
part4, err := hex.DecodeString(s[19:23])
if err != nil {
return nil, fmt.Errorf("Invalid UUID: %v", err)
}
part5, err := hex.DecodeString(s[24:])
if err != nil {
return nil, fmt.Errorf("Invalid UUID: %v", err)
}
// Copy into a single buffer
buf := make([]byte, 16)
copy(buf[0:4], part1)
copy(buf[4:6], part2)
copy(buf[6:8], part3)
copy(buf[8:10], part4)
copy(buf[10:16], part5)
return buf, nil
}

View file

@ -1,6 +1,11 @@
package memdb
import "testing"
import (
"bytes"
crand "crypto/rand"
"fmt"
"testing"
)
type TestObject struct {
ID string
@ -92,3 +97,100 @@ func TestStringFieldIndex_FromArgs(t *testing.T) {
t.Fatalf("foo")
}
}
func TestUUIDFeldIndex_parseString(t *testing.T) {
u := &UUIDFieldIndex{}
_, err := u.parseString("invalid")
if err == nil {
t.Fatalf("should error")
}
buf, uuid := generateUUID()
out, err := u.parseString(uuid)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(out, buf) {
t.Fatalf("bad: %#v %#v", out, buf)
}
}
func TestUUIDFieldIndex_FromObject(t *testing.T) {
obj := testObj()
uuidBuf, uuid := generateUUID()
obj.Foo = uuid
indexer := &UUIDFieldIndex{"Foo"}
ok, val, err := indexer.FromObject(obj)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(uuidBuf, val) {
t.Fatalf("bad: %s", val)
}
if !ok {
t.Fatalf("should be ok")
}
badField := &UUIDFieldIndex{"NA"}
ok, val, err = badField.FromObject(obj)
if err == nil {
t.Fatalf("should get error")
}
emptyField := &UUIDFieldIndex{"Empty"}
ok, val, err = emptyField.FromObject(obj)
if err != nil {
t.Fatalf("err: %v", err)
}
if ok {
t.Fatalf("should not ok")
}
}
func TestUUIDFieldIndex_FromArgs(t *testing.T) {
indexer := &UUIDFieldIndex{"Foo"}
_, err := indexer.FromArgs()
if err == nil {
t.Fatalf("should get err")
}
_, err = indexer.FromArgs(42)
if err == nil {
t.Fatalf("should get err")
}
uuidBuf, uuid := generateUUID()
val, err := indexer.FromArgs(uuid)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(uuidBuf, val) {
t.Fatalf("foo")
}
val, err = indexer.FromArgs(uuidBuf)
if err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(uuidBuf, val) {
t.Fatalf("foo")
}
}
func generateUUID() ([]byte, string) {
buf := make([]byte, 16)
if _, err := crand.Read(buf); err != nil {
panic(fmt.Errorf("failed to read random bytes: %v", err))
}
uuid := fmt.Sprintf("%08x-%04x-%04x-%04x-%12x",
buf[0:4],
buf[4:6],
buf[6:8],
buf[8:10],
buf[10:16])
return buf, uuid
}