Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run nn.Graph by VM #9884

Merged
merged 18 commits into from
Feb 25, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions oneflow/api/python/framework/nn_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "oneflow/core/register/blob.h"
#include "oneflow/core/job/job.pb.h"
#include "oneflow/core/job/job_ir.h"
#include "oneflow/core/job/job_interpreter.h"

namespace py = pybind11;

Expand Down Expand Up @@ -114,6 +115,7 @@ ONEFLOW_API_PYBIND11_MODULE("nn.graph.", m) {
.def("get_current_job_str", &APINNGraphGetCurrentSerializedJob);

m.def("RunLazyNNGraph", &RunLazyNNGraph);
m.def("RunLazyNNGraphByVM", &one::InterpretJob);
m.def("SoftSyncNNGraphBuffers", &SoftSyncNNGraphBuffers);
m.def("AddTensorAsGraphLoss", &AddTensorAsGraphLoss);
m.def("MarkVariableGradients", [](const std::vector<std::shared_ptr<one::Tensor>>& variables,
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/framework/nn_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include "oneflow/core/common/util.h"
#include "oneflow/core/framework/nn_graph_if.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/framework/multi_client_session_context.h"
Expand Down Expand Up @@ -98,6 +99,8 @@ class NNGraph final : public NNGraphIf {
Maybe<void> InitRuntime();
Maybe<void> CompileAndInitRuntime();
Maybe<void> Close();
const auto variable_op_name2tensor() const { return variable_op_name2tensor_; }
std::vector<std::shared_ptr<one::UserOpExpr>> cached_op_exprs;

private:
Maybe<void> RegisterFreeEagerTensorsToVariableOpNames();
Expand Down
174 changes: 174 additions & 0 deletions oneflow/core/job/job_interpreter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/framework/nn_graph.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/core/job/job.pb.h"
#include "oneflow/core/profiler/profiler.h"
#include "oneflow/core/framework/local_tensor_infer_cache.h"

namespace oneflow {
namespace one {

using Env = std::map<std::string, std::shared_ptr<Tensor>>;

Maybe<Env> InitEnv(const one::TensorTuple& graph_inputs, const std::shared_ptr<NNGraph>& graph) {
Env env;
for (const auto& [name, tensor] : graph->variable_op_name2tensor()) {
env.emplace(name + "/out", tensor);
}
for (size_t i = 0; i < graph->inputs_op_names().size(); ++i) {
const auto& name = graph->inputs_op_names()[i];
env.emplace(name + "/out", JUST(VectorAt(graph_inputs, i)));
}
return env;
}

Maybe<UserOpExpr> OpConfToUserOpExpr(const OperatorConf& op_conf) {
CHECK_OR_RETURN(op_conf.has_user_conf());
const auto& user_conf = op_conf.user_conf();
auto builder = OpBuilder(user_conf.op_type_name());
for (const auto& pair : user_conf.attr()) { builder.Attr(pair.first, pair.second); }
for (const auto& pair : user_conf.input()) {
// ignore "UserSourceOpTickInput"
if (pair.first == "UserSourceOpTickInput") { continue; }
builder.Input(pair.first, pair.second.s_size());
}
for (const auto& pair : user_conf.output()) { builder.Output(pair.first, pair.second.s_size()); }
return JUST(builder.Build());
}

template<typename Func>
Maybe<TensorTuple> GetInputTensors(const UserOpConf& user_conf, const Env& env,
const Func& preprocess) {
TensorTuple inputs;
for (const auto& pair : user_conf.input()) {
if (pair.first == "UserSourceOpTickInput") { continue; }
for (const auto& name : pair.second.s()) {
inputs.emplace_back(preprocess(JUST(MapAt(env, name))));
}
}
return inputs;
}

OpArgsVector<std::string> GetOutputNamesOfOp(const UserOpConf& user_conf) {
OpArgsVector<std::string> output_names;
for (const auto& pair : user_conf.output()) {
for (const auto& name : pair.second.s()) { output_names.emplace_back(name); }
}
return output_names;
}

// Only support a limited subset of view ops for now
bool IsViewOp(const std::shared_ptr<UserOpExpr>& op) {
return op->op_type_name() == "reshape" || op->op_type_name() == "expand_dims";
}

Maybe<void> RunViewOp(const std::shared_ptr<UserOpExpr>& op, Env& env, const TensorTuple& inputs,
const OpArgsVector<std::string>& output_names) {
// eliminate the memcpy of view ops
CHECK_OR_RETURN(IsViewOp(op));
const std::shared_ptr<const LocalTensorInferResult> result =
JUST([&]() -> Maybe<const LocalTensorInferResult> {
LocalTensorMetaInferArgs infer_args;
JUST(infer_args.Init(op->base_attrs(), JUST(inputs[0]->device()), inputs));
return JUST(op->mut_local_tensor_infer_cache()->GetOrInfer(infer_args));
}());
const auto& output_shape = result->output_tensor_metas()[0]->shape();
const auto output =
JUST(view::BasicView(inputs[0], output_shape, JUST(inputs[0]->storage_offset())));
env.emplace(output_names[0], output);
return Maybe<void>::Ok();
}

Maybe<void> RunNormalOp(const std::shared_ptr<UserOpExpr>& op, Env& env, const TensorTuple& inputs,
const OpArgsVector<std::string>& output_names) {
TensorTuple outputs(output_names.size());
static EagerLocalInterpreter it;
static AttrMap empty_attr_map;
JUST(it.Apply(*op, inputs, &outputs, empty_attr_map));
for (size_t i = 0; i < output_names.size(); ++i) {
env.emplace(output_names[i], JUST(VectorAt(outputs, i)));
}
return Maybe<void>::Ok();
}

// tensors in dead_tensors[i] will not be accessed any more after i-th op
// so they can be released once i-th op's execution finishes.
std::vector<std::vector<std::string>> GetDeadTensorVector(const Job& job) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dead tensor 看起来主要就是会 outdated 的 activation tensor ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,dead_tensors[i] 表示第 i 个 op 之后会变为 dead 的 tensors,如果有更好的名字也可以提出

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OudatedTensorAfterOp?

Copy link
Contributor Author

@daquexian daquexian Feb 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以 :good: 已修改

std::vector<std::vector<std::string>> dead_tensors(job.net().op_size());
std::set<std::string> visited;
for (int i = job.net().op_size() - 1; i >= 0; --i) {
const auto& op_conf = job.net().op(i);
// do not release the graph output tensors
if (op_conf.has_output_conf()) {
const auto& output_conf = op_conf.output_conf();
visited.insert(output_conf.in());
} else if (op_conf.has_user_conf()) {
const auto& user_conf = op_conf.user_conf();
for (const auto& pair : user_conf.input()) {
if (pair.first == "UserSourceOpTickInput") { continue; }
for (const auto& name : pair.second.s()) {
if (visited.find(name) == visited.end()) {
dead_tensors[i].push_back(name);
visited.insert(name);
}
}
}
}
}
return dead_tensors;
}

Maybe<void> InitOpExprs(const std::shared_ptr<NNGraph>& graph) {
CHECK_OR_RETURN(graph->cached_op_exprs.empty());

const auto& job = graph->job();
for (int i = 0; i < job.net().op_size(); i++) {
const auto& op_conf = job.net().op(i);
if (op_conf.has_user_conf()) {
const auto op_expr = JUST(OpConfToUserOpExpr(op_conf));
graph->cached_op_exprs.push_back(op_expr);
} else {
graph->cached_op_exprs.push_back(nullptr);
}
}
return Maybe<void>::Ok();
}

Maybe<one::TensorTuple> InterpretJob(const one::TensorTuple& graph_inputs,
const std::shared_ptr<NNGraph>& graph) {
if (graph->cached_op_exprs.empty()) { JUST(InitOpExprs(graph)); }

const auto& job = graph->job();
auto env = *JUST(InitEnv(graph_inputs, graph));

const auto dead_tensors = GetDeadTensorVector(job);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dead tensor 的含义是什么意思呢,可以注释下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,在 GetDeadTensorVector 的定义处有一个注释,我再在这里指明一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加


one::TensorTuple graph_outputs;
for (int i = 0; i < job.net().op_size(); i++) {
const auto& op_conf = job.net().op(i);
if (op_conf.has_user_conf()) {
auto op = CHECK_NOTNULL(graph->cached_op_exprs[i]);
const auto& user_conf = op_conf.user_conf();
OF_PROFILER_RANGE_GUARD(user_conf.op_type_name());
TensorTuple inputs =
*JUST(GetInputTensors(user_conf, env, [&op_conf](const std::shared_ptr<Tensor>& tensor) {
return CHECK_JUST(functional::To(tensor, op_conf.device_tag()));
}));
OpArgsVector<std::string> output_names = GetOutputNamesOfOp(user_conf);
if (IsViewOp(op)) {
JUST(RunViewOp(op, env, inputs, output_names));
} else {
JUST(RunNormalOp(op, env, inputs, output_names));
}
for (const auto& name : dead_tensors[i]) { CHECK_EQ_OR_RETURN(env.erase(name), 1); }
} else if (op_conf.has_output_conf()) {
const auto& output_conf = op_conf.output_conf();
graph_outputs.emplace_back(JUST(MapAt(env, output_conf.in())));
}
}
return graph_outputs;
}
} // namespace one
} // namespace oneflow
11 changes: 11 additions & 0 deletions oneflow/core/job/job_interpreter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/job/job.pb.h"

namespace oneflow {
class NNGraph;
namespace one {
class TensorTuple;
Maybe<one::TensorTuple> InterpretJob(const one::TensorTuple& inputs,
const std::shared_ptr<NNGraph>& graph);
} // namespace one
} // namespace oneflow
59 changes: 34 additions & 25 deletions python/oneflow/nn/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1472,21 +1472,40 @@ def __run(self, *args, **kwargs):
self.__ensure_input_tensors_contiguous(*args, **kwargs)
try:
flattened_eager_args = self.__flatten_io("input", *args, **kwargs)
outputs_tensor_tuple = self._outputs_tensor_tuple_buffer[
self._cur_index_of_ouputs_buffer
]
eager_outputs = self._eager_outputs_buffer[self._cur_index_of_ouputs_buffer]

# oneflow._oneflow_internal.eager.Sync() NOTE(chengcheng): Need Sync?
oneflow._oneflow_internal.nn.graph.RunLazyNNGraph(
convert_to_tensor_tuple(flattened_eager_args),
outputs_tensor_tuple,
self._c_nn_graph,
)
# Update outputs buffer reading index
self._cur_index_of_ouputs_buffer += 1
if self._cur_index_of_ouputs_buffer >= self._outputs_buffer_size:
self._cur_index_of_ouputs_buffer = 0

if oneflow.support.env_var_util.parse_boolean_from_env(
"ONEFLOW_RUN_GRAPH_BY_VM", False
):
eager_outputs = oneflow._oneflow_internal.nn.graph.RunLazyNNGraphByVM(
convert_to_tensor_tuple(flattened_eager_args),
self._c_nn_graph,
)
else:
outputs_tensor_tuple = self._outputs_tensor_tuple_buffer[
self._cur_index_of_ouputs_buffer
]
eager_outputs = self._eager_outputs_buffer[self._cur_index_of_ouputs_buffer]
# oneflow._oneflow_internal.eager.Sync() NOTE(chengcheng): Need Sync?
oneflow._oneflow_internal.nn.graph.RunLazyNNGraph(
convert_to_tensor_tuple(flattened_eager_args),
outputs_tensor_tuple,
self._c_nn_graph,
)
# Update outputs buffer reading index
self._cur_index_of_ouputs_buffer += 1
if self._cur_index_of_ouputs_buffer >= self._outputs_buffer_size:
self._cur_index_of_ouputs_buffer = 0

# Copy outputs from buffer
eager_outputs, _ = self.__copy_io("output", *eager_outputs)

# Make sure that last used devices of tensors in `outputs_tensor_tuple` are
# "critical_section".
# NNGraph's execution flow will be broken if `last_used_device` of `outputs_tensor_tuple`
# are not "critical_section".
oneflow._oneflow_internal.nn.graph.SoftSyncNNGraphBuffers(
outputs_tensor_tuple, self._c_nn_graph
)
except:
self.__print(
2,
Expand All @@ -1498,16 +1517,6 @@ def __run(self, *args, **kwargs):
)
raise

# Copy outputs from buffer
eager_outputs, _ = self.__copy_io("output", *eager_outputs)

# Make sure that last used devices of tensors in `outputs_tensor_tuple` are
# "critical_section".
# NNGraph's execution flow will be broken if `last_used_device` of `outputs_tensor_tuple`
# are not "critical_section".
oneflow._oneflow_internal.nn.graph.SoftSyncNNGraphBuffers(
outputs_tensor_tuple, self._c_nn_graph
)
# Always pack outputs to remain type of outputs
return seq_to_func_return(eager_outputs, True)

Expand Down
54 changes: 54 additions & 0 deletions python/oneflow/test/graph/test_run_graph_by_vm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import os

os.environ["ONEFLOW_RUN_GRAPH_BY_VM"] = "1"
os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1"
os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1"

import oneflow as flow
import numpy as np


class Graph(flow.nn.Graph):
def __init__(self, m):
super().__init__()
self.m = m

def build(self, x):
return self.m(x)


class M(flow.nn.Module):
def __init__(self):
super().__init__()
self.w = flow.nn.Parameter(flow.randn(4))

def forward(self, x):
# these broadcast_sub and cast ops will be
# eliminated by nn.Graph
w1 = self.w - self.w - self.w
x = x * w1.to(flow.float32)
return x


def test_run_graph_by_vm(capsys):
m = M().eval()
g = Graph(m)

input = flow.randn(4)
graph_output = g(input)
eager_output = m(input)
assert graph_output.shape == (4,)
assert np.allclose(graph_output, eager_output)

input = flow.randn(3, 4)
graph_output = g(input)
eager_output = m(input)
assert graph_output.shape == (3, 4)
assert np.allclose(graph_output, eager_output)

# Test the optimization in graph works.
# broadcast_sub and cast ops are pruned.
print(g)
assert "broadcast_sub" not in capsys.readouterr().out
assert "cast" not in capsys.readouterr().out
assert "broadcast_mul" not in capsys.readouterr().out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个怎么看起来不像标准的 unittest,ci 能跑到这个 case 么

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以,是 pytest 的写法,比 python 自带的 unittest 好用不少,CI 已经在用 pytest 跑了