diff --git a/main.py b/main.py index 8626170..98d6255 100644 --- a/main.py +++ b/main.py @@ -119,8 +119,7 @@ def model_load(fn): if args.resume: print('Resuming model ...') model_load(args.resume) - optimizer.param_groups[0]['lr'] = args.lr - model.dropouti, model.dropouth, model.dropout, args.dropoute = args.dropouti, args.dropouth, args.dropout, args.dropoute + model.dropouti, model.dropouth, model.dropout, model.dropoute = args.dropouti, args.dropouth, args.dropout, args.dropoute if args.wdrop: from weight_drop import WeightDrop for rnn in model.rnns: