Skip to content

Commit

Permalink
Merge branch 'main' into bugfix/coco-typecheck
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Mar 13, 2024
2 parents 60f86b6 + 6d64cb3 commit 3187381
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 16 deletions.
21 changes: 21 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,6 +1625,10 @@ def inject_fake_data(self, tmpdir, config):
num_examples_total += num_examples
classes.append(cls)

if config.pop("make_empty_class", False):
os.makedirs(pathlib.Path(tmpdir) / "empty_class")
classes.append("empty_class")

return dict(num_examples=num_examples_total, classes=classes)

def _file_name_fn(self, cls, ext, idx):
Expand All @@ -1649,6 +1653,23 @@ def test_classes(self, config):
assert len(dataset.classes) == len(info["classes"])
assert all([a == b for a, b in zip(dataset.classes, info["classes"])])

def test_allow_empty(self):
config = {
"extensions": self._EXTENSIONS,
"make_empty_class": True,
}

config["allow_empty"] = True
with self.create_dataset(config) as (dataset, info):
assert "empty_class" in dataset.classes
assert len(dataset.classes) == len(info["classes"])
assert all([a == b for a, b in zip(dataset.classes, info["classes"])])

config["allow_empty"] = False
with pytest.raises(FileNotFoundError, match="Found no valid file"):
with self.create_dataset(config) as (dataset, info):
pass


class ImageFolderTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.ImageFolder
Expand Down
8 changes: 4 additions & 4 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1614,8 +1614,8 @@ def test_augmix(fill, severity, mixture_width, chain_depth, all_ops, grayscale):
def test_random_crop():
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
oheight = random.randint(5, (height - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2
oheight = random.randint(5, (height - 2) // 2) * 2
owidth = random.randint(5, (width - 2) // 2) * 2
img = torch.ones(3, height, width, dtype=torch.uint8)
result = transforms.Compose(
[
Expand Down Expand Up @@ -1664,8 +1664,8 @@ def test_random_crop():
def test_center_crop():
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
oheight = random.randint(5, (height - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2
oheight = random.randint(5, (height - 2) // 2) * 2
owidth = random.randint(5, (width - 2) // 2) * 2

img = torch.ones(3, height, width, dtype=torch.uint8)
oh1 = (height - oheight) // 2
Expand Down
12 changes: 6 additions & 6 deletions test/test_transforms_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def test_random_crop_video(self):
numFrames = random.randint(4, 128)
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
oheight = random.randint(5, (height - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2
oheight = random.randint(5, (height - 2) // 2) * 2
owidth = random.randint(5, (width - 2) // 2) * 2
clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8)
result = Compose(
[
Expand All @@ -41,8 +41,8 @@ def test_random_resized_crop_video(self):
numFrames = random.randint(4, 128)
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
oheight = random.randint(5, (height - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2
oheight = random.randint(5, (height - 2) // 2) * 2
owidth = random.randint(5, (width - 2) // 2) * 2
clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8)
result = Compose(
[
Expand All @@ -59,8 +59,8 @@ def test_center_crop_video(self):
numFrames = random.randint(4, 128)
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
oheight = random.randint(5, (height - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2
oheight = random.randint(5, (height - 2) // 2) * 2
owidth = random.randint(5, (width - 2) // 2) * 2

clip = torch.ones((numFrames, height, width, 3), dtype=torch.uint8) * 255
oh1 = (height - oheight) // 2
Expand Down
25 changes: 22 additions & 3 deletions torchvision/datasets/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def make_dataset(
class_to_idx: Optional[Dict[str, int]] = None,
extensions: Optional[Union[str, Tuple[str, ...]]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
allow_empty: bool = False,
) -> List[Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).
Expand Down Expand Up @@ -95,7 +96,7 @@ def is_valid_file(x: str) -> bool:
available_classes.add(target_class)

empty_classes = set(class_to_idx.keys()) - available_classes
if empty_classes:
if empty_classes and not allow_empty:
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
if extensions is not None:
msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
Expand Down Expand Up @@ -123,6 +124,8 @@ class DatasetFolder(VisionDataset):
is_valid_file (callable, optional): A function that takes path of a file
and check if the file is a valid file (used to check of corrupt files)
both extensions and is_valid_file should not be passed.
allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
An error is raised on empty folders if False (default).
Attributes:
classes (list): List of the class names sorted alphabetically.
Expand All @@ -139,10 +142,17 @@ def __init__(
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
allow_empty: bool = False,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
classes, class_to_idx = self.find_classes(self.root)
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
samples = self.make_dataset(
self.root,
class_to_idx=class_to_idx,
extensions=extensions,
is_valid_file=is_valid_file,
allow_empty=allow_empty,
)

self.loader = loader
self.extensions = extensions
Expand All @@ -158,6 +168,7 @@ def make_dataset(
class_to_idx: Dict[str, int],
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
allow_empty: bool = False,
) -> List[Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).
Expand All @@ -172,6 +183,8 @@ def make_dataset(
and checks if the file is a valid file
(used to check of corrupt files) both extensions and
is_valid_file should not be passed. Defaults to None.
allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
An error is raised on empty folders if False (default).
Raises:
ValueError: In case ``class_to_idx`` is empty.
Expand All @@ -186,7 +199,9 @@ def make_dataset(
# find_classes() function, instead of using that of the find_classes() method, which
# is potentially overridden and thus could have a different logic.
raise ValueError("The class_to_idx parameter cannot be None.")
return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)
return make_dataset(
directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file, allow_empty=allow_empty
)

def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Find the class folders in a dataset structured as follows::
Expand Down Expand Up @@ -291,6 +306,8 @@ class ImageFolder(DatasetFolder):
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
An error is raised on empty folders if False (default).
Attributes:
classes (list): List of the class names sorted alphabetically.
Expand All @@ -305,6 +322,7 @@ def __init__(
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
allow_empty: bool = False,
):
super().__init__(
root,
Expand All @@ -313,5 +331,6 @@ def __init__(
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file,
allow_empty=allow_empty,
)
self.imgs = self.samples
7 changes: 4 additions & 3 deletions torchvision/models/maxvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ class MaxVit(nn.Module):
stochastic_depth_prob (float): Probability of stochastic depth. Expands to a list of probabilities for each layer that scales linearly to the specified value.
squeeze_ratio (float): Squeeze ratio in the SE Layer. Default: 0.25.
expansion_ratio (float): Expansion ratio in the MBConv bottleneck. Default: 4.
norm_layer (Callable[..., nn.Module]): Normalization function. Default: None (setting to None will produce a `BatchNorm2d(eps=1e-3, momentum=0.99)`).
norm_layer (Callable[..., nn.Module]): Normalization function. Default: None (setting to None will produce a `BatchNorm2d(eps=1e-3, momentum=0.01)`).
activation_layer (Callable[..., nn.Module]): Activation function Default: nn.GELU.
head_dim (int): Dimension of the attention heads.
mlp_ratio (int): Expansion ratio of the MLP layer. Default: 4.
Expand Down Expand Up @@ -623,7 +623,7 @@ def __init__(
# https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L1029-L1030
# for the exact parameters used in batchnorm
if norm_layer is None:
norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.99)
norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.01)

# Make sure input size will be divisible by the partition size in all blocks
# Undefined behavior if H or W are not divisible by p
Expand Down Expand Up @@ -788,7 +788,8 @@ class MaxVit_T_Weights(WeightsEnum):
},
"_ops": 5.558,
"_file_size": 118.769,
"_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
"_docs": """These weights reproduce closely the results of the paper using a similar training recipe.
They were trained with a BatchNorm2D momentum of 0.99 instead of the more correct 0.01.""",
},
)
DEFAULT = IMAGENET1K_V1
Expand Down

0 comments on commit 3187381

Please sign in to comment.