Skip to content

Commit 8fbdba7

Browse files
authored
Fix stride computation formula used during compute estimation (#42)
Turns out the previous PR #37 was not correct. It divided the wrong dim's stride. This PR divides the dim to the left of the one being sharded, which is what really happens. Note: that we have this util at all is worrying me. Why don't we just use dtensors to propagate?
1 parent 233d68b commit 8fbdba7

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

autoparallel/compute_estimation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,10 @@ def _get_sharded_shape_stride(spec):
169169
if placement.is_shard():
170170
dim = placement.dim
171171
new_tensor_shape[dim] = (new_tensor_shape[dim] + mesh_size - 1) // mesh_size
172-
new_tensor_stride[dim] = (
173-
new_tensor_stride[dim] + mesh_size - 1
174-
) // mesh_size
172+
if dim - 1 > 0:
173+
new_tensor_stride[dim - 1] = (
174+
new_tensor_stride[dim - 1] + mesh_size - 1
175+
) // mesh_size
175176
return new_tensor_shape, new_tensor_stride
176177

177178

0 commit comments

Comments
 (0)