Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update mypy job to torch 2.0 #16933

Merged
merged 4 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/code-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
FREEZE_REQUIREMENTS: 1
run: |
# todo: adjust requirements for both code-bases
pip install -e '.[extra,ui,cloud]' -r requirements/typing.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
pip install -e '.[extra,ui,cloud]' -r requirements/typing.txt
pip list

- name: Check typing
Expand Down
3 changes: 2 additions & 1 deletion requirements/typing.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mypy==0.982
carmocca marked this conversation as resolved.
Show resolved Hide resolved
torch==1.12.0
-f https://download.pytorch.org/whl/test/cpu/torch_test.html --pre
torch==2.0.0

types-Markdown
types-PyYAML
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/fabric/plugins/collectives/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ def all_to_all(self, output_tensor_list: List[Tensor], input_tensor_list: List[T
...

@abstractmethod
def send(self, tensor: Tensor, dst: int, tag: Optional[int] = 0) -> None:
def send(self, tensor: Tensor, dst: int, tag: int = 0) -> None:
...

@abstractmethod
def recv(self, tensor: Tensor, src: Optional[int] = None, tag: Optional[int] = 0) -> Tensor:
def recv(self, tensor: Tensor, src: Optional[int] = None, tag: int = 0) -> Tensor:
...

@abstractmethod
Expand Down
14 changes: 7 additions & 7 deletions src/lightning/fabric/plugins/collectives/torch_collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ def __init__(self) -> None:
@property
def rank(self) -> int:
# local rank
return dist.get_rank(self.group)
return dist.get_rank(self.group) # type: ignore[arg-type]

@property
def world_size(self) -> int:
return dist.get_world_size(self.group)
return dist.get_world_size(self.group) # type: ignore[arg-type]

def broadcast(self, tensor: Tensor, src: int) -> Tensor:
dist.broadcast(tensor, src, group=self.group)
Expand Down Expand Up @@ -71,11 +71,11 @@ def all_to_all(self, output_tensor_list: List[Tensor], input_tensor_list: List[T
dist.all_to_all(output_tensor_list, input_tensor_list, group=self.group)
return output_tensor_list

def send(self, tensor: Tensor, dst: int, tag: Optional[int] = 0) -> None:
dist.send(tensor, dst, tag=tag, group=self.group)
def send(self, tensor: Tensor, dst: int, tag: int = 0) -> None:
dist.send(tensor, dst, tag=tag, group=self.group) # type: ignore[arg-type]

def recv(self, tensor: Tensor, src: Optional[int] = None, tag: Optional[int] = 0) -> Tensor:
dist.recv(tensor, src, tag=tag, group=self.group)
def recv(self, tensor: Tensor, src: Optional[int] = None, tag: int = 0) -> Tensor:
dist.recv(tensor, src, tag=tag, group=self.group) # type: ignore[arg-type]
return tensor

def all_gather_object(self, object_list: List[Any], obj: Any) -> List[Any]:
Expand Down Expand Up @@ -168,7 +168,7 @@ def new_group(cls, **kwargs: Any) -> CollectibleGroup:
def destroy_group(cls, group: CollectibleGroup) -> None:
# can be called by all processes in the default group, group will be `object()` if they are not part of the
# current group
dist.destroy_process_group(group)
dist.destroy_process_group(group) # type: ignore[arg-type]

@classmethod
def _convert_to_native_op(cls, op: Union[str, ReduceOp, RedOpType]) -> Union[ReduceOp, RedOpType]:
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,5 +209,5 @@ def no_backward_sync(self, module: Module) -> Generator:
f" `{self.__class__.__name__}.no_backward_sync` is wrapped in `DistributedDataParallel`."
f" Got: {module.__class__.__name__}."
)
with module.no_sync(): # type: ignore[operator]
with module.no_sync():
yield
2 changes: 1 addition & 1 deletion src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def clip_gradients_norm( # type: ignore[override]
"""Clip gradients by norm."""
rank_zero_warn("Gradient Clipping by Norm is currently experimental for FSDP. Proceed with Caution!")
self.precision.unscale_gradients(optimizer)
return module.clip_grad_norm_(max_norm=max_norm, norm_type=norm_type) # type: ignore[return-value]
return module.clip_grad_norm_(max_norm=max_norm, norm_type=norm_type)

def clip_gradients_value( # type: ignore[override]
self, module: "FullyShardedDataParallel", optimizer: Optimizer, clip_val: Union[float, int]
Expand Down
5 changes: 4 additions & 1 deletion src/lightning/fabric/strategies/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool:
bool: The reduced boolean decision.
"""
decision = torch.tensor(int(decision), device=self.root_device)
decision = self.all_reduce(decision, reduce_op=ReduceOp.SUM)
decision = self.all_reduce(
decision,
reduce_op=ReduceOp.SUM, # type: ignore[arg-type]
)
decision = bool(decision == self.world_size) if all else bool(decision)
return decision

Expand Down
9 changes: 6 additions & 3 deletions src/lightning/fabric/utilities/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,18 @@ def _load(
"""
if not isinstance(path_or_url, (str, Path)):
# any sort of BytesIO or similar
return torch.load(path_or_url, map_location=map_location)
return torch.load(
path_or_url,
map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct
)
if str(path_or_url).startswith("http"):
return torch.hub.load_state_dict_from_url(
str(path_or_url),
map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct
map_location=map_location, # type: ignore[arg-type]
)
fs = get_filesystem(path_or_url)
with fs.open(path_or_url, "rb") as f:
return torch.load(f, map_location=map_location)
return torch.load(f, map_location=map_location) # type: ignore[arg-type]


def get_filesystem(path: _PATH, **kwargs: Any) -> AbstractFileSystem:
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[U
op: Optional[ReduceOp]
if isinstance(reduce_op, str):
if reduce_op.lower() in ("avg", "mean"):
op = ReduceOp.SUM
op = ReduceOp.SUM # type: ignore[assignment]
divide_by_world_size = True
else:
op = getattr(ReduceOp, reduce_op.upper())
Expand Down
9 changes: 6 additions & 3 deletions src/lightning/fabric/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
from torch import Tensor
from torch.optim import Optimizer
from typing_extensions import TypeAlias

from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_13, _TORCH_GREATER_EQUAL_2_0

Expand All @@ -29,7 +30,7 @@
if torch.distributed.is_available():
from torch.distributed import ProcessGroup, ReduceOp

RedOpType = ReduceOp.RedOpType if _TORCH_GREATER_EQUAL_1_13 else object
RedOpType: TypeAlias = ReduceOp.RedOpType if _TORCH_GREATER_EQUAL_1_13 else object # type: ignore[misc]
else:
ProcessGroup = Any # type: ignore[assignment,misc]
ReduceOp = object # type: ignore[assignment,misc] # we are using isinstance check once
Expand Down Expand Up @@ -73,8 +74,10 @@ def step(self, epoch: Optional[int] = None) -> None:
...


_TORCH_LRSCHEDULER = (
torch.optim.lr_scheduler.LRScheduler if _TORCH_GREATER_EQUAL_2_0 else torch.optim.lr_scheduler._LRScheduler
_TORCH_LRSCHEDULER: TypeAlias = (
torch.optim.lr_scheduler.LRScheduler # type: ignore[misc]
if _TORCH_GREATER_EQUAL_2_0
else torch.optim.lr_scheduler._LRScheduler
)


Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/overrides/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _find_tensors(
def prepare_for_backward(model: DistributedDataParallel, output: Any) -> None:
# `prepare_for_backward` is `DistributedDataParallel` specific.
if torch.is_grad_enabled() and model.require_backward_grad_sync:
model.require_forward_param_sync = True # type: ignore[assignment]
model.require_forward_param_sync = True
# We'll return the output object verbatim since it is a freeform
# object. We need to find any tensors in this object, though,
# because we need to figure out which parameters were used during
Expand All @@ -52,7 +52,7 @@ def prepare_for_backward(model: DistributedDataParallel, output: Any) -> None:
reducer._rebuild_buckets() # avoids "INTERNAL ASSERT FAILED" with `find_unused_parameters=False`
reducer.prepare_for_backward(args)
else:
model.require_forward_param_sync = False # type: ignore[assignment]
model.require_forward_param_sync = False


class UnrepeatedDistributedSampler(DistributedSampler):
Expand Down
4 changes: 1 addition & 3 deletions src/lightning/pytorch/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,7 @@ def teardown(self) -> None:

pl_module = self.lightning_module
if isinstance(self.model, DistributedDataParallel):
if not self.model.static_graph and self.model._get_ddp_logging_data().get( # type: ignore[operator]
"can_set_static_graph"
):
if not self.model.static_graph and self.model._get_ddp_logging_data().get("can_set_static_graph"):
rank_zero_info(
"Your model can run with static graph optimizations. For future training runs, we suggest you"
f" pass `Trainer(..., strategy={self.__class__.__name__}(static_graph=True))` to enable them."
Expand Down
5 changes: 4 additions & 1 deletion src/lightning/pytorch/strategies/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,10 @@ def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool:
bool: The reduced boolean decision.
"""
decision = torch.tensor(int(decision), device=self.root_device)
decision = self.reduce(decision, reduce_op=ReduceOp.SUM)
decision = self.reduce(
decision,
reduce_op=ReduceOp.SUM, # type: ignore[arg-type]
)
decision = bool(decision == self.world_size) if all else bool(decision)
return decision

Expand Down
18 changes: 12 additions & 6 deletions src/lightning/pytorch/utilities/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def to_uncompiled(model: Union["pl.LightningModule", "torch._dynamo.OptimizedMod

if isinstance(model, OptimizedModule):
model = model._orig_mod
if not isinstance(model, pl.LightningModule):
raise TypeError(
f"Unexpected error, the wrapped model should be a LightningModule, found {type(model).__name__}"
)

elif isinstance(model, pl.LightningModule):
if model._compiler_ctx is None:
Expand All @@ -88,12 +92,14 @@ def to_uncompiled(model: Union["pl.LightningModule", "torch._dynamo.OptimizedMod
else:
raise ValueError("`model` must either be an instance of OptimizedModule or LightningModule")

model.forward = model._compiler_ctx["original_forward"]
model.training_step = model._compiler_ctx["original_training_step"]
model.validation_step = model._compiler_ctx["original_validation_step"]
model.test_step = model._compiler_ctx["original_test_step"]
model.predict_step = model._compiler_ctx["original_predict_step"]
model._compiler_ctx = None
ctx = model._compiler_ctx
if ctx is not None:
model.forward = ctx["original_forward"] # type: ignore[assignment]
model.training_step = ctx["original_training_step"] # type: ignore[assignment]
model.validation_step = ctx["original_validation_step"] # type: ignore[assignment]
model.test_step = ctx["original_test_step"] # type: ignore[assignment]
model.predict_step = ctx["original_predict_step"] # type: ignore[assignment]
model._compiler_ctx = None

return model

Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def register_ddp_comm_hook(
ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook)

rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.")
model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook) # type: ignore[operator]
model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook)


def _broadcast_object_list(obj: Any, rank: int) -> Any:
Expand Down