Skip to content

Commit 2cb799b

Browse files
authored
[CPU] Support int8 scaled embedding bag (#2938)
* add int8 embeddingbag * improve code style * improve code style * refine ut
1 parent 851e2e6 commit 2cb799b

File tree

2 files changed

+119
-44
lines changed

2 files changed

+119
-44
lines changed

test/test_ops.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -779,19 +779,10 @@ def test_swizzle_mm():
779779
)
780780

781781

782-
@pytest.mark.skipif(
783-
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
784-
reason="cpp kernels not built",
785-
)
786-
@pytest.mark.parametrize(
787-
"multi_hot, batch_size, vector_size, index_type",
788-
EMBEDINGBAG_TEST_PARAMS,
789-
ids=str,
790-
)
791-
def test_scaled_embedding_bag_cpu(multi_hot, batch_size, vector_size, index_type):
792-
qtype = torch.float8_e4m3fn
782+
def _test_scaled_embedding_bag_cpu_helper(
783+
multi_hot, batch_size, vector_size, index_type, qtype
784+
):
793785
dtype = torch.float32
794-
weight_scale = torch.tensor([2.0])
795786
include_last_offset = True
796787
mode = "sum"
797788

@@ -811,13 +802,18 @@ def test_scaled_embedding_bag_cpu(multi_hot, batch_size, vector_size, index_type
811802
dtype=dtype,
812803
include_last_offset=include_last_offset,
813804
)
814-
fp8_weight = m.weight.data.to(qtype)
815-
m.weight.data = fp8_weight.to(m.weight.dtype)
805+
if qtype == torch.int8:
806+
weight_scale = 127.0 / m.weight.data.abs().max()
807+
qweight = (m.weight.data * weight_scale).to(qtype)
808+
else:
809+
weight_scale = torch.tensor([2.0])
810+
qweight = m.weight.data.to(qtype)
811+
m.weight.data = qweight.to(m.weight.dtype)
816812

817813
with torch.no_grad():
818814
refe_out = m.forward(indices, offsets) * weight_scale
819815
test_out = torch.ops.torchao._scaled_embedding_bag(
820-
fp8_weight,
816+
qweight,
821817
indices,
822818
offsets,
823819
weight_scale,
@@ -828,5 +824,35 @@ def test_scaled_embedding_bag_cpu(multi_hot, batch_size, vector_size, index_type
828824
torch.testing.assert_close(refe_out, test_out, atol=1e-5, rtol=1e-5)
829825

830826

827+
@pytest.mark.skipif(
828+
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
829+
reason="cpp kernels not built",
830+
)
831+
@pytest.mark.parametrize(
832+
"multi_hot, batch_size, vector_size, index_type",
833+
EMBEDINGBAG_TEST_PARAMS,
834+
ids=str,
835+
)
836+
def test_scaled_embedding_bag_int8_cpu(multi_hot, batch_size, vector_size, index_type):
837+
_test_scaled_embedding_bag_cpu_helper(
838+
multi_hot, batch_size, vector_size, index_type, torch.int8
839+
)
840+
841+
842+
@pytest.mark.skipif(
843+
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
844+
reason="cpp kernels not built",
845+
)
846+
@pytest.mark.parametrize(
847+
"multi_hot, batch_size, vector_size, index_type",
848+
EMBEDINGBAG_TEST_PARAMS,
849+
ids=str,
850+
)
851+
def test_scaled_embedding_bag_fp8_cpu(multi_hot, batch_size, vector_size, index_type):
852+
_test_scaled_embedding_bag_cpu_helper(
853+
multi_hot, batch_size, vector_size, index_type, torch.float8_e4m3fn
854+
)
855+
856+
831857
if __name__ == "__main__":
832858
pytest.main(sys.argv)

torchao/csrc/cpu/aten_kernels/scaled_embedding_bag.cpp

Lines changed: 78 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,55 @@ namespace torchao {
1111
namespace {
1212

1313
#if defined(CPU_CAPABILITY_AVX512)
14+
using CHUNK =
15+
std::tuple<__m512, __m512, __m512, __m512, __m512, __m512, __m512, __m512>;
1416
static inline __m512 _mm512_load_e4m3_cvt_ps(const at::Float8_e4m3fn *x) {
1517
__m512 o;
1618
__m128i v = _mm_loadu_si128(reinterpret_cast<const __m128i *>(x));
1719
at::vec::CPU_CAPABILITY::cvtfp8e4m3_fp32(v, o);
1820
return o;
1921
}
22+
23+
static inline __m512 _mm512_cvt_s8_ps(__m128i x) {
24+
return _mm512_cvt_roundepi32_ps(
25+
_mm512_cvtepi8_epi32(x), (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
26+
}
27+
28+
static inline CHUNK load_chunk(const at::Float8_e4m3fn *x) {
29+
__m512 x0, x1, x2, x3, x4, x5, x6, x7;
30+
x0 = _mm512_load_e4m3_cvt_ps(x + 0);
31+
x1 = _mm512_load_e4m3_cvt_ps(x + 16);
32+
x2 = _mm512_load_e4m3_cvt_ps(x + 32);
33+
x3 = _mm512_load_e4m3_cvt_ps(x + 48);
34+
x4 = _mm512_load_e4m3_cvt_ps(x + 64);
35+
x5 = _mm512_load_e4m3_cvt_ps(x + 80);
36+
x6 = _mm512_load_e4m3_cvt_ps(x + 96);
37+
x7 = _mm512_load_e4m3_cvt_ps(x + 112);
38+
return {x0, x1, x2, x3, x4, x5, x6, x7};
39+
}
40+
41+
static inline CHUNK load_chunk(const int8_t *x) {
42+
__m512i x00, x64;
43+
__m512 x0, x1, x2, x3, x4, x5, x6, x7;
44+
x00 = _mm512_load_si512(x);
45+
x64 = _mm512_load_si512(x + 64);
46+
x0 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x00, 0));
47+
x1 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x00, 1));
48+
x2 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x00, 2));
49+
x3 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x00, 3));
50+
x4 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x64, 0));
51+
x5 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x64, 1));
52+
x6 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x64, 2));
53+
x7 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x64, 3));
54+
return {x0, x1, x2, x3, x4, x5, x6, x7};
55+
}
2056
#endif
2157

