Skip to content

Commit

Permalink
ARROW-17689: [R] Implement dplyr::across() inside group_by() (apache#…
Browse files Browse the repository at this point in the history
…14122)

Because the handling of the case `.add = TRUE` and the `add` argument have been changed, test cases for these are also added.

Authored-by: SHIMA Tatsuya <ts1s1andn@gmail.com>
Signed-off-by: Dewey Dunnington <dewey@fishandwhistle.net>
  • Loading branch information
eitsupi authored and zagto committed Oct 7, 2022
1 parent dd284a2 commit ad98203
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 22 deletions.
38 changes: 16 additions & 22 deletions r/R/dplyr-group-by.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,37 +21,31 @@
group_by.arrow_dplyr_query <- function(.data,
...,
.add = FALSE,
add = .add,
add = NULL,
.drop = dplyr::group_by_drop_default(.data)) {
if (!missing(add)) {
.Deprecated(
msg = paste("The `add` argument of `group_by()` is deprecated. Please use the `.add` argument instead.")
)
.add <- add
}

.data <- as_adq(.data)
new_groups <- enquos(...)
# ... can contain expressions (i.e. can add (or rename?) columns) and so we
# need to identify those and add them on to the query with mutate. Specifically,
# we want to mark as new:
# * expressions (named or otherwise)
# * variables that have new names
# All others (i.e. simple references to variables) should not be (re)-added
expression_list <- expand_across(.data, quos(...))
new_groups <- ensure_named_exprs(expression_list)

# Identify any groups with names which aren't in names of .data
new_group_ind <- map_lgl(new_groups, ~ !(quo_name(.x) %in% names(.data)))
# Identify any groups which don't have names
named_group_ind <- map_lgl(names(new_groups), nzchar)
# Retain any new groups identified above
new_groups <- new_groups[new_group_ind | named_group_ind]
if (length(new_groups)) {
# now either use the name that was given in ... or if that is "" then use the expr
names(new_groups) <- imap_chr(new_groups, ~ ifelse(.y == "", quo_name(.x), .y))

# Add them to the data
.data <- dplyr::mutate(.data, !!!new_groups)
}
if (".add" %in% names(formals(dplyr::group_by))) {
# For compatibility with dplyr >= 1.0
gv <- dplyr::group_by_prepare(.data, ..., .add = .add)$group_names

if (.add) {
gv <- union(dplyr::group_vars(.data), names(new_groups))
} else {
gv <- dplyr::group_by_prepare(.data, ..., add = add)$group_names
gv <- names(new_groups)
}
.data$group_by_vars <- gv

.data$group_by_vars <- gv %||% character()
.data$drop_empty_groups <- ifelse(length(gv), .drop, dplyr::group_by_drop_default(.data))
.data
}
Expand Down
110 changes: 110 additions & 0 deletions r/tests/testthat/test-dplyr-group-by.R
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,113 @@ test_that("group_by() with namespaced functions", {
tbl
)
})

test_that("group_by() with .add", {
compare_dplyr_binding(
.input %>%
group_by(dbl2) %>%
group_by(.add = FALSE) %>%
collect(),
tbl
)
compare_dplyr_binding(
.input %>%
group_by(dbl2) %>%
group_by(.add = TRUE) %>%
collect(),
tbl
)
compare_dplyr_binding(
.input %>%
group_by(dbl2) %>%
group_by(chr, .add = FALSE) %>%
collect(),
tbl
)
compare_dplyr_binding(
.input %>%
group_by(dbl2) %>%
group_by(chr, .add = TRUE) %>%
collect(),
tbl
)
compare_dplyr_binding(
.input %>%
group_by(chr, .add = FALSE) %>%
collect(),
tbl %>%
group_by(dbl2)
)
compare_dplyr_binding(
.input %>%
group_by(chr, .add = TRUE) %>%
collect(),
tbl %>%
group_by(dbl2)
)
suppressWarnings(compare_dplyr_binding(
.input %>%
group_by(dbl2) %>%
group_by(add = FALSE) %>%
collect(),
tbl,
warning = "deprecated"
))
suppressWarnings(compare_dplyr_binding(
.input %>%
group_by(dbl2) %>%
group_by(add = TRUE) %>%
collect(),
tbl,
warning = "deprecated"
))
expect_warning(
tbl %>%
arrow_table() %>%
group_by(add = TRUE) %>%
collect(),
"The `add` argument of `group_by\\(\\)` is deprecated"
)
expect_error(
suppressWarnings(
tbl %>%
arrow_table() %>%
group_by(add = dbl2) %>%
collect()
),
"object 'dbl2' not found"
)
})

test_that("Can use across() within group_by()", {
test_groups <- c("dbl", "int", "chr")
compare_dplyr_binding(
.input %>%
group_by(across()) %>%
collect(),
tbl
)
compare_dplyr_binding(
.input %>%
group_by(across(starts_with("d"))) %>%
collect(),
tbl
)
compare_dplyr_binding(
.input %>%
group_by(across({{ test_groups }})) %>%
collect(),
tbl
)

# ARROW-12778 - `where()` is not yet supported
expect_error(
compare_dplyr_binding(
.input %>%
group_by(across(where(is.numeric))) %>%
collect(),
tbl
),
"Unsupported selection helper"
)
})

0 comments on commit ad98203

Please sign in to comment.