Skip to content

Commit

Permalink
Update mypy job to 1.1.1 (#16974)
Browse files Browse the repository at this point in the history
(cherry picked from commit 9583128)
  • Loading branch information
carmocca authored and Borda committed Mar 31, 2023
1 parent 65b2637 commit b20da7f
Show file tree
Hide file tree
Showing 14 changed files with 45 additions and 43 deletions.
2 changes: 1 addition & 1 deletion requirements/typing.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mypy==0.982
mypy==1.1.1
-f https://download.pytorch.org/whl/test/cpu/torch_test.html --pre
torch==2.0.0

Expand Down
2 changes: 1 addition & 1 deletion src/lightning_app/core/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
root: Union["LightningFlow", LightningWork],
flow_cloud_compute: Optional["CloudCompute"] = None,
log_level: str = "info",
info: frontend.AppInfo = None,
info: Optional[frontend.AppInfo] = None,
root_path: str = "",
) -> None:
"""The Lightning App, or App in short runs a tree of one or more components that interact to create end-to-end
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_app/utilities/packaging/build_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def to_dict(self) -> Dict:
return {"__build_config__": asdict(self)}

@classmethod
def from_dict(cls, d: Dict) -> Self: # type: ignore[valid-type]
def from_dict(cls, d: Dict) -> Self:
return cls(**d["__build_config__"])


Expand Down
6 changes: 3 additions & 3 deletions src/lightning_fabric/plugins/collectives/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,12 @@ def destroy_group(cls, group: CollectibleGroup) -> None:
def _convert_to_native_op(cls, op: str) -> Any:
...

def setup(self, **kwargs: Any) -> Self: # type: ignore[valid-type]
def setup(self, **kwargs: Any) -> Self:
if not self.is_initialized():
self.init_group(**kwargs)
return self

def create_group(self, **kwargs: Any) -> Self: # type: ignore[valid-type]
def create_group(self, **kwargs: Any) -> Self:
"""Create a group.
This assumes that :meth:`~lightning_fabric.plugins.collectives.Collective.init_group` has been
Expand All @@ -127,7 +127,7 @@ def create_group(self, **kwargs: Any) -> Self: # type: ignore[valid-type]
self._group = self.new_group(**kwargs)
return self

def teardown(self) -> Self: # type: ignore[valid-type]
def teardown(self) -> Self:
if self._group is None:
raise RuntimeError(f"`{type(self).__name__}` does not own a group to destroy.")
self.destroy_group(self._group)
Expand Down
6 changes: 2 additions & 4 deletions src/lightning_fabric/plugins/collectives/torch_collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ def barrier(self, device_ids: Optional[List[int]] = None) -> None:
def monitored_barrier(self, timeout: Optional[datetime.timedelta] = None, wait_all_ranks: bool = False) -> None:
dist.monitored_barrier(group=self.group, timeout=timeout, wait_all_ranks=wait_all_ranks)

