diff --git a/autoparallel/_testing/models/moe.py b/autoparallel/_testing/models/moe.py new file mode 100644 index 00000000..3a3a7a46 --- /dev/null +++ b/autoparallel/_testing/models/moe.py @@ -0,0 +1,413 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Callable, Literal + +import torch +import torch.nn.functional as F + +from autoparallel._testing.models.moe_utils import ( + generate_permute_indices, + TOKEN_GROUP_ALIGN_SIZE_M, +) +from torch import nn +from torch.distributed.tensor.placement_types import Replicate, Shard + + +def expert_parallel(func: Callable) -> Callable: + """ + This is a wrapper applied to the GroupedExperts computation, serving + the following three purposes: + 1. Convert parameters from DTensors to plain Tensors, to work with + dynamic-shape inputs which cannot be easily expressed as DTensors. + 2. In Expert Parallel, apply the generate_permute_indices kernel to + permute the inputs to be ordered by local experts (see the _token_dispatch + function in ExpertParallel) and permute the outputs back. + 3. In order to use torch._grouped_mm, we need to make sure the number of + tokens each expert gets is a multiple of ALIGN_SIZE_M. The generate_permute_indices + kernel also helps achieve this via padding, without incurring synchronization + between device and host. Note that this will create side effects when wrapping + the for-loop implementation of GroupedExperts, as it does not need padding. + + Among the above: + 1 and 2 are needed only when expert_parallel_degree > 1. + 3 is needed even for single-device computation. + 2 can be moved to ExpertParallel _token_dispatch if not coupled with 3. + """ + + def wrapper( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + ) -> torch.Tensor: + + experts_per_ep_rank = w1.shape[0] + num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank + + with torch.no_grad(): + ( + permuted_indices, + num_tokens_per_expert, + _, # offsets, + ) = generate_permute_indices( + num_tokens_per_expert, + experts_per_ep_rank, + num_ep_ranks, + x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M, + TOKEN_GROUP_ALIGN_SIZE_M, + ) + + x = torch.vstack((x, x.new_zeros((x.shape[-1])))) + input_shape = x.shape + x = x[permuted_indices, :] + + out = func(w1, w2, w3, x, num_tokens_per_expert) + + out_unpermuted = out.new_empty(input_shape) + out_unpermuted[permuted_indices, :] = out + out = out_unpermuted[:-1] + + return out + + return wrapper + + +@dataclass +class MoEArgs: + num_experts: int = 8 + num_shared_experts: int = 1 + + # router + score_func: Literal["softmax", "sigmoid"] = "sigmoid" + route_norm: bool = False + route_scale: float = 1.0 + score_before_experts: bool = True + + # token-choice + top_k: int = 1 + use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation + load_balance_coeff: float | None = 1e-3 + + +def _run_shared_experts( + shared_w1: torch.Tensor, + shared_w2: torch.Tensor, + shared_w3: torch.Tensor, + x: torch.Tensor, +) -> torch.Tensor: + + h = F.silu(x @ shared_w1.transpose(-2, -1)) + h = h * x @ shared_w3.transpose(-2, -1) + out = h @ shared_w2.transpose(-2, -1) + return out + + +@expert_parallel +def _run_experts_grouped_mm( + x: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + num_tokens_per_expert: torch.Tensor, +) -> torch.Tensor: + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + # grouped mm between a 2D tensor and a 3D tensor + assert x.dim() == 2 + + h = F.silu( + torch._grouped_mm(x.bfloat16(), w1.bfloat16().transpose(-2, -1), offs=offsets) + ) + h = h * torch._grouped_mm( + x.bfloat16(), w3.bfloat16().transpose(-2, -1), offs=offsets + ) + out = torch._grouped_mm(h, w2.bfloat16().transpose(-2, -1), offs=offsets).type_as(x) + + return out + + +def _topk_token_choice_router( + x: torch.Tensor, + gate: torch.Tensor, + top_k: int, + num_experts: int, + score_func: Literal["softmax", "sigmoid"], + route_norm: bool, + route_scale: float, + expert_bias: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # scores shape (bs*slen, num_experts) + scores = x @ gate.transpose(-2, -1) + + # By default, sigmoid or softmax is performed in float32 to avoid loss explosion + if score_func == "sigmoid": + scores = torch.sigmoid(scores.to(torch.float32)) + elif score_func == "softmax": + scores = F.softmax(scores.to(torch.float32), dim=1) + else: + raise NotImplementedError(f"Unknown score function {score_func}") + + # top scores shape (bs*slen, top_k) + # NOTE: The expert_bias is only used for routing. The gating value + # top_scores is still derived from the original scores. + if expert_bias is not None: + _, selected_experts_indices = torch.topk(scores + expert_bias, k=top_k, dim=1) + top_scores = scores.gather(dim=1, index=selected_experts_indices) + else: + top_scores, selected_experts_indices = torch.topk(scores, k=top_k, dim=1) + + if score_func == "sigmoid" and route_norm: + denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20 + top_scores = top_scores / denominator + top_scores = top_scores * route_scale + + # group tokens together by expert indices from 0 to num_experts and pass that to experts forward + num_tokens_per_expert = torch.histc( + selected_experts_indices.view(-1), + bins=num_experts, + min=0, + max=num_experts, + ) + + return top_scores, selected_experts_indices, num_tokens_per_expert + + +def _reorder_tokens_by_experts( + top_scores: torch.Tensor, + selected_experts_indices: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + top_k: int, + num_experts: int, +) -> tuple[torch.Tensor, torch.Tensor]: + # Reorders token indices to match the order of experts for MoE routing. + # NOTE: the reason we need to compute num_tokens_per_expert again is: + # 1st computation in router is to update self.tokens_per_expert + # which would be the same across all TP ranks. + # 2nd computation in reorderer is for the actual routing and experts computation + # which would be sharded over TP ranks if expert_tensor_parallel_degree==1. + # If tensor_paralllel_degree == expert_tensor_parallel_degree, they agree. + # num_tokens_per_expert = torch.histc( + # selected_experts_indices.view(-1), + # bins=num_experts, + # min=0, + # max=num_experts, + # ) + + # Reorder the token indices to match the order of the experts + # token_indices_experts_sorted shape (bs*slen*top_k,) + token_indices_experts_sorted = torch.argsort( + selected_experts_indices.view(-1), stable=True + ) + + top_scores_experts_sorted = top_scores.view(-1)[token_indices_experts_sorted] + token_indices_experts_sorted = token_indices_experts_sorted // top_k + + return ( + top_scores_experts_sorted, + token_indices_experts_sorted, + ) + + +def _moe_forward( + x: torch.Tensor, + gate: torch.Tensor, + expert_bias: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + shared_w1: torch.Tensor, + shared_w2: torch.Tensor, + shared_w3: torch.Tensor, + top_k: int, + num_experts: int, + score_func: Literal["softmax", "sigmoid"], + route_norm: bool, + route_scale: float, + score_before_experts: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + bs, slen, dim = x.shape + x = x.view(-1, dim) + # top_scores and selected_experts_indices shape (bs*slen*top_k,) + # num_tokens_per_expert shape (num_experts,) + ( + top_scores, + selected_experts_indices, + num_tokens_per_expert, + ) = _topk_token_choice_router( + x, + gate, + top_k, + num_experts, + score_func, + route_norm, + route_scale, + expert_bias, + ) + # top_scores_experts_sorted and token_indices_experts_sorted shape (bs*slen*top_k,) + ( + top_scores_experts_sorted, + token_indices_experts_sorted, + ) = _reorder_tokens_by_experts( + top_scores, + selected_experts_indices, + num_tokens_per_expert, + top_k, + num_experts, + ) + + # shape (bs*slen*top_k, dim) + token_indices_experts_sorted = token_indices_experts_sorted.reshape(-1, 1).expand( + -1, dim + ) + + # shape (bs*slen*top_k, dim) + routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted) + + if score_before_experts: + routed_input = ( + routed_input.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) + ).to(x.dtype) + + # shape (bs*slen*top_k, dim) + routed_output = _run_experts_grouped_mm( + routed_input, w1, w2, w3, num_tokens_per_expert + ) + + if not score_before_experts: + routed_output = ( + routed_output.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) + ).to(x.dtype) + + # shared expert + # shape (bs*slen*top_k, dim) + out = _run_shared_experts(shared_w1, shared_w2, shared_w3, x) + out = out.scatter_add(dim=0, index=token_indices_experts_sorted, src=routed_output) + out = out.reshape(bs, slen, dim) + return out, num_tokens_per_expert + + +class MoE(nn.Module): + def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): + super().__init__() + + # Routed Experts + self.num_experts = moe_args.num_experts + self.w1 = nn.Parameter(torch.empty(self.num_experts, hidden_dim, dim)) + self.w2 = nn.Parameter(torch.empty(self.num_experts, dim, hidden_dim)) + self.w3 = nn.Parameter(torch.empty(self.num_experts, hidden_dim, dim)) + + # Router + self.top_k = moe_args.top_k + self.score_func = moe_args.score_func + self.route_norm = moe_args.route_norm + self.route_scale = moe_args.route_scale + self.gate = nn.Parameter(torch.empty(self.num_experts, dim)) + + # Shared Experts + self.use_shared_experts = moe_args.num_shared_experts > 0 + if self.use_shared_experts: + self.shared_w1 = nn.Parameter( + torch.empty(hidden_dim * moe_args.num_shared_experts, dim) + ) + self.shared_w2 = nn.Parameter( + torch.empty(dim, hidden_dim * moe_args.num_shared_experts) + ) + self.shared_w3 = nn.Parameter( + torch.empty(hidden_dim * moe_args.num_shared_experts, dim) + ) + + self.score_before_experts = moe_args.score_before_experts + + # define fields for auxiliary-loss-free load balancing (https://arxiv.org/abs/2408.15664) + # NOTE: tokens_per_expert is accumulated in the model forward pass. + # expert_bias is updated outside the model in an optimizer step pre hook + # to work with gradient accumulation. + self.load_balance_coeff = moe_args.load_balance_coeff + if self.load_balance_coeff is not None: + assert self.load_balance_coeff > 0.0 + self.register_buffer( + "expert_bias", + torch.zeros(self.num_experts, dtype=torch.float32), + persistent=True, + ) + else: + self.expert_bias = None + # tokens_per_expert will be used to track expert usage and to update the expert bias for load balancing + self.register_buffer( + "tokens_per_expert", + torch.zeros(self.num_experts, dtype=torch.float32), + persistent=False, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``. + + Returns: + out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. + """ + # tokens_per_expert will be used to update the expert bias for load balancing. + # and also to count the expert usage + # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- + # first in the forward pass, and then in the backward pass. However, this has no + # effect on the expert bias update thanks to the torch.sign() operator. + # moved out to remove mutation + assert self.expert_bias is not None, "Load balance coeff must be set" + assert self.use_shared_experts, "Shared experts must be enabled" + out, num_tokens_per_expert = _moe_forward( + x, + self.gate, + self.expert_bias, + self.w1, + self.w2, + self.w3, + self.shared_w1, + self.shared_w2, + self.shared_w3, + self.top_k, + self.num_experts, + self.score_func, + self.route_norm, + self.route_scale, + self.score_before_experts, + ) + + # HOPs don't support buffer mutations, keep this outside + with torch.no_grad(): + self.tokens_per_expert.add_(num_tokens_per_expert) + return out + + def init_weights( + self, + init_std: float, + buffer_device: torch.device, + ): + # Initialize Routed Expert Weights + nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std) + + # Initialize Router Weight + nn.init.trunc_normal_(self.gate, mean=0.0, std=init_std) + + # Initialize Shared Expert Weights + if self.use_shared_experts: + nn.init.trunc_normal_(self.shared_w1, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.shared_w2, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.shared_w3, mean=0.0, std=init_std) + + # Initialize Buffers + with torch.device(buffer_device): + self.tokens_per_expert = torch.zeros( + self.experts.num_experts, dtype=torch.float32 + ) + if self.load_balance_coeff is not None: + self.expert_bias = torch.zeros( + self.experts.num_experts, dtype=torch.float32 + ) diff --git a/autoparallel/_testing/models/moe_utils.py b/autoparallel/_testing/models/moe_utils.py new file mode 100644 index 00000000..6a0399b9 --- /dev/null +++ b/autoparallel/_testing/models/moe_utils.py @@ -0,0 +1,237 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Literal + +import torch +import triton +import triton.language as tl + +TOKEN_GROUP_ALIGN_SIZE_M = 8 +ValidTokenGroupAlignmentSize = Literal[8, 16, 32] + + +def set_token_group_alignment_size_m( + alignment_size: ValidTokenGroupAlignmentSize, +) -> None: + """ + Set the token group alignment size for token groups in MoE. This is implemented by + padding each token group size to the next multiple of TOKEN_GROUP_ALIGN_SIZE_M. + + Valid values are: 8, 16, or 32. + Different values are needed for different cases: + + * For bf16, 8 is enough (16 byte alignment / 2 bytes per elem = 8 elements). + * For fp8, 16 byte alignment / 1 byte per elem = 16 elements. + * For mxfp8, we need 32 (or block_size) because scaling block size is (1 x 32), + so when doing per-token-group quantization on each logically distinct subtensor, + we need to ensure the contracting dim is divisible by block_size. + In the backward pass, grad_weight = (grad_output_t @ input).t() has gemm dims + of (N, M) @ (M, K) so M is the contracting dim, and group offsets are along M, + so we need 32 element alignment. + """ + global TOKEN_GROUP_ALIGN_SIZE_M + TOKEN_GROUP_ALIGN_SIZE_M = alignment_size + + +# parallelized kernel +@triton.jit +def _fill_indices_kernel( + tokens_per_expert_group_ptr, + start_index_values_ptr, + write_offsets_ptr, + output_ptr, + experts_per_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, # Number of threads per block +): + pid = tl.program_id(axis=0) + num_programs = tl.num_programs(axis=0) + + # map programs (blocks) to the experts and loop (grid stride) if needed + for expert_id in range(pid, experts_per_rank, num_programs): + # read this experts write offset + write_offset = tl.load(write_offsets_ptr + expert_id) + + for r in range(num_ranks): + # index into tokens_per_expert_group array + i = r * experts_per_rank + expert_id + + # load start index and number of tokens for this expert-rank pair + start_index = tl.load(start_index_values_ptr + i) + length = tl.load(tokens_per_expert_group_ptr + i) + + # each thread in block processes tokens in parallel + offsets = tl.arange(0, BLOCK_SIZE) + + # tokens are processed in chunks of BLOCK_SIZE + for chunk_start in range(0, length, BLOCK_SIZE): + chunk_offsets = chunk_start + offsets + + # mask valid indices + mask = chunk_offsets < length + + values = start_index + chunk_offsets + + # destination + dest_indices = write_offset + chunk_offsets + + # store + tl.store(output_ptr + dest_indices, values, mask=mask) + + # update write offset for next rank + write_offset += length + + +# ============== +# wrapper +# ============== + + +def fill_indices_wrapper( + tokens_per_expert_group: torch.Tensor, + start_index_values: torch.Tensor, + write_offsets: torch.Tensor, + experts_per_rank: int, + num_ranks: int, + max_len: int, + block_size: int = 128, + max_blocks: int = 1024, # cap on total number of blocks to launch +): + # preallocate output + permuted_indices = torch.full( + (max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device + ) + + # write offsets is per local expert... + num_blocks = min(experts_per_rank, max_blocks) + # grid = one block per expert unless capped and then we loop... + grid = (num_blocks,) + + # launch kernel + _fill_indices_kernel[grid]( + tokens_per_expert_group, + start_index_values, + write_offsets, + permuted_indices, + experts_per_rank, + num_ranks, + BLOCK_SIZE=block_size, + ) + return permuted_indices + + +# reference +def fill_indices_cpu( + tokens_per_expert_group: torch.Tensor, + start_index_values: torch.Tensor, + write_offsets: torch.Tensor, + experts_per_rank: int, + num_ranks: int, + max_len: int, +): + # We need to preallocate the output - we ignore device and force it on cpu + # device = tokens_per_expert_group.device + permuted_indices = torch.full( + (max_len,), + -1, + dtype=torch.int32, + ) # device=device) + # Fill the permuted indices + # For each local expert + for e in range(experts_per_rank): + write_start = write_offsets[e].item() + # For each remote rank + for r in range(num_ranks): + i = r * experts_per_rank + e + start_index = start_index_values[i].item() + length = tokens_per_expert_group[i].item() + # Fill in the indices + if length > 0: + end_idx = min(write_start + length, max_len) + permuted_indices[write_start:end_idx] = torch.arange( + start_index, + start_index + (end_idx - write_start), + dtype=torch.int32, + # device=device, + ) + write_start += length + return permuted_indices + + +def generate_permute_indices( + tokens_per_expert_group: torch.Tensor, + experts_per_rank: int, + num_ranks: int, + max_len: int, + alignment: int, + use_cpu: bool = False, +): + """ + Prepare permutation indices and the number of tokens for each expert. + + Args: + tokens_per_expert_group: number of tokens for each expert from all ranks. + experts_per_rank: number of experts per rank. + num_ranks: number of ranks. + max_len: maximum length of the output index vector. + alignment: alignment for each returned element in `m_sizes` and padding min for zero token experts. + use_cpu: whether to use CPU implementation. + + + Returns: + permuted_indices: Tensor of indices that map original token order to the expert-grouped order. + m_sizes: aligned number of tokens for each expert (padded to alignment boundary). + m_offsets: Cumulative sum of m_sizes. The exclusive ending position for each expert's tokens. + + Explanatory details: + `tokens_per_expert_group` is of shape (num_ranks * experts_per_rank,), for example: + From: | rank 0 | rank 1 | + To: | E0 | E1 | E2 | E3 | E0 | E1 | E2 | E3 | + | 4 | 2 | 1 | 3 | 1 | 2 | 3 | 4 | + """ + + # prefix sum to get start index of each expert (parallel scan kernel in future?) + start_index_values = ( + torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group + ) + + # total tokens for each expert (sum over ranks) + total_tokens_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0) + + # pad out empty experts to alignment requirement + total_tokens_per_expert = torch.clamp_min(total_tokens_per_expert, alignment) + + # align the chunk sizes (cdiv) + m_sizes = ((total_tokens_per_expert + alignment - 1) // alignment * alignment).to( + torch.int32 + ) + + # additional prefix sum to get write offset of each expert in permuted_indices + # write offsets is per local expert, not global + m_offsets = torch.cumsum(m_sizes, 0) + write_offsets = m_offsets - m_sizes + + # Select the implementation to use + if use_cpu: + permuted_indices = fill_indices_cpu( + tokens_per_expert_group, + start_index_values, + write_offsets, + experts_per_rank, + num_ranks, + max_len, + ) + else: + permuted_indices = fill_indices_wrapper( + tokens_per_expert_group, + start_index_values, + write_offsets, + experts_per_rank, + num_ranks, + max_len, + ) + + return permuted_indices, m_sizes, m_offsets.to(torch.int32)