Skip to content

Commit d23ed9e

Browse files
authored
[CPU] Optimize for int8 sdpa kernel (#3025)
* [CPU] Optimize for int8 sdpa kernel
1 parent 4013764 commit d23ed9e

File tree

1 file changed

+9
-17
lines changed

1 file changed

+9
-17
lines changed

torchao/prototype/inductor/codegen/cpp_int8_sdpa_template.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -212,14 +212,14 @@
212212
for (; col < vec_size * (kvBlockSize / vec_size); col += vec_size) {
213213
auto tmp0 = at::vec::Vectorized<float>::loadu(tmp_in + col);
214214
auto tmp1 = tmp0 - vec_max;
215-
auto tmp2 = tmp1.exp_u20();
215+
auto tmp2 = tmp1.fexp_u20();
216216
vec_tmp_sum += tmp2;
217217
store(tmp_out + col, tmp2);
218218
}
219219
if (col < kvBlockSize) {
220220
auto tmp0 = at::vec::Vectorized<float>::loadu(tmp_in + col, kvBlockSize - col);
221221
auto tmp1 = tmp0 - vec_max;
222-
auto tmp2 = tmp1.exp_u20();
222+
auto tmp2 = tmp1.fexp_u20();
223223
store(tmp_out + col, tmp2, kvBlockSize - col);
224224
vec_tmp_sum = at::vec::Vectorized<float>::set(vec_tmp_sum, vec_tmp_sum + tmp2, kvBlockSize - col);
225225
}
@@ -316,14 +316,14 @@
316316
for (; col < vec_size * (kvBlockSize / vec_size); col += vec_size) {
317317
auto tmp0 = at::vec::Vectorized<float>::loadu(tmp_in + col);
318318
auto tmp1 = tmp0 - vec_max;
319-
auto tmp2 = tmp1.exp_u20();
319+
auto tmp2 = tmp1.fexp_u20();
320320
vec_tmp_sum += tmp2;
321321
store(tmp_out + col, tmp2);
322322
}
323323
if (col < kvBlockSize) {
324324
auto tmp0 = at::vec::Vectorized<float>::loadu(tmp_in + col, kvBlockSize - col);
325325
auto tmp1 = tmp0 - vec_max;
326-
auto tmp2 = tmp1.exp_u20();
326+
auto tmp2 = tmp1.fexp_u20();
327327
vec_tmp_sum = at::vec::Vectorized<float>::set(vec_tmp_sum, vec_tmp_sum + tmp2, kvBlockSize - col);
328328
store(tmp_out + col, tmp2, kvBlockSize - col);
329329
}
@@ -1300,7 +1300,6 @@
13001300
int64_t i = 0, j = 0, l = 0, n = 0;
13011301
at::native::data_index_init(
13021302
begin, i, batchSize, j, num_head, l, kvSlice);
1303-
uint8_t* B_blocked_xform_u8 = new uint8_t[rndHeadSize * kvSplitSize];
13041303
for (const auto z : c10::irange(begin, end)) {
13051304
(void)z; // Suppress unused variable
13061305
n = l * kvSplitSize;
@@ -1310,19 +1309,12 @@
13101309
i * num_head * kvSlice * v_reorder_strideL +
13111310
j * kvSlice * v_reorder_strideL + n * rndHeadSize;
13121311
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
1313-
at::native::utils::transpose<uint8_t>(
1314-
kvBlockSize,
1315-
headSize,
1316-
k_data + i * kStrideB + j * kStrideH + n * kStrideN,
1317-
kStrideN,
1318-
B_blocked_xform_u8,
1319-
kvBlockSize);
1320-
at::vec::pack_vnni4(
1321-
/* src */ B_blocked_xform_u8,
1312+
at::vec::transpose_pack_vnni4(
1313+
/* src */ k_data + i * kStrideB + j * kStrideH + n * kStrideN,
13221314
/* dst */ k_reorder,
1323-
/* ld_src */ kvBlockSize,
1324-
/* K */ rndHeadSize,
1325-
/* N */ kvBlockSize);
1315+
/* ld_src */ kStrideN,
1316+
/* K */ kvBlockSize,
1317+
/* N */ rndHeadSize);
13261318
at::vec::pack_vnni4(
13271319
/* src */ v_data + i * vStrideB + j * vStrideH + n * vStrideN,
13281320
/* dst */ v_reorder,

0 commit comments

Comments
 (0)