@@ -23,7 +23,6 @@ template <typename TX, typename TY>
2323__attribute__ ((global)) void
2424eb_gather_next_token (TX *src, TY *dst, int *encoder_seqs_lods,
2525 int *encoder_batch_map, int *decoder_batch_map,
26- const int * token_type_ids,
2726 int en_batch, int de_batch, int64_t copy_size);
2827} // namespace plugin
2928} // namespace xpu3
@@ -36,22 +35,13 @@ template <typename TX, typename TY>
3635static int
3736cpu_wrapper (api::Context *ctx, const TX *x, TY *y, const int *encoder_seqs_lods,
3837 const int *encoder_batch_map, const int *decoder_batch_map,
39- const int * token_type_ids, int en_batch, int de_batch,
40- int64_t hidden_dim) {
38+ int en_batch, int de_batch, int64_t hidden_dim) {
4139 int ret = 0 ;
4240 int encoder_len_total = encoder_seqs_lods[en_batch];
4341 for (int i = 0 ; i < en_batch; i++) {
44- int last_text_token = encoder_seqs_lods[i + 1 ] - 1 ;
45- if (token_type_ids != nullptr ) {
46- for (int id = encoder_seqs_lods[i + 1 ] - 1 ; id >= encoder_seqs_lods[i]; id--) {
47- if (token_type_ids[id] == 0 ) {
48- last_text_token = id;
49- break ;
50- }
51- }
52- }
53- ret = api::cast<TX, TY>(ctx, x + last_text_token * hidden_dim,
54- y + encoder_batch_map[i] * hidden_dim, hidden_dim);
42+ ret =
43+ api::cast<TX, TY>(ctx, x + (encoder_seqs_lods[i + 1 ] - 1 ) * hidden_dim,
44+ y + encoder_batch_map[i] * hidden_dim, hidden_dim);
5545 WRAPPER_ASSERT_SUCCESS (ctx, ret);
5646 }
5747 for (int i = 0 ; i < de_batch; i++) {
@@ -67,14 +57,12 @@ static int xpu3_wrapper(api::Context *ctx, const TX *x, TY *y,
6757 api::VectorParam<int32_t > &encoder_seqs_lods, // NOLINT
6858 api::VectorParam<int32_t > &encoder_batch_map, // NOLINT
6959 api::VectorParam<int32_t > &decoder_batch_map, // NOLINT
70- const int32_t * token_type_ids,
71- int en_batch, int de_batch,
72- int64_t hidden_dim) {
60+ int en_batch, int de_batch, int64_t hidden_dim) {
7361 auto eb_gather_next_token_kernel = xpu3::plugin::eb_gather_next_token<TX, TY>;
7462 // NOTE: Don't change 16 to 64, because kernel use gsm
7563 eb_gather_next_token_kernel<<<ctx->ncluster (), 16 , ctx->xpu_stream >>>(
7664 const_cast <TX *>(x), y, encoder_seqs_lods.xpu , encoder_batch_map.xpu ,
77- decoder_batch_map.xpu , token_type_ids, en_batch, de_batch, hidden_dim);
65+ decoder_batch_map.xpu , en_batch, de_batch, hidden_dim);
7866 return api::SUCCESS;
7967}
8068
@@ -83,25 +71,18 @@ int eb_gather_next_token(api::Context *ctx, const TX *x, TY *y,
8371 api::VectorParam<int32_t > &encoder_seqs_lods, // NOLINT
8472 api::VectorParam<int32_t > &encoder_batch_map, // NOLINT
8573 api::VectorParam<int32_t > &decoder_batch_map, // NOLINT
86- const int32_t * token_type_ids, // for VL model
8774 int64_t hidden_dim) {
8875 WRAPPER_CHECK_CTX (ctx);
8976 WRAPPER_DUMP_FUNCTION_T2 (ctx, " eb_gather_next_token" , TX, TY);
9077 WRAPPER_DUMP_PARAM6 (ctx, x, y, encoder_seqs_lods, encoder_batch_map,
91- decoder_batch_map, token_type_ids);
92- WRAPPER_DUMP_PARAM1 (ctx, hidden_dim);
78+ decoder_batch_map, hidden_dim);
9379 WRAPPER_DUMP (ctx);
9480 int encoder_batch = encoder_batch_map.len ;
9581 int batch = encoder_batch + decoder_batch_map.len ;
9682 int max_encoder_lod = encoder_seqs_lods.cpu [encoder_batch];
9783 int m = encoder_seqs_lods.cpu [encoder_batch] + decoder_batch_map.len ;
9884 WRAPPER_CHECK_PTR (ctx, TX, m * hidden_dim, x);
9985 WRAPPER_CHECK_PTR (ctx, TY, batch * hidden_dim, y);
100- if (token_type_ids != nullptr ) {
101- // token_type_ids records the token type, 1 for vision and 0 for text
102- // in text model, token_type_ids is nullptr
103- WRAPPER_CHECK_PTR (ctx, int32_t , m, token_type_ids);
104- }
10586 WRAPPER_ASSERT_GT (ctx, hidden_dim, 0 );
10687 // check VectorParam
10788 WRAPPER_ASSERT_EQ (ctx, encoder_seqs_lods.len , encoder_batch_map.len + 1 );
@@ -118,15 +99,8 @@ int eb_gather_next_token(api::Context *ctx, const TX *x, TY *y,
11899 WRAPPER_ASSERT_GE (ctx, decoder_batch_map.cpu [i], 0 );
119100 }
120101 if (ctx->dev ().type () == api::kCPU ) {
121- std::vector<int > token_type_ids_vec;
122- if (token_type_ids != nullptr ) {
123- token_type_ids_vec.resize (m);
124- int ret = do_device2host (ctx, token_type_ids, token_type_ids_vec.data (), m * sizeof (int32_t ));
125- WRAPPER_ASSERT_SUCCESS (ctx, ret);
126- }
127102 return cpu_wrapper<TX, TY>(ctx, x, y, encoder_seqs_lods.cpu ,
128103 encoder_batch_map.cpu , decoder_batch_map.cpu ,
129- token_type_ids_vec.data (),
130104 encoder_batch_map.len , decoder_batch_map.len ,
131105 hidden_dim);
132106 }
@@ -140,7 +114,6 @@ int eb_gather_next_token(api::Context *ctx, const TX *x, TY *y,
140114 decoder_batch_map.to_xpu (RAII_GUARD);
141115 return xpu3_wrapper<TX, TY>(ctx, x, y, encoder_seqs_lods_xpu,
142116 encoder_batch_map_xpu, decoder_batch_map_xpu,
143- token_type_ids,
144117 encoder_batch_map.len , decoder_batch_map.len ,
145118 hidden_dim);
146119 }
@@ -149,7 +122,7 @@ int eb_gather_next_token(api::Context *ctx, const TX *x, TY *y,
149122#define INSTANTIATION_EB_GATHER_NEXT_TOKEN (TX, TY ) \
150123 template int eb_gather_next_token<TX, TY>( \
151124 api::Context *, const TX *, TY *, api::VectorParam<int32_t > &, \
152- api::VectorParam<int32_t > &, api::VectorParam<int32_t > &, const int32_t *, int64_t );
125+ api::VectorParam<int32_t > &, api::VectorParam<int32_t > &, int64_t );
153126
154127INSTANTIATION_EB_GATHER_NEXT_TOKEN (float16, float16);
155128INSTANTIATION_EB_GATHER_NEXT_TOKEN (bfloat16, bfloat16);
0 commit comments