diff --git a/train.py b/train.py
index eeabc578..2723a5bd 100644
--- a/train.py
+++ b/train.py
@@ -38,6 +38,14 @@ def train(cfg, writer, logger):
     data_aug = get_composed_augmentations(augmentations)
 
     # Setup Dataloader
+    dataloader_args = cfg["data"].copy()
+    dataloader_args.pop('dataset')
+    dataloader_args.pop('train_split')
+    dataloader_args.pop('val_split')
+    dataloader_args.pop('img_rows')
+    dataloader_args.pop('img_cols')
+    dataloader_args.pop('path')
+
     data_loader = get_loader(cfg["data"]["dataset"])
     data_path = cfg["data"]["path"]
 
@@ -47,6 +55,7 @@ def train(cfg, writer, logger):
         split=cfg["data"]["train_split"],
         img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
         augmentations=data_aug,
+	**dataloader_args
     )
 
     v_loader = data_loader(
@@ -54,6 +63,7 @@ def train(cfg, writer, logger):
         is_transform=True,
         split=cfg["data"]["val_split"],
         img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
+	**dataloader_args
     )
 
     n_classes = t_loader.n_classes
@@ -61,7 +71,7 @@ def train(cfg, writer, logger):
         t_loader,
         batch_size=cfg["training"]["batch_size"],
         num_workers=cfg["training"]["n_workers"],
-        shuffle=True,
+        shuffle=True
     )
 
     valloader = data.DataLoader(