diff --git a/sharktank/sharktank/layers/ffn_moe_block.py b/sharktank/sharktank/layers/ffn_moe_block.py index 0746f0fa0..2adb9464f 100644 --- a/sharktank/sharktank/layers/ffn_moe_block.py +++ b/sharktank/sharktank/layers/ffn_moe_block.py @@ -84,6 +84,12 @@ def __init__( super().__init__(theta) if theta.optional_tensor("ffn_gate_exps") is not None: + ''' + Expands a single merged expert tensor to individual expert tensors + Eg: Converts blk.0.ffn_gate_exps.weight to blk.0.ffn_gate.0.weight, blk.0.ffn_gate.1.weight, etc. + + ''' + merged_tensor = theta.tensor("ffn_gate_exps", "weight") expert_tensor = extract_ffn_layer( @@ -130,7 +136,11 @@ def forward( def extract_ffn_layer( merged_tensor: DefaultPrimitiveTensor, layer_name: str, expert_idx: int ): - # fetches the block_idx from merged_tensor_name. e.g. blk.0.ffn_gate_exps.weight + ''' + Given a merged expert tensor and an expert_idx, extracts the respective expert tensor + and constructs a DefaultPrimitiveTensor with the relevant expert layer name + ''' + expert_layer_name = ( f"blk.{merged_tensor.name.split('.')[1]}.{layer_name}.{expert_idx}.weight" )