diff --git a/pytorch_tabnet/tab_network.py b/pytorch_tabnet/tab_network.py index 3a72e65c..59bbb2db 100644 --- a/pytorch_tabnet/tab_network.py +++ b/pytorch_tabnet/tab_network.py @@ -2,7 +2,6 @@ from torch.nn import Linear, BatchNorm1d, ReLU import numpy as np from pytorch_tabnet import sparsemax -from copy import deepcopy def initialize_non_glu(module, input_dim, output_dim): @@ -263,7 +262,7 @@ def __init__(self, input_dim, output_dim, shared_blocks, n_glu, Float value between 0 and 1 which will be used for momentum in batch norm """ - self.shared = deepcopy(shared_blocks) + self.shared = shared_blocks if self.shared is not None: for l in self.shared.glu_layers: l.bn = GBN(2*output_dim, virtual_batch_size=virtual_batch_size,