Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/pyarrow/_compute.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ cdef class ScalarUdfContext(_Weakrefable):

cdef void init(self, const CScalarUdfContext& c_context)

cdef class FunctionRegistry(_Weakrefable):
cdef:
CFunctionRegistry* registry

cdef class FunctionOptions(_Weakrefable):
cdef:
shared_ptr[CFunctionOptions] wrapped
Expand Down
31 changes: 27 additions & 4 deletions python/pyarrow/_compute.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -477,10 +477,20 @@ cdef _pack_compute_args(object values, vector[CDatum]* out):


cdef class FunctionRegistry(_Weakrefable):
cdef CFunctionRegistry* registry

def __init__(self):
self.registry = GetFunctionRegistry()
def __init__(self, parent=None):
cdef:
FunctionRegistry new_rg
FunctionRegistry parent_rg
if parent is None:
self.registry = GetFunctionRegistry()
else:
parent_rg = <FunctionRegistry>(parent)
self.registry = CFunctionRegistry.Make(
parent_rg.registry).release()

def __dealloc__(self):
self.clear()

def list_functions(self):
"""
Expand All @@ -505,6 +515,8 @@ cdef class FunctionRegistry(_Weakrefable):
func = GetResultValue(self.registry.GetFunction(c_name))
return wrap_function(func)

def clear(self):
del self.registry

cdef FunctionRegistry _global_func_registry = FunctionRegistry()

Expand Down Expand Up @@ -2515,7 +2527,7 @@ def _get_scalar_udf_context(memory_pool, batch_length):


def register_scalar_function(func, function_name, function_doc, in_types,
out_type):
out_type, registry=function_registry()):
"""
Register a user-defined scalar function.

