Skip to content

Commit

Permalink
Refactor vm stream desc (#6989)
Browse files Browse the repository at this point in the history
* remove StreamDesc::num_machines

* Prepare one thread for one stream_type

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
lixinqi and oneflow-ci-bot authored Dec 11, 2021
1 parent ef8b06c commit 9abccf4
Show file tree
Hide file tree
Showing 14 changed files with 24 additions and 34 deletions.
1 change: 0 additions & 1 deletion oneflow/core/eager/critical_section_stream_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ intrusive::shared_ptr<StreamDesc> CriticalSectionStreamType::MakeStreamDesc(
const Resource& resource, int64_t this_machine_id) const {
auto ret = intrusive::make_shared<StreamDesc>();
ret->mut_stream_type_id()->__Init__(LookupStreamType4TypeIndex<CriticalSectionStreamType>());
ret->set_num_machines(1);
ret->set_num_streams_per_machine(1);
ret->set_num_streams_per_thread(1);
return ret;
Expand Down
1 change: 0 additions & 1 deletion oneflow/core/eager/lazy_job_stream_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ intrusive::shared_ptr<StreamDesc> LazyJobStreamType::MakeStreamDesc(const Resour
int64_t this_machine_id) const {
auto ret = intrusive::make_shared<StreamDesc>();
ret->mut_stream_type_id()->__Init__(LookupStreamType4TypeIndex<LazyJobStreamType>());
ret->set_num_machines(1);
ret->set_num_streams_per_machine(1);
ret->set_num_streams_per_thread(1);
return ret;
Expand Down
3 changes: 1 addition & 2 deletions oneflow/core/vm/async_cuda_stream_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,8 @@ intrusive::shared_ptr<StreamDesc> AsyncCudaStreamType::MakeStreamDesc(
std::size_t device_num = resource.gpu_device_num();
auto ret = intrusive::make_shared<StreamDesc>();
ret->mut_stream_type_id()->__Init__(LookupStreamType4TypeIndex<AsyncCudaStreamType>());
ret->set_num_machines(1);
ret->set_num_streams_per_machine(device_num);
ret->set_num_streams_per_thread(1);
ret->set_num_streams_per_thread(device_num);
return ret;
}

Expand Down
1 change: 0 additions & 1 deletion oneflow/core/vm/control_stream_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ intrusive::shared_ptr<StreamDesc> ControlStreamType::MakeStreamDesc(const Resour
int64_t this_machine_id) const {
auto ret = intrusive::make_shared<StreamDesc>();
ret->mut_stream_type_id()->__Init__(LookupStreamType4TypeIndex<ControlStreamType>());
ret->set_num_machines(1);
ret->set_num_streams_per_machine(1);
ret->set_num_streams_per_thread(1);
return ret;
Expand Down
3 changes: 1 addition & 2 deletions oneflow/core/vm/cpu_stream_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,8 @@ intrusive::shared_ptr<StreamDesc> CpuStreamType::MakeStreamDesc(const Resource&
std::size_t device_num = resource.cpu_device_num();
auto ret = intrusive::make_shared<StreamDesc>();
ret->mut_stream_type_id()->__Init__(LookupStreamType4TypeIndex<CpuStreamType>());
ret->set_num_machines(1);
ret->set_num_streams_per_machine(device_num);
ret->set_num_streams_per_thread(1);
ret->set_num_streams_per_thread(device_num);
return ret;
}

Expand Down
3 changes: 1 addition & 2 deletions oneflow/core/vm/cuda_copy_d2h_stream_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,8 @@ intrusive::shared_ptr<StreamDesc> CudaCopyD2HStreamType::MakeStreamDesc(
std::size_t device_num = resource.gpu_device_num();
auto ret = intrusive::make_shared<StreamDesc>();
ret->mut_stream_type_id()->__Init__(LookupStreamType4TypeIndex<CudaCopyD2HStreamType>());
ret->set_num_machines(1);
ret->set_num_streams_per_machine(device_num);
ret->set_num_streams_per_thread(1);
ret->set_num_streams_per_thread(device_num);
return ret;
}

Expand Down
3 changes: 1 addition & 2 deletions oneflow/core/vm/cuda_copy_h2d_stream_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,8 @@ intrusive::shared_ptr<StreamDesc> CudaCopyH2DStreamType::MakeStreamDesc(
std::size_t device_num = resource.gpu_device_num();
auto ret = intrusive::make_shared<StreamDesc>();
ret->mut_stream_type_id()->__Init__(LookupStreamType4TypeIndex<CudaCopyH2DStreamType>());
ret->set_num_machines(1);
ret->set_num_streams_per_machine(device_num);
ret->set_num_streams_per_thread(1);
ret->set_num_streams_per_thread(device_num);
return ret;
}

Expand Down
3 changes: 1 addition & 2 deletions oneflow/core/vm/cuda_stream_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,8 @@ intrusive::shared_ptr<StreamDesc> CudaStreamType::MakeStreamDesc(const Resource&
std::size_t device_num = resource.gpu_device_num();
auto ret = intrusive::make_shared<StreamDesc>();
ret->mut_stream_type_id()->__Init__(LookupStreamType4TypeIndex<CudaStreamType>());
ret->set_num_machines(1);
ret->set_num_streams_per_machine(device_num);
ret->set_num_streams_per_thread(1);
ret->set_num_streams_per_thread(device_num);
return ret;
}

Expand Down
3 changes: 1 addition & 2 deletions oneflow/core/vm/device_helper_stream_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,8 @@ intrusive::shared_ptr<StreamDesc> DeviceHelperStreamType::MakeStreamDesc(
CHECK_GT(device_num, 0);
auto ret = intrusive::make_shared<StreamDesc>();
ret->mut_stream_type_id()->__Init__(LookupStreamType4TypeIndex<DeviceHelperStreamType>());
ret->set_num_machines(1);
ret->set_num_streams_per_machine(device_num);
ret->set_num_streams_per_thread(1);
ret->set_num_streams_per_thread(device_num);
return ret;
}

Expand Down
1 change: 0 additions & 1 deletion oneflow/core/vm/host_stream_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ intrusive::shared_ptr<StreamDesc> HostStreamType::MakeStreamDesc(const Resource&
int64_t this_machine_id) const {
auto ret = intrusive::make_shared<StreamDesc>();
ret->mut_stream_type_id()->__Init__(LookupStreamType4TypeIndex<HostStreamType>());
ret->set_num_machines(1);
ret->set_num_streams_per_machine(1);
ret->set_num_streams_per_thread(1);
return ret;
Expand Down
8 changes: 4 additions & 4 deletions oneflow/core/vm/stream_desc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@ limitations under the License.
namespace oneflow {
namespace vm {

void StreamDesc::__Init__(const StreamTypeId& stream_type_id, int32_t num_machines,
int32_t num_streams_per_machine, int32_t num_streams_per_thread) {
void StreamDesc::__Init__(const StreamTypeId& stream_type_id, int32_t num_streams_per_machine,
int32_t num_streams_per_thread) {
mut_stream_type_id()->CopyFrom(stream_type_id);
set_num_machines(num_machines);
set_num_streams_per_machine(num_streams_per_machine);
set_num_streams_per_thread(num_streams_per_thread);
}

int32_t StreamDesc::num_threads() const {
int32_t num_devices = num_machines() * num_streams_per_machine();
int32_t num_devices = num_streams_per_machine();
if (num_devices == 0) { return 0; }
CHECK_EQ(num_devices % num_streams_per_thread(), 0);
return num_devices / num_streams_per_thread();
}
Expand Down
10 changes: 3 additions & 7 deletions oneflow/core/vm/stream_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,36 +59,32 @@ class StreamId final {
class StreamDesc final : public intrusive::Base {
public:
// Getters
int32_t num_machines() const { return num_machines_; }
int32_t num_streams_per_machine() const { return num_streams_per_machine_; }
int32_t num_streams_per_thread() const { return num_streams_per_thread_; }
const StreamTypeId& stream_type_id() const { return stream_type_id_.key().Get(); }
// Setters
void set_num_machines(int32_t val) { num_machines_ = val; }
void set_num_streams_per_machine(int32_t val) { num_streams_per_machine_ = val; }
void set_num_streams_per_thread(int32_t val) { num_streams_per_thread_ = val; }
StreamTypeId* mut_stream_type_id() { return stream_type_id_.mut_key()->Mutable(); }

// methods
void __Init__() {}
void __Init__(const StreamTypeId& stream_type_id, int32_t num_machines,
int32_t num_streams_per_machine, int32_t num_streams_per_thread);
void __Init__(const StreamTypeId& stream_type_id, int32_t num_streams_per_machine,
int32_t num_streams_per_thread);
int32_t num_threads() const;
int32_t parallel_num() const { return num_machines() * num_streams_per_machine(); }
int32_t parallel_num() const { return num_streams_per_machine(); }

private:
friend class intrusive::Ref;
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }

StreamDesc()
: intrusive_ref_(),
num_machines_(),
num_streams_per_machine_(),
num_streams_per_thread_(),
stream_type_id_() {}
intrusive::Ref intrusive_ref_;
// fields
int32_t num_machines_;
int32_t num_streams_per_machine_;
int32_t num_streams_per_thread_;

Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/vm/test_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ void TestUtil::AddStreamDescByInstrNames(VmDesc* vm_desc, int64_t parallel_num,
const std::vector<std::string>& instr_names) {
auto Insert = [&](const std::string& instr_name) {
const auto& stream_type_id = LookupInstrTypeId(instr_name).stream_type_id();
auto stream_desc = intrusive::make_shared<StreamDesc>(stream_type_id, 1, parallel_num, 1);
auto stream_desc = intrusive::make_shared<StreamDesc>(stream_type_id, parallel_num, 1);
vm_desc->mut_stream_type_id2desc()->Insert(stream_desc.Mutable());
};
for (const auto& instr_name : instr_names) {
Expand Down
16 changes: 10 additions & 6 deletions oneflow/core/vm/transport_stream_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/vm/transport_stream_type.h"
#include "oneflow/core/common/multi_client.h"

namespace oneflow {
namespace vm {
Expand Down Expand Up @@ -51,16 +52,19 @@ template<typename DerivedT>
intrusive::shared_ptr<StreamDesc> TransportStreamType::MakeTransportStreamDesc(
const Resource& resource, int64_t this_machine_id) const {
std::size_t device_num = 0;
if (resource.has_cpu_device_num()) {
device_num = std::max<std::size_t>(device_num, resource.cpu_device_num());
}
if (resource.has_gpu_device_num()) {
device_num = std::max<std::size_t>(device_num, resource.gpu_device_num());
if (!CHECK_JUST(IsMultiClient())) {
if (resource.has_cpu_device_num()) {
device_num = std::max<std::size_t>(device_num, resource.cpu_device_num());
}
if (resource.has_gpu_device_num()) {
device_num = std::max<std::size_t>(device_num, resource.gpu_device_num());
}
} else {
// Keep device_num = 0. TransportStreamType is not used in multi-client mode.
}
auto ret = intrusive::make_shared<StreamDesc>();
ret->mut_stream_type_id()->__Init__(LookupStreamType4TypeIndex<DerivedT>());
// TODO(lixinqi): remove this ugly field
ret->set_num_machines(1);
ret->set_num_streams_per_machine(device_num);
// TODO(lixinqi): refactor to a num_threads_per_machine field
ret->set_num_streams_per_thread(1);
Expand Down

0 comments on commit 9abccf4

Please sign in to comment.