Skip to content

Commit

Permalink
[AutoScheduler] Fix the occasional crash caused by split memo (apache…
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored and Trevor Morris committed Dec 4, 2020
1 parent dab8783 commit 3288531
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 68 deletions.
6 changes: 4 additions & 2 deletions src/auto_scheduler/search_policy/sketch_policy_rules.cc
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ std::vector<std::pair<State, int>> RuleSpecialComputeLocationGPU::Apply(

PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, State* state,
std::mt19937* rand_gen) const {
SplitFactorizationMemo split_memo;
int max_innermost_split_factor =
GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor);

Expand All @@ -470,8 +471,9 @@ PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* p

ICHECK(ps->extent);
int extent = GetIntImm(ps->extent.value());
const auto& candidate_lens = policy->split_memo.GetFactorizationSchemes(
extent, ps->lengths.size(), max_innermost_split_factor);
const auto& candidate_lens = split_memo.GetFactorizationSchemes(extent, ps->lengths.size(),
max_innermost_split_factor);
ICHECK(!candidate_lens.empty());
const auto& candidate_lengths = candidate_lens[(*rand_gen)() % candidate_lens.size()];

pstate->transform_steps.Set(
Expand Down
42 changes: 3 additions & 39 deletions src/auto_scheduler/search_policy/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,55 +413,19 @@ void PruneInvalidState(const SearchTask& task, Array<State>* states) {
}

/********** SplitFactorizationMemo **********/

void SplitFactorizationMemo::ReadWriteLock::GetRead() {
std::unique_lock<std::mutex> lock(cv_mutex_);
// Wake up and get the mutex lock if there's no writing thread
cv_.wait(lock, [this]() { return !this->is_writing_; });
read_count_++;
}

void SplitFactorizationMemo::ReadWriteLock::GetWrite() {
std::unique_lock<std::mutex> lock(cv_mutex_);
// Wake up and get the mutex lock if there's no reading or writing threads
cv_.wait(lock, [this]() { return this->read_count_ == 0 && !this->is_writing_; });
is_writing_ = true;
}

void SplitFactorizationMemo::ReadWriteLock::UnlockRead() {
std::lock_guard<std::mutex> lock(cv_mutex_);
read_count_--;
// Notify the other blocked threads if this is the last reading thread
if (read_count_ == 0) {
cv_.notify_one();
}
}

void SplitFactorizationMemo::ReadWriteLock::UnlockWrite() {
std::lock_guard<std::mutex> lock(cv_mutex_);
is_writing_ = false;
// Notify the other blocked threads
cv_.notify_one();
}

const Array<Array<Integer>>& SplitFactorizationMemo::GetFactorizationSchemes(
int extent, int n_lengths, int max_innermost_factor) {
QueryKey key = std::make_tuple(extent, n_lengths, max_innermost_factor);
const auto& const_memory = memory_;
lock_.GetRead();
const auto& it = const_memory.find(key);
const auto& memory_end = const_memory.end();
lock_.UnlockRead();
if (it != memory_end) {
const auto& it = memory_.find(key);
if (it != memory_.end()) {
return it->second;
}

lock_.GetWrite();
tmp_stack_ = Array<Integer>(n_lengths, Integer());
results_ = &memory_[key];
n_lengths_ = n_lengths;

DfsEnumerate(0, extent, max_innermost_factor);
lock_.UnlockWrite();

return *results_;
}
Expand Down
27 changes: 0 additions & 27 deletions src/auto_scheduler/search_policy/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -677,33 +677,6 @@ class SplitFactorizationMemo {
private:
void DfsEnumerate(int now, int remaining_length, int max_innermost_factor);

/*!
* \brief A simple implementation of read-write lock.
* The guarded block can be read by multiple threads at the same time, while other operations will
* be blocked if one thread is writing.
* \note Writing threads will wait until all reading threads have finshed. If there're multiple
* writing threads, the process order of them is not guaranteed.
*/
class ReadWriteLock {
public:
/*! \brief The method to get the read lock. One thread can process read if there's on other
* writing threads. */
void GetRead();
/*! \brief The method to get the write lock. One thread can process write if there's on other
* reading or writing threads. */
void GetWrite();
/*! \brief The method to release the read lock. */
void UnlockRead();
/*! \brief The method to release the write lock. */
void UnlockWrite();

private:
uint32_t read_count_ = 0;
bool is_writing_ = false;
std::mutex cv_mutex_;
std::condition_variable cv_;
} lock_;

std::unordered_map<QueryKey, Array<Array<Integer>>> memory_;

int n_lengths_;
Expand Down

0 comments on commit 3288531

Please sign in to comment.