Skip to content

Commit

Permalink
Reenable the distributed checkpointing test (#8424)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Dec 2, 2024
1 parent 1c91219 commit 591c397
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
3 changes: 1 addition & 2 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ python3 test/pjrt/test_collective_ops_tpu.py
python3 test/spmd/test_mp_input_sharding.py
python3 test/spmd/test_xla_sharding.py
python3 test/spmd/test_xla_virtual_device.py
# TODO(JackCaoG): to reenable
# python3 test/spmd/test_xla_distributed_checkpoint.py
python3 test/spmd/test_xla_distributed_checkpoint.py
python3 test/spmd/test_train_spmd_linear_model.py
python3 test/spmd/test_xla_spmd_python_api_interaction.py
python3 test/spmd/test_xla_auto_sharding.py
Expand Down
12 changes: 11 additions & 1 deletion torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,17 @@ void XLATensor::UpdateFromTensor(at::Tensor tensor, bool sync) {
at::Tensor coyped_tensor = torch::lazy::CopyTensor(tensor, dtype());
SetTensorData(coyped_tensor);
data()->handle = nullptr;
data()->sharding = nullptr;
// if shape is different,
if (data()->sharding) {
auto coyped_tensor_dims = XlaHelpers::I64List(coyped_tensor.sizes());
auto sharding_dims = data()->sharding->shape.dimensions();
if (coyped_tensor_dims != sharding_dims) {
// sharding shape from origional tensor is different from the new cpu
// tensor, we need to clear the sharding here.
ClearShardingSpec();
}
}
// ClearShardingSpec();
AssignIrValue(torch::lazy::Value());
if (data()->view != nullptr) {
torch::lazy::Value ir_value = GetIrValueForTensor(coyped_tensor, device);
Expand Down

0 comments on commit 591c397

Please sign in to comment.