Skip to content

Commit

Permalink
Merge pull request #58 from Visual-Behavior/50-train-with-samples
Browse files Browse the repository at this point in the history
50 train with samples
  • Loading branch information
thibo73800 authored Sep 10, 2021
2 parents c64dae3 + 5923c59 commit e63d100
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 39 deletions.
18 changes: 10 additions & 8 deletions alodataset/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,11 @@ def __init__(
super(BaseDataset, self).__init__(**kwargs)
self.name = name
self.sample = sample
if not self.sample:
self.items = []
self.dataset_dir = self.get_dataset_dir()
self.dataset_dir = self.get_dataset_dir()
if self.sample:
self.items = self.download_sample()
self.dataset_dir = os.path.join(self.vb_folder, "samples")
else:
self.items = []
self.transform_fn = transform_fn
self.ignore_errors = ignore_errors
self.print_errors = print_errors
Expand Down Expand Up @@ -212,6 +211,9 @@ def get_dataset_dir(self) -> str:
"""Look for dataset_dir based on the given name. To work properly a alodataset_config.json
file must be save into /home/USER/.aloception/alodataset_config.json
"""
if self.sample:
return os.path.join(self.vb_folder, "samples")

streaming_dt_config = os.path.join(self.vb_folder, "alodataset_config.json")
if not os.path.exists(streaming_dt_config):
self.set_dataset_dir(None)
Expand Down Expand Up @@ -247,19 +249,19 @@ def set_dataset_dir(self, dataset_dir: str):

if dataset_dir is None:
dataset_dir = _user_prompt(
f"{self.name} does not exist in config file."
f"{self.name} does not exist in config file. "
+ "Do you want to download and use a sample?: (Y)es or (N)o: "
)
if dataset_dir.lower() in ["y", "yes"]:
if dataset_dir.lower() in ["y", "yes"]: # Download sample and change root directory
self.sample = True
return
return os.path.join(self.vb_folder, "samples")
dataset_dir = _user_prompt(f"Please write a new root directory for {self.name} dataset: ")
dataset_dir = os.path.expanduser(dataset_dir)

# Save the config
if not os.path.exists(dataset_dir):
dataset_dir = _user_prompt(
f"[WARNING] {dataset_dir} path does not exists for dataset: {self.name}."
f"[WARNING] {dataset_dir} path does not exists for dataset: {self.name}. "
+ "Please write a new directory:"
)
dataset_dir = os.path.expanduser(dataset_dir)
Expand Down
5 changes: 3 additions & 2 deletions alodataset/coco_detection_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def __init__(
if "sample" not in kwargs:
kwargs["sample"] = False

if not kwargs["sample"]:
self.sample = kwargs["sample"]
if not self.sample:
assert img_folder is not None, "When sample = False, img_folder must be given."
assert ann_file is not None, "When sample = False, ann_file must be given."

Expand All @@ -91,8 +92,8 @@ def __init__(
img_folder = os.path.join(dataset_dir, img_folder)
ann_file = os.path.join(dataset_dir, ann_file)
stuff_ann_file = None if stuff_ann_file is None else os.path.join(dataset_dir, stuff_ann_file)
kwargs["sample"] = self.sample

self.sample = kwargs["sample"]
super(CocoDetectionDataset, self).__init__(name=name, root=img_folder, annFile=ann_file, **kwargs)
if self.sample:
return
Expand Down
3 changes: 0 additions & 3 deletions alonet/deformable_detr/train_on_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ def get_arg_parser():
parser = ArgumentParser(conflict_handler="resolve")
parser = alonet.common.add_argparse_args(parser) # Common alonet parser
parser = CocoDetection2Detr.add_argparse_args(parser) # Coco detection parser
parser.add_argument(
"--use_sample", action="store_true", help="Download a sample for train process (Default: %(default)s)"
)
parser = LitDeformableDetr.add_argparse_args(parser) # LitDeformableDetr training parser
# parser = pl.Trainer.add_argparse_args(parser) # Pytorch lightning Parser
return parser
Expand Down
34 changes: 19 additions & 15 deletions alonet/detr/coco_data_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def __init__(
train_ann: str = "annotations/instances_train2017.json",
val_folder: str = "val2017",
val_ann: str = "annotations/instances_val2017.json",
sample: bool = False,
**kwargs
):
"""LightningDataModule to use coco dataset in Detr models
Expand Down Expand Up @@ -58,6 +57,7 @@ def __init__(
Arguments entered by the user (kwargs) will replace those stored in args attribute
"""
# Update class attributes with args and kwargs inputs

super().__init__()
alonet.common.pl_helpers.params_update(self, args, kwargs)

Expand All @@ -71,17 +71,16 @@ def __init__(
# Split=Split.TRAIN if not self.train_on_val else Split.VAL,
classes=classes,
name=name,
sample=sample,
)
self.val_loader_kwargs = dict(
img_folder=val_folder,
ann_file=val_ann,
# split=Split.VAL,
classes=classes,
name=name,
sample=sample,
)
self.args = args
self.val_check() # Check val loader and set some previous parameters

@staticmethod
def add_argparse_args(parent_parser):
Expand All @@ -104,18 +103,11 @@ def add_argparse_args(parent_parser):
nargs="+",
help="If no augmentation (--no_augmentation) is used, --size can be used to resize all the frame.",
)
# parser.add_argument("--classes", type=str, default=None, nargs="+", help="List to classes to be filtered in dataset. (%(default)s by default)")
parser.add_argument(
"--sample", action="store_true", help="Download a sample for train/val process (Default: %(default)s)"
)
return parent_parser

@property
def CATEGORIES(self):
if not hasattr(self, "coco_train"):
self.setup()
if not hasattr(self, "coco_train"):
return None
else:
return self.coco_train.CATEGORIES if hasattr(self.coco_train, "CATEGORIES") else None

def train_transform(self, frame, same_on_sequence: bool = True, same_on_frames: bool = False):
if self.no_augmentation:
if self.size[0] is not None and self.size[1] is not None:
Expand Down Expand Up @@ -158,15 +150,27 @@ def val_transform(

return frame.norm_resnet()

def val_check(self):
# Instance a default loader to set attributes
self.coco_val = alodataset.CocoDetectionDataset(
transform_fn=self.val_transform, sample=self.sample, **self.val_loader_kwargs,
)
self.sample = self.coco_val.sample or self.sample # Update sample if user prompt is given
self.CATEGORIES = self.coco_val.CATEGORIES if hasattr(self.coco_val, "CATEGORIES") else None

def setup(self, stage: Optional[str] = None) -> None:
if stage == "fit" or stage is None:
# Setup train/val loaders
self.coco_train = alodataset.CocoDetectionDataset(
transform_fn=self.train_transform, **self.train_loader_kwargs
transform_fn=self.train_transform, sample=self.sample, **self.train_loader_kwargs
)
self.coco_val = alodataset.CocoDetectionDataset(
transform_fn=self.val_transform, sample=self.sample, **self.val_loader_kwargs
)
self.coco_val = alodataset.CocoDetectionDataset(transform_fn=self.val_transform, **self.val_loader_kwargs)

def train_dataloader(self):
"""Train dataloader"""
# Init training loader
if not hasattr(self, "coco_train"):
self.setup()
return self.coco_train.train_loader(batch_size=self.batch_size, num_workers=self.num_workers)
Expand Down
5 changes: 1 addition & 4 deletions alonet/detr/train_on_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@ def get_arg_parser():
parser = ArgumentParser(conflict_handler="resolve")
parser = alonet.common.add_argparse_args(parser) # Common alonet parser
parser = CocoDetection2Detr.add_argparse_args(parser) # Coco detection parser
parser.add_argument(
"--use_sample", action="store_true", help="Download a sample for train process (Default: %(default)s)"
)
parser = LitDetr.add_argparse_args(parser) # LitDetr training parser
# parser = pl.Trainer.add_argparse_args(parser) # Pytorch lightning Parser
return parser
Expand All @@ -24,7 +21,7 @@ def main():

# Init the Detr model with the dataset
detr = LitDetr(args)
coco_loader = CocoDetection2Detr(args, sample=args.use_sample)
coco_loader = CocoDetection2Detr(args)

detr.run_train(data_loader=coco_loader, args=args, project="detr", expe_name="detr_50")

Expand Down
3 changes: 2 additions & 1 deletion alonet/raft/data_modules/chairs2raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ def __init__(self, args):
def train_dataloader(self):
split = Split.VAL if self.train_on_val else Split.TRAIN
dataset = FlyingChairs2Dataset(split=split, transform_fn=self.train_transform, sample=self.sample)
self.sample = self.sample or dataset.sample
sampler = SequentialSampler if self.sequential else RandomSampler
return dataset.train_loader(batch_size=self.batch_size, num_workers=self.num_workers, sampler=sampler)

def val_dataloader(self):
dataset = FlyingChairs2Dataset(split=Split.VAL, transform_fn=self.val_transform, sample=self.sample)

self.sample = self.sample or dataset.sample
return dataset.train_loader(batch_size=1, num_workers=self.num_workers, sampler=SequentialSampler)


Expand Down
3 changes: 3 additions & 0 deletions alonet/raft/data_modules/data2raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def add_argparse_args(parent_parser):
parser.add_argument("--num_workers", type=int, default=8, help="num_workers to use on the dataset")
parser.add_argument("--limit_val_batches", type=_int_or_float_type, default=100)
parser.add_argument("--sequential_sampler", action="store_true", help="sample data sequentially (no shuffle)")
parser.add_argument(
"--sample", action="store_true", help="Download a sample for train/val process (Default: %(default)s)"
)
return parent_parser

def train_transform(self, frame):
Expand Down
3 changes: 0 additions & 3 deletions alonet/raft/train_on_chairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@ def get_args_parser():
parser = argparse.ArgumentParser(conflict_handler="resolve")
parser = alonet.common.add_argparse_args(parser, add_pl_args=True)
parser = Chairs2RAFT.add_argparse_args(parser)
parser.add_argument(
"--use_sample", action="store_true", help="Download a sample for train process (Default: %(default)s)"
)
parser = LitRAFT.add_argparse_args(parser)
return parser

Expand Down
6 changes: 3 additions & 3 deletions unittest/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def get_argparse_defaults(parser):
detr_args["weights"] = "detr-r50"
detr_args["train_on_val"] = True
detr_args["fast_dev_run"] = True
detr_args["use_sample"] = True
detr_args["sample"] = True


@mock.patch("argparse.ArgumentParser.parse_args", return_value=argparse.Namespace(**detr_args))
Expand All @@ -38,7 +38,7 @@ def test_detr(mock_args):
def_detr_args["model_name"] = "deformable-detr-r50"
def_detr_args["train_on_val"] = True
def_detr_args["fast_dev_run"] = True
def_detr_args["use_sample"] = True
def_detr_args["sample"] = True


@mock.patch("argparse.ArgumentParser.parse_args", return_value=argparse.Namespace(**def_detr_args))
Expand All @@ -54,7 +54,7 @@ def test_deformable_detr(mock_args):
raft_args["weights"] = "raft-things"
raft_args["train_on_val"] = True
raft_args["fast_dev_run"] = True
raft_args["use_sample"] = True
raft_args["sample"] = True


@mock.patch("argparse.ArgumentParser.parse_args", return_value=argparse.Namespace(**raft_args))
Expand Down

0 comments on commit e63d100

Please sign in to comment.