Skip to content

Commit

Permalink
ARROW-15010: [R] Create a function registry for our NSE funcs
Browse files Browse the repository at this point in the history
This is PR implementing that admittedly does a few things. I'd be happy to split up with a suggestion on the desired steps if that makes it easier to review. In fact, I'm totally game to use reviews on this PR to collect ideas about the changes that collectively we agree on and just implement that subset in a new PR.

This PR:

- Defines `register_translation()` and `register_translation_agg()` instead of direct assignment into `nse_funcs` and `agg_funcs`. This enables attaching a package name when the function is being registered (e.g., `register_translation("stringr::str_dup", function(x, ...) ...)`) and makes it possible in the future to allow other packages to define translations and/or for us to change how we evaluate translated expressions without changing many function assignments.
- Moves the registration of translations to package load time rather than package build time. This enables splitting up translations into multiple files, adds the usual CMD check that normal functions undergo (e.g., for use of missing variables), and opens up the possibility of defining different translations for different versions of packages (or omitting translations if a package isn't installed).
- Splits up the definition of translations into dplyr-funcs-type.R, dplyr-funcs-conditional.R, dplyr-funcs-string.R, dplyr-funcs-datetime.R, and dplyr-funcs-math.R. This matches where the translations are tested (the test files were named test-dplyr-funcs-string.R, etc.). Some translations were moved to dplyr-summarise.R because the translations were being tested in test-dplyr-summarise.R. This makes it easier for parallel PRs defining translations and makes it easier for new contributors to figure out where tests should go.
- Consolidates internal documentation on how to write translations into one .Rd file rather than scattered around as comments in dplyr-functions.R as they were before.
- Removes direct references to `nse_funcs` and `agg_funcs` except where used to implement evaluation

This PR does not:

- Change any test filenames or remove any tests.
- Change how translations are stored in the package namespace
- Change anything about how evaluation works

Reprex with the gist of how the registration works:

``` r
# remotes::install_github("apache/arrow/r#11904")
withr::with_namespace("arrow", {

  # translations get defined in function wrappers so that they can get called
  # at package load (so there's no need to consider collate order)
  register_file_translations <- function() {

    register_translation("some_func", function() {
      Expression$scalar(1L)
    })

  }

  # the .onLoad() hook for registering translations lives in dplyr-funcs.R
  register_all_translations <- function() {
    register_file_translations()
    # ...and other translation functions that might live in other files
  }

  # ...and gets called in .onLoad()
  register_all_translations()

  # if you need to call a translation, use call_translation() (to keep references
  # to nse_funcs and/or agg_funcs constrained to the eval implementation)
  call_translation("some_func")

  # ...the same machinery exists for agg_funcs
  register_translation_agg("some_agg_func", function() {
    Expression$scalar(2L)
  })

  call_translation_agg("some_agg_func")
})
#> Expression
#> 2
```

<sup>Created on 2021-12-08 by the [reprex package](https://reprex.tidyverse.org) (v2.0.1)</sup>

Closes #11904 from paleolimbot/r-register-translation

Lead-authored-by: Dewey Dunnington <dewey@fishandwhistle.net>
Co-authored-by: Jonathan Keane <jkeane@gmail.com>
Signed-off-by: Jonathan Keane <jkeane@gmail.com>
  • Loading branch information
paleolimbot and jonkeane committed Jan 4, 2022
1 parent 85f6f7c commit 314d8bf
Show file tree
Hide file tree
Showing 21 changed files with 1,717 additions and 1,429 deletions.
7 changes: 6 additions & 1 deletion r/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,13 @@ Collate:
'dplyr-distinct.R'
'dplyr-eval.R'
'dplyr-filter.R'
'dplyr-funcs-conditional.R'
'dplyr-funcs-datetime.R'
'dplyr-funcs-math.R'
'dplyr-funcs-string.R'
'dplyr-funcs-type.R'
'expression.R'
'dplyr-functions.R'
'dplyr-funcs.R'
'dplyr-group-by.R'
'dplyr-join.R'
'dplyr-mutate.R'
Expand Down
19 changes: 4 additions & 15 deletions r/R/arrow-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,10 @@
s3_register("reticulate::r_to_py", cl)
}

# Create these once, at package build time
if (arrow_available()) {
# Also include all available Arrow Compute functions,
# namespaced as arrow_fun.
# We can't do this at install time because list_compute_functions() may error
all_arrow_funs <- list_compute_functions()
arrow_funcs <- set_names(
lapply(all_arrow_funs, function(fun) {
force(fun)
function(...) build_expr(fun, ...)
}),
paste0("arrow_", all_arrow_funs)
)
.cache$functions <- c(nse_funcs, arrow_funcs)
}
# Create the .cache$functions list at package load time.
# We can't do this at build time because list_compute_functions() may error
# if arrow_available() is FALSE
create_binding_cache()

if (tolower(Sys.info()[["sysname"]]) == "windows") {
# Disable multithreading on Windows
Expand Down
2 changes: 1 addition & 1 deletion r/R/dplyr-collect.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ implicit_schema <- function(.data) {
hash <- length(.data$group_by_vars) > 0
agg_fields <- imap(
new_fields[setdiff(names(new_fields), .data$group_by_vars)],
~ output_type(.data$aggregations[[.y]][["fun"]], .x, hash)
~ agg_fun_output_type(.data$aggregations[[.y]][["fun"]], .x, hash)
)
new_fields <- c(group_fields, agg_fields)
}
Expand Down
105 changes: 105 additions & 0 deletions r/R/dplyr-funcs-conditional.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

register_bindings_conditional <- function() {
register_binding("coalesce", function(...) {
args <- list2(...)
if (length(args) < 1) {
abort("At least one argument must be supplied to coalesce()")
}

# Treat NaN like NA for consistency with dplyr::coalesce(), but if *all*
# the values are NaN, we should return NaN, not NA, so don't replace
# NaN with NA in the final (or only) argument
# TODO: if an option is added to the coalesce kernel to treat NaN as NA,
# use that to simplify the code here (ARROW-13389)
attr(args[[length(args)]], "last") <- TRUE
args <- lapply(args, function(arg) {
last_arg <- is.null(attr(arg, "last"))
attr(arg, "last") <- NULL

if (!inherits(arg, "Expression")) {
arg <- Expression$scalar(arg)
}

if (last_arg && arg$type_id() %in% TYPES_WITH_NAN) {
# store the NA_real_ in the same type as arg to avoid avoid casting
# smaller float types to larger float types
NA_expr <- Expression$scalar(Scalar$create(NA_real_, type = arg$type()))
Expression$create("if_else", Expression$create("is_nan", arg), NA_expr, arg)
} else {
arg
}
})
Expression$create("coalesce", args = args)
})

if_else_binding <- function(condition, true, false, missing = NULL) {
if (!is.null(missing)) {
return(if_else_binding(
call_binding("is.na", (condition)),
missing,
if_else_binding(condition, true, false)
))
}

build_expr("if_else", condition, true, false)
}

register_binding("if_else", if_else_binding)

# Although base R ifelse allows `yes` and `no` to be different classes
register_binding("ifelse", function(test, yes, no) {
if_else_binding(condition = test, true = yes, false = no)
})

register_binding("case_when", function(...) {
formulas <- list2(...)
n <- length(formulas)
if (n == 0) {
abort("No cases provided in case_when()")
}
query <- vector("list", n)
value <- vector("list", n)
mask <- caller_env()
for (i in seq_len(n)) {
f <- formulas[[i]]
if (!inherits(f, "formula")) {
abort("Each argument to case_when() must be a two-sided formula")
}
query[[i]] <- arrow_eval(f[[2]], mask)
value[[i]] <- arrow_eval(f[[3]], mask)
if (!call_binding("is.logical", query[[i]])) {
abort("Left side of each formula in case_when() must be a logical expression")
}
if (inherits(value[[i]], "try-error")) {
abort(handle_arrow_not_supported(value[[i]], format_expr(f[[3]])))
}
}
build_expr(
"case_when",
args = c(
build_expr(
"make_struct",
args = query,
options = list(field_names = as.character(seq_along(query)))
),
value
)
)
})
}
133 changes: 133 additions & 0 deletions r/R/dplyr-funcs-datetime.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

register_bindings_datetime <- function() {
register_binding("strptime", function(x, format = "%Y-%m-%d %H:%M:%S", tz = NULL,
unit = "ms") {
# Arrow uses unit for time parsing, strptime() does not.
# Arrow has no default option for strptime (format, unit),
# we suggest following format = "%Y-%m-%d %H:%M:%S", unit = MILLI/1L/"ms",
# (ARROW-12809)

# ParseTimestampStrptime currently ignores the timezone information (ARROW-12820).
# Stop if tz is provided.
if (is.character(tz)) {
arrow_not_supported("Time zone argument")
}

unit <- make_valid_time_unit(unit, c(valid_time64_units, valid_time32_units))

Expression$create("strptime", x, options = list(format = format, unit = unit))
})

register_binding("strftime", function(x, format = "", tz = "", usetz = FALSE) {
if (usetz) {
format <- paste(format, "%Z")
}
if (tz == "") {
tz <- Sys.timezone()
}
# Arrow's strftime prints in timezone of the timestamp. To match R's strftime behavior we first
# cast the timestamp to desired timezone. This is a metadata only change.
if (call_binding("is.POSIXct", x)) {
ts <- Expression$create("cast", x, options = list(to_type = timestamp(x$type()$unit(), tz)))
} else {
ts <- x
}
Expression$create("strftime", ts, options = list(format = format, locale = Sys.getlocale("LC_TIME")))
})

register_binding("format_ISO8601", function(x, usetz = FALSE, precision = NULL, ...) {
ISO8601_precision_map <-
list(
y = "%Y",
ym = "%Y-%m",
ymd = "%Y-%m-%d",
ymdh = "%Y-%m-%dT%H",
ymdhm = "%Y-%m-%dT%H:%M",
ymdhms = "%Y-%m-%dT%H:%M:%S"
)

if (is.null(precision)) {
precision <- "ymdhms"
}
if (!precision %in% names(ISO8601_precision_map)) {
abort(
paste(
"`precision` must be one of the following values:",
paste(names(ISO8601_precision_map), collapse = ", "),
"\nValue supplied was: ",
precision
)
)
}
format <- ISO8601_precision_map[[precision]]
if (usetz) {
format <- paste0(format, "%z")
}
Expression$create("strftime", x, options = list(format = format, locale = "C"))
})

register_binding("second", function(x) {
Expression$create("add", Expression$create("second", x), Expression$create("subsecond", x))
})

register_binding("wday", function(x, label = FALSE, abbr = TRUE,
week_start = getOption("lubridate.week.start", 7),
locale = Sys.getlocale("LC_TIME")) {
if (label) {
if (abbr) {
format <- "%a"
} else {
format <- "%A"
}
return(Expression$create("strftime", x, options = list(format = format, locale = locale)))
}

Expression$create("day_of_week", x, options = list(count_from_zero = FALSE, week_start = week_start))
})

register_binding("month", function(x, label = FALSE, abbr = TRUE, locale = Sys.getlocale("LC_TIME")) {
if (label) {
if (abbr) {
format <- "%b"
} else {
format <- "%B"
}
return(Expression$create("strftime", x, options = list(format = format, locale = locale)))
}

Expression$create("month", x)
})

register_binding("is.Date", function(x) {
inherits(x, "Date") ||
(inherits(x, "Expression") && x$type_id() %in% Type[c("DATE32", "DATE64")])
})

is_instant_binding <- function(x) {
inherits(x, c("POSIXt", "POSIXct", "POSIXlt", "Date")) ||
(inherits(x, "Expression") && x$type_id() %in% Type[c("TIMESTAMP", "DATE32", "DATE64")])
}
register_binding("is.instant", is_instant_binding)
register_binding("is.timepoint", is_instant_binding)

register_binding("is.POSIXct", function(x) {
inherits(x, "POSIXct") ||
(inherits(x, "Expression") && x$type_id() %in% Type[c("TIMESTAMP")])
})
}
83 changes: 83 additions & 0 deletions r/R/dplyr-funcs-math.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

register_bindings_math <- function() {
log_binding <- function(x, base = exp(1)) {
# like other binary functions, either `x` or `base` can be Expression or double(1)
if (is.numeric(x) && length(x) == 1) {
x <- Expression$scalar(x)
} else if (!inherits(x, "Expression")) {
arrow_not_supported("x must be a column or a length-1 numeric; other values")
}

# handle `base` differently because we use the simpler ln, log2, and log10
# functions for specific scalar base values
if (inherits(base, "Expression")) {
return(Expression$create("logb_checked", x, base))
}

if (!is.numeric(base) || length(base) != 1) {
arrow_not_supported("base must be a column or a length-1 numeric; other values")
}

if (base == exp(1)) {
return(Expression$create("ln_checked", x))
}

if (base == 2) {
return(Expression$create("log2_checked", x))
}

if (base == 10) {
return(Expression$create("log10_checked", x))
}

Expression$create("logb_checked", x, Expression$scalar(base))
}

register_binding("log", log_binding)
register_binding("logb", log_binding)

register_binding("pmin", function(..., na.rm = FALSE) {
build_expr(
"min_element_wise",
...,
options = list(skip_nulls = na.rm)
)
})

register_binding("pmax", function(..., na.rm = FALSE) {
build_expr(
"max_element_wise",
...,
options = list(skip_nulls = na.rm)
)
})

register_binding("trunc", function(x, ...) {
# accepts and ignores ... for consistency with base::trunc()
build_expr("trunc", x)
})

register_binding("round", function(x, digits = 0) {
build_expr(
"round",
x,
options = list(ndigits = digits, round_mode = RoundMode$HALF_TO_EVEN)
)
})
}
Loading

0 comments on commit 314d8bf

Please sign in to comment.