@@ -38,6 +38,15 @@ def train(cfg, writer, logger):
38
38
data_aug = get_composed_augmentations (augmentations )
39
39
40
40
# Setup Dataloader
41
+ dataloader_args = cfg ["data" ].copy ()
42
+ dataloader_args .pop ('dataset' )
43
+ dataloader_args .pop ('train_split' )
44
+ dataloader_args .pop ('val_split' )
45
+ dataloader_args .pop ('img_rows' )
46
+ dataloader_args .pop ('img_cols' )
47
+ dataloader_args .pop ('path' )
48
+
49
+
41
50
data_loader = get_loader (cfg ["data" ]["dataset" ])
42
51
data_path = cfg ["data" ]["path" ]
43
52
@@ -47,21 +56,23 @@ def train(cfg, writer, logger):
47
56
split = cfg ["data" ]["train_split" ],
48
57
img_size = (cfg ["data" ]["img_rows" ], cfg ["data" ]["img_cols" ]),
49
58
augmentations = data_aug ,
59
+ ** dataloader_args
50
60
)
51
61
52
62
v_loader = data_loader (
53
63
data_path ,
54
64
is_transform = True ,
55
65
split = cfg ["data" ]["val_split" ],
56
66
img_size = (cfg ["data" ]["img_rows" ], cfg ["data" ]["img_cols" ]),
67
+ ** dataloader_args
57
68
)
58
69
59
70
n_classes = t_loader .n_classes
60
71
trainloader = data .DataLoader (
61
72
t_loader ,
62
73
batch_size = cfg ["training" ]["batch_size" ],
63
74
num_workers = cfg ["training" ]["n_workers" ],
64
- shuffle = True ,
75
+ shuffle = True
65
76
)
66
77
67
78
valloader = data .DataLoader (
0 commit comments