-
Notifications
You must be signed in to change notification settings - Fork 561
[TorchRec] fix stride_per_key_per_rank in stagger scenario in D74366343 #3112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This pull request was exported from Phabricator. Differential Revision: D76903646 |
Summary: Pull Request resolved: meta-pytorch#2587 # context * Previously for a KJT, only the following fields and `_keys` are stored in the pytree flatten specs. All other arguments/parameters would be derived accordingly. ``` _fields = [ "_values", "_weights", "_lengths", "_offsets", ] ``` * Particularly, the `stride` (int) of a KJT, which represents the `batch_size`, is computed by `_maybe_compute_stride_kjt`: ``` def _maybe_compute_stride_kjt( keys: List[str], stride: Optional[int], lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor], stride_per_key_per_rank: Optional[List[List[int]]], ) -> int: if stride is None: if len(keys) == 0: stride = 0 elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0: stride = max([sum(s) for s in stride_per_key_per_rank]) elif offsets is not None and offsets.numel() > 0: stride = (offsets.numel() - 1) // len(keys) elif lengths is not None: stride = lengths.numel() // len(keys) else: stride = 0 return stride ``` * The previously stored pytree flatten specs are enough if the `batch_size` is static, however, this no longer holds true in a variable batch size scenario, where the `stride_per_key_per_rank` is not `None`. * An example is that with `dedup_ebc`, where the actual batch_size is variable (depending on the dedup data), but the output of the ebc should always be the **true** `stride` (static). * During ir_export, the output shape will be calculated from `kjt.stride()` function, which would be incorrect if the pytree specs only contains the `keys`. * This diff adds the `stride` into the KJT pytree flatten/unflatten functions so that a fakified KJT would have the correct stride value. Differential Revision: D66400821 Reviewed By: PaulZhang12
…torch#3111) Summary: Pull Request resolved: meta-pytorch#3111 # context * original diff D74366343 broke cogwheel test and was reverted * the error stack P1844048578 is shown below: ``` File "/dev/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/dev/torch/nn/modules/module.py", line 1784, in _call_impl return forward_call(*args, **kwargs) File "/dev/torchrec/distributed/train_pipeline/runtime_forwards.py", line 84, in __call__ data = request.wait() File "/dev/torchrec/distributed/types.py", line 334, in wait ret: W = self._wait_impl() File "/dev/torchrec/distributed/embedding_sharding.py", line 655, in _wait_impl kjts.append(w.wait()) File "/dev/torchrec/distributed/types.py", line 334, in wait ret: W = self._wait_impl() File "/dev/torchrec/distributed/dist_data.py", line 426, in _wait_impl return type(self._input).dist_init( File "/dev/torchrec/sparse/jagged_tensor.py", line 2993, in dist_init return kjt.sync() File "/dev/torchrec/sparse/jagged_tensor.py", line 2067, in sync self.length_per_key() File "/dev/torchrec/sparse/jagged_tensor.py", line 2281, in length_per_key _length_per_key = _maybe_compute_length_per_key( File "/dev/torchrec/sparse/jagged_tensor.py", line 1192, in _maybe_compute_length_per_key _length_per_key_from_stride_per_key(lengths, stride_per_key) File "/dev/torchrec/sparse/jagged_tensor.py", line 1144, in _length_per_key_from_stride_per_key if _use_segment_sum_csr(stride_per_key): File "/dev/torchrec/sparse/jagged_tensor.py", line 1131, in _use_segment_sum_csr elements_per_segment = sum(stride_per_key) / len(stride_per_key) ZeroDivisionError: division by zero ``` * the complaint is `stride_per_key` is an empty list, which comes from the following function call: ``` stride_per_key = _maybe_compute_stride_per_key( self._stride_per_key, self._stride_per_key_per_rank, self.stride(), self._keys, ) ``` * the only place this `stride_per_key` could be empty is when the `stride_per_key_per_rank.dim() != 2` ``` def _maybe_compute_stride_per_key( stride_per_key: Optional[List[int]], stride_per_key_per_rank: Optional[torch.IntTensor], stride: Optional[int], keys: List[str], ) -> Optional[List[int]]: if stride_per_key is not None: return stride_per_key elif stride_per_key_per_rank is not None: if stride_per_key_per_rank.dim() != 2: # after permute the kjt could be empty return [] rt: List[int] = stride_per_key_per_rank.sum(dim=1).tolist() if not torch.jit.is_scripting() and is_torchdynamo_compiling(): pt2_checks_all_is_size(rt) return rt elif stride is not None: return [stride] * len(keys) else: return None ``` # the main change from D74366343 is that the `stride_per_key_per_rank` in `dist_init`: * baseline ``` if stagger > 1: stride_per_key_per_rank_stagger: List[List[int]] = [] local_world_size = num_workers // stagger for i in range(len(keys)): stride_per_rank_stagger: List[int] = [] for j in range(local_world_size): stride_per_rank_stagger.extend( stride_per_key_per_rank[i][j::local_world_size] ) stride_per_key_per_rank_stagger.append(stride_per_rank_stagger) stride_per_key_per_rank = stride_per_key_per_rank_stagger ``` * D76875546 (correct, this diff) ``` if stagger > 1: indices = torch.arange(num_workers).view(stagger, -1).T.reshape(-1) stride_per_key_per_rank = stride_per_key_per_rank[:, indices] ``` * D74366343 (incorrect, reverted) ``` if stagger > 1: local_world_size = num_workers // stagger indices = [ list(range(i, num_workers, local_world_size)) for i in range(local_world_size) ] stride_per_key_per_rank = stride_per_key_per_rank[:, indices] ``` Differential Revision: D76903646
This pull request was exported from Phabricator. Differential Revision: D76903646 |
2f0747f
to
696d332
Compare
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
ci-no-td
CLA Signed
This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
fb-exported
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
context
stride_per_key
is an empty list, which comes from the following function call:stride_per_key
could be empty is when thestride_per_key_per_rank.dim() != 2
the main change from D74366343 is that the
stride_per_key_per_rank
indist_init
:Rollback Plan:
Differential Revision: D76875546