From 97e70906fab2bdc2878f5e05864165526dca79fe Mon Sep 17 00:00:00 2001 From: Levi Tamasi Date: Thu, 9 May 2024 12:25:19 -0700 Subject: [PATCH] Improve the sanity checks in (Multi)GetEntity and friends (#12630) Summary: Pull Request resolved: https://github.com/facebook/rocksdb/pull/12630 The patch cleans up, improves, and brings into sync (to the extent possible without API signature changes) the sanity checks around the `GetEntity` / `MultiGetEntity` family of APIs, including the read-your-own-writes (`WriteBatchWithIndex`) and transaction layers. The checks are centralized in two main sets of entry points, namely in `DB(Impl)` and the "main" `GetEntityFromBatchAndDB` / `MultiGetEntityFromBatchAndDB` overloads in `WriteBatchWithIndex`. This eliminates the need to duplicate the checks in the transaction classes. Reviewed By: jaykorean Differential Revision: D57125741 fbshipit-source-id: 4dd059ef644a9b173fbba767538943397e4cc6cd --- db/db_impl/db_impl.cc | 125 ++++++++++-- db/db_impl/db_impl.h | 5 + db/wide/db_wide_basic_test.cc | 138 +++++++++++++ .../utilities/write_batch_with_index.h | 15 +- .../optimistic_transaction_test.cc | 31 +++ utilities/transactions/transaction_base.cc | 14 +- utilities/transactions/transaction_test.cc | 31 +++ .../write_batch_with_index.cc | 181 ++++++++++-------- .../write_batch_with_index_test.cc | 122 ++++++++++++ 9 files changed, 545 insertions(+), 117 deletions(-) diff --git a/db/db_impl/db_impl.cc b/db/db_impl/db_impl.cc index 12b0afc2bb..4e28454d63 100644 --- a/db/db_impl/db_impl.cc +++ b/db/db_impl/db_impl.cc @@ -2099,7 +2099,7 @@ Status DBImpl::GetEntity(const ReadOptions& _read_options, if (_read_options.io_activity != Env::IOActivity::kUnknown && _read_options.io_activity != Env::IOActivity::kGetEntity) { return Status::InvalidArgument( - "Cannot call GetEntity with `ReadOptions::io_activity` != " + "Can only call GetEntity with `ReadOptions::io_activity` set to " "`Env::IOActivity::kUnknown` or `Env::IOActivity::kGetEntity`"); } ReadOptions read_options(_read_options); @@ -2126,7 +2126,7 @@ Status DBImpl::GetEntity(const ReadOptions& _read_options, const Slice& key, if (_read_options.io_activity != Env::IOActivity::kUnknown && _read_options.io_activity != Env::IOActivity::kGetEntity) { s = Status::InvalidArgument( - "Cannot call GetEntity with `ReadOptions::io_activity` != " + "Can only call GetEntity with `ReadOptions::io_activity` set to " "`Env::IOActivity::kUnknown` or `Env::IOActivity::kGetEntity`"); for (size_t i = 0; i < num_column_families; ++i) { (*result)[i].SetStatus(s); @@ -3185,22 +3185,55 @@ void DBImpl::MultiGetEntity(const ReadOptions& _read_options, size_t num_keys, ColumnFamilyHandle** column_families, const Slice* keys, PinnableWideColumns* results, Status* statuses, bool sorted_input) { - if (_read_options.io_activity != Env::IOActivity::kUnknown && - _read_options.io_activity != Env::IOActivity::kMultiGetEntity) { - Status s = Status::InvalidArgument( - "Can only call MultiGetEntity with `ReadOptions::io_activity` is " - "`Env::IOActivity::kUnknown` or `Env::IOActivity::kMultiGetEntity`"); + assert(statuses); + + if (!column_families) { + const Status s = Status::InvalidArgument( + "Cannot call MultiGetEntity without column families"); for (size_t i = 0; i < num_keys; ++i) { - if (statuses[i].ok()) { - statuses[i] = s; - } + statuses[i] = s; } + return; } + + if (!keys) { + const Status s = + Status::InvalidArgument("Cannot call MultiGetEntity without keys"); + for (size_t i = 0; i < num_keys; ++i) { + statuses[i] = s; + } + + return; + } + + if (!results) { + const Status s = Status::InvalidArgument( + "Cannot call MultiGetEntity without PinnableWideColumns objects"); + for (size_t i = 0; i < num_keys; ++i) { + statuses[i] = s; + } + + return; + } + + if (_read_options.io_activity != Env::IOActivity::kUnknown && + _read_options.io_activity != Env::IOActivity::kMultiGetEntity) { + const Status s = Status::InvalidArgument( + "Can only call MultiGetEntity with `ReadOptions::io_activity` set to " + "`Env::IOActivity::kUnknown` or `Env::IOActivity::kMultiGetEntity`"); + for (size_t i = 0; i < num_keys; ++i) { + statuses[i] = s; + } + + return; + } + ReadOptions read_options(_read_options); if (read_options.io_activity == Env::IOActivity::kUnknown) { read_options.io_activity = Env::IOActivity::kMultiGetEntity; } + MultiGetCommon(read_options, num_keys, column_families, keys, /* values */ nullptr, results, /* timestamps */ nullptr, statuses, sorted_input); @@ -3210,22 +3243,54 @@ void DBImpl::MultiGetEntity(const ReadOptions& _read_options, ColumnFamilyHandle* column_family, size_t num_keys, const Slice* keys, PinnableWideColumns* results, Status* statuses, bool sorted_input) { + assert(statuses); + + if (!column_family) { + const Status s = Status::InvalidArgument( + "Cannot call MultiGetEntity without a column family handle"); + for (size_t i = 0; i < num_keys; ++i) { + statuses[i] = s; + } + + return; + } + + if (!keys) { + const Status s = + Status::InvalidArgument("Cannot call MultiGetEntity without keys"); + for (size_t i = 0; i < num_keys; ++i) { + statuses[i] = s; + } + + return; + } + + if (!results) { + const Status s = Status::InvalidArgument( + "Cannot call MultiGetEntity without PinnableWideColumns objects"); + for (size_t i = 0; i < num_keys; ++i) { + statuses[i] = s; + } + + return; + } + if (_read_options.io_activity != Env::IOActivity::kUnknown && _read_options.io_activity != Env::IOActivity::kMultiGetEntity) { - Status s = Status::InvalidArgument( - "Can only call MultiGetEntity with `ReadOptions::io_activity` is " + const Status s = Status::InvalidArgument( + "Can only call MultiGetEntity with `ReadOptions::io_activity` set to " "`Env::IOActivity::kUnknown` or `Env::IOActivity::kMultiGetEntity`"); for (size_t i = 0; i < num_keys; ++i) { - if (statuses[i].ok()) { - statuses[i] = s; - } + statuses[i] = s; } return; } + ReadOptions read_options(_read_options); if (read_options.io_activity == Env::IOActivity::kUnknown) { read_options.io_activity = Env::IOActivity::kMultiGetEntity; } + MultiGetCommon(read_options, column_family, num_keys, keys, /* values */ nullptr, results, /* timestamps */ nullptr, statuses, sorted_input); @@ -3234,18 +3299,34 @@ void DBImpl::MultiGetEntity(const ReadOptions& _read_options, void DBImpl::MultiGetEntity(const ReadOptions& _read_options, size_t num_keys, const Slice* keys, PinnableAttributeGroups* results) { + assert(results); + + if (!keys) { + const Status s = + Status::InvalidArgument("Cannot call MultiGetEntity without keys"); + for (size_t i = 0; i < num_keys; ++i) { + for (size_t j = 0; j < results[i].size(); ++j) { + results[i][j].SetStatus(s); + } + } + + return; + } + if (_read_options.io_activity != Env::IOActivity::kUnknown && _read_options.io_activity != Env::IOActivity::kMultiGetEntity) { - Status s = Status::InvalidArgument( - "Can only call MultiGetEntity with ReadOptions::io_activity` is " + const Status s = Status::InvalidArgument( + "Can only call MultiGetEntity with `ReadOptions::io_activity` set to " "`Env::IOActivity::kUnknown` or `Env::IOActivity::kMultiGetEntity`"); for (size_t i = 0; i < num_keys; ++i) { for (size_t j = 0; j < results[i].size(); ++j) { results[i][j].SetStatus(s); } } + return; } + ReadOptions read_options(_read_options); if (read_options.io_activity == Env::IOActivity::kUnknown) { read_options.io_activity = Env::IOActivity::kMultiGetEntity; @@ -3263,6 +3344,7 @@ void DBImpl::MultiGetEntity(const ReadOptions& _read_options, size_t num_keys, ++total_count; } } + std::vector statuses(total_count); std::vector columns(total_count); MultiGetCommon(read_options, total_count, column_families.data(), @@ -3283,6 +3365,15 @@ void DBImpl::MultiGetEntity(const ReadOptions& _read_options, size_t num_keys, } } +void DBImpl::MultiGetEntityWithCallback( + const ReadOptions& read_options, ColumnFamilyHandle* column_family, + ReadCallback* callback, + autovector* sorted_keys) { + assert(read_options.io_activity == Env::IOActivity::kMultiGetEntity); + + MultiGetWithCallbackImpl(read_options, column_family, callback, sorted_keys); +} + Status DBImpl::WrapUpCreateColumnFamilies( const ReadOptions& read_options, const WriteOptions& write_options, const std::vector& cf_options) { diff --git a/db/db_impl/db_impl.h b/db/db_impl/db_impl.h index f4a95b52b8..a7f0ec0d16 100644 --- a/db/db_impl/db_impl.h +++ b/db/db_impl/db_impl.h @@ -291,6 +291,11 @@ class DBImpl : public DB { const Slice* keys, PinnableAttributeGroups* results) override; + void MultiGetEntityWithCallback( + const ReadOptions& read_options, ColumnFamilyHandle* column_family, + ReadCallback* callback, + autovector* sorted_keys); + Status CreateColumnFamily(const ColumnFamilyOptions& cf_options, const std::string& column_family, ColumnFamilyHandle** handle) override { diff --git a/db/wide/db_wide_basic_test.cc b/db/wide/db_wide_basic_test.cc index 6f3bc9be7e..886f71d745 100644 --- a/db/wide/db_wide_basic_test.cc +++ b/db/wide/db_wide_basic_test.cc @@ -1794,6 +1794,144 @@ TEST_F(DBWideBasicTest, PinnableWideColumnsMove) { test_move(/* fill_cache*/ true); } +TEST_F(DBWideBasicTest, SanityChecks) { + constexpr char foo[] = "foo"; + constexpr char bar[] = "bar"; + constexpr size_t num_keys = 2; + + { + constexpr ColumnFamilyHandle* column_family = nullptr; + PinnableWideColumns columns; + ASSERT_TRUE(db_->GetEntity(ReadOptions(), column_family, foo, &columns) + .IsInvalidArgument()); + } + + { + constexpr PinnableWideColumns* columns = nullptr; + ASSERT_TRUE( + db_->GetEntity(ReadOptions(), db_->DefaultColumnFamily(), foo, columns) + .IsInvalidArgument()); + } + + { + ReadOptions read_options; + read_options.io_activity = Env::IOActivity::kGet; + + PinnableWideColumns columns; + ASSERT_TRUE( + db_->GetEntity(read_options, db_->DefaultColumnFamily(), foo, &columns) + .IsInvalidArgument()); + } + + { + constexpr ColumnFamilyHandle* column_family = nullptr; + std::array keys{{foo, bar}}; + std::array results; + std::array statuses; + + db_->MultiGetEntity(ReadOptions(), column_family, num_keys, keys.data(), + results.data(), statuses.data()); + + ASSERT_TRUE(statuses[0].IsInvalidArgument()); + ASSERT_TRUE(statuses[1].IsInvalidArgument()); + } + + { + constexpr Slice* keys = nullptr; + std::array results; + std::array statuses; + + db_->MultiGetEntity(ReadOptions(), db_->DefaultColumnFamily(), num_keys, + keys, results.data(), statuses.data()); + + ASSERT_TRUE(statuses[0].IsInvalidArgument()); + ASSERT_TRUE(statuses[1].IsInvalidArgument()); + } + + { + std::array keys{{foo, bar}}; + constexpr PinnableWideColumns* results = nullptr; + std::array statuses; + + db_->MultiGetEntity(ReadOptions(), db_->DefaultColumnFamily(), num_keys, + keys.data(), results, statuses.data()); + + ASSERT_TRUE(statuses[0].IsInvalidArgument()); + ASSERT_TRUE(statuses[1].IsInvalidArgument()); + } + + { + ReadOptions read_options; + read_options.io_activity = Env::IOActivity::kMultiGet; + + std::array keys{{foo, bar}}; + std::array results; + std::array statuses; + + db_->MultiGetEntity(read_options, db_->DefaultColumnFamily(), num_keys, + keys.data(), results.data(), statuses.data()); + ASSERT_TRUE(statuses[0].IsInvalidArgument()); + ASSERT_TRUE(statuses[1].IsInvalidArgument()); + } + + { + constexpr ColumnFamilyHandle** column_families = nullptr; + std::array keys{{foo, bar}}; + std::array results; + std::array statuses; + + db_->MultiGetEntity(ReadOptions(), num_keys, column_families, keys.data(), + results.data(), statuses.data()); + + ASSERT_TRUE(statuses[0].IsInvalidArgument()); + ASSERT_TRUE(statuses[1].IsInvalidArgument()); + } + + { + std::array column_families{ + {db_->DefaultColumnFamily(), db_->DefaultColumnFamily()}}; + constexpr Slice* keys = nullptr; + std::array results; + std::array statuses; + + db_->MultiGetEntity(ReadOptions(), num_keys, column_families.data(), keys, + results.data(), statuses.data()); + + ASSERT_TRUE(statuses[0].IsInvalidArgument()); + ASSERT_TRUE(statuses[1].IsInvalidArgument()); + } + + { + std::array column_families{ + {db_->DefaultColumnFamily(), db_->DefaultColumnFamily()}}; + std::array keys{{foo, bar}}; + constexpr PinnableWideColumns* results = nullptr; + std::array statuses; + + db_->MultiGetEntity(ReadOptions(), num_keys, column_families.data(), + keys.data(), results, statuses.data()); + + ASSERT_TRUE(statuses[0].IsInvalidArgument()); + ASSERT_TRUE(statuses[1].IsInvalidArgument()); + } + + { + ReadOptions read_options; + read_options.io_activity = Env::IOActivity::kMultiGet; + + std::array column_families{ + {db_->DefaultColumnFamily(), db_->DefaultColumnFamily()}}; + std::array keys{{foo, bar}}; + std::array results; + std::array statuses; + + db_->MultiGetEntity(read_options, num_keys, column_families.data(), + keys.data(), results.data(), statuses.data()); + ASSERT_TRUE(statuses[0].IsInvalidArgument()); + ASSERT_TRUE(statuses[1].IsInvalidArgument()); + } +} + } // namespace ROCKSDB_NAMESPACE int main(int argc, char** argv) { diff --git a/include/rocksdb/utilities/write_batch_with_index.h b/include/rocksdb/utilities/write_batch_with_index.h index 251654fe25..ad66236478 100644 --- a/include/rocksdb/utilities/write_batch_with_index.h +++ b/include/rocksdb/utilities/write_batch_with_index.h @@ -284,7 +284,12 @@ class WriteBatchWithIndex : public WriteBatchBase { Status GetEntityFromBatchAndDB(DB* db, const ReadOptions& read_options, ColumnFamilyHandle* column_family, const Slice& key, - PinnableWideColumns* columns); + PinnableWideColumns* columns) { + constexpr ReadCallback* callback = nullptr; + + return GetEntityFromBatchAndDB(db, read_options, column_family, key, + columns, callback); + } void MultiGetFromBatchAndDB(DB* db, const ReadOptions& read_options, ColumnFamilyHandle* column_family, @@ -310,7 +315,13 @@ class WriteBatchWithIndex : public WriteBatchBase { ColumnFamilyHandle* column_family, size_t num_keys, const Slice* keys, PinnableWideColumns* results, - Status* statuses, bool sorted_input); + Status* statuses, bool sorted_input) { + constexpr ReadCallback* callback = nullptr; + + MultiGetEntityFromBatchAndDB(db, read_options, column_family, num_keys, + keys, results, statuses, sorted_input, + callback); + } // Records the state of the batch for future calls to RollbackToSavePoint(). // May be called multiple times to set multiple save points. diff --git a/utilities/transactions/optimistic_transaction_test.cc b/utilities/transactions/optimistic_transaction_test.cc index eb522d73db..c75321323d 100644 --- a/utilities/transactions/optimistic_transaction_test.cc +++ b/utilities/transactions/optimistic_transaction_test.cc @@ -1928,6 +1928,37 @@ TEST_P(OptimisticTransactionTest, PutEntityWriteConflictTxnTxn) { } } +TEST_P(OptimisticTransactionTest, EntityReadSanityChecks) { + constexpr char foo[] = "foo"; + + std::unique_ptr txn(txn_db->BeginTransaction(WriteOptions())); + ASSERT_NE(txn, nullptr); + + { + constexpr ColumnFamilyHandle* column_family = nullptr; + PinnableWideColumns columns; + ASSERT_TRUE(txn->GetEntity(ReadOptions(), column_family, foo, &columns) + .IsInvalidArgument()); + } + + { + constexpr PinnableWideColumns* columns = nullptr; + ASSERT_TRUE(txn->GetEntity(ReadOptions(), txn_db->DefaultColumnFamily(), + foo, columns) + .IsInvalidArgument()); + } + + { + ReadOptions read_options; + read_options.io_activity = Env::IOActivity::kGet; + + PinnableWideColumns columns; + ASSERT_TRUE(txn->GetEntity(read_options, txn_db->DefaultColumnFamily(), foo, + &columns) + .IsInvalidArgument()); + } +} + INSTANTIATE_TEST_CASE_P( InstanceOccGroup, OptimisticTransactionTest, testing::Values(OccValidationPolicy::kValidateSerial, diff --git a/utilities/transactions/transaction_base.cc b/utilities/transactions/transaction_base.cc index 51031f6364..046f848f62 100644 --- a/utilities/transactions/transaction_base.cc +++ b/utilities/transactions/transaction_base.cc @@ -289,22 +289,10 @@ Status TransactionBaseImpl::GetImpl(const ReadOptions& read_options, pinnable_val); } -Status TransactionBaseImpl::GetEntity(const ReadOptions& _read_options, +Status TransactionBaseImpl::GetEntity(const ReadOptions& read_options, ColumnFamilyHandle* column_family, const Slice& key, PinnableWideColumns* columns) { - if (_read_options.io_activity != Env::IOActivity::kUnknown && - _read_options.io_activity != Env::IOActivity::kGetEntity) { - return Status::InvalidArgument( - "Can only call GetEntity with `ReadOptions::io_activity` set to " - "`Env::IOActivity::kUnknown` or `Env::IOActivity::kGetEntity`"); - } - - ReadOptions read_options(_read_options); - if (read_options.io_activity == Env::IOActivity::kUnknown) { - read_options.io_activity = Env::IOActivity::kGetEntity; - } - return GetEntityImpl(read_options, column_family, key, columns); } diff --git a/utilities/transactions/transaction_test.cc b/utilities/transactions/transaction_test.cc index 36d2715fb4..7dbdb788e3 100644 --- a/utilities/transactions/transaction_test.cc +++ b/utilities/transactions/transaction_test.cc @@ -7150,6 +7150,37 @@ TEST_P(TransactionTest, PutEntityWriteConflict) { } } +TEST_P(TransactionTest, EntityReadSanityChecks) { + constexpr char foo[] = "foo"; + + std::unique_ptr txn(db->BeginTransaction(WriteOptions())); + ASSERT_NE(txn, nullptr); + + { + constexpr ColumnFamilyHandle* column_family = nullptr; + PinnableWideColumns columns; + ASSERT_TRUE(txn->GetEntity(ReadOptions(), column_family, foo, &columns) + .IsInvalidArgument()); + } + + { + constexpr PinnableWideColumns* columns = nullptr; + ASSERT_TRUE( + txn->GetEntity(ReadOptions(), db->DefaultColumnFamily(), foo, columns) + .IsInvalidArgument()); + } + + { + ReadOptions read_options; + read_options.io_activity = Env::IOActivity::kGet; + + PinnableWideColumns columns; + ASSERT_TRUE( + txn->GetEntity(read_options, db->DefaultColumnFamily(), foo, &columns) + .IsInvalidArgument()); + } +} + TEST_F(TransactionDBTest, CollapseKey) { ASSERT_OK(ReOpen()); ASSERT_OK(db->Put({}, "hello", "world")); diff --git a/utilities/write_batch_with_index/write_batch_with_index.cc b/utilities/write_batch_with_index/write_batch_with_index.cc index d2e1816d4e..834c2fa491 100644 --- a/utilities/write_batch_with_index/write_batch_with_index.cc +++ b/utilities/write_batch_with_index/write_batch_with_index.cc @@ -823,11 +823,41 @@ void WriteBatchWithIndex::MultiGetFromBatchAndDB( } Status WriteBatchWithIndex::GetEntityFromBatchAndDB( - DB* db, const ReadOptions& read_options, ColumnFamilyHandle* column_family, + DB* db, const ReadOptions& _read_options, ColumnFamilyHandle* column_family, const Slice& key, PinnableWideColumns* columns, ReadCallback* callback) { - assert(db); - assert(column_family); - assert(columns); + if (!db) { + return Status::InvalidArgument( + "Cannot call GetEntityFromBatchAndDB without a DB object"); + } + + if (_read_options.io_activity != Env::IOActivity::kUnknown && + _read_options.io_activity != Env::IOActivity::kGetEntity) { + return Status::InvalidArgument( + "Can only call GetEntityFromBatchAndDB with `ReadOptions::io_activity` " + "set to `Env::IOActivity::kUnknown` or `Env::IOActivity::kGetEntity`"); + } + + ReadOptions read_options(_read_options); + if (read_options.io_activity == Env::IOActivity::kUnknown) { + read_options.io_activity = Env::IOActivity::kGetEntity; + } + + if (!column_family) { + return Status::InvalidArgument( + "Cannot call GetEntityFromBatchAndDB without a column family handle"); + } + + const Comparator* const ucmp = rep->comparator.GetComparator(column_family); + size_t ts_sz = ucmp ? ucmp->timestamp_size() : 0; + if (ts_sz > 0 && !read_options.timestamp) { + return Status::InvalidArgument("Must specify timestamp"); + } + + if (!columns) { + return Status::InvalidArgument( + "Cannot call GetEntityFromBatchAndDB without a PinnableWideColumns " + "object"); + } columns->Reset(); @@ -872,46 +902,78 @@ Status WriteBatchWithIndex::GetEntityFromBatchAndDB( return s; } -Status WriteBatchWithIndex::GetEntityFromBatchAndDB( - DB* db, const ReadOptions& read_options, ColumnFamilyHandle* column_family, - const Slice& key, PinnableWideColumns* columns) { +void WriteBatchWithIndex::MultiGetEntityFromBatchAndDB( + DB* db, const ReadOptions& _read_options, ColumnFamilyHandle* column_family, + size_t num_keys, const Slice* keys, PinnableWideColumns* results, + Status* statuses, bool sorted_input, ReadCallback* callback) { + assert(statuses); + if (!db) { - return Status::InvalidArgument( - "Cannot call GetEntityFromBatchAndDB without a DB object"); + const Status s = Status::InvalidArgument( + "Cannot call MultiGetEntityFromBatchAndDB without a DB object"); + for (size_t i = 0; i < num_keys; ++i) { + statuses[i] = s; + } + return; + } + + if (_read_options.io_activity != Env::IOActivity::kUnknown && + _read_options.io_activity != Env::IOActivity::kMultiGetEntity) { + const Status s = Status::InvalidArgument( + "Can only call MultiGetEntityFromBatchAndDB with " + "`ReadOptions::io_activity` set to `Env::IOActivity::kUnknown` or " + "`Env::IOActivity::kMultiGetEntity`"); + for (size_t i = 0; i < num_keys; ++i) { + if (statuses[i].ok()) { + statuses[i] = s; + } + } + return; + } + + ReadOptions read_options(_read_options); + if (read_options.io_activity == Env::IOActivity::kUnknown) { + read_options.io_activity = Env::IOActivity::kMultiGetEntity; } if (!column_family) { - return Status::InvalidArgument( - "Cannot call GetEntityFromBatchAndDB without a column family handle"); + const Status s = Status::InvalidArgument( + "Cannot call MultiGetEntityFromBatchAndDB without a column family " + "handle"); + for (size_t i = 0; i < num_keys; ++i) { + statuses[i] = s; + } + return; } const Comparator* const ucmp = rep->comparator.GetComparator(column_family); - size_t ts_sz = ucmp ? ucmp->timestamp_size() : 0; + const size_t ts_sz = ucmp ? ucmp->timestamp_size() : 0; if (ts_sz > 0 && !read_options.timestamp) { - return Status::InvalidArgument("Must specify timestamp"); + const Status s = Status::InvalidArgument("Must specify timestamp"); + for (size_t i = 0; i < num_keys; ++i) { + statuses[i] = s; + } + return; } - if (!columns) { - return Status::InvalidArgument( - "Cannot call GetEntityFromBatchAndDB without a PinnableWideColumns " - "object"); + if (!keys) { + const Status s = Status::InvalidArgument( + "Cannot call MultiGetEntityFromBatchAndDB without keys"); + for (size_t i = 0; i < num_keys; ++i) { + statuses[i] = s; + } + return; } - constexpr ReadCallback* callback = nullptr; - - return GetEntityFromBatchAndDB(db, read_options, column_family, key, columns, - callback); -} - -void WriteBatchWithIndex::MultiGetEntityFromBatchAndDB( - DB* db, const ReadOptions& read_options, ColumnFamilyHandle* column_family, - size_t num_keys, const Slice* keys, PinnableWideColumns* results, - Status* statuses, bool sorted_input, ReadCallback* callback) { - assert(db); - assert(column_family); - assert(keys); - assert(results); - assert(statuses); + if (!results) { + const Status s = Status::InvalidArgument( + "Cannot call MultiGetEntityFromBatchAndDB without " + "PinnableWideColumns objects"); + for (size_t i = 0; i < num_keys; ++i) { + statuses[i] = s; + } + return; + } struct MergeTuple { MergeTuple(const Slice& _key, Status* _s, MergeContext&& _merge_context, @@ -990,8 +1052,8 @@ void WriteBatchWithIndex::MultiGetEntityFromBatchAndDB( static_cast_with_check(db->GetRootDB()) ->PrepareMultiGetKeys(sorted_keys.size(), sorted_input, &sorted_keys); static_cast_with_check(db->GetRootDB()) - ->MultiGetWithCallback(read_options, column_family, callback, - &sorted_keys); + ->MultiGetEntityWithCallback(read_options, column_family, callback, + &sorted_keys); for (const auto& merge : merges) { if (merge.s->ok() || merge.s->IsNotFound()) { // DB lookup succeeded @@ -1001,57 +1063,6 @@ void WriteBatchWithIndex::MultiGetEntityFromBatchAndDB( } } -void WriteBatchWithIndex::MultiGetEntityFromBatchAndDB( - DB* db, const ReadOptions& read_options, ColumnFamilyHandle* column_family, - size_t num_keys, const Slice* keys, PinnableWideColumns* results, - Status* statuses, bool sorted_input) { - assert(statuses); - - if (!db) { - for (size_t i = 0; i < num_keys; ++i) { - statuses[i] = Status::InvalidArgument( - "Cannot call MultiGetEntityFromBatchAndDB without a DB object"); - } - } - - if (!column_family) { - for (size_t i = 0; i < num_keys; ++i) { - statuses[i] = Status::InvalidArgument( - "Cannot call MultiGetEntityFromBatchAndDB without a column family " - "handle"); - } - } - - const Comparator* const ucmp = rep->comparator.GetComparator(column_family); - const size_t ts_sz = ucmp ? ucmp->timestamp_size() : 0; - if (ts_sz > 0 && !read_options.timestamp) { - for (size_t i = 0; i < num_keys; ++i) { - statuses[i] = Status::InvalidArgument("Must specify timestamp"); - } - return; - } - - if (!keys) { - for (size_t i = 0; i < num_keys; ++i) { - statuses[i] = Status::InvalidArgument( - "Cannot call MultiGetEntityFromBatchAndDB without keys"); - } - } - - if (!results) { - for (size_t i = 0; i < num_keys; ++i) { - statuses[i] = Status::InvalidArgument( - "Cannot call MultiGetEntityFromBatchAndDB without " - "PinnableWideColumns objects"); - } - } - - constexpr ReadCallback* callback = nullptr; - - MultiGetEntityFromBatchAndDB(db, read_options, column_family, num_keys, keys, - results, statuses, sorted_input, callback); -} - void WriteBatchWithIndex::SetSavePoint() { rep->write_batch.SetSavePoint(); } Status WriteBatchWithIndex::RollbackToSavePoint() { diff --git a/utilities/write_batch_with_index/write_batch_with_index_test.cc b/utilities/write_batch_with_index/write_batch_with_index_test.cc index bf72303313..d706682a5f 100644 --- a/utilities/write_batch_with_index/write_batch_with_index_test.cc +++ b/utilities/write_batch_with_index/write_batch_with_index_test.cc @@ -2973,6 +2973,128 @@ TEST_P(WriteBatchWithIndexTest, GetEntityFromBatch) { } } +TEST_P(WriteBatchWithIndexTest, EntityReadSanityChecks) { + ASSERT_OK(OpenDB()); + + constexpr char foo[] = "foo"; + constexpr char bar[] = "bar"; + constexpr size_t num_keys = 2; + + { + constexpr DB* db = nullptr; + PinnableWideColumns columns; + ASSERT_TRUE(batch_ + ->GetEntityFromBatchAndDB(db, ReadOptions(), + db_->DefaultColumnFamily(), foo, + &columns) + .IsInvalidArgument()); + } + + { + constexpr ColumnFamilyHandle* column_family = nullptr; + PinnableWideColumns columns; + ASSERT_TRUE(batch_ + ->GetEntityFromBatchAndDB(db_, ReadOptions(), column_family, + foo, &columns) + .IsInvalidArgument()); + } + + { + constexpr PinnableWideColumns* columns = nullptr; + ASSERT_TRUE(batch_ + ->GetEntityFromBatchAndDB(db_, ReadOptions(), + db_->DefaultColumnFamily(), foo, + columns) + .IsInvalidArgument()); + } + + { + ReadOptions read_options; + read_options.io_activity = Env::IOActivity::kGet; + + PinnableWideColumns columns; + ASSERT_TRUE(batch_ + ->GetEntityFromBatchAndDB(db_, read_options, + db_->DefaultColumnFamily(), foo, + &columns) + .IsInvalidArgument()); + } + + { + constexpr DB* db = nullptr; + std::array keys{{foo, bar}}; + std::array results; + std::array statuses; + constexpr bool sorted_input = false; + + batch_->MultiGetEntityFromBatchAndDB( + db, ReadOptions(), db_->DefaultColumnFamily(), num_keys, keys.data(), + results.data(), statuses.data(), sorted_input); + + ASSERT_TRUE(statuses[0].IsInvalidArgument()); + ASSERT_TRUE(statuses[1].IsInvalidArgument()); + } + + { + constexpr ColumnFamilyHandle* column_family = nullptr; + std::array keys{{foo, bar}}; + std::array results; + std::array statuses; + constexpr bool sorted_input = false; + + batch_->MultiGetEntityFromBatchAndDB(db_, ReadOptions(), column_family, + num_keys, keys.data(), results.data(), + statuses.data(), sorted_input); + + ASSERT_TRUE(statuses[0].IsInvalidArgument()); + ASSERT_TRUE(statuses[1].IsInvalidArgument()); + } + + { + constexpr Slice* keys = nullptr; + std::array results; + std::array statuses; + constexpr bool sorted_input = false; + + batch_->MultiGetEntityFromBatchAndDB( + db_, ReadOptions(), db_->DefaultColumnFamily(), num_keys, keys, + results.data(), statuses.data(), sorted_input); + + ASSERT_TRUE(statuses[0].IsInvalidArgument()); + ASSERT_TRUE(statuses[1].IsInvalidArgument()); + } + + { + std::array keys{{foo, bar}}; + constexpr PinnableWideColumns* results = nullptr; + std::array statuses; + constexpr bool sorted_input = false; + + batch_->MultiGetEntityFromBatchAndDB( + db_, ReadOptions(), db_->DefaultColumnFamily(), num_keys, keys.data(), + results, statuses.data(), sorted_input); + + ASSERT_TRUE(statuses[0].IsInvalidArgument()); + ASSERT_TRUE(statuses[1].IsInvalidArgument()); + } + + { + ReadOptions read_options; + read_options.io_activity = Env::IOActivity::kMultiGet; + + std::array keys{{foo, bar}}; + std::array results; + std::array statuses; + constexpr bool sorted_input = false; + + batch_->MultiGetEntityFromBatchAndDB( + db_, read_options, db_->DefaultColumnFamily(), num_keys, keys.data(), + results.data(), statuses.data(), sorted_input); + ASSERT_TRUE(statuses[0].IsInvalidArgument()); + ASSERT_TRUE(statuses[1].IsInvalidArgument()); + } +} + INSTANTIATE_TEST_CASE_P(WBWI, WriteBatchWithIndexTest, testing::Bool()); } // namespace ROCKSDB_NAMESPACE