diff --git a/kernels/optimized/cpu/op_where.cpp b/kernels/optimized/cpu/op_where.cpp new file mode 100644 index 00000000000..4d897ea6281 --- /dev/null +++ b/kernels/optimized/cpu/op_where.cpp @@ -0,0 +1,97 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include + +namespace torch { +namespace executor { +namespace native { + +Tensor& opt_where_out( + KernelRuntimeContext& ctx, + const Tensor& cond, + const Tensor& a, + const Tensor& b, + Tensor& out) { + // Common Dtype + ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); + + // Check Common Dtype + ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(cond, a, b, out), InvalidArgument, out); + + // Resize + ET_KERNEL_CHECK( + ctx, + resize_to_broadcast_target_size(a, b, cond, out) == Error::Ok, + InvalidArgument, + out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "where.self_out"; + + if (a.scalar_type() == b.scalar_type() && + a.scalar_type() == out.scalar_type() && a.scalar_type() == compute_type && + // Using a Byte tensor for cond has been deprecated for a long time. + cond.scalar_type() == ScalarType::Bool) { + auto out_numel = out.numel(); + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const bool a_is_broadcasted = !out.sizes().equals(a.sizes()); + const bool b_is_broadcasted = !out.sizes().equals(b.sizes()); + const bool cond_is_broadcasted = !out.sizes().equals(cond.sizes()); + const bool any_is_broadcasted = + (a_is_broadcasted || b_is_broadcasted || cond_is_broadcasted); + const CTYPE_COMPUTE* const data_a = a.const_data_ptr(); + const CTYPE_COMPUTE* const data_b = b.const_data_ptr(); + const bool* const data_cond = cond.const_data_ptr(); + CTYPE_COMPUTE* const data_out = out.data_ptr(); + if (any_is_broadcasted) { + for (const auto [out_index, a_index, b_index, cond_index] : + BroadcastIndexesRange<3>(out, a, b, cond)) { + data_out[out_index] = + data_cond[cond_index] ? data_a[a_index] : data_b[b_index]; + } + } else { + for (const auto i : c10::irange(out_numel)) { + data_out[i] = data_cond[i] ? data_a[i] : data_b[i]; + } + } + }); + } else { + // Fall back for mixed dtype to keep code size and compile time + // reasonable. + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_tritensor_elementwise_fn( + [](const CTYPE_COMPUTE val_a, + const CTYPE_COMPUTE val_b, + const CTYPE_COMPUTE val_c) { return val_c ? val_a : val_b; }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + cond, + utils::SupportedTensorDtypes::BOOL_OR_BYTE, + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); + }); + } + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index 83b2c320266..dc189708992 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -95,6 +95,12 @@ _OPTIMIZED_ATEN_OPS = ( "//executorch/kernels/portable/cpu/util:broadcast_util", ], ), + op_target( + name = "op_where", + deps = [ + "//executorch/kernels/portable/cpu/util:elementwise_util", + ], + ), ) diff --git a/kernels/optimized/optimized.yaml b/kernels/optimized/optimized.yaml index fd5143b1511..4f90059aa93 100644 --- a/kernels/optimized/optimized.yaml +++ b/kernels/optimized/optimized.yaml @@ -101,3 +101,8 @@ kernels: - arg_meta: null kernel_name: torch::executor::opt_sub_scalar_out + +- op: where.self_out + kernels: + - arg_meta: null + kernel_name: torch::executor::opt_where_out diff --git a/kernels/portable/cpu/util/broadcast_util.h b/kernels/portable/cpu/util/broadcast_util.h index 10bd07baee2..f6bfae9bdaa 100644 --- a/kernels/portable/cpu/util/broadcast_util.h +++ b/kernels/portable/cpu/util/broadcast_util.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include #include @@ -290,23 +291,18 @@ inline void apply_binary_elementwise_fn( const CTYPE_B* const data_b = b.const_data_ptr(); CTYPE_OUT* const data_out = out.mutable_data_ptr(); - for (const auto i : c10::irange(out.numel())) { - size_t a_linear_index = i; - size_t b_linear_index = i; - - if (any_is_broadcasted) { - size_t out_indexes[kTensorDimensionLimit]; - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); - - if (a_is_broadcasted) { - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); - } - if (b_is_broadcasted) { - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); - } + if (any_is_broadcasted) { + for (const auto [out_index, a_index, b_index] : + BroadcastIndexesRange<2>(out, a, b)) { + data_out[out_index] = compute_fun(data_a[a_index], data_b[b_index]); } + } else { + for (const auto i : c10::irange(out.numel())) { + size_t a_linear_index = i; + size_t b_linear_index = i; - data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]); + data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]); + } } } @@ -338,28 +334,16 @@ inline void apply_ternary_elementwise_fn( const CTYPE_C* const data_c = c.const_data_ptr(); CTYPE_OUT* const data_out = out.mutable_data_ptr(); - for (const auto i : c10::irange(out.numel())) { - size_t a_linear_index = i; - size_t b_linear_index = i; - size_t c_linear_index = i; - - if (any_is_broadcasted) { - size_t out_indexes[kTensorDimensionLimit]; - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); - - if (a_is_broadcasted) { - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); - } - if (b_is_broadcasted) { - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); - } - if (c_is_broadcasted) { - c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c); - } + if (any_is_broadcasted) { + for (const auto [out_index, a_index, b_index, c_index] : + BroadcastIndexesRange<3>(out, a, b, c)) { + data_out[out_index] = + compute_fun(data_a[a_index], data_b[b_index], data_c[c_index]); + } + } else { + for (const auto i : c10::irange(out.numel())) { + data_out[i] = compute_fun(data_a[i], data_b[i], data_c[i]); } - - data_out[i] = compute_fun( - data_a[a_linear_index], data_b[b_linear_index], data_c[c_linear_index]); } } diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 778006f1b99..09db5f7180d 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include #include #include @@ -121,26 +122,24 @@ inline void apply_bitensor_elementwise_fn( char* const data_out = reinterpret_cast(out.mutable_data_ptr()); auto out_numel = out.numel(); - for (const auto i : c10::irange(out_numel)) { - size_t a_linear_index = i; - size_t b_linear_index = i; - - if (any_is_broadcasted) { - size_t out_indexes[kTensorDimensionLimit]; - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); - - if (a_is_broadcasted) { - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); - } - if (b_is_broadcasted) { - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); - } + if (any_is_broadcasted) { + for (const auto [out_index, a_index, b_index] : + BroadcastIndexesRange<2>(out, a, b)) { + auto result = compute_fun( + load_a_to_common(&data_a[a_index * a_element_size]), + load_b_to_common(&data_b[b_index * b_element_size])); + store_common_to_out(result, &data_out[out_index * out_element_size]); + } + } else { + for (const auto i : c10::irange(out_numel)) { + size_t a_linear_index = i; + size_t b_linear_index = i; + + auto result = compute_fun( + load_a_to_common(&data_a[a_linear_index * a_element_size]), + load_b_to_common(&data_b[b_linear_index * b_element_size])); + store_common_to_out(result, &data_out[i * out_element_size]); } - - auto result = compute_fun( - load_a_to_common(&data_a[a_linear_index * a_element_size]), - load_b_to_common(&data_b[b_linear_index * b_element_size])); - store_common_to_out(result, &data_out[i * out_element_size]); } } @@ -211,31 +210,27 @@ inline void apply_tritensor_elementwise_fn( char* const data_out = reinterpret_cast(out.mutable_data_ptr()); auto out_numel = out.numel(); - for (const auto i : c10::irange(out_numel)) { - size_t a_linear_index = i; - size_t b_linear_index = i; - size_t c_linear_index = i; - - if (any_is_broadcasted) { - size_t out_indexes[kTensorDimensionLimit]; - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); - - if (a_is_broadcasted) { - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); - } - if (b_is_broadcasted) { - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); - } - if (c_is_broadcasted) { - c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c); - } + if (any_is_broadcasted) { + for (const auto [out_index, a_index, b_index, c_index] : + BroadcastIndexesRange<3>(out, a, b, c)) { + auto result = compute_fun( + load_a_to_common(&data_a[a_index * a_element_size]), + load_b_to_common(&data_b[b_index * b_element_size]), + load_c_to_common(&data_c[c_index * c_element_size])); + store_common_to_out(result, &data_out[out_index * out_element_size]); + } + } else { + for (const auto i : c10::irange(out_numel)) { + size_t a_linear_index = i; + size_t b_linear_index = i; + size_t c_linear_index = i; + + auto result = compute_fun( + load_a_to_common(&data_a[a_linear_index * a_element_size]), + load_b_to_common(&data_b[b_linear_index * b_element_size]), + load_c_to_common(&data_c[c_linear_index * c_element_size])); + store_common_to_out(result, &data_out[i * out_element_size]); } - - auto result = compute_fun( - load_a_to_common(&data_a[a_linear_index * a_element_size]), - load_b_to_common(&data_b[b_linear_index * b_element_size]), - load_c_to_common(&data_c[c_linear_index * c_element_size])); - store_common_to_out(result, &data_out[i * out_element_size]); } } diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index c42f38fd8b0..739bc117fbf 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -70,6 +70,9 @@ def define_common_targets(): exported_headers = [ "broadcast_util.h", ], + exported_deps = [ + ":broadcast_indexes_range", + ], deps = [ ":repeat_util", "//executorch/runtime/kernel:kernel_includes", diff --git a/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp b/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp index d1db40fca48..f147958558d 100644 --- a/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp +++ b/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp @@ -112,8 +112,66 @@ TEST(BroadcastIndexesRangeTest, OneAndTwoDExhaustive) { EXPECT_EQ(expected, actual); } -// Here we assume that the previous tests established that padding -// with leading 1s is working, and test: +// Make sure nothing is thrown off by a size-1 dim in the output: +// [] -> [1, W] +// [] -> [H, 1] +// [1] -> [1, W] +// [1] -> [H, 1] +// [W] -> [1, W] +// [1, 1] -> [1, W] +// [1, 1] -> [H, 1] +// [1, W] -> [1, W] +// [H, 1] -> [H, 1] +TEST(BroadcastIndexesRangeTest, OneAndTwoDWith1InOutputShapeExhaustive) { + TensorFactory tf; + constexpr auto H = 2; + constexpr auto W = 3; + Tensor out_row = tf.zeros({1, W}); + Tensor out_col = tf.zeros({H, 1}); + Tensor in_0d_scalar = tf.zeros({}); + Tensor in_1d_scalar = tf.zeros({1}); + Tensor in_2d_scalar = tf.zeros({1, 1}); + + Tensor in_row = tf.zeros({W}); + Tensor in_leading_one_row = tf.zeros({1, W}); + + Tensor in_col = tf.zeros({H, 1}); + + size_t idx = 0; + for (const auto + [out_idx, + in_0d_idx, + in_1d_idx, + in_2d_idx, + in_row_idx, + in_leading_one_row_idx] : + BroadcastIndexesRange<5>( + out_row, + in_0d_scalar, + in_1d_scalar, + in_2d_scalar, + in_row, + in_leading_one_row)) { + EXPECT_EQ(out_idx, idx++); + EXPECT_EQ(in_0d_idx, 0); + EXPECT_EQ(in_1d_idx, 0); + EXPECT_EQ(in_2d_idx, 0); + EXPECT_EQ(in_row_idx, out_idx); + EXPECT_EQ(in_leading_one_row_idx, out_idx); + } + + idx = 0; + for (const auto [out_idx, in_0d_idx, in_1d_idx, in_2d_idx, in_col_idx] : + BroadcastIndexesRange<4>( + out_col, in_0d_scalar, in_1d_scalar, in_2d_scalar, in_col)) { + EXPECT_EQ(out_idx, idx++); + EXPECT_EQ(in_0d_idx, 0); + EXPECT_EQ(in_1d_idx, 0); + EXPECT_EQ(in_2d_idx, 0); + EXPECT_EQ(in_col_idx, out_idx); + } +} + // [1, 1, 1] -> [C, H, W] // [C, H, 1] -> [C, H, W] // [C, 1, W] -> [C, H, W] @@ -166,11 +224,12 @@ TEST(BroadcastIndexesRangeTest, ThreeDBroadcasting) { // 4-D should generalize, but we will go ahead and test: // [N, 1, H, 1] -> [N, C, H, W] // [1, C, 1, W] -> [N, C, H, W] -TEST(BroadcastIndexesRangeTest, FourDBroadcasting) { +template +void four_d_broadcasting_test() { TensorFactory tf; - Tensor out = tf.zeros({2, 3, 4, 5}); - Tensor in_broadcast_cw = tf.zeros({2, 1, 4, 1}); - Tensor in_broadcast_nh = tf.zeros({1, 3, 1, 5}); + Tensor out = tf.zeros({N, C, H, W}); + Tensor in_broadcast_cw = tf.zeros({N, 1, H, 1}); + Tensor in_broadcast_nh = tf.zeros({1, C, 1, W}); // Writing out all the indexes would be too cumbersome, so here we // take the opportunity to mutation test against delinearize_index @@ -190,3 +249,12 @@ TEST(BroadcastIndexesRangeTest, FourDBroadcasting) { linearize_access_indexes(out_indexes, out.dim(), in_broadcast_nh)); } } + +TEST(BroadcastIndexesRangeTest, FourDBroadcasting) { + four_d_broadcasting_test<2, 3, 4, 5>(); +} + +TEST(BroadcastIndexesRangeTest, FourDBroadcastingWithOneDimsInOutput) { + four_d_broadcasting_test<2, 3, 1, 5>(); + four_d_broadcasting_test<2, 1, 3, 1>(); +} diff --git a/kernels/test/CMakeLists.txt b/kernels/test/CMakeLists.txt index 24adb8d9c80..394ec241698 100644 --- a/kernels/test/CMakeLists.txt +++ b/kernels/test/CMakeLists.txt @@ -275,6 +275,7 @@ set(_optimized_kernels_test_sources "op_native_layer_norm_test.cpp" "op_neg_test.cpp" "op_sub_test.cpp" + "op_where_test.cpp" "UnaryUfuncRealHBBF16ToFloatHBF16Test.cpp" ${CMAKE_CURRENT_BINARY_DIR}/include/optimized/executorch/kernels/test/supported_features.cpp )