diff --git a/python/pyarrow/_compute.pxd b/python/pyarrow/_compute.pxd index 8b09cbd445e..545c730831a 100644 --- a/python/pyarrow/_compute.pxd +++ b/python/pyarrow/_compute.pxd @@ -27,6 +27,10 @@ cdef class ScalarUdfContext(_Weakrefable): cdef void init(self, const CScalarUdfContext& c_context) +cdef class FunctionRegistry(_Weakrefable): + cdef: + CFunctionRegistry* registry + cdef class FunctionOptions(_Weakrefable): cdef: shared_ptr[CFunctionOptions] wrapped diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 2aa65e75c50..0d17c1d1226 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -477,10 +477,20 @@ cdef _pack_compute_args(object values, vector[CDatum]* out): cdef class FunctionRegistry(_Weakrefable): - cdef CFunctionRegistry* registry - def __init__(self): - self.registry = GetFunctionRegistry() + def __init__(self, parent=None): + cdef: + FunctionRegistry new_rg + FunctionRegistry parent_rg + if parent is None: + self.registry = GetFunctionRegistry() + else: + parent_rg = (parent) + self.registry = CFunctionRegistry.Make( + parent_rg.registry).release() + + def __dealloc__(self): + self.clear() def list_functions(self): """ @@ -505,6 +515,8 @@ cdef class FunctionRegistry(_Weakrefable): func = GetResultValue(self.registry.GetFunction(c_name)) return wrap_function(func) + def clear(self): + del self.registry cdef FunctionRegistry _global_func_registry = FunctionRegistry() @@ -2515,7 +2527,7 @@ def _get_scalar_udf_context(memory_pool, batch_length): def register_scalar_function(func, function_name, function_doc, in_types, - out_type): + out_type, registry=function_registry()): """ Register a user-defined scalar function. @@ -2556,6 +2568,12 @@ def register_scalar_function(func, function_name, function_doc, in_types, arity. out_type : DataType Output type of the function. + registry: FunctionRegistry + FunctionRegistry with which the function will be registered. + The default value is set to the global function registry. + This is an optional feature to allow grouping functions into + different registeries to enable removing functions if they + are not intended to be used further. Examples -------- @@ -2593,6 +2611,8 @@ def register_scalar_function(func, function_name, function_doc, in_types, PyObject* c_function shared_ptr[CDataType] c_out_type CScalarUdfOptions c_options + CFunctionRegistry* c_registry + FunctionRegistry func_registry if callable(func): c_function = func @@ -2628,11 +2648,14 @@ def register_scalar_function(func, function_name, function_doc, in_types, c_out_type = pyarrow_unwrap_data_type(ensure_type(out_type)) + func_registry = (registry) + c_options.func_name = c_func_name c_options.arity = c_arity c_options.func_doc = c_func_doc c_options.input_types = c_in_types c_options.output_type = c_out_type + c_options.registry = func_registry.registry check_status(RegisterScalarFunction(c_function, &_scalar_udf_callback, c_options)) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index e44fa2615e2..1d027a3451b 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1853,6 +1853,7 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: CExecContext() CExecContext(CMemoryPool* pool) CExecContext(CMemoryPool* pool, CExecutor* exc) + CExecContext(CMemoryPool* pool, CExecutor* exc, CFunctionRegistry* rgr) CMemoryPool* memory_pool() const CExecutor* executor() @@ -1956,6 +1957,9 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: vector[c_string] GetFunctionNames() const int num_functions() const + @staticmethod + unique_ptr[CFunctionRegistry] Make(CFunctionRegistry* parent) + CFunctionRegistry* GetFunctionRegistry() cdef cppclass CElementWiseAggregateOptions \ @@ -2762,6 +2766,7 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py": cdef cppclass CScalarUdfContext" arrow::py::ScalarUdfContext": CMemoryPool *pool int64_t batch_length + CFunctionRegistry *registry cdef cppclass CScalarUdfOptions" arrow::py::ScalarUdfOptions": c_string func_name @@ -2769,6 +2774,7 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py": CFunctionDoc func_doc vector[shared_ptr[CDataType]] input_types shared_ptr[CDataType] output_type + CFunctionRegistry* registry CStatus RegisterScalarFunction(PyObject* function, function[CallbackUdf] wrapper, const CScalarUdfOptions& options) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 81bf47c0ade..fe50133713c 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -115,8 +115,7 @@ Status RegisterScalarFunction(PyObject* user_function, ScalarUdfWrapperCallback kernel.mem_allocation = compute::MemAllocation::NO_PREALLOCATE; kernel.null_handling = compute::NullHandling::COMPUTED_NO_PREALLOCATE; RETURN_NOT_OK(scalar_func->AddKernel(std::move(kernel))); - auto registry = compute::GetFunctionRegistry(); - RETURN_NOT_OK(registry->AddFunction(std::move(scalar_func))); + RETURN_NOT_OK(options.registry->AddFunction(std::move(scalar_func))); return Status::OK(); } diff --git a/python/pyarrow/src/arrow/python/udf.h b/python/pyarrow/src/arrow/python/udf.h index 9a3666459fd..e3784baf632 100644 --- a/python/pyarrow/src/arrow/python/udf.h +++ b/python/pyarrow/src/arrow/python/udf.h @@ -39,6 +39,7 @@ struct ARROW_PYTHON_EXPORT ScalarUdfOptions { compute::FunctionDoc func_doc; std::vector> input_types; std::shared_ptr output_type; + compute::FunctionRegistry* registry; }; /// \brief A context passed as the first argument of scalar UDF functions. diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index e711619582d..66aae8529eb 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -17,6 +17,7 @@ import pytest +import multiprocessing as mp import pyarrow as pa from pyarrow import compute as pc @@ -504,3 +505,110 @@ def test_input_lifetime(unary_func_fixture): # Calling a UDF should not have kept `v` alive longer than required v = None assert proxy_pool.bytes_allocated() == 0 + + +def test_nested_function_registry(): + def f1(ctx, x): + return pc.call_function("add", [x, 1], + memory_pool=ctx.memory_pool) + func_name = "f1" + unary_doc = {"summary": "add function", + "description": "test add function"} + + default_registry = pc.function_registry() + + registry1 = pc.FunctionRegistry(default_registry) + + registry2 = pc.FunctionRegistry(registry1) + + pc.register_scalar_function(f1, + func_name, + unary_doc, + {"array": pa.int64()}, + pa.int64(), + registry2) + + assert registry2.get_function(func_name).name == func_name + + error_msg = "No function registered with name: f1" + with pytest.raises(pa.lib.ArrowKeyError, match=error_msg): + registry1.get_function(func_name) + + pc.register_scalar_function(f1, + func_name, + unary_doc, + {"array": pa.int64()}, + pa.int64(), + registry1) + assert registry1.get_function(func_name).name == func_name + + pc.register_scalar_function(f1, + func_name, + unary_doc, + {"array": pa.int64()}, + pa.int64(), + default_registry) + + assert default_registry.get_function(func_name).name == func_name + + +def parallel_task1(data, q): + def f1(ctx, x): + return pc.call_function("add", [x, 10], + memory_pool=ctx.memory_pool) + func_name = "f1" + unary_doc = {"summary": "add function", + "description": "test add function"} + + default_registry = pc.function_registry() + + registry1 = pc.FunctionRegistry(default_registry) + pc.register_scalar_function(f1, + func_name, + unary_doc, + {"array": pa.int64()}, + pa.int64(), + registry1) + func = registry1.get_function(func_name) + result = func.call(data) + q.put(result) + + +def parallel_task2(data, q): + def f1(ctx, x): + return pc.call_function("multiply", [x, 10], + memory_pool=ctx.memory_pool) + func_name = "f1" + unary_doc = {"summary": "multiply function", + "description": "test multiply function"} + + default_registry = pc.function_registry() + + registry1 = pc.FunctionRegistry(default_registry) + pc.register_scalar_function(f1, + func_name, + unary_doc, + {"array": pa.int64()}, + pa.int64(), + registry1) + func = registry1.get_function(func_name) + q.put(func.call(data)) + + +def test_udf_usage_by_scope(): + # With support to custom function registration on global and + # nested function registries, user has the ability to make an + # scope for a registry and get some tasks done in a particular + # process and drop the registry once the process terminates. + ctx = mp.get_context('spawn') + q = ctx.Queue() + p1 = ctx.Process(target=parallel_task1, args=( + [pa.array([10, 20, 30], pa.int64())], q)) + p1.start() + result = q.get() + p1.join() + p2 = ctx.Process(target=parallel_task2, args=([result], q)) + p2.start() + final_result = q.get() + p2.join() + assert final_result == pa.array([200, 300, 400], pa.int64())