-
Notifications
You must be signed in to change notification settings - Fork 786
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor vm stream desc #6989
Refactor vm stream desc #6989
Changes from 4 commits
caf225f
41b4507
d34e9e3
01fe89a
68fdc0f
54b9a7c
f1eda54
5fecfd5
ca41527
8e692a4
8ae6bb7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -74,9 +74,8 @@ intrusive::shared_ptr<StreamDesc> AsyncCudaStreamType::MakeStreamDesc( | |
std::size_t device_num = resource.gpu_device_num(); | ||
auto ret = intrusive::make_shared<StreamDesc>(); | ||
ret->mut_stream_type_id()->__Init__(LookupStreamType4TypeIndex<AsyncCudaStreamType>()); | ||
ret->set_num_machines(1); | ||
ret->set_num_streams_per_machine(device_num); | ||
ret->set_num_streams_per_thread(1); | ||
ret->set_num_streams_per_thread(device_num); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. num_threads = num_streams_per_machine / num_streams_per_thread |
||
return ret; | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and | |
limitations under the License. | ||
*/ | ||
#include "oneflow/core/vm/transport_stream_type.h" | ||
#include "oneflow/core/common/multi_client.h" | ||
|
||
namespace oneflow { | ||
namespace vm { | ||
|
@@ -51,16 +52,19 @@ template<typename DerivedT> | |
intrusive::shared_ptr<StreamDesc> TransportStreamType::MakeTransportStreamDesc( | ||
const Resource& resource, int64_t this_machine_id) const { | ||
std::size_t device_num = 0; | ||
if (resource.has_cpu_device_num()) { | ||
device_num = std::max<std::size_t>(device_num, resource.cpu_device_num()); | ||
} | ||
if (resource.has_gpu_device_num()) { | ||
device_num = std::max<std::size_t>(device_num, resource.gpu_device_num()); | ||
if (!CHECK_JUST(IsMultiClient())) { | ||
if (resource.has_cpu_device_num()) { | ||
device_num = std::max<std::size_t>(device_num, resource.cpu_device_num()); | ||
} | ||
if (resource.has_gpu_device_num()) { | ||
device_num = std::max<std::size_t>(device_num, resource.gpu_device_num()); | ||
} | ||
} else { | ||
// Keep device_num = 0. TransportStreamType is not used in multi-client mode. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个是目前的现状吗?TransportStreamType只在 single-client eager 模式下有用到吗? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是的。 |
||
} | ||
auto ret = intrusive::make_shared<StreamDesc>(); | ||
ret->mut_stream_type_id()->__Init__(LookupStreamType4TypeIndex<DerivedT>()); | ||
// TODO(lixinqi): remove this ugly field | ||
ret->set_num_machines(1); | ||
ret->set_num_streams_per_machine(device_num); | ||
// TODO(lixinqi): refactor to a num_threads_per_machine field | ||
ret->set_num_streams_per_thread(1); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前这个字段完全没有用,直接去掉。