From 75be22469b695f49e3b59a7c0196269fcdefa305 Mon Sep 17 00:00:00 2001 From: Daniel Falbel Date: Wed, 24 Jan 2024 10:57:00 -0300 Subject: [PATCH] make sure deep cloning preserves state dict attributes --- R/nn.R | 8 ++++++-- tests/testthat/test-nn.R | 9 +++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/R/nn.R b/R/nn.R index a90a1d2255..8324b58283 100644 --- a/R/nn.R +++ b/R/nn.R @@ -256,7 +256,7 @@ nn_Module <- R6::R6Class( par <- private$parameters_[[i]] # par or buf might not be available in `table` if, for some reason they # have already been replaced. This happens for example, when a module - # has the same layer twice. this also applies for modules, they might be duplicated + # has the same layer twice. this also applies for modules, they might be duplicated private$parameters_[[i]] <- table[[xptr_address(par)]] %||% par } for (i in seq_along(private$buffers_)) { @@ -527,7 +527,11 @@ create_nn_module_callable <- function(instance) { names(state_dict) <- sapply(state_dict, xptr_address) state_dict <- state_dict[!duplicated(names(state_dict))] - state_dict <- lapply(state_dict, function(x) x$detach()$clone()) + state_dict <- lapply(state_dict, function(x) { + out <- x$detach()$clone() + attributes(out) <- attributes(x) + out + }) # also need to append a clone of the modules to this list. # child modules can be duplicated - and have the same name diff --git a/tests/testthat/test-nn.R b/tests/testthat/test-nn.R index efb4a87c2f..8066446e27 100644 --- a/tests/testthat/test-nn.R +++ b/tests/testthat/test-nn.R @@ -706,6 +706,15 @@ test_that("deep cloning", { # make sure we re-lock binding expect_true(bindingIsLocked("clone", attr(x, "module"))) + + # make sure the class of parameters remains + a <- nn_linear(1, 1) + b <- a$clone(deep = TRUE) + + expect_equal( + attributes(b$parameters$weight), + attributes(a$parameters$weight) + ) }) test_that("Can initialize a model in the meta device and copy parameters to it", {