From b5299c65323be8561dc0254f77fe017206fee36e Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Fri, 28 Feb 2025 17:11:04 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- kernels/portable/cpu/util/broadcast_util.h | 57 +++++++------- kernels/portable/cpu/util/elementwise_util.h | 79 ++++++++++++-------- 2 files changed, 79 insertions(+), 57 deletions(-) diff --git a/kernels/portable/cpu/util/broadcast_util.h b/kernels/portable/cpu/util/broadcast_util.h index 10bd07baee2..47ac9da4af2 100644 --- a/kernels/portable/cpu/util/broadcast_util.h +++ b/kernels/portable/cpu/util/broadcast_util.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include #include @@ -290,23 +291,27 @@ inline void apply_binary_elementwise_fn( const CTYPE_B* const data_b = b.const_data_ptr(); CTYPE_OUT* const data_out = out.mutable_data_ptr(); - for (const auto i : c10::irange(out.numel())) { - size_t a_linear_index = i; - size_t b_linear_index = i; - - if (any_is_broadcasted) { - size_t out_indexes[kTensorDimensionLimit]; - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); - + if (any_is_broadcasted) { + size_t i = 0; + for (const auto& delinearized_indexes : DelinearizedIndexesRange(out)) { + size_t a_linear_index = i; + size_t b_linear_index = i; if (a_is_broadcasted) { - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); + a_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), a); } if (b_is_broadcasted) { - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); + b_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), b); } + + data_out[i++] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]); } + } else { + for (const auto i : c10::irange(out.numel())) { + size_t a_linear_index = i; + size_t b_linear_index = i; - data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]); + data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]); + } } } @@ -338,28 +343,28 @@ inline void apply_ternary_elementwise_fn( const CTYPE_C* const data_c = c.const_data_ptr(); CTYPE_OUT* const data_out = out.mutable_data_ptr(); - for (const auto i : c10::irange(out.numel())) { - size_t a_linear_index = i; - size_t b_linear_index = i; - size_t c_linear_index = i; - - if (any_is_broadcasted) { - size_t out_indexes[kTensorDimensionLimit]; - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); - + if (any_is_broadcasted) { + size_t i = 0; + for (const auto& delinearized_indexes : DelinearizedIndexesRange(out)) { + size_t a_linear_index = i; + size_t b_linear_index = i; + size_t c_linear_index = i; if (a_is_broadcasted) { - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); + a_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), a); } if (b_is_broadcasted) { - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); + b_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), b); } if (c_is_broadcasted) { - c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c); + c_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), c); } - } - data_out[i] = compute_fun( - data_a[a_linear_index], data_b[b_linear_index], data_c[c_linear_index]); + data_out[i++] = compute_fun(data_a[a_linear_index], data_b[b_linear_index], data_c[c_linear_index]); + } + } else { + for (const auto i : c10::irange(out.numel())) { + data_out[i] = compute_fun(data_a[i], data_b[i], data_c[i]); + } } } diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 778006f1b99..ee19a3640fb 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -10,6 +10,7 @@ #include #include +#include #include #include @@ -121,26 +122,33 @@ inline void apply_bitensor_elementwise_fn( char* const data_out = reinterpret_cast(out.mutable_data_ptr()); auto out_numel = out.numel(); - for (const auto i : c10::irange(out_numel)) { - size_t a_linear_index = i; - size_t b_linear_index = i; - - if (any_is_broadcasted) { - size_t out_indexes[kTensorDimensionLimit]; - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); - + if (any_is_broadcasted) { + size_t i = 0; + for (const auto& delinearized_indexes : DelinearizedIndexesRange(out)) { + size_t a_linear_index = i; + size_t b_linear_index = i; if (a_is_broadcasted) { - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); + a_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), a); } if (b_is_broadcasted) { - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); + b_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), b); } + auto result = compute_fun( + load_a_to_common(&data_a[a_linear_index * a_element_size]), + load_b_to_common(&data_b[b_linear_index * b_element_size])); + store_common_to_out(result, &data_out[i * out_element_size]); + i++; } + } else { + for (const auto i : c10::irange(out_numel)) { + size_t a_linear_index = i; + size_t b_linear_index = i; - auto result = compute_fun( - load_a_to_common(&data_a[a_linear_index * a_element_size]), - load_b_to_common(&data_b[b_linear_index * b_element_size])); - store_common_to_out(result, &data_out[i * out_element_size]); + auto result = compute_fun( + load_a_to_common(&data_a[a_linear_index * a_element_size]), + load_b_to_common(&data_b[b_linear_index * b_element_size])); + store_common_to_out(result, &data_out[i * out_element_size]); + } } } @@ -211,31 +219,40 @@ inline void apply_tritensor_elementwise_fn( char* const data_out = reinterpret_cast(out.mutable_data_ptr()); auto out_numel = out.numel(); - for (const auto i : c10::irange(out_numel)) { - size_t a_linear_index = i; - size_t b_linear_index = i; - size_t c_linear_index = i; - - if (any_is_broadcasted) { - size_t out_indexes[kTensorDimensionLimit]; - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); - + if (any_is_broadcasted) { + size_t i = 0; + for (const auto& delinearized_indexes : DelinearizedIndexesRange(out)) { + size_t a_linear_index = i; + size_t b_linear_index = i; + size_t c_linear_index = i; if (a_is_broadcasted) { - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); + a_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), a); } if (b_is_broadcasted) { - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); + b_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), b); } if (c_is_broadcasted) { - c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c); + c_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), c); } + auto result = compute_fun( + load_a_to_common(&data_a[a_linear_index * a_element_size]), + load_b_to_common(&data_b[b_linear_index * b_element_size]), + load_c_to_common(&data_c[c_linear_index * c_element_size])); + store_common_to_out(result, &data_out[i * out_element_size]); + i++; } + } else { + for (const auto i : c10::irange(out_numel)) { + size_t a_linear_index = i; + size_t b_linear_index = i; + size_t c_linear_index = i; - auto result = compute_fun( - load_a_to_common(&data_a[a_linear_index * a_element_size]), - load_b_to_common(&data_b[b_linear_index * b_element_size]), - load_c_to_common(&data_c[c_linear_index * c_element_size])); - store_common_to_out(result, &data_out[i * out_element_size]); + auto result = compute_fun( + load_a_to_common(&data_a[a_linear_index * a_element_size]), + load_b_to_common(&data_b[b_linear_index * b_element_size]), + load_c_to_common(&data_c[c_linear_index * c_element_size])); + store_common_to_out(result, &data_out[i * out_element_size]); + } } }