Skip to content

Commit

Permalink
add predict_step to MultiLabelClassificationTask
Browse files Browse the repository at this point in the history
  • Loading branch information
isaaccorley committed Sep 27, 2022
1 parent d049e7e commit f94a684
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
1 change: 1 addition & 0 deletions tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def test_trainer(
trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1)
trainer.fit(model=model, datamodule=datamodule)
trainer.test(model=model, datamodule=datamodule)
trainer.predict(model=model, dataloaders=datamodule.val_dataloader())

def test_no_logger(self) -> None:
conf = OmegaConf.load(os.path.join("tests", "conf", "bigearthnet_s1.yaml"))
Expand Down
12 changes: 12 additions & 0 deletions torchgeo/trainers/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,3 +388,15 @@ def test_step(self, *args: Any, **kwargs: Any) -> None:
# by default, the test and validation steps only log per *epoch*
self.log("test_loss", loss, on_step=False, on_epoch=True)
self.test_metrics(y_hat_hard, y)

def predict_step(self, *args: Any, **kwargs: Any) -> Tensor:
"""Compute and return the predictions.
Args:
batch: the output of your DataLoader
Returns:
predicted sigmoid probabilities
"""
batch = args[0]
x = batch["image"]
y_hat = torch.sigmoid(self(x))
return y_hat

0 comments on commit f94a684

Please sign in to comment.