Skip to content

[BUG] EBC Mean pooling division is not handled properly. #2362

@JacoCheung

Description

@JacoCheung

Describe the bug
Hi torchrec team, I found out that if the keys of an input bag were across multiple devices, the mean pooling result was incorrect. The root cause was that fbgemm will divide the embedding results by local bag size , and the output_dist of torchrec is a SUM reduce scatter (RW sharding).

Replicating steps

  • scripts:

import os
import sys
import torch
import torchrec
import torch.distributed as dist
from torchrec.distributed.fbgemm_qcomm_codec import get_qcomm_codecs_registry, QCommsConfig, CommType
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder

from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology, ParameterConstraints
from torchrec.distributed.embedding import EmbeddingCollectionSharder
from torchrec.distributed.types import (
    ModuleSharder,
    ShardingType,
)
dist.init_process_group(backend="nccl")

local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")

def init_fn(x: torch.Tensor):
  with torch.no_grad():
    x.fill_(2.0)
ebc = torchrec.EmbeddingBagCollection(
    device=torch.device("meta"),
    tables=[
        torchrec.EmbeddingBagConfig(
            name="product_table",
            embedding_dim=4,
            num_embeddings=4,
            feature_names=["product"],
            init_fn=init_fn,
            pooling=torchrec.PoolingType.MEAN,
        ),
    ]
)
sharding_types = [ShardingType.ROW_WISE.value]
constraints = {"product_table": ParameterConstraints(sharding_types=sharding_types)}
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
planner = EmbeddingShardingPlanner(
    constraints=constraints,

)
sharders = [EmbeddingBagCollectionSharder()]
plan = planner.collective_plan(ebc, sharders, pg = dist.GroupMember.WORLD)

apply_optimizer_in_backward(
    optimizer_class=torch.optim.SGD,
    params=ebc.parameters(),
    optimizer_kwargs={"lr": 0.02},
)

model = torchrec.distributed.DistributedModelParallel(ebc, sharders=sharders, device=torch.device("cuda"), plan = plan)
mb = torchrec.KeyedJaggedTensor(
    keys = ["product"],
    values = torch.tensor([0, 1, 2]).cuda(), # key [0,1] on rank0, [2] on rank 1
    lengths = torch.tensor([3], dtype=torch.int64).cuda(),
)
ret = model(mb) # => this is awaitable
product = ret.to_dict()["product"] # implicitly call awaitable.wait(); ec does not have to_dict attribute

if(local_rank == 0):
  print(model.plan)
  print(f'product {product} ') # resut is 4!! (2+2) / 2 + (2) / 1
  • cmd to run:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun \
  --nnodes 1 \
  --nproc_per_node 2 \
  ./mean_polling.py
  • Output result:
module: 

    param     | sharding type | compute kernel | ranks 
------------- | ------------- | -------------- | ------
product_table | row_wise      | fused          | [0, 1]

    param     | shard offsets | shard sizes |   placement  
------------- | ------------- | ----------- | -------------
product_table | [0, 0]        | [2, 4]      | rank:0/cuda:0
product_table | [2, 0]        | [2, 4]      | rank:1/cuda:1

product tensor([[4., 4., 4., 4.]], device='cuda:0', grad_fn=<SplitWithSizesBackward0>) 
  • expected sould be product tensor([[2., 2., 2., 2.]]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions