Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 33 additions & 16 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@

import torch
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
DenseTableBatchedEmbeddingBagsCodegen,
)
from tensordict import TensorDict
from torch import distributed as dist, nn, Tensor
from torch.autograd.profiler import record_function
Expand Down Expand Up @@ -61,6 +58,7 @@
get_largest_dims_from_sharding_plan_updates,
shards_all_to_all,
update_module_sharding_plan,
update_optimizer_state_post_resharding,
update_state_dict_post_resharding,
)
from torchrec.distributed.sharding.grid_sharding import GridPooledEmbeddingSharding
Expand Down Expand Up @@ -1535,7 +1533,7 @@ def update_shards(
return

current_state = self.state_dict()
# TODO: Save Optimizers
has_optimizer = len(self._optim._optims) > 0

# TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again
# TODO: Ensure lookup tensors are actually being deleted
Expand All @@ -1550,6 +1548,7 @@ def update_shards(
max_dim_0, max_dim_1 = get_largest_dims_from_sharding_plan_updates(
changed_sharding_params
)
old_optimizer_state = self._optim.state_dict() if has_optimizer else None

local_shard_names_by_src_rank, local_output_tensor = shards_all_to_all(
module=self,
Expand All @@ -1560,16 +1559,7 @@ def update_shards(
extend_shard_name=self.extend_shard_name,
max_dim_0=max_dim_0,
max_dim_1=max_dim_1,
)

current_state = update_state_dict_post_resharding(
state_dict=current_state,
ordered_shard_names_and_lengths=local_shard_names_by_src_rank,
output_tensor=local_output_tensor,
new_sharding_params=changed_sharding_params,
curr_rank=dist.get_rank(),
extend_shard_name=self.extend_shard_name,
max_dim_0=max_dim_0,
optimizer_state=old_optimizer_state,
)

for name, param in changed_sharding_params.items():
Expand Down Expand Up @@ -1615,8 +1605,6 @@ def update_shards(
if env.process_group and dist.get_backend(env.process_group) != "fake":
self._initialize_torch_state(skip_registering=True)

self.load_state_dict(current_state)

# update optimizer
optims = []
for lookup in self._lookups:
Expand All @@ -1635,6 +1623,35 @@ def update_shards(

self._optim: CombinedOptimizer = CombinedOptimizer(optims)

if has_optimizer:
split_index = len(local_output_tensor) // 2
local_weight_tensors = local_output_tensor[:split_index]
local_optimizer_tensors = local_output_tensor[split_index:]
# Modifies new_opt_state in place and returns it
optimizer_state = update_optimizer_state_post_resharding(
old_opt_state=old_optimizer_state, # pyre-ignore
new_opt_state=copy.deepcopy(self._optim.state_dict()),
ordered_shard_names_and_lengths=local_shard_names_by_src_rank,
output_tensor=local_optimizer_tensors,
max_dim_0=max_dim_0,
)

self._optim.load_state_dict(optimizer_state)
else:
local_weight_tensors = local_output_tensor

current_state = update_state_dict_post_resharding(
state_dict=current_state,
ordered_shard_names_and_lengths=local_shard_names_by_src_rank,
output_tensor=local_weight_tensors,
new_sharding_params=changed_sharding_params,
curr_rank=dist.get_rank(),
extend_shard_name=self.extend_shard_name,
max_dim_0=max_dim_0,
)

self.load_state_dict(current_state)

update_module_sharding_plan(self, changed_sharding_params)
return

Expand Down
5 changes: 4 additions & 1 deletion torchrec/distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,10 @@ def reshard(
self.device,
)

self._optim: CombinedOptimizer = self._init_optim(self._dmp_wrapped_module)
# Need to use .module to maintain FQN consistency
self._optim: CombinedOptimizer = self._init_optim(
self._dmp_wrapped_module.module # pyre-ignore
)
self._plan.plan[sharded_module_fqn] = sharded_module.module_sharding_plan
return sharded_module

Expand Down
105 changes: 102 additions & 3 deletions torchrec/distributed/sharding/dynamic_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# pyre-strict

import copy
from typing import Any, Callable, Dict, List, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -84,6 +84,7 @@ def shards_all_to_all(
max_dim_0: int,
max_dim_1: int,
extend_shard_name: Callable[[str], str] = lambda x: x,
optimizer_state: Optional[Dict[str, Dict[str, Dict[str, ShardedTensor]]]] = None,
) -> Tuple[OrderedShardNamesWithSizes, torch.Tensor]:
"""
Performs an all-to-all communication to redistribute shards across ranks based on new sharding parameters.
Expand Down Expand Up @@ -121,14 +122,18 @@ def shards_all_to_all(
# Module sharding plan is used to get the source ranks for each shard
assert hasattr(module, "module_sharding_plan")

has_optimizer = optimizer_state is not None

world_size = env.world_size
rank = dist.get_rank()
input_splits_per_rank = [[0] * world_size for _ in range(world_size)]
output_splits_per_rank = [[0] * world_size for _ in range(world_size)]

output_tensor_tensor_count = 0
output_optimizer_tensor_count = 0
shard_names_to_lengths_by_src_rank = [[] for _ in range(world_size)]
local_table_to_input_tensor_by_dst_rank = [[] for _ in range(world_size)]
local_table_to_opt_by_dst_rank = [[] for _ in range(world_size)]
for shard_name, param in changed_sharding_params.items():
sharded_t = state_dict[extend_shard_name(shard_name)]
assert param.ranks is not None
Expand All @@ -142,24 +147,47 @@ def shards_all_to_all(
# index needed to distinguish between multiple shards
# within the same shardedTensor for each table
for i in range(len(src_ranks)):

# 1 to 1 mapping from src to dst
dst_rank = dst_ranks[i]
src_rank = src_ranks[i]

shard_size = sharded_t.metadata().shards_metadata[i].shard_sizes
input_splits_per_rank[src_rank][dst_rank] += max_dim_0
output_splits_per_rank[dst_rank][src_rank] += max_dim_0
if has_optimizer:
input_splits_per_rank[src_rank][dst_rank] += max_dim_0
output_splits_per_rank[dst_rank][src_rank] += max_dim_0

# If sending from current rank
if src_rank == rank:
if has_optimizer:
# pyre-ignore
local_optimizer = optimizer_state["state"][
extend_shard_name(shard_name)
][tmp_momentum_extender(shard_name)].local_shards()
assert len(local_optimizer) == 1
padded_local_optimizer = pad_tensor_to_max_dims(
local_optimizer[0].tensor, max_dim_0, max_dim_1
)
local_table_to_opt_by_dst_rank[dst_rank].append(
padded_local_optimizer
)
local_shards = sharded_t.local_shards()
assert len(local_shards) == 1
cur_t = pad_tensor_to_max_dims(
sharded_t.local_shards()[0].tensor, max_dim_0, max_dim_1
local_shards[0].tensor, max_dim_0, max_dim_1
)
local_table_to_input_tensor_by_dst_rank[dst_rank].append(cur_t)

# If recieving from current rank
if dst_rank == rank:
shard_names_to_lengths_by_src_rank[src_rank].append(
(shard_name, shard_size)
)
output_tensor_tensor_count += max_dim_0
if has_optimizer:
output_optimizer_tensor_count += max_dim_0

local_input_splits = input_splits_per_rank[rank]
local_output_splits = output_splits_per_rank[rank]
Expand All @@ -175,9 +203,23 @@ def shards_all_to_all(
dim=0,
)

for sub_l in local_table_to_opt_by_dst_rank:
for shard_info in sub_l:
local_input_tensor = torch.cat(
(
local_input_tensor,
shard_info,
),
dim=0,
)

max_embedding_size = max_dim_1
local_output_tensor = torch.empty(
[output_tensor_tensor_count, max_embedding_size], device=device
[
output_tensor_tensor_count + output_optimizer_tensor_count,
max_embedding_size,
],
device=device,
)

assert sum(local_output_splits) == len(local_output_tensor)
Expand Down Expand Up @@ -277,6 +319,50 @@ def update_state_dict_post_resharding(
return state_dict


def update_optimizer_state_post_resharding(
old_opt_state: Dict[str, Dict[str, Dict[str, ShardedTensor]]],
new_opt_state: Dict[str, Dict[str, Dict[str, ShardedTensor]]],
ordered_shard_names_and_lengths: OrderedShardNamesWithSizes,
output_tensor: torch.Tensor,
max_dim_0: int,
) -> Dict[str, Dict[str, Dict[str, ShardedTensor]]]:
new_opt_state_state = new_opt_state["state"]
old_opt_state_state = old_opt_state["state"]

# Remove padding and store tensors by shard name
slice_index = 0
shard_name_to_local_output_tensor: Dict[str, torch.Tensor] = {}
for shard_name, shard_size in ordered_shard_names_and_lengths:
end_slice_index = slice_index + max_dim_0
cur_t = output_tensor[slice_index:end_slice_index]
cur_t = pad_tensor_to_max_dims(
cur_t, shard_size[0], shard_size[1], remove_padding=True
)
shard_name_to_local_output_tensor[shard_name] = cur_t
slice_index = end_slice_index

for extended_shard_name, item in new_opt_state_state.items():
if extended_shard_name in old_opt_state_state:
new_opt_state_state[extended_shard_name] = old_opt_state_state[
extended_shard_name
]
else:
shard_name = extract_shard_name(extended_shard_name)
momentum_name = tmp_momentum_extender(shard_name)
sharded_t = item[momentum_name]
assert len(sharded_t._local_shards) == 1
# TODO: support multiple shards in CW sharding
sharded_t._local_shards = [
Shard(
tensor=shard_name_to_local_output_tensor[shard_name],
metadata=shard.metadata,
)
for shard in sharded_t._local_shards
]

return new_opt_state


def update_module_sharding_plan(
module: ShardedModule[Any, Any, Any, Any], # pyre-ignore
changed_sharding_params: Dict[str, ParameterSharding],
Expand Down Expand Up @@ -388,3 +474,16 @@ def output_sharding_plan_delta(
if v.ranks != old_plan[k].ranks
}
)


"""
Utils for Optimizer State accessing
"""


def tmp_momentum_extender(name: str) -> str:
return name + ".momentum1"


def extract_shard_name(name: str) -> str:
return name.split(".")[-2]
36 changes: 36 additions & 0 deletions torchrec/distributed/test_utils/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,8 @@ def dynamic_sharding_test(
)

local_m1_dmp.reshard("sparse.ebc", new_module_sharding_plan_delta)
# Must recreate local_m1_opt, because current local_m1_opt is a copy of underlying fused_opt
local_m1_opt = CombinedOptimizer([local_m1_dmp.fused_optimizer, dense_m1_optim])

local_m1_pred = gen_full_pred_after_one_step(
local_m1_dmp, local_m1_opt, local_input_1
Expand Down Expand Up @@ -954,7 +956,12 @@ def gen_full_pred_after_one_step(
opt: torch.optim.Optimizer,
input: ModelInput,
skip_inference: bool = False,
skip_training: bool = False,
) -> torch.Tensor:
if skip_training:
model.train(False)
output = model(input)
return output
# Run a single training step of the global model.
opt.zero_grad()
model.train(True)
Expand Down Expand Up @@ -1120,3 +1127,32 @@ def generate_rank_placements(
placement = sorted(random.sample(range(world_size), ranks_per_table))
placements.append(placement)
return placements


def compare_opt_local_t(
opt_1: CombinedOptimizer,
opt_2: CombinedOptimizer,
table_id: int,
rtol: float = 1e-4,
atol: float = 1e-4,
) -> None:
"""
Helper function to compare the optimizer state of two models after one training step.
Useful for debugging sharding tests to see which model weights are different
"""
# TODO: update logic to be generic other embedding modules
t1 = (
opt_1.state_dict()["state"][
"sparse.ebc.embedding_bags.table_" + str(table_id) + ".weight"
]["table_" + str(table_id) + ".momentum1"]
.local_shards()[0]
.tensor
)
t2 = (
opt_2.state_dict()["state"][
"sparse.ebc.embedding_bags.table_" + str(table_id) + ".weight"
]["table_" + str(table_id) + ".momentum1"]
.local_shards()[0]
.tensor
)
torch.testing.assert_close(t1, t2, rtol=rtol, atol=atol)
8 changes: 4 additions & 4 deletions torchrec/distributed/tests/test_dynamic_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
# pyre-strict


import copy

import random
import unittest

Expand All @@ -21,7 +19,7 @@

from hypothesis import assume, given, settings, Verbosity

from torch import nn
from torch import nn, optim

from torchrec import distributed as trec_dist, EmbeddingBagCollection, KeyedJaggedTensor
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
Expand Down Expand Up @@ -530,9 +528,11 @@ class MultiRankDMPDynamicShardingTest(ModelParallelTestShared):
apply_optimizer_in_backward_config=st.sampled_from(
[
None,
{
"embedding_bags": (optim.Adagrad, {"lr": 0.04}),
},
{
"embedding_bags": (torch.optim.SGD, {"lr": 0.01}),
"embeddings": (torch.optim.SGD, {"lr": 0.2}),
},
]
),
Expand Down
Loading