@@ -11,55 +11,19 @@ namespace torchao {
1111namespace {
1212
1313#if defined(CPU_CAPABILITY_AVX512)
14- using CHUNK =
15- std::tuple<__m512, __m512, __m512, __m512, __m512, __m512, __m512, __m512>;
1614static inline __m512 _mm512_load_e4m3_cvt_ps (const at::Float8_e4m3fn *x) {
1715 __m512 o;
1816 __m128i v = _mm_loadu_si128 (reinterpret_cast <const __m128i *>(x));
1917 at::vec::CPU_CAPABILITY::cvtfp8e4m3_fp32 (v, o);
2018 return o;
2119}
22-
23- static inline __m512 _mm512_cvt_s8_ps (__m128i x) {
24- return _mm512_cvt_roundepi32_ps (
25- _mm512_cvtepi8_epi32 (x), (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
26- }
27-
28- static inline CHUNK load_chunk (const at::Float8_e4m3fn *x) {
29- __m512 x0, x1, x2, x3, x4, x5, x6, x7;
30- x0 = _mm512_load_e4m3_cvt_ps (x + 0 );
31- x1 = _mm512_load_e4m3_cvt_ps (x + 16 );
32- x2 = _mm512_load_e4m3_cvt_ps (x + 32 );
33- x3 = _mm512_load_e4m3_cvt_ps (x + 48 );
34- x4 = _mm512_load_e4m3_cvt_ps (x + 64 );
35- x5 = _mm512_load_e4m3_cvt_ps (x + 80 );
36- x6 = _mm512_load_e4m3_cvt_ps (x + 96 );
37- x7 = _mm512_load_e4m3_cvt_ps (x + 112 );
38- return {x0, x1, x2, x3, x4, x5, x6, x7};
39- }
40-
41- static inline CHUNK load_chunk (const int8_t *x) {
42- __m512i x00, x64;
43- __m512 x0, x1, x2, x3, x4, x5, x6, x7;
44- x00 = _mm512_load_si512 (x);
45- x64 = _mm512_load_si512 (x + 64 );
46- x0 = _mm512_cvt_s8_ps (_mm512_extracti32x4_epi32 (x00, 0 ));
47- x1 = _mm512_cvt_s8_ps (_mm512_extracti32x4_epi32 (x00, 1 ));
48- x2 = _mm512_cvt_s8_ps (_mm512_extracti32x4_epi32 (x00, 2 ));
49- x3 = _mm512_cvt_s8_ps (_mm512_extracti32x4_epi32 (x00, 3 ));
50- x4 = _mm512_cvt_s8_ps (_mm512_extracti32x4_epi32 (x64, 0 ));
51- x5 = _mm512_cvt_s8_ps (_mm512_extracti32x4_epi32 (x64, 1 ));
52- x6 = _mm512_cvt_s8_ps (_mm512_extracti32x4_epi32 (x64, 2 ));
53- x7 = _mm512_cvt_s8_ps (_mm512_extracti32x4_epi32 (x64, 3 ));
54- return {x0, x1, x2, x3, x4, x5, x6, x7};
55- }
5620#endif
5721
58- template <typename index_t , typename data_t >
22+ template <typename index_t >
5923inline void _scaled_embedding_bag_krnl (
6024 const int64_t bs_begin, const int64_t bs_end, const int64_t num_emb,
6125 const int64_t emb_dim, const index_t last_offset, const index_t *indices,
62- const index_t *offsets, const data_t *weight, const double scale,
26+ const index_t *offsets, const at::Float8_e4m3fn *weight, const double scale,
6327 float *result, const int64_t num_batch) {
6428#if defined(CPU_CAPABILITY_AVX512)
6529 if (emb_dim % 128 == 0 ) {
@@ -68,7 +32,6 @@ inline void _scaled_embedding_bag_krnl(
6832 __m512 scale_v = _mm512_set1_ps (scale);
6933 for (int64_t b = bs_begin; b < bs_end; ++b) {
7034 __m512 x0, x1, x2, x3, x4, x5, x6, x7;
71- __m512 y0, y1, y2, y3, y4, y5, y6, y7;
7235 int64_t start_idx = offsets[b];
7336 int64_t end_idx = ((b + 1 ) == num_batch && last_offset != -1 )
7437 ? last_offset
@@ -77,19 +40,25 @@ inline void _scaled_embedding_bag_krnl(
7740 // load first indices
7841 int64_t idx = indices[start_idx] * emb_dim + block_dim * block_id;
7942 float *block_result = result + block_dim * block_id;
80- std::tie (x0, x1, x2, x3, x4, x5, x6, x7) = load_chunk (weight + idx);
43+ x0 = _mm512_load_e4m3_cvt_ps (&weight[idx]);
44+ x1 = _mm512_load_e4m3_cvt_ps (&weight[idx + 16 ]);
45+ x2 = _mm512_load_e4m3_cvt_ps (&weight[idx + 32 ]);
46+ x3 = _mm512_load_e4m3_cvt_ps (&weight[idx + 48 ]);
47+ x4 = _mm512_load_e4m3_cvt_ps (&weight[idx + 64 ]);
48+ x5 = _mm512_load_e4m3_cvt_ps (&weight[idx + 80 ]);
49+ x6 = _mm512_load_e4m3_cvt_ps (&weight[idx + 96 ]);
50+ x7 = _mm512_load_e4m3_cvt_ps (&weight[idx + 112 ]);
8151 for (int64_t j = start_idx + 1 ; j < end_idx; ++j) {
8252 // add following idx
8353 idx = indices[j] * emb_dim + block_dim * block_id;
84- std::tie (y0, y1, y2, y3, y4, y5, y6, y7) = load_chunk (weight + idx);
85- x0 = _mm512_add_ps (x0, y0);
86- x1 = _mm512_add_ps (x1, y1);
87- x2 = _mm512_add_ps (x2, y2);
88- x3 = _mm512_add_ps (x3, y3);
89- x4 = _mm512_add_ps (x4, y4);
90- x5 = _mm512_add_ps (x5, y5);
91- x6 = _mm512_add_ps (x6, y6);
92- x7 = _mm512_add_ps (x7, y7);
54+ x0 = _mm512_add_ps (x0, _mm512_load_e4m3_cvt_ps (&weight[idx]));
55+ x1 = _mm512_add_ps (x1, _mm512_load_e4m3_cvt_ps (&weight[idx + 16 ]));
56+ x2 = _mm512_add_ps (x2, _mm512_load_e4m3_cvt_ps (&weight[idx + 32 ]));
57+ x3 = _mm512_add_ps (x3, _mm512_load_e4m3_cvt_ps (&weight[idx + 48 ]));
58+ x4 = _mm512_add_ps (x4, _mm512_load_e4m3_cvt_ps (&weight[idx + 64 ]));
59+ x5 = _mm512_add_ps (x5, _mm512_load_e4m3_cvt_ps (&weight[idx + 80 ]));
60+ x6 = _mm512_add_ps (x6, _mm512_load_e4m3_cvt_ps (&weight[idx + 96 ]));
61+ x7 = _mm512_add_ps (x7, _mm512_load_e4m3_cvt_ps (&weight[idx + 112 ]));
9362 }
9463 x0 = _mm512_mul_ps (x0, scale_v);
9564 x1 = _mm512_mul_ps (x1, scale_v);
@@ -174,7 +143,6 @@ at::Tensor _scaled_embedding_bag_impl(const at::Tensor &qweight,
174143 int64_t emb_dim = qweight.size (1 );
175144
176145 auto index_type = indices.scalar_type ();
177- auto qtype = qweight.scalar_type ();
178146 float w_scale = w_scales.data_ptr <float >()[0 ];
179147
180148 TORCH_CHECK (indices.is_contiguous () && offsets.is_contiguous (),
@@ -186,39 +154,22 @@ at::Tensor _scaled_embedding_bag_impl(const at::Tensor &qweight,
186154 " _scaled_embedding_bag: only accept contiguous weight" );
187155 TORCH_CHECK (qweight.dim () == 2 ,
188156 " _scaled_embedding_bag: only accept weight with dim == 2" );
189- TORCH_CHECK (qweight.scalar_type () == c10::ScalarType::Float8_e4m3fn ||
190- qweight.scalar_type () == c10::ScalarType::Char,
191- " _scaled_embedding_bag: only support e4m3fn and int8 weight" )
157+ TORCH_CHECK (qweight.scalar_type () == c10::ScalarType::Float8_e4m3fn,
158+ " _scaled_embedding_bag: only support e4m3fn weight" )
192159 // handle last offsets
193160 int64_t last_offset = indices.numel ();
194161
195162 at::Tensor output =
196163 at::empty ({batch_size, emb_dim}, qweight.options ().dtype (at::kFloat ));
197- if (qweight.scalar_type () == c10::ScalarType::Float8_e4m3fn) {
198- AT_DISPATCH_INDEX_TYPES (
199- indices.scalar_type (), " _scaled_embedding_bag" , [&] {
200- at::Float8_e4m3fn *qweight_ptr =
201- qweight.data_ptr <at::Float8_e4m3fn>();
202- index_t *indices_ptr = indices.data_ptr <index_t >();
203- index_t *offsets_ptr = offsets.data_ptr <index_t >();
204- float *output_ptr = output.data_ptr <float >();
205- _scaled_embedding_bag<index_t , at::Float8_e4m3fn>(
206- output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size,
207- emb_dim, last_offset, w_scale, o_scale);
208- });
209- } else {
210- AT_DISPATCH_INDEX_TYPES (
211- indices.scalar_type (), " _scaled_embedding_bag" , [&] {
212- int8_t *qweight_ptr = qweight.data_ptr <int8_t >();
213- index_t *indices_ptr = indices.data_ptr <index_t >();
214- index_t *offsets_ptr = offsets.data_ptr <index_t >();
215- float *output_ptr = output.data_ptr <float >();
216- _scaled_embedding_bag<index_t , int8_t >(
217- output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size,
218- emb_dim, last_offset, w_scale, o_scale);
219- });
220- }
221-
164+ AT_DISPATCH_INDEX_TYPES (indices.scalar_type (), " embeddingbag_cat" , [&] {
165+ at::Float8_e4m3fn *qweight_ptr = qweight.data_ptr <at::Float8_e4m3fn>();
166+ index_t *indices_ptr = indices.data_ptr <index_t >();
167+ index_t *offsets_ptr = offsets.data_ptr <index_t >();
168+ float *output_ptr = output.data_ptr <float >();
169+ _scaled_embedding_bag<index_t , at::Float8_e4m3fn>(
170+ output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size, emb_dim,
171+ last_offset, w_scale, o_scale);
172+ });
222173 return output;
223174}
224175
0 commit comments