diff --git a/R/PipeOpSmote.R b/R/PipeOpSmote.R index 9524df701..c42ab1e51 100644 --- a/R/PipeOpSmote.R +++ b/R/PipeOpSmote.R @@ -87,7 +87,13 @@ PipeOpSmote = R6Class("PipeOpSmote", .train_task = function(task) { assert_true(all(task$feature_types$type == "numeric")) - cols = private$.select_cols(task) + cols = task$feature_names + + unsupported_cols = setdiff(unlist(task$col_roles), union(cols, task$target_names)) + if (length(unsupported_cols)) { + stopf("SMOTE cannot generate synthetic data for the following columns since they are neither features nor targets: '%s'", + paste(unsupported_cols, collapse = "', '")) + } if (!length(cols)) { return(task) @@ -102,7 +108,7 @@ PipeOpSmote = R6Class("PipeOpSmote", # rename target column and fix character conversion st[["class"]] = as_factor(st[["class"]], levels = task$class_names) setnames(st, "class", task$target_names) - + task$rbind(st) } )