Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[New features] add early stop in training #3558

Merged
merged 3 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions paddleseg/core/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def train(model,
save_dir='output',
iters=10000,
batch_size=2,
early_stop=0,
resume_model=None,
save_interval=1000,
log_iters=10,
Expand Down Expand Up @@ -117,6 +118,7 @@ def train(model,
local_rank = paddle.distributed.ParallelEnv().local_rank

start_iter = 0
stop_count = 0
if resume_model is not None:
start_iter = resume(model, optimizer, resume_model)

Expand Down Expand Up @@ -354,15 +356,23 @@ def train(model,

if val_dataset is not None:
if mean_iou > best_mean_iou:
stop_count = 0
best_mean_iou = mean_iou
best_model_iter = iter
best_model_dir = os.path.join(save_dir, "best_model")
paddle.save(
model.state_dict(),
os.path.join(best_model_dir, 'model.pdparams'))
logger.info(
'[EVAL] The model with the best validation mIoU ({:.4f}) was saved at iter {}.'
.format(best_mean_iou, best_model_iter))
elif mean_iou < best_mean_iou:
stop_count += 1
if early_stop > 0 and stop_count >= early_stop:
logger.info(
'Early stopping at iter {}. The best mean IoU is {:.4f}.'
.format(iter, best_mean_iou))
else:
logger.info(
'[EVAL] The model with the best validation mIoU ({:.4f}) was saved at iter {}.'
.format(best_mean_iou, best_model_iter))
if use_ema:
if ema_mean_iou > best_ema_mean_iou:
best_ema_mean_iou = ema_mean_iou
Expand Down
6 changes: 6 additions & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ def parse_args():
help='Maximum number of checkpoints to save.',
type=int,
default=5)
parser.add_argument(
'--early_stop',
help='Whether to early stop when loss is not decreasing and max numbers.',
type=int,
default=0)

# Other params
parser.add_argument(
Expand Down Expand Up @@ -187,6 +192,7 @@ def main(args):
save_dir=args.save_dir,
iters=cfg.iters,
batch_size=cfg.batch_size,
early_stop=args.early_stop,
resume_model=args.resume_model,
save_interval=args.save_interval,
log_iters=args.log_iters,
Expand Down