Added Equal method to Comparator interface

Summary:
In some cases, equality comparisons can be done more efficiently than three-way
comparisons. There are quite a few places in the code where we only care about
equality. This patch adds an Equal() method that defaults to using the
Compare() method.

Test Plan: make clean all check

Reviewers: rven, anthony, yhchiang, igor, sdong

Reviewed By: igor

Subscribers: dhruba, leveldb

Differential Revision: https://reviews.facebook.net/D46233
This commit is contained in:
Andres Noetzli 2015-09-08 15:30:49 -07:00
parent 7a31960ee9
commit 6bdc484fd8
16 changed files with 56 additions and 44 deletions

View File

@ -10,6 +10,7 @@
### Public API Changes ### Public API Changes
* Removed class Env::RandomRWFile and Env::NewRandomRWFile(). * Removed class Env::RandomRWFile and Env::NewRandomRWFile().
* Renamed DBOptions.num_subcompactions to DBOptions.max_subcompactions to make the name better match the actual functionality of the option. * Renamed DBOptions.num_subcompactions to DBOptions.max_subcompactions to make the name better match the actual functionality of the option.
* Added Equal() method to the Comparator interface that can optionally be overwritten in cases where equality comparisons can be done more efficiently than three-way comparisons.
## 3.13.0 (8/6/2015) ## 3.13.0 (8/6/2015)
### New Features ### New Features

View File

