Skip to content

Commit

Permalink
[AutoParallel] Grad node info autoparallel 3 (PaddlePaddle#58727)
Browse files Browse the repository at this point in the history
* grad_node_info.cc support autoparallel 1
  • Loading branch information
wanghuancoder authored and SecretXV committed Nov 28, 2023
1 parent 9182cef commit 31a2ba5
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
18 changes: 18 additions & 0 deletions paddle/fluid/eager/grad_node_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,24 @@ void GradNodeBase::HandleComplexGradToRealGrad(
fwd_data_type, curr_data_type, *grad_dense_tensor, out.get());

(*out_grads)[slot_id][rank_id].set_impl(out);
} else if (phi::distributed::DistTensor::classof(grad.impl().get())) {
auto grad_dense_tensor =
static_cast<phi::distributed::DistTensor*>(grad.impl().get())
->value();

auto curr_data_type =
paddle::framework::TransToProtoVarType(grad_dense_tensor.type());
if (!paddle::framework::IsComplexType(curr_data_type)) continue;
if (grad_dense_tensor.dims().size() == -1) continue;

// Convert Complex GradOut to Real
auto out = std::make_shared<phi::DenseTensor>();
paddle::framework::TransComplexToReal(
fwd_data_type, curr_data_type, grad_dense_tensor, out.get());

*(static_cast<phi::distributed::DistTensor*>(
(*out_grads)[slot_id][rank_id].impl().get())
->unsafe_mutable_value()) = *(out.get());
}
}
}
Expand Down
37 changes: 37 additions & 0 deletions test/auto_parallel/semi_auto_parallel_for_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def test_body(
dist_out = paddle.matmul(
dist_x, dist_y, transpose_x=trans_x, transpose_y=trans_y
)

self.check_tensor_eq(out, dist_out)

out.backward()
Expand Down Expand Up @@ -227,6 +228,41 @@ def test_matmul_x_row_shard_trans_y(self):
)
assert dist_y_grad.dist_attr._is_partial() is False

def test_matmul_with_complex_type(self):
paddle.seed(self._seed)
np.random.seed(self._seed)

x_np = np.random.random(size=[64, 32]).astype(np.complex128)
y_np = np.random.random(size=[32, 48]).astype(np.float32)
x = paddle.to_tensor(x_np)
y = paddle.to_tensor(y_np)
x.stop_gradient = False
y.stop_gradient = False

x_dist_attr = dist.DistAttr(
mesh=self._mesh, sharding_specs=[None, None]
)
y_dist_attr = dist.DistAttr(
mesh=self._mesh, sharding_specs=[None, None]
)

dist_x = dist.shard_tensor(x_np, dist_attr=x_dist_attr)
dist_y = dist.shard_tensor(y_np, dist_attr=y_dist_attr)
dist_x.stop_gradient = False
dist_y.stop_gradient = False

out = paddle.matmul(x, y, transpose_x=False, transpose_y=False)
dist_out = paddle.matmul(
dist_x, dist_y, transpose_x=False, transpose_y=False
)

self.check_tensor_eq(out, dist_out)

out.backward()
dist_out.backward()
self.check_tensor_eq(x.grad, dist_x.grad)
self.check_tensor_eq(y.grad, dist_y.grad)

def run_test_case(self):
if self._backend == "cpu":
paddle.set_device("cpu")
Expand All @@ -240,6 +276,7 @@ def run_test_case(self):
self.test_matmul_x_column_shard_trans_x_y()
self.test_matmul_x_column_shard_trans_x()
self.test_matmul_x_row_shard_trans_y()
self.test_matmul_with_complex_type()


if __name__ == '__main__':
Expand Down

0 comments on commit 31a2ba5

Please sign in to comment.