Improve the sanity checks in (Multi)GetEntity and friends (#12630)

Summary:
Pull Request resolved: https://github.com/facebook/rocksdb/pull/12630

The patch cleans up, improves, and brings into sync (to the extent possible without API signature changes) the sanity checks around the `GetEntity` / `MultiGetEntity` family of APIs, including the read-your-own-writes (`WriteBatchWithIndex`) and transaction layers. The checks are centralized in two main sets of entry points, namely in `DB(Impl)` and the "main" `GetEntityFromBatchAndDB` / `MultiGetEntityFromBatchAndDB` overloads in `WriteBatchWithIndex`. This eliminates the need to duplicate the checks in the transaction classes.

Reviewed By: jaykorean

Differential Revision: D57125741

fbshipit-source-id: 4dd059ef644a9b173fbba767538943397e4cc6cd
This commit is contained in:
Levi Tamasi 2024-05-09 12:25:19 -07:00 committed by Facebook GitHub Bot
parent 1a3357648f
commit 97e70906fa
9 changed files with 545 additions and 117 deletions

View File

@ -2099,7 +2099,7 @@ Status DBImpl::GetEntity(const ReadOptions& _read_options,
if (_read_options.io_activity != Env::IOActivity::kUnknown && if (_read_options.io_activity != Env::IOActivity::kUnknown &&
_read_options.io_activity != Env::IOActivity::kGetEntity) { _read_options.io_activity != Env::IOActivity::kGetEntity) {
return Status::InvalidArgument( return Status::InvalidArgument(
"Cannot call GetEntity with `ReadOptions::io_activity` != " "Can only call GetEntity with `ReadOptions::io_activity` set to "
"`Env::IOActivity::kUnknown` or `Env::IOActivity::kGetEntity`"); "`Env::IOActivity::kUnknown` or `Env::IOActivity::kGetEntity`");
} }
ReadOptions read_options(_read_options); ReadOptions read_options(_read_options);
@ -2126,7 +2126,7 @@ Status DBImpl::GetEntity(const ReadOptions& _read_options, const Slice& key,
if (_read_options.io_activity != Env::IOActivity::kUnknown && if (_read_options.io_activity != Env::IOActivity::kUnknown &&
_read_options.io_activity != Env::IOActivity::kGetEntity) { _read_options.io_activity != Env::IOActivity::kGetEntity) {
s = Status::InvalidArgument( s = Status::InvalidArgument(
"Cannot call GetEntity with `ReadOptions::io_activity` != " "Can only call GetEntity with `ReadOptions::io_activity` set to "
"`Env::IOActivity::kUnknown` or `Env::IOActivity::kGetEntity`"); "`Env::IOActivity::kUnknown` or `Env::IOActivity::kGetEntity`");
for (size_t i = 0; i < num_column_families; ++i) { for (size_t i = 0; i < num_column_families; ++i) {
(*result)[i].SetStatus(s); (*result)[i].SetStatus(s);
@ -3185,22 +3185,55 @@ void DBImpl::MultiGetEntity(const ReadOptions& _read_options, size_t num_keys,
ColumnFamilyHandle** column_families, ColumnFamilyHandle** column_families,
const Slice* keys, PinnableWideColumns* results, const Slice* keys, PinnableWideColumns* results,
Status* statuses, bool sorted_input) { Status* statuses, bool sorted_input) {
if (_read_options.io_activity != Env::IOActivity::kUnknown && assert(statuses);
_read_options.io_activity != Env::IOActivity::kMultiGetEntity) {
Status s = Status::InvalidArgument( if (!column_families) {
"Can only call MultiGetEntity with `ReadOptions::io_activity` is " const Status s = Status::InvalidArgument(
"`Env::IOActivity::kUnknown` or `Env::IOActivity::kMultiGetEntity`"); "Cannot call MultiGetEntity without column families");
for (size_t i = 0; i < num_keys; ++i) { for (size_t i = 0; i < num_keys; ++i) {
if (statuses[i].ok()) {
statuses[i] = s; statuses[i] = s;
} }
}
return; return;
} }
if (!keys) {
const Status s =
Status::InvalidArgument("Cannot call MultiGetEntity without keys");
for (size_t i = 0; i < num_keys; ++i) {
statuses[i] = s;
}
return;
}
if (!results) {
const Status s = Status::InvalidArgument(
"Cannot call MultiGetEntity without PinnableWideColumns objects");
for (size_t i = 0; i < num_keys; ++i) {
statuses[i] = s;
}
return;
}
if (_read_options.io_activity != Env::IOActivity::kUnknown &&
_read_options.io_activity != Env::IOActivity::kMultiGetEntity) {
const Status s = Status::InvalidArgument(
"Can only call MultiGetEntity with `ReadOptions::io_activity` set to "
"`Env::IOActivity::kUnknown` or `Env::IOActivity::kMultiGetEntity`");
for (size_t i = 0; i < num_keys; ++i) {
statuses[i] = s;
}
return;
}
ReadOptions read_options(_read_options); ReadOptions read_options(_read_options);
if (read_options.io_activity == Env::IOActivity::kUnknown) { if (read_options.io_activity == Env::IOActivity::kUnknown) {
read_options.io_activity = Env::IOActivity::kMultiGetEntity; read_options.io_activity = Env::IOActivity::kMultiGetEntity;
} }
MultiGetCommon(read_options, num_keys, column_families, keys, MultiGetCommon(read_options, num_keys, column_families, keys,
/* values */ nullptr, results, /* timestamps */ nullptr, /* values */ nullptr, results, /* timestamps */ nullptr,
statuses, sorted_input); statuses, sorted_input);
@ -3210,22 +3243,54 @@ void DBImpl::MultiGetEntity(const ReadOptions& _read_options,
ColumnFamilyHandle* column_family, size_t num_keys, ColumnFamilyHandle* column_family, size_t num_keys,
const Slice* keys, PinnableWideColumns* results, const Slice* keys, PinnableWideColumns* results,
Status* statuses, bool sorted_input) { Status* statuses, bool sorted_input) {
if (_read_options.io_activity != Env::IOActivity::kUnknown && assert(statuses);
_read_options.io_activity != Env::IOActivity::kMultiGetEntity) {
Status s = Status::InvalidArgument( if (!column_family) {
"Can only call MultiGetEntity with `ReadOptions::io_activity` is " const Status s = Status::InvalidArgument(
"`Env::IOActivity::kUnknown` or `Env::IOActivity::kMultiGetEntity`"); "Cannot call MultiGetEntity without a column family handle");
for (size_t i = 0; i < num_keys; ++i) { for (size_t i = 0; i < num_keys; ++i) {
if (statuses[i].ok()) {
statuses[i] = s; statuses[i] = s;
} }
return;
}
if (!keys) {
const Status s =
Status::InvalidArgument("Cannot call MultiGetEntity without keys");
for (size_t i = 0; i < num_keys; ++i) {
statuses[i] = s;
}
return;
}
if (!results) {
const Status s = Status::InvalidArgument(
"Cannot call MultiGetEntity without PinnableWideColumns objects");
for (size_t i = 0; i < num_keys; ++i) {
statuses[i] = s;
}
return;
}
if (_read_options.io_activity != Env::IOActivity::kUnknown &&
_read_options.io_activity != Env::IOActivity::kMultiGetEntity) {
const Status s = Status::InvalidArgument(
"Can only call MultiGetEntity with `ReadOptions::io_activity` set to "
"`Env::IOActivity::kUnknown` or `Env::IOActivity::kMultiGetEntity`");
for (size_t i = 0; i < num_keys; ++i) {
statuses[i] = s;
} }
return; return;
} }
ReadOptions read_options(_read_options); ReadOptions read_options(_read_options);
if (read_options.io_activity == Env::IOActivity::kUnknown) { if (read_options.io_activity == Env::IOActivity::kUnknown) {
read_options.io_activity = Env::IOActivity::kMultiGetEntity; read_options.io_activity = Env::IOActivity::kMultiGetEntity;
} }
MultiGetCommon(read_options, column_family, num_keys, keys, MultiGetCommon(read_options, column_family, num_keys, keys,
/* values */ nullptr, results, /* timestamps */ nullptr, /* values */ nullptr, results, /* timestamps */ nullptr,
statuses, sorted_input); statuses, sorted_input);
@ -3234,18 +3299,34 @@ void DBImpl::MultiGetEntity(const ReadOptions& _read_options,
void DBImpl::MultiGetEntity(const ReadOptions& _read_options, size_t num_keys, void DBImpl::MultiGetEntity(const ReadOptions& _read_options, size_t num_keys,
const Slice* keys, const Slice* keys,
PinnableAttributeGroups* results) { PinnableAttributeGroups* results) {
assert(results);
if (!keys) {
const Status s =
Status::InvalidArgument("Cannot call MultiGetEntity without keys");
for (size_t i = 0; i < num_keys; ++i) {
for (size_t j = 0; j < results[i].size(); ++j) {
results[i][j].SetStatus(s);
}
}
return;
}
if (_read_options.io_activity != Env::IOActivity::kUnknown && if (_read_options.io_activity != Env::IOActivity::kUnknown &&
_read_options.io_activity != Env::IOActivity::kMultiGetEntity) { _read_options.io_activity != Env::IOActivity::kMultiGetEntity) {
Status s = Status::InvalidArgument( const Status s = Status::InvalidArgument(
"Can only call MultiGetEntity with ReadOptions::io_activity` is " "Can only call MultiGetEntity with `ReadOptions::io_activity` set to "
"`Env::IOActivity::kUnknown` or `Env::IOActivity::kMultiGetEntity`"); "`Env::IOActivity::kUnknown` or `Env::IOActivity::kMultiGetEntity`");
for (size_t i = 0; i < num_keys; ++i) { for (size_t i = 0; i < num_keys; ++i) {
for (size_t j = 0; j < results[i].size(); ++j) { for (size_t j = 0; j < results[i].size(); ++j) {
results[i][j].SetStatus(s); results[i][j].SetStatus(s);
} }
} }
return; return;
} }
ReadOptions read_options(_read_options); ReadOptions read_options(_read_options);
if (read_options.io_activity == Env::IOActivity::kUnknown) { if (read_options.io_activity == Env::IOActivity::kUnknown) {
read_options.io_activity = Env::IOActivity::kMultiGetEntity; read_options.io_activity = Env::IOActivity::kMultiGetEntity;
@ -3263,6 +3344,7 @@ void DBImpl::MultiGetEntity(const ReadOptions& _read_options, size_t num_keys,
++total_count; ++total_count;
} }
} }
std::vector<Status> statuses(total_count); std::vector<Status> statuses(total_count);
std::vector<PinnableWideColumns> columns(total_count); std::vector<PinnableWideColumns> columns(total_count);
MultiGetCommon(read_options, total_count, column_families.data(), MultiGetCommon(read_options, total_count, column_families.data(),
@ -3283,6 +3365,15 @@ void DBImpl::MultiGetEntity(const ReadOptions& _read_options, size_t num_keys,
} }
} }
void DBImpl::MultiGetEntityWithCallback(
const ReadOptions& read_options, ColumnFamilyHandle* column_family,
ReadCallback* callback,
autovector<KeyContext*, MultiGetContext::MAX_BATCH_SIZE>* sorted_keys) {
assert(read_options.io_activity == Env::IOActivity::kMultiGetEntity);
MultiGetWithCallbackImpl(read_options, column_family, callback, sorted_keys);
}
Status DBImpl::WrapUpCreateColumnFamilies( Status DBImpl::WrapUpCreateColumnFamilies(
const ReadOptions& read_options, const WriteOptions& write_options, const ReadOptions& read_options, const WriteOptions& write_options,
const std::vector<const ColumnFamilyOptions*>& cf_options) { const std::vector<const ColumnFamilyOptions*>& cf_options) {

View File

@ -291,6 +291,11 @@ class DBImpl : public DB {
const Slice* keys, const Slice* keys,
PinnableAttributeGroups* results) override; PinnableAttributeGroups* results) override;
void MultiGetEntityWithCallback(
const ReadOptions& read_options, ColumnFamilyHandle* column_family,
ReadCallback* callback,
autovector<KeyContext*, MultiGetContext::MAX_BATCH_SIZE>* sorted_keys);
Status CreateColumnFamily(const ColumnFamilyOptions& cf_options, Status CreateColumnFamily(const ColumnFamilyOptions& cf_options,
const std::string& column_family, const std::string& column_family,
ColumnFamilyHandle** handle) override { ColumnFamilyHandle** handle) override {

View File

@ -1794,6 +1794,144 @@ TEST_F(DBWideBasicTest, PinnableWideColumnsMove) {
test_move(/* fill_cache*/ true); test_move(/* fill_cache*/ true);
} }
TEST_F(DBWideBasicTest, SanityChecks) {
constexpr char foo[] = "foo";
constexpr char bar[] = "bar";
constexpr size_t num_keys = 2;
{
constexpr ColumnFamilyHandle* column_family = nullptr;
PinnableWideColumns columns;
ASSERT_TRUE(db_->GetEntity(ReadOptions(), column_family, foo, &columns)
.IsInvalidArgument());
}
{
constexpr PinnableWideColumns* columns = nullptr;
ASSERT_TRUE(
db_->GetEntity(ReadOptions(), db_->DefaultColumnFamily(), foo, columns)
.IsInvalidArgument());
}
{
ReadOptions read_options;
read_options.io_activity = Env::IOActivity::kGet;
PinnableWideColumns columns;
ASSERT_TRUE(
db_->GetEntity(read_options, db_->DefaultColumnFamily(), foo, &columns)
.IsInvalidArgument());
}
{
constexpr ColumnFamilyHandle* column_family = nullptr;
std::array<Slice, num_keys> keys{{foo, bar}};
std::array<PinnableWideColumns, num_keys> results;
std::array<Status, num_keys> statuses;
db_->MultiGetEntity(ReadOptions(), column_family, num_keys, keys.data(),
results.data(), statuses.data());
ASSERT_TRUE(statuses[0].IsInvalidArgument());
ASSERT_TRUE(statuses[1].IsInvalidArgument());
}
{
constexpr Slice* keys = nullptr;
std::array<PinnableWideColumns, num_keys> results;
std::array<Status, num_keys> statuses;
db_->MultiGetEntity(ReadOptions(), db_->DefaultColumnFamily(), num_keys,
keys, results.data(), statuses.data());
ASSERT_TRUE(statuses[0].IsInvalidArgument());
ASSERT_TRUE(statuses[1].IsInvalidArgument());
}
{
std::array<Slice, num_keys> keys{{foo, bar}};
constexpr PinnableWideColumns* results = nullptr;
std::array<Status, num_keys> statuses;
db_->MultiGetEntity(ReadOptions(), db_->DefaultColumnFamily(), num_keys,
keys.data(), results, statuses.data());
ASSERT_TRUE(statuses[0].IsInvalidArgument());
ASSERT_TRUE(statuses[1].IsInvalidArgument());
}
{
ReadOptions read_options;
read_options.io_activity = Env::IOActivity::kMultiGet;
std::array<Slice, num_keys> keys{{foo, bar}};
std::array<PinnableWideColumns, num_keys> results;
std::array<Status, num_keys> statuses;
db_->MultiGetEntity(read_options, db_->DefaultColumnFamily(), num_keys,
keys.data(), results.data(), statuses.data());
ASSERT_TRUE(statuses[0].IsInvalidArgument());
ASSERT_TRUE(statuses[1].IsInvalidArgument());
}
{
constexpr ColumnFamilyHandle** column_families = nullptr;
std::array<Slice, num_keys> keys{{foo, bar}};
std::array<PinnableWideColumns, num_keys> results;
std::array<Status, num_keys> statuses;
db_->MultiGetEntity(ReadOptions(), num_keys, column_families, keys.data(),
results.data(), statuses.data());
ASSERT_TRUE(statuses[0].IsInvalidArgument());
ASSERT_TRUE(statuses[1].IsInvalidArgument());
}
{
std::array<ColumnFamilyHandle*, num_keys> column_families{
{db_->DefaultColumnFamily(), db_->DefaultColumnFamily()}};
constexpr Slice* keys = nullptr;
std::array<PinnableWideColumns, num_keys> results;
std::array<Status, num_keys> statuses;
db_->MultiGetEntity(ReadOptions(), num_keys, column_families.data(), keys,
results.data(), statuses.data());
ASSERT_TRUE(statuses[0].IsInvalidArgument());
ASSERT_TRUE(statuses[1].IsInvalidArgument());
}
{
std::array<ColumnFamilyHandle*, num_keys> column_families{
{db_->DefaultColumnFamily(), db_->DefaultColumnFamily()}};
std::array<Slice, num_keys> keys{{foo, bar}};
constexpr PinnableWideColumns* results = nullptr;
std::array<Status, num_keys> statuses;
db_->MultiGetEntity(ReadOptions(), num_keys, column_families.data(),
keys.data(), results, statuses.data());
ASSERT_TRUE(statuses[0].IsInvalidArgument());
ASSERT_TRUE(statuses[1].IsInvalidArgument());
}
{
ReadOptions read_options;
read_options.io_activity = Env::IOActivity::kMultiGet;
std::array<ColumnFamilyHandle*, num_keys> column_families{
{db_->DefaultColumnFamily(), db_->DefaultColumnFamily()}};
std::array<Slice, num_keys> keys{{foo, bar}};
std::array<PinnableWideColumns, num_keys> results;
std::array<Status, num_keys> statuses;
db_->MultiGetEntity(read_options, num_keys, column_families.data(),
keys.data(), results.data(), statuses.data());
ASSERT_TRUE(statuses[0].IsInvalidArgument());
ASSERT_TRUE(statuses[1].IsInvalidArgument());
}
}
} // namespace ROCKSDB_NAMESPACE } // namespace ROCKSDB_NAMESPACE
int main(int argc, char** argv) { int main(int argc, char** argv) {

View File

@ -284,7 +284,12 @@ class WriteBatchWithIndex : public WriteBatchBase {
Status GetEntityFromBatchAndDB(DB* db, const ReadOptions& read_options, Status GetEntityFromBatchAndDB(DB* db, const ReadOptions& read_options,
ColumnFamilyHandle* column_family, ColumnFamilyHandle* column_family,
const Slice& key, const Slice& key,
PinnableWideColumns* columns); PinnableWideColumns* columns) {
constexpr ReadCallback* callback = nullptr;
return GetEntityFromBatchAndDB(db, read_options, column_family, key,
columns, callback);
}
void MultiGetFromBatchAndDB(DB* db, const ReadOptions& read_options, void MultiGetFromBatchAndDB(DB* db, const ReadOptions& read_options,
ColumnFamilyHandle* column_family, ColumnFamilyHandle* column_family,
@ -310,7 +315,13 @@ class WriteBatchWithIndex : public WriteBatchBase {
ColumnFamilyHandle* column_family, ColumnFamilyHandle* column_family,
size_t num_keys, const Slice* keys, size_t num_keys, const Slice* keys,
PinnableWideColumns* results, PinnableWideColumns* results,
Status* statuses, bool sorted_input); Status* statuses, bool sorted_input) {
constexpr ReadCallback* callback = nullptr;
MultiGetEntityFromBatchAndDB(db, read_options, column_family, num_keys,
keys, results, statuses, sorted_input,
callback);
}
// Records the state of the batch for future calls to RollbackToSavePoint(). // Records the state of the batch for future calls to RollbackToSavePoint().
// May be called multiple times to set multiple save points. // May be called multiple times to set multiple save points.

View File

@ -1928,6 +1928,37 @@ TEST_P(OptimisticTransactionTest, PutEntityWriteConflictTxnTxn) {
} }
} }
TEST_P(OptimisticTransactionTest, EntityReadSanityChecks) {
constexpr char foo[] = "foo";
std::unique_ptr<Transaction> txn(txn_db->BeginTransaction(WriteOptions()));
ASSERT_NE(txn, nullptr);
{
constexpr ColumnFamilyHandle* column_family = nullptr;
PinnableWideColumns columns;
ASSERT_TRUE(txn->GetEntity(ReadOptions(), column_family, foo, &columns)
.IsInvalidArgument());
}
{
constexpr PinnableWideColumns* columns = nullptr;
ASSERT_TRUE(txn->GetEntity(ReadOptions(), txn_db->DefaultColumnFamily(),
foo, columns)
.IsInvalidArgument());
}
{
ReadOptions read_options;
read_options.io_activity = Env::IOActivity::kGet;
PinnableWideColumns columns;
ASSERT_TRUE(txn->GetEntity(read_options, txn_db->DefaultColumnFamily(), foo,
&columns)
.IsInvalidArgument());
}
}
INSTANTIATE_TEST_CASE_P( INSTANTIATE_TEST_CASE_P(
InstanceOccGroup, OptimisticTransactionTest, InstanceOccGroup, OptimisticTransactionTest,
testing::Values(OccValidationPolicy::kValidateSerial, testing::Values(OccValidationPolicy::kValidateSerial,

View File

@ -289,22 +289,10 @@ Status TransactionBaseImpl::GetImpl(const ReadOptions& read_options,
pinnable_val); pinnable_val);
} }
Status TransactionBaseImpl::GetEntity(const ReadOptions& _read_options, Status TransactionBaseImpl::GetEntity(const ReadOptions& read_options,
ColumnFamilyHandle* column_family, ColumnFamilyHandle* column_family,
const Slice& key, const Slice& key,
PinnableWideColumns* columns) { PinnableWideColumns* columns) {
if (_read_options.io_activity != Env::IOActivity::kUnknown &&
_read_options.io_activity != Env::IOActivity::kGetEntity) {
return Status::InvalidArgument(
"Can only call GetEntity with `ReadOptions::io_activity` set to "
"`Env::IOActivity::kUnknown` or `Env::IOActivity::kGetEntity`");
}
ReadOptions read_options(_read_options);
if (read_options.io_activity == Env::IOActivity::kUnknown) {
read_options.io_activity = Env::IOActivity::kGetEntity;
}
return GetEntityImpl(read_options, column_family, key, columns); return GetEntityImpl(read_options, column_family, key, columns);
} }

View File

@ -7150,6 +7150,37 @@ TEST_P(TransactionTest, PutEntityWriteConflict) {
} }
} }
TEST_P(TransactionTest, EntityReadSanityChecks) {
constexpr char foo[] = "foo";
std::unique_ptr<Transaction> txn(db->BeginTransaction(WriteOptions()));
ASSERT_NE(txn, nullptr);
{
constexpr ColumnFamilyHandle* column_family = nullptr;
PinnableWideColumns columns;
ASSERT_TRUE(txn->GetEntity(ReadOptions(), column_family, foo, &columns)
.IsInvalidArgument());
}
{
constexpr PinnableWideColumns* columns = nullptr;
ASSERT_TRUE(
txn->GetEntity(ReadOptions(), db->DefaultColumnFamily(), foo, columns)
.IsInvalidArgument());
}
{
ReadOptions read_options;
read_options.io_activity = Env::IOActivity::kGet;
PinnableWideColumns columns;
ASSERT_TRUE(
txn->GetEntity(read_options, db->DefaultColumnFamily(), foo, &columns)
.IsInvalidArgument());
}
}
TEST_F(TransactionDBTest, CollapseKey) { TEST_F(TransactionDBTest, CollapseKey) {
ASSERT_OK(ReOpen()); ASSERT_OK(ReOpen());
ASSERT_OK(db->Put({}, "hello", "world")); ASSERT_OK(db->Put({}, "hello", "world"));

View File

@ -823,11 +823,41 @@ void WriteBatchWithIndex::MultiGetFromBatchAndDB(
} }
Status WriteBatchWithIndex::GetEntityFromBatchAndDB( Status WriteBatchWithIndex::GetEntityFromBatchAndDB(
DB* db, const ReadOptions& read_options, ColumnFamilyHandle* column_family, DB* db, const ReadOptions& _read_options, ColumnFamilyHandle* column_family,
const Slice& key, PinnableWideColumns* columns, ReadCallback* callback) { const Slice& key, PinnableWideColumns* columns, ReadCallback* callback) {
assert(db); if (!db) {
assert(column_family); return Status::InvalidArgument(
assert(columns); "Cannot call GetEntityFromBatchAndDB without a DB object");
}
if (_read_options.io_activity != Env::IOActivity::kUnknown &&
_read_options.io_activity != Env::IOActivity::kGetEntity) {
return Status::InvalidArgument(
"Can only call GetEntityFromBatchAndDB with `ReadOptions::io_activity` "
"set to `Env::IOActivity::kUnknown` or `Env::IOActivity::kGetEntity`");
}
ReadOptions read_options(_read_options);
if (read_options.io_activity == Env::IOActivity::kUnknown) {
read_options.io_activity = Env::IOActivity::kGetEntity;
}
if (!column_family) {
return Status::InvalidArgument(
"Cannot call GetEntityFromBatchAndDB without a column family handle");
}
const Comparator* const ucmp = rep->comparator.GetComparator(column_family);
size_t ts_sz = ucmp ? ucmp->timestamp_size() : 0;
if (ts_sz > 0 && !read_options.timestamp) {
return Status::InvalidArgument("Must specify timestamp");
}
if (!columns) {
return Status::InvalidArgument(
"Cannot call GetEntityFromBatchAndDB without a PinnableWideColumns "
"object");
}
columns->Reset(); columns->Reset();
@ -872,46 +902,78 @@ Status WriteBatchWithIndex::GetEntityFromBatchAndDB(
return s; return s;
} }
Status WriteBatchWithIndex::GetEntityFromBatchAndDB( void WriteBatchWithIndex::MultiGetEntityFromBatchAndDB(
DB* db, const ReadOptions& read_options, ColumnFamilyHandle* column_family, DB* db, const ReadOptions& _read_options, ColumnFamilyHandle* column_family,
const Slice& key, PinnableWideColumns* columns) { size_t num_keys, const Slice* keys, PinnableWideColumns* results,
Status* statuses, bool sorted_input, ReadCallback* callback) {
assert(statuses);
if (!db) { if (!db) {
return Status::InvalidArgument( const Status s = Status::InvalidArgument(
"Cannot call GetEntityFromBatchAndDB without a DB object"); "Cannot call MultiGetEntityFromBatchAndDB without a DB object");
for (size_t i = 0; i < num_keys; ++i) {
statuses[i] = s;
}
return;
}
if (_read_options.io_activity != Env::IOActivity::kUnknown &&
_read_options.io_activity != Env::IOActivity::kMultiGetEntity) {
const Status s = Status::InvalidArgument(
"Can only call MultiGetEntityFromBatchAndDB with "
"`ReadOptions::io_activity` set to `Env::IOActivity::kUnknown` or "
"`Env::IOActivity::kMultiGetEntity`");
for (size_t i = 0; i < num_keys; ++i) {
if (statuses[i].ok()) {
statuses[i] = s;
}
}
return;
}
ReadOptions read_options(_read_options);
if (read_options.io_activity == Env::IOActivity::kUnknown) {
read_options.io_activity = Env::IOActivity::kMultiGetEntity;
} }
if (!column_family) { if (!column_family) {
return Status::InvalidArgument( const Status s = Status::InvalidArgument(
"Cannot call GetEntityFromBatchAndDB without a column family handle"); "Cannot call MultiGetEntityFromBatchAndDB without a column family "
"handle");
for (size_t i = 0; i < num_keys; ++i) {
statuses[i] = s;
}
return;
} }
const Comparator* const ucmp = rep->comparator.GetComparator(column_family); const Comparator* const ucmp = rep->comparator.GetComparator(column_family);
size_t ts_sz = ucmp ? ucmp->timestamp_size() : 0; const size_t ts_sz = ucmp ? ucmp->timestamp_size() : 0;
if (ts_sz > 0 && !read_options.timestamp) { if (ts_sz > 0 && !read_options.timestamp) {
return Status::InvalidArgument("Must specify timestamp"); const Status s = Status::InvalidArgument("Must specify timestamp");
for (size_t i = 0; i < num_keys; ++i) {
statuses[i] = s;
}
return;
} }
if (!columns) { if (!keys) {
return Status::InvalidArgument( const Status s = Status::InvalidArgument(
"Cannot call GetEntityFromBatchAndDB without a PinnableWideColumns " "Cannot call MultiGetEntityFromBatchAndDB without keys");
"object"); for (size_t i = 0; i < num_keys; ++i) {
statuses[i] = s;
}
return;
} }
constexpr ReadCallback* callback = nullptr; if (!results) {
const Status s = Status::InvalidArgument(
return GetEntityFromBatchAndDB(db, read_options, column_family, key, columns, "Cannot call MultiGetEntityFromBatchAndDB without "
callback); "PinnableWideColumns objects");
} for (size_t i = 0; i < num_keys; ++i) {
statuses[i] = s;
void WriteBatchWithIndex::MultiGetEntityFromBatchAndDB( }
DB* db, const ReadOptions& read_options, ColumnFamilyHandle* column_family, return;
size_t num_keys, const Slice* keys, PinnableWideColumns* results, }
Status* statuses, bool sorted_input, ReadCallback* callback) {
assert(db);
assert(column_family);
assert(keys);
assert(results);
assert(statuses);
struct MergeTuple { struct MergeTuple {
MergeTuple(const Slice& _key, Status* _s, MergeContext&& _merge_context, MergeTuple(const Slice& _key, Status* _s, MergeContext&& _merge_context,
@ -990,7 +1052,7 @@ void WriteBatchWithIndex::MultiGetEntityFromBatchAndDB(
static_cast_with_check<DBImpl>(db->GetRootDB()) static_cast_with_check<DBImpl>(db->GetRootDB())
->PrepareMultiGetKeys(sorted_keys.size(), sorted_input, &sorted_keys); ->PrepareMultiGetKeys(sorted_keys.size(), sorted_input, &sorted_keys);
static_cast_with_check<DBImpl>(db->GetRootDB()) static_cast_with_check<DBImpl>(db->GetRootDB())
->MultiGetWithCallback(read_options, column_family, callback, ->MultiGetEntityWithCallback(read_options, column_family, callback,
&sorted_keys); &sorted_keys);
for (const auto& merge : merges) { for (const auto& merge : merges) {
@ -1001,57 +1063,6 @@ void WriteBatchWithIndex::MultiGetEntityFromBatchAndDB(
} }
} }
void WriteBatchWithIndex::MultiGetEntityFromBatchAndDB(
DB* db, const ReadOptions& read_options, ColumnFamilyHandle* column_family,
size_t num_keys, const Slice* keys, PinnableWideColumns* results,
Status* statuses, bool sorted_input) {
assert(statuses);
if (!db) {
for (size_t i = 0; i < num_keys; ++i) {
statuses[i] = Status::InvalidArgument(
"Cannot call MultiGetEntityFromBatchAndDB without a DB object");
}
}
if (!column_family) {
for (size_t i = 0; i < num_keys; ++i) {
statuses[i] = Status::InvalidArgument(
"Cannot call MultiGetEntityFromBatchAndDB without a column family "
"handle");
}
}
const Comparator* const ucmp = rep->comparator.GetComparator(column_family);
const size_t ts_sz = ucmp ? ucmp->timestamp_size() : 0;
if (ts_sz > 0 && !read_options.timestamp) {
for (size_t i = 0; i < num_keys; ++i) {
statuses[i] = Status::InvalidArgument("Must specify timestamp");
}
return;
}
if (!keys) {
for (size_t i = 0; i < num_keys; ++i) {
statuses[i] = Status::InvalidArgument(
"Cannot call MultiGetEntityFromBatchAndDB without keys");
}
}
if (!results) {
for (size_t i = 0; i < num_keys; ++i) {
statuses[i] = Status::InvalidArgument(
"Cannot call MultiGetEntityFromBatchAndDB without "
"PinnableWideColumns objects");
}
}
constexpr ReadCallback* callback = nullptr;
MultiGetEntityFromBatchAndDB(db, read_options, column_family, num_keys, keys,
results, statuses, sorted_input, callback);
}
void WriteBatchWithIndex::SetSavePoint() { rep->write_batch.SetSavePoint(); } void WriteBatchWithIndex::SetSavePoint() { rep->write_batch.SetSavePoint(); }
Status WriteBatchWithIndex::RollbackToSavePoint() { Status WriteBatchWithIndex::RollbackToSavePoint() {

View File

@ -2973,6 +2973,128 @@ TEST_P(WriteBatchWithIndexTest, GetEntityFromBatch) {
} }
} }
TEST_P(WriteBatchWithIndexTest, EntityReadSanityChecks) {
ASSERT_OK(OpenDB());
constexpr char foo[] = "foo";
constexpr char bar[] = "bar";
constexpr size_t num_keys = 2;
{
constexpr DB* db = nullptr;
PinnableWideColumns columns;
ASSERT_TRUE(batch_
->GetEntityFromBatchAndDB(db, ReadOptions(),
db_->DefaultColumnFamily(), foo,
&columns)
.IsInvalidArgument());
}
{
constexpr ColumnFamilyHandle* column_family = nullptr;
PinnableWideColumns columns;
ASSERT_TRUE(batch_
->GetEntityFromBatchAndDB(db_, ReadOptions(), column_family,
foo, &columns)
.IsInvalidArgument());
}
{
constexpr PinnableWideColumns* columns = nullptr;
ASSERT_TRUE(batch_
->GetEntityFromBatchAndDB(db_, ReadOptions(),
db_->DefaultColumnFamily(), foo,
columns)
.IsInvalidArgument());
}
{
ReadOptions read_options;
read_options.io_activity = Env::IOActivity::kGet;
PinnableWideColumns columns;
ASSERT_TRUE(batch_
->GetEntityFromBatchAndDB(db_, read_options,
db_->DefaultColumnFamily(), foo,
&columns)
.IsInvalidArgument());
}
{
constexpr DB* db = nullptr;
std::array<Slice, num_keys> keys{{foo, bar}};
std::array<PinnableWideColumns, num_keys> results;
std::array<Status, num_keys> statuses;
constexpr bool sorted_input = false;
batch_->MultiGetEntityFromBatchAndDB(
db, ReadOptions(), db_->DefaultColumnFamily(), num_keys, keys.data(),
results.data(), statuses.data(), sorted_input);
ASSERT_TRUE(statuses[0].IsInvalidArgument());
ASSERT_TRUE(statuses[1].IsInvalidArgument());
}
{
constexpr ColumnFamilyHandle* column_family = nullptr;
std::array<Slice, num_keys> keys{{foo, bar}};
std::array<PinnableWideColumns, num_keys> results;
std::array<Status, num_keys> statuses;
constexpr bool sorted_input = false;
batch_->MultiGetEntityFromBatchAndDB(db_, ReadOptions(), column_family,
num_keys, keys.data(), results.data(),
statuses.data(), sorted_input);
ASSERT_TRUE(statuses[0].IsInvalidArgument());
ASSERT_TRUE(statuses[1].IsInvalidArgument());
}
{
constexpr Slice* keys = nullptr;
std::array<PinnableWideColumns, num_keys> results;
std::array<Status, num_keys> statuses;
constexpr bool sorted_input = false;
batch_->MultiGetEntityFromBatchAndDB(
db_, ReadOptions(), db_->DefaultColumnFamily(), num_keys, keys,
results.data(), statuses.data(), sorted_input);
ASSERT_TRUE(statuses[0].IsInvalidArgument());
ASSERT_TRUE(statuses[1].IsInvalidArgument());
}
{
std::array<Slice, num_keys> keys{{foo, bar}};
constexpr PinnableWideColumns* results = nullptr;
std::array<Status, num_keys> statuses;
constexpr bool sorted_input = false;
batch_->MultiGetEntityFromBatchAndDB(
db_, ReadOptions(), db_->DefaultColumnFamily(), num_keys, keys.data(),
results, statuses.data(), sorted_input);
ASSERT_TRUE(statuses[0].IsInvalidArgument());
ASSERT_TRUE(statuses[1].IsInvalidArgument());
}
{
ReadOptions read_options;
read_options.io_activity = Env::IOActivity::kMultiGet;
std::array<Slice, num_keys> keys{{foo, bar}};
std::array<PinnableWideColumns, num_keys> results;
std::array<Status, num_keys> statuses;
constexpr bool sorted_input = false;
batch_->MultiGetEntityFromBatchAndDB(
db_, read_options, db_->DefaultColumnFamily(), num_keys, keys.data(),
results.data(), statuses.data(), sorted_input);
ASSERT_TRUE(statuses[0].IsInvalidArgument());
ASSERT_TRUE(statuses[1].IsInvalidArgument());
}
}
INSTANTIATE_TEST_CASE_P(WBWI, WriteBatchWithIndexTest, testing::Bool()); INSTANTIATE_TEST_CASE_P(WBWI, WriteBatchWithIndexTest, testing::Bool());
} // namespace ROCKSDB_NAMESPACE } // namespace ROCKSDB_NAMESPACE