Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions custom_ops/xpu_ops/src/ops/block_attn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ std::vector<paddle::Tensor> BlockAttnKernel(
const paddle::Tensor &encoder_seq_lod_cpu,
const paddle::Tensor &encoder_batch_map_cpu,
const paddle::Tensor &decoder_context_len_cpu,
const paddle::Tensor &decoder_batch_map_cpu) {
const paddle::Tensor &decoder_batch_map_cpu,
const std::string &pos_emb_type="NORMAL",
bool rope_3d=false) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx =
paddle::experimental::DeviceContextPool::Instance().Get(place);
Expand Down Expand Up @@ -72,6 +74,14 @@ std::vector<paddle::Tensor> BlockAttnKernel(
int enc_batch = enc_batch_tensor.data<int32_t>()[0];
int dec_batch = dec_batch_tensor.data<int32_t>()[0];
int total_enc_len = total_enc_len_tensor.data<int32_t>()[0];
int rope_max_seqlen = 0;
int rope_3d_num_seqs = 1;
if (rope_3d) {
rope_max_seqlen = rotary_embs.dims()[3];
rope_3d_num_seqs = rotary_embs.dims()[0];
} else {
rope_max_seqlen = rotary_embs.dims()[2];
}

auto block_attn_out =
paddle::full({token_num, hidden_dim}, -1, qkv.type(), qkv.place());
Expand Down Expand Up @@ -151,10 +161,10 @@ std::vector<paddle::Tensor> BlockAttnKernel(
prefix_lens_vp, // start_tokens
param.batch_size, // batch_size
1, // emb_batch_size
rotary_embs.dims()[2], // max_seqlen
rope_max_seqlen, // max_seqlen
param.head_num, param.kv_head_num, param.head_dim,
param.max_batch_size, block_size, max_block_per_seq, "BLHD",
"HLD", "NORMAL",
"HLD", pos_emb_type,
!p_kcache_perhead_scale.defined()
? nullptr
: p_kcache_perhead_scale.data<float>() +
Expand Down Expand Up @@ -246,10 +256,10 @@ std::vector<paddle::Tensor> BlockAttnKernel(
vsl.slot_mapping_vp, // real_batch
param.batch_size, // batch_size
1, // emb_batch_size
rotary_embs.dims()[2], // max_seqlen TODO!!double check
rope_max_seqlen, // max_seqlen
param.head_num, param.kv_head_num, param.head_dim,
param.max_batch_size, block_size, max_block_per_seq, "BLHD", "HLD",
"NORMAL",
pos_emb_type,
!p_kcache_perhead_scale.defined()
? nullptr
: p_kcache_perhead_scale.data<float>() +
Expand All @@ -260,7 +270,9 @@ std::vector<paddle::Tensor> BlockAttnKernel(
param.kv_head_num, // v_cache_scale_inv
nullptr, // k_cache_zp
nullptr, // v_cache_zp
false); // b_c8_pc
false, // b_c8_pc
rope_3d, // rope_3d
rope_3d_num_seqs);
XFTBLOCK_CHECK_EQ(ret, api::SUCCESS);

// attn decode
Expand Down Expand Up @@ -314,6 +326,7 @@ PD_BUILD_OP(block_attn)
"decoder_context_len_cpu",
"decoder_batch_map_cpu",
})
.Attrs({"pos_emb_type:std::string", "rope_3d:bool"})
.Outputs({"block_attn_out"})
.SetKernelFn(PD_KERNEL(BlockAttnKernel))
.SetInferShapeFn(PD_INFER_SHAPE(BlockAttnInferShape))
Expand Down
60 changes: 60 additions & 0 deletions custom_ops/xpu_ops/src/ops/get_img_boundaries.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/extension.h"

std::vector<paddle::Tensor> GetImgBoundaries(const paddle::Tensor& task_input_ids,
const paddle::Tensor& grid_thw,
const int64_t image_patch_id) {
// All tensor in cpu
auto input_ids_ptr = task_input_ids.data<int64_t>();
int64_t seq_lens_origin = task_input_ids.numel();
auto grid_thw_ptr = grid_thw.data<int64_t>();

int token_times = 4;
int token_idx = 0;
int image_idx = 0;
std::vector<int> img_boundaries, img_nums;
img_boundaries.emplace_back(0);
img_nums.emplace_back(0);
while (token_idx < seq_lens_origin) {
if (input_ids_ptr[token_idx] != image_patch_id) {
do {
token_idx++;
} while (token_idx < seq_lens_origin && input_ids_ptr[token_idx] != image_patch_id);
} else {
int cur_image_token_len = (grid_thw_ptr[image_idx * 3 + 1] * grid_thw_ptr[image_idx * 3 + 2]) / token_times;
image_idx++;
token_idx += cur_image_token_len;
}
img_boundaries.emplace_back(token_idx);
img_nums.emplace_back(image_idx);
}

int64_t num_img_boundaries = static_cast<int64_t>(img_boundaries.size());
auto out = paddle::full({2, num_img_boundaries}, 0, paddle::DataType::INT64, paddle::CPUPlace());

for (int i = 0; i < num_img_boundaries; i++) {
out.data<int64_t>()[i] = img_boundaries[i];
out.data<int64_t>()[num_img_boundaries + i] = img_nums[i];
}

return {out};
}

