diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py index 829fa2560fd..12771a11efa 100644 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ b/torchvision/prototype/datasets/_builtin/celeba.py @@ -30,25 +30,26 @@ def __init__( def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]: for _, file in self.datapipe: - lines = (line.decode() for line in file) - - if self.fieldnames: - fieldnames = self.fieldnames - else: - # The first row is skipped, because it only contains the number of samples - next(lines) - - # Empty field names are filtered out, because some files have an extra white space after the header - # line, which is recognized as extra column - fieldnames = [name for name in next(csv.reader([next(lines)], dialect="celeba")) if name] - # Some files do not include a label for the image ID column - if fieldnames[0] != "image_id": - fieldnames.insert(0, "image_id") - - for line in csv.DictReader(lines, fieldnames=fieldnames, dialect="celeba"): - yield line.pop("image_id"), line - - file.close() + try: + lines = (line.decode() for line in file) + + if self.fieldnames: + fieldnames = self.fieldnames + else: + # The first row is skipped, because it only contains the number of samples + next(lines) + + # Empty field names are filtered out, because some files have an extra white space after the header + # line, which is recognized as extra column + fieldnames = [name for name in next(csv.reader([next(lines)], dialect="celeba")) if name] + # Some files do not include a label for the image ID column + if fieldnames[0] != "image_id": + fieldnames.insert(0, "image_id") + + for line in csv.DictReader(lines, fieldnames=fieldnames, dialect="celeba"): + yield line.pop("image_id"), line + finally: + file.close() NAME = "celeba" diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index c13836a8c4c..97d729d530b 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -37,27 +37,28 @@ def __init__( def __iter__(self) -> Iterator[torch.Tensor]: for _, file in self.datapipe: - read = functools.partial(fromfile, file, byte_order="big") + try: + read = functools.partial(fromfile, file, byte_order="big") - magic = int(read(dtype=torch.int32, count=1)) - dtype = self._DTYPE_MAP[magic // 256] - ndim = magic % 256 - 1 + magic = int(read(dtype=torch.int32, count=1)) + dtype = self._DTYPE_MAP[magic // 256] + ndim = magic % 256 - 1 - num_samples = int(read(dtype=torch.int32, count=1)) - shape = cast(List[int], read(dtype=torch.int32, count=ndim).tolist()) if ndim else [] - count = prod(shape) if shape else 1 + num_samples = int(read(dtype=torch.int32, count=1)) + shape = cast(List[int], read(dtype=torch.int32, count=ndim).tolist()) if ndim else [] + count = prod(shape) if shape else 1 - start = self.start or 0 - stop = min(self.stop, num_samples) if self.stop else num_samples + start = self.start or 0 + stop = min(self.stop, num_samples) if self.stop else num_samples - if start: - num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8 - file.seek(num_bytes_per_value * count * start, 1) + if start: + num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8 + file.seek(num_bytes_per_value * count * start, 1) - for _ in range(stop - start): - yield read(dtype=dtype, count=count).reshape(shape) - - file.close() + for _ in range(stop - start): + yield read(dtype=dtype, count=count).reshape(shape) + finally: + file.close() class _MNISTBase(Dataset): diff --git a/torchvision/prototype/datasets/_builtin/pcam.py b/torchvision/prototype/datasets/_builtin/pcam.py index 3a9fe6e9031..f533ba18084 100644 --- a/torchvision/prototype/datasets/_builtin/pcam.py +++ b/torchvision/prototype/datasets/_builtin/pcam.py @@ -28,12 +28,13 @@ def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]: import h5py for _, handle in self.datapipe: - with h5py.File(handle) as data: - if self.key is not None: - data = data[self.key] - yield from data - - handle.close() + try: + with h5py.File(handle) as data: + if self.key is not None: + data = data[self.key] + yield from data + finally: + handle.close() _Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256"))