Skip to content

Commit 4a03494

Browse files
authored
[Reland][CPU] Support int8 scaled embedding bag (#3060)
* re-enable scaled_embedding_bag * only support fp32 out_dtype
1 parent 0d3217d commit 4a03494

File tree

3 files changed

+124
-46
lines changed

3 files changed

+124
-46
lines changed

test/test_ops.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -863,19 +863,10 @@ def test_swizzle_mm():
863863
)
864864

865865

866-
@pytest.mark.skipif(
867-
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
868-
reason="cpp kernels not built",
869-
)
870-
@pytest.mark.parametrize(
871-
"multi_hot, batch_size, vector_size, index_type",
872-
EMBEDINGBAG_TEST_PARAMS,
873-
ids=str,
874-
)
875-
def test_scaled_embedding_bag_cpu(multi_hot, batch_size, vector_size, index_type):
876-
qtype = torch.float8_e4m3fn
866+
def _test_scaled_embedding_bag_cpu_helper(
867+
multi_hot, batch_size, vector_size, index_type, qtype
868+
):
877869
dtype = torch.float32
878-
weight_scale = torch.tensor([2.0])
879870
include_last_offset = True
880871
mode = "sum"
881872

@@ -895,13 +886,18 @@ def test_scaled_embedding_bag_cpu(multi_hot, batch_size, vector_size, index_type
895886
dtype=dtype,
896887
include_last_offset=include_last_offset,
897888
)
898-
fp8_weight = m.weight.data.to(qtype)
899-
m.weight.data = fp8_weight.to(m.weight.dtype)
889+
if qtype == torch.int8:
890+
weight_scale = 127.0 / m.weight.data.abs().max()
891+
qweight = (m.weight.data * weight_scale).to(qtype)
892+
else:
893+
weight_scale = torch.tensor([2.0])
894+
qweight = m.weight.data.to(qtype)
895+
m.weight.data = qweight.to(m.weight.dtype)
900896

901897
with torch.no_grad():
902898
refe_out = m.forward(indices, offsets) * weight_scale
903899
test_out = torch.ops.torchao._scaled_embedding_bag(
904-
fp8_weight,
900+
qweight,
905901
indices,
906902
offsets,
907903
weight_scale,
@@ -912,6 +908,36 @@ def test_scaled_embedding_bag_cpu(multi_hot, batch_size, vector_size, index_type
912908
torch.testing.assert_close(refe_out, test_out, atol=1e-5, rtol=1e-5)
913909

914910

911+
@pytest.mark.skipif(
912+
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
913+
reason="cpp kernels not built",
914+
)
915+
@pytest.mark.parametrize(
916+
"multi_hot, batch_size, vector_size, index_type",
917+
EMBEDINGBAG_TEST_PARAMS,
918+
ids=str,
919+
)
920+
def test_scaled_embedding_bag_int8_cpu(multi_hot, batch_size, vector_size, index_type):
921+
_test_scaled_embedding_bag_cpu_helper(
922+
multi_hot, batch_size, vector_size, index_type, torch.int8
923+
)
924+
925+
926+
@pytest.mark.skipif(
927+
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
928+
reason="cpp kernels not built",
929+
)
930+
@pytest.mark.parametrize(
931+
"multi_hot, batch_size, vector_size, index_type",
932+
EMBEDINGBAG_TEST_PARAMS,
933+
ids=str,
934+
)
935+
def test_scaled_embedding_bag_fp8_cpu(multi_hot, batch_size, vector_size, index_type):
936+
_test_scaled_embedding_bag_cpu_helper(
937+
multi_hot, batch_size, vector_size, index_type, torch.float8_e4m3fn
938+
)
939+
940+
915941
@pytest.mark.skipif(
916942
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_prepack_cpu")
917943
or "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"),

torchao/csrc/cpu/aten_kernels/scaled_embedding_bag.cpp

Lines changed: 79 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,55 @@ namespace torchao {
1111
namespace {
1212

1313
#if defined(CPU_CAPABILITY_AVX512)
14+
using CHUNK =
15+
std::tuple<__m512, __m512, __m512, __m512, __m512, __m512, __m512, __m512>;
1416
static inline __m512 _mm512_load_e4m3_cvt_ps(const at::Float8_e4m3fn *x) {
1517
__m512 o;
1618
__m128i v = _mm_loadu_si128(reinterpret_cast<const __m128i *>(x));
1719
at::vec::CPU_CAPABILITY::cvtfp8e4m3_fp32(v, o);
1820
return o;
1921
}
22+
23+
static inline __m512 _mm512_cvt_s8_ps(__m128i x) {
24+
return _mm512_cvt_roundepi32_ps(
25+
_mm512_cvtepi8_epi32(x), (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
26+
}
27+
28+
static inline CHUNK load_chunk(const at::Float8_e4m3fn *x) {
29+
__m512 x0, x1, x2, x3, x4, x5, x6, x7;
30+
x0 = _mm512_load_e4m3_cvt_ps(x + 0);
31+
x1 = _mm512_load_e4m3_cvt_ps(x + 16);
32+
x2 = _mm512_load_e4m3_cvt_ps(x + 32);
33+
x3 = _mm512_load_e4m3_cvt_ps(x + 48);
34+
x4 = _mm512_load_e4m3_cvt_ps(x + 64);
35+
x5 = _mm512_load_e4m3_cvt_ps(x + 80);
36+
x6 = _mm512_load_e4m3_cvt_ps(x + 96);
37+
x7 = _mm512_load_e4m3_cvt_ps(x + 112);
38+
return {x0, x1, x2, x3, x4, x5, x6, x7};
39+
}
40+
41+
static inline CHUNK load_chunk(const int8_t *x) {
42+
__m512i x00, x64;
43+
__m512 x0, x1, x2, x3, x4, x5, x6, x7;
44+
x00 = _mm512_load_si512(x);
45+
x64 = _mm512_load_si512(x + 64);
46+
x0 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x00, 0));
47+
x1 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x00, 1));
48+
x2 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x00, 2));
49+
x3 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x00, 3));
50+
x4 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x64, 0));
51+
x5 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x64, 1));
52+
x6 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x64, 2));
53+
x7 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x64, 3));
54+
return {x0, x1, x2, x3, x4, x5, x6, x7};
55+
}
2056
#endif
2157

