Skip to content

Commit

Permalink
bug fix, test pass
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jan 4, 2021
1 parent 697956d commit 26254f5
Showing 1 changed file with 76 additions and 66 deletions.
142 changes: 76 additions & 66 deletions src/runtime/contrib/thrust/thrust.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,34 +64,45 @@ void thrust_sort(DLTensor* input,
}
} else {
// segmented sort by key
// Follow the back-to-back stable_sort_by_key strategy explained below
// https://groups.google.com/g/thrust-users/c/BoLsxO6b4FY
thrust::device_vector<int64_t> argsort_order(size);
thrust::sequence(argsort_order.begin(), argsort_order.begin() + size);
thrust::sequence(argsort_order.begin(), argsort_order.end());

auto do_sort_by_key = [n_values, is_ascend](auto keys_ptr, auto values_ptr) {
if (is_ascend) {
thrust::stable_sort_by_key(keys_ptr, keys_ptr + n_values, values_ptr);
} else {
thrust::stable_sort_by_key(keys_ptr, keys_ptr + n_values, values_ptr,
thrust::greater<DataType>());
}
}; // NOLINT(*)

do_sort_by_key(values_ptr, argsort_order.begin());
// First, sort values and store the sorted order in argsort_order.
if (is_ascend) {
thrust::stable_sort_by_key(values_ptr, values_ptr + size, argsort_order.begin());
} else {
thrust::stable_sort_by_key(values_ptr, values_ptr + size, argsort_order.begin(),
thrust::greater<DataType>());
}

// The following is to create the indices array 0, 1, 2, 0, 1, 2 ... 0, 1, 2
// without materializing it
auto counting_iter = thrust::counting_iterator<int64_t>(0);
auto linear_index_to_sort_axis_index = [n_values] __host__ __device__(int64_t i) {
return i % n_values;
}; // NOLINT(*)
auto init_indices_iter = thrust::make_transform_iterator(counting_iter, linear_index_to_sort_axis_index);
auto init_indices_iter = thrust::make_transform_iterator(counting_iter,
linear_index_to_sort_axis_index);

// This will reorder indices 0, 1, 2 ... in the sorted order of values_ptr
thrust::gather(argsort_order.begin(), argsort_order.end(), init_indices_iter, indices_ptr);

thrust::device_vector<int> segment_ids(size);
auto linear_index_to_segment_id = [n_values] __host__ __device__(int64_t i) {
return i / n_values;
}; // NOLINT(*)
thrust::transform(argsort_order.begin(), argsort_order.end(), segment_ids.begin(), linear_index_to_segment_id);
// We also reorder segment indices 0, 0, 0, 1, 1, 1 ... in the order of values_ptr
thrust::transform(argsort_order.begin(), argsort_order.end(), segment_ids.begin(),
linear_index_to_segment_id);

do_sort_by_key(segment_ids.begin(), thrust::make_zip_iterator(thrust::make_tuple(values_ptr, indices_ptr)));
// The second sort key-ed by segment_ids would bring segment_ids back to 0, 0, 0, 1, 1, 1 ...
// values_ptr and indices_ptr will also be sorted in the order of segmend_ids above
// Since sorting has been done in a stable way, relative orderings of values and indices
// in the segment do not change and hence they remain sorted.
auto key_val_zip = thrust::make_zip_iterator(thrust::make_tuple(values_ptr, indices_ptr));
thrust::stable_sort_by_key(segment_ids.begin(), segment_ids.end(), key_val_zip);
}
}

Expand All @@ -102,58 +113,57 @@ void thrust_sort_common(DLTensor* input,
int sort_len,
std::string data_dtype,
std::string out_dtype) {
thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend, sort_len);
// if (data_dtype == "float32") {
// if (out_dtype == "int32") {

// } else if (out_dtype == "int64") {
// thrust_sort<float, int64_t>(input, values_out, indices_out, is_ascend, sort_len);
// } else if (out_dtype == "float32") {
// thrust_sort<float, float>(input, values_out, indices_out, is_ascend, sort_len);
// } else if (out_dtype == "float64") {
// thrust_sort<float, double>(input, values_out, indices_out, is_ascend, sort_len);
// } else {
// LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
// }
// } else if (data_dtype == "float64") {
// if (out_dtype == "int32") {
// thrust_sort<double, int32_t>(input, values_out, indices_out, is_ascend, sort_len);
// } else if (out_dtype == "int64") {
// thrust_sort<double, int64_t>(input, values_out, indices_out, is_ascend, sort_len);
// } else if (out_dtype == "float32") {
// thrust_sort<double, float>(input, values_out, indices_out, is_ascend, sort_len);
// } else if (out_dtype == "float64") {
// thrust_sort<double, double>(input, values_out, indices_out, is_ascend, sort_len);
// } else {
// LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
// }
// } else if (data_dtype == "int32") {
// if (out_dtype == "int32") {
// thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend, sort_len);
// } else if (out_dtype == "int64") {
// thrust_sort<int32_t, int64_t>(input, values_out, indices_out, is_ascend, sort_len);
// } else if (out_dtype == "float32") {
// thrust_sort<int32_t, float>(input, values_out, indices_out, is_ascend, sort_len);
// } else if (out_dtype == "float64") {
// thrust_sort<int32_t, double>(input, values_out, indices_out, is_ascend, sort_len);
// } else {
// LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
// }
// } else if (data_dtype == "int64") {
// if (out_dtype == "int32") {
// thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend, sort_len);
// } else if (out_dtype == "int64") {
// thrust_sort<int64_t, int64_t>(input, values_out, indices_out, is_ascend, sort_len);
// } else if (out_dtype == "float32") {
// thrust_sort<int64_t, float>(input, values_out, indices_out, is_ascend, sort_len);
// } else if (out_dtype == "float64") {
// thrust_sort<int64_t, double>(input, values_out, indices_out, is_ascend, sort_len);
// } else {
// LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
// }
// } else {
// LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
// }
if (data_dtype == "float32") {
if (out_dtype == "int32") {
thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "int64") {
thrust_sort<float, int64_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "float32") {
thrust_sort<float, float>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "float64") {
thrust_sort<float, double>(input, values_out, indices_out, is_ascend, sort_len);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "float64") {
if (out_dtype == "int32") {
thrust_sort<double, int32_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "int64") {
thrust_sort<double, int64_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "float32") {
thrust_sort<double, float>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "float64") {
thrust_sort<double, double>(input, values_out, indices_out, is_ascend, sort_len);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int32") {
if (out_dtype == "int32") {
thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "int64") {
thrust_sort<int32_t, int64_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "float32") {
thrust_sort<int32_t, float>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "float64") {
thrust_sort<int32_t, double>(input, values_out, indices_out, is_ascend, sort_len);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int64") {
if (out_dtype == "int32") {
thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "int64") {
thrust_sort<int64_t, int64_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "float32") {
thrust_sort<int64_t, float>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "float64") {
thrust_sort<int64_t, double>(input, values_out, indices_out, is_ascend, sort_len);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else {
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
}
}

TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
Expand Down

0 comments on commit 26254f5

Please sign in to comment.