|
212 | 212 | for (; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { |
213 | 213 | auto tmp0 = at::vec::Vectorized<float>::loadu(tmp_in + col); |
214 | 214 | auto tmp1 = tmp0 - vec_max; |
215 | | - auto tmp2 = tmp1.exp_u20(); |
| 215 | + auto tmp2 = tmp1.fexp_u20(); |
216 | 216 | vec_tmp_sum += tmp2; |
217 | 217 | store(tmp_out + col, tmp2); |
218 | 218 | } |
219 | 219 | if (col < kvBlockSize) { |
220 | 220 | auto tmp0 = at::vec::Vectorized<float>::loadu(tmp_in + col, kvBlockSize - col); |
221 | 221 | auto tmp1 = tmp0 - vec_max; |
222 | | - auto tmp2 = tmp1.exp_u20(); |
| 222 | + auto tmp2 = tmp1.fexp_u20(); |
223 | 223 | store(tmp_out + col, tmp2, kvBlockSize - col); |
224 | 224 | vec_tmp_sum = at::vec::Vectorized<float>::set(vec_tmp_sum, vec_tmp_sum + tmp2, kvBlockSize - col); |
225 | 225 | } |
|
316 | 316 | for (; col < vec_size * (kvBlockSize / vec_size); col += vec_size) { |
317 | 317 | auto tmp0 = at::vec::Vectorized<float>::loadu(tmp_in + col); |
318 | 318 | auto tmp1 = tmp0 - vec_max; |
319 | | - auto tmp2 = tmp1.exp_u20(); |
| 319 | + auto tmp2 = tmp1.fexp_u20(); |
320 | 320 | vec_tmp_sum += tmp2; |
321 | 321 | store(tmp_out + col, tmp2); |
322 | 322 | } |
323 | 323 | if (col < kvBlockSize) { |
324 | 324 | auto tmp0 = at::vec::Vectorized<float>::loadu(tmp_in + col, kvBlockSize - col); |
325 | 325 | auto tmp1 = tmp0 - vec_max; |
326 | | - auto tmp2 = tmp1.exp_u20(); |
| 326 | + auto tmp2 = tmp1.fexp_u20(); |
327 | 327 | vec_tmp_sum = at::vec::Vectorized<float>::set(vec_tmp_sum, vec_tmp_sum + tmp2, kvBlockSize - col); |
328 | 328 | store(tmp_out + col, tmp2, kvBlockSize - col); |
329 | 329 | } |
|
1300 | 1300 | int64_t i = 0, j = 0, l = 0, n = 0; |
1301 | 1301 | at::native::data_index_init( |
1302 | 1302 | begin, i, batchSize, j, num_head, l, kvSlice); |
1303 | | - uint8_t* B_blocked_xform_u8 = new uint8_t[rndHeadSize * kvSplitSize]; |
1304 | 1303 | for (const auto z : c10::irange(begin, end)) { |
1305 | 1304 | (void)z; // Suppress unused variable |
1306 | 1305 | n = l * kvSplitSize; |
|
1310 | 1309 | i * num_head * kvSlice * v_reorder_strideL + |
1311 | 1310 | j * kvSlice * v_reorder_strideL + n * rndHeadSize; |
1312 | 1311 | int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); |
1313 | | - at::native::utils::transpose<uint8_t>( |
1314 | | - kvBlockSize, |
1315 | | - headSize, |
1316 | | - k_data + i * kStrideB + j * kStrideH + n * kStrideN, |
1317 | | - kStrideN, |
1318 | | - B_blocked_xform_u8, |
1319 | | - kvBlockSize); |
1320 | | - at::vec::pack_vnni4( |
1321 | | - /* src */ B_blocked_xform_u8, |
| 1312 | + at::vec::transpose_pack_vnni4( |
| 1313 | + /* src */ k_data + i * kStrideB + j * kStrideH + n * kStrideN, |
1322 | 1314 | /* dst */ k_reorder, |
1323 | | - /* ld_src */ kvBlockSize, |
1324 | | - /* K */ rndHeadSize, |
1325 | | - /* N */ kvBlockSize); |
| 1315 | + /* ld_src */ kStrideN, |
| 1316 | + /* K */ kvBlockSize, |
| 1317 | + /* N */ rndHeadSize); |
1326 | 1318 | at::vec::pack_vnni4( |
1327 | 1319 | /* src */ v_data + i * vStrideB + j * vStrideH + n * vStrideN, |
1328 | 1320 | /* dst */ v_reorder, |
|
0 commit comments