Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-15010: [R] Create a function registry for our NSE funcs #11904

Closed
wants to merge 49 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
a71dee6
make nse_funcs and agg_funcs environments instead of lists
paleolimbot Dec 8, 2021
47dba00
move translation cache generation to dplyr-functions
paleolimbot Dec 8, 2021
5a69943
use register_translation() for .array_function_map registration
paleolimbot Dec 8, 2021
403af23
group string functions
paleolimbot Dec 8, 2021
382dcb8
move strptime and strftime tests to test-dplyr-funcs-datetime
paleolimbot Dec 8, 2021
7afdc06
move pmin, pmax, trunc, and round to be next to the other math transl…
paleolimbot Dec 8, 2021
07c0d89
group datetime translations
paleolimbot Dec 8, 2021
d308884
better name for 'output_type()'
paleolimbot Dec 8, 2021
09c2a64
group aggregate functions
paleolimbot Dec 8, 2021
9d576c9
rename dplyr-functions to dplyr-funcs to match names of test files
paleolimbot Dec 8, 2021
c7e9770
regsiter translations at load time and not at build time
paleolimbot Dec 8, 2021
0d15ddd
remove debugging
paleolimbot Dec 8, 2021
d834b3d
separate conditional and string translations
paleolimbot Dec 8, 2021
247297f
separate datetime and type funcs
paleolimbot Dec 8, 2021
dac53f0
separate math translations
paleolimbot Dec 8, 2021
78e1d08
use Apache 2.0 file header
paleolimbot Dec 8, 2021
b8ad455
move aggregate functions to dplyr-summarise (because they're tested i…
paleolimbot Dec 8, 2021
6bc82df
move array map registration to where the functions are defined
paleolimbot Dec 8, 2021
d5b149f
move documentation from comments into internal documentation for regi…
paleolimbot Dec 8, 2021
2c45de9
don't use links in internal docs
paleolimbot Dec 8, 2021
04f380b
fix undefined variable error from CMD check
paleolimbot Dec 8, 2021
312d836
test register_translation()
paleolimbot Dec 8, 2021
ec6897c
`nse_funcs$fun <- ` -> `register_translation("fun", ` for dplyr-funcs…
paleolimbot Dec 8, 2021
dfb6e77
use register_translation for dplyr-funcs-conditional.R
paleolimbot Dec 8, 2021
3b017d7
use_register_translation in dplyr-funcs-datetime.R
paleolimbot Dec 8, 2021
ed5835d
use register_translation
paleolimbot Dec 8, 2021
328373a
use_register_translation for dplyr-funcs-string.R
paleolimbot Dec 8, 2021
a2e399f
remove more references to nse_funcs
paleolimbot Dec 8, 2021
c5fbdf1
remove more references to `nse_funcs`
paleolimbot Dec 8, 2021
24eabdb
remove more references to nse_funcs
paleolimbot Dec 8, 2021
76b4301
remove references to agg_funcs
paleolimbot Dec 8, 2021
d02ad94
try to reduce the complexity of the string register function for the …
paleolimbot Dec 8, 2021
298e69a
split type functions into several chunks to satisfy the linter
paleolimbot Dec 9, 2021
16baf91
fix lint on dplyr-funcs-type.R
paleolimbot Dec 9, 2021
344b986
fix arrowExports.cpp
paleolimbot Dec 9, 2021
dc581f1
_translation -> _binding
paleolimbot Dec 13, 2021
3d3c56a
more translation -> binding
paleolimbot Dec 14, 2021
5ec8e41
Update r/R/dplyr-funcs.R
paleolimbot Dec 22, 2021
1ddd9e1
Update r/R/dplyr-funcs.R
paleolimbot Dec 22, 2021
bf510e7
redocument
paleolimbot Dec 22, 2021
8d1ffa6
remove translation_registry() functions
paleolimbot Dec 22, 2021
9e24621
register_bindings_X instead of register_X_bindings
paleolimbot Dec 22, 2021
c01f5d8
Update r/R/dplyr-funcs-datetime.R
paleolimbot Dec 30, 2021
5b8d6c9
Update r/R/dplyr-funcs-conditional.R
paleolimbot Dec 30, 2021
2637fe7
Update r/R/dplyr-funcs-math.R
paleolimbot Dec 30, 2021
db7cb6d
Update r/R/dplyr-funcs-type.R
paleolimbot Dec 30, 2021
19d1f74
Update r/R/dplyr-funcs-type.R
paleolimbot Dec 30, 2021
12e96d2
Update r/R/dplyr-summarize.R
paleolimbot Dec 30, 2021
54286eb
Oops, add in missing register_bindings_type_cast()
jonkeane Jan 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion r/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,13 @@ Collate:
'dplyr-distinct.R'
'dplyr-eval.R'
'dplyr-filter.R'
'dplyr-funcs-conditional.R'
'dplyr-funcs-datetime.R'
'dplyr-funcs-math.R'
'dplyr-funcs-string.R'
'dplyr-funcs-type.R'
'expression.R'
'dplyr-functions.R'
'dplyr-funcs.R'
'dplyr-group-by.R'
'dplyr-join.R'
'dplyr-mutate.R'
Expand Down
19 changes: 4 additions & 15 deletions r/R/arrow-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,10 @@
s3_register("reticulate::r_to_py", cl)
}

