From f5bf27593c88b6160505f457591ea04014c779f0 Mon Sep 17 00:00:00 2001 From: Jasper Shan Date: Wed, 28 May 2025 22:04:20 -0700 Subject: [PATCH] Refactoring ITEP / PTP Pruning Scuba Logger [3/N] (#3002) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/3002 refactor 3/n Reviewed By: AKhazane Differential Revision: D75108474 --- torchrec/modules/itep_modules.py | 338 ++++++++++++++--------------- torchrec/modules/pruning_logger.py | 63 +++--- 2 files changed, 195 insertions(+), 206 deletions(-) diff --git a/torchrec/modules/itep_modules.py b/torchrec/modules/itep_modules.py index a7f87a009..00af41fbc 100644 --- a/torchrec/modules/itep_modules.py +++ b/torchrec/modules/itep_modules.py @@ -11,7 +11,7 @@ import logging import math from collections import OrderedDict -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Type import torch from torch import distributed as dist, nn @@ -73,113 +73,121 @@ def __init__( pruning_interval: int = 1001, # Default pruning interval 1001 iterations pg: Optional[dist.ProcessGroup] = None, table_name_to_sharding_type: Optional[Dict[str, str]] = None, - scuba_logger: Optional[PruningLogger] = None, + pruning_logger_type: Type[PruningLogger] = PruningLoggerDefault, ) -> None: - super(GenericITEPModule, self).__init__() - - if not table_name_to_sharding_type: - table_name_to_sharding_type = {} - - # Construct in-training embedding pruning args - self.enable_pruning: bool = enable_pruning - self.rank_to_virtual_index_mapping: Dict[str, Dict[int, int]] = {} - self.pruning_interval: int = pruning_interval - self.lookups: Optional[List[nn.Module]] = None if not lookups else lookups - self.table_name_to_unpruned_hash_sizes: Dict[str, int] = ( - table_name_to_unpruned_hash_sizes - ) - self.table_name_to_sharding_type = table_name_to_sharding_type - - self.scuba_logger: PruningLogger = ( - scuba_logger if scuba_logger is not None else PruningLoggerDefault() - ) - self.scuba_logger.log_run_info() - - # Map each feature to a physical address_lookup/row_util buffer - self.feature_table_map: Dict[str, int] = {} - self.table_name_to_idx: Dict[str, int] = {} - self.buffer_offsets_list: List[int] = [] - self.idx_to_table_name: Dict[int, str] = {} - # Prevent multi-pruning, after moving iteration counter to outside. - self.last_pruned_iter = -1 - self.pg = pg - - if self.lookups is not None: - self.init_itep_state() - else: - logger.info( - "ITEP init: no lookups provided. Skipping init for dummy module." + self.pruning_logger: Type[PruningLogger] = pruning_logger_type + with self.pruning_logger.pruning_logger( + event="ITEP_MODULE_INIT" + ) as log_details: + log_details.__setattr__("enable_pruning", enable_pruning) + log_details.__setattr__("pruning_interval", pruning_interval) + + super(GenericITEPModule, self).__init__() + + if not table_name_to_sharding_type: + table_name_to_sharding_type = {} + + # Construct in-training embedding pruning args + self.enable_pruning: bool = enable_pruning + self.rank_to_virtual_index_mapping: Dict[str, Dict[int, int]] = {} + self.pruning_interval: int = pruning_interval + self.lookups: List[nn.Module] = [] if not lookups else lookups + self.table_name_to_unpruned_hash_sizes: Dict[str, int] = ( + table_name_to_unpruned_hash_sizes + ) + self.table_name_to_sharding_type: Dict[str, str] = ( + table_name_to_sharding_type ) + # Map each feature to a physical address_lookup/row_util buffer + self.feature_table_map: Dict[str, int] = {} + self.table_name_to_idx: Dict[str, int] = {} + self.buffer_offsets_list: List[int] = [] + self.idx_to_table_name: Dict[int, str] = {} + # Prevent multi-pruning, after moving iteration counter to outside. + self.last_pruned_iter: int = -1 + self.pg: Optional[dist.ProcessGroup] = pg + + if self.lookups is not None: + self.init_itep_state() + else: + logger.info( + "ITEP init: no lookups provided. Skipping init for dummy module." + ) + def print_itep_eviction_stats( self, pruned_indices_offsets: torch.Tensor, pruned_indices_total_length: torch.Tensor, cur_iter: int, ) -> None: - table_name_to_eviction_ratio = {} - buffer_idx_to_eviction_ratio = {} - buffer_idx_to_sizes = {} - - num_buffers = len(self.buffer_offsets_list) - 1 - for buffer_idx in range(num_buffers): - pruned_start = pruned_indices_offsets[buffer_idx] - pruned_end = pruned_indices_offsets[buffer_idx + 1] - pruned_length = pruned_end - pruned_start - - if pruned_length > 0: - start = self.buffer_offsets_list[buffer_idx] - end = self.buffer_offsets_list[buffer_idx + 1] - buffer_length = end - start - assert buffer_length > 0 - eviction_ratio = pruned_length.item() / buffer_length - table_name_to_eviction_ratio[self.idx_to_table_name[buffer_idx]] = ( - eviction_ratio + with self.pruning_logger.pruning_logger(event="ITEP_EVICTION"): + table_name_to_eviction_ratio = {} + buffer_idx_to_eviction_ratio = {} + buffer_idx_to_sizes = {} + + num_buffers = len(self.buffer_offsets_list) - 1 + for buffer_idx in range(num_buffers): + pruned_start = pruned_indices_offsets[buffer_idx] + pruned_end = pruned_indices_offsets[buffer_idx + 1] + pruned_length = pruned_end - pruned_start + + if pruned_length > 0: + start = self.buffer_offsets_list[buffer_idx] + end = self.buffer_offsets_list[buffer_idx + 1] + buffer_length = end - start + assert buffer_length > 0 + eviction_ratio = pruned_length.item() / buffer_length + table_name_to_eviction_ratio[self.idx_to_table_name[buffer_idx]] = ( + eviction_ratio + ) + buffer_idx_to_eviction_ratio[buffer_idx] = eviction_ratio + buffer_idx_to_sizes[buffer_idx] = ( + pruned_length.item(), + buffer_length, + ) + + # Sort the mapping by eviction ratio in descending order + sorted_mapping = dict( + sorted( + table_name_to_eviction_ratio.items(), + key=lambda item: item[1], + reverse=True, ) - buffer_idx_to_eviction_ratio[buffer_idx] = eviction_ratio - buffer_idx_to_sizes[buffer_idx] = (pruned_length.item(), buffer_length) - - # Sort the mapping by eviction ratio in descending order - sorted_mapping = dict( - sorted( - table_name_to_eviction_ratio.items(), - key=lambda item: item[1], - reverse=True, ) - ) - logged_eviction_mapping = {} - for idx in sorted_mapping.keys(): - try: - logged_eviction_mapping[self.reversed_feature_table_map[idx]] = ( - sorted_mapping[idx] - ) - except KeyError: - # in dummy mode, we don't have the feature_table_map or reversed_feature_table_map - pass - - table_to_sizes_mapping = {} - for idx in buffer_idx_to_sizes.keys(): - try: - table_to_sizes_mapping[self.reversed_feature_table_map[idx]] = ( - buffer_idx_to_sizes[idx] - ) - except KeyError: - # in dummy mode, we don't have the feature_table_map or reversed_feature_table_map - pass - - # Print the sorted mapping - logger.info(f"ITEP: table name to eviction ratio {sorted_mapping}") - - # Calculate percentage of indiced updated/evicted during ITEP iter - pruned_indices_ratio = ( - pruned_indices_total_length / self.buffer_offsets_list[-1] - if self.buffer_offsets_list[-1] > 0 - else 0 - ) - logger.info( - f"Performed ITEP in iter {cur_iter}, evicted {pruned_indices_total_length} ({pruned_indices_ratio:%}) indices." - ) + logged_eviction_mapping = {} + for idx in sorted_mapping.keys(): + try: + logged_eviction_mapping[self.reversed_feature_table_map[idx]] = ( + sorted_mapping[idx] + ) + except KeyError: + # in dummy mode, we don't have the feature_table_map or reversed_feature_table_map + pass + + table_to_sizes_mapping = {} + for idx in buffer_idx_to_sizes.keys(): + try: + table_to_sizes_mapping[self.reversed_feature_table_map[idx]] = ( + buffer_idx_to_sizes[idx] + ) + except KeyError: + # in dummy mode, we don't have the feature_table_map or reversed_feature_table_map + pass + + # Print the sorted mapping + logger.info(f"ITEP: table name to eviction ratio {sorted_mapping}") + + # Calculate percentage of indiced updated/evicted during ITEP iter + pruned_indices_ratio = ( + pruned_indices_total_length / self.buffer_offsets_list[-1] + if self.buffer_offsets_list[-1] > 0 + else 0 + ) + logger.info( + f"Performed ITEP in iter {cur_iter}, evicted {pruned_indices_total_length} ({pruned_indices_ratio:%}) indices." + ) def get_table_hash_sizes(self, table: ShardedEmbeddingTable) -> Tuple[int, int]: unpruned_hash_size = table.num_embeddings @@ -251,7 +259,6 @@ def init_itep_state(self) -> None: self.current_device = None # Iterate over all tables - # pyre-ignore for lookup in self.lookups: while isinstance(lookup, DistributedDataParallel): lookup = lookup.module @@ -337,55 +344,49 @@ def reset_weight_momentum( pruned_indices: torch.Tensor, pruned_indices_offsets: torch.Tensor, ) -> None: - if self.lookups is not None: - # pyre-ignore - for lookup in self.lookups: - while isinstance(lookup, DistributedDataParallel): - lookup = lookup.module - for emb in lookup._emb_modules: - emb_tables: List[ShardedEmbeddingTable] = ( - emb._config.embedding_tables - ) + for lookup in self.lookups: + while isinstance(lookup, DistributedDataParallel): + lookup = lookup.module + for emb in lookup._emb_modules: + emb_tables: List[ShardedEmbeddingTable] = emb._config.embedding_tables - logical_idx = 0 - logical_table_ids = [] - buffer_ids = [] - for table in emb_tables: - name = table.name - if name in self.table_name_to_idx: - buffer_idx = self.table_name_to_idx[name] - start = pruned_indices_offsets[buffer_idx] - end = pruned_indices_offsets[buffer_idx + 1] - length = end - start - if length > 0: - logical_table_ids.append(logical_idx) - buffer_ids.append(buffer_idx) - logical_idx += table.num_features() - - if len(logical_table_ids) > 0: - emb.emb_module.reset_embedding_weight_momentum( - pruned_indices, - pruned_indices_offsets, - torch.tensor( - logical_table_ids, - dtype=torch.int32, - requires_grad=False, - ), - torch.tensor( - buffer_ids, dtype=torch.int32, requires_grad=False - ), - ) + logical_idx = 0 + logical_table_ids = [] + buffer_ids = [] + for table in emb_tables: + name = table.name + if name in self.table_name_to_idx: + buffer_idx = self.table_name_to_idx[name] + start = pruned_indices_offsets[buffer_idx] + end = pruned_indices_offsets[buffer_idx + 1] + length = end - start + if length > 0: + logical_table_ids.append(logical_idx) + buffer_ids.append(buffer_idx) + logical_idx += table.num_features() + + if len(logical_table_ids) > 0: + emb.emb_module.reset_embedding_weight_momentum( + pruned_indices, + pruned_indices_offsets, + torch.tensor( + logical_table_ids, + dtype=torch.int32, + requires_grad=False, + ), + torch.tensor( + buffer_ids, dtype=torch.int32, requires_grad=False + ), + ) # Flush UVM cache after ITEP eviction to remove stale states def flush_uvm_cache(self) -> None: - if self.lookups is not None: - # pyre-ignore - for lookup in self.lookups: - while isinstance(lookup, DistributedDataParallel): - lookup = lookup.module - for emb in lookup._emb_modules: - emb.emb_module.flush() - emb.emb_module.reset_cache_states() + for lookup in self.lookups: + while isinstance(lookup, DistributedDataParallel): + lookup = lookup.module + for emb in lookup._emb_modules: + emb.emb_module.flush() + emb.emb_module.reset_cache_states() def get_remap_info(self, features: KeyedJaggedTensor) -> List[torch.Tensor]: keys = features.keys() @@ -460,7 +461,7 @@ def forward( We use the same forward method for sharded and non-sharded case. """ - if not self.enable_pruning or self.lookups is None: + if not self.enable_pruning or not self.lookups: return sparse_features num_buffers = self.buffer_offsets.size(dim=0) - 1 @@ -695,7 +696,7 @@ def get_key_from_table_name_and_suffix( key = get_key_from_table_name_and_suffix(table.name, prefix, suffix) param_idx = self.table_name_to_idx[table.name] buffer_param: torch.Tensor = get_param(params, param_idx) - sharding_type = self.table_name_to_sharding_type[table.name] # pyre-ignore + sharding_type = self.table_name_to_sharding_type[table.name] # For inference there is no pg, all tensors are local if table.global_metadata is not None and pg is not None: @@ -806,29 +807,28 @@ def state_dict( destination = OrderedDict() # pyre-ignore [16] destination._metadata = OrderedDict() - if self.lookups is not None: - # pyre-ignore [16] - for lookup in self.lookups: - list_of_tables: List[ShardedEmbeddingTable] = [] - for emb_config in lookup.grouped_configs: - list_of_tables.extend(emb_config.embedding_tables) - - destination = self.get_itp_state_dict( - list_of_tables, - self.address_lookup, # pyre-ignore - self.pg, - destination, - prefix, - suffix="_itp_address_lookup", - dtype=torch.int64, - ) - destination = self.get_itp_state_dict( - list_of_tables, - self.row_util, # pyre-ignore - self.pg, - destination, - prefix, - suffix="_itp_row_util", - dtype=torch.float32, - ) + for lookup in self.lookups: + list_of_tables: List[ShardedEmbeddingTable] = [] + # pyre-ignore [29] + for emb_config in lookup.grouped_configs: + list_of_tables.extend(emb_config.embedding_tables) + + destination = self.get_itp_state_dict( + list_of_tables, + self.address_lookup, # pyre-ignore + self.pg, + destination, + prefix, + suffix="_itp_address_lookup", + dtype=torch.int64, + ) + destination = self.get_itp_state_dict( + list_of_tables, + self.row_util, # pyre-ignore + self.pg, + destination, + prefix, + suffix="_itp_row_util", + dtype=torch.float32, + ) return destination diff --git a/torchrec/modules/pruning_logger.py b/torchrec/modules/pruning_logger.py index 90d6a6b10..53bf1589e 100644 --- a/torchrec/modules/pruning_logger.py +++ b/torchrec/modules/pruning_logger.py @@ -7,26 +7,29 @@ import logging from abc import ABC, abstractmethod -from typing import Mapping, Optional, Tuple, Union +from contextlib import contextmanager +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Generator, Optional logger: logging.Logger = logging.getLogger(__name__) -class PruningLogger(ABC): - @abstractmethod - def log_table_eviction_info( - self, - iteration: Optional[Union[bool, float, int]], - rank: Optional[int], - table_to_sizes_mapping: Mapping[str, Tuple[int, int]], - eviction_tables: Mapping[str, float], - ) -> None: - pass +@dataclass +class PruningLogBase(object): + pass + +class PruningLogger(ABC): + @classmethod @abstractmethod - def log_run_info( - self, - ) -> None: + @contextmanager + def pruning_logger( + cls, + event: str, + trainer: Optional[str] = None, + publisher: Optional[str] = None, + ) -> Generator[object, None, None]: pass @@ -35,26 +38,12 @@ class PruningLoggerDefault(PruningLogger): noop logger as a default """ - def __init__( - self, - ) -> None: - """ - Initialize PruningScubaLogger. - """ - pass - - def log_table_eviction_info( - self, - iteration: Optional[Union[bool, float, int]], - rank: Optional[int], - table_to_sizes_mapping: Mapping[str, Tuple[int, int]], - eviction_tables: Mapping[str, float], - ) -> None: - logger.info( - f"iteration={iteration}, rank={rank}, table_to_sizes_mapping={table_to_sizes_mapping}, eviction_tables={eviction_tables}" - ) - - def log_run_info( - self, - ) -> None: - pass + @classmethod + @contextmanager + def pruning_logger( + cls, + event: str, + trainer: Optional[str] = None, + publisher: Optional[str] = None, + ) -> Generator[object, None, None]: + yield SimpleNamespace()