diff --git a/nomad/memdb/index.go b/nomad/memdb/index.go index 3d70bfc79..b11494ffc 100644 --- a/nomad/memdb/index.go +++ b/nomad/memdb/index.go @@ -1,6 +1,7 @@ package memdb import ( + "encoding/hex" "fmt" "reflect" "strings" @@ -47,3 +48,90 @@ func (s *StringFieldIndex) FromArgs(args ...interface{}) ([]byte, error) { } return []byte(arg), nil } + +// UUIDFieldIndex is used to extract a field from an object +// using reflection and builds an index on that field by treating +// it as a UUID. This is an optimization to using a StringFieldIndex +// as the UUID can be more compactly represented in byte form. +type UUIDFieldIndex struct { + Field string +} + +func (u *UUIDFieldIndex) FromObject(obj interface{}) (bool, []byte, error) { + v := reflect.ValueOf(obj) + v = reflect.Indirect(v) // Derefence the pointer if any + + fv := v.FieldByName(u.Field) + if !fv.IsValid() { + return false, nil, + fmt.Errorf("field '%s' for %#v is invalid", u.Field, obj) + } + + val := fv.String() + if val == "" { + return false, nil, nil + } + + buf, err := u.parseString(val) + return true, buf, err +} + +func (u *UUIDFieldIndex) FromArgs(args ...interface{}) ([]byte, error) { + if len(args) != 1 { + return nil, fmt.Errorf("must provide only a single argument") + } + switch arg := args[0].(type) { + case string: + return u.parseString(arg) + case []byte: + if len(arg) != 16 { + return nil, fmt.Errorf("byte slice must be 16 characters") + } + return arg, nil + default: + return nil, + fmt.Errorf("argument must be a string or byte slice: %#v", args[0]) + } +} + +func (u *UUIDFieldIndex) parseString(s string) ([]byte, error) { + // Verify the length + if len(s) != 36 { + return nil, fmt.Errorf("UUID must be 36 characters") + } + + // Decode each of the parts + part1, err := hex.DecodeString(s[0:8]) + if err != nil { + return nil, fmt.Errorf("Invalid UUID: %v", err) + } + + part2, err := hex.DecodeString(s[9:13]) + if err != nil { + return nil, fmt.Errorf("Invalid UUID: %v", err) + } + + part3, err := hex.DecodeString(s[14:18]) + if err != nil { + return nil, fmt.Errorf("Invalid UUID: %v", err) + } + + part4, err := hex.DecodeString(s[19:23]) + if err != nil { + return nil, fmt.Errorf("Invalid UUID: %v", err) + } + + part5, err := hex.DecodeString(s[24:]) + if err != nil { + return nil, fmt.Errorf("Invalid UUID: %v", err) + } + + // Copy into a single buffer + buf := make([]byte, 16) + copy(buf[0:4], part1) + copy(buf[4:6], part2) + copy(buf[6:8], part3) + copy(buf[8:10], part4) + copy(buf[10:16], part5) + return buf, nil +} diff --git a/nomad/memdb/index_test.go b/nomad/memdb/index_test.go index 41d755df8..714512512 100644 --- a/nomad/memdb/index_test.go +++ b/nomad/memdb/index_test.go @@ -1,6 +1,11 @@ package memdb -import "testing" +import ( + "bytes" + crand "crypto/rand" + "fmt" + "testing" +) type TestObject struct { ID string @@ -92,3 +97,100 @@ func TestStringFieldIndex_FromArgs(t *testing.T) { t.Fatalf("foo") } } + +func TestUUIDFeldIndex_parseString(t *testing.T) { + u := &UUIDFieldIndex{} + _, err := u.parseString("invalid") + if err == nil { + t.Fatalf("should error") + } + + buf, uuid := generateUUID() + + out, err := u.parseString(uuid) + if err != nil { + t.Fatalf("err: %v", err) + } + + if !bytes.Equal(out, buf) { + t.Fatalf("bad: %#v %#v", out, buf) + } +} + +func TestUUIDFieldIndex_FromObject(t *testing.T) { + obj := testObj() + uuidBuf, uuid := generateUUID() + obj.Foo = uuid + indexer := &UUIDFieldIndex{"Foo"} + + ok, val, err := indexer.FromObject(obj) + if err != nil { + t.Fatalf("err: %v", err) + } + if !bytes.Equal(uuidBuf, val) { + t.Fatalf("bad: %s", val) + } + if !ok { + t.Fatalf("should be ok") + } + + badField := &UUIDFieldIndex{"NA"} + ok, val, err = badField.FromObject(obj) + if err == nil { + t.Fatalf("should get error") + } + + emptyField := &UUIDFieldIndex{"Empty"} + ok, val, err = emptyField.FromObject(obj) + if err != nil { + t.Fatalf("err: %v", err) + } + if ok { + t.Fatalf("should not ok") + } +} + +func TestUUIDFieldIndex_FromArgs(t *testing.T) { + indexer := &UUIDFieldIndex{"Foo"} + _, err := indexer.FromArgs() + if err == nil { + t.Fatalf("should get err") + } + + _, err = indexer.FromArgs(42) + if err == nil { + t.Fatalf("should get err") + } + + uuidBuf, uuid := generateUUID() + + val, err := indexer.FromArgs(uuid) + if err != nil { + t.Fatalf("err: %v", err) + } + if !bytes.Equal(uuidBuf, val) { + t.Fatalf("foo") + } + + val, err = indexer.FromArgs(uuidBuf) + if err != nil { + t.Fatalf("err: %v", err) + } + if !bytes.Equal(uuidBuf, val) { + t.Fatalf("foo") + } +} + +func generateUUID() ([]byte, string) { + buf := make([]byte, 16) + if _, err := crand.Read(buf); err != nil { + panic(fmt.Errorf("failed to read random bytes: %v", err)) + } + uuid := fmt.Sprintf("%08x-%04x-%04x-%04x-%12x", + buf[0:4], + buf[4:6], + buf[6:8], + buf[8:10], + buf[10:16]) + return buf, uuid +}