diff --git a/backends/npu/custom_op/llama_decoder_layer_parallel_op.cc b/backends/npu/custom_op/llama_decoder_layer_parallel_op.cc index 8c807780b..91470ca6a 100644 --- a/backends/npu/custom_op/llama_decoder_layer_parallel_op.cc +++ b/backends/npu/custom_op/llama_decoder_layer_parallel_op.cc @@ -255,13 +255,13 @@ std::vector LlamaDecoderLayerParallelOp( std::vector layer_id_vec(1, 0); custom_kernel::TensorFromVector(*dev_ctx, layer_id_vec, *dev_ctx, &(g_llamaDecoderLayerParallelOp->layerIdTensor_)); - g_llamaDecoderLayerParallelOp->output_ = std::make_shared(); - g_llamaDecoderLayerParallelOp->output_->Resize(phi::make_ddim(hidden.shape())); - dev_ctx->Alloc(g_llamaDecoderLayerParallelOp->output_.get(), - static_cast(hidden.impl().get())->dtype()); } if (executeCount % layer_num == 0) { // 每个token第一次进layer,更新stop flag + g_llamaDecoderLayerParallelOp->output_ = std::make_shared(); + g_llamaDecoderLayerParallelOp->output_->Resize(phi::make_ddim(hidden.shape())); + dev_ctx->Alloc(g_llamaDecoderLayerParallelOp->output_.get(), + static_cast(hidden.impl().get())->dtype()); g_llamaDecoderLayerParallelOp->UpdateInputTensorAndParam(kv_seq_len); } diff --git a/backends/npu/custom_op/llama_encoder_layer_parallel_op.cc b/backends/npu/custom_op/llama_encoder_layer_parallel_op.cc index 6f490bc18..9725c200e 100644 --- a/backends/npu/custom_op/llama_encoder_layer_parallel_op.cc +++ b/backends/npu/custom_op/llama_encoder_layer_parallel_op.cc @@ -221,13 +221,13 @@ std::vector LlamaEncoderLayerParallelOp( std::vector layer_id_vec(1, 0); custom_kernel::TensorFromVector(*dev_ctx, layer_id_vec, *dev_ctx, &(g_llamaEncoderLayerParallelOp->layerIdTensor_)); - g_llamaEncoderLayerParallelOp->output_ = std::make_shared(); - g_llamaEncoderLayerParallelOp->output_->Resize(phi::make_ddim(hidden.shape())); - dev_ctx->Alloc(g_llamaEncoderLayerParallelOp->output_.get(), - static_cast(hidden.impl().get())->dtype()); } if (executeCount % layer_num == 0) { + g_llamaEncoderLayerParallelOp->output_ = std::make_shared(); + g_llamaEncoderLayerParallelOp->output_->Resize(phi::make_ddim(hidden.shape())); + dev_ctx->Alloc(g_llamaEncoderLayerParallelOp->output_.get(), + static_cast(hidden.impl().get())->dtype()); g_llamaEncoderLayerParallelOp->UpdateInputTensorAndParam(kv_seq_len); }