Skip to content

Commit 753ec68

Browse files
author
Jeremy Fix
committed
forwarding key:value arguments to the data_loader
1 parent 801fb20 commit 753ec68

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

train.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,15 @@ def train(cfg, writer, logger):
3838
data_aug = get_composed_augmentations(augmentations)
3939

4040
# Setup Dataloader
41+
dataloader_args = cfg["data"].copy()
42+
dataloader_args.pop('dataset')
43+
dataloader_args.pop('train_split')
44+
dataloader_args.pop('val_split')
45+
dataloader_args.pop('img_rows')
46+
dataloader_args.pop('img_cols')
47+
dataloader_args.pop('path')
48+
49+
4150
data_loader = get_loader(cfg["data"]["dataset"])
4251
data_path = cfg["data"]["path"]
4352

@@ -47,21 +56,23 @@ def train(cfg, writer, logger):
4756
split=cfg["data"]["train_split"],
4857
img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
4958
augmentations=data_aug,
59+
**dataloader_args
5060
)
5161

5262
v_loader = data_loader(
5363
data_path,
5464
is_transform=True,
5565
split=cfg["data"]["val_split"],
5666
img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
67+
**dataloader_args
5768
)
5869

5970
n_classes = t_loader.n_classes
6071
trainloader = data.DataLoader(
6172
t_loader,
6273
batch_size=cfg["training"]["batch_size"],
6374
num_workers=cfg["training"]["n_workers"],
64-
shuffle=True,
75+
shuffle=True
6576
)
6677

6778
valloader = data.DataLoader(

0 commit comments

Comments
 (0)