@@ -11,19 +11,55 @@ namespace torchao {
1111namespace {
1212
1313#if defined(CPU_CAPABILITY_AVX512)
14+ using CHUNK =
15+ std::tuple<__m512, __m512, __m512, __m512, __m512, __m512, __m512, __m512>;
1416static inline __m512 _mm512_load_e4m3_cvt_ps (const at::Float8_e4m3fn *x) {
1517 __m512 o;
1618 __m128i v = _mm_loadu_si128 (reinterpret_cast <const __m128i *>(x));
1719 at::vec::CPU_CAPABILITY::cvtfp8e4m3_fp32 (v, o);
1820 return o;
1921}
22+
23+ static inline __m512 _mm512_cvt_s8_ps (__m128i x) {
24+ return _mm512_cvt_roundepi32_ps (
25+ _mm512_cvtepi8_epi32 (x), (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
26+ }
27+
28+ static inline CHUNK load_chunk (const at::Float8_e4m3fn *x) {
29+ __m512 x0, x1, x2, x3, x4, x5, x6, x7;
30+ x0 = _mm512_load_e4m3_cvt_ps (x + 0 );
31+ x1 = _mm512_load_e4m3_cvt_ps (x + 16 );
32+ x2 = _mm512_load_e4m3_cvt_ps (x + 32 );
33+ x3 = _mm512_load_e4m3_cvt_ps (x + 48 );
34+ x4 = _mm512_load_e4m3_cvt_ps (x + 64 );
35+ x5 = _mm512_load_e4m3_cvt_ps (x + 80 );
36+ x6 = _mm512_load_e4m3_cvt_ps (x + 96 );
37+ x7 = _mm512_load_e4m3_cvt_ps (x + 112 );
38+ return {x0, x1, x2, x3, x4, x5, x6, x7};
39+ }
40+
41+ static inline CHUNK load_chunk (const int8_t *x) {
42+ __m512i x00, x64;
43+ __m512 x0, x1, x2, x3, x4, x5, x6, x7;
44+ x00 = _mm512_load_si512 (x);
45+ x64 = _mm512_load_si512 (x + 64 );
46+ x0 = _mm512_cvt_s8_ps (_mm512_extracti32x4_epi32 (x00, 0 ));
47+ x1 = _mm512_cvt_s8_ps (_mm512_extracti32x4_epi32 (x00, 1 ));
48+ x2 = _mm512_cvt_s8_ps (_mm512_extracti32x4_epi32 (x00, 2 ));
49+ x3 = _mm512_cvt_s8_ps (_mm512_extracti32x4_epi32 (x00, 3 ));
50+ x4 = _mm512_cvt_s8_ps (_mm512_extracti32x4_epi32 (x64, 0 ));
51+ x5 = _mm512_cvt_s8_ps (_mm512_extracti32x4_epi32 (x64, 1 ));
52+ x6 = _mm512_cvt_s8_ps (_mm512_extracti32x4_epi32 (x64, 2 ));
53+ x7 = _mm512_cvt_s8_ps (_mm512_extracti32x4_epi32 (x64, 3 ));
54+ return {x0, x1, x2, x3, x4, x5, x6, x7};
55+ }
2056#endif
2157
22- template <typename index_t >
58+ template <typename index_t , typename data_t >
2359inline void _scaled_embedding_bag_krnl (
2460 const int64_t bs_begin, const int64_t bs_end, const int64_t num_emb,
2561 const int64_t emb_dim, const index_t last_offset, const index_t *indices,
26- const index_t *offsets, const at::Float8_e4m3fn *weight, const double scale,
62+ const index_t *offsets, const data_t *weight, const double scale,
2763 float *result, const int64_t num_batch) {
2864#if defined(CPU_CAPABILITY_AVX512)
2965 if (emb_dim % 128 == 0 ) {
@@ -32,6 +68,7 @@ inline void _scaled_embedding_bag_krnl(
3268 __m512 scale_v = _mm512_set1_ps (scale);
3369 for (int64_t b = bs_begin; b < bs_end; ++b) {
3470 __m512 x0, x1, x2, x3, x4, x5, x6, x7;
71+ __m512 y0, y1, y2, y3, y4, y5, y6, y7;
3572 int64_t start_idx = offsets[b];
3673 int64_t end_idx = ((b + 1 ) == num_batch && last_offset != -1 )
3774 ? last_offset
@@ -40,25 +77,19 @@ inline void _scaled_embedding_bag_krnl(
4077 // load first indices
4178 int64_t idx = indices[start_idx] * emb_dim + block_dim * block_id;
4279 float *block_result = result + block_dim * block_id;
43- x0 = _mm512_load_e4m3_cvt_ps (&weight[idx]);
44- x1 = _mm512_load_e4m3_cvt_ps (&weight[idx + 16 ]);
45- x2 = _mm512_load_e4m3_cvt_ps (&weight[idx + 32 ]);
46- x3 = _mm512_load_e4m3_cvt_ps (&weight[idx + 48 ]);
47- x4 = _mm512_load_e4m3_cvt_ps (&weight[idx + 64 ]);
48- x5 = _mm512_load_e4m3_cvt_ps (&weight[idx + 80 ]);
49- x6 = _mm512_load_e4m3_cvt_ps (&weight[idx + 96 ]);
50- x7 = _mm512_load_e4m3_cvt_ps (&weight[idx + 112 ]);
80+ std::tie (x0, x1, x2, x3, x4, x5, x6, x7) = load_chunk (weight + idx);
5181 for (int64_t j = start_idx + 1 ; j < end_idx; ++j) {
5282 // add following idx
5383 idx = indices[j] * emb_dim + block_dim * block_id;
54- x0 = _mm512_add_ps (x0, _mm512_load_e4m3_cvt_ps (&weight[idx]));
55- x1 = _mm512_add_ps (x1, _mm512_load_e4m3_cvt_ps (&weight[idx + 16 ]));
56- x2 = _mm512_add_ps (x2, _mm512_load_e4m3_cvt_ps (&weight[idx + 32 ]));
57- x3 = _mm512_add_ps (x3, _mm512_load_e4m3_cvt_ps (&weight[idx + 48 ]));
58- x4 = _mm512_add_ps (x4, _mm512_load_e4m3_cvt_ps (&weight[idx + 64 ]));
59- x5 = _mm512_add_ps (x5, _mm512_load_e4m3_cvt_ps (&weight[idx + 80 ]));
60- x6 = _mm512_add_ps (x6, _mm512_load_e4m3_cvt_ps (&weight[idx + 96 ]));
61- x7 = _mm512_add_ps (x7, _mm512_load_e4m3_cvt_ps (&weight[idx + 112 ]));
84+ std::tie (y0, y1, y2, y3, y4, y5, y6, y7) = load_chunk (weight + idx);
85+ x0 = _mm512_add_ps (x0, y0);
86+ x1 = _mm512_add_ps (x1, y1);
87+ x2 = _mm512_add_ps (x2, y2);
88+ x3 = _mm512_add_ps (x3, y3);
89+ x4 = _mm512_add_ps (x4, y4);
90+ x5 = _mm512_add_ps (x5, y5);
91+ x6 = _mm512_add_ps (x6, y6);
92+ x7 = _mm512_add_ps (x7, y7);
6293 }
6394 x0 = _mm512_mul_ps (x0, scale_v);
6495 x1 = _mm512_mul_ps (x1, scale_v);
@@ -143,6 +174,7 @@ at::Tensor _scaled_embedding_bag_impl(const at::Tensor &qweight,
143174 int64_t emb_dim = qweight.size (1 );
144175
145176 auto index_type = indices.scalar_type ();
177+ auto qtype = qweight.scalar_type ();
146178 float w_scale = w_scales.data_ptr <float >()[0 ];
147179
148180 TORCH_CHECK (indices.is_contiguous () && offsets.is_contiguous (),
@@ -154,22 +186,39 @@ at::Tensor _scaled_embedding_bag_impl(const at::Tensor &qweight,
154186 " _scaled_embedding_bag: only accept contiguous weight" );
155187 TORCH_CHECK (qweight.dim () == 2 ,
156188 " _scaled_embedding_bag: only accept weight with dim == 2" );
157- TORCH_CHECK (qweight.scalar_type () == c10::ScalarType::Float8_e4m3fn,
158- " _scaled_embedding_bag: only support e4m3fn weight" )
189+ TORCH_CHECK (qweight.scalar_type () == c10::ScalarType::Float8_e4m3fn ||
190+ qweight.scalar_type () == c10::ScalarType::Char,
191+ " _scaled_embedding_bag: only support e4m3fn and int8 weight" )
159192 // handle last offsets
160193 int64_t last_offset = indices.numel ();
161194
162195 at::Tensor output =
163196 at::empty ({batch_size, emb_dim}, qweight.options ().dtype (at::kFloat ));
164- AT_DISPATCH_INDEX_TYPES (indices.scalar_type (), " embeddingbag_cat" , [&] {
165- at::Float8_e4m3fn *qweight_ptr = qweight.data_ptr <at::Float8_e4m3fn>();
166- index_t *indices_ptr = indices.data_ptr <index_t >();
167- index_t *offsets_ptr = offsets.data_ptr <index_t >();
168- float *output_ptr = output.data_ptr <float >();
169- _scaled_embedding_bag<index_t , at::Float8_e4m3fn>(
170- output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size, emb_dim,
171- last_offset, w_scale, o_scale);
172- });
197+ if (qweight.scalar_type () == c10::ScalarType::Float8_e4m3fn) {
198+ AT_DISPATCH_INDEX_TYPES (
199+ indices.scalar_type (), " _scaled_embedding_bag" , [&] {
200+ at::Float8_e4m3fn *qweight_ptr =
201+ qweight.data_ptr <at::Float8_e4m3fn>();
202+ index_t *indices_ptr = indices.data_ptr <index_t >();
203+ index_t *offsets_ptr = offsets.data_ptr <index_t >();
204+ float *output_ptr = output.data_ptr <float >();
205+ _scaled_embedding_bag<index_t , at::Float8_e4m3fn>(
206+ output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size,
207+ emb_dim, last_offset, w_scale, o_scale);
208+ });
209+ } else {
210+ AT_DISPATCH_INDEX_TYPES (
211+ indices.scalar_type (), " _scaled_embedding_bag" , [&] {
212+ int8_t *qweight_ptr = qweight.data_ptr <int8_t >();
213+ index_t *indices_ptr = indices.data_ptr <index_t >();
214+ index_t *offsets_ptr = offsets.data_ptr <index_t >();
215+ float *output_ptr = output.data_ptr <float >();
216+ _scaled_embedding_bag<index_t , int8_t >(
217+ output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size,
218+ emb_dim, last_offset, w_scale, o_scale);
219+ });
220+ }
221+
173222 return output;
174223}
175224
0 commit comments