Skip to content

Commit

Permalink
Fix global mode print tensor bug (#10056)
Browse files Browse the repository at this point in the history
## This PR is done:

Fix: Oneflow-Inc/OneTeam#1942

test with:
```python
import oneflow as flow
from oneflow.utils.global_view import global_mode

a = flow.randn(
    (1, 8),
    sbp=flow.sbp.broadcast,
    placement=flow.placement("cuda", ranks=[0])
)

with global_mode(True, placement=flow.placement(type="cuda", ranks=[0,1,2,3]),sbp=flow.sbp.broadcast):
    print(a)
```
output:

![image](https://user-images.githubusercontent.com/54010254/228256534-0414f773-263a-4759-8789-849ecf184848.png)

---------

Co-authored-by: Xiaoyu Xu <xiaoyulink@gmail.com>
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Mar 29, 2023
1 parent b305117 commit 7e7fb20
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 3 deletions.
1 change: 1 addition & 0 deletions oneflow/api/python/functional/tensor_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class TensorWithDataFunctor {
// its a eager tensor by Run functional::Empty() in LazyMode::Grad(false)
LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false);
if (GlobalMode::is_enabled()) {
auto global_mode_gurad = GlobalMode::Guard(false);
return JUST(
functional::GlobalTensorWithData(data, dtype, GetGlobalParallelDescFromDevice(device),
*JUST(GetSbpList(GlobalMode::nd_sbp())), requires_grad));
Expand Down
12 changes: 9 additions & 3 deletions oneflow/core/framework/op_interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class LazyInterpreter : public OpExprInterpreter {

class EagerInterpreter : public OpExprInterpreter {
public:
EagerInterpreter() : OpExprInterpreter() {}
EagerInterpreter(bool is_local) : OpExprInterpreter(), is_local_(is_local) {}
virtual ~EagerInterpreter() = default;

Maybe<void> Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs,
Expand All @@ -134,14 +134,20 @@ class EagerInterpreter : public OpExprInterpreter {
Maybe<void> Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs,
const OpExprInterpContext& ctx) const override;

protected:
// NOTE(lixiang): To ensure the correctness of GlobalMode, check whether it is a local operation
// and initialize it as true when using EagerLocalInterpreter.
// Used by Maybe<void> EagerInterpreter::Apply.
bool is_local_;

private:
FOR_EACH_BUILTIN_OPS(DECLARE_PURE_VIRTUAL_APPLY_FUNC);
DECLARE_NORMAL_APPLY_FUNC(FunctionOp);
};

class EagerGlobalInterpreter : public EagerInterpreter {
public:
EagerGlobalInterpreter() : EagerInterpreter() {}
EagerGlobalInterpreter() : EagerInterpreter(false) {}
virtual ~EagerGlobalInterpreter() = default;

private:
Expand All @@ -150,7 +156,7 @@ class EagerGlobalInterpreter : public EagerInterpreter {

class EagerLocalInterpreter : public EagerInterpreter {
public:
EagerLocalInterpreter() : EagerInterpreter() {}
EagerLocalInterpreter() : EagerInterpreter(true) {}
virtual ~EagerLocalInterpreter() = default;

private:
Expand Down
7 changes: 7 additions & 0 deletions oneflow/core/framework/op_interpreter/op_interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ Maybe<void> LazyInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& inp

Maybe<void> EagerInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& inputs,
TensorTuple* outputs, const OpExprInterpContext& ctx) const {
// In the op interpreter, judge whether to open the global mode to avoid recursion caused by
// GlobalMode.
// The global mode is enabled only if it was enabled and the current operation is a local
// operation.
auto global_mode_gurad = GlobalMode::Guard(GlobalMode::is_enabled() && is_local_,
GlobalMode::nd_sbp(), GlobalMode::parallel_desc());

#define APPLY_IF(op_type) \
if (const auto* op = dynamic_cast<const op_type##Expr*>(&op_expr)) { \
return ApplyImpl(*op, inputs, outputs, ctx); \
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,7 @@ class GlobalTensor final : public TensorIf<GlobalTensor> {
Maybe<Symbol<ParallelDesc>> parallel_desc() const override { return impl_->parallel_desc(); }
Maybe<Symbol<Device>> device() const override {
if (GlobalMode::is_enabled()) {
auto global_mode_gurad = GlobalMode::Guard(false);
const auto& device_tag = JUST(parallel_desc())->device_tag();
return JUST(Device::New(device_tag));
}
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/functional/impl/array_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ class TensorConstantFunctor {
// NOTE: this op is an source op, so the value(scalar tensor) should not have autograd status.
autograd::AutoGradMode mode(false);
if (GlobalMode::is_enabled()) {
auto global_mode_gurad = GlobalMode::Guard(false);
return JUST(functional::GlobalTensorConstant(shape, value, dtype,
GetGlobalParallelDescFromDevice(device),
*JUST(GetSbpList(GlobalMode::nd_sbp()))));
Expand Down Expand Up @@ -251,6 +252,7 @@ class ConstantFunctor {
Maybe<Tensor> operator()(const Shape& shape, const Scalar& value, const Symbol<DType>& dtype,
const Optional<Symbol<Device>>& device) const {
if (GlobalMode::is_enabled()) {
auto global_mode_gurad = GlobalMode::Guard(false);
return JUST(functional::GlobalConstant(shape, value, dtype,
GetGlobalParallelDescFromDevice(device),
*JUST(GetSbpList(GlobalMode::nd_sbp()))));
Expand Down Expand Up @@ -288,6 +290,7 @@ class EmptyFunctor {
const bool pin_memory) const {
std::shared_ptr<Tensor> empty;
if (GlobalMode::is_enabled()) {
auto global_mode_gurad = GlobalMode::Guard(false);
empty = JUST(functional::GlobalEmpty(shape, dtype, GetGlobalParallelDescFromDevice(device),
*JUST(GetSbpList(GlobalMode::nd_sbp()))));
if (dtype->is_floating_point()) { JUST(empty->set_requires_grad(requires_grad)); }
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/functional/impl/math_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1453,6 +1453,7 @@ class ArangeFunctor {
const Optional<Symbol<DType>>& dtype,
const Optional<Symbol<Device>>& device) const {
if (GlobalMode::is_enabled()) {
auto global_mode_gurad = GlobalMode::Guard(false);
return JUST(functional::GlobalArange(start, limit, delta, dtype,
GetGlobalParallelDescFromDevice(device),
*JUST(GetSbpList(GlobalMode::nd_sbp()))));
Expand Down Expand Up @@ -1557,6 +1558,7 @@ class HannWindowFunctor {
const Optional<Symbol<Device>>& device,
const Optional<Symbol<DType>>& dtype, const bool& requires_grad) const {
if (GlobalMode::is_enabled()) {
auto global_mode_gurad = GlobalMode::Guard(false);
return JUST(functional::GlobalHannWindow(
window_length, periodic, GetGlobalParallelDescFromDevice(device),
*JUST(GetSbpList(GlobalMode::nd_sbp())), dtype, requires_grad));
Expand Down
4 changes: 4 additions & 0 deletions oneflow/core/functional/impl/random_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ class RandFunctor {
const Optional<one::Generator>& generator,
const bool& requires_grad) const {
if (GlobalMode::is_enabled()) {
auto global_mode_gurad = GlobalMode::Guard(false);
return JUST(functional::GlobalRand(shape, GetGlobalParallelDescFromDevice(device),
*JUST(GetSbpList(GlobalMode::nd_sbp())), dtype, generator,
requires_grad));
Expand Down Expand Up @@ -264,6 +265,7 @@ class RandNFunctor {
const Optional<one::Generator>& generator, const bool& requires_grad,
const Symbol<Layout>& layout) const {
if (GlobalMode::is_enabled()) {
auto global_mode_gurad = GlobalMode::Guard(false);
return JUST(functional::GlobalRandN(shape, GetGlobalParallelDescFromDevice(device),
*JUST(GetSbpList(GlobalMode::nd_sbp())), dtype, generator,
requires_grad));
Expand Down Expand Up @@ -476,6 +478,7 @@ class RandIntFunctor {
const Optional<one::Generator>& generator,
const bool& requires_grad) const {
if (GlobalMode::is_enabled()) {
auto global_mode_gurad = GlobalMode::Guard(false);
return JUST(functional::GlobalRandInt(
low, high, shape, GetGlobalParallelDescFromDevice(device),
*JUST(GetSbpList(GlobalMode::nd_sbp())), dtype, generator, requires_grad));
Expand Down Expand Up @@ -618,6 +621,7 @@ class RandPermFunctor {
const Symbol<DType>& dtype, const Optional<Symbol<Device>>& device,
const bool& requires_grad) const {
if (GlobalMode::is_enabled()) {
auto global_mode_gurad = GlobalMode::Guard(false);
return JUST(functional::GlobalRandPerm(n, GetGlobalParallelDescFromDevice(device),
*JUST(GetSbpList(GlobalMode::nd_sbp())), generator,
dtype, requires_grad));
Expand Down
19 changes: 19 additions & 0 deletions python/oneflow/test/graph/test_graph_with_global.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,24 @@ def build(self):
test_case.assertEqual(v.sbp[0], B, k)


def _test_global_mode_with_default_placement_and_sbp(test_case):
# create a tensor with broadcast split and placement on rank 0
a = flow.randn(
(1, 8), sbp=flow.sbp.broadcast, placement=flow.placement("cuda", ranks=[0])
)
# enter global mode with broadcast split and placement on 2 GPUs
with global_mode(
True,
placement=flow.placement(type="cuda", ranks=[0, 1]),
sbp=flow.sbp.broadcast,
):
# check tensor placement and split
test_case.assertTrue(a.placement == flow.placement("cuda", ranks=[0]))
test_case.assertTrue(a.sbp == (flow.sbp.broadcast,))
# check tensor print
print(a)


@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@flow.unittest.skip_unless_1n2d()
class TestLinearTrainGraphWithDDP(oneflow.unittest.TestCase):
Expand All @@ -277,6 +295,7 @@ def test_linear_train_graph_with_ddp(test_case):

def test_global_mode(test_case):
_test_global_mode(test_case)
_test_global_mode_with_default_placement_and_sbp(test_case)


if __name__ == "__main__":
Expand Down

0 comments on commit 7e7fb20

Please sign in to comment.