Skip to content

Commit

Permalink
basically ok
Browse files Browse the repository at this point in the history
  • Loading branch information
liuzhenhai93 committed Jan 5, 2024
2 parents 350cf5d + 212ec21 commit 4faf821
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 13 deletions.
29 changes: 18 additions & 11 deletions paddle/fluid/eager/custom_operator/custom_operator_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,14 @@ std::
? paddle::get<0>(spmd_info.first[0]).process_mesh()
: paddle::get<1>(spmd_info.first[0]).at(0).process_mesh();


std::vector<std::vector<phi::DDim>> out_dims =
RunInferShapeFn(op_info, is_forward, is_double_grad, inputs, outputs, inplace_map, ctx);

std::vector<std::vector<phi::DataType>> out_dtypes =
RunInferDtypeFn(op_info, is_forward, is_double_grad, inputs, outputs, inplace_map, ctx);


if (rank_is_in_current_mesh) {
auto* dev_ctx = phi::DeviceContextPool::Instance().Get(x.at(0).place());
for (size_t i = 0; i < x.size(); ++i) {
Expand All @@ -662,13 +670,8 @@ std::
std::make_shared<phi::DenseTensor>(dist_input_i->value()));
dist_inputs.emplace_back(dist_input_i);
}
}

std::vector<std::vector<phi::DDim>> out_dims =
RunInferShapeFn(op_info, is_forward, is_double_grad, inputs, outputs, inplace_map, ctx);
}

std::vector<std::vector<phi::DataType>> out_dtypes =
RunInferDtypeFn(op_info, is_forward, is_double_grad, inputs, outputs, inplace_map, ctx);

for (size_t i = 0; i < out_dims.size(); ++i) {
const auto& out_dim = out_dims.at(i);
Expand Down Expand Up @@ -745,17 +748,21 @@ void TransCtxTensorsToDistTensors(
std::vector<Tensor>* input_all = ctx.AllMutableInput();
for (size_t i = 0; i < input_all->size(); ++i) {
auto& tensor = input_all->at(i);

phi::distributed::TensorDistAttr dist_attr;
if (!spmd_info.first.empty()) {
phi::DDim global_dims;

if(i < dist_inputs.size()){
auto& dist_input = dist_inputs.at(i);
global_dims = dist_input->dims();
dist_attr = dist_input->dist_attr();
} else {
dist_attr = PADDLE_GET_CONST(phi::distributed::TensorDistAttr,
spmd_info.first[i]);
} else {
phi::distributed::TensorDistAttr(common::vectorize(tensor.dims()));
dist_attr.set_process_mesh(current_process_mesh);
global_dims = tensor.dims();
}
auto dist_t = std::make_shared<phi::distributed::DistTensor>(
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()),
global_dims,
dist_attr);
tensor.set_impl(dist_t);
}
Expand Down
15 changes: 15 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/dist_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,21 @@ DistTensor::DistTensor(const std::shared_ptr<phi::DenseTensor>& local_value,
}
}

DistTensor::DistTensor(const std::shared_ptr<phi::DenseTensor>& local_value,
const DDim& global_dims,
const TensorDistAttr& dist_attr)
: global_dims_(global_dims), dist_attr_(dist_attr) {
process_mesh_ = dist_attr_.process_mesh();
placements_ = ToPlacements(dist_attr);
if (IsCurRankInMesh(process_mesh_)) {
value_ = local_value;
} else {
value_ = std::make_shared<DenseTensor>(
std::make_shared<phi::Allocation>(nullptr, 0, local_value->place()),
phi::DenseTensorMeta(local_value->dtype(), global_dims_));
}
}

DistTensor::DistTensor(const std::shared_ptr<phi::DenseTensor>& global_value,
const ProcessMesh& process_mesh,
const Placements& placements)
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/dist_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ class DistTensor final
const ProcessMesh& process_mesh,
const Placements& placements);

/// \brief Construct a dist tensor based local dense tensor.
/// \param global_dims The global dim of the dist tensor.
/// \param dist_attr The distributed attributes of the current tensor.
DistTensor(const std::shared_ptr<phi::DenseTensor>& local_value,
const DDim& global_dims,
const TensorDistAttr& dist_attr);

/// \brief Construct a dist tensor based local dense tensor.
/// \param global_dims The global dim of the dist tensor.
/// \param process_mesh The process mesh of the current tensor.
Expand Down
10 changes: 8 additions & 2 deletions test/auto_parallel/custom_op/semi_auto_parallel_for_custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ def __init__(self):
self._backend = os.getenv("backend")
self._seed = eval(os.getenv("seed"))

def check_placements(self, output, expected_placements):
assert (
output.placements == expected_placements
), f"{output.placements} vs {expected_placements}"


def test_custom_relu(self):
shapes = [16, 4, 4]
specs = ['x', None, None]
Expand All @@ -55,7 +61,7 @@ def test_custom_relu_no_shard(self):
op_func=custom_relu.custom_relu,
with_backward=True,
)
self.check_placements(outputs, [dist.Replicate(0)])
self.check_placements(outputs, [dist.Replicate()])

def run_test_case(self):
if self._backend == "cpu":
Expand All @@ -65,7 +71,7 @@ def run_test_case(self):
else:
raise ValueError("Only support cpu or gpu backend.")
self.test_custom_relu_no_shard()
# self.test_custom_relu()
self.test_custom_relu()


if __name__ == '__main__':
Expand Down

0 comments on commit 4faf821

Please sign in to comment.