From 1cb018f6070ed93a566f148e3f3cabf8b3ee655d Mon Sep 17 00:00:00 2001 From: Hadley Wickham Date: Tue, 6 Sep 2022 15:29:35 -0500 Subject: [PATCH 1/2] Use tidyverse recycling rules Fixes #878 --- NEWS.md | 4 +++ R/modify.R | 24 +++++++----------- R/pmap.R | 3 +++ R/utils.R | 10 -------- src/map.c | 39 ++++++++++-------------------- tests/testthat/_snaps/map2.md | 10 ++++++++ tests/testthat/_snaps/pmap.md | 16 ++++++++++++ tests/testthat/test-map2.R | 34 ++++++-------------------- tests/testthat/test-pmap.R | 29 +++++++++++----------- tests/testthat/test-recycle_args.R | 23 ------------------ 10 files changed, 77 insertions(+), 115 deletions(-) create mode 100644 tests/testthat/_snaps/map2.md create mode 100644 tests/testthat/_snaps/pmap.md delete mode 100644 tests/testthat/test-recycle_args.R diff --git a/NEWS.md b/NEWS.md index 5ab561c9..c2e16708 100644 --- a/NEWS.md +++ b/NEWS.md @@ -38,6 +38,10 @@ ## Features and fixes +* `map2()`, `modify2()`, and `pmap()` now use tidyverse recycling rules where + vectors of length 1 are recycled to any size but all others must have + the same length (#878). + * `as_mapper()` is now around twice as fast when used with character, integer, or list (#820). diff --git a/R/modify.R b/R/modify.R index 3abf3014..225f74ba 100644 --- a/R/modify.R +++ b/R/modify.R @@ -329,18 +329,9 @@ modify2 <- function(.x, .y, .f, ...) { } #' @export modify2.default <- function(.x, .y, .f, ...) { - .f <- as_mapper(.f, ...) - - args <- recycle_args(list(.x, .y)) - .x <- args[[1]] - .y <- args[[2]] - - for (i in seq_along(.x)) { - list_slice2(.x, i) <- .f(.x[[i]], .y[[i]], ...) - } - - .x + modify_base(map2, .x, .y, .f, ...) } + #' @rdname modify #' @export imodify <- function(.x, .f, ...) { @@ -366,11 +357,14 @@ modify2.logical <- function(.x, .y, .f, ...) { } modify_base <- function(mapper, .x, .y, .f, ...) { - args <- recycle_args(list(.x, .y)) - .x <- args[[1]] - .y <- args[[2]] + .f <- as_mapper(.f, ...) + out <- mapper(.x, .y, .f, ...) - .x[] <- mapper(.x, .y, .f, ...) + # if .x got recycled by map2 + if (length(out) > length(.x)) { + .x <- .x[rep(1L, length(out))] + } + .x[] <- out .x } diff --git a/R/pmap.R b/R/pmap.R index f762b6c2..da659970 100644 --- a/R/pmap.R +++ b/R/pmap.R @@ -10,6 +10,9 @@ #' arguments that `.f` will be called with. Arguments will be supply by #' position if unnamed, and by name if named. #' +#' Vectors of length 1 will be recycled to any length; all other elements +#' must be have the same length. +#' #' A data frame is an important special case of `.l`. It will cause `.f` #' to be called once for each row. #' @inheritParams map diff --git a/R/utils.R b/R/utils.R index d92851c1..2a7cf84a 100644 --- a/R/utils.R +++ b/R/utils.R @@ -18,16 +18,6 @@ at_selection <- function(nm, .at){ .at } -recycle_args <- function(args) { - lengths <- map_int(args, length) - n <- max(lengths) - - stopifnot(all(lengths == 1L | lengths == n)) - to_recycle <- lengths == 1L - args[to_recycle] <- lapply(args[to_recycle], function(x) rep.int(x, n)) - args -} - #' Infix attribute accessor #' #' @description diff --git a/src/map.c b/src/map.c index 6437b7ea..2946bec9 100644 --- a/src/map.c +++ b/src/map.c @@ -111,13 +111,7 @@ SEXP map2_impl(SEXP env, SEXP x_name_, SEXP y_name_, SEXP f_name_, SEXP type_) { check_vector(y_val, ".y"); int nx = Rf_length(x_val), ny = Rf_length(y_val); - if (nx == 0 || ny == 0) { - SEXP out = PROTECT(Rf_allocVector(type, 0)); - copy_names(x_val, out); - UNPROTECT(3); - return out; - } - if (nx != ny && !(nx == 1 || ny == 1)) { + if (nx != ny && nx != 1 && ny != 1) { Rf_errorcall(R_NilValue, "Mapped vectors must have consistent lengths:\n" "* `.x` has length %d\n" @@ -125,7 +119,7 @@ SEXP map2_impl(SEXP env, SEXP x_name_, SEXP y_name_, SEXP f_name_, SEXP type_) { nx, ny); } - int n = (nx > ny) ? nx : ny; + int n = (nx == 1) ? ny : nx; // Constructs a call like f(x[[i]], y[[i]], ...) SEXP one = PROTECT(Rf_ScalarInteger(1)); @@ -150,9 +144,10 @@ SEXP pmap_impl(SEXP env, SEXP l_name_, SEXP f_name_, SEXP type_) { stop_bad_type(l_val, "a list", NULL, l_name); } - // Check all elements are lists and find maximum length + // Check all elements are lists and find recycled length int m = Rf_length(l_val); - int n = 0; + int has_scalar = 0; + int n = -1; for (int j = 0; j < m; ++j) { SEXP j_val = VECTOR_ELT(l_val, j); @@ -161,28 +156,20 @@ SEXP pmap_impl(SEXP env, SEXP l_name_, SEXP f_name_, SEXP type_) { } int nj = Rf_length(j_val); - - if (nj == 0) { - SEXP out = PROTECT(Rf_allocVector(type, 0)); - copy_names(j_val, out); - UNPROTECT(2); - return out; + if (nj == 1) { + has_scalar = 1; + continue; } - if (nj > n) { + if (n == -1) { n = nj; + } else if (nj != n) { + stop_bad_element_length(j_val, j + 1, n, NULL, ".l", true); } - } - // Check length of all elements - for (int j = 0; j < m; ++j) { - SEXP j_val = VECTOR_ELT(l_val, j); - int nj = Rf_length(j_val); - - if (nj != 1 && nj != n) { - stop_bad_element_length(j_val, j + 1, n, NULL, ".l", true); - } + if (n == -1) { + n = has_scalar ? 1 : 0; } SEXP l_names = PROTECT(Rf_getAttrib(l_val, R_NamesSymbol)); diff --git a/tests/testthat/_snaps/map2.md b/tests/testthat/_snaps/map2.md new file mode 100644 index 00000000..bb313d7d --- /dev/null +++ b/tests/testthat/_snaps/map2.md @@ -0,0 +1,10 @@ +# map2 recycles inputs + + Code + map2(1:2, 1:3, `+`) + Condition + Error: + ! Mapped vectors must have consistent lengths: + * `.x` has length 2 + * `.y` has length 3 + diff --git a/tests/testthat/_snaps/pmap.md b/tests/testthat/_snaps/pmap.md new file mode 100644 index 00000000..96e62154 --- /dev/null +++ b/tests/testthat/_snaps/pmap.md @@ -0,0 +1,16 @@ +# inputs are recycled + + Code + pmap(list(1:2, 1:3), identity) + Condition + Error in `stop_bad_length()`: + ! Element 2 of `.l` must have length 1 or 2, not 3 + +--- + + Code + pmap(list(1:2, integer()), identity) + Condition + Error in `stop_bad_length()`: + ! Element 2 of `.l` must have length 1 or 2, not 0 + diff --git a/tests/testthat/test-map2.R b/tests/testthat/test-map2.R index b2c85bac..db7ae5ad 100644 --- a/tests/testthat/test-map2.R +++ b/tests/testthat/test-map2.R @@ -1,14 +1,3 @@ -test_that("map2 inputs must be same length", { - expect_error( - map2(1:3, 2:3, function(...) NULL), - paste_line( - "Mapped vectors must have consistent lengths:", - "\\* `.x` has length 3", - "\\* `.y` has length 2" - ) - ) -}) - test_that("map2 can't simplify if elements longer than length 1", { expect_bad_element_vector_error( map2_int(1:4, 5:8, range), @@ -21,19 +10,14 @@ test_that("fails on non-vectors", { expect_bad_type_error(map2("a", environment(), identity), "`.y` must be a vector, not an environment") }) -test_that("map2 vectorised inputs of length 1", { - expect_equal(map2(1:2, 1, `+`), list(2, 3)) - expect_equal(map2(1, 1:2, `+`), list(2, 3)) -}) +test_that("map2 recycles inputs", { + expect_equal(map2(1, 1, `+`), list(2)) -test_that("any 0 length input gives 0 length output", { - expect_equal(map2(list(), list(), ~ 1), list()) - expect_equal(map2(1:10, list(), ~ 1), list()) - expect_equal(map2(list(), 1:10, ~ 1), list()) + expect_equal(map2(1:2, 1, `+`), list(2, 3)) + expect_equal(map2(integer(), 1, `+`), list()) + expect_equal(map2(NULL, 1, `+`), list()) - expect_equal(map2(NULL, NULL, ~ 1), list()) - expect_equal(map2(1:10, NULL, ~ 1), list()) - expect_equal(map2(NULL, 1:10, ~ 1), list()) + expect_snapshot(map2(1:2, 1:3, `+`), error = TRUE) }) test_that("map2 takes only names from x", { @@ -57,13 +41,9 @@ test_that("map2() with empty input copies names", { expect_identical(map2_chr(named_list, list(), identity), named(chr())) }) -test_that("map2() and pmap() recycle names (#779)", { +test_that("map2() recycle names (#779)", { expect_identical( map2(c(a = 1), 1:2, ~ .x), list(a = 1, a = 1) ) - expect_identical( - pmap(list(c(a = 1), 1:2), ~ .x), - list(a = 1, a = 1) - ) }) diff --git a/tests/testthat/test-pmap.R b/tests/testthat/test-pmap.R index dcb65ccc..671a52a0 100644 --- a/tests/testthat/test-pmap.R +++ b/tests/testthat/test-pmap.R @@ -3,21 +3,15 @@ test_that("input must be a list of vectors", { expect_bad_type_error(pmap(list(environment()), identity), "Element 1 of `.l` must be a vector, not an environment") }) -test_that("elements must be same length", { - expect_bad_element_length_error(pmap(list(1:2, 1:3), identity), "Element 1 of `.l` must have length 1 or 3, not 2") -}) - -test_that("handles any length 0 input", { - expect_equal(pmap(list(list(), list(), list()), ~ 1), list()) - expect_equal(pmap(list(NULL, NULL, NULL), ~ 1), list()) +test_that("inputs are recycled", { + expect_equal(pmap(list(1, 1), c), list(c(1, 1))) + expect_equal(pmap(list(1:2, 1), c), list(c(1, 1), c(2, 1))) - expect_equal(pmap(list(list(), list(), 1:10), ~ 1), list()) - expect_equal(pmap(list(NULL, NULL, 1:10), ~ 1), list()) -}) + expect_equal(pmap(list(list(), 1), ~ 1), list()) + expect_equal(pmap(list(NULL, 1), ~ 1), list()) -test_that("length 1 elemetns are recycled", { - out <- pmap(list(1:2, 1), c) - expect_equal(out, list(c(1, 1), c(2, 1))) + expect_snapshot(pmap(list(1:2, 1:3), identity), error = TRUE) + expect_snapshot(pmap(list(1:2, integer()), identity), error = TRUE) }) test_that(".f called with named arguments", { @@ -30,6 +24,13 @@ test_that("names are preserved", { expect_equal(names(out), c("x", "y")) }) +test_that("pmap() recycles names (#779)", { + expect_identical( + pmap(list(c(a = 1), 1:2), ~ .x), + list(a = 1, a = 1) + ) +}) + test_that("... are passed on", { out <- pmap(list(x = 1:2), list, n = 1) expect_equal(out, list( @@ -66,7 +67,7 @@ test_that("pmap on data frames performs rowwise operations", { }) test_that("pmap works with empty lists", { - expect_identical(pmap(list(), identity), list()) + expect_identical(pmap(list(), ~ 1), list()) }) test_that("preserves S3 class of input vectors (#358)", { diff --git a/tests/testthat/test-recycle_args.R b/tests/testthat/test-recycle_args.R deleted file mode 100644 index 8beee0a7..00000000 --- a/tests/testthat/test-recycle_args.R +++ /dev/null @@ -1,23 +0,0 @@ -test_that("rejects uneven lengths", { - args <- list(1, c(1:2), NULL) - expect_error(purrr:::recycle_args(args), "lengths == 1L \\| lengths == n") -}) - - -test_that("recycles single values and preserves longer ones", { - args <- list(1, 1:12, month.name, "a") - recycled <- purrr:::recycle_args(args) - - expect_equal(recycled[[1]], rep(1, 12)) - expect_equal(recycled[[2]], 1:12) - expect_equal(recycled[[3]], month.name) - expect_equal(recycled[[4]], rep("a", 12)) -}) - -test_that("will not recycle non-vectors", { - args <- list(1:12, identity) - expect_error( - purrr:::recycle_args(args), - "replicate an object of type 'closure'" - ) -}) From baf101d6420e0cc35f0eb9c0e6208c8c1608392e Mon Sep 17 00:00:00 2001 From: Hadley Wickham Date: Wed, 7 Sep 2022 13:19:09 -0500 Subject: [PATCH 2/2] Re-document --- man/pmap.Rd | 3 +++ 1 file changed, 3 insertions(+) diff --git a/man/pmap.Rd b/man/pmap.Rd index 08fe5cc2..b68e024b 100644 --- a/man/pmap.Rd +++ b/man/pmap.Rd @@ -33,6 +33,9 @@ pwalk(.l, .f, ...) arguments that \code{.f} will be called with. Arguments will be supply by position if unnamed, and by name if named. +Vectors of length 1 will be recycled to any length; all other elements +must be have the same length. + A data frame is an important special case of \code{.l}. It will cause \code{.f} to be called once for each row.}