Skip to content

Conversation

TroyGarden
Copy link
Contributor

@TroyGarden TroyGarden commented Jun 18, 2025

Summary:

context

  • original diff D74366343 broke cogwheel test and was reverted
  • the error stack P1844048578 is shown below:
  File "/dev/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/dev/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/dev/torchrec/distributed/train_pipeline/runtime_forwards.py", line 84, in __call__
    data = request.wait()
  File "/dev/torchrec/distributed/types.py", line 334, in wait
    ret: W = self._wait_impl()
  File "/dev/torchrec/distributed/embedding_sharding.py", line 655, in _wait_impl
    kjts.append(w.wait())
  File "/dev/torchrec/distributed/types.py", line 334, in wait
    ret: W = self._wait_impl()
  File "/dev/torchrec/distributed/dist_data.py", line 426, in _wait_impl
    return type(self._input).dist_init(
  File "/dev/torchrec/sparse/jagged_tensor.py", line 2993, in dist_init
    return kjt.sync()
  File "/dev/torchrec/sparse/jagged_tensor.py", line 2067, in sync
    self.length_per_key()
  File "/dev/torchrec/sparse/jagged_tensor.py", line 2281, in length_per_key
    _length_per_key = _maybe_compute_length_per_key(
  File "/dev/torchrec/sparse/jagged_tensor.py", line 1192, in _maybe_compute_length_per_key
    _length_per_key_from_stride_per_key(lengths, stride_per_key)
  File "/dev/torchrec/sparse/jagged_tensor.py", line 1144, in _length_per_key_from_stride_per_key
    if _use_segment_sum_csr(stride_per_key):
  File "/dev/torchrec/sparse/jagged_tensor.py", line 1131, in _use_segment_sum_csr
    elements_per_segment = sum(stride_per_key) / len(stride_per_key)
ZeroDivisionError: division by zero
  • the complaint is stride_per_key is an empty list, which comes from the following function call:
        stride_per_key = _maybe_compute_stride_per_key(
            self._stride_per_key,
            self._stride_per_key_per_rank,
            self.stride(),
            self._keys,
        )
  • the only place this stride_per_key could be empty is when the stride_per_key_per_rank.dim() != 2
def _maybe_compute_stride_per_key(
    stride_per_key: Optional[List[int]],
    stride_per_key_per_rank: Optional[torch.IntTensor],
    stride: Optional[int],
    keys: List[str],
) -> Optional[List[int]]:
    if stride_per_key is not None:
        return stride_per_key
    elif stride_per_key_per_rank is not None:
        if stride_per_key_per_rank.dim() != 2:
            # after permute the kjt could be empty
            return []
        rt: List[int] = stride_per_key_per_rank.sum(dim=1).tolist()
        if not torch.jit.is_scripting() and is_torchdynamo_compiling():
            pt2_checks_all_is_size(rt)
        return rt
    elif stride is not None:
        return [stride] * len(keys)
    else:
        return None

the main change from D74366343 is that the stride_per_key_per_rank in dist_init:

  • baseline
            if stagger > 1:
                stride_per_key_per_rank_stagger: List[List[int]] = []
                local_world_size = num_workers // stagger
                for i in range(len(keys)):
                    stride_per_rank_stagger: List[int] = []
                    for j in range(local_world_size):
                        stride_per_rank_stagger.extend(
                            stride_per_key_per_rank[i][j::local_world_size]
                        )
                    stride_per_key_per_rank_stagger.append(stride_per_rank_stagger)
                stride_per_key_per_rank = stride_per_key_per_rank_stagger
  • D76875546 (correct, this diff)
            if stagger > 1:
                indices = torch.arange(num_workers).view(stagger, -1).T.reshape(-1)
                stride_per_key_per_rank = stride_per_key_per_rank[:, indices]
  • D74366343 (incorrect, reverted)
            if stagger > 1:
                local_world_size = num_workers // stagger
                indices = [
                    list(range(i, num_workers, local_world_size))
                    for i in range(local_world_size)
                ]
                stride_per_key_per_rank = stride_per_key_per_rank[:, indices]

Rollback Plan:

Differential Revision: D76875546

@pytorch-bot pytorch-bot bot added the ci-no-td label Jun 18, 2025
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 18, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D76903646

Summary:
Pull Request resolved: meta-pytorch#2587

# context
* Previously for a KJT, only the following fields and `_keys` are stored in the pytree flatten specs. All other arguments/parameters would be derived accordingly.
```
    _fields = [
        "_values",
        "_weights",
        "_lengths",
        "_offsets",
    ]
```
* Particularly, the `stride` (int) of a KJT, which represents the `batch_size`, is computed by `_maybe_compute_stride_kjt`:
```
def _maybe_compute_stride_kjt(
    keys: List[str],
    stride: Optional[int],
    lengths: Optional[torch.Tensor],
    offsets: Optional[torch.Tensor],
    stride_per_key_per_rank: Optional[List[List[int]]],
) -> int:
    if stride is None:
        if len(keys) == 0:
            stride = 0
        elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
            stride = max([sum(s) for s in stride_per_key_per_rank])
        elif offsets is not None and offsets.numel() > 0:
            stride = (offsets.numel() - 1) // len(keys)
        elif lengths is not None:
            stride = lengths.numel() // len(keys)
        else:
            stride = 0
    return stride
```
* The previously stored pytree flatten specs are enough if the `batch_size` is static, however, this no longer holds true in a variable batch size scenario, where the `stride_per_key_per_rank` is not `None`.
* An example is that with `dedup_ebc`, where the actual batch_size is variable (depending on the dedup data), but the output of the ebc should always be the **true** `stride` (static).
* During ir_export, the output shape will be calculated from `kjt.stride()` function, which would be incorrect if the pytree specs only contains the `keys`.
* This diff adds the `stride` into the KJT pytree flatten/unflatten functions so that a fakified KJT would have the correct stride value.

Differential Revision: D66400821

Reviewed By: PaulZhang12
@TroyGarden TroyGarden changed the title Back out "Revert D74366343" [TorchRec] fix stride_per_key_per_rank in stagger scenario in D74366343 Jun 18, 2025
…torch#3111)

Summary:
Pull Request resolved: meta-pytorch#3111

# context
* original diff D74366343 broke cogwheel test and was reverted
* the error stack P1844048578 is shown below:
```
  File "/dev/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/dev/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/dev/torchrec/distributed/train_pipeline/runtime_forwards.py", line 84, in __call__
    data = request.wait()
  File "/dev/torchrec/distributed/types.py", line 334, in wait
    ret: W = self._wait_impl()
  File "/dev/torchrec/distributed/embedding_sharding.py", line 655, in _wait_impl
    kjts.append(w.wait())
  File "/dev/torchrec/distributed/types.py", line 334, in wait
    ret: W = self._wait_impl()
  File "/dev/torchrec/distributed/dist_data.py", line 426, in _wait_impl
    return type(self._input).dist_init(
  File "/dev/torchrec/sparse/jagged_tensor.py", line 2993, in dist_init
    return kjt.sync()
  File "/dev/torchrec/sparse/jagged_tensor.py", line 2067, in sync
    self.length_per_key()
  File "/dev/torchrec/sparse/jagged_tensor.py", line 2281, in length_per_key
    _length_per_key = _maybe_compute_length_per_key(
  File "/dev/torchrec/sparse/jagged_tensor.py", line 1192, in _maybe_compute_length_per_key
    _length_per_key_from_stride_per_key(lengths, stride_per_key)
  File "/dev/torchrec/sparse/jagged_tensor.py", line 1144, in _length_per_key_from_stride_per_key
    if _use_segment_sum_csr(stride_per_key):
  File "/dev/torchrec/sparse/jagged_tensor.py", line 1131, in _use_segment_sum_csr
    elements_per_segment = sum(stride_per_key) / len(stride_per_key)
ZeroDivisionError: division by zero
```
* the complaint is `stride_per_key` is an empty list, which comes from the following function call:
```
        stride_per_key = _maybe_compute_stride_per_key(
            self._stride_per_key,
            self._stride_per_key_per_rank,
            self.stride(),
            self._keys,
        )
```
* the only place this `stride_per_key` could be empty is when the `stride_per_key_per_rank.dim() != 2`
```
def _maybe_compute_stride_per_key(
    stride_per_key: Optional[List[int]],
    stride_per_key_per_rank: Optional[torch.IntTensor],
    stride: Optional[int],
    keys: List[str],
) -> Optional[List[int]]:
    if stride_per_key is not None:
        return stride_per_key
    elif stride_per_key_per_rank is not None:
        if stride_per_key_per_rank.dim() != 2:
            # after permute the kjt could be empty
            return []
        rt: List[int] = stride_per_key_per_rank.sum(dim=1).tolist()
        if not torch.jit.is_scripting() and is_torchdynamo_compiling():
            pt2_checks_all_is_size(rt)
        return rt
    elif stride is not None:
        return [stride] * len(keys)
    else:
        return None
```
# the main change from D74366343 is that the `stride_per_key_per_rank` in `dist_init`:
* baseline
```
            if stagger > 1:
                stride_per_key_per_rank_stagger: List[List[int]] = []
                local_world_size = num_workers // stagger
                for i in range(len(keys)):
                    stride_per_rank_stagger: List[int] = []
                    for j in range(local_world_size):
                        stride_per_rank_stagger.extend(
                            stride_per_key_per_rank[i][j::local_world_size]
                        )
                    stride_per_key_per_rank_stagger.append(stride_per_rank_stagger)
                stride_per_key_per_rank = stride_per_key_per_rank_stagger
```
* D76875546 (correct, this diff)
```
            if stagger > 1:
                indices = torch.arange(num_workers).view(stagger, -1).T.reshape(-1)
                stride_per_key_per_rank = stride_per_key_per_rank[:, indices]
```
* D74366343 (incorrect, reverted)
```
            if stagger > 1:
                local_world_size = num_workers // stagger
                indices = [
                    list(range(i, num_workers, local_world_size))
                    for i in range(local_world_size)
                ]
                stride_per_key_per_rank = stride_per_key_per_rank[:, indices]
```

Differential Revision: D76903646
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D76903646

@TroyGarden TroyGarden deleted the export-D76903646 branch June 19, 2025 07:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-no-td CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants