diff --git a/utilities/transactions/transaction_db_impl.cc b/utilities/transactions/transaction_db_impl.cc index f8a47b948c..91440ae657 100644 --- a/utilities/transactions/transaction_db_impl.cc +++ b/utilities/transactions/transaction_db_impl.cc @@ -24,7 +24,7 @@ TransactionDBImpl::TransactionDBImpl(DB* db, const TransactionDBOptions& txn_db_options) : TransactionDB(db), txn_db_options_(txn_db_options), - lock_mgr_(txn_db_options_.num_stripes, txn_db_options.max_num_locks, + lock_mgr_(this, txn_db_options_.num_stripes, txn_db_options.max_num_locks, txn_db_options_.custom_mutex_factory ? txn_db_options_.custom_mutex_factory : std::shared_ptr( @@ -278,5 +278,29 @@ Status TransactionDBImpl::Write(const WriteOptions& opts, WriteBatch* updates) { return s; } +void TransactionDBImpl::InsertExpirableTransaction(TransactionID tx_id, + TransactionImpl* tx) { + assert(tx->GetExpirationTime() > 0); + std::lock_guard lock(map_mutex_); + expirable_transactions_map_.insert({tx_id, tx}); +} + +void TransactionDBImpl::RemoveExpirableTransaction(TransactionID tx_id) { + std::lock_guard lock(map_mutex_); + expirable_transactions_map_.erase(tx_id); +} + +bool TransactionDBImpl::TryStealingExpiredTransactionLocks( + TransactionID tx_id) { + std::lock_guard lock(map_mutex_); + + auto tx_it = expirable_transactions_map_.find(tx_id); + if (tx_it == expirable_transactions_map_.end()) { + return true; + } + TransactionImpl& tx = *(tx_it->second); + return tx.TryStealingLocks(); +} + } // namespace rocksdb #endif // ROCKSDB_LITE diff --git a/utilities/transactions/transaction_db_impl.h b/utilities/transactions/transaction_db_impl.h index 5a9d8b474e..0605101360 100644 --- a/utilities/transactions/transaction_db_impl.h +++ b/utilities/transactions/transaction_db_impl.h @@ -6,7 +6,9 @@ #pragma once #ifndef ROCKSDB_LITE +#include #include +#include #include "rocksdb/db.h" #include "rocksdb/options.h" @@ -66,6 +68,15 @@ class TransactionDBImpl : public TransactionDB { return txn_db_options_; } + void InsertExpirableTransaction(TransactionID tx_id, TransactionImpl* tx); + void RemoveExpirableTransaction(TransactionID tx_id); + + // If transaction is no longer available, locks can be stolen + // If transaction is available, try stealing locks directly from transaction + // It is the caller's responsibility to ensure that the referred transaction + // is expirable (GetExpirationTime() > 0) and that it is expired. + bool TryStealingExpiredTransactionLocks(TransactionID tx_id); + private: const TransactionDBOptions txn_db_options_; TransactionLockMgr lock_mgr_; @@ -74,6 +85,13 @@ class TransactionDBImpl : public TransactionDB { InstrumentedMutex column_family_mutex_; Transaction* BeginInternalTransaction(const WriteOptions& options); Status WriteHelper(WriteBatch* updates, TransactionImpl* txn_impl); + + // Used to ensure that no locks are stolen from an expirable transaction + // that has started a commit. Only transactions with an expiration time + // should be in this map. + std::mutex map_mutex_; + std::unordered_map + expirable_transactions_map_; }; } // namespace rocksdb diff --git a/utilities/transactions/transaction_impl.cc b/utilities/transactions/transaction_impl.cc index 3f25ff77d9..2602d30e66 100644 --- a/utilities/transactions/transaction_impl.cc +++ b/utilities/transactions/transaction_impl.cc @@ -20,6 +20,7 @@ #include "rocksdb/status.h" #include "rocksdb/utilities/transaction_db.h" #include "util/string_util.h" +#include "util/sync_point.h" #include "utilities/transactions/transaction_db_impl.h" #include "utilities/transactions/transaction_util.h" @@ -42,7 +43,8 @@ TransactionImpl::TransactionImpl(TransactionDB* txn_db, expiration_time_(txn_options.expiration >= 0 ? start_time_ + txn_options.expiration * 1000 : 0), - lock_timeout_(txn_options.lock_timeout * 1000) { + lock_timeout_(txn_options.lock_timeout * 1000), + exec_status_(STARTED) { txn_db_impl_ = dynamic_cast(txn_db); assert(txn_db_impl_); @@ -55,10 +57,16 @@ TransactionImpl::TransactionImpl(TransactionDB* txn_db, if (txn_options.set_snapshot) { SetSnapshot(); } + if (expiration_time_ > 0) { + txn_db_impl_->InsertExpirableTransaction(txn_id_, this); + } } TransactionImpl::~TransactionImpl() { txn_db_impl_->UnLock(this, &GetTrackedKeys()); + if (expiration_time_ > 0) { + txn_db_impl_->RemoveExpirableTransaction(txn_id_); + } } void TransactionImpl::Clear() { @@ -103,18 +111,27 @@ Status TransactionImpl::DoCommit(WriteBatch* batch) { Status s; if (expiration_time_ > 0) { - // We cannot commit a transaction that is expired as its locks might have - // been released. - // To avoid race conditions, we need to use a WriteCallback to check the - // expiration time once we're on the writer thread. - TransactionCallback callback(this); + if (IsExpired()) { + return Status::Expired(); + } - // Do write directly on base db as TransctionDB::Write() would attempt to - // do conflict checking that we've already done. - assert(dynamic_cast(db_) != nullptr); - auto db_impl = reinterpret_cast(db_); + // Transaction should only be committed if the thread succeeds + // changing its execution status to COMMITTING. This is because + // A different transaction may consider this one expired and attempt + // to steal its locks between the IsExpired() check and the beginning + // of a commit. + ExecutionStatus expected = STARTED; + bool can_commit = std::atomic_compare_exchange_strong( + &exec_status_, &expected, COMMITTING); - s = db_impl->WriteWithCallback(write_options_, batch, &callback); + TEST_SYNC_POINT("TransactionTest::ExpirableTransactionDataRace:1"); + + if (can_commit) { + s = db_->Write(write_options_, batch); + } else { + assert(exec_status_ == LOCKS_STOLEN); + return Status::Expired(); + } } else { s = db_->Write(write_options_, batch); } @@ -316,6 +333,13 @@ Status TransactionImpl::ValidateSnapshot(ColumnFamilyHandle* column_family, false /* cache_only */); } +bool TransactionImpl::TryStealingLocks() { + assert(IsExpired()); + ExecutionStatus expected = STARTED; + return std::atomic_compare_exchange_strong(&exec_status_, &expected, + LOCKS_STOLEN); +} + } // namespace rocksdb #endif // ROCKSDB_LITE diff --git a/utilities/transactions/transaction_impl.h b/utilities/transactions/transaction_impl.h index 0fa087d67f..caed15d3ab 100644 --- a/utilities/transactions/transaction_impl.h +++ b/utilities/transactions/transaction_impl.h @@ -66,11 +66,16 @@ class TransactionImpl : public TransactionBaseImpl { lock_timeout_ = timeout * 1000; } + // Returns true if locks were stolen successfully, false otherwise. + bool TryStealingLocks(); + protected: Status TryLock(ColumnFamilyHandle* column_family, const Slice& key, bool untracked = false) override; private: + enum ExecutionStatus { STARTED, COMMITTING, LOCKS_STOLEN }; + TransactionDBImpl* txn_db_impl_; // Used to create unique ids for transactions. @@ -86,6 +91,9 @@ class TransactionImpl : public TransactionBaseImpl { // Timeout in microseconds when locking a key or -1 if there is no timeout. int64_t lock_timeout_; + // Execution status of the transaction. + std::atomic exec_status_; + void Clear() override; Status ValidateSnapshot(ColumnFamilyHandle* column_family, const Slice& key, @@ -102,24 +110,6 @@ class TransactionImpl : public TransactionBaseImpl { void operator=(const TransactionImpl&); }; -// Used at commit time to check whether transaction is committing before its -// expiration time. -class TransactionCallback : public WriteCallback { - public: - explicit TransactionCallback(TransactionImpl* txn) : txn_(txn) {} - - Status Callback(DB* db) override { - if (txn_->IsExpired()) { - return Status::Expired(); - } else { - return Status::OK(); - } - } - - private: - TransactionImpl* txn_; -}; - } // namespace rocksdb #endif // ROCKSDB_LITE diff --git a/utilities/transactions/transaction_lock_mgr.cc b/utilities/transactions/transaction_lock_mgr.cc index 80e4fb8d9b..51b8d4a826 100644 --- a/utilities/transactions/transaction_lock_mgr.cc +++ b/utilities/transactions/transaction_lock_mgr.cc @@ -25,6 +25,7 @@ #include "util/autovector.h" #include "util/murmurhash.h" #include "util/thread_local.h" +#include "utilities/transactions/transaction_db_impl.h" namespace rocksdb { @@ -99,12 +100,16 @@ void UnrefLockMapsCache(void* ptr) { } // anonymous namespace TransactionLockMgr::TransactionLockMgr( - size_t default_num_stripes, int64_t max_num_locks, + TransactionDB* txn_db, size_t default_num_stripes, int64_t max_num_locks, std::shared_ptr mutex_factory) - : default_num_stripes_(default_num_stripes), + : txn_db_impl_(nullptr), + default_num_stripes_(default_num_stripes), max_num_locks_(max_num_locks), mutex_factory_(mutex_factory), - lock_maps_cache_(new ThreadLocalPtr(&UnrefLockMapsCache)) {} + lock_maps_cache_(new ThreadLocalPtr(&UnrefLockMapsCache)) { + txn_db_impl_ = dynamic_cast(txn_db); + assert(txn_db_impl_); +} TransactionLockMgr::~TransactionLockMgr() {} @@ -197,6 +202,11 @@ bool TransactionLockMgr::IsLockExpired(const LockInfo& lock_info, Env* env, // return how many microseconds until lock will be expired *expire_time = lock_info.expiration_time; } else { + bool success = + txn_db_impl_->TryStealingExpiredTransactionLocks(lock_info.txn_id); + if (!success) { + expired = false; + } *expire_time = 0; } diff --git a/utilities/transactions/transaction_lock_mgr.h b/utilities/transactions/transaction_lock_mgr.h index 8f640d4ca2..fa46c62be0 100644 --- a/utilities/transactions/transaction_lock_mgr.h +++ b/utilities/transactions/transaction_lock_mgr.h @@ -24,10 +24,12 @@ struct LockMap; struct LockMapStripe; class Slice; +class TransactionDBImpl; class TransactionLockMgr { public: - TransactionLockMgr(size_t default_num_stripes, int64_t max_num_locks, + TransactionLockMgr(TransactionDB* txn_db, size_t default_num_stripes, + int64_t max_num_locks, std::shared_ptr factory); ~TransactionLockMgr(); @@ -53,6 +55,8 @@ class TransactionLockMgr { const std::string& key, Env* env); private: + TransactionDBImpl* txn_db_impl_; + // Default number of lock map stripes per column family const size_t default_num_stripes_; diff --git a/utilities/transactions/transaction_test.cc b/utilities/transactions/transaction_test.cc index 911212317f..859b02bce3 100644 --- a/utilities/transactions/transaction_test.cc +++ b/utilities/transactions/transaction_test.cc @@ -14,6 +14,7 @@ #include "rocksdb/utilities/transaction_db.h" #include "table/mock_table.h" #include "util/logging.h" +#include "util/sync_point.h" #include "util/testharness.h" #include "util/testutil.h" #include "utilities/merge_operators.h" @@ -2483,6 +2484,51 @@ TEST_F(TransactionTest, ToggleAutoCompactionTest) { } } +TEST_F(TransactionTest, ExpiredTransactionDataRace1) { + // In this test, txn1 should succeed committing, + // as the callback is called after txn1 starts committing. + rocksdb::SyncPoint::GetInstance()->LoadDependency( + {{"TransactionTest::ExpirableTransactionDataRace:1"}}); + rocksdb::SyncPoint::GetInstance()->SetCallBack( + "TransactionTest::ExpirableTransactionDataRace:1", [&](void* arg) { + WriteOptions write_options; + TransactionOptions txn_options; + + // Force txn1 to expire + /* sleep override */ + std::this_thread::sleep_for(std::chrono::milliseconds(150)); + + Transaction* txn2 = db->BeginTransaction(write_options, txn_options); + Status s; + s = txn2->Put("X", "2"); + ASSERT_TRUE(s.IsTimedOut()); + s = txn2->Commit(); + ASSERT_OK(s); + delete txn2; + }); + + rocksdb::SyncPoint::GetInstance()->EnableProcessing(); + + WriteOptions write_options; + TransactionOptions txn_options; + + txn_options.expiration = 100; + Transaction* txn1 = db->BeginTransaction(write_options, txn_options); + + Status s; + s = txn1->Put("X", "1"); + ASSERT_OK(s); + s = txn1->Commit(); + ASSERT_OK(s); + + ReadOptions read_options; + string value; + s = db->Get(read_options, "X", &value); + ASSERT_EQ("1", value); + + delete txn1; +} + } // namespace rocksdb int main(int argc, char** argv) {