diff --git a/train.py b/train.py index eeabc578..f29cc80a 100644 --- a/train.py +++ b/train.py @@ -1,11 +1,11 @@ import os -import yaml import time import shutil import torch import random import argparse import numpy as np +from yaml import load, Loader from torch.utils import data from tqdm import tqdm @@ -214,7 +214,7 @@ def train(cfg, writer, logger): args = parser.parse_args() with open(args.config) as fp: - cfg = yaml.load(fp) + cfg = load(fp, Loader=Loader) run_id = random.randint(1, 100000) logdir = os.path.join("runs", os.path.basename(args.config)[:-4], str(run_id))