-
Notifications
You must be signed in to change notification settings - Fork 4k
ARROW-16444: [R] Implement user-defined scalar functions in R bindings #13397
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2eb3c9b
80293f7
29d02d8
190d059
5e3f682
e129471
83ad7ad
94c0b2f
9e1a362
ddc0d46
fadf258
2eb48ae
c29fc00
b1c8cbf
cf98635
c171da6
e129028
4631cb9
5b82d79
99f7225
80e5683
1e343a2
df6ea0c
36feaac
32e8d83
8ff947b
87402fc
6d60921
e338f7d
e0dd5c0
cdecb55
65a5dc0
f5ec713
bb38274
565c5b5
b96469b
0c139d1
52880b1
e8856b7
2665fdf
010ccf6
f732505
4c01654
21a932a
0c1b8cf
017f681
85519a2
9f251fc
ebc0b84
ed735e1
8877c12
a89ce07
f687451
2e9e261
514d91e
58c8573
b4154af
1ed6d25
88bf4d2
6f3d601
031ec64
0652ae0
49261d6
c1207eb
72d650d
a1f8b53
4acaa61
7ccb23b
6ff4fb2
83aa148
8a8955f
0d3520a
4ac9ec5
ba34d1a
dfbdbc2
7fd6a77
b07e736
c79aca5
8d18754
8d33b07
abd938a
86c1c7e
652175f
12b9721
259eed9
1f8b248
510221f
e906632
1805836
aa9165f
7952710
e31f2b1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -306,3 +306,179 @@ cast_options <- function(safe = TRUE, ...) { | |
| ) | ||
| modifyList(opts, list(...)) | ||
| } | ||
|
|
||
| #' Register user-defined functions | ||
| #' | ||
| #' These functions support calling R code from query engine execution | ||
| #' (i.e., a [dplyr::mutate()] or [dplyr::filter()] on a [Table] or [Dataset]). | ||
| #' Use [register_scalar_function()] attach Arrow input and output types to an | ||
| #' R function and make it available for use in the dplyr interface and/or | ||
| #' [call_function()]. Scalar functions are currently the only type of | ||
| #' user-defined function supported. In Arrow, scalar functions must be | ||
| #' stateless and return output with the same shape (i.e., the same number | ||
| #' of rows) as the input. | ||
| #' | ||
| #' @param name The function name to be used in the dplyr bindings | ||
| #' @param in_type A [DataType] of the input type or a [schema()] | ||
| #' for functions with more than one argument. This signature will be used | ||
| #' to determine if this function is appropriate for a given set of arguments. | ||
| #' If this function is appropriate for more than one signature, pass a | ||
| #' `list()` of the above. | ||
| #' @param out_type A [DataType] of the output type or a function accepting | ||
| #' a single argument (`types`), which is a `list()` of [DataType]s. If a | ||
| #' function it must return a [DataType]. | ||
| #' @param fun An R function or rlang-style lambda expression. The function | ||
| #' will be called with a first argument `context` which is a `list()` | ||
| #' with elements `batch_size` (the expected length of the output) and | ||
| #' `output_type` (the required [DataType] of the output) that may be used | ||
| #' to ensure that the output has the correct type and length. Subsequent | ||
| #' arguments are passed by position as specified by `in_types`. If | ||
| #' `auto_convert` is `TRUE`, subsequent arguments are converted to | ||
| #' R vectors before being passed to `fun` and the output is automatically | ||
| #' constructed with the expected output type via [as_arrow_array()]. | ||
| #' @param auto_convert Use `TRUE` to convert inputs before passing to `fun` | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I envision it being a lot more common to use |
||
| #' and construct an Array of the correct type from the output. Use this | ||
| #' option to write functions of R objects as opposed to functions of | ||
| #' Arrow R6 objects. | ||
| #' | ||
| #' @return `NULL`, invisibly | ||
| #' @export | ||
| #' | ||
| #' @examplesIf arrow_with_dataset() | ||
| #' library(dplyr, warn.conflicts = FALSE) | ||
| #' | ||
| #' some_model <- lm(mpg ~ disp + cyl, data = mtcars) | ||
| #' register_scalar_function( | ||
| #' "mtcars_predict_mpg", | ||
| #' function(context, disp, cyl) { | ||
| #' predict(some_model, newdata = data.frame(disp, cyl)) | ||
| #' }, | ||
| #' in_type = schema(disp = float64(), cyl = float64()), | ||
| #' out_type = float64(), | ||
| #' auto_convert = TRUE | ||
| #' ) | ||
| #' | ||
| #' as_arrow_table(mtcars) %>% | ||
| #' transmute(mpg, mpg_predicted = mtcars_predict_mpg(disp, cyl)) %>% | ||
| #' collect() %>% | ||
| #' head() | ||
| #' | ||
| register_scalar_function <- function(name, fun, in_type, out_type, | ||
| auto_convert = FALSE) { | ||
| assert_that(is.string(name)) | ||
|
|
||
| scalar_function <- arrow_scalar_function( | ||
| fun, | ||
| in_type, | ||
| out_type, | ||
| auto_convert = auto_convert | ||
| ) | ||
|
|
||
| # register with Arrow C++ function registry (enables its use in | ||
| # call_function() and Expression$create()) | ||
| RegisterScalarUDF(name, scalar_function) | ||
|
|
||
| # register with dplyr binding (enables its use in mutate(), filter(), etc.) | ||
| register_binding( | ||
| name, | ||
| function(...) build_expr(name, ...), | ||
| update_cache = TRUE | ||
| ) | ||
|
|
||
| invisible(NULL) | ||
| } | ||
|
|
||
| arrow_scalar_function <- function(fun, in_type, out_type, auto_convert = FALSE) { | ||
| assert_that(is.function(fun)) | ||
|
|
||
| # Create a small wrapper function that is easier to call from C++. | ||
| # TODO(ARROW-17148): This wrapper could be implemented in C/C++ to | ||
| # reduce evaluation overhead and generate prettier backtraces when | ||
| # errors occur (probably using a similar approach to purrr). | ||
| if (auto_convert) { | ||
| wrapper_fun <- function(context, args) { | ||
| args <- lapply(args, as.vector) | ||
| result <- do.call(fun, c(list(context), args)) | ||
| as_arrow_array(result, type = context$output_type) | ||
| } | ||
| } else { | ||
| wrapper_fun <- function(context, args) { | ||
| do.call(fun, c(list(context), args)) | ||
| } | ||
| } | ||
|
|
||
| # in_type can be a list() if registering multiple kernels at once | ||
| if (is.list(in_type)) { | ||
| in_type <- lapply(in_type, in_type_as_schema) | ||
| } else { | ||
| in_type <- list(in_type_as_schema(in_type)) | ||
| } | ||
|
|
||
| # out_type can be a list() if registering multiple kernels at once | ||
| if (is.list(out_type)) { | ||
| out_type <- lapply(out_type, out_type_as_function) | ||
| } else { | ||
| out_type <- list(out_type_as_function(out_type)) | ||
| } | ||
|
|
||
| # recycle out_type (which is frequently length 1 even if multiple kernels | ||
| # are being registered at once) | ||
| out_type <- rep_len(out_type, length(in_type)) | ||
|
|
||
| # check n_kernels and number of args in fun | ||
| n_kernels <- length(in_type) | ||
| if (n_kernels == 0) { | ||
| abort("Can't register user-defined scalar function with 0 kernels") | ||
| } | ||
|
|
||
| expected_n_args <- in_type[[1]]$num_fields + 1L | ||
| fun_formals_have_dots <- any(names(formals(fun)) == "...") | ||
| if (!fun_formals_have_dots && length(formals(fun)) != expected_n_args) { | ||
| abort( | ||
| sprintf( | ||
| paste0( | ||
| "Expected `fun` to accept %d argument(s)\n", | ||
| "but found a function that acccepts %d argument(s)\n", | ||
| "Did you forget to include `context` as the first argument?" | ||
| ), | ||
| expected_n_args, | ||
| length(formals(fun)) | ||
| ) | ||
| ) | ||
| } | ||
|
|
||
| structure( | ||
| list( | ||
| wrapper_fun = wrapper_fun, | ||
| in_type = in_type, | ||
| out_type = out_type | ||
| ), | ||
| class = "arrow_scalar_function" | ||
| ) | ||
| } | ||
|
|
||
| # This function sanitizes the in_type argument for arrow_scalar_function(), | ||
| # which can be a data type (e.g., int32()), a field for a unary function | ||
| # or a schema() for functions accepting more than one argument. C++ expects | ||
| # a schema(). | ||
| in_type_as_schema <- function(x) { | ||
| if (inherits(x, "Field")) { | ||
| schema(x) | ||
| } else if (inherits(x, "DataType")) { | ||
| schema(field("", x)) | ||
| } else { | ||
| as_schema(x) | ||
| } | ||
| } | ||
|
|
||
| # This function sanitizes the out_type argument for arrow_scalar_function(), | ||
| # which can be a data type (e.g., int32()) or a function of the input types. | ||
| # C++ currently expects a function. | ||
| out_type_as_function <- function(x) { | ||
| if (is.function(x)) { | ||
| x | ||
| } else { | ||
| x <- as_data_type(x) | ||
| function(types) x | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,20 +50,28 @@ NULL | |
| #' - `fun`: string function name | ||
| #' - `data`: `Expression` (these are all currently a single field) | ||
| #' - `options`: list of function options, as passed to call_function | ||
| #' @param update_cache Update .cache$functions at the time of registration. | ||
|
||
| #' the default is FALSE because the majority of usage is to register | ||
| #' bindings at package load, after which we create the cache once. The | ||
| #' reason why .cache$functions is needed in addition to nse_funcs for | ||
| #' non-aggregate functions could be revisited...it is currently used | ||
| #' as the data mask in mutate, filter, and aggregate (but not | ||
| #' summarise) because the data mask has to be a list. | ||
| #' @param registry An environment in which the functions should be | ||
| #' assigned. | ||
| #' | ||
| #' @return The previously registered binding or `NULL` if no previously | ||
| #' registered function existed. | ||
| #' @keywords internal | ||
| #' | ||
| register_binding <- function(fun_name, fun, registry = nse_funcs) { | ||
| register_binding <- function(fun_name, fun, registry = nse_funcs, | ||
| update_cache = FALSE) { | ||
| unqualified_name <- sub("^.*?:{+}", "", fun_name) | ||
|
|
||
| previous_fun <- registry[[unqualified_name]] | ||
|
|
||
| # if the unqualified name exists in the registry, warn | ||
| if (!is.null(fun) && !is.null(previous_fun)) { | ||
| if (!is.null(previous_fun)) { | ||
| warn( | ||
| paste0( | ||
| "A \"", | ||
|
|
@@ -73,11 +81,36 @@ register_binding <- function(fun_name, fun, registry = nse_funcs) { | |
| } | ||
|
|
||
| # register both as `pkg::fun` and as `fun` if `qualified_name` is prefixed | ||
| if (grepl("::", fun_name)) { | ||
| registry[[unqualified_name]] <- fun | ||
| registry[[fun_name]] <- fun | ||
| } else { | ||
| registry[[unqualified_name]] <- fun | ||
| # unqualified_name and fun_name will be the same if not prefixed | ||
| registry[[unqualified_name]] <- fun | ||
| registry[[fun_name]] <- fun | ||
|
|
||
| if (update_cache) { | ||
| fun_cache <- .cache$functions | ||
| fun_cache[[unqualified_name]] <- fun | ||
| fun_cache[[fun_name]] <- fun | ||
| .cache$functions <- fun_cache | ||
| } | ||
|
|
||
| invisible(previous_fun) | ||
| } | ||
|
|
||
| unregister_binding <- function(fun_name, registry = nse_funcs, | ||
| update_cache = FALSE) { | ||
| unqualified_name <- sub("^.*?:{+}", "", fun_name) | ||
| previous_fun <- registry[[unqualified_name]] | ||
|
|
||
| rm( | ||
| list = unique(c(fun_name, unqualified_name)), | ||
| envir = registry, | ||
| inherits = FALSE | ||
| ) | ||
|
|
||
| if (update_cache) { | ||
| fun_cache <- .cache$functions | ||
| fun_cache[[unqualified_name]] <- NULL | ||
| fun_cache[[fun_name]] <- NULL | ||
| .cache$functions <- fun_cache | ||
| } | ||
|
|
||
| invisible(previous_fun) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This first argument
contextfeels like a real 🦶 🔫 . A few questions:context?contextas an arg to my function? (I'm guessing it's not pretty.) Can we detect up front if someone has forgotten to put context in the function? Something like checking thatlength(formals(fun)) == length(as_schema(in_type)) + 1and raise a useful error message if the check fails?contextin my function?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A previous version of this PR didn't require the
contextargument when what was the equivalent ofauto_convertwasTRUE, but the comment was raised of "why two APIs" (and I agree...one wrapper function scheme is easier to remember).In its current form, the
contextargument provides the information needed forauto_convertto do its magic. Whenauto_convertisTRUE, you could also use it to do something likerunif(n = context$batch_size). The python version also provides the memory pool here but we don't provide a way to use the memory pool for constructing arrays, so I didn't add it to the context object.Because it's a
list(), assignments won't have any effect outsidefun. A future version may be an environment to avoid the extra unwind protects needed to allocate a new list for each call (but could be one with an overridden[[<-to prevent modification).I added some text to the documentation for
funand disallowed lambdas for now, since a potential future workaround could be to not pass thecontextargument for an rlang/purrr style lambda (e.g.,~.x + .ywould be the equivalent offunction(context, x, y) x + y). I hesitate to add too much convenience functionality in this PR since it's already rather unwieldy.I added a
length(formals(fun))check...you're right that the error message was awful.