def setup(
self, main_address: Optional[str] = None, main_port: Optional[str] = None, **kwargs: Any
) -> Self: # type: ignore[valid-type]
def setup(self, main_address: Optional[str] = None, main_port: Optional[str] = None, **kwargs: Any) -> Self:
if self.is_initialized():
return self
# maybe set addr
Expand All @@ -134,7 +132,7 @@ def setup(
os.environ.pop("MASTER_PORT", None)
return self

def teardown(self) -> Self: # type: ignore[valid-type]
def teardown(self) -> Self:
non_group_member = self.group == dist.GroupMember.NON_GROUP_MEMBER
super().teardown() # will destroy its own group
# try to destroy the default group. this should only be done by a group member to avoid race conditions,
Expand Down
14 changes: 7 additions & 7 deletions src/lightning_fabric/utilities/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ def device(self) -> torch.device:

return device

def to(self, *args: Any, **kwargs: Any) -> Self: # type: ignore[valid-type]
def to(self, *args: Any, **kwargs: Any) -> Self:
"""See :meth:`torch.nn.Module.to`."""
# this converts `str` device to `torch.device`
device, dtype = torch._C._nn._parse_to(*args, **kwargs)[:2]
self.__update_properties(device=device, dtype=dtype)
return super().to(*args, **kwargs)

def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: # type: ignore[valid-type]
def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self:
"""Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers
different objects. So it should be called before constructing optimizer if the module will live on GPU
while being optimized.
Expand All @@ -72,27 +72,27 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: # ty
self.__update_properties(device=device)
return super().cuda(device=device)

def cpu(self) -> Self: # type: ignore[valid-type]
def cpu(self) -> Self:
"""See :meth:`torch.nn.Module.cpu`."""
self.__update_properties(device=torch.device("cpu"))
return super().cpu()

def type(self, dst_type: Union[str, torch.dtype]) -> Self: # type: ignore[valid-type]
def type(self, dst_type: Union[str, torch.dtype]) -> Self:
"""See :meth:`torch.nn.Module.type`."""
self.__update_properties(dtype=dst_type)
return super().type(dst_type=dst_type)

def float(self) -> Self: # type: ignore[valid-type]
def float(self) -> Self:
"""See :meth:`torch.nn.Module.float`."""
self.__update_properties(dtype=torch.float)
return super().float()

def double(self) -> Self: # type: ignore[valid-type]
def double(self) -> Self:
"""See :meth:`torch.nn.Module.double`."""
self.__update_properties(dtype=torch.double)
return super().double()

def half(self) -> Self: # type: ignore[valid-type]
def half(self) -> Self:
"""See :meth:`torch.nn.Module.half`."""
self.__update_properties(dtype=torch.half)
return super().half()
Expand Down
4 changes: 2 additions & 2 deletions src/lightning_fabric/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
if torch.distributed.is_available():
from torch.distributed import ProcessGroup, ReduceOp

RedOpType: TypeAlias = ReduceOp.RedOpType if _TORCH_GREATER_EQUAL_1_13 else object # type: ignore[misc]
RedOpType: TypeAlias = ReduceOp.RedOpType if _TORCH_GREATER_EQUAL_1_13 else object # type: ignore[valid-type]
else:
ProcessGroup = Any # type: ignore[assignment,misc]
ReduceOp = object # type: ignore[assignment,misc] # we are using isinstance check once
Expand Down Expand Up @@ -75,7 +75,7 @@ def step(self, epoch: Optional[int] = None) -> None:


_TORCH_LRSCHEDULER: TypeAlias = (
torch.optim.lr_scheduler.LRScheduler # type: ignore[misc]
torch.optim.lr_scheduler.LRScheduler # type: ignore[valid-type]
if _TORCH_GREATER_EQUAL_2_0
else torch.optim.lr_scheduler._LRScheduler
)
Expand Down
1 change: 1 addition & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed registration of `ShardedTensor` state dict hooks in `LightningModule.__init__` with `torch>=2.1` ([#16892](https://github.com/Lightning-AI/lightning/pull/16892))

- Removed the `lightning.pytorch.core.saving.ModelIO` class interface ([#16974](https://github.com/Lightning-AI/lightning/pull/16974))


### Fixed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def compare_version(package: str, op: Callable, version: str, use_base_version:

try:
if hasattr(torchmetrics.metric, "_compare_version"):
torchmetrics.metric._compare_version = compare_version # type: ignore
torchmetrics.metric._compare_version = compare_version
except AttributeError:
pass
pickle.Unpickler = RedirectingUnpickler # type: ignore
15 changes: 8 additions & 7 deletions src/pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""LightningDataModule for loading DataLoaders with ease."""
import inspect
from argparse import ArgumentParser, Namespace
from typing import Any, Dict, IO, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, cast, Dict, IO, List, Mapping, Optional, Sequence, Tuple, Union

from torch.utils.data import DataLoader, Dataset, IterableDataset
from typing_extensions import Self
Expand Down Expand Up @@ -188,13 +188,13 @@ def predict_dataloader() -> EVAL_DATALOADERS:

datamodule = cls(**datamodule_kwargs, **special_kwargs)
if train_dataset is not None:
datamodule.train_dataloader = train_dataloader # type: ignore[assignment]
datamodule.train_dataloader = train_dataloader # type: ignore[method-assign]
if val_dataset is not None:
datamodule.val_dataloader = val_dataloader # type: ignore[assignment]
datamodule.val_dataloader = val_dataloader # type: ignore[method-assign]
if test_dataset is not None:
datamodule.test_dataloader = test_dataloader # type: ignore[assignment]
datamodule.test_dataloader = test_dataloader # type: ignore[method-assign]
if predict_dataset is not None:
datamodule.predict_dataloader = predict_dataloader # type: ignore[assignment]
datamodule.predict_dataloader = predict_dataloader # type: ignore[method-assign]
return datamodule

def state_dict(self) -> Dict[str, Any]:
Expand All @@ -219,7 +219,7 @@ def load_from_checkpoint(
checkpoint_path: Union[_PATH, IO],
hparams_file: Optional[_PATH] = None,
**kwargs: Any,
) -> Self: # type: ignore[valid-type]
) -> Self:
r"""
Primary way of loading a datamodule from a checkpoint. When Lightning saves a checkpoint
it stores the arguments passed to ``__init__`` in the checkpoint under ``"datamodule_hyper_parameters"``.
Expand Down Expand Up @@ -273,11 +273,12 @@ def load_from_checkpoint(
)
"""
return _load_from_checkpoint(
loaded = _load_from_checkpoint(
cls,
checkpoint_path,
map_location=None,
hparams_file=hparams_file,
strict=None,
**kwargs,
)
return cast(Self, loaded)
7 changes: 5 additions & 2 deletions src/pytorch_lightning/strategies/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,9 @@ def teardown(self) -> None:
assert self.lightning_module is not None
if self._optimizer_zero_grad_original is not None:
# re-enable `optimizer_zero_grad`
self.lightning_module.optimizer_zero_grad = self._optimizer_zero_grad_original # type: ignore[assignment]
self.lightning_module.optimizer_zero_grad = ( # type: ignore[method-assign]
self._optimizer_zero_grad_original
)

for model in self.poptorch_models.values():
model.destroy()
Expand Down Expand Up @@ -362,7 +364,8 @@ def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
self.poptorch_models[RunningStage.TRAINING].setOptimizer(optimizer)

@property
def root_device(self) -> torch.device:
def root_device(self) -> torch.device: # type: ignore[empty-body]
# TODO: this should return `self.parallel_devices[self.local_rank]`
pass

def model_to_device(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> "_Meta
return meta


class _ResultMetric(Metric, _DeviceDtypeModuleMixin):
class _ResultMetric(Metric, _DeviceDtypeModuleMixin): # type: ignore[misc] # torchmetrics methods should return Self
"""Wraps the value provided to `:meth:`~pytorch_lightning.core.module.LightningModule.log`"""

def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
Expand Down
20 changes: 10 additions & 10 deletions src/pytorch_lightning/utilities/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ def from_compiled(model: "torch._dynamo.OptimizedModule") -> "pl.LightningModule
"original_predict_step": orig_module.predict_step,
}

orig_module.forward = model.dynamo_ctx(orig_module.forward) # type: ignore[assignment]
orig_module.forward = model.dynamo_ctx(orig_module.forward) # type: ignore[method-assign]
if not _TORCH_GREATER_EQUAL_2_1:
orig_module.forward._torchdynamo_inline = orig_module.forward # https://github.com/pytorch/pytorch/issues/95630
orig_module.training_step = model.dynamo_ctx(orig_module.training_step) # type: ignore[assignment]
orig_module.validation_step = model.dynamo_ctx(orig_module.validation_step) # type: ignore[assignment]
orig_module.test_step = model.dynamo_ctx(orig_module.test_step) # type: ignore[assignment]
orig_module.predict_step = model.dynamo_ctx(orig_module.predict_step) # type: ignore[assignment]
orig_module.training_step = model.dynamo_ctx(orig_module.training_step) # type: ignore[method-assign]
orig_module.validation_step = model.dynamo_ctx(orig_module.validation_step) # type: ignore[method-assign]
orig_module.test_step = model.dynamo_ctx(orig_module.test_step) # type: ignore[method-assign]
orig_module.predict_step = model.dynamo_ctx(orig_module.predict_step) # type: ignore[method-assign]
return orig_module


Expand Down Expand Up @@ -95,11 +95,11 @@ def to_uncompiled(model: Union["pl.LightningModule", "torch._dynamo.OptimizedMod

ctx = model._compiler_ctx
if ctx is not None:
model.forward = ctx["original_forward"] # type: ignore[assignment]
model.training_step = ctx["original_training_step"] # type: ignore[assignment]
model.validation_step = ctx["original_validation_step"] # type: ignore[assignment]
model.test_step = ctx["original_test_step"] # type: ignore[assignment]
model.predict_step = ctx["original_predict_step"] # type: ignore[assignment]
model.forward = ctx["original_forward"] # type: ignore[method-assign]
model.training_step = ctx["original_training_step"] # type: ignore[method-assign]
model.validation_step = ctx["original_validation_step"] # type: ignore[method-assign]
model.test_step = ctx["original_test_step"] # type: ignore[method-assign]
model.predict_step = ctx["original_predict_step"] # type: ignore[method-assign]
model._compiler_ctx = None

return model
Expand Down
5 changes: 2 additions & 3 deletions src/pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn, WarningCache

# might be supported in later releases, see https://github.com/python/mypy/pull/13297
BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]] # type: ignore[misc]
BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]]

warning_cache = WarningCache()

Expand All @@ -59,7 +58,7 @@ def _extract_batch_size(batch: BType) -> Generator[Optional[int], None, None]:
for sample in batch:
yield from _extract_batch_size(sample)
elif is_dataclass_instance(batch):
for field in fields(batch):
for field in fields(batch): # type: ignore[arg-type]
yield from _extract_batch_size(getattr(batch, field.name))
else:
yield None
Expand Down

0 comments on commit b20da7f

Please sign in to comment.