-
Notifications
You must be signed in to change notification settings - Fork 561
Closed
Description
Describe the bug
Hi torchrec team, I found out that if the keys of an input bag were across multiple devices, the mean pooling result was incorrect. The root cause was that fbgemm will divide the embedding results by local bag size
, and the output_dist
of torchrec is a SUM
reduce scatter (RW sharding).
Replicating steps
- scripts:
import os
import sys
import torch
import torchrec
import torch.distributed as dist
from torchrec.distributed.fbgemm_qcomm_codec import get_qcomm_codecs_registry, QCommsConfig, CommType
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology, ParameterConstraints
from torchrec.distributed.embedding import EmbeddingCollectionSharder
from torchrec.distributed.types import (
ModuleSharder,
ShardingType,
)
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")
def init_fn(x: torch.Tensor):
with torch.no_grad():
x.fill_(2.0)
ebc = torchrec.EmbeddingBagCollection(
device=torch.device("meta"),
tables=[
torchrec.EmbeddingBagConfig(
name="product_table",
embedding_dim=4,
num_embeddings=4,
feature_names=["product"],
init_fn=init_fn,
pooling=torchrec.PoolingType.MEAN,
),
]
)
sharding_types = [ShardingType.ROW_WISE.value]
constraints = {"product_table": ParameterConstraints(sharding_types=sharding_types)}
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
planner = EmbeddingShardingPlanner(
constraints=constraints,
)
sharders = [EmbeddingBagCollectionSharder()]
plan = planner.collective_plan(ebc, sharders, pg = dist.GroupMember.WORLD)
apply_optimizer_in_backward(
optimizer_class=torch.optim.SGD,
params=ebc.parameters(),
optimizer_kwargs={"lr": 0.02},
)
model = torchrec.distributed.DistributedModelParallel(ebc, sharders=sharders, device=torch.device("cuda"), plan = plan)
mb = torchrec.KeyedJaggedTensor(
keys = ["product"],
values = torch.tensor([0, 1, 2]).cuda(), # key [0,1] on rank0, [2] on rank 1
lengths = torch.tensor([3], dtype=torch.int64).cuda(),
)
ret = model(mb) # => this is awaitable
product = ret.to_dict()["product"] # implicitly call awaitable.wait(); ec does not have to_dict attribute
if(local_rank == 0):
print(model.plan)
print(f'product {product} ') # resut is 4!! (2+2) / 2 + (2) / 1
- cmd to run:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun \
--nnodes 1 \
--nproc_per_node 2 \
./mean_polling.py
- Output result:
module:
param | sharding type | compute kernel | ranks
------------- | ------------- | -------------- | ------
product_table | row_wise | fused | [0, 1]
param | shard offsets | shard sizes | placement
------------- | ------------- | ----------- | -------------
product_table | [0, 0] | [2, 4] | rank:0/cuda:0
product_table | [2, 0] | [2, 4] | rank:1/cuda:1
product tensor([[4., 4., 4., 4.]], device='cuda:0', grad_fn=<SplitWithSizesBackward0>)
- expected sould be
product tensor([[2., 2., 2., 2.]]
Metadata
Metadata
Assignees
Labels
No labels