diff --git a/r/NAMESPACE b/r/NAMESPACE index c7d2657baed..60f53524c14 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -45,7 +45,9 @@ S3method(as_arrow_array,data.frame) S3method(as_arrow_array,default) S3method(as_arrow_array,pyarrow.lib.Array) S3method(as_arrow_table,RecordBatch) +S3method(as_arrow_table,RecordBatchReader) S3method(as_arrow_table,Table) +S3method(as_arrow_table,arrow_dplyr_query) S3method(as_arrow_table,data.frame) S3method(as_arrow_table,default) S3method(as_arrow_table,pyarrow.lib.RecordBatch) @@ -343,6 +345,7 @@ export(read_schema) export(read_tsv_arrow) export(record_batch) export(register_extension_type) +export(register_scalar_function) export(reregister_extension_type) export(s3_bucket) export(schema) diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 84f6ee54fc7..dfe0db614ad 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -408,6 +408,10 @@ ExecPlan_run <- function(plan, final_node, sort_options, metadata, head) { .Call(`_arrow_ExecPlan_run`, plan, final_node, sort_options, metadata, head) } +ExecPlan_read_table <- function(plan, final_node, sort_options, metadata, head) { + .Call(`_arrow_ExecPlan_read_table`, plan, final_node, sort_options, metadata, head) +} + ExecPlan_StopProducing <- function(plan) { invisible(.Call(`_arrow_ExecPlan_StopProducing`, plan)) } @@ -480,6 +484,10 @@ compute__GetFunctionNames <- function() { .Call(`_arrow_compute__GetFunctionNames`) } +RegisterScalarUDF <- function(name, func_sexp) { + invisible(.Call(`_arrow_RegisterScalarUDF`, name, func_sexp)) +} + build_info <- function() { .Call(`_arrow_build_info`) } @@ -1108,12 +1116,12 @@ ipc___feather___Reader__version <- function(reader) { .Call(`_arrow_ipc___feather___Reader__version`, reader) } -ipc___feather___Reader__Read <- function(reader, columns, on_old_windows) { - .Call(`_arrow_ipc___feather___Reader__Read`, reader, columns, on_old_windows) +ipc___feather___Reader__Read <- function(reader, columns) { + .Call(`_arrow_ipc___feather___Reader__Read`, reader, columns) } -ipc___feather___Reader__Open <- function(stream, on_old_windows) { - .Call(`_arrow_ipc___feather___Reader__Open`, stream, on_old_windows) +ipc___feather___Reader__Open <- function(stream) { + .Call(`_arrow_ipc___feather___Reader__Open`, stream) } ipc___feather___Reader__schema <- function(reader) { @@ -1792,6 +1800,10 @@ InitializeMainRThread <- function() { invisible(.Call(`_arrow_InitializeMainRThread`)) } +CanRunWithCapturedR <- function() { + .Call(`_arrow_CanRunWithCapturedR`) +} + TestSafeCallIntoR <- function(r_fun_that_returns_a_string, opt) { .Call(`_arrow_TestSafeCallIntoR`, r_fun_that_returns_a_string, opt) } diff --git a/r/R/compute.R b/r/R/compute.R index 1cd12f2e29d..0985e73a5f2 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -306,3 +306,179 @@ cast_options <- function(safe = TRUE, ...) { ) modifyList(opts, list(...)) } + +#' Register user-defined functions +#' +#' These functions support calling R code from query engine execution +#' (i.e., a [dplyr::mutate()] or [dplyr::filter()] on a [Table] or [Dataset]). +#' Use [register_scalar_function()] attach Arrow input and output types to an +#' R function and make it available for use in the dplyr interface and/or +#' [call_function()]. Scalar functions are currently the only type of +#' user-defined function supported. In Arrow, scalar functions must be +#' stateless and return output with the same shape (i.e., the same number +#' of rows) as the input. +#' +#' @param name The function name to be used in the dplyr bindings +#' @param in_type A [DataType] of the input type or a [schema()] +#' for functions with more than one argument. This signature will be used +#' to determine if this function is appropriate for a given set of arguments. +#' If this function is appropriate for more than one signature, pass a +#' `list()` of the above. +#' @param out_type A [DataType] of the output type or a function accepting +#' a single argument (`types`), which is a `list()` of [DataType]s. If a +#' function it must return a [DataType]. +#' @param fun An R function or rlang-style lambda expression. The function +#' will be called with a first argument `context` which is a `list()` +#' with elements `batch_size` (the expected length of the output) and +#' `output_type` (the required [DataType] of the output) that may be used +#' to ensure that the output has the correct type and length. Subsequent +#' arguments are passed by position as specified by `in_types`. If +#' `auto_convert` is `TRUE`, subsequent arguments are converted to +#' R vectors before being passed to `fun` and the output is automatically +#' constructed with the expected output type via [as_arrow_array()]. +#' @param auto_convert Use `TRUE` to convert inputs before passing to `fun` +#' and construct an Array of the correct type from the output. Use this +#' option to write functions of R objects as opposed to functions of +#' Arrow R6 objects. +#' +#' @return `NULL`, invisibly +#' @export +#' +#' @examplesIf arrow_with_dataset() +#' library(dplyr, warn.conflicts = FALSE) +#' +#' some_model <- lm(mpg ~ disp + cyl, data = mtcars) +#' register_scalar_function( +#' "mtcars_predict_mpg", +#' function(context, disp, cyl) { +#' predict(some_model, newdata = data.frame(disp, cyl)) +#' }, +#' in_type = schema(disp = float64(), cyl = float64()), +#' out_type = float64(), +#' auto_convert = TRUE +#' ) +#' +#' as_arrow_table(mtcars) %>% +#' transmute(mpg, mpg_predicted = mtcars_predict_mpg(disp, cyl)) %>% +#' collect() %>% +#' head() +#' +register_scalar_function <- function(name, fun, in_type, out_type, + auto_convert = FALSE) { + assert_that(is.string(name)) + + scalar_function <- arrow_scalar_function( + fun, + in_type, + out_type, + auto_convert = auto_convert + ) + + # register with Arrow C++ function registry (enables its use in + # call_function() and Expression$create()) + RegisterScalarUDF(name, scalar_function) + + # register with dplyr binding (enables its use in mutate(), filter(), etc.) + register_binding( + name, + function(...) build_expr(name, ...), + update_cache = TRUE + ) + + invisible(NULL) +} + +arrow_scalar_function <- function(fun, in_type, out_type, auto_convert = FALSE) { + assert_that(is.function(fun)) + + # Create a small wrapper function that is easier to call from C++. + # TODO(ARROW-17148): This wrapper could be implemented in C/C++ to + # reduce evaluation overhead and generate prettier backtraces when + # errors occur (probably using a similar approach to purrr). + if (auto_convert) { + wrapper_fun <- function(context, args) { + args <- lapply(args, as.vector) + result <- do.call(fun, c(list(context), args)) + as_arrow_array(result, type = context$output_type) + } + } else { + wrapper_fun <- function(context, args) { + do.call(fun, c(list(context), args)) + } + } + + # in_type can be a list() if registering multiple kernels at once + if (is.list(in_type)) { + in_type <- lapply(in_type, in_type_as_schema) + } else { + in_type <- list(in_type_as_schema(in_type)) + } + + # out_type can be a list() if registering multiple kernels at once + if (is.list(out_type)) { + out_type <- lapply(out_type, out_type_as_function) + } else { + out_type <- list(out_type_as_function(out_type)) + } + + # recycle out_type (which is frequently length 1 even if multiple kernels + # are being registered at once) + out_type <- rep_len(out_type, length(in_type)) + + # check n_kernels and number of args in fun + n_kernels <- length(in_type) + if (n_kernels == 0) { + abort("Can't register user-defined scalar function with 0 kernels") + } + + expected_n_args <- in_type[[1]]$num_fields + 1L + fun_formals_have_dots <- any(names(formals(fun)) == "...") + if (!fun_formals_have_dots && length(formals(fun)) != expected_n_args) { + abort( + sprintf( + paste0( + "Expected `fun` to accept %d argument(s)\n", + "but found a function that acccepts %d argument(s)\n", + "Did you forget to include `context` as the first argument?" + ), + expected_n_args, + length(formals(fun)) + ) + ) + } + + structure( + list( + wrapper_fun = wrapper_fun, + in_type = in_type, + out_type = out_type + ), + class = "arrow_scalar_function" + ) +} + +# This function sanitizes the in_type argument for arrow_scalar_function(), +# which can be a data type (e.g., int32()), a field for a unary function +# or a schema() for functions accepting more than one argument. C++ expects +# a schema(). +in_type_as_schema <- function(x) { + if (inherits(x, "Field")) { + schema(x) + } else if (inherits(x, "DataType")) { + schema(field("", x)) + } else { + as_schema(x) + } +} + +# This function sanitizes the out_type argument for arrow_scalar_function(), +# which can be a data type (e.g., int32()) or a function of the input types. +# C++ currently expects a function. +out_type_as_function <- function(x) { + if (is.function(x)) { + x + } else { + x <- as_data_type(x) + function(types) x + } +} diff --git a/r/R/dplyr-collect.R b/r/R/dplyr-collect.R index 7f10ed307e8..3e83475a8c8 100644 --- a/r/R/dplyr-collect.R +++ b/r/R/dplyr-collect.R @@ -20,7 +20,7 @@ collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) { tryCatch( - out <- as_record_batch_reader(x)$read_table(), + out <- as_arrow_table(x), # n = 4 because we want the error to show up as being from collect() # and not handle_csv_read_error() error = function(e, call = caller_env(n = 4)) { diff --git a/r/R/dplyr-funcs.R b/r/R/dplyr-funcs.R index 7c4ed99e2ed..c1dcdd17744 100644 --- a/r/R/dplyr-funcs.R +++ b/r/R/dplyr-funcs.R @@ -50,6 +50,13 @@ NULL #' - `fun`: string function name #' - `data`: `Expression` (these are all currently a single field) #' - `options`: list of function options, as passed to call_function +#' @param update_cache Update .cache$functions at the time of registration. +#' the default is FALSE because the majority of usage is to register +#' bindings at package load, after which we create the cache once. The +#' reason why .cache$functions is needed in addition to nse_funcs for +#' non-aggregate functions could be revisited...it is currently used +#' as the data mask in mutate, filter, and aggregate (but not +#' summarise) because the data mask has to be a list. #' @param registry An environment in which the functions should be #' assigned. #' @@ -57,13 +64,14 @@ NULL #' registered function existed. #' @keywords internal #' -register_binding <- function(fun_name, fun, registry = nse_funcs) { +register_binding <- function(fun_name, fun, registry = nse_funcs, + update_cache = FALSE) { unqualified_name <- sub("^.*?:{+}", "", fun_name) previous_fun <- registry[[unqualified_name]] # if the unqualified name exists in the registry, warn - if (!is.null(fun) && !is.null(previous_fun)) { + if (!is.null(previous_fun)) { warn( paste0( "A \"", @@ -73,11 +81,36 @@ register_binding <- function(fun_name, fun, registry = nse_funcs) { } # register both as `pkg::fun` and as `fun` if `qualified_name` is prefixed - if (grepl("::", fun_name)) { - registry[[unqualified_name]] <- fun - registry[[fun_name]] <- fun - } else { - registry[[unqualified_name]] <- fun + # unqualified_name and fun_name will be the same if not prefixed + registry[[unqualified_name]] <- fun + registry[[fun_name]] <- fun + + if (update_cache) { + fun_cache <- .cache$functions + fun_cache[[unqualified_name]] <- fun + fun_cache[[fun_name]] <- fun + .cache$functions <- fun_cache + } + + invisible(previous_fun) +} + +unregister_binding <- function(fun_name, registry = nse_funcs, + update_cache = FALSE) { + unqualified_name <- sub("^.*?:{+}", "", fun_name) + previous_fun <- registry[[unqualified_name]] + + rm( + list = unique(c(fun_name, unqualified_name)), + envir = registry, + inherits = FALSE + ) + + if (update_cache) { + fun_cache <- .cache$functions + fun_cache[[unqualified_name]] <- NULL + fun_cache[[fun_name]] <- NULL + .cache$functions <- fun_cache } invisible(previous_fun) diff --git a/r/R/feather.R b/r/R/feather.R index 02871396fa6..73eb5d8b6fd 100644 --- a/r/R/feather.R +++ b/r/R/feather.R @@ -190,7 +190,7 @@ FeatherReader <- R6Class("FeatherReader", inherit = ArrowObject, public = list( Read = function(columns) { - ipc___feather___Reader__Read(self, columns, on_old_windows()) + ipc___feather___Reader__Read(self, columns) }, print = function(...) { cat("FeatherReader:\n") @@ -211,5 +211,5 @@ names.FeatherReader <- function(x) x$column_names FeatherReader$create <- function(file) { assert_is(file, "RandomAccessFile") - ipc___feather___Reader__Open(file, on_old_windows()) + ipc___feather___Reader__Open(file) } diff --git a/r/R/query-engine.R b/r/R/query-engine.R index 511bf3dbc27..e63fa75ebf1 100644 --- a/r/R/query-engine.R +++ b/r/R/query-engine.R @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# nolint start: cyclocomp_linter, ExecPlan <- R6Class("ExecPlan", inherit = ArrowObject, public = list( @@ -191,7 +192,7 @@ ExecPlan <- R6Class("ExecPlan", } node }, - Run = function(node) { + Run = function(node, as_table = FALSE) { assert_is(node, "ExecNode") # Sorting and head/tail (if sorted) are handled in the SinkNode, @@ -209,7 +210,14 @@ ExecPlan <- R6Class("ExecPlan", sorting$orders <- as.integer(sorting$orders) } - out <- ExecPlan_run( + # If we are going to return a Table anyway, we do this in one step and + # entirely in one C++ call to ensure that we can execute user-defined + # functions from the worker threads spawned by the ExecPlan. If not, we + # use ExecPlan_run which returns a RecordBatchReader that can be + # manipulated in R code (but that right now won't work with + # user-defined functions). + exec_fun <- if (as_table) ExecPlan_read_table else ExecPlan_run + out <- exec_fun( self, node, sorting, @@ -232,10 +240,12 @@ ExecPlan <- R6Class("ExecPlan", } else if (!is.null(node$extras$tail)) { # TODO(ARROW-16630): proper BottomK support # Reverse the row order to get back what we expect - out <- out$read_table() + out <- as_arrow_table(out) out <- out[rev(seq_len(nrow(out))), , drop = FALSE] # Put back into RBR - out <- as_record_batch_reader(out) + if (!as_table) { + out <- as_record_batch_reader(out) + } } # If arrange() created $temp_columns, make sure to omit them from the result @@ -243,9 +253,13 @@ ExecPlan <- R6Class("ExecPlan", # happens in the end (SinkNode) so nothing comes after it. # TODO(ARROW-16631): move into ExecPlan if (length(node$extras$sort$temp_columns) > 0) { - tab <- out$read_table() + tab <- as_arrow_table(out) tab <- tab[, setdiff(names(tab), node$extras$sort$temp_columns), drop = FALSE] - out <- as_record_batch_reader(tab) + if (!as_table) { + out <- as_record_batch_reader(tab) + } else { + out <- tab + } } out @@ -262,6 +276,8 @@ ExecPlan <- R6Class("ExecPlan", Stop = function() ExecPlan_StopProducing(self) ) ) +# nolint end. + ExecPlan$create <- function(use_threads = option_use_threads()) { ExecPlan_create(use_threads) } diff --git a/r/R/table.R b/r/R/table.R index 305f305129e..5579c676d51 100644 --- a/r/R/table.R +++ b/r/R/table.R @@ -318,3 +318,18 @@ as_arrow_table.RecordBatch <- function(x, ..., schema = NULL) { as_arrow_table.data.frame <- function(x, ..., schema = NULL) { Table$create(x, schema = schema) } + +#' @rdname as_arrow_table +#' @export +as_arrow_table.RecordBatchReader <- function(x, ...) { + x$read_table() +} + +#' @rdname as_arrow_table +#' @export +as_arrow_table.arrow_dplyr_query <- function(x, ...) { + # See query-engine.R for ExecPlan/Nodes + plan <- ExecPlan$create() + final_node <- plan$Build(x) + plan$Run(final_node, as_table = TRUE) +} diff --git a/r/_pkgdown.yml b/r/_pkgdown.yml index c0f599fb8a5..b04cab8195e 100644 --- a/r/_pkgdown.yml +++ b/r/_pkgdown.yml @@ -219,6 +219,7 @@ reference: - match_arrow - value_counts - list_compute_functions + - register_scalar_function - title: Connections to other systems contents: - to_arrow diff --git a/r/man/as_arrow_table.Rd b/r/man/as_arrow_table.Rd index 0ba563f581b..aac4495e7c6 100644 --- a/r/man/as_arrow_table.Rd +++ b/r/man/as_arrow_table.Rd @@ -6,6 +6,8 @@ \alias{as_arrow_table.Table} \alias{as_arrow_table.RecordBatch} \alias{as_arrow_table.data.frame} +\alias{as_arrow_table.RecordBatchReader} +\alias{as_arrow_table.arrow_dplyr_query} \title{Convert an object to an Arrow Table} \usage{ as_arrow_table(x, ..., schema = NULL) @@ -17,6 +19,10 @@ as_arrow_table(x, ..., schema = NULL) \method{as_arrow_table}{RecordBatch}(x, ..., schema = NULL) \method{as_arrow_table}{data.frame}(x, ..., schema = NULL) + +\method{as_arrow_table}{RecordBatchReader}(x, ...) + +\method{as_arrow_table}{arrow_dplyr_query}(x, ...) } \arguments{ \item{x}{An object to convert to an Arrow Table} diff --git a/r/man/register_binding.Rd b/r/man/register_binding.Rd index e776e7b3f5b..c53df707516 100644 --- a/r/man/register_binding.Rd +++ b/r/man/register_binding.Rd @@ -4,7 +4,7 @@ \alias{register_binding} \title{Register compute bindings} \usage{ -register_binding(fun_name, fun, registry = nse_funcs) +register_binding(fun_name, fun, registry = nse_funcs, update_cache = FALSE) } \arguments{ \item{fun_name}{A string containing a function name in the form \code{"function"} or @@ -18,6 +18,14 @@ This function must accept \code{Expression} objects as arguments and return \item{registry}{An environment in which the functions should be assigned.} +\item{update_cache}{Update .cache$functions at the time of registration. +the default is FALSE because the majority of usage is to register +bindings at package load, after which we create the cache once. The +reason why .cache$functions is needed in addition to nse_funcs for +non-aggregate functions could be revisited...it is currently used +as the data mask in mutate, filter, and aggregate (but not +summarise) because the data mask has to be a list.} + \item{agg_fun}{An aggregate function or \code{NULL} to un-register a previous aggregate function. This function must accept \code{Expression} objects as arguments and return a \code{list()} with components: diff --git a/r/man/register_scalar_function.Rd b/r/man/register_scalar_function.Rd new file mode 100644 index 00000000000..4da8f54f645 --- /dev/null +++ b/r/man/register_scalar_function.Rd @@ -0,0 +1,70 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/compute.R +\name{register_scalar_function} +\alias{register_scalar_function} +\title{Register user-defined functions} +\usage{ +register_scalar_function(name, fun, in_type, out_type, auto_convert = FALSE) +} +\arguments{ +\item{name}{The function name to be used in the dplyr bindings} + +\item{fun}{An R function or rlang-style lambda expression. The function +will be called with a first argument \code{context} which is a \code{list()} +with elements \code{batch_size} (the expected length of the output) and +\code{output_type} (the required \link{DataType} of the output) that may be used +to ensure that the output has the correct type and length. Subsequent +arguments are passed by position as specified by \code{in_types}. If +\code{auto_convert} is \code{TRUE}, subsequent arguments are converted to +R vectors before being passed to \code{fun} and the output is automatically +constructed with the expected output type via \code{\link[=as_arrow_array]{as_arrow_array()}}.} + +\item{in_type}{A \link{DataType} of the input type or a \code{\link[=schema]{schema()}} +for functions with more than one argument. This signature will be used +to determine if this function is appropriate for a given set of arguments. +If this function is appropriate for more than one signature, pass a +\code{list()} of the above.} + +\item{out_type}{A \link{DataType} of the output type or a function accepting +a single argument (\code{types}), which is a \code{list()} of \link{DataType}s. If a +function it must return a \link{DataType}.} + +\item{auto_convert}{Use \code{TRUE} to convert inputs before passing to \code{fun} +and construct an Array of the correct type from the output. Use this +option to write functions of R objects as opposed to functions of +Arrow R6 objects.} +} +\value{ +\code{NULL}, invisibly +} +\description{ +These functions support calling R code from query engine execution +(i.e., a \code{\link[dplyr:mutate]{dplyr::mutate()}} or \code{\link[dplyr:filter]{dplyr::filter()}} on a \link{Table} or \link{Dataset}). +Use \code{\link[=register_scalar_function]{register_scalar_function()}} attach Arrow input and output types to an +R function and make it available for use in the dplyr interface and/or +\code{\link[=call_function]{call_function()}}. Scalar functions are currently the only type of +user-defined function supported. In Arrow, scalar functions must be +stateless and return output with the same shape (i.e., the same number +of rows) as the input. +} +\examples{ +\dontshow{if (arrow_with_dataset()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +library(dplyr, warn.conflicts = FALSE) + +some_model <- lm(mpg ~ disp + cyl, data = mtcars) +register_scalar_function( + "mtcars_predict_mpg", + function(context, disp, cyl) { + predict(some_model, newdata = data.frame(disp, cyl)) + }, + in_type = schema(disp = float64(), cyl = float64()), + out_type = float64(), + auto_convert = TRUE +) + +as_arrow_table(mtcars) \%>\% + transmute(mpg, mpg_predicted = mtcars_predict_mpg(disp, cyl)) \%>\% + collect() \%>\% + head() +\dontshow{\}) # examplesIf} +} diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index e89718144ab..fd9f92e5d1a 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -881,6 +881,18 @@ BEGIN_CPP11 END_CPP11 } // compute-exec.cpp +std::shared_ptr ExecPlan_read_table(const std::shared_ptr& plan, const std::shared_ptr& final_node, cpp11::list sort_options, cpp11::strings metadata, int64_t head); +extern "C" SEXP _arrow_ExecPlan_read_table(SEXP plan_sexp, SEXP final_node_sexp, SEXP sort_options_sexp, SEXP metadata_sexp, SEXP head_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type plan(plan_sexp); + arrow::r::Input&>::type final_node(final_node_sexp); + arrow::r::Input::type sort_options(sort_options_sexp); + arrow::r::Input::type metadata(metadata_sexp); + arrow::r::Input::type head(head_sexp); + return cpp11::as_sexp(ExecPlan_read_table(plan, final_node, sort_options, metadata, head)); +END_CPP11 +} +// compute-exec.cpp void ExecPlan_StopProducing(const std::shared_ptr& plan); extern "C" SEXP _arrow_ExecPlan_StopProducing(SEXP plan_sexp){ BEGIN_CPP11 @@ -1099,6 +1111,16 @@ BEGIN_CPP11 return cpp11::as_sexp(compute__GetFunctionNames()); END_CPP11 } +// compute.cpp +void RegisterScalarUDF(std::string name, cpp11::list func_sexp); +extern "C" SEXP _arrow_RegisterScalarUDF(SEXP name_sexp, SEXP func_sexp_sexp){ +BEGIN_CPP11 + arrow::r::Input::type name(name_sexp); + arrow::r::Input::type func_sexp(func_sexp_sexp); + RegisterScalarUDF(name, func_sexp); + return R_NilValue; +END_CPP11 +} // config.cpp std::vector build_info(); extern "C" SEXP _arrow_build_info(){ @@ -2788,22 +2810,20 @@ BEGIN_CPP11 END_CPP11 } // feather.cpp -std::shared_ptr ipc___feather___Reader__Read(const std::shared_ptr& reader, cpp11::sexp columns, bool on_old_windows); -extern "C" SEXP _arrow_ipc___feather___Reader__Read(SEXP reader_sexp, SEXP columns_sexp, SEXP on_old_windows_sexp){ +std::shared_ptr ipc___feather___Reader__Read(const std::shared_ptr& reader, cpp11::sexp columns); +extern "C" SEXP _arrow_ipc___feather___Reader__Read(SEXP reader_sexp, SEXP columns_sexp){ BEGIN_CPP11 arrow::r::Input&>::type reader(reader_sexp); arrow::r::Input::type columns(columns_sexp); - arrow::r::Input::type on_old_windows(on_old_windows_sexp); - return cpp11::as_sexp(ipc___feather___Reader__Read(reader, columns, on_old_windows)); + return cpp11::as_sexp(ipc___feather___Reader__Read(reader, columns)); END_CPP11 } // feather.cpp -std::shared_ptr ipc___feather___Reader__Open(const std::shared_ptr& stream, bool on_old_windows); -extern "C" SEXP _arrow_ipc___feather___Reader__Open(SEXP stream_sexp, SEXP on_old_windows_sexp){ +std::shared_ptr ipc___feather___Reader__Open(const std::shared_ptr& stream); +extern "C" SEXP _arrow_ipc___feather___Reader__Open(SEXP stream_sexp){ BEGIN_CPP11 arrow::r::Input&>::type stream(stream_sexp); - arrow::r::Input::type on_old_windows(on_old_windows_sexp); - return cpp11::as_sexp(ipc___feather___Reader__Open(stream, on_old_windows)); + return cpp11::as_sexp(ipc___feather___Reader__Open(stream)); END_CPP11 } // feather.cpp @@ -4601,6 +4621,13 @@ BEGIN_CPP11 END_CPP11 } // safe-call-into-r-impl.cpp +bool CanRunWithCapturedR(); +extern "C" SEXP _arrow_CanRunWithCapturedR(){ +BEGIN_CPP11 + return cpp11::as_sexp(CanRunWithCapturedR()); +END_CPP11 +} +// safe-call-into-r-impl.cpp std::string TestSafeCallIntoR(cpp11::function r_fun_that_returns_a_string, std::string opt); extern "C" SEXP _arrow_TestSafeCallIntoR(SEXP r_fun_that_returns_a_string_sexp, SEXP opt_sexp){ BEGIN_CPP11 @@ -5240,6 +5267,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_io___CompressedInputStream__Make", (DL_FUNC) &_arrow_io___CompressedInputStream__Make, 2}, { "_arrow_ExecPlan_create", (DL_FUNC) &_arrow_ExecPlan_create, 1}, { "_arrow_ExecPlan_run", (DL_FUNC) &_arrow_ExecPlan_run, 5}, + { "_arrow_ExecPlan_read_table", (DL_FUNC) &_arrow_ExecPlan_read_table, 5}, { "_arrow_ExecPlan_StopProducing", (DL_FUNC) &_arrow_ExecPlan_StopProducing, 1}, { "_arrow_ExecNode_output_schema", (DL_FUNC) &_arrow_ExecNode_output_schema, 1}, { "_arrow_ExecNode_Scan", (DL_FUNC) &_arrow_ExecNode_Scan, 4}, @@ -5258,6 +5286,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_Table__cast", (DL_FUNC) &_arrow_Table__cast, 3}, { "_arrow_compute__CallFunction", (DL_FUNC) &_arrow_compute__CallFunction, 3}, { "_arrow_compute__GetFunctionNames", (DL_FUNC) &_arrow_compute__GetFunctionNames, 0}, + { "_arrow_RegisterScalarUDF", (DL_FUNC) &_arrow_RegisterScalarUDF, 2}, { "_arrow_build_info", (DL_FUNC) &_arrow_build_info, 0}, { "_arrow_runtime_info", (DL_FUNC) &_arrow_runtime_info, 0}, { "_arrow_set_timezone_database", (DL_FUNC) &_arrow_set_timezone_database, 1}, @@ -5415,8 +5444,8 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_arrow__UnregisterRExtensionType", (DL_FUNC) &_arrow_arrow__UnregisterRExtensionType, 1}, { "_arrow_ipc___WriteFeather__Table", (DL_FUNC) &_arrow_ipc___WriteFeather__Table, 6}, { "_arrow_ipc___feather___Reader__version", (DL_FUNC) &_arrow_ipc___feather___Reader__version, 1}, - { "_arrow_ipc___feather___Reader__Read", (DL_FUNC) &_arrow_ipc___feather___Reader__Read, 3}, - { "_arrow_ipc___feather___Reader__Open", (DL_FUNC) &_arrow_ipc___feather___Reader__Open, 2}, + { "_arrow_ipc___feather___Reader__Read", (DL_FUNC) &_arrow_ipc___feather___Reader__Read, 2}, + { "_arrow_ipc___feather___Reader__Open", (DL_FUNC) &_arrow_ipc___feather___Reader__Open, 1}, { "_arrow_ipc___feather___Reader__schema", (DL_FUNC) &_arrow_ipc___feather___Reader__schema, 1}, { "_arrow_Field__initialize", (DL_FUNC) &_arrow_Field__initialize, 3}, { "_arrow_Field__ToString", (DL_FUNC) &_arrow_Field__ToString, 1}, @@ -5586,6 +5615,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_ipc___RecordBatchFileWriter__Open", (DL_FUNC) &_arrow_ipc___RecordBatchFileWriter__Open, 4}, { "_arrow_ipc___RecordBatchStreamWriter__Open", (DL_FUNC) &_arrow_ipc___RecordBatchStreamWriter__Open, 4}, { "_arrow_InitializeMainRThread", (DL_FUNC) &_arrow_InitializeMainRThread, 0}, + { "_arrow_CanRunWithCapturedR", (DL_FUNC) &_arrow_CanRunWithCapturedR, 0}, { "_arrow_TestSafeCallIntoR", (DL_FUNC) &_arrow_TestSafeCallIntoR, 2}, { "_arrow_Array__GetScalar", (DL_FUNC) &_arrow_Array__GetScalar, 2}, { "_arrow_Scalar__ToString", (DL_FUNC) &_arrow_Scalar__ToString, 1}, diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp index 76112b4cefd..e348675fc17 100644 --- a/r/src/compute-exec.cpp +++ b/r/src/compute-exec.cpp @@ -16,6 +16,7 @@ // under the License. #include "./arrow_types.h" +#include "./safe-call-into-r.h" #include #include @@ -55,11 +56,10 @@ std::shared_ptr MakeExecNodeOrStop( }); } -// [[arrow::export]] -std::shared_ptr ExecPlan_run( - const std::shared_ptr& plan, - const std::shared_ptr& final_node, cpp11::list sort_options, - cpp11::strings metadata, int64_t head = -1) { +std::pair, std::shared_ptr> +ExecPlan_prepare(const std::shared_ptr& plan, + const std::shared_ptr& final_node, + cpp11::list sort_options, cpp11::strings metadata, int64_t head = -1) { // For now, don't require R to construct SinkNodes. // Instead, just pass the node we should collect as an argument. arrow::AsyncGenerator> sink_gen; @@ -89,7 +89,6 @@ std::shared_ptr ExecPlan_run( } StopIfNotOk(plan->Validate()); - StopIfNotOk(plan->StartProducing()); // If the generator is destroyed before being completely drained, inform plan std::shared_ptr stop_producing{nullptr, [plan](...) { @@ -109,9 +108,40 @@ std::shared_ptr ExecPlan_run( auto kv = strings_to_kvm(metadata); out_schema = out_schema->WithMetadata(kv); } - return compute::MakeGeneratorReader( + + std::pair, std::shared_ptr> + out; + out.first = plan; + out.second = compute::MakeGeneratorReader( out_schema, [stop_producing, plan, sink_gen] { return sink_gen(); }, gc_memory_pool()); + return out; +} + +// [[arrow::export]] +std::shared_ptr ExecPlan_run( + const std::shared_ptr& plan, + const std::shared_ptr& final_node, cpp11::list sort_options, + cpp11::strings metadata, int64_t head = -1) { + auto prepared_plan = ExecPlan_prepare(plan, final_node, sort_options, metadata, head); + StopIfNotOk(prepared_plan.first->StartProducing()); + return prepared_plan.second; +} + +// [[arrow::export]] +std::shared_ptr ExecPlan_read_table( + const std::shared_ptr& plan, + const std::shared_ptr& final_node, cpp11::list sort_options, + cpp11::strings metadata, int64_t head = -1) { + auto prepared_plan = ExecPlan_prepare(plan, final_node, sort_options, metadata, head); + + auto result = RunWithCapturedRIfPossible>( + [&]() -> arrow::Result> { + ARROW_RETURN_NOT_OK(prepared_plan.first->StartProducing()); + return prepared_plan.second->ToTable(); + }); + + return ValueOrStop(result); } // [[arrow::export]] @@ -196,8 +226,14 @@ void ExecPlan_Write( ds::WriteNodeOptions{std::move(opts), std::move(kv)}); StopIfNotOk(plan->Validate()); - StopIfNotOk(plan->StartProducing()); - StopIfNotOk(plan->finished().status()); + + arrow::Status result = RunWithCapturedRIfPossibleVoid([&]() { + RETURN_NOT_OK(plan->StartProducing()); + RETURN_NOT_OK(plan->finished().status()); + return arrow::Status::OK(); + }); + + StopIfNotOk(result); } #endif diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 0db558972e8..f15117f7e48 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -16,7 +16,9 @@ // under the License. #include "./arrow_types.h" +#include "./safe-call-into-r.h" +#include #include #include #include @@ -574,3 +576,168 @@ SEXP compute__CallFunction(std::string func_name, cpp11::list args, cpp11::list std::vector compute__GetFunctionNames() { return arrow::compute::GetFunctionRegistry()->GetFunctionNames(); } + +class RScalarUDFKernelState : public arrow::compute::KernelState { + public: + RScalarUDFKernelState(cpp11::sexp exec_func, cpp11::sexp resolver) + : exec_func_(exec_func), resolver_(resolver) {} + + cpp11::function exec_func_; + cpp11::function resolver_; +}; + +arrow::Result ResolveScalarUDFOutputType( + arrow::compute::KernelContext* context, + const std::vector& input_types) { + return SafeCallIntoR( + [&]() -> arrow::TypeHolder { + auto kernel = + reinterpret_cast(context->kernel()); + auto state = std::dynamic_pointer_cast(kernel->data); + + cpp11::writable::list input_types_sexp(input_types.size()); + for (size_t i = 0; i < input_types.size(); i++) { + input_types_sexp[i] = + cpp11::to_r6(input_types[i].GetSharedPtr()); + } + + cpp11::sexp output_type_sexp = state->resolver_(input_types_sexp); + if (!Rf_inherits(output_type_sexp, "DataType")) { + cpp11::stop( + "Function specified as arrow_scalar_function() out_type argument must " + "return a DataType"); + } + + return arrow::TypeHolder( + cpp11::as_cpp>(output_type_sexp)); + }, + "resolve scalar user-defined function output data type"); +} + +arrow::Status CallRScalarUDF(arrow::compute::KernelContext* context, + const arrow::compute::ExecSpan& span, + arrow::compute::ExecResult* result) { + if (result->is_array_span()) { + return arrow::Status::NotImplemented("ArraySpan result from R scalar UDF"); + } + + return SafeCallIntoRVoid( + [&]() { + auto kernel = + reinterpret_cast(context->kernel()); + auto state = std::dynamic_pointer_cast(kernel->data); + + cpp11::writable::list args_sexp(span.num_values()); + + for (int i = 0; i < span.num_values(); i++) { + const arrow::compute::ExecValue& exec_val = span[i]; + if (exec_val.is_array()) { + args_sexp[i] = cpp11::to_r6(exec_val.array.ToArray()); + } else if (exec_val.is_scalar()) { + args_sexp[i] = cpp11::to_r6(exec_val.scalar->GetSharedPtr()); + } + } + + cpp11::sexp batch_length_sexp = cpp11::as_sexp(span.length); + + std::shared_ptr output_type = result->type()->GetSharedPtr(); + cpp11::sexp output_type_sexp = cpp11::to_r6(output_type); + cpp11::writable::list udf_context = {batch_length_sexp, output_type_sexp}; + udf_context.names() = {"batch_length", "output_type"}; + + cpp11::sexp func_result_sexp = state->exec_func_(udf_context, args_sexp); + + if (Rf_inherits(func_result_sexp, "Array")) { + auto array = cpp11::as_cpp>(func_result_sexp); + + // Error for an Array result of the wrong type + if (!result->type()->Equals(array->type())) { + return cpp11::stop( + "Expected return Array or Scalar with type '%s' from user-defined " + "function but got Array with type '%s'", + result->type()->ToString().c_str(), array->type()->ToString().c_str()); + } + + result->value = std::move(array->data()); + } else if (Rf_inherits(func_result_sexp, "Scalar")) { + auto scalar = cpp11::as_cpp>(func_result_sexp); + + // handle a Scalar result of the wrong type + if (!result->type()->Equals(scalar->type)) { + return cpp11::stop( + "Expected return Array or Scalar with type '%s' from user-defined " + "function but got Scalar with type '%s'", + result->type()->ToString().c_str(), scalar->type->ToString().c_str()); + } + + auto array = ValueOrStop( + arrow::MakeArrayFromScalar(*scalar, span.length, context->memory_pool())); + result->value = std::move(array->data()); + } else { + cpp11::stop("arrow_scalar_function must return an Array or Scalar"); + } + }, + "execute scalar user-defined function"); +} + +// [[arrow::export]] +void RegisterScalarUDF(std::string name, cpp11::list func_sexp) { + cpp11::list in_type_r(func_sexp["in_type"]); + cpp11::list out_type_r(func_sexp["out_type"]); + R_xlen_t n_kernels = in_type_r.size(); + + if (n_kernels == 0) { + cpp11::stop("Can't register user-defined function with zero kernels"); + } + + // Compute the Arity from the list of input kernels. We don't currently handle + // variable numbers of arguments in a user-defined function. + int64_t n_args = + cpp11::as_cpp>(in_type_r[0])->num_fields(); + for (R_xlen_t i = 1; i < n_kernels; i++) { + auto in_types = cpp11::as_cpp>(in_type_r[i]); + if (in_types->num_fields() != n_args) { + cpp11::stop( + "Kernels for user-defined function must accept the same number of arguments"); + } + } + + arrow::compute::Arity arity(n_args, false); + + // The function documentation isn't currently accessible from R but is required + // for the C++ function constructor. + std::vector dummy_argument_names(n_args); + for (int64_t i = 0; i < n_args; i++) { + dummy_argument_names[i] = "arg"; + } + const arrow::compute::FunctionDoc dummy_function_doc{ + "A user-defined R function", "returns something", std::move(dummy_argument_names)}; + + auto func = + std::make_shared(name, arity, dummy_function_doc); + + for (R_xlen_t i = 0; i < n_kernels; i++) { + auto in_types = cpp11::as_cpp>(in_type_r[i]); + cpp11::sexp out_type_func = out_type_r[i]; + + std::vector compute_in_types(in_types->num_fields()); + for (int64_t j = 0; j < in_types->num_fields(); j++) { + compute_in_types[j] = arrow::compute::InputType(in_types->field(j)->type()); + } + + arrow::compute::OutputType out_type((&ResolveScalarUDFOutputType)); + + auto signature = std::make_shared( + std::move(compute_in_types), std::move(out_type), true); + arrow::compute::ScalarKernel kernel(signature, &CallRScalarUDF); + kernel.mem_allocation = arrow::compute::MemAllocation::NO_PREALLOCATE; + kernel.null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE; + kernel.data = + std::make_shared(func_sexp["wrapper_fun"], out_type_func); + + StopIfNotOk(func->AddKernel(std::move(kernel))); + } + + auto registry = arrow::compute::GetFunctionRegistry(); + StopIfNotOk(registry->AddFunction(std::move(func), true)); +} diff --git a/r/src/csv.cpp b/r/src/csv.cpp index d031cc87cac..7ce55feb5fe 100644 --- a/r/src/csv.cpp +++ b/r/src/csv.cpp @@ -162,16 +162,9 @@ std::shared_ptr csv___TableReader__Make( // [[arrow::export]] std::shared_ptr csv___TableReader__Read( const std::shared_ptr& table_reader) { -#if !defined(HAS_SAFE_CALL_INTO_R) - return ValueOrStop(table_reader->Read()); -#else - const auto& io_context = arrow::io::default_io_context(); - auto result = RunWithCapturedR>([&]() { - return DeferNotOk( - io_context.executor()->Submit([&]() { return table_reader->Read(); })); - }); + auto result = RunWithCapturedRIfPossible>( + [&]() { return table_reader->Read(); }); return ValueOrStop(result); -#endif } // [[arrow::export]] diff --git a/r/src/extension-impl.cpp b/r/src/extension-impl.cpp index efb9f0f4675..e6efcf36479 100644 --- a/r/src/extension-impl.cpp +++ b/r/src/extension-impl.cpp @@ -38,18 +38,19 @@ bool RExtensionType::ExtensionEquals(const arrow::ExtensionType& other) const { // With any ambiguity, we need to materialize the R6 instance and call its // ExtensionEquals method. We can't do this on the non-R thread. - // After ARROW-15841, we can use SafeCallIntoR. - arrow::Result result = SafeCallIntoR([&]() { - cpp11::environment instance = r6_instance(); - cpp11::function instance_ExtensionEquals(instance["ExtensionEquals"]); - - std::shared_ptr other_shared = - ValueOrStop(other.Deserialize(other.storage_type(), other.Serialize())); - cpp11::sexp other_r6 = cpp11::to_r6(other_shared, "ExtensionType"); - - cpp11::logicals result(instance_ExtensionEquals(other_r6)); - return cpp11::as_cpp(result); - }); + arrow::Result result = SafeCallIntoR( + [&]() { + cpp11::environment instance = r6_instance(); + cpp11::function instance_ExtensionEquals(instance["ExtensionEquals"]); + + std::shared_ptr other_shared = + ValueOrStop(other.Deserialize(other.storage_type(), other.Serialize())); + cpp11::sexp other_r6 = cpp11::to_r6(other_shared, "ExtensionType"); + + cpp11::logicals result(instance_ExtensionEquals(other_r6)); + return cpp11::as_cpp(result); + }, + "RExtensionType$ExtensionEquals()"); if (!result.ok()) { throw std::runtime_error(result.status().message()); diff --git a/r/src/feather.cpp b/r/src/feather.cpp index debabe49689..cf68faef1b5 100644 --- a/r/src/feather.cpp +++ b/r/src/feather.cpp @@ -49,8 +49,7 @@ int ipc___feather___Reader__version( // [[arrow::export]] std::shared_ptr ipc___feather___Reader__Read( - const std::shared_ptr& reader, cpp11::sexp columns, - bool on_old_windows) { + const std::shared_ptr& reader, cpp11::sexp columns) { bool use_names = columns != R_NilValue; std::vector names; if (use_names) { @@ -61,7 +60,7 @@ std::shared_ptr ipc___feather___Reader__Read( } } - auto read_table = [&]() { + auto result = RunWithCapturedRIfPossible>([&]() { std::shared_ptr table; arrow::Status read_result; if (use_names) { @@ -75,39 +74,17 @@ std::shared_ptr ipc___feather___Reader__Read( } else { return arrow::Result>(read_result); } - }; + }); -#if !defined(HAS_SAFE_CALL_INTO_R) - return ValueOrStop(read_table()); -#else - if (!on_old_windows) { - const auto& io_context = arrow::io::default_io_context(); - auto result = RunWithCapturedR>( - [&]() { return DeferNotOk(io_context.executor()->Submit(read_table)); }); - return ValueOrStop(result); - } else { - return ValueOrStop(read_table()); - } -#endif + return ValueOrStop(result); } // [[arrow::export]] std::shared_ptr ipc___feather___Reader__Open( - const std::shared_ptr& stream, bool on_old_windows) { -#if !defined(HAS_SAFE_CALL_INTO_R) - return ValueOrStop(arrow::ipc::feather::Reader::Open(stream)); -#else - if (!on_old_windows) { - const auto& io_context = arrow::io::default_io_context(); - auto result = RunWithCapturedR>([&]() { - return DeferNotOk(io_context.executor()->Submit( - [&]() { return arrow::ipc::feather::Reader::Open(stream); })); - }); - return ValueOrStop(result); - } else { - return ValueOrStop(arrow::ipc::feather::Reader::Open(stream)); - } -#endif + const std::shared_ptr& stream) { + auto result = RunWithCapturedRIfPossible>( + [&]() { return arrow::ipc::feather::Reader::Open(stream); }); + return ValueOrStop(result); } // [[arrow::export]] diff --git a/r/src/io.cpp b/r/src/io.cpp index 42766ddd2f5..321b1b17feb 100644 --- a/r/src/io.cpp +++ b/r/src/io.cpp @@ -223,8 +223,8 @@ class RConnectionFileInterface : public virtual arrow::io::FileInterface { closed_ = true; - return SafeCallIntoRVoid( - [&]() { cpp11::package("base")["close"](connection_sexp_); }); + return SafeCallIntoRVoid([&]() { cpp11::package("base")["close"](connection_sexp_); }, + "close() on R connection"); } arrow::Result Tell() const { @@ -232,10 +232,12 @@ class RConnectionFileInterface : public virtual arrow::io::FileInterface { return arrow::Status::IOError("R connection is closed"); } - return SafeCallIntoR([&]() { - cpp11::sexp result = cpp11::package("base")["seek"](connection_sexp_); - return cpp11::as_cpp(result); - }); + return SafeCallIntoR( + [&]() { + cpp11::sexp result = cpp11::package("base")["seek"](connection_sexp_); + return cpp11::as_cpp(result); + }, + "tell() on R connection"); } bool closed() const { return closed_; } @@ -251,17 +253,19 @@ class RConnectionFileInterface : public virtual arrow::io::FileInterface { return arrow::Status::IOError("R connection is closed"); } - return SafeCallIntoR([&] { - cpp11::function read_bin = cpp11::package("base")["readBin"]; - cpp11::writable::raws ptype((R_xlen_t)0); - cpp11::integers n = cpp11::as_sexp(nbytes); + return SafeCallIntoR( + [&] { + cpp11::function read_bin = cpp11::package("base")["readBin"]; + cpp11::writable::raws ptype((R_xlen_t)0); + cpp11::integers n = cpp11::as_sexp(nbytes); - cpp11::sexp result = read_bin(connection_sexp_, ptype, n); + cpp11::sexp result = read_bin(connection_sexp_, ptype, n); - int64_t result_size = cpp11::safe[Rf_xlength](result); - memcpy(out, cpp11::safe[RAW](result), result_size); - return result_size; - }); + int64_t result_size = cpp11::safe[Rf_xlength](result); + memcpy(out, cpp11::safe[RAW](result), result_size); + return result_size; + }, + "readBin() on R connection"); } arrow::Result> ReadBase(int64_t nbytes) { @@ -278,13 +282,15 @@ class RConnectionFileInterface : public virtual arrow::io::FileInterface { return arrow::Status::IOError("R connection is closed"); } - return SafeCallIntoRVoid([&]() { - cpp11::writable::raws data_raw(nbytes); - memcpy(cpp11::safe[RAW](data_raw), data, nbytes); - - cpp11::function write_bin = cpp11::package("base")["writeBin"]; - write_bin(data_raw, connection_sexp_); - }); + return SafeCallIntoRVoid( + [&]() { + cpp11::writable::raws data_raw(nbytes); + memcpy(cpp11::safe[RAW](data_raw), data, nbytes); + + cpp11::function write_bin = cpp11::package("base")["writeBin"]; + write_bin(data_raw, connection_sexp_); + }, + "writeBin() on R connection"); } arrow::Status SeekBase(int64_t pos) { @@ -292,9 +298,11 @@ class RConnectionFileInterface : public virtual arrow::io::FileInterface { return arrow::Status::IOError("R connection is closed"); } - return SafeCallIntoRVoid([&]() { - cpp11::package("base")["seek"](connection_sexp_, cpp11::as_sexp(pos)); - }); + return SafeCallIntoRVoid( + [&]() { + cpp11::package("base")["seek"](connection_sexp_, cpp11::as_sexp(pos)); + }, + "seek() on R connection"); } private: @@ -305,10 +313,12 @@ class RConnectionFileInterface : public virtual arrow::io::FileInterface { return true; } - auto is_open_result = SafeCallIntoR([&]() { - cpp11::sexp result = cpp11::package("base")["isOpen"](connection_sexp_); - return cpp11::as_cpp(result); - }); + auto is_open_result = SafeCallIntoR( + [&]() { + cpp11::sexp result = cpp11::package("base")["isOpen"](connection_sexp_); + return cpp11::as_cpp(result); + }, + "isOpen() on R connection"); if (!is_open_result.ok()) { closed_ = true; diff --git a/r/src/safe-call-into-r-impl.cpp b/r/src/safe-call-into-r-impl.cpp index 7c5e75b788e..7318c81bb55 100644 --- a/r/src/safe-call-into-r-impl.cpp +++ b/r/src/safe-call-into-r-impl.cpp @@ -29,6 +29,21 @@ MainRThread& GetMainRThread() { // [[arrow::export]] void InitializeMainRThread() { GetMainRThread().Initialize(); } +// [[arrow::export]] +bool CanRunWithCapturedR() { +#if defined(HAS_UNWIND_PROTECT) + static int on_old_windows = -1; + if (on_old_windows == -1) { + cpp11::function on_old_windows_fun = cpp11::package("arrow")["on_old_windows"]; + on_old_windows = on_old_windows_fun(); + } + + return !on_old_windows; +#else + return false; +#endif +} + // [[arrow::export]] std::string TestSafeCallIntoR(cpp11::function r_fun_that_returns_a_string, std::string opt) { diff --git a/r/src/safe-call-into-r.h b/r/src/safe-call-into-r.h index 0555628d7d5..937163a05df 100644 --- a/r/src/safe-call-into-r.h +++ b/r/src/safe-call-into-r.h @@ -20,6 +20,7 @@ #include "./arrow_types.h" +#include #include #include @@ -27,11 +28,11 @@ #include // Unwind protection was added in R 3.5 and some calls here use it -// and crash R in older versions (ARROW-16201). We use this define -// to make sure we don't crash on R 3.4 and lower. -#if defined(HAS_UNWIND_PROTECT) -#define HAS_SAFE_CALL_INTO_R -#endif +// and crash R in older versions (ARROW-16201). Crashes also occur +// on 32-bit R builds on R 3.6 and lower. Implementation provided +// in safe-call-into-r-impl.cpp so that we can skip some tests +// when this feature is not provided. +bool CanRunWithCapturedR(); // The MainRThread class keeps track of the thread on which it is safe // to call the R API to facilitate its safe use (or erroring @@ -48,7 +49,7 @@ class MainRThread { void Initialize() { thread_id_ = std::this_thread::get_id(); initialized_ = true; - SetError(R_NilValue); + ResetError(); } bool IsInitialized() { return initialized_; } @@ -56,33 +57,34 @@ class MainRThread { // Check if the current thread is the main R thread bool IsMainThread() { return initialized_ && std::this_thread::get_id() == thread_id_; } + // Check if a SafeCallIntoR call is able to execute + bool CanExecuteSafeCallIntoR() { return IsMainThread() || executor_ != nullptr; } + // The Executor that is running on the main R thread, if it exists arrow::internal::Executor*& Executor() { return executor_; } - // Save an error token generated from a cpp11::unwind_exception - // so that it can be properly handled after some cleanup code - // has run (e.g., cancelling some futures or waiting for them - // to finish). - void SetError(cpp11::sexp token) { error_token_ = token; } + // Save an error (possibly with an error token generated from + // a cpp11::unwind_exception) so that it can be properly handled + // after some cleanup code has run (e.g., cancelling some futures + // or waiting for them to finish). + void SetError(arrow::Status status) { status_ = status; } - void ResetError() { error_token_ = R_NilValue; } + void ResetError() { status_ = arrow::Status::OK(); } // Check if there is a saved error - bool HasError() { return error_token_ != R_NilValue; } + bool HasError() { return !status_.ok(); } - // Throw a cpp11::unwind_exception() with the saved token if it exists + // Throw a cpp11::unwind_exception() if void ClearError() { - if (HasError()) { - cpp11::unwind_exception e(error_token_); - ResetError(); - throw e; - } + arrow::Status maybe_error_status = status_; + ResetError(); + arrow::StopIfNotOk(maybe_error_status); } private: bool initialized_; std::thread::id thread_id_; - cpp11::sexp error_token_; + arrow::Status status_; arrow::internal::Executor* executor_; }; @@ -93,55 +95,76 @@ MainRThread& GetMainRThread(); // a SEXP (use cpp11::as_cpp to convert it to a C++ type inside // `fun`). template -arrow::Future SafeCallIntoRAsync(std::function(void)> fun) { +arrow::Future SafeCallIntoRAsync(std::function(void)> fun, + std::string reason = "unspecified") { MainRThread& main_r_thread = GetMainRThread(); if (main_r_thread.IsMainThread()) { // If we're on the main thread, run the task immediately and let // the cpp11::unwind_exception be thrown since it will be caught // at the top level. return fun(); - } else if (main_r_thread.Executor() != nullptr) { + } else if (main_r_thread.CanExecuteSafeCallIntoR()) { // If we are not on the main thread and have an Executor, // use it to run the task on the main R thread. We can't throw // a cpp11::unwind_exception here, so we need to propagate it back // to RunWithCapturedR through the MainRThread singleton. - return DeferNotOk(main_r_thread.Executor()->Submit([fun]() { + return DeferNotOk(main_r_thread.Executor()->Submit([fun, reason]() { + // This occurs when some other R code that was previously scheduled to run + // has errored, in which case we skip execution and let the original + // error surface. if (GetMainRThread().HasError()) { - return arrow::Result(arrow::Status::UnknownError("R code execution error")); + return arrow::Result( + arrow::Status::Cancelled("Previous R code execution error (", reason, ")")); } try { return fun(); } catch (cpp11::unwind_exception& e) { - GetMainRThread().SetError(e.token); - return arrow::Result(arrow::Status::UnknownError("R code execution error")); + // Here we save the token and set the main R thread to an error state + GetMainRThread().SetError(arrow::StatusUnwindProtect(e.token)); + + // We also return an error although this should not surface because + // main_r_thread.ClearError() will get called before this value can be + // returned and will StopIfNotOk(). We don't save the error token here + // to ensure that it will only get thrown once. + return arrow::Result( + arrow::Status::UnknownError("R code execution error (", reason, ")")); } })); } else { return arrow::Status::NotImplemented( - "Call to R from a non-R thread without calling RunWithCapturedR"); + "Call to R (", reason, ") from a non-R thread from an unsupported context"); } } template -arrow::Result SafeCallIntoR(std::function fun) { - arrow::Future future = SafeCallIntoRAsync(std::move(fun)); +arrow::Result SafeCallIntoR(std::function fun, + std::string reason = "unspecified") { + arrow::Future future = SafeCallIntoRAsync(std::move(fun), reason); return future.result(); } -static inline arrow::Status SafeCallIntoRVoid(std::function fun) { - arrow::Future future = SafeCallIntoRAsync([&fun]() { - fun(); - return true; - }); +static inline arrow::Status SafeCallIntoRVoid(std::function fun, + std::string reason = "unspecified") { + arrow::Future future = SafeCallIntoRAsync( + [&fun]() { + fun(); + return true; + }, + reason); return future.status(); } +// Performs an Arrow call (e.g., run an exec plan) in such a way that background threads +// can use SafeCallIntoR(). This version is useful for Arrow calls that already +// return a Future<>. template arrow::Result RunWithCapturedR(std::function()> make_arrow_call) { -#if !defined(HAS_SAFE_CALL_INTO_R) - return arrow::Status::NotImplemented("RunWithCapturedR() without UnwindProtect"); -#else + if (!CanRunWithCapturedR()) { + return arrow::Status::NotImplemented( + "RunWithCapturedR() without UnwindProtect or on 32-bit Windows + R <= 3.6"); + } + if (GetMainRThread().Executor() != nullptr) { return arrow::Status::AlreadyExists("Attempt to use more than one R Executor()"); } @@ -158,7 +181,39 @@ arrow::Result RunWithCapturedR(std::function()> make_arrow_c GetMainRThread().ClearError(); return result; -#endif +} + +// Performs an Arrow call (e.g., run an exec plan) in such a way that background threads +// can use SafeCallIntoR(). This version is useful for Arrow calls that do not already +// return a Future<>(). If it is not possible to use RunWithCapturedR() (i.e., +// CanRunWithCapturedR() returns false), this will run make_arrow_call on the main +// R thread (which will cause background threads that try to SafeCallIntoR() to +// error). +template +arrow::Result RunWithCapturedRIfPossible( + std::function()> make_arrow_call) { + if (CanRunWithCapturedR()) { + // Note that the use of the io_context here is arbitrary (i.e. we could use + // any construct that launches a background thread). + const auto& io_context = arrow::io::default_io_context(); + return RunWithCapturedR([&]() { + return DeferNotOk(io_context.executor()->Submit(std::move(make_arrow_call))); + }); + } else { + return make_arrow_call(); + } +} + +// Like RunWithCapturedRIfPossible<>() but for arrow calls that don't return +// a Result. +static inline arrow::Status RunWithCapturedRIfPossibleVoid( + std::function make_arrow_call) { + auto result = RunWithCapturedRIfPossible([&]() -> arrow::Result { + ARROW_RETURN_NOT_OK(make_arrow_call()); + return true; + }); + ARROW_RETURN_NOT_OK(result); + return arrow::Status::OK(); } #endif diff --git a/r/tests/testthat/_snaps/compute.md b/r/tests/testthat/_snaps/compute.md new file mode 100644 index 00000000000..89506a7fbc2 --- /dev/null +++ b/r/tests/testthat/_snaps/compute.md @@ -0,0 +1,4 @@ +# arrow_scalar_function() works + + fun is not a function + diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R new file mode 100644 index 00000000000..946583ae004 --- /dev/null +++ b/r/tests/testthat/test-compute.R @@ -0,0 +1,305 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +test_that("list_compute_functions() works", { + expect_type(list_compute_functions(), "character") + expect_true(all(!grepl("^hash_", list_compute_functions()))) +}) + + +test_that("arrow_scalar_function() works", { + # check in/out type as schema/data type + fun <- arrow_scalar_function( + function(context, x) x$cast(int64()), + schema(x = int32()), int64() + ) + expect_equal(fun$in_type[[1]], schema(x = int32())) + expect_equal(fun$out_type[[1]](), int64()) + + # check in/out type as data type/data type + fun <- arrow_scalar_function( + function(context, x) x$cast(int64()), + int32(), int64() + ) + expect_equal(fun$in_type[[1]][[1]], field("", int32())) + expect_equal(fun$out_type[[1]](), int64()) + + # check in/out type as field/data type + fun <- arrow_scalar_function( + function(context, a_name) x$cast(int64()), + field("a_name", int32()), + int64() + ) + expect_equal(fun$in_type[[1]], schema(a_name = int32())) + expect_equal(fun$out_type[[1]](), int64()) + + # check in/out type as lists + fun <- arrow_scalar_function( + function(context, x) x, + list(int32(), int64()), + list(int64(), int32()), + auto_convert = TRUE + ) + + expect_equal(fun$in_type[[1]][[1]], field("", int32())) + expect_equal(fun$in_type[[2]][[1]], field("", int64())) + expect_equal(fun$out_type[[1]](), int64()) + expect_equal(fun$out_type[[2]](), int32()) + + expect_snapshot_error(arrow_scalar_function(NULL, int32(), int32())) +}) + +test_that("arrow_scalar_function() works with auto_convert = TRUE", { + times_32_wrapper <- arrow_scalar_function( + function(context, x) x * 32, + float64(), + float64(), + auto_convert = TRUE + ) + + dummy_kernel_context <- list() + + expect_equal( + times_32_wrapper$wrapper_fun(dummy_kernel_context, list(Scalar$create(2))), + Array$create(2 * 32) + ) +}) + +test_that("register_scalar_function() adds a compute function to the registry", { + skip_if_not(CanRunWithCapturedR()) + + register_scalar_function( + "times_32", + function(context, x) x * 32.0, + int32(), float64(), + auto_convert = TRUE + ) + on.exit(unregister_binding("times_32", update_cache = TRUE)) + + expect_true("times_32" %in% names(asNamespace("arrow")$.cache$functions)) + expect_true("times_32" %in% list_compute_functions()) + + expect_equal( + call_function("times_32", Array$create(1L, int32())), + Array$create(32L, float64()) + ) + + expect_equal( + call_function("times_32", Scalar$create(1L, int32())), + Scalar$create(32L, float64()) + ) + + expect_identical( + record_batch(a = 1L) %>% + dplyr::mutate(b = times_32(a)) %>% + dplyr::collect(), + tibble::tibble(a = 1L, b = 32.0) + ) +}) + +test_that("arrow_scalar_function() with bad return type errors", { + skip_if_not(CanRunWithCapturedR()) + + register_scalar_function( + "times_32_bad_return_type_array", + function(context, x) Array$create(x, int32()), + int32(), + float64() + ) + on.exit( + unregister_binding("times_32_bad_return_type_array", update_cache = TRUE) + ) + + expect_error( + call_function("times_32_bad_return_type_array", Array$create(1L)), + "Expected return Array or Scalar with type 'double'" + ) + + register_scalar_function( + "times_32_bad_return_type_scalar", + function(context, x) Scalar$create(x, int32()), + int32(), + float64() + ) + on.exit( + unregister_binding("times_32_bad_return_type_scalar", update_cache = TRUE) + ) + + expect_error( + call_function("times_32_bad_return_type_scalar", Array$create(1L)), + "Expected return Array or Scalar with type 'double'" + ) +}) + +test_that("register_user_defined_function() can register multiple kernels", { + skip_if_not(CanRunWithCapturedR()) + + register_scalar_function( + "times_32", + function(context, x) x * 32L, + in_type = list(int32(), int64(), float64()), + out_type = function(in_types) in_types[[1]], + auto_convert = TRUE + ) + on.exit(unregister_binding("times_32", update_cache = TRUE)) + + expect_equal( + call_function("times_32", Scalar$create(1L, int32())), + Scalar$create(32L, int32()) + ) + + expect_equal( + call_function("times_32", Scalar$create(1L, int64())), + Scalar$create(32L, int64()) + ) + + expect_equal( + call_function("times_32", Scalar$create(1L, float64())), + Scalar$create(32L, float64()) + ) +}) + +test_that("register_user_defined_function() errors for unsupported specifications", { + expect_error( + register_scalar_function( + "no_kernels", + function(...) NULL, + list(), + list() + ), + "Can't register user-defined scalar function with 0 kernels" + ) + + expect_error( + register_scalar_function( + "wrong_n_args", + function(x) NULL, + int32(), + int32() + ), + "Expected `fun` to accept 2 argument\\(s\\)" + ) + + expect_error( + register_scalar_function( + "var_kernels", + function(...) NULL, + list(float64(), schema(x = float64(), y = float64())), + float64() + ), + "Kernels for user-defined function must accept the same number of arguments" + ) +}) + +test_that("user-defined functions work during multi-threaded execution", { + skip_if_not(CanRunWithCapturedR()) + skip_if_not_available("dataset") + + n_rows <- 10000 + n_partitions <- 10 + example_df <- expand.grid( + part = letters[seq_len(n_partitions)], + value = seq_len(n_rows), + stringsAsFactors = FALSE + ) + + # make sure values are different for each partition and + example_df$row_num <- seq_len(nrow(example_df)) + example_df$value <- example_df$value + match(example_df$part, letters) + + tf_dataset <- tempfile() + tf_dest <- tempfile() + on.exit(unlink(c(tf_dataset, tf_dest))) + write_dataset(example_df, tf_dataset, partitioning = "part") + + register_scalar_function( + "times_32", + function(context, x) x * 32.0, + int32(), + float64(), + auto_convert = TRUE + ) + on.exit(unregister_binding("times_32", update_cache = TRUE)) + + # check a regular collect() + result <- open_dataset(tf_dataset) %>% + dplyr::mutate(fun_result = times_32(value)) %>% + dplyr::collect() %>% + dplyr::arrange(row_num) + + expect_identical(result$fun_result, example_df$value * 32) + + # check a write_dataset() + open_dataset(tf_dataset) %>% + dplyr::mutate(fun_result = times_32(value)) %>% + write_dataset(tf_dest) + + result2 <- dplyr::collect(open_dataset(tf_dest)) %>% + dplyr::arrange(row_num) %>% + dplyr::collect() + + expect_identical(result2$fun_result, example_df$value * 32) +}) + +test_that("user-defined error when called from an unsupported context", { + skip_if_not_available("dataset") + skip_if_not(CanRunWithCapturedR()) + + register_scalar_function( + "times_32", + function(context, x) x * 32.0, + int32(), + float64(), + auto_convert = TRUE + ) + on.exit(unregister_binding("times_32", update_cache = TRUE)) + + stream_plan_with_udf <- function() { + record_batch(a = 1:1000) %>% + dplyr::mutate(b = times_32(a)) %>% + as_record_batch_reader() %>% + as_arrow_table() + } + + collect_plan_with_head <- function() { + record_batch(a = 1:1000) %>% + dplyr::mutate(fun_result = times_32(a)) %>% + head(11) %>% + dplyr::collect() + } + + if (identical(tolower(Sys.info()[["sysname"]]), "windows")) { + expect_equal( + stream_plan_with_udf(), + record_batch(a = 1:1000) %>% + dplyr::mutate(b = times_32(a)) %>% + dplyr::collect(as_data_frame = FALSE) + ) + + result <- collect_plan_with_head() + expect_equal(nrow(result), 11) + } else { + expect_error( + stream_plan_with_udf(), + "Call to R \\(.*?\\) from a non-R thread from an unsupported context" + ) + expect_error( + collect_plan_with_head(), + "Call to R \\(.*?\\) from a non-R thread from an unsupported context" + ) + } +}) diff --git a/r/tests/testthat/test-csv.R b/r/tests/testthat/test-csv.R index fca717cc051..d4878e6d670 100644 --- a/r/tests/testthat/test-csv.R +++ b/r/tests/testthat/test-csv.R @@ -293,9 +293,7 @@ test_that("more informative error when reading a CSV with headers and schema", { }) test_that("read_csv_arrow() and write_csv_arrow() accept connection objects", { - # connections with csv need RunWithCapturedR, which is not available - # in R <= 3.4.4 - skip_on_r_older_than("3.5") + skip_if_not(CanRunWithCapturedR()) tf <- tempfile() on.exit(unlink(tf)) diff --git a/r/tests/testthat/test-dplyr-funcs.R b/r/tests/testthat/test-dplyr-funcs.R index 2156ad9af06..86f984dd32c 100644 --- a/r/tests/testthat/test-dplyr-funcs.R +++ b/r/tests/testthat/test-dplyr-funcs.R @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -test_that("register_binding() works", { +test_that("register_binding()/unregister_binding() works", { fake_registry <- new.env(parent = emptyenv()) fun1 <- function() NULL fun2 <- function() "Hello" @@ -24,8 +24,9 @@ test_that("register_binding() works", { expect_identical(fake_registry$some_fun, fun1) expect_identical(fake_registry$`some.pkg::some_fun`, fun1) - expect_identical(register_binding("some.pkg::some_fun", NULL, fake_registry), fun1) - expect_silent(expect_null(register_binding("some.pkg::some_fun", NULL, fake_registry))) + expect_identical(unregister_binding("some.pkg::some_fun", fake_registry), fun1) + expect_false("some.pkg::some_fun" %in% names(fake_registry)) + expect_false("some_fun" %in% names(fake_registry)) expect_null(register_binding("somePkg::some_fun", fun1, fake_registry)) expect_identical(fake_registry$some_fun, fun1) diff --git a/r/tests/testthat/test-extension.R b/r/tests/testthat/test-extension.R index 638869dc8c3..55a1f8d21ee 100644 --- a/r/tests/testthat/test-extension.R +++ b/r/tests/testthat/test-extension.R @@ -312,6 +312,7 @@ test_that("Table can roundtrip extension types", { test_that("Dataset/arrow_dplyr_query can roundtrip extension types", { skip_if_not_available("dataset") + skip_if_not(CanRunWithCapturedR()) tf <- tempfile() on.exit(unlink(tf, recursive = TRUE)) diff --git a/r/tests/testthat/test-feather.R b/r/tests/testthat/test-feather.R index bed097762a2..99dc8ab9c90 100644 --- a/r/tests/testthat/test-feather.R +++ b/r/tests/testthat/test-feather.R @@ -179,11 +179,7 @@ test_that("read_feather requires RandomAccessFile and errors nicely otherwise (A }) test_that("read_feather() and write_feather() accept connection objects", { - # connection object don't work on Windows i386 before R 4.0 - skip_if(on_old_windows()) - # connections with feather need RunWithCapturedR, which is not available - # in R <= 3.4.4 - skip_on_r_older_than("3.5") + skip_if_not(CanRunWithCapturedR()) tf <- tempfile() on.exit(unlink(tf)) diff --git a/r/tests/testthat/test-safe-call-into-r.R b/r/tests/testthat/test-safe-call-into-r.R index a8027ac4237..c07d90433fd 100644 --- a/r/tests/testthat/test-safe-call-into-r.R +++ b/r/tests/testthat/test-safe-call-into-r.R @@ -32,7 +32,7 @@ test_that("SafeCallIntoR works from the main R thread", { }) test_that("SafeCallIntoR works within RunWithCapturedR", { - skip_on_r_older_than("3.5") + skip_if_not(CanRunWithCapturedR()) skip_on_cran() expect_identical( @@ -47,16 +47,16 @@ test_that("SafeCallIntoR works within RunWithCapturedR", { }) test_that("SafeCallIntoR errors from the non-R thread", { - skip_on_r_older_than("3.5") + skip_if_not(CanRunWithCapturedR()) skip_on_cran() expect_error( TestSafeCallIntoR(function() "string one!", opt = "async_without_executor"), - "Call to R from a non-R thread" + "Call to R \\(unspecified\\) from a non-R thread" ) expect_error( TestSafeCallIntoR(function() stop("an error!"), opt = "async_without_executor"), - "Call to R from a non-R thread" + "Call to R \\(unspecified\\) from a non-R thread" ) })