diff --git a/pytorch_tabnet/tab_model.py b/pytorch_tabnet/tab_model.py index 94f45b0a..9c59881c 100644 --- a/pytorch_tabnet/tab_model.py +++ b/pytorch_tabnet/tab_model.py @@ -283,7 +283,7 @@ def train_epoch(self, train_loader): for data, targets in train_loader: batch_outs = self.train_batch(data, targets) if self.output_dim == 1: - y_preds.append(batch_outs["y_preds"].cpu().detach().numpy()) + y_preds.append(batch_outs["y_preds"].cpu().detach().numpy().flatten()) elif self.output_dim == 2: y_preds.append(batch_outs["y_preds"][:, 1].cpu().detach().numpy()) else: @@ -328,7 +328,11 @@ def train_batch(self, data, targets): """ self.network.train() data = data.to(self.device).float() - targets = targets.to(self.device).long() + + if self.output_dim == 1: + targets = targets.to(self.device).float() + else: + targets = targets.to(self.device).long() self.optimizer.zero_grad() output, M_loss, M_explain, _ = self.network(data) @@ -365,7 +369,7 @@ def predict_epoch(self, loader): batch_outs = self.predict_batch(data, targets) total_loss += batch_outs["loss"] if self.output_dim == 1: - y_preds.append(batch_outs["y_preds"].cpu().detach().numpy()) + y_preds.append(batch_outs["y_preds"].cpu().detach().numpy().flatten()) elif self.output_dim == 2: y_preds.append(batch_outs["y_preds"][:, 1].cpu().detach().numpy()) else: @@ -409,7 +413,10 @@ def predict_batch(self, data, targets): """ self.network.eval() data = data.to(self.device).float() - targets = targets.to(self.device).long() + if self.output_dim == 1: + targets = targets.to(self.device).float() + else: + targets = targets.to(self.device).long() output, M_loss, M_explain, _ = self.network(data)