diff --git a/NEWS.md b/NEWS.md index afbf154d71..5faf40723f 100644 --- a/NEWS.md +++ b/NEWS.md @@ -173,6 +173,40 @@ 29. `setkey()` now supports type `raw` as value columns (not as key columns), [#5100](https://github.com/Rdatatable/data.table/issues/5100). Thanks Hugh Parsonage for requesting, and Benjamin Schwendinger for the PR. +30. `shift()` is now optimised by group, [#1534](https://github.com/Rdatatable/data.table/issues/1534). Thanks to Gerhard Nachtmann for requesting, and Benjamin Schwendinger for the PR. + + ```R + N = 1e7 + DT = data.table(x=sample(N), y=sample(1e6,N,TRUE)) + shift_no_opt = shift # different name not optimised as a way to compare + microbenchmark( + DT[, c(NA, head(x,-1)), y], + DT[, shift_no_opt(x, 1, type="lag"), y], + DT[, shift(x, 1, type="lag"), y], + times=10L, unit="s") + # Unit: seconds + # expr min lq mean median uq max neval + # DT[, c(NA, head(x, -1)), y] 8.7620 9.0240 9.1870 9.2800 9.3700 9.4110 10 + # DT[, shift_no_opt(x, 1, type = "lag"), y] 20.5500 20.9000 21.1600 21.3200 21.4400 21.5200 10 + # DT[, shift(x, 1, type = "lag"), y] 0.4865 0.5238 0.5463 0.5446 0.5725 0.5982 10 + ``` + + Example from [stackoverflow](https://stackoverflow.com/questions/35179911/shift-in-data-table-v1-9-6-is-slow-for-many-groups) + ```R + set.seed(1) + mg = data.table(expand.grid(year=2012:2016, id=1:1000), + value=rnorm(5000)) + microbenchmark(v1.9.4 = mg[, c(value[-1], NA), by=id], + v1.9.6 = mg[, shift_no_opt(value, n=1, type="lead"), by=id], + v1.14.4 = mg[, shift(value, n=1, type="lead"), by=id], + unit="ms") + # Unit: milliseconds + # expr min lq mean median uq max neval + # v1.9.4 3.6600 3.8250 4.4930 4.1720 4.9490 11.700 100 + # v1.9.6 18.5400 19.1800 21.5100 20.6900 23.4200 29.040 100 + # v1.14.4 0.4826 0.5586 0.6586 0.6329 0.7348 1.318 100 + ``` + ## BUG FIXES 1. `by=.EACHI` when `i` is keyed but `on=` different columns than `i`'s key could create an invalidly keyed result, [#4603](https://github.com/Rdatatable/data.table/issues/4603) [#4911](https://github.com/Rdatatable/data.table/issues/4911). Thanks to @myoung3 and @adamaltmejd for reporting, and @ColeMiller1 for the PR. An invalid key is where a `data.table` is marked as sorted by the key columns but the data is not sorted by those columns, leading to incorrect results from subsequent queries. diff --git a/R/data.table.R b/R/data.table.R index b8c1132f6a..e020ea3e3d 100644 --- a/R/data.table.R +++ b/R/data.table.R @@ -1745,6 +1745,10 @@ replace_dot_alias = function(e) { if (!(is.call(q) && is.symbol(q[[1L]]) && is.symbol(q[[2L]]) && (q1 <- q[[1L]]) %chin% gfuns)) return(FALSE) if (!(q2 <- q[[2L]]) %chin% names(SDenv$.SDall) && q2 != ".I") return(FALSE) # 875 if ((length(q)==2L || (!is.null(names(q)) && startsWith(names(q)[3L], "na")))) return(TRUE) + if (length(q)>=2L && q[[1L]] == "shift") { + q_named = match.call(shift, q) + if (!is.call(q_named[["fill"]]) && is.null(q_named[["give.names"]])) return(TRUE) + } # add gshift support # ^^ base::startWith errors on NULL unfortunately # head-tail uses default value n=6 which as of now should not go gforce ... ^^ # otherwise there must be three arguments, and only in two cases: @@ -1848,6 +1852,17 @@ replace_dot_alias = function(e) { gi = if (length(o__)) o__[f__] else f__ g = lapply(grpcols, function(i) groups[[i]][gi]) + # returns all rows instead of one per group + nrow_funs = c("gshift") + .is_nrows = function(q) { + if (!is.call(q)) return(FALSE) + if (q[[1L]] == "list") { + any(vapply(q, .is_nrows, FALSE)) + } else { + q[[1L]] %chin% nrow_funs + } + } + # adding ghead/gtail(n) support for n > 1 #5060 #523 q3 = 0 if (!is.symbol(jsub)) { @@ -1865,6 +1880,8 @@ replace_dot_alias = function(e) { if (q3 > 0) { grplens = pmin.int(q3, len__) g = lapply(g, rep.int, times=grplens) + } else if (.is_nrows(jsub)) { + g = lapply(g, rep.int, times=len__) } ans = c(g, ans) } else { @@ -2970,7 +2987,7 @@ rleidv = function(x, cols=seq_along(x), prefix=NULL) { # (2) edit .gforce_ok (defined within `[`) to catch which j will apply the new function # (3) define the gfun = function() R wrapper gfuns = c("[", "[[", "head", "tail", "first", "last", "sum", "mean", "prod", - "median", "min", "max", "var", "sd", ".N") # added .N for #334 + "median", "min", "max", "var", "sd", ".N", "shift") # added .N for #334 `g[` = `g[[` = function(x, n) .Call(Cgnthvalue, x, as.integer(n)) # n is of length=1 here. ghead = function(x, n) .Call(Cghead, x, as.integer(n)) # n is not used at the moment gtail = function(x, n) .Call(Cgtail, x, as.integer(n)) # n is not used at the moment @@ -2984,6 +3001,11 @@ gmin = function(x, na.rm=FALSE) .Call(Cgmin, x, na.rm) gmax = function(x, na.rm=FALSE) .Call(Cgmax, x, na.rm) gvar = function(x, na.rm=FALSE) .Call(Cgvar, x, na.rm) gsd = function(x, na.rm=FALSE) .Call(Cgsd, x, na.rm) +gshift = function(x, n=1L, fill=NA, type=c("lag", "lead", "shift", "cyclic")) { + type = match.arg(type) + stopifnot(is.numeric(n)) + .Call(Cgshift, x, as.integer(n), fill, type) +} gforce = function(env, jsub, o, f, l, rows) .Call(Cgforce, env, jsub, o, f, l, rows) .prepareFastSubset = function(isub, x, enclos, notjoin, verbose = FALSE){ diff --git a/inst/tests/test2224.Rdata b/inst/tests/test2224.Rdata new file mode 100644 index 0000000000..9c6423b9fb Binary files /dev/null and b/inst/tests/test2224.Rdata differ diff --git a/inst/tests/tests.Rraw b/inst/tests/tests.Rraw index 26f531455b..6382a13a85 100644 --- a/inst/tests/tests.Rraw +++ b/inst/tests/tests.Rraw @@ -18243,21 +18243,39 @@ test(2217, DT1[, by = grp, .(agg = list(setNames(as.numeric(value), id)))], DT2) testnum = 2218 funs = c(as.integer, as.double, as.complex, as.character, if (test_bit64) as.integer64) # when test_bit64==FALSE these all passed before; now passes with test_bit64==TRUE too +# add grouping tests for #5205 +g = rep(c(1,2), each=2) +options(datatable.optimize = 2L) for (f1 in funs) { - DT = data.table(x=f1(1:4)) + DT = data.table(x=f1(1:4), g=g) for (f2 in funs) { - testnum = testnum + 0.01 + testnum = testnum + 0.001 test(testnum, DT[, shift(x)], f1(c(NA, 1:3))) - testnum = testnum + 0.01 + testnum = testnum + 0.001 w = if (identical(f2,as.character) && !identical(f1,as.character)) "Coercing.*character.*to match the type of target vector" test(testnum, DT[, shift(x, fill=f2(NA))], f1(c(NA, 1:3)), warning=w) - testnum = testnum + 0.01 + testnum = testnum + 0.001 if (identical(f1,as.character) && identical(f2,as.complex)) { # one special case due to as.complex(0)=="0+0i"!="0" test(testnum, DT[, shift(x, fill="0")], f1(0:3)) } else { test(testnum, DT[, shift(x, fill=f2(0))], f1(0:3), warning=w) } + + testnum = testnum + 0.001 + test(testnum, DT[, shift(x), by=g], data.table(g=g, V1=f1(c(NA, 1, NA, 3)))) + testnum = testnum + 0.001 + w = if (identical(f2,as.character) && !identical(f1,as.character)) "Coercing.*character.*to match the type of target vector" + f = f2(NA) + test(testnum, DT[, shift(x, fill=f), by=g], data.table(g=g, V1=f1(c(NA, 1, NA, 3))), warning=w) + testnum = testnum + 0.001 + if (identical(f1,as.character) && identical(f2,as.complex)) { + # one special case due to as.complex(0)=="0+0i"!="0" + test(testnum, DT[, shift(x, fill="0"), by=g], data.table(g=g, V1=f1(c(0,1,0,3)))) + } else { + f = f2(0) + test(testnum, DT[, shift(x, fill=f), by=g], data.table(g=g, V1=f1(c(0,1,0,3))), warning=w) + } } } @@ -18292,6 +18310,41 @@ DT = data.table(A=1:3, key="A") test(2223.1, DT[.(4), nomatch=FALSE], data.table(A=integer(), key="A")) test(2223.2, DT[.(4), nomatch=NA_character_], data.table(A=4L, key="A")) +# gshift, #5205 +options(datatable.optimize = 2L) +set.seed(123) +DT = data.table(x = sample(letters[1:5], 20, TRUE), + y = rep.int(1:2, 10), # to test 2 grouping columns get rep'd properly + i = sample(c(-2L,0L,3L,NA), 20, TRUE), + d = sample(c(1.2,-3.4,5.6,NA), 20, TRUE), + s = sample(c("foo","bar",NA), 20, TRUE), + c = sample(c(0+3i,1,-1-1i,NA), 20, TRUE), + l = sample(c(TRUE, FALSE, NA), 20, TRUE), + r = as.raw(sample(1:5, 20, TRUE))) +load(testDir("test2224.Rdata")) # ans array +if (test_bit64) { + DT[, i64:=as.integer64(sample(c(-2L,0L,2L,NA), 20, TRUE))] +} else { + ans = ans[, -match("i64",colnames(ans))] +} +test(2224.01, sapply(names(DT)[-1], function(col) { + sapply(list(1, 5, -1, -5, c(1,2), c(-1,1)), function(n) list( + # fill is tested by group in tests 2218.*; see comments in #5205 + EVAL(sprintf("DT[, shift(%s, %d, type='lag'), by=x]$V1", col, n)), + EVAL(sprintf("DT[, shift(%s, %d, type='lead'), by=x]$V1", col, n)), + EVAL(sprintf("DT[, shift(%s, %d, type='shift'), by=x]$V1", col, n)), + EVAL(sprintf("DT[, shift(%s, %d, type='cyclic'), by=x]$V1", col, n)) + )) +}), ans) +a = 1:2 # fill argument with length > 1 which is not a call +test(2224.02, DT[, shift(i, fill=a), by=x], error="fill must be a vector of length 1") +DT = data.table(x=pairlist(1), g=1) +# unsupported type as argument +test(2224.03, DT[, shift(x), g], error="Type 'list' is not supported by GForce gshift.") + # groupingsets by named by argument -test(2224.1, groupingsets(data.table(iris), j = sum(Sepal.Length), by = c('Sp'='Species'), sets = list('Species')), data.table(Species = factor(c("setosa", "versicolor", "virginica")), V1=c(250.3, 296.8, 329.4))) -test(2224.2, groupingsets(data.table(iris), j = mean(Sepal.Length), by = c('Sp'='Species'), sets = list('Species')), groupingsets(data.table(iris), j = mean(Sepal.Length), by = c('Species'), sets = list('Species'))) +test(2225.1, groupingsets(data.table(iris), j=sum(Sepal.Length), by=c('Sp'='Species'), sets=list('Species')), + data.table(Species=factor(c("setosa", "versicolor", "virginica")), V1=c(250.3, 296.8, 329.4))) +test(2225.2, groupingsets(data.table(iris), j=mean(Sepal.Length), by=c('Sp'='Species'), sets=list('Species')), + groupingsets(data.table(iris), j=mean(Sepal.Length), by=c('Species'), sets=list('Species'))) + diff --git a/src/gsumm.c b/src/gsumm.c index 5bb2620243..4964de8b6e 100644 --- a/src/gsumm.c +++ b/src/gsumm.c @@ -1162,3 +1162,93 @@ SEXP gprod(SEXP x, SEXP narmArg) { return(ans); } +SEXP gshift(SEXP x, SEXP nArg, SEXP fillArg, SEXP typeArg) { + const bool nosubset = irowslen == -1; + const bool issorted = !isunsorted; + const int n = nosubset ? length(x) : irowslen; + if (nrow != n) error(_("Internal error: nrow [%d] != length(x) [%d] in %s"), nrow, n, "gshift"); + + int nprotect=0; + enum {LAG, LEAD/*, SHIFT*/,CYCLIC} stype = LAG; + if (!(length(fillArg) == 1)) + error(_("fill must be a vector of length 1")); + + if (!isString(typeArg) || length(typeArg) != 1) + error(_("Internal error: invalid type for gshift(), should have been caught before. please report to data.table issue tracker")); // # nocov + if (!strcmp(CHAR(STRING_ELT(typeArg, 0)), "lag")) stype = LAG; + else if (!strcmp(CHAR(STRING_ELT(typeArg, 0)), "lead")) stype = LEAD; + else if (!strcmp(CHAR(STRING_ELT(typeArg, 0)), "shift")) stype = LAG; + else if (!strcmp(CHAR(STRING_ELT(typeArg, 0)), "cyclic")) stype = CYCLIC; + else error(_("Internal error: invalid type for gshift(), should have been caught before. please report to data.table issue tracker")); // # nocov + + bool lag; + const bool cycle = stype == CYCLIC; + + R_xlen_t nx = xlength(x), nk = length(nArg); + if (!isInteger(nArg)) error(_("Internal error: n must be integer")); // # nocov + const int *kd = INTEGER(nArg); + for (int i=0; i grpn -> jend = jstart */ \ + if (lag) { \ + const int o = ff[i]-1+(grpn-thisn); \ + for (int j=0; j