From 15fb989c2db5a47df008db54c16c55801fe05eb8 Mon Sep 17 00:00:00 2001 From: Johan MEJIA Date: Wed, 8 Sep 2021 23:49:39 +0200 Subject: [PATCH] fix bug in train with samples --- alodataset/base_dataset.py | 18 +++++++------ alodataset/coco_detection_dataset.py | 9 +++---- alonet/deformable_detr/train_on_coco.py | 3 --- alonet/detr/coco_data_modules.py | 34 ++++++++++++++----------- alonet/detr/train_on_coco.py | 5 +--- alonet/raft/data_modules/chairs2raft.py | 3 ++- alonet/raft/data_modules/data2raft.py | 3 +++ alonet/raft/train_on_chairs.py | 3 --- unittest/test_train.py | 6 ++--- 9 files changed, 42 insertions(+), 42 deletions(-) diff --git a/alodataset/base_dataset.py b/alodataset/base_dataset.py index a56ab78e..ed7e1e45 100644 --- a/alodataset/base_dataset.py +++ b/alodataset/base_dataset.py @@ -131,12 +131,11 @@ def __init__( super(BaseDataset, self).__init__(**kwargs) self.name = name self.sample = sample - if not self.sample: - self.items = [] - self.dataset_dir = self.get_dataset_dir() + self.dataset_dir = self.get_dataset_dir() if self.sample: self.items = self.download_sample() - self.dataset_dir = os.path.join(self.vb_folder, "samples") + else: + self.items = [] self.transform_fn = transform_fn self.ignore_errors = ignore_errors self.print_errors = print_errors @@ -210,6 +209,9 @@ def get_dataset_dir(self) -> str: """Look for dataset_dir based on the given name. To work properly a alodataset_config.json file must be save into /home/USER/.aloception/alodataset_config.json """ + if self.sample: + return os.path.join(self.vb_folder, "samples") + streaming_dt_config = os.path.join(self.vb_folder, "alodataset_config.json") if not os.path.exists(streaming_dt_config): self.set_dataset_dir(None) @@ -245,19 +247,19 @@ def set_dataset_dir(self, dataset_dir: str): if dataset_dir is None: dataset_dir = _user_prompt( - f"{self.name} does not exist in config file." + f"{self.name} does not exist in config file. " + "Do you want to download and use a sample?: (Y)es or (N)o: " ) - if dataset_dir.lower() in ["y", "yes"]: + if dataset_dir.lower() in ["y", "yes"]: # Download sample and change root directory self.sample = True - return + return os.path.join(self.vb_folder, "samples") dataset_dir = _user_prompt(f"Please write a new root directory for {self.name} dataset: ") dataset_dir = os.path.expanduser(dataset_dir) # Save the config if not os.path.exists(dataset_dir): dataset_dir = _user_prompt( - f"[WARNING] {dataset_dir} path does not exists for dataset: {self.name}." + f"[WARNING] {dataset_dir} path does not exists for dataset: {self.name}. " + "Please write a new directory:" ) dataset_dir = os.path.expanduser(dataset_dir) diff --git a/alodataset/coco_detection_dataset.py b/alodataset/coco_detection_dataset.py index 2b0a2b9b..e9996700 100644 --- a/alodataset/coco_detection_dataset.py +++ b/alodataset/coco_detection_dataset.py @@ -80,7 +80,8 @@ def __init__( if "sample" not in kwargs: kwargs["sample"] = False - if not kwargs["sample"]: + self.sample = kwargs["sample"] + if not self.sample: assert img_folder is not None, "When sample = False, img_folder must be given." assert ann_file is not None, "When sample = False, ann_file must be given." @@ -88,8 +89,8 @@ def __init__( dataset_dir = BaseDataset.get_dataset_dir(self) img_folder = os.path.join(dataset_dir, img_folder) ann_file = os.path.join(dataset_dir, ann_file) + kwargs["sample"] = self.sample - self.sample = kwargs["sample"] super(CocoDetectionDataset, self).__init__(name=name, root=img_folder, annFile=ann_file, **kwargs) if self.sample: return @@ -282,9 +283,7 @@ def show_random_frame(coco_loader): def main(): """Main""" logging.basicConfig( - level=logging.INFO, - format="[%(asctime)s][%(levelname)s] %(message)s", - datefmt="%d-%m-%y %H:%M:%S", + level=logging.INFO, format="[%(asctime)s][%(levelname)s] %(message)s", datefmt="%d-%m-%y %H:%M:%S", ) log = logging.getLogger("aloception") diff --git a/alonet/deformable_detr/train_on_coco.py b/alonet/deformable_detr/train_on_coco.py index a3db5485..43281134 100644 --- a/alonet/deformable_detr/train_on_coco.py +++ b/alonet/deformable_detr/train_on_coco.py @@ -15,9 +15,6 @@ def get_arg_parser(): parser = ArgumentParser(conflict_handler="resolve") parser = alonet.common.add_argparse_args(parser) # Common alonet parser parser = CocoDetection2Detr.add_argparse_args(parser) # Coco detection parser - parser.add_argument( - "--use_sample", action="store_true", help="Download a sample for train process (Default: %(default)s)" - ) parser = LitDeformableDetr.add_argparse_args(parser) # LitDeformableDetr training parser # parser = pl.Trainer.add_argparse_args(parser) # Pytorch lightning Parser return parser diff --git a/alonet/detr/coco_data_modules.py b/alonet/detr/coco_data_modules.py index d61c8d15..18d25641 100644 --- a/alonet/detr/coco_data_modules.py +++ b/alonet/detr/coco_data_modules.py @@ -19,7 +19,6 @@ def __init__( train_ann: str = "annotations/instances_train2017.json", val_folder: str = "val2017", val_ann: str = "annotations/instances_val2017.json", - sample: bool = False, **kwargs ): """LightningDataModule to use coco dataset in Detr models @@ -58,6 +57,7 @@ def __init__( Arguments entered by the user (kwargs) will replace those stored in args attribute """ # Update class attributes with args and kwargs inputs + super().__init__() alonet.common.pl_helpers.params_update(self, args, kwargs) @@ -71,7 +71,6 @@ def __init__( # Split=Split.TRAIN if not self.train_on_val else Split.VAL, classes=classes, name=name, - sample=sample, ) self.val_loader_kwargs = dict( img_folder=val_folder, @@ -79,9 +78,9 @@ def __init__( # split=Split.VAL, classes=classes, name=name, - sample=sample, ) self.args = args + self.val_check() # Check val loader and set some previous parameters @staticmethod def add_argparse_args(parent_parser): @@ -104,18 +103,11 @@ def add_argparse_args(parent_parser): nargs="+", help="If no augmentation (--no_augmentation) is used, --size can be used to resize all the frame.", ) - # parser.add_argument("--classes", type=str, default=None, nargs="+", help="List to classes to be filtered in dataset. (%(default)s by default)") + parser.add_argument( + "--sample", action="store_true", help="Download a sample for train/val process (Default: %(default)s)" + ) return parent_parser - @property - def CATEGORIES(self): - if not hasattr(self, "coco_train"): - self.setup() - if not hasattr(self, "coco_train"): - return None - else: - return self.coco_train.CATEGORIES if hasattr(self.coco_train, "CATEGORIES") else None - def train_transform(self, frame, same_on_sequence: bool = True, same_on_frames: bool = False): if self.no_augmentation: if self.size[0] is not None and self.size[1] is not None: @@ -158,15 +150,27 @@ def val_transform( return frame.norm_resnet() + def val_check(self): + # Instance a default loader to set attributes + self.coco_val = alodataset.CocoDetectionDataset( + transform_fn=self.val_transform, sample=self.sample, **self.val_loader_kwargs, + ) + self.sample = self.coco_val.sample or self.sample # Update sample if user prompt is given + self.CATEGORIES = self.coco_val.CATEGORIES if hasattr(self.coco_val, "CATEGORIES") else None + def setup(self, stage: Optional[str] = None) -> None: if stage == "fit" or stage is None: # Setup train/val loaders self.coco_train = alodataset.CocoDetectionDataset( - transform_fn=self.train_transform, **self.train_loader_kwargs + transform_fn=self.train_transform, sample=self.sample, **self.train_loader_kwargs + ) + self.coco_val = alodataset.CocoDetectionDataset( + transform_fn=self.val_transform, sample=self.sample, **self.val_loader_kwargs ) - self.coco_val = alodataset.CocoDetectionDataset(transform_fn=self.val_transform, **self.val_loader_kwargs) def train_dataloader(self): + """Train dataloader""" + # Init training loader if not hasattr(self, "coco_train"): self.setup() return self.coco_train.train_loader(batch_size=self.batch_size, num_workers=self.num_workers) diff --git a/alonet/detr/train_on_coco.py b/alonet/detr/train_on_coco.py index 0d265157..07232ad6 100644 --- a/alonet/detr/train_on_coco.py +++ b/alonet/detr/train_on_coco.py @@ -9,9 +9,6 @@ def get_arg_parser(): parser = ArgumentParser(conflict_handler="resolve") parser = alonet.common.add_argparse_args(parser) # Common alonet parser parser = CocoDetection2Detr.add_argparse_args(parser) # Coco detection parser - parser.add_argument( - "--use_sample", action="store_true", help="Download a sample for train process (Default: %(default)s)" - ) parser = LitDetr.add_argparse_args(parser) # LitDetr training parser # parser = pl.Trainer.add_argparse_args(parser) # Pytorch lightning Parser return parser @@ -24,7 +21,7 @@ def main(): # Init the Detr model with the dataset detr = LitDetr(args) - coco_loader = CocoDetection2Detr(args, sample=args.use_sample) + coco_loader = CocoDetection2Detr(args) detr.run_train(data_loader=coco_loader, args=args, project="detr", expe_name="detr_50") diff --git a/alonet/raft/data_modules/chairs2raft.py b/alonet/raft/data_modules/chairs2raft.py index 7c15a82b..a31d65b7 100644 --- a/alonet/raft/data_modules/chairs2raft.py +++ b/alonet/raft/data_modules/chairs2raft.py @@ -13,12 +13,13 @@ def __init__(self, args): def train_dataloader(self): split = Split.VAL if self.train_on_val else Split.TRAIN dataset = FlyingChairs2Dataset(split=split, transform_fn=self.train_transform, sample=self.sample) + self.sample = self.sample or dataset.sample sampler = SequentialSampler if self.sequential else RandomSampler return dataset.train_loader(batch_size=self.batch_size, num_workers=self.num_workers, sampler=sampler) def val_dataloader(self): dataset = FlyingChairs2Dataset(split=Split.VAL, transform_fn=self.val_transform, sample=self.sample) - + self.sample = self.sample or dataset.sample return dataset.train_loader(batch_size=1, num_workers=self.num_workers, sampler=SequentialSampler) diff --git a/alonet/raft/data_modules/data2raft.py b/alonet/raft/data_modules/data2raft.py index 727938ab..e14bd579 100644 --- a/alonet/raft/data_modules/data2raft.py +++ b/alonet/raft/data_modules/data2raft.py @@ -25,6 +25,9 @@ def add_argparse_args(parent_parser): parser.add_argument("--num_workers", type=int, default=8, help="num_workers to use on the dataset") parser.add_argument("--limit_val_batches", type=_int_or_float_type, default=100) parser.add_argument("--sequential_sampler", action="store_true", help="sample data sequentially (no shuffle)") + parser.add_argument( + "--sample", action="store_true", help="Download a sample for train/val process (Default: %(default)s)" + ) return parent_parser def train_transform(self, frame): diff --git a/alonet/raft/train_on_chairs.py b/alonet/raft/train_on_chairs.py index 941549b4..5c664fd4 100644 --- a/alonet/raft/train_on_chairs.py +++ b/alonet/raft/train_on_chairs.py @@ -9,9 +9,6 @@ def get_args_parser(): parser = argparse.ArgumentParser(conflict_handler="resolve") parser = alonet.common.add_argparse_args(parser, add_pl_args=True) parser = Chairs2RAFT.add_argparse_args(parser) - parser.add_argument( - "--use_sample", action="store_true", help="Download a sample for train process (Default: %(default)s)" - ) parser = LitRAFT.add_argparse_args(parser) return parser diff --git a/unittest/test_train.py b/unittest/test_train.py index a1c3a7c8..40ac6773 100644 --- a/unittest/test_train.py +++ b/unittest/test_train.py @@ -18,7 +18,7 @@ def get_argparse_defaults(parser): detr_args["weights"] = "detr-r50" detr_args["train_on_val"] = True detr_args["fast_dev_run"] = True -detr_args["use_sample"] = True +detr_args["sample"] = True @mock.patch("argparse.ArgumentParser.parse_args", return_value=argparse.Namespace(**detr_args)) @@ -38,7 +38,7 @@ def test_detr(mock_args): def_detr_args["model_name"] = "deformable-detr-r50" def_detr_args["train_on_val"] = True def_detr_args["fast_dev_run"] = True -def_detr_args["use_sample"] = True +def_detr_args["sample"] = True @mock.patch("argparse.ArgumentParser.parse_args", return_value=argparse.Namespace(**def_detr_args)) @@ -54,7 +54,7 @@ def test_deformable_detr(mock_args): raft_args["weights"] = "raft-things" raft_args["train_on_val"] = True raft_args["fast_dev_run"] = True -raft_args["use_sample"] = True +raft_args["sample"] = True @mock.patch("argparse.ArgumentParser.parse_args", return_value=argparse.Namespace(**raft_args))