Skip to content

Commit ff80b30

Browse files
paleolimbotkou
authored andcommitted
ARROW-17460: [R] Don't warn if the new UDF I'm registering is the same as the existing one (#14436)
This PR makes it so that you can do the following without a warning: ``` r library(arrow, warn.conflicts = FALSE) register_scalar_function( "times_32", function(context, x) x * 32L, in_type = list(int32(), int64(), float64()), out_type = function(in_types) in_types[[1]], auto_convert = TRUE ) register_scalar_function( "times_32", function(context, x) x * 32L, in_type = list(int32(), int64(), float64()), out_type = function(in_types) in_types[[1]], auto_convert = TRUE ) ``` In fixing that, I also ran across an important discovery, which is that `cpp11::function` does *not* protect the underlying `SEXP` from garbage collection!!!! It the two functions we used this for were being protected by proxy because the execution environment of `register_scalar_function()` was being saved when the binding was registered. Authored-by: Dewey Dunnington <dewey@voltrondata.com> Signed-off-by: Dewey Dunnington <dewey@fishandwhistle.net>
1 parent 04ccd84 commit ff80b30

File tree

5 files changed

+21
-8
lines changed

5 files changed

+21
-8
lines changed

r/R/compute.R

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,9 +379,17 @@ register_scalar_function <- function(name, fun, in_type, out_type,
379379
RegisterScalarUDF(name, scalar_function)
380380

381381
# register with dplyr binding (enables its use in mutate(), filter(), etc.)
382+
binding_fun <- function(...) build_expr(name, ...)
383+
384+
# inject the value of `name` into the expression to avoid saving this
385+
# execution environment in the binding, which eliminates a warning when the
386+
# same binding is registered twice
387+
body(binding_fun) <- expr_substitute(body(binding_fun), sym("name"), name)
388+
environment(binding_fun) <- asNamespace("arrow")
389+
382390
register_binding(
383391
name,
384-
function(...) build_expr(name, ...),
392+
binding_fun,
385393
update_cache = TRUE
386394
)
387395

r/R/dplyr-funcs.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ register_binding <- function(fun_name,
7575
previous_fun <- registry[[unqualified_name]]
7676

7777
# if the unqualified name exists in the registry, warn
78-
if (!is.null(previous_fun)) {
78+
if (!is.null(previous_fun) && !identical(fun, previous_fun)) {
7979
warn(
8080
paste0(
8181
"A \"",

r/src/compute.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -611,8 +611,8 @@ class RScalarUDFKernelState : public arrow::compute::KernelState {
611611
RScalarUDFKernelState(cpp11::sexp exec_func, cpp11::sexp resolver)
612612
: exec_func_(exec_func), resolver_(resolver) {}
613613

614-
cpp11::function exec_func_;
615-
cpp11::function resolver_;
614+
cpp11::sexp exec_func_;
615+
cpp11::sexp resolver_;
616616
};
617617

618618
arrow::Result<arrow::TypeHolder> ResolveScalarUDFOutputType(
@@ -630,7 +630,8 @@ arrow::Result<arrow::TypeHolder> ResolveScalarUDFOutputType(
630630
cpp11::to_r6<arrow::DataType>(input_types[i].GetSharedPtr());
631631
}
632632

633-
cpp11::sexp output_type_sexp = state->resolver_(input_types_sexp);
633+
cpp11::sexp output_type_sexp =
634+
cpp11::function(state->resolver_)(input_types_sexp);
634635
if (!Rf_inherits(output_type_sexp, "DataType")) {
635636
cpp11::stop(
636637
"Function specified as arrow_scalar_function() out_type argument must "
@@ -674,7 +675,8 @@ arrow::Status CallRScalarUDF(arrow::compute::KernelContext* context,
674675
cpp11::writable::list udf_context = {batch_length_sexp, output_type_sexp};
675676
udf_context.names() = {"batch_length", "output_type"};
676677

677-
cpp11::sexp func_result_sexp = state->exec_func_(udf_context, args_sexp);
678+
cpp11::sexp func_result_sexp =
679+
cpp11::function(state->exec_func_)(udf_context, args_sexp);
678680

679681
if (Rf_inherits(func_result_sexp, "Array")) {
680682
auto array = cpp11::as_cpp<std::shared_ptr<arrow::Array>>(func_result_sexp);

r/src/recordbatchreader.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class RFunctionRecordBatchReader : public arrow::RecordBatchReader {
7070

7171
arrow::Status ReadNext(std::shared_ptr<arrow::RecordBatch>* batch_out) {
7272
auto batch = SafeCallIntoR<std::shared_ptr<arrow::RecordBatch>>([&]() {
73-
cpp11::sexp result_sexp = fun_();
73+
cpp11::sexp result_sexp = cpp11::function(fun_)();
7474
if (result_sexp == R_NilValue) {
7575
return std::shared_ptr<arrow::RecordBatch>(nullptr);
7676
} else if (!Rf_inherits(result_sexp, "RecordBatch")) {
@@ -94,7 +94,7 @@ class RFunctionRecordBatchReader : public arrow::RecordBatchReader {
9494
}
9595

9696
private:
97-
cpp11::function fun_;
97+
cpp11::sexp fun_;
9898
std::shared_ptr<arrow::Schema> schema_;
9999
};
100100

r/tests/testthat/test-dplyr-funcs.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ test_that("register_binding()/unregister_binding() works", {
3535
register_binding("some.pkg2::some_fun", fun2, fake_registry),
3636
"A \"some_fun\" binding already exists in the registry and will be overwritten."
3737
)
38+
39+
# No warning when an identical function is re-registered
40+
expect_silent(register_binding("some.pkg2::some_fun", fun2, fake_registry))
3841
})
3942

4043
test_that("register_binding_agg() works", {

0 commit comments

Comments
 (0)