Skip to content

Commit

Permalink
Clear empty thread when graph destroy (#7633)
Browse files Browse the repository at this point in the history
  • Loading branch information
chengtbf authored Mar 22, 2022
1 parent 174bc3d commit 3e8585e
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 35 deletions.
11 changes: 10 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ jobs:
working-directory: ${{ env.ONEFLOW_SRC }}
run: |
docker run -d --rm --privileged --shm-size=8g \
--pids-limit -1 \
--pids-limit 1000 \
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--runtime=nvidia \
-v /dataset:/dataset:ro -v /model_zoo:/model_zoo:ro \
Expand Down Expand Up @@ -691,6 +691,14 @@ jobs:
EXTRA_DOCKER_ARGS+=" --env ONEFLOW_THREAD_ENABLE_LOCAL_MESSAGE_QUEUE=1"
EXTRA_DOCKER_ARGS+=" --env ONEFLOW_KERNEL_DISABLE_BLOB_ACCESS_CHECKER=1"
echo "EXTRA_DOCKER_ARGS=${EXTRA_DOCKER_ARGS}" >> $GITHUB_ENV
- name: Set Thread Limit (CPU)
if: ${{ !fromJson(matrix.cache-hit) && matrix.device == 'cpu' }}
run: |
echo "THREAD_LIMIT=8000" >> $GITHUB_ENV
- name: Set Thread Limit (CUDA)
if: ${{ !fromJson(matrix.cache-hit) && matrix.device == 'cuda' }}
run: |
echo "THREAD_LIMIT=3000" >> $GITHUB_ENV
- name: Enable ONEFLOW_TEST_VERBOSE
if: ${{ contains(github.event.pull_request.labels.*.name, 'need-test-verbose') }}
run: |
Expand All @@ -710,6 +718,7 @@ jobs:
working-directory: ${{ env.ONEFLOW_SRC }}
run: |
docker run -d --rm --privileged --shm-size=8g \
--pids-limit ${{ env.THREAD_LIMIT }} \
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--runtime=nvidia \
-v /dataset:/dataset:ro -v /model_zoo:/model_zoo:ro \
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/framework/multi_client_session_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ Maybe<void> MultiClientSessionContext::TryInit(const std::string& config_proto_s
}

Maybe<void> MultiClientSessionContext::UpdateResource(const Resource& reso_proto) {
CHECK_OR_RETURN(is_inited_);
CHECK_NOTNULL_OR_RETURN((Global<ResourceDesc, ForSession>::Get()));
Global<ResourceDesc, ForSession>::Get()->Update(reso_proto);
return Maybe<void>::Ok();
Expand Down
39 changes: 20 additions & 19 deletions oneflow/core/framework/nn_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,12 @@ NNGraph::~NNGraph() {
Maybe<void> NNGraph::Close() {
if (!is_closed_) {
VLOG(1) << "Try to close c nn graph name " << name_ << "." << std::endl;
CloseRuntimeBuffers();
runtime_.reset();
if (runtime_inited_) {
CloseRuntimeBuffers();
runtime_.reset();
}
session_ctx_->RemoveGraphFreeEagerTensors(name_);
is_closed_ = true;
VLOG(1) << "Finish close c nn graph name " << name_ << "." << std::endl;

session_ctx_.reset();
Expand Down Expand Up @@ -428,25 +431,23 @@ void NNGraph::NewRuntimeBuffers() {
}

void NNGraph::CloseRuntimeBuffers() {
if (runtime_inited_) {
{
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<CriticalSectionInstance>>>::Get();
for (const std::string& output_op_name : outputs_op_names_) {
buffer_mgr->Get(GetOutputBufferName(name_, output_op_name))->Close();
}
for (const std::string& input_op_name : inputs_op_names_) {
buffer_mgr->Get(GetInputBufferName(name_, input_op_name))->Close();
}
buffer_mgr->Get(GetOutputCriticalSectionCallbackBufferName(name_))->Close();
buffer_mgr->Get(GetOutputCriticalSectionWaitBufferName(name_))->Close();
buffer_mgr->Get(GetInputCriticalSectionCallbackBufferName(name_))->Close();
buffer_mgr->Get(GetInputCriticalSectionWaitBufferName(name_))->Close();
{
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<CriticalSectionInstance>>>::Get();
for (const std::string& output_op_name : outputs_op_names_) {
buffer_mgr->Get(GetOutputBufferName(name_, output_op_name))->Close();
}
{
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get();
buffer_mgr->Get(GetCallbackNotifierBufferName(name_))->Close();
buffer_mgr->Get(GetSourceTickBufferName(name_))->Close();
for (const std::string& input_op_name : inputs_op_names_) {
buffer_mgr->Get(GetInputBufferName(name_, input_op_name))->Close();
}
buffer_mgr->Get(GetOutputCriticalSectionCallbackBufferName(name_))->Close();
buffer_mgr->Get(GetOutputCriticalSectionWaitBufferName(name_))->Close();
buffer_mgr->Get(GetInputCriticalSectionCallbackBufferName(name_))->Close();
buffer_mgr->Get(GetInputCriticalSectionWaitBufferName(name_))->Close();
}
{
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get();
buffer_mgr->Get(GetCallbackNotifierBufferName(name_))->Close();
buffer_mgr->Get(GetSourceTickBufferName(name_))->Close();
}
}

Expand Down
15 changes: 14 additions & 1 deletion oneflow/core/job/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,11 @@ bool HasNonCtrlConsumedRegstDescId(const TaskProto& task) {
} // namespace

Runtime::Runtime(const Plan& plan, const HashMap<std::string, Blob*>& variable_op_name2eager_blob) {
DumpThreadIdsFromPlan(plan);
{
// NOTE(chengcheng): All runtime Global objects AddPlan
Global<RegstMgr>::Get()->AddPlan(plan, variable_op_name2eager_blob);
Global<ThreadMgr>::Get()->AddPlan(plan);
Global<ThreadMgr>::Get()->AddThreads(thread_ids_);
Global<RuntimeJobDescs>::Get()->AddPlan(plan);
collective_boxing_scheduler_plan_token_ =
Global<boxing::collective::Scheduler>::Get()->AddPlan(plan);
Expand Down Expand Up @@ -106,7 +107,19 @@ Runtime::~Runtime() {
Global<RuntimeCtx>::Get()->WaitUntilCntEqualZero(GetRunningActorCountKeyByJobId(pair.first));
}
OF_SESSION_BARRIER();
Global<ThreadMgr>::Get()->TryDeleteThreads(thread_ids_);
Global<boxing::collective::Scheduler>::Get()->DeletePlan(collective_boxing_scheduler_plan_token_);
}

void Runtime::DumpThreadIdsFromPlan(const Plan& plan) {
const int64_t this_rank = GlobalProcessCtx::Rank();
for (const TaskProto& task : plan.task()) {
TaskId task_id = DecodeTaskIdFromInt64(task.task_id());
StreamId stream_id = task_id.stream_id();
if (stream_id.rank() != this_rank) { continue; }
int64_t thrd_id = EncodeStreamIdToInt64(stream_id);
thread_ids_.insert(thrd_id);
}
}

} // namespace oneflow
3 changes: 3 additions & 0 deletions oneflow/core/job/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ class Runtime final {
Runtime(const Plan& plan, const HashMap<std::string, Blob*>& variable_op_name2eager_blob);

private:
void DumpThreadIdsFromPlan(const Plan& plan);

HashMap<int64_t, int64_t> job_id2actor_size_;
HashSet<int64_t> thread_ids_;

boxing::collective::SchedulerPlanToken* collective_boxing_scheduler_plan_token_;
};
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/thread/thread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ void Thread::AddTask(const TaskProto& task) {
CHECK(id2task_.emplace(task.task_id(), task).second);
}

bool Thread::Empty() const { return id2actor_ptr_.empty(); }

void Thread::PollMsgChannel() {
while (true) {
if (local_msg_queue_.empty()) {
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/thread/thread.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class Thread {
virtual ~Thread();

void AddTask(const TaskProto&);
// NOTE(chengcheng): Indicates whether all actors in the thread have been destructed.
bool Empty() const;

Channel<ActorMsg>* GetMsgChannelPtr() { return &msg_channel_; }

Expand Down
47 changes: 34 additions & 13 deletions oneflow/core/thread/thread_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,12 @@ limitations under the License.
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/job/global_for.h"

namespace oneflow {

ThreadMgr::~ThreadMgr() {
for (auto& thread_pair : threads_) {
ActorMsg msg = ActorMsg::BuildCommandMsg(-1, ActorCmd::kStopThread);
thread_pair.second->GetMsgChannelPtr()->Send(msg);
thread_pair.second.reset();
VLOG(3) << "actor thread " << thread_pair.first << " finish";
}
CHECK(threads_.empty()) << " Runtime Error! num = " << threads_.size()
<< " threads did not destroy with graph.";
}

Thread* ThreadMgr::GetThrd(int64_t thrd_id) {
Expand All @@ -36,20 +31,46 @@ Thread* ThreadMgr::GetThrd(int64_t thrd_id) {
return iter->second.get();
}

void ThreadMgr::AddPlan(const Plan& plan) {
void ThreadMgr::AddThreads(const HashSet<int64_t>& thread_ids) {
const int64_t this_rank = GlobalProcessCtx::Rank();
for (const TaskProto& task : plan.task()) {
TaskId task_id = DecodeTaskIdFromInt64(task.task_id());
StreamId stream_id = task_id.stream_id();
for (int64_t thrd_id : thread_ids) {
const auto& it = threads_.find(thrd_id);
if (it != threads_.end()) {
// NOTE(chengcheng): check thread is not null.
CHECK(it->second) << " Runtime Error! Thread: " << thrd_id << " in manager must be NOT null.";
continue;
}
StreamId stream_id = DecodeStreamIdFromInt64(thrd_id);
if (stream_id.rank() != this_rank) { continue; }
int64_t thrd_id = EncodeStreamIdToInt64(stream_id);
if (threads_.find(thrd_id) != threads_.end()) { continue; }
Thread* thread = new Thread(stream_id);
CHECK_NOTNULL(thread);
threads_[thrd_id].reset(thread);
}
}

void ThreadMgr::TryDeleteThreads(const HashSet<int64_t>& thread_ids) {
std::unique_lock<std::mutex> lock(mutex4del_threads_);
for (int64_t thrd_id : thread_ids) {
const auto& it = threads_.find(thrd_id);
if (it == threads_.end()) { continue; }
auto& thread = it->second;
CHECK(thread) << " actor thread " << thrd_id << " non-existent but want to delete";
if (thread->Empty()) {
// NOTE(chengcheng): Only delete thread when it is empty.
// We need send Stop msg to exit the main loop of the thread. Here we can safely call reset
// directly, because the thread destructor will specify actor_thread_.join() to blocking
// wait for the end of the thread real execution.
ActorMsg msg = ActorMsg::BuildCommandMsg(-1, ActorCmd::kStopThread);
thread->GetMsgChannelPtr()->Send(msg);
thread.reset();
VLOG(2) << " actor thread " << thrd_id << " finish.";
threads_.erase(it);
} else {
LOG(INFO) << " actor thread " << thrd_id << " not delete because it is not empty.";
}
}
}

void SingleThreadLoop(size_t num, std::function<void(size_t i)> Callback) {
FOR_RANGE(size_t, i, 0, num) { Callback(i); }
}
Expand Down
5 changes: 4 additions & 1 deletion oneflow/core/thread/thread_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#ifndef ONEFLOW_CORE_THREAD_THREAD_MANAGER_H_
#define ONEFLOW_CORE_THREAD_THREAD_MANAGER_H_

#include <mutex>
#include "oneflow/core/common/channel.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/auto_registration_factory.h"
Expand All @@ -36,13 +37,15 @@ class ThreadMgr final {
ThreadMgr() = default;
~ThreadMgr();

void AddPlan(const Plan& plan);
void AddThreads(const HashSet<int64_t>& thread_ids);
void TryDeleteThreads(const HashSet<int64_t>& thread_ids);
Thread* GetThrd(int64_t thrd_id);

private:
friend class Global<ThreadMgr>;

HashMap<int64_t, std::unique_ptr<Thread>> threads_;
std::mutex mutex4del_threads_;
};

void SingleThreadLoop(size_t num, std::function<void(size_t i)> Callback);
Expand Down

0 comments on commit 3e8585e

Please sign in to comment.