From dd628b5e99aa3a8809bcb1b3a4b9dcc2e5ddcece Mon Sep 17 00:00:00 2001 From: Shuangping Liu Date: Thu, 12 Jun 2025 23:38:08 -0700 Subject: [PATCH] Add unit test for SSD TBE with VBE input (#3086) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/3086 Add a new unit test in [`test_model_parallel_nccl_ssd_single_gpu.py`](https://www.internalfb.com/code/fbsource/[5f477259031a]/fbcode/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py) for SSD TBE with VBE input. ### Context * This test is a prerequisite to test out the incoming FBGEMM & TorchRec changes to merge VBE output. * For SSD TBE, the tensor wrapped in a shard is a [`PartiallyMaterializedTensor`](https://www.internalfb.com/code/fbsource/fbcode/deeplearning/fbgemm/fbgemm_gpu/fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py) (PMT) which requires special handling when copying state dict from an unsharded tensor. Specifically: - It misses certain methods like `ndim`. - `copy_` method is a no-op. Writing should be done through the [wrapped C++ object](https://www.internalfb.com/code/fbsource/fbcode/deeplearning/fbgemm/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp?lines=417) of PMT. - Only `ROW_WISE`, `TABLE_WISE` and `TABLE_ROW_WISE` sharding types are supported. NOTE: SSD TBE only support `RowWiseAdagrad` optimizer. For **FP16**, The learning rate and eps need to be carefully selected for avoid numerical instabilities for the unsharded model. Here we use `lr = 0.001` and `eps = 0.001` to pass the test. Reviewed By: TroyGarden Differential Revision: D76455104 --- .../distributed/test_utils/test_sharding.py | 19 ++++- ...test_model_parallel_nccl_ssd_single_gpu.py | 70 +++++++++++++++++++ 2 files changed, 86 insertions(+), 3 deletions(-) diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index e037ed096..8c8f11d35 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -16,6 +16,9 @@ import torch.distributed as dist import torch.nn as nn from fbgemm_gpu.split_embedding_configs import EmbOptimType +from fbgemm_gpu.tbe.ssd.utils.partially_materialized_tensor import ( + PartiallyMaterializedTensor, +) from torch.distributed._tensor import DeviceMesh, DTensor from torch.distributed.optim import ( _apply_optimizer_in_backward as apply_optimizer_in_backward, @@ -610,12 +613,16 @@ def copy_state_dict( if isinstance(tensor, ShardedTensor): for local_shard in tensor.local_shards(): + # Tensors like `PartiallyMaterializedTensor` do not provide + # `ndim` property, so use shape length here as a workaround + ndim = len(local_shard.tensor.shape) assert ( - global_tensor.ndim == local_shard.tensor.ndim - ), f"global_tensor.ndim: {global_tensor.ndim}, local_shard.tensor.ndim: {local_shard.tensor.ndim}" + global_tensor.ndim == ndim + ), f"global_tensor.ndim: {global_tensor.ndim}, local_shard.tensor.ndim: {ndim}" assert ( global_tensor.dtype == local_shard.tensor.dtype ), f"global tensor dtype: {global_tensor.dtype}, local tensor dtype: {local_shard.tensor.dtype}" + shard_meta = local_shard.metadata t = global_tensor.detach() if t.ndim == 1: @@ -632,7 +639,13 @@ def copy_state_dict( ] else: raise ValueError("Tensors with ndim > 2 are not supported") - local_shard.tensor.copy_(t) + + if isinstance(local_shard.tensor, PartiallyMaterializedTensor): + local_shard.tensor.wrapped.set_range( + 0, 0, t.size(0), t.to(device="cpu") + ) + else: + local_shard.tensor.copy_(t) elif isinstance(tensor, DTensor): for local_shard, global_offset in zip( tensor.to_local().local_shards(), diff --git a/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py b/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py index a9a77fcaf..6b9ee360b 100644 --- a/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py +++ b/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py @@ -12,6 +12,7 @@ from typing import cast, List, OrderedDict, Union import torch +import torch.distributed as dist import torch.nn as nn from fbgemm_gpu.split_embedding_configs import EmbOptimType from hypothesis import given, settings, strategies as st, Verbosity @@ -27,6 +28,7 @@ from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder from torchrec.distributed.model_parallel import DistributedModelParallel from torchrec.distributed.planner import ParameterConstraints +from torchrec.distributed.test_utils.test_model import TestSparseNN from torchrec.distributed.test_utils.test_model_parallel_base import ( ModelParallelSingleRankBase, ) @@ -34,6 +36,7 @@ copy_state_dict, create_test_sharder, SharderType, + sharding_single_rank_test_single_process, ) from torchrec.distributed.tests.test_sequence_model import ( TestEmbeddingCollectionSharder, @@ -45,6 +48,7 @@ EmbeddingBagConfig, EmbeddingConfig, ) +from torchrec.optim import RowWiseAdagrad def _load_split_embedding_weights( @@ -540,6 +544,72 @@ def test_ssd_mixed_kernels( self._eval_models(m1, m2, batch, is_deterministic=is_deterministic) self._compare_models(m1, m2, is_deterministic=is_deterministic) + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-fixme[56] + @given( + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + # TODO: uncomment when ssd ckpt support cw sharding + # ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + # TODO: uncomment when ssd ckpt support cw sharding + # ShardingType.TABLE_COLUMN_WISE.value, + ] + ), + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=6, deadline=None) + def test_ssd_mixed_kernels_with_vbe( + self, + sharding_type: str, + dtype: DataType, + ) -> None: + self._set_table_weights_precision(dtype) + fused_params = { + "prefetch_pipeline": True, + } + constraints = { + table.name: ParameterConstraints( + min_partition=4, + compute_kernels=( + [EmbeddingComputeKernel.FUSED.value] + if i % 2 == 0 + else [EmbeddingComputeKernel.KEY_VALUE.value] + ), + sharding_types=[sharding_type], + ) + for i, table in enumerate(self.tables) + } + optimizer_config = (RowWiseAdagrad, {"lr": 0.001, "eps": 0.001}) + pg = dist.GroupMember.WORLD + + assert pg is not None, "Process group is not initialized" + sharding_single_rank_test_single_process( + pg=pg, + device=self.device, + rank=0, + world_size=1, + # pyre-fixme[6]: The intake type should be `type[TestSparseNNBase]` + model_class=TestSparseNN, + embedding_groups={}, + tables=self.tables, + # pyre-fixme[6] + sharders=[EmbeddingBagCollectionSharder(fused_params=fused_params)], + optim=EmbOptimType.EXACT_SGD, + # The optimizer config here will overwrite the SGD optimizer above + apply_optimizer_in_backward_config={ + "embedding_bags": optimizer_config, + "embeddings": optimizer_config, + }, + constraints=constraints, + variable_batch_per_feature=True, + ) + @unittest.skipIf( not torch.cuda.is_available(), "Not enough GPUs, this test requires at least one GPU",