Skip to content

Commit 2845e75

Browse files
Fix mypy errors attributed to pytorch_lightning.utilities.distributed (#13678)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
1 parent e23756b commit 2845e75

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ module = [
8484
"pytorch_lightning.tuner.batch_size_scaling",
8585
"pytorch_lightning.utilities.auto_restart",
8686
"pytorch_lightning.utilities.data",
87-
"pytorch_lightning.utilities.distributed",
8887
"pytorch_lightning.utilities.meta",
8988
]
9089
ignore_errors = "True"

src/pytorch_lightning/utilities/distributed.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Un
145145
if group is None:
146146
group = torch.distributed.group.WORLD
147147

148+
op: Optional[ReduceOp]
148149
if isinstance(reduce_op, str):
149150
if reduce_op.lower() in ("avg", "mean"):
150151
op = ReduceOp.SUM
@@ -174,7 +175,7 @@ def sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Un
174175

175176
class AllGatherGrad(torch.autograd.Function):
176177
@staticmethod
177-
def forward(
178+
def forward( # type: ignore[override]
178179
ctx: Any,
179180
tensor: Tensor,
180181
group: Optional["torch.distributed.ProcessGroup"] = group.WORLD,
@@ -317,7 +318,7 @@ def register_ddp_comm_hook(
317318
ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook)
318319

319320
new_rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.")
320-
model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook)
321+
model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook) # type: ignore[operator]
321322

322323

323324
def tpu_distributed() -> bool:

0 commit comments

Comments
 (0)