Skip to content

Commit eb6f4ab

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
enable customized emb lookup kernel for TorchRec
Summary: # context * NVIDIA dynamicemb package depends on an old TorchRec release (r0.7) plus a PR ([meta-pytorch#2533](meta-pytorch#2533)) * The goal is to refactor the PR ([meta-pytorch#2533](meta-pytorch#2533)) on trunk so that torchrec can accept customized kernel. # design rationales * Given the fact that the [`EmbeddingComputeKernel`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding_types.py#L64-L72) is a Enum class which can't be dynamically extended outside of TorchRec codebase, we are adding a placeholder type named `customized_kernel` for all customized compute kernels. * `compute_kernel` is set in [ParameterSharding](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/types.py#L694), along with `sharding_type`, `sharding_specs`, etc. User can subclass the `ParameterSharding` dataclass to add more configs and parameters needed by the customized compute kernel, including something like `customized_compute_kernel` to specify the exact one in case there are many. * In order to propagate some [extra config](https://fburl.com/code/bnwp44sz) to the customized kernel, we add a `get_additional_fused_params` to propagate the params to `fused_params`. (we might consider to move the [`add_params_from_parameter_sharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/utils.py#L359) function to a class function of ParameterSharding, so that the user can modify the function when necessary. NOTE: `fused_params` is originally used for passing necessary parameters to the fbgemm lookup kernels (e.g., TBE, see below). It now seems to be just a convenient way of [propagating configs to the kernel from `ParametersSharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/utils.py#L359). ``` (Pdb) group_fused_params {'optimizer': <EmbOptimType.EXACT_ADAGRAD: 'exact_adagrad'>, 'learning_rate': 0.1} ``` * besides the lookup module, very often the customized kernel also needs a customized input_dist and/or a customized output_dist. they all come from [EmbeddingSharding](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding_sharding.py#L964) and its [child classes](https://github.com/pytorch/torchrec/tree/main/torchrec/distributed/sharding) like cw_sharding, tw_sharding, etc. * we make it public for the main API [`create_embedding_sharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding.py#L150) function that return a subclass of EmbeddingSharding, which further creates the user-defined input_dist, output_dist, lookup modules and so on. WARNING: somehow the HKV-based customized compute kernel can't handle `_initialize_torch_state` likely due to the table.weight tensor is no long on the GPU, so it can't really be represented with sharded tensor or DTensor. It's the user's responsibility to correctly handle the state_dict by overriding the `_initialize_torch_state` function. Differential Revision: D70723583
1 parent 75f1f1c commit eb6f4ab

File tree

7 files changed

+259
-194
lines changed

7 files changed

+259
-194
lines changed

torchrec/distributed/embedding.py

Lines changed: 63 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -147,46 +147,6 @@ def get_ec_index_dedup() -> bool:
147147
return EC_INDEX_DEDUP
148148

149149

