Skip to content

Commit

Permalink
fix comment
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed Oct 2, 2024
1 parent aea22ab commit ceb6799
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,7 +1365,7 @@ def test_data_loader_with_non_batch_size_and_mini_batch(self):
mesh, ('data', None, None, None), minibatch=True))
with self.assertRaisesRegex(
RuntimeError,
"When minibatch is configured, batch dimension of the tensor must be divisible by data mesh*"
"When minibatch is configured, batch dimension of the tensor must be divisible by local runtime device count*"
):
data, _ = iter(train_device_loader).__next__()

Expand Down
13 changes: 8 additions & 5 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,16 +1301,19 @@ def convert_fn(tensors):
shardings = [input_sharding.xla_spec(t) for t in tensors]
if input_sharding and input_sharding.minibatch:
# when minibatch is configured we must make sure batch dimension of
# the tensor is divisible by the data mesh dimension.
data_mesh_dim = input_sharding.mesh.mesh_shape[0]
# the tensor is divisible by the local runtime device count.
for tensor, sharding in zip(tensors, shardings):
# assume batch dimension is 0
local_runtime_device_count = torch_xla.runtime.addressable_runtime_device_count(
)
if sharding and tensor.dim() > 0 and (tensor.size()[0] %
data_mesh_dim) != 0:
local_runtime_device_count) != 0:
raise RuntimeError(
"When minibatch is configured, batch dimension of the tensor " +
"must be divisible by data mesh dimension.input data shape " +
f"={tensor.size()}, mesh data dimension = {data_mesh_dim}")
"must be divisible by local runtime device count.input data shape "
+
f"={tensor.size()}, local_runtime_device_count = {local_runtime_device_count}"
)

xtensors = torch_xla._XLAC._xla_tensors_from_aten(tensors, devices,
shardings)
Expand Down

0 comments on commit ceb6799

Please sign in to comment.