diff --git a/oneflow/api/python/functional/tensor_api.cpp b/oneflow/api/python/functional/tensor_api.cpp index 55a24df962f..bbe92965734 100644 --- a/oneflow/api/python/functional/tensor_api.cpp +++ b/oneflow/api/python/functional/tensor_api.cpp @@ -54,6 +54,7 @@ class TensorWithDataFunctor { // its a eager tensor by Run functional::Empty() in LazyMode::Grad(false) LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false); if (GlobalMode::is_enabled()) { + auto global_mode_gurad = GlobalMode::Guard(false); return JUST( functional::GlobalTensorWithData(data, dtype, GetGlobalParallelDescFromDevice(device), *JUST(GetSbpList(GlobalMode::nd_sbp())), requires_grad)); diff --git a/oneflow/core/framework/op_interpreter.h b/oneflow/core/framework/op_interpreter.h index d56fd3ad8c2..b8d8578e0cd 100644 --- a/oneflow/core/framework/op_interpreter.h +++ b/oneflow/core/framework/op_interpreter.h @@ -123,7 +123,7 @@ class LazyInterpreter : public OpExprInterpreter { class EagerInterpreter : public OpExprInterpreter { public: - EagerInterpreter() : OpExprInterpreter() {} + EagerInterpreter(bool is_local) : OpExprInterpreter(), is_local_(is_local) {} virtual ~EagerInterpreter() = default; Maybe Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, @@ -134,6 +134,12 @@ class EagerInterpreter : public OpExprInterpreter { Maybe Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const override; + protected: + // NOTE(lixiang): To ensure the correctness of GlobalMode, check whether it is a local operation + // and initialize it as true when using EagerLocalInterpreter. + // Used by Maybe EagerInterpreter::Apply. + bool is_local_; + private: FOR_EACH_BUILTIN_OPS(DECLARE_PURE_VIRTUAL_APPLY_FUNC); DECLARE_NORMAL_APPLY_FUNC(FunctionOp); @@ -141,7 +147,7 @@ class EagerInterpreter : public OpExprInterpreter { class EagerGlobalInterpreter : public EagerInterpreter { public: - EagerGlobalInterpreter() : EagerInterpreter() {} + EagerGlobalInterpreter() : EagerInterpreter(false) {} virtual ~EagerGlobalInterpreter() = default; private: @@ -150,7 +156,7 @@ class EagerGlobalInterpreter : public EagerInterpreter { class EagerLocalInterpreter : public EagerInterpreter { public: - EagerLocalInterpreter() : EagerInterpreter() {} + EagerLocalInterpreter() : EagerInterpreter(true) {} virtual ~EagerLocalInterpreter() = default; private: diff --git a/oneflow/core/framework/op_interpreter/op_interpreter.cpp b/oneflow/core/framework/op_interpreter/op_interpreter.cpp index 3dfdd60c564..8e85e5049d0 100644 --- a/oneflow/core/framework/op_interpreter/op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/op_interpreter.cpp @@ -50,6 +50,13 @@ Maybe LazyInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& inp Maybe EagerInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { + // In the op interpreter, judge whether to open the global mode to avoid recursion caused by + // GlobalMode. + // The global mode is enabled only if it was enabled and the current operation is a local + // operation. + auto global_mode_gurad = GlobalMode::Guard(GlobalMode::is_enabled() && is_local_, + GlobalMode::nd_sbp(), GlobalMode::parallel_desc()); + #define APPLY_IF(op_type) \ if (const auto* op = dynamic_cast(&op_expr)) { \ return ApplyImpl(*op, inputs, outputs, ctx); \ diff --git a/oneflow/core/framework/tensor.h b/oneflow/core/framework/tensor.h index e745a88e504..7c019e6ccd6 100644 --- a/oneflow/core/framework/tensor.h +++ b/oneflow/core/framework/tensor.h @@ -623,6 +623,7 @@ class GlobalTensor final : public TensorIf { Maybe> parallel_desc() const override { return impl_->parallel_desc(); } Maybe> device() const override { if (GlobalMode::is_enabled()) { + auto global_mode_gurad = GlobalMode::Guard(false); const auto& device_tag = JUST(parallel_desc())->device_tag(); return JUST(Device::New(device_tag)); } diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp index 63378cc1742..ef9f26a4728 100644 --- a/oneflow/core/functional/impl/array_functor.cpp +++ b/oneflow/core/functional/impl/array_functor.cpp @@ -168,6 +168,7 @@ class TensorConstantFunctor { // NOTE: this op is an source op, so the value(scalar tensor) should not have autograd status. autograd::AutoGradMode mode(false); if (GlobalMode::is_enabled()) { + auto global_mode_gurad = GlobalMode::Guard(false); return JUST(functional::GlobalTensorConstant(shape, value, dtype, GetGlobalParallelDescFromDevice(device), *JUST(GetSbpList(GlobalMode::nd_sbp())))); @@ -251,6 +252,7 @@ class ConstantFunctor { Maybe operator()(const Shape& shape, const Scalar& value, const Symbol& dtype, const Optional>& device) const { if (GlobalMode::is_enabled()) { + auto global_mode_gurad = GlobalMode::Guard(false); return JUST(functional::GlobalConstant(shape, value, dtype, GetGlobalParallelDescFromDevice(device), *JUST(GetSbpList(GlobalMode::nd_sbp())))); @@ -288,6 +290,7 @@ class EmptyFunctor { const bool pin_memory) const { std::shared_ptr empty; if (GlobalMode::is_enabled()) { + auto global_mode_gurad = GlobalMode::Guard(false); empty = JUST(functional::GlobalEmpty(shape, dtype, GetGlobalParallelDescFromDevice(device), *JUST(GetSbpList(GlobalMode::nd_sbp())))); if (dtype->is_floating_point()) { JUST(empty->set_requires_grad(requires_grad)); } diff --git a/oneflow/core/functional/impl/math_functor.cpp b/oneflow/core/functional/impl/math_functor.cpp index fcd1095785d..580a4a65526 100644 --- a/oneflow/core/functional/impl/math_functor.cpp +++ b/oneflow/core/functional/impl/math_functor.cpp @@ -1453,6 +1453,7 @@ class ArangeFunctor { const Optional>& dtype, const Optional>& device) const { if (GlobalMode::is_enabled()) { + auto global_mode_gurad = GlobalMode::Guard(false); return JUST(functional::GlobalArange(start, limit, delta, dtype, GetGlobalParallelDescFromDevice(device), *JUST(GetSbpList(GlobalMode::nd_sbp())))); @@ -1557,6 +1558,7 @@ class HannWindowFunctor { const Optional>& device, const Optional>& dtype, const bool& requires_grad) const { if (GlobalMode::is_enabled()) { + auto global_mode_gurad = GlobalMode::Guard(false); return JUST(functional::GlobalHannWindow( window_length, periodic, GetGlobalParallelDescFromDevice(device), *JUST(GetSbpList(GlobalMode::nd_sbp())), dtype, requires_grad)); diff --git a/oneflow/core/functional/impl/random_functor.cpp b/oneflow/core/functional/impl/random_functor.cpp index 903cff6fd14..4e532a65bc3 100644 --- a/oneflow/core/functional/impl/random_functor.cpp +++ b/oneflow/core/functional/impl/random_functor.cpp @@ -188,6 +188,7 @@ class RandFunctor { const Optional& generator, const bool& requires_grad) const { if (GlobalMode::is_enabled()) { + auto global_mode_gurad = GlobalMode::Guard(false); return JUST(functional::GlobalRand(shape, GetGlobalParallelDescFromDevice(device), *JUST(GetSbpList(GlobalMode::nd_sbp())), dtype, generator, requires_grad)); @@ -264,6 +265,7 @@ class RandNFunctor { const Optional& generator, const bool& requires_grad, const Symbol& layout) const { if (GlobalMode::is_enabled()) { + auto global_mode_gurad = GlobalMode::Guard(false); return JUST(functional::GlobalRandN(shape, GetGlobalParallelDescFromDevice(device), *JUST(GetSbpList(GlobalMode::nd_sbp())), dtype, generator, requires_grad)); @@ -476,6 +478,7 @@ class RandIntFunctor { const Optional& generator, const bool& requires_grad) const { if (GlobalMode::is_enabled()) { + auto global_mode_gurad = GlobalMode::Guard(false); return JUST(functional::GlobalRandInt( low, high, shape, GetGlobalParallelDescFromDevice(device), *JUST(GetSbpList(GlobalMode::nd_sbp())), dtype, generator, requires_grad)); @@ -618,6 +621,7 @@ class RandPermFunctor { const Symbol& dtype, const Optional>& device, const bool& requires_grad) const { if (GlobalMode::is_enabled()) { + auto global_mode_gurad = GlobalMode::Guard(false); return JUST(functional::GlobalRandPerm(n, GetGlobalParallelDescFromDevice(device), *JUST(GetSbpList(GlobalMode::nd_sbp())), generator, dtype, requires_grad)); diff --git a/python/oneflow/test/graph/test_graph_with_global.py b/python/oneflow/test/graph/test_graph_with_global.py index 47a81f46967..de0ec2d0d25 100644 --- a/python/oneflow/test/graph/test_graph_with_global.py +++ b/python/oneflow/test/graph/test_graph_with_global.py @@ -269,6 +269,24 @@ def build(self): test_case.assertEqual(v.sbp[0], B, k) +def _test_global_mode_with_default_placement_and_sbp(test_case): + # create a tensor with broadcast split and placement on rank 0 + a = flow.randn( + (1, 8), sbp=flow.sbp.broadcast, placement=flow.placement("cuda", ranks=[0]) + ) + # enter global mode with broadcast split and placement on 2 GPUs + with global_mode( + True, + placement=flow.placement(type="cuda", ranks=[0, 1]), + sbp=flow.sbp.broadcast, + ): + # check tensor placement and split + test_case.assertTrue(a.placement == flow.placement("cuda", ranks=[0])) + test_case.assertTrue(a.sbp == (flow.sbp.broadcast,)) + # check tensor print + print(a) + + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class TestLinearTrainGraphWithDDP(oneflow.unittest.TestCase): @@ -277,6 +295,7 @@ def test_linear_train_graph_with_ddp(test_case): def test_global_mode(test_case): _test_global_mode(test_case) + _test_global_mode_with_default_placement_and_sbp(test_case) if __name__ == "__main__":