Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cpp/src/arrow/compute/function_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ static inline Result<std::shared_ptr<Scalar>> GenericToScalar(
static inline Result<std::shared_ptr<Scalar>> GenericToScalar(
const std::shared_ptr<DataType>& value) {
if (!value) {
return Status::Invalid("shared_ptr<DataType> is nullptr");
return std::make_shared<NullScalar>();
}
return MakeNullScalar(value);
}
Expand Down Expand Up @@ -448,6 +448,9 @@ static inline enable_if_same_result<T, SortKey> GenericFromScalar(
template <typename T>
static inline enable_if_same_result<T, std::shared_ptr<DataType>> GenericFromScalar(
const std::shared_ptr<Scalar>& value) {
if (value->type->id() == Type::NA) {
return std::shared_ptr<NullType>();
}
return value->type;
}

Expand Down
5 changes: 3 additions & 2 deletions cpp/src/arrow/compute/kernels/vector_swizzle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ namespace {

const FunctionDoc inverse_permutation_doc(
"Return the inverse permutation of the given indices",
"For the `i`-th `index` in `indices`, the `index`-th output is `i`", {"indices"});
"For the `i`-th `index` in `indices`, the `index`-th output is `i`", {"indices"},
"InversePermutationOptions");

const InversePermutationOptions* GetDefaultInversePermutationOptions() {
static const auto kDefaultInversePermutationOptions =
Expand Down Expand Up @@ -332,7 +333,7 @@ void RegisterVectorInversePermutation(FunctionRegistry* registry) {
const FunctionDoc scatter_doc(
"Scatter the values into specified positions according to the indices",
"Place the `i`-th value at the position specified by the `i`-th index",
{"values", "indices"});
{"values", "indices"}, "ScatterOptions");

const ScatterOptions* GetDefaultScatterOptions() {
static const auto kDefaultScatterOptions = ScatterOptions::Defaults();
Expand Down
4 changes: 3 additions & 1 deletion docs/source/python/api/compute.rst
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,8 @@ Selections
drop_null
filter
inverse_permutation
take
scatter
take

Sorts and Partitions
--------------------
Expand Down Expand Up @@ -606,6 +606,7 @@ Compute Options
ExtractRegexSpanOptions
FilterOptions
IndexOptions
InversePermutationOptions
JoinOptions
ListFlattenOptions
ListSliceOptions
Expand Down Expand Up @@ -635,6 +636,7 @@ Compute Options
SkewOptions
SliceOptions
SortOptions
ScatterOptions
SplitOptions
SplitPatternOptions
StrftimeOptions
Expand Down
56 changes: 56 additions & 0 deletions python/pyarrow/_compute.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1444,6 +1444,62 @@ class RunEndEncodeOptions(_RunEndEncodeOptions):
self._set_options(run_end_type)


cdef class _InversePermutationOptions(FunctionOptions):
def _set_options(self, max_index, output_type):
if output_type is None:
self.wrapped.reset(new CInversePermutationOptions(max_index))
else:
output_ty = ensure_type(output_type)
self.wrapped.reset(
new CInversePermutationOptions(max_index,
pyarrow_unwrap_data_type(output_ty)))


class InversePermutationOptions(_InversePermutationOptions):
"""
Options for `inverse_permutation` function.

Parameters
----------
max_index : int64, default -1
The max value in the input indices to allow.
The length of the function’s output will be this value plus 1.
If negative, this value will be set to the length of the input indices
minus 1 and the length of the function’s output will be the length
of the input indices.
output_type : DataType, default None
The type of the output inverse permutation.
If None, the output will be of the same type as the input indices, otherwise
must be signed integer type. An invalid error will be reported if this type
is not able to store the length of the input indices.
"""

def __init__(self, max_index=-1, output_type=None):
self._set_options(max_index, output_type)


cdef class _ScatterOptions(FunctionOptions):
def _set_options(self, max_index):
self.wrapped.reset(new CScatterOptions(max_index))


class ScatterOptions(_ScatterOptions):
"""
Options for `scatter` function.

Parameters
----------
max_index : int64, default -1
The max value in the input indices to allow.
The length of the function’s output will be this value plus 1.
If negative, this value will be set to the length of the input indices minus 1
and the length of the function’s output will be the length of the input indices.
"""

def __init__(self, max_index=-1):
self._set_options(max_index)


cdef class _TakeOptions(FunctionOptions):
def _set_options(self, boundscheck):
self.wrapped.reset(new CTakeOptions(boundscheck))
Expand Down
2 changes: 2 additions & 0 deletions python/pyarrow/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
ExtractRegexSpanOptions,
FilterOptions,
IndexOptions,
InversePermutationOptions,
JoinOptions,
ListSliceOptions,
ListFlattenOptions,
Expand All @@ -66,6 +67,7 @@
RoundTemporalOptions,
RoundToMultipleOptions,
ScalarAggregateOptions,
ScatterOptions,
SelectKOptions,
SetLookupOptions,
SkewOptions,
Expand Down
12 changes: 12 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2588,6 +2588,18 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil:
CTakeOptions(c_bool boundscheck)
c_bool boundscheck

cdef cppclass CInversePermutationOptions \
"arrow::compute::InversePermutationOptions"(CFunctionOptions):
CInversePermutationOptions(int64_t max_index)
CInversePermutationOptions(int64_t max_index, shared_ptr[CDataType] output_type)
int64_t max_index
shared_ptr[CDataType] output_type

cdef cppclass CScatterOptions \
"arrow::compute::ScatterOptions"(CFunctionOptions):
CScatterOptions(int64_t max_index)
int64_t max_index

cdef cppclass CStrptimeOptions \
"arrow::compute::StrptimeOptions"(CFunctionOptions):
CStrptimeOptions(c_string format, TimeUnit unit, c_bool raise_error)
Expand Down
34 changes: 33 additions & 1 deletion python/pyarrow/tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

import pyarrow as pa
import pyarrow.compute as pc
from pyarrow.lib import ArrowNotImplementedError
from pyarrow.lib import ArrowNotImplementedError, ArrowIndexError

try:
import pyarrow.substrait as pas
Expand Down Expand Up @@ -1590,6 +1590,38 @@ def test_filter_null_type():
assert len(table.filter(mask).column(0)) == 5


def test_inverse_permutation():
arr0 = pa.array([], type=pa.int32())
arr = pa.chunked_array([
arr0, [9, 7, 5, 3, 1], [0], [2, 4, 6], [8], arr0,
])
expected = pa.chunked_array([[5, 4, 6, 3, 7, 2, 8, 1, 9, 0]], type=pa.int32())
assert pc.inverse_permutation(arr).equals(expected)

options = pc.InversePermutationOptions(max_index=9, output_type=pa.int32())
assert pc.inverse_permutation(arr, options=options).equals(expected)
assert pc.inverse_permutation(arr, max_index=-1).equals(expected)

with pytest.raises(ArrowIndexError, match="Index out of bounds: 9"):
pc.inverse_permutation(arr, max_index=4)


def test_scatter():
values = pa.array([True, False, True, True, False, False, True, True, True, False])
indices = pa.array([9, 8, 7, 6, 5, 4, 3, 2, 1, 0])
expected = pa.array([False, True, True, True, False,
False, True, True, False, True])
result = pc.scatter(values, indices)
assert result.equals(expected)

options = pc.ScatterOptions(max_index=-1)
assert pc.scatter(values, indices, options=options).equals(expected)
assert pc.scatter(values, indices, max_index=9).equals(expected)

with pytest.raises(ArrowIndexError, match="Index out of bounds: 9"):
pc.scatter(values, indices, max_index=4)


@pytest.mark.parametrize("typ", ["array", "chunked_array"])
def test_compare_array(typ):
if typ == "array":
Expand Down
Loading