Skip to content

Commit

Permalink
Add transfer cost for middle nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
Yipeng1994 committed Jan 24, 2022
1 parent 2824513 commit d377f11
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
12 changes: 10 additions & 2 deletions oneflow/core/auto_parallel/boxing_collector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,21 @@ void DfsSetNdSbp(std::vector<::oneflow::cfg::SbpParallel>& id2SbpParallel, int32
} // namespace

// Construct a boxing collector with given maximum number of axis
void BoxingCollector::Init(int32_t max_axis) {
void BoxingCollector::Init(int32_t max_axis, double transfer_cost_) {
// Set up at least two split for op graph.
// For a negative example: Resnet50 only have B, P, S(0)
SetTransferCost(transfer_cost_);
CollectUniverse(max_axis);
GenerateNdSbpList();
GenerateCombination(2);
}

// Set up the cost incresement after each middle node
void BoxingCollector::SetTransferCost(double transfer_cost_) {
transfer_cost = transfer_cost_;
if (transfer_cost < 0) { transfer_cost = 0.0; }
}

// Collect Sbp Parallel
void BoxingCollector::CollectUniverse(const cfg::SbpParallel& sbp) {
if (SbpParallelUniverse_.find(sbp) == SbpParallelUniverse_.end()) {
Expand Down Expand Up @@ -138,7 +145,7 @@ Maybe<void> BoxingCollector::GenerateCombination(int32_t max_middle_node_num) {
// k is the middle node, i -> k -> j
for (int32_t k = 0; k < n; k++) {
if (NotMiddleNode(i, j, k, middle_node_num_ik)) { continue; }
double curr_copy_cost = minimum_copy_cost[i][k] + minimum_copy_cost[k][j];
double curr_copy_cost = minimum_copy_cost[i][k] + minimum_copy_cost[k][j] + transfer_cost;
if (curr_copy_cost < minimum_copy_cost[i][j]) {
minimum_copy_cost[i][j] = curr_copy_cost;
}
Expand Down Expand Up @@ -317,6 +324,7 @@ Maybe<void> BoxingCollector::AskSbpCombination(const cfg::NdSbp& sbp_producer,
// Customized boxing collector and try the algorithm again
BoxingCollector customized_boxing_collector;
customized_boxing_collector.CollectUniverse(logical_blob_desc.shape().NumAxes());
customized_boxing_collector.SetTransferCost(transfer_cost);
customized_boxing_collector.GenerateNdSbpList();
// Filter out unsuitable middle nodes before computing minimum cost.
customized_boxing_collector.FilterNdSbpList4LogicalShape(logical_blob_desc, *parallel_hierarchy);
Expand Down
7 changes: 5 additions & 2 deletions oneflow/core/auto_parallel/boxing_collector.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ class BoxingCollector final {

// Set default Sbp list
void CollectUniverse(int32_t max_axis);

// Set up the cost incresement after each middle node
void SetTransferCost(double transfer_cost_);
// Construct a boxing collector with given maximum number of axis
void Init(int32_t max_axis);
void Init(int32_t max_axis, double transfer_cost_);

// Generate nd sbp list
void GenerateNdSbpList();
Expand Down Expand Up @@ -72,6 +73,8 @@ class BoxingCollector final {
std::unordered_map<::oneflow::cfg::NdSbp, int32_t> NdSbpUniverse;
// Relationship between id and Nd Sbp
std::vector<cfg::NdSbp> nd_sbp_lists;
// The cost incresement after each middle node
double transfer_cost;
}; // class BoxingCollector

} // namespace oneflow
Expand Down
3 changes: 2 additions & 1 deletion oneflow/core/job_rewriter/boxing_with_middle_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ Maybe<void> BoxingWithMiddleNodes(const OpGraph& op_graph, JobBuilder* job_build
// We assemble the boxing table from S(0) to S(5).
// Those splitting in higher axes are considered in the customized boxing.
constexpr int32_t kRegularMaxSplitAxes = 6;
boxing_collector.Init(kRegularMaxSplitAxes);
boxing_collector.Init(kRegularMaxSplitAxes,
job_builder->job().job_conf().auto_parallel_transfer_cost());
std::vector<cfg::NdSbp> middle_sbps;
HashMap<const OpNode*, OperatorConf> op_node2op_conf;
// Fill other unsupported combinations
Expand Down

0 comments on commit d377f11

Please sign in to comment.