Skip to content

Commit

Permalink
restyle and fix args
Browse files Browse the repository at this point in the history
  • Loading branch information
nhejazi committed Dec 23, 2021
1 parent 8b641cc commit 7435428
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 79 deletions.
1 change: 0 additions & 1 deletion R/Lrnr_caret.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ Lrnr_caret <- R6Class(
),
private = list(
.properties = c("continuous", "binomial", "categorical", "wrapper"),

.train = function(task) {
# set type
outcome_type <- self$get_outcome_type(task)
Expand Down
2 changes: 0 additions & 2 deletions R/Lrnr_hal9001.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ Lrnr_hal9001 <- R6Class(

return(fit_object)
},

.predict = function(task = NULL) {
predictions <- stats::predict(
self$fit_object,
Expand All @@ -111,7 +110,6 @@ Lrnr_hal9001 <- R6Class(
}
return(predictions)
},

.required_packages = c("hal9001", "glmnet")
)
)
115 changes: 57 additions & 58 deletions R/importance.R
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
utils::globalVariables(c("score"))

#' Variable Importance
#'
#' Function that takes a cross-validated fit (i.e., cross-validated learner that
#' has already been trained on a task), which could be a cross-validated single
#' learner or super learner, and generates a risk-based variable importance
#' score for either each covariate or each group of covariates in the task.
#' This function outputs a \code{data.table}, where each row corresponds to the
#' risk difference or the risk ratio between the following two risks: the risk
#' when a covariate (or group of covariates) is permuted or removed, and the
#' original risk (i.e., when all covariates are included as they were in the
#' observed data). A higher risk ratio/difference corresponds to a more
#' important covariate/group. A plot can be generated from the returned
#' \code{data.table} by calling companion function \code{\link{importance_plot}}.
#' Function that takes a cross-validated fit (i.e., cross-validated learner
#' that has already been trained on a task), which could be a cross-validated
#' single learner or super learner, and generates a risk-based variable
#' importance score for either each covariate or each group of covariates in
#' the task. This function outputs a \code{data.table}, where each row
#' corresponds to the risk difference or the risk ratio between the following
#' two risks: the risk when a covariate (or group of covariates) is permuted or
#' removed, and the original risk (i.e., when all covariates are included as
#' they were in the observed data). A higher risk ratio/difference corresponds
#' to a more important covariate/group. A plot can be generated from the
#' returned \code{data.table} by calling companion function
#' \code{\link{importance_plot}}.
#'
#' @export
#'
Expand All @@ -23,45 +25,42 @@ utils::globalVariables(c("score"))
#'
#' @return A \code{data.table} of variable importance for each covariate.
#'
#' @section Parameters:
#' - \code{fit}: A trained cross-validated (CV) learner (such as a CV stack
#' or super learner), from which cross-validated predictions can be
#' generated.
#' - \code{eval_fun = NULL}: The evaluation function (risk or loss function)
#' for evaluating the risk. Defaults vary based on the outcome type,
#' matching defaults in \code{\link{default_metalearner}}. See
#' \code{\link{loss_functions}} and \code{\link{risk_functions}} for
#' options.
#' - \code{fold_number}: The fold number to use for obtaining the predictions
#' from the fit. Either a positive integer for obtaining predictions from
#' a specific fold's fit; \code{"full"} for obtaining predictions from a
#' fit on all of the data, or \code{"validation"} (default) for obtaining
#' cross-validated predictions, where the data used for training and
#' prediction never overlaps across the folds. Note that if a positive
#' integer or \code{"full"} is supplied here then there will be overlap
#' between the data used for training and validation, so
#' \code{fold_number ="validation"} is recommended.
#' - \code{type}: Which method should be used to obscure the relationship
#' between each covariate / covariate group and the outcome? When
#' \code{type} is \code{"remove"} (default), each covariate / covariate
#' group is removed one at a time from the task; the cross-validated
#' learner is refit to this modified task; and finally, predictions are
#' obtained from this refit. When \code{type} is \code{"permute"}, each
#' covariate / covariate group is permuted (sampled without replacement)
#' one at a time, and then predictions are obtained from this modified
#' data.
#' - \code{importance_metric}: Either \code{"ratio"} or \code{"difference"}
#' (default). For each covariate / covariate group, \code{"ratio"}
#' returns the risk of the permuted/removed covariate / covariate group
#' divided by observed/original risk (i.e., the risk with all covariates
#' as they existed in the sample) and \code{"difference"} returns the
#' difference between the risk with the permuted/removed covariate /
#' covariate group and the observed risk.
#' - \code{covariate_groups}: Optional named list covariate groups which will
#' invoke variable importance evaluation at the group-level, by
#' removing/permuting all covariates in the same group together. If
#' covariates in the task are not specified in the list of groups, then
#' those covariates will be added as additional single-covariate groups.
#' @param fit A trained cross-validated (CV) learner (such as a CV stack or
#' super learner), from which cross-validated predictions can be generated.
#' @param eval_fun The evaluation function (risk or loss function) for
#' evaluating the risk. Defaults vary based on the outcome type, matching
#' defaults in \code{\link{default_metalearner}}. See
#' \code{\link{loss_functions}} and \code{\link{risk_functions}} for options.
#' Default is \code{NULL}.
#' @param fold_number The fold number to use for obtaining the predictions from
#' the fit. Either a positive integer for obtaining predictions from a
#' specific fold's fit; \code{"full"} for obtaining predictions from a fit on
#' all of the data, or \code{"validation"} (default) for obtaining
#' cross-validated predictions, where the data used for training and
#' prediction never overlaps across the folds. Note that if a positive integer
#' or \code{"full"} is supplied here then there will be overlap between the
#' data used for training and validation, so \code{fold_number ="validation"}
#' is recommended.
#' @param type Which method should be used to obscure the relationship between
#' each covariate / covariate group and the outcome? When \code{type} is
#' \code{"remove"} (default), each covariate / covariate group is removed one
#' at a time from the task; the cross-validated learner is refit to this
#' modified task; and finally, predictions are obtained from this refit. When
#' \code{type} is \code{"permute"}, each covariate / covariate group is
#' permuted (sampled without replacement) one at a time, and then predictions
#' are obtained from this modified data.
#' @param importance_metric Either \code{"ratio"} or \code{"difference"}
#' (default). For each covariate / covariate group, \code{"ratio"} returns the
#' risk of the permuted/removed covariate / covariate group divided by
#' observed/original risk (i.e., the risk with all covariates as they existed
#' in the sample) and \code{"difference"} returns the difference between the
#' risk with the permuted/removed covariate / covariate group and the observed
#' risk.
#' @param covariate_groups Optional named list covariate groups which will
#' invoke variable importance evaluation at the group-level, by
#' removing/permuting all covariates in the same group together. If covariates
#' in the task are not specified in the list of groups, then those covariates
#' will be added as additional single-covariate groups.
#'
#' @examples
#' # define ML task
Expand Down Expand Up @@ -246,12 +245,11 @@ importance <- function(fit, eval_fun = NULL,

#' Variable Importance Plot
#'
#' @section Parameters:
#' - \code{x}: The 2-column \code{data.table} returned by
#' \code{\link{importance}}, where the first column is the
#' covariate/groups and the second column is the importance score.
#' - \code{nvar}: The maximum number of predictors to be plotted. Defaults to
#' the minimum between 30 and the number of rows in \code{x}.
#' @param x The two-column \code{data.table} returned by
#' \code{\link{importance}}, where the first column is the covariate/groups
#' and the second column is the importance score.
#' @param nvar The maximum number of predictors to be plotted. Defaults to the
#' minimum between 30 and the number of rows in \code{x}.
#'
#' @return A \code{\link[ggplot2]{ggplot}} of variable importance.
#'
Expand Down Expand Up @@ -284,7 +282,6 @@ importance <- function(fit, eval_fun = NULL,
#' importance_result <- importance(sl_fit)
#' importance_plot(importance_result)
importance_plot <- function(x, nvar = min(30, nrow(x))) {

# get the importance metric
xlab <- colnames(x)[2]

Expand All @@ -294,7 +291,9 @@ importance_plot <- function(x, nvar = min(30, nrow(x))) {
x <- x[1:(min(nvar, nrow(x))), ]

# format for ggplot
d <- data.table::data.table(vars = factor(x[[1]], levels = x[[1]]), score = x[[2]])
d <- data.table::data.table(
vars = factor(x[[1]], levels = x[[1]]), score = x[[2]]
)

ggplot2::ggplot(d, aes(x = vars, y = score)) +
ggplot2::geom_point() +
Expand Down
31 changes: 15 additions & 16 deletions R/loss_functions.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
utils::globalVariables(c("id", "loss", "obs", "pred", "wts"))

#' Loss Function Definitions
#'
#' Loss functions for use in evaluating learner fits.
Expand Down Expand Up @@ -97,7 +99,7 @@ risk <- function(pred, observed, loss = loss_squared_error, weights = NULL) {
}

# calculate risk
risk <- weighted.mean(losses, weights)
risk <- stats::weighted.mean(losses, weights)
return(risk)
}

Expand All @@ -113,30 +115,28 @@ risk <- function(pred, observed, loss = loss_squared_error, weights = NULL) {
#'
#' @name risk_functions
#'
#' @param pred A vector of predicted values.
#' @param observed A vector of binary observed values.
#' @param measure A character indicating which \code{ROCR} performance measure
#' to use for evaluation. The \code{measure} must be either cutoff-dependent
#' so a single value can be selected (e.g., "tpr"), or it's value is a scalar
#' (e.g., "aucpr"). For more information, see \code{\link[ROCR]{performance}}.
#' to use for evaluation. The \code{measure} must be either cutoff-dependent
#' so a single value can be selected (e.g., "tpr"), or it's value is a scalar
#' (e.g., "aucpr"). For more information, see \code{\link[ROCR]{performance}}.
#' @param cutoff A numeric value specifying the cutoff for choosing a single
#' performance measure from the returned set. Only used for performance measures
#' that are cutoff-dependent and default is 0.5. See
#' \code{\link[ROCR]{performance}} for more detail.
#' performance measure from the returned set. Only used for performance
#' measures that are cutoff-dependent and default is 0.5. See
#' \code{\link[ROCR]{performance}} for more detail.
#' @param name An optional character string for user to supply their desired
#' name for the performance measure, which will be used for naming subsequent
#' risk-related tables and metrics (e.g., \code{cv_risk} column names). When
#' \code{name} is not supplied, the \code{measure} will be used for naming.
#' name for the performance measure, which will be used for naming subsequent
#' risk-related tables and metrics (e.g., \code{cv_risk} column names). When
#' \code{name} is not supplied, the \code{measure} will be used for naming.
#' @param ... Optional arguments to specific \code{ROCR} performance
#' measures. See \code{\link[ROCR]{performance}} for more detail.
#' measures. See \code{\link[ROCR]{performance}} for more detail.
#'
#' @rdname risk_functions
#'
#' @importFrom ROCR prediction performance
#'
#' @export
#'
custom_ROCR_risk <- function(measure, cutoff = 0.5, name = NULL, ...) {
# NOTE: arguments to factory-produced function goes undocumented
function(pred, observed) {

# remove NA, NaN, Inf values
Expand Down Expand Up @@ -174,15 +174,14 @@ custom_ROCR_risk <- function(measure, cutoff = 0.5, name = NULL, ...) {
}
}

utils::globalVariables(c("id", "loss", "obs", "pred", "wts"))
#' Cross-validated Risk Estimation
#'
#' Estimates the cross-validated risk for a given learner and evaluation
#' function, which can be either a loss or a risk function.
#'
#' @param learner A trained learner object.
#' @param eval_fun A valid loss or risk function. See
#' \code{\link{loss_functions}} and \code{\link{risk_functions}}.
#' \code{\link{loss_functions}} and \code{\link{risk_functions}}.
#' @param coefs A \code{numeric} vector of coefficients.
#'
#' @importFrom assertthat assert_that
Expand Down
3 changes: 1 addition & 2 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ sl3Options <- function(o, value) {
}
if (is.null(value)) {
res[o] <- list(NULL)
}
else {
} else {
res[[o]] <- value
}
options(res[o])
Expand Down

0 comments on commit 7435428

Please sign in to comment.