150-
def create_embedding_sharding(
151-
sharding_type: str,
152-
sharding_infos: List[EmbeddingShardingInfo],
153-
env: ShardingEnv,
154-
device: Optional[torch.device] = None,
155-
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
156-
) -> EmbeddingSharding[
157-
SequenceShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor
158-
]:
159-
if sharding_type == ShardingType.TABLE_WISE.value:
160-
return TwSequenceEmbeddingSharding(
161-
sharding_infos=sharding_infos,
162-
env=env,
163-
device=device,
164-
qcomm_codecs_registry=qcomm_codecs_registry,
165-
)
166-
elif sharding_type == ShardingType.ROW_WISE.value:
167-
return RwSequenceEmbeddingSharding(
168-
sharding_infos=sharding_infos,
169-
env=env,
170-
device=device,
171-
qcomm_codecs_registry=qcomm_codecs_registry,
172-
)
173-
elif sharding_type == ShardingType.DATA_PARALLEL.value:
174-
return DpSequenceEmbeddingSharding(
175-
sharding_infos=sharding_infos,
176-
env=env,
177-
device=device,
178-
)
179-
elif sharding_type == ShardingType.COLUMN_WISE.value:
180-
return CwSequenceEmbeddingSharding(
181-
sharding_infos=sharding_infos,
182-
env=env,
183-
device=device,
184-
qcomm_codecs_registry=qcomm_codecs_registry,
185-
)
186-
else:
187-
raise ValueError(f"Sharding not supported {sharding_type}")
188-
189-
190150
def create_sharding_infos_by_sharding(
191151
module: EmbeddingCollectionInterface,
192152
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
@@ -557,7 +517,7 @@ def __init__(
557517
SequenceShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor
558518
],
559519
] = {
560-
sharding_type: create_embedding_sharding(
520+
sharding_type: self.create_embedding_sharding(
561521
sharding_type=sharding_type,
562522
sharding_infos=embedding_confings,
563523
env=env,
@@ -637,6 +597,51 @@ def __init__(
637597
if module.device != torch.device("meta"):
638598
self.load_state_dict(module.state_dict())
639599

600+
@classmethod
601+
def create_embedding_sharding(
602+
cls,
603+
sharding_type: str,
604+
sharding_infos: List[EmbeddingShardingInfo],
605+
env: ShardingEnv,
606+
device: Optional[torch.device] = None,
607+
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
608+
) -> EmbeddingSharding[
609+
SequenceShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor
610+
]:
611+
"""
612+
This is the main function to generate `EmbeddingSharding` instances based on sharding_type
613+
so that the same sharding_type in one EC would be fused.
614+
"""
615+
if sharding_type == ShardingType.TABLE_WISE.value:
616+
return TwSequenceEmbeddingSharding(
617+
sharding_infos=sharding_infos,
618+
env=env,
619+
device=device,
620+
qcomm_codecs_registry=qcomm_codecs_registry,
621+
)
622+
elif sharding_type == ShardingType.ROW_WISE.value:
623+
return RwSequenceEmbeddingSharding(
624+
sharding_infos=sharding_infos,
625+
env=env,
626+
device=device,
627+
qcomm_codecs_registry=qcomm_codecs_registry,
628+
)
629+
elif sharding_type == ShardingType.DATA_PARALLEL.value:
630+
return DpSequenceEmbeddingSharding(
631+
sharding_infos=sharding_infos,
632+
env=env,
633+
device=device,
634+
)
635+
elif sharding_type == ShardingType.COLUMN_WISE.value:
636+
return CwSequenceEmbeddingSharding(
637+
sharding_infos=sharding_infos,
638+
env=env,
639+
device=device,
640+
qcomm_codecs_registry=qcomm_codecs_registry,
641+
)
642+
else:
643+
raise ValueError(f"Sharding not supported {sharding_type}")
644+
640645
@staticmethod
641646
def _pre_state_dict_hook(
642647
self: "ShardedEmbeddingCollection",
@@ -757,14 +762,23 @@ def _initialize_torch_state(self) -> None: # noqa
757762
parameter_sharding,
758763
) in self.module_sharding_plan.items():
759764
if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value:
765+
# Don't need to use sharded/distributed state tensor for DATA_PARALLEL
766+
# because each rank has a full copy of the table in DATA_PARALLEL
767+
continue
768+
_model_parallel_name_to_compute_kernel[table_name] = (
769+
parameter_sharding.compute_kernel
770+
)
771+
if (
772+
parameter_sharding.compute_kernel
773+
== EmbeddingComputeKernel.CUSTOMIZED_KERNEL.value
774+
):
775+
# Skip state_dict handling for CUSTOMIZED_KERNEL, this should be implemented
776+
# in child class for the CUSTOMIZED_KERNEL
760777
continue
761778
self._model_parallel_name_to_local_shards[table_name] = []
762779
self._model_parallel_name_to_shards_wrapper[table_name] = OrderedDict(
763780
[("local_tensors", []), ("local_offsets", [])]
764781
)
765-
_model_parallel_name_to_compute_kernel[table_name] = (
766-
parameter_sharding.compute_kernel
767-
)
768782

769783
self._name_to_table_size = {}
770784
for table in self._embedding_configs:
@@ -783,6 +797,11 @@ def _initialize_torch_state(self) -> None: # noqa
783797
# save local_shards for transforming MP params to shardedTensor
784798
for key, v in lookup.state_dict().items():
785799
table_name = key[: -len(".weight")]
800+
if (
801+
_model_parallel_name_to_compute_kernel[table_name]
802+
== EmbeddingComputeKernel.CUSTOMIZED_KERNEL.value
803+
):
804+
continue
786805
if isinstance(v, DTensor):
787806
shards_wrapper = self._model_parallel_name_to_shards_wrapper[
788807
table_name

torchrec/distributed/embedding_lookup.py

Lines changed: 72 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -181,46 +181,11 @@ def __init__(
181181
pg: Optional[dist.ProcessGroup] = None,
182182
device: Optional[torch.device] = None,
183183
) -> None:
184-
# TODO rename to _create_embedding_kernel
185-
def _create_lookup(
186-
config: GroupedEmbeddingConfig,
187-
) -> BaseEmbedding:
188-
for table in config.embedding_tables:
189-
if (
190-
table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING
191-
or table.compute_kernel == EmbeddingComputeKernel.KEY_VALUE
192-
):
193-
self._need_prefetch = True
194-
if config.compute_kernel == EmbeddingComputeKernel.DENSE:
195-
return BatchedDenseEmbedding(
196-
config=config,
197-
pg=pg,
198-
device=device,
199-
)
200-
elif config.compute_kernel == EmbeddingComputeKernel.FUSED:
201-
return BatchedFusedEmbedding(
202-
config=config,
203-
pg=pg,
204-
device=device,
205-
)
206-
elif config.compute_kernel in {
207-
EmbeddingComputeKernel.KEY_VALUE,
208-
}:
209-
return KeyValueEmbedding(
210-
config=config,
211-
pg=pg,
212-
device=device,
213-
)
214-
else:
215-
raise ValueError(
216-
f"Compute kernel not supported {config.compute_kernel}"
217-
)
218-
219184
super().__init__()
220185
self._emb_modules: nn.ModuleList = nn.ModuleList()
221186
self._need_prefetch: bool = False
222187
for config in grouped_configs:
223-
self._emb_modules.append(_create_lookup(config))
188+
self._emb_modules.append(self._create_embedding_kernel(config, pg, device))
224189

225190
self._feature_splits: List[int] = []
226191
for config in grouped_configs:
@@ -239,6 +204,41 @@ def _create_lookup(
239204

240205
self.grouped_configs = grouped_configs
241206

207+
def _create_embedding_kernel(
208+
self,
209+
config: GroupedEmbeddingConfig,
210+
pg: Optional[dist.ProcessGroup],
211+
device: Optional[torch.device],
212+
) -> BaseEmbedding:
213+
for table in config.embedding_tables:
214+
if (
215+
table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING
216+
or table.compute_kernel == EmbeddingComputeKernel.KEY_VALUE
217+
):
218+
self._need_prefetch = True
219+
if config.compute_kernel == EmbeddingComputeKernel.DENSE:
220+
return BatchedDenseEmbedding(
221+
config=config,
222+
pg=pg,
223+
device=device,
224+
)
225+
elif config.compute_kernel == EmbeddingComputeKernel.FUSED:
226+
return BatchedFusedEmbedding(
227+
config=config,
228+
pg=pg,
229+
device=device,
230+
)
231+
elif config.compute_kernel in {
232+
EmbeddingComputeKernel.KEY_VALUE,
233+
}:
234+
return KeyValueEmbedding(
235+
config=config,
236+
pg=pg,
237+
device=device,
238+
)
239+
else:
240+
raise ValueError(f"Compute kernel not supported {config.compute_kernel}")
241+
242242
def prefetch(
243243
self,
244244
sparse_features: KeyedJaggedTensor,
@@ -409,44 +409,12 @@ def __init__(
409409
scale_weight_gradients: bool = True,
410410
sharding_type: Optional[ShardingType] = None,
411411
) -> None:
412-
# TODO rename to _create_embedding_kernel
413-
def _create_lookup(
414-
config: GroupedEmbeddingConfig,
415-
device: Optional[torch.device] = None,
416-
sharding_type: Optional[ShardingType] = None,
417-
) -> BaseEmbedding:
418-
if config.compute_kernel == EmbeddingComputeKernel.DENSE:
419-
return BatchedDenseEmbeddingBag(
420-
config=config,
421-
pg=pg,
422-
device=device,
423-
sharding_type=sharding_type,
424-
)
425-
elif config.compute_kernel == EmbeddingComputeKernel.FUSED:
426-
return BatchedFusedEmbeddingBag(
427-
config=config,
428-
pg=pg,
429-
device=device,
430-
sharding_type=sharding_type,
431-
)
432-
elif config.compute_kernel in {
433-
EmbeddingComputeKernel.KEY_VALUE,
434-
}:
435-
return KeyValueEmbeddingBag(
436-
config=config,
437-
pg=pg,
438-
device=device,
439-
sharding_type=sharding_type,
440-
)
441-
else:
442-
raise ValueError(
443-
f"Compute kernel not supported {config.compute_kernel}"
444-
)
445-
446412
super().__init__()
447413
self._emb_modules: nn.ModuleList = nn.ModuleList()
448414
for config in grouped_configs:
449-
self._emb_modules.append(_create_lookup(config, device, sharding_type))
415+
self._emb_modules.append(
416+
self._create_embedding_kernel(config, device, pg, sharding_type)
417+
)
450418

451419
self._feature_splits: List[int] = []
452420
for config in grouped_configs:
@@ -473,6 +441,39 @@ def _create_lookup(
473441
else 1
474442
)
475443

444+
def _create_embedding_kernel(
445+
self,
446+
config: GroupedEmbeddingConfig,
447+
device: Optional[torch.device],
448+
pg: Optional[dist.ProcessGroup],
449+
sharding_type: Optional[ShardingType],
450+
) -> BaseEmbedding:
451+
if config.compute_kernel == EmbeddingComputeKernel.DENSE:
452+
return BatchedDenseEmbeddingBag(
453+
config=config,
454+
pg=pg,
455+
device=device,
456+
sharding_type=sharding_type,
457+
)
458+
elif config.compute_kernel == EmbeddingComputeKernel.FUSED:
459+
return BatchedFusedEmbeddingBag(
460+
config=config,
461+
pg=pg,
462+
device=device,
463+
sharding_type=sharding_type,
464+
)
465+
elif config.compute_kernel in {
466+
EmbeddingComputeKernel.KEY_VALUE,
467+
}:
468+
return KeyValueEmbeddingBag(
469+
config=config,
470+
pg=pg,
471+
device=device,
472+
sharding_type=sharding_type,
473+
)
474+
else:
475+
raise ValueError(f"Compute kernel not supported {config.compute_kernel}")
476+
476477
def prefetch(
477478
self,
478479
sparse_features: KeyedJaggedTensor,

torchrec/distributed/embedding_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class EmbeddingComputeKernel(Enum):
7070
QUANT_UVM = "quant_uvm"
7171
QUANT_UVM_CACHING = "quant_uvm_caching"
7272
KEY_VALUE = "key_value"
73+
CUSTOMIZED_KERNEL = "customized_kernel"
7374

7475

7576
def compute_kernel_to_embedding_location(

0 commit comments

Comments
 (0)