Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Image Dataloaders and Flax Resnet18 model #1

Merged
merged 7 commits into from
Jun 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
*__pycache__*
*.DS_Store*
.vscode
colabs/figures

# Distribution / packaging
*.egg-info/


results

# Logging
wandb

# Linting
.trunk
.flake8
.markdownlint.yaml
.isort.cfg

# Checkpoints
checkpoints
artifacts
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# jax-utils
# jaxutils
Common utilities in JAX/Flax to use across research projects
10 changes: 10 additions & 0 deletions data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

__all__ = [
'get_image_dataset',
'train_val_split_sizes',
'NumpyLoader',
'METADATA',
]

from .image import get_image_dataset, train_val_split_sizes, METADATA
from .utils import NumpyLoader
Binary file added data/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file added data/__pycache__/image.cpython-39.pyc
Binary file not shown.
228 changes: 228 additions & 0 deletions data/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
"""Image dataset loading functionality."""
from pathlib import Path
from typing import Tuple, Union

import numpy as np
import torch
from torch.utils import data
from torchvision import datasets, transforms

METADATA = {
'image_shape': {
'MNIST': (28, 28, 1),
'FashionMNIST': (28, 28, 1),
'KMNIST': (28, 28, 1),
'SVHN': (32, 32, 3),
'CIFAR10': (32, 32, 3),
'CIFAR100': (32, 32, 3),
'Imagenet': (224, 224, 3),
},
'num_train': {
'MNIST': 60_000,
'FashionMNIST': 60_000,
'KMNIST': 60_000,
'SVHN': 60_000,
'CIFAR10': 60_000,
'CIFAR100': 60_000,
},
'num_test': {
'MNIST': 10_000,
'FashionMNIST': 10_000,
'KMNIST': 10_000,
'SVHN': 10_000,
'CIFAR10': 10_000,
'CIFAR100': 10_000,
},
'mean': {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice addition!

'MNIST': (0.1307,),
'FashionMNIST': (0.2860,),
'SVHN': (0.4377, 0.4438, 0.4728),
'CIFAR10': (0.4914, 0.4822, 0.4465),
'CIFAR100': (0.5071, 0.4866, 0.4409),
'Imagenet': (0.485, 0.456, 0.406),
},
'std': {
'MNIST': (0.3081,),
'FashionMNIST': (0.3530,),
'SVHN': (0.1980, 0.2010, 0.1970),
'CIFAR10': (0.2470, 0.2435, 0.2616),
'CIFAR100': (0.2673, 0.2564, 0.2762),
'Imagenet': (0.229, 0.224, 0.225),
}
}


TRAIN_TRANSFORMATIONS = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also good!

'MNIST': [transforms.RandomCrop(28, padding=2)],
'FashionMNIST': [transforms.RandomCrop(28, padding=2)],
'SVHN': [transforms.RandomCrop(32, padding=4)],
'CIFAR10': [transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip()],
'CIFAR100': [transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(), ],
'Imagenet': [transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip()],
}

TEST_TRANSFORMATIONS = {
'MNIST': [],
'FashionMNIST': [],
'SVHN': [],
'CIFAR10': [],
'CIFAR100': [],
'Imagenet': [transforms.Resize(256), transforms.CenterCrop(224)],
}


class Flatten:
"""Transform to flatten an image for use with MLPs."""

def __call__(self, array: np.ndarray) -> np.ndarray:
return np.ravel(array)


class MoveChannelDim:
"""Transform to change from PyTorch image ordering to Jax/TF ordering."""

def __call__(self, array: np.ndarray) -> np.ndarray:
return np.moveaxis(array, 0, -1)


class ToNumpy:
"""Transform to convert from a PyTorch Tensor to a Numpy ndarray."""

def __call__(self, tensor: torch.Tensor) -> np.ndarray:
return np.array(tensor, dtype=np.float32)


