Skip to content

Commit 3f08bc3

Browse files
committed
fix perf
1 parent 4a88598 commit 3f08bc3

File tree

8 files changed

+34
-118
lines changed

8 files changed

+34
-118
lines changed

custom_ops/xpu_ops/src/ops/gather_next_token.cc

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ GatherNextToken(const paddle::Tensor &tmp_out, // [token_num, dim_embed]
2727
const paddle::Tensor &enc_batch_tensor,
2828
const paddle::Tensor &dec_batch_tensor,
2929
const paddle::optional<paddle::Tensor> &output_padding_offset,
30-
const paddle::optional<paddle::Tensor> &token_type_ids,
3130
int max_input_length) {
3231
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
3332
auto dev_ctx =
@@ -53,13 +52,11 @@ GatherNextToken(const paddle::Tensor &tmp_out, // [token_num, dim_embed]
5352

5453
auto out = paddle::full({bsz, dim}, -2, tmp_out.type(), tmp_out.place());
5554

56-
const int32_t* token_type_ids_ptr = token_type_ids.get_ptr() ?
57-
token_type_ids->data<int32_t>() : nullptr;
5855
int r = baidu::xpu::api::plugin::eb_gather_next_token<XPUType, XPUType>(
5956
xpu_ctx->x_context(),
6057
reinterpret_cast<const XPUType *>(tmp_out.data<data_t>()),
6158
reinterpret_cast<XPUType *>(out.data<data_t>()), encoder_seqs_lods_vp,
62-
encoder_batch_map_vp, decoder_batch_map_vp, token_type_ids_ptr, dim);
59+
encoder_batch_map_vp, decoder_batch_map_vp, dim);
6360
return {out};
6461
}
6562

@@ -103,8 +100,7 @@ PD_BUILD_OP(gather_next_token)
103100
"decoder_batch_map", "encoder_seq_lod_cpu",
104101
"encoder_batch_map_cpu", "decoder_batch_map_cpu",
105102
"enc_batch_tensor", "dec_batch_tensor",
106-
paddle::Optional("output_padding_offset"),
107-
paddle::Optional("token_type_ids")})
103+
paddle::Optional("output_padding_offset")})
108104
.Outputs({"out"})
109105
.Attrs({"max_input_length: int"})
110106
.SetKernelFn(PD_KERNEL(GatherNextToken))

custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ eb_gather_next_token(Context *ctx, const TX *x, TY *y,
133133
VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
134134
VectorParam<int32_t> &encoder_batch_map, // NOLINT
135135
VectorParam<int32_t> &decoder_batch_map, // NOLINT
136-
const int32_t* token_type_ids, // for VL model
137136
int64_t hidden_dim);
138137

139138
template <typename TX, typename TSCALE = float, typename TY = int8_t>

custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/eb_gather_next_token.xpu

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ __global__ void eb_gather_next_token(TX* src,
4242
int* encoder_seqs_lods,
4343
int* encoder_batch_map,
4444
int* decoder_batch_map,
45-
const int* token_type_ids,
4645
int en_batch,
4746
int de_batch,
4847
int64_t copy_size) {
@@ -51,7 +50,6 @@ __global__ void eb_gather_next_token(TX* src,
5150
__group_shared__ int local_lods_en[MAX_BATCH + 1];
5251
__group_shared__ int local_map_en[MAX_BATCH];
5352
__group_shared__ int local_map_de[MAX_BATCH];
54-
__group_shared__ int local_token_type_ids[128]; // 128 * 4B = 0.5KB
5553
GM2GSM_ASYNC(encoder_seqs_lods, local_lods_en, (en_batch + 1) * sizeof(int));
5654
if (en_batch > 0) {
5755
GM2GSM_ASYNC(encoder_batch_map, local_map_en, en_batch * sizeof(int));
@@ -68,23 +66,7 @@ __global__ void eb_gather_next_token(TX* src,
6866
for (int i = start; i < end; i++) {
6967
if (i < en_batch) {
7068
// src encode part
71-
int last_text_token = local_lods_en[i + 1] - 1;
72-
if (token_type_ids != nullptr) {
73-
int token_type = token_type_ids[last_text_token]; // GM2LM, size = 1
74-
// token_type: 0 for text and 1 for image
75-
if (__builtin_expect(token_type == 0, 1)) {
76-
; // branch prediction
77-
} else {
78-
// TODU(lilujia): to be optimized with aligned-gm2lm and vectorization
79-
for (int id = local_lods_en[i + 1] - 1; id >= local_lods_en[i]; id--) {
80-
if (token_type_ids[id] == 0) {
81-
last_text_token = id;
82-
break;
83-
}
84-
}
85-
}
86-
}
87-
_global_ptr_ TX* cur_src = src + last_text_token * copy_size;
69+
_global_ptr_ TX* cur_src = src + (local_lods_en[i + 1] - 1) * copy_size;
8870
_global_ptr_ TY* cur_dst = dst + local_map_en[i] * copy_size;
8971
do_memcpy_1d<TX, TY>(cur_src, cur_dst, copy_size);
9072
} else {
@@ -103,7 +85,6 @@ __global__ void eb_gather_next_token(TX* src,
10385
int* encoder_seqs_lods, \
10486
int* encoder_batch_map, \
10587
int* decoder_batch_map, \
106-
const int* token_type_ids, \
10788
int en_batch, \
10889
int de_batch, \
10990
int64_t copy_size);

custom_ops/xpu_ops/src/plugin/src/wrapper/eb_gather_next_token.cpp

Lines changed: 8 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ template <typename TX, typename TY>
2323
__attribute__((global)) void
2424
eb_gather_next_token(TX *src, TY *dst, int *encoder_seqs_lods,
2525
int *encoder_batch_map, int *decoder_batch_map,
26-
const int* token_type_ids,
2726
int en_batch, int de_batch, int64_t copy_size);
2827
} // namespace plugin
2928
} // namespace xpu3
@@ -36,22 +35,13 @@ template <typename TX, typename TY>
3635
static int
3736
cpu_wrapper(api::Context *ctx, const TX *x, TY *y, const int *encoder_seqs_lods,
3837
const int *encoder_batch_map, const int *decoder_batch_map,
39-
const int* token_type_ids, int en_batch, int de_batch,
40-
int64_t hidden_dim) {
38+
int en_batch, int de_batch, int64_t hidden_dim) {
4139
int ret = 0;
4240
int encoder_len_total = encoder_seqs_lods[en_batch];
4341
for (int i = 0; i < en_batch; i++) {
44-
int last_text_token = encoder_seqs_lods[i + 1] - 1;
45-
if (token_type_ids != nullptr) {
46-
for (int id = encoder_seqs_lods[i + 1] - 1; id >= encoder_seqs_lods[i]; id--) {
47-
if (token_type_ids[id] == 0) {
48-
last_text_token = id;
49-
break;
50-
}
51-
}
52-
}
53-
ret = api::cast<TX, TY>(ctx, x + last_text_token * hidden_dim,
54-
y + encoder_batch_map[i] * hidden_dim, hidden_dim);
42+
ret =
43+
api::cast<TX, TY>(ctx, x + (encoder_seqs_lods[i + 1] - 1) * hidden_dim,
44+
y + encoder_batch_map[i] * hidden_dim, hidden_dim);
5545
WRAPPER_ASSERT_SUCCESS(ctx, ret);
5646
}
5747
for (int i = 0; i < de_batch; i++) {
@@ -67,14 +57,12 @@ static int xpu3_wrapper(api::Context *ctx, const TX *x, TY *y,
6757
api::VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
6858
api::VectorParam<int32_t> &encoder_batch_map, // NOLINT
6959
api::VectorParam<int32_t> &decoder_batch_map, // NOLINT
70-
const int32_t* token_type_ids,
71-
int en_batch, int de_batch,
72-
int64_t hidden_dim) {
60+
int en_batch, int de_batch, int64_t hidden_dim) {
7361
auto eb_gather_next_token_kernel = xpu3::plugin::eb_gather_next_token<TX, TY>;
7462
// NOTE: Don't change 16 to 64, because kernel use gsm
7563
eb_gather_next_token_kernel<<<ctx->ncluster(), 16, ctx->xpu_stream>>>(
7664
const_cast<TX *>(x), y, encoder_seqs_lods.xpu, encoder_batch_map.xpu,
77-
decoder_batch_map.xpu, token_type_ids, en_batch, de_batch, hidden_dim);
65+
decoder_batch_map.xpu, en_batch, de_batch, hidden_dim);
7866
return api::SUCCESS;
7967
}
8068

@@ -83,25 +71,18 @@ int eb_gather_next_token(api::Context *ctx, const TX *x, TY *y,
8371
api::VectorParam<int32_t> &encoder_seqs_lods, // NOLINT
8472
api::VectorParam<int32_t> &encoder_batch_map, // NOLINT
8573
api::VectorParam<int32_t> &decoder_batch_map, // NOLINT
86-
const int32_t* token_type_ids, // for VL model
8774
int64_t hidden_dim) {
8875
WRAPPER_CHECK_CTX(ctx);
8976
WRAPPER_DUMP_FUNCTION_T2(ctx, "eb_gather_next_token", TX, TY);
9077
WRAPPER_DUMP_PARAM6(ctx, x, y, encoder_seqs_lods, encoder_batch_map,
91-
decoder_batch_map, token_type_ids);
92-
WRAPPER_DUMP_PARAM1(ctx, hidden_dim);
78+
decoder_batch_map, hidden_dim);
9379
WRAPPER_DUMP(ctx);
9480
int encoder_batch = encoder_batch_map.len;
9581
int batch = encoder_batch + decoder_batch_map.len;
9682
int max_encoder_lod = encoder_seqs_lods.cpu[encoder_batch];
9783
int m = encoder_seqs_lods.cpu[encoder_batch] + decoder_batch_map.len;
9884
WRAPPER_CHECK_PTR(ctx, TX, m * hidden_dim, x);
9985
WRAPPER_CHECK_PTR(ctx, TY, batch * hidden_dim, y);
100-
if (token_type_ids != nullptr) {
101-
// token_type_ids records the token type, 1 for vision and 0 for text
102-
// in text model, token_type_ids is nullptr
103-
WRAPPER_CHECK_PTR(ctx, int32_t, m, token_type_ids);
104-
}
10586
WRAPPER_ASSERT_GT(ctx, hidden_dim, 0);
10687
// check VectorParam
10788
WRAPPER_ASSERT_EQ(ctx, encoder_seqs_lods.len, encoder_batch_map.len + 1);
@@ -118,15 +99,8 @@ int eb_gather_next_token(api::Context *ctx, const TX *x, TY *y,
11899
WRAPPER_ASSERT_GE(ctx, decoder_batch_map.cpu[i], 0);
119100
}
120101
if (ctx->dev().type() == api::kCPU) {
121-
std::vector<int> token_type_ids_vec;
122-
if (token_type_ids != nullptr) {
123-
token_type_ids_vec.resize(m);
124-
int ret = do_device2host(ctx, token_type_ids, token_type_ids_vec.data(), m * sizeof(int32_t));
125-
WRAPPER_ASSERT_SUCCESS(ctx, ret);
126-
}
127102
return cpu_wrapper<TX, TY>(ctx, x, y, encoder_seqs_lods.cpu,
128103
encoder_batch_map.cpu, decoder_batch_map.cpu,
129-
token_type_ids_vec.data(),
130104
encoder_batch_map.len, decoder_batch_map.len,
131105
hidden_dim);
132106
}
@@ -140,7 +114,6 @@ int eb_gather_next_token(api::Context *ctx, const TX *x, TY *y,
140114
decoder_batch_map.to_xpu(RAII_GUARD);
141115
return xpu3_wrapper<TX, TY>(ctx, x, y, encoder_seqs_lods_xpu,
142116
encoder_batch_map_xpu, decoder_batch_map_xpu,
143-
token_type_ids,
144117
encoder_batch_map.len, decoder_batch_map.len,
145118
hidden_dim);
146119
}
@@ -149,7 +122,7 @@ int eb_gather_next_token(api::Context *ctx, const TX *x, TY *y,
149122
#define INSTANTIATION_EB_GATHER_NEXT_TOKEN(TX, TY) \
150123
template int eb_gather_next_token<TX, TY>( \
151124
api::Context *, const TX *, TY *, api::VectorParam<int32_t> &, \
152-
api::VectorParam<int32_t> &, api::VectorParam<int32_t> &, const int32_t*, int64_t);
125+
api::VectorParam<int32_t> &, api::VectorParam<int32_t> &, int64_t);
153126

154127
INSTANTIATION_EB_GATHER_NEXT_TOKEN(float16, float16);
155128
INSTANTIATION_EB_GATHER_NEXT_TOKEN(bfloat16, bfloat16);

fastdeploy/model_executor/forward_meta.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,6 @@ class XPUForwardMeta(ForwardMeta):
234234
total_enc_len: Optional[paddle.Tensor] = None
235235
# position embedding type in rope, supports 'NORMAL' or 'HALF_HEAD_DIM'
236236
pos_emb_type: Optional[str] = "NORMAL"
237-
# used in VL model, record the token type, 1 for image, 0 for text
238-
token_type_ids: Optional[paddle.Tensor] = None
239237

240238

241239
@dataclass

fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -778,9 +778,6 @@ def forward(
778778
)
779779
self._input_embeddings.copy_(input_embeddings, False)
780780

781-
if vl_moe_meta.image_token_num.item() > 0: # for XPU
782-
forward_meta.token_type_ids = vl_moe_meta.token_type_ids
783-
784781
hidden_states = self.ernie(
785782
input_embeddings=self._input_embeddings,
786783
ids_remove_padding=ids_remove_padding,

fastdeploy/worker/xpu_model_runner.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,6 @@ def xpu_process_output(
184184
xpu_forward_meta.enc_batch,
185185
xpu_forward_meta.dec_batch,
186186
None, # output_padding_offset
187-
xpu_forward_meta.token_type_ids, # token_type_ids
188187
-1, # max_input_length
189188
)
190189
return hiddden_states
@@ -1231,6 +1230,15 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
12311230
}
12321231
)
12331232

1233+
def clear_block_table(self) -> None:
1234+
"""
1235+
Clear the block tables and kv cache after profiling.
1236+
"""
1237+
del self.share_inputs["caches"]
1238+
if self.forward_meta is not None:
1239+
del self.forward_meta.caches
1240+
paddle.device.xpu.empty_cache()
1241+
12341242
def cal_theortical_kvcache(self):
12351243
"""
12361244
Calculate the total block memory required at the model level

fastdeploy/worker/xpu_worker.py

Lines changed: 14 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"""
1616

1717
import gc
18-
import time
1918
from typing import List, Optional
2019

2120
import paddle
@@ -101,73 +100,38 @@ def determine_available_memory(self) -> int:
101100
len(self.device_ids) > self.local_rank
102101
), f"device number must be greater than local rank, but get device number is {len(self.device_ids)}, rank is {self.local_rank}"
103102

104-
# 1. Record memory state before profile run
105-
start_time = time.perf_counter()
106-
Gb = 1024**3
107-
local_rank = self.local_rank % 8
108-
paddle.device.xpu.reset_max_memory_reserved(local_rank)
109-
paddle.device.xpu.reset_max_memory_allocated(local_rank)
110-
paddle_reserved_mem_before_run = paddle.device.xpu.max_memory_reserved(local_rank)
111-
paddle_allocated_mem_before_run = paddle.device.xpu.max_memory_allocated(local_rank)
112-
before_run_mem_total = xpu_get_total_global_memory(int(self.device_ids[self.local_rank]))
113-
before_run_mem_used = xpu_get_used_global_memory(int(self.device_ids[self.local_rank]))
114-
before_run_mem_free = xpu_get_free_global_memory(int(self.device_ids[self.local_rank]))
103+
total_memory = xpu_get_total_global_memory(int(self.device_ids[self.local_rank]))
104+
used_memory = xpu_get_used_global_memory(int(self.device_ids[self.local_rank]))
105+
free_memory = xpu_get_free_global_memory(int(self.device_ids[self.local_rank]))
115106

116107
logger.info(
117-
(
118-
"Before running the profile, the memory usage info is as follows:",
119-
f"\nDevice Total memory: {before_run_mem_total / Gb}",
120-
f"\nDevice used memory: {before_run_mem_used / Gb}",
121-
f"\nDevice free memory: {before_run_mem_free / Gb}",
122-
f"\nPaddle reserved memory: {paddle_reserved_mem_before_run / Gb}",
123-
f"\nPaddle allocated memory: {paddle_allocated_mem_before_run / Gb}",
124-
)
108+
f"Before warm up, total_memory: {total_memory}, \
109+
used_memory: {used_memory}, free_memory: {free_memory}"
125110
)
126111

127-
# 2. Profile run
128112
self.model_runner.prepare_profile()
129113
if self.parallel_config.use_ep:
130114
logger.warning("EP mode does not support profile run.")
131115
else:
132116
self.model_runner.profile_run()
133117
set_random_seed(self.fd_config.model_config.seed)
134118

135-
# 3. Statistical memory information
136-
paddle_reserved_mem_after_run = paddle.device.xpu.max_memory_reserved(local_rank)
137-
paddle_allocated_mem_after_run = paddle.device.xpu.max_memory_allocated(local_rank)
138-
119+
total_available_memory = int(total_memory * self.cache_config.gpu_memory_utilization)
120+
used_memory = xpu_get_used_global_memory(int(self.device_ids[self.local_rank]))
121+
available_kv_cache_memory = total_available_memory - used_memory
139122
model_block_memory_used = self.cal_theortical_kvcache()
140-
paddle_peak_increase = paddle_reserved_mem_after_run - paddle_allocated_mem_before_run
141-
142-
paddle.device.xpu.empty_cache()
143-
144-
after_run_mem_total = xpu_get_total_global_memory(int(self.device_ids[self.local_rank])).item()
145-
after_run_mem_used = xpu_get_used_global_memory(int(self.device_ids[self.local_rank])).item()
146-
after_run_mem_free = xpu_get_free_global_memory(int(self.device_ids[self.local_rank])).item()
147-
148-
available_kv_cache_memory = (
149-
after_run_mem_total * self.cache_config.gpu_memory_utilization - after_run_mem_used - paddle_peak_increase
150-
)
151123
available_kv_cache_memory += model_block_memory_used * self.parallel_config.total_block_num
152-
153124
if self.parallel_config.use_ep:
154125
available_kv_cache_memory = int(available_kv_cache_memory * 0.6)
155126

156-
end_time = time.perf_counter()
127+
self.model_runner.clear_block_table()
128+
157129
logger.info(
158-
(
159-
"After running the profile, the memory usage info is as follows:",
160-
f"\nDevice Total memory: {after_run_mem_total / Gb}",
161-
f"\nDevice used memory: {after_run_mem_used / Gb}",
162-
f"\nDevice free memory: {after_run_mem_free / Gb}",
163-
f"\nPaddle reserved memory: {paddle_reserved_mem_after_run / Gb}",
164-
f"\nPaddle allocated memory: {paddle_allocated_mem_after_run / Gb}",
165-
f"\nAvailable KV Cache meomory: {available_kv_cache_memory / Gb}",
166-
f"Profile time: {end_time - start_time}",
167-
)
130+
f"After warm up, total_available_memory: {total_available_memory}, \
131+
used_memory: {used_memory}, available_kv_cache_memory: {available_kv_cache_memory}"
168132
)
169-
170-
return available_kv_cache_memory # return to calculate the block num in this device
133+
paddle.device.xpu.empty_cache()
134+
return available_kv_cache_memory # approximate value
171135

172136
def load_model(self) -> None:
173137
"""Load model"""

0 commit comments

Comments
 (0)