diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 119c304c8..1aff0ecf6 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -764,32 +764,37 @@ def __init__( self._weight_init_mins: List[float] = [] self._weight_init_maxs: List[float] = [] self._num_embeddings: List[int] = [] + self._embedding_dims: List[int] = [] self._local_cols: List[int] = [] + self._row_offset: List[int] = [] + self._col_offset: List[int] = [] self._feature_table_map: List[int] = [] self.table_name_to_count: Dict[str, int] = {} self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {} - # pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as - # `ShardedEmbeddingTable`. - for idx, config in enumerate(self._config.embedding_tables): - # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_rows`. - self._local_rows.append(config.local_rows) - # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute - # `get_weight_init_min`. - self._weight_init_mins.append(config.get_weight_init_min()) - # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute - # `get_weight_init_max`. - self._weight_init_maxs.append(config.get_weight_init_max()) - # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute - # `num_embeddings`. - self._num_embeddings.append(config.num_embeddings) - # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_cols`. - self._local_cols.append(config.local_cols) - self._feature_table_map.extend([idx] * config.num_features()) - # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `name`. - if config.name not in self.table_name_to_count: - self.table_name_to_count[config.name] = 0 - self.table_name_to_count[config.name] += 1 + for idx, table_config in enumerate(self._config.embedding_tables): + self._local_rows.append(table_config.local_rows) + self._weight_init_mins.append(table_config.get_weight_init_min()) + self._weight_init_maxs.append(table_config.get_weight_init_max()) + self._num_embeddings.append(table_config.num_embeddings) + self._embedding_dims.append(table_config.embedding_dim) + self._row_offset.append( + table_config.local_metadata.shard_offsets[0] + if table_config.local_metadata + and len(table_config.local_metadata.shard_offsets) > 0 + else 0 + ) + self._col_offset.append( + table_config.local_metadata.shard_offsets[1] + if table_config.local_metadata + and len(table_config.local_metadata.shard_offsets) > 1 + else 0 + ) + self._local_cols.append(table_config.local_cols) + self._feature_table_map.extend([idx] * table_config.num_features()) + if table_config.name not in self.table_name_to_count: + self.table_name_to_count[table_config.name] = 0 + self.table_name_to_count[table_config.name] += 1 def init_parameters(self) -> None: # initialize embedding weights @@ -1080,6 +1085,14 @@ def __init__( weights_precision=weights_precision, device=device, table_names=[t.name for t in config.embedding_tables], + embedding_shard_info=list( + zip( + self._num_embeddings, + self._embedding_dims, + self._row_offset, + self._col_offset, + ) + ), **fused_params, ) ) @@ -1216,34 +1229,39 @@ def __init__( self._weight_init_mins: List[float] = [] self._weight_init_maxs: List[float] = [] self._num_embeddings: List[int] = [] + self._embedding_dims: List[int] = [] self._local_cols: List[int] = [] + self._row_offset: List[int] = [] + self._col_offset: List[int] = [] self._feature_table_map: List[int] = [] self._emb_names: List[str] = [] self._lengths_per_emb: List[int] = [] self.table_name_to_count: Dict[str, int] = {} self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {} - # pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as - # `ShardedEmbeddingTable`. - for idx, config in enumerate(self._config.embedding_tables): - # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_rows`. - self._local_rows.append(config.local_rows) - # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute - # `get_weight_init_min`. - self._weight_init_mins.append(config.get_weight_init_min()) - # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute - # `get_weight_init_max`. - self._weight_init_maxs.append(config.get_weight_init_max()) - # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute - # `num_embeddings`. - self._num_embeddings.append(config.num_embeddings) - # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_cols`. - self._local_cols.append(config.local_cols) - self._feature_table_map.extend([idx] * config.num_features()) - # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `name`. - if config.name not in self.table_name_to_count: - self.table_name_to_count[config.name] = 0 - self.table_name_to_count[config.name] += 1 + for idx, table_config in enumerate(self._config.embedding_tables): + self._local_rows.append(table_config.local_rows) + self._weight_init_mins.append(table_config.get_weight_init_min()) + self._weight_init_maxs.append(table_config.get_weight_init_max()) + self._num_embeddings.append(table_config.num_embeddings) + self._embedding_dims.append(table_config.embedding_dim) + self._row_offset.append( + table_config.local_metadata.shard_offsets[0] + if table_config.local_metadata + and len(table_config.local_metadata.shard_offsets) > 0 + else 0 + ) + self._col_offset.append( + table_config.local_metadata.shard_offsets[1] + if table_config.local_metadata + and len(table_config.local_metadata.shard_offsets) > 1 + else 0 + ) + self._local_cols.append(table_config.local_cols) + self._feature_table_map.extend([idx] * table_config.num_features()) + if table_config.name not in self.table_name_to_count: + self.table_name_to_count[table_config.name] = 0 + self.table_name_to_count[table_config.name] += 1 def init_parameters(self) -> None: # initialize embedding weights @@ -1564,6 +1582,14 @@ def __init__( weights_precision=weights_precision, device=device, table_names=[t.name for t in config.embedding_tables], + embedding_shard_info=list( + zip( + self._num_embeddings, + self._embedding_dims, + self._row_offset, + self._col_offset, + ) + ), **fused_params, ) )