Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiLabelClassificationTask predict step #792

Conversation

isaaccorley
Copy link
Collaborator

This PR is similar to #790 and adds a predict_step method to MultiLabelClassificationTask so that users can utilize PyTorch Lightning's trainer.predict() function which will automatically loop and predict over a PyTorch DataLoader e.g.:

preds = trainer.predict(model=task, dataloaders=datamodule.test_dataloader())

# Or if your datamodule has a `predict_dataloader` method defined
preds = trainer.predict(model=task, datamodule=datamodule)

@isaaccorley isaaccorley added the trainers PyTorch Lightning trainers label Sep 26, 2022
@isaaccorley isaaccorley added this to the 0.3.2 milestone Sep 26, 2022
@isaaccorley isaaccorley self-assigned this Sep 26, 2022
@github-actions github-actions bot added the testing Continuous integration testing label Sep 26, 2022
@isaaccorley isaaccorley force-pushed the trainers/multilabelclassificationtask-predict-step branch from 3523e88 to f94a684 Compare September 27, 2022 17:22
@isaaccorley isaaccorley force-pushed the trainers/multilabelclassificationtask-predict-step branch from f94a684 to 3e9cd35 Compare September 28, 2022 14:17
@isaaccorley
Copy link
Collaborator Author

Not sure why codecov is complaining. When I go to the link it only shows coverage dips on scripts I didn't modify.

@isaaccorley isaaccorley reopened this Sep 28, 2022
@calebrob6 calebrob6 merged commit 079e2e4 into microsoft:main Sep 30, 2022
@isaaccorley isaaccorley deleted the trainers/multilabelclassificationtask-predict-step branch September 30, 2022 18:37
@adamjstewart adamjstewart mentioned this pull request Oct 4, 2022
6 tasks
@adamjstewart adamjstewart modified the milestones: 0.3.2, 0.4.0 Jan 23, 2023
yichiac pushed a commit to yichiac/torchgeo that referenced this pull request Apr 29, 2023
* add predict_step to MultiLabelClassificationTask

* fix docs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
testing Continuous integration testing trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants