-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Validation during training (version 2) #828
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
MODEL: | ||
META_ARCHITECTURE: "GeneralizedRCNN" | ||
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50" | ||
BACKBONE: | ||
CONV_BODY: "R-50-FPN" | ||
RESNETS: | ||
BACKBONE_OUT_CHANNELS: 256 | ||
RPN: | ||
USE_FPN: True | ||
ANCHOR_STRIDE: (4, 8, 16, 32, 64) | ||
PRE_NMS_TOP_N_TRAIN: 2000 | ||
PRE_NMS_TOP_N_TEST: 1000 | ||
POST_NMS_TOP_N_TEST: 1000 | ||
FPN_POST_NMS_TOP_N_TEST: 1000 | ||
ROI_HEADS: | ||
USE_FPN: True | ||
ROI_BOX_HEAD: | ||
POOLER_RESOLUTION: 7 | ||
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) | ||
POOLER_SAMPLING_RATIO: 2 | ||
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor" | ||
PREDICTOR: "FPNPredictor" | ||
ROI_MASK_HEAD: | ||
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) | ||
FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor" | ||
PREDICTOR: "MaskRCNNC4Predictor" | ||
POOLER_RESOLUTION: 14 | ||
POOLER_SAMPLING_RATIO: 2 | ||
RESOLUTION: 28 | ||
SHARE_BOX_FEATURE_EXTRACTOR: False | ||
MASK_ON: True | ||
DATASETS: | ||
TRAIN: ("coco_2014_train", "coco_2014_valminusminival") | ||
TEST: ("coco_2014_minival",) | ||
DATALOADER: | ||
SIZE_DIVISIBILITY: 32 | ||
SOLVER: | ||
BASE_LR: 0.02 | ||
WEIGHT_DECAY: 0.0001 | ||
STEPS: (60000, 80000) | ||
MAX_ITER: 90000 | ||
TEST_PERIOD: 2500 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I hope this isn't a silly question, but can you explain why you made the decision to change the BASE_LR : 0.02 (default: 0.001), WEIGHT_DECAY: 0.0001(default: 0.0005) and STEPS:(60000, 80000)? If these have been answered in a previous issue, I wouldn't mind being pointed to that discussion. Thank you for you time! |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,15 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | ||
import datetime | ||
import logging | ||
import os | ||
import time | ||
|
||
import torch | ||
import torch.distributed as dist | ||
|
||
from maskrcnn_benchmark.utils.comm import get_world_size | ||
from maskrcnn_benchmark.utils.comm import get_world_size, synchronize | ||
from maskrcnn_benchmark.utils.metric_logger import MetricLogger | ||
from maskrcnn_benchmark.engine.inference import inference | ||
|
||
from apex import amp | ||
|
||
|
@@ -37,13 +39,16 @@ def reduce_loss_dict(loss_dict): | |
|
||
|
||
def do_train( | ||
cfg, | ||
model, | ||
data_loader, | ||
data_loaders_val, | ||
optimizer, | ||
scheduler, | ||
checkpointer, | ||
device, | ||
checkpoint_period, | ||
test_period, | ||
arguments, | ||
): | ||
logger = logging.getLogger("maskrcnn_benchmark.trainer") | ||
|
@@ -54,6 +59,14 @@ def do_train( | |
model.train() | ||
start_training_time = time.time() | ||
end = time.time() | ||
|
||
iou_types = ("bbox",) | ||
if cfg.MODEL.MASK_ON: | ||
iou_types = iou_types + ("segm",) | ||
if cfg.MODEL.KEYPOINT_ON: | ||
iou_types = iou_types + ("keypoints",) | ||
dataset_names = cfg.DATASETS.TEST | ||
|
||
for iteration, (images, targets, _) in enumerate(data_loader, start_iter): | ||
data_time = time.time() - end | ||
iteration = iteration + 1 | ||
|
@@ -107,6 +120,51 @@ def do_train( | |
) | ||
if iteration % checkpoint_period == 0: | ||
checkpointer.save("model_{:07d}".format(iteration), **arguments) | ||
if data_loaders_val is not None and test_period > 0 and iteration % test_period == 0: | ||
synchronize() | ||
for dataset_name, data_loader_val in zip(dataset_names, data_loaders_val): | ||
_ = inference( # The result can be used for additional loggin, e. g. to TensorBoard | ||
model, | ||
data_loader_val, | ||
dataset_name=dataset_name, | ||
iou_types=iou_types, | ||
box_only=False if cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY, | ||
device=cfg.MODEL.DEVICE, | ||
expected_results=cfg.TEST.EXPECTED_RESULTS, | ||
expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, | ||
output_folder=None, | ||
) | ||
synchronize() | ||
model.train() | ||
meters_val = MetricLogger(delimiter=" ") | ||
with torch.no_grad(): | ||
for idx_val, (images_val, targets_val, _) in enumerate(data_loaders_val[0]): | ||
images_val = images_val.to(device) | ||
targets_val = [target.to(device) for target in targets_val] | ||
loss_dict = model(images_val, targets_val) | ||
losses = sum(loss for loss in loss_dict.values()) | ||
loss_dict_reduced = reduce_loss_dict(loss_dict) | ||
losses_reduced = sum(loss for loss in loss_dict_reduced.values()) | ||
meters_val.update(loss=losses_reduced, **loss_dict_reduced) | ||
synchronize() | ||
logger.info( | ||
meters.delimiter.join( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does not matter here because |
||
[ | ||
"[Validation]: ", | ||
"eta: {eta}", | ||
"iter: {iter}", | ||
"{meters}", | ||
"lr: {lr:.6f}", | ||
"max mem: {memory:.0f}", | ||
] | ||
).format( | ||
eta=eta_string, | ||
iter=iteration, | ||
meters=str(meters_val), | ||
lr=optimizer.param_groups[0]["lr"], | ||
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, | ||
) | ||
) | ||
if iteration == max_iter: | ||
checkpointer.save("model_final", **arguments) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why we have to add two data sets in TRAIN: ("coco_2014_train", "coco_2014_valminusminival").
Only one data set will be returned in
maskrcnn_benchmark/data/build.py
's functionbuild_dataset
:datasets is a list, so dataset is
coco_2014_train
, right?And, Question 2:
Why you delete the
VAL
? From my perspective view, TEST is TEST, VAL is VAL. They are different distribution data set, right?Thank you so much for your work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding to Question 1:
In the highlighted code snippet datasets are concatenated if there are more than 1 dataset in the
TRAIN
field.Regarding to Question 2:
As discussed in #785 (proposed by @fmassa) in this case a separate validation dataset is needed rarely because you do not change hyperparameters when a training script works. After network tuning you can get the best model variant (evaluated on validation dataset which is
TEST
here) and runtools/test_net.py
with another dataset.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your patient and nice reply : )