Skip to content

Commit

Permalink
fix uncontiguous tensor resize bug (#60684)
Browse files Browse the repository at this point in the history
* fix uncontiguous tensor resize bug
  • Loading branch information
wanghuancoder authored Jan 10, 2024
1 parent 06fdeeb commit da91813
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
8 changes: 5 additions & 3 deletions paddle/phi/kernels/stride/as_strided_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ void AsStridedKernel(const Context& dev_ctx,
const std::vector<int64_t>& stride,
int64_t offset,
DenseTensor* out) {
out->Resize(DDim(dims.data(), static_cast<int>(dims.size())));
out->set_strides(DDim(stride.data(), static_cast<int>(stride.size())));
out->set_offset(offset);
auto meta = out->meta();
meta.dims = DDim(dims.data(), static_cast<int>(dims.size()));
meta.strides = DDim(stride.data(), static_cast<int>(stride.size()));
meta.offset = offset;
out->set_meta(meta);
out->ResetHolder(input.Holder());
out->ShareInplaceVersionCounterWith(input);
}
Expand Down
8 changes: 5 additions & 3 deletions paddle/phi/kernels/stride/tensor_unfold_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ void TensorUnfoldKernel(const Context& dev_ctx,
}
}

out->Resize(DDim(shape.data(), static_cast<int>(shape.size())));
out->set_strides(DDim(stride.data(), static_cast<int>(stride.size())));
out->set_offset(input.offset());
auto meta = out->meta();
meta.dims = DDim(shape.data(), static_cast<int>(shape.size()));
meta.strides = DDim(stride.data(), static_cast<int>(stride.size()));
meta.offset = input.offset();
out->set_meta(meta);
out->ResetHolder(input.Holder());
out->ShareInplaceVersionCounterWith(input);
}
Expand Down

0 comments on commit da91813

Please sign in to comment.