Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Update Image training code for downstream datasets (iWildcam, transfer datasets) #2310

Closed
wants to merge 6 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/sparseml/pytorch/torchvision/presets.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@ def __init__(
self,
*,
crop_size,
resize_size=None,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR,
@@ -20,7 +21,14 @@ def __init__(
augmix_severity=3,
random_erase_prob=0.0,
):
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
if resize_size is not None:
trans = [
transforms.Resize(resize_size, interpolation=interpolation),
]
else:
trans = [
transforms.RandomResizedCrop(crop_size, interpolation=interpolation)
]
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
@@ -73,7 +81,6 @@ def __init__(
std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR,
):

self.transforms = transforms.Compose(
[
transforms.Resize(resize_size, interpolation=interpolation),
108 changes: 83 additions & 25 deletions src/sparseml/pytorch/torchvision/train.py
Original file line number Diff line number Diff line change
@@ -42,6 +42,7 @@
from packaging import version
from torch import nn
from torch.utils.data.dataloader import DataLoader, default_collate
from torchvision.datasets import DTD, FGVCAircraft, Flowers102 # noqa: F401
from torchvision.transforms.functional import InterpolationMode

import click
@@ -278,11 +279,21 @@ def _get_cache_path(filepath):


def load_data(traindir, valdir, args):
if args.transfer_dataset is not None:
if args.transfer_dataset not in ("FGVCAircraft", "DTD", "Flowers102"):
raise ValueError(
"FGVCAircraft, DTD, and Flowers102 are allowed as transfer_datasets."
)
# Data loading code
_LOGGER.info("Loading data")
val_resize_size, val_crop_size, train_crop_size = (
if args.train_resize_size is not None:
args.train_resize_size = [args.train_resize_size, args.train_resize_size]
if args.val_resize_size is not None:
args.val_resize_size = [args.val_resize_size, args.val_resize_size]
val_resize_size, val_crop_size, train_resize_size, train_crop_size = (
args.val_resize_size,
args.val_crop_size,
args.train_resize_size,
args.train_crop_size,
)
interpolation = InterpolationMode(args.interpolation)
@@ -299,19 +310,26 @@ def load_data(traindir, valdir, args):
random_erase_prob = getattr(args, "random_erase", 0.0)
ra_magnitude = args.ra_magnitude
augmix_severity = args.augmix_severity
dataset = torchvision.datasets.ImageFolder(
traindir,
presets.ClassificationPresetTrain(
crop_size=train_crop_size,
mean=args.rgb_mean,
std=args.rgb_std,
interpolation=interpolation,
auto_augment_policy=auto_augment_policy,
random_erase_prob=random_erase_prob,
ra_magnitude=ra_magnitude,
augmix_severity=augmix_severity,
),
train_transforms = presets.ClassificationPresetTrain(
crop_size=train_crop_size,
resize_size=train_resize_size,
mean=args.rgb_mean,
std=args.rgb_std,
interpolation=interpolation,
auto_augment_policy=auto_augment_policy,
random_erase_prob=random_erase_prob,
ra_magnitude=ra_magnitude,
augmix_severity=augmix_severity,
)
if args.transfer_dataset is None:
dataset = torchvision.datasets.ImageFolder(traindir, train_transforms)
else:
dataset = eval(args.transfer_dataset)(
root=f"/tmp/{args.transfer_dataset}",
split=args.transfer_dataset_train_split,
transform=train_transforms,
download=True,
)
if args.cache_dataset:
_LOGGER.info(f"Saving dataset_train to {cache_path}")
utils.mkdir(os.path.dirname(cache_path))
@@ -333,10 +351,18 @@ def load_data(traindir, valdir, args):
interpolation=interpolation,
)

dataset_test = torchvision.datasets.ImageFolder(
valdir,
preprocessing,
)
if args.transfer_dataset is None:
dataset_test = torchvision.datasets.ImageFolder(
valdir,
preprocessing,
)
else:
dataset_test = eval(args.transfer_dataset)(
root=f"/tmp/{args.transfer_dataset}",
split=args.transfer_dataset_test_split,
transform=preprocessing,
download=True,
)
if args.cache_dataset:
_LOGGER.info(f"Saving dataset_test to {cache_path}")
utils.mkdir(os.path.dirname(cache_path))
@@ -389,9 +415,15 @@ def main(args):
dataset, dataset_test, train_sampler, test_sampler = load_data(
train_dir, val_dir, args
)

collate_fn = None
num_classes = len(dataset.classes)
try:
num_classes = len(dataset.classes)
except: # noqa: E722
# For some reason, the classes method is not implemented for Flowers102.
if args.transfer_dataset == "Flowers102":
num_classes = 102
else:
raise ValueError(f"unknown number of classes for {args.transfer_dataset}")
mixup_transforms = []
if args.mixup_alpha > 0.0:
mixup_transforms.append(
@@ -404,7 +436,7 @@ def main(args):
if mixup_transforms:
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)

def collate_fn(batch):
def collate_fn(batch): # noqa: F811
return mixupcutmix(*default_collate(batch))

data_loader = torch.utils.data.DataLoader(
@@ -475,9 +507,9 @@ def collate_fn(batch):
model,
args.weight_decay,
norm_weight_decay=args.norm_weight_decay,
custom_keys_weight_decay=custom_keys_weight_decay
if len(custom_keys_weight_decay) > 0
else None,
custom_keys_weight_decay=(
custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None
),
)

opt_name = args.opt.lower()
@@ -986,6 +1018,26 @@ def new_func(*args, **kwargs):
help="json parsable dict of recipe variable names to values to overwrite with",
)
@click.option("--dataset-path", required=True, type=str, help="dataset path")
@click.option(
"--transfer-dataset",
required=False,
type=str,
help="Dataset to be loaded using torchvision class.",
)
@click.option(
"--transfer-dataset-train-split",
required=False,
type=str,
default="train",
help="Train split name for transfer dataset",
)
@click.option(
"--transfer-dataset-test-split",
required=False,
type=str,
default="test",
help="Test split name for transfer dataset",
)
@click.option(
"--arch-key",
default=None,
@@ -1205,19 +1257,25 @@ def new_func(*args, **kwargs):
"--val-resize-size",
default=256,
type=int,
help="the resize size used for validation",
help="the resize size used for validation (always square)",
)
@click.option(
"--val-crop-size",
default=224,
type=int,
help="the central crop size used for validation",
)
@click.option(
"--train-resize-size",
default=None,
type=int,
help="If set, the resize size used for training (always square)",
)
@click.option(
"--train-crop-size",
default=224,
type=int,
help="the random crop size used for training",
help="the random crop size used for training (always square)",
)
@click.option(
"--clip-grad-norm",