Skip to content
This repository has been archived by the owner on May 17, 2022. It is now read-only.

Commit

Permalink
fix intrusive_ptr
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei-Lebedev committed Nov 18, 2020
1 parent 0032855 commit c9173c1
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
10 changes: 5 additions & 5 deletions include/torch_ucc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class ProcessGroupUCC : public ProcessGroup {
public:
WorkColl(
torch_ucc_coll_ops_t ops,
std::list<std::shared_ptr<WorkColl>>& list)
std::list<c10::intrusive_ptr<WorkColl>>& list)
: coll_ops(ops),
work_list(list),
external_progress(false),
Expand All @@ -63,8 +63,8 @@ class ProcessGroupUCC : public ProcessGroup {

protected:
torch_ucc_coll_ops_t coll_ops;
std::list<std::shared_ptr<WorkColl>>& work_list;
std::list<std::shared_ptr<WorkColl>>::iterator work_list_entry;
std::list<c10::intrusive_ptr<WorkColl>>& work_list;
std::list<c10::intrusive_ptr<WorkColl>>::iterator work_list_entry;
bool external_progress;
char* scratch;
std::vector<at::Tensor> src;
Expand Down Expand Up @@ -174,12 +174,12 @@ class ProcessGroupUCC : public ProcessGroup {
std::mutex pg_mutex;
std::thread progress_thread;
bool stop_progress_loop;
std::list<std::shared_ptr<WorkColl>> progress_list;
std::list<c10::intrusive_ptr<WorkColl>> progress_list;
std::condition_variable queue_produce_cv;
std::condition_variable queue_consume_cv;

void progress_loop();
std::shared_ptr<ProcessGroup::Work> enqueue_request(
c10::intrusive_ptr<ProcessGroup::Work> enqueue_request(
torch_ucc_coll_request_t* req,
void* scratch);
torch_ucc_coll_comm_t* get_coll_comm();
Expand Down
13 changes: 7 additions & 6 deletions src/torch_ucc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,21 +218,22 @@ void ProcessGroupUCC::progress_loop() {
}
}

std::shared_ptr<ProcessGroup::Work> ProcessGroupUCC::enqueue_request(
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::enqueue_request(
torch_ucc_coll_request_t* req,
void* scratch) {
std::unique_lock<std::mutex> lock(pg_mutex);

auto iter = progress_list.emplace(
progress_list.end(),
std::make_shared<ProcessGroupUCC::WorkColl>(coll_ops, progress_list));
c10::make_intrusive<ProcessGroupUCC::WorkColl>(coll_ops, progress_list));
(*iter)->work_list_entry = iter;
(*iter)->coll_req = req;
(*iter)->external_progress = config.enable_progress_thread;
(*iter)->scratch = (char*)scratch;
auto workreq = (*iter);
lock.unlock();
queue_produce_cv.notify_one();
return *iter;
return workreq;
}

ProcessGroupUCC::~ProcessGroupUCC() {
Expand Down Expand Up @@ -424,7 +425,7 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::send(
throw std::runtime_error("ProcessGroupUCC: failed to send msg");
}

return std::make_shared<ProcessGroupUCC::WorkUCX>(req, ucx_comm);
return c10::make_intrusive<ProcessGroupUCC::WorkUCX>(req, ucx_comm);
}

c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::recv(
Expand All @@ -442,7 +443,7 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::recv(
throw std::runtime_error("ProcessGroupUCC: failed to recv msg");
}

return std::make_shared<ProcessGroupUCC::WorkUCX>(req, ucx_comm);
return c10::make_intrusive<ProcessGroupUCC::WorkUCX>(req, ucx_comm);
}

c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::recvAnysource(
Expand All @@ -456,7 +457,7 @@ c10::intrusive_ptr<ProcessGroup> ProcessGroupUCC::createProcessGroupUCC(
int rank,
int size,
const std::chrono::duration<float>& timeout) {
return std::make_shared<ProcessGroupUCC>(store, rank, size);
return c10::make_intrusive<ProcessGroupUCC>(store, rank, size);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
Expand Down

0 comments on commit c9173c1

Please sign in to comment.