-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adds tf_split/combine for functional fragments
- Loading branch information
fabian-s
committed
Dec 2, 2024
1 parent
86f83df
commit 622c337
Showing
8 changed files
with
265 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
#' Split / Combine functional fragments | ||
#' | ||
#' `tf_split` separates each function into a vector of functions defined on a sub-interval of | ||
#' its domain, either with overlap at the cut points or without. | ||
#' | ||
#' @param x a `tf` object | ||
#' @param splits numeric vector containing `arg`-values at which to split | ||
#' @param include which of the end points defined by `splits` to include in each | ||
#' of the resulting split functions. Defaults to `"both"`, other options are "`left`" or | ||
#' "`right`". See examples. | ||
#' | ||
#' @return for `tf_split`: a list of `tf` objects | ||
#' @export | ||
#' @rdname tfsplitcombine | ||
#' | ||
#' @examples | ||
#' x <- tfd(1:100, arg = 1:100) | ||
#' tf_split(x, splits = c(20, 80)) | ||
#' tf_split(x, splits = c(20, 80), include = "left") | ||
#' tf_split(x, splits = c(20, 80), include = "right") | ||
tf_split <- function(x, splits, include = c("both", "left", "right")) { | ||
assert_tf(x) | ||
assert_numeric(splits, | ||
lower = tf_domain(x)[1], | ||
upper = tf_domain(x)[2], | ||
any.missing = FALSE, sorted = TRUE, unique = TRUE) | ||
include <- match.arg(include) | ||
resolution_x <- get_resolution(tf_arg(x)) | ||
# if user supplied domain limit(s), remove | ||
if (splits[1] == tf_domain(x)[1]) { | ||
splits <- splits[-1] | ||
} | ||
if (splits[length(splits)] == tf_domain(x)[2]) { | ||
splits <- splits[-length(splits)] | ||
} | ||
|
||
start <- c(tf_domain(x)[1], splits) | ||
end <- c(splits, tf_domain(x)[2]) | ||
if (include == "left") { | ||
end[1:(length(end) -1)] <- head(end, -1) - resolution_x | ||
} | ||
if (include == "right") { | ||
start[-1] <- start[-1] + resolution_x | ||
} | ||
|
||
map2(start, end, \(a, b) tf_zoom(x, begin = a, end = b)) | ||
} | ||
|
||
|
||
|
||
#' @description | ||
#' `tf_combine` joins functional fragments together to create longer (or more densely evaluated) functions. | ||
#' @param ... `tf`-objects of identical lengths to combine | ||
#' @param strict only combine functions whose argument ranges do not overlap, | ||
#' are given in the correct order & contain no duplicate values at identical arguments? | ||
#' defaults to `FALSE`. By default, only the first function value at duplicate locations | ||
#' are used, the rest are discarded (with a warning). | ||
#' | ||
#' @return for `tf_combine`: a `tfd` with the combined subfunctions on the union of the input `tf_arg`-values | ||
#' @export | ||
#' @rdname tfsplitcombine | ||
#' | ||
#' @examples | ||
#' x <- tf_rgp(5) | ||
#' tfs <- tf_split(x, splits = c(.2, .6)) | ||
#' x2 <- tf_combine(tfs[[1]], tfs[[2]], tfs[[3]]) | ||
#' # tf_combine(tfs[[1]], tfs[[2]], tfs[[3]], strict = TRUE) # errors out! | ||
#' all.equal(x, x2) | ||
#' # combine works for different input types: | ||
#' tfs2_sparse <- tf_sparsify(tfs[[2]]) | ||
#' tfs3_spline <- tfb(tfs[[3]]) | ||
#' tf_combine(tfs[[1]], tfs2_sparse, tfs3_spline) | ||
|
||
tf_combine <- function(..., strict = FALSE) { | ||
tfs <- list(...) | ||
map(tfs, assert_tf) | ||
sizes <- map_int(tfs, vec_size) | ||
if (!all(duplicated(sizes)[-1])) { | ||
cli::cli_abort("can't {.fun tf_combine} objects of different sizes") | ||
} | ||
size <- sizes[1] | ||
|
||
if (strict) { | ||
# assert arg ranges in tfs at each vector position are strictly ordered | ||
args <- map(tfs, | ||
function(x) tf_arg(x) |> ensure_list()) | ||
irreg <- map(args, length) != 1 | ||
if (any(irreg)) { | ||
args <- map_if(args, !irreg, \(x) replicate(size, x)) | ||
} | ||
|
||
arg_mins <- do.call(cbind, map_depth(args, pluck_depth(args) - 1, min)) | ||
arg_maxs <- do.call(cbind, map_depth(args, pluck_depth(args) - 1, max)) | ||
|
||
min_overlap <- apply(arg_mins, 1, \(x) is.unsorted(as.numeric(x))) | ||
max_overlap <- apply(arg_maxs, 1, \(x) is.unsorted(as.numeric(x))) | ||
if( any(min_overlap) || any(max_overlap) ) { | ||
cli::cli_abort("{.fun tf_arg}-ranges of input data are not strictly ordered.") | ||
} | ||
} | ||
|
||
domains <- do.call(rbind, map(tfs, tf_domain)) | ||
new_domain <- c(min(domains[, 1]), max(domains[, 2])) | ||
tfs <- map(tfs, | ||
function(x) { | ||
suppressWarnings(tf_domain(x) <- new_domain) | ||
x | ||
}) | ||
# TODO: add check for names and reuse for output if identical? map(tfs, names) |> reduce(identical) | ||
tfs_data <- do.call(rbind, | ||
map(tfs, \(x) { | ||
names(x) <- seq_along(x) | ||
tf_2_df(x) | ||
})) | ||
|
||
duplicates <- which(duplicated(tfs_data[, -3])) #check for multiple values at some id&arg | ||
if (length(duplicates)) { | ||
if (strict) { | ||
cli::cli_abort("can't combine functions with multiple values at same argument.") | ||
} | ||
cli::cli_alert_warning("removing {length(duplicates)} duplicated points from input data.") | ||
tfs_data <- tfs_data[-duplicates, ] | ||
} | ||
|
||
tfd(tfs_data, domain = new_domain) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
|
||
test_that("tf_split works as expected", { | ||
x <- tf_rgp(3) | ||
x_i <- tf_jiggle(x) |> tf_sparsify() | ||
x_sp <- tfb(x, verbose = FALSE) | ||
x_pc <- tfb_fpc(x, verbose = FALSE) | ||
|
||
|
||
|
||
tfs <- tf_split(x, splits = c(.3, .5)) | ||
expect_list(tfs, types = "tfd_reg", len = 3) | ||
|
||
tfs_i <- tf_split(x_i, splits = c(.3, .5)) | ||
expect_list(tfs_i, types = "tfd_irreg", len = 3) | ||
|
||
tfs_sp <- tf_split(x_sp, splits = c(.3, .5)) | ||
expect_list(tfs_sp, types = "tfb_spline", len = 3) | ||
|
||
tfs_pc <- suppressWarnings(tf_split(x_pc, splits = c(.3, .5))) | ||
expect_list(tfs_pc, types = "tfb_fpc", len = 3) | ||
|
||
expect_identical(tfs, | ||
tf_split(x, splits = c(0, .3, .5, 1))) | ||
|
||
expect_error(tf_split(x, splits = c(.3, .5, 2)), | ||
"<= 1") | ||
expect_error(tf_split(x, splits = c(-2, .3, .5)), | ||
">= 0") | ||
|
||
tfs_l <- tf_split(x, splits = c(.3), include = "left") | ||
expect_identical(map(tfs_l, tf_domain), list(c(0, .299), c(.3, 1))) | ||
|
||
tfs_r <- tf_split(x, splits = c(.3), include = "right") | ||
expect_identical(map(tfs_r, tf_domain), list(c(0, .3), c(.301, 1))) | ||
}) | ||
|
||
|
||
test_that("tf_combine works as expected", { | ||
x <- tf_rgp(3) | ||
x_i <- tf_jiggle(x) |> tf_sparsify() | ||
x_sp <- tfb(x, verbose = FALSE) | ||
x_pc <- tfb_fpc(x, verbose = FALSE) | ||
|
||
expect_identical(x, | ||
do.call(tf_combine, tf_split(x, c(.3), "left"))) | ||
expect_identical(x_i, | ||
do.call(tf_combine, tf_split(x_i, c(.3), "left"))) | ||
|
||
expect_equal(tfd(x_sp), | ||
do.call(tf_combine, tf_split(x_sp, c(.3), "left")), | ||
tolerance = 1e-3) | ||
|
||
expect_equal(tfd(x_pc), | ||
do.call(tf_combine, tf_split(x_pc, c(.3), "left") |> | ||
suppressWarnings()), | ||
tolerance = 1e-5) | ||
|
||
# | ||
expect_identical( | ||
do.call(tf_combine, tf_split(x, .3)) |> suppressMessages(), | ||
do.call(tf_combine, rev(tf_split(x, .3)))|> suppressMessages()) | ||
|
||
expect_error( | ||
do.call(tf_combine, c(tf_split(x, .3), strict = TRUE)), | ||
"multiple values") | ||
|
||
expect_class(tf_combine(x, tf_jiggle(x)), | ||
"tfd_irreg") | ||
|
||
expect_error(tf_combine(x, tf_jiggle(x), strict = TRUE), | ||
"not strictly ordered") | ||
|
||
tfs <- tf_split(x, splits = c(.2, .6), include = "left") | ||
tfs2_sparse <- tf_sparsify(tfs[[2]]) | ||
tfs3_spline <- tfb(tfs[[3]], verbose = FALSE) | ||
expect_class(tf_combine(tfs[[1]], tfs2_sparse, tfs3_spline), "tfd_irreg") | ||
expect_equal(tf_combine(tfs[[1]], tfs2_sparse, tfs3_spline) |> tf_domain(), | ||
c(0, 1)) | ||
}) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters