Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix global mode recursive call #10056

Merged
merged 18 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions oneflow/api/python/functional/tensor_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class TensorWithDataFunctor {
// its a eager tensor by Run functional::Empty() in LazyMode::Grad(false)
LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false);
if (GlobalMode::is_enabled()) {
auto global_mode_gurad = GlobalMode::Guard(false);
return JUST(
functional::GlobalTensorWithData(data, dtype, GetGlobalParallelDescFromDevice(device),
*JUST(GetSbpList(GlobalMode::nd_sbp())), requires_grad));
Expand Down
12 changes: 9 additions & 3 deletions oneflow/core/framework/op_interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class LazyInterpreter : public OpExprInterpreter {

class EagerInterpreter : public OpExprInterpreter {
public:
EagerInterpreter() : OpExprInterpreter() {}
EagerInterpreter(bool is_local) : OpExprInterpreter(), is_local_(is_local) {}
virtual ~EagerInterpreter() = default;

Maybe<void> Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs,
Expand All @@ -134,14 +134,20 @@ class EagerInterpreter : public OpExprInterpreter {
Maybe<void> Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs,
const OpExprInterpContext& ctx) const override;

protected:
// NOTE(lixiang): To ensure the correctness of GlobalMode, check whether it is a local operation and
// initialize it as true when using EagerLocalInterpreter.
// Used by Maybe<void> EagerInterpreter::Apply.
bool is_local_;

private:
FOR_EACH_BUILTIN_OPS(DECLARE_PURE_VIRTUAL_APPLY_FUNC);
DECLARE_NORMAL_APPLY_FUNC(FunctionOp);
};

class EagerGlobalInterpreter : public EagerInterpreter {
public:
EagerGlobalInterpreter() : EagerInterpreter() {}
EagerGlobalInterpreter() : EagerInterpreter(false) {}
virtual ~EagerGlobalInterpreter() = default;

private:
Expand All @@ -150,7 +156,7 @@ class EagerGlobalInterpreter : public EagerInterpreter {

class EagerLocalInterpreter : public EagerInterpreter {
public:
EagerLocalInterpreter() : EagerInterpreter() {}
EagerLocalInterpreter() : EagerInterpreter(true) {}
virtual ~EagerLocalInterpreter() = default;

private:
Expand Down
6 changes: 6 additions & 0 deletions oneflow/core/framework/op_interpreter/op_interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ Maybe<void> LazyInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& inp

Maybe<void> EagerInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& inputs,
TensorTuple* outputs, const OpExprInterpContext& ctx) const {
// In the op interpreter, judge whether to open the global mode to avoid recursion caused by
// GlobalMode.
// The global mode is enabled only if it was enabled and the current operation is a local operation.
auto global_mode_gurad = GlobalMode::Guard(GlobalMode::is_enabled() && is_local_,
GlobalMode::nd_sbp(), GlobalMode::parallel_desc());

#define APPLY_IF(op_type) \
if (const auto* op = dynamic_cast<const op_type##Expr*>(&op_expr)) { \
return ApplyImpl(*op, inputs, outputs, ctx); \
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,7 @@ class GlobalTensor final : public TensorIf<GlobalTensor> {
Maybe<Symbol<ParallelDesc>> parallel_desc() const override { return impl_->parallel_desc(); }
Maybe<Symbol<Device>> device() const override {
if (GlobalMode::is_enabled()) {
auto global_mode_gurad = GlobalMode::Guard(false);
const auto& device_tag = JUST(parallel_desc())->device_tag();
return JUST(Device::New(device_tag));
}
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/functional/impl/array_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ class TensorConstantFunctor {
// NOTE: this op is an source op, so the value(scalar tensor) should not have autograd status.
autograd::AutoGradMode mode(false);
if (GlobalMode::is_enabled()) {
auto global_mode_gurad = GlobalMode::Guard(false);
return JUST(functional::GlobalTensorConstant(shape, value, dtype,
GetGlobalParallelDescFromDevice(device),
*JUST(GetSbpList(GlobalMode::nd_sbp()))));
Expand Down Expand Up @@ -251,6 +252,7 @@ class ConstantFunctor {
Maybe<Tensor> operator()(const Shape& shape, const Scalar& value, const Symbol<DType>& dtype,
const Optional<Symbol<Device>>& device) const {
if (GlobalMode::is_enabled()) {
auto global_mode_gurad = GlobalMode::Guard(false);
return JUST(functional::GlobalConstant(shape, value, dtype,
GetGlobalParallelDescFromDevice(device),
*JUST(GetSbpList(GlobalMode::nd_sbp()))));
Expand Down Expand Up @@ -288,6 +290,7 @@ class EmptyFunctor {
const bool pin_memory) const {
std::shared_ptr<Tensor> empty;
if (GlobalMode::is_enabled()) {
auto global_mode_gurad = GlobalMode::Guard(false);
empty = JUST(functional::GlobalEmpty(shape, dtype, GetGlobalParallelDescFromDevice(device),
*JUST(GetSbpList(GlobalMode::nd_sbp()))));
if (dtype->is_floating_point()) { JUST(empty->set_requires_grad(requires_grad)); }
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/functional/impl/math_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1453,6 +1453,7 @@ class ArangeFunctor {
const Optional<Symbol<DType>>& dtype,
const Optional<Symbol<Device>>& device) const {
if (GlobalMode::is_enabled()) {
auto global_mode_gurad = GlobalMode::Guard(false);
return JUST(functional::GlobalArange(start, limit, delta, dtype,
GetGlobalParallelDescFromDevice(device),
*JUST(GetSbpList(GlobalMode::nd_sbp()))));
Expand Down Expand Up @@ -1557,6 +1558,7 @@ class HannWindowFunctor {
const Optional<Symbol<Device>>& device,
const Optional<Symbol<DType>>& dtype, const bool& requires_grad) const {
if (GlobalMode::is_enabled()) {
auto global_mode_gurad = GlobalMode::Guard(false);
return JUST(functional::GlobalHannWindow(
window_length, periodic, GetGlobalParallelDescFromDevice(device),
*JUST(GetSbpList(GlobalMode::nd_sbp())), dtype, requires_grad));
Expand Down
4 changes: 4 additions & 0 deletions oneflow/core/functional/impl/random_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ class RandFunctor {
const Optional<one::Generator>& generator,
const bool& requires_grad) const {
if (GlobalMode::is_enabled()) {
auto global_mode_gurad = GlobalMode::Guard(false);
return JUST(functional::GlobalRand(shape, GetGlobalParallelDescFromDevice(device),
*JUST(GetSbpList(GlobalMode::nd_sbp())), dtype, generator,
requires_grad));
Expand Down Expand Up @@ -264,6 +265,7 @@ class RandNFunctor {
const Optional<one::Generator>& generator, const bool& requires_grad,
const Symbol<Layout>& layout) const {
if (GlobalMode::is_enabled()) {
auto global_mode_gurad = GlobalMode::Guard(false);
return JUST(functional::GlobalRandN(shape, GetGlobalParallelDescFromDevice(device),
*JUST(GetSbpList(GlobalMode::nd_sbp())), dtype, generator,
requires_grad));
Expand Down Expand Up @@ -476,6 +478,7 @@ class RandIntFunctor {
const Optional<one::Generator>& generator,
const bool& requires_grad) const {
if (GlobalMode::is_enabled()) {
auto global_mode_gurad = GlobalMode::Guard(false);
return JUST(functional::GlobalRandInt(
low, high, shape, GetGlobalParallelDescFromDevice(device),
*JUST(GetSbpList(GlobalMode::nd_sbp())), dtype, generator, requires_grad));
Expand Down Expand Up @@ -618,6 +621,7 @@ class RandPermFunctor {
const Symbol<DType>& dtype, const Optional<Symbol<Device>>& device,
const bool& requires_grad) const {
if (GlobalMode::is_enabled()) {
auto global_mode_gurad = GlobalMode::Guard(false);
return JUST(functional::GlobalRandPerm(n, GetGlobalParallelDescFromDevice(device),
*JUST(GetSbpList(GlobalMode::nd_sbp())), generator,
dtype, requires_grad));
Expand Down
19 changes: 19 additions & 0 deletions python/oneflow/test/graph/test_graph_with_global.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,24 @@ def build(self):
test_case.assertEqual(v.sbp[0], B, k)


def _test_global_mode_with_default_placement_and_sbp(test_case):
# create a tensor with broadcast split and placement on rank 0
a = flow.randn(
(1, 8), sbp=flow.sbp.broadcast, placement=flow.placement("cuda", ranks=[0])
)
# enter global mode with broadcast split and placement on 2 GPUs
with global_mode(
True,
placement=flow.placement(type="cuda", ranks=[0, 1]),
sbp=flow.sbp.broadcast,
):
# check tensor placement and split
test_case.assertTrue(a.placement == flow.placement("cuda", ranks=[0]))
test_case.assertTrue(a.sbp == (flow.sbp.broadcast,))
# check tensor print
print(a)


@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@flow.unittest.skip_unless_1n2d()
class TestLinearTrainGraphWithDDP(oneflow.unittest.TestCase):
Expand All @@ -277,6 +295,7 @@ def test_linear_train_graph_with_ddp(test_case):

def test_global_mode(test_case):
_test_global_mode(test_case)
_test_global_mode_with_default_placement_and_sbp(test_case)


if __name__ == "__main__":
Expand Down