diff --git a/train.py b/train.py index 3f99e629..49b4b5f3 100644 --- a/train.py +++ b/train.py @@ -1,6 +1,7 @@ # System libs import os import time +from distutils.util import strtobool # import math import random import argparse @@ -248,7 +249,7 @@ def main(args): help='maxmimum downsampling rate of the network') parser.add_argument('--segm_downsampling_rate', default=8, type=int, help='downsampling rate of the segmentation label') - parser.add_argument('--random_flip', default=True, type=bool, + parser.add_argument('--random_flip', default=True, type=lambda x: bool(strtobool(x)), help='if horizontally flip images when training') # Misc arguments