@ -149,8 +149,8 @@ Status BuildTable(
// first key), then we skip it, since it is an older version. // first key), then we skip it, since it is an older version.
// Otherwise we output the key and mark it as the "new" previous key. // Otherwise we output the key and mark it as the "new" previous key.
if (!has_current_user_key || if (!has_current_user_key ||
internal_comparator.user_comparator()->Compare( !internal_comparator.user_comparator()->Equal(
ikey.user_key, current_user_key.GetKey()) != 0) { ikey.user_key, current_user_key.GetKey())) {
// First occurrence of this user key // First occurrence of this user key
current_user_key.SetKey(ikey.user_key); current_user_key.SetKey(ikey.user_key);
has_current_user_key = true; has_current_user_key = true;

View File

@ -634,8 +634,8 @@ void CompactionJob::ProcessKeyValueCompaction(SubCompactionState* sub_compact) {
} }
if (!has_current_user_key || if (!has_current_user_key ||
cfd->user_comparator()->Compare(ikey.user_key, !cfd->user_comparator()->Equal(ikey.user_key,
current_user_key.GetKey()) != 0) { current_user_key.GetKey())) {
// First occurrence of this user key // First occurrence of this user key
current_user_key.SetKey(ikey.user_key); current_user_key.SetKey(ikey.user_key);
has_current_user_key = true; has_current_user_key = true;

View File

@ -298,7 +298,7 @@ void DBIter::MergeValuesNewToOld() {
continue; continue;
} }
if (user_comparator_->Compare(ikey.user_key, saved_key_.GetKey()) != 0) { if (!user_comparator_->Equal(ikey.user_key, saved_key_.GetKey())) {
// hit the next user key, stop right here // hit the next user key, stop right here
break; break;
} }
@ -400,7 +400,7 @@ void DBIter::PrevInternal() {
return; return;
} }
FindParseableKey(&ikey, kReverse); FindParseableKey(&ikey, kReverse);
if (user_comparator_->Compare(ikey.user_key, saved_key_.GetKey()) == 0) { if (user_comparator_->Equal(ikey.user_key, saved_key_.GetKey())) {
FindPrevUserKey(); FindPrevUserKey();
} }
return; return;
@ -409,8 +409,7 @@ void DBIter::PrevInternal() {
break; break;
} }
FindParseableKey(&ikey, kReverse); FindParseableKey(&ikey, kReverse);
if (user_comparator_->Compare(ikey.user_key, saved_key_.GetKey()) == 0) { if (user_comparator_->Equal(ikey.user_key, saved_key_.GetKey())) {
FindPrevUserKey(); FindPrevUserKey();
} }
} }
@ -434,7 +433,7 @@ bool DBIter::FindValueForCurrentKey() {
size_t num_skipped = 0; size_t num_skipped = 0;
while (iter_->Valid() && ikey.sequence <= sequence_ && while (iter_->Valid() && ikey.sequence <= sequence_ &&
(user_comparator_->Compare(ikey.user_key, saved_key_.GetKey()) == 0)) { user_comparator_->Equal(ikey.user_key, saved_key_.GetKey())) {
// We iterate too much: let's use Seek() to avoid too much key comparisons // We iterate too much: let's use Seek() to avoid too much key comparisons
if (num_skipped >= max_skip_) { if (num_skipped >= max_skip_) {
return FindValueForCurrentKeyUsingSeek(); return FindValueForCurrentKeyUsingSeek();
@ -461,7 +460,7 @@ bool DBIter::FindValueForCurrentKey() {
} }
PERF_COUNTER_ADD(internal_key_skipped_count, 1); PERF_COUNTER_ADD(internal_key_skipped_count, 1);
assert(user_comparator_->Compare(ikey.user_key, saved_key_.GetKey()) == 0); assert(user_comparator_->Equal(ikey.user_key, saved_key_.GetKey()));
iter_->Prev(); iter_->Prev();
++num_skipped; ++num_skipped;
FindParseableKey(&ikey, kReverse); FindParseableKey(&ikey, kReverse);
@ -531,7 +530,7 @@ bool DBIter::FindValueForCurrentKeyUsingSeek() {
// in operands // in operands
std::deque<std::string> operands; std::deque<std::string> operands;
while (iter_->Valid() && while (iter_->Valid() &&
(user_comparator_->Compare(ikey.user_key, saved_key_.GetKey()) == 0) && user_comparator_->Equal(ikey.user_key, saved_key_.GetKey()) &&
ikey.type == kTypeMerge) { ikey.type == kTypeMerge) {
operands.push_front(iter_->value().ToString()); operands.push_front(iter_->value().ToString());
iter_->Next(); iter_->Next();
@ -539,7 +538,7 @@ bool DBIter::FindValueForCurrentKeyUsingSeek() {
} }
if (!iter_->Valid() || if (!iter_->Valid() ||
(user_comparator_->Compare(ikey.user_key, saved_key_.GetKey()) != 0) || !user_comparator_->Equal(ikey.user_key, saved_key_.GetKey()) ||
ikey.type == kTypeDeletion) { ikey.type == kTypeDeletion) {
{ {
StopWatchNano timer(env_, statistics_ != nullptr); StopWatchNano timer(env_, statistics_ != nullptr);
@ -550,7 +549,7 @@ bool DBIter::FindValueForCurrentKeyUsingSeek() {
} }
// Make iter_ valid and point to saved_key_ // Make iter_ valid and point to saved_key_
if (!iter_->Valid() || if (!iter_->Valid() ||
(user_comparator_->Compare(ikey.user_key, saved_key_.GetKey()) != 0)) { !user_comparator_->Equal(ikey.user_key, saved_key_.GetKey())) {
iter_->Seek(last_key); iter_->Seek(last_key);
RecordTick(statistics_, NUMBER_OF_RESEEKS_IN_ITERATION); RecordTick(statistics_, NUMBER_OF_RESEEKS_IN_ITERATION);
} }
@ -581,7 +580,7 @@ void DBIter::FindNextUserKey() {
ParsedInternalKey ikey; ParsedInternalKey ikey;
FindParseableKey(&ikey, kForward); FindParseableKey(&ikey, kForward);
while (iter_->Valid() && while (iter_->Valid() &&
user_comparator_->Compare(ikey.user_key, saved_key_.GetKey()) != 0) { !user_comparator_->Equal(ikey.user_key, saved_key_.GetKey())) {
iter_->Next(); iter_->Next();
FindParseableKey(&ikey, kForward); FindParseableKey(&ikey, kForward);
} }

View File

@ -404,8 +404,8 @@ static bool SaveValue(void* arg, const char* entry) {
// all entries with overly large sequence numbers. // all entries with overly large sequence numbers.
uint32_t key_length; uint32_t key_length;
const char* key_ptr = GetVarint32Ptr(entry, entry + 5, &key_length); const char* key_ptr = GetVarint32Ptr(entry, entry + 5, &key_length);
if (s->mem->GetInternalKeyComparator().user_comparator()->Compare( if (s->mem->GetInternalKeyComparator().user_comparator()->Equal(
Slice(key_ptr, key_length - 8), s->key->user_key()) == 0) { Slice(key_ptr, key_length - 8), s->key->user_key())) {
// Correct user key // Correct user key
const uint64_t tag = DecodeFixed64(key_ptr + key_length - 8); const uint64_t tag = DecodeFixed64(key_ptr + key_length - 8);
ValueType type; ValueType type;
@ -563,8 +563,8 @@ void MemTable::Update(SequenceNumber seq,
const char* entry = iter->key(); const char* entry = iter->key();
uint32_t key_length = 0; uint32_t key_length = 0;
const char* key_ptr = GetVarint32Ptr(entry, entry + 5, &key_length); const char* key_ptr = GetVarint32Ptr(entry, entry + 5, &key_length);
if (comparator_.comparator.user_comparator()->Compare( if (comparator_.comparator.user_comparator()->Equal(
Slice(key_ptr, key_length - 8), lkey.user_key()) == 0) { Slice(key_ptr, key_length - 8), lkey.user_key())) {
// Correct user key // Correct user key
const uint64_t tag = DecodeFixed64(key_ptr + key_length - 8); const uint64_t tag = DecodeFixed64(key_ptr + key_length - 8);
ValueType type; ValueType type;
@ -624,8 +624,8 @@ bool MemTable::UpdateCallback(SequenceNumber seq,
const char* entry = iter->key(); const char* entry = iter->key();
uint32_t key_length = 0; uint32_t key_length = 0;
const char* key_ptr = GetVarint32Ptr(entry, entry + 5, &key_length); const char* key_ptr = GetVarint32Ptr(entry, entry + 5, &key_length);
if (comparator_.comparator.user_comparator()->Compare( if (comparator_.comparator.user_comparator()->Equal(
Slice(key_ptr, key_length - 8), lkey.user_key()) == 0) { Slice(key_ptr, key_length - 8), lkey.user_key())) {
// Correct user key // Correct user key
const uint64_t tag = DecodeFixed64(key_ptr + key_length - 8); const uint64_t tag = DecodeFixed64(key_ptr + key_length - 8);
ValueType type; ValueType type;
@ -695,8 +695,8 @@ size_t MemTable::CountSuccessiveMergeEntries(const LookupKey& key) {
const char* entry = iter->key(); const char* entry = iter->key();
uint32_t key_length = 0; uint32_t key_length = 0;
const char* iter_key_ptr = GetVarint32Ptr(entry, entry + 5, &key_length); const char* iter_key_ptr = GetVarint32Ptr(entry, entry + 5, &key_length);
if (comparator_.comparator.user_comparator()->Compare( if (!comparator_.comparator.user_comparator()->Equal(
Slice(iter_key_ptr, key_length - 8), key.user_key()) != 0) { Slice(iter_key_ptr, key_length - 8), key.user_key())) {
break; break;
} }

View File

@ -312,9 +312,10 @@ class ReadBenchmarkThread : public BenchmarkThread {
assert(callback_args != nullptr); assert(callback_args != nullptr);
uint32_t key_length; uint32_t key_length;
const char* key_ptr = GetVarint32Ptr(entry, entry + 5, &key_length); const char* key_ptr = GetVarint32Ptr(entry, entry + 5, &key_length);
if ((callback_args->comparator)->user_comparator()->Compare( if ((callback_args->comparator)
Slice(key_ptr, key_length - 8), callback_args->key->user_key()) == ->user_comparator()
0) { ->Equal(Slice(key_ptr, key_length - 8),
callback_args->key->user_key())) {
callback_args->found = true; callback_args->found = true;
} }
return false; return false;

View File

@ -91,8 +91,7 @@ Status MergeHelper::MergeUntil(Iterator* iter, const SequenceNumber stop_before,
assert(!"corrupted internal key is not expected"); assert(!"corrupted internal key is not expected");
} }
break; break;
} else if (user_comparator_->Compare(ikey.user_key, orig_ikey.user_key) != } else if (!user_comparator_->Equal(ikey.user_key, orig_ikey.user_key)) {
0) {
// hit a different user key, stop right here // hit a different user key, stop right here
hit_the_next_user_key = true; hit_the_next_user_key = true;
break; break;

View File

@ -1581,7 +1581,7 @@ bool VersionStorageInfo::HasOverlappingUserKey(
files[last_file].largest_key); files[last_file].largest_key);
const Slice first_key_after = ExtractUserKey( const Slice first_key_after = ExtractUserKey(
files[last_file+1].smallest_key); files[last_file+1].smallest_key);
if (user_cmp->Compare(last_key_in_input, first_key_after) == 0) { if (user_cmp->Equal(last_key_in_input, first_key_after)) {
// The last user key in input overlaps with the next file's first key // The last user key in input overlaps with the next file's first key
return true; return true;
} }
@ -1596,7 +1596,7 @@ bool VersionStorageInfo::HasOverlappingUserKey(
files[first_file].smallest_key); files[first_file].smallest_key);
const Slice& last_key_before = ExtractUserKey( const Slice& last_key_before = ExtractUserKey(
files[first_file-1].largest_key); files[first_file-1].largest_key);
if (user_cmp->Compare(first_key_in_input, last_key_before) == 0) { if (user_cmp->Equal(first_key_in_input, last_key_before)) {
// The first user key in input overlaps with the previous file's last key // The first user key in input overlaps with the previous file's last key
return true; return true;
} }

View File

@ -29,6 +29,15 @@ class Comparator {
// > 0 iff "a" > "b" // > 0 iff "a" > "b"
virtual int Compare(const Slice& a, const Slice& b) const = 0; virtual int Compare(const Slice& a, const Slice& b) const = 0;
// Compares two slices for equality. The following invariant should always
// hold (and is the default implementation):
// Equal(a, b) iff Compare(a, b) == 0
// Overwrite only if equality comparisons can be done more efficiently than
// three-way comparisons.
virtual bool Equal(const Slice& a, const Slice& b) const {
return Compare(a, b) == 0;
}
// The name of the comparator. Used to check for comparator // The name of the comparator. Used to check for comparator
// mismatches (i.e., a DB created with one comparator is // mismatches (i.e., a DB created with one comparator is
// accessed using a different comparator. // accessed using a different comparator.

View File

@ -137,13 +137,13 @@ Status CuckooTableReader::Get(const ReadOptions& readOptions, const Slice& key,
const char* bucket = &file_data_.data()[offset]; const char* bucket = &file_data_.data()[offset];
for (uint32_t block_idx = 0; block_idx < cuckoo_block_size_; for (uint32_t block_idx = 0; block_idx < cuckoo_block_size_;
++block_idx, bucket += bucket_length_) { ++block_idx, bucket += bucket_length_) {
if (ucomp_->Compare(Slice(unused_key_.data(), user_key.size()), if (ucomp_->Equal(Slice(unused_key_.data(), user_key.size()),
Slice(bucket, user_key.size())) == 0) { Slice(bucket, user_key.size()))) {
return Status::OK(); return Status::OK();
} }
// Here, we compare only the user key part as we support only one entry // Here, we compare only the user key part as we support only one entry
// per user key and we don't support sanpshot. // per user key and we don't support sanpshot.
if (ucomp_->Compare(user_key, Slice(bucket, user_key.size())) == 0) { if (ucomp_->Equal(user_key, Slice(bucket, user_key.size()))) {
Slice value(bucket + key_length_, value_length_); Slice value(bucket + key_length_, value_length_);
if (is_last_level_) { if (is_last_level_) {
get_context->SaveValue(value); get_context->SaveValue(value);

View File

@ -71,7 +71,7 @@ bool GetContext::SaveValue(const ParsedInternalKey& parsed_key,
const Slice& value) { const Slice& value) {
assert((state_ != kMerge && parsed_key.type != kTypeMerge) || assert((state_ != kMerge && parsed_key.type != kTypeMerge) ||
merge_context_ != nullptr); merge_context_ != nullptr);
if (ucmp_->Compare(parsed_key.user_key, user_key_) == 0) { if (ucmp_->Equal(parsed_key.user_key, user_key_)) {
appendToReplayLog(replay_log_, parsed_key.type, value); appendToReplayLog(replay_log_, parsed_key.type, value);
// Key matches. Process it // Key matches. Process it

View File

@ -131,8 +131,7 @@ class MergingIterator : public Iterator {
for (auto& child : children_) { for (auto& child : children_) {
if (&child != current_) { if (&child != current_) {
child.Seek(key()); child.Seek(key());
if (child.Valid() && if (child.Valid() && comparator_->Equal(key(), child.key())) {
comparator_->Compare(key(), child.key()) == 0) {
child.Next(); child.Next();
} }
} }

View File

@ -32,6 +32,10 @@ class BytewiseComparatorImpl : public Comparator {
return a.compare(b); return a.compare(b);
} }
virtual bool Equal(const Slice& a, const Slice& b) const override {
return a == b;
}
virtual void FindShortestSeparator(std::string* start, virtual void FindShortestSeparator(std::string* start,
const Slice& limit) const override { const Slice& limit) const override {
// Find length of common prefix // Find length of common prefix

View File

@ -553,7 +553,7 @@ std::string DBTestBase::AllEntriesFor(const Slice& user_key, int cf) {
if (!ParseInternalKey(iter->key(), &ikey)) { if (!ParseInternalKey(iter->key(), &ikey)) {
result += "CORRUPTED"; result += "CORRUPTED";
} else { } else {
if (last_options_.comparator->Compare(ikey.user_key, user_key) != 0) { if (!last_options_.comparator->Equal(ikey.user_key, user_key)) {
break; break;
} }
if (!first) { if (!first) {

View File

@ -299,8 +299,8 @@ void HashCuckooRep::Get(const LookupKey& key, void* callback_args,
const char* bucket = const char* bucket =
cuckoo_array_[GetHash(user_key, hid)].load(std::memory_order_acquire); cuckoo_array_[GetHash(user_key, hid)].load(std::memory_order_acquire);
if (bucket != nullptr) { if (bucket != nullptr) {
auto bucket_user_key = UserKey(bucket); Slice bucket_user_key = UserKey(bucket);
if (user_key.compare(bucket_user_key) == 0) { if (user_key == bucket_user_key) {
callback_func(callback_args, bucket); callback_func(callback_args, bucket);
break; break;
} }
@ -466,10 +466,10 @@ bool HashCuckooRep::FindCuckooPath(const char* internal_key,
} }
// again, we can perform no barrier load safely here as the current // again, we can perform no barrier load safely here as the current
// thread is the only writer. // thread is the only writer.
auto bucket_user_key = Slice bucket_user_key =
UserKey(cuckoo_array_[step.bucket_id_].load(std::memory_order_relaxed)); UserKey(cuckoo_array_[step.bucket_id_].load(std::memory_order_relaxed));
if (step.prev_step_id_ != CuckooStep::kNullStep) { if (step.prev_step_id_ != CuckooStep::kNullStep) {
if (bucket_user_key.compare(user_key) == 0) { if (bucket_user_key == user_key) {
// then there is a loop in the current path, stop discovering this path. // then there is a loop in the current path, stop discovering this path.
continue; continue;
} }

View File

@ -92,8 +92,8 @@ class BaseDeltaIterator : public Iterator {
AdvanceBase(); AdvanceBase();
} }
if (DeltaValid() && BaseValid()) { if (DeltaValid() && BaseValid()) {
if (comparator_->Compare(delta_iterator_->Entry().key, if (comparator_->Equal(delta_iterator_->Entry().key,
base_iterator_->key()) == 0) { base_iterator_->key())) {
equal_keys_ = true; equal_keys_ = true;
} }
} }
@ -127,8 +127,8 @@ class BaseDeltaIterator : public Iterator {
AdvanceBase(); AdvanceBase();
} }
if (DeltaValid() && BaseValid()) { if (DeltaValid() && BaseValid()) {
if (comparator_->Compare(delta_iterator_->Entry().key, if (comparator_->Equal(delta_iterator_->Entry().key,
base_iterator_->key()) == 0) { base_iterator_->key())) {
equal_keys_ = true; equal_keys_ = true;
} }
} }