Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/graph logical op debug repr #8131

Merged
merged 149 commits into from
Jun 14, 2022
Merged
Show file tree
Hide file tree
Changes from 124 commits
Commits
Show all changes
149 commits
Select commit Hold shift + click to select a range
a470ff4
add zero limit
strint Apr 8, 2022
9447157
add debug
strint Apr 12, 2022
5eefdf8
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
strint Apr 12, 2022
e3acaa9
add mix zero test
strint Apr 12, 2022
b481a7e
refactor zero api
strint Apr 13, 2022
3b56468
zero test with mp
strint Apr 14, 2022
66c8ac3
add 2d test
strint Apr 14, 2022
4f56df2
add zero nd
strint Apr 15, 2022
635ac69
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
strint Apr 19, 2022
2834289
add nd zero
strint Apr 21, 2022
b805256
add sbp cast
strint Apr 21, 2022
2ede354
test passed soft limit consumer
strint Apr 22, 2022
0227f54
refine size api
strint Apr 22, 2022
5989506
add module config
xiacijie Apr 26, 2022
dd08951
save nn.Module info in job.proto for better debugging
xiacijie Apr 28, 2022
090a7f4
add new line
xiacijie Apr 28, 2022
e577d82
Merge branch 'master' into add-ModuleBlock.ops()-method
xiacijie Apr 28, 2022
9560f56
add ModuleBlock.ops_proto() API
xiacijie Apr 28, 2022
01c19f8
Merge branch 'add-ModuleBlock.ops()-method' of github.com:Oneflow-Inc…
xiacijie Apr 28, 2022
7036e04
zero use stage 2
strint Apr 28, 2022
e5e637d
Merge branch 'master' into add-ModuleBlock.ops()-method
xiacijie Apr 28, 2022
9eb7a5a
print operators' info when print ModuleBlock
xiacijie Apr 29, 2022
b727d97
Merge branch 'add-ModuleBlock.ops()-method' of github.com:Oneflow-Inc…
xiacijie Apr 29, 2022
2269b9e
handle VariableOpConf
xiacijie Apr 29, 2022
7ea7fc1
update
xiacijie Apr 29, 2022
2dfd997
Merge branch 'master' into add-ModuleBlock.ops()-method
xiacijie Apr 29, 2022
048965f
update
xiacijie Apr 29, 2022
35d23b0
fix
xiacijie Apr 29, 2022
18fce6c
Merge branch 'add-ModuleBlock.ops()-method' of github.com:Oneflow-Inc…
xiacijie Apr 29, 2022
8bc590f
move operators repr method to graph util
xiacijie Apr 29, 2022
c26763e
add limit consumer api
strint Apr 29, 2022
d84e8a9
add new api
strint Apr 29, 2022
2555ee1
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
strint Apr 29, 2022
ac0b9d2
refine zero s select
strint Apr 29, 2022
d2f9f35
Merge branch 'add-ModuleBlock.ops()-method' of https://github.com/One…
strint Apr 29, 2022
55bb6df
add module block
strint Apr 29, 2022
5039557
fix
strint Apr 29, 2022
8e0abb7
refact for rm op in module conf
strint Apr 29, 2022
511e25b
fix
strint Apr 29, 2022
9101fb7
add sbp debug
strint Apr 30, 2022
8cb036a
add sbp repr
strint May 2, 2022
0ab75a2
add shape
strint May 2, 2022
c69e35f
refine
strint May 2, 2022
8eb62fe
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
strint May 2, 2022
1110066
add sys op in repr
strint May 2, 2022
4306ac7
add full op debug
strint May 2, 2022
dd0a865
fix index out of range
strint May 5, 2022
51f3559
Merge branch 'feat/zero_mix_with_mp' of https://github.com/Oneflow-In…
strint May 6, 2022
e0304a7
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
strint May 6, 2022
73da0b7
Merge branch 'feat/zero_mix_with_mp' of https://github.com/Oneflow-In…
strint May 6, 2022
501518f
rm zero limit on device type
strint May 6, 2022
0e2f9a2
Merge branch 'feat/zero_mix_with_mp' of https://github.com/Oneflow-In…
strint May 6, 2022
8a67dd4
add no scope op to graph
strint May 6, 2022
e3eed8c
zero test with activation checkpointing
strint May 7, 2022
f966b4f
Merge branch 'feat/zero_mix_with_mp' of https://github.com/Oneflow-In…
strint May 7, 2022
a6b16cd
merge zero
strint May 7, 2022
ab20d2e
Merge branch 'master' into feat/op_level_debug_backward_sbp
strint May 9, 2022
1599ee6
fix order
strint May 19, 2022
5144e32
Merge branch 'feat/op_level_debug_backward_sbp' of https://github.com…
strint May 19, 2022
ebc9ff9
add indentity when dp sequence len is 1
strint May 21, 2022
e77dd89
add debug repr
strint May 21, 2022
cc40e14
refine repr of op
strint May 26, 2022
9a61d3b
refine and fix
strint May 26, 2022
51b9657
rm useless log
strint May 26, 2022
2011e2c
move to base with master
strint May 26, 2022
b7f4fed
fix confict
strint May 26, 2022
1ba26df
merge op level debug
strint May 26, 2022
ffe2094
fix
strint May 26, 2022
b58b48a
fix
strint May 26, 2022
6975f33
fix
strint May 26, 2022
32bc1d1
Merge branch 'feat/logical_nccl_send_recv' into feat/zero_mix_with_mp
strint May 26, 2022
70d793e
Merge branch 'feat/zero_mix_with_mp' into feat/op_level_debug_backwar…
strint May 26, 2022
5209b02
fix proto
strint May 26, 2022
98dcf2d
refine test
strint May 26, 2022
484aff0
fix type
strint May 26, 2022
cce8efd
add test
strint May 27, 2022
a30b0c0
debug bad case
strint May 27, 2022
c73013f
refine test for eager and graph boxing
strint May 27, 2022
08b1f69
test case ready
strint May 27, 2022
821a8f4
simplify
strint May 30, 2022
29079a0
refine test
strint May 30, 2022
e49d380
fix buff size
strint May 30, 2022
9bd521f
Merge branch 'feat/logical_nccl_send_recv' into feat/zero_mix_with_mp
strint May 30, 2022
f82a317
Merge branch 'feat/zero_mix_with_mp' into feat/op_level_debug_backwar…
strint May 30, 2022
b374505
merge master
strint Jun 1, 2022
3fc1821
fix conflict
strint Jun 1, 2022
79e1290
refine zero nd
strint Jun 1, 2022
3225045
refine
strint Jun 1, 2022
bbe7114
Merge branch 'feat/zero_mix_with_mp' into feat/op_level_debug_backwar…
strint Jun 2, 2022
c751435
add full test
strint Jun 2, 2022
5c78921
revert change
strint Jun 2, 2022
bfa726c
refine split check
strint Jun 2, 2022
0bcbf30
fix typo
strint Jun 6, 2022
14c8520
rm log
strint Jun 6, 2022
56754bc
spit long func
strint Jun 6, 2022
459d6f5
Merge branch 'feat/zero_mix_with_mp' into feat/op_level_debug_backwar…
strint Jun 7, 2022
b78c4bd
refine
strint Jun 7, 2022
567af33
restore test
strint Jun 7, 2022
a1c0ff2
Merge branch 'feat/zero_mix_with_mp' into feat/op_level_debug_backwar…
strint Jun 7, 2022
c1508e3
merge master
strint Jun 7, 2022
b5bdbef
refine pass and mem debug
strint Jun 7, 2022
886914c
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
strint Jun 7, 2022
3957515
Merge branch 'feat/zero_mix_with_mp' into feat/op_level_debug_backwar…
strint Jun 7, 2022
b7dad59
merge master
strint Jun 7, 2022
a6aa236
repr dtype
strint Jun 7, 2022
9840be2
add placement
strint Jun 7, 2022
84ca778
Update optimizer_placement_optimization_pass.cpp
strint Jun 9, 2022
7095ec3
auto format by CI
oneflow-ci-bot Jun 9, 2022
5a5b9c5
Merge branch 'master' into feat/zero_mix_with_mp
strint Jun 9, 2022
b401e66
auto format by CI
oneflow-ci-bot Jun 9, 2022
3c7d2c5
Merge branch 'master' into feat/zero_mix_with_mp
strint Jun 9, 2022
2b0324e
fix static check
strint Jun 9, 2022
7d611c4
add tips for zero api change
strint Jun 10, 2022
04548ac
Merge branch 'master' into feat/zero_mix_with_mp
strint Jun 10, 2022
640487b
auto format by CI
oneflow-ci-bot Jun 10, 2022
9928cdd
Merge branch 'master' into feat/zero_mix_with_mp
mergify[bot] Jun 10, 2022
dc4e40d
Merge branch 'master' into feat/zero_mix_with_mp
mergify[bot] Jun 10, 2022
a54ff83
Merge branch 'feat/zero_mix_with_mp' of https://github.com/Oneflow-In…
strint Jun 10, 2022
e40e732
fix merge
strint Jun 10, 2022
451bb22
merge new update
strint Jun 10, 2022
3b3b1a9
merge master
strint Jun 10, 2022
ee7ea67
auto format by CI
oneflow-ci-bot Jun 10, 2022
9b2cf2d
auto format by CI
oneflow-ci-bot Jun 10, 2022
7b5c6f3
Merge branch 'master' into feat/op_level_debug_backward_sbp
strint Jun 13, 2022
bc556d4
refine get job api
strint Jun 13, 2022
ca8d852
refine graph util import order
strint Jun 13, 2022
1b3bbaa
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
strint Jun 13, 2022
66923eb
auto format by CI
oneflow-ci-bot Jun 13, 2022
7bb0c85
fix static check
strint Jun 13, 2022
79b1f3d
Merge branch 'feat/op_level_debug_backward_sbp' of https://github.com…
strint Jun 13, 2022
7fdcc67
Merge branch 'master' into feat/op_level_debug_backward_sbp
strint Jun 13, 2022
95cebef
auto format by CI
oneflow-ci-bot Jun 13, 2022
0e0101e
Merge branch 'master' into feat/op_level_debug_backward_sbp
strint Jun 13, 2022
fe1013c
Merge branch 'master' into feat/op_level_debug_backward_sbp
mergify[bot] Jun 13, 2022
45c2f37
fix special case
strint Jun 13, 2022
51931ca
Merge branch 'feat/op_level_debug_backward_sbp' of https://github.com…
strint Jun 13, 2022
c010589
Merge branch 'master' into feat/op_level_debug_backward_sbp
strint Jun 13, 2022
9966b89
Merge branch 'master' into feat/op_level_debug_backward_sbp
strint Jun 13, 2022
6b01da2
refine level print and add full dtype repr
strint Jun 14, 2022
c882457
rm useless
strint Jun 14, 2022
b7b7429
Merge branch 'master' into feat/op_level_debug_backward_sbp
mergify[bot] Jun 14, 2022
d637e20
Merge branch 'master' into feat/op_level_debug_backward_sbp
mergify[bot] Jun 14, 2022
49c8562
Merge branch 'master' into feat/op_level_debug_backward_sbp
strint Jun 14, 2022
22ea17a
Merge branch 'master' into feat/op_level_debug_backward_sbp
mergify[bot] Jun 14, 2022
1461e30
Merge branch 'master' into feat/op_level_debug_backward_sbp
mergify[bot] Jun 14, 2022
e145a0c
Merge branch 'master' into feat/op_level_debug_backward_sbp
mergify[bot] Jun 14, 2022
37999e7
Merge branch 'master' into feat/op_level_debug_backward_sbp
strint Jun 14, 2022
a865cc5
Merge branch 'master' into feat/op_level_debug_backward_sbp
mergify[bot] Jun 14, 2022
9ecad29
Merge branch 'master' into feat/op_level_debug_backward_sbp
mergify[bot] Jun 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion oneflow/api/python/framework/dtype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
[](int t) { // __setstate__
return CHECK_JUST(DType::Get(DataType(t)));
}))
.def_property_readonly("bytes", [](const Symbol<DType>& dtype) { return dtype->bytes(); });
.def_property_readonly("bytes", [](const Symbol<DType>& dtype) { return dtype->bytes(); })
.def("get", [](const int data_type_enum) {
return CHECK_JUST(DType::Get(static_cast<DataType>(data_type_enum)));
});

