From 7435428911a590afee5172619b6a38760c9e66be Mon Sep 17 00:00:00 2001 From: Nima Hejazi Date: Thu, 23 Dec 2021 01:18:08 -0500 Subject: [PATCH] restyle and fix args --- R/Lrnr_caret.R | 1 - R/Lrnr_hal9001.R | 2 - R/importance.R | 115 ++++++++++++++++++++++----------------------- R/loss_functions.R | 31 ++++++------ R/zzz.R | 3 +- 5 files changed, 73 insertions(+), 79 deletions(-) diff --git a/R/Lrnr_caret.R b/R/Lrnr_caret.R index e1099000..9b8d6ff0 100644 --- a/R/Lrnr_caret.R +++ b/R/Lrnr_caret.R @@ -66,7 +66,6 @@ Lrnr_caret <- R6Class( ), private = list( .properties = c("continuous", "binomial", "categorical", "wrapper"), - .train = function(task) { # set type outcome_type <- self$get_outcome_type(task) diff --git a/R/Lrnr_hal9001.R b/R/Lrnr_hal9001.R index 6d6b13d4..9cdd02aa 100644 --- a/R/Lrnr_hal9001.R +++ b/R/Lrnr_hal9001.R @@ -99,7 +99,6 @@ Lrnr_hal9001 <- R6Class( return(fit_object) }, - .predict = function(task = NULL) { predictions <- stats::predict( self$fit_object, @@ -111,7 +110,6 @@ Lrnr_hal9001 <- R6Class( } return(predictions) }, - .required_packages = c("hal9001", "glmnet") ) ) diff --git a/R/importance.R b/R/importance.R index 54fe426c..b471e4de 100644 --- a/R/importance.R +++ b/R/importance.R @@ -1,17 +1,19 @@ utils::globalVariables(c("score")) + #' Variable Importance #' -#' Function that takes a cross-validated fit (i.e., cross-validated learner that -#' has already been trained on a task), which could be a cross-validated single -#' learner or super learner, and generates a risk-based variable importance -#' score for either each covariate or each group of covariates in the task. -#' This function outputs a \code{data.table}, where each row corresponds to the -#' risk difference or the risk ratio between the following two risks: the risk -#' when a covariate (or group of covariates) is permuted or removed, and the -#' original risk (i.e., when all covariates are included as they were in the -#' observed data). A higher risk ratio/difference corresponds to a more -#' important covariate/group. A plot can be generated from the returned -#' \code{data.table} by calling companion function \code{\link{importance_plot}}. +#' Function that takes a cross-validated fit (i.e., cross-validated learner +#' that has already been trained on a task), which could be a cross-validated +#' single learner or super learner, and generates a risk-based variable +#' importance score for either each covariate or each group of covariates in +#' the task. This function outputs a \code{data.table}, where each row +#' corresponds to the risk difference or the risk ratio between the following +#' two risks: the risk when a covariate (or group of covariates) is permuted or +#' removed, and the original risk (i.e., when all covariates are included as +#' they were in the observed data). A higher risk ratio/difference corresponds +#' to a more important covariate/group. A plot can be generated from the +#' returned \code{data.table} by calling companion function +#' \code{\link{importance_plot}}. #' #' @export #' @@ -23,45 +25,42 @@ utils::globalVariables(c("score")) #' #' @return A \code{data.table} of variable importance for each covariate. #' -#' @section Parameters: -#' - \code{fit}: A trained cross-validated (CV) learner (such as a CV stack -#' or super learner), from which cross-validated predictions can be -#' generated. -#' - \code{eval_fun = NULL}: The evaluation function (risk or loss function) -#' for evaluating the risk. Defaults vary based on the outcome type, -#' matching defaults in \code{\link{default_metalearner}}. See -#' \code{\link{loss_functions}} and \code{\link{risk_functions}} for -#' options. -#' - \code{fold_number}: The fold number to use for obtaining the predictions -#' from the fit. Either a positive integer for obtaining predictions from -#' a specific fold's fit; \code{"full"} for obtaining predictions from a -#' fit on all of the data, or \code{"validation"} (default) for obtaining -#' cross-validated predictions, where the data used for training and -#' prediction never overlaps across the folds. Note that if a positive -#' integer or \code{"full"} is supplied here then there will be overlap -#' between the data used for training and validation, so -#' \code{fold_number ="validation"} is recommended. -#' - \code{type}: Which method should be used to obscure the relationship -#' between each covariate / covariate group and the outcome? When -#' \code{type} is \code{"remove"} (default), each covariate / covariate -#' group is removed one at a time from the task; the cross-validated -#' learner is refit to this modified task; and finally, predictions are -#' obtained from this refit. When \code{type} is \code{"permute"}, each -#' covariate / covariate group is permuted (sampled without replacement) -#' one at a time, and then predictions are obtained from this modified -#' data. -#' - \code{importance_metric}: Either \code{"ratio"} or \code{"difference"} -#' (default). For each covariate / covariate group, \code{"ratio"} -#' returns the risk of the permuted/removed covariate / covariate group -#' divided by observed/original risk (i.e., the risk with all covariates -#' as they existed in the sample) and \code{"difference"} returns the -#' difference between the risk with the permuted/removed covariate / -#' covariate group and the observed risk. -#' - \code{covariate_groups}: Optional named list covariate groups which will -#' invoke variable importance evaluation at the group-level, by -#' removing/permuting all covariates in the same group together. If -#' covariates in the task are not specified in the list of groups, then -#' those covariates will be added as additional single-covariate groups. +#' @param fit A trained cross-validated (CV) learner (such as a CV stack or +#' super learner), from which cross-validated predictions can be generated. +#' @param eval_fun The evaluation function (risk or loss function) for +#' evaluating the risk. Defaults vary based on the outcome type, matching +#' defaults in \code{\link{default_metalearner}}. See +#' \code{\link{loss_functions}} and \code{\link{risk_functions}} for options. +#' Default is \code{NULL}. +#' @param fold_number The fold number to use for obtaining the predictions from +#' the fit. Either a positive integer for obtaining predictions from a +#' specific fold's fit; \code{"full"} for obtaining predictions from a fit on +#' all of the data, or \code{"validation"} (default) for obtaining +#' cross-validated predictions, where the data used for training and +#' prediction never overlaps across the folds. Note that if a positive integer +#' or \code{"full"} is supplied here then there will be overlap between the +#' data used for training and validation, so \code{fold_number ="validation"} +#' is recommended. +#' @param type Which method should be used to obscure the relationship between +#' each covariate / covariate group and the outcome? When \code{type} is +#' \code{"remove"} (default), each covariate / covariate group is removed one +#' at a time from the task; the cross-validated learner is refit to this +#' modified task; and finally, predictions are obtained from this refit. When +#' \code{type} is \code{"permute"}, each covariate / covariate group is +#' permuted (sampled without replacement) one at a time, and then predictions +#' are obtained from this modified data. +#' @param importance_metric Either \code{"ratio"} or \code{"difference"} +#' (default). For each covariate / covariate group, \code{"ratio"} returns the +#' risk of the permuted/removed covariate / covariate group divided by +#' observed/original risk (i.e., the risk with all covariates as they existed +#' in the sample) and \code{"difference"} returns the difference between the +#' risk with the permuted/removed covariate / covariate group and the observed +#' risk. +#' @param covariate_groups Optional named list covariate groups which will +#' invoke variable importance evaluation at the group-level, by +#' removing/permuting all covariates in the same group together. If covariates +#' in the task are not specified in the list of groups, then those covariates +#' will be added as additional single-covariate groups. #' #' @examples #' # define ML task @@ -246,12 +245,11 @@ importance <- function(fit, eval_fun = NULL, #' Variable Importance Plot #' -#' @section Parameters: -#' - \code{x}: The 2-column \code{data.table} returned by -#' \code{\link{importance}}, where the first column is the -#' covariate/groups and the second column is the importance score. -#' - \code{nvar}: The maximum number of predictors to be plotted. Defaults to -#' the minimum between 30 and the number of rows in \code{x}. +#' @param x The two-column \code{data.table} returned by +#' \code{\link{importance}}, where the first column is the covariate/groups +#' and the second column is the importance score. +#' @param nvar The maximum number of predictors to be plotted. Defaults to the +#' minimum between 30 and the number of rows in \code{x}. #' #' @return A \code{\link[ggplot2]{ggplot}} of variable importance. #' @@ -284,7 +282,6 @@ importance <- function(fit, eval_fun = NULL, #' importance_result <- importance(sl_fit) #' importance_plot(importance_result) importance_plot <- function(x, nvar = min(30, nrow(x))) { - # get the importance metric xlab <- colnames(x)[2] @@ -294,7 +291,9 @@ importance_plot <- function(x, nvar = min(30, nrow(x))) { x <- x[1:(min(nvar, nrow(x))), ] # format for ggplot - d <- data.table::data.table(vars = factor(x[[1]], levels = x[[1]]), score = x[[2]]) + d <- data.table::data.table( + vars = factor(x[[1]], levels = x[[1]]), score = x[[2]] + ) ggplot2::ggplot(d, aes(x = vars, y = score)) + ggplot2::geom_point() + diff --git a/R/loss_functions.R b/R/loss_functions.R index 8fd0a042..168ecc8a 100644 --- a/R/loss_functions.R +++ b/R/loss_functions.R @@ -1,3 +1,5 @@ +utils::globalVariables(c("id", "loss", "obs", "pred", "wts")) + #' Loss Function Definitions #' #' Loss functions for use in evaluating learner fits. @@ -97,7 +99,7 @@ risk <- function(pred, observed, loss = loss_squared_error, weights = NULL) { } # calculate risk - risk <- weighted.mean(losses, weights) + risk <- stats::weighted.mean(losses, weights) return(risk) } @@ -113,30 +115,28 @@ risk <- function(pred, observed, loss = loss_squared_error, weights = NULL) { #' #' @name risk_functions #' -#' @param pred A vector of predicted values. -#' @param observed A vector of binary observed values. #' @param measure A character indicating which \code{ROCR} performance measure -#' to use for evaluation. The \code{measure} must be either cutoff-dependent -#' so a single value can be selected (e.g., "tpr"), or it's value is a scalar -#' (e.g., "aucpr"). For more information, see \code{\link[ROCR]{performance}}. +#' to use for evaluation. The \code{measure} must be either cutoff-dependent +#' so a single value can be selected (e.g., "tpr"), or it's value is a scalar +#' (e.g., "aucpr"). For more information, see \code{\link[ROCR]{performance}}. #' @param cutoff A numeric value specifying the cutoff for choosing a single -#' performance measure from the returned set. Only used for performance measures -#' that are cutoff-dependent and default is 0.5. See -#' \code{\link[ROCR]{performance}} for more detail. +#' performance measure from the returned set. Only used for performance +#' measures that are cutoff-dependent and default is 0.5. See +#' \code{\link[ROCR]{performance}} for more detail. #' @param name An optional character string for user to supply their desired -#' name for the performance measure, which will be used for naming subsequent -#' risk-related tables and metrics (e.g., \code{cv_risk} column names). When -#' \code{name} is not supplied, the \code{measure} will be used for naming. +#' name for the performance measure, which will be used for naming subsequent +#' risk-related tables and metrics (e.g., \code{cv_risk} column names). When +#' \code{name} is not supplied, the \code{measure} will be used for naming. #' @param ... Optional arguments to specific \code{ROCR} performance -#' measures. See \code{\link[ROCR]{performance}} for more detail. +#' measures. See \code{\link[ROCR]{performance}} for more detail. #' #' @rdname risk_functions #' #' @importFrom ROCR prediction performance #' #' @export -#' custom_ROCR_risk <- function(measure, cutoff = 0.5, name = NULL, ...) { + # NOTE: arguments to factory-produced function goes undocumented function(pred, observed) { # remove NA, NaN, Inf values @@ -174,7 +174,6 @@ custom_ROCR_risk <- function(measure, cutoff = 0.5, name = NULL, ...) { } } -utils::globalVariables(c("id", "loss", "obs", "pred", "wts")) #' Cross-validated Risk Estimation #' #' Estimates the cross-validated risk for a given learner and evaluation @@ -182,7 +181,7 @@ utils::globalVariables(c("id", "loss", "obs", "pred", "wts")) #' #' @param learner A trained learner object. #' @param eval_fun A valid loss or risk function. See -#' \code{\link{loss_functions}} and \code{\link{risk_functions}}. +#' \code{\link{loss_functions}} and \code{\link{risk_functions}}. #' @param coefs A \code{numeric} vector of coefficients. #' #' @importFrom assertthat assert_that diff --git a/R/zzz.R b/R/zzz.R index b7d8eed4..ac66f07f 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -38,8 +38,7 @@ sl3Options <- function(o, value) { } if (is.null(value)) { res[o] <- list(NULL) - } - else { + } else { res[[o]] <- value } options(res[o])