diff --git a/db/db_test.cc b/db/db_test.cc index fbcff5b48b..8909863431 100644 --- a/db/db_test.cc +++ b/db/db_test.cc @@ -1826,21 +1826,30 @@ TEST_F(DBTest, GetApproximateMemTableStats) { uint64_t count; uint64_t size; + // Because Random::GetTLSInstance() seed is reset in DBTestBase, + // this test is deterministic. + std::string start = Key(50); std::string end = Key(60); Range r(start, end); db_->GetApproximateMemTableStats(r, &count, &size); - ASSERT_GT(count, 0); - ASSERT_LE(count, N); - ASSERT_GT(size, 6000); - ASSERT_LT(size, 204800); + // When actual count is <= 10, it returns that as the minimum + EXPECT_EQ(count, 10); + EXPECT_EQ(size, 10440); + + start = Key(20); + end = Key(100); + r = Range(start, end); + db_->GetApproximateMemTableStats(r, &count, &size); + EXPECT_EQ(count, 72); + EXPECT_EQ(size, 75168); start = Key(500); end = Key(600); r = Range(start, end); db_->GetApproximateMemTableStats(r, &count, &size); - ASSERT_EQ(count, 0); - ASSERT_EQ(size, 0); + EXPECT_EQ(count, 0); + EXPECT_EQ(size, 0); ASSERT_OK(Flush()); @@ -1848,8 +1857,8 @@ TEST_F(DBTest, GetApproximateMemTableStats) { end = Key(60); r = Range(start, end); db_->GetApproximateMemTableStats(r, &count, &size); - ASSERT_EQ(count, 0); - ASSERT_EQ(size, 0); + EXPECT_EQ(count, 0); + EXPECT_EQ(size, 0); for (int i = 0; i < N; i++) { ASSERT_OK(Put(Key(1000 + i), rnd.RandomString(1024))); @@ -1857,10 +1866,11 @@ TEST_F(DBTest, GetApproximateMemTableStats) { start = Key(100); end = Key(1020); + // Actually 20 keys in the range ^^ r = Range(start, end); db_->GetApproximateMemTableStats(r, &count, &size); - ASSERT_GT(count, 20); - ASSERT_GT(size, 6000); + EXPECT_EQ(count, 20); + EXPECT_EQ(size, 20880); } TEST_F(DBTest, ApproximateSizes) { diff --git a/db_stress_tool/db_stress_gflags.cc b/db_stress_tool/db_stress_gflags.cc index 9f165cf977..a2632dfa3e 100644 --- a/db_stress_tool/db_stress_gflags.cc +++ b/db_stress_tool/db_stress_gflags.cc @@ -1031,8 +1031,9 @@ DEFINE_int32(continuous_verification_interval, 1000, "disables continuous verification."); DEFINE_int32(approximate_size_one_in, 64, - "If non-zero, DB::GetApproximateSizes() will be called against" - " random key ranges."); + "If non-zero, DB::GetApproximateSizes() and " + "DB::GetApproximateMemTableStats() will be called against " + "random key ranges."); DEFINE_int32(read_fault_one_in, 1000, "On non-zero, enables fault injection on read"); diff --git a/db_stress_tool/db_stress_test_base.cc b/db_stress_tool/db_stress_test_base.cc index 2b6db414f9..40d4c589a0 100644 --- a/db_stress_tool/db_stress_test_base.cc +++ b/db_stress_tool/db_stress_test_base.cc @@ -2427,22 +2427,31 @@ Status StressTest::TestApproximateSize( std::string key1_str = Key(key1); std::string key2_str = Key(key2); Range range{Slice(key1_str), Slice(key2_str)}; - SizeApproximationOptions sao; - sao.include_memtables = thread->rand.OneIn(2); - if (sao.include_memtables) { - sao.include_files = thread->rand.OneIn(2); - } - if (thread->rand.OneIn(2)) { - if (thread->rand.OneIn(2)) { - sao.files_size_error_margin = 0.0; - } else { - sao.files_size_error_margin = - static_cast(thread->rand.Uniform(3)); + if (thread->rand.OneIn(3)) { + // Call GetApproximateMemTableStats instead + uint64_t count, size; + db_->GetApproximateMemTableStats(column_families_[rand_column_families[0]], + range, &count, &size); + return Status::OK(); + } else { + // Call GetApproximateSizes + SizeApproximationOptions sao; + sao.include_memtables = thread->rand.OneIn(2); + if (sao.include_memtables) { + sao.include_files = thread->rand.OneIn(2); } + if (thread->rand.OneIn(2)) { + if (thread->rand.OneIn(2)) { + sao.files_size_error_margin = 0.0; + } else { + sao.files_size_error_margin = + static_cast(thread->rand.Uniform(3)); + } + } + uint64_t result; + return db_->GetApproximateSizes( + sao, column_families_[rand_column_families[0]], &range, 1, &result); } - uint64_t result; - return db_->GetApproximateSizes( - sao, column_families_[rand_column_families[0]], &range, 1, &result); } Status StressTest::TestCheckpoint(ThreadState* thread, diff --git a/memtable/inlineskiplist.h b/memtable/inlineskiplist.h index 06ef0397a2..9fdf618fa5 100644 --- a/memtable/inlineskiplist.h +++ b/memtable/inlineskiplist.h @@ -141,8 +141,9 @@ class InlineSkipList { // Returns true iff an entry that compares equal to key is in the list. bool Contains(const char* key) const; - // Return estimated number of entries smaller than `key`. - uint64_t EstimateCount(const char* key) const; + // Return estimated number of entries from `start_ikey` to `end_ikey`. + uint64_t ApproximateNumEntries(const Slice& start_ikey, + const Slice& end_ikey) const; // Validate correctness of the skip-list. void TEST_Validate() const; @@ -673,31 +674,88 @@ InlineSkipList::FindRandomEntry() const { } template -uint64_t InlineSkipList::EstimateCount(const char* key) const { - uint64_t count = 0; +uint64_t InlineSkipList::ApproximateNumEntries( + const Slice& start_ikey, const Slice& end_ikey) const { + // The number of entries at a given level for the given range, in terms of + // the actual number of entries in that range (level 0), follows a binomial + // distribution, which is very well approximated by the Poisson distribution. + // That has stddev sqrt(x) where x is the expected number of entries (mean) + // at this level, and the best predictor of x is the number of observed + // entries (at this level). To predict the number of entries on level 0 we use + // x * kBranchinng ^ level. From the standard deviation, the P99+ relative + // error is roughly 3 * sqrt(x) / x. Thus, a reasonable approach would be to + // find the smallest level with at least some moderate constant number entries + // in range. E.g. with at least ~40 entries, we expect P99+ relative error + // (approximation accuracy) of ~ 50% = 3 * sqrt(40) / 40; P95 error of + // ~30%; P75 error of < 20%. + // + // However, there are two issues with this approach, and an observation: + // * Pointer chasing on the larger (bottom) levels is much slower because of + // cache hierarchy effects, so when the result is smaller, getting the result + // will be substantially slower, despite traversing a similar number of + // entries. (We could be clever about pipelining our pointer chasing but + // that's complicated.) + // * The larger (bottom) levels also have lower variance because there's a + // chance (or certainty) that we reach level 0 and return the exact answer. + // * For applications in query planning, we can also tolerate more variance on + // small results because the impact of misestimating is likely smaller. + // + // These factors point us to an approach in which we have a higher minimum + // threshold number of samples for higher levels and lower for lower levels + // (see sufficient_samples below). This seems to yield roughly consistent + // relative error (stddev around 20%, less for large results) and roughly + // consistent query time around the time of two memtable point queries. + // + // Engineering observation: it is tempting to think that taking into account + // what we already found in how many entries occur on higher levels, not just + // the first iterated level with a sufficient number of samples, would yield + // a more accurate estimate. But that doesn't work because of the particular + // correlations and independences of the data: each level higher is just an + // independently probabilistic filtering of the level below it. That + // filtering from level l to l+1 has no more information about levels + // 0 .. l-1 than we can get from level l. The structure of RandomHeight() is + // a clue to these correlations and independences. - Node* x = head_; - int level = GetMaxHeight() - 1; - const DecodedKey key_decoded = compare_.decode_key(key); - while (true) { - assert(x == head_ || compare_(x->Key(), key_decoded) < 0); - Node* next = x->Next(level); - if (next != nullptr) { - PREFETCH(next->Next(level), 0, 1); + Node* lb = head_; + Node* ub = nullptr; + uint64_t count = 0; + for (int level = GetMaxHeight() - 1; level >= 0; level--) { + auto sufficient_samples = static_cast(level) * kBranching_ + 10U; + if (count >= sufficient_samples) { + // No more counting; apply powers of kBranching and avoid floating point + count *= kBranching_; + continue; } - if (next == nullptr || compare_(next->Key(), key_decoded) >= 0) { - if (level == 0) { - return count; - } else { - // Switch to next list - count *= kBranching_; - level--; + count = 0; + Node* next; + // Get a more precise lower bound (for start key) + for (;;) { + next = lb->Next(level); + if (next == ub) { + break; + } + assert(next != nullptr); + if (compare_(next->Key(), start_ikey) >= 0) { + break; + } + lb = next; + } + // Count entries on this level until upper bound (for end key) + for (;;) { + if (next == ub) { + break; + } + assert(next != nullptr); + if (compare_(next->Key(), end_ikey) >= 0) { + // Save refined upper bound to potentially save key comparison + ub = next; + break; } - } else { - x = next; count++; + next = next->Next(level); } } + return count; } template diff --git a/memtable/skiplist.h b/memtable/skiplist.h index e3cecd30c1..f2e2a829de 100644 --- a/memtable/skiplist.h +++ b/memtable/skiplist.h @@ -64,8 +64,9 @@ class SkipList { // Returns true iff an entry that compares equal to key is in the list. bool Contains(const Key& key) const; - // Return estimated number of entries smaller than `key`. - uint64_t EstimateCount(const Key& key) const; + // Return estimated number of entries from `start_ikey` to `end_ikey`. + uint64_t ApproximateNumEntries(const Slice& start_ikey, + const Slice& end_ikey) const; // Iteration over the contents of a skip list class Iterator { @@ -383,27 +384,49 @@ typename SkipList::Node* SkipList::FindLast() } template -uint64_t SkipList::EstimateCount(const Key& key) const { +uint64_t SkipList::ApproximateNumEntries( + const Slice& start_ikey, const Slice& end_ikey) const { + // See InlineSkipList::ApproximateNumEntries() (copy-paste) + Node* lb = head_; + Node* ub = nullptr; uint64_t count = 0; - - Node* x = head_; - int level = GetMaxHeight() - 1; - while (true) { - assert(x == head_ || compare_(x->key, key) < 0); - Node* next = x->Next(level); - if (next == nullptr || compare_(next->key, key) >= 0) { - if (level == 0) { - return count; - } else { - // Switch to next list - count *= kBranching_; - level--; + for (int level = GetMaxHeight() - 1; level >= 0; level--) { + auto sufficient_samples = static_cast(level) * kBranching_ + 10U; + if (count >= sufficient_samples) { + // No more counting; apply powers of kBranching and avoid floating point + count *= kBranching_; + continue; + } + count = 0; + Node* next; + // Get a more precise lower bound (for start key) + for (;;) { + next = lb->Next(level); + if (next == ub) { + break; + } + assert(next != nullptr); + if (compare_(next->Key(), start_ikey) >= 0) { + break; + } + lb = next; + } + // Count entries on this level until upper bound (for end key) + for (;;) { + if (next == ub) { + break; + } + assert(next != nullptr); + if (compare_(next->Key(), end_ikey) >= 0) { + // Save refined upper bound to potentially save key comparison + ub = next; + break; } - } else { - x = next; count++; + next = next->Next(level); } } + return count; } template diff --git a/memtable/skiplistrep.cc b/memtable/skiplistrep.cc index 3b2f3f4d8d..73bb64d184 100644 --- a/memtable/skiplistrep.cc +++ b/memtable/skiplistrep.cc @@ -108,11 +108,7 @@ class SkipListRep : public MemTableRep { uint64_t ApproximateNumEntries(const Slice& start_ikey, const Slice& end_ikey) override { - std::string tmp; - uint64_t start_count = - skip_list_.EstimateCount(EncodeKey(&tmp, start_ikey)); - uint64_t end_count = skip_list_.EstimateCount(EncodeKey(&tmp, end_ikey)); - return (end_count >= start_count) ? (end_count - start_count) : 0; + return skip_list_.ApproximateNumEntries(start_ikey, end_ikey); } void UniqueRandomSample(const uint64_t num_entries, diff --git a/tools/db_bench_tool.cc b/tools/db_bench_tool.cc index 713aaaa412..1baa76cbd4 100644 --- a/tools/db_bench_tool.cc +++ b/tools/db_bench_tool.cc @@ -153,10 +153,11 @@ DEFINE_string( "randomtransaction," "randomreplacekeys," "timeseries," - "getmergeoperands,", + "getmergeoperands," "readrandomoperands," "backup," - "restore" + "restore," + "approximatememtablestats", "Comma-separated list of operations to run in the specified" " order. Available benchmarks:\n" @@ -243,9 +244,14 @@ DEFINE_string( "operation includes a rare but possible retry in case it got " "`Status::Incomplete()`. This happens upon encountering more keys than " "have ever been seen by the thread (or eight initially)\n" - "\tbackup -- Create a backup of the current DB and verify that a new backup is corrected. " + "\tbackup -- Create a backup of the current DB and verify that a new " + "backup is corrected. " "Rate limit can be specified through --backup_rate_limit\n" - "\trestore -- Restore the DB from the latest backup available, rate limit can be specified through --restore_rate_limit\n"); + "\trestore -- Restore the DB from the latest backup available, rate limit " + "can be specified through --restore_rate_limit\n" + "\tapproximatememtablestats -- Tests accuracy of " + "GetApproximateMemTableStats, ideally\n" + "after fillrandom, where actual answer is batch_size"); DEFINE_int64(num, 1000000, "Number of key/values to place in database"); @@ -3621,6 +3627,8 @@ class Benchmark { fprintf(stderr, "entries_per_batch = %" PRIi64 "\n", entries_per_batch_); method = &Benchmark::ApproximateSizeRandom; + } else if (name == "approximatememtablestats") { + method = &Benchmark::ApproximateMemtableStats; } else if (name == "mixgraph") { method = &Benchmark::MixGraph; } else if (name == "readmissing") { @@ -6298,6 +6306,35 @@ class Benchmark { thread->stats.AddMessage(msg); } + void ApproximateMemtableStats(ThreadState* thread) { + const size_t batch_size = entries_per_batch_; + std::unique_ptr skey_guard; + Slice skey = AllocateKey(&skey_guard); + std::unique_ptr ekey_guard; + Slice ekey = AllocateKey(&ekey_guard); + Duration duration(FLAGS_duration, reads_); + if (FLAGS_num < static_cast(batch_size)) { + std::terminate(); + } + uint64_t range = static_cast(FLAGS_num) - batch_size; + auto count_hist = std::make_shared(); + while (!duration.Done(1)) { + DB* db = SelectDB(thread); + uint64_t start_key = thread->rand.Uniform(range); + GenerateKeyFromInt(start_key, FLAGS_num, &skey); + uint64_t end_key = start_key + batch_size; + GenerateKeyFromInt(end_key, FLAGS_num, &ekey); + uint64_t count = UINT64_MAX; + uint64_t size = UINT64_MAX; + db->GetApproximateMemTableStats({skey, ekey}, &count, &size); + count_hist->Add(count); + thread->stats.FinishedOps(nullptr, db, 1, kOthers); + } + thread->stats.AddMessage("\nReported entry count stats (expected " + + std::to_string(batch_size) + "):"); + thread->stats.AddMessage("\n" + count_hist->ToString()); + } + // Calls ApproximateSize over random key ranges. void ApproximateSizeRandom(ThreadState* thread) { int64_t size_sum = 0; diff --git a/unreleased_history/bug_fixes/memtable_stats.md b/unreleased_history/bug_fixes/memtable_stats.md new file mode 100644 index 0000000000..047dfbc3d0 --- /dev/null +++ b/unreleased_history/bug_fixes/memtable_stats.md @@ -0,0 +1 @@ +* `GetApproximateMemTableStats()` could return disastrously bad estimates 5-25% of the time. The function has been re-engineered to return much better estimates with similar CPU cost.