@@ -54,6 +54,7 @@ class GetMaskedInputAndMask {
5454 pipe.InitBuffer (maskQueue, 1 , size_ * sizeof (bool ));
5555
5656 // Initialize calculation buffers
57+ // NOTE: calc_buf_1 and calc_buf_2 are also used for int16 casting on older archs.
5758 pipe.InitBuffer (calc_buf_1, size_ * sizeof (float ));
5859 pipe.InitBuffer (calc_buf_2, size_ * sizeof (float ));
5960
@@ -66,7 +67,7 @@ class GetMaskedInputAndMask {
6667 // Initialize temporary buffers
6768 pipe.InitBuffer (start_buf, size_ * sizeof (float ));
6869 pipe.InitBuffer (end_buf, size_ * sizeof (float ));
69- pipe.InitBuffer (inputFloat_buf, size_ * sizeof (float ));
70+ pipe.InitBuffer (inputFloat_buf, size_ * sizeof (float )); // Also used for half intermediate in casting
7071 pipe.InitBuffer (validOffset_buf, size_ * sizeof (float ));
7172 pipe.InitBuffer (vocabMask_buf_, size_ * sizeof (int8_t ));
7273 pipe.InitBuffer (ones_buf_, size_ * sizeof (float ));
@@ -121,7 +122,6 @@ class GetMaskedInputAndMask {
121122 const float start_value,
122123 const float end_value) {
123124
124- // Use already initialized buffers
125125 AscendC::LocalTensor<float > start_value_tensor = start_buf.Get <float >();
126126 AscendC::LocalTensor<float > end_value_tensor = end_buf.Get <float >();
127127
@@ -134,7 +134,35 @@ class GetMaskedInputAndMask {
134134 CompareWithValue (ge_result, start_value_tensor, input, true );
135135 CompareWithValue (lt_result, input, end_value_tensor, false );
136136
137+ #if (__CCE_AICORE__ >= 220)
137138 AscendC::And (range_mask, ge_result, lt_result, size_);
139+ #else
140+ {
141+ // WORKAROUND for older arch
142+ // No direct int8->int16 cast. Use half as intermediate.
143+ // No direct int8 And. Use int16 And.
144+ AscendC::LocalTensor<int16_t > ge_result_i16 = calc_buf_1.Get <int16_t >();
145+ AscendC::LocalTensor<int16_t > lt_result_i16 = calc_buf_2.Get <int16_t >();
146+ AscendC::LocalTensor<int16_t > range_mask_i16 = ge_result_i16;
147+
148+ // Use a temporary buffer for half type
149+ AscendC::LocalTensor<half> tmp_half = inputFloat_buf.Get <half>();
150+
151+ // 1. Cast inputs: int8_t -> half -> int16_t
152+ AscendC::Cast (tmp_half, ge_result, AscendC::RoundMode::CAST_NONE, size_);
153+ AscendC::Cast (ge_result_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
154+
155+ AscendC::Cast (tmp_half, lt_result, AscendC::RoundMode::CAST_NONE, size_);
156+ AscendC::Cast (lt_result_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
157+
158+ // 2. Perform And on int16_t tensors
159+ AscendC::And (range_mask_i16, ge_result_i16, lt_result_i16, size_);
160+
161+ // 3. Cast result back: int16_t -> half -> int8_t
162+ AscendC::Cast (tmp_half, range_mask_i16, AscendC::RoundMode::CAST_NONE, size_);
163+ AscendC::Cast (range_mask, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
164+ }
165+ #endif
138166 }
139167
140168 __aicore__ inline void Compute () {
@@ -145,24 +173,18 @@ class GetMaskedInputAndMask {
145173 AscendC::LocalTensor<float > inputFloat = inputFloat_buf.Get <float >();
146174 AscendC::Cast (inputFloat, inputLocal, AscendC::RoundMode::CAST_NONE, size_);
147175
148- // Calculate mask for org_vocab range
149- // org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index)
150176 AscendC::LocalTensor<int8_t > orgVocabMask = result_org_mask_que.AllocTensor <int8_t >();
151177 ComputeRangeMask (orgVocabMask,
152178 inputFloat,
153179 static_cast <float >(org_vocab_start_index_),
154180 static_cast <float >(org_vocab_end_index_));
155181
156- // Calculate mask for added_vocab range
157- // added_vocab_mask = (input_ >= added_vocab_start_index) & (input_ < added_vocab_end_index)
158182 AscendC::LocalTensor<int8_t > addedVocabMask = result_add_mask_que.AllocTensor <int8_t >();
159183 ComputeRangeMask (addedVocabMask,
160184 inputFloat,
161185 static_cast <float >(added_vocab_start_index_),
162186 static_cast <float >(added_vocab_end_index_));
163187
164- // Calculate validOffset
165- // valid_offset = (org_vocab_start_index * org_vocab_mask) + (added_offset * added_vocab_mask)
166188 AscendC::LocalTensor<float > validOffset = validOffset_buf.Get <float >();
167189 AscendC::LocalTensor<float > constOrgStartIndex = start_buf.Get <float >();
168190
@@ -173,10 +195,7 @@ class GetMaskedInputAndMask {
173195 AscendC::Cast (orgVocabMask_fp16, orgVocabMask, AscendC::RoundMode::CAST_NONE, size_);
174196 AscendC::Cast (orgVocabMask_fp32, orgVocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_);
175197
176- AscendC::Mul (validOffset,
177- constOrgStartIndex,
178- orgVocabMask_fp32,
179- size_);
198+ AscendC::Mul (validOffset, constOrgStartIndex, orgVocabMask_fp32, size_);
180199
181200 AscendC::LocalTensor<float > addedOffset;
182201 AscendC::LocalTensor<float > addedOffsetTensor = end_buf.Get <float >();
@@ -187,44 +206,61 @@ class GetMaskedInputAndMask {
187206 AscendC::Cast (addedVocabMask_fp16, addedVocabMask, AscendC::RoundMode::CAST_NONE, size_);
188207 AscendC::Cast (addedVocabMask_fp32, addedVocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_);
189208
190- AscendC::Mul (addedOffset,
191- addedOffsetTensor,
192- addedVocabMask_fp32,
193- size_);
194-
209+ AscendC::Mul (addedOffset, addedOffsetTensor, addedVocabMask_fp32, size_);
195210 AscendC::Add (validOffset, validOffset, addedOffset, size_);
196211
197- // vocab_mask = org_vocab_mask | added_vocab_mask
198212 AscendC::LocalTensor<int8_t > vocabMask = vocabMask_buf_.Get <int8_t >();
199-
213+
214+ #if (__CCE_AICORE__ >= 220)
200215 AscendC::Or (vocabMask,
201216 orgVocabMask,
202217 addedVocabMask,
203218 size_);
204-
219+ #else
220+ {
221+ // WORKAROUND for older arch
222+ // No direct int8->int16 cast. Use half as intermediate.
223+ // No direct int8 Or. Use int16 Or.
224+ AscendC::LocalTensor<int16_t > orgVocabMask_i16 = calc_buf_1.Get <int16_t >();
225+ AscendC::LocalTensor<int16_t > addedVocabMask_i16 = calc_buf_2.Get <int16_t >();
226+ AscendC::LocalTensor<int16_t > vocabMask_i16 = orgVocabMask_i16;
227+
228+ // Use a temporary buffer for half type. inputFloat_buf is free now.
229+ AscendC::LocalTensor<half> tmp_half = inputFloat_buf.Get <half>();
230+
231+ // 1. Cast inputs: int8_t -> half -> int16_t
232+ AscendC::Cast (tmp_half, orgVocabMask, AscendC::RoundMode::CAST_NONE, size_);
233+ AscendC::Cast (orgVocabMask_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
234+
235+ AscendC::Cast (tmp_half, addedVocabMask, AscendC::RoundMode::CAST_NONE, size_);
236+ AscendC::Cast (addedVocabMask_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
237+
238+ // 2. Perform Or on int16_t tensors
239+ AscendC::Or (vocabMask_i16, orgVocabMask_i16, addedVocabMask_i16, size_);
240+
241+ // 3. Cast result back: int16_t -> half -> int8_t
242+ AscendC::Cast (tmp_half, vocabMask_i16, AscendC::RoundMode::CAST_NONE, size_);
243+ AscendC::Cast (vocabMask, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
244+ }
245+ #endif
246+
205247 AscendC::Sub (inputFloat, inputFloat, validOffset, size_);
206248
207- // input_ = vocab_mask * (input_ - valid_offset)
208249 AscendC::LocalTensor<half> vocabMask_fp16;
209250 AscendC::LocalTensor<float > vocabMask_fp32;
210251 AscendC::Cast (vocabMask_fp16, vocabMask, AscendC::RoundMode::CAST_NONE, size_);
211252 AscendC::Cast (vocabMask_fp32, vocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_);
212253
213- AscendC::LocalTensor<float > inputFloat_fp32;
214254 AscendC::Mul (inputFloat, inputFloat, vocabMask_fp32, size_);
215255
216256 AscendC::Cast (maskedLocal, inputFloat, AscendC::RoundMode::CAST_CEIL, size_);
217257 outQueue.EnQue (maskedLocal);
218258
219- // ~vocab_mask
220259 AscendC::LocalTensor<float > ones_tensor = ones_buf_.Get <float >();
221260 AscendC::Duplicate (ones_tensor, (float )1 , size_);
222261 AscendC::LocalTensor<float > maskLocal_fp32;
223262
224- AscendC::Sub (maskLocal_fp32,
225- ones_tensor,
226- vocabMask_fp32,
227- size_);
263+ AscendC::Sub (maskLocal_fp32, ones_tensor, vocabMask_fp32, size_);
228264
229265 AscendC::LocalTensor<half> maskLocal_fp16;
230266 AscendC::Cast (maskLocal_fp16, maskLocal_fp32, AscendC::RoundMode::CAST_NONE, size_);
@@ -262,8 +298,6 @@ class GetMaskedInputAndMask {
262298 // Temporary buffers
263299 AscendC::TBuf<AscendC::TPosition::VECCALC> start_buf;
264300 AscendC::TBuf<AscendC::TPosition::VECCALC> end_buf;
265-
266- // Temporary buffers continued
267301 AscendC::TBuf<AscendC::TPosition::VECCALC> inputFloat_buf;
268302 AscendC::TBuf<AscendC::TPosition::VECCALC> validOffset_buf;
269303 AscendC::TBuf<AscendC::TPosition::VECCALC> vocabMask_buf_;
@@ -342,4 +376,3 @@ void get_masked_input_and_mask_impl(
342376}
343377
344378} // namespace vllm_ascend
345-
0 commit comments