Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions ignite/distributed/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,19 +284,19 @@ def __init__(self, sampler: Sampler, num_replicas: Optional[int] = None, rank: O

def __iter__(self) -> Iterator:
# deterministically shuffle based on epoch
torch.manual_seed(self.epoch) # type: ignore[attr-defined]
torch.manual_seed(self.epoch)

indices = [] # type: List
while len(indices) < self.total_size: # type: ignore[attr-defined]
while len(indices) < self.total_size:
indices += list(self.sampler)

if len(indices) > self.total_size: # type: ignore[attr-defined]
indices = indices[: self.total_size] # type: ignore[attr-defined]
if len(indices) > self.total_size:
indices = indices[: self.total_size]

# subsample
indices = indices[self.rank : self.total_size : self.num_replicas] # type: ignore[attr-defined]
if len(indices) != self.num_samples: # type: ignore[attr-defined]
raise RuntimeError("{} vs {}".format(len(indices), self.num_samples)) # type: ignore[attr-defined]
indices = indices[self.rank : self.total_size : self.num_replicas]
if len(indices) != self.num_samples:
raise RuntimeError("{} vs {}".format(len(indices), self.num_samples))

return iter(indices)

Expand Down
30 changes: 13 additions & 17 deletions ignite/distributed/comp_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,11 @@ def _apply_op(
return tensor

def _collective_op(
self, tensor: Union[torch.Tensor, Number, str], fn: Callable, *args: Any, **kwargs: Any
) -> Union[torch.Tensor, Number, List[Number], List[str]]:
self, tensor: Union[torch.Tensor, float, str], fn: Callable, *args: Any, **kwargs: Any
) -> Union[torch.Tensor, float, List[float], List[str]]:
tensor_to_number = tensor_to_str = False
device = self.device()
if isinstance(tensor, Number):
if isinstance(tensor, (Number, float)):
tensor_to_number = True
tensor = torch.tensor(tensor, device=device, dtype=self._collective_op_dtype)
elif isinstance(tensor, str):
Expand All @@ -150,28 +150,26 @@ def _collective_op(

if tensor_to_number:
if tensor.numel() == 1:
return cast(Number, tensor.item())
return tensor.item()
else:
return tensor.tolist()
elif tensor_to_str:
return self._decode_str(tensor)
return tensor

def all_reduce(self, tensor: Union[torch.Tensor, Number], op: str = "sum") -> Union[torch.Tensor, Number]:
def all_reduce(self, tensor: Union[torch.Tensor, float], op: str = "sum") -> Union[torch.Tensor, float]:
if not isinstance(tensor, (torch.Tensor, Number)):
raise TypeError("Unhandled input type {}".format(type(tensor)))

return cast(Union[torch.Tensor, Number], self._collective_op(tensor, self._do_all_reduce, op))
return cast(Union[torch.Tensor, float], self._collective_op(tensor, self._do_all_reduce, op))

def all_gather(
self, tensor: Union[torch.Tensor, Number, str]
) -> Union[torch.Tensor, Number, List[Number], List[str]]:
def all_gather(self, tensor: Union[torch.Tensor, float, str]) -> Union[torch.Tensor, float, List[float], List[str]]:
if not isinstance(tensor, (torch.Tensor, Number, str)):
raise TypeError("Unhandled input type {}".format(type(tensor)))

return self._collective_op(tensor, self._do_all_gather)

def broadcast(self, tensor: Union[torch.Tensor, Number, str], src: int = 0) -> Union[torch.Tensor, Number, str]:
def broadcast(self, tensor: Union[torch.Tensor, float, str], src: int = 0) -> Union[torch.Tensor, float, str]:
if not isinstance(tensor, (torch.Tensor, Number, str)):
raise TypeError("Unhandled input type {}".format(type(tensor)))

Expand All @@ -196,7 +194,7 @@ def broadcast(self, tensor: Union[torch.Tensor, Number, str], src: int = 0) -> U
tensor = self._apply_op(tensor, device, self._do_broadcast, src)

if tensor_to_number:
return cast(Number, tensor.item())
return tensor.item()
if tensor_to_str:
list_str = self._decode_str(tensor)
return list_str[0]
Expand Down Expand Up @@ -273,17 +271,15 @@ def create_from_backend(backend: Optional[str] = None, **kwargs: Any) -> "_Seria
def spawn(*args: Any, **kwargs: Any) -> None:
raise NotImplementedError("Serial computation model does not implement spawn method")

def all_reduce(self, tensor: Union[torch.Tensor, Number], op: str = "sum") -> Union[torch.Tensor, Number]:
def all_reduce(self, tensor: Union[torch.Tensor, float], op: str = "sum") -> Union[torch.Tensor, float]:
return tensor

def all_gather(
self, tensor: Union[torch.Tensor, Number, str]
) -> Union[torch.Tensor, Number, List[Number], List[str]]:
def all_gather(self, tensor: Union[torch.Tensor, float, str]) -> Union[torch.Tensor, float, List[float], List[str]]:
if isinstance(tensor, torch.Tensor):
return tensor
return cast(Union[List[Number], List[str]], [tensor])
return cast(Union[List[float], List[str]], [tensor])

def broadcast(self, tensor: Union[torch.Tensor, Number, str], src: int = 0) -> Union[torch.Tensor, Number, str]:
def broadcast(self, tensor: Union[torch.Tensor, float, str], src: int = 0) -> Union[torch.Tensor, float, str]:
return tensor

def _do_all_reduce(self, tensor: torch.Tensor, op: str = "sum") -> torch.Tensor:
Expand Down
7 changes: 3 additions & 4 deletions ignite/distributed/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import socket
from functools import wraps
from numbers import Number
from typing import Any, Callable, List, Mapping, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -316,7 +315,7 @@ def train_fn(local_rank, a, b, c, d=12):
)


def all_reduce(tensor: Union[torch.Tensor, Number], op: str = "SUM") -> Union[torch.Tensor, Number]:
def all_reduce(tensor: Union[torch.Tensor, float], op: str = "SUM") -> Union[torch.Tensor, float]:
"""Helper method to perform all reduce operation.

Args:
Expand All @@ -334,7 +333,7 @@ def all_reduce(tensor: Union[torch.Tensor, Number], op: str = "SUM") -> Union[to
return _model.all_reduce(tensor, op)


def all_gather(tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor, Number, List[Number], List[str]]:
def all_gather(tensor: Union[torch.Tensor, float, str]) -> Union[torch.Tensor, float, List[float], List[str]]:
"""Helper method to perform all gather operation.

Args:
Expand All @@ -352,7 +351,7 @@ def all_gather(tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor,
return _model.all_gather(tensor)


def broadcast(tensor: Union[torch.Tensor, Number, str], src: int = 0) -> Union[torch.Tensor, Number, str]:
def broadcast(tensor: Union[torch.Tensor, float, str], src: int = 0) -> Union[torch.Tensor, float, str]:
"""Helper method to perform broadcast operation.

Args:
Expand Down
19 changes: 9 additions & 10 deletions ignite/engine/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def state_dict(self) -> OrderedDict:
def _init_run(self) -> None:
self.state.seed = int(torch.randint(0, int(1e9), (1,)).item())
if not hasattr(self.state, "rng_states"):
self.state.rng_states = None # type: ignore[attr-defined]
setattr(self.state, "rng_states", None)

if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
Expand All @@ -203,21 +203,19 @@ def _setup_engine(self) -> None:
# attribute _dataset_kind is introduced since 1.3.0 => before 1.3.0 all datasets are map-like
can_patch_dataloader = True
if hasattr(self.state.dataloader, "_dataset_kind"):
from torch.utils.data.dataloader import _DatasetKind # type: ignore[attr-defined]
from torch.utils.data.dataloader import _DatasetKind

_dataloader_kind = self.state.dataloader._dataset_kind # type: ignore[attr-defined]
_dataloader_kind = self.state.dataloader._dataset_kind
can_patch_dataloader = _dataloader_kind == _DatasetKind.Map
if can_patch_dataloader:
if self._dataloader_len is not None and hasattr(
self.state.dataloader.sampler, "epoch" # type: ignore[attr-defined]
):
if self._dataloader_len is not None and hasattr(self.state.dataloader.sampler, "epoch"):
if self._dataloader_len != self.state.epoch_length:
warnings.warn(
"When defined engine's epoch length is different of input dataloader length, "
"distributed sampler indices can not be setup in a reproducible manner"
)

batch_sampler = self.state.dataloader.batch_sampler # type: ignore[attr-defined]
batch_sampler = self.state.dataloader.batch_sampler
if not (batch_sampler is None or isinstance(batch_sampler, ReproducibleBatchSampler)):
self.state.dataloader = update_dataloader(
self.state.dataloader, ReproducibleBatchSampler(batch_sampler) # type: ignore[arg-type]
Expand All @@ -233,9 +231,10 @@ def _setup_engine(self) -> None:

# restore rng state if in the middle
in_the_middle = self.state.iteration % self._dataloader_len > 0 if self._dataloader_len is not None else False
if (getattr(self.state, "rng_states", None) is not None) and in_the_middle:
_set_rng_states(self.state.rng_states) # type: ignore[attr-defined]
self.state.rng_states = None # type: ignore[attr-defined]
rng_states = getattr(self.state, "rng_states", None)
if rng_states is not None and in_the_middle:
_set_rng_states(rng_states)
setattr(self.state, "rng_states", None)

def _from_iteration(self, iteration: int) -> Iterator:
if self.state.dataloader is None:
Expand Down
2 changes: 1 addition & 1 deletion ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
return handler(*args, **kwargs)

# setup input handler as parent to make has_event_handler work
wrapper._parent = weakref.ref(handler) # type: ignore[attr-defined]
setattr(wrapper, "_parent", weakref.ref(handler))
return wrapper

def add_event_handler(self, event_name: Any, handler: Callable, *args: Any, **kwargs: Any) -> RemovableEventHandle:
Expand Down
2 changes: 1 addition & 1 deletion ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict, namedtuple
from tempfile import _TemporaryFileWrapper # type: ignore
from tempfile import _TemporaryFileWrapper # type: ignore[attr-defined]
from typing import Callable, Mapping, Optional, Union

import torch
Expand Down
2 changes: 1 addition & 1 deletion ignite/handlers/terminate_on_nan.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, output_transform: Callable = lambda x: x):
def __call__(self, engine: Engine) -> None:
output = self._output_transform(engine.state.output)

def raise_error(x: Union[numbers.Number, torch.Tensor]) -> None:
def raise_error(x: Union[float, torch.Tensor]) -> None:

if isinstance(x, numbers.Number):
x = torch.tensor(x)
Expand Down
18 changes: 9 additions & 9 deletions ignite/metrics/accumulation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numbers
from typing import Any, Callable, Tuple, Union, cast
from typing import Any, Callable, Tuple, Union

import torch

Expand Down Expand Up @@ -57,12 +57,12 @@ def reset(self) -> None:
self.accumulator = torch.tensor(0.0, dtype=torch.float64, device=self._device)
self.num_examples = 0

def _check_output_type(self, output: Union[Any, torch.Tensor, numbers.Number]) -> None:
def _check_output_type(self, output: Union[float, torch.Tensor]) -> None:
if not (isinstance(output, numbers.Number) or isinstance(output, torch.Tensor)):
raise TypeError("Output should be a number or torch.Tensor, but given {}".format(type(output)))

@reinit__is_reduced
def update(self, output: Union[Any, torch.Tensor, numbers.Number]) -> None:
def update(self, output: Union[float, torch.Tensor]) -> None:
self._check_output_type(output)

if isinstance(output, torch.Tensor):
Expand Down Expand Up @@ -125,14 +125,14 @@ def __init__(
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")
):
def _mean_op(a: Union[float, torch.Tensor], x: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
if isinstance(x, torch.Tensor) and x.ndim > 1: # type: ignore[attr-defined]
if isinstance(x, torch.Tensor) and x.ndim > 1:
x = x.sum(dim=0)
return a + x

super(Average, self).__init__(op=_mean_op, output_transform=output_transform, device=device)

@sync_all_reduce("accumulator", "num_examples")
def compute(self) -> Union[torch.Tensor, numbers.Number]:
def compute(self) -> Union[float, torch.Tensor]:
if self.num_examples < 1:
raise NotComputableError(
"{} must have at least one example before it can be computed.".format(self.__class__.__name__)
Expand Down Expand Up @@ -172,18 +172,18 @@ class GeometricAverage(VariableAccumulation):
def __init__(
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")
):
def _geom_op(a: torch.Tensor, x: Union[numbers.Number, torch.Tensor]) -> torch.Tensor:
def _geom_op(a: torch.Tensor, x: Union[float, torch.Tensor]) -> torch.Tensor:
if not isinstance(x, torch.Tensor):
x = torch.tensor(x)
x = torch.log(x)
if x.ndim > 1: # type: ignore[attr-defined]
if x.ndim > 1:
x = x.sum(dim=0)
return a + x

super(GeometricAverage, self).__init__(op=_geom_op, output_transform=output_transform, device=device)

@sync_all_reduce("accumulator", "num_examples")
def compute(self) -> Union[torch.Tensor, numbers.Number]:
def compute(self) -> Union[float, torch.Tensor]:
if self.num_examples < 1:
raise NotComputableError(
"{} must have at least one example before it can be computed.".format(self.__class__.__name__)
Expand All @@ -192,6 +192,6 @@ def compute(self) -> Union[torch.Tensor, numbers.Number]:
tensor = torch.exp(self.accumulator / self.num_examples)

if tensor.numel() == 1:
return cast(numbers.Number, tensor.item())
return tensor.item()

return tensor
2 changes: 1 addition & 1 deletion ignite/metrics/epoch_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def compute(self) -> float:

if ws > 1:
# broadcast result to all processes
result = cast(float, idist.broadcast(result, src=0)) # type: ignore[arg-type]
result = cast(float, idist.broadcast(result, src=0))

return result

Expand Down
4 changes: 2 additions & 2 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,10 @@ def __init__(
)

# Some metrics have a large performance regression when run on XLA devices, so for now, we disallow it.
if torch.device(device).type == "xla": # type: ignore[arg-type]
if torch.device(device).type == "xla":
raise ValueError("Cannot create metric on an XLA device. Use device='cpu' instead.")

self._device = torch.device(device) # type: ignore[arg-type]
self._device = torch.device(device)
self._is_reduced = False
self.reset()

Expand Down
8 changes: 3 additions & 5 deletions ignite/metrics/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,16 @@ def reset(self) -> None:
super(_BasePrecisionRecall, self).reset()

def compute(self) -> Union[torch.Tensor, float]:
is_scalar = (
not isinstance(self._positives, torch.Tensor) or self._positives.ndim == 0 # type: ignore[attr-defined]
)
is_scalar = not isinstance(self._positives, torch.Tensor) or self._positives.ndim == 0
if is_scalar and self._positives == 0:
raise NotComputableError(
"{} must have at least one example before it can be computed.".format(self.__class__.__name__)
)

if not (self._type == "multilabel" and not self._average):
if not self._is_reduced:
self._true_positives = idist.all_reduce(self._true_positives) # type: ignore[arg-type, assignment]
self._positives = idist.all_reduce(self._positives) # type: ignore[arg-type, assignment]
self._true_positives = idist.all_reduce(self._true_positives) # type: ignore[assignment]
self._positives = idist.all_reduce(self._positives) # type: ignore[assignment]
self._is_reduced = True # type: bool

result = self._true_positives / (self._positives + self.eps)
Expand Down