File tree Expand file tree Collapse file tree 1 file changed +14
-1
lines changed Expand file tree Collapse file tree 1 file changed +14
-1
lines changed Original file line number Diff line number Diff line change @@ -198,6 +198,13 @@ void ArgFullSort(const phi::GPUContext& dev_ctx,
198198 const int64_t num_rows,
199199 const int64_t num_cols,
200200 const bool descending) {
201+ PADDLE_ENFORCE_LE (num_cols,
202+ std::numeric_limits<int >::max (),
203+ ::common::errors::PreconditionNotMet (
204+ " The dimension being sorted should be less than "
205+ " 2^31, but got %lld. Please check the input tensor. " ,
206+ num_cols));
207+
201208 auto cu_stream = dev_ctx.stream ();
202209 auto ComputeBlockSize = [](IndType col) {
203210 if (col > 512 )
@@ -228,8 +235,14 @@ void ArgFullSort(const phi::GPUContext& dev_ctx,
228235 const int64_t total_elements = num_cols * num_rows;
229236 const int64_t segment_size = num_cols;
230237 const int64_t element_per_call = std::min (max_elements, total_elements);
238+
239+ // make sure element_per_call >= segment_size
240+ const int64_t adjusted_elements_per_call =
241+ std::max (max_elements, segment_size);
242+
231243 // make sure batch size is the multiple of segment_size
232- const int64_t batch_size = (element_per_call / segment_size) * segment_size;
244+ const int64_t batch_size =
245+ (adjusted_elements_per_call / segment_size) * segment_size;
233246 int64_t offset = 0 ;
234247 DenseTensor input_indices;
235248
You can’t perform that action at this time.
0 commit comments