From 06da95d7959b3116cd86af8ce66930ead9796119 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Sat, 24 Feb 2024 20:02:26 -0800 Subject: [PATCH 1/3] refactor: Vimeo90kDataset extract constants --- compressai/datasets/vimeo90k.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/compressai/datasets/vimeo90k.py b/compressai/datasets/vimeo90k.py index 770f49af..f606533b 100644 --- a/compressai/datasets/vimeo90k.py +++ b/compressai/datasets/vimeo90k.py @@ -64,6 +64,9 @@ class Vimeo90kDataset(Dataset): tuplet (int): order of dataset tuplet (e.g. 3 for "triplet" dataset) """ + TUPLET_PREFIX = {3: "tri", 7: "sep"} + SPLIT_TO_LIST_SUFFIX = {"train": "trainlist", "valid": "testlist"} + def __init__(self, root, transform=None, split="train", tuplet=3): list_path = Path(root) / self._list_filename(split, tuplet) @@ -94,6 +97,6 @@ def __len__(self): return len(self.samples) def _list_filename(self, split: str, tuplet: int) -> str: - tuplet_prefix = {3: "tri", 7: "sep"}[tuplet] - list_suffix = {"train": "trainlist", "valid": "testlist"}[split] + tuplet_prefix = self.TUPLET_PREFIX[tuplet] + list_suffix = self.SPLIT_TO_LIST_SUFFIX[split] return f"{tuplet_prefix}_{list_suffix}.txt" From 1a4ce599f1339bcbd1a251fc021154a342c0a822 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Sat, 24 Feb 2024 23:23:10 -0800 Subject: [PATCH 2/3] feat: Vimeo90kDataset support image/video modes --- compressai/datasets/vimeo90k.py | 71 ++++++++++++++++++++++++--------- 1 file changed, 52 insertions(+), 19 deletions(-) diff --git a/compressai/datasets/vimeo90k.py b/compressai/datasets/vimeo90k.py index f606533b..579c90a1 100644 --- a/compressai/datasets/vimeo90k.py +++ b/compressai/datasets/vimeo90k.py @@ -29,6 +29,8 @@ from pathlib import Path +import torch + from PIL import Image from torch.utils.data import Dataset @@ -58,43 +60,74 @@ class Vimeo90kDataset(Dataset): Args: root (string): root directory of the dataset - transform (callable, optional): a function or transform that takes in a - PIL image and returns a transformed version + transform (callable, optional): a function for image/sequence transformation + transform_frame (callable, optional): a function for frame transformation split (string): split mode ('train' or 'valid') tuplet (int): order of dataset tuplet (e.g. 3 for "triplet" dataset) + mode (string): item grouping mode ('image' or 'video'). If 'image', each + item is a single frame. If 'video', each item is a sequence of frames. """ TUPLET_PREFIX = {3: "tri", 7: "sep"} SPLIT_TO_LIST_SUFFIX = {"train": "trainlist", "valid": "testlist"} - def __init__(self, root, transform=None, split="train", tuplet=3): + def __init__( + self, + root, + transform=None, + transform_frame=None, + split="train", + tuplet=3, + mode="image", + ): + self.mode = mode + self.tuplet = tuplet + list_path = Path(root) / self._list_filename(split, tuplet) with open(list_path) as f: - self.samples = [ - f"{root}/sequences/{line.rstrip()}/im{idx}.png" - for line in f - if line.strip() != "" - for idx in range(1, tuplet + 1) + self.sequences = [ + f"{root}/sequences/{line.rstrip()}" for line in f if line.strip() != "" ] + self.frames = [ + f"{seq}/im{idx}.png" + for seq in self.sequences + for idx in range(1, tuplet + 1) + ] + self.transform = transform + self.transform_frame = transform_frame # Suggested: transforms.ToTensor() def __getitem__(self, index): - """ - Args: - index (int): Index - - Returns: - img: `PIL.Image.Image` or transformed `PIL.Image.Image`. - """ - img = Image.open(self.samples[index]).convert("RGB") + if self.mode == "image": + item = self._get_frame(self.frames[index]) + elif self.mode == "video": + item = torch.stack( + [ + self._get_frame(f"{self.sequences[index]}/im{idx}.png") + for idx in range(1, self.tuplet + 1) + ] + ) + else: + raise ValueError(f"Invalid mode {self.mode}. Must be 'image' or 'video'.") if self.transform: - return self.transform(img) - return img + item = self.transform(item) + return item + + def _get_frame(self, filename): + frame = Image.open(filename).convert("RGB") + if self.transform_frame: + frame = self.transform_frame(frame) + return frame def __len__(self): - return len(self.samples) + if self.mode == "image": + return len(self.frames) + elif self.mode == "video": + return len(self.sequences) + else: + raise ValueError(f"Invalid mode {self.mode}. Must be 'image' or 'video'.") def _list_filename(self, split: str, tuplet: int) -> str: tuplet_prefix = self.TUPLET_PREFIX[tuplet] From 9618484d882316a46e1ef2e2b3618256434e4667 Mon Sep 17 00:00:00 2001 From: Mateen Ulhaq Date: Wed, 21 Feb 2024 20:01:57 -0800 Subject: [PATCH 3/3] feat: PreGeneratedMemmapDataset support image/video modes --- compressai/datasets/pregenerated.py | 48 ++++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/compressai/datasets/pregenerated.py b/compressai/datasets/pregenerated.py index 043514c1..c54ab2de 100644 --- a/compressai/datasets/pregenerated.py +++ b/compressai/datasets/pregenerated.py @@ -31,6 +31,7 @@ from typing import Tuple, Union import numpy as np +import torch from PIL import Image from torch.utils.data import Dataset @@ -57,20 +58,27 @@ class PreGeneratedMemmapDataset(Dataset): batch_size (int): batch size. num_workers (int): number of CPU thread workers. pin_memory (bool): pin memory. + mode (string): item grouping mode ('image' or 'video'). If 'image', each + item is a single frame. If 'video', each item is a sequence of frames. + frames_per_sample (int): number of frames per sample (only for 'video' mode). """ def __init__( self, root: str, transform=None, + transform_frame=None, split: str = "train", image_size: _size_2_t = (256, 256), + mode: str = "image", + frames_per_sample: int = 1, ): if not Path(root).is_dir(): raise RuntimeError(f"Invalid path {root}") self.split = split self.transform = transform + self.mode = mode self.shuffle = False @@ -84,14 +92,44 @@ def __init__( data: np.ndarray = np.memmap(path, mode="r", dtype="uint8") assert data.size > 0 image_size = _coerce_size_2_t(image_size) - self.data = data.reshape((-1, image_size[0], image_size[1], 3)) + + if self.mode == "image": + shape = (-1, image_size[0], image_size[1], 3) + elif self.mode == "video": + shape = (-1, frames_per_sample, image_size[0], image_size[1], 3) + else: + raise ValueError(f"Invalid mode {self.mode}. Must be 'image' or 'video'.") + + self.data = data.reshape(shape) + + self.transform = transform + self.transform_frame = transform_frame # Suggested: transforms.ToTensor() def __getitem__(self, index): - sample = self.data[index] - sample = Image.fromarray(sample) + item = self.data[index] + + if self.mode == "image": + item = Image.fromarray(item) + elif self.mode == "video": + item = [Image.fromarray(frame) for frame in item] + + if self.mode == "image": + if self.transform_frame: + item = self.transform_frame(item) + elif self.mode == "video": + if self.transform_frame: + item = [self.transform_frame(frame) for frame in item] + if isinstance(item[0], torch.Tensor): + item = torch.stack(item) + elif isinstance(item[0], np.ndarray): + item = np.stack(item) + else: + raise ValueError("Expected items to be tensors or numpy arrays.") + if self.transform: - return self.transform(sample) - return sample + item = self.transform(item) + + return item def __len__(self): return self.data.shape[0]