diff --git a/R/package.R b/R/package.R index 8fcc9ca..8ba8bc1 100644 --- a/R/package.R +++ b/R/package.R @@ -117,6 +117,12 @@ tf_v2 <- function() { tryCatch(tf$python$util$deprecation$silence()$`__enter__`(), error = function(e) NULL) + if (isNamespaceLoaded("keras")) { + keras2_backcompat_hook() + } else { + setHook(packageEvent("keras", "onLoad"), keras2_backcompat_hook) + } + # TODO: move this into .onAttach, where you either emit immediately if # already loaded otherwise register emit hook for reticulate # emit <- get("packageStartupMessage") # R CMD check @@ -136,7 +142,6 @@ tf_v2 <- function() { "a Python installation where the tensorflow module is installed.", call. = FALSE) }) - # provide a common base S3 class for tensors reticulate::register_class_filter(function(classes) { if (any(c("tensorflow.python.ops.variables.Variable", @@ -160,6 +165,82 @@ is_string <- function(x) { is.character(x) && length(x) == 1L && !is.na(x) } + +keras2_backcompat_hook <- #function(){} + function(...) { + message("Calling keras2 backcompat hooks") + keras_ns <- asNamespace("keras") + + new_model_class_names <- c("keras.src.models.sequential.Sequential", + "keras.models.sequential.Sequential") + + generic <- "compose_layer" + method <- keras_ns[["compose_layer.keras.models.Sequential"]] + envir <- environment(get(generic, keras_ns)) + for(name in new_model_class_names) + registerS3method(generic, name, method, envir) + + new_model_class_names <- c("keras.src.models.model.Model", + "keras.models.model.Model") + + generic <- "fit" + method <- keras_ns[["fit.keras.engine.training.Model"]] + envir <- environment(get(generic, keras_ns)) + for(name in new_model_class_names) + registerS3method(generic, name, method, envir) + + generic <- "compile" + method <- keras_ns[["compile.keras.engine.training.Model"]] + envir <- environment(get(generic, keras_ns)) + for(name in new_model_class_names) + registerS3method(generic, name, method, envir) + + generic <- "predict" + method <- keras_ns[["predict.keras.engine.training.Model"]] + envir <- environment(get(generic, keras_ns)) + for(name in new_model_class_names) + registerS3method(generic, name, method, envir) + + generic <- "evaluate" + method <- keras_ns[["evaluate.keras.engine.training.Model"]] + envir <- environment(get(generic, keras_ns)) + for(name in new_model_class_names) + registerS3method(generic, name, method, envir) + + generic <- "export_savedmodel" + method <- keras_ns[["export_savedmodel.keras.engine.training.Model"]] + envir <- environment(get(generic, keras_ns)) + for(name in new_model_class_names) + registerS3method(generic, name, method, envir) + + generic <- "format" + method <- keras_ns[["format.keras.engine.training.Model"]] + envir <- environment(get(generic, keras_ns)) + for(name in new_model_class_names) + registerS3method(generic, name, method, envir) + + generic <- "print" + method <- keras_ns[["print.keras.engine.training.Model"]] + envir <- environment(get(generic, keras_ns)) + for(name in new_model_class_names) + registerS3method(generic, name, method, envir) + + generic <- "summary" + method <- keras_ns[["summary.keras.engine.training.Model"]] + envir <- environment(get(generic, keras_ns)) + for(name in new_model_class_names) + registerS3method(generic, name, method, envir) + + generic <- "plot" + method <- keras_ns[["plot.keras.engine.training.Model"]] + envir <- environment(get(generic, keras_ns)) + for(name in new_model_class_names) + registerS3method(generic, name, method, envir) + +} + + + #' TensorFlow configuration information #' #' @return List with information on the current configuration of TensorFlow.