Skip to content
This repository was archived by the owner on Oct 31, 2023. It is now read-only.

Validation during training (version 2) #828

Merged
merged 4 commits into from
Sep 29, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions configs/e2e_mask_rcnn_R_50_FPN_1x_periodically_testing.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
BACKBONE:
CONV_BODY: "R-50-FPN"
RESNETS:
BACKBONE_OUT_CHANNELS: 256
RPN:
USE_FPN: True
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
ROI_MASK_HEAD:
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor"
PREDICTOR: "MaskRCNNC4Predictor"
POOLER_RESOLUTION: 14
POOLER_SAMPLING_RATIO: 2
RESOLUTION: 28
SHARE_BOX_FEATURE_EXTRACTOR: False
MASK_ON: True
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
Copy link

@qihao-huang qihao-huang May 30, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why we have to add two data sets in TRAIN: ("coco_2014_train", "coco_2014_valminusminival").

Only one data set will be returned in maskrcnn_benchmark/data/build.py 's function build_dataset:

  # for training, concatenate all datasets into a single one
    dataset = datasets[0]
    if len(datasets) > 1:
        dataset = D.ConcatDataset(datasets)
    return [dataset]

datasets is a list, so dataset is coco_2014_train, right?

And, Question 2:
Why you delete the VAL? From my perspective view, TEST is TEST, VAL is VAL. They are different distribution data set, right?

Thank you so much for your work.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding to Question 1:
In the highlighted code snippet datasets are concatenated if there are more than 1 dataset in the TRAIN field.

Regarding to Question 2:
As discussed in #785 (proposed by @fmassa) in this case a separate validation dataset is needed rarely because you do not change hyperparameters when a training script works. After network tuning you can get the best model variant (evaluated on validation dataset which is TEST here) and run tools/test_net.py with another dataset.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your patient and nice reply : )

TEST: ("coco_2014_minival",)
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
BASE_LR: 0.02
WEIGHT_DECAY: 0.0001
STEPS: (60000, 80000)
MAX_ITER: 90000
TEST_PERIOD: 2500

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope this isn't a silly question, but can you explain why you made the decision to change the BASE_LR : 0.02 (default: 0.001), WEIGHT_DECAY: 0.0001(default: 0.0005) and STEPS:(60000, 80000)? If these have been answered in a previous issue, I wouldn't mind being pointed to that discussion. Thank you for you time!

1 change: 1 addition & 0 deletions maskrcnn_benchmark/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@
_C.SOLVER.WARMUP_METHOD = "linear"

_C.SOLVER.CHECKPOINT_PERIOD = 2500
_C.SOLVER.TEST_PERIOD = 0

# Number of images per batch
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
Expand Down
4 changes: 2 additions & 2 deletions maskrcnn_benchmark/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def make_batch_data_sampler(
return batch_sampler


def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0):
def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0, is_for_period=False):
num_gpus = get_world_size()
if is_train:
images_per_batch = cfg.SOLVER.IMS_PER_BATCH
Expand Down Expand Up @@ -152,7 +152,7 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0):

# If bbox aug is enabled in testing, simply set transforms to None and we will apply transforms later
transforms = None if not is_train and cfg.TEST.BBOX_AUG.ENABLED else build_transforms(cfg, is_train)
datasets = build_dataset(dataset_list, transforms, DatasetCatalog, is_train)
datasets = build_dataset(dataset_list, transforms, DatasetCatalog, is_train or is_for_period)

data_loaders = []
for dataset in datasets:
Expand Down
60 changes: 59 additions & 1 deletion maskrcnn_benchmark/engine/trainer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import datetime
import logging
import os
import time

import torch
import torch.distributed as dist

from maskrcnn_benchmark.utils.comm import get_world_size
from maskrcnn_benchmark.utils.comm import get_world_size, synchronize
from maskrcnn_benchmark.utils.metric_logger import MetricLogger
from maskrcnn_benchmark.engine.inference import inference

from apex import amp

Expand Down Expand Up @@ -37,13 +39,16 @@ def reduce_loss_dict(loss_dict):


def do_train(
cfg,
model,
data_loader,
data_loaders_val,
optimizer,
scheduler,
checkpointer,
device,
checkpoint_period,
test_period,
arguments,
):
logger = logging.getLogger("maskrcnn_benchmark.trainer")
Expand All @@ -54,6 +59,14 @@ def do_train(
model.train()
start_training_time = time.time()
end = time.time()

iou_types = ("bbox",)
if cfg.MODEL.MASK_ON:
iou_types = iou_types + ("segm",)
if cfg.MODEL.KEYPOINT_ON:
iou_types = iou_types + ("keypoints",)
dataset_names = cfg.DATASETS.TEST

for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
data_time = time.time() - end
iteration = iteration + 1
Expand Down Expand Up @@ -107,6 +120,51 @@ def do_train(
)
if iteration % checkpoint_period == 0:
checkpointer.save("model_{:07d}".format(iteration), **arguments)
if data_loaders_val is not None and test_period > 0 and iteration % test_period == 0:
synchronize()
for dataset_name, data_loader_val in zip(dataset_names, data_loaders_val):
_ = inference( # The result can be used for additional loggin, e. g. to TensorBoard
model,
data_loader_val,
dataset_name=dataset_name,
iou_types=iou_types,
box_only=False if cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY,
device=cfg.MODEL.DEVICE,
expected_results=cfg.TEST.EXPECTED_RESULTS,
expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
output_folder=None,
)
synchronize()
model.train()
meters_val = MetricLogger(delimiter=" ")
with torch.no_grad():
for idx_val, (images_val, targets_val, _) in enumerate(data_loaders_val[0]):
images_val = images_val.to(device)
targets_val = [target.to(device) for target in targets_val]
loss_dict = model(images_val, targets_val)
losses = sum(loss for loss in loss_dict.values())
loss_dict_reduced = reduce_loss_dict(loss_dict)
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
meters_val.update(loss=losses_reduced, **loss_dict_reduced)
synchronize()
logger.info(
meters.delimiter.join(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should meters be meters_val ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not matter here because meters and meters_val have the same delimiter, but yes, ideally meters_val should be here. I'll fix this.

[
"[Validation]: ",
"eta: {eta}",
"iter: {iter}",
"{meters}",
"lr: {lr:.6f}",
"max mem: {memory:.0f}",
]
).format(
eta=eta_string,
iter=iteration,
meters=str(meters_val),
lr=optimizer.param_groups[0]["lr"],
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
)
)
if iteration == max_iter:
checkpointer.save("model_final", **arguments)

Expand Down
9 changes: 9 additions & 0 deletions tools/train_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,25 @@ def train(cfg, local_rank, distributed):
start_iter=arguments["iteration"],
)

test_period = cfg.SOLVER.TEST_PERIOD
if test_period > 0:
data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed, is_for_period=True)
else:
data_loaders_val = None

checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

do_train(
cfg,
model,
data_loader,
data_loaders_val,
optimizer,
scheduler,
checkpointer,
device,
checkpoint_period,
test_period,
arguments,
)

Expand Down