Skip to content

Commit

Permalink
printing hot fix
Browse files Browse the repository at this point in the history
fixing the
  • Loading branch information
isaacmg committed Sep 19, 2021
1 parent c2da628 commit 053a504
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions flood_forecast/pytorch_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,6 @@ def torch_single_train(model: PyTorchForecast,
output = output[:, :, 0:multi_targets]
labels = trg[:, -pred_len:, 0:multi_targets]
multi_targets = False
print(trg.shape)
if model.params["dataset_params"]["class"] == "GeneralClassificationLoader":
labels = trg
elif multi_targets == 1:
Expand Down Expand Up @@ -552,7 +551,7 @@ def compute_validation(validation_loader: DataLoader,
scaled = {k.__class__.__name__: v / (len(validation_loader.dataset) - 1) for k, v in scaled_crit.items()}
wandb.log({'epoch': epoch, val_or_test: scaled})
if classification:
print("Plotting classification metrics")
print("Plotting test classification metrics")
label_list = torch.cat(label_list)
label_list = label_list[:, 0, :]
mod_output1 = torch.cat(mod_output_list)[:, 0, :]
Expand Down

0 comments on commit 053a504

Please sign in to comment.