-
Notifications
You must be signed in to change notification settings - Fork 384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove type ignores for PyTorch #460
Conversation
what was the bug in ETCI2021? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added comments to everything I think was worth pointing out. 85% of the changes are simply removing "type: ignore" from PyTorch functions, 10% are fixes to type hinting mistakes like pytest and pytorch-lightning, and 5% are bugs in ETCI 2021 plotting that were uncovered by mypy. I can try to separate those 10% of changes into a separate PR if that makes it easier to review, but I think the other 90% have to be merged first in order to get our unit tests to pass with PyTorch 1.11.
It's crazy that simply removing "type: ignore" and reformatting code that had to be on multiple lines is enough to drop 500+ lines of code from TorchGeo!
build: | ||
os: ubuntu-20.04 | ||
tools: | ||
python: "3.9" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had to do some hacky type hint stuff that now requires Python 3.9+ to build the docs or run mypy. Basically, both typing
and collections
have an OrderedDict
, but collections.OrderedDict
doesn't support type hints until Python 3.9+ and typing.OrderedDict
doesn't exist until Python 3.7.2+. So if we want to support Python 3.6 (will be dropped soon) or 3.7.1 we need to use collections.OrderedDict
. This does not affect run-time since I wrapped it in quotes and it's only evaluated by mypy/sphinx.
device: torch.device, # type: ignore[name-defined] | ||
metrics: Metric, | ||
device: torch.device, | ||
metrics: MetricCollection, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could use Union[Metric, MetricCollection]
but currently we are only using MetricCollection
arguments.
@@ -158,7 +159,7 @@ def main(args: argparse.Namespace) -> None: | |||
"loss": model.hparams["loss"], | |||
} | |||
elif issubclass(TASK, SemanticSegmentationTask): | |||
val_row: Dict[str, Union[str, float]] = { # type: ignore[no-redef] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to redefine this type and then ignore the fact that we redefined this type
@@ -23,38 +23,30 @@ def download_url(url: str, root: str, *args: str) -> None: | |||
|
|||
class TestADVANCE: | |||
@pytest.fixture | |||
def dataset( | |||
self, monkeypatch: Generator[MonkeyPatch, None, None], tmp_path: Path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know why I was under the impression that monkeypatch
is a Generator
but it is in fact just a MonkeyPatch
object. So we no longer need to ignore type hints relating to pytest!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This definitely makes more sense -- I skeptically copy+pasted the Generator type several times but was never skeptical enough to investigate.
@@ -263,7 +234,7 @@ def dataset( | |||
|
|||
def test_getitem(self, dataset: SpaceNet5) -> None: | |||
# Iterate over all elements to maximize coverage | |||
samples = [i for i in dataset] # type: ignore[attr-defined] | |||
samples = [dataset[i] for i in range(len(dataset))] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm actually pretty confused why this used to work because dataset
isn't iterable.
@@ -161,7 +159,7 @@ def __init__( | |||
model: Module, | |||
projection_size: int = 256, | |||
hidden_size: int = 4096, | |||
layer: Union[str, int] = -2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we index self.model.children()
, I believe this is a nn.ModuleList
, so only integer indices are allowed. So far we're only using integer indices in our code. I don't know of a reason why we need to support string indices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe there was an option here where you could look up a layer by name that I removed for simplicity and forgot to update
# Copying the weights of the old layer to the extra channels | ||
for i in range(in_channels - layer.in_channels): | ||
channel = layer.in_channels + i | ||
new_layer.weight[:, channel : channel + 1, :, :].data[ | ||
... # type: ignore[index] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mypy doesn't like ...
for some reason, :
should be equivalent as far as I know.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is okay here
self.save_hyperparameters() # creates `self.hparams` from kwargs | ||
|
||
# Creates `self.hparams` from kwargs | ||
self.save_hyperparameters() # type: ignore[operator] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For some reason mypy thinks this is a Tensor
, not a Callable
. We could cast it but I'm not sure exactly what kind of Callable
this is. I'm going to chalk this up to "pytorch-lightning is super hacky" and leave it for future work to figure out why this doesn't work.
|
||
self.config_task() | ||
|
||
def forward(self, x: Tensor) -> Any: # type: ignore[override] | ||
def forward(self, *args: Any, **kwargs: Any) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's best to avoid overriding type signatures of functions in subclasses
@@ -18,7 +18,7 @@ | |||
Conv2d.__module__ = "nn.Conv2d" | |||
|
|||
|
|||
def extract_encoder(path: str) -> Tuple[str, Dict[str, Tensor]]: | |||
def extract_encoder(path: str) -> Tuple[str, "OrderedDict[str, Tensor]"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the OrderedDict
hack I mentioned above. model.load_state_dict
requires an OrderedDict
, not Dict
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the hparams
changes in the trainers, can we just reset self.hparams
in the constructor after calling save_hyperparameters to avoid having to have make a local copy of the hyperparameters everywhere (just to keep mypy happy)?
That's what I tried originally but pytorch-lightning doesn't like this:
I'm guessing it's a read only |
Can we just make a new Note: this isn't a requirement for this PR (I can open a followup with the .long() and .float() changes too), just trying to figure out how to work around having these local |
* Remove type ignores for PyTorch * Mypy fixes for pytest MonkeyPatch * Black * Ignore Identity * Generic fixes * Remove unused Generator import * More fixes * Fix remaining mypy errors * More typing cleanups * typing.OrderedDict isn't available until Python 3.7.2+ * Need Python 3.9 to build docs for fancy OrderedDict * Fix Python 3.8 and earlier support * Fix BigEarthNet tests * Fix bug in ETCI 2021 tests * Remove unused flake8 ignore * More robust and well-documented trainer steps * Many functions don't actually use batch_idx * Store cast hparams in trainers
* Remove type ignores for PyTorch * Mypy fixes for pytest MonkeyPatch * Black * Ignore Identity * Generic fixes * Remove unused Generator import * More fixes * Fix remaining mypy errors * More typing cleanups * typing.OrderedDict isn't available until Python 3.7.2+ * Need Python 3.9 to build docs for fancy OrderedDict * Fix Python 3.8 and earlier support * Fix BigEarthNet tests * Fix bug in ETCI 2021 tests * Remove unused flake8 ignore * More robust and well-documented trainer steps * Many functions don't actually use batch_idx * Store cast hparams in trainers
* Remove type ignores for PyTorch * Mypy fixes for pytest MonkeyPatch * Black * Ignore Identity * Generic fixes * Remove unused Generator import * More fixes * Fix remaining mypy errors * More typing cleanups * typing.OrderedDict isn't available until Python 3.7.2+ * Need Python 3.9 to build docs for fancy OrderedDict * Fix Python 3.8 and earlier support * Fix BigEarthNet tests * Fix bug in ETCI 2021 tests * Remove unused flake8 ignore * More robust and well-documented trainer steps * Many functions don't actually use batch_idx * Store cast hparams in trainers
PyTorch 1.11 added type hints for most of the library. This PR removes type ignores for the majority of PyTorch functions and fixes other random typing mistakes I found along the way.
Closes #266 (no longer needed)