-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Image Dataloaders and Flax Resnet18 model #1
Changes from all commits
5537792
1694e98
f47209d
ed083bf
af5695c
92b3f5c
76034f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
*__pycache__* | ||
*.DS_Store* | ||
.vscode | ||
colabs/figures | ||
|
||
# Distribution / packaging | ||
*.egg-info/ | ||
|
||
|
||
results | ||
|
||
# Logging | ||
wandb | ||
|
||
# Linting | ||
.trunk | ||
.flake8 | ||
.markdownlint.yaml | ||
.isort.cfg | ||
|
||
# Checkpoints | ||
checkpoints | ||
artifacts |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
# jax-utils | ||
# jaxutils | ||
Common utilities in JAX/Flax to use across research projects |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
|
||
__all__ = [ | ||
'get_image_dataset', | ||
'train_val_split_sizes', | ||
'NumpyLoader', | ||
'METADATA', | ||
] | ||
|
||
from .image import get_image_dataset, train_val_split_sizes, METADATA | ||
from .utils import NumpyLoader |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
"""Image dataset loading functionality.""" | ||
from pathlib import Path | ||
from typing import Tuple, Union | ||
|
||
import numpy as np | ||
import torch | ||
from torch.utils import data | ||
from torchvision import datasets, transforms | ||
|
||
METADATA = { | ||
'image_shape': { | ||
'MNIST': (28, 28, 1), | ||
'FashionMNIST': (28, 28, 1), | ||
'KMNIST': (28, 28, 1), | ||
'SVHN': (32, 32, 3), | ||
'CIFAR10': (32, 32, 3), | ||
'CIFAR100': (32, 32, 3), | ||
'Imagenet': (224, 224, 3), | ||
}, | ||
'num_train': { | ||
'MNIST': 60_000, | ||
'FashionMNIST': 60_000, | ||
'KMNIST': 60_000, | ||
'SVHN': 60_000, | ||
'CIFAR10': 60_000, | ||
'CIFAR100': 60_000, | ||
}, | ||
'num_test': { | ||
'MNIST': 10_000, | ||
'FashionMNIST': 10_000, | ||
'KMNIST': 10_000, | ||
'SVHN': 10_000, | ||
'CIFAR10': 10_000, | ||
'CIFAR100': 10_000, | ||
}, | ||
'mean': { | ||
'MNIST': (0.1307,), | ||
'FashionMNIST': (0.2860,), | ||
'SVHN': (0.4377, 0.4438, 0.4728), | ||
'CIFAR10': (0.4914, 0.4822, 0.4465), | ||
'CIFAR100': (0.5071, 0.4866, 0.4409), | ||
'Imagenet': (0.485, 0.456, 0.406), | ||
}, | ||
'std': { | ||
'MNIST': (0.3081,), | ||
'FashionMNIST': (0.3530,), | ||
'SVHN': (0.1980, 0.2010, 0.1970), | ||
'CIFAR10': (0.2470, 0.2435, 0.2616), | ||
'CIFAR100': (0.2673, 0.2564, 0.2762), | ||
'Imagenet': (0.229, 0.224, 0.225), | ||
} | ||
} | ||
|
||
|
||
TRAIN_TRANSFORMATIONS = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is also good! |
||
'MNIST': [transforms.RandomCrop(28, padding=2)], | ||
'FashionMNIST': [transforms.RandomCrop(28, padding=2)], | ||
'SVHN': [transforms.RandomCrop(32, padding=4)], | ||
'CIFAR10': [transforms.RandomCrop(32, padding=4), | ||
transforms.RandomHorizontalFlip()], | ||
'CIFAR100': [transforms.RandomCrop(32, padding=4), | ||
transforms.RandomHorizontalFlip(), ], | ||
'Imagenet': [transforms.RandomResizedCrop(224), | ||
transforms.RandomHorizontalFlip()], | ||
} | ||
|
||
TEST_TRANSFORMATIONS = { | ||
'MNIST': [], | ||
'FashionMNIST': [], | ||
'SVHN': [], | ||
'CIFAR10': [], | ||
'CIFAR100': [], | ||
'Imagenet': [transforms.Resize(256), transforms.CenterCrop(224)], | ||
} | ||
|
||
|
||
class Flatten: | ||
"""Transform to flatten an image for use with MLPs.""" | ||
|
||
def __call__(self, array: np.ndarray) -> np.ndarray: | ||
return np.ravel(array) | ||
|
||
|
||
class MoveChannelDim: | ||
"""Transform to change from PyTorch image ordering to Jax/TF ordering.""" | ||
|
||
def __call__(self, array: np.ndarray) -> np.ndarray: | ||
return np.moveaxis(array, 0, -1) | ||
|
||
|
||
class ToNumpy: | ||
"""Transform to convert from a PyTorch Tensor to a Numpy ndarray.""" | ||
|
||
def __call__(self, tensor: torch.Tensor) -> np.ndarray: | ||
return np.array(tensor, dtype=np.float32) | ||
|
||
|
||
def get_image_dataset( | ||
dataset_name: str, | ||
data_dir: str = "../raw_data", | ||
flatten_img: bool = False, | ||
val_percent: float = 0.1, | ||
random_seed: int = 42, | ||
perform_augmentations: bool = True, | ||
) -> Union[ | ||
Tuple[data.Dataset, data.Dataset], Tuple[data.Dataset, | ||
data.Dataset, data.Dataset] | ||
]: | ||
"""Provides PyTorch `Dataset`s for the specified image dataset_name. | ||
Args: | ||
dataset_name: the `str` name of the dataset. E.g. `'MNIST'`. | ||
data_dir: the `str` directory where the datasets should be downloaded to | ||
and loaded from. (Default: `'../raw_data'`) | ||
flatten_img: a `bool` indicating whether images should be flattened. | ||
(Default: `False`) | ||
val_percent: the `float` percentage of training data to use for | ||
validation. (Default: `0.1`) | ||
random_seed: the `int` random seed for splitting the val data and | ||
applying random affine transformations. (Default: 42) | ||
perform_augmentations: a `bool` indicating whether to apply random | ||
transformations to the training data. (Default: `True`) | ||
Returns: | ||
`(train_dataset, test_dataset)` if `val_percent` is 0 otherwise | ||
`(train_dataset, test_dataset, val_dataset)` | ||
""" | ||
dataset_choices = [ | ||
"MNIST", | ||
"FashionMNIST", | ||
"SVHN", | ||
"CIFAR10", | ||
"CIFAR100", | ||
"Imagenet", | ||
] | ||
if dataset_name not in dataset_choices: | ||
msg = f"Dataset should be one of {dataset_choices} but was {dataset_name} instead." | ||
raise RuntimeError(msg) | ||
|
||
if dataset_name in ["MNIST", "FashionMNIST", "CIFAR10", "CIFAR100", "Imagenet"]: | ||
train_kwargs = {"train": True} | ||
test_kwargs = {"train": False} | ||
|
||
elif dataset_name == "SVHN": | ||
train_kwargs = {"split": "train"} | ||
test_kwargs = {"split": "test"} | ||
|
||
data_dir = Path(data_dir).resolve() | ||
|
||
common_transforms = [ | ||
ToNumpy(), | ||
MoveChannelDim(), | ||
] | ||
|
||
if flatten_img: | ||
common_transforms += [Flatten()] | ||
|
||
# We need to disable train augmentations when fitting mode of linear model | ||
# and for sample-then-optimise posterior sampling. | ||
if perform_augmentations: | ||
train_augmentations = TRAIN_TRANSFORMATIONS[dataset_name] | ||
else: | ||
train_augmentations = TEST_TRANSFORMATIONS[dataset_name] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the non-imagenet cases, this makes sense since the test augementations are empty, but for imagenet does it make sense to be applying transformations when the user of the function has set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason we need Resize(256) and CenterCrop(224) for Imagenet is because, by default, test images in Imagenet are all of random sizes, and not uniform. So we still need some deterministic preprocessing to ensure all images are of size 224x224x3. |
||
|
||
transform_train = transforms.Compose( | ||
train_augmentations | ||
+ [ | ||
transforms.ToTensor(), | ||
transforms.Normalize( | ||
METADATA['mean'][dataset_name], METADATA['std'][dataset_name]), | ||
] | ||
+ common_transforms | ||
) | ||
|
||
transform_test = transforms.Compose( | ||
TEST_TRANSFORMATIONS[dataset_name] | ||
+ [ | ||
transforms.ToTensor(), | ||
transforms.Normalize( | ||
METADATA['mean'][dataset_name], METADATA['std'][dataset_name]), | ||
] | ||
+ common_transforms | ||
) | ||
|
||
if dataset_name == "Imagenet": | ||
train_dir = data_dir / "imagenet/train" | ||
val_dir = data_dir / "imagenet/val" | ||
|
||
train_dataset = datasets.ImageFolder( | ||
train_dir, transform=transform_train) | ||
|
||
test_dataset = datasets.ImageFolder(val_dir, transform=transform_test) | ||
|
||
else: | ||
dataset = getattr(datasets, dataset_name) | ||
train_dataset = dataset( | ||
**train_kwargs, | ||
transform=transform_train, | ||
download=True, | ||
root=data_dir, | ||
) | ||
test_dataset = dataset( | ||
**test_kwargs, | ||
transform=transform_test, | ||
download=True, | ||
root=data_dir, | ||
) | ||
|
||
if val_percent != 0.0: | ||
num_train, num_val = train_val_split_sizes( | ||
len(train_dataset), val_percent) | ||
|
||
train_dataset, val_dataset = data.random_split( | ||
train_dataset, | ||
[num_train, num_val], | ||
torch.Generator().manual_seed(random_seed) | ||
if random_seed is not None | ||
else None, | ||
) | ||
|
||
return train_dataset, test_dataset, val_dataset | ||
else: | ||
return train_dataset, test_dataset | ||
|
||
|
||
def train_val_split_sizes(num_train: int, val_percent: float) -> Tuple[int, int]: | ||
num_val = int(val_percent * num_train) | ||
num_train = num_train - num_val | ||
|
||
return num_train, num_val |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
"""A PyTorch dataloader that returns `np.ndarray` batches. | ||
Taken from: colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb | ||
""" | ||
import numpy as np | ||
from torch.utils import data | ||
|
||
|
||
def _numpy_collate(batch): | ||
if isinstance(batch[0], np.ndarray): | ||
return np.stack(batch) | ||
elif isinstance(batch[0], (tuple, list)): | ||
transposed = zip(*batch) | ||
return [_numpy_collate(samples) for samples in transposed] | ||
else: | ||
return np.array(batch) | ||
|
||
|
||
class NumpyLoader(data.DataLoader): | ||
"""A PyTorch dataloader that returns `np.ndarray` batches, which can be used with Jax! | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dataset, | ||
batch_size=1, | ||
shuffle=False, | ||
sampler=None, | ||
batch_sampler=None, | ||
num_workers=0, | ||
pin_memory=False, | ||
drop_last=False, | ||
timeout=0, | ||
worker_init_fn=None, | ||
generator=None | ||
): | ||
super().__init__( | ||
dataset, | ||
batch_size=batch_size, | ||
shuffle=shuffle, | ||
sampler=sampler, | ||
batch_sampler=batch_sampler, | ||
num_workers=num_workers, | ||
collate_fn=_numpy_collate, | ||
pin_memory=pin_memory, | ||
drop_last=drop_last, | ||
timeout=timeout, | ||
worker_init_fn=worker_init_fn, | ||
generator=generator, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice addition!