Skip to content

Commit

Permalink
[alphafold] Transpose support large tensors where there numel is bigg…
Browse files Browse the repository at this point in the history
…er than INT32_MAX (#45753)
  • Loading branch information
FeixLiu authored Sep 7, 2022
1 parent 0ddcf30 commit d9a9e63
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 86 deletions.
14 changes: 8 additions & 6 deletions paddle/fluid/framework/gpu_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,22 +82,24 @@ struct Index3 : DeviceArray<int, 3, 0> {
};

// Flat index with real dimension
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int FlatTensorIndex(const Index3& index,
const Dim3& dims) {
int flat_index = index[0];
template <typename IDX_T = int>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IDX_T FlatTensorIndex(const Index3& index,
const Dim3& dims) {
IDX_T flat_index = index[0];
for (int i = 1; i < 3; i++) {
flat_index = flat_index * dims[i] + index[i];
}
return flat_index;
}

// Convert index to tensor index with dimension.
template <typename IDX_T = int>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index3
ConvertTensorIndex(int index, const Dim3& dims) {
ConvertTensorIndex(IDX_T index, const Dim3& dims) {
Index3 tensor_index;
for (int i = 2; i >= 0; i--) {
int new_index = index / dims[i];
tensor_index[i] = index - dims[i] * new_index;
IDX_T new_index = index / dims[i];
tensor_index[i] = static_cast<int>(index - dims[i] * new_index);
index = new_index;
}
return tensor_index;
Expand Down
Loading

0 comments on commit d9a9e63

Please sign in to comment.