Block cache simulator: Add pysim to simulate caches using reinforcement learning. (#5610)

Summary:
This PR implements cache eviction using reinforcement learning. It includes two implementations:
1. An implementation of Thompson Sampling for the Bernoulli Bandit [1].
2. An implementation of LinUCB with disjoint linear models [2].

The idea is that a cache uses multiple eviction policies, e.g., MRU, LRU, and LFU. The cache learns which eviction policy is the best and uses it upon a cache miss.
Thompson Sampling is contextless and does not include any features.
LinUCB includes features such as level, block type, caller, column family id to decide which eviction policy to use.

[1] Daniel J. Russo, Benjamin Van Roy, Abbas Kazerouni, Ian Osband, and Zheng Wen. 2018. A Tutorial on Thompson Sampling. Found. Trends Mach. Learn. 11, 1 (July 2018), 1-96. DOI: https://doi.org/10.1561/2200000070
[2] Lihong Li, Wei Chu, John Langford, and Robert E. Schapire. 2010. A contextual-bandit approach to personalized news article recommendation. In Proceedings of the 19th international conference on World wide web (WWW '10). ACM, New York, NY, USA, 661-670. DOI=http://dx.doi.org/10.1145/1772690.1772758
Pull Request resolved: https://github.com/facebook/rocksdb/pull/5610

Differential Revision: D16435067

Pulled By: HaoyuHuang

fbshipit-source-id: 6549239ae14115c01cb1e70548af9e46d8dc21bb
This commit is contained in:
haoyuhuang 2019-07-26 14:36:16 -07:00 committed by Facebook Github Bot
parent 41df734830
commit 70c7302fb5
14 changed files with 1345 additions and 20 deletions

1
.gitignore vendored
View File

@ -34,6 +34,7 @@ manifest_dump
sst_dump
blob_dump
block_cache_trace_analyzer
tools/block_cache_analyzer/*.pyc
column_aware_encoding_exp
util/build_version.cc
build_tools/VALGRIND_LOGS/

View File

@ -626,7 +626,7 @@ set(SOURCES
test_util/sync_point_impl.cc
test_util/testutil.cc
test_util/transaction_test_util.cc
tools/block_cache_trace_analyzer.cc
tools/block_cache_analyzer/block_cache_trace_analyzer.cc
tools/db_bench_tool.cc
tools/dump/db_dump_tool.cc
tools/ldb_cmd.cc
@ -976,7 +976,7 @@ if(WITH_TESTS)
table/merger_test.cc
table/sst_file_reader_test.cc
table/table_test.cc
tools/block_cache_trace_analyzer_test.cc
tools/block_cache_analyzer/block_cache_trace_analyzer_test.cc
tools/ldb_cmd_test.cc
tools/reduce_levels_test.cc
tools/sst_dump_test.cc

View File

@ -1114,7 +1114,7 @@ db_bench: tools/db_bench.o $(BENCHTOOLOBJECTS)
trace_analyzer: tools/trace_analyzer.o $(ANALYZETOOLOBJECTS) $(LIBOBJECTS)
$(AM_LINK)
block_cache_trace_analyzer: tools/block_cache_trace_analyzer_tool.o $(ANALYZETOOLOBJECTS) $(LIBOBJECTS)
block_cache_trace_analyzer: tools/block_cache_analyzer/block_cache_trace_analyzer_tool.o $(ANALYZETOOLOBJECTS) $(LIBOBJECTS)
$(AM_LINK)
cache_bench: cache/cache_bench.o $(LIBOBJECTS) $(TESTUTIL)
@ -1614,7 +1614,7 @@ db_secondary_test: db/db_impl/db_secondary_test.o db/db_test_util.o $(LIBOBJECTS
block_cache_tracer_test: trace_replay/block_cache_tracer_test.o trace_replay/block_cache_tracer.o $(LIBOBJECTS) $(TESTHARNESS)
$(AM_LINK)
block_cache_trace_analyzer_test: tools/block_cache_trace_analyzer_test.o tools/block_cache_trace_analyzer.o $(LIBOBJECTS) $(TESTHARNESS)
block_cache_trace_analyzer_test: tools/block_cache_analyzer/block_cache_trace_analyzer_test.o tools/block_cache_analyzer/block_cache_trace_analyzer.o $(LIBOBJECTS) $(TESTHARNESS)
$(AM_LINK)
#-------------------------------------------------

View File

@ -351,7 +351,7 @@ cpp_library(
"test_util/fault_injection_test_env.cc",
"test_util/testharness.cc",
"test_util/testutil.cc",
"tools/block_cache_trace_analyzer.cc",
"tools/block_cache_analyzer/block_cache_trace_analyzer.cc",
"tools/trace_analyzer_tool.cc",
"utilities/cassandra/test_utils.cc",
],
@ -369,7 +369,7 @@ cpp_library(
name = "rocksdb_tools_lib",
srcs = [
"test_util/testutil.cc",
"tools/block_cache_trace_analyzer.cc",
"tools/block_cache_analyzer/block_cache_trace_analyzer.cc",
"tools/db_bench_tool.cc",
"tools/trace_analyzer_tool.cc",
],
@ -430,7 +430,7 @@ ROCKS_TESTS = [
],
[
"block_cache_trace_analyzer_test",
"tools/block_cache_trace_analyzer_test.cc",
"tools/block_cache_analyzer/block_cache_trace_analyzer_test.cc",
"serial",
],
[

6
src.mk
View File

@ -246,7 +246,7 @@ TOOL_LIB_SOURCES = \
utilities/blob_db/blob_dump_tool.cc \
ANALYZER_LIB_SOURCES = \
tools/block_cache_trace_analyzer.cc \
tools/block_cache_analyzer/block_cache_trace_analyzer.cc \
tools/trace_analyzer_tool.cc \
MOCK_LIB_SOURCES = \
@ -374,8 +374,8 @@ MAIN_SOURCES = \
table/table_reader_bench.cc \
table/table_test.cc \
third-party/gtest-1.7.0/fused-src/gtest/gtest-all.cc \
tools/block_cache_trace_analyzer_test.cc \
tools/block_cache_trace_analyzer_tool.cc \
tools/block_cache_analyzer/block_cache_trace_analyzer_test.cc \
tools/block_cache_analyzer/block_cache_trace_analyzer_tool.cc \
tools/db_bench.cc \
tools/db_bench_tool_test.cc \
tools/db_sanity_test.cc \

View File

@ -0,0 +1,2 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

View File

@ -0,0 +1,864 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import gc
import random
import sys
import time
from os import path
import numpy as np
kSampleSize = 16 # The sample size used when performing eviction.
kMicrosInSecond = 1000000
kSecondsInMinute = 60
kSecondsInHour = 3600
class TraceRecord:
"""
A trace record represents a block access.
It holds the same struct as BlockCacheTraceRecord in
trace_replay/block_cache_tracer.h
"""
def __init__(
self,
access_time,
block_id,
block_type,
block_size,
cf_id,
cf_name,
level,
fd,
caller,
no_insert,
get_id,
key_id,
kv_size,
is_hit,
):
self.access_time = access_time
self.block_id = block_id
self.block_type = block_type
self.block_size = block_size
self.cf_id = cf_id
self.cf_name = cf_name
self.level = level
self.fd = fd
self.caller = caller
if no_insert == 1:
self.no_insert = True
else:
self.no_insert = False
self.get_id = get_id
self.key_id = key_id
self.kv_size = kv_size
if is_hit == 1:
self.is_hit = True
else:
self.is_hit = False
class CacheEntry:
"""A cache entry stored in the cache."""
def __init__(self, value_size, cf_id, level, block_type, access_number):
self.value_size = value_size
self.last_access_number = access_number
self.num_hits = 0
self.cf_id = 0
self.level = level
self.block_type = block_type
def __repr__(self):
"""Debug string."""
return "s={},last={},hits={},cf={},l={},bt={}".format(
self.value_size,
self.last_access_number,
self.num_hits,
self.cf_id,
self.level,
self.block_type,
)
class HashEntry:
"""A hash entry stored in a hash table."""
def __init__(self, key, hash, value):
self.key = key
self.hash = hash
self.value = value
def __repr__(self):
return "k={},h={},v=[{}]".format(self.key, self.hash, self.value)
class HashTable:
"""
A custom implementation of hash table to support fast random sampling.
It is closed hashing and uses chaining to resolve hash conflicts.
It grows/shrinks the hash table upon insertion/deletion to support
fast lookups and random samplings.
"""
def __init__(self):
self.table = [None] * 32
self.elements = 0
def random_sample(self, sample_size):
"""Randomly sample 'sample_size' hash entries from the table."""
samples = []
index = random.randint(0, len(self.table))
pos = (index + 1) % len(self.table)
searches = 0
# Starting from index, adding hash entries to the sample list until
# sample_size is met or we ran out of entries.
while pos != index and len(samples) < sample_size:
if self.table[pos] is not None:
for i in range(len(self.table[pos])):
if self.table[pos][i] is None:
continue
samples.append(self.table[pos][i])
if len(samples) > sample_size:
break
pos += 1
pos = pos % len(self.table)
searches += 1
return samples
def insert(self, key, hash, value):
"""
Insert a hash entry in the table. Replace the old entry if it already
exists.
"""
self.grow()
inserted = False
index = hash % len(self.table)
if self.table[index] is None:
self.table[index] = []
for i in range(len(self.table[index])):
if self.table[index][i] is not None:
if (
self.table[index][i].hash == hash
and self.table[index][i].key == key
):
# The entry already exists in the table.
self.table[index][i] = HashEntry(key, hash, value)
return
continue
self.table[index][i] = HashEntry(key, hash, value)
inserted = True
break
if not inserted:
self.table[index].append(HashEntry(key, hash, value))
self.elements += 1
def resize(self, new_size):
if new_size == len(self.table):
return
if new_size == 0:
return
if self.elements < 100:
return
new_table = [None] * new_size
# Copy 'self.table' to new_table.
for i in range(len(self.table)):
entries = self.table[i]
if entries is None:
continue
for j in range(len(entries)):
if entries[j] is None:
continue
index = entries[j].hash % new_size
if new_table[index] is None:
new_table[index] = []
new_table[index].append(entries[j])
self.table = new_table
del new_table
# Manually call python gc here to free the memory as 'self.table'
# might be very large.
gc.collect()
def grow(self):
if self.elements < len(self.table):
return
new_size = int(len(self.table) * 1.2)
self.resize(new_size)
def delete(self, key, hash):
index = hash % len(self.table)
entries = self.table[index]
deleted = False
if entries is None:
return
for i in range(len(entries)):
if (
entries[i] is not None
and entries[i].hash == hash
and entries[i].key == key
):
entries[i] = None
self.elements -= 1
deleted = True
break
if deleted:
self.shrink()
def shrink(self):
if self.elements * 2 >= len(self.table):
return
new_size = int(len(self.table) * 0.7)
self.resize(new_size)
def lookup(self, key, hash):
index = hash % len(self.table)
entries = self.table[index]
if entries is None:
return None
for entry in entries:
if entry is not None and entry.hash == hash and entry.key == key:
return entry.value
return None
class MissRatioStats:
def __init__(self, time_unit):
self.num_misses = 0
self.num_accesses = 0
self.time_unit = time_unit
self.time_misses = {}
self.time_accesses = {}
def update_metrics(self, access_time, is_hit):
access_time /= kMicrosInSecond * self.time_unit
self.num_accesses += 1
if access_time not in self.time_accesses:
self.time_accesses[access_time] = 0
self.time_accesses[access_time] += 1
if not is_hit:
self.num_misses += 1
if access_time not in self.time_misses:
self.time_misses[access_time] = 0
self.time_misses[access_time] += 1
def reset_counter(self):
self.num_misses = 0
self.num_accesses = 0
def miss_ratio(self):
return float(self.num_misses) * 100.0 / float(self.num_accesses)
def write_miss_timeline(self, cache_type, cache_size, result_dir, start, end):
start /= kMicrosInSecond * self.time_unit
end /= kMicrosInSecond * self.time_unit
header_file_path = "{}/header-ml-miss-timeline-{}-{}-{}".format(
result_dir, self.time_unit, cache_type, cache_size
)
if not path.exists(header_file_path):
with open(header_file_path, "w+") as header_file:
header = "time"
for trace_time in range(start, end):
header += ",{}".format(trace_time)
header_file.write(header + "\n")
file_path = "{}/data-ml-miss-timeline-{}-{}-{}".format(
result_dir, self.time_unit, cache_type, cache_size
)
with open(file_path, "w+") as file:
row = "{}".format(cache_type)
for trace_time in range(start, end):
row += ",{}".format(self.time_misses.get(trace_time, 0))
file.write(row + "\n")
def write_miss_ratio_timeline(self, cache_type, cache_size, result_dir, start, end):
start /= kMicrosInSecond * self.time_unit
end /= kMicrosInSecond * self.time_unit
header_file_path = "{}/header-ml-miss-ratio-timeline-{}-{}-{}".format(
result_dir, self.time_unit, cache_type, cache_size
)
if not path.exists(header_file_path):
with open(header_file_path, "w+") as header_file:
header = "time"
for trace_time in range(start, end):
header += ",{}".format(trace_time)
header_file.write(header + "\n")
file_path = "{}/data-ml-miss-ratio-timeline-{}-{}-{}".format(
result_dir, self.time_unit, cache_type, cache_size
)
with open(file_path, "w+") as file:
row = "{}".format(cache_type)
for trace_time in range(start, end):
naccesses = self.time_accesses.get(trace_time, 0)
miss_ratio = 0
if naccesses > 0:
miss_ratio = float(
self.time_misses.get(trace_time, 0) * 100.0
) / float(naccesses)
row += ",{0:.2f}".format(miss_ratio)
file.write(row + "\n")
class PolicyStats:
def __init__(self, time_unit, policies):
self.time_selected_polices = {}
self.time_accesses = {}
self.policy_names = {}
self.time_unit = time_unit
for i in range(len(policies)):
self.policy_names[i] = policies[i].policy_name()
def update_metrics(self, access_time, selected_policy):
access_time /= kMicrosInSecond * self.time_unit
if access_time not in self.time_accesses:
self.time_accesses[access_time] = 0
self.time_accesses[access_time] += 1
if access_time not in self.time_selected_polices:
self.time_selected_polices[access_time] = {}
policy_name = self.policy_names[selected_policy]
if policy_name not in self.time_selected_polices[access_time]:
self.time_selected_polices[access_time][policy_name] = 0
self.time_selected_polices[access_time][policy_name] += 1
def write_policy_timeline(self, cache_type, cache_size, result_dir, start, end):
start /= kMicrosInSecond * self.time_unit
end /= kMicrosInSecond * self.time_unit
header_file_path = "{}/header-ml-policy-timeline-{}-{}-{}".format(
result_dir, self.time_unit, cache_type, cache_size
)
if not path.exists(header_file_path):
with open(header_file_path, "w+") as header_file:
header = "time"
for trace_time in range(start, end):
header += ",{}".format(trace_time)
header_file.write(header + "\n")
file_path = "{}/data-ml-policy-timeline-{}-{}-{}".format(
result_dir, self.time_unit, cache_type, cache_size
)
with open(file_path, "w+") as file:
for policy in self.policy_names:
policy_name = self.policy_names[policy]
row = "{}-{}".format(cache_type, policy_name)
for trace_time in range(start, end):
row += ",{}".format(
self.time_selected_polices.get(trace_time, {}).get(
policy_name, 0
)
)
file.write(row + "\n")
def write_policy_ratio_timeline(
self, cache_type, cache_size, file_path, start, end
):
start /= kMicrosInSecond * self.time_unit
end /= kMicrosInSecond * self.time_unit
header_file_path = "{}/header-ml-policy-ratio-timeline-{}-{}-{}".format(
result_dir, self.time_unit, cache_type, cache_size
)
if not path.exists(header_file_path):
with open(header_file_path, "w+") as header_file:
header = "time"
for trace_time in range(start, end):
header += ",{}".format(trace_time)
header_file.write(header + "\n")
file_path = "{}/data-ml-policy-ratio-timeline-{}-{}-{}".format(
result_dir, self.time_unit, cache_type, cache_size
)
with open(file_path, "w+") as file:
for policy in self.policy_names:
policy_name = self.policy_names[policy]
row = "{}-{}".format(cache_type, policy_name)
for trace_time in range(start, end):
naccesses = self.time_accesses.get(trace_time, 0)
ratio = 0
if naccesses > 0:
ratio = float(
self.time_selected_polices.get(trace_time, {}).get(
policy_name, 0
)
* 100.0
) / float(naccesses)
row += ",{0:.2f}".format(ratio)
file.write(row + "\n")
class Policy(object):
"""
A policy maintains a set of evicted keys. It returns a reward of one to
itself if it has not evicted a missing key. Otherwise, it gives itself 0
reward.
"""
def __init__(self):
self.evicted_keys = {}
def evict(self, key, max_size):
self.evicted_keys[key] = 0
def delete(self, key):
self.evicted_keys.pop(key, None)
def prioritize_samples(self, samples):
raise NotImplementedError
def policy_name(self):
raise NotImplementedError
def generate_reward(self, key):
if key in self.evicted_keys:
return 0
return 1
class LRUPolicy(Policy):
def prioritize_samples(self, samples):
return sorted(
samples,
cmp=lambda e1, e2: e1.value.last_access_number
- e2.value.last_access_number,
)
def policy_name(self):
return "lru"
class MRUPolicy(Policy):
def prioritize_samples(self, samples):
return sorted(
samples,
cmp=lambda e1, e2: e2.value.last_access_number
- e1.value.last_access_number,
)
def policy_name(self):
return "mru"
class LFUPolicy(Policy):
def prioritize_samples(self, samples):
return sorted(samples, cmp=lambda e1, e2: e1.value.num_hits - e2.value.num_hits)
def policy_name(self):
return "lfu"
class MLCache(object):
def __init__(self, cache_size, enable_cache_row_key, policies):
self.cache_size = cache_size
self.used_size = 0
self.miss_ratio_stats = MissRatioStats(kSecondsInMinute)
self.policy_stats = PolicyStats(kSecondsInMinute, policies)
self.per_hour_miss_ratio_stats = MissRatioStats(kSecondsInHour)
self.per_hour_policy_stats = PolicyStats(kSecondsInHour, policies)
self.table = HashTable()
self.enable_cache_row_key = enable_cache_row_key
self.get_id_row_key_map = {}
self.policies = policies
def _lookup(self, key, hash):
value = self.table.lookup(key, hash)
if value is not None:
value.last_access_number = self.miss_ratio_stats.num_accesses
value.num_hits += 1
return True
return False
def _select_policy(self, trace_record, key):
raise NotImplementedError
def cache_name(self):
raise NotImplementedError
def _evict(self, policy_index, value_size):
# Randomly sample n entries.
samples = self.table.random_sample(kSampleSize)
samples = self.policies[policy_index].prioritize_samples(samples)
for hash_entry in samples:
self.used_size -= hash_entry.value.value_size
self.table.delete(hash_entry.key, hash_entry.hash)
self.policies[policy_index].evict(
key=hash_entry.key, max_size=self.table.elements
)
if self.used_size + value_size <= self.cache_size:
break
def _insert(self, trace_record, key, hash, value_size):
if value_size > self.cache_size:
return
policy_index = self._select_policy(trace_record, key)
self.policies[policy_index].delete(key)
self.policy_stats.update_metrics(trace_record.access_time, policy_index)
self.per_hour_policy_stats.update_metrics(
trace_record.access_time, policy_index
)
while self.used_size + value_size > self.cache_size:
self._evict(policy_index, value_size)
self.table.insert(
key,
hash,
CacheEntry(
value_size,
trace_record.cf_id,
trace_record.level,
trace_record.block_type,
self.miss_ratio_stats.num_accesses,
),
)
self.used_size += value_size
def _access_kv(self, trace_record, key, hash, value_size, no_insert):
if self._lookup(key, hash):
return True
if not no_insert and value_size > 0:
self._insert(trace_record, key, hash, value_size)
return False
def _update_stats(self, access_time, is_hit):
self.miss_ratio_stats.update_metrics(access_time, is_hit)
self.per_hour_miss_ratio_stats.update_metrics(access_time, is_hit)
def access(self, trace_record):
assert self.used_size <= self.cache_size
if (
self.enable_cache_row_key
and trace_record.caller == 1
and trace_record.key_id != 0
and trace_record.get_id != 0
):
# This is a get request.
if trace_record.get_id not in self.get_id_row_key_map:
self.get_id_row_key_map[trace_record.get_id] = {}
self.get_id_row_key_map[trace_record.get_id]["h"] = False
if self.get_id_row_key_map[trace_record.get_id]["h"]:
# We treat future accesses as hits since this get request
# completes.
self._update_stats(trace_record.access_time, is_hit=True)
return
if trace_record.key_id not in self.get_id_row_key_map[trace_record.get_id]:
# First time seen this key.
is_hit = self._access_kv(
trace_record,
key="g{}".format(trace_record.key_id),
hash=trace_record.key_id,
value_size=trace_record.kv_size,
no_insert=False,
)
inserted = False
if trace_record.kv_size > 0:
inserted = True
self.get_id_row_key_map[trace_record.get_id][
trace_record.key_id
] = inserted
self.get_id_row_key_map[trace_record.get_id]["h"] = is_hit
if self.get_id_row_key_map[trace_record.get_id]["h"]:
# We treat future accesses as hits since this get request
# completes.
self._update_stats(trace_record.access_time, is_hit=True)
return
# Access its blocks.
is_hit = self._access_kv(
trace_record,
key="b{}".format(trace_record.block_id),
hash=trace_record.block_id,
value_size=trace_record.block_size,
no_insert=trace_record.no_insert,
)
self._update_stats(trace_record.access_time, is_hit)
if (
trace_record.kv_size > 0
and not self.get_id_row_key_map[trace_record.get_id][
trace_record.key_id
]
):
# Insert the row key-value pair.
self._access_kv(
trace_record,
key="g{}".format(trace_record.key_id),
hash=trace_record.key_id,
value_size=trace_record.kv_size,
no_insert=False,
)
# Mark as inserted.
self.get_id_row_key_map[trace_record.get_id][trace_record.key_id] = True
return
# Access the block.
is_hit = self._access_kv(
trace_record,
key="b{}".format(trace_record.block_id),
hash=trace_record.block_id,
value_size=trace_record.block_size,
no_insert=trace_record.no_insert,
)
self._update_stats(trace_record.access_time, is_hit)
class ThompsonSamplingCache(MLCache):
"""
An implementation of Thompson Sampling for the Bernoulli Bandit [1].
[1] Daniel J. Russo, Benjamin Van Roy, Abbas Kazerouni, Ian Osband,
and Zheng Wen. 2018. A Tutorial on Thompson Sampling. Found.
Trends Mach. Learn. 11, 1 (July 2018), 1-96.
DOI: https://doi.org/10.1561/2200000070
"""
def __init__(self, cache_size, enable_cache_row_key, policies, init_a=1, init_b=1):
super(ThompsonSamplingCache, self).__init__(
cache_size, enable_cache_row_key, policies
)
self._as = {}
self._bs = {}
for _i in range(len(policies)):
self._as = [init_a] * len(self.policies)
self._bs = [init_b] * len(self.policies)
def _select_policy(self, trace_record, key):
samples = [
np.random.beta(self._as[x], self._bs[x]) for x in range(len(self.policies))
]
selected_policy = max(range(len(self.policies)), key=lambda x: samples[x])
reward = self.policies[selected_policy].generate_reward(key)
assert reward <= 1 and reward >= 0
self._as[selected_policy] += reward
self._bs[selected_policy] += 1 - reward
return selected_policy
def cache_name(self):
if self.enable_cache_row_key:
return "Hybrid ThompsonSampling (ts_hybrid)"
return "ThompsonSampling (ts)"
class LinUCBCache(MLCache):
"""
An implementation of LinUCB with disjoint linear models [2].
[2] Lihong Li, Wei Chu, John Langford, and Robert E. Schapire. 2010.
A contextual-bandit approach to personalized news article recommendation.
In Proceedings of the 19th international conference on World wide web
(WWW '10). ACM, New York, NY, USA, 661-670.
DOI=http://dx.doi.org/10.1145/1772690.1772758
"""
def __init__(self, cache_size, enable_cache_row_key, policies):
super(LinUCBCache, self).__init__(cache_size, enable_cache_row_key, policies)
self.nfeatures = 4 # Block type, caller, level, cf.
self.th = np.zeros((len(self.policies), self.nfeatures))
self.eps = 0.2
self.b = np.zeros_like(self.th)
self.A = np.zeros((len(self.policies), self.nfeatures, self.nfeatures))
self.A_inv = np.zeros((len(self.policies), self.nfeatures, self.nfeatures))
for i in range(len(self.policies)):
self.A[i] = np.identity(self.nfeatures)
self.th_hat = np.zeros_like(self.th)
self.p = np.zeros(len(self.policies))
self.alph = 0.2
def _select_policy(self, trace_record, key):
x_i = np.zeros(self.nfeatures) # The current context vector
x_i[0] = trace_record.block_type
x_i[1] = trace_record.caller
x_i[2] = trace_record.level
x_i[3] = trace_record.cf_id
p = np.zeros(len(self.policies))
for a in range(len(self.policies)):
self.th_hat[a] = self.A_inv[a].dot(self.b[a])
ta = x_i.dot(self.A_inv[a]).dot(x_i)
a_upper_ci = self.alph * np.sqrt(ta)
a_mean = self.th_hat[a].dot(x_i)
p[a] = a_mean + a_upper_ci
p = p + (np.random.random(len(p)) * 0.000001)
selected_policy = p.argmax()
reward = self.policies[selected_policy].generate_reward(key)
assert reward <= 1 and reward >= 0
self.A[selected_policy] += np.outer(x_i, x_i)
self.b[selected_policy] += reward * x_i
self.A_inv[selected_policy] = np.linalg.inv(self.A[selected_policy])
del x_i
return selected_policy
def cache_name(self):
if self.enable_cache_row_key:
return "Hybrid LinUCB (linucb_hybrid)"
return "LinUCB (linucb)"
def parse_cache_size(cs):
cs = cs.replace("\n", "")
if cs[-1] == "M":
return int(cs[: len(cs) - 1]) * 1024 * 1024
if cs[-1] == "G":
return int(cs[: len(cs) - 1]) * 1024 * 1024 * 1024
if cs[-1] == "T":
return int(cs[: len(cs) - 1]) * 1024 * 1024 * 1024 * 1024
return int(cs)
def create_cache(cache_type, cache_size, downsample_size):
policies = []
policies.append(LRUPolicy())
policies.append(MRUPolicy())
policies.append(LFUPolicy())
cache_size = cache_size / downsample_size
enable_cache_row_key = False
if "hybrid" in cache_type:
enable_cache_row_key = True
cache_type = cache_type[:-7]
if cache_type == "ts":
return ThompsonSamplingCache(cache_size, enable_cache_row_key, policies)
elif cache_type == "linucb":
return LinUCBCache(cache_size, enable_cache_row_key, policies)
else:
print("Unknown cache type {}".format(cache_type))
assert False
return None
def run(trace_file_path, cache_type, cache, warmup_seconds):
warmup_complete = False
num = 0
trace_start_time = 0
trace_duration = 0
start_time = time.time()
time_interval = 1
trace_miss_ratio_stats = MissRatioStats(kSecondsInMinute)
with open(trace_file_path, "r") as trace_file:
for line in trace_file:
num += 1
if num % 1000000 == 0:
# Force a python gc periodically to reduce memory usage.
gc.collect()
ts = line.split(",")
timestamp = int(ts[0])
if trace_start_time == 0:
trace_start_time = timestamp
trace_duration = timestamp - trace_start_time
if not warmup_complete and trace_duration > warmup_seconds * 1000000:
cache.miss_ratio_stats.reset_counter()
warmup_complete = True
record = TraceRecord(
access_time=int(ts[0]),
block_id=int(ts[1]),
block_type=int(ts[2]),
block_size=int(ts[3]),
cf_id=int(ts[4]),
cf_name=ts[5],
level=int(ts[6]),
fd=int(ts[7]),
caller=int(ts[8]),
no_insert=int(ts[9]),
get_id=int(ts[10]),
key_id=int(ts[11]),
kv_size=int(ts[12]),
is_hit=int(ts[13]),
)
trace_miss_ratio_stats.update_metrics(
record.access_time, is_hit=record.is_hit
)
cache.access(record)
del record
if num % 100 != 0:
continue
# Report progress every 10 seconds.
now = time.time()
if now - start_time > time_interval * 10:
print(
"Take {} seconds to process {} trace records with trace "
"duration of {} seconds. Throughput: {} records/second. "
"Trace miss ratio {}".format(
now - start_time,
num,
trace_duration / 1000000,
num / (now - start_time),
trace_miss_ratio_stats.miss_ratio(),
)
)
time_interval += 1
print(
"{},0,0,{},{},{}".format(
cache_type,
cache.cache_size,
cache.miss_ratio_stats.miss_ratio(),
cache.miss_ratio_stats.num_accesses,
)
)
now = time.time()
print(
"Take {} seconds to process {} trace records with trace duration of {} "
"seconds. Throughput: {} records/second. Trace miss ratio {}".format(
now - start_time,
num,
trace_duration / 1000000,
num / (now - start_time),
trace_miss_ratio_stats.miss_ratio(),
)
)
return trace_start_time, trace_duration
def report_stats(
cache, cache_type, cache_size, result_dir, trace_start_time, trace_end_time
):
cache_label = "{}-{}".format(cache_type, cache_size)
with open("{}/data-ml-mrc-{}".format(result_dir, cache_label), "w+") as mrc_file:
mrc_file.write(
"{},0,0,{},{},{}\n".format(
cache_type,
cache_size,
cache.miss_ratio_stats.miss_ratio(),
cache.miss_ratio_stats.num_accesses,
)
)
cache.policy_stats.write_policy_timeline(
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)
cache.policy_stats.write_policy_ratio_timeline(
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)
cache.miss_ratio_stats.write_miss_timeline(
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)
cache.miss_ratio_stats.write_miss_ratio_timeline(
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)
cache.per_hour_policy_stats.write_policy_timeline(
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)
cache.per_hour_policy_stats.write_policy_ratio_timeline(
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)
cache.per_hour_miss_ratio_stats.write_miss_timeline(
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)
cache.per_hour_miss_ratio_stats.write_miss_ratio_timeline(
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)
if __name__ == "__main__":
if len(sys.argv) <= 6:
print(
"Must provide 6 arguments. "
"1) cache_type (ts, ts_hybrid, linucb, linucb_hybrid). "
"2) cache size (xM, xG, xT). "
"3) The sampling frequency used to collect the trace. (The "
"simulation scales down the cache size by the sampling frequency). "
"4) Warmup seconds (The number of seconds used for warmup). "
"5) Trace file path. "
"6) Result directory (A directory that saves generated results)"
)
exit(1)
cache_type = sys.argv[1]
cache_size = parse_cache_size(sys.argv[2])
downsample_size = int(sys.argv[3])
warmup_seconds = int(sys.argv[4])
trace_file_path = sys.argv[5]
result_dir = sys.argv[6]
cache = create_cache(cache_type, cache_size, downsample_size)
trace_start_time, trace_duration = run(
trace_file_path, cache_type, cache, warmup_seconds
)
trace_end_time = trace_start_time + trace_duration
report_stats(
cache, cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)

View File

@ -0,0 +1,118 @@
#!/usr/bin/env bash
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#
# A shell script to run a batch of pysims and combine individual pysim output files.
#
# Usage: bash block_cache_pysim.sh trace_file_path result_dir downsample_size warmup_seconds max_jobs
# trace_file_path: The file path that stores the traces.
# result_dir: The directory to store pysim results. The output files from a pysim is stores in result_dir/ml
# downsample_size: The downsample size used to collect the trace.
# warmup_seconds: The number of seconds used for warmup.
# max_jobs: The max number of concurrent pysims to run.
if [ $# -ne 5 ]; then
echo "Usage: ./block_cache_pysim.sh trace_file_path result_dir downsample_size warmup_seconds max_jobs"
exit 0
fi
trace_file="$1"
result_dir="$2"
downsample_size="$3"
warmup_seconds="$4"
max_jobs="$5"
current_jobs=0
ml_tmp_result_dir="$result_dir/ml"
rm -rf "$ml_tmp_result_dir"
mkdir -p "$result_dir"
mkdir -p "$ml_tmp_result_dir"
for cache_type in "ts" "linucb" "ts_hybrid" "linucb_hybrid"
do
for cache_size in "16M" "256M" "1G" "2G" "4G" "8G" "12G" "16G"
do
while [ "$current_jobs" -ge "$max_jobs" ]
do
sleep 10
echo "Waiting jobs to complete. Number of running jobs: $current_jobs"
current_jobs=$(ps aux | grep pysim | grep python | grep -cv grep)
echo "Waiting jobs to complete. Number of running jobs: $current_jobs"
done
output="log-ml-$cache_type-$cache_size"
echo "Running simulation for $cache_type and cache size $cache_size. Number of running jobs: $current_jobs. "
nohup python block_cache_pysim.py "$cache_type" "$cache_size" "$downsample_size" "$warmup_seconds" "$trace_file" "$ml_tmp_result_dir" >& $ml_tmp_result_dir/$output &
current_jobs=$((current_jobs+1))
done
done
# Wait for all jobs to complete.
while [ $current_jobs -gt 0 ]
do
sleep 10
echo "Waiting jobs to complete. Number of running jobs: $current_jobs"
current_jobs=$(ps aux | grep pysim | grep python | grep -cv grep)
echo "Waiting jobs to complete. Number of running jobs: $current_jobs"
done
echo "Combine individual pysim output files"
rm -rf "$result_dir/ml_*"
mrc_file="$result_dir/ml_mrc"
for header in "header-" "data-"
do
for fn in $ml_tmp_result_dir/*
do
sum_file=""
time_unit=""
capacity=""
if [[ $fn == *"timeline"* ]]; then
tmpfn="$fn"
IFS='-' read -ra elements <<< "$tmpfn"
time_unit_index=0
capacity_index=0
for i in "${elements[@]}"
do
if [[ $i == "timeline" ]]; then
break
fi
time_unit_index=$((time_unit_index+1))
done
time_unit_index=$((time_unit_index+1))
capacity_index=$((time_unit_index+2))
time_unit="${elements[$time_unit_index]}_"
capacity="${elements[$capacity_index]}_"
fi
if [[ $fn == "${header}ml-policy-timeline"* ]]; then
sum_file="$result_dir/ml_${capacity}${time_unit}policy_timeline"
fi
if [[ $fn == "${header}ml-policy-ratio-timeline"* ]]; then
sum_file="$result_dir/ml_${capacity}${time_unit}policy_ratio_timeline"
fi
if [[ $fn == "${header}ml-miss-timeline"* ]]; then
sum_file="$result_dir/ml_${capacity}${time_unit}miss_timeline"
fi
if [[ $fn == "${header}ml-miss-ratio-timeline"* ]]; then
sum_file="$result_dir/ml_${capacity}${time_unit}miss_ratio_timeline"
fi
if [[ $fn == "${header}ml-mrc"* ]]; then
sum_file="$mrc_file"
fi
if [[ $sum_file == "" ]]; then
continue
fi
if [[ $header == "header-" ]]; then
if [ -e "$sum_file" ]; then
continue
fi
fi
cat "$ml_tmp_result_dir/$fn" >> "$sum_file"
done
done
echo "Done"
# Sort MRC file by cache_type and cache_size.
tmp_file="$result_dir/tmp_mrc"
cat "$mrc_file" | sort -t ',' -k1,1 -k4,4n > "$tmp_file"
cat "$tmp_file" > "$mrc_file"
rm -rf "$tmp_file"

View File

@ -0,0 +1,340 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import random
from block_cache_pysim import (
HashTable,
LFUPolicy,
LinUCBCache,
LRUPolicy,
MRUPolicy,
ThompsonSamplingCache,
TraceRecord,
kSampleSize,
)
def test_hash_table():
print("Test hash table")
table = HashTable()
data_size = 10000
for i in range(data_size):
table.insert("k{}".format(i), i, "v{}".format(i))
for i in range(data_size):
assert table.lookup("k{}".format(i), i) is not None
for i in range(data_size):
table.delete("k{}".format(i), i)
for i in range(data_size):
assert table.lookup("k{}".format(i), i) is None
truth_map = {}
n = 1000000
records = 100
for i in range(n):
key_id = random.randint(0, records)
key = "k{}".format(key_id)
value = "v{}".format(key_id)
action = random.randint(0, 2)
# print "{}:{}:{}".format(action, key, value)
assert len(truth_map) == table.elements, "{} {} {}".format(
len(truth_map), table.elements, i
)
if action == 0:
table.insert(key, key_id, value)
truth_map[key] = value
elif action == 1:
if key in truth_map:
assert table.lookup(key, key_id) is not None
assert truth_map[key] == table.lookup(key, key_id)
else:
assert table.lookup(key, key_id) is None
else:
table.delete(key, key_id)
if key in truth_map:
del truth_map[key]
print("Test hash table: Success")
def assert_metrics(cache, expected_value):
assert cache.used_size == expected_value[0], "Expected {}, Actual {}".format(
expected_value[0], cache.used_size
)
assert (
cache.miss_ratio_stats.num_accesses == expected_value[1]
), "Expected {}, Actual {}".format(
expected_value[1], cache.miss_ratio_stats.num_accesses
)
assert (
cache.miss_ratio_stats.num_misses == expected_value[2]
), "Expected {}, Actual {}".format(
expected_value[2], cache.miss_ratio_stats.num_misses
)
assert cache.table.elements == len(expected_value[3]) + len(
expected_value[4]
), "Expected {}, Actual {}".format(
len(expected_value[3]) + len(expected_value[4]), cache.table.elements
)
for expeceted_k in expected_value[3]:
val = cache.table.lookup("b{}".format(expeceted_k), expeceted_k)
assert val is not None
assert val.value_size == 1
for expeceted_k in expected_value[4]:
val = cache.table.lookup("g{}".format(expeceted_k), expeceted_k)
assert val is not None
assert val.value_size == 1
# Access k1, k1, k2, k3, k3, k3, k4
def test_cache(policies, expected_value):
cache = ThompsonSamplingCache(3, False, policies)
k1 = TraceRecord(
access_time=0,
block_id=1,
block_type=1,
block_size=1,
cf_id=0,
cf_name="",
level=0,
fd=0,
caller=1,
no_insert=0,
get_id=1,
key_id=1,
kv_size=5,
is_hit=1,
)
k2 = TraceRecord(
access_time=1,
block_id=2,
block_type=1,
block_size=1,
cf_id=0,
cf_name="",
level=0,
fd=0,
caller=1,
no_insert=0,
get_id=1,
key_id=1,
kv_size=5,
is_hit=1,
)
k3 = TraceRecord(
access_time=2,
block_id=3,
block_type=1,
block_size=1,
cf_id=0,
cf_name="",
level=0,
fd=0,
caller=1,
no_insert=0,
get_id=1,
key_id=1,
kv_size=5,
is_hit=1,
)
k4 = TraceRecord(
access_time=3,
block_id=4,
block_type=1,
block_size=1,
cf_id=0,
cf_name="",
level=0,
fd=0,
caller=1,
no_insert=0,
get_id=1,
key_id=1,
kv_size=5,
is_hit=1,
)
sequence = [k1, k1, k2, k3, k3, k3]
index = 0
expected_values = []
# Access k1, miss.
expected_values.append([1, 1, 1, [1], []])
# Access k1, hit.
expected_values.append([1, 2, 1, [1], []])
# Access k2, miss.
expected_values.append([2, 3, 2, [1, 2], []])
# Access k3, miss.
expected_values.append([3, 4, 3, [1, 2, 3], []])
# Access k3, hit.
expected_values.append([3, 5, 3, [1, 2, 3], []])
# Access k3, hit.
expected_values.append([3, 6, 3, [1, 2, 3], []])
for access in sequence:
cache.access(access)
assert_metrics(cache, expected_values[index])
index += 1
cache.access(k4)
assert_metrics(cache, expected_value)
def test_lru_cache():
print("Test LRU cache")
policies = []
policies.append(LRUPolicy())
# Access k4, miss. evict k1
test_cache(policies, [3, 7, 4, [2, 3, 4], []])
print("Test LRU cache: Success")
def test_mru_cache():
print("Test MRU cache")
policies = []
policies.append(MRUPolicy())
# Access k4, miss. evict k3
test_cache(policies, [3, 7, 4, [1, 2, 4], []])
print("Test MRU cache: Success")
def test_lfu_cache():
print("Test LFU cache")
policies = []
policies.append(LFUPolicy())
# Access k4, miss. evict k2
test_cache(policies, [3, 7, 4, [1, 3, 4], []])
print("Test LFU cache: Success")
def test_mix(cache):
print("Test Mix {} cache".format(cache.cache_name()))
n = 100000
records = 199
for i in range(n):
key_id = random.randint(0, records)
vs = random.randint(0, 10)
k = TraceRecord(
access_time=i,
block_id=key_id,
block_type=1,
block_size=vs,
cf_id=0,
cf_name="",
level=0,
fd=0,
caller=1,
no_insert=0,
get_id=key_id,
key_id=key_id,
kv_size=5,
is_hit=1,
)
cache.access(k)
assert cache.miss_ratio_stats.miss_ratio() > 0
print("Test Mix {} cache: Success".format(cache.cache_name()))
def test_hybrid(cache):
print("Test {} cache".format(cache.cache_name()))
k = TraceRecord(
access_time=0,
block_id=1,
block_type=1,
block_size=1,
cf_id=0,
cf_name="",
level=0,
fd=0,
caller=1,
no_insert=0,
get_id=1, # the first get request.
key_id=1,
kv_size=0, # no size.
is_hit=1,
)
cache.access(k) # Expect a miss.
# used size, num accesses, num misses, hash table size, blocks, get keys.
assert_metrics(cache, [1, 1, 1, [1], []])
k.access_time += 1
k.kv_size = 1
k.block_id = 2
cache.access(k) # k should be inserted.
assert_metrics(cache, [3, 2, 2, [1, 2], [1]])
k.access_time += 1
k.block_id = 3
cache.access(k) # k should not be inserted again.
assert_metrics(cache, [4, 3, 3, [1, 2, 3], [1]])
# A second get request referencing the same key.
k.access_time += 1
k.get_id = 2
k.block_id = 4
k.kv_size = 0
cache.access(k) # k should observe a hit. No block access.
assert_metrics(cache, [4, 4, 3, [1, 2, 3], [1]])
# A third get request searches three files, three different keys.
# And the second key observes a hit.
k.access_time += 1
k.kv_size = 1
k.get_id = 3
k.block_id = 3
k.key_id = 2
cache.access(k) # k should observe a miss. block 3 observes a hit.
assert_metrics(cache, [5, 5, 3, [1, 2, 3], [1, 2]])
k.access_time += 1
k.kv_size = 1
k.get_id = 3
k.block_id = 4
k.kv_size = 1
k.key_id = 1
cache.access(k) # k1 should observe a hit.
assert_metrics(cache, [5, 6, 3, [1, 2, 3], [1, 2]])
k.access_time += 1
k.kv_size = 1
k.get_id = 3
k.block_id = 4
k.kv_size = 1
k.key_id = 3
# k3 should observe a miss.
# However, as the get already complete, we should not access k3 any more.
cache.access(k)
assert_metrics(cache, [5, 7, 3, [1, 2, 3], [1, 2]])
# A fourth get request searches one file and two blocks. One row key.
k.access_time += 1
k.get_id = 4
k.block_id = 5
k.key_id = 4
k.kv_size = 1
cache.access(k)
assert_metrics(cache, [7, 8, 4, [1, 2, 3, 5], [1, 2, 4]])
# A bunch of insertions which evict cached row keys.
for i in range(6, 100):
k.access_time += 1
k.get_id = 0
k.block_id = i
cache.access(k)
k.get_id = 4
k.block_id = 100 # A different block.
k.key_id = 4 # Same row key and should not be inserted again.
k.kv_size = 1
cache.access(k)
assert_metrics(cache, [16, 103, 99, [i for i in range(101 - kSampleSize, 101)], []])
print("Test {} cache: Success".format(cache.cache_name()))
if __name__ == "__main__":
policies = []
policies.append(MRUPolicy())
policies.append(LRUPolicy())
policies.append(LFUPolicy())
test_hash_table()
test_lru_cache()
test_mru_cache()
test_lfu_cache()
test_mix(ThompsonSamplingCache(100, False, policies))
test_mix(ThompsonSamplingCache(100, True, policies))
test_mix(LinUCBCache(100, False, policies))
test_mix(LinUCBCache(100, True, policies))
test_hybrid(ThompsonSamplingCache(kSampleSize, True, [LRUPolicy()]))
test_hybrid(LinUCBCache(kSampleSize, True, [LRUPolicy()]))

View File

@ -5,7 +5,7 @@
#ifndef ROCKSDB_LITE
#ifdef GFLAGS
#include "tools/block_cache_trace_analyzer.h"
#include "tools/block_cache_analyzer/block_cache_trace_analyzer.h"
#include <algorithm>
#include <cinttypes>
@ -1395,13 +1395,12 @@ Status BlockCacheTraceAnalyzer::WriteHumanReadableTraceRecord(
}
int ret = snprintf(
trace_record_buffer_, sizeof(trace_record_buffer_),
"%" PRIu64 ",%" PRIu64 ",%u,%" PRIu64 ",%" PRIu64 ",%" PRIu32 ",%" PRIu64
""
",%u,%u,%" PRIu64 ",%" PRIu64 ",%" PRIu64 ",%u\n",
"%" PRIu64 ",%" PRIu64 ",%u,%" PRIu64 ",%" PRIu64 ",%s,%" PRIu32
",%" PRIu64 ",%u,%u,%" PRIu64 ",%" PRIu64 ",%" PRIu64 ",%u\n",
access.access_timestamp, block_id, access.block_type, access.block_size,
access.cf_id, access.level, access.sst_fd_number, access.caller,
access.no_insert, access.get_id, get_key_id, access.referenced_data_size,
access.is_cache_hit);
access.cf_id, access.cf_name.c_str(), access.level, access.sst_fd_number,
access.caller, access.no_insert, access.get_id, get_key_id,
access.referenced_data_size, access.is_cache_hit);
if (ret < 0) {
return Status::IOError("failed to format the output");
}
@ -2134,6 +2133,7 @@ int block_cache_trace_analyzer_tool(int argc, char** argv) {
analyzer.WriteAccessTimeline(label, kSecondInHour, false);
} else {
analyzer.WriteAccessTimeline(label, kSecondInMinute, false);
analyzer.WriteAccessTimeline(label, kSecondInHour, false);
}
}
}

View File

@ -23,7 +23,7 @@ int main() {
#include "rocksdb/trace_reader_writer.h"
#include "test_util/testharness.h"
#include "test_util/testutil.h"
#include "tools/block_cache_trace_analyzer.h"
#include "tools/block_cache_analyzer/block_cache_trace_analyzer.h"
#include "trace_replay/block_cache_tracer.h"
namespace rocksdb {
@ -343,7 +343,7 @@ TEST_F(BlockCacheTracerTest, BlockCacheAnalyzer) {
std::string l;
ASSERT_TRUE(getline(ss, l, ','));
if (l.find("block") == std::string::npos) {
if (unit != "_60" || user_access_only != "all_access_") {
if (user_access_only != "all_access_") {
continue;
}
}

View File

@ -11,7 +11,7 @@ int main() {
return 1;
}
#else // GFLAGS
#include "tools/block_cache_trace_analyzer.h"
#include "tools/block_cache_analyzer/block_cache_trace_analyzer.h"
int main(int argc, char** argv) {
return rocksdb::block_cache_trace_analyzer_tool(argc, argv);
}