-
Notifications
You must be signed in to change notification settings - Fork 796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Run nn.Graph by VM #9884
Run nn.Graph by VM #9884
Changes from 9 commits
6658783
69decc3
0314b66
96addf4
f543907
f527293
0e7097c
ed42968
2179698
fa3e6f1
6a99a3c
c0ec050
b4a18e6
51846ce
2904eaa
21e8b3e
b015649
26d5aaa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
/* | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
#include "oneflow/core/common/container_util.h" | ||
#include "oneflow/core/framework/nn_graph.h" | ||
#include "oneflow/core/framework/op_builder.h" | ||
#include "oneflow/core/framework/op_interpreter.h" | ||
#include "oneflow/core/functional/functional_api.yaml.h" | ||
#include "oneflow/core/job/job.pb.h" | ||
#include "oneflow/core/profiler/profiler.h" | ||
#include "oneflow/core/framework/local_tensor_infer_cache.h" | ||
|
||
namespace oneflow { | ||
namespace one { | ||
|
||
using Env = std::map<std::string, std::shared_ptr<Tensor>>; | ||
|
||
Maybe<Env> InitEnv(const one::TensorTuple& graph_inputs, const std::shared_ptr<NNGraph>& graph) { | ||
Env env; | ||
for (const auto& [name, tensor] : graph->variable_op_name2tensor()) { | ||
env.emplace(name + "/out", tensor); | ||
} | ||
for (size_t i = 0; i < graph->inputs_op_names().size(); ++i) { | ||
const auto& name = graph->inputs_op_names()[i]; | ||
env.emplace(name + "/out", JUST(VectorAt(graph_inputs, i))); | ||
} | ||
return env; | ||
} | ||
|
||
Maybe<UserOpExpr> OpConfToUserOpExpr(const OperatorConf& op_conf) { | ||
CHECK_OR_RETURN(op_conf.has_user_conf()); | ||
const auto& user_conf = op_conf.user_conf(); | ||
auto builder = OpBuilder(user_conf.op_type_name()); | ||
for (const auto& pair : user_conf.attr()) { builder.Attr(pair.first, pair.second); } | ||
for (const auto& pair : user_conf.input()) { | ||
// ignore "UserSourceOpTickInput" | ||
if (pair.first == "UserSourceOpTickInput") { continue; } | ||
builder.Input(pair.first, pair.second.s_size()); | ||
} | ||
for (const auto& pair : user_conf.output()) { builder.Output(pair.first, pair.second.s_size()); } | ||
return JUST(builder.Build()); | ||
} | ||
|
||
template<typename Func> | ||
Maybe<TensorTuple> GetInputTensors(const UserOpConf& user_conf, const Env& env, | ||
const Func& preprocess) { | ||
TensorTuple inputs; | ||
for (const auto& pair : user_conf.input()) { | ||
if (pair.first == "UserSourceOpTickInput") { continue; } | ||
for (const auto& name : pair.second.s()) { | ||
inputs.emplace_back(preprocess(JUST(MapAt(env, name)))); | ||
} | ||
} | ||
return inputs; | ||
} | ||
|
||
OpArgsVector<std::string> GetOutputNamesOfOp(const UserOpConf& user_conf) { | ||
OpArgsVector<std::string> output_names; | ||
for (const auto& pair : user_conf.output()) { | ||
for (const auto& name : pair.second.s()) { output_names.emplace_back(name); } | ||
} | ||
return output_names; | ||
} | ||
|
||
// Only support a limited subset of view ops for now | ||
bool IsViewOp(const std::shared_ptr<UserOpExpr>& op) { | ||
return op->op_type_name() == "reshape" || op->op_type_name() == "expand_dims"; | ||
} | ||
|
||
Maybe<void> RunViewOp(const std::shared_ptr<UserOpExpr>& op, Env& env, const TensorTuple& inputs, | ||
const OpArgsVector<std::string>& output_names) { | ||
// eliminate the memcpy of view ops | ||
CHECK_OR_RETURN(IsViewOp(op)); | ||
const std::shared_ptr<const LocalTensorInferResult> result = | ||
JUST([&]() -> Maybe<const LocalTensorInferResult> { | ||
LocalTensorMetaInferArgs infer_args; | ||
JUST(infer_args.Init(op->base_attrs(), JUST(inputs[0]->device()), inputs)); | ||
return JUST(op->mut_local_tensor_infer_cache()->GetOrInfer(infer_args)); | ||
}()); | ||
const auto& output_shape = result->output_tensor_metas()[0]->shape(); | ||
const auto output = | ||
JUST(view::BasicView(inputs[0], output_shape, JUST(inputs[0]->storage_offset()))); | ||
env.emplace(output_names[0], output); | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
Maybe<void> RunNormalOp(const std::shared_ptr<UserOpExpr>& op, Env& env, const TensorTuple& inputs, | ||
const OpArgsVector<std::string>& output_names) { | ||
TensorTuple outputs(output_names.size()); | ||
static EagerLocalInterpreter it; | ||
static AttrMap empty_attr_map; | ||
JUST(it.Apply(*op, inputs, &outputs, empty_attr_map)); | ||
for (size_t i = 0; i < output_names.size(); ++i) { | ||
env.emplace(output_names[i], JUST(VectorAt(outputs, i))); | ||
} | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
// tensors in dead_tensors[i] will not be accessed any more after i-th op | ||
// so they can be released once i-th op's execution finishes. | ||
std::vector<std::vector<std::string>> GetDeadTensorVector(const Job& job) { | ||
std::vector<std::vector<std::string>> dead_tensors(job.net().op_size()); | ||
std::set<std::string> visited; | ||
for (int i = job.net().op_size() - 1; i >= 0; --i) { | ||
const auto& op_conf = job.net().op(i); | ||
// do not release the graph output tensors | ||
if (op_conf.has_output_conf()) { | ||
const auto& output_conf = op_conf.output_conf(); | ||
visited.insert(output_conf.in()); | ||
} else if (op_conf.has_user_conf()) { | ||
const auto& user_conf = op_conf.user_conf(); | ||
for (const auto& pair : user_conf.input()) { | ||
if (pair.first == "UserSourceOpTickInput") { continue; } | ||
for (const auto& name : pair.second.s()) { | ||
if (visited.find(name) == visited.end()) { | ||
dead_tensors[i].push_back(name); | ||
visited.insert(name); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
return dead_tensors; | ||
} | ||
|
||
Maybe<void> InitOpExprs(const std::shared_ptr<NNGraph>& graph) { | ||
CHECK_OR_RETURN(graph->cached_op_exprs.empty()); | ||
|
||
const auto& job = graph->job(); | ||
for (int i = 0; i < job.net().op_size(); i++) { | ||
const auto& op_conf = job.net().op(i); | ||
if (op_conf.has_user_conf()) { | ||
const auto op_expr = JUST(OpConfToUserOpExpr(op_conf)); | ||
graph->cached_op_exprs.push_back(op_expr); | ||
} else { | ||
graph->cached_op_exprs.push_back(nullptr); | ||
} | ||
} | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
Maybe<one::TensorTuple> InterpretJob(const one::TensorTuple& graph_inputs, | ||
const std::shared_ptr<NNGraph>& graph) { | ||
if (graph->cached_op_exprs.empty()) { JUST(InitOpExprs(graph)); } | ||
|
||
const auto& job = graph->job(); | ||
auto env = *JUST(InitEnv(graph_inputs, graph)); | ||
|
||
// See comments above GetDeadTensorVector's definition for more details | ||
const auto dead_tensors = GetDeadTensorVector(job); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dead tensor 的含义是什么意思呢,可以注释下 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的,在 GetDeadTensorVector 的定义处有一个注释,我再在这里指明一下 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已添加 |
||
|
||
one::TensorTuple graph_outputs; | ||
for (int i = 0; i < job.net().op_size(); i++) { | ||
const auto& op_conf = job.net().op(i); | ||
if (op_conf.has_user_conf()) { | ||
auto op = CHECK_NOTNULL(graph->cached_op_exprs[i]); | ||
const auto& user_conf = op_conf.user_conf(); | ||
OF_PROFILER_RANGE_GUARD(user_conf.op_type_name()); | ||
TensorTuple inputs = | ||
*JUST(GetInputTensors(user_conf, env, [&op_conf](const std::shared_ptr<Tensor>& tensor) { | ||
return CHECK_JUST(functional::To(tensor, op_conf.device_tag())); | ||
})); | ||
OpArgsVector<std::string> output_names = GetOutputNamesOfOp(user_conf); | ||
if (IsViewOp(op)) { | ||
JUST(RunViewOp(op, env, inputs, output_names)); | ||
} else { | ||
JUST(RunNormalOp(op, env, inputs, output_names)); | ||
} | ||
for (const auto& name : dead_tensors[i]) { CHECK_EQ_OR_RETURN(env.erase(name), 1); } | ||
} else if (op_conf.has_output_conf()) { | ||
const auto& output_conf = op_conf.output_conf(); | ||
graph_outputs.emplace_back(JUST(MapAt(env, output_conf.in()))); | ||
} | ||
} | ||
return graph_outputs; | ||
} | ||
} // namespace one | ||
} // namespace oneflow |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
/* | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
#include "oneflow/core/common/maybe.h" | ||
#include "oneflow/core/job/job.pb.h" | ||
|
||
namespace oneflow { | ||
class NNGraph; | ||
namespace one { | ||
class TensorTuple; | ||
Maybe<one::TensorTuple> InterpretJob(const one::TensorTuple& inputs, | ||
const std::shared_ptr<NNGraph>& graph); | ||
} // namespace one | ||
} // namespace oneflow |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
""" | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
import os | ||
|
||
os.environ["ONEFLOW_RUN_GRAPH_BY_VM"] = "1" | ||
os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" | ||
os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1" | ||
|
||
import oneflow as flow | ||
import numpy as np | ||
|
||
|
||
class Graph(flow.nn.Graph): | ||
def __init__(self, m): | ||
super().__init__() | ||
self.m = m | ||
|
||
def build(self, x): | ||
return self.m(x) | ||
|
||
|
||
class M(flow.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.w = flow.nn.Parameter(flow.randn(4)) | ||
|
||
def forward(self, x): | ||
# these broadcast_sub and cast ops will be | ||
# eliminated by nn.Graph | ||
w1 = self.w - self.w - self.w | ||
x = x * w1.to(flow.float32) | ||
return x | ||
|
||
|
||
def test_run_graph_by_vm(capsys): | ||
m = M().eval() | ||
g = Graph(m) | ||
|
||
input = flow.randn(4) | ||
graph_output = g(input) | ||
eager_output = m(input) | ||
assert graph_output.shape == (4,) | ||
assert np.allclose(graph_output, eager_output) | ||
|
||
input = flow.randn(3, 4) | ||
graph_output = g(input) | ||
eager_output = m(input) | ||
assert graph_output.shape == (3, 4) | ||
assert np.allclose(graph_output, eager_output) | ||
|
||
# Test the optimization in graph works. | ||
# broadcast_sub and cast ops are pruned. | ||
print(g) | ||
assert "broadcast_sub" not in capsys.readouterr().out | ||
assert "cast" not in capsys.readouterr().out | ||
assert "broadcast_mul" not in capsys.readouterr().out | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个怎么看起来不像标准的 unittest,ci 能跑到这个 case 么 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以,是 pytest 的写法,比 python 自带的 unittest 好用不少,CI 已经在用 pytest 跑了 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dead tensor 看起来主要就是会 outdated 的 activation tensor ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,dead_tensors[i] 表示第 i 个 op 之后会变为 dead 的 tensors,如果有更好的名字也可以提出
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OudatedTensorAfterOp?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以 :good: 已修改