Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 63 additions & 30 deletions csrc/kernels/get_masked_input_and_mask_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class GetMaskedInputAndMask {
pipe.InitBuffer(maskQueue, 1, size_ * sizeof(bool));

// Initialize calculation buffers
// NOTE: calc_buf_1 and calc_buf_2 are also used for int16 casting on older archs.
pipe.InitBuffer(calc_buf_1, size_ * sizeof(float));
pipe.InitBuffer(calc_buf_2, size_ * sizeof(float));

Expand All @@ -66,7 +67,7 @@ class GetMaskedInputAndMask {
// Initialize temporary buffers
pipe.InitBuffer(start_buf, size_ * sizeof(float));
pipe.InitBuffer(end_buf, size_ * sizeof(float));
pipe.InitBuffer(inputFloat_buf, size_ * sizeof(float));
pipe.InitBuffer(inputFloat_buf, size_ * sizeof(float)); // Also used for half intermediate in casting
pipe.InitBuffer(validOffset_buf, size_ * sizeof(float));
pipe.InitBuffer(vocabMask_buf_, size_ * sizeof(int8_t));
pipe.InitBuffer(ones_buf_, size_ * sizeof(float));
Expand Down Expand Up @@ -121,7 +122,6 @@ class GetMaskedInputAndMask {
const float start_value,
const float end_value) {

// Use already initialized buffers
AscendC::LocalTensor<float> start_value_tensor = start_buf.Get<float>();
AscendC::LocalTensor<float> end_value_tensor = end_buf.Get<float>();

Expand All @@ -134,7 +134,35 @@ class GetMaskedInputAndMask {
CompareWithValue(ge_result, start_value_tensor, input, true);
CompareWithValue(lt_result, input, end_value_tensor, false);

#if (__CCE_AICORE__ >= 220)
AscendC::And(range_mask, ge_result, lt_result, size_);
#else
{
// WORKAROUND for older arch
// No direct int8->int16 cast. Use half as intermediate.
// No direct int8 And. Use int16 And.
AscendC::LocalTensor<int16_t> ge_result_i16 = calc_buf_1.Get<int16_t>();
AscendC::LocalTensor<int16_t> lt_result_i16 = calc_buf_2.Get<int16_t>();
AscendC::LocalTensor<int16_t> range_mask_i16 = ge_result_i16;

// Use a temporary buffer for half type
AscendC::LocalTensor<half> tmp_half = inputFloat_buf.Get<half>();

// 1. Cast inputs: int8_t -> half -> int16_t
AscendC::Cast(tmp_half, ge_result, AscendC::RoundMode::CAST_NONE, size_);
AscendC::Cast(ge_result_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_);

AscendC::Cast(tmp_half, lt_result, AscendC::RoundMode::CAST_NONE, size_);
AscendC::Cast(lt_result_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_);

// 2. Perform And on int16_t tensors
AscendC::And(range_mask_i16, ge_result_i16, lt_result_i16, size_);

// 3. Cast result back: int16_t -> half -> int8_t
AscendC::Cast(tmp_half, range_mask_i16, AscendC::RoundMode::CAST_NONE, size_);
AscendC::Cast(range_mask, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
}
#endif
}

__aicore__ inline void Compute() {
Expand All @@ -145,24 +173,18 @@ class GetMaskedInputAndMask {
AscendC::LocalTensor<float> inputFloat = inputFloat_buf.Get<float>();
AscendC::Cast(inputFloat, inputLocal, AscendC::RoundMode::CAST_NONE, size_);

// Calculate mask for org_vocab range
// org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index)
AscendC::LocalTensor<int8_t> orgVocabMask = result_org_mask_que.AllocTensor<int8_t>();
ComputeRangeMask(orgVocabMask,
inputFloat,
static_cast<float>(org_vocab_start_index_),
static_cast<float>(org_vocab_end_index_));

// Calculate mask for added_vocab range
// added_vocab_mask = (input_ >= added_vocab_start_index) & (input_ < added_vocab_end_index)
AscendC::LocalTensor<int8_t> addedVocabMask = result_add_mask_que.AllocTensor<int8_t>();
ComputeRangeMask(addedVocabMask,
inputFloat,
static_cast<float>(added_vocab_start_index_),
static_cast<float>(added_vocab_end_index_));

// Calculate validOffset
// valid_offset = (org_vocab_start_index * org_vocab_mask) + (added_offset * added_vocab_mask)
AscendC::LocalTensor<float> validOffset = validOffset_buf.Get<float>();
AscendC::LocalTensor<float> constOrgStartIndex = start_buf.Get<float>();

Expand All @@ -173,10 +195,7 @@ class GetMaskedInputAndMask {
AscendC::Cast(orgVocabMask_fp16, orgVocabMask, AscendC::RoundMode::CAST_NONE, size_);
AscendC::Cast(orgVocabMask_fp32, orgVocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_);

AscendC::Mul(validOffset,
constOrgStartIndex,
orgVocabMask_fp32,
size_);
AscendC::Mul(validOffset, constOrgStartIndex, orgVocabMask_fp32, size_);

AscendC::LocalTensor<float> addedOffset;
AscendC::LocalTensor<float> addedOffsetTensor = end_buf.Get<float>();
Expand All @@ -187,44 +206,61 @@ class GetMaskedInputAndMask {
AscendC::Cast(addedVocabMask_fp16, addedVocabMask, AscendC::RoundMode::CAST_NONE, size_);
AscendC::Cast(addedVocabMask_fp32, addedVocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_);

AscendC::Mul(addedOffset,
addedOffsetTensor,
addedVocabMask_fp32,
size_);

AscendC::Mul(addedOffset, addedOffsetTensor, addedVocabMask_fp32, size_);
AscendC::Add(validOffset, validOffset, addedOffset, size_);

// vocab_mask = org_vocab_mask | added_vocab_mask
AscendC::LocalTensor<int8_t> vocabMask = vocabMask_buf_.Get<int8_t>();


#if (__CCE_AICORE__ >= 220)
AscendC::Or(vocabMask,
orgVocabMask,
addedVocabMask,
size_);

#else
{
// WORKAROUND for older arch
// No direct int8->int16 cast. Use half as intermediate.
// No direct int8 Or. Use int16 Or.
AscendC::LocalTensor<int16_t> orgVocabMask_i16 = calc_buf_1.Get<int16_t>();
AscendC::LocalTensor<int16_t> addedVocabMask_i16 = calc_buf_2.Get<int16_t>();
AscendC::LocalTensor<int16_t> vocabMask_i16 = orgVocabMask_i16;

// Use a temporary buffer for half type. inputFloat_buf is free now.
AscendC::LocalTensor<half> tmp_half = inputFloat_buf.Get<half>();

// 1. Cast inputs: int8_t -> half -> int16_t
AscendC::Cast(tmp_half, orgVocabMask, AscendC::RoundMode::CAST_NONE, size_);
AscendC::Cast(orgVocabMask_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_);

AscendC::Cast(tmp_half, addedVocabMask, AscendC::RoundMode::CAST_NONE, size_);
AscendC::Cast(addedVocabMask_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_);

// 2. Perform Or on int16_t tensors
AscendC::Or(vocabMask_i16, orgVocabMask_i16, addedVocabMask_i16, size_);

// 3. Cast result back: int16_t -> half -> int8_t
AscendC::Cast(tmp_half, vocabMask_i16, AscendC::RoundMode::CAST_NONE, size_);
AscendC::Cast(vocabMask, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
}
#endif

AscendC::Sub(inputFloat, inputFloat, validOffset, size_);

// input_ = vocab_mask * (input_ - valid_offset)
AscendC::LocalTensor<half> vocabMask_fp16;
AscendC::LocalTensor<float> vocabMask_fp32;
AscendC::Cast(vocabMask_fp16, vocabMask, AscendC::RoundMode::CAST_NONE, size_);
AscendC::Cast(vocabMask_fp32, vocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_);

AscendC::LocalTensor<float> inputFloat_fp32;
AscendC::Mul(inputFloat, inputFloat, vocabMask_fp32, size_);

AscendC::Cast(maskedLocal, inputFloat, AscendC::RoundMode::CAST_CEIL, size_);
outQueue.EnQue(maskedLocal);

// ~vocab_mask
AscendC::LocalTensor<float> ones_tensor = ones_buf_.Get<float>();
AscendC::Duplicate(ones_tensor, (float)1, size_);
AscendC::LocalTensor<float> maskLocal_fp32;

AscendC::Sub(maskLocal_fp32,
ones_tensor,
vocabMask_fp32,
size_);
AscendC::Sub(maskLocal_fp32, ones_tensor, vocabMask_fp32, size_);

AscendC::LocalTensor<half> maskLocal_fp16;
AscendC::Cast(maskLocal_fp16, maskLocal_fp32, AscendC::RoundMode::CAST_NONE, size_);
Expand Down Expand Up @@ -262,8 +298,6 @@ class GetMaskedInputAndMask {
// Temporary buffers
AscendC::TBuf<AscendC::TPosition::VECCALC> start_buf;
AscendC::TBuf<AscendC::TPosition::VECCALC> end_buf;

// Temporary buffers continued
AscendC::TBuf<AscendC::TPosition::VECCALC> inputFloat_buf;
AscendC::TBuf<AscendC::TPosition::VECCALC> validOffset_buf;
AscendC::TBuf<AscendC::TPosition::VECCALC> vocabMask_buf_;
Expand Down Expand Up @@ -342,4 +376,3 @@ void get_masked_input_and_mask_impl(
}

} // namespace vllm_ascend

20 changes: 15 additions & 5 deletions csrc/kernels/pos_encoding_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ using vllm_ascend::local_mem_copy;
template <typename scalar_t, bool isNeox> class RotaryEmbedding {
// NOTE(ganyi): we use 512B as load stride for pipe, need to find another way to
// retrieve this size from runtime for more Soc support
static int constexpr loadSize = 512;
#if (__CCE_AICORE__ >= 220)
static int constexpr loadSize = 512;
#else
static int constexpr loadSize = 1024 * 4;
#endif
using dst_t = scalar_t;
using acc_t = typename AccType<scalar_t>::type;
// only half tensor have cast instruct to int8, hardcode acc_dst_t as half
Expand Down Expand Up @@ -326,7 +330,9 @@ template <typename scalar_t, bool isNeox> class RotaryEmbedding {

// Declare all the kernel entry here
ROPE_CUSTOM_KERNEL_DECLARE(half)
ROPE_CUSTOM_KERNEL_DECLARE(bfloat16_t)
#if (__CCE_AICORE__ >= 220)
ROPE_CUSTOM_KERNEL_DECLARE(bfloat16_t)
#endif

namespace vllm_ascend {

Expand All @@ -342,7 +348,7 @@ namespace vllm_ascend {
reinterpret_cast<TYPE *>(cosSinCache), rotDim, queryStride, keyStride, dstQueryStride, dstKeyStride, \
numHeads, numKvHeads, headSize, numTokens, loopCnt, blockDim);

// maximum number for runtime to launch a ascendc kernel.
// maximum number for runtime to launch a ascendc kernel.
// we use this to constrain the maximum number of block size
static const int64_t maxParallelSize = 65535;

Expand All @@ -357,9 +363,13 @@ extern void rotary_embedding_impl(AscendType type, bool isNeox, void *stream, in
int blockDim = maxParallelSize > numTokens ? numTokens : maxParallelSize;
if (type == AscendType::FP16) {
ROTARY_EMBEDDING_KERNEL_CALL(half);
} else if (type == AscendType::BF16) {
}
#if (__CCE_AICORE__ >= 220)
else if (type == AscendType::BF16) {
ROTARY_EMBEDDING_KERNEL_CALL(bfloat16_t);
} else {
}
#endif
else {
return;
}
}
Expand Down
4 changes: 3 additions & 1 deletion csrc/kernels/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ namespace vllm_ascend {

template <typename scalar_t> struct AccType;

#if (__CCE_AICORE__ >= 220)
template <> struct AccType<bfloat16_t> {
using type = float;
using type = float;
};
#endif

template <> struct AccType<half> {
using type = half;
Expand Down
2 changes: 1 addition & 1 deletion format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ echo 'vllm-ascend isort: Done'
# Clang-format section
# Exclude some files for formatting because they are vendored
CLANG_FORMAT_EXCLUDES=(
'csrc/kernels/pos_encoding_kernels.cpp' 'csrc/kernels/advance_step.cpp' 'csrc/kernels/get_masked_input_and_mask_kernel.cpp' 'csrc/torch_binding.cpp' 'csrc/ops.h'
'csrc/kernels/utils.h' 'csrc/kernels/pos_encoding_kernels.cpp' 'csrc/kernels/advance_step.cpp' 'csrc/kernels/get_masked_input_and_mask_kernel.cpp' 'csrc/torch_binding.cpp' 'csrc/ops.h'
)

# Format specified files with clang-format
Expand Down
31 changes: 29 additions & 2 deletions vllm_ascend/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@

from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ops.cache import concat_and_cache_mla
from vllm_ascend.utils import enable_custom_op
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16,
enable_custom_op, is_310p, nd_to_nz_2d)
from vllm_ascend.worker.model_runner import (
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)

Expand Down Expand Up @@ -170,7 +171,11 @@ def get_kv_cache_shape(
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size, num_kv_heads, head_size)
if is_310p():
return (2, num_blocks, num_kv_heads * head_size // 16, block_size,
16)
else:
return (2, num_blocks, block_size, num_kv_heads, head_size)

@staticmethod
def swap_blocks(
Expand Down Expand Up @@ -654,6 +659,11 @@ def build(
# normal mask
self.attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
max_prefill_seq_len, dtype, device)
if is_310p():
mask_nz = nd_to_nz_2d(self.attn_mask)
mask_nz = torch_npu.npu_format_cast(
mask_nz.contiguous(), ACL_FORMAT_FRACTAL_NZ)
self.attn_mask = mask_nz
elif self.num_decode_tokens == 0 and not self.input_builder.chunked_prefill_enabled:
# compress mask for prefix cache
self.compress_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
Expand Down Expand Up @@ -868,6 +878,18 @@ def forward(
self.seq_lens_tensor_cpu = torch.from_numpy(
np.array(attn_metadata.prefill_metadata.seq_lens).
astype(np.int32))
if is_310p():
# align q k v output tensors
query = aligned_16(query)
key = aligned_16(key)
value = aligned_16(value)
output = aligned_16(output)

# do reformat in case of broadcasted tensors
mask = mask.repeat(
self.seq_lens_tensor_cpu.size(0), 1, 1, 1)
mask = torch_npu.npu_format_cast(
mask.contiguous(), ACL_FORMAT_FRACTAL_NZ)
torch_npu._npu_flash_attention(
query=query,
key=key,
Expand All @@ -878,6 +900,7 @@ def forward(
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output)
output = output[:num_tokens, :, :]
# Prefix cache only and cache hit
elif attn_metadata.num_decode_tokens == 0 and not attn_metadata.chunked_prefill_enabled:
assert kv_cache is not None
Expand Down Expand Up @@ -935,6 +958,10 @@ def forward(
self.seq_lens_tensor_cpu = torch.from_numpy(
np.array(attn_metadata.decode_metadata.seq_lens).astype(
np.int32))
if is_310p():
# # seq_lens_tensor needs to be transferred to the device for 310P
self.seq_lens_tensor_cpu = self.seq_lens_tensor_cpu.to(
device=self.key_cache.device)
block_tables = attn_metadata.decode_metadata.block_tables
torch_npu._npu_paged_attention(
query=query,
Expand Down
Loading
Loading