def get_image_dataset(
dataset_name: str,
data_dir: str = "../raw_data",
flatten_img: bool = False,
val_percent: float = 0.1,
random_seed: int = 42,
perform_augmentations: bool = True,
) -> Union[
Tuple[data.Dataset, data.Dataset], Tuple[data.Dataset,
data.Dataset, data.Dataset]
]:
"""Provides PyTorch `Dataset`s for the specified image dataset_name.
Args:
dataset_name: the `str` name of the dataset. E.g. `'MNIST'`.
data_dir: the `str` directory where the datasets should be downloaded to
and loaded from. (Default: `'../raw_data'`)
flatten_img: a `bool` indicating whether images should be flattened.
(Default: `False`)
val_percent: the `float` percentage of training data to use for
validation. (Default: `0.1`)
random_seed: the `int` random seed for splitting the val data and
applying random affine transformations. (Default: 42)
perform_augmentations: a `bool` indicating whether to apply random
transformations to the training data. (Default: `True`)
Returns:
`(train_dataset, test_dataset)` if `val_percent` is 0 otherwise
`(train_dataset, test_dataset, val_dataset)`
"""
dataset_choices = [
"MNIST",
"FashionMNIST",
"SVHN",
"CIFAR10",
"CIFAR100",
"Imagenet",
]
if dataset_name not in dataset_choices:
msg = f"Dataset should be one of {dataset_choices} but was {dataset_name} instead."
raise RuntimeError(msg)

if dataset_name in ["MNIST", "FashionMNIST", "CIFAR10", "CIFAR100", "Imagenet"]:
train_kwargs = {"train": True}
test_kwargs = {"train": False}

elif dataset_name == "SVHN":
train_kwargs = {"split": "train"}
test_kwargs = {"split": "test"}

data_dir = Path(data_dir).resolve()

common_transforms = [
ToNumpy(),
MoveChannelDim(),
]

if flatten_img:
common_transforms += [Flatten()]

# We need to disable train augmentations when fitting mode of linear model
# and for sample-then-optimise posterior sampling.
if perform_augmentations:
train_augmentations = TRAIN_TRANSFORMATIONS[dataset_name]
else:
train_augmentations = TEST_TRANSFORMATIONS[dataset_name]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the non-imagenet cases, this makes sense since the test augementations are empty, but for imagenet does it make sense to be applying transformations when the user of the function has set perform_augmentations=False?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason we need Resize(256) and CenterCrop(224) for Imagenet is because, by default, test images in Imagenet are all of random sizes, and not uniform. So we still need some deterministic preprocessing to ensure all images are of size 224x224x3.


transform_train = transforms.Compose(
train_augmentations
+ [
transforms.ToTensor(),
transforms.Normalize(
METADATA['mean'][dataset_name], METADATA['std'][dataset_name]),
]
+ common_transforms
)

transform_test = transforms.Compose(
TEST_TRANSFORMATIONS[dataset_name]
+ [
transforms.ToTensor(),
transforms.Normalize(
METADATA['mean'][dataset_name], METADATA['std'][dataset_name]),
]
+ common_transforms
)

if dataset_name == "Imagenet":
train_dir = data_dir / "imagenet/train"
val_dir = data_dir / "imagenet/val"

train_dataset = datasets.ImageFolder(
train_dir, transform=transform_train)

test_dataset = datasets.ImageFolder(val_dir, transform=transform_test)

else:
dataset = getattr(datasets, dataset_name)
train_dataset = dataset(
**train_kwargs,
transform=transform_train,
download=True,
root=data_dir,
)
test_dataset = dataset(
**test_kwargs,
transform=transform_test,
download=True,
root=data_dir,
)

if val_percent != 0.0:
num_train, num_val = train_val_split_sizes(
len(train_dataset), val_percent)

train_dataset, val_dataset = data.random_split(
train_dataset,
[num_train, num_val],
torch.Generator().manual_seed(random_seed)
if random_seed is not None
else None,
)

return train_dataset, test_dataset, val_dataset
else:
return train_dataset, test_dataset


def train_val_split_sizes(num_train: int, val_percent: float) -> Tuple[int, int]:
num_val = int(val_percent * num_train)
num_train = num_train - num_val

return num_train, num_val
49 changes: 49 additions & 0 deletions data/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""A PyTorch dataloader that returns `np.ndarray` batches.
Taken from: colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb
"""
import numpy as np
from torch.utils import data


def _numpy_collate(batch):
if isinstance(batch[0], np.ndarray):
return np.stack(batch)
elif isinstance(batch[0], (tuple, list)):
transposed = zip(*batch)
return [_numpy_collate(samples) for samples in transposed]
else:
return np.array(batch)


class NumpyLoader(data.DataLoader):
"""A PyTorch dataloader that returns `np.ndarray` batches, which can be used with Jax!
"""

def __init__(
self,
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
generator=None
):
super().__init__(
dataset,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=_numpy_collate,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn,
generator=generator,
)
Empty file added models/__init__.py
Empty file.
Loading