-
Notifications
You must be signed in to change notification settings - Fork 387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Increase coverage of trainers #109
Conversation
f75dafa
to
c14173c
Compare
@adamjstewart it doesn't look like Lightning is made for you to call |
Alright, debugged this and got a pattern for getting 100% test coverage + 100% coverage in SEN12MS and Cyclone trainers |
Alright @adamjstewart, summary of what happened here. I wrote tests for the Cyclone, SEN12MS, LandCoverAI, So2Sat, and RESISC45 trainers. In most cases this involved some cleaning up of the trainer code to make everything a little more homogenous. In RESISC45 specifically:
Things that are missing for this PR (can you take these?):
|
03d8752
to
7d71dec
Compare
Up to you. I believe the original intent was to use this to generate pre-trained model weights. |
a60b5f4
to
f00b4e5
Compare
@calebrob6 can you take a look at my last commit? As far as I can tell, the tests are correct, but the trainer task itself is broken? I tried all possible combinations of model and loss and most of them break in one way or another. |
Looking through the test output it seems that the label mask used for training has more values than expected. Also, the training data patches might be very small. How did you generate the test files? |
Ah, that might explain it. I generated 512x512 files with the same number of bands and random ints. So it sounds like I need to be careful with the range of those ints in the label files. |
Okay, this should be good to go now! Once this is merged, TorchGeo will have 100% test coverage. The tests themselves should be fine, but please carefully review all changes to torchgeo itself to make sure I didn't break anything in the process of getting the tests to work. |
@@ -0,0 +1,44 @@ | |||
# Copyright (c) Microsoft Corporation. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
conftest.py
is a special file that pytest looks for. It allows us to share fixtures across multiple files in a directory.
from .test_utils import mocked_log | ||
|
||
|
||
@pytest.fixture(scope="module") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scope="module"
means that this fixture is only instantiated a single time instead of every time it gets requested. This speeds things up a bit, although the trainer tests are still very slow.
scope="module", params=[("all", 15), ("s1", 2), ("s2-all", 13), ("s2-reduced", 6)] | ||
) | ||
def bands(request: SubRequest) -> Tuple[str, int]: | ||
return cast(Tuple[str, int], request.param) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Found this super clever way of syncing two separate fixtures (datamodule
and config
) to share the same values.
|
||
|
||
def test_extract_encoder_unsupported_model(tmp_path: Path) -> None: | ||
checkpoint = {"hyper_parameters": {"some_unsupported_model": "resnet18"}} | ||
path = os.path.join(str(tmp_path), "dummy.ckpt") | ||
torch.save(checkpoint, path) | ||
err = """Unknown checkpoint task. Only encoder or classification_model""" | ||
"""extraction is supported""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests weren't actually working properly. The second line doesn't get appended to the string, so it was only ever checking for the first half.
@@ -80,15 +80,15 @@ def load_state_dict(model: Module, state_dict: Dict[str, Tensor]) -> Module: | |||
|
|||
if in_channels != expected_in_channels: | |||
warnings.warn( | |||
f"""input channels {in_channels} != input channels in pretrained""" | |||
"""model {expected_in_channels}. Overriding with new input channels""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These second lines were missing f-string processing and a space between words
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some other things:
conf/naipchesapeake.conf
needs to be updated- An FYI, there are currently 578 tests for the trainers and it takes 10 minutes to run on my VM.
- I'm guessing a lot of test time is spent running Conv2ds on the CPU. If so, the Chesapeake tests could be sped up with smaller patch sizes. I'm wondering if the fake LandCoverAI data could be reduced in size too (e.g. 512 --> 64 would be way faster).
Have you tried to run the naipchesapeake trainer?Edit: I'm able to run the naipchesapeake trainer in a notebook
Updated in what way? |
After the last commit, most time is actually spent initializing the models: $ pyinstrument $(which pytest) tests/trainers/test_chesapeake.py
...
117.755 <module> <string>:1
[4 frames hidden] <string>, runpy
117.755 _run_code runpy.py:64
└─ 117.755 <module> pytest:3
└─ 117.561 console_main _pytest/config/__init__.py:178
[3486 frames hidden] _pytest, pluggy, typing, inspect, <bu...
64.225 call_fixture_func _pytest/fixtures.py:916
├─ 62.387 task tests/trainers/test_chesapeake.py:55
│ └─ 62.384 __init__ torchgeo/trainers/chesapeake.py:100
│ └─ 61.314 config_task torchgeo/trainers/chesapeake.py:59
│ ├─ 31.118 __init__ segmentation_models_pytorch/unet/model.py:50
│ │ [1183 frames hidden] segmentation_models_pytorch, torchvis...
│ └─ 29.098 __init__ segmentation_models_pytorch/deeplabv3/model.py:123
│ [1361 frames hidden] segmentation_models_pytorch, torchvis...
└─ 1.239 datamodule tests/trainers/test_chesapeake.py:19
└─ 1.238 wrapped_fn pytorch_lightning/core/datamodule.py:393
1.558 call_fixture_func _pytest/fixtures.py:916
└─ 1.553 config tests/trainers/test_chesapeake.py:38
└─ 1.482 load omegaconf/omegaconf.py:178
[1343 frames hidden] omegaconf, yaml, <built-in>, abc, cod...
38.158 pytest_pyfunc_call _pytest/python.py:176
├─ 13.584 test_validation tests/trainers/test_chesapeake.py:77
│ └─ 12.254 validation_step torchgeo/trainers/chesapeake.py:174
│ ├─ 7.470 wrapper matplotlib/_api/deprecation.py:459
│ │ [6280 frames hidden] matplotlib, <string>, <built-in>, con...
│ └─ 4.062 forward torchgeo/trainers/chesapeake.py:128
│ └─ 4.062 _call_impl torch/nn/modules/module.py:1045
│ [239 frames hidden] torch, segmentation_models_pytorch, t...
│ 2.182 forward torchgeo/models/fcn.py:61
│ └─ 2.182 _call_impl torch/nn/modules/module.py:1045
│ [17 frames hidden] torch, <built-in>
├─ 11.427 test_test tests/trainers/test_chesapeake.py:84
│ └─ 10.160 test_step torchgeo/trainers/chesapeake.py:247
│ └─ 9.354 forward torchgeo/trainers/chesapeake.py:128
│ └─ 9.354 _call_impl torch/nn/modules/module.py:1045
│ [237 frames hidden] torch, segmentation_models_pytorch, t...
│ 7.211 forward torchgeo/models/fcn.py:61
│ └─ 7.211 _call_impl torch/nn/modules/module.py:1045
│ [15 frames hidden] torch, <built-in>
... |
The last couple of commits have reduced total pytest time from 8m 26s to 2m 24s (Python 3.9, Ubuntu). Still more than I would like, but you can always test only the file you are working on instead of running all tests. There's likely additional places we can speed things up if we need to, such as smaller raster files and parallel testing. |
Would like to get rid of these warnings:
Any idea if this is in code we can control or if it's internal to PyTorch/torchvision/Kornia? |
I believe this happens in the BYOL trainer in the augmentations when using K.Resize and K.RandomResizedCrop when not explicitly defining align_corners=True. But the warning comes from the nn.Upsample which is under the hood so it may not suppress it but it's worth a try. |
Alright, all warnings have been silenced and the tests are 4x faster now. I think this is ready for a second round of review. |
It doesn't follow how the other config files are arranged and would not work if passed to train.py. This can probably just be deleted. |
Nicely done!! |
How much did reducing the patch sizes speed everything up? |
I think reducing patch size was 2x and reducing params was another 2x. |
Good to know, thanks! Everything here looks good to me |
* Increase coverage of trainers * Actually make the tests work * Updated Cyclone trainer * Style fix in cyclone tests * Fixing landcoverai trainer * Moving mock log to utils * Fixing the RESISC45 trainer and related * Skip RESISC45 trainer tests if Windows * Removing some stupid docstrings from the RESISC45 trainer * Adding So2Sat trainer tests * isort * Adding RESISC45 test data * Use os.path.join for paths * Remove unused import * Add tests for ChesapeakeCVPR trainer * mypy fixes * Fix most Chesapeake tests * Fixed test batching issue in the test dataset * Get 100% coverage of Chesapeake trainer * use a FakeTrainer instead of pl.Trainer * Add naive BYOL trainer tests * Style fixes * Add 100% test coverage for BYOL trainer * Get 100% coverage for LandCover.ai trainer * Simplify tests * Add tests for NAIP + Chesapeake trainer * Fix tests * Add tests for checkpoint loading * Reorganize fixtures and specify scope * Fix various test bugs * Mypy fixes * Reduce patch sizes * Test fewer possible combinations of params * Prevent warnings in tests * Restore missing line of coverage in So2Sat trainer * Silence resampling warning * Ignore difference in output classes Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
Since we no longer run integration tests on main, our trainer modules are currently the least well-tested code. This PR attempts to add unit tests for these trainers. I'm a bit stuck at the moment, so opening this up for feedback on better ways to test this code while I focus on more important things.