Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 31 additions & 26 deletions kernels/portable/cpu/util/broadcast_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#pragma once

#include <c10/util/irange.h>
#include <executorch/kernels/portable/cpu/util/delinearized_indexes_range.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>

Expand Down Expand Up @@ -290,23 +291,27 @@ inline void apply_binary_elementwise_fn(
const CTYPE_B* const data_b = b.const_data_ptr<CTYPE_B>();
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();

for (const auto i : c10::irange(out.numel())) {
size_t a_linear_index = i;
size_t b_linear_index = i;

if (any_is_broadcasted) {
size_t out_indexes[kTensorDimensionLimit];
delinearize_index(i, out, out_indexes, kTensorDimensionLimit);

if (any_is_broadcasted) {
size_t i = 0;
for (const auto& delinearized_indexes : DelinearizedIndexesRange(out)) {
size_t a_linear_index = i;
size_t b_linear_index = i;
if (a_is_broadcasted) {
a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
a_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), a);
}
if (b_is_broadcasted) {
b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
b_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), b);
}

data_out[i++] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]);
}
} else {
for (const auto i : c10::irange(out.numel())) {
size_t a_linear_index = i;
size_t b_linear_index = i;

data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]);
data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]);
}
}
}

Expand Down Expand Up @@ -338,28 +343,28 @@ inline void apply_ternary_elementwise_fn(
const CTYPE_C* const data_c = c.const_data_ptr<CTYPE_C>();
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();

for (const auto i : c10::irange(out.numel())) {
size_t a_linear_index = i;
size_t b_linear_index = i;
size_t c_linear_index = i;

if (any_is_broadcasted) {
size_t out_indexes[kTensorDimensionLimit];
delinearize_index(i, out, out_indexes, kTensorDimensionLimit);

if (any_is_broadcasted) {
size_t i = 0;
for (const auto& delinearized_indexes : DelinearizedIndexesRange(out)) {
size_t a_linear_index = i;
size_t b_linear_index = i;
size_t c_linear_index = i;
if (a_is_broadcasted) {
a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
a_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), a);
}
if (b_is_broadcasted) {
b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
b_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), b);
}
if (c_is_broadcasted) {
c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c);
c_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), c);
}
}

data_out[i] = compute_fun(
data_a[a_linear_index], data_b[b_linear_index], data_c[c_linear_index]);
data_out[i++] = compute_fun(data_a[a_linear_index], data_b[b_linear_index], data_c[c_linear_index]);
}
} else {
for (const auto i : c10::irange(out.numel())) {
data_out[i] = compute_fun(data_a[i], data_b[i], data_c[i]);
}
}
}

Expand Down
79 changes: 48 additions & 31 deletions kernels/portable/cpu/util/elementwise_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <c10/util/irange.h>
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/kernels/portable/cpu/util/delinearized_indexes_range.h>
#include <executorch/kernels/portable/cpu/util/dtype_util.h>
#include <executorch/runtime/kernel/kernel_runtime_context.h>

Expand Down Expand Up @@ -121,26 +122,33 @@ inline void apply_bitensor_elementwise_fn(
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());

auto out_numel = out.numel();
for (const auto i : c10::irange(out_numel)) {
size_t a_linear_index = i;
size_t b_linear_index = i;

if (any_is_broadcasted) {
size_t out_indexes[kTensorDimensionLimit];
delinearize_index(i, out, out_indexes, kTensorDimensionLimit);

if (any_is_broadcasted) {
size_t i = 0;
for (const auto& delinearized_indexes : DelinearizedIndexesRange(out)) {
size_t a_linear_index = i;
size_t b_linear_index = i;
if (a_is_broadcasted) {
a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
a_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), a);
}
if (b_is_broadcasted) {
b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
b_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), b);
}
auto result = compute_fun(
load_a_to_common(&data_a[a_linear_index * a_element_size]),
load_b_to_common(&data_b[b_linear_index * b_element_size]));
store_common_to_out(result, &data_out[i * out_element_size]);
i++;
}
} else {
for (const auto i : c10::irange(out_numel)) {
size_t a_linear_index = i;
size_t b_linear_index = i;

auto result = compute_fun(
load_a_to_common(&data_a[a_linear_index * a_element_size]),
load_b_to_common(&data_b[b_linear_index * b_element_size]));
store_common_to_out(result, &data_out[i * out_element_size]);
auto result = compute_fun(
load_a_to_common(&data_a[a_linear_index * a_element_size]),
load_b_to_common(&data_b[b_linear_index * b_element_size]));
store_common_to_out(result, &data_out[i * out_element_size]);
}
}
}

Expand Down Expand Up @@ -211,31 +219,40 @@ inline void apply_tritensor_elementwise_fn(
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());

auto out_numel = out.numel();
for (const auto i : c10::irange(out_numel)) {
size_t a_linear_index = i;
size_t b_linear_index = i;
size_t c_linear_index = i;

if (any_is_broadcasted) {
size_t out_indexes[kTensorDimensionLimit];
delinearize_index(i, out, out_indexes, kTensorDimensionLimit);

if (any_is_broadcasted) {
size_t i = 0;
for (const auto& delinearized_indexes : DelinearizedIndexesRange(out)) {
size_t a_linear_index = i;
size_t b_linear_index = i;
size_t c_linear_index = i;
if (a_is_broadcasted) {
a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
a_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), a);
}
if (b_is_broadcasted) {
b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
b_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), b);
}
if (c_is_broadcasted) {
c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c);
c_linear_index = linearize_access_indexes(delinearized_indexes, out.dim(), c);
}
auto result = compute_fun(
load_a_to_common(&data_a[a_linear_index * a_element_size]),
load_b_to_common(&data_b[b_linear_index * b_element_size]),
load_c_to_common(&data_c[c_linear_index * c_element_size]));
store_common_to_out(result, &data_out[i * out_element_size]);
i++;
}
} else {
for (const auto i : c10::irange(out_numel)) {
size_t a_linear_index = i;
size_t b_linear_index = i;
size_t c_linear_index = i;

auto result = compute_fun(
load_a_to_common(&data_a[a_linear_index * a_element_size]),
load_b_to_common(&data_b[b_linear_index * b_element_size]),
load_c_to_common(&data_c[c_linear_index * c_element_size]));
store_common_to_out(result, &data_out[i * out_element_size]);
auto result = compute_fun(
load_a_to_common(&data_a[a_linear_index * a_element_size]),
load_b_to_common(&data_b[b_linear_index * b_element_size]),
load_c_to_common(&data_c[c_linear_index * c_element_size]));
store_common_to_out(result, &data_out[i * out_element_size]);
}
}
}

Expand Down
Loading