You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
currently when we call TensorView::merge(int axis), it accumulates the non-contiguous-ness and stops dimension collapsing once it hits the first non-contig axis.
This is due to the limitation of our implementation under the hood. To better address this trade-off, we prioritize dimension collapsing for faster dimension in PR #294
e.g. for a tensor with a non-contig in the middle like
x = torch.randn(8, 8, 8, 8)
x = x[:, ::2, :, :]
If we merge from left to right as in
for (x_tensor_view->nDims() > 1) {
x_tensor_view->merge(0);
}
We'll have collapsing from dim0 to dim1, but nothing on the right hand side;
vs merging from right to left
for (x_tensor_view->nDims() > 1) {
x_tensor_view->merge(-2);
}
We'll have collapsing from dim2 to dim3 (and dim3 to squeeze out in the kernel, because of stride 1) instead.
In the long term, we should be able to properly recognize contiguous axes and collapsing both sides. Starting this thread to track the issue.
The text was updated successfully, but these errors were encountered:
Agreed this should be fixed, but fixing it is unfortunately tricky. We'd likely need a sense of merge reordering, will have to think about this more to figure out if there's an easier/more natural change to address this.
🚀 Feature
currently when we call
TensorView::merge(int axis)
, it accumulates the non-contiguous-ness and stops dimension collapsing once it hits the first non-contig axis.This is due to the limitation of our implementation under the hood. To better address this trade-off, we prioritize dimension collapsing for faster dimension in PR #294
e.g. for a tensor with a non-contig in the middle like
If we merge from left to right as in
We'll have collapsing from
dim0
todim1
, but nothing on the right hand side;vs merging from right to left
We'll have collapsing from
dim2
todim3
(anddim3
to squeeze out in the kernel, because of stride 1) instead.In the long term, we should be able to properly recognize contiguous axes and collapsing both sides. Starting this thread to track the issue.
The text was updated successfully, but these errors were encountered: