Skip to content

Commit

Permalink
Extract and export weights offsets/placements initialization functions (
Browse files Browse the repository at this point in the history
pytorch#1669)

Summary:
Pull Request resolved: pytorch#1669

Extract portions initializing the weights_placements/offsets tensors into separate functions and jit.export them.
SplitState is converted to a NamedTuple since we can't jit.script a dataclass that also holds an enum.

Differential Revision: https://internalfb.com/D44338256

fbshipit-source-id: ecf64241f999bccabf3cb2ed9b72923f9c2951af
  • Loading branch information
qxy11 authored and facebook-github-bot committed Mar 28, 2023
1 parent 2776770 commit d1f8973
Showing 1 changed file with 68 additions and 39 deletions.
107 changes: 68 additions & 39 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,16 @@ class CounterBasedRegularizationDefinition:
[("record_cache_miss_counter", bool), ("record_tablewise_cache_miss", bool)],
)


@dataclass
class SplitState:
dev_size: int
host_size: int
uvm_size: int
placements: List[EmbeddingLocation]
offsets: List[int]
SplitState: NamedTuple = NamedTuple(
"SplitState",
[
("dev_size", int),
("host_size", int),
("uvm_size", int),
("placements", List[EmbeddingLocation]),
("offsets", List[int]),
],
)


def construct_split_state(
Expand All @@ -132,11 +134,11 @@ def construct_split_state(
precision: SparseType = SparseType.FP32,
int8_emb_row_dim_offset: int = INT8_EMB_ROW_DIM_OFFSET,
) -> SplitState:
placements = []
offsets = []
dev_size = 0
host_size = 0
uvm_size = 0
placements: List[EmbeddingLocation] = []
offsets: List[int] = []
dev_size: int = 0
host_size: int = 0
uvm_size: int = 0
for num_embeddings, embedding_dim, location, _ in embedding_specs:
assert (
embedding_dim % 4 == 0
Expand Down Expand Up @@ -1935,8 +1937,8 @@ def nbit_construct_split_state(
scale_bias_size_in_bytes: int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
cacheline_alignment: bool = True,
) -> SplitState:
placements = []
offsets = []
placements = torch.jit.annotate(List[EmbeddingLocation], [])
offsets = torch.jit.annotate(List[int], [])
dev_size = 0
host_size = 0
uvm_size = 0
Expand Down Expand Up @@ -1984,6 +1986,8 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):
cache_miss_counter: torch.Tensor
uvm_cache_stats: torch.Tensor
local_uvm_cache_stats: torch.Tensor
weights_offsets: torch.Tensor
weights_placements: torch.Tensor

def __init__(
self,
Expand Down Expand Up @@ -2165,21 +2169,7 @@ def max_ty_D(ty: SparseType) -> int:
]
self.max_D_cache: int = max(cached_dims) if len(cached_dims) > 0 else 0

weight_split: SplitState = nbit_construct_split_state(
self.embedding_specs,
cacheable=True,
row_alignment=self.row_alignment,
scale_bias_size_in_bytes=self.scale_bias_size_in_bytes,
cacheline_alignment=cacheline_alignment,
)

self.weights_physical_placements: List[int] = [
t.value for t in weight_split.placements
]
self.weights_physical_offsets: List[int] = weight_split.offsets
self.host_size: int = weight_split.host_size
self.dev_size: int = weight_split.dev_size
self.uvm_size: int = weight_split.uvm_size
self.initialize_physical_weights_placements_and_offsets(cacheline_alignment)
self.enforce_hbm: bool = enforce_hbm

# Assign weights after weights and weights_offsets are initialized.
Expand All @@ -2192,7 +2182,8 @@ def max_ty_D(ty: SparseType) -> int:
self.weights_physical_offsets,
self.enforce_hbm,
)
self.assign_embedding_weights(weight_lists) # type: ignore
# pyre-fixme [6]: In call `IntNBitTableBatchedEmbeddingBagsCodegen.assign_embedding_weights`, for 1st positional argument, expected `List[Tuple[Tensor, Optional[Tensor]]]` but got `List[Tuple[Tensor, Tensor]]`.
self.assign_embedding_weights(weight_lists)

# Handle index remapping for embedding pruning.
self.register_buffer(
Expand Down Expand Up @@ -2654,6 +2645,51 @@ def forward(
fp8_exponent_bias=self.fp8_exponent_bias,
)

def initialize_logical_weights_placements_and_offsets(
self,
) -> None:
assert len(self.weights_physical_offsets) == len(self.embedding_specs)
assert len(self.weights_physical_offsets) == len(
self.weights_physical_placements
)
offsets = [self.weights_physical_offsets[t] for t in self.feature_table_map]
placements = [
self.weights_physical_placements[t] for t in self.feature_table_map
]
self.weights_offsets = torch.tensor(
offsets, device=self.current_device, dtype=torch.int64
)
self.weights_placements = torch.tensor(
placements, device=self.current_device, dtype=torch.int32
)

def initialize_physical_weights_placements_and_offsets(
self,
cacheline_alignment: bool = True,
) -> None:
# Initialize physical weights placements and offsets
# and host/dev/uvm sizes
weight_split: SplitState = nbit_construct_split_state(
self.embedding_specs,
cacheable=True,
row_alignment=self.row_alignment,
scale_bias_size_in_bytes=self.scale_bias_size_in_bytes,
cacheline_alignment=cacheline_alignment,
)
self.weights_physical_placements = [t.value for t in weight_split.placements]
self.weights_physical_offsets = weight_split.offsets
self.host_size = weight_split.host_size
self.dev_size = weight_split.dev_size
self.uvm_size = weight_split.uvm_size

@torch.jit.export
def reset_weights_placements_and_offsets(
self,
) -> None:
# Initialize all physical/logical weights placements and offsets without initializing large dev weights tensor
self.initialize_physical_weights_placements_and_offsets()
self.initialize_logical_weights_placements_and_offsets()

def _apply_split(
self,
dev_size: int,
Expand All @@ -2672,14 +2708,7 @@ def _apply_split(
self.dev_size = dev_size
self.uvm_size = uvm_size

offsets = [offsets[t] for t in self.feature_table_map]
placements = [placements[t] for t in self.feature_table_map]
self.weights_offsets = torch.tensor(
offsets, device=self.current_device, dtype=torch.int64
)
self.weights_placements = torch.tensor(
placements, device=self.current_device, dtype=torch.int32
)
self.initialize_logical_weights_placements_and_offsets()

if dev_size > 0:
self.weights_dev = torch.zeros(
Expand Down

0 comments on commit d1f8973

Please sign in to comment.