From c47fda2cbb083be26c34535a1a66c30b099ecc9b Mon Sep 17 00:00:00 2001 From: wyg1997 Date: Thu, 13 Jan 2022 12:25:08 +0800 Subject: [PATCH 1/9] feat(Parameter): Parameter support both inplace op and setter --- oneflow/core/framework/tensor.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/framework/tensor.h b/oneflow/core/framework/tensor.h index 2d7e2fce348..bb43365d39b 100644 --- a/oneflow/core/framework/tensor.h +++ b/oneflow/core/framework/tensor.h @@ -288,7 +288,7 @@ class Parameter final : public TensorIf { Maybe> consistent_tensor_meta() const override { return tensor_->consistent_tensor_meta(); } - Maybe data() override { return tensor_; } + Maybe data() override { return tensor_->detach(); } // Must override grad_fn_node function. Otherwise grad_fn will belong to this not tensor_, // and it will be wrong when use Parameter.data() in operators. From 4c475835813c54b80d7c05b19541d0d1548dadee Mon Sep 17 00:00:00 2001 From: wyg1997 Date: Thu, 13 Jan 2022 12:38:23 +0800 Subject: [PATCH 2/9] feat(Tensor): tensor support data's getter interface --- oneflow/core/framework/tensor.h | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/oneflow/core/framework/tensor.h b/oneflow/core/framework/tensor.h index bb43365d39b..109e49948f5 100644 --- a/oneflow/core/framework/tensor.h +++ b/oneflow/core/framework/tensor.h @@ -417,10 +417,7 @@ class MirroredTensor final : public TensorIf { bool is_cuda() const override; const TensorMeta& tensor_meta() const override { return *impl_->tensor_meta(); } - Maybe data() override { - OF_LOG_ONCE(LOG(WARNING) << "You shouldn't call `.data` for a LocalTensor."); - return std::static_pointer_cast(shared_from_this()); - } + Maybe data() override { return this->detach(); } // Getters valid only for EagerMirroredTensor Maybe eager_blob_object() const override { @@ -522,10 +519,7 @@ class ConsistentTensor final : public TensorIf { return impl_->cur_rank_phy_tensor(); } bool is_cuda() const override; - Maybe data() override { - OF_LOG_ONCE(LOG(WARNING) << "You shouldn't call `.data` for a ConsistentTensor."); - return std::static_pointer_cast(shared_from_this()); - } + Maybe data() override { return this->detach(); } // Getters valid only for EagerMirroredTensor Maybe eager_blob_object() const override { From 94595a2fca842dffe9ffadce73a23ea1cea24c0a Mon Sep 17 00:00:00 2001 From: wyg1997 Date: Thu, 13 Jan 2022 12:50:12 +0800 Subject: [PATCH 3/9] test(Parameter): add getter test --- python/oneflow/test/tensor/test_parameter.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/oneflow/test/tensor/test_parameter.py b/python/oneflow/test/tensor/test_parameter.py index 029e6f89cd2..b8d63b72d6e 100644 --- a/python/oneflow/test/tensor/test_parameter.py +++ b/python/oneflow/test/tensor/test_parameter.py @@ -40,6 +40,12 @@ def test_parameter_set_data_autograd_meta(test_case): z.data = y return z.grad_fn, z.is_leaf + @autotest(n=1, check_graph=False) + def test_parameter_inplace_modify_data(test_case): + x = torch.nn.Parameter(torch.ones(2, 3)) + x.data.mul_(2) + return x + def test_parameter_set_data(test_case): a = flow.nn.Parameter(flow.ones(2, 3), False) old_id = id(a) From 49e3193c73732f5cbec4a68242dc751f53953af7 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 13 Jan 2022 17:59:16 +0800 Subject: [PATCH 4/9] debug --- .../op_interpreter/lazy_op_interpreter.cpp | 1 + python/oneflow/test/tensor/test_parameter.py | 4 ++-- .../automated_test_util/torch_flow_dual_object.py | 15 ++++++++++++++- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp index 151d11a1d9e..636cf40a3c6 100644 --- a/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp @@ -683,6 +683,7 @@ Maybe LazyInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTu (*outputs)[i] = JUST(BuildTensor(op_attr, obn, blob_parallel_desc, /* is_lazy= */ true, is_local)); } else { + VLOG(2) << "Lazy nn.Graph name " << graph_name << " op name " << new_op_name << " run with inplace."; const std::shared_ptr& inplace_out = (*outputs)[i]; JUST(CheckTensorMatchAttr(inplace_out, op_attr, obn, blob_parallel_desc, is_local)); } diff --git a/python/oneflow/test/tensor/test_parameter.py b/python/oneflow/test/tensor/test_parameter.py index b8d63b72d6e..828a25b3c34 100644 --- a/python/oneflow/test/tensor/test_parameter.py +++ b/python/oneflow/test/tensor/test_parameter.py @@ -40,10 +40,10 @@ def test_parameter_set_data_autograd_meta(test_case): z.data = y return z.grad_fn, z.is_leaf - @autotest(n=1, check_graph=False) + @autotest(n=1, check_graph=True) def test_parameter_inplace_modify_data(test_case): x = torch.nn.Parameter(torch.ones(2, 3)) - x.data.mul_(2) + x = x.data.mul_(2) return x def test_parameter_set_data(test_case): diff --git a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py index 94ae59ab922..1745ca9aacd 100644 --- a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py +++ b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py @@ -335,7 +335,7 @@ def build(self, *args): test_g = TestGraphOfModule() if verbose: print("Run graph of module: ", repr(oneflow)) - test_g.debug(2) + test_g.debug(3) test_g_res = test_g(*oneflow_args) elif oneflow.__name__ in ignore_apis_list: find_check_module_func = False @@ -379,8 +379,18 @@ def build(self): test_g_res = oneflow_res else: pass + if verbose: + print("Run graph of function: ", repr(oneflow), ", graph check is intentionally skiped.") + elif (oneflow.__name__ == "Parameter"): + # nn.Graph donot deal with Parameter creation. + test_g_res = oneflow_res + if verbose: + print("Run graph of function: ", repr(oneflow), ", graph check is intentionally skiped.") else: test_g = TestGraphOfFunctional() + if verbose: + print("Run graph of function: ", repr(oneflow)) + test_g.debug(3) test_g_res = test_g() except Exception as e: print_note_fake_program() @@ -439,6 +449,9 @@ def build(self): try: test_g = TestGraphOfTensorMethod() + if verbose: + print("Run graph of method: ", repr(oneflow)) + test_g.debug(3) test_g_res = test_g() except Exception as e: print_note_fake_program() From abc42e19d216558b572b6aba5676826bc5499d3e Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 13 Jan 2022 22:55:30 +0800 Subject: [PATCH 5/9] add test --- .../op_interpreter/lazy_op_interpreter.cpp | 102 ++++++++++-------- python/oneflow/framework/graph_build_util.py | 1 - .../graph/test_graph_free_eager_tensor.py | 67 ++++++++++++ python/oneflow/test/tensor/test_parameter.py | 8 +- 4 files changed, 129 insertions(+), 49 deletions(-) diff --git a/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp index 636cf40a3c6..e4dd86991fe 100644 --- a/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "oneflow/core/common/cpp_attribute.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/cpp_attribute.h" @@ -190,6 +191,52 @@ Maybe NewScopeWithParallelDescByTensor(const std::shared_ptr& ten return NewScopeWithParallelConfAndCurScope(parallel_conf); } +Maybe AddFreeEagerTensorToVariableOp(const std::shared_ptr& input_tensor) { + CHECK_OR_RETURN(input_tensor->is_eager()); + const std::string& empty_lbn = TensorNameScope::Global()->Lookup(input_tensor); + CHECK_OR_RETURN(empty_lbn.empty()); + std::shared_ptr scope = JUST(NewScopeWithParallelDescByTensor(input_tensor)); + OperatorConf op_conf; + op_conf.set_scope_symbol_id(JUST(scope->symbol_id())); + op_conf.set_device_tag(GetDeviceTagOfTensor(input_tensor)); + VariableOpConf* var_conf = op_conf.mutable_variable_conf(); + var_conf->set_out("out"); + input_tensor->shape()->ToProto(var_conf->mutable_shape()); + var_conf->set_data_type(input_tensor->dtype()->data_type()); + // NOTE(chengcheng): VariableOpConf initializer_conf is useless because variable is inited + // by EagerTensor. + var_conf->mutable_initializer()->mutable_empty_conf(); + JUST(GenVariableOpConfNdSbpStringByTensor(var_conf, input_tensor)); + // NOTE(chengcheng): Free EagerTensor not trainable + var_conf->set_trainable(false); + + auto infer_ctx = JUST(GetCurInferCtx()); + // NOTE(chengcheng): MUST reset unique op name before InferCtx::AddOp, FreeEagerTensor has no + // name so just new a unique name for it. + const std::string new_op_name = *JUST(infer_ctx->NewUniqueOpNameByFunctionalOpConf(op_conf)); + op_conf.set_name(new_op_name); + + VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " try to add op: \n" + << op_conf.DebugString() << std::endl; + OpAttribute op_attr = *JUST(infer_ctx->AddAndInferConsistentOp(op_conf)); + VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " add op : \n" + << op_conf.DebugString() << " for FreeEagerTensor.\n"; + VLOG(3) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() + << " infer and and op attr : \n" + << op_attr.DebugString() << " for FreeEagerTensor.\n"; + + // NOTE(chengcheng): MUST store this tensor to MultiClientSessionContext for graph runtime bind. + const std::string graph_name = *JUST(JUST(GlobalJobBuildAndInferCtxMgr())->GetCurrentJobName()); + const std::string lbn = GenLogicalBlobName(new_op_name, "out"); + Global::Get()->StoreFreeEagerTensorWithNameByGraphName( + graph_name, input_tensor, new_op_name); + // NOTE(chengcheng): MUST record this eager_tensor name as new variable output lbn. + TensorNameScope::Global()->Record(input_tensor, lbn); + + return Maybe::Ok(); +} + + Maybe LazyInterpreter::ApplyImpl(const FeedInputOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { // NOTE(chengcheng): inputs[0] is the EagerTensor @@ -308,7 +355,15 @@ Maybe LazyInterpreter::ApplyImpl(const FetchOutputOpExpr& op_expr, const T CHECK_EQ_OR_RETURN(inputs.size(), 1); CHECK_EQ_OR_RETURN(op_expr.input_size(), 1); const std::shared_ptr& input_tensor = inputs.at(0); - CHECK_OR_RETURN(input_tensor->is_lazy()); + if (!input_tensor->is_lazy()) { + const std::string& input_lbn = TensorNameScope::Global()->Lookup(input_tensor); + if (input_lbn.empty()) { + // This output tensor is a new free eager tensor, so treat it as a new variable op output. + JUST(AddFreeEagerTensorToVariableOp(input_tensor)); + } + // Else, this eager output tensor has already been treated as an output of a variable op + // or an inplace op, so do nothing. + } const std::string& input_lbn = TensorNameScope::Global()->Lookup(input_tensor); CHECK_OR_RETURN(!input_lbn.empty()); // lbn must exist. @@ -490,51 +545,6 @@ Maybe LazyInterpreterApplyImplForSourceUserOpExpr(const UserOpExpr& op_exp return Maybe::Ok(); } -Maybe AddFreeEagerTensorToVariableOp(const std::shared_ptr& input_tensor) { - CHECK_OR_RETURN(input_tensor->is_eager()); - const std::string& empty_lbn = TensorNameScope::Global()->Lookup(input_tensor); - CHECK_OR_RETURN(empty_lbn.empty()); - std::shared_ptr scope = JUST(NewScopeWithParallelDescByTensor(input_tensor)); - OperatorConf op_conf; - op_conf.set_scope_symbol_id(JUST(scope->symbol_id())); - op_conf.set_device_tag(GetDeviceTagOfTensor(input_tensor)); - VariableOpConf* var_conf = op_conf.mutable_variable_conf(); - var_conf->set_out("out"); - input_tensor->shape()->ToProto(var_conf->mutable_shape()); - var_conf->set_data_type(input_tensor->dtype()->data_type()); - // NOTE(chengcheng): VariableOpConf initializer_conf is useless because variable is inited - // by EagerTensor. - var_conf->mutable_initializer()->mutable_empty_conf(); - JUST(GenVariableOpConfNdSbpStringByTensor(var_conf, input_tensor)); - // NOTE(chengcheng): Free EagerTensor not trainable - var_conf->set_trainable(false); - - auto infer_ctx = JUST(GetCurInferCtx()); - // NOTE(chengcheng): MUST reset unique op name before InferCtx::AddOp, FreeEagerTensor has no - // name so just new a unique name for it. - const std::string new_op_name = *JUST(infer_ctx->NewUniqueOpNameByFunctionalOpConf(op_conf)); - op_conf.set_name(new_op_name); - - VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " try to add op: \n" - << op_conf.DebugString() << std::endl; - OpAttribute op_attr = *JUST(infer_ctx->AddAndInferConsistentOp(op_conf)); - VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " add op : \n" - << op_conf.DebugString() << " for FreeEagerTensor.\n"; - VLOG(3) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() - << " infer and and op attr : \n" - << op_attr.DebugString() << " for FreeEagerTensor.\n"; - - // NOTE(chengcheng): MUST store this tensor to MultiClientSessionContext for graph runtime bind. - const std::string graph_name = *JUST(JUST(GlobalJobBuildAndInferCtxMgr())->GetCurrentJobName()); - const std::string lbn = GenLogicalBlobName(new_op_name, "out"); - Global::Get()->StoreFreeEagerTensorWithNameByGraphName( - graph_name, input_tensor, new_op_name); - // NOTE(chengcheng): MUST record this eager_tensor name as new variable output lbn. - TensorNameScope::Global()->Record(input_tensor, lbn); - - return Maybe::Ok(); -} - Maybe LazyInterpreterApplyImplForCopyUserOpExpr(const UserOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, diff --git a/python/oneflow/framework/graph_build_util.py b/python/oneflow/framework/graph_build_util.py index bd2b357375b..c145fe9d1b6 100644 --- a/python/oneflow/framework/graph_build_util.py +++ b/python/oneflow/framework/graph_build_util.py @@ -184,7 +184,6 @@ def build_graph_state(op_name, state_tensor, state_config): def build_graph_output(op_name, out): assert isinstance(out, Tensor) - assert out.is_lazy output_conf = ( oneflow._oneflow_internal.oneflow.core.operator.op_conf.FetchOutputOpConf() diff --git a/python/oneflow/test/graph/test_graph_free_eager_tensor.py b/python/oneflow/test/graph/test_graph_free_eager_tensor.py index f4890daa61c..a481664d7a0 100644 --- a/python/oneflow/test/graph/test_graph_free_eager_tensor.py +++ b/python/oneflow/test/graph/test_graph_free_eager_tensor.py @@ -115,6 +115,73 @@ def build(self): np.allclose(mul_out.numpy(), np_x * np_y, atol=1e-4, rtol=1e-4) ) + def test_graph_return_free_eager_tensor(test_case): + np_x = np.random.randn(5, 3) + x = flow.tensor(np_x, dtype=flow.float32) + + class GraphReturnEager(flow.nn.Graph): + def __init__(self): + super().__init__() + + def build(self): + # Return free eager tensor + return x + + g_return_eager = GraphReturnEager() + + # Run first time + ret_eager_out = g_return_eager() + test_case.assertTrue( + np.allclose(ret_eager_out.numpy(), np_x, atol=1e-4, rtol=1e-4) + ) + + # Run second time + ret_eager_out1 = g_return_eager() + test_case.assertTrue( + np.allclose(ret_eager_out1.numpy(), np_x, atol=1e-4, rtol=1e-4) + ) + + def test_graph_return_inplace_free_eager_tensor(test_case): + np_x = np.random.randn(5, 3) + x = flow.tensor(np_x, dtype=flow.float32) + + class GraphInplaceReturnEager(flow.nn.Graph): + def __init__(self): + super().__init__() + + def build(self): + # x is free eager tensor + # mul_ is inplace scalar mul + # Input and output of mul_ are both tensor x + # After lazy interpretr, tensor x's name will be the ouput lbn of mul_ + x.mul_(2) + # Here will return the output of mul_ + return x + + g_return_eager = GraphInplaceReturnEager() + + # Run first time + ret_eager_out = g_return_eager() + # x in ouput changed + # So nn.Graph simulate inplace in nn.Graph.build(). + test_case.assertTrue( + np.allclose(ret_eager_out.numpy(), np_x*2, atol=1e-4, rtol=1e-4) + ) + # x has not changed + # So nn.Graph inplace will not change free eager tensor. + test_case.assertTrue( + np.allclose(x.numpy(), np_x, atol=1e-4, rtol=1e-4) + ) + + # Run second time + ret_eager_out = g_return_eager() + test_case.assertTrue( + np.allclose(ret_eager_out.numpy(), np_x*2, atol=1e-4, rtol=1e-4) + ) + test_case.assertTrue( + np.allclose(x.numpy(), np_x, atol=1e-4, rtol=1e-4) + ) + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() diff --git a/python/oneflow/test/tensor/test_parameter.py b/python/oneflow/test/tensor/test_parameter.py index 828a25b3c34..9e7e44e92e5 100644 --- a/python/oneflow/test/tensor/test_parameter.py +++ b/python/oneflow/test/tensor/test_parameter.py @@ -40,10 +40,14 @@ def test_parameter_set_data_autograd_meta(test_case): z.data = y return z.grad_fn, z.is_leaf - @autotest(n=1, check_graph=True) + # Not check graph because of 2 reason. + # Reason 1, x.data return a new tensor but share storage with the origin tensor, this is not well dealed in nn.Graph. + # Reason 2, inplace operation mul_ can works well inside nn.Graph but will not change the value in free eager tensor. + # Please refer to test case: test_graph_return_inplace_free_eager_tensor + @autotest(n=1, check_graph=False) def test_parameter_inplace_modify_data(test_case): x = torch.nn.Parameter(torch.ones(2, 3)) - x = x.data.mul_(2) + x.data.mul_(2) return x def test_parameter_set_data(test_case): From bef91d53d23f4cfd631a6333da5e9e8c5238a8e6 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 13 Jan 2022 23:10:16 +0800 Subject: [PATCH 6/9] open flatten graph test --- python/oneflow/test/modules/test_flatten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/oneflow/test/modules/test_flatten.py b/python/oneflow/test/modules/test_flatten.py index 91145475b9f..2d69caf8131 100644 --- a/python/oneflow/test/modules/test_flatten.py +++ b/python/oneflow/test/modules/test_flatten.py @@ -79,7 +79,7 @@ def test_flatten_module_with_random_data(test_case): y = m(x) return y - @autotest(check_graph=False) + @autotest(check_graph=True) def test_flatten_with_random_data(test_case): device = random_device() x = random_pytorch_tensor().to(device) From 6149ce6e323a78b4f083ce2e770b05f7107cf999 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 13 Jan 2022 23:24:10 +0800 Subject: [PATCH 7/9] add validated flase type --- python/oneflow/test/tensor/test_parameter.py | 2 +- .../test_utils/automated_test_util/torch_flow_dual_object.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/python/oneflow/test/tensor/test_parameter.py b/python/oneflow/test/tensor/test_parameter.py index 9e7e44e92e5..8c9284eb562 100644 --- a/python/oneflow/test/tensor/test_parameter.py +++ b/python/oneflow/test/tensor/test_parameter.py @@ -44,7 +44,7 @@ def test_parameter_set_data_autograd_meta(test_case): # Reason 1, x.data return a new tensor but share storage with the origin tensor, this is not well dealed in nn.Graph. # Reason 2, inplace operation mul_ can works well inside nn.Graph but will not change the value in free eager tensor. # Please refer to test case: test_graph_return_inplace_free_eager_tensor - @autotest(n=1, check_graph=False) + @autotest(n=1, check_graph="ValidatedFlase") def test_parameter_inplace_modify_data(test_case): x = torch.nn.Parameter(torch.ones(2, 3)) x.data.mul_(2) diff --git a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py index 1745ca9aacd..f6ed25588ef 100644 --- a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py +++ b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py @@ -689,6 +689,10 @@ def autotest( ): verbose = os.getenv("ONEFLOW_TEST_VERBOSE") is not None + if check_graph == "ValidatedFlase": + # check graph is intentionally closed and threre is a validated reason. + check_graph = False + def deco(f): @functools.wraps(f) def new_f(test_case): From 6121b3c70e76ccc4a8a9e243802d2f1f42839043 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 14 Jan 2022 11:07:25 +0800 Subject: [PATCH 8/9] refine --- .../op_interpreter/lazy_op_interpreter.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp index e4dd86991fe..c9855fabaf0 100644 --- a/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp @@ -355,16 +355,16 @@ Maybe LazyInterpreter::ApplyImpl(const FetchOutputOpExpr& op_expr, const T CHECK_EQ_OR_RETURN(inputs.size(), 1); CHECK_EQ_OR_RETURN(op_expr.input_size(), 1); const std::shared_ptr& input_tensor = inputs.at(0); - if (!input_tensor->is_lazy()) { - const std::string& input_lbn = TensorNameScope::Global()->Lookup(input_tensor); - if (input_lbn.empty()) { - // This output tensor is a new free eager tensor, so treat it as a new variable op output. - JUST(AddFreeEagerTensorToVariableOp(input_tensor)); - } - // Else, this eager output tensor has already been treated as an output of a variable op - // or an inplace op, so do nothing. + std::string input_lbn = TensorNameScope::Global()->Lookup(input_tensor); + // Lazy tensor must has lbn. + // Eager tensor may has lbn if it has already been treated as an output of a variable op + // or an output of an inplace op. + if (input_lbn.empty()) { + CHECK_OR_RETURN(input_tensor->is_eager()); + // This output tensor is a new free eager tensor, so treat it as a new variable op output. + JUST(AddFreeEagerTensorToVariableOp(input_tensor)); + input_lbn = TensorNameScope::Global()->Lookup(input_tensor); } - const std::string& input_lbn = TensorNameScope::Global()->Lookup(input_tensor); CHECK_OR_RETURN(!input_lbn.empty()); // lbn must exist. std::shared_ptr scope = JUST(NewScopeWithParallelDescByTensor(input_tensor)); From 713400dc10cedf0c3991a9aaa847f3e3dae765c2 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 14 Jan 2022 11:09:07 +0800 Subject: [PATCH 9/9] foramt --- .../op_interpreter/lazy_op_interpreter.cpp | 4 ++-- .../graph/test_graph_free_eager_tensor.py | 12 ++++-------- .../torch_flow_dual_object.py | 19 +++++++++++++++---- 3 files changed, 21 insertions(+), 14 deletions(-) diff --git a/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp index c9855fabaf0..1d851718628 100644 --- a/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp @@ -236,7 +236,6 @@ Maybe AddFreeEagerTensorToVariableOp(const std::shared_ptr& input_ return Maybe::Ok(); } - Maybe LazyInterpreter::ApplyImpl(const FeedInputOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { // NOTE(chengcheng): inputs[0] is the EagerTensor @@ -693,7 +692,8 @@ Maybe LazyInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTu (*outputs)[i] = JUST(BuildTensor(op_attr, obn, blob_parallel_desc, /* is_lazy= */ true, is_local)); } else { - VLOG(2) << "Lazy nn.Graph name " << graph_name << " op name " << new_op_name << " run with inplace."; + VLOG(2) << "Lazy nn.Graph name " << graph_name << " op name " << new_op_name + << " run with inplace."; const std::shared_ptr& inplace_out = (*outputs)[i]; JUST(CheckTensorMatchAttr(inplace_out, op_attr, obn, blob_parallel_desc, is_local)); } diff --git a/python/oneflow/test/graph/test_graph_free_eager_tensor.py b/python/oneflow/test/graph/test_graph_free_eager_tensor.py index a481664d7a0..07a3f17006a 100644 --- a/python/oneflow/test/graph/test_graph_free_eager_tensor.py +++ b/python/oneflow/test/graph/test_graph_free_eager_tensor.py @@ -165,22 +165,18 @@ def build(self): # x in ouput changed # So nn.Graph simulate inplace in nn.Graph.build(). test_case.assertTrue( - np.allclose(ret_eager_out.numpy(), np_x*2, atol=1e-4, rtol=1e-4) + np.allclose(ret_eager_out.numpy(), np_x * 2, atol=1e-4, rtol=1e-4) ) # x has not changed # So nn.Graph inplace will not change free eager tensor. - test_case.assertTrue( - np.allclose(x.numpy(), np_x, atol=1e-4, rtol=1e-4) - ) + test_case.assertTrue(np.allclose(x.numpy(), np_x, atol=1e-4, rtol=1e-4)) # Run second time ret_eager_out = g_return_eager() test_case.assertTrue( - np.allclose(ret_eager_out.numpy(), np_x*2, atol=1e-4, rtol=1e-4) - ) - test_case.assertTrue( - np.allclose(x.numpy(), np_x, atol=1e-4, rtol=1e-4) + np.allclose(ret_eager_out.numpy(), np_x * 2, atol=1e-4, rtol=1e-4) ) + test_case.assertTrue(np.allclose(x.numpy(), np_x, atol=1e-4, rtol=1e-4)) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") diff --git a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py index f6ed25588ef..7761facd6a3 100644 --- a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py +++ b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py @@ -380,16 +380,27 @@ def build(self): else: pass if verbose: - print("Run graph of function: ", repr(oneflow), ", graph check is intentionally skiped.") - elif (oneflow.__name__ == "Parameter"): + print( + "Run graph of function: ", + repr(oneflow), + ", graph check is intentionally skiped.", + ) + elif oneflow.__name__ == "Parameter": # nn.Graph donot deal with Parameter creation. test_g_res = oneflow_res if verbose: - print("Run graph of function: ", repr(oneflow), ", graph check is intentionally skiped.") + print( + "Run graph of function: ", + repr(oneflow), + ", graph check is intentionally skiped.", + ) else: test_g = TestGraphOfFunctional() if verbose: - print("Run graph of function: ", repr(oneflow)) + print( + "Run graph of function: ", + repr(oneflow), + ) test_g.debug(3) test_g_res = test_g() except Exception as e: