Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 171 additions & 3 deletions kernels/portable/cpu/op_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,189 @@ namespace native {
using Tensor = executorch::aten::Tensor;
using TensorOptList = executorch::aten::ArrayRef<std::optional<Tensor>>;

Tensor& index_Tensor_out(
namespace {

bool check_fast_path_conditions(
ET_UNUSED const Tensor& in,
TensorOptList indices,
size_t* dim) {
bool found_index = false;
for (const auto i : c10::irange(indices.size())) {
if (indices[i].has_value()) {
*dim = i;
// Fast path only supports a single non-null index tensor
if (found_index) {
return false;
}
found_index = true;
const Tensor& index = indices[i].value();
ScalarType ix_type = index.scalar_type();
// Fast path only supports Long or Int index tensors
if (ix_type != ScalarType::Long && ix_type != ScalarType::Int) {
return false;
}
// Fast path only supports a 1-dimensional index tensor
if (index.dim() != 1) {
return false;
}
}
}

// Fast path needs at least one non-null index tensor
if (!found_index) {
return false;
}

return true;
}

bool check_fast_path_args(
KernelRuntimeContext& ctx,
const Tensor& in,
TensorOptList indices,
size_t dim,
Tensor& out) {
(void)ctx;
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));

ET_CHECK_OR_RETURN_FALSE(
static_cast<ssize_t>(indices.size()) <= in.dim(),
"Indexing too many dimensions");

const Tensor& index = indices[dim].value();

bool is_valid_index = true;
ET_SWITCH_TWO_TYPES(
Long, Int, index.scalar_type(), ctx, "index.Tensor", CTYPE, [&]() {
const CTYPE* const index_arr = index.const_data_ptr<CTYPE>();
for (const auto i : c10::irange(index.numel())) {
if (index_arr[i] < 0 ||
index_arr[i] >= static_cast<CTYPE>(in.size(dim))) {
ET_LOG(
Error,
"Index %" PRId64
" out of range for tensor with size %zd"
" at dimension %zu",
static_cast<int64_t>(index_arr[i]),
in.size(dim),
dim);
is_valid_index = false;
break;
}
}
});

ET_CHECK_OR_RETURN_FALSE(
is_valid_index,
"Some index values are not within bounds of input tensor at indexed dim");

return true;
}

void get_fast_path_index_out_target_size(
const Tensor& in,
TensorOptList indices,
size_t dim,
Tensor::SizesType* out_sizes,
size_t* out_ndim) {
*out_ndim = in.dim();

for (const auto d : c10::irange(static_cast<size_t>(in.dim()))) {
if (d != dim) {
out_sizes[d] = static_cast<Tensor::SizesType>(in.size(d));
} else {
out_sizes[d] =
static_cast<Tensor::SizesType>(indices[dim].value().numel());
}
}
}

Tensor& fast_path(
KernelRuntimeContext& ctx,
const Tensor& in,
TensorOptList indices,
size_t dim,
Tensor& out) {
ET_KERNEL_CHECK(
ctx, check_index_args(in, indices, out), InvalidArgument, out);
ctx,
check_fast_path_args(ctx, in, indices, dim, out),
InvalidArgument,
out);

const Tensor& index = indices[dim].value();
ScalarType index_type = index.scalar_type();

// @lint-ignore CLANGTIDY facebook-hte-CArray
Tensor::SizesType expected_size[kTensorDimensionLimit];
size_t expected_ndim = 0;
get_fast_path_index_out_target_size(
in, indices, dim, expected_size, &expected_ndim);

ET_KERNEL_CHECK(
ctx,
resize_tensor(out, {expected_size, expected_ndim}) == Error::Ok,
InvalidArgument,
out);

if (out.dim() == 0) {
memcpy(out.mutable_data_ptr(), in.const_data_ptr(), out.nbytes());
return out;
}

size_t leading_dims = getLeadingDims(in, dim);
size_t trailing_dims = getTrailingDims(in, dim);

if (leading_dims == 0 || trailing_dims == 0) {
return out;
}

size_t in_dim_length = in.size(dim);
size_t out_dim_length = out.size(dim);

size_t length_per_step = trailing_dims * in.element_size();

const char* in_data = in.const_data_ptr<char>();
char* out_data = out.mutable_data_ptr<char>();

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "index.Tensor_out";

ET_SWITCH_TWO_TYPES(Long, Int, index_type, ctx, op_name, CTYPE, [&]() {
const CTYPE* const index_arr = index.const_data_ptr<CTYPE>();
for (const auto i : c10::irange(leading_dims)) {
const char* src = in_data + i * in_dim_length * length_per_step;
char* dest = out_data + i * out_dim_length * length_per_step;
for (const auto j : c10::irange(out_dim_length)) {
const char* copy_src = src + index_arr[j] * length_per_step;
char* copy_dest = dest + j * length_per_step;
memcpy(copy_dest, copy_src, length_per_step);
}
}
});

return out;
}

} // namespace

Tensor& index_Tensor_out(
KernelRuntimeContext& ctx,
const Tensor& in,
TensorOptList indices,
Tensor& out) {
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);

ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);

size_t dim = 0;
bool is_fast_path = check_fast_path_conditions(in, indices, &dim);
if (is_fast_path) {
return fast_path(ctx, in, indices, dim, out);
}

ET_KERNEL_CHECK(
ctx, check_index_args(in, indices, out), InvalidArgument, out);

ScalarType in_type = in.scalar_type();
size_t block_count = count_index_blocks(indices);

Expand Down
Loading
Loading