# Create these once, at package build time
if (arrow_available()) {
# Also include all available Arrow Compute functions,
# namespaced as arrow_fun.
# We can't do this at install time because list_compute_functions() may error
all_arrow_funs <- list_compute_functions()
arrow_funcs <- set_names(
lapply(all_arrow_funs, function(fun) {
force(fun)
function(...) build_expr(fun, ...)
}),
paste0("arrow_", all_arrow_funs)
)
.cache$functions <- c(nse_funcs, arrow_funcs)
}
# Create the .cache$functions list at package load time.
# We can't do this at build time because list_compute_functions() may error
# if arrow_available() is FALSE
create_binding_cache()

if (tolower(Sys.info()[["sysname"]]) == "windows") {
# Disable multithreading on Windows
Expand Down
2 changes: 1 addition & 1 deletion r/R/dplyr-collect.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ implicit_schema <- function(.data) {
hash <- length(.data$group_by_vars) > 0
agg_fields <- imap(
new_fields[setdiff(names(new_fields), .data$group_by_vars)],
~ output_type(.data$aggregations[[.y]][["fun"]], .x, hash)
~ agg_fun_output_type(.data$aggregations[[.y]][["fun"]], .x, hash)
)
new_fields <- c(group_fields, agg_fields)
}
Expand Down
105 changes: 105 additions & 0 deletions r/R/dplyr-funcs-conditional.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

register_bindings_conditional <- function() {
register_binding("coalesce", function(...) {
args <- list2(...)
if (length(args) < 1) {
abort("At least one argument must be supplied to coalesce()")
}

# Treat NaN like NA for consistency with dplyr::coalesce(), but if *all*
# the values are NaN, we should return NaN, not NA, so don't replace
# NaN with NA in the final (or only) argument
# TODO: if an option is added to the coalesce kernel to treat NaN as NA,
# use that to simplify the code here (ARROW-13389)
attr(args[[length(args)]], "last") <- TRUE
args <- lapply(args, function(arg) {
last_arg <- is.null(attr(arg, "last"))
attr(arg, "last") <- NULL

if (!inherits(arg, "Expression")) {
arg <- Expression$scalar(arg)
}

if (last_arg && arg$type_id() %in% TYPES_WITH_NAN) {
# store the NA_real_ in the same type as arg to avoid avoid casting
# smaller float types to larger float types
NA_expr <- Expression$scalar(Scalar$create(NA_real_, type = arg$type()))
Expression$create("if_else", Expression$create("is_nan", arg), NA_expr, arg)
} else {
arg
}
})
Expression$create("coalesce", args = args)
})

if_else_binding <- function(condition, true, false, missing = NULL) {
if (!is.null(missing)) {
return(if_else_binding(
call_binding("is.na", (condition)),
missing,
if_else_binding(condition, true, false)
))
}

build_expr("if_else", condition, true, false)
}

register_binding("if_else", if_else_binding)

# Although base R ifelse allows `yes` and `no` to be different classes
register_binding("ifelse", function(test, yes, no) {
if_else_binding(condition = test, true = yes, false = no)
})

register_binding("case_when", function(...) {
formulas <- list2(...)
n <- length(formulas)
if (n == 0) {
abort("No cases provided in case_when()")
}
query <- vector("list", n)
value <- vector("list", n)
mask <- caller_env()
for (i in seq_len(n)) {
f <- formulas[[i]]
if (!inherits(f, "formula")) {
abort("Each argument to case_when() must be a two-sided formula")
}
query[[i]] <- arrow_eval(f[[2]], mask)
value[[i]] <- arrow_eval(f[[3]], mask)
if (!call_binding("is.logical", query[[i]])) {
abort("Left side of each formula in case_when() must be a logical expression")
}
if (inherits(value[[i]], "try-error")) {
abort(handle_arrow_not_supported(value[[i]], format_expr(f[[3]])))
}
}
build_expr(
"case_when",
args = c(
build_expr(
"make_struct",
args = query,
options = list(field_names = as.character(seq_along(query)))
),
value
)
)
})
}
133 changes: 133 additions & 0 deletions r/R/dplyr-funcs-datetime.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

register_bindings_datetime <- function() {
register_binding("strptime", function(x, format = "%Y-%m-%d %H:%M:%S", tz = NULL,
unit = "ms") {
# Arrow uses unit for time parsing, strptime() does not.
# Arrow has no default option for strptime (format, unit),
# we suggest following format = "%Y-%m-%d %H:%M:%S", unit = MILLI/1L/"ms",
# (ARROW-12809)

# ParseTimestampStrptime currently ignores the timezone information (ARROW-12820).
# Stop if tz is provided.
if (is.character(tz)) {
arrow_not_supported("Time zone argument")
}

unit <- make_valid_time_unit(unit, c(valid_time64_units, valid_time32_units))

Expression$create("strptime", x, options = list(format = format, unit = unit))
})

register_binding("strftime", function(x, format = "", tz = "", usetz = FALSE) {
if (usetz) {
format <- paste(format, "%Z")
}
if (tz == "") {
tz <- Sys.timezone()
}
# Arrow's strftime prints in timezone of the timestamp. To match R's strftime behavior we first
# cast the timestamp to desired timezone. This is a metadata only change.
if (call_binding("is.POSIXct", x)) {
ts <- Expression$create("cast", x, options = list(to_type = timestamp(x$type()$unit(), tz)))
} else {
ts <- x
}
Expression$create("strftime", ts, options = list(format = format, locale = Sys.getlocale("LC_TIME")))
})

register_binding("format_ISO8601", function(x, usetz = FALSE, precision = NULL, ...) {
ISO8601_precision_map <-
list(
y = "%Y",
ym = "%Y-%m",
ymd = "%Y-%m-%d",
ymdh = "%Y-%m-%dT%H",
ymdhm = "%Y-%m-%dT%H:%M",
ymdhms = "%Y-%m-%dT%H:%M:%S"
)

if (is.null(precision)) {
precision <- "ymdhms"
}
if (!precision %in% names(ISO8601_precision_map)) {
abort(
paste(
"`precision` must be one of the following values:",
paste(names(ISO8601_precision_map), collapse = ", "),
"\nValue supplied was: ",
precision
)
)
}
format <- ISO8601_precision_map[[precision]]
if (usetz) {
format <- paste0(format, "%z")
}
Expression$create("strftime", x, options = list(format = format, locale = "C"))
})

register_binding("second", function(x) {
Expression$create("add", Expression$create("second", x), Expression$create("subsecond", x))
})

register_binding("wday", function(x, label = FALSE, abbr = TRUE,
week_start = getOption("lubridate.week.start", 7),
locale = Sys.getlocale("LC_TIME")) {
if (label) {
if (abbr) {
format <- "%a"
} else {
format <- "%A"
}
return(Expression$create("strftime", x, options = list(format = format, locale = locale)))
}

Expression$create("day_of_week", x, options = list(count_from_zero = FALSE, week_start = week_start))
})

register_binding("month", function(x, label = FALSE, abbr = TRUE, locale = Sys.getlocale("LC_TIME")) {
if (label) {
if (abbr) {
format <- "%b"
} else {
format <- "%B"
}
return(Expression$create("strftime", x, options = list(format = format, locale = locale)))
}

Expression$create("month", x)
})

register_binding("is.Date", function(x) {
inherits(x, "Date") ||
(inherits(x, "Expression") && x$type_id() %in% Type[c("DATE32", "DATE64")])
})

is_instant_binding <- function(x) {
inherits(x, c("POSIXt", "POSIXct", "POSIXlt", "Date")) ||
(inherits(x, "Expression") && x$type_id() %in% Type[c("TIMESTAMP", "DATE32", "DATE64")])
}
register_binding("is.instant", is_instant_binding)
register_binding("is.timepoint", is_instant_binding)

register_binding("is.POSIXct", function(x) {
inherits(x, "POSIXct") ||
(inherits(x, "Expression") && x$type_id() %in% Type[c("TIMESTAMP")])
})
}
83 changes: 83 additions & 0 deletions r/R/dplyr-funcs-math.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

register_bindings_math <- function() {
log_binding <- function(x, base = exp(1)) {
# like other binary functions, either `x` or `base` can be Expression or double(1)
if (is.numeric(x) && length(x) == 1) {
x <- Expression$scalar(x)
} else if (!inherits(x, "Expression")) {
arrow_not_supported("x must be a column or a length-1 numeric; other values")
}

# handle `base` differently because we use the simpler ln, log2, and log10
# functions for specific scalar base values
if (inherits(base, "Expression")) {
return(Expression$create("logb_checked", x, base))
}

if (!is.numeric(base) || length(base) != 1) {
arrow_not_supported("base must be a column or a length-1 numeric; other values")
}

if (base == exp(1)) {
return(Expression$create("ln_checked", x))
}

if (base == 2) {
return(Expression$create("log2_checked", x))
}

if (base == 10) {
return(Expression$create("log10_checked", x))
}

Expression$create("logb_checked", x, Expression$scalar(base))
}

register_binding("log", log_binding)
register_binding("logb", log_binding)

register_binding("pmin", function(..., na.rm = FALSE) {
build_expr(
"min_element_wise",
...,
options = list(skip_nulls = na.rm)
)
})

register_binding("pmax", function(..., na.rm = FALSE) {
build_expr(
"max_element_wise",
...,
options = list(skip_nulls = na.rm)
)
})

register_binding("trunc", function(x, ...) {
# accepts and ignores ... for consistency with base::trunc()
build_expr("trunc", x)
})

register_binding("round", function(x, digits = 0) {
build_expr(
"round",
x,
options = list(ndigits = digits, round_mode = RoundMode$HALF_TO_EVEN)
)
})
}
Loading