diff --git a/R/Lrnr_gru_keras.R b/R/Lrnr_gru_keras.R index 0630316e..a774fa34 100644 --- a/R/Lrnr_gru_keras.R +++ b/R/Lrnr_gru_keras.R @@ -46,6 +46,7 @@ #' be applied at given stages of the training procedure. Default callback #' function \code{callback_early_stopping} stops training if the validation #' loss does not improve across \code{patience} number of epochs. +#' - \code{validation_split}: Fraction of the training data to be used as validation data. Default is 0 (no validation). #' - \code{...}: Other parameters passed to \code{\link[keras]{keras}}. #' #' @examples @@ -74,7 +75,7 @@ #' valid_task <- validation(task, fold = task$folds[[1]]) #' #' # instantiate learner, then fit and predict (simplifed example) -#' gru_lrnr <- Lrnr_gru_keras$new(batch_size = 1, epochs = 200) +#' gru_lrnr <- Lrnr_gru_keras$new(batch_size = 1, epochs = 200, validation_split=0.2) #' gru_fit <- gru_lrnr$train(train_task) #' gru_preds <- gru_fit$predict(valid_task) #' } @@ -95,6 +96,7 @@ Lrnr_gru_keras <- R6Class( callbacks = list( keras::callback_early_stopping(patience = 10) ), + validation_split=0, ...) { params <- args_to_list() super$initialize(params = params, ...) @@ -186,7 +188,8 @@ Lrnr_gru_keras <- R6Class( y = args$y, batch_size = args$batch_size, epochs = args$epochs, - callbacks = callbacks, + callbacks = args$callbacks, + validation_split= args$validation_split, verbose = verbose, shuffle = FALSE ) diff --git a/R/Lrnr_lstm_keras.R b/R/Lrnr_lstm_keras.R index ba27992d..022d59a8 100644 --- a/R/Lrnr_lstm_keras.R +++ b/R/Lrnr_lstm_keras.R @@ -44,6 +44,7 @@ #' be applied at given stages of the training procedure. Default callback #' function \code{callback_early_stopping} stops training if the validation #' loss does not improve across \code{patience} number of epochs. +#' - \code{validation_split}: Fraction of the training data to be used as validation data. Default is 0 (no validation). #' - \code{...}: Other parameters passed to \code{\link[keras]{keras}}. #' #' @examples @@ -72,7 +73,7 @@ #' valid_task <- validation(task, fold = task$folds[[1]]) #' #' # instantiate learner, then fit and predict (simplifed example) -#' lstm_lrnr <- Lrnr_lstm_keras$new(batch_size = 1, epochs = 200) +#' lstm_lrnr <- Lrnr_lstm_keras$new(batch_size = 1, epochs = 200, validation_split=0.2) #' lstm_fit <- lstm_lrnr$train(train_task) #' lstm_preds <- lstm_fit$predict(valid_task) #' } @@ -93,6 +94,7 @@ Lrnr_lstm_keras <- R6Class( lr = 0.001, layers = 1, callbacks = list(keras::callback_early_stopping(patience = 10)), + validation_split=0, ...) { params <- args_to_list() super$initialize(params = params, ...) @@ -184,7 +186,8 @@ Lrnr_lstm_keras <- R6Class( y = args$y, batch_size = args$batch_size, epochs = args$epochs, - callbacks = callbacks, + callbacks = args$callbacks, + validation_split= args$validation_split, verbose = verbose, shuffle = FALSE )