22-
template <typename index_t>
58+
template <typename index_t, typename data_t>
2359
inline void _scaled_embedding_bag_krnl(
2460
const int64_t bs_begin, const int64_t bs_end, const int64_t num_emb,
2561
const int64_t emb_dim, const index_t last_offset, const index_t *indices,
26-
const index_t *offsets, const at::Float8_e4m3fn *weight, const double scale,
62+
const index_t *offsets, const data_t *weight, const double scale,
2763
float *result, const int64_t num_batch) {
2864
#if defined(CPU_CAPABILITY_AVX512)
2965
if (emb_dim % 128 == 0) {
@@ -32,6 +68,7 @@ inline void _scaled_embedding_bag_krnl(
3268
__m512 scale_v = _mm512_set1_ps(scale);
3369
for (int64_t b = bs_begin; b < bs_end; ++b) {
3470
__m512 x0, x1, x2, x3, x4, x5, x6, x7;
71+
__m512 y0, y1, y2, y3, y4, y5, y6, y7;
3572
int64_t start_idx = offsets[b];
3673
int64_t end_idx = ((b + 1) == num_batch && last_offset != -1)
3774
? last_offset
@@ -40,25 +77,19 @@ inline void _scaled_embedding_bag_krnl(
4077
// load first indices
4178
int64_t idx = indices[start_idx] * emb_dim + block_dim * block_id;
4279
float *block_result = result + block_dim * block_id;
43-
x0 = _mm512_load_e4m3_cvt_ps(&weight[idx]);
44-
x1 = _mm512_load_e4m3_cvt_ps(&weight[idx + 16]);
45-
x2 = _mm512_load_e4m3_cvt_ps(&weight[idx + 32]);
46-
x3 = _mm512_load_e4m3_cvt_ps(&weight[idx + 48]);
47-
x4 = _mm512_load_e4m3_cvt_ps(&weight[idx + 64]);
48-
x5 = _mm512_load_e4m3_cvt_ps(&weight[idx + 80]);
49-
x6 = _mm512_load_e4m3_cvt_ps(&weight[idx + 96]);
50-
x7 = _mm512_load_e4m3_cvt_ps(&weight[idx + 112]);
80+
std::tie(x0, x1, x2, x3, x4, x5, x6, x7) = load_chunk(weight + idx);
5181
for (int64_t j = start_idx + 1; j < end_idx; ++j) {
5282
// add following idx
5383
idx = indices[j] * emb_dim + block_dim * block_id;
54-
x0 = _mm512_add_ps(x0, _mm512_load_e4m3_cvt_ps(&weight[idx]));
55-
x1 = _mm512_add_ps(x1, _mm512_load_e4m3_cvt_ps(&weight[idx + 16]));
56-
x2 = _mm512_add_ps(x2, _mm512_load_e4m3_cvt_ps(&weight[idx + 32]));
57-
x3 = _mm512_add_ps(x3, _mm512_load_e4m3_cvt_ps(&weight[idx + 48]));
58-
x4 = _mm512_add_ps(x4, _mm512_load_e4m3_cvt_ps(&weight[idx + 64]));
59-
x5 = _mm512_add_ps(x5, _mm512_load_e4m3_cvt_ps(&weight[idx + 80]));
60-
x6 = _mm512_add_ps(x6, _mm512_load_e4m3_cvt_ps(&weight[idx + 96]));
61-
x7 = _mm512_add_ps(x7, _mm512_load_e4m3_cvt_ps(&weight[idx + 112]));
84+
std::tie(y0, y1, y2, y3, y4, y5, y6, y7) = load_chunk(weight + idx);
85+
x0 = _mm512_add_ps(x0, y0);
86+
x1 = _mm512_add_ps(x1, y1);
87+
x2 = _mm512_add_ps(x2, y2);
88+
x3 = _mm512_add_ps(x3, y3);
89+
x4 = _mm512_add_ps(x4, y4);
90+
x5 = _mm512_add_ps(x5, y5);
91+
x6 = _mm512_add_ps(x6, y6);
92+
x7 = _mm512_add_ps(x7, y7);
6293
}
6394
x0 = _mm512_mul_ps(x0, scale_v);
6495
x1 = _mm512_mul_ps(x1, scale_v);
@@ -143,6 +174,7 @@ at::Tensor _scaled_embedding_bag_impl(const at::Tensor &qweight,
143174
int64_t emb_dim = qweight.size(1);
144175

145176
auto index_type = indices.scalar_type();
177+
auto qtype = qweight.scalar_type();
146178
float w_scale = w_scales.data_ptr<float>()[0];
147179

148180
TORCH_CHECK(indices.is_contiguous() && offsets.is_contiguous(),
@@ -154,22 +186,39 @@ at::Tensor _scaled_embedding_bag_impl(const at::Tensor &qweight,
154186
"_scaled_embedding_bag: only accept contiguous weight");
155187
TORCH_CHECK(qweight.dim() == 2,
156188
"_scaled_embedding_bag: only accept weight with dim == 2");
157-
TORCH_CHECK(qweight.scalar_type() == c10::ScalarType::Float8_e4m3fn,
158-
"_scaled_embedding_bag: only support e4m3fn weight")
189+
TORCH_CHECK(qweight.scalar_type() == c10::ScalarType::Float8_e4m3fn ||
190+
qweight.scalar_type() == c10::ScalarType::Char,
191+
"_scaled_embedding_bag: only support e4m3fn and int8 weight")
159192
// handle last offsets
160193
int64_t last_offset = indices.numel();
161194

162195
at::Tensor output =
163196
at::empty({batch_size, emb_dim}, qweight.options().dtype(at::kFloat));
164-
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embeddingbag_cat", [&] {
165-
at::Float8_e4m3fn *qweight_ptr = qweight.data_ptr<at::Float8_e4m3fn>();
166-
index_t *indices_ptr = indices.data_ptr<index_t>();
167-
index_t *offsets_ptr = offsets.data_ptr<index_t>();
168-
float *output_ptr = output.data_ptr<float>();
169-
_scaled_embedding_bag<index_t, at::Float8_e4m3fn>(
170-
output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size, emb_dim,
171-
last_offset, w_scale, o_scale);
172-
});
197+
if (qweight.scalar_type() == c10::ScalarType::Float8_e4m3fn) {
198+
AT_DISPATCH_INDEX_TYPES(
199+
indices.scalar_type(), "_scaled_embedding_bag", [&] {
200+
at::Float8_e4m3fn *qweight_ptr =
201+
qweight.data_ptr<at::Float8_e4m3fn>();
202+
index_t *indices_ptr = indices.data_ptr<index_t>();
203+
index_t *offsets_ptr = offsets.data_ptr<index_t>();
204+
float *output_ptr = output.data_ptr<float>();
205+
_scaled_embedding_bag<index_t, at::Float8_e4m3fn>(
206+
output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size,
207+
emb_dim, last_offset, w_scale, o_scale);
208+
});
209+
} else {
210+
AT_DISPATCH_INDEX_TYPES(
211+
indices.scalar_type(), "_scaled_embedding_bag", [&] {
212+
int8_t *qweight_ptr = qweight.data_ptr<int8_t>();
213+
index_t *indices_ptr = indices.data_ptr<index_t>();
214+
index_t *offsets_ptr = offsets.data_ptr<index_t>();
215+
float *output_ptr = output.data_ptr<float>();
216+
_scaled_embedding_bag<index_t, int8_t>(
217+
output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size,
218+
emb_dim, last_offset, w_scale, o_scale);
219+
});
220+
}
221+
173222
return output;
174223
}
175224

0 commit comments

Comments
 (0)