diff --git a/kernels/aten/functions.yaml b/kernels/aten/functions.yaml index a8fa6611478..2156e85ab3c 100644 --- a/kernels/aten/functions.yaml +++ b/kernels/aten/functions.yaml @@ -423,6 +423,8 @@ - op: var.out +- op: view_as_real_copy.out + - op: view_copy.out - op: where.self_out diff --git a/kernels/portable/cpu/op_view_as_real_copy.cpp b/kernels/portable/cpu/op_view_as_real_copy.cpp new file mode 100644 index 00000000000..4a2803eded0 --- /dev/null +++ b/kernels/portable/cpu/op_view_as_real_copy.cpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace torch { +namespace executor { +namespace native { + +using Tensor = executorch::aten::Tensor; + +namespace { + +template +inline void _to_impl(const Tensor& self, Tensor& out) { + auto self_data = self.mutable_data_ptr(); + auto out_data = out.mutable_data_ptr(); + + for (size_t i = 0, e = self.numel(); i < e; i++) { + auto val_in = self_data[i]; + out_data[2 * i] = static_cast(val_in.real_); + out_data[2 * i + 1] = static_cast(val_in.imag_); + } +} + +} // namespace + +// view_as_real_copy(Tensor self) -> Tensor +Tensor& view_as_real_copy_out( + KernelRuntimeContext& ctx, + const Tensor& self, + Tensor& out) { + (void)ctx; + + // Get the output shape + Tensor::SizesType expected_output_size[kTensorDimensionLimit]; + get_view_as_real_copy_out_target_size(self, expected_output_size); + + // Resize for dynamic shape + ET_KERNEL_CHECK_MSG( + ctx, + resize_tensor( + out, {expected_output_size, static_cast(out.dim())}) == + Error::Ok, + InvalidArgument, + out, + "Failed to resize output tensor."); + + // The input tensor must be complex type + ET_KERNEL_CHECK_MSG( + ctx, + executorch::runtime::isComplexType(self.scalar_type()), + InvalidArgument, + out, + "Input tensor must be complex type"); + + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out); + + constexpr auto op_name = "view_as_real_copy.out"; + + ET_SWITCH_COMPLEXH_TYPES(self.scalar_type(), ctx, op_name, CTYPE_IN, [&] { + ET_SWITCH_FLOATH_TYPES(out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] { + _to_impl(self, out); + }); + }); + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/util/copy_ops_util.cpp b/kernels/portable/cpu/util/copy_ops_util.cpp index 93725d92dab..02b2910fc88 100644 --- a/kernels/portable/cpu/util/copy_ops_util.cpp +++ b/kernels/portable/cpu/util/copy_ops_util.cpp @@ -1018,5 +1018,14 @@ void get_unfold_copy_out_target_size( *out_ndim = self.dim() + 1; } +void get_view_as_real_copy_out_target_size( + const Tensor& self, + executorch::aten::SizesType* out_sizes) { + for (auto i : c10::irange(self.dim())) { + out_sizes[i] = self.size(i); + } + out_sizes[self.dim()] = 2; +} + } // namespace executor } // namespace torch diff --git a/kernels/portable/cpu/util/copy_ops_util.h b/kernels/portable/cpu/util/copy_ops_util.h index edcc6eb0021..cef2b3d4ee1 100644 --- a/kernels/portable/cpu/util/copy_ops_util.h +++ b/kernels/portable/cpu/util/copy_ops_util.h @@ -247,5 +247,9 @@ void get_unfold_copy_out_target_size( executorch::aten::SizesType* out_sizes, size_t* out_ndim); +void get_view_as_real_copy_out_target_size( + const Tensor& self, + executorch::aten::SizesType* out_sizes); + } // namespace executor } // namespace torch diff --git a/kernels/portable/functions.yaml b/kernels/portable/functions.yaml index 5e45a210a70..ab04d3b26ac 100644 --- a/kernels/portable/functions.yaml +++ b/kernels/portable/functions.yaml @@ -957,6 +957,11 @@ - arg_meta: null kernel_name: torch::executor::var_out +- op: view_as_real_copy.out + kernels: + - arg_meta: null + kernel_name: torch::executor::view_as_real_copy_out + - op: view_copy.out kernels: - arg_meta: null diff --git a/kernels/test/CMakeLists.txt b/kernels/test/CMakeLists.txt index 2d497dfc124..dcefa8c2e68 100644 --- a/kernels/test/CMakeLists.txt +++ b/kernels/test/CMakeLists.txt @@ -242,6 +242,7 @@ set(all_test_sources "op_upsample_bilinear2d_test.cpp" "op_upsample_nearest2d_test.cpp" "op_var_test.cpp" + "op_view_as_real_copy_test.cpp" "op_view_copy_test.cpp" "op_where_test.cpp" "op_zeros_test.cpp" diff --git a/kernels/test/op_view_as_real_copy_test.cpp b/kernels/test/op_view_as_real_copy_test.cpp new file mode 100644 index 00000000000..8e959c3db8c --- /dev/null +++ b/kernels/test/op_view_as_real_copy_test.cpp @@ -0,0 +1,86 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include // Declares the operator +#include +#include +#include +#include + +#include + +using namespace ::testing; +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using torch::executor::testing::TensorFactory; + +class OpViewAsRealTest : public OperatorTest { + protected: + Tensor& view_as_real_copy_out(const Tensor& self, Tensor& out) { + return torch::executor::aten::view_as_real_copy_outf(context_, self, out); + } + + template + void run_complex_smoke_test() { + TensorFactory tf; + constexpr auto REAL_DTYPE = executorch::runtime::toRealValueType(DTYPE); + TensorFactory tf_out; + + Tensor in = tf.make( + {2, 2}, + {CTYPE(3, 4), CTYPE(-1.7, 7.4), CTYPE(5, -12), CTYPE(8.3, 0.1)}); + Tensor out = tf_out.zeros({2, 2, 2}); + Tensor expected = + tf_out.make({2, 2, 2}, {3, 4, -1.7, 7.4, 5, -12, 8.3, 0.1}); + Tensor ret = view_as_real_copy_out(in, out); + + EXPECT_TENSOR_EQ(out, ret); + EXPECT_TENSOR_EQ(out, expected); + } + + // Tests on tensors with 0 size + template + void test_empty_input() { + TensorFactory tf; + constexpr auto REAL_DTYPE = executorch::runtime::toRealValueType(DTYPE); + TensorFactory tf_out; + + Tensor in = tf.make(/*sizes=*/{3, 0, 4}, /*data=*/{}); + Tensor out = tf_out.zeros({3, 0, 4, 2}); + Tensor expected = tf_out.make(/*sizes=*/{3, 0, 4, 2}, /*data=*/{}); + Tensor ret = view_as_real_copy_out(in, out); + + EXPECT_TENSOR_EQ(out, ret); + EXPECT_TENSOR_EQ(out, expected); + } + + // Tests on 0-dim input tensors + template + void zero_dim_input() { + TensorFactory tf; + constexpr auto REAL_DTYPE = executorch::runtime::toRealValueType(DTYPE); + TensorFactory tf_out; + + Tensor in = tf.make(/*sizes=*/{}, {CTYPE(0, 0)}); + Tensor out = tf_out.zeros({2}); + Tensor expected = tf_out.zeros(/*sizes=*/{2}); + Tensor ret = view_as_real_copy_out(in, out); + + EXPECT_TENSOR_EQ(out, ret); + EXPECT_TENSOR_EQ(out, expected); + } +}; + +TEST_F(OpViewAsRealTest, ComplexSmokeTest) { +#define RUN_SMOKE_TEST(ctype, dtype) \ + run_complex_smoke_test(); \ + test_empty_input(); \ + zero_dim_input(); + ET_FORALL_COMPLEXH_TYPES(RUN_SMOKE_TEST); +#undef RUN_SMOKE_TEST +} diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index 05e678c6229..8c5fad1f588 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -331,6 +331,7 @@ def define_common_targets(): _common_op_test("op_upsample_bilinear2d_test", ["aten", "portable"]) _common_op_test("op_upsample_nearest2d_test", ["aten", "portable"]) _common_op_test("op_var_test", ["aten", "portable"]) + _common_op_test("op_view_as_real_copy_test", ["aten", "portable"]) _common_op_test("op_view_copy_test", ["aten", "portable"]) _common_op_test("op_where_test", ["aten", "portable"]) _common_op_test("op_zeros_test", ["aten", "portable"]) diff --git a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl index 06f9a452935..d0c39bcf17f 100644 --- a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -1268,6 +1268,12 @@ ATEN_OPS = ( "//executorch/kernels/portable/cpu/util:reduce_util", ], ), + op_target( + name = "op_view_as_real_copy", + deps = [ + "//executorch/kernels/portable/cpu/util:copy_ops_util", + ], + ), op_target( name = "op_view_copy", deps = [