Expand Down Expand Up @@ -2556,6 +2568,12 @@ def register_scalar_function(func, function_name, function_doc, in_types,
arity.
out_type : DataType
Output type of the function.
registry: FunctionRegistry
FunctionRegistry with which the function will be registered.
The default value is set to the global function registry.
This is an optional feature to allow grouping functions into
different registeries to enable removing functions if they
are not intended to be used further.

Examples
--------
Expand Down Expand Up @@ -2593,6 +2611,8 @@ def register_scalar_function(func, function_name, function_doc, in_types,
PyObject* c_function
shared_ptr[CDataType] c_out_type
CScalarUdfOptions c_options
CFunctionRegistry* c_registry
FunctionRegistry func_registry

if callable(func):
c_function = <PyObject*>func
Expand Down Expand Up @@ -2628,11 +2648,14 @@ def register_scalar_function(func, function_name, function_doc, in_types,

c_out_type = pyarrow_unwrap_data_type(ensure_type(out_type))

func_registry = <FunctionRegistry>(registry)

c_options.func_name = c_func_name
c_options.arity = c_arity
c_options.func_doc = c_func_doc
c_options.input_types = c_in_types
c_options.output_type = c_out_type
c_options.registry = func_registry.registry

check_status(RegisterScalarFunction(c_function,
<function[CallbackUdf]> &_scalar_udf_callback, c_options))
6 changes: 6 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -1853,6 +1853,7 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil:
CExecContext()
CExecContext(CMemoryPool* pool)
CExecContext(CMemoryPool* pool, CExecutor* exc)
CExecContext(CMemoryPool* pool, CExecutor* exc, CFunctionRegistry* rgr)

CMemoryPool* memory_pool() const
CExecutor* executor()
Expand Down Expand Up @@ -1956,6 +1957,9 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil:
vector[c_string] GetFunctionNames() const
int num_functions() const

@staticmethod
unique_ptr[CFunctionRegistry] Make(CFunctionRegistry* parent)

CFunctionRegistry* GetFunctionRegistry()

cdef cppclass CElementWiseAggregateOptions \
Expand Down Expand Up @@ -2762,13 +2766,15 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py":
cdef cppclass CScalarUdfContext" arrow::py::ScalarUdfContext":
CMemoryPool *pool
int64_t batch_length
CFunctionRegistry *registry

cdef cppclass CScalarUdfOptions" arrow::py::ScalarUdfOptions":
c_string func_name
CArity arity
CFunctionDoc func_doc
vector[shared_ptr[CDataType]] input_types
shared_ptr[CDataType] output_type
CFunctionRegistry* registry

CStatus RegisterScalarFunction(PyObject* function,
function[CallbackUdf] wrapper, const CScalarUdfOptions& options)
3 changes: 1 addition & 2 deletions python/pyarrow/src/arrow/python/udf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,7 @@ Status RegisterScalarFunction(PyObject* user_function, ScalarUdfWrapperCallback
kernel.mem_allocation = compute::MemAllocation::NO_PREALLOCATE;
kernel.null_handling = compute::NullHandling::COMPUTED_NO_PREALLOCATE;
RETURN_NOT_OK(scalar_func->AddKernel(std::move(kernel)));
auto registry = compute::GetFunctionRegistry();
RETURN_NOT_OK(registry->AddFunction(std::move(scalar_func)));
RETURN_NOT_OK(options.registry->AddFunction(std::move(scalar_func)));
return Status::OK();
}

Expand Down
1 change: 1 addition & 0 deletions python/pyarrow/src/arrow/python/udf.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ struct ARROW_PYTHON_EXPORT ScalarUdfOptions {
compute::FunctionDoc func_doc;
std::vector<std::shared_ptr<DataType>> input_types;
std::shared_ptr<DataType> output_type;
compute::FunctionRegistry* registry;
};

/// \brief A context passed as the first argument of scalar UDF functions.
Expand Down
108 changes: 108 additions & 0 deletions python/pyarrow/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@


import pytest
import multiprocessing as mp

import pyarrow as pa
from pyarrow import compute as pc
Expand Down Expand Up @@ -504,3 +505,110 @@ def test_input_lifetime(unary_func_fixture):
# Calling a UDF should not have kept `v` alive longer than required
v = None
assert proxy_pool.bytes_allocated() == 0


def test_nested_function_registry():
def f1(ctx, x):
return pc.call_function("add", [x, 1],
memory_pool=ctx.memory_pool)
func_name = "f1"
unary_doc = {"summary": "add function",
"description": "test add function"}

default_registry = pc.function_registry()

registry1 = pc.FunctionRegistry(default_registry)

registry2 = pc.FunctionRegistry(registry1)

pc.register_scalar_function(f1,
func_name,
unary_doc,
{"array": pa.int64()},
pa.int64(),
registry2)

assert registry2.get_function(func_name).name == func_name

error_msg = "No function registered with name: f1"
with pytest.raises(pa.lib.ArrowKeyError, match=error_msg):
registry1.get_function(func_name)

pc.register_scalar_function(f1,
func_name,
unary_doc,
{"array": pa.int64()},
pa.int64(),
registry1)
assert registry1.get_function(func_name).name == func_name

pc.register_scalar_function(f1,
func_name,
unary_doc,
{"array": pa.int64()},
pa.int64(),
default_registry)

assert default_registry.get_function(func_name).name == func_name


def parallel_task1(data, q):
def f1(ctx, x):
return pc.call_function("add", [x, 10],
memory_pool=ctx.memory_pool)
func_name = "f1"
unary_doc = {"summary": "add function",
"description": "test add function"}

default_registry = pc.function_registry()

registry1 = pc.FunctionRegistry(default_registry)
pc.register_scalar_function(f1,
func_name,
unary_doc,
{"array": pa.int64()},
pa.int64(),
registry1)
func = registry1.get_function(func_name)
result = func.call(data)
q.put(result)


def parallel_task2(data, q):
def f1(ctx, x):
return pc.call_function("multiply", [x, 10],
memory_pool=ctx.memory_pool)
func_name = "f1"
unary_doc = {"summary": "multiply function",
"description": "test multiply function"}

default_registry = pc.function_registry()

registry1 = pc.FunctionRegistry(default_registry)
pc.register_scalar_function(f1,
func_name,
unary_doc,
{"array": pa.int64()},
pa.int64(),
registry1)
func = registry1.get_function(func_name)
q.put(func.call(data))


def test_udf_usage_by_scope():
# With support to custom function registration on global and
# nested function registries, user has the ability to make an
# scope for a registry and get some tasks done in a particular
# process and drop the registry once the process terminates.
ctx = mp.get_context('spawn')
q = ctx.Queue()
p1 = ctx.Process(target=parallel_task1, args=(
[pa.array([10, 20, 30], pa.int64())], q))
p1.start()
result = q.get()
p1.join()
p2 = ctx.Process(target=parallel_task2, args=([result], q))
p2.start()
final_result = q.get()
p2.join()
assert final_result == pa.array([200, 300, 400], pa.int64())