Skip to content

Commit

Permalink
support dist tensor in reshape api
Browse files Browse the repository at this point in the history
  • Loading branch information
LiYuRio committed Mar 5, 2024
1 parent d07406f commit 72cecf1
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 3 deletions.
29 changes: 29 additions & 0 deletions paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1449,10 +1449,39 @@ static PyObject* tensor__getitem_from_offset(TensorObject* self,
PyObject* kwargs) {
EAGER_TRY
phi::DenseTensor* ptr = nullptr;
phi::DenseTensor tensor_after_reshard;
if (self->tensor.is_selected_rows()) {
auto* selected_rows =
static_cast<phi::SelectedRows*>(self->tensor.impl().get());
ptr = static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
} else if (self->tensor.is_dist_tensor()) {
#ifdef PADDLE_WITH_DISTRIBUTE
auto* dist_tensor =
static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
PADDLE_ENFORCE(dist_tensor->initialized(),
"The input dist tensor can't be uninitialized for we don't "
"know the correct mesh to be reshard.");
const auto& placements = dist_tensor->placements();
bool need_reshard = false;
for (const auto& placement : placements) {
if (!placement->is_replicated()) {
need_reshard = true;
break;
}
}
if (need_reshard) {
tensor_after_reshard = ReshardXToReplicated(dist_tensor);
ptr = &tensor_after_reshard;
} else {
ptr = dist_tensor->unsafe_mutable_value();
}
#else
PADDLE_THROW(platform::errors::Unavailable(
"The `_getitem_from_offset` method of (Dist)Tensor is not supported "
"in the current PaddlePaddle, please recompile and install "
"PaddlePaddle "
"with the option of `WITH_DISTRIBUTE=ON`."));
#endif
} else {
ptr = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
}
Expand Down
2 changes: 1 addition & 1 deletion test/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100)
py_test_modules(test_dist_tensor_api MODULES test_dist_tensor_api)
set_tests_properties(test_dist_tensor_api
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100)
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 200)
py_test_modules(test_gpt_with_pir MODULES test_gpt_with_pir)
set_tests_properties(test_gpt_with_pir PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 100)
Expand Down
47 changes: 47 additions & 0 deletions test/auto_parallel/semi_auto_parallel_for_item.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from semi_auto_parallel_util import SemiAutoParallelTestBase

import paddle
import paddle.distributed as dist


class TestItemApiForSemiAutoParallel(SemiAutoParallelTestBase):
def __init__(self):
super().__init__()
paddle.seed(self._seed)
np.random.seed(self._seed)

def test_item_api(self):
mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
a = paddle.rand(shape=[6, 8])
b = dist.shard_tensor(a, mesh, [dist.Shard(0)])
np.testing.assert_equal(b.item(0, 0), a[0][0].item())
np.testing.assert_equal(b.item(3, 5), a[3][5].item())

def run_test_case(self):
if self._backend == "cpu":
paddle.set_device("cpu")
elif self._backend == "gpu":
paddle.set_device("gpu:" + str(dist.get_rank()))
else:
raise ValueError("Only support cpu or gpu backend.")

self.test_item_api()


if __name__ == '__main__':
TestItemApiForSemiAutoParallel().run_test_case()
15 changes: 13 additions & 2 deletions test/auto_parallel/semi_auto_parallel_for_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,26 @@ def test_reshape_infer_shape(self):
assert y.shape == [30, 20, 10]
assert y._local_shape == [15, 20, 10]

def test_shape_api_with_reshape(self):
mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
a = paddle.rand(shape=[4, 6, 8])
b = dist.shard_tensor(a, mesh, [dist.Shard(0)])

dist_shape = paddle.shape(b)
b = b.reshape((-1, dist_shape[-1]))
assert b.shape == [24, 8]
assert b._local_shape == [12, 8]

def run_test_case(self):
if self._backend == "cpu":
paddle.set_device("cpu")
elif self._backend == "gpu":
paddle.set_device("gpu:" + str(dist.get_rank()))
else:
raise ValueError("Only support cpu or gpu backend.")
self.test_reshape_forward()
self.test_reshape_infer_shape()
# self.test_reshape_forward()
# self.test_reshape_infer_shape()
self.test_shape_api_with_reshape()


if __name__ == '__main__':
Expand Down
10 changes: 10 additions & 0 deletions test/auto_parallel/test_semi_auto_parallel_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,16 @@ def test_reshape_api(self):
user_defined_envs=envs,
)

def test_item_api(self):
envs_list = test_base.gen_product_envs_list(
self._default_envs, self._changeable_envs
)
for envs in envs_list:
self.run_test_case(
"semi_auto_parallel_for_item.py",
user_defined_envs=envs,
)

def test_squeeze_api(self):
envs_list = test_base.gen_product_envs_list(
self._default_envs, self._changeable_envs
Expand Down

0 comments on commit 72cecf1

Please sign in to comment.