m.attr("bool") = &CHECK_JUST(DType::Get(DataType::kBool));
m.attr("char") = &CHECK_JUST(DType::Get(DataType::kChar));
Expand Down
9 changes: 8 additions & 1 deletion oneflow/api/python/framework/nn_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>
#include <memory>
#include <string>
Expand Down Expand Up @@ -41,6 +42,11 @@ Maybe<py::object> APINNGraphAdditionalVarTensors(const std::shared_ptr<NNGraph>&
py::list tensor_list = py::cast(tensors);
return py::cast<py::object>(tensor_list);
}

Maybe<py::bytes> APINNGraphGetCompiledSerializedJob(const std::shared_ptr<NNGraph>& graph) {
const auto job = *JUST(graph->GetCompiledJob());
return py::bytes(job.SerializeAsString());
}
} // namespace

ONEFLOW_API_PYBIND11_MODULE("nn.graph.", m) {
Expand Down Expand Up @@ -75,7 +81,8 @@ ONEFLOW_API_PYBIND11_MODULE("nn.graph.", m) {
&NNGraph::RegisterAdditionalVarOpNamesAndTensorsToBeLoaded)
.def_property_readonly("additional_var_names", &APINNGraphAdditionalVarNames)
.def_property_readonly("additional_var_tensors", &APINNGraphAdditionalVarTensors)
.def("complie_and_init_runtime", &NNGraph::CompileAndInitRuntime);
.def("complie_and_init_runtime", &NNGraph::CompileAndInitRuntime)
.def("get_compiled_job_str", &APINNGraphGetCompiledSerializedJob);

m.def("RunLazyNNGraph", &RunLazyNNGraph);
m.def("SoftSyncNNGraphBuffers", &SoftSyncNNGraphBuffers);
Expand Down
21 changes: 21 additions & 0 deletions oneflow/api/python/symbol/placement_symbol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#include <pybind11/stl.h>
#include <pybind11/operators.h>

#include "oneflow/core/common/maybe.h"
#include "oneflow/extension/python/numpy.h"
#include "oneflow/api/python/framework/size.h"
#include "oneflow/api/python/of_api_registry.h"
Expand Down Expand Up @@ -63,6 +64,18 @@ struct PlacementSymbolExportUtil {
return parallel_desc;
}

static Maybe<ParallelDesc> CreateParallelDesc(const std::string& proto_str) {
ParallelConf parallel_conf;
CHECK_OR_RETURN(TxtString2PbMessage(proto_str, &parallel_conf));
std::shared_ptr<ParallelDesc> parallel_desc;
JUST(PhysicalRun([&parallel_desc, &parallel_conf](InstructionsBuilder* builder) -> Maybe<void> {
parallel_desc = JUST(builder->GetParallelDescSymbol(parallel_conf));
return Maybe<void>::Ok();
}));

return parallel_desc;
}

static Maybe<std::vector<std::string>> ParseAndFormatRanks(const py::dict& device_ids) {
std::vector<std::pair<int64_t, int64_t>> machine_device_id_vec;
for (const auto& pair : device_ids) {
Expand Down Expand Up @@ -137,6 +150,10 @@ struct PlacementSymbolExportUtil {
return SymbolOf(*JUST(CreateParallelDesc(type, *formated_machine_device_ids, shape)));
}

static Maybe<Symbol<ParallelDesc>> CreateParallelDescSymbol(const std::string& proto_str) {
return SymbolOf(*JUST(CreateParallelDesc(proto_str)));
}

static Maybe<Symbol<ParallelDesc>> AllDevicePlacement(const std::string& type) {
static thread_local HashMap<std::string, Symbol<ParallelDesc>> device_tag2placement;
CHECK_NOTNULL((Global<ResourceDesc, ForEnv>::Get()));
Expand Down Expand Up @@ -213,6 +230,10 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
return PlacementSymbolExportUtil::CreateParallelDescSymbol(type, ranks).GetOrThrow();
}),
py::arg("type"), py::arg("ranks"))
.def(py::init([](const std::string& proto_str) {
return PlacementSymbolExportUtil::CreateParallelDescSymbol(proto_str).GetOrThrow();
}),
py::arg("proto_str"))
.def_property_readonly(
"device_type",
[](Symbol<ParallelDesc> p) {
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/framework/nn_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class NNGraph final : public NNGraphIf {
Maybe<std::vector<std::string>> GetAdditionalVarOpNames() const;
Maybe<std::vector<std::shared_ptr<one::Tensor>>> GetAdditionalVarOpTensors() const;
Maybe<void> CompileAndInitRuntime();
Maybe<Job> GetCompiledJob() { return job_; }
xiacijie marked this conversation as resolved.
Show resolved Hide resolved
Maybe<void> Close();

private:
Expand Down
7 changes: 5 additions & 2 deletions oneflow/core/job/job_build_and_infer_ctx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ void JobBuildAndInferCtx::AddOpAndUpdateJobParallelViewConf(const OperatorConf&
(*module_name2module_conf)[module_name].set_name(scope.scope_proto().module_name());
}

(*module_name2module_conf)[module_name].add_ops()->CopyFrom(operator_conf);
*((*module_name2module_conf)[module_name].add_ops()) = operator_conf.name();
}
}

Expand Down Expand Up @@ -999,7 +999,7 @@ Maybe<void> LazyJobBuildAndInferCtx::Complete() {
int32_t pass_cnt = 0;
const int64_t prev_v = FLAGS_v;
auto DoPass = [&](const std::string& pass_name, int32_t cnt = 0) -> Maybe<void> {
VLOG(1) << job_name << " is compiling with pass"
VLOG(1) << job_name << " start compiling with pass"
<< " pass_cnt_" + std::to_string(pass_cnt) + "-" + pass_name
<< (cnt > 0 ? std::to_string(cnt) : "");
if (unlikely(NeedLogJob(pass_name))) {
Expand All @@ -1013,6 +1013,9 @@ Maybe<void> LazyJobBuildAndInferCtx::Complete() {
std::string cnt_str = cnt > 0 ? std::to_string(cnt) : "";
LogJob("pass_cnt_" + std::to_string(pass_cnt) + "-" + pass_name + cnt_str + "-after");
}
VLOG(1) << job_name << " finish compiling with pass"
<< " pass_cnt_" + std::to_string(pass_cnt) + "-" + pass_name
<< (cnt > 0 ? std::to_string(cnt) : "");
++pass_cnt;
return Maybe<void>::Ok();
};
Expand Down
44 changes: 44 additions & 0 deletions oneflow/core/job/job_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ limitations under the License.
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/job/job.pb.h"
#include "oneflow/core/job/sbp_parallel.pb.h"
#include "oneflow/core/operator/op_conf.pb.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/framework/scope_util.h"

namespace oneflow {

Expand Down Expand Up @@ -170,6 +173,7 @@ Maybe<void> JobBuilder::AddOp(const ParallelConf& parallel_conf, const OperatorC
OperatorConf* mut_op_conf = job_->mutable_net()->add_op();
*mut_op_conf = op_conf;
CHECK_OR_RETURN(op_name2op_conf_.emplace(op_conf.name(), mut_op_conf).second);
AddOpToModuleConf(op_conf);
AddOpNamesToPlacementGroup({op_conf.name()}, parallel_conf);
return Maybe<void>::Ok();
}
Expand All @@ -185,10 +189,35 @@ void JobBuilder::AddOps(const ParallelConf& parallel_conf,
*mut_op_conf = op_conf;
CHECK(op_name2op_conf_.emplace(op_conf.name(), mut_op_conf).second);
op_names.emplace_back(op_conf.name());
AddOpToModuleConf(op_conf);
}
AddOpNamesToPlacementGroup(op_names, parallel_conf);
}

void JobBuilder::AddOpToModuleConf(const OperatorConf& op_conf) {
// set up the module config
if (Global<symbol::Storage<Scope>>::Get()->Has(op_conf.scope_symbol_id())) {
const auto& scope = Global<symbol::Storage<Scope>>::Get()->Get(op_conf.scope_symbol_id());
if (scope.scope_proto().has_module_name()) {
const auto& module_name = scope.scope_proto().module_name();
auto* module_name2module_conf = job_->mutable_module_name2module_conf();
if (!(*module_name2module_conf)[module_name].has_name()) {
(*module_name2module_conf)[module_name].set_name(scope.scope_proto().module_name());
}

*((*module_name2module_conf)[module_name].add_ops()) = op_conf.name();
return;
}
}
const auto& module_name = job_->job_conf().job_name();
auto* module_name2module_conf = job_->mutable_module_name2module_conf();
if (!(*module_name2module_conf)[module_name].has_name()) {
(*module_name2module_conf)[module_name].set_name(module_name);
}

*((*module_name2module_conf)[module_name].add_ops()) = op_conf.name();
}

void JobBuilder::AddOpNamesToPlacementGroup(const std::vector<std::string>& op_names,
const ParallelConf& parallel_conf) {
PlacementGroup* placement_group = nullptr;
Expand Down Expand Up @@ -230,6 +259,21 @@ void JobBuilder::RemoveOpByName(const std::unordered_set<std::string>& removing_
for (const OperatorConf& op_conf : net.op()) {
if (removing_names.count(op_conf.name()) == 0) { *(job_->mutable_net()->add_op()) = op_conf; }
}
// Update module conf
auto module_confs_map = job_->module_name2module_conf();
job_->clear_module_name2module_conf();
for (const auto& module_conf_pair : module_confs_map) {
const auto& module_name = module_conf_pair.first;
auto* module_name2module_conf = job_->mutable_module_name2module_conf();
if (!(*module_name2module_conf)[module_name].has_name()) {
(*module_name2module_conf)[module_name].set_name(module_name);
}
for (const auto& op_name : module_conf_pair.second.ops()) {
if (removing_names.count(op_name) == 0) {
*((*module_name2module_conf)[module_name].add_ops()) = op_name;
}
}
}
// Update placement
auto placement_group = job_->placement().placement_group();
job_->mutable_placement()->clear_placement_group();
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/job/job_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class JobBuilder final {
private:
void AddOpNamesToPlacementGroup(const std::vector<std::string>& op_names,
const ParallelConf& parallel_conf);
void AddOpToModuleConf(const OperatorConf& op_conf);

Job* job_;
HashMap<std::string, OperatorConf*> op_name2op_conf_;
Expand Down
4 changes: 1 addition & 3 deletions oneflow/core/job/module_conf.proto
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
syntax = "proto2";
package oneflow;

import "oneflow/core/operator/op_conf.proto";

message ModuleConf {
required string name = 1;
repeated OperatorConf ops = 2;
repeated string ops = 2;
}
45 changes: 25 additions & 20 deletions oneflow/core/job/plan_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -861,8 +861,9 @@ namespace {
struct MemBlockMemoryInfo {
int64_t mem_block_id;
int64_t mem_block_mem_size;
bool is_reused;
std::vector<std::string> ordered_op_names;
MemBlockMemoryInfo() : mem_block_id(-1), mem_block_mem_size(-1) {}
MemBlockMemoryInfo() : mem_block_id(-1), mem_block_mem_size(-1), is_reused(false) {}
};

struct ChunkMemoryInfo {
Expand Down Expand Up @@ -924,7 +925,10 @@ void PlanUtil::PlanMemoryLog(Plan* plan, const std::string& plan_name) {
if (mem_block.mem_case().has_device_cuda_mem()) {
if (mem_block.has_chunk_id()) {
rank_memory_info.chunk_info.mem_block_ids.push_back(mem_block_id);
info.is_reused = true;
} else {
rank_memory_info.chunk_info.mem_block_ids.push_back(mem_block_id);
Copy link
Contributor

@chengtbf chengtbf Jun 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么要把没有 chunk 的 mem block 也放在 chunk info 里呢? 好处是? 现在 chunk info 里增加了一大堆模型的 block,看起来不方便。

语义上也不正确。 这些 mem block 是不参与内存复用的,不在 chunk 内。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果要加,也应该是加到 rank memory info 里,而不是 chunk info 里。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是为了 debug 那些 Memory out of Chunk 的 op,之前有次这里数据比较异常,所以也加进来做 debug了。

info.is_reused = false;
rank_memory_info.not_reused_mem_size += mem_block.mem_size();
rank_memory_info.total_mem_size += mem_block.mem_size();
if (mem_block.has_variable_op_name()) {
Expand Down Expand Up @@ -968,25 +972,26 @@ void PlanUtil::PlanMemoryLog(Plan* plan, const std::string& plan_name) {
<< B2MiB(rank_memory_info.eager_variable_total_mem_size) << " MiB ].";
}

if (IsInDebugMode()) {
for (const auto& rank_memory_info : rank_device_memory_infos) {
int64_t chunk_id = rank_memory_info.chunk_info.chunk_id;
VLOG(2) << " For detail: Chunk id: " << chunk_id << " has "
<< rank_memory_info.chunk_info.mem_block_ids.size() << " MemBlocks.";
for (int64_t mem_block_id : rank_memory_info.chunk_info.mem_block_ids) {
CHECK(mem_block_id2info.find(mem_block_id) != mem_block_id2info.end());
const auto& mem_block_info = mem_block_id2info.at(mem_block_id);
VLOG(2) << " In Chunk id: " << chunk_id << " MemBlock id: " << mem_block_id
<< " has num = " << mem_block_info.ordered_op_names.size()
<< " ops with mem size = " << B2MiB(mem_block_info.mem_block_mem_size);
}
for (int64_t mem_block_id : rank_memory_info.chunk_info.mem_block_ids) {
CHECK(mem_block_id2info.find(mem_block_id) != mem_block_id2info.end());
const auto& mem_block_info = mem_block_id2info.at(mem_block_id);
for (int64_t i = 0; i < mem_block_info.ordered_op_names.size(); ++i) {
VLOG(3) << " In Chunk id: " << chunk_id << " MemBlock id: " << mem_block_id
<< " order: " << i << " op_name: " << mem_block_info.ordered_op_names.at(i);
}
for (const auto& rank_memory_info : rank_device_memory_infos) {
int64_t chunk_id = rank_memory_info.chunk_info.chunk_id;
int64_t device_id = rank_memory_info.device_id;
int64_t not_reuse_size = rank_memory_info.not_reused_mem_size;
VLOG(2) << " For detail: Chunk id: " << chunk_id << " has "
<< rank_memory_info.chunk_info.mem_block_ids.size() << " MemBlocks"
<< " not reused size = " << B2MiB(not_reuse_size);
for (int64_t mem_block_id : rank_memory_info.chunk_info.mem_block_ids) {
CHECK(mem_block_id2info.find(mem_block_id) != mem_block_id2info.end());
const auto& mem_block_info = mem_block_id2info.at(mem_block_id);
VLOG(2) << " In Device: " << device_id << " Chunk id: " << chunk_id
<< " MemBlock id: " << mem_block_id
<< " has num = " << mem_block_info.ordered_op_names.size()
<< " ops with mem size = " << B2MiB(mem_block_info.mem_block_mem_size)
<< " is reused " << mem_block_info.is_reused;
for (int64_t i = 0; i < mem_block_info.ordered_op_names.size(); ++i) {
VLOG(3) << " In Device: " << device_id << " Chunk id: " << chunk_id
<< " In MemBlock id: " << mem_block_id << " order: " << i << " is reused "
<< mem_block_info.is_reused
<< " op_name: " << mem_block_info.ordered_op_names.at(i);
}
}
}
Expand Down
49 changes: 34 additions & 15 deletions python/oneflow/framework/graph_build_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import oneflow.framework.scope_util as scope_util
import oneflow.framework.session_context as session_context
from oneflow.framework.tensor import Tensor

import oneflow._oneflow_internal._C as _C

lazy_mode = oneflow._oneflow_internal.lazy_mode
Expand All @@ -42,9 +41,11 @@ def graph_build_context(config_proto, session):
config_proto_str, oneflow.placement("cpu", [0]), False, # is_mirrored
)

graph_scope = _make_new_graph_scope(new_scope, config_proto.job_name)

with lazy_mode.guard(True):
with JobBuildAndInferCtx(config_proto):
with BlockScopeContext(prev_scope, new_scope):
with BlockScopeContext(prev_scope, graph_scope):
yield


Expand Down Expand Up @@ -118,6 +119,36 @@ def __exit__(self, exc_type, exc_val, exc_tb):
)


def _make_new_scope(prev_scope, scope_proto_str_setter):
new_scope = None

def build_scope(builder):
nonlocal new_scope
new_scope = builder.BuildScopeByProtoStrSetter(
prev_scope, scope_proto_str_setter
)
assert new_scope is not None

oneflow._oneflow_internal.deprecated.PhysicalRun(build_scope)
oneflow._oneflow_internal.eager.Sync()
return new_scope


def _make_new_graph_scope(prev_scope, graph_name):
assert prev_scope is not None
attr_dict = dict()
name2default = session_context.GetDefaultSession().scope_attr_name2default_val

def scope_proto_str_setter(serialized_scope_proto: str):
scope_proto = text_format.Parse(
serialized_scope_proto, scope_pb2_util.ScopeProto()
)
scope_proto.module_name = graph_name
return str(text_format.MessageToString(scope_proto))

return _make_new_scope(prev_scope, scope_proto_str_setter)


def make_new_block_scope(prev_scope, block):
assert prev_scope is not None
assert block is not None
Expand Down Expand Up @@ -147,21 +178,9 @@ def scope_proto_str_setter(serialized_scope_proto: str):
# set module name
if isinstance(block, oneflow.nn.graph.block.ModuleBlock):
scope_proto.module_name = block.name_prefix + block.name

return str(text_format.MessageToString(scope_proto))

new_scope = None

def build_scope(builder):
nonlocal new_scope
new_scope = builder.BuildScopeByProtoStrSetter(
prev_scope, scope_proto_str_setter
)
assert new_scope is not None

oneflow._oneflow_internal.deprecated.PhysicalRun(build_scope)
oneflow._oneflow_internal.eager.Sync()
return new_scope
return _make_new_scope(prev_scope, scope_proto_str_setter)


def scope_to_proto(scope):
Expand Down
13 changes: 8 additions & 5 deletions python/oneflow/nn/graph/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
self._origin = None
self._scope = None
self._prev_scope = None
assert belonged_graph is None or isinstance(belonged_graph, weakref.ProxyTypes)
self._belonged_graph = belonged_graph
self.config = BlockConfig()

Expand Down Expand Up @@ -563,11 +564,13 @@ def _ops_repr(self):
)

if self._belonged_graph.is_compiled:
module_conf = self._belonged_graph._graph_proto.module_name2module_conf[
self.name_prefix + self.name
]

return operators_repr(module_conf.ops)
if self._belonged_graph._compiled_graph_proto is not None:
module_conf = self._belonged_graph._compiled_graph_proto.module_name2module_conf[
self.name_prefix + self.name
]
return operators_repr(
module_conf.ops, self._belonged_graph._compiled_graph_proto
)

return []

Expand Down
Loading