1010
1111void  moe_permute (
1212    const  torch::Tensor& input,                      //  [n_token, hidden]
13-     const  torch::Tensor& topk_weights,               // [n_token, topk]
14-     torch::Tensor& topk_ids,                         //  [n_token, topk]
13+     const  torch::Tensor& topk_ids,                   //  [n_token, topk]
1514    const  torch::Tensor& token_expert_indices,       //  [n_token, topk]
1615    const  std::optional<torch::Tensor>& expert_map,  //  [n_expert]
1716    int64_t  n_expert, int64_t  n_local_expert, int64_t  topk,
1817    const  std::optional<int64_t >& align_block_size,
19-     torch::Tensor&
20-         permuted_input,  //  [topk * n_token/align_block_size_m, hidden]
18+     torch::Tensor& permuted_input,             //  [permuted_size, hidden]
2119    torch::Tensor& expert_first_token_offset,  //  [n_local_expert + 1]
22-     torch::Tensor& src_row_id2dst_row_id_map,  //  [n_token, topk]
20+     torch::Tensor& inv_permuted_idx,           //  [n_token, topk]
21+     torch::Tensor& permuted_idx,               //  [permute_size]
2322    torch::Tensor& m_indices) {                //  [align_expand_m]
24-   TORCH_CHECK (topk_weights.scalar_type () == at::ScalarType::Float,
25-               " topk_weights must be float32"  );
2623  TORCH_CHECK (expert_first_token_offset.scalar_type () == at::ScalarType::Long,
2724              " expert_first_token_offset must be int64"  );
2825  TORCH_CHECK (topk_ids.scalar_type () == at::ScalarType::Int,
2926              " topk_ids must be int32"  );
3027  TORCH_CHECK (token_expert_indices.scalar_type () == at::ScalarType::Int,
3128              " token_expert_indices must be int32"  );
32-   TORCH_CHECK (src_row_id2dst_row_id_map .scalar_type () == at::ScalarType::Int,
33-               " src_row_id2dst_row_id_map  must be int32"  );
29+   TORCH_CHECK (inv_permuted_idx .scalar_type () == at::ScalarType::Int,
30+               " inv_permuted_idx  must be int32"  );
3431  TORCH_CHECK (expert_first_token_offset.size (0 ) == n_local_expert + 1 ,
3532              " expert_first_token_offset shape != n_local_expert+1"  )
36-   TORCH_CHECK (
37-       src_row_id2dst_row_id_map.sizes () == token_expert_indices.sizes (),
38-       " token_expert_indices shape must be same as src_row_id2dst_row_id_map"  );
33+   TORCH_CHECK (inv_permuted_idx.sizes () == token_expert_indices.sizes (),
34+               " token_expert_indices shape must be same as inv_permuted_idx"  );
3935  auto  n_token = input.sizes ()[0 ];
4036  auto  n_hidden = input.sizes ()[1 ];
4137  auto  align_block_size_value =
@@ -46,8 +42,9 @@ void moe_permute(
4642  auto  sort_workspace = torch::empty (
4743      {sorter_size},
4844      torch::dtype (torch::kInt8 ).device (torch::kCUDA ).requires_grad (false ));
45+   auto  copy_topk_ids = topk_ids.clone ();  //  copy topk_ids for preprocess
4946  auto  permuted_experts_id = torch::empty_like (topk_ids);
50-   auto  dst_row_id2src_row_id_map  = torch::empty_like (src_row_id2dst_row_id_map );
47+   auto  sorted_row_idx  = torch::empty_like (inv_permuted_idx );
5148  auto  align_expert_first_token_offset =
5249      torch::zeros_like (expert_first_token_offset);
5350
@@ -67,24 +64,22 @@ void moe_permute(
6764    const  int * expert_map_ptr = get_ptr<int >(expert_map.value ());
6865    valid_num_ptr =
6966        get_ptr<int64_t >(expert_first_token_offset) + n_local_expert;
70-     preprocessTopkIdLauncher (get_ptr<int >(topk_ids ), n_token * topk,
67+     preprocessTopkIdLauncher (get_ptr<int >(copy_topk_ids ), n_token * topk,
7168                             expert_map_ptr, n_expert, stream);
7269  }
7370  //  expert sort topk expert id and scan expert id get expert_first_token_offset
74-   sortAndScanExpert (get_ptr<int >(topk_ids), get_ptr<int >(token_expert_indices),
75-                     get_ptr<int >(permuted_experts_id),
76-                     get_ptr<int >(dst_row_id2src_row_id_map),
77-                     get_ptr<int64_t >(expert_first_token_offset), n_token,
78-                     n_expert, n_local_expert, topk, sorter,
79-                     get_ptr<int >(sort_workspace), stream);
71+   sortAndScanExpert (
72+       get_ptr<int >(copy_topk_ids), get_ptr<int >(token_expert_indices),
73+       get_ptr<int >(permuted_experts_id), get_ptr<int >(sorted_row_idx),
74+       get_ptr<int64_t >(expert_first_token_offset), n_token, n_expert,
75+       n_local_expert, topk, sorter, get_ptr<int >(sort_workspace), stream);
8076
8177  //  dispatch expandInputRowsKernelLauncher
8278  MOE_DISPATCH (input.scalar_type (), [&] {
8379    expandInputRowsKernelLauncher<scalar_t >(
8480        get_ptr<scalar_t >(input), get_ptr<scalar_t >(permuted_input),
85-         get_ptr<float >(topk_weights), get_ptr<int >(permuted_experts_id),
86-         get_ptr<int >(dst_row_id2src_row_id_map),
87-         get_ptr<int >(src_row_id2dst_row_id_map),
81+         get_ptr<int >(permuted_experts_id), get_ptr<int >(sorted_row_idx),
82+         get_ptr<int >(inv_permuted_idx), get_ptr<int >(permuted_idx),
8883        get_ptr<int64_t >(expert_first_token_offset), n_token, valid_num_ptr,
8984        n_hidden, topk, n_local_expert, align_block_size_value, stream);
9085  });
@@ -101,32 +96,34 @@ void moe_permute(
10196}
10297
10398void  moe_unpermute (
104-     const  torch::Tensor& permuted_hidden_states,      //  [n_token * topk, hidden]
105-     const  torch::Tensor& topk_weights,                // [n_token, topk]
106-     const  torch::Tensor& topk_ids,                    //  [n_token, topk]
107-     const  torch::Tensor& src_row_id2dst_row_id_map,   //  [n_token, topk] 
108-     const  torch::Tensor&  expert_first_token_offset,  //  [n_local_expert+1]
109-     int64_t  n_expert,  int64_t  n_local_expert,  int64_t   topk,
99+     const  torch::Tensor& permuted_hidden_states,  //  [n_token * topk, hidden]
100+     const  torch::Tensor& topk_weights,            //   [n_token, topk]
101+     const  torch::Tensor& inv_permuted_idx,         //  [n_token, topk]
102+     const  std::optional< torch::Tensor>& 
103+          expert_first_token_offset,  //  [n_local_expert+1]
104+     int64_t  topk,
110105    torch::Tensor& hidden_states  //  [n_token, hidden]
111106) {
112-   TORCH_CHECK (src_row_id2dst_row_id_map.sizes () == topk_ids.sizes (),
113-               " topk_ids shape must be same as src_row_id2dst_row_id_map"  );
114-   TORCH_CHECK (topk_ids.scalar_type () == at::ScalarType::Int,
115-               " topk_ids must be int32"  );
116107  TORCH_CHECK (
117108      permuted_hidden_states.scalar_type () == hidden_states.scalar_type (),
118-       " topk_ids  dtype must be same as src_row_id2dst_row_id_map "  );
109+       " permuted_hidden_states  dtype must be same as hidden_states "  );
119110  auto  n_token = hidden_states.size (0 );
120111  auto  n_hidden = hidden_states.size (1 );
121112  auto  stream = at::cuda::getCurrentCUDAStream ().stream ();
122-   const  int64_t * valid_ptr =
123-       get_ptr<int64_t >(expert_first_token_offset) + n_local_expert;
113+ 
114+   int64_t  const * valid_ptr = nullptr ;
115+   if  (expert_first_token_offset.has_value ()) {
116+     int  n_local_expert = expert_first_token_offset.value ().size (0 ) - 1 ;
117+     valid_ptr =
118+         get_ptr<int64_t >(expert_first_token_offset.value ()) + n_local_expert;
119+   }
120+ 
124121  MOE_DISPATCH (hidden_states.scalar_type (), [&] {
125122    finalizeMoeRoutingKernelLauncher<scalar_t , scalar_t >(
126123        get_ptr<scalar_t >(permuted_hidden_states),
127124        get_ptr<scalar_t >(hidden_states), get_ptr<float >(topk_weights),
128-         get_ptr<int >(src_row_id2dst_row_id_map ), get_ptr< int >(topk_ids) ,
129-         n_token, n_hidden, topk, valid_ptr,  stream);
125+         get_ptr<int >(inv_permuted_idx ), n_token, n_hidden, topk, valid_ptr ,
126+         stream);
130127  });
131128}
132129
0 commit comments