diff --git a/.Rbuildignore b/.Rbuildignore index 5c6e83d..d2fede9 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -6,3 +6,4 @@ ^_pkgdown\.yml$ ^docs$ ^pkgdown$ +^SCRATCH$ diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index a3ac618..03a22c9 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -18,11 +18,11 @@ jobs: fail-fast: false matrix: config: - - {os: macos-latest, r: 'release'} - - {os: windows-latest, r: 'release'} + # - {os: macos-latest, r: 'release'} + # - {os: windows-latest, r: 'release'} - {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'} - {os: ubuntu-latest, r: 'release'} - - {os: ubuntu-latest, r: 'oldrel-1'} + # - {os: ubuntu-latest, r: 'oldrel-1'} env: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore index f545b2b..8c80211 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ .RData docs inst/doc +SCRATCH/ diff --git a/DESCRIPTION b/DESCRIPTION index 6ed4feb..5167d69 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: parttree Title: Visualise simple decision tree partitions -Version: 0.0.1.9004 +Version: 0.0.1.9005 Authors@R: c( person(given = "Grant", family = "McDermott", @@ -12,8 +12,7 @@ Authors@R: c( role = "ctb", email = "Achim.Zeileis@R-project.org", comment = c(ORCID = "0000-0003-0918-3766")), - person(given = "Brian", - middle = "Heseung", + person(given = "Brian Heseung", family = "Kim", role = "ctb", email = "brhkim@gmail.com", @@ -27,20 +26,26 @@ Description: Simple functions for plotting 2D decision tree partition plots. License: MIT + file LICENSE Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.2 LazyData: true URL: https://github.com/grantmcdermott/parttree, http://grantmcdermott.com/parttree BugReports: https://github.com/grantmcdermott/parttree/issues -Depends: - ggplot2 (>= 3.4.0) -Imports: - rpart, +Imports: + graphics, + stats, data.table, partykit, - rlang + rlang, + rpart, + ggplot2 (>= 3.4.0), + tinyplot (> 0.1.0) Suggests: tinytest, + tinysnapshot (>= 0.0.3), + fontquiver, + rsvg, + svglite, palmerpenguins, titanic, mlr3, @@ -48,7 +53,7 @@ Suggests: workflows, magick, imager, - patchwork, knitr, rmarkdown +Remotes: grantmcdermott/tinyplot VignetteBuilder: knitr diff --git a/NAMESPACE b/NAMESPACE index edd82b9..601e6f6 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -6,9 +6,12 @@ S3method(parttree,LearnerRegrRpart) S3method(parttree,constparty) S3method(parttree,rpart) S3method(parttree,workflow) +S3method(plot,parttree) export(geom_parttree) export(parttree) -import(ggplot2) importFrom(data.table,":=") importFrom(data.table,.SD) importFrom(data.table,fifelse) +importFrom(graphics,par) +importFrom(stats,reformulate) +importFrom(tinyplot,tinyplot) diff --git a/NEWS.md b/NEWS.md index dbc4feb..47d0414 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,19 +1,34 @@ -# parttree 0.0.1.9004 +# parttree 0.0.1.9005 -To be released as 0.1 +To be released as 0.1.0 + +#### Breaking changes + +* Move ggplot2 to Suggests, following the addition of native (base R) +`plot.parttree` method. The `geom_parttree()` function now checks whether +ggplot2 is available on the user's system before executing any code. (#18) +* The `flipaxes` argument has been renamed to `flip`, e.g. +`parttree(..., flip = TRUE)`. (#18) #### Improvements -* Major speed-up for extracting parttree nodes and coordinates on complicated trees (#15). -* Add method for tidymodels workflows objects fitted with `"rpart"` engine (#7 by @juliasilge). +* Parttree objects now have their own class with a dedicated `plot.parttree` +method, powered by tinyplot. (#18) +* Major speed-up for extracting parttree nodes and coordinates on complicated +trees. (#15) +* Add method for tidymodels workflows objects fitted with `"rpart"` engine. (#7 +by @juliasilge). #### Bug fixes -* Support for negative values (#6 by @pjgeens). -* Better handling of single-level factors and `flipaxes` (#5). +* Support for negative values. (#6 by @pjgeens) +* Better handling of single-level factors and `flip(axes)`. (#5) #### Internals +* Several dependency adjustments, e.g. tinyplot to Imports and ggplot2 to +Suggests. (#18) +* Added SVG snapshots for image-based tests. (#18) * Bump ggplot2 version dependency to match deprecated functions from 3.4.0. * Switched to "main" as primary GitHub branch for development. * Added two dedicated vignettes. diff --git a/R/geom_parttree.R b/R/geom_parttree.R index 6a0b0c6..5754015 100644 --- a/R/geom_parttree.R +++ b/R/geom_parttree.R @@ -1,18 +1,18 @@ -#' @title Visualise tree partitions +#' @title Visualise tree partitions with ggplot2 #' #' @description `geom_parttree()` is a simple extension of #' [ggplot2::geom_rect()]that first calls #' [parttree()] to convert the inputted tree object into an -#' amenable data frame. +#' amenable data frame. Please note that `ggplot2` is not a hard dependency +#' of `parttree` and should thus be installed separately on the user's system. #' @param data An [rpart::rpart.object] or an object of compatible #' type (e.g. a decision tree constructed via the `partykit`, `tidymodels`, or #' `mlr3` front-ends). -#' @param flipaxes Logical. By default, the "x" and "y" axes variables for +#' @param flip Logical. By default, the "x" and "y" axes variables for #' plotting are determined by the first split in the tree. This can cause #' plot orientation mismatches depending on how users specify the other layers #' of their plot. Setting to `TRUE` will flip the "x" and "y" variables for #' the `geom_parttree` layer. -#' @import ggplot2 #' @inheritParams ggplot2::layer #' @inheritParams ggplot2::geom_point #' @inheritParams ggplot2::geom_segment @@ -39,7 +39,9 @@ #' @seealso [parttree()], [ggplot2::geom_rect()]. #' @export #' @examples -#' library(rpart) +#' library(parttree) # this package +#' library(rpart) # decision trees +#' library(ggplot2) # ggplot2 must be loaded separately #' #' ### Simple decision tree (max of two predictor variables) #' @@ -67,8 +69,7 @@ #' ## Oops #' p2 + geom_parttree(data = iris_tree, aes(fill=Species), alpha = 0.1) #' -#' ## Fix with 'flipaxes = TRUE' -#' p2 + geom_parttree(data = iris_tree, aes(fill=Species), alpha = 0.1, flipaxes = TRUE) +#' ## Fix with 'flip = TRUE' #' #' #' ### Various front-end frameworks are also supported, e.g.: @@ -106,19 +107,28 @@ geom_parttree = function(mapping = NULL, data = NULL, stat = "identity", position = "identity", linejoin = "mitre", na.rm = FALSE, show.legend = NA, - inherit.aes = TRUE, flipaxes = FALSE, ...) { - pdata = parttree(data, flipaxes = flipaxes) + inherit.aes = TRUE, flip = FALSE, ...) { + + ggplot2_installed = requireNamespace("ggplot2", quietly = TRUE) + if (isFALSE(ggplot2_installed)) { + stop("Please install the ggplot2 package.", .call = FALSE) + } else if (utils::packageVersion("ggplot2") < "3.4.0") { + stop("Please install a newer version of ggplot2 (>= 3.4.0).") + } + + + pdata = parttree(data, flip = flip) mapping_null = is.null(mapping) mapping$xmin = quote(xmin) mapping$xmax = quote(xmax) mapping$ymin = quote(ymin) mapping$ymax = quote(ymax) if (mapping_null) { - mapping = aes_all(mapping) + mapping = ggplot2::aes_all(mapping) } mapping$x = rlang::quo(NULL) mapping$y = rlang::quo(NULL) - layer( + ggplot2::layer( stat = stat, geom = GeomParttree, data = pdata, mapping = mapping, @@ -129,11 +139,13 @@ geom_parttree = ## Underlying ggproto object GeomParttree = - ggproto( - "GeomParttree", GeomRect, - default_aes = aes(colour = "black", linewidth = 0.5, linetype = 1, + ggplot2::ggproto( + "GeomParttree", ggplot2::GeomRect, + default_aes = ggplot2::aes(colour = "black", linewidth = 0.5, linetype = 1, x=NULL, y = NULL, fill = NA, alpha = NA ), non_missing_aes = c("x", "y", "xmin", "xmax", "ymin", "ymax") ) + + diff --git a/R/parttree.R b/R/parttree.R index 30796f2..ded35f4 100644 --- a/R/parttree.R +++ b/R/parttree.R @@ -1,52 +1,100 @@ #' @title Convert a decision tree into a data frame of partition coordinates -#' @aliases parttree parttree.rpart parttree._rpart parttree.workflow parttree.LearnerClassifRpart parttree.LearnerRegrRpart parttree.constparty -#' -#' @description Extracts the terminal leaf nodes of a decision tree with one or -#' two numeric predictor variables. These leaf nodes are then converted into a data -#' frame, where each row represents a partition (or leaf or terminal node) -#' that can easily be plotted in coordinate space. -#' @param tree A tree object. Supported classes include -#' [rpart::rpart.object], or the compatible classes from -#' from the `parsnip`, `workflows`, or `mlr3` front-ends, or the -#' `constparty` class inheriting from [partykit::party()]. +#' @aliases parttree parttree.rpart parttree._rpart parttree.workflow +#' parttree.LearnerClassifRpart parttree.LearnerRegrRpart parttree.constparty +#' @description Extracts the terminal leaf nodes of a decision tree that +#' contains no more that two numeric predictor variables. These leaf nodes are +#' then converted into a data frame, where each row represents a partition (or +#' leaf or terminal node) that can easily be plotted in 2-D coordinate space. +#' @param tree An \code{\link[rpart]{rpart.object}} or alike. This includes +#' compatible classes from the `mlr3` and `tidymodels` frontends, or the +#' `constparty` class inheriting from \code{\link[partykit]{party}}. #' @param keep_as_dt Logical. The function relies on `data.table` for internal #' data manipulation. But it will coerce the final return object into a #' regular data frame (default behavior) unless the user specifies `TRUE`. -#' @param flipaxes Logical. The function will automatically set the y-axis -#' variable as the first split variable in the tree provided unless -#' the user specifies `TRUE`. -#' @details This function can be used with a regression or classification tree -#' containing one or (at most) two numeric predictors. -#' @seealso [geom_parttree()], [rpart::rpart()], [partykit::ctree()]. -#' @return A data frame comprising seven columns: the leaf node, its path, a set -#' of coordinates understandable to `ggplot2` (i.e., xmin, xmax, ymin, ymax), -#' and a final column corresponding to the predicted value for that leaf. -#' @importFrom data.table := -#' @importFrom data.table .SD -#' @importFrom data.table fifelse +#' @param flip Logical. Should we flip the "x" and "y" variables in the return +#' data frame? The default behaviour is for the first split variable in the +#' tree to take the "y" slot, and any second split variable to take the "x" +#' slot. Setting to `TRUE` switches these around. +#' @seealso [plot.parttree], [geom_parttree], \code{\link[rpart]{rpart}}, +#' \code{\link[partykit]{ctree}} [partykit::ctree]. +#' @returns A data frame comprising seven columns: the leaf node, its path, a +#' set of rectangle limits (i.e., xmin, xmax, ymin, ymax), and a final column +#' corresponding to the predicted value for that leaf. +#' @importFrom data.table := .SD fifelse #' @export #' @examples +#' library("parttree") +#' #' ## rpart trees +#' #' library("rpart") -#' rp = rpart(Species ~ Petal.Length + Petal.Width, data = iris) -#' parttree(rp) +#' rp = rpart(Kyphosis ~ Start + Age, data = kyphosis) +#' +#' # A parttree object is just a data frame with additional attributes +#' (rp_pt = parttree(rp)) +#' attr(rp_pt, "parttree") +#' +#' # simple plot +#' plot(rp_pt) +#' +#' # removing the (recursive) partition borders helps to emphasise overall fit +#' plot(rp_pt, border = NA) +#' +#' # customize further by passing extra options to (tiny)plot +#' plot( +#' rp_pt, +#' border = NA, # no partition borders +#' pch = 19, # filled points +#' alpha = 0.6, # point transparency +#' grid = TRUE, # background grid +#' palette = "classic", # new colour palette +#' xlab = "Topmost vertebra operated on", # custom x title +#' ylab = "Patient age (months)", # custom y title +#' main = "Tree predictions: Kyphosis recurrence" # custom title +#' ) +#' +#' ## conditional inference trees from partyit #' -#' ## conditional inference trees #' library("partykit") #' ct = ctree(Species ~ Petal.Length + Petal.Width, data = iris) -#' parttree(ct) +#' ct_pt = parttree(ct) +#' plot(ct_pt, pch = 19, palette = "okabe", main = "ctree predictions: iris species") #' #' ## rpart via partykit #' rp2 = as.party(rp) #' parttree(rp2) +#' +#' ## various front-end frameworks are also supported, e.g. +#' +#' # tidymodels +#' +#' library(parsnip) +#' +#' decision_tree() |> +#' set_engine("rpart") |> +#' set_mode("classification") |> +#' fit(Species ~ Petal.Length + Petal.Width, data=iris) |> +#' parttree() |> +#' plot(main = "This time brought to you via parsnip...") +#' +#' # mlr3 (NB: use `keep_model = TRUE` for mlr3 learners) +#' +#' library(mlr3) +#' +#' task_iris = TaskClassif$new("iris", iris, target = "Species") +#' task_iris$formula(rhs = "Petal.Length + Petal.Width") +#' fit_iris = lrn("classif.rpart", keep_model = TRUE) # NB! +#' fit_iris$train(task_iris) +#' plot(parttree(fit_iris), main = "... and now mlr3") +#' parttree = - function(tree, keep_as_dt = FALSE, flipaxes = FALSE) { + function(tree, keep_as_dt = FALSE, flip = FALSE) { UseMethod("parttree") } #' @export parttree.rpart = - function(tree, keep_as_dt = FALSE, flipaxes = FALSE) { + function(tree, keep_as_dt = FALSE, flip = FALSE, ...) { ## Silence NSE notes in R CMD check. See: ## https://cran.r-project.org/web/packages/data.table/vignettes/datatable-importing.html#globals V1 = node = path = variable = side = ..vars = xvar = yvar = xmin = xmax = ymin = ymax = NULL @@ -64,7 +112,8 @@ parttree.rpart = ## Get details about y variable for later ### y variable string (i.e. name) - y_var = attr(tree$terms, "variables")[[2]] + y_var = paste(tree$terms)[2] + # y_var = attr(tree$terms, "variables")[[2]] ### y values yvals = tree$frame[tree$frame$var == "", ]$yval y_factored = attr(tree$terms, "dataClasses")[paste(y_var)] == "factor" @@ -106,7 +155,7 @@ parttree.rpart = ## special case we can assume is likely wrong, notwithstanding ability to still flip axes if (vars[1]=="y" & vars[2]=="x") vars = rev(vars) } - if (flipaxes) { + if (flip) { vars = rev(vars) ## Handle edge cases with only 1 level if (length(vars)==1) { @@ -150,46 +199,146 @@ parttree.rpart = part_coords = as.data.frame(part_coords) } + class(part_coords) = c("parttree", class(part_coords)) + + # attributes (for plot method) + dots = list(...) + if (!is.null(dots[["xvar"]])) { + xvar = dots[["xvar"]] + } else { + xvar = ifelse(isFALSE(flip), vars[1], vars[2]) + } + if (!is.null(dots[["yvar"]])) { + yvar = dots[["yvar"]] + } else { + yvar = ifelse(isFALSE(flip), vars[2], vars[1]) + } + if (!is.null(dots[["xrange"]])) { + xrange = dots[["xrange"]] + } else { + xrange = range(eval(tree$call$data, envir = attr(tree$terms, ".Environment"))[[xvar]], na.rm = TRUE) + } + if (!is.null(dots[["yrange"]])) { + yrange = dots[["yrange"]] + } else { + yrange = range(eval(tree$call$data, envir = attr(tree$terms, ".Environment"))[[yvar]], na.rm = TRUE) + } + raw_data = orig_call = orig_na.action = NULL + if (!is.null(dots[["raw_data"]])) { + raw_data = substitute(dots[["raw_data"]]) + } else { + orig_call = tree$call + orig_na.action = tree$na.action + } + + attr(part_coords, "parttree") = list( + xvar = xvar, + yvar = yvar, + xrange = xrange, + yrange = yrange, + response = y_var, + call = orig_call, + na.action = orig_na.action, + raw_data = raw_data + ) + return(part_coords) } #' @export parttree._rpart = - function(tree, keep_as_dt = FALSE, flipaxes = FALSE) { + function(tree, keep_as_dt = FALSE, flip = FALSE) { ## parsnip front-end if (is.null(tree$fit)) { stop("No model detected.\n", "Did you forget to fit a model? See `?parsnip::fit`.") } tree = tree$fit - parttree.rpart(tree, keep_as_dt = keep_as_dt, flipaxes = flipaxes) + # extra attribute arguments to pass through ... to parttree.rpart + raw_data = attr(tree$terms, ".Environment")$data + vars = attr(tree$terms, "term.labels") + xvar = ifelse(isFALSE(flip), vars[1], vars[2]) + yvar = ifelse(isFALSE(flip), vars[2], vars[1]) + xrange = range(raw_data[[xvar]]) + yrange = range(raw_data[[yvar]]) + + parttree.rpart( + tree, keep_as_dt = keep_as_dt, flip = flip, + raw_data = raw_data, + xvar = xvar, yvar = yvar, + xrange = xrange, yrange = yrange + ) } #' @export parttree.workflow = - function(tree, keep_as_dt = FALSE, flipaxes = FALSE) { + function(tree, keep_as_dt = FALSE, flip = FALSE) { ## workflow front-end if (!workflows::is_trained_workflow(tree)) { stop("No model detected.\n", "Did you forget to fit a model? See `?workflows::fit`.") } y_name = names(tree$pre$mold$outcomes)[[1]] + raw_data = cbind(tree$pre$mold$predictors, tree$pre$mold$outcomes) tree = workflows::extract_fit_engine(tree) + tree$terms[[2]] = y_name attr(tree$terms, "variables")[[2]] = y_name names(attr(tree$terms, "dataClasses"))[[1]] = y_name - parttree.rpart(tree, keep_as_dt = keep_as_dt, flipaxes = flipaxes) + + # extra attribute arguments to pass through ... to parttree.rpart + vars = attr(tree$terms, "term.labels") + xvar = ifelse(isFALSE(flip), vars[1], vars[2]) + yvar = ifelse(isFALSE(flip), vars[2], vars[1]) + xrange = range(raw_data[[xvar]]) + yrange = range(raw_data[[yvar]]) + + parttree.rpart( + tree, keep_as_dt = keep_as_dt, flip = flip, + raw_data = raw_data, + xvar = xvar, yvar = yvar, + xrange = xrange, yrange = yrange + ) } #' @export parttree.LearnerClassifRpart = - function(tree, keep_as_dt = FALSE, flipaxes = FALSE) { + function(tree, keep_as_dt = FALSE, flip = FALSE) { ## mlr3 front-end if (is.null(tree$model)) { stop("No model detected.\n", "Did you forget to assign a learner? See `?mlr3::lrn`.") } + + pars = tree$param_set$get_values() + keep_model = isTRUE(pars$keep_model) + tree = tree$model - parttree.rpart(tree, keep_as_dt = keep_as_dt, flipaxes = flipaxes) + + # extra attribute arguments to pass through ... to parttree.rpart + # raw_data = eval(tree$call$data) + vars = attr(tree$terms, "term.labels") + xvar = ifelse(isFALSE(flip), vars[1], vars[2]) + yvar = ifelse(isFALSE(flip), vars[2], vars[1]) + if (keep_model) { + raw_data = tree$model + xrange = range(raw_data[[xvar]]) + yrange = range(raw_data[[yvar]]) + } else { + raw_data = NA + xrange = NA + yrange = NA + message( + "\nUnable to retrieve the original data, which we need for the default plot.parttree method.", + "\nFor mlr3 workflows, we recommended an explicit call to `keep_model = TRUE` when defining your Learner before training the model.\n" + ) + } + + parttree.rpart( + tree, keep_as_dt = keep_as_dt, flip = flip, + raw_data = raw_data, + xvar = xvar, yvar = yvar, + xrange = xrange, yrange = yrange + ) } #' @export @@ -197,7 +346,7 @@ parttree.LearnerRegrRpart = parttree.LearnerClassifRpart #' @export parttree.constparty = - function(tree, keep_as_dt = FALSE, flipaxes = FALSE) { + function(tree, keep_as_dt = FALSE, flip = FALSE) { ## sanity checks for tree mt = tree$terms mf = attr(mt, "factors") @@ -303,11 +452,25 @@ parttree.constparty = path = labs ) names(rval)[2L] = my - rval = cbind(rval, if(flipaxes) ints[, c(3L:4L, 1L:2L, drop = FALSE)] else ints) + rval = cbind(rval, if(flip) ints[, c(3L:4L, 1L:2L, drop = FALSE)] else ints) colnames(rval)[4L:7L] = c("xmin", "xmax", "ymin", "ymax") ## turn into data.table? if(keep_as_dt) rval = data.table::as.data.table(rval) + class(rval) = c("parttree", class(rval)) + xvar = ifelse(isFALSE(flip), mx[1], mx[2]) + yvar = ifelse(isFALSE(flip), mx[2], mx[1]) + attr(rval, "parttree") = list( + xvar = xvar, + yvar = yvar, + xrange = range(eval(tree$data)[[xvar]], na.rm = TRUE), + yrange = range(eval(tree$data)[[yvar]], na.rm = TRUE), + response = my, + call = NULL, + na.action = NULL, + raw_data = substitute(tree$data) # Or, partykit::model_frame_rpart? + ) + return(rval) } diff --git a/R/plot.R b/R/plot.R new file mode 100644 index 0000000..babb849 --- /dev/null +++ b/R/plot.R @@ -0,0 +1,145 @@ +#' @title Plot decision tree partitions +#' @description Provides a plot method for parttree objects. +#' @returns No return value, called for side effect of producing a plot. +#' @param x A [parttree] data frame. +#' @param raw Logical. Should the raw (i.e., original) data be plotted alongside +#' the tree partitions? Default is `TRUE`. +#' @param border Colour of the partition borders (edges). Default is "black". To +#' remove the borders altogether, specify as `NA`. +#' @param fill_alpha Numeric in the range `[0,1]`. Alpha transparency of the +#' filled partition rectangles. Default is `0.3`. +#' @param expand Logical. Should the partition limits be expanded to to meet the +#' edge of the plot axes? Default is `TRUE`. If `FALSE`, then the partition +#' limits will extend only until the range of the raw data. +#' @param jitter Logical. Should the raw points be jittered? Default is `FALSE`. +#' Only evaluated if `raw = TRUE`. +#' @param add Logical. Add to an existing plot? Default is `FALSE`. +#' @param ... Additional arguments passed down to +#' \code{\link[tinyplot]{tinyplot}}. +#' @param raw Logical. Should the raw (original) data points be plotted too? +#' Default is TRUE. +#' @returns No return value; called for its side effect of producing a plot. +#' @importFrom stats reformulate +#' @importFrom graphics par +#' @importFrom tinyplot tinyplot +#' @rdname plot.parttree +#' @inherit parttree examples +#' @export +plot.parttree = function( + x, + raw = TRUE, + border = "black", + fill_alpha = 0.3, + expand = TRUE, + jitter = FALSE, + add = FALSE, + ... + ) { + object = x + xvar = attr(object, "parttree")[["xvar"]] + yvar = attr(object, "parttree")[["yvar"]] + xrange = attr(object, "parttree")[["xrange"]] + yrange = attr(object, "parttree")[["yrange"]] + response = attr(object, "parttree")[["response"]] + raw_data = attr(object, "parttree")[["raw_data"]] + orig_call = attr(object, "parttree")[["call"]] + orig_na_idx = attr(object, "parttree")[["na.action"]] + + if (isTRUE(raw)) { + if (!is.null(raw_data)) { + raw_data = eval(raw_data) + } else { + raw_data = eval(orig_call$data)[, c(response, xvar, yvar)] + if (!is.null(orig_na_idx)) raw_data = raw_data[-orig_na_idx, , drop = FALSE] + } + if (is.null(raw_data) || (is.atomic(raw_data) && is.na(raw_data))) { + warning( + "\nCould not find original data. Setting `raw = FALSE`.\n" + ) + raw = FALSE + } + } + + ## First adjust our parttree object to better fit some base R graphics + ## requirements + + xmin_idxr = object$xmin == -Inf + xmax_idxr = object$xmax == Inf + ymin_idxr = object$ymin == -Inf + ymax_idxr = object$ymax == Inf + + object$xmin[xmin_idxr] = xrange[1] + object$xmax[xmax_idxr] = xrange[2] + object$ymin[ymin_idxr] = yrange[1] + object$ymax[ymax_idxr] = yrange[2] + + ## Start plotting... + + plot_fml = reformulate(paste(xvar, "|", response), response = yvar) + + # First draw an empty plot (since we need the plot corners to correctly + # expand the partition limits to the edges of the plot). We'll create a + # dummy object for this task. + if (isFALSE(add)) { + dobj = data.frame( + response = rep(object[[response]], 2), + x = c(object[["xmin"]], object[["xmax"]]), + y = c(object[["ymin"]], object[["ymax"]]) + ) + colnames(dobj) = c(response, xvar, yvar) + + if (isTRUE(raw) && isTRUE(jitter)) { + dobj[[xvar]] = range(c(dobj[[xvar]], jitter(raw_data[[xvar]])), na.rm = TRUE) + dobj[[yvar]] = range(c(dobj[[yvar]], jitter(raw_data[[yvar]])), na.rm = TRUE) + } + + tinyplot( + plot_fml, + data = dobj, + type = "rect", + col = border, + fill = fill_alpha, + empty = TRUE, + ... + ) + } + + object$response = object[[response]] + + # Grab the plot corners and adjust the partition limits + if (isTRUE(expand)) { + corners = par("usr") + object$xmin[xmin_idxr] = corners[1] + object$xmax[xmax_idxr] = corners[2] + object$ymin[ymin_idxr] = corners[3] + object$ymax[ymax_idxr] = corners[4] + } + + + # Add the (adjusted) partition rectangles + with( + object, + tinyplot( + xmin = xmin, ymin = ymin, xmax = xmax, ymax = ymax, + by = response, + type = "rect", + add = TRUE, + col = border, + fill = fill_alpha, + ... + ) + ) + + # Add the original data points (if requested) + if (isTRUE(raw)) { + ptype = ifelse(isTRUE(jitter), "j", "p") + tinyplot( + plot_fml, + data = raw_data, + type = ptype, + add = TRUE, + ... + ) + } +} + diff --git a/README.Rmd b/README.Rmd index cd1a4b2..6160f87 100644 --- a/README.Rmd +++ b/README.Rmd @@ -21,17 +21,17 @@ knitr::opts_chunk$set( Visualize simple 2-D decision tree partitions in R. The **parttree** -package is optimised to work with [**ggplot2**](https://ggplot2.tidyverse.org/), -although it can be used to visualize tree partitions with base R graphics too. +package provides visualization methods for both base R graphics (via +[**tinyplot**](https://grantmcdermott.com/tinyplot/)) and +[**ggplot2**](https://ggplot2.tidyverse.org/). ## Installation -This package is not yet on CRAN, but can be installed from [GitHub](https://github.com/) -with: +This package is not on CRAN yet, but can be installed from +[r-universe](https://grantmcdermott.r-universe.dev/parttree): ``` r -# install.packages("remotes") -remotes::install_github("grantmcdermott/parttree") +install.packages("parttree", repos = "https://grantmcdermott.r-universe.dev") ``` ## Quickstart @@ -42,15 +42,44 @@ quickstart example using the dataset that comes bundled with the **rpart** package. In this case, we are interested in predicting kyphosis recovery after spinal surgery, as a function of 1) the number of topmost vertebra that were operated, and 2) patient age. -The key visualization layer below---provided by this package---is -`geom_partree()`. + +The key function is `parttree()`, which comes with its own plotting method. ```{r quickstart} library(rpart) # For the dataset and fitting decisions trees -library(parttree) # This package (will automatically load ggplot2 too) +library(parttree) # This package fit = rpart(Kyphosis ~ Start + Age, data = kyphosis) +# Grab the partitions and plot +fit_pt = parttree(fit) +plot(fit_pt) +``` + +Customize your plots by passing additional arguments: + +```{r quickstart2} +plot( + fit_pt, + border = NA, # no partition borders + pch = 19, # filled points + alpha = 0.6, # point transparency + grid = TRUE, # background grid + palette = "classic", # new colour palette + xlab = "Topmost vertebra operated on", # custom x title + ylab = "Patient age (months)", # custom y title + main = "Tree predictions: Kyphosis recurrence" # custom title +) +``` + +### ggplot2 + +For **ggplot2** users, we offer an equivalent workflow via the `geom_partree()` +visualization layer. + +```{r quickstart_gg} +library(ggplot2) ## Should be loaded separately + ggplot(kyphosis, aes(x = Start, y = Age)) + geom_parttree(data = fit, alpha = 0.1, aes(fill = Kyphosis)) + # <-- key layer geom_point(aes(col = Kyphosis)) + diff --git a/README.md b/README.md index ff399f2..3a261b3 100644 --- a/README.md +++ b/README.md @@ -10,18 +10,17 @@ Visualize simple 2-D decision tree partitions in R. The **parttree** -package is optimised to work with -[**ggplot2**](https://ggplot2.tidyverse.org/), although it can be used -to visualize tree partitions with base R graphics too. +package provides visualization methods for both base R graphics (via +[**tinyplot**](https://grantmcdermott.com/tinyplot/)) and +[**ggplot2**](https://ggplot2.tidyverse.org/). ## Installation -This package is not yet on CRAN, but can be installed from -[GitHub](https://github.com/) with: +This package is not on CRAN yet, but can be installed from +[r-universe](https://grantmcdermott.r-universe.dev/parttree): ``` r -# install.packages("remotes") -remotes::install_github("grantmcdermott/parttree") +install.packages("parttree", repos = "https://grantmcdermott.r-universe.dev") ``` ## Quickstart @@ -34,16 +33,50 @@ quickstart example using the dataset that comes bundled with the **rpart** package. In this case, we are interested in predicting kyphosis recovery after spinal surgery, as a function of 1) the number of topmost vertebra that were operated, and -2) patient age. The key visualization layer below—provided by this -package—is `geom_partree()`. +2) patient age. + +The key function is `parttree()`, which comes with its own plotting +method. ``` r library(rpart) # For the dataset and fitting decisions trees -library(parttree) # This package (will automatically load ggplot2 too) -#> Loading required package: ggplot2 +library(parttree) # This package fit = rpart(Kyphosis ~ Start + Age, data = kyphosis) +# Grab the partitions and plot +fit_pt = parttree(fit) +plot(fit_pt) +``` + + + +Customize your plots by passing additional arguments: + +``` r +plot( + fit_pt, + border = NA, # no partition borders + pch = 19, # filled points + alpha = 0.6, # point transparency + grid = TRUE, # background grid + palette = "classic", # new colour palette + xlab = "Topmost vertebra operated on", # custom x title + ylab = "Patient age (months)", # custom y title + main = "Tree predictions: Kyphosis recurrence" # custom title +) +``` + + + +### ggplot2 + +For **ggplot2** users, we offer an equivalent workflow via the +`geom_partree()` visualization layer. + +``` r +library(ggplot2) ## Should be loaded separately + ggplot(kyphosis, aes(x = Start, y = Age)) + geom_parttree(data = fit, alpha = 0.1, aes(fill = Kyphosis)) + # <-- key layer geom_point(aes(col = Kyphosis)) + @@ -54,4 +87,4 @@ ggplot(kyphosis, aes(x = Start, y = Age)) + theme_minimal() ``` - + diff --git a/inst/tinytest/_tinysnapshot/iris_classification.svg b/inst/tinytest/_tinysnapshot/iris_classification.svg new file mode 100644 index 0000000..ffb1192 --- /dev/null +++ b/inst/tinytest/_tinysnapshot/iris_classification.svg @@ -0,0 +1,230 @@ + + + + + + + + + + + + + + + +Species +setosa +versicolor +virginica + + + + + + + +Petal.Length +Petal.Width + + + + + + + + + + +1 +2 +3 +4 +5 +6 +7 + + + + + + +0.5 +1.0 +1.5 +2.0 +2.5 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/inst/tinytest/_tinysnapshot/iris_regression.svg b/inst/tinytest/_tinysnapshot/iris_regression.svg new file mode 100644 index 0000000..a65edc5 --- /dev/null +++ b/inst/tinytest/_tinysnapshot/iris_regression.svg @@ -0,0 +1,235 @@ + + + + + + + + + + + + + + 5.5 + 6.5 + 7.5 +- - +- - +- - +Sepal.Length + + + + + + + +Petal.Length +Sepal.Width + + + + + + + + + + +1 +2 +3 +4 +5 +6 +7 + + + + + + +2.0 +2.5 +3.0 +3.5 +4.0 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/inst/tinytest/helpers.R b/inst/tinytest/helpers.R new file mode 100644 index 0000000..e0c7db0 --- /dev/null +++ b/inst/tinytest/helpers.R @@ -0,0 +1,10 @@ +library(tinytest) +library(tinysnapshot) + +# # Skip tests if not on Linux +ON_LINUX = Sys.info()["sysname"] == "Linux" +if (!ON_LINUX) exit_file("Linux snapshots") + +options("tinysnapshot_os" = "Linux") +options("tinysnapshot_device" = "svglite") +options("tinysnapshot_device_args" = list(user_fonts = fontquiver::font_families("Liberation"))) diff --git a/inst/tinytest/test_ctree.R b/inst/tinytest/test_ctree.R index 3f86e71..4286136 100644 --- a/inst/tinytest/test_ctree.R +++ b/inst/tinytest/test_ctree.R @@ -8,5 +8,7 @@ if (require(partykit)) { data = iris) ct = parttree(ct) row.names(ct) = NULL + attr(ct, "parttree") = NULL + class(ct) = "data.frame" expect_equal(ct, pt_ct_cl_known) } diff --git a/inst/tinytest/test_mlr3.R b/inst/tinytest/test_mlr3.R index a12923b..655f02b 100644 --- a/inst/tinytest/test_mlr3.R +++ b/inst/tinytest/test_mlr3.R @@ -6,16 +6,22 @@ if (require(mlr3)) { # task_cl = tsk("iris", target = "Species") # simpler (but less precise?) version of below task_cl = TaskClassif$new("iris", iris, target = "Species") task_cl$formula(rhs = "Petal.Length + Petal.Width") - fit_cl = lrn("classif.rpart") + fit_cl = lrn("classif.rpart", keep_model = TRUE) fit_cl$train(task_cl) - expect_equal(pt_cl_known, parttree(fit_cl)) + fit_cl_pt = parttree(fit_cl) + attr(fit_cl_pt, "parttree") = NULL + class(fit_cl_pt) = "data.frame" + expect_equal(pt_cl_known, fit_cl_pt) # Regression source('known_output/parttree_rpart_regression.R') task_reg = TaskRegr$new("iris", iris, target = "Sepal.Length") task_reg$formula(rhs = "Petal.Length + Sepal.Width") - fit_reg = lrn("regr.rpart") + fit_reg = lrn("regr.rpart", keep_model = TRUE) fit_reg$train(task_reg) - expect_equal(pt_reg_known, parttree(fit_reg), , tolerance = 1e-7) + fit_reg_pt = parttree(fit_reg) + attr(fit_reg_pt, "parttree") = NULL + class(fit_reg_pt) = "data.frame" + expect_equal(pt_reg_known, fit_reg_pt, tolerance = 1e-7) } diff --git a/inst/tinytest/test_rpart.R b/inst/tinytest/test_rpart.R index 57cd453..8caee82 100644 --- a/inst/tinytest/test_rpart.R +++ b/inst/tinytest/test_rpart.R @@ -1,3 +1,7 @@ +# For tinysnapshot +source("helpers.R") +using("tinysnapshot") + # # Classification # @@ -7,25 +11,38 @@ source('known_output/parttree_rpart_classification.R') # rpart rp = rpart::rpart(Species ~ Petal.Length + Petal.Width, data = iris) -expect_equal(pt_cl_known, parttree(rp)) +rp_pt = parttree(rp) +# plot method +f = function() {plot(rp_pt)} +expect_snapshot_plot(f, label = "iris_classification") +# now strip attributes and compare data frames +attr(rp_pt, "parttree") = NULL +class(rp_pt) = "data.frame" +expect_equal(pt_cl_known, rp_pt) + # partykit if (require(partykit)) { rp2 = as.party(rp) - rp2 = parttree(rp2) - row.names(rp2) = NULL + rp2_pt = parttree(rp2) + row.names(rp2_pt) = NULL + attr(rp2_pt, "parttree") = NULL + class(rp2_pt) = "data.frame" expect_equal(pt_cl_known[, setdiff(names(pt_cl_known), 'node')], - rp2[, setdiff(names(rp2), 'node')]) + rp2_pt[, setdiff(names(rp2_pt), 'node')]) } # -# flipaxes +# flip # # Comparison data source('known_output/parttree_rpart_classification_flip.R') -expect_equal(pt_cl_known_flip, parttree(rp, flipaxes = TRUE)) +rp_pt_flip = parttree(rp, flip = TRUE) +attr(rp_pt_flip, "parttree") = NULL +class(rp_pt_flip) = "data.frame" +expect_equal(pt_cl_known_flip, rp_pt_flip) # @@ -35,5 +52,13 @@ expect_equal(pt_cl_known_flip, parttree(rp, flipaxes = TRUE)) # Comparison data source('known_output/parttree_rpart_regression.R') -rp = rpart::rpart(Sepal.Length ~ Petal.Length + Sepal.Width, data = iris) -expect_equal(pt_reg_known, parttree(rp), tolerance = 1e-7) +rp_reg = rpart::rpart(Sepal.Length ~ Petal.Length + Sepal.Width, data = iris) +rp_reg_pt = parttree(rp_reg) +# plot method +f = function() {plot(rp_reg_pt)} +expect_snapshot_plot(f, label = "iris_regression") +# now strip attributes and compare data frames +attr(rp_reg_pt, "parttree") = NULL +class(rp_reg_pt) = "data.frame" +expect_equal(pt_reg_known, rp_reg_pt, tolerance = 1e-7) + diff --git a/inst/tinytest/test_tidymodels.R b/inst/tinytest/test_tidymodels.R index 6dd7a84..19e04df 100644 --- a/inst/tinytest/test_tidymodels.R +++ b/inst/tinytest/test_tidymodels.R @@ -9,17 +9,23 @@ if (require(tidymodels)) { fml_cl = Species ~ Petal.Length + Petal.Width # parsnip - ps_cl = decision_tree() %>% - set_engine("rpart") %>% - set_mode("classification") %>% - fit(fml_cl, data = iris) - expect_equal(pt_cl_known, parttree(ps_cl)) + ps_cl_pt = decision_tree() |> + set_engine("rpart") |> + set_mode("classification") |> + fit(fml_cl, data = iris) |> + parttree() + attr(ps_cl_pt, "parttree") = NULL + class(ps_cl_pt) = "data.frame" + expect_equal(pt_cl_known, ps_cl_pt) # workflows - wf_spec_cl = decision_tree() %>% set_mode("classification") + wf_spec_cl = decision_tree() |> set_mode("classification") wf_tree_cl = workflow(fml_cl, spec = wf_spec_cl) wf_cl = fit(wf_tree_cl, iris) - expect_equal(pt_cl_known, parttree(wf_cl)) + wf_cl_pt = parttree(wf_cl) + attr(wf_cl_pt, "parttree") = NULL + class(wf_cl_pt) = "data.frame" + expect_equal(pt_cl_known, wf_cl_pt) } # @@ -33,15 +39,21 @@ if (require(tidymodels)) { fml_reg = Sepal.Length ~ Petal.Length + Sepal.Width # parsnip - ps_reg = decision_tree() %>% - set_engine("rpart") %>% - set_mode("regression") %>% - fit(fml_reg, data = iris) - expect_equal(pt_reg_known, parttree(ps_reg), tolerance = 1e-7) + ps_reg_pt = decision_tree() |> + set_engine("rpart") |> + set_mode("regression") |> + fit(fml_reg, data = iris) |> + parttree() + attr(ps_reg_pt, "parttree") = NULL + class(ps_reg_pt) = "data.frame" + expect_equal(pt_reg_known, ps_reg_pt, tolerance = 1e-7) # workflows - wf_spec_reg = decision_tree() %>% set_mode("regression") + wf_spec_reg = decision_tree() |> set_mode("regression") wf_tree_reg = workflow(fml_reg, spec = wf_spec_reg) wf_reg = fit(wf_tree_reg, iris) - expect_equal(pt_reg_known, parttree(wf_reg), tolerance = 1e-7) + wf_reg_pt = parttree(wf_reg) + attr(wf_reg_pt, "parttree") = NULL + class(wf_reg_pt) = "data.frame" + expect_equal(pt_reg_known, wf_reg_pt, tolerance = 1e-7) } diff --git a/man/figures/README-quickstart-1.png b/man/figures/README-quickstart-1.png index 0445a4f..bba8a5a 100644 Binary files a/man/figures/README-quickstart-1.png and b/man/figures/README-quickstart-1.png differ diff --git a/man/figures/README-quickstart2-1.png b/man/figures/README-quickstart2-1.png new file mode 100644 index 0000000..e5c234e Binary files /dev/null and b/man/figures/README-quickstart2-1.png differ diff --git a/man/figures/README-quickstart_gg-1.png b/man/figures/README-quickstart_gg-1.png new file mode 100644 index 0000000..0445a4f Binary files /dev/null and b/man/figures/README-quickstart_gg-1.png differ diff --git a/man/geom_parttree.Rd b/man/geom_parttree.Rd index 9a04fa3..cf30bb3 100644 --- a/man/geom_parttree.Rd +++ b/man/geom_parttree.Rd @@ -2,7 +2,7 @@ % Please edit documentation in R/geom_parttree.R \name{geom_parttree} \alias{geom_parttree} -\title{Visualise tree partitions} +\title{Visualise tree partitions with ggplot2} \usage{ geom_parttree( mapping = NULL, @@ -13,7 +13,7 @@ geom_parttree( na.rm = FALSE, show.legend = NA, inherit.aes = TRUE, - flipaxes = FALSE, + flip = FALSE, ... ) } @@ -27,15 +27,31 @@ mapping.} type (e.g. a decision tree constructed via the \code{partykit}, \code{tidymodels}, or \code{mlr3} front-ends).} -\item{stat}{The statistical transformation to use on the data for this -layer, either as a \code{ggproto} \code{Geom} subclass or as a string naming the -stat stripped of the \code{stat_} prefix (e.g. \code{"count"} rather than -\code{"stat_count"})} - -\item{position}{Position adjustment, either as a string naming the adjustment -(e.g. \code{"jitter"} to use \code{position_jitter}), or the result of a call to a -position adjustment function. Use the latter if you need to change the -settings of the adjustment.} +\item{stat}{The statistical transformation to use on the data for this layer. +When using a \verb{geom_*()} function to construct a layer, the \code{stat} +argument can be used the override the default coupling between geoms and +stats. The \code{stat} argument accepts the following: +\itemize{ +\item A \code{Stat} ggproto subclass, for example \code{StatCount}. +\item A string naming the stat. To give the stat as a string, strip the +function name of the \code{stat_} prefix. For example, to use \code{stat_count()}, +give the stat as \code{"count"}. +\item For more information and other ways to specify the stat, see the +\link[ggplot2:layer_stats]{layer stat} documentation. +}} + +\item{position}{A position adjustment to use on the data for this layer. This +can be used in various ways, including to prevent overplotting and +improving the display. The \code{position} argument accepts the following: +\itemize{ +\item The result of calling a position function, such as \code{position_jitter()}. +This method allows for passing extra arguments to the position. +\item A string naming the position adjustment. To give the position as a +string, strip the function name of the \code{position_} prefix. For example, +to use \code{position_jitter()}, give the position as \code{"jitter"}. +\item For more information and other ways to specify the position, see the +\link[ggplot2:layer_positions]{layer position} documentation. +}} \item{linejoin}{Line join style (round, mitre, bevel).} @@ -53,22 +69,46 @@ rather than combining with them. This is most useful for helper functions that define both data and aesthetics and shouldn't inherit behaviour from the default plot specification, e.g. \code{\link[ggplot2:borders]{borders()}}.} -\item{flipaxes}{Logical. By default, the "x" and "y" axes variables for +\item{flip}{Logical. By default, the "x" and "y" axes variables for plotting are determined by the first split in the tree. This can cause plot orientation mismatches depending on how users specify the other layers of their plot. Setting to \code{TRUE} will flip the "x" and "y" variables for the \code{geom_parttree} layer.} -\item{...}{Other arguments passed on to \code{\link[ggplot2:layer]{layer()}}. These are -often aesthetics, used to set an aesthetic to a fixed value, like -\code{colour = "red"} or \code{size = 3}. They may also be parameters -to the paired geom/stat.} +\item{...}{Other arguments passed on to \code{\link[ggplot2:layer]{layer()}}'s \code{params} argument. These +arguments broadly fall into one of 4 categories below. Notably, further +arguments to the \code{position} argument, or aesthetics that are required +can \emph{not} be passed through \code{...}. Unknown arguments that are not part +of the 4 categories below are ignored. +\itemize{ +\item Static aesthetics that are not mapped to a scale, but are at a fixed +value and apply to the layer as a whole. For example, \code{colour = "red"} +or \code{linewidth = 3}. The geom's documentation has an \strong{Aesthetics} +section that lists the available options. The 'required' aesthetics +cannot be passed on to the \code{params}. Please note that while passing +unmapped aesthetics as vectors is technically possible, the order and +required length is not guaranteed to be parallel to the input data. +\item When constructing a layer using +a \verb{stat_*()} function, the \code{...} argument can be used to pass on +parameters to the \code{geom} part of the layer. An example of this is +\code{stat_density(geom = "area", outline.type = "both")}. The geom's +documentation lists which parameters it can accept. +\item Inversely, when constructing a layer using a +\verb{geom_*()} function, the \code{...} argument can be used to pass on parameters +to the \code{stat} part of the layer. An example of this is +\code{geom_area(stat = "density", adjust = 0.5)}. The stat's documentation +lists which parameters it can accept. +\item The \code{key_glyph} argument of \code{\link[ggplot2:layer]{layer()}} may also be passed on through +\code{...}. This can be one of the functions described as +\link[ggplot2:draw_key]{key glyphs}, to change the display of the layer in the legend. +}} } \description{ \code{geom_parttree()} is a simple extension of \code{\link[ggplot2:geom_tile]{ggplot2::geom_rect()}}that first calls \code{\link[=parttree]{parttree()}} to convert the inputted tree object into an -amenable data frame. +amenable data frame. Please note that \code{ggplot2} is not a hard dependency +of \code{parttree} and should thus be installed separately on the user's system. } \details{ Because of the way that \code{ggplot2} validates inputs and assembles @@ -95,7 +135,9 @@ cue regarding the prediction in each partition region)} } \examples{ -library(rpart) +library(parttree) # this package +library(rpart) # decision trees +library(ggplot2) # ggplot2 must be loaded separately ### Simple decision tree (max of two predictor variables) @@ -123,8 +165,7 @@ p2 = ggplot(iris, aes(x=Petal.Width, y=Petal.Length)) + ## Oops p2 + geom_parttree(data = iris_tree, aes(fill=Species), alpha = 0.1) -## Fix with 'flipaxes = TRUE' -p2 + geom_parttree(data = iris_tree, aes(fill=Species), alpha = 0.1, flipaxes = TRUE) +## Fix with 'flip = TRUE' ### Various front-end frameworks are also supported, e.g.: diff --git a/man/parttree.Rd b/man/parttree.Rd index 94e8fbf..4839975 100644 --- a/man/parttree.Rd +++ b/man/parttree.Rd @@ -10,52 +10,100 @@ \alias{parttree.constparty} \title{Convert a decision tree into a data frame of partition coordinates} \usage{ -parttree(tree, keep_as_dt = FALSE, flipaxes = FALSE) +parttree(tree, keep_as_dt = FALSE, flip = FALSE) } \arguments{ -\item{tree}{A tree object. Supported classes include -\link[rpart:rpart.object]{rpart::rpart.object}, or the compatible classes from -from the \code{parsnip}, \code{workflows}, or \code{mlr3} front-ends, or the -\code{constparty} class inheriting from \code{\link[partykit:party]{partykit::party()}}.} +\item{tree}{An \code{\link[rpart]{rpart.object}} or alike. This includes +compatible classes from the \code{mlr3} and \code{tidymodels} frontends, or the +\code{constparty} class inheriting from \code{\link[partykit]{party}}.} \item{keep_as_dt}{Logical. The function relies on \code{data.table} for internal data manipulation. But it will coerce the final return object into a regular data frame (default behavior) unless the user specifies \code{TRUE}.} -\item{flipaxes}{Logical. The function will automatically set the y-axis -variable as the first split variable in the tree provided unless -the user specifies \code{TRUE}.} +\item{flip}{Logical. Should we flip the "x" and "y" variables in the return +data frame? The default behaviour is for the first split variable in the +tree to take the "y" slot, and any second split variable to take the "x" +slot. Setting to \code{TRUE} switches these around.} } \value{ -A data frame comprising seven columns: the leaf node, its path, a set -of coordinates understandable to \code{ggplot2} (i.e., xmin, xmax, ymin, ymax), -and a final column corresponding to the predicted value for that leaf. +A data frame comprising seven columns: the leaf node, its path, a +set of rectangle limits (i.e., xmin, xmax, ymin, ymax), and a final column +corresponding to the predicted value for that leaf. } \description{ -Extracts the terminal leaf nodes of a decision tree with one or -two numeric predictor variables. These leaf nodes are then converted into a data -frame, where each row represents a partition (or leaf or terminal node) -that can easily be plotted in coordinate space. -} -\details{ -This function can be used with a regression or classification tree -containing one or (at most) two numeric predictors. +Extracts the terminal leaf nodes of a decision tree that +contains no more that two numeric predictor variables. These leaf nodes are +then converted into a data frame, where each row represents a partition (or +leaf or terminal node) that can easily be plotted in 2-D coordinate space. } \examples{ +library("parttree") + ## rpart trees + library("rpart") -rp = rpart(Species ~ Petal.Length + Petal.Width, data = iris) -parttree(rp) +rp = rpart(Kyphosis ~ Start + Age, data = kyphosis) + +# A parttree object is just a data frame with additional attributes +(rp_pt = parttree(rp)) +attr(rp_pt, "parttree") + +# simple plot +plot(rp_pt) + +# removing the (recursive) partition borders helps to emphasise overall fit +plot(rp_pt, border = NA) + +# customize further by passing extra options to (tiny)plot +plot( + rp_pt, + border = NA, # no partition borders + pch = 19, # filled points + alpha = 0.6, # point transparency + grid = TRUE, # background grid + palette = "classic", # new colour palette + xlab = "Topmost vertebra operated on", # custom x title + ylab = "Patient age (months)", # custom y title + main = "Tree predictions: Kyphosis recurrence" # custom title +) + +## conditional inference trees from partyit -## conditional inference trees library("partykit") ct = ctree(Species ~ Petal.Length + Petal.Width, data = iris) -parttree(ct) +ct_pt = parttree(ct) +plot(ct_pt, pch = 19, palette = "okabe", main = "ctree predictions: iris species") ## rpart via partykit rp2 = as.party(rp) parttree(rp2) + +## various front-end frameworks are also supported, e.g. + +# tidymodels + +library(parsnip) + +decision_tree() |> + set_engine("rpart") |> + set_mode("classification") |> + fit(Species ~ Petal.Length + Petal.Width, data=iris) |> + parttree() |> + plot(main = "This time brought to you via parsnip...") + +# mlr3 (NB: use `keep_model = TRUE` for mlr3 learners) + +library(mlr3) + +task_iris = TaskClassif$new("iris", iris, target = "Species") +task_iris$formula(rhs = "Petal.Length + Petal.Width") +fit_iris = lrn("classif.rpart", keep_model = TRUE) # NB! +fit_iris$train(task_iris) +plot(parttree(fit_iris), main = "... and now mlr3") + } \seealso{ -\code{\link[=geom_parttree]{geom_parttree()}}, \code{\link[rpart:rpart]{rpart::rpart()}}, \code{\link[partykit:ctree]{partykit::ctree()}}. +\link{plot.parttree}, \link{geom_parttree}, \code{\link[rpart]{rpart}}, +\code{\link[partykit]{ctree}} \link[partykit:ctree]{partykit::ctree}. } diff --git a/man/plot.parttree.Rd b/man/plot.parttree.Rd new file mode 100644 index 0000000..29ab8d6 --- /dev/null +++ b/man/plot.parttree.Rd @@ -0,0 +1,115 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/plot.R +\name{plot.parttree} +\alias{plot.parttree} +\title{Plot decision tree partitions} +\usage{ +\method{plot}{parttree}( + x, + raw = TRUE, + border = "black", + fill_alpha = 0.3, + expand = TRUE, + jitter = FALSE, + add = FALSE, + ... +) +} +\arguments{ +\item{x}{A \link{parttree} data frame.} + +\item{raw}{Logical. Should the raw (original) data points be plotted too? +Default is TRUE.} + +\item{border}{Colour of the partition borders (edges). Default is "black". To +remove the borders altogether, specify as \code{NA}.} + +\item{fill_alpha}{Numeric in the range \verb{[0,1]}. Alpha transparency of the +filled partition rectangles. Default is \code{0.3}.} + +\item{expand}{Logical. Should the partition limits be expanded to to meet the +edge of the plot axes? Default is \code{TRUE}. If \code{FALSE}, then the partition +limits will extend only until the range of the raw data.} + +\item{jitter}{Logical. Should the raw points be jittered? Default is \code{FALSE}. +Only evaluated if \code{raw = TRUE}.} + +\item{add}{Logical. Add to an existing plot? Default is \code{FALSE}.} + +\item{...}{Additional arguments passed down to +\code{\link[tinyplot]{tinyplot}}.} +} +\value{ +No return value, called for side effect of producing a plot. + +No return value; called for its side effect of producing a plot. +} +\description{ +Provides a plot method for parttree objects. +} +\examples{ +library("parttree") + +## rpart trees + +library("rpart") +rp = rpart(Kyphosis ~ Start + Age, data = kyphosis) + +# A parttree object is just a data frame with additional attributes +(rp_pt = parttree(rp)) +attr(rp_pt, "parttree") + +# simple plot +plot(rp_pt) + +# removing the (recursive) partition borders helps to emphasise overall fit +plot(rp_pt, border = NA) + +# customize further by passing extra options to (tiny)plot +plot( + rp_pt, + border = NA, # no partition borders + pch = 19, # filled points + alpha = 0.6, # point transparency + grid = TRUE, # background grid + palette = "classic", # new colour palette + xlab = "Topmost vertebra operated on", # custom x title + ylab = "Patient age (months)", # custom y title + main = "Tree predictions: Kyphosis recurrence" # custom title +) + +## conditional inference trees from partyit + +library("partykit") +ct = ctree(Species ~ Petal.Length + Petal.Width, data = iris) +ct_pt = parttree(ct) +plot(ct_pt, pch = 19, palette = "okabe", main = "ctree predictions: iris species") + +## rpart via partykit +rp2 = as.party(rp) +parttree(rp2) + +## various front-end frameworks are also supported, e.g. + +# tidymodels + +library(parsnip) + +decision_tree() |> + set_engine("rpart") |> + set_mode("classification") |> + fit(Species ~ Petal.Length + Petal.Width, data=iris) |> + parttree() |> + plot(main = "This time brought to you via parsnip...") + +# mlr3 (NB: use `keep_model = TRUE` for mlr3 learners) + +library(mlr3) + +task_iris = TaskClassif$new("iris", iris, target = "Species") +task_iris$formula(rhs = "Petal.Length + Petal.Width") +fit_iris = lrn("classif.rpart", keep_model = TRUE) # NB! +fit_iris$train(task_iris) +plot(parttree(fit_iris), main = "... and now mlr3") + +} diff --git a/vignettes/parttree-art.Rmd b/vignettes/parttree-art.Rmd index b15bbc7..5a8c842 100644 --- a/vignettes/parttree-art.Rmd +++ b/vignettes/parttree-art.Rmd @@ -29,6 +29,8 @@ library(parttree) # This package library(rpart) # For decision trees library(magick) # For reading and manipulating images library(imager) # Another image library, with some additional features + +op = par(mar = c(0,0,0,0)) # Remove plot margins ``` While the exact details will vary depending on the image at hand, the essential @@ -71,7 +73,8 @@ manipulation.] rosalba = image_read("https://upload.wikimedia.org/wikipedia/commons/a/aa/Rembrandt_Peale_-_Portrait_of_Rosalba_Peale_-_Google_Art_Project.jpg") # Crop around the eyes -rosalba = image_crop(rosalba, "855x450+890+1350") +rosalba = image_crop(rosalba, "850x400+890+1350") +# rosalba = image_crop(rosalba, "750x350+890+1350") # Convert to cimg (better for in-memory manipulation) rosalba = magick2cimg(rosalba) @@ -85,7 +88,7 @@ With our cropped image in hand, let's walk through the 4-step recipe from above. Step 1 is converting the image into a data frame. -```{r} +```{r rosalba_df} # Coerce to data frame rosalba_df = as.data.frame(rosalba) @@ -98,7 +101,7 @@ head(rosalba_df) Step 2 is splitting the image by RGB colour channel. This is the `cc` column above, where 1=Red, 2=Green, and 3=Blue. -```{r} +```{r rosalba_ccs} rosalba_ccs = split(rosalba_df, rosalba_df$cc) # We have a list of three DFs by colour channel. Uncomment if you want to see: @@ -112,13 +115,14 @@ we see more variation in the final predictions) and trimming each tree to a maximum depth of 30 nodes. The next code chunk takes about 15 seconds to run on my laptop, but should be much quicker if you downloaded a lower-res image. -```{r} +```{r trees} ## Start creating regression tree for each color channel. We'll adjust some ## control parameters to give us the "right" amount of resolution in the final ## plots. trees = lapply( rosalba_ccs, - function(d) rpart(value ~ x + y, data=d, control=list(cp=0.00001, maxdepth=20)) + # function(d) rpart(value ~ x + y, data=d, control=list(cp=0.00001, maxdepth=20)) + function(d) rpart(value ~ x + y, data=d, control=list(cp=0.00002, maxdepth=20)) ) ``` @@ -126,7 +130,7 @@ Step 4 is using our model (colour) predictions to construct our abstracted art piece. I was bit glib about it earlier, since it really involves a few sub-steps. First, let's grab the predictions for each of our trees. -```{r} +```{r pred_trees} pred = lapply(trees, predict) # get predictions for each tree ``` @@ -138,7 +142,7 @@ overwriting the "value" column of our original (pre-split) `rosalba_df` data frame. We can then coerce the data frame back into a `cimg` object, which comes with a bunch of nice plotting methods. -```{r} +```{r pred_img} # The pred object is a list, so we convert it to a vector before overwriting the # value column of the original data frame rosalba_df$value = do.call("c", pred) @@ -152,10 +156,7 @@ Now we're ready to draw our abstracted art piece. It's also where the partitioned areas of the downscaled pixels. Here's how we can do it using base R plotting functions. -```{r} -## Maximum/minimum for plotting range as base rect() does not handle Inf well -m = 1000 - +```{r rosalba_abstract} # get a list of parttree data frames (one for each tree) pts = lapply(trees, parttree) @@ -164,22 +165,23 @@ plot(pred_img, axes = FALSE) ## ... then layer the partitions as a series of rectangles lapply( pts, - function(pt) rect( - pmax(-m, pt$xmin), pmax(-m, pt$ymin), pmin(m, pt$xmax), pmin(m, pt$ymax), - lwd = 0.06, border = "grey15" - ) + function(pt) plot( + pt, raw = FALSE, add = TRUE, expand = FALSE, + fill_alpha = NULL, lwd = 0.1, border = "grey15" ) +) ``` We can achieve the same effect with **ggplot2** if you prefer to use that. -```{r} +```{r rosalba_abstract_gg} +library(ggplot2) ggplot() + annotation_raster(pred_img, ymin=-Inf, ymax=Inf, xmin=-Inf, xmax=Inf) + lapply(trees, function(d) geom_parttree(data = d, lwd = 0.05, col = "grey15")) + scale_x_continuous(limits=c(0, max(rosalba_df$x)), expand=c(0,0)) + scale_y_reverse(limits=c(max(rosalba_df$y), 0), expand=c(0,0)) + - coord_fixed(ratio = 0.9) + + coord_fixed(ratio = Reduce(x = dim(rosalba)[2:1], f = "/") * 2) + theme_void() ``` @@ -196,7 +198,7 @@ Here are the main image conversion, modeling, and prediction steps. These follow the same recipe that we saw in the previous portrait example, so I won't repeat my explanations. -```{r} +```{r bonzai} bonzai = load.image("https://upload.wikimedia.org/wikipedia/commons/thumb/0/0c/Japanese_Zelkova_bonsai_16%2C_30_April_2012.JPG/480px-Japanese_Zelkova_bonsai_16%2C_30_April_2012.JPG") plot(bonzai, axes = FALSE) @@ -214,7 +216,7 @@ bonzai_ccs = split(bonzai_df, bonzai_df$cc) bonzai_trees = lapply( bonzai_ccs, function(d) rpart(value ~ x + y, data=d, control=list(cp=0.00001, maxdepth=10)) - ) +) # Overwrite value column with predictions vector bonzai_df$value = do.call("c", lapply(bonzai_trees, predict)) @@ -232,7 +234,7 @@ I first saw this trick or adapted this function from. But it works particularly well in cases like this where we want the partition lines to blend in with the main image.] -```{r} +```{r bonzai_mean_cols} mean_cols = function(dat) { mcols = tapply(dat$value, dat$cc, FUN = "mean") col1 = rgb(mcols[1], mcols[2], mcols[3]) @@ -246,9 +248,33 @@ mean_cols = function(dat) { bonzai_mean_cols = mean_cols(bonzai_df) ``` +The penultimate step is to generate `parttree` objects from each of our colour +channel-based trees. + +```{r bonza_pts} +bonzai_pts = lapply(bonzai_trees, parttree) +``` + And now we can plot everything together. +```{r bonza_abstract} +plot(bonzai_pred_img, axes = FALSE) +Map( + function(pt, cols) { + plot( + pt, raw = FALSE, add = TRUE, expand = FALSE, + fill_alpha = 0, lwd = 0.2, border = cols + ) + }, + pt = bonzai_pts, + cols = bonzai_mean_cols +) +``` + +Again, equivalent result for those prefer **ggplot2**. + ```{r} +# library(ggplot2) ggplot() + annotation_raster(bonzai_pred_img, ymin=-Inf, ymax=Inf, xmin=-Inf, xmax=Inf) + Map(function(d, cols) geom_parttree(data = d, lwd = 0.1, col = cols), d = bonzai_trees, cols = bonzai_mean_cols) + @@ -262,7 +288,7 @@ ggplot() + The individual trees for each colour channel make for nice stained glass prints... -```{r} +```{r, eval=FALSE, include=FALSE} library(patchwork) ## Aside: We're reversing the y-scale since higher values actually correspond @@ -286,4 +312,23 @@ g = g[[1]] + g[[2]] + g[[3]] ``` +```{r bonzai_glass} +lapply( + seq_along(bonzai_pts), + function(i) { + plot( + bonzai_pts[[i]], raw = FALSE, expand = FALSE, + axes = FALSE, legend = FALSE, + main = paste0(c("R", "G", "B")[i]), + ## Aside: We're reversing the y-scale since higher values actually + ## correspond to points lower on the image, visually. + ylim = rev(attr(bonzai_pts[[i]], "parttree")[["yrange"]]) + ) + } +) +``` +```{r reset_par} +# reset the plot margins +par(op) +``` diff --git a/vignettes/parttree-intro.Rmd b/vignettes/parttree-intro.Rmd index 3554ec0..878ef6f 100644 --- a/vignettes/parttree-intro.Rmd +++ b/vignettes/parttree-intro.Rmd @@ -15,11 +15,10 @@ knitr::opts_chunk$set( ) ``` -## Basic use +## Motivating example: Classifying penguin species -Let's start by loading the **parttree** package alongside **rpart**, which comes -bundled with the base R installation and is what we'll use for fitting our -decision trees (at least, to start with). For the basic examples that follow, +Start by loading the **parttree** package alongside **rpart**, which comes +bundled with the base R installation. For the basic examples that follow, I'll use the well-known [Palmer Penguins](https://allisonhorst.github.io/palmerpenguins/) dataset to demonstrate functionality. You can load this dataset via the parent package (as @@ -27,74 +26,108 @@ I have here), or import it directly as a CSV [here](https://vincentarelbundock.github.io/Rdatasets/csv/palmerpenguins/penguins.csv). ```{r setup} +library(parttree) # This package library(rpart) # For fitting decisions trees -library(parttree) # This package (will automatically load ggplot2 too) - -theme_set(theme_linedraw()) # install.packages("palmerpenguins") data("penguins", package = "palmerpenguins") head(penguins) ``` -### Categorical predictions +Dataset in hand, let's say that we are interested in predicting penguin +_species_ as a function of 1) flipper length and 2) bill length. We could model +this as a simple decision tree: + +```{r tree} +tree = rpart(species ~ flipper_length_mm + bill_length_mm, data = penguins) +tree +``` -Say we are interested in predicting the penguins _species_ as a function of 1) -flipper length and 2) bill length. We can visualize these relationships as a -simple scatter plot prior to doing any formal modeling. +Like most tree-based frameworks, **rpart** comes with a default `plot` method +for visualizing the resulting node splits. -```{r penguin_plot_cat1} -p = - ggplot(data = penguins, aes(x = flipper_length_mm, y = bill_length_mm)) + - geom_point(aes(col = species)) -p +```{r rpart_plot} +plot(tree, compress = TRUE) +text(tree, use.n = TRUE) ``` -Recasting in terms of a decision tree is easily done (e.g., with `rpart`). -However, visualizing the resulting tree predictions against the raw data is hard -to do out of the box and this where **parttree** enters the fray. The main -function that users will interact with is `geom_parttree()`, which provides a -new geom layer for **ggplot2** objects. +While this is okay, I don't feel that it provides much intuition about the +model's prediction on the _scale of the actual data_. In other words, what I'd +prefer to see is: How has our tree partitioned the original penguin data? +This is where **parttree** enters the fray. The package is named for its primary +workhorse function `parttree()`, which extracts all of the information needed +to produce a nice plot of our tree partitions alongside the original data. -```{r penguin_plot_cat2} -## Fit a decision tree using the same variables as the above plot -tree = rpart(species ~ flipper_length_mm + bill_length_mm, data = penguins) +```{r penguin_cl_plot} +ptree = parttree(tree) +plot(ptree) +``` + +_Et voila!_ Now we can clearly see how our model has divided up the Cartesian +space of the data. Gentoo penguins typically have longer flippers than Chinstrap +or Adelie penguins, while the latter have the shortest bills. + +From the perspective of the end-user, the `ptree` parttree object is not all +that interesting in of itself. It is simply a data frame that contains the basic +information needed for our plot (partition coordinates, etc.). You can think of +it as a helpful intermediate object on our way to the visualization of interest. -## Visualize the tree partitions by adding it to our plot with geom_parttree() -p + - geom_parttree(data = tree, aes(fill=species), alpha = 0.1) + - labs(caption = "Note: Points denote observations. Shaded regions denote model predictions.") +```{r ptree} +# See also `attr(ptree, "parttree")` +ptree ``` -#### Continuous predictions +Speaking of visualization, underneath the hood `plot.parttree` calls the +powerful +[**tinyplot**](https://grantmcdermott.com/tinyplot) +package. All of the latter's various customization arguments can be passed on to +our `parttree` plot to make it look a bit nicer. For example: -Trees with continuous independent variables are also supported. However, I -recommend adjusting the plot fill aesthetic since your model will likely -partition the data into intervals that don't match up exactly with the raw data. -The easiest way to do this is by setting your colour and fill aesthetic together -as part of the same `scale_colour_*` call. +```{r penguin_cl_plot_custom} +plot(ptree, pch = 16, palette = "classic", alpha = 0.75, grid = TRUE) +``` -```{r penguin_plot_con} -tree2 = rpart(body_mass_g ~ flipper_length_mm + bill_length_mm, data=penguins) -ggplot(data = penguins, aes(x = flipper_length_mm, y = bill_length_mm)) + - geom_parttree(data = tree2, aes(fill=body_mass_g), alpha = 0.3) + - geom_point(aes(col = body_mass_g)) + - scale_colour_viridis_c(aesthetics = c('colour', 'fill')) # NB: Set colour + fill together +### Continuous predictions + +In addition to discrete classification problems, **parttree** also supports +regression trees with continuous independent variables. + +```{r penguin_reg_plot} +tree_cont = rpart(body_mass_g ~ flipper_length_mm + bill_length_mm, data = penguins) + +tree_cont |> + parttree() |> + plot(pch = 16, palette = "viridis") ``` + ## Supported model classes -Currently, the package works with decision trees created by the -[**rpart**](https://CRAN.R-project.org/web/package=rpart) and -[**partykit**](https://CRAN.R-project.org/web/package=partykit) packages. -Moreover, it supports other front-end modes that call `rpart::rpart()` as -the underlying engine; in particular the -[**tidymodels**](https://www.tidymodels.org/) ([parsnip](https://parsnip.tidymodels.org/) -or [workflows](https://workflows.tidymodels.org/)) and -[**mlr3**](https://mlr3.mlr-org.com/) packages. Here's a quick example with -**parsnip**. +Alongside the [**rpart**](https://CRAN.R-project.org/web/package=rpart) model +objects that we have been working with thus far, **parttree** also supports +decision trees created by the +[**partykit**](https://CRAN.R-project.org/web/package=partykit) package. Here we +see how the latter's `ctree` (conditional inference tree) algorithm yields a +slightly more sophisticated partitioning that the former's default. + +```{r penguin_ctree_plot} +library(partykit) + +ctree(species ~ flipper_length_mm + bill_length_mm, data = penguins) |> + parttree() |> + plot(pch = 19, palette = "classic", alpha = 0.5) +``` + +**parttree** also supports a variety of "frontend" modes that call +`rpart::rpart()` as the underlying engine. This includes packages from both the +[**mlr3**](https://mlr3.mlr-org.com/) and +[**tidymodels**](https://www.tidymodels.org/) +([parsnip](https://parsnip.tidymodels.org/) +or [workflows](https://workflows.tidymodels.org/)) +ecosystems. Here is a quick demonstration using **parsnip**, where we'll also +pull in a different dataset just to change things up a little. ```{r titanic_plot} set.seed(123) ## For consistent jitter @@ -110,51 +143,81 @@ ti_tree = set_engine("rpart") |> set_mode("classification") |> fit(Survived ~ Pclass + Age, data = titanic_train) +## Now pass to parttree and plot +ti_tree |> + parttree() |> + plot(pch = 16, jitter = TRUE, palette = "dark", alpha = 0.7) +``` + +## ggplot2 + +The default `plot.parttree` method produces a base graphics plot. But we also +support **ggplot2** via with a dedicated `geom_parttree()` function. Here we +demonstrate with our initial classification tree from earlier. + +```{r penguin_cl_ggplot2} +library(ggplot2) +theme_set(theme_linedraw()) -## Plot the data and model partitions -titanic_train |> - ggplot(aes(x=Pclass, y=Age)) + - geom_parttree(data = ti_tree, aes(fill=Survived), alpha = 0.1) + - geom_jitter(aes(col=Survived), alpha=0.7) +## re-using the tree model object from above... +ggplot(data = penguins, aes(x = flipper_length_mm, y = bill_length_mm)) + + geom_point(aes(col = species)) + + geom_parttree(data = tree, aes(fill=species), alpha = 0.1) ``` -## Plot orientation +Compared to the "native" `plot.parttree` method, note that the **ggplot2** +workflow requires a few tweaks: -Underneath the hood, `geom_parttree()` is calling the companion `parttree()` -function, which coerces the **rpart** tree object into a data frame that is -easily understood by **ggplot2**. For example, consider again our first "tree" -model from earlier. Here's the print output of the raw model. +- We need to need to plot the original dataset as a separate layer (i.e., `geom_point()`). +- `geom_parttree()` accepts the tree object _itself_, not the result of `parttree()`.^[This is because `geom_parttree(data = )` calls `parttree()` internally.] -```{r tree} -tree +Continuous regression trees can also be drawn with `geom_parttree`. However, I +recommend adjusting the plot fill aesthetic since your model will likely +partition the data into intervals that don't match up exactly with the raw data. +The easiest way to do this is by setting your colour and fill aesthetic together +as part of the same `scale_colour_*` call. + +```{r penguin_reg_ggplot2} +## re-using the tree_cont model object from above... +ggplot(data = penguins, aes(x = flipper_length_mm, y = bill_length_mm)) + + geom_parttree(data = tree_cont, aes(fill=body_mass_g), alpha = 0.3) + + geom_point(aes(col = body_mass_g)) + + scale_colour_viridis_c(aesthetics = c('colour', 'fill')) # NB: Set colour + fill together ``` -And here's what we get after we feed it to `parttree()`. +### Gotcha: (gg)plot orientation + +As we have already said, `geom_parttree()` calls the companion `parttree()` +function internally, which coerces the **rpart** tree object into a data frame +that is easily understood by **ggplot2**. For example, consider our initial +"ptree" object from earlier. -```{r tree_parted} -parttree(tree) +```{r ptree_redux} +# ptree = parttree(tree) +ptree ``` Again, the resulting data frame is designed to be amenable to a **ggplot2** geom -layer, with columns like `xmin`, `xmax`, etc. specifying aesthetics that -**ggplot2** recognises. (Fun fact: `geom_parttree()` is really just a thin -wrapper around `geom_rect()`.) The goal of the package is to abstract away these -kinds of details -from the user, so we can just specify `geom_parttree()` — with a valid -tree object as the data input — and be done with it. However, while this -generally works well, it can sometimes lead to unexpected behaviour in terms of -plot orientation. That's because it's hard to guess ahead of time what the user -will specify as the x and y variables (i.e. axes) in their other plot layers. To -see what I mean, let's redo our penguin plot from earlier, but this time switch -the axes in the main `ggplot()` call. - -```{r tree_plot_mismatch} +layer, with columns like `xmin`, `xmax`, etc. specifying aesthetics that +**ggplot2** recognizes. (Fun fact: `geom_parttree()` is really just a thin +wrapper around `geom_rect()`.) The goal of **parttree** is to abstract away +these kinds of details from the user, so that they can just specify +`geom_parttree()`—with a valid tree object as the data input—and be +done with it. However, while this generally works well, it can sometimes lead to +unexpected behaviour in terms of plot orientation. That's because it's hard to +guess ahead of time what the user will specify as the x and y variables (i.e. +axes) in their other **ggplot2** layers.^[The default `plot.partree` method +doesn't have this problem because it assigns the x and y variables for both the +partitions and raw data points as part of the same function call.] To see what I +mean, let's redo our penguin plot from earlier, but this time switch the axes in +the main `ggplot()` call. + +```{r penguin_cl_ggplot_mismatch} ## First, redo our first plot but this time switch the x and y variables -p3 = - ggplot( - data = penguins, - aes(x = bill_length_mm, y = flipper_length_mm) ## Switched! - ) + +p3 = ggplot( + data = penguins, + aes(x = bill_length_mm, y = flipper_length_mm) ## Switched! + ) + geom_point(aes(col = species)) ## Add on our tree (and some preemptive titling..) @@ -163,49 +226,18 @@ p3 + labs( title = "Oops!", subtitle = "Looks like a mismatch between our x and y axes..." - ) + ) ``` As was the case here, this kind of orientation mismatch is normally (hopefully) -pretty easy to recognize. To fix, we can use the `flipaxes = TRUE` argument to +pretty easy to recognize. To fix, we can use the `flip = TRUE` argument to flip the orientation of the `geom_parttree` layer. -```{r tree_plot_flip} +```{r penguin_cl_ggplot_mismatch_flip} p3 + geom_parttree( data = tree, aes(fill = species), alpha = 0.1, - flipaxes = TRUE ## Flip the orientation - ) + + flip = TRUE ## Flip the orientation + ) + labs(title = "That's better") ``` - -## Base graphics - -While the package has been primarily designed to work with **ggplot2**, the -`parttree()` infrastructure can also be used to generate plots with base -graphics. Here, the `ctree()` function from **partykit** is used for fitting -the tree. - -```{r ctree_base_graphics} -library(partykit) - -## CTree and corresponding partition -ct = ctree(species ~ flipper_length_mm + bill_length_mm, data = penguins) -pt = parttree(ct) - -## Color palette -pal = palette.colors(4, "R4")[-1] - -## Maximum/minimum for plotting range as rect() does not handle Inf well -m = 1000 - -## scatter plot() with added rect() -plot( - bill_length_mm ~ flipper_length_mm, - data = penguins, col = pal[species], pch = 19 - ) -rect( - pmax(-m, pt$xmin), pmax(-m, pt$ymin), pmin(m, pt$xmax), pmin(m, pt$ymax), - col = adjustcolor(pal, alpha.f = 0.1)[pt$species] - ) -```