From 877dc48cc24e95681a17ee7b34b2ab3c7a07f3e0 Mon Sep 17 00:00:00 2001 From: lixiang <88304454@qq.com> Date: Mon, 17 Jan 2022 18:33:15 +0800 Subject: [PATCH 1/6] Fix autotest inplace bug, hardsigmod --- .../op_interpreter/lazy_op_interpreter.cpp | 105 ++++++++++-------- python/oneflow/framework/graph_build_util.py | 1 - .../graph/test_graph_free_eager_tensor.py | 63 +++++++++++ .../oneflow/test/modules/test_activation.py | 2 +- python/oneflow/test/modules/test_flatten.py | 2 +- python/oneflow/test/tensor/test_parameter.py | 6 +- .../torch_flow_dual_object.py | 37 +++++- 7 files changed, 163 insertions(+), 53 deletions(-) diff --git a/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp index 151d11a1d9e..1d851718628 100644 --- a/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "oneflow/core/common/cpp_attribute.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/cpp_attribute.h" @@ -190,6 +191,51 @@ Maybe NewScopeWithParallelDescByTensor(const std::shared_ptr& ten return NewScopeWithParallelConfAndCurScope(parallel_conf); } +Maybe AddFreeEagerTensorToVariableOp(const std::shared_ptr& input_tensor) { + CHECK_OR_RETURN(input_tensor->is_eager()); + const std::string& empty_lbn = TensorNameScope::Global()->Lookup(input_tensor); + CHECK_OR_RETURN(empty_lbn.empty()); + std::shared_ptr scope = JUST(NewScopeWithParallelDescByTensor(input_tensor)); + OperatorConf op_conf; + op_conf.set_scope_symbol_id(JUST(scope->symbol_id())); + op_conf.set_device_tag(GetDeviceTagOfTensor(input_tensor)); + VariableOpConf* var_conf = op_conf.mutable_variable_conf(); + var_conf->set_out("out"); + input_tensor->shape()->ToProto(var_conf->mutable_shape()); + var_conf->set_data_type(input_tensor->dtype()->data_type()); + // NOTE(chengcheng): VariableOpConf initializer_conf is useless because variable is inited + // by EagerTensor. + var_conf->mutable_initializer()->mutable_empty_conf(); + JUST(GenVariableOpConfNdSbpStringByTensor(var_conf, input_tensor)); + // NOTE(chengcheng): Free EagerTensor not trainable + var_conf->set_trainable(false); + + auto infer_ctx = JUST(GetCurInferCtx()); + // NOTE(chengcheng): MUST reset unique op name before InferCtx::AddOp, FreeEagerTensor has no + // name so just new a unique name for it. + const std::string new_op_name = *JUST(infer_ctx->NewUniqueOpNameByFunctionalOpConf(op_conf)); + op_conf.set_name(new_op_name); + + VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " try to add op: \n" + << op_conf.DebugString() << std::endl; + OpAttribute op_attr = *JUST(infer_ctx->AddAndInferConsistentOp(op_conf)); + VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " add op : \n" + << op_conf.DebugString() << " for FreeEagerTensor.\n"; + VLOG(3) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() + << " infer and and op attr : \n" + << op_attr.DebugString() << " for FreeEagerTensor.\n"; + + // NOTE(chengcheng): MUST store this tensor to MultiClientSessionContext for graph runtime bind. + const std::string graph_name = *JUST(JUST(GlobalJobBuildAndInferCtxMgr())->GetCurrentJobName()); + const std::string lbn = GenLogicalBlobName(new_op_name, "out"); + Global::Get()->StoreFreeEagerTensorWithNameByGraphName( + graph_name, input_tensor, new_op_name); + // NOTE(chengcheng): MUST record this eager_tensor name as new variable output lbn. + TensorNameScope::Global()->Record(input_tensor, lbn); + + return Maybe::Ok(); +} + Maybe LazyInterpreter::ApplyImpl(const FeedInputOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { // NOTE(chengcheng): inputs[0] is the EagerTensor @@ -308,8 +354,16 @@ Maybe LazyInterpreter::ApplyImpl(const FetchOutputOpExpr& op_expr, const T CHECK_EQ_OR_RETURN(inputs.size(), 1); CHECK_EQ_OR_RETURN(op_expr.input_size(), 1); const std::shared_ptr& input_tensor = inputs.at(0); - CHECK_OR_RETURN(input_tensor->is_lazy()); - const std::string& input_lbn = TensorNameScope::Global()->Lookup(input_tensor); + std::string input_lbn = TensorNameScope::Global()->Lookup(input_tensor); + // Lazy tensor must has lbn. + // Eager tensor may has lbn if it has already been treated as an output of a variable op + // or an output of an inplace op. + if (input_lbn.empty()) { + CHECK_OR_RETURN(input_tensor->is_eager()); + // This output tensor is a new free eager tensor, so treat it as a new variable op output. + JUST(AddFreeEagerTensorToVariableOp(input_tensor)); + input_lbn = TensorNameScope::Global()->Lookup(input_tensor); + } CHECK_OR_RETURN(!input_lbn.empty()); // lbn must exist. std::shared_ptr scope = JUST(NewScopeWithParallelDescByTensor(input_tensor)); @@ -490,51 +544,6 @@ Maybe LazyInterpreterApplyImplForSourceUserOpExpr(const UserOpExpr& op_exp return Maybe::Ok(); } -Maybe AddFreeEagerTensorToVariableOp(const std::shared_ptr& input_tensor) { - CHECK_OR_RETURN(input_tensor->is_eager()); - const std::string& empty_lbn = TensorNameScope::Global()->Lookup(input_tensor); - CHECK_OR_RETURN(empty_lbn.empty()); - std::shared_ptr scope = JUST(NewScopeWithParallelDescByTensor(input_tensor)); - OperatorConf op_conf; - op_conf.set_scope_symbol_id(JUST(scope->symbol_id())); - op_conf.set_device_tag(GetDeviceTagOfTensor(input_tensor)); - VariableOpConf* var_conf = op_conf.mutable_variable_conf(); - var_conf->set_out("out"); - input_tensor->shape()->ToProto(var_conf->mutable_shape()); - var_conf->set_data_type(input_tensor->dtype()->data_type()); - // NOTE(chengcheng): VariableOpConf initializer_conf is useless because variable is inited - // by EagerTensor. - var_conf->mutable_initializer()->mutable_empty_conf(); - JUST(GenVariableOpConfNdSbpStringByTensor(var_conf, input_tensor)); - // NOTE(chengcheng): Free EagerTensor not trainable - var_conf->set_trainable(false); - - auto infer_ctx = JUST(GetCurInferCtx()); - // NOTE(chengcheng): MUST reset unique op name before InferCtx::AddOp, FreeEagerTensor has no - // name so just new a unique name for it. - const std::string new_op_name = *JUST(infer_ctx->NewUniqueOpNameByFunctionalOpConf(op_conf)); - op_conf.set_name(new_op_name); - - VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " try to add op: \n" - << op_conf.DebugString() << std::endl; - OpAttribute op_attr = *JUST(infer_ctx->AddAndInferConsistentOp(op_conf)); - VLOG(2) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() << " add op : \n" - << op_conf.DebugString() << " for FreeEagerTensor.\n"; - VLOG(3) << "Lazy nn.Graph name " << infer_ctx->job().job_conf().job_name() - << " infer and and op attr : \n" - << op_attr.DebugString() << " for FreeEagerTensor.\n"; - - // NOTE(chengcheng): MUST store this tensor to MultiClientSessionContext for graph runtime bind. - const std::string graph_name = *JUST(JUST(GlobalJobBuildAndInferCtxMgr())->GetCurrentJobName()); - const std::string lbn = GenLogicalBlobName(new_op_name, "out"); - Global::Get()->StoreFreeEagerTensorWithNameByGraphName( - graph_name, input_tensor, new_op_name); - // NOTE(chengcheng): MUST record this eager_tensor name as new variable output lbn. - TensorNameScope::Global()->Record(input_tensor, lbn); - - return Maybe::Ok(); -} - Maybe LazyInterpreterApplyImplForCopyUserOpExpr(const UserOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, @@ -683,6 +692,8 @@ Maybe LazyInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTu (*outputs)[i] = JUST(BuildTensor(op_attr, obn, blob_parallel_desc, /* is_lazy= */ true, is_local)); } else { + VLOG(2) << "Lazy nn.Graph name " << graph_name << " op name " << new_op_name + << " run with inplace."; const std::shared_ptr& inplace_out = (*outputs)[i]; JUST(CheckTensorMatchAttr(inplace_out, op_attr, obn, blob_parallel_desc, is_local)); } diff --git a/python/oneflow/framework/graph_build_util.py b/python/oneflow/framework/graph_build_util.py index bd2b357375b..c145fe9d1b6 100644 --- a/python/oneflow/framework/graph_build_util.py +++ b/python/oneflow/framework/graph_build_util.py @@ -184,7 +184,6 @@ def build_graph_state(op_name, state_tensor, state_config): def build_graph_output(op_name, out): assert isinstance(out, Tensor) - assert out.is_lazy output_conf = ( oneflow._oneflow_internal.oneflow.core.operator.op_conf.FetchOutputOpConf() diff --git a/python/oneflow/test/graph/test_graph_free_eager_tensor.py b/python/oneflow/test/graph/test_graph_free_eager_tensor.py index f4890daa61c..07a3f17006a 100644 --- a/python/oneflow/test/graph/test_graph_free_eager_tensor.py +++ b/python/oneflow/test/graph/test_graph_free_eager_tensor.py @@ -115,6 +115,69 @@ def build(self): np.allclose(mul_out.numpy(), np_x * np_y, atol=1e-4, rtol=1e-4) ) + def test_graph_return_free_eager_tensor(test_case): + np_x = np.random.randn(5, 3) + x = flow.tensor(np_x, dtype=flow.float32) + + class GraphReturnEager(flow.nn.Graph): + def __init__(self): + super().__init__() + + def build(self): + # Return free eager tensor + return x + + g_return_eager = GraphReturnEager() + + # Run first time + ret_eager_out = g_return_eager() + test_case.assertTrue( + np.allclose(ret_eager_out.numpy(), np_x, atol=1e-4, rtol=1e-4) + ) + + # Run second time + ret_eager_out1 = g_return_eager() + test_case.assertTrue( + np.allclose(ret_eager_out1.numpy(), np_x, atol=1e-4, rtol=1e-4) + ) + + def test_graph_return_inplace_free_eager_tensor(test_case): + np_x = np.random.randn(5, 3) + x = flow.tensor(np_x, dtype=flow.float32) + + class GraphInplaceReturnEager(flow.nn.Graph): + def __init__(self): + super().__init__() + + def build(self): + # x is free eager tensor + # mul_ is inplace scalar mul + # Input and output of mul_ are both tensor x + # After lazy interpretr, tensor x's name will be the ouput lbn of mul_ + x.mul_(2) + # Here will return the output of mul_ + return x + + g_return_eager = GraphInplaceReturnEager() + + # Run first time + ret_eager_out = g_return_eager() + # x in ouput changed + # So nn.Graph simulate inplace in nn.Graph.build(). + test_case.assertTrue( + np.allclose(ret_eager_out.numpy(), np_x * 2, atol=1e-4, rtol=1e-4) + ) + # x has not changed + # So nn.Graph inplace will not change free eager tensor. + test_case.assertTrue(np.allclose(x.numpy(), np_x, atol=1e-4, rtol=1e-4)) + + # Run second time + ret_eager_out = g_return_eager() + test_case.assertTrue( + np.allclose(ret_eager_out.numpy(), np_x * 2, atol=1e-4, rtol=1e-4) + ) + test_case.assertTrue(np.allclose(x.numpy(), np_x, atol=1e-4, rtol=1e-4)) + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() diff --git a/python/oneflow/test/modules/test_activation.py b/python/oneflow/test/modules/test_activation.py index bae166eb9ae..1c4df68e06b 100644 --- a/python/oneflow/test/modules/test_activation.py +++ b/python/oneflow/test/modules/test_activation.py @@ -256,7 +256,7 @@ def test_hardsigmoid_module_with_random_data(test_case): y = m(x) return y - @autotest(check_graph=False) + @autotest(check_graph=True) def test_functional_hardsigmoid_with_random_data(test_case): device = random_device() x = random_pytorch_tensor().to(device) diff --git a/python/oneflow/test/modules/test_flatten.py b/python/oneflow/test/modules/test_flatten.py index 91145475b9f..2d69caf8131 100644 --- a/python/oneflow/test/modules/test_flatten.py +++ b/python/oneflow/test/modules/test_flatten.py @@ -79,7 +79,7 @@ def test_flatten_module_with_random_data(test_case): y = m(x) return y - @autotest(check_graph=False) + @autotest(check_graph=True) def test_flatten_with_random_data(test_case): device = random_device() x = random_pytorch_tensor().to(device) diff --git a/python/oneflow/test/tensor/test_parameter.py b/python/oneflow/test/tensor/test_parameter.py index b8d63b72d6e..8c9284eb562 100644 --- a/python/oneflow/test/tensor/test_parameter.py +++ b/python/oneflow/test/tensor/test_parameter.py @@ -40,7 +40,11 @@ def test_parameter_set_data_autograd_meta(test_case): z.data = y return z.grad_fn, z.is_leaf - @autotest(n=1, check_graph=False) + # Not check graph because of 2 reason. + # Reason 1, x.data return a new tensor but share storage with the origin tensor, this is not well dealed in nn.Graph. + # Reason 2, inplace operation mul_ can works well inside nn.Graph but will not change the value in free eager tensor. + # Please refer to test case: test_graph_return_inplace_free_eager_tensor + @autotest(n=1, check_graph="ValidatedFlase") def test_parameter_inplace_modify_data(test_case): x = torch.nn.Parameter(torch.ones(2, 3)) x.data.mul_(2) diff --git a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py index 94ae59ab922..50bc3bc8461 100644 --- a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py +++ b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py @@ -18,6 +18,7 @@ import inspect import os import warnings +import copy import numpy as np import oneflow as flow @@ -317,6 +318,10 @@ def dual_method(self, *args, **kwargs): if name in postulate: oneflow_res = torch_tensor_to_flow(pytorch_res) else: + graph_args = [] + for arg in oneflow_args: + copy_arg = copy.deepcopy(arg) + graph_args.append(copy_arg) oneflow_res = oneflow(*oneflow_args, **oneflow_kwargs) if testing_graph: find_check_module_func = True @@ -335,7 +340,7 @@ def build(self, *args): test_g = TestGraphOfModule() if verbose: print("Run graph of module: ", repr(oneflow)) - test_g.debug(2) + test_g.debug(3) test_g_res = test_g(*oneflow_args) elif oneflow.__name__ in ignore_apis_list: find_check_module_func = False @@ -357,7 +362,7 @@ def __init__(self): def build(self): return oneflow( - *oneflow_args, **oneflow_kwargs + *graph_args, **oneflow_kwargs ) try: @@ -379,8 +384,29 @@ def build(self): test_g_res = oneflow_res else: pass + if verbose: + print( + "Run graph of function: ", + repr(oneflow), + ", graph check is intentionally skiped.", + ) + elif oneflow.__name__ == "Parameter": + # nn.Graph donot deal with Parameter creation. + test_g_res = oneflow_res + if verbose: + print( + "Run graph of function: ", + repr(oneflow), + ", graph check is intentionally skiped.", + ) else: test_g = TestGraphOfFunctional() + if verbose: + print( + "Run graph of function: ", + repr(oneflow), + ) + test_g.debug(3) test_g_res = test_g() except Exception as e: print_note_fake_program() @@ -439,6 +465,9 @@ def build(self): try: test_g = TestGraphOfTensorMethod() + if verbose: + print("Run graph of method: ", repr(oneflow)) + test_g.debug(3) test_g_res = test_g() except Exception as e: print_note_fake_program() @@ -676,6 +705,10 @@ def autotest( ): verbose = os.getenv("ONEFLOW_TEST_VERBOSE") is not None + if check_graph == "ValidatedFlase": + # check graph is intentionally closed and threre is a validated reason. + check_graph = False + def deco(f): @functools.wraps(f) def new_f(test_case): From a86e78085eee7042ac9cd30e11068eef3e7fcf5e Mon Sep 17 00:00:00 2001 From: lixiang <88304454@qq.com> Date: Tue, 18 Jan 2022 15:56:01 +0800 Subject: [PATCH 2/6] Fix --- docs/source/oneflow.rst | 1 + python/oneflow/__init__.py | 1 + python/oneflow/nn/modules/is_tensor.py | 42 +++++++++++++++++++ .../torch_flow_dual_object.py | 20 +++++++-- 4 files changed, 60 insertions(+), 4 deletions(-) create mode 100644 python/oneflow/nn/modules/is_tensor.py diff --git a/docs/source/oneflow.rst b/docs/source/oneflow.rst index 1a9133ac372..1301c98c04d 100644 --- a/docs/source/oneflow.rst +++ b/docs/source/oneflow.rst @@ -147,6 +147,7 @@ oneflow zeros, zeros_like, is_nonzero, + is_tensor, no_grad, grad_enable, inference_mode, diff --git a/python/oneflow/__init__.py b/python/oneflow/__init__.py index 055f3f96f4b..761ec851fa8 100755 --- a/python/oneflow/__init__.py +++ b/python/oneflow/__init__.py @@ -296,6 +296,7 @@ def atexit_hook(hook): adaptive_avg_pool2d, adaptive_avg_pool3d, ) +from oneflow.nn.modules.is_tensor import is_tensor_op as is_tensor from oneflow.nn.modules.arange import arange_op as arange from oneflow.nn.modules.linspace import linspace_op as linspace from oneflow.nn.modules.argsort import argsort_op as argsort diff --git a/python/oneflow/nn/modules/is_tensor.py b/python/oneflow/nn/modules/is_tensor.py new file mode 100644 index 00000000000..13b2f0b4a87 --- /dev/null +++ b/python/oneflow/nn/modules/is_tensor.py @@ -0,0 +1,42 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import oneflow as flow + +def is_tensor_op(obj): + r""" + is_tensor(input) -> (bool) + + Note that this function is simply doing ``isinstance(obj, Tensor)``. + Using that ``isinstance`` check is better for typechecking with mypy, + and more explicit - so it's recommended to use that instead of + ``is_tensor``. + + Args: + obj (Object): Object to test + + For example: + + .. code-block:: python + + >>> import oneflow as flow + + >>> x=flow.tensor([1,2,3]) + >>> flow.is_tensor(x) + True + + """ + return isinstance(obj,flow.Tensor) diff --git a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py index 50bc3bc8461..400177bbd2c 100644 --- a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py +++ b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py @@ -320,8 +320,12 @@ def dual_method(self, *args, **kwargs): else: graph_args = [] for arg in oneflow_args: - copy_arg = copy.deepcopy(arg) + if flow.is_tensor(arg): + copy_arg=arg.clone() + else: + copy_arg = copy.deepcopy(arg) graph_args.append(copy_arg) + graph_kwargs=copy.deepcopy(oneflow_kwargs) oneflow_res = oneflow(*oneflow_args, **oneflow_kwargs) if testing_graph: find_check_module_func = True @@ -341,7 +345,7 @@ def build(self, *args): if verbose: print("Run graph of module: ", repr(oneflow)) test_g.debug(3) - test_g_res = test_g(*oneflow_args) + test_g_res = test_g(*graph_args) elif oneflow.__name__ in ignore_apis_list: find_check_module_func = False # 1. "oneflow.nn.modules" not in oneflow.__module__: For avoid run nn.Module branch graph test, like fold op call Fold Module actually. @@ -362,7 +366,7 @@ def __init__(self): def build(self): return oneflow( - *graph_args, **oneflow_kwargs + *graph_args, **graph_kwargs ) try: @@ -451,6 +455,14 @@ def dual_method(self, *args, **kwargs): "PyTorch has an error but OneFlow is ok, maybe you should check your implementation to align with PyTorch." ) raise PyTorchDoesNotSupportError(e) + tensor_graph_args = [] + for arg in oneflow_args: + if flow.is_tensor(arg): + copy_arg=arg.clone() + else: + copy_arg = copy.deepcopy(arg) + tensor_graph_args.append(copy_arg) + tensor_graph_kwargs=copy.deepcopy(oneflow_kwargs) oneflow_res = oneflow_method(*oneflow_args, **oneflow_kwargs) if testing_graph: @@ -460,7 +472,7 @@ def __init__(self): def build(self): return oneflow_method( - *oneflow_args, **oneflow_kwargs + *tensor_graph_args, **tensor_graph_kwargs ) try: From 35bac2dade7f1f7bb645053b9609ddb1c727ba94 Mon Sep 17 00:00:00 2001 From: lixiang <88304454@qq.com> Date: Tue, 18 Jan 2022 15:57:39 +0800 Subject: [PATCH 3/6] Format --- python/oneflow/nn/modules/is_tensor.py | 3 ++- .../torch_flow_dual_object.py | 16 +++++++--------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/python/oneflow/nn/modules/is_tensor.py b/python/oneflow/nn/modules/is_tensor.py index 13b2f0b4a87..49dacaffed2 100644 --- a/python/oneflow/nn/modules/is_tensor.py +++ b/python/oneflow/nn/modules/is_tensor.py @@ -16,6 +16,7 @@ import oneflow as flow + def is_tensor_op(obj): r""" is_tensor(input) -> (bool) @@ -39,4 +40,4 @@ def is_tensor_op(obj): True """ - return isinstance(obj,flow.Tensor) + return isinstance(obj, flow.Tensor) diff --git a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py index 400177bbd2c..dc4c07f829e 100644 --- a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py +++ b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py @@ -321,11 +321,11 @@ def dual_method(self, *args, **kwargs): graph_args = [] for arg in oneflow_args: if flow.is_tensor(arg): - copy_arg=arg.clone() - else: + copy_arg = arg.clone() + else: copy_arg = copy.deepcopy(arg) graph_args.append(copy_arg) - graph_kwargs=copy.deepcopy(oneflow_kwargs) + graph_kwargs = copy.deepcopy(oneflow_kwargs) oneflow_res = oneflow(*oneflow_args, **oneflow_kwargs) if testing_graph: find_check_module_func = True @@ -365,9 +365,7 @@ def __init__(self): super().__init__() def build(self): - return oneflow( - *graph_args, **graph_kwargs - ) + return oneflow(*graph_args, **graph_kwargs) try: # When the tensor on the cpu executes to to the cpu in nn.Graph, a check error will be reported. @@ -458,11 +456,11 @@ def dual_method(self, *args, **kwargs): tensor_graph_args = [] for arg in oneflow_args: if flow.is_tensor(arg): - copy_arg=arg.clone() - else: + copy_arg = arg.clone() + else: copy_arg = copy.deepcopy(arg) tensor_graph_args.append(copy_arg) - tensor_graph_kwargs=copy.deepcopy(oneflow_kwargs) + tensor_graph_kwargs = copy.deepcopy(oneflow_kwargs) oneflow_res = oneflow_method(*oneflow_args, **oneflow_kwargs) if testing_graph: From 5e95e1535e1b54b8e360d657ebcefbebf9c9915e Mon Sep 17 00:00:00 2001 From: lixiang <88304454@qq.com> Date: Tue, 18 Jan 2022 16:32:41 +0800 Subject: [PATCH 4/6] Fix --- .../test_utils/automated_test_util/torch_flow_dual_object.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py index dc4c07f829e..1716642372f 100644 --- a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py +++ b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py @@ -321,7 +321,7 @@ def dual_method(self, *args, **kwargs): graph_args = [] for arg in oneflow_args: if flow.is_tensor(arg): - copy_arg = arg.clone() + copy_arg = arg.clone().detach() else: copy_arg = copy.deepcopy(arg) graph_args.append(copy_arg) @@ -345,6 +345,7 @@ def build(self, *args): if verbose: print("Run graph of module: ", repr(oneflow)) test_g.debug(3) + # When testing module methods, kwargs are not considered. test_g_res = test_g(*graph_args) elif oneflow.__name__ in ignore_apis_list: find_check_module_func = False @@ -456,7 +457,7 @@ def dual_method(self, *args, **kwargs): tensor_graph_args = [] for arg in oneflow_args: if flow.is_tensor(arg): - copy_arg = arg.clone() + copy_arg = arg.clone().detach() else: copy_arg = copy.deepcopy(arg) tensor_graph_args.append(copy_arg) From 967715e9fa3060eaa1d4dd3be4f4b6fd8631d58c Mon Sep 17 00:00:00 2001 From: lixiang <88304454@qq.com> Date: Tue, 18 Jan 2022 17:25:39 +0800 Subject: [PATCH 5/6] Fix kwargs --- .../automated_test_util/torch_flow_dual_object.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py index 1716642372f..c90a00d9579 100644 --- a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py +++ b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py @@ -325,7 +325,12 @@ def dual_method(self, *args, **kwargs): else: copy_arg = copy.deepcopy(arg) graph_args.append(copy_arg) - graph_kwargs = copy.deepcopy(oneflow_kwargs) + graph_kwargs={} + for key,value in oneflow_kwargs.items(): + if flow.is_tensor(value): + graph_kwargs[key]=value.clone().detach() + else: + graph_kwargs[key]=copy.deepcopy(value) oneflow_res = oneflow(*oneflow_args, **oneflow_kwargs) if testing_graph: find_check_module_func = True @@ -461,7 +466,12 @@ def dual_method(self, *args, **kwargs): else: copy_arg = copy.deepcopy(arg) tensor_graph_args.append(copy_arg) - tensor_graph_kwargs = copy.deepcopy(oneflow_kwargs) + tensor_graph_kwargs={} + for key,value in oneflow_kwargs.items(): + if flow.is_tensor(value): + tensor_graph_kwargs[key]=value.clone().detach() + else: + tensor_graph_kwargs[key]=copy.deepcopy(value) oneflow_res = oneflow_method(*oneflow_args, **oneflow_kwargs) if testing_graph: From 9c2305fd1372937dec98b550e9e5e5ee05ddf450 Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Tue, 18 Jan 2022 10:38:39 +0000 Subject: [PATCH 6/6] auto format by CI --- .../torch_flow_dual_object.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py index c90a00d9579..23a6baecb47 100644 --- a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py +++ b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py @@ -325,12 +325,12 @@ def dual_method(self, *args, **kwargs): else: copy_arg = copy.deepcopy(arg) graph_args.append(copy_arg) - graph_kwargs={} - for key,value in oneflow_kwargs.items(): + graph_kwargs = {} + for key, value in oneflow_kwargs.items(): if flow.is_tensor(value): - graph_kwargs[key]=value.clone().detach() + graph_kwargs[key] = value.clone().detach() else: - graph_kwargs[key]=copy.deepcopy(value) + graph_kwargs[key] = copy.deepcopy(value) oneflow_res = oneflow(*oneflow_args, **oneflow_kwargs) if testing_graph: find_check_module_func = True @@ -466,12 +466,12 @@ def dual_method(self, *args, **kwargs): else: copy_arg = copy.deepcopy(arg) tensor_graph_args.append(copy_arg) - tensor_graph_kwargs={} - for key,value in oneflow_kwargs.items(): + tensor_graph_kwargs = {} + for key, value in oneflow_kwargs.items(): if flow.is_tensor(value): - tensor_graph_kwargs[key]=value.clone().detach() + tensor_graph_kwargs[key] = value.clone().detach() else: - tensor_graph_kwargs[key]=copy.deepcopy(value) + tensor_graph_kwargs[key] = copy.deepcopy(value) oneflow_res = oneflow_method(*oneflow_args, **oneflow_kwargs) if testing_graph: