Skip to content

Commit 83e8e60

Browse files
authored
Revert "[CPU] Support int8 scaled embedding bag" (#2974)
Revert "[CPU] Support int8 scaled embedding bag (#2938)" This reverts commit 2cb799b.
1 parent b99904b commit 83e8e60

File tree

2 files changed

+44
-119
lines changed

2 files changed

+44
-119
lines changed

test/test_ops.py

Lines changed: 15 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -779,10 +779,19 @@ def test_swizzle_mm():
779779
)
780780

781781

782-
def _test_scaled_embedding_bag_cpu_helper(
783-
multi_hot, batch_size, vector_size, index_type, qtype
784-
):
782+
@pytest.mark.skipif(
783+
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
784+
reason="cpp kernels not built",
785+
)
786+
@pytest.mark.parametrize(
787+
"multi_hot, batch_size, vector_size, index_type",
788+
EMBEDINGBAG_TEST_PARAMS,
789+
ids=str,
790+
)
791+
def test_scaled_embedding_bag_cpu(multi_hot, batch_size, vector_size, index_type):
792+
qtype = torch.float8_e4m3fn
785793
dtype = torch.float32
794+
weight_scale = torch.tensor([2.0])
786795
include_last_offset = True
787796
mode = "sum"
788797

@@ -802,18 +811,13 @@ def _test_scaled_embedding_bag_cpu_helper(
802811
dtype=dtype,
803812
include_last_offset=include_last_offset,
804813
)
805-
if qtype == torch.int8:
806-
weight_scale = 127.0 / m.weight.data.abs().max()
807-
qweight = (m.weight.data * weight_scale).to(qtype)
808-
else:
809-
weight_scale = torch.tensor([2.0])
810-
qweight = m.weight.data.to(qtype)
811-
m.weight.data = qweight.to(m.weight.dtype)
814+
fp8_weight = m.weight.data.to(qtype)
815+
m.weight.data = fp8_weight.to(m.weight.dtype)
812816

813817
with torch.no_grad():
814818
refe_out = m.forward(indices, offsets) * weight_scale
815819
test_out = torch.ops.torchao._scaled_embedding_bag(
816-
qweight,
820+
fp8_weight,
817821
indices,
818822
offsets,
819823
weight_scale,
@@ -824,35 +828,5 @@ def _test_scaled_embedding_bag_cpu_helper(
824828
torch.testing.assert_close(refe_out, test_out, atol=1e-5, rtol=1e-5)
825829

826830

827-
@pytest.mark.skipif(
828-
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
829-
reason="cpp kernels not built",
830-
)
831-
@pytest.mark.parametrize(
832-
"multi_hot, batch_size, vector_size, index_type",
833-
EMBEDINGBAG_TEST_PARAMS,
834-
ids=str,
835-
)
836-
def test_scaled_embedding_bag_int8_cpu(multi_hot, batch_size, vector_size, index_type):
837-
_test_scaled_embedding_bag_cpu_helper(
838-
multi_hot, batch_size, vector_size, index_type, torch.int8
839-
)
840-
841-
842-
@pytest.mark.skipif(
843-
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
844-
reason="cpp kernels not built",
845-
)
846-
@pytest.mark.parametrize(
847-
"multi_hot, batch_size, vector_size, index_type",
848-
EMBEDINGBAG_TEST_PARAMS,
849-
ids=str,
850-
)
851-
def test_scaled_embedding_bag_fp8_cpu(multi_hot, batch_size, vector_size, index_type):
852-
_test_scaled_embedding_bag_cpu_helper(
853-
multi_hot, batch_size, vector_size, index_type, torch.float8_e4m3fn
854-
)
855-
856-
857831
if __name__ == "__main__":
858832
pytest.main(sys.argv)

torchao/csrc/cpu/aten_kernels/scaled_embedding_bag.cpp

Lines changed: 29 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -11,55 +11,19 @@ namespace torchao {
1111
namespace {
1212

1313
#if defined(CPU_CAPABILITY_AVX512)
14-
using CHUNK =
15-
std::tuple<__m512, __m512, __m512, __m512, __m512, __m512, __m512, __m512>;
1614
static inline __m512 _mm512_load_e4m3_cvt_ps(const at::Float8_e4m3fn *x) {
1715
__m512 o;
1816
__m128i v = _mm_loadu_si128(reinterpret_cast<const __m128i *>(x));
1917
at::vec::CPU_CAPABILITY::cvtfp8e4m3_fp32(v, o);
2018
return o;
2119
}
22-
23-
static inline __m512 _mm512_cvt_s8_ps(__m128i x) {
24-
return _mm512_cvt_roundepi32_ps(
25-
_mm512_cvtepi8_epi32(x), (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
26-
}
27-
28-
static inline CHUNK load_chunk(const at::Float8_e4m3fn *x) {
29-
__m512 x0, x1, x2, x3, x4, x5, x6, x7;
30-
x0 = _mm512_load_e4m3_cvt_ps(x + 0);
31-
x1 = _mm512_load_e4m3_cvt_ps(x + 16);
32-
x2 = _mm512_load_e4m3_cvt_ps(x + 32);
33-
x3 = _mm512_load_e4m3_cvt_ps(x + 48);
34-
x4 = _mm512_load_e4m3_cvt_ps(x + 64);
35-
x5 = _mm512_load_e4m3_cvt_ps(x + 80);
36-
x6 = _mm512_load_e4m3_cvt_ps(x + 96);
37-
x7 = _mm512_load_e4m3_cvt_ps(x + 112);
38-
return {x0, x1, x2, x3, x4, x5, x6, x7};
39-
}
40-
41-
static inline CHUNK load_chunk(const int8_t *x) {
42-
__m512i x00, x64;
43-
__m512 x0, x1, x2, x3, x4, x5, x6, x7;
44-
x00 = _mm512_load_si512(x);
45-
x64 = _mm512_load_si512(x + 64);
46-
x0 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x00, 0));
47-
x1 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x00, 1));
48-
x2 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x00, 2));
49-
x3 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x00, 3));
50-
x4 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x64, 0));
51-
x5 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x64, 1));
52-
x6 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x64, 2));
53-
x7 = _mm512_cvt_s8_ps(_mm512_extracti32x4_epi32(x64, 3));
54-
return {x0, x1, x2, x3, x4, x5, x6, x7};
55-
}
5620
#endif
5721

58-
template <typename index_t, typename data_t>
22+
template <typename index_t>
5923
inline void _scaled_embedding_bag_krnl(
6024
const int64_t bs_begin, const int64_t bs_end, const int64_t num_emb,
6125
const int64_t emb_dim, const index_t last_offset, const index_t *indices,
62-
const index_t *offsets, const data_t *weight, const double scale,
26+
const index_t *offsets, const at::Float8_e4m3fn *weight, const double scale,
6327
float *result, const int64_t num_batch) {
6428
#if defined(CPU_CAPABILITY_AVX512)
6529
if (emb_dim % 128 == 0) {
@@ -68,7 +32,6 @@ inline void _scaled_embedding_bag_krnl(
6832
__m512 scale_v = _mm512_set1_ps(scale);
6933
for (int64_t b = bs_begin; b < bs_end; ++b) {
7034
__m512 x0, x1, x2, x3, x4, x5, x6, x7;
71-
__m512 y0, y1, y2, y3, y4, y5, y6, y7;
7235
int64_t start_idx = offsets[b];
7336
int64_t end_idx = ((b + 1) == num_batch && last_offset != -1)
7437
? last_offset
@@ -77,19 +40,25 @@ inline void _scaled_embedding_bag_krnl(
7740
// load first indices
7841
int64_t idx = indices[start_idx] * emb_dim + block_dim * block_id;
7942
float *block_result = result + block_dim * block_id;
80-
std::tie(x0, x1, x2, x3, x4, x5, x6, x7) = load_chunk(weight + idx);
43+
x0 = _mm512_load_e4m3_cvt_ps(&weight[idx]);
44+
x1 = _mm512_load_e4m3_cvt_ps(&weight[idx + 16]);
45+
x2 = _mm512_load_e4m3_cvt_ps(&weight[idx + 32]);
46+
x3 = _mm512_load_e4m3_cvt_ps(&weight[idx + 48]);
47+
x4 = _mm512_load_e4m3_cvt_ps(&weight[idx + 64]);
48+
x5 = _mm512_load_e4m3_cvt_ps(&weight[idx + 80]);
49+
x6 = _mm512_load_e4m3_cvt_ps(&weight[idx + 96]);
50+
x7 = _mm512_load_e4m3_cvt_ps(&weight[idx + 112]);
8151
for (int64_t j = start_idx + 1; j < end_idx; ++j) {
8252
// add following idx
8353
idx = indices[j] * emb_dim + block_dim * block_id;
84-
std::tie(y0, y1, y2, y3, y4, y5, y6, y7) = load_chunk(weight + idx);
85-
x0 = _mm512_add_ps(x0, y0);
86-
x1 = _mm512_add_ps(x1, y1);
87-
x2 = _mm512_add_ps(x2, y2);
88-
x3 = _mm512_add_ps(x3, y3);
89-
x4 = _mm512_add_ps(x4, y4);
90-
x5 = _mm512_add_ps(x5, y5);
91-
x6 = _mm512_add_ps(x6, y6);
92-
x7 = _mm512_add_ps(x7, y7);
54+
x0 = _mm512_add_ps(x0, _mm512_load_e4m3_cvt_ps(&weight[idx]));
55+
x1 = _mm512_add_ps(x1, _mm512_load_e4m3_cvt_ps(&weight[idx + 16]));
56+
x2 = _mm512_add_ps(x2, _mm512_load_e4m3_cvt_ps(&weight[idx + 32]));
57+
x3 = _mm512_add_ps(x3, _mm512_load_e4m3_cvt_ps(&weight[idx + 48]));
58+
x4 = _mm512_add_ps(x4, _mm512_load_e4m3_cvt_ps(&weight[idx + 64]));
59+
x5 = _mm512_add_ps(x5, _mm512_load_e4m3_cvt_ps(&weight[idx + 80]));
60+
x6 = _mm512_add_ps(x6, _mm512_load_e4m3_cvt_ps(&weight[idx + 96]));
61+
x7 = _mm512_add_ps(x7, _mm512_load_e4m3_cvt_ps(&weight[idx + 112]));
9362
}
9463
x0 = _mm512_mul_ps(x0, scale_v);
9564
x1 = _mm512_mul_ps(x1, scale_v);
@@ -174,7 +143,6 @@ at::Tensor _scaled_embedding_bag_impl(const at::Tensor &qweight,
174143
int64_t emb_dim = qweight.size(1);
175144

176145
auto index_type = indices.scalar_type();
177-
auto qtype = qweight.scalar_type();
178146
float w_scale = w_scales.data_ptr<float>()[0];
179147

180148
TORCH_CHECK(indices.is_contiguous() && offsets.is_contiguous(),
@@ -186,39 +154,22 @@ at::Tensor _scaled_embedding_bag_impl(const at::Tensor &qweight,
186154
"_scaled_embedding_bag: only accept contiguous weight");
187155
TORCH_CHECK(qweight.dim() == 2,
188156
"_scaled_embedding_bag: only accept weight with dim == 2");
189-
TORCH_CHECK(qweight.scalar_type() == c10::ScalarType::Float8_e4m3fn ||
190-
qweight.scalar_type() == c10::ScalarType::Char,
191-
"_scaled_embedding_bag: only support e4m3fn and int8 weight")
157+
TORCH_CHECK(qweight.scalar_type() == c10::ScalarType::Float8_e4m3fn,
158+
"_scaled_embedding_bag: only support e4m3fn weight")
192159
// handle last offsets
193160
int64_t last_offset = indices.numel();
194161

195162
at::Tensor output =
196163
at::empty({batch_size, emb_dim}, qweight.options().dtype(at::kFloat));
197-
if (qweight.scalar_type() == c10::ScalarType::Float8_e4m3fn) {
198-
AT_DISPATCH_INDEX_TYPES(
199-
indices.scalar_type(), "_scaled_embedding_bag", [&] {
200-
at::Float8_e4m3fn *qweight_ptr =
201-
qweight.data_ptr<at::Float8_e4m3fn>();
202-
index_t *indices_ptr = indices.data_ptr<index_t>();
203-
index_t *offsets_ptr = offsets.data_ptr<index_t>();
204-
float *output_ptr = output.data_ptr<float>();
205-
_scaled_embedding_bag<index_t, at::Float8_e4m3fn>(
206-
output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size,
207-
emb_dim, last_offset, w_scale, o_scale);
208-
});
209-
} else {
210-
AT_DISPATCH_INDEX_TYPES(
211-
indices.scalar_type(), "_scaled_embedding_bag", [&] {
212-
int8_t *qweight_ptr = qweight.data_ptr<int8_t>();
213-
index_t *indices_ptr = indices.data_ptr<index_t>();
214-
index_t *offsets_ptr = offsets.data_ptr<index_t>();
215-
float *output_ptr = output.data_ptr<float>();
216-
_scaled_embedding_bag<index_t, int8_t>(
217-
output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size,
218-
emb_dim, last_offset, w_scale, o_scale);
219-
});
220-
}
221-
164+
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embeddingbag_cat", [&] {
165+
at::Float8_e4m3fn *qweight_ptr = qweight.data_ptr<at::Float8_e4m3fn>();
166+
index_t *indices_ptr = indices.data_ptr<index_t>();
167+
index_t *offsets_ptr = offsets.data_ptr<index_t>();
168+
float *output_ptr = output.data_ptr<float>();
169+
_scaled_embedding_bag<index_t, at::Float8_e4m3fn>(
170+
output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size, emb_dim,
171+
last_offset, w_scale, o_scale);
172+
});
222173
return output;
223174
}
224175

0 commit comments

Comments
 (0)