diff --git a/train.py b/train.py index eeabc578..2723a5bd 100644 --- a/train.py +++ b/train.py @@ -38,6 +38,14 @@ def train(cfg, writer, logger): data_aug = get_composed_augmentations(augmentations) # Setup Dataloader + dataloader_args = cfg["data"].copy() + dataloader_args.pop('dataset') + dataloader_args.pop('train_split') + dataloader_args.pop('val_split') + dataloader_args.pop('img_rows') + dataloader_args.pop('img_cols') + dataloader_args.pop('path') + data_loader = get_loader(cfg["data"]["dataset"]) data_path = cfg["data"]["path"] @@ -47,6 +55,7 @@ def train(cfg, writer, logger): split=cfg["data"]["train_split"], img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]), augmentations=data_aug, + **dataloader_args ) v_loader = data_loader( @@ -54,6 +63,7 @@ def train(cfg, writer, logger): is_transform=True, split=cfg["data"]["val_split"], img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]), + **dataloader_args ) n_classes = t_loader.n_classes @@ -61,7 +71,7 @@ def train(cfg, writer, logger): t_loader, batch_size=cfg["training"]["batch_size"], num_workers=cfg["training"]["n_workers"], - shuffle=True, + shuffle=True ) valloader = data.DataLoader(