-
Notifications
You must be signed in to change notification settings - Fork 654
[MPS] Add portable grid_sampler_2d implementation + tests #10561
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[MPS] Add portable grid_sampler_2d implementation + tests #10561
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/10561
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below:
✅ No FailuresAs of commit b4a4e04 with merge base b4e1145 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
8f26df6
to
682b539
Compare
682b539
to
4450646
Compare
@cccclai @manuelcandales bumping up for a review. |
It's portable op implementation, can @manuelcandales or @swolchok take a look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not familiar with this operator and its mathematical function, but here are some general C++ comments
namespace torch { | ||
namespace executor { | ||
namespace native { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we have C++17, new code should use nested namespaces (namespace torch::executor::native {
)
using Tensor = exec_aten::Tensor; | ||
using ScalarType = executorch::aten::ScalarType; | ||
using SizesType = executorch::aten::SizesType; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: sort
using SizesType = executorch::aten::SizesType; | ||
|
||
// Transform normalized coordinates to pixel space | ||
inline float unnormalize_coord(float coord, int64_t size, bool align_corners) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for internal functions like this one: remove inline, put in anonymous namespace
} | ||
|
||
// Compute source index and interpolation weight | ||
inline std::pair<int64_t, float> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto inline
int64_t abs_coord = std::abs(coord); | ||
abs_coord = abs_coord % double_size; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super nit: I see a couple places in this PR where you could be terser by performing more than one operation per statement. for example, here would be better as int64_t abs_coord = std::abs(coord) % double_size;
|
||
// Create a 1x1x4x4 input tensor | ||
const std::vector<int32_t> input_sizes = {1, 1, 4, 4}; | ||
std::vector<float> input_data = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you could use std::iota to make this a lot shorter
https://en.cppreference.com/w/cpp/algorithm/iota
align_corners, | ||
out); | ||
|
||
// Check non-zero |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this comment just repeats the code, delete (and similar comments throughout that repeat the code)
|
||
// Create a 1x1x4x4 input tensor | ||
const std::vector<int32_t> input_sizes = {1, 1, 4, 4}; | ||
std::vector<float> input_data = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto iota
@@ -1330,4 +1337,4 @@ def portable_source_list(): | |||
|
|||
def portable_header_list(): | |||
"""All the header file names from //executorch/kernels/portable/cpu/""" | |||
return ["selective_build.h", "scalar_utils.h", "math_constants.h", "vec_ops.h"] | |||
return ["selective_build.h", "scalar_utils.h", "math_constants.h", "vec_ops.h"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not 100% sure, but it looks like github is saying you removed the end-of-file newline? if so, please put it back
(out.size(0) == N && out.size(1) == C && out.size(2) == out_H && | ||
out.size(3) == out_W), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't we be attempting to resize output to the desired size so that we can handle dynamic shapes smoothly? @manuelcandales
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes.
@DenisVieriu97 please follow this pattern
@@ -335,4 +336,4 @@ def define_common_targets(): | |||
_common_op_test("op_view_as_real_copy_test", ["aten", "portable"]) | |||
_common_op_test("op_view_copy_test", ["aten", "portable"]) | |||
_common_op_test("op_where_test", ["aten", "portable"]) | |||
_common_op_test("op_zeros_test", ["aten", "portable"]) | |||
_common_op_test("op_zeros_test", ["aten", "portable"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
newline?
const int64_t out_W = grid.size(2); | ||
|
||
// Check for 4D input and grid | ||
ET_KERNEL_CHECK(ctx, (input.dim() == 4), InvalidArgument, out); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DenisVieriu97 do you mind refactoring all of the InvalidArgument checks into a helper function dedicated to check the arguments, which returns True only if all checks pass, and logs when a check doesn't pass. Here is an example
and then you would be calling
ET_KERNEL_CHECK(ctx, check_grid_sampler_2d_args(...) == 4), InvalidArgument, out);
|
||
// Process grid sampling with specified input, grid, and modes | ||
template <typename T> | ||
void grid_sampler_2d_impl( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implementation requires default dim order. Add a check in your check_args function, to check that the dim order is the default.
This PR needs a
|
Upstream grid_sampler_2d (cpu implementation) and the tests implemented by @versi379
Summary of changes:
cc @cccclai , @shoumikhin