diff --git a/R/Lrnr_gam.R b/R/Lrnr_gam.R index c52ec44e..09a7516f 100644 --- a/R/Lrnr_gam.R +++ b/R/Lrnr_gam.R @@ -77,7 +77,7 @@ Lrnr_gam <- R6Class( } ), private = list( - .properties = c("continuous", "binomial"), + .properties = c("continuous", "binomial", "weights", "offset"), .train = function(task) { # load args args <- self$params @@ -87,21 +87,27 @@ Lrnr_gam <- R6Class( Y <- data.frame(outcome_type$format(task$Y)) colnames(Y) <- task$nodes$outcome args$data <- cbind(task$X, Y) - ## family + + + # family if (is.null(args$family)) { - if (outcome_type$type == "continuous") { - args$family <- stats::gaussian() - } else if (outcome_type$type == "binomial") { - args$family <- stats::binomial() - } else if (outcome_type$type == "categorical") { - # TODO: implement categorical? - # NOTE: must specify (#{categories}-1)+linear_predictors) in formula - stop("Categorical outcomes are unsupported by Lrnr_gam for now.") - } else { - stop("Specified outcome type is unsupported by Lrnr_gam.") - } + args$family <- outcome_type$glm_family(return_object = TRUE) + } + family_name <- args$family$family + linkinv_fun <- args$family$linkinv + link_fun <- args$family$linkfun + + # weights + if (task$has_node("weights")) { + args$weights <- task$weights } - ## formula + + # offset + if (task$has_node("offset")) { + args$offset <- task$offset_transformed(link_fun) + } + + # formula if (is.null(args$formula)) { covars_type <- lapply( task$X, @@ -111,9 +117,8 @@ Lrnr_gam <- R6Class( X_discrete <- task$X[, ..i_discrete] i_continuous <- covars_type == "continuous" X_continuous <- task$X[, ..i_continuous] - "y ~ s(x1) + s(x2) + x3" X_smooth <- sapply(colnames(X_continuous), function(x) { - unique_x <- unlist(unique(task$X[, x, with = F])) + unique_x <- unlist(unique(task$X[, x, with = FALSE])) if (length(unique_x) < 10) { paste0("s(", x, ", k=", length(unique_x), ")") } else { @@ -151,7 +156,7 @@ Lrnr_gam <- R6Class( collapse = " ~ " )) } else { - stop("Specified covariates types are unsupported in Lrnr_gam.") + stop("Specified covariates types are unsupported in Lrnr_gam") } } else if (is.list(args$formula)) { args$formula <- lapply(args$formula, as.formula) @@ -164,6 +169,15 @@ Lrnr_gam <- R6Class( return(fit_object) }, .predict = function(task) { + # offset check + if (self$fit_object$training_offset) { + offset <- task$offset_transformed( + self$fit_object$link_fun, + for_prediction = TRUE + ) + warning("Lrnr_gam: offset given in training is ignored in prediction") + } + # get predictions predictions <- stats::predict( private$.fit_object, diff --git a/R/Lrnr_glm.R b/R/Lrnr_glm.R index 9b4001a1..dacf8a08 100644 --- a/R/Lrnr_glm.R +++ b/R/Lrnr_glm.R @@ -73,7 +73,6 @@ Lrnr_glm <- R6Class( } args$y <- outcome_type$format(task$Y) - if (task$has_node("weights")) { args$weights <- task$weights } diff --git a/R/Lrnr_svm.R b/R/Lrnr_svm.R index 0fbc2c4b..b8fbc5c3 100644 --- a/R/Lrnr_svm.R +++ b/R/Lrnr_svm.R @@ -85,7 +85,7 @@ Lrnr_svm <- R6Class( # set SVM type based on detected outcome family outcome_type <- self$get_outcome_type(task) if (is.null(args$type)) { - if (outcome_type$type == "continuous") { + if (outcome_type$type %in% c("continuous", "quasibinomial")) { args$type <- "eps-regression" } else if (outcome_type$type %in% c("binomial", "categorical")) { args$type <- "C-classification" @@ -126,7 +126,7 @@ Lrnr_svm <- R6Class( if (outcome_type == "categorical") { predictions <- pack_predictions(predictions) } else if (outcome_type == "binomial") { - predictions <- as.numeric(predictions[, 2]) # P(Y=1|X) + predictions <- as.numeric(predictions[, 2]) } } else { predictions <- as.numeric(predictions)