Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf. improvement - save the gather result into dst directly if dst is contiguous (copy_kernel_mps) #44

Merged
merged 2 commits into from
Jun 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/native/mps/OperationUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ std::string getTensorsStringKey(const TensorList& tensors, bool use_scalar_value
double getMPSScalarValue(const Tensor& t);
std::string getArrayRefString(const IntArrayRef s);
// use has_storage() on the returned tensor to determine if src actually is a view
Tensor gatherViewTensor(const at::Tensor& src);
Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst);
Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output);

MPSShape* getMPSShape(const Tensor& t);
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/mps/OperationUtils.mm
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,9 @@ void printTensorNDArray(const Tensor& t) {
id<MTLBuffer> srcBuf = getMTLBufferStorage(src);
// a view tensor could be contiguous (e.g., slice ops) or non-contiguous (e.g., transpose())
if (src.is_view() || !src.is_contiguous()) {
Tensor emptyShell = Tensor();
// use "_tensor" from Placeholder to retain view's output during its usage in other ops
_tensor = gatherViewTensor(src);
_tensor = gatherViewTensor(src, emptyShell);
if (!_tensor.has_storage()) {
// if we cannot gather, we make the the tensor contiguous implicitly, and keep
// it in placeholder to be able to retrieve it when we return from constructor
Expand Down
17 changes: 14 additions & 3 deletions aten/src/ATen/native/mps/operations/Copy.mm
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,

auto storage_byte_offset = src_.storage_offset() * src_.itemsize();
if (!src_.is_contiguous()) {
src = gatherViewTensor(src_);
Tensor emptyShell = Tensor();
src = gatherViewTensor(src_, emptyShell);
if (src.has_storage()) {
storage_byte_offset = 0;
} else {
Expand Down Expand Up @@ -250,10 +251,21 @@ void copy_blit_mps(void* dst, const void* src, size_t size) {
static at::Tensor& copy_kernel_mps(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking)
{
auto src_byte_offset = src_.storage_offset() * src_.itemsize();
auto dst_byte_offset = dst_.storage_offset() * dst_.itemsize();

// If dst is contiguous and there is no byte offset, we can save directly the result of
// gather into dst. This reduces the overhead of doing an additional blit for most cases
bool returnGatherOutput = (dst_.is_contiguous() && !dst_byte_offset);
Tensor src;

if (!src_.is_contiguous()) {
src = gatherViewTensor(src_);
Tensor emptyShell = Tensor();
src = gatherViewTensor(src_, returnGatherOutput ? dst_ : emptyShell);

if (src.has_storage()) {
if (returnGatherOutput)
return dst_;

src_byte_offset = 0;
} else {
src = src_.expand_as(dst_).contiguous();
Expand All @@ -271,7 +283,6 @@ void copy_blit_mps(void* dst, const void* src, size_t size) {
src._set_conj(src_.is_conj());
src._set_neg(src_.is_neg());

auto dst_byte_offset = dst_.storage_offset() * dst_.itemsize();
id<MTLBuffer> destBuffer = getMTLBufferStorage(dst_);
id<MTLBuffer> sourceBuffer = getMTLBufferStorage(src);
const size_t src_size = src.nbytes();
Expand Down
9 changes: 6 additions & 3 deletions aten/src/ATen/native/mps/operations/View.mm
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@
}
}

Tensor gatherViewTensor(const at::Tensor& src)
Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst)
{
ViewCachedGraph* cachedGraph = nullptr;

Expand All @@ -224,9 +224,12 @@ Tensor gatherViewTensor(const at::Tensor& src)
if (!cachedGraph) {
return Tensor();
}
Tensor output = at::native::empty_mps(src.sizes(), src.scalar_type(), c10::nullopt, kMPS);

return runViewGraph(cachedGraph, src, output, /*needsScatter*/ false);
Tensor output;
if (!dst.has_storage())
output = at::native::empty_mps(src.sizes(), src.scalar_type(), c10::nullopt, kMPS);

return runViewGraph(cachedGraph, src, dst.has_storage() ? dst : output, /*needsScatter*/ false);
}

Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output)
Expand Down
2 changes: 1 addition & 1 deletion test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4641,7 +4641,7 @@ def test_slicing_with_step(self):
x_mps = torch.zeros(10, dtype=torch.float32, device="mps")
x_mps[::2] = 1.0

x_cpu = torch.zeros(10, dtype=torch.float32, device="mps")
x_cpu = torch.zeros(10, dtype=torch.float32, device="cpu")
x_cpu[::2] = 1.0

self.assertEqual(x_cpu, x_mps)
Expand Down