Skip to content

Commit 07aacec

Browse files
committed
fix: e2e tests
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
1 parent ddf54ef commit 07aacec

File tree

2 files changed

+70
-32
lines changed

2 files changed

+70
-32
lines changed

csrc/kernels/get_masked_input_and_mask_kernel.cpp

Lines changed: 63 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class GetMaskedInputAndMask {
5454
pipe.InitBuffer(maskQueue, 1, size_ * sizeof(bool));
5555

5656
// Initialize calculation buffers
57+
// NOTE: calc_buf_1 and calc_buf_2 are also used for int16 casting on older archs.
5758
pipe.InitBuffer(calc_buf_1, size_ * sizeof(float));
5859
pipe.InitBuffer(calc_buf_2, size_ * sizeof(float));
5960

@@ -66,7 +67,7 @@ class GetMaskedInputAndMask {
6667
// Initialize temporary buffers
6768
pipe.InitBuffer(start_buf, size_ * sizeof(float));
6869
pipe.InitBuffer(end_buf, size_ * sizeof(float));
69-
pipe.InitBuffer(inputFloat_buf, size_ * sizeof(float));
70+
pipe.InitBuffer(inputFloat_buf, size_ * sizeof(float)); // Also used for half intermediate in casting
7071
pipe.InitBuffer(validOffset_buf, size_ * sizeof(float));
7172
pipe.InitBuffer(vocabMask_buf_, size_ * sizeof(int8_t));
7273
pipe.InitBuffer(ones_buf_, size_ * sizeof(float));
@@ -121,7 +122,6 @@ class GetMaskedInputAndMask {
121122
const float start_value,
122123
const float end_value) {
123124

124-
// Use already initialized buffers
125125
AscendC::LocalTensor<float> start_value_tensor = start_buf.Get<float>();
126126
AscendC::LocalTensor<float> end_value_tensor = end_buf.Get<float>();
127127

@@ -134,7 +134,35 @@ class GetMaskedInputAndMask {
134134
CompareWithValue(ge_result, start_value_tensor, input, true);
135135
CompareWithValue(lt_result, input, end_value_tensor, false);
136136

137+
#if (__CCE_AICORE__ >= 220)
137138
AscendC::And(range_mask, ge_result, lt_result, size_);
139+
#else
140+
{
141+
// WORKAROUND for older arch
142+
// No direct int8->int16 cast. Use half as intermediate.
143+
// No direct int8 And. Use int16 And.
144+
AscendC::LocalTensor<int16_t> ge_result_i16 = calc_buf_1.Get<int16_t>();
145+
AscendC::LocalTensor<int16_t> lt_result_i16 = calc_buf_2.Get<int16_t>();
146+
AscendC::LocalTensor<int16_t> range_mask_i16 = ge_result_i16;
147+
148+
// Use a temporary buffer for half type
149+
AscendC::LocalTensor<half> tmp_half = inputFloat_buf.Get<half>();
150+
151+
// 1. Cast inputs: int8_t -> half -> int16_t
152+
AscendC::Cast(tmp_half, ge_result, AscendC::RoundMode::CAST_NONE, size_);
153+
AscendC::Cast(ge_result_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
154+
155+
AscendC::Cast(tmp_half, lt_result, AscendC::RoundMode::CAST_NONE, size_);
156+
AscendC::Cast(lt_result_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
157+
158+
// 2. Perform And on int16_t tensors
159+
AscendC::And(range_mask_i16, ge_result_i16, lt_result_i16, size_);
160+
161+
// 3. Cast result back: int16_t -> half -> int8_t
162+
AscendC::Cast(tmp_half, range_mask_i16, AscendC::RoundMode::CAST_NONE, size_);
163+
AscendC::Cast(range_mask, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
164+
}
165+
#endif
138166
}
139167

140168
__aicore__ inline void Compute() {
@@ -145,24 +173,18 @@ class GetMaskedInputAndMask {
145173
AscendC::LocalTensor<float> inputFloat = inputFloat_buf.Get<float>();
146174
AscendC::Cast(inputFloat, inputLocal, AscendC::RoundMode::CAST_NONE, size_);
147175

148-
// Calculate mask for org_vocab range
149-
// org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index)
150176
AscendC::LocalTensor<int8_t> orgVocabMask = result_org_mask_que.AllocTensor<int8_t>();
151177
ComputeRangeMask(orgVocabMask,
152178
inputFloat,
153179
static_cast<float>(org_vocab_start_index_),
154180
static_cast<float>(org_vocab_end_index_));
155181

156-
// Calculate mask for added_vocab range
157-
// added_vocab_mask = (input_ >= added_vocab_start_index) & (input_ < added_vocab_end_index)
158182
AscendC::LocalTensor<int8_t> addedVocabMask = result_add_mask_que.AllocTensor<int8_t>();
159183
ComputeRangeMask(addedVocabMask,
160184
inputFloat,
161185
static_cast<float>(added_vocab_start_index_),
162186
static_cast<float>(added_vocab_end_index_));
163187

164-
// Calculate validOffset
165-
// valid_offset = (org_vocab_start_index * org_vocab_mask) + (added_offset * added_vocab_mask)
166188
AscendC::LocalTensor<float> validOffset = validOffset_buf.Get<float>();
167189
AscendC::LocalTensor<float> constOrgStartIndex = start_buf.Get<float>();
168190

@@ -173,10 +195,7 @@ class GetMaskedInputAndMask {
173195
AscendC::Cast(orgVocabMask_fp16, orgVocabMask, AscendC::RoundMode::CAST_NONE, size_);
174196
AscendC::Cast(orgVocabMask_fp32, orgVocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_);
175197

176-
AscendC::Mul(validOffset,
177-
constOrgStartIndex,
178-
orgVocabMask_fp32,
179-
size_);
198+
AscendC::Mul(validOffset, constOrgStartIndex, orgVocabMask_fp32, size_);
180199

181200
AscendC::LocalTensor<float> addedOffset;
182201
AscendC::LocalTensor<float> addedOffsetTensor = end_buf.Get<float>();
@@ -187,44 +206,61 @@ class GetMaskedInputAndMask {
187206
AscendC::Cast(addedVocabMask_fp16, addedVocabMask, AscendC::RoundMode::CAST_NONE, size_);
188207
AscendC::Cast(addedVocabMask_fp32, addedVocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_);
189208

190-
AscendC::Mul(addedOffset,
191-
addedOffsetTensor,
192-
addedVocabMask_fp32,
193-
size_);
194-
209+
AscendC::Mul(addedOffset, addedOffsetTensor, addedVocabMask_fp32, size_);
195210
AscendC::Add(validOffset, validOffset, addedOffset, size_);
196211

197-
// vocab_mask = org_vocab_mask | added_vocab_mask
198212
AscendC::LocalTensor<int8_t> vocabMask = vocabMask_buf_.Get<int8_t>();
199-
213+
214+
#if (__CCE_AICORE__ >= 220)
200215
AscendC::Or(vocabMask,
201216
orgVocabMask,
202217
addedVocabMask,
203218
size_);
204-
219+
#else
220+
{
221+
// WORKAROUND for older arch
222+
// No direct int8->int16 cast. Use half as intermediate.
223+
// No direct int8 Or. Use int16 Or.
224+
AscendC::LocalTensor<int16_t> orgVocabMask_i16 = calc_buf_1.Get<int16_t>();
225+
AscendC::LocalTensor<int16_t> addedVocabMask_i16 = calc_buf_2.Get<int16_t>();
226+
AscendC::LocalTensor<int16_t> vocabMask_i16 = orgVocabMask_i16;
227+
228+
// Use a temporary buffer for half type. inputFloat_buf is free now.
229+
AscendC::LocalTensor<half> tmp_half = inputFloat_buf.Get<half>();
230+
231+
// 1. Cast inputs: int8_t -> half -> int16_t
232+
AscendC::Cast(tmp_half, orgVocabMask, AscendC::RoundMode::CAST_NONE, size_);
233+
AscendC::Cast(orgVocabMask_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
234+
235+
AscendC::Cast(tmp_half, addedVocabMask, AscendC::RoundMode::CAST_NONE, size_);
236+
AscendC::Cast(addedVocabMask_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
237+
238+
// 2. Perform Or on int16_t tensors
239+
AscendC::Or(vocabMask_i16, orgVocabMask_i16, addedVocabMask_i16, size_);
240+
241+
// 3. Cast result back: int16_t -> half -> int8_t
242+
AscendC::Cast(tmp_half, vocabMask_i16, AscendC::RoundMode::CAST_NONE, size_);
243+
AscendC::Cast(vocabMask, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
244+
}
245+
#endif
246+
205247
AscendC::Sub(inputFloat, inputFloat, validOffset, size_);
206248

207-
// input_ = vocab_mask * (input_ - valid_offset)
208249
AscendC::LocalTensor<half> vocabMask_fp16;
209250
AscendC::LocalTensor<float> vocabMask_fp32;
210251
AscendC::Cast(vocabMask_fp16, vocabMask, AscendC::RoundMode::CAST_NONE, size_);
211252
AscendC::Cast(vocabMask_fp32, vocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_);
212253

213-
AscendC::LocalTensor<float> inputFloat_fp32;
214254
AscendC::Mul(inputFloat, inputFloat, vocabMask_fp32, size_);
215255

216256
AscendC::Cast(maskedLocal, inputFloat, AscendC::RoundMode::CAST_CEIL, size_);
217257
outQueue.EnQue(maskedLocal);
218258

219-
// ~vocab_mask
220259
AscendC::LocalTensor<float> ones_tensor = ones_buf_.Get<float>();
221260
AscendC::Duplicate(ones_tensor, (float)1, size_);
222261
AscendC::LocalTensor<float> maskLocal_fp32;
223262

224-
AscendC::Sub(maskLocal_fp32,
225-
ones_tensor,
226-
vocabMask_fp32,
227-
size_);
263+
AscendC::Sub(maskLocal_fp32, ones_tensor, vocabMask_fp32, size_);
228264

229265
AscendC::LocalTensor<half> maskLocal_fp16;
230266
AscendC::Cast(maskLocal_fp16, maskLocal_fp32, AscendC::RoundMode::CAST_NONE, size_);
@@ -262,8 +298,6 @@ class GetMaskedInputAndMask {
262298
// Temporary buffers
263299
AscendC::TBuf<AscendC::TPosition::VECCALC> start_buf;
264300
AscendC::TBuf<AscendC::TPosition::VECCALC> end_buf;
265-
266-
// Temporary buffers continued
267301
AscendC::TBuf<AscendC::TPosition::VECCALC> inputFloat_buf;
268302
AscendC::TBuf<AscendC::TPosition::VECCALC> validOffset_buf;
269303
AscendC::TBuf<AscendC::TPosition::VECCALC> vocabMask_buf_;
@@ -342,4 +376,3 @@ void get_masked_input_and_mask_impl(
342376
}
343377

344378
} // namespace vllm_ascend
345-

vllm_ascend/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757

5858
CUSTOM_OP_ENABLED = None
5959

60-
# 310P3 202, 910B4 224
6160
SOC_VERSION = None
6261
SOC_VERSION_INFERENCE_SERIES = ["Ascend310P3"]
6362

@@ -66,9 +65,15 @@
6665

6766

6867
def is_310p():
68+
if not torch.npu.is_available():
69+
return False
70+
device_count = torch.npu.device_count()
71+
if device_count <= 0:
72+
return False
73+
current_device = torch.npu.current_device()
6974
global SOC_VERSION
7075
if SOC_VERSION is None:
71-
SOC_VERSION = torch.npu.get_device_name(0)
76+
SOC_VERSION = torch.npu.get_device_name(current_device)
7277
return SOC_VERSION in SOC_VERSION_INFERENCE_SERIES
7378

7479

0 commit comments

Comments
 (0)