PD_BUILD_OP(get_img_boundaries)
.Inputs({"task_input_ids", "grid_thw"})
.Attrs({"image_patch_id: int64_t"})
.Outputs({"img_boundaries"})
.SetKernelFn(PD_KERNEL(GetImgBoundaries));
6 changes: 4 additions & 2 deletions custom_ops/xpu_ops/src/ops/moe_layer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,17 @@ std::vector<paddle::Tensor> MoeLayerKernel(
? up_gate_proj_weight_scale.get_ptr()->data<float>()
: nullptr),
xftblock_tw,
std::vector<int64_t>{expert_num, inter_dim, hidden_dim});
std::vector<int64_t>{expert_num, inter_dim, hidden_dim}
);

xdown_proj_w = std::make_shared<xftblock::Tensor>(
const_cast<TW *>(down_proj_weight.data<TW>()), nullptr,
const_cast<float *>(down_proj_weight_scale.get_ptr()
? down_proj_weight_scale.get_ptr()->data<float>()
: nullptr),
xftblock_tw,
std::vector<int64_t>{expert_num, hidden_dim, outer_dim});
std::vector<int64_t>{expert_num, hidden_dim, outer_dim}
);
}
std::shared_ptr<xftblock::Tensor> xup_gate_proj_bias;
std::shared_ptr<xftblock::Tensor> xdown_proj_bias;
Expand Down
83 changes: 83 additions & 0 deletions custom_ops/xpu_ops/src/ops/text_image_gather_scatter.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <paddle/phi/backends/xpu/xpu_context.h>
#include <xft/xdnn_plugin.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"

void TextImageGatherScatter(
paddle::Tensor& input,
paddle::Tensor& text_input,
paddle::Tensor& image_input,
paddle::Tensor& token_type_ids,
paddle::Tensor& text_index,
paddle::Tensor& image_index,
const bool is_scatter) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);

const int64_t token_num = input.dims()[0];
const int64_t hidden_size = input.dims()[1];
const int64_t text_token_num = text_input.dims()[0];
const int64_t image_token_num = image_input.dims()[0];

switch (input.type()) {
case paddle::DataType::BFLOAT16: {
using XPUType = typename XPUTypeTrait<bfloat16>::Type;
typedef paddle::bfloat16 data_t;
int r = baidu::xpu::api::plugin::text_image_gather_scatter<XPUType>(
xpu_ctx->x_context(),
reinterpret_cast<XPUType*>(input.data<data_t>()),
reinterpret_cast<XPUType*>(text_input.data<data_t>()),
reinterpret_cast<XPUType*>(image_input.data<data_t>()),
reinterpret_cast<int*>(token_type_ids.data<int>()),
reinterpret_cast<int*>(text_index.data<int>()),
reinterpret_cast<int*>(image_index.data<int>()),
token_num,
text_token_num,
image_token_num,
hidden_size,
is_scatter
);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "text_image_gather_scatter");
break;
}
default: {
PD_THROW(
"NOT supported data type. Only support BFLOAT16. ");
break;
}
}
}


PD_BUILD_OP(text_image_gather_scatter)
.Inputs({"input",
"text_input",
"image_input",
"token_type_ids",
"text_index",
"image_index"})
.Outputs({"text_input_out",
"image_input_out",
"text_index_out",
"image_index_out"})
.Attrs({"is_scatter:bool"})
.SetInplaceMap({{"text_input", "text_input_out"},
{"image_input", "image_input_out"},
{"text_index", "text_index_out"},
{"image_index", "image_index_out"}})
.SetKernelFn(PD_KERNEL(TextImageGatherScatter));
48 changes: 48 additions & 0 deletions custom_ops/xpu_ops/src/ops/text_image_index_out.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"

void TextImageIndexOut(
const paddle::Tensor& token_type_ids,
const paddle::Tensor& text_index,
const paddle::Tensor& image_index) {
if (token_type_ids.type() != paddle::DataType::INT32 || text_index.type()
!= paddle::DataType::INT32 || image_index.type() != paddle::DataType::INT32) {
PD_THROW("NOT supported data type. Only support BFLOAT16. ");
}
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
const int64_t token_num = token_type_ids.shape()[0];
int r = baidu::xpu::api::plugin::text_image_index_out(xpu_ctx->x_context(),
token_type_ids.data<int32_t>(),
const_cast<int32_t*>(text_index.data<int32_t>()),
const_cast<int32_t*>(image_index.data<int32_t>()),
token_num);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "text_image_index_out");
}


PD_BUILD_OP(text_image_index_out)
.Inputs({"token_type_ids",
"text_index",
"image_index"})
.Outputs({"text_index_out",
"image_index_out"})
.SetInplaceMap({{"text_index", "text_index_out"},
{"image_index", "image_index_out"}})
.SetKernelFn(PD_KERNEL(TextImageIndexOut));
19 changes: 19 additions & 0 deletions custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,25 @@ DLL_EXPORT int quant2d_per_channel(api::Context *ctx, const TX *x,
const TSCALE *scale_in, TY *y,
TSCALE *scale_out, int64_t m, int64_t n);

DLL_EXPORT int text_image_index_out(Context* ctx,
const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int64_t token_num);

template <typename T>
DLL_EXPORT int text_image_gather_scatter(api::Context* ctx,
T* input,
T* text_input,
T* image_input,
int* token_type_ids,
int* text_index,
int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter);

/*--------------------------------------- MTP being --------------------------------------------*/

Expand Down
Loading
Loading