-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ARROW-15010: [R] Create a function registry for our NSE funcs
This is PR implementing that admittedly does a few things. I'd be happy to split up with a suggestion on the desired steps if that makes it easier to review. In fact, I'm totally game to use reviews on this PR to collect ideas about the changes that collectively we agree on and just implement that subset in a new PR. This PR: - Defines `register_translation()` and `register_translation_agg()` instead of direct assignment into `nse_funcs` and `agg_funcs`. This enables attaching a package name when the function is being registered (e.g., `register_translation("stringr::str_dup", function(x, ...) ...)`) and makes it possible in the future to allow other packages to define translations and/or for us to change how we evaluate translated expressions without changing many function assignments. - Moves the registration of translations to package load time rather than package build time. This enables splitting up translations into multiple files, adds the usual CMD check that normal functions undergo (e.g., for use of missing variables), and opens up the possibility of defining different translations for different versions of packages (or omitting translations if a package isn't installed). - Splits up the definition of translations into dplyr-funcs-type.R, dplyr-funcs-conditional.R, dplyr-funcs-string.R, dplyr-funcs-datetime.R, and dplyr-funcs-math.R. This matches where the translations are tested (the test files were named test-dplyr-funcs-string.R, etc.). Some translations were moved to dplyr-summarise.R because the translations were being tested in test-dplyr-summarise.R. This makes it easier for parallel PRs defining translations and makes it easier for new contributors to figure out where tests should go. - Consolidates internal documentation on how to write translations into one .Rd file rather than scattered around as comments in dplyr-functions.R as they were before. - Removes direct references to `nse_funcs` and `agg_funcs` except where used to implement evaluation This PR does not: - Change any test filenames or remove any tests. - Change how translations are stored in the package namespace - Change anything about how evaluation works Reprex with the gist of how the registration works: ``` r # remotes::install_github("apache/arrow/r#11904") withr::with_namespace("arrow", { # translations get defined in function wrappers so that they can get called # at package load (so there's no need to consider collate order) register_file_translations <- function() { register_translation("some_func", function() { Expression$scalar(1L) }) } # the .onLoad() hook for registering translations lives in dplyr-funcs.R register_all_translations <- function() { register_file_translations() # ...and other translation functions that might live in other files } # ...and gets called in .onLoad() register_all_translations() # if you need to call a translation, use call_translation() (to keep references # to nse_funcs and/or agg_funcs constrained to the eval implementation) call_translation("some_func") # ...the same machinery exists for agg_funcs register_translation_agg("some_agg_func", function() { Expression$scalar(2L) }) call_translation_agg("some_agg_func") }) #> Expression #> 2 ``` <sup>Created on 2021-12-08 by the [reprex package](https://reprex.tidyverse.org) (v2.0.1)</sup> Closes #11904 from paleolimbot/r-register-translation Lead-authored-by: Dewey Dunnington <dewey@fishandwhistle.net> Co-authored-by: Jonathan Keane <jkeane@gmail.com> Signed-off-by: Jonathan Keane <jkeane@gmail.com>
- Loading branch information
1 parent
85f6f7c
commit 314d8bf
Showing
21 changed files
with
1,717 additions
and
1,429 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
register_bindings_conditional <- function() { | ||
register_binding("coalesce", function(...) { | ||
args <- list2(...) | ||
if (length(args) < 1) { | ||
abort("At least one argument must be supplied to coalesce()") | ||
} | ||
|
||
# Treat NaN like NA for consistency with dplyr::coalesce(), but if *all* | ||
# the values are NaN, we should return NaN, not NA, so don't replace | ||
# NaN with NA in the final (or only) argument | ||
# TODO: if an option is added to the coalesce kernel to treat NaN as NA, | ||
# use that to simplify the code here (ARROW-13389) | ||
attr(args[[length(args)]], "last") <- TRUE | ||
args <- lapply(args, function(arg) { | ||
last_arg <- is.null(attr(arg, "last")) | ||
attr(arg, "last") <- NULL | ||
|
||
if (!inherits(arg, "Expression")) { | ||
arg <- Expression$scalar(arg) | ||
} | ||
|
||
if (last_arg && arg$type_id() %in% TYPES_WITH_NAN) { | ||
# store the NA_real_ in the same type as arg to avoid avoid casting | ||
# smaller float types to larger float types | ||
NA_expr <- Expression$scalar(Scalar$create(NA_real_, type = arg$type())) | ||
Expression$create("if_else", Expression$create("is_nan", arg), NA_expr, arg) | ||
} else { | ||
arg | ||
} | ||
}) | ||
Expression$create("coalesce", args = args) | ||
}) | ||
|
||
if_else_binding <- function(condition, true, false, missing = NULL) { | ||
if (!is.null(missing)) { | ||
return(if_else_binding( | ||
call_binding("is.na", (condition)), | ||
missing, | ||
if_else_binding(condition, true, false) | ||
)) | ||
} | ||
|
||
build_expr("if_else", condition, true, false) | ||
} | ||
|
||
register_binding("if_else", if_else_binding) | ||
|
||
# Although base R ifelse allows `yes` and `no` to be different classes | ||
register_binding("ifelse", function(test, yes, no) { | ||
if_else_binding(condition = test, true = yes, false = no) | ||
}) | ||
|
||
register_binding("case_when", function(...) { | ||
formulas <- list2(...) | ||
n <- length(formulas) | ||
if (n == 0) { | ||
abort("No cases provided in case_when()") | ||
} | ||
query <- vector("list", n) | ||
value <- vector("list", n) | ||
mask <- caller_env() | ||
for (i in seq_len(n)) { | ||
f <- formulas[[i]] | ||
if (!inherits(f, "formula")) { | ||
abort("Each argument to case_when() must be a two-sided formula") | ||
} | ||
query[[i]] <- arrow_eval(f[[2]], mask) | ||
value[[i]] <- arrow_eval(f[[3]], mask) | ||
if (!call_binding("is.logical", query[[i]])) { | ||
abort("Left side of each formula in case_when() must be a logical expression") | ||
} | ||
if (inherits(value[[i]], "try-error")) { | ||
abort(handle_arrow_not_supported(value[[i]], format_expr(f[[3]]))) | ||
} | ||
} | ||
build_expr( | ||
"case_when", | ||
args = c( | ||
build_expr( | ||
"make_struct", | ||
args = query, | ||
options = list(field_names = as.character(seq_along(query))) | ||
), | ||
value | ||
) | ||
) | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
register_bindings_datetime <- function() { | ||
register_binding("strptime", function(x, format = "%Y-%m-%d %H:%M:%S", tz = NULL, | ||
unit = "ms") { | ||
# Arrow uses unit for time parsing, strptime() does not. | ||
# Arrow has no default option for strptime (format, unit), | ||
# we suggest following format = "%Y-%m-%d %H:%M:%S", unit = MILLI/1L/"ms", | ||
# (ARROW-12809) | ||
|
||
# ParseTimestampStrptime currently ignores the timezone information (ARROW-12820). | ||
# Stop if tz is provided. | ||
if (is.character(tz)) { | ||
arrow_not_supported("Time zone argument") | ||
} | ||
|
||
unit <- make_valid_time_unit(unit, c(valid_time64_units, valid_time32_units)) | ||
|
||
Expression$create("strptime", x, options = list(format = format, unit = unit)) | ||
}) | ||
|
||
register_binding("strftime", function(x, format = "", tz = "", usetz = FALSE) { | ||
if (usetz) { | ||
format <- paste(format, "%Z") | ||
} | ||
if (tz == "") { | ||
tz <- Sys.timezone() | ||
} | ||
# Arrow's strftime prints in timezone of the timestamp. To match R's strftime behavior we first | ||
# cast the timestamp to desired timezone. This is a metadata only change. | ||
if (call_binding("is.POSIXct", x)) { | ||
ts <- Expression$create("cast", x, options = list(to_type = timestamp(x$type()$unit(), tz))) | ||
} else { | ||
ts <- x | ||
} | ||
Expression$create("strftime", ts, options = list(format = format, locale = Sys.getlocale("LC_TIME"))) | ||
}) | ||
|
||
register_binding("format_ISO8601", function(x, usetz = FALSE, precision = NULL, ...) { | ||
ISO8601_precision_map <- | ||
list( | ||
y = "%Y", | ||
ym = "%Y-%m", | ||
ymd = "%Y-%m-%d", | ||
ymdh = "%Y-%m-%dT%H", | ||
ymdhm = "%Y-%m-%dT%H:%M", | ||
ymdhms = "%Y-%m-%dT%H:%M:%S" | ||
) | ||
|
||
if (is.null(precision)) { | ||
precision <- "ymdhms" | ||
} | ||
if (!precision %in% names(ISO8601_precision_map)) { | ||
abort( | ||
paste( | ||
"`precision` must be one of the following values:", | ||
paste(names(ISO8601_precision_map), collapse = ", "), | ||
"\nValue supplied was: ", | ||
precision | ||
) | ||
) | ||
} | ||
format <- ISO8601_precision_map[[precision]] | ||
if (usetz) { | ||
format <- paste0(format, "%z") | ||
} | ||
Expression$create("strftime", x, options = list(format = format, locale = "C")) | ||
}) | ||
|
||
register_binding("second", function(x) { | ||
Expression$create("add", Expression$create("second", x), Expression$create("subsecond", x)) | ||
}) | ||
|
||
register_binding("wday", function(x, label = FALSE, abbr = TRUE, | ||
week_start = getOption("lubridate.week.start", 7), | ||
locale = Sys.getlocale("LC_TIME")) { | ||
if (label) { | ||
if (abbr) { | ||
format <- "%a" | ||
} else { | ||
format <- "%A" | ||
} | ||
return(Expression$create("strftime", x, options = list(format = format, locale = locale))) | ||
} | ||
|
||
Expression$create("day_of_week", x, options = list(count_from_zero = FALSE, week_start = week_start)) | ||
}) | ||
|
||
register_binding("month", function(x, label = FALSE, abbr = TRUE, locale = Sys.getlocale("LC_TIME")) { | ||
if (label) { | ||
if (abbr) { | ||
format <- "%b" | ||
} else { | ||
format <- "%B" | ||
} | ||
return(Expression$create("strftime", x, options = list(format = format, locale = locale))) | ||
} | ||
|
||
Expression$create("month", x) | ||
}) | ||
|
||
register_binding("is.Date", function(x) { | ||
inherits(x, "Date") || | ||
(inherits(x, "Expression") && x$type_id() %in% Type[c("DATE32", "DATE64")]) | ||
}) | ||
|
||
is_instant_binding <- function(x) { | ||
inherits(x, c("POSIXt", "POSIXct", "POSIXlt", "Date")) || | ||
(inherits(x, "Expression") && x$type_id() %in% Type[c("TIMESTAMP", "DATE32", "DATE64")]) | ||
} | ||
register_binding("is.instant", is_instant_binding) | ||
register_binding("is.timepoint", is_instant_binding) | ||
|
||
register_binding("is.POSIXct", function(x) { | ||
inherits(x, "POSIXct") || | ||
(inherits(x, "Expression") && x$type_id() %in% Type[c("TIMESTAMP")]) | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
register_bindings_math <- function() { | ||
log_binding <- function(x, base = exp(1)) { | ||
# like other binary functions, either `x` or `base` can be Expression or double(1) | ||
if (is.numeric(x) && length(x) == 1) { | ||
x <- Expression$scalar(x) | ||
} else if (!inherits(x, "Expression")) { | ||
arrow_not_supported("x must be a column or a length-1 numeric; other values") | ||
} | ||
|
||
# handle `base` differently because we use the simpler ln, log2, and log10 | ||
# functions for specific scalar base values | ||
if (inherits(base, "Expression")) { | ||
return(Expression$create("logb_checked", x, base)) | ||
} | ||
|
||
if (!is.numeric(base) || length(base) != 1) { | ||
arrow_not_supported("base must be a column or a length-1 numeric; other values") | ||
} | ||
|
||
if (base == exp(1)) { | ||
return(Expression$create("ln_checked", x)) | ||
} | ||
|
||
if (base == 2) { | ||
return(Expression$create("log2_checked", x)) | ||
} | ||
|
||
if (base == 10) { | ||
return(Expression$create("log10_checked", x)) | ||
} | ||
|
||
Expression$create("logb_checked", x, Expression$scalar(base)) | ||
} | ||
|
||
register_binding("log", log_binding) | ||
register_binding("logb", log_binding) | ||
|
||
register_binding("pmin", function(..., na.rm = FALSE) { | ||
build_expr( | ||
"min_element_wise", | ||
..., | ||
options = list(skip_nulls = na.rm) | ||
) | ||
}) | ||
|
||
register_binding("pmax", function(..., na.rm = FALSE) { | ||
build_expr( | ||
"max_element_wise", | ||
..., | ||
options = list(skip_nulls = na.rm) | ||
) | ||
}) | ||
|
||
register_binding("trunc", function(x, ...) { | ||
# accepts and ignores ... for consistency with base::trunc() | ||
build_expr("trunc", x) | ||
}) | ||
|
||
register_binding("round", function(x, digits = 0) { | ||
build_expr( | ||
"round", | ||
x, | ||
options = list(ndigits = digits, round_mode = RoundMode$HALF_TO_EVEN) | ||
) | ||
}) | ||
} |
Oops, something went wrong.