Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions torchrec/distributed/test_utils/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import torch.distributed as dist
import torch.nn as nn
from fbgemm_gpu.split_embedding_configs import EmbOptimType
from fbgemm_gpu.tbe.ssd.utils.partially_materialized_tensor import (
PartiallyMaterializedTensor,
)
from torch.distributed._tensor import DeviceMesh, DTensor
from torch.distributed.optim import (
_apply_optimizer_in_backward as apply_optimizer_in_backward,
Expand Down Expand Up @@ -610,12 +613,16 @@ def copy_state_dict(

if isinstance(tensor, ShardedTensor):
for local_shard in tensor.local_shards():
# Tensors like `PartiallyMaterializedTensor` do not provide
# `ndim` property, so use shape length here as a workaround
ndim = len(local_shard.tensor.shape)
assert (
global_tensor.ndim == local_shard.tensor.ndim
), f"global_tensor.ndim: {global_tensor.ndim}, local_shard.tensor.ndim: {local_shard.tensor.ndim}"
global_tensor.ndim == ndim
), f"global_tensor.ndim: {global_tensor.ndim}, local_shard.tensor.ndim: {ndim}"
assert (
global_tensor.dtype == local_shard.tensor.dtype
), f"global tensor dtype: {global_tensor.dtype}, local tensor dtype: {local_shard.tensor.dtype}"

shard_meta = local_shard.metadata
t = global_tensor.detach()
if t.ndim == 1:
Expand All @@ -632,7 +639,13 @@ def copy_state_dict(
]
else:
raise ValueError("Tensors with ndim > 2 are not supported")
local_shard.tensor.copy_(t)

if isinstance(local_shard.tensor, PartiallyMaterializedTensor):
local_shard.tensor.wrapped.set_range(
0, 0, t.size(0), t.to(device="cpu")
)
else:
local_shard.tensor.copy_(t)
elif isinstance(tensor, DTensor):
for local_shard, global_offset in zip(
tensor.to_local().local_shards(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import cast, List, OrderedDict, Union

import torch
import torch.distributed as dist
import torch.nn as nn
from fbgemm_gpu.split_embedding_configs import EmbOptimType
from hypothesis import given, settings, strategies as st, Verbosity
Expand All @@ -27,13 +28,15 @@
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import ParameterConstraints
from torchrec.distributed.test_utils.test_model import TestSparseNN
from torchrec.distributed.test_utils.test_model_parallel_base import (
ModelParallelSingleRankBase,
)
from torchrec.distributed.test_utils.test_sharding import (
copy_state_dict,
create_test_sharder,
SharderType,
sharding_single_rank_test_single_process,
)
from torchrec.distributed.tests.test_sequence_model import (
TestEmbeddingCollectionSharder,
Expand All @@ -45,6 +48,7 @@
EmbeddingBagConfig,
EmbeddingConfig,
)
from torchrec.optim import RowWiseAdagrad


def _load_split_embedding_weights(
Expand Down Expand Up @@ -540,6 +544,72 @@ def test_ssd_mixed_kernels(
self._eval_models(m1, m2, batch, is_deterministic=is_deterministic)
self._compare_models(m1, m2, is_deterministic=is_deterministic)

@unittest.skipIf(
not torch.cuda.is_available(),
"Not enough GPUs, this test requires at least one GPU",
)
# pyre-fixme[56]
@given(
sharding_type=st.sampled_from(
[
ShardingType.TABLE_WISE.value,
# TODO: uncomment when ssd ckpt support cw sharding
# ShardingType.COLUMN_WISE.value,
ShardingType.ROW_WISE.value,
ShardingType.TABLE_ROW_WISE.value,
# TODO: uncomment when ssd ckpt support cw sharding
# ShardingType.TABLE_COLUMN_WISE.value,
]
),
dtype=st.sampled_from([DataType.FP32, DataType.FP16]),
)
@settings(verbosity=Verbosity.verbose, max_examples=6, deadline=None)
def test_ssd_mixed_kernels_with_vbe(
self,
sharding_type: str,
dtype: DataType,
) -> None:
self._set_table_weights_precision(dtype)
fused_params = {
"prefetch_pipeline": True,
}
constraints = {
table.name: ParameterConstraints(
min_partition=4,
compute_kernels=(
[EmbeddingComputeKernel.FUSED.value]
if i % 2 == 0
else [EmbeddingComputeKernel.KEY_VALUE.value]
),
sharding_types=[sharding_type],
)
for i, table in enumerate(self.tables)
}
optimizer_config = (RowWiseAdagrad, {"lr": 0.001, "eps": 0.001})
pg = dist.GroupMember.WORLD

assert pg is not None, "Process group is not initialized"
sharding_single_rank_test_single_process(
pg=pg,
device=self.device,
rank=0,
world_size=1,
# pyre-fixme[6]: The intake type should be `type[TestSparseNNBase]`
model_class=TestSparseNN,
embedding_groups={},
tables=self.tables,
# pyre-fixme[6]
sharders=[EmbeddingBagCollectionSharder(fused_params=fused_params)],
optim=EmbOptimType.EXACT_SGD,
# The optimizer config here will overwrite the SGD optimizer above
apply_optimizer_in_backward_config={
"embedding_bags": optimizer_config,
"embeddings": optimizer_config,
},
constraints=constraints,
variable_batch_per_feature=True,
)

@unittest.skipIf(
not torch.cuda.is_available(),
"Not enough GPUs, this test requires at least one GPU",
Expand Down