Skip to content

Commit

Permalink
fix: float type when output_dim=1
Browse files Browse the repository at this point in the history
  • Loading branch information
eduardocarvp authored and Optimox committed Oct 30, 2019
1 parent 2094d00 commit 7bb7dfd
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions pytorch_tabnet/tab_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def train_epoch(self, train_loader):
for data, targets in train_loader:
batch_outs = self.train_batch(data, targets)
if self.output_dim == 1:
y_preds.append(batch_outs["y_preds"].cpu().detach().numpy())
y_preds.append(batch_outs["y_preds"].cpu().detach().numpy().flatten())
elif self.output_dim == 2:
y_preds.append(batch_outs["y_preds"][:, 1].cpu().detach().numpy())
else:
Expand Down Expand Up @@ -328,7 +328,11 @@ def train_batch(self, data, targets):
"""
self.network.train()
data = data.to(self.device).float()
targets = targets.to(self.device).long()

if self.output_dim == 1:
targets = targets.to(self.device).float()
else:
targets = targets.to(self.device).long()
self.optimizer.zero_grad()

output, M_loss, M_explain, _ = self.network(data)
Expand Down Expand Up @@ -365,7 +369,7 @@ def predict_epoch(self, loader):
batch_outs = self.predict_batch(data, targets)
total_loss += batch_outs["loss"]
if self.output_dim == 1:
y_preds.append(batch_outs["y_preds"].cpu().detach().numpy())
y_preds.append(batch_outs["y_preds"].cpu().detach().numpy().flatten())
elif self.output_dim == 2:
y_preds.append(batch_outs["y_preds"][:, 1].cpu().detach().numpy())
else:
Expand Down Expand Up @@ -409,7 +413,10 @@ def predict_batch(self, data, targets):
"""
self.network.eval()
data = data.to(self.device).float()
targets = targets.to(self.device).long()
if self.output_dim == 1:
targets = targets.to(self.device).float()
else:
targets = targets.to(self.device).long()

output, M_loss, M_explain, _ = self.network(data)

Expand Down

0 comments on commit 7bb7dfd

Please sign in to comment.