-
Notifications
You must be signed in to change notification settings - Fork 986
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
* cast multiple 'value.var' columns * multiple 'fun.aggregate' as well * accept undefined variables in formula * accept optional column prefixes
- Loading branch information
1 parent
a71e21d
commit 25a74df
Showing
5 changed files
with
354 additions
and
391 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,122 +17,210 @@ dcast <- function(data, formula, fun.aggregate = NULL, ..., margins = NULL, | |
subset = subset, fill = fill, value.var = value.var) | ||
} | ||
|
||
check_formula <- function(formula, varnames, valnames) { | ||
if (is.character(formula)) formula = as.formula(formula) | ||
if (class(formula) != "formula" || length(formula) != 3L) | ||
stop("Invalid formula. Cast formula should be of the form LHS ~ RHS, for e.g., a + b ~ c.") | ||
vars = all.vars(formula) | ||
vars = vars[!vars %chin% c(".", "...")] | ||
allvars = c(vars, valnames) | ||
ans = deparse_formula(as.list(formula)[-1L], varnames, allvars) | ||
} | ||
|
||
deparse_formula <- function(expr, varnames, allvars) { | ||
lvars = lapply(expr, function(this) { | ||
if (is.call(this)) { | ||
if (this[[1L]] == quote(`+`)) | ||
unlist(deparse_formula(as.list(this)[-1L], varnames, allvars)) | ||
else this | ||
} else if (is.name(this)) { | ||
if (this == quote(`...`)) { | ||
subvars = setdiff(varnames, allvars) | ||
lapply(subvars, as.name) | ||
} else this | ||
} | ||
}) | ||
lvars = lapply(lvars, function(x) if (length(x) && !is.list(x)) list(x) else x) | ||
} | ||
|
||
value_vars <- function(value.var, varnames) { | ||
if (is.character(value.var)) | ||
value.var = list(value.var) | ||
value.var = lapply(value.var, unique) | ||
valnames = unique(unlist(value.var)) | ||
iswrong = which(!valnames %in% varnames) | ||
if (length(iswrong)) | ||
stop("value.var values [", paste(value.var[iswrong], collapse=", "), "] are not found in 'data'.") | ||
value.var | ||
} | ||
|
||
aggregate_funs <- function(funs, vals, ...) { | ||
if (is.call(funs) && funs[[1L]] == "eval") | ||
funs = eval(funs[[2L]], parent.frame(2L), parent.frame(2L)) | ||
if (is.call(funs) && as.character(funs[[1L]]) %in% c("c", "list")) | ||
funs = lapply(as.list(funs)[-1L], function(x) { | ||
if (is.call(x)) as.list(x)[-1L] else x | ||
}) | ||
else funs = list(funs) | ||
if (length(funs) != length(vals)) { | ||
if (length(vals) == 1L) | ||
vals = replicate(length(funs), vals) | ||
else stop("When 'fun.aggregate' and 'value.var' are both lists, 'value.var' must be either of length =1 or =length(fun.aggregate).") | ||
} | ||
dots = list(...) | ||
construct_funs <- function(fun, val) { | ||
if (is.name(fun)) fun = list(fun) | ||
ans = vector("list", length(fun)*length(val)) | ||
nms = vector("character", length(ans)) | ||
k = 1L | ||
for (i in fun) { | ||
for (j in val) { | ||
expr = list(i, as.name(j)) | ||
if (length(dots)) | ||
expr = c(expr, dots) | ||
ans[[k]] = as.call(expr) | ||
nms[k] = paste(all.names(i, max.names=1L, functions=TRUE), j, sep="_") | ||
k = k+1L; | ||
} | ||
} | ||
setattr(ans, 'names', nms) | ||
} | ||
ans = lapply(seq_along(funs), function(i) construct_funs(funs[[i]], vals[[i]])) | ||
as.call(c(quote(list), unlist(ans))) | ||
} | ||
|
||
dcast.data.table <- function(data, formula, fun.aggregate = NULL, ..., margins = NULL, | ||
subset = NULL, fill = NULL, drop = TRUE, value.var = guess(data), verbose = getOption("datatable.verbose")) { | ||
if (!is.data.table(data)) stop("'data' must be a data.table.") | ||
if (anyDuplicated(names(data))) stop('data.table to cast must have unique column names') | ||
is.formula <- function(x) class(x) == "formula" | ||
strip <- function(x) gsub("[[:space:]]*", "", x) | ||
if (is.formula(formula)) formula <- deparse(formula, 500) | ||
if (is.character(formula)) { | ||
ff <- strsplit(strip(formula), "~", fixed=TRUE)[[1]] | ||
if (length(ff) > 2) | ||
stop("Cast formula length is > 2, must be = 2.") | ||
ff <- strsplit(ff, "+", fixed=TRUE) | ||
setattr(ff, 'names', c("ll", "rr")) | ||
ff <- lapply(ff, function(x) x[x != "."]) | ||
ff_ <- unlist(ff, use.names=FALSE) | ||
replace_dots <- function(x) { | ||
if (!is.list(x)) x = as.list(x) | ||
for (i in seq_along(x)) { | ||
if (x[[i]] == "...") | ||
x[[i]] = setdiff(names(data), c(value.var, ff_)) | ||
} | ||
unlist(x) | ||
} | ||
ff <- lapply(ff, replace_dots) | ||
} else stop("Invalid formula.") | ||
ff_ <- unlist(ff, use.names=FALSE) | ||
if (length(is_wrong <- which(is.na(chmatch(ff_, names(data))))) > 0) stop("Column '", ff_[is_wrong[1]], "' not found.") | ||
if (length(ff$ll) == 0) stop("LHS of formula evaluates to 'character(0)', invalid formula.") | ||
if (length(value.var) != 1 || !is.character(value.var)) stop("'value.var' must be a character vector of length 1.") | ||
if (is.na(chmatch(value.var, names(data)))) stop("'value.var' column '", value.var, "' not found.") | ||
if (any(unlist(lapply(as.list(data)[ff_], class), use.names=FALSE) == "list")) | ||
stop("Only 'value.var' column maybe of type 'list'. This may change in the future.") | ||
drop <- as.logical(drop[1]) | ||
if (is.na(drop)) stop("'drop' must be TRUE/FALSE") | ||
|
||
# subset | ||
m <- as.list(match.call()[-1]) | ||
subset <- m$subset[[2]] | ||
drop = as.logical(drop[1]) | ||
if (is.na(drop)) stop("'drop' must be logical TRUE/FALSE") | ||
lvals = value_vars(value.var, names(data)) | ||
valnames = unique(unlist(lvals)) | ||
lvars = check_formula(formula, names(data), valnames) | ||
lvars = lapply(lvars, function(x) if (!length(x)) quote(`.`) else x) | ||
# tired of lapply and the way it handles environments! | ||
allcols = c(unlist(lvars), lapply(valnames, as.name)) | ||
dat = vector("list", length(allcols)) | ||
for (i in seq_along(allcols)) { | ||
x = allcols[[i]] | ||
dat[[i]] = if (identical(x, quote(`.`))) rep(".", nrow(data)) | ||
else eval(x, data, parent.frame()) | ||
if (is.function(dat[[i]])) | ||
stop("Column [", deparse(x), "] not found or of unknown type.") | ||
} | ||
setattr(lvars, 'names', c("lhs", "rhs")) | ||
# Have to take care of duplicate names, and provide names for expression columns properly. | ||
varnames = make.unique(sapply(unlist(lvars), all.vars, max.names=1L), sep="_") | ||
dupidx = which(valnames %in% varnames) | ||
if (length(dupidx)) { | ||
dups = valnames[dupidx] | ||
valnames = tail(make.unique(c(varnames, valnames)), -length(varnames)) | ||
lvals = lapply(lvals, function(x) { x[x %in% dups] = valnames[dupidx]; x }) | ||
} | ||
lhsnames = head(varnames, length(lvars$lhs)) | ||
rhsnames = tail(varnames, -length(lvars$lhs)) | ||
setattr(dat, 'names', c(varnames, valnames)) | ||
setDT(dat) | ||
if (any(sapply(as.list(dat)[varnames], is.list))) { | ||
stop("Columns specified in formula can not be of type list") | ||
} | ||
m <- as.list(match.call()[-1L]) | ||
subset <- m[["subset"]][[2L]] | ||
if (!is.null(subset)) { | ||
if (is.name(subset)) subset = as.call(list(quote(`(`), subset)) | ||
data = data[eval(subset, data, parent.frame()), unique(c(ff_, value.var)), with=FALSE] | ||
idx = which(eval(subset, data, parent.frame())) # any advantage thro' secondary keys? | ||
dat = .Call(CsubsetDT, dat, idx, seq_along(dat)) | ||
} | ||
if (nrow(data) == 0L || ncol(data) == 0L) stop("Can't 'cast' on an empty data.table") | ||
|
||
# set 'fun.aggregate = length' if max group size > 1 | ||
fun.null=FALSE | ||
if (is.null(fun.aggregate)) { | ||
fun.null=TRUE | ||
oo = forderv(data, by=ff_, retGrp=TRUE) | ||
if (!nrow(dat) || !ncol(dat)) stop("Can not cast an empty data.table") | ||
fun.call = m[["fun.aggregate"]] | ||
fill.default = NULL | ||
if (is.null(fun.call)) { | ||
oo = forderv(dat, by=varnames, retGrp=TRUE) | ||
if (attr(oo, 'maxgrpn') > 1L) { | ||
message("Aggregate function missing, defaulting to 'length'") | ||
fun.aggregate <- length | ||
m[["fun.aggregate"]] = quote(length) | ||
fun.call = quote(length) | ||
} | ||
} | ||
fill.default <- NULL | ||
if (!is.null(fun.aggregate)) { # construct the 'call' | ||
fill.default = fun.aggregate(data[[value.var]][0], ...) | ||
if (!length(fill.default) && (is.null(fill) || !length(fill))) | ||
stop("Aggregating function provided to argument 'fun.aggregate' should always return a length 1 vector, but returns 0-length value for fun.aggregate(", typeof(data[[value.var]]), "(0)).", " This value will have to be used to fill missing combinations, if any, and therefore can not be of length 0. Either override by setting the 'fill' argument explicitly or modify your function to handle this case appropriately.") | ||
args <- c("data", "formula", "margins", "subset", "fill", "value.var", "verbose", "drop") | ||
m <- m[setdiff(names(m), args)] | ||
.CASTcall = as.call(c(m[1], as.name(value.var), m[-1])) # issues/713 | ||
.CASTcall = as.call(c(as.name("list"), setattr(list(.CASTcall), 'names', value.var))) | ||
# workaround until #5191 (issues/497) is fixed | ||
if (length(intersect(value.var, ff_))) | ||
.CASTcall = as.call(list(as.name("{"), as.name(".SD"), .CASTcall)) | ||
} | ||
# special case | ||
if (length(ff$rr) == 0) { | ||
if (is.null(fun.aggregate)) | ||
ans = data[, c(ff$ll, value.var), with=FALSE] | ||
else { | ||
# workaround until #5191 (issues/497) is fixed | ||
if (length(intersect(value.var, ff_))) ans = data[, eval(.CASTcall), by=c(ff$ll), .SDcols=value.var] | ||
else ans = data[, eval(.CASTcall), by=c(ff$ll)] | ||
if (!is.null(fun.call)) { | ||
fun.call = aggregate_funs(fun.call, lvals, ...) | ||
errmsg = "Aggregating function(s) should take vector inputs and return a single value (length=1). However, function(s) returns length!=1. This value will have to be used to fill any missing combinations, and therefore must be length=1. Either override by setting the 'fill' argument explicitly or modify your function to handle this case appropriately." | ||
if (is.null(fill)) { | ||
tryCatch(fill.default <- dat[0][, eval(fun.call)], warning = function(x) stop(errmsg, call.=FALSE)) | ||
if (nrow(fill.default) != 1L) stop(errmsg, call.=FALSE) | ||
} | ||
if (anyDuplicated(names(ans))) { | ||
message("Duplicate column names found in cast data.table. Setting unique names using 'make.unique'") | ||
setnames(ans, make.unique(names(ans))) | ||
if (!any(valnames %chin% varnames)) { | ||
dat = dat[, eval(fun.call), by=c(varnames)] | ||
} else { | ||
dat = dat[, { .SD; eval(fun.call) }, by=c(varnames), .SDcols = valnames] | ||
} | ||
if (!identical(key(ans), ff$ll)) setkeyv(ans, names(ans)[seq_along(ff$ll)]) | ||
return(ans) | ||
} | ||
# aggregation moved to R now that 'adhoc-by' is crazy fast! | ||
if (!is.null(fun.aggregate)) { | ||
if (length(intersect(value.var, ff_))) { | ||
data = data[, eval(.CASTcall), by=c(ff_), .SDcols=value.var] | ||
value.var = tail(make.unique(names(data)), 1L) | ||
setnames(data, ncol(data), value.var) | ||
} | ||
else data = data[, eval(.CASTcall), by=c(ff_)] | ||
setkeyv(data, ff_) | ||
# issues/693 | ||
fun_agg_chk <- function(x) { | ||
# sorted now, 'forderv' should be as fast as uniqlist+uniqlengths | ||
oo = forderv(data, by=key(data), retGrp=TRUE) | ||
attr(oo, 'maxgrpn') > 1L | ||
order_ <- function(x) { | ||
o = forderv(x, retGrp=TRUE, sort=TRUE) | ||
idx = attr(o, 'starts') | ||
if (!length(o)) o = seq_along(x) | ||
o[idx] # subsetVector retains attributes, using R's subset for now | ||
} | ||
cj_uniq <- function(DT) { | ||
do.call("CJ", lapply(DT, function(x) | ||
if (is.factor(x)) { | ||
xint = seq_along(levels(x)) | ||
setattr(xint, 'levels', levels(x)) | ||
setattr(xint, 'class', class(x)) | ||
} else .Call(CsubsetVector, x, order_(x)) | ||
))} | ||
valnames = setdiff(names(dat), varnames) | ||
# 'dat' != 'data'? then setkey to speed things up (slightly), else ad-hoc (for now). Still very fast! | ||
if (!is.null(fun.call) || !is.null(subset)) | ||
setkeyv(dat, varnames) | ||
if (length(rhsnames)) { | ||
lhs = shallow(dat, lhsnames); rhs = shallow(dat, rhsnames); val = shallow(dat, valnames) | ||
# handle drop=TRUE/FALSE - Update: Logic moved to R, AND faster than previous version. Take that... old me :-). | ||
if (drop) { | ||
map = setDT(lapply(list(lhsnames, rhsnames), function(cols) frankv(dat, cols=cols, ties.method="dense"))) | ||
maporder = lapply(map, order_) | ||
mapunique = lapply(seq_along(map), function(i) .Call(CsubsetVector, map[[i]], maporder[[i]])) | ||
lhs = .Call(CsubsetDT, lhs, maporder[[1L]], seq_along(lhs)) | ||
rhs = .Call(CsubsetDT, rhs, maporder[[2L]], seq_along(rhs)) | ||
} else { | ||
lhs_ = cj_uniq(lhs); rhs_ = cj_uniq(rhs) | ||
map = vector("list", 2L) | ||
.Call(Csetlistelt, map, 1L, lhs_[lhs, which=TRUE]) | ||
.Call(Csetlistelt, map, 2L, rhs_[rhs, which=TRUE]) | ||
setDT(map) | ||
mapunique = vector("list", 2L) | ||
.Call(Csetlistelt, mapunique, 1L, seq_len(nrow(lhs_))) | ||
.Call(Csetlistelt, mapunique, 2L, seq_len(nrow(rhs_))) | ||
lhs = lhs_; rhs = rhs_ | ||
} | ||
if (!fun.null && fun_agg_chk(data)) | ||
stop("Aggregating function provided to argument 'fun.aggregate' should always return a length 1 vector for each group, but returns length != 1 for atleast one group. Please have a look at the DETAILS section of ?dcast.data.table ") | ||
maplen = sapply(mapunique, length) | ||
idx = do.call("CJ", mapunique)[map, I := .I][["I"]] # TO DO: move this to C and avoid materialising the Cross Join. | ||
ans = .Call("Cfcast", lhs, val, maplen[[1L]], maplen[[2L]], idx, fill, fill.default, is.null(fun.call)) | ||
allcols = do.call("paste", c(rhs, sep="_")) | ||
if (length(valnames) > 1L) | ||
allcols = do.call("paste", c(setcolorder(CJ(valnames, allcols, sorted=FALSE), 2:1), sep="_")) | ||
setattr(ans, 'names', c(lhsnames, allcols)) | ||
setDT(ans); setattr(ans, 'sorted', lhsnames) | ||
} else { | ||
if (is.null(subset)) | ||
data = data[, unique(c(ff_, value.var)), with=FALSE] # data is untouched so far. subset only required columns | ||
if (length(oo)) .Call(Creorder, data, oo) | ||
setattr(data, 'sorted', ff_) | ||
} | ||
.CASTenv = new.env(parent=parent.frame()) | ||
assign("forder", forderv, .CASTenv) | ||
assign("CJ", CJ, .CASTenv) | ||
ans <- .Call("Cfcast", data, ff$ll, ff$rr, value.var, fill, fill.default, is.null(fun.aggregate), .CASTenv, drop) | ||
setDT(ans) | ||
if (anyDuplicated(names(ans))) { | ||
message("Duplicate column names found in cast data.table. Setting unique names using 'make.unique'") | ||
setnames(ans, make.unique(names(ans))) | ||
# formula is of the form x + y ~ . (rare case) | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
MichaelChirico
Member
|
||
if (drop) { | ||
if (is.null(subset) && is.null(fun.call)) { | ||
dat = copy(dat) # can't be avoided | ||
setkeyv(dat, lhsnames) | ||
} | ||
ans = dat | ||
} else { | ||
lhs = shallow(dat, lhsnames) | ||
val = shallow(dat, valnames) | ||
lhs_ = cj_uniq(lhs) | ||
idx = lhs_[lhs, I := .I][["I"]] | ||
lhs_[, I := NULL] | ||
ans = .Call("Cfcast", lhs_, val, nrow(lhs_), 1L, idx, fill, fill.default, is.null(fun.call)) | ||
setDT(ans); setattr(ans, 'sorted', lhsnames) | ||
setnames(ans, c(lhsnames, valnames)) | ||
} | ||
if (length(valnames) == 1L) | ||
setnames(ans, valnames, value.var) | ||
} | ||
setattr(ans, 'sorted', names(ans)[seq_along(ff$ll)]) | ||
ans | ||
return (ans) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
@arunsrinivasan could you elaborate on what this branch is for? It's not covered in tests and this comment seems to be spurious:
runs just fine without touching that branch (regardless of
drop
)