Skip to content

Commit

Permalink
Update mypy job to torch 2.0 (#16933)
Browse files Browse the repository at this point in the history
(cherry picked from commit c2b28a0)
  • Loading branch information
carmocca authored and Borda committed Mar 30, 2023
1 parent 14d1bde commit 6d4303f
Show file tree
Hide file tree
Showing 17 changed files with 58 additions and 40 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/code-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
FREEZE_REQUIREMENTS: 1
run: |
# todo: adjust requirements for both code-bases
pip install -e '.[extra,ui,cloud]' -r requirements/typing.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
pip install -e '.[extra,ui,cloud]' -r requirements/typing.txt
pip list
- name: Check typing
Expand Down
3 changes: 2 additions & 1 deletion requirements/typing.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mypy==0.982
torch==1.12.0
-f https://download.pytorch.org/whl/test/cpu/torch_test.html --pre
torch==2.0.0

types-Markdown
types-PyYAML
Expand Down
4 changes: 2 additions & 2 deletions src/lightning_fabric/plugins/collectives/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ def all_to_all(self, output_tensor_list: List[Tensor], input_tensor_list: List[T
...

@abstractmethod
def send(self, tensor: Tensor, dst: int, tag: Optional[int] = 0) -> None:
def send(self, tensor: Tensor, dst: int, tag: int = 0) -> None:
...

@abstractmethod
def recv(self, tensor: Tensor, src: Optional[int] = None, tag: Optional[int] = 0) -> Tensor:
def recv(self, tensor: Tensor, src: Optional[int] = None, tag: int = 0) -> Tensor:
...

@abstractmethod
Expand Down
14 changes: 7 additions & 7 deletions src/lightning_fabric/plugins/collectives/torch_collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ def __init__(self) -> None:
@property
def rank(self) -> int:
# local rank
return dist.get_rank(self.group)
return dist.get_rank(self.group) # type: ignore[arg-type]

@property
def world_size(self) -> int:
return dist.get_world_size(self.group)
return dist.get_world_size(self.group) # type: ignore[arg-type]

def broadcast(self, tensor: Tensor, src: int) -> Tensor:
dist.broadcast(tensor, src, group=self.group)
Expand Down Expand Up @@ -71,11 +71,11 @@ def all_to_all(self, output_tensor_list: List[Tensor], input_tensor_list: List[T
dist.all_to_all(output_tensor_list, input_tensor_list, group=self.group)
return output_tensor_list

def send(self, tensor: Tensor, dst: int, tag: Optional[int] = 0) -> None:
dist.send(tensor, dst, tag=tag, group=self.group)
def send(self, tensor: Tensor, dst: int, tag: int = 0) -> None:
dist.send(tensor, dst, tag=tag, group=self.group) # type: ignore[arg-type]

def recv(self, tensor: Tensor, src: Optional[int] = None, tag: Optional[int] = 0) -> Tensor:
dist.recv(tensor, src, tag=tag, group=self.group)
def recv(self, tensor: Tensor, src: Optional[int] = None, tag: int = 0) -> Tensor:
dist.recv(tensor, src, tag=tag, group=self.group) # type: ignore[arg-type]
return tensor

def all_gather_object(self, object_list: List[Any], obj: Any) -> List[Any]:
Expand Down Expand Up @@ -168,7 +168,7 @@ def new_group(cls, **kwargs: Any) -> CollectibleGroup:
def destroy_group(cls, group: CollectibleGroup) -> None:
# can be called by all processes in the default group, group will be `object()` if they are not part of the
# current group
dist.destroy_process_group(group)
dist.destroy_process_group(group) # type: ignore[arg-type]

@classmethod
def _convert_to_native_op(cls, op: Union[str, ReduceOp, RedOpType]) -> Union[ReduceOp, RedOpType]:
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_fabric/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,5 +205,5 @@ def no_backward_sync(self, module: Module) -> Generator:
f" `{self.__class__.__name__}.no_backward_sync` is wrapped in `DistributedDataParallel`."
f" Got: {module.__class__.__name__}."
)
with module.no_sync(): # type: ignore[operator]
with module.no_sync():
yield
5 changes: 4 additions & 1 deletion src/lightning_fabric/strategies/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool:
bool: The reduced boolean decision.
"""
decision = torch.tensor(int(decision), device=self.root_device)
decision = self.all_reduce(decision, reduce_op=ReduceOp.SUM)
decision = self.all_reduce(
decision,
reduce_op=ReduceOp.SUM, # type: ignore[arg-type]
)
decision = bool(decision == self.world_size) if all else bool(decision)
return decision

Expand Down
9 changes: 6 additions & 3 deletions src/lightning_fabric/utilities/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,18 @@ def _load(
"""
if not isinstance(path_or_url, (str, Path)):
# any sort of BytesIO or similar
return torch.load(path_or_url, map_location=map_location)
return torch.load(
path_or_url,
map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct
)
if str(path_or_url).startswith("http"):
return torch.hub.load_state_dict_from_url(
str(path_or_url),
map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct
map_location=map_location, # type: ignore[arg-type]
)
fs = get_filesystem(path_or_url)
with fs.open(path_or_url, "rb") as f:
return torch.load(f, map_location=map_location)
return torch.load(f, map_location=map_location) # type: ignore[arg-type]


def get_filesystem(path: _PATH, **kwargs: Any) -> AbstractFileSystem:
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_fabric/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[U
op: Optional[ReduceOp]
if isinstance(reduce_op, str):
if reduce_op.lower() in ("avg", "mean"):
op = ReduceOp.SUM
op = ReduceOp.SUM # type: ignore[assignment]
divide_by_world_size = True
else:
op = getattr(ReduceOp, reduce_op.upper())
Expand Down
10 changes: 6 additions & 4 deletions src/lightning_fabric/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
from torch import Tensor
from torch.optim import Optimizer
from typing_extensions import Protocol, runtime_checkable
from typing_extensions import Protocol, runtime_checkable, TypeAlias

from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_13, _TORCH_GREATER_EQUAL_2_0

Expand All @@ -30,7 +30,7 @@
if torch.distributed.is_available():
from torch.distributed import ProcessGroup, ReduceOp

RedOpType = ReduceOp.RedOpType if _TORCH_GREATER_EQUAL_1_13 else object
RedOpType: TypeAlias = ReduceOp.RedOpType if _TORCH_GREATER_EQUAL_1_13 else object # type: ignore[misc]
else:
ProcessGroup = Any # type: ignore[assignment,misc]
ReduceOp = object # type: ignore[assignment,misc] # we are using isinstance check once
Expand Down Expand Up @@ -74,8 +74,10 @@ def step(self, epoch: Optional[int] = None) -> None:
...


_TORCH_LRSCHEDULER = (
torch.optim.lr_scheduler.LRScheduler if _TORCH_GREATER_EQUAL_2_0 else torch.optim.lr_scheduler._LRScheduler
_TORCH_LRSCHEDULER: TypeAlias = (
torch.optim.lr_scheduler.LRScheduler # type: ignore[misc]
if _TORCH_GREATER_EQUAL_2_0
else torch.optim.lr_scheduler._LRScheduler
)


Expand Down
12 changes: 6 additions & 6 deletions src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2116,14 +2116,14 @@ def to_uncompiled(cls, model: Union["pl.LightningModule", "torch._dynamo.Optimiz
else:
raise ValueError("`model` must either be an instance of OptimizedModule or LightningModule")

model.forward = model._compiler_ctx["original_forward"]
model.training_step = model._compiler_ctx["original_training_step"]
model.validation_step = model._compiler_ctx["original_validation_step"]
model.test_step = model._compiler_ctx["original_test_step"]
model.predict_step = model._compiler_ctx["original_predict_step"]
model.forward = model._compiler_ctx["original_forward"] # type: ignore[assignment,index]
model.training_step = model._compiler_ctx["original_training_step"] # type: ignore[assignment,index]
model.validation_step = model._compiler_ctx["original_validation_step"] # type: ignore[assignment,index]
model.test_step = model._compiler_ctx["original_test_step"] # type: ignore[assignment,index]
model.predict_step = model._compiler_ctx["original_predict_step"] # type: ignore[assignment,index]
model._compiler_ctx = None

return model
return model # type: ignore[return-value]


@contextmanager
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_lightning/overrides/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _find_tensors(
def prepare_for_backward(model: DistributedDataParallel, output: Any) -> None:
# `prepare_for_backward` is `DistributedDataParallel` specific.
if torch.is_grad_enabled() and model.require_backward_grad_sync:
model.require_forward_param_sync = True # type: ignore[assignment]
model.require_forward_param_sync = True
# We'll return the output object verbatim since it is a freeform
# object. We need to find any tensors in this object, though,
# because we need to figure out which parameters were used during
Expand All @@ -64,7 +64,7 @@ def prepare_for_backward(model: DistributedDataParallel, output: Any) -> None:
reducer._rebuild_buckets() # avoids "INTERNAL ASSERT FAILED" with `find_unused_parameters=False`
reducer.prepare_for_backward(args)
else:
model.require_forward_param_sync = False # type: ignore[assignment]
model.require_forward_param_sync = False


class UnrepeatedDistributedSampler(DistributedSampler):
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/colossalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def reduce(

if isinstance(reduce_op, str):
if reduce_op.lower() in ("avg", "mean"):
reduce_op = ReduceOp.SUM
reduce_op = ReduceOp.SUM # type: ignore[assignment]
div_factor = gpc.get_world_size(parallel_mode=ParallelMode.GLOBAL)
with torch.no_grad():
tensor = tensor / div_factor
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def teardown(self) -> None:
if (
_TORCH_GREATER_EQUAL_1_11
and not self.model.static_graph
and self.model._get_ddp_logging_data().get("can_set_static_graph") # type: ignore[operator]
and self.model._get_ddp_logging_data().get("can_set_static_graph")
):
rank_zero_info(
"Your model can run with static graph optimizations. For future training runs, we suggest you"
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def teardown(self) -> None:
if (
_TORCH_GREATER_EQUAL_1_11
and not self.model.static_graph
and self.model._get_ddp_logging_data().get("can_set_static_graph") # type: ignore[operator]
and self.model._get_ddp_logging_data().get("can_set_static_graph")
):
rank_zero_info(
"Your model can run with static graph optimizations. For future training runs, we suggest you"
Expand Down
5 changes: 4 additions & 1 deletion src/pytorch_lightning/strategies/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,10 @@ def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool:
bool: The reduced boolean decision.
"""
decision = torch.tensor(int(decision), device=self.root_device)
decision = self.reduce(decision, reduce_op=ReduceOp.SUM)
decision = self.reduce(
decision,
reduce_op=ReduceOp.SUM, # type: ignore[arg-type]
)
decision = bool(decision == self.world_size) if all else bool(decision)
return decision

Expand Down
18 changes: 12 additions & 6 deletions src/pytorch_lightning/utilities/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def to_uncompiled(model: Union["pl.LightningModule", "torch._dynamo.OptimizedMod

if isinstance(model, OptimizedModule):
model = model._orig_mod
if not isinstance(model, pl.LightningModule):
raise TypeError(
f"Unexpected error, the wrapped model should be a LightningModule, found {type(model).__name__}"
)

elif isinstance(model, pl.LightningModule):
if model._compiler_ctx is None:
Expand All @@ -89,12 +93,14 @@ def to_uncompiled(model: Union["pl.LightningModule", "torch._dynamo.OptimizedMod
else:
raise ValueError("`model` must either be an instance of OptimizedModule or LightningModule")

model.forward = model._compiler_ctx["original_forward"]
model.training_step = model._compiler_ctx["original_training_step"]
model.validation_step = model._compiler_ctx["original_validation_step"]
model.test_step = model._compiler_ctx["original_test_step"]
model.predict_step = model._compiler_ctx["original_predict_step"]
model._compiler_ctx = None
ctx = model._compiler_ctx
if ctx is not None:
model.forward = ctx["original_forward"] # type: ignore[assignment]
model.training_step = ctx["original_training_step"] # type: ignore[assignment]
model.validation_step = ctx["original_validation_step"] # type: ignore[assignment]
model.test_step = ctx["original_test_step"] # type: ignore[assignment]
model.predict_step = ctx["original_predict_step"] # type: ignore[assignment]
model._compiler_ctx = None

return model

Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def register_ddp_comm_hook(
ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook)

rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.")
model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook) # type: ignore[operator]
model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook)


def _broadcast_object_list(obj: Any, rank: int) -> Any:
Expand Down

0 comments on commit 6d4303f

Please sign in to comment.