diff --git a/R/bmerge.R b/R/bmerge.R index 3d6ab028f3..69e9c89e02 100644 --- a/R/bmerge.R +++ b/R/bmerge.R @@ -1,6 +1,23 @@ bmerge = function(i, x, icols, xcols, roll, rollends, nomatch, mult, ops, verbose) { + if (length(icols)==1L && length(xcols)==1L && is.integer(i[[icols]]) && is.integer(x[[xcols]]) ## single column integer + && isTRUE(getOption("datatable.smerge")) ## enable option + && identical(nomatch, NA_integer_) ## for now only outer join + && identical(ops, 1L) ## equi join + && identical(roll, 0) && identical(rollends, c(FALSE, TRUE)) ## non-rolling join + ) { + getIdxGrp = function(x, cols) { ## get index only if retGrp=T + if (!isTRUE(getOption("datatable.use.index"))) return() + if (is.numeric(cols)) cols = names(x)[cols] + idx = attr(attr(x, "index", exact=TRUE), paste0("__", cols, collapse=""), exact=TRUE) + if (!is.null(attr(idx, "starts", exact=TRUE))) idx + } + if (verbose) {last.started.at=proc.time();cat("Starting smerge ...\n");flush.console()} + ans = smerge(x=i[[icols]], y=x[[xcols]], x.idx=getIdxGrp(i, icols), y.idx=getIdxGrp(x, xcols), mult=mult, out.bmerge=TRUE) + if (verbose) {cat("smerge done in",timetaken(last.started.at),"\n"); flush.console()} + return(ans) + } callersi = i i = shallow(i) # Just before the call to bmerge() in [.data.table there is a shallow() copy of i to prevent coercions here @@ -130,7 +147,7 @@ bmerge = function(i, x, icols, xcols, roll, rollends, nomatch, mult, ops, verbos } else { xo = NULL if (isTRUE(getOption("datatable.use.index"))) { - xo = getindex(x, names(x)[xcols]) + xo = c(getindex(x, names(x)[xcols])) ## c takes care of future #4386 if (verbose && !is.null(xo)) cat("on= matches existing index, using index\n") } if (is.null(xo)) { @@ -180,9 +197,7 @@ bmerge = function(i, x, icols, xcols, roll, rollends, nomatch, mult, ops, verbos if (verbose) {last.started.at=proc.time();cat("Starting bmerge ...\n");flush.console()} ans = .Call(Cbmerge, i, x, as.integer(icols), as.integer(xcols), io, xo, roll, rollends, nomatch, mult, ops, nqgrp, nqmaxgrp) if (verbose) {cat("bmerge done in",timetaken(last.started.at),"\n"); flush.console()} - # TO DO: xo could be moved inside Cbmerge - ans$xo = xo # for further use by [.data.table return(ans) } diff --git a/R/data.table.R b/R/data.table.R index 75b6b290ed..27e27dec2d 100644 --- a/R/data.table.R +++ b/R/data.table.R @@ -513,6 +513,7 @@ replace_dot_alias = function(e) { } # TODO: when nomatch=NA, len__ need not be allocated / set at all for mult="first"/"last"? # TODO: how about when nomatch=0L, can we avoid allocating then as well? + # if we take nomatch out from [b|s]merge then it should be easier to avoid allocation } if (length(xo) && length(irows)) { irows = xo[irows] # TO DO: fsort here? diff --git a/R/wrappers.R b/R/wrappers.R index 5fec33a92f..636e6e04ee 100644 --- a/R/wrappers.R +++ b/R/wrappers.R @@ -12,3 +12,5 @@ colnamesInt = function(x, cols, check_dups=FALSE) .Call(CcolnamesInt, x, cols, c coerceFill = function(x) .Call(CcoerceFillR, x) testMsg = function(status=0L, nx=2L, nk=2L) .Call(CtestMsgR, as.integer(status)[1L], as.integer(nx)[1L], as.integer(nk)[1L]) + +smerge = function(x, y, x.idx=NULL, y.idx=NULL, mult=c("all","first","last","error"), out.bmerge=FALSE) .Call(CsmergeR, x, y, x.idx, y.idx, match.arg(mult), out.bmerge) diff --git a/inst/tests/smerge.Rraw b/inst/tests/smerge.Rraw new file mode 100644 index 0000000000..3a3bf9812a --- /dev/null +++ b/inst/tests/smerge.Rraw @@ -0,0 +1,237 @@ +require(methods) + +if (exists("test.data.table", .GlobalEnv, inherits=FALSE)) { + if ((tt<-compiler::enableJIT(-1))>0) + cat("This is dev mode and JIT is enabled (level ", tt, ") so there will be a brief pause around the first test.\n", sep="") +} else { + require(data.table) + test = data.table:::test + smerge = data.table:::smerge + bmerge = data.table:::bmerge + forderv = data.table:::forderv + vecseq = data.table:::vecseq +} + +bm = function(x, y, mult="all") { + stopifnot(is.integer(x), is.integer(y)) + ans = bmerge(data.table(x=x), data.table(y=y), 1L, 1L, roll=0, rollends=c(FALSE, TRUE), nomatch=NA_integer_, mult=mult, ops=1L, verbose=FALSE) + ## if undefining SMERGE_STATS then we have to ignore allLen1 as well + ans$nMatch = as.numeric(sum(!is.na(vecseq(ans$starts, ans$lens, NULL)))) + ans +} +sm = function(x, y, mult="all") { + stopifnot(is.integer(x), is.integer(y)) + ans = smerge(x, y, mult=mult, out.bmerge=TRUE) + ## if undefining SMERGE_STATS then we have to ignore allLen1 as well + ans$nMatch = smerge(x, y, mult=mult, out.bmerge=FALSE)$nMatch + ans +} + +#setDTthreads(2) +#options(datatable.verbose=TRUE) + +# unique and sort +## x y sorted +x = c(1L,2L,3L,4L) # unq +y = c(2L,3L,5L) # unq +test(1.01, sm(x, y), bm(x, y)) +x = c(1L,2L,3L,3L,4L) +y = c(2L,3L,5L) # unq +test(1.02, sm(x, y), bm(x, y)) +x = c(1L,2L,3L,4L) # unq +y = c(2L,3L,3L,5L) +test(1.03, sm(x, y), bm(x, y)) +x = c(1L,2L,3L,3L,4L) +y = c(2L,3L,3L,5L) +test(1.04, sm(x, y), bm(x, y)) +## y unsorted +x = c(1L,2L,3L,4L) # unq +y = c(2L,5L,3L) # unq +test(2.01, sm(x, y), bm(x, y)) +x = c(1L,2L,3L,3L,4L) +y = c(3L,2L,5L) # unq +test(2.02, sm(x, y), bm(x, y)) +x = c(1L,2L,3L,4L) # unq +y = c(5L,3L,2L,3L) +test(2.03, sm(x, y), bm(x, y)) +x = c(1L,2L,3L,3L,4L) +y = c(5L,3L,3L,2L) +test(2.04, sm(x, y), bm(x, y)) +## x unsorted +x = c(2L,3L,1L,4L) # unq +y = c(2L,3L,5L) # unq +test(3.01, sm(x, y), bm(x, y)) +x = c(1L,3L,2L,4L,3L) +y = c(2L,3L,5L) # unq +test(3.02, sm(x, y), bm(x, y)) +x = c(4L,2L,3L,1L) +y = c(2L,3L,3L,5L) # unq +test(3.03, sm(x, y), bm(x, y)) +x = c(1L,2L,4L,3L,3L) +y = c(2L,3L,3L,5L) +test(3.04, sm(x, y), bm(x, y)) +## xy unsorted +x = c(4L,1L,3L,2L) # unq +y = c(2L,5L,3L) +test(4.01, sm(x, y), bm(x, y)) +x = c(1L,3L,2L,4L,3L) +y = c(5L,3L,2L) # unq +test(4.02, sm(x, y), bm(x, y)) +x = c(4L,2L,3L,1L) +y = c(3L,3L,2L,5L) # unq +test(4.03, sm(x, y), bm(x, y)) +x = c(1L,2L,4L,3L,3L) +y = c(5L,2L,3L,3L) +test(4.04, sm(x, y), bm(x, y)) + +# ties +x = c(1L,2L,3L,4L,5L) +y = c(2L,4L) # within +test(5.01, sm(x, y), bm(x, y)) +x = c(1L,2L,3L,4L,5L) +y = c(-1L,2L,4L) # left tie +test(5.02, sm(x, y), bm(x, y)) +x = c(1L,2L,3L,4L,5L) +y = c(2L,4L,7L) # right tie +test(5.03, sm(x, y), bm(x, y)) +x = c(1L,2L,3L,4L,5L) +y = c(-1L,2L,4L,6L) # both ties +test(5.04, sm(x, y), bm(x, y)) + +# nomatch +x = c(1L,3L,5L) +y = c(2L,4L) # within nomatch +test(6.01, sm(x, y), bm(x, y)) +x = c(1L,2L,3L,4L,5L) +y = c(-1L,6L) # ties nomatch +test(6.02, sm(x, y), bm(x, y)) +x = c(1L,2L,2L,2L,5L) +y = c(2L,4L) # x duplicates single match +test(6.03, sm(x, y), bm(x, y)) +x = c(1L,2L,2L,2L,3L,3L,4L,5L) +y = c(2L,4L) # x duplicates multi, single match +test(6.04, sm(x, y), bm(x, y)) +x = c(1L,2L,2L,2L,3L,4L,4L,5L) +y = c(2L,4L) # x duplicates multi, multi match +test(6.05, sm(x, y), bm(x, y)) +x = c(1L,2L,2L,2L,5L) +y = c(-1L,6L) # x duplicates nomatch +test(6.06, sm(x, y), bm(x, y)) + +# skew +N = 2e3L +x = seq_len(N) +y = c(head(x), tail(x)) +test(7.01, sm(x, y), bm(x, y)) +y = c(1:6, 750L, 1250L, 1995:2000) +test(7.02, sm(x, y), bm(x, y)) + +# custom cases +x=c(39L, 41L, 41L, 37L, 86L, 93L, 20L, 34L, 38L, 21L, 79L, 84L, +2L, 80L, 51L, 58L, 66L, 33L, 32L, 22L, 24L, 4L, 67L, 59L, 89L, +1L, 44L, 62L, 34L, 18L, 93L, 67L, 22L, 42L, 8L, 72L, 45L, 87L, +41L, 85L, 30L, 61L, 5L, 45L, 48L, 41L, 57L, 63L, 68L, 96L, 72L, +62L, 14L, 84L, 57L, 43L, 6L, 49L, 33L, 68L, 2L, 18L, 69L, 41L, +2L, 52L, 69L, 94L, 56L, 72L, 13L, 50L, 86L, 81L, 8L, 28L, 96L, +28L, 87L, 28L, 1L, 27L, 60L, 61L, 99L, 19L, 39L, 99L, 67L, 70L, +53L, 86L, 64L, 49L, 99L, 91L, 36L, 7L, 57L, 63L) +y=c(44L, 50L, 47L, 26L, 44L, 11L, 18L, 60L, 9L, 96L, 25L, 59L, +53L, 82L, 4L, 41L, 65L, 30L, 29L, 34L, 29L, 23L, 12L, 40L, 76L, +40L, 30L, 29L, 98L, 2L, 57L, 13L, 44L, 68L, 72L, 82L, 19L, 88L, +19L, 95L, 22L, 46L, 43L, 36L, 67L, 96L, 34L, 6L, 16L, 20L, 86L, +65L, 89L, 78L, 36L, 95L, 19L, 67L, 65L, 99L, 59L, 77L, 16L, 50L, +99L, 98L, 72L, 26L, 35L, 46L, 52L, 55L, 56L, 1L, 91L, 21L, 52L, +69L, 7L, 87L, 97L, 97L, 71L, 48L, 6L, 35L, 62L, 26L, 44L, 36L, +50L, 75L, 100L, 63L, 39L, 3L, 94L, 85L, 99L, 61L) +test(8.01, sm(x, y), bm(x, y)) + +# scale up +ssa = function(unq_n, size, sort=FALSE) { + if (unq_n > size) return(sample.int(unq_n, size)) + unq_sub = seq_len(unq_n) + ans = sample(c(unq_sub, sample(unq_sub, size=max(size-unq_n, 0), replace=TRUE))) + if (sort) sort(ans) else ans +} +set.seed(108) +N = 1e4L +## xy sorted +x = ssa(N, N, sort=TRUE) # unq +y = ssa(N, N, sort=TRUE) # unq +test(11.01, sm(x, y), bm(x, y)) +x = ssa(N, N*1.1, sort=TRUE) +y = ssa(N, N, sort=TRUE) # unq +test(11.02, sm(x, y), bm(x, y)) +x = ssa(N, N, sort=TRUE) # unq +y = ssa(N, N*1.1, sort=TRUE) +test(11.03, sm(x, y), bm(x, y)) +x = ssa(N, N*1.1, sort=TRUE) +y = ssa(N, N*1.1, sort=TRUE) +test(11.04, sm(x, y), bm(x, y)) +## y unsorted +x = ssa(N, N, sort=TRUE) # unq +y = ssa(N, N) # unq +test(12.01, sm(x, y), bm(x, y)) +x = ssa(N, N*1.1, sort=TRUE) +y = ssa(N, N) # unq +test(12.02, sm(x, y), bm(x, y)) +x = ssa(N, N, sort=TRUE) # unq +y = ssa(N, N*1.1) +test(12.03, sm(x, y), bm(x, y)) +x = ssa(N, N*1.1, sort=TRUE) +y = ssa(N, N*1.1) +test(12.04, sm(x, y), bm(x, y)) +## x unsorted +x = ssa(N, N) # unq +y = ssa(N, N, sort=TRUE) # unq +test(13.01, sm(x, y), bm(x, y)) +x = ssa(N, N*1.1) +y = ssa(N, N, sort=TRUE) # unq +test(13.02, sm(x, y), bm(x, y)) +x = ssa(N, N) # unq +y = ssa(N, N*1.1, sort=TRUE) +test(13.03, sm(x, y), bm(x, y)) +x = ssa(N, N*1.1) +y = ssa(N, N*1.1, sort=TRUE) +test(13.04, sm(x, y), bm(x, y)) +## xy unsorted +x = ssa(N, N) # unq +y = ssa(N, N) # unq +test(14.01, sm(x, y), bm(x, y)) +x = ssa(N, N*1.1) +y = ssa(N, N) # unq +test(14.02, sm(x, y), bm(x, y)) +x = ssa(N, N) # unq +y = ssa(N, N*1.1) +test(14.03, sm(x, y), bm(x, y)) +x = ssa(N, N*1.1) +y = ssa(N, N*1.1) +test(14.04, sm(x, y), bm(x, y)) + +# sparse +x = sample.int(2e2L, 1e2L) +y = sample.int(2e2L, 1e2L) +test(21.01, sm(x, y), bm(x, y)) +x = sample.int(2e2L, 1e2L, TRUE) +y = sample.int(2e2L, 1e2L, TRUE) +test(21.02, sm(x, y), bm(x, y)) + +# [.data.table join +d1 = data.table(x=sample.int(2e2L, 1e2L, TRUE), v1=seq_along(x)) +d2 = data.table(y=sample.int(2e2L, 1e2L, TRUE), v2=seq_along(y)) +options(datatable.smerge=FALSE, datatable.verbose=TRUE) ## verbose=2L after #4491 +test(101.01, expected <- d1[d2, on="x==y"], output="bmerge", notOutput="smerge") +options(datatable.smerge=TRUE) +test(101.02, d1[d2, on="x==y"], expected, output="smerge") ## for now extra computation of bmerge is still done so no: #, notOutput="bmerge") +setindexv2 = function(x, cols) { ## pretend we are after #4386 + stopifnot(is.data.table(x), is.character(cols)) + if (is.null(attr(x, "index", TRUE))) setattr(x, "index", integer()) + setattr(attr(x, "index", TRUE), paste0("__", cols, collapse="__"), forderv(x, cols, retGrp=TRUE)) + invisible(x) +} +options(datatable.verbose=FALSE) +setindexv2(d1, "x"); setindexv2(d2, "y") +options(datatable.use.index=TRUE, datatable.verbose=TRUE) +test(101.03, d1[d2, on="x==y"], expected, output="smerge.*already indexed") +options(datatable.use.index=FALSE) +test(101.04, d1[d2, on="x==y"], expected, output="smerge", notOutput="already indexed") +options(datatable.use.index=TRUE) diff --git a/src/bmerge.c b/src/bmerge.c index 5273ae59b9..cd6d2b307d 100644 --- a/src/bmerge.c +++ b/src/bmerge.c @@ -175,18 +175,20 @@ SEXP bmerge(SEXP iArg, SEXP xArg, SEXP icolsArg, SEXP xcolsArg, SEXP isorted, SE memcpy(INTEGER(retLengthArg), retLength, sizeof(int)*ctr); memcpy(INTEGER(retIndexArg), retIndex, sizeof(int)*ctr); } - SEXP ans = PROTECT(allocVector(VECSXP, 5)); protecti++; - SEXP ansnames = PROTECT(allocVector(STRSXP, 5)); protecti++; + SEXP ans = PROTECT(allocVector(VECSXP, 6)); protecti++; + SEXP ansnames = PROTECT(allocVector(STRSXP, 6)); protecti++; SET_VECTOR_ELT(ans, 0, retFirstArg); SET_VECTOR_ELT(ans, 1, retLengthArg); SET_VECTOR_ELT(ans, 2, retIndexArg); SET_VECTOR_ELT(ans, 3, allLen1Arg); SET_VECTOR_ELT(ans, 4, allGrp1Arg); + SET_VECTOR_ELT(ans, 5, xoArg); SET_STRING_ELT(ansnames, 0, char_starts); // changed from mkChar to char_ to pass the grep in CRAN_Release.cmd SET_STRING_ELT(ansnames, 1, char_lens); SET_STRING_ELT(ansnames, 2, char_indices); SET_STRING_ELT(ansnames, 3, char_allLen1); SET_STRING_ELT(ansnames, 4, char_allGrp1); + SET_STRING_ELT(ansnames, 5, char_xo); setAttrib(ans, R_NamesSymbol, ansnames); if (nqmaxgrp > 1 && mult == ALL) { Free(retFirst); diff --git a/src/data.table.h b/src/data.table.h index 1cf975e68b..231f8f28b4 100644 --- a/src/data.table.h +++ b/src/data.table.h @@ -79,6 +79,11 @@ extern SEXP char_lens; extern SEXP char_indices; extern SEXP char_allLen1; extern SEXP char_allGrp1; +extern SEXP char_xo; +extern SEXP char_io; +extern SEXP char_lhsLen1; +extern SEXP char_xyLen1; +extern SEXP char_nMatch; extern SEXP char_factor; extern SEXP char_ordered; extern SEXP char_datatable; @@ -243,3 +248,6 @@ SEXP testMsgR(SEXP status, SEXP x, SEXP k); //fifelse.c SEXP fifelseR(SEXP l, SEXP a, SEXP b, SEXP na); SEXP fcaseR(SEXP na, SEXP rho, SEXP args); + +// smjoin.c +SEXP smergeR(SEXP x, SEXP y, SEXP x_idx, SEXP y_idx, SEXP multArg, SEXP out_bmerge); diff --git a/src/init.c b/src/init.c index 916db3ab57..a47cfd99c7 100644 --- a/src/init.c +++ b/src/init.c @@ -15,6 +15,11 @@ SEXP char_lens; SEXP char_indices; SEXP char_allLen1; SEXP char_allGrp1; +SEXP char_xo; +SEXP char_io; +SEXP char_lhsLen1; +SEXP char_xyLen1; +SEXP char_nMatch; SEXP char_factor; SEXP char_ordered; SEXP char_datatable; @@ -119,6 +124,7 @@ SEXP lock(); SEXP unlock(); SEXP islockedR(); SEXP allNAR(); +SEXP smergeR(); // .Externals SEXP fastmean(); @@ -211,6 +217,7 @@ R_CallMethodDef callMethods[] = { {"CfrollapplyR", (DL_FUNC) &frollapplyR, -1}, {"CtestMsgR", (DL_FUNC) &testMsgR, -1}, {"C_allNAR", (DL_FUNC) &allNAR, -1}, +{"CsmergeR", (DL_FUNC) &smergeR, -1}, {NULL, NULL, 0} }; @@ -317,6 +324,11 @@ void attribute_visible R_init_datatable(DllInfo *info) char_indices = PRINTNAME(install("indices")); char_allLen1 = PRINTNAME(install("allLen1")); char_allGrp1 = PRINTNAME(install("allGrp1")); + char_xo = PRINTNAME(install("xo")); + char_io = PRINTNAME(install("io")); + char_lhsLen1 = PRINTNAME(install("lhsLen1")); + char_xyLen1 = PRINTNAME(install("xyLen1")); + char_nMatch = PRINTNAME(install("nMatch")); char_factor = PRINTNAME(install("factor")); char_ordered = PRINTNAME(install("ordered")); char_datatable = PRINTNAME(install("data.table")); diff --git a/src/smerge.c b/src/smerge.c new file mode 100644 index 0000000000..8918546e88 --- /dev/null +++ b/src/smerge.c @@ -0,0 +1,646 @@ +#include "data.table.h" + +/* + * sort-merge join + * + * join on a single integer column + * sort LHS and RHS + * split LHS into equal batches based on its unique values + * split RHS into batches by matching corresponding upper-lower bounds of LHS batches using binary search + * parallel sort-merge join + * + * for a maximum speed collecting of following statistics can be disabled by undefine SMERGE_STATS + * y_len1/y_lens1/allLen1: signals if there where multiple matches in RHS of join, ['s _x_ table + * x_len1/x_lens1/lhsLen1: signals if multiple matches in LHS of join, ['s _i_ table + * xy_len1/xy_lens1/xlLen1: signals if "many to many" matches between LHS and RHS + * cnt/nmatch/n_match/n_matchr: count of matches, taking multiple matches into account, uint64_t + * + * when hardcoding or changing default number of batches it is advised to undefine SMERGE_BATCHING_BALANCED + */ +#define SMERGE_STATS +#define SMERGE_BATCHING_BALANCED + +// this only refers to RHS side: y +enum emult{ALL, FIRST, LAST, ERR}; + +// workhorse join that runs in parallel on batches +static void smerge(const int bx_off, const int bnx, + const int by_off, const int bny, + const int *restrict x, const int *restrict x_starts, const int *restrict x_lens, const bool unq_x, + const int *restrict y, const int *restrict y_starts, const int *restrict y_lens, const bool unq_y, + int *restrict starts, int *restrict lens, + uint64_t *nmatch, bool *xlens1, bool *ylens1, bool *xylens1, + const enum emult mult) { + uint64_t cnt = 0; + bool xlen1 = true, ylen1 = true, xylen1 = true; + if (unq_x && unq_y) { + int i = bx_off, j = by_off; + const int ni = bx_off+bnx, nj = by_off+bny; + while (i y_j) { + j++; + } + } + } else if (unq_x) { + int i = bx_off, js = by_off; + const int ni = bx_off+bnx, njs = by_off+bny; + if (mult == ALL || mult == ERR) { // mult==err is raised based on ylens1 flag outside of parallel region + while (i1) + ylen1 = false; + cnt += (uint64_t)yl1; +#endif + } else if (x_i < y_j) { + i++; + } else if (x_i > y_j) { + js++; + } + } + } else if (mult == FIRST) { + while (i y_j) { + js++; + } + } + } else if (mult == LAST) { + while (i y_j) { + js++; + } + } + } + } else if (unq_y) { + int is = bx_off, j = by_off; + const int nis = bx_off+bnx, nj = by_off+bny; + while (is1) + xlen1 = false; + cnt += (uint64_t)xl1; +#endif + } else if (x_i < y_j) { + is++; + } else if (x_i > y_j) { + j++; + } + } + } else { + int is = bx_off, js = by_off; + const int nis = bx_off+bnx, njs = by_off+bny; + if (mult==ALL || mult==ERR) { + while (is1) + xlen1 = false; + if (ylen1 && yl1>1) + ylen1 = false; + if (xylen1 && xl1>1 && yl1>1) + xylen1 = false; + cnt += (uint64_t)xl1 * (uint64_t)yl1; + #endif + } else if (x_i < y_j) { + is++; + } else if (x_i > y_j) { + js++; + } + } + } else if (mult==FIRST) { + while (is1) + xlen1 = false; + cnt += (uint64_t)xl1; +#endif + } else if (x_i < y_j) { + is++; + } else if (x_i > y_j) { + js++; + } + } + } else if (mult==LAST) { + while (is1) + xlen1 = false; + cnt += (uint64_t)xl1; +#endif + } else if (x_i < y_j) { + is++; + } else if (x_i > y_j) { + js++; + } + } + } + } + xlens1[0] = xlen1; ylens1[0] = ylen1; xylens1[0] = xylen1; nmatch[0] = cnt; + return; +} + +/* + * 'rolling nearest' binary search + * used to find 'y' lower and upper, 0-based, bounds for each batch + * side -1: lower bound; side 1: upper bound + */ +static int rollbs(const int *restrict x, const int *restrict ix, const int nix, const int val, const int side) { + int verbose = 0; // devel debug only + if (verbose>0) + Rprintf("rollbs: side=%d=%s; val=%d\n", side, side<0?"min":"max", val); // # nocov + if (x[ix[0]-1] == val) { // common early stopping + if (verbose>0) + Rprintf("rollbs: min elements %d match: 0\n", val); // # nocov + return 0; + } + if (x[ix[nix-1]-1] == val) { + if (verbose>0) + Rprintf("rollbs: max elements %d match: nix-1=%d\n", val, nix-1); // # nocov + return nix-1; + } + if (side < 0) { // lower bound early stopping + if (x[ix[nix-1]-1] < val) { + if (verbose>0) + Rprintf("rollbs: max element %d is still smaller than %d: -1\n", x[ix[nix-1]-1], val); // # nocov + return -1; + } + if (x[ix[0]-1] > val) { + if (verbose>0) + Rprintf("rollbs: min element %d is bigger than %d: 0\n", x[ix[0]-1], val); // # nocov + return 0; + } + } else if (side > 0) { // upper bound early stopping + if (x[ix[0]-1] > val) { + if (verbose>0) + Rprintf("rollbs: min element %d is still bigger than %d: -1\n", x[ix[0]-1], val); // # nocov + return -1; + } + if (x[ix[nix-1]-1] < val) { + if (verbose>0) + Rprintf("rollbs: max element %d is smaller than %d: nix-1=%d\n", x[ix[nix-1]-1], val, nix-1); // # nocov + return nix-1; + } + } + int lower=0, upper=nix, i=0; + while (lower<=upper) { + i = lower + (upper-lower)/2; + int thisx = x[ix[i]-1]; + //Rprintf("rollbs: x[ix[%d]-1]=%d ?? %d\n", i, thisx, val); // if (verbose) here would slow down while loop so need to be commented out + if (thisx==val) + return(i); + else if (thisx < val) + lower = i+1; + else if (thisx > val) + upper = i-1; + } + if (verbose>0) + Rprintf("rollbs: nomatch: i=%d; this=%d; lower=%d, upper=%d; side=%d: %d\n", i, x[ix[i]-1], lower, upper, side, side<0?lower:upper); // # nocov + if (side < 0) // anyone to stress test this logic? + return lower; + else + return upper; +} + +// cuts x_starts into equal batches and binary search corresponding y_starts ranges +static void batching(const int nBatch, + const int *restrict x, const int nx, const int *restrict x_starts, const int nx_starts, + const int *restrict y, const int ny, const int *restrict y_starts, const int ny_starts, + int *restrict Bx_off, int *restrict Bnx, int *restrict By_off, int *restrict Bny, + const int verbose) { +#ifdef SMERGE_BATCHING_BALANCED + //if (nBatch > nx_starts || (nx_starts/nBatch==1 && nx_starts%nBatch>0)) error("internal error: batching %d input into %d batches, number of batches should have been reduced be now", nx_starts, nBatch); // # nocov + size_t batchSize = (nx_starts-1)/nBatch + 1; // this is fragile for arbitrary nBatch + bool balanced = true; // this is only for verbose message +#else + size_t batchSize = nx_starts / nBatch; // last batch size can be anything between 1 and 2*batchSize-1 + bool balanced = false; +#endif + size_t lastBatchSize = nx_starts - (nBatch-1)*batchSize; + if (verbose>0) + Rprintf("batching: input %d into %s %d batches (batchSize=%d, lastBatchSize=%d) of sorted x y: x[1]<=y[1] && x[nx]>=y[ny]:\n", nx_starts, balanced?"balanced":"unbalanced", nBatch, batchSize, lastBatchSize); + if (lastBatchSize==0 || ((nBatch-1) * batchSize + lastBatchSize != nx_starts)) + error("internal error: batching %d input is attempting to use invalid batches: balanced=%d, nBatch=%d, batchSize=%d, lastBatchSize=%d", nx_starts, balanced?"balanced":"unbalanced", nBatch, batchSize, lastBatchSize); // # nocov + for (int b=0; b= 0 && y_i_max >= 0; + By_off[b] = y_match ? y_i_min : 0; + Bny[b] = y_match ? y_i_max - y_i_min + 1 : 0; + } + if (verbose>0) { // print batches, 1-indexed! x y sorted! for debugging and verbose + for (int b=0; b 0) { + int x_i_min = (x_starts + Bx_off[b])[0], x_i_max = (x_starts + Bx_off[b])[Bnx[b]-1]; + int y_i_min = (y_starts + By_off[b])[0], y_i_max = (y_starts + By_off[b])[Bny[b]-1]; + Rprintf("## lower: x[%d]: %d <= %d :y[%d]\n", x_i_min, x[x_i_min-1], y[y_i_min-1], y_i_min); + Rprintf("## upper: x[%d]: %d >= %d :y[%d]\n", x_i_max, x[x_i_max-1], y[y_i_max-1], y_i_max); + } + } + } + return; +} + +// helper for verbose messages to count how many threads were used, due to schedule dynamic not all may be used +static int unqNth(const int *x, const int nx) { // x have 0:(nx-1) values + int ans = 0; + uint8_t *seen = (uint8_t *)R_alloc(nx, sizeof(uint8_t)); + memset(seen, 0, nx*sizeof(uint8_t)); + for (int i=0; i0) + t = omp_get_wtime(); + int *restrict x_lens = 0, *restrict y_lens = 0; + const bool unq_x = nx_starts==nx, unq_y = ny_starts==ny; + if (!unq_x) { + x_lens = (int *)R_alloc(nx_starts, sizeof(int)); // remove after #4395 + grpLens(x_starts, nx_starts, nx, x_lens); + } + if (!(unq_y || mult==FIRST)) { + y_lens = (int *)R_alloc(ny_starts, sizeof(int)); + grpLens(y_starts, ny_starts, ny, y_lens); + } + if (verbose>0) + Rprintf("smergeC: grpLens %s took %.3fs\n", verboseDone(!unq_x, !(unq_y || mult==FIRST), "(x already unq, y unq or mult='first')", "(y)", "(x)", "(x, y)"), omp_get_wtime() - t); + + if (verbose>0) + t = omp_get_wtime(); + int nBatch = 0; + const int nth = getDTthreads(); + if (nth == 1 || nx_starts < 1024) { +#ifdef SMERGE_BATCHING_BALANCED + nBatch = 1; // when using balanced lastBatchSize is never bigger than batchSize, any hardcoding here is likely to raise internal error in batching or segfault, testing possible with: for (i in 1:10) cc("smerge.Rraw") +#else + nBatch = 1; // so if hardcoding needed, do it here, and undefine SMERGE_BATCHING_BALANCED at the top, for unbalanced batching lastBatchSize is anything between 1 and 2*batchSize-1 +#endif + } else if (nx_starts < nth * 2) { + nBatch = nx_starts; // stress test single row batches, will be usually escaped by branch above + } else { + nBatch = nth * 2; + } + int *restrict Bx_off = (int *)R_alloc(nBatch, sizeof(int)), *restrict Bnx = (int *)R_alloc(nBatch, sizeof(int)); + int *restrict By_off = (int *)R_alloc(nBatch, sizeof(int)), *restrict Bny = (int *)R_alloc(nBatch, sizeof(int)); + batching(nBatch, x, nx, x_starts, nx_starts, y, ny, y_starts, ny_starts, Bx_off, Bnx, By_off, Bny, verbose-1); + int *restrict th = (int *)R_alloc(nBatch, sizeof(int)); // report threads used + if (verbose>0) + Rprintf("smergeC: preparing %d batches took %.3fs\n", nBatch, omp_get_wtime() - t); + + if (verbose>0) + t = omp_get_wtime(); + uint64_t nmatch = 0; + bool xlens1 = true, ylens1 = true, xylens1 = true; + #pragma omp parallel for schedule(dynamic) reduction(&&:xlens1,ylens1,xylens1) reduction(+:nmatch) num_threads(nth) + for (int b=0; b0) + Rprintf("smergeC: %d calls to smerge using %d/%d threads took %.3fs\n", nBatch, unqNth(th, nBatch), nth, omp_get_wtime() - t); // all threads may not always be used bc schedule(dynamic) + if (mult==ERR && !ylens1) + error("mult='error' and multiple matches during merge"); + + return; +} + +void sortInt(const int *restrict x, const int nx, const int *restrict idx, int *restrict ans) { + #pragma omp parallel for schedule(static) num_threads(getDTthreads()) + for (int i=0; i (uint64_t)DBL_MAX) { // 1e9 x 1e9 cartesian join results 1e18 still less than DBL_MAX, should we check against DBL_MAX or something lesser? we cast uinst64_t to double here + REAL(n_matchr)[0] = NA_REAL; + warning("count of matches exceeds DBL_MAX, returning NA in 'nMatch' field"); + } else { + REAL(n_matchr)[0] = (double)n_match; + } + SET_STRING_ELT(ansnames, 9, char_nMatch); SET_VECTOR_ELT(ans, 9, n_matchr); + UNPROTECT(1); + } + UNPROTECT(1); + return ans; +} + +const enum emult matchMultArg(SEXP multArg) { + enum emult mult; + if (!strcmp(CHAR(STRING_ELT(multArg, 0)), "all")) + mult = ALL; + else if (!strcmp(CHAR(STRING_ELT(multArg, 0)), "first")) + mult = FIRST; + else if (!strcmp(CHAR(STRING_ELT(multArg, 0)), "last")) + mult = LAST; + else if (!strcmp(CHAR(STRING_ELT(multArg, 0)), "error")) + mult = ERR; + else + error(_("Internal error: invalid value for 'mult'. please report to data.table issue tracker")); // # nocov + return mult; +} + +// main interface from R +SEXP smergeR(SEXP x, SEXP y, SEXP x_idx, SEXP y_idx, SEXP multArg, SEXP out_bmerge) { + + const int verbose = GetVerbose()*3; // remove *3 after #4491 + double t_total = 0, t = 0; + if (verbose>0) + t_total = omp_get_wtime(); + if (!isInteger(x) || !isInteger(y)) + error("'x' and 'y' must be integer"); + if (!isString(multArg)) + error("'mult' must be a string"); + const enum emult mult = matchMultArg(multArg); + const bool multLen1 = mult==FIRST || mult==LAST; + if (!IS_TRUE_OR_FALSE(out_bmerge)) + error("'out.bmerge' must be TRUE or FALSE"); + const bool ans_bmerge = (bool)LOGICAL(out_bmerge)[0]; + int protecti = 0, nx = LENGTH(x), ny = LENGTH(y); + + if (verbose>0) + t = omp_get_wtime(); + const bool do_x_idx = isNull(x_idx), do_y_idx = isNull(y_idx); + if (do_x_idx) { + x_idx = PROTECT(forder(x, R_NilValue, ScalarLogical(true), ScalarLogical(true), ScalarInteger(1), ScalarLogical(false))); protecti++; // verbose=verbose-2L after #4533 + } + if (do_y_idx) { + y_idx = PROTECT(forder(y, R_NilValue, ScalarLogical(true), ScalarLogical(true), ScalarInteger(1), ScalarLogical(false))); protecti++; // verbose=verbose-2L after #4533 + } + if (!isInteger(x_idx) || !isInteger(y_idx)) + error("'x.idx' and 'y.idx' must be integer"); + SEXP x_starts = getAttrib(x_idx, sym_starts); SEXP y_starts = getAttrib(y_idx, sym_starts); + if (isNull(x_starts) || isNull(y_starts)) + error("Indices provided to smerge must carry 'starts' attribute"); + if (verbose>0) + Rprintf("smergeR: index %s took %.3fs\n", verboseDone(do_x_idx, do_y_idx, "(already indexed)", "(y)", "(x)", "(x, y)"), omp_get_wtime() - t); + + if (verbose>0) + t = omp_get_wtime(); + const bool x_ord = !LENGTH(x_idx), y_ord = !LENGTH(y_idx); + int *xp, *yp; + if (!x_ord) { + xp = (int *)R_alloc(nx, sizeof(int)); + sortInt(INTEGER(x), nx, INTEGER(x_idx), xp); + } else { + xp = INTEGER(x); + } + if (!y_ord) { + yp = (int *)R_alloc(ny, sizeof(int)); + sortInt(INTEGER(y), ny, INTEGER(y_idx), yp); + } else { + yp = INTEGER(y); + } + if (verbose>0) + Rprintf("smergeR: sort %s took %.3fs\n", verboseDone(!x_ord, !y_ord, "(already sorted)", "(y)", "(x)", "(x, y)"), omp_get_wtime() - t); + + if (verbose>0) + t = omp_get_wtime(); + SEXP out_starts = R_NilValue, out_lens = R_NilValue; + int *restrict starts=0, *restrict lens=0; + const int lens_len = (!multLen1 || ans_bmerge) ? nx : 0; // for mult=first|last we dont need to allocate lens + if (x_ord) { // we dont need to reorder results so can save one allocation + out_starts = PROTECT(allocVector(INTSXP, nx)); protecti++; + out_lens = PROTECT(allocVector(INTSXP, lens_len)); protecti++; + starts = INTEGER(out_starts); + lens = INTEGER(out_lens); + } else { + starts = (int *)R_alloc(nx, sizeof(int)); + lens = (int *)R_alloc(lens_len, sizeof(int)); + } + // this fills default values, bmerge's defaults are tricky (dictated by how they are consumed): nomatch=0 makes starts=0 not NA, lens=0 is fine there; nomatch=NA makes lens=1 not NA, starts=NA is fine there + // AFAIU it make sense to take out nomatch argument from merge + if (multLen1 && !ans_bmerge) { + #pragma omp parallel for schedule(static) num_threads(getDTthreads()) + for (int i=0; i0) + Rprintf("smergeR: alloc of size %d took %.3fs\n", nx, omp_get_wtime() - t); + + if (verbose>0) + t = omp_get_wtime(); + uint64_t n_match = 0; + bool x_lens1 = true, y_lens1 = true, xy_lens1 = true; + smergeC( + xp, nx, INTEGER(x_starts), LENGTH(x_starts), + yp, ny, INTEGER(y_starts), LENGTH(y_starts), + starts, lens, + &n_match, &x_lens1, &y_lens1, &xy_lens1, + mult, verbose-1 + ); + if (verbose>0) + Rprintf("smergeR: smergeC of %d x %d = %"PRIu64"; took %.3fs\n", nx, ny, n_match, omp_get_wtime() - t); + + if (verbose>0) + t = omp_get_wtime(); + SEXP ans = outSmergeR(nx, starts, lens, x_ord, out_starts, out_lens, x_idx, y_idx, n_match, x_lens1, y_lens1, xy_lens1, multLen1, ans_bmerge); + if (verbose>0) + Rprintf("smergeR: outSmerge %s took %.3fs\n", x_ord ? "(was sorted)" : "(alloc and unsort)", omp_get_wtime() - t); + if (verbose>0) + Rprintf("smergeR: all took %.3fs\n", omp_get_wtime() - t_total); + + UNPROTECT(protecti); + return ans; +}