22-
template <typename index_t>
58+
template <typename index_t, typename data_t>
2359
inline void _scaled_embedding_bag_krnl(
2460
const int64_t bs_begin, const int64_t bs_end, const int64_t num_emb,
2561
const int64_t emb_dim, const index_t last_offset, const index_t *indices,
26-
const index_t *offsets, const at::Float8_e4m3fn *weight, const double scale,
62+
const index_t *offsets, const data_t *weight, const double scale,
2763
float *result, const int64_t num_batch) {
2864
#if defined(CPU_CAPABILITY_AVX512)
2965
if (emb_dim % 128 == 0) {
@@ -32,6 +68,7 @@ inline void _scaled_embedding_bag_krnl(
3268
__m512 scale_v = _mm512_set1_ps(scale);
3369
for (int64_t b = bs_begin; b < bs_end; ++b) {
3470
__m512 x0, x1, x2, x3, x4, x5, x6, x7;
71+
__m512 y0, y1, y2, y3, y4, y5, y6, y7;
3572
int64_t start_idx = offsets[b];
3673
int64_t end_idx = ((b + 1) == num_batch && last_offset != -1)
3774
? last_offset
@@ -40,25 +77,19 @@ inline void _scaled_embedding_bag_krnl(
4077
// load first indices
4178
int64_t idx = indices[start_idx] * emb_dim + block_dim * block_id;
4279
float *block_result = result + block_dim * block_id;
43-
x0 = _mm512_load_e4m3_cvt_ps(&weight[idx]);
44-
x1 = _mm512_load_e4m3_cvt_ps(&weight[idx + 16]);
45-
x2 = _mm512_load_e4m3_cvt_ps(&weight[idx + 32]);
46-
x3 = _mm512_load_e4m3_cvt_ps(&weight[idx + 48]);
47-
x4 = _mm512_load_e4m3_cvt_ps(&weight[idx + 64]);
48-
x5 = _mm512_load_e4m3_cvt_ps(&weight[idx + 80]);
49-
x6 = _mm512_load_e4m3_cvt_ps(&weight[idx + 96]);
50-
x7 = _mm512_load_e4m3_cvt_ps(&weight[idx + 112]);
80+
std::tie(x0, x1, x2, x3, x4, x5, x6, x7) = load_chunk(weight + idx);
5181
for (int64_t j = start_idx + 1; j < end_idx; ++j) {
5282
// add following idx
5383
idx = indices[j] * emb_dim + block_dim * block_id;
54-
x0 = _mm512_add_ps(x0, _mm512_load_e4m3_cvt_ps(&weight[idx]));
55-
x1 = _mm512_add_ps(x1, _mm512_load_e4m3_cvt_ps(&weight[idx + 16]));
56-
x2 = _mm512_add_ps(x2, _mm512_load_e4m3_cvt_ps(&weight[idx + 32]));
57-
x3 = _mm512_add_ps(x3, _mm512_load_e4m3_cvt_ps(&weight[idx + 48]));
58-
x4 = _mm512_add_ps(x4, _mm512_load_e4m3_cvt_ps(&weight[idx + 64]));
59-
x5 = _mm512_add_ps(x5, _mm512_load_e4m3_cvt_ps(&weight[idx + 80]));
60-
x6 = _mm512_add_ps(x6, _mm512_load_e4m3_cvt_ps(&weight[idx + 96]));
61-
x7 = _mm512_add_ps(x7, _mm512_load_e4m3_cvt_ps(&weight[idx + 112]));
84+
std::tie(y0, y1, y2, y3, y4, y5, y6, y7) = load_chunk(weight + idx);
85+
x0 = _mm512_add_ps(x0, y0);
86+
x1 = _mm512_add_ps(x1, y1);
87+
x2 = _mm512_add_ps(x2, y2);
88+
x3 = _mm512_add_ps(x3, y3);
89+
x4 = _mm512_add_ps(x4, y4);
90+
x5 = _mm512_add_ps(x5, y5);
91+
x6 = _mm512_add_ps(x6, y6);
92+
x7 = _mm512_add_ps(x7, y7);
6293
}
6394
x0 = _mm512_mul_ps(x0, scale_v);
6495
x1 = _mm512_mul_ps(x1, scale_v);
@@ -143,6 +174,7 @@ at::Tensor _scaled_embedding_bag_impl(const at::Tensor &qweight,
143174
int64_t emb_dim = qweight.size(1);
144175

145176
auto index_type = indices.scalar_type();
177+
auto qtype = qweight.scalar_type();
146178
float w_scale = w_scales.data_ptr<float>()[0];
147179

148180
TORCH_CHECK(indices.is_contiguous() && offsets.is_contiguous(),
@@ -154,22 +186,39 @@ at::Tensor _scaled_embedding_bag_impl(const at::Tensor &qweight,
154186
"_scaled_embedding_bag: only accept contiguous weight");
155187
TORCH_CHECK(qweight.dim() == 2,
156188
"_scaled_embedding_bag: only accept weight with dim == 2");
157-
TORCH_CHECK(qweight.scalar_type() == c10::ScalarType::Float8_e4m3fn,
158-
"_scaled_embedding_bag: only support e4m3fn weight")
189+
TORCH_CHECK(qtype == c10::ScalarType::Float8_e4m3fn ||
190+
qtype == c10::ScalarType::Char,
191+
"_scaled_embedding_bag: only support e4m3fn and int8 weight")
159192
// handle last offsets
160193
int64_t last_offset = indices.numel();
161194

162195
at::Tensor output =
163196
at::empty({batch_size, emb_dim}, qweight.options().dtype(at::kFloat));
164-
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embeddingbag_cat", [&] {
165-
at::Float8_e4m3fn *qweight_ptr = qweight.data_ptr<at::Float8_e4m3fn>();
166-
index_t *indices_ptr = indices.data_ptr<index_t>();
167-
index_t *offsets_ptr = offsets.data_ptr<index_t>();
168-
float *output_ptr = output.data_ptr<float>();
169-
_scaled_embedding_bag<index_t, at::Float8_e4m3fn>(
170-
output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size, emb_dim,
171-
last_offset, w_scale, o_scale);
172-
});
197+
if (qtype == c10::ScalarType::Float8_e4m3fn) {
198+
AT_DISPATCH_INDEX_TYPES(
199+
indices.scalar_type(), "_scaled_embedding_bag", [&] {
200+
at::Float8_e4m3fn *qweight_ptr =
201+
qweight.data_ptr<at::Float8_e4m3fn>();
202+
index_t *indices_ptr = indices.data_ptr<index_t>();
203+
index_t *offsets_ptr = offsets.data_ptr<index_t>();
204+
float *output_ptr = output.data_ptr<float>();
205+
_scaled_embedding_bag<index_t, at::Float8_e4m3fn>(
206+
output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size,
207+
emb_dim, last_offset, w_scale, o_scale);
208+
});
209+
} else {
210+
AT_DISPATCH_INDEX_TYPES(
211+
indices.scalar_type(), "_scaled_embedding_bag", [&] {
212+
int8_t *qweight_ptr = qweight.data_ptr<int8_t>();
213+
index_t *indices_ptr = indices.data_ptr<index_t>();
214+
index_t *offsets_ptr = offsets.data_ptr<index_t>();
215+
float *output_ptr = output.data_ptr<float>();
216+
_scaled_embedding_bag<index_t, int8_t>(
217+
output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size,
218+
emb_dim, last_offset, w_scale, o_scale);
219+
});
220+
}
221+
173222
return output;
174223
}
175224

@@ -179,4 +228,4 @@ TORCH_LIBRARY_IMPL(torchao, CPU, m) {
179228
m.impl("torchao::_scaled_embedding_bag", &_scaled_embedding_bag_impl);
180229
}
181230

182-
} // namespace torchao
231+
} // namespace torchao

torchao/ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1122,7 +1122,10 @@ def _(
11221122
# Only support include_last_offset == True
11231123
assert include_last_offset == True
11241124
batch_size = offsets.shape[0] - 1
1125-
return qweight.new_empty(batch_size, qweight.shape[1], dtype=qweight.dtype)
1125+
# Only support out_dtype == torch.float32
1126+
# Next setp: support more out_dtype
1127+
out_dtype = torch.float32
1128+
return qweight.new_empty(batch_size, qweight.shape[1], dtype=out_dtype)
11261129

11271130

11281131
def float8_linear_prepack_cpu(

0 commit comments

Comments
 (0)