diff --git a/tpot/builtins/nn.py b/tpot/builtins/nn.py index 953a9c0f..5c8dd049 100644 --- a/tpot/builtins/nn.py +++ b/tpot/builtins/nn.py @@ -120,6 +120,7 @@ def fit(self, X, y): # pylint: disable=no-member self._init_model(X, y) + self.classes_ = np.unique(y) assert _pytorch_model_is_fully_initialized(self)