Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How do I compute validation loss during training? #810

Closed
tshead2 opened this issue Feb 5, 2020 · 56 comments
Closed

How do I compute validation loss during training? #810

tshead2 opened this issue Feb 5, 2020 · 56 comments

Comments

@tshead2
Copy link

tshead2 commented Feb 5, 2020

How do I compute validation loss during training?

I'm trying to compute the loss on a validation dataset for each iteration during training. To do so, I've created my own hook:

class ValidationLoss(detectron2.engine.HookBase):
    def __init__(self, config, dataset_name):
        super(ValidationLoss, self).__init__()
        self._loader = detectron2.data.build_detection_test_loader(config, dataset_name)
        
    def after_step(self):
        for batch in self._loader:
            loss = self.trainer.model(batch)
            log.debug(f"validation loss: {loss}")

... which I register with a DefaultTrainer. The hook code is called during training, but fails with the following:

INFO:detectron2.engine.train_loop:Starting training from iteration 0
ERROR:detectron2.engine.train_loop:Exception during training:
Traceback (most recent call last):
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/engine/train_loop.py", line 133, in train
    self.after_step()
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/engine/train_loop.py", line 153, in after_step
    h.after_step()
  File "<ipython-input-6-63b308743b7d>", line 8, in after_step
    loss = self.trainer.model(batch)
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/modeling/meta_arch/rcnn.py", line 123, in forward
    proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/modeling/proposal_generator/rpn.py", line 164, in forward
    losses = {k: v * self.loss_weight for k, v in outputs.losses().items()}
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/modeling/proposal_generator/rpn_outputs.py", line 322, in losses
    gt_objectness_logits, gt_anchor_deltas = self._get_ground_truth()
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/modeling/proposal_generator/rpn_outputs.py", line 262, in _get_ground_truth
    for image_size_i, anchors_i, gt_boxes_i in zip(self.image_sizes, anchors, self.gt_boxes):
TypeError: zip argument #3 must support iteration
INFO:detectron2.engine.hooks:Total training time: 0:00:00 (0:00:00 on hooks)

The traceback seems to imply that ground truth data is missing, which made me think that the data loader was the problem. However, switching to a training loader produces a different error:

class ValidationLoss(detectron2.engine.HookBase):
    def __init__(self, config, dataset_name):
        super(ValidationLoss, self).__init__()
        self._loader = detectron2.data.build_detection_train_loader(config, dataset_name)
        
    def after_step(self):
        for batch in self._loader:
            loss = self.trainer.model(batch)
            log.debug(f"validation loss: {loss}")
INFO:detectron2.engine.train_loop:Starting training from iteration 0
ERROR:detectron2.engine.train_loop:Exception during training:
Traceback (most recent call last):
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/engine/train_loop.py", line 133, in train
    self.after_step()
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/engine/train_loop.py", line 153, in after_step
    h.after_step()
  File "<ipython-input-6-e0d2c509cc72>", line 7, in after_step
    for batch in self._loader:
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/data/common.py", line 109, in __iter__
    for d in self.dataset:
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 345, in __next__
    data = self._next_data()
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 856, in _next_data
    return self._process_data(data)
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 881, in _process_data
    data.reraise()
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/_utils.py", line 394, in reraise
    raise self.exc_type(msg)
TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/data/common.py", line 39, in __getitem__
    data = self._map_func(self._dataset[cur_idx])
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/utils/serialize.py", line 23, in __call__
    return self._obj(*args, **kwargs)
TypeError: 'str' object is not callable

INFO:detectron2.engine.hooks:Total training time: 0:00:00 (0:00:00 on hooks)

As a sanity check, inference works just fine:

class ValidationLoss(detectron2.engine.HookBase):
    def __init__(self, config, dataset_name):
        super(ValidationLoss, self).__init__()
        self._loader = detectron2.data.build_detection_test_loader(config, dataset_name)
        
    def after_step(self):
        for batch in self._loader:
            with detectron2.evaluation.inference_context(self.trainer.model):
                loss = self.trainer.model(batch)
                log.debug(f"validation loss: {loss}")
INFO:detectron2.engine.train_loop:Starting training from iteration 0
DEBUG:root:validation loss: [{'instances': Instances(num_instances=100, image_height=720, image_width=720, fields=[pred_boxes = Boxes(tensor([[4.4867e+02, 1.9488e+02, 5.1496e+02, 3.9878e+02],
        [4.2163e+02, 1.1204e+02, 6.1118e+02, 5.5378e+02],
        [8.7323e-01, 3.0374e+02, 9.2917e+01, 3.8698e+02],
        [4.3202e+02, 2.0296e+02, 5.7938e+02, 3.6817e+02],
        ...

... but that isn't what I want, of course. Any thoughts?

Thanks in advance,
Tim

@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Feb 5, 2020

The traceback seems to imply that ground truth data is missing,

That's correct and it's because the default dataloader for test-set does not include ground truth:

if not self.is_train:
dataset_dict.pop("annotations", None)
dataset_dict.pop("sem_seg_file_name", None)
return dataset_dict

You can provide mapper= to create a dataloader that loads test data with ground truth.

However, switching to a training loader produces a different error:

That's because you're not calling data.build_detection_train_loader following its API: https://detectron2.readthedocs.io/modules/data.html#detectron2.data.build_detection_train_loader

@tshead2
Copy link
Author

tshead2 commented Feb 5, 2020

Ah, copy-and-paste error. It's working now, thanks for the assist.

Cheers,
Tim

@tshead2 tshead2 closed this as completed Feb 5, 2020
@GorkemP
Copy link

GorkemP commented Feb 21, 2020

Hi @tshead2,

after creating hooker class, I performed the following:

valLoss = ValidationLoss(cfg, 'my_validation_set')  
hooks = [valLoss]  
trainer.register_hooks(hooks)  
DefaultTrainer.build_test_loader(cfg, "my_validation_set")  

Still get the same error, do I have to create my own mapper function? Can you provide me a template?

Thanks.

@mnslarcher
Copy link

Hi,

I have an hacky solution for this, I'll leave it here in case anyone needs it or someone has suggestions on how to improve it.

from detectron2.engine import HookBase
from detectron2.data import build_detection_train_loader
import detectron2.utils.comm as comm

cfg.DATASETS.VAL = ("voc_2007_val",)


class ValidationLoss(HookBase):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg.clone()
        self.cfg.DATASETS.TRAIN = cfg.DATASETS.VAL
        self._loader = iter(build_detection_train_loader(self.cfg))
        
    def after_step(self):
        data = next(self._loader)
        with torch.no_grad():
            loss_dict = self.trainer.model(data)
            
            losses = sum(loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {"val_" + k: v.item() for k, v in 
                                 comm.reduce_dict(loss_dict).items()}
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                self.trainer.storage.put_scalars(total_val_loss=losses_reduced, 
                                                 **loss_dict_reduced)

And then

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = Trainer(cfg) 
val_loss = ValidationLoss(cfg)  
trainer.register_hooks([val_loss])
# swap the order of PeriodicWriter and ValidationLoss
trainer._hooks = trainer._hooks[:-2] + trainer._hooks[-2:][::-1]
trainer.resume_or_load(resume=False)
trainer.train()

@nihal-rao
Copy link

nihal-rao commented Mar 21, 2020

Ah, copy-and-paste error. It's working now, thanks for the assist.

Cheers,
Tim

Hi @tshead2 , could you please mention the copy paste error ? How did you get it to work using build_detection_train_loader ?

@ortegatron
Copy link

Hi, I have written it and commented the code, you can see it here:
https://medium.com/@apofeniaco/training-on-detectron2-with-a-validation-set-and-plot-loss-on-it-to-avoid-overfitting-6449418fbf4e
or just the gist here:
https://gist.github.com/ortegatron/c0dad15e49c2b74de8bb09a5615d9f6b

@alono88
Copy link

alono88 commented Mar 24, 2020

@ortegatron Aren't you accumulating gradients in your implementation?

@ortegatron
Copy link

@alono88 can you please suggest me how that could be happening? On a code level, I'm just doing the sum on each iteration. But maybe I'm missing something at a general understanding level of how gradient behave

@alono88
Copy link

alono88 commented Mar 24, 2020

@ortegatron My mistake. It seems in this discussion that using torch.no_grad() only affects the memory and no intermediate tensors are stored.

@wesleylp
Copy link

wesleylp commented Mar 25, 2020

@ortegatron, first thank you for your code, it's very helpful !

I have the same question as @alono88. In your code, shouldn't the model be switched to eval mode (model.eval()) somewhere so that you don't accumulate the gradients?
[inference_on_dataset does that by calling inference_context in evaluator.py].

Or did I miss something?

once again,
thank you!

@alono88
Copy link

alono88 commented Mar 26, 2020

@wesleylp eval mode does not effect gradient accumulation, it adjusts layers such as dropout. In addition, using eval mode will cause the model to output predictions instead of loss values so you will have nothing to write.

@ortegatron I was trying to run your code using multiple GPUs and it does not work. Have you had experience with such setting or did you run it on a single gpu?

@ortegatron
Copy link

Hi Alono, nice that you answer, I was about to research about wesleylp question.

I have only tried it con single gpu, no idea what changes would multiple cpu imply, sorry

@dangmanhtruong1995
Copy link

dangmanhtruong1995 commented May 30, 2020

Hi Alono, nice that you answer, I was about to research about wesleylp question.

I have only tried it con single gpu, no idea what changes would multiple cpu imply, sorry

Hi I tried your code but after running validation it just hangs and does not run anything else. Please help me. Thank you very much. After a while, an error popped up: RuntimeError: [/opt/conda/conda-bld/pytorch_1587428207430/work/third_party/gloo/gloo/transport/tcp/unbound_buffer.cc:136] Timed out waiting 1800000ms for send operation to complete

@pvti
Copy link

pvti commented Jun 5, 2020

@dangmanhtruong1995 hi, have you solved it?

Hi Alono, nice that you answer, I was about to research about wesleylp question.
I have only tried it con single gpu, no idea what changes would multiple cpu imply, sorry

Hi I tried your code but after running validation it just hangs and does not run anything else. Please help me. Thank you very much. After a while, an error popped up: RuntimeError: [/opt/conda/conda-bld/pytorch_1587428207430/work/third_party/gloo/gloo/transport/tcp/unbound_buffer.cc:136] Timed out waiting 1800000ms for send operation to complete

@dangmanhtruong1995
Copy link

@dangmanhtruong1995 hi, have you solved it?

Hi Alono, nice that you answer, I was about to research about wesleylp question.
I have only tried it con single gpu, no idea what changes would multiple cpu imply, sorry

Hi I tried your code but after running validation it just hangs and does not run anything else. Please help me. Thank you very much. After a while, an error popped up: RuntimeError: [/opt/conda/conda-bld/pytorch_1587428207430/work/third_party/gloo/gloo/transport/tcp/unbound_buffer.cc:136] Timed out waiting 1800000ms for send operation to complete

Hi, I have not been able to solve it.

@cognitiveRobot
Copy link

cognitiveRobot commented Aug 26, 2020

I copied the idea from @mnslarcher and wrote the following two functions for my keypoint detector (resnet50) algorithm.

def build_valid_loader(cfg):
    _cfg = cfg.clone()
    _cfg.defrost()  # make this cfg mutable.
    _cfg.DATASETS.TRAIN = cfg.DATASETS.TEST
    return build_detection_train_loader(_cfg)

def store_valid_loss(model, data, storage):
    training_mode = model.training
    with torch.no_grad():
        loss_dict = model(data)
        losses = sum(loss_dict.values())
        assert torch.isfinite(losses).all(), loss_dict

        loss_dict_reduced = {k: v.item()
                                for k, v in comm.reduce_dict(loss_dict).items()}
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())

        if comm.is_main_process():
            storage.put_scalars(val_loss=losses_reduced, **loss_dict_reduced)
    model.train(training_mode)

then in plain_train_net.py I am calling them as bellow..

    val_data_loader = build_valid_loader(cfg)
    logger.info("Starting training from iteration {}".format(start_iter))
    with EventStorage(start_iter) as storage:
        for data, val_data, iteration in zip(data_loader, val_data_loader, range(start_iter, max_iter)):
            iteration = iteration + 1
            ..
            ..
           #At the end of the for loop.
           # Calculate and log validation loss.
            store_valid_loss(model, val_data, storage)

after 1k iteration, loss_keypoint is increasing, but total_loss is same compared to without store_valid_loss call. What am I missing? Can anyone please help to understand?

@bconsolvo-zvelo
Copy link

Hi,

I have an hacky solution for this, I'll leave it here in case anyone needs it or someone has suggestions on how to improve it.

from detectron2.engine import HookBase
from detectron2.data import build_detection_train_loader
import detectron2.utils.comm as comm

cfg.DATASETS.VAL = ("voc_2007_val",)


class ValidationLoss(HookBase):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg.clone()
        self.cfg.DATASETS.TRAIN = cfg.DATASETS.VAL
        self._loader = iter(build_detection_train_loader(self.cfg))
        
    def after_step(self):
        data = next(self._loader)
        with torch.no_grad():
            loss_dict = self.trainer.model(data)
            
            losses = sum(loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {"val_" + k: v.item() for k, v in 
                                 comm.reduce_dict(loss_dict).items()}
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                self.trainer.storage.put_scalars(total_val_loss=losses_reduced, 
                                                 **loss_dict_reduced)

And then

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = Trainer(cfg) 
val_loss = ValidationLoss(cfg)  
trainer.register_hooks([val_loss])
# swap the order of PeriodicWriter and ValidationLoss
trainer._hooks = trainer._hooks[:-2] + trainer._hooks[-2:][::-1]
trainer.resume_or_load(resume=False)
trainer.train()

I just wondered if there was a way to only calculate the validation loss every 500 iterations instead of every 20? I found that your code even works on my multi-GPU setup, but calculating the validation loss every 20 iterations is very costly time-wise.

@mnslarcher
Copy link

mnslarcher commented Sep 24, 2020

Hi,
I have an hacky solution for this, I'll leave it here in case anyone needs it or someone has suggestions on how to improve it.

from detectron2.engine import HookBase
from detectron2.data import build_detection_train_loader
import detectron2.utils.comm as comm

cfg.DATASETS.VAL = ("voc_2007_val",)


class ValidationLoss(HookBase):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg.clone()
        self.cfg.DATASETS.TRAIN = cfg.DATASETS.VAL
        self._loader = iter(build_detection_train_loader(self.cfg))
        
    def after_step(self):
        data = next(self._loader)
        with torch.no_grad():
            loss_dict = self.trainer.model(data)
            
            losses = sum(loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {"val_" + k: v.item() for k, v in 
                                 comm.reduce_dict(loss_dict).items()}
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                self.trainer.storage.put_scalars(total_val_loss=losses_reduced, 
                                                 **loss_dict_reduced)

And then

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = Trainer(cfg) 
val_loss = ValidationLoss(cfg)  
trainer.register_hooks([val_loss])
# swap the order of PeriodicWriter and ValidationLoss
trainer._hooks = trainer._hooks[:-2] + trainer._hooks[-2:][::-1]
trainer.resume_or_load(resume=False)
trainer.train()

I just wondered if there was a way to only calculate the validation loss every 500 iterations instead of every 20? I found that your code even works on my multi-GPU setup, but calculating the validation loss every 20 iterations is very costly time-wise.

Hi @bconsolvo-zvelo , its a lot that I don't play with this library but something like this PROBABLY (I'm not 100% sure) works:

YOUR_MAGIC_NUMBER = 42

class ValidationLoss(HookBase):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg.clone()
        self.cfg.DATASETS.TRAIN = cfg.DATASETS.VAL
        self._loader = iter(build_detection_train_loader(self.cfg))
        self.num_steps = 0
        
    def after_step(self):
        self.num_steps += 1
        if self.num_steps % YOUR_MAGIC_NUMBER == 0:
            data = next(self._loader)
            with torch.no_grad():
                loss_dict = self.trainer.model(data)
            
                losses = sum(loss_dict.values())
                assert torch.isfinite(losses).all(), loss_dict

                loss_dict_reduced = {"val_" + k: v.item() for k, v in 
                                     comm.reduce_dict(loss_dict).items()}
                losses_reduced = sum(loss for loss in loss_dict_reduced.values())
                if comm.is_main_process():
                    self.trainer.storage.put_scalars(total_val_loss=losses_reduced, 
                                                     **loss_dict_reduced)
        else:
            pass

Now I'm sure you can do a lot better than this, for example probably you don't have to re-define a concept like "num_steps" and instead to hardcode a number you can have something like this

cfg.VAL_INTERVAL = 42
...
 if self.num_steps % self.cfg.VAL_INTERVAL == 0:

I didn't test this solution so sorry if you will find out that for some reason it doesn't work, in case it does work or in case you will find a better solution, please comment here so also others can benefit from it

@bconsolvo-zvelo
Copy link

Thank you for your comments. For whatever reason, I am finding that calculating the validation at different steps produces different validation loss results (drastic orders of magnitude difference). Not sure if it is my setup/data or something inherent with the code. But trying to resolve it.

I have also heard the suggestion of not using hooks, but rather using "run_step" as seen here: https://tshafer.com/blog/2020/06/detectron2-eval-loss

Still investigating. Thank you for your prompt reply.

@TobiasWeis
Copy link

For some reason tensorboard will not display the validation-loss at all when using mnslarcher's code (only run if self.num_steps % MAGIC_NUM == 0) - the val_losses are computed and shown in the console, but somehow tensorboard does not like them.. Validation losses show fine if it runs on every call...

@bconsolvo-zvelo
Copy link

bconsolvo-zvelo commented Oct 2, 2020

@mnslarcher
On a 4xGPU setup, if I tell it to calculate the validation loss on the same iteration as I calculate my coco_eval results, it hangs indefinitely, just before finishing the inference calculation. Every other iteration works except on the exact one where it is calculating the coco_eval inference. Just very strange behaviour. It also seems a bit odd that now I have to calculate inference on all of my validation data twice: once for the coco_eval results, and then on another iteration for calculating the validation loss. Both are doing inference and comparing them to ground truth: coco_eval produces AP results, and the other produces just validation losses. Would be nice to combine somehow, and figure out why it is breaking whenever I put the
cfg.TEST.EVAL_PERIOD as the same iteration as where I am telling it to calculate the validation loss.

Some other questions:

  1. On the iteration where I tell it to calculate the validation loss, is it just not calculating the normal total loss, and only calculating validation loss?
  2. Can you elaborate on what this does below? I am confused by why you have to index things this way?
    trainer._hooks = trainer._hooks[:-2] + trainer._hooks[-2:][::-1]
  3. Is there any way I can verify that it is really getting losses from all 4 GPUs and combining them?
  4. Why do you not use comm.synchronize()? I thought this was necessary for 4 GPUs.

Thanks!

@mnslarcher
Copy link

Hi @bconsolvo-zvelo

  1. This hook will add "calculate val loss" to whatever your code is already doing. Training loss is calculated on each iteration by default otherwise it would not be possible to back-propagate (calculated does not mean displayed)
  2. It's a trick to change the order when each hook is applied. As written in the comment "swap the order of PeriodicWriter and ValidationLoss", first calculate the loss and then write everything (included val loss)
    At the time I needed a solution for a 1x GPU setting (Colab) so I didn't check if it was also multi-gpu compatible. Said this consider that:
    • I took inspiration from "official code" (probably the one that calculate the train loss, I don't remember) that is supposed to work also multi-gpu
    • I suspect that comm.reduce_dict do the trick

Said this, I'm not sure, I'm not an expert of this library and I don't use it from a long time so is better if you open a specific issue so someone more expert then me can answer your questions

@Syzygianinfern0
Copy link

Hi Alono, nice that you answer, I was about to research about wesleylp question.
I have only tried it con single gpu, no idea what changes would multiple cpu imply, sorry

Hi I tried your code but after running validation it just hangs and does not run anything else. Please help me. Thank you very much. After a while, an error popped up: RuntimeError: [/opt/conda/conda-bld/pytorch_1587428207430/work/third_party/gloo/gloo/transport/tcp/unbound_buffer.cc:136] Timed out waiting 1800000ms for send operation to complete

Did you find the solution to this issue?

@dangmanhtruong1995
Copy link

Hi Alono, nice that you answer, I was about to research about wesleylp question.
I have only tried it con single gpu, no idea what changes would multiple cpu imply, sorry

Hi I tried your code but after running validation it just hangs and does not run anything else. Please help me. Thank you very much. After a while, an error popped up: RuntimeError: [/opt/conda/conda-bld/pytorch_1587428207430/work/third_party/gloo/gloo/transport/tcp/unbound_buffer.cc:136] Timed out waiting 1800000ms for send operation to complete

Did you find the solution to this issue?

Hi, no I have not.

@Revist
Copy link

Revist commented Mar 29, 2021

Hi Alono, nice that you answer, I was about to research about wesleylp question.

I have only tried it con single gpu, no idea what changes would multiple cpu imply, sorry

Fix to the issue.

def build_hooks(self):
    hooks = super().build_hooks()
    hooks.insert(-1,LossEvalHook(
        cfg.TEST.EVAL_PERIOD,
        self.model,
        build_detection_test_loader(
            self.cfg,
            self.cfg.DATASETS.TEST[0],
            DatasetMapper(self.cfg,True)
        )
    ))
    # swap the order of PeriodicWriter and ValidationLoss
    # code hangs with no GPUs > 1 if this line is removed
    hooks = hooks[:-2] + hooks[-2:][::-1]
    return hooks

@Revist
Copy link

Revist commented Mar 29, 2021

Hi Alono, nice that you answer, I was about to research about wesleylp question.
I have only tried it con single gpu, no idea what changes would multiple cpu imply, sorry

Hi I tried your code but after running validation it just hangs and does not run anything else. Please help me. Thank you very much. After a while, an error popped up: RuntimeError: [/opt/conda/conda-bld/pytorch_1587428207430/work/third_party/gloo/gloo/transport/tcp/unbound_buffer.cc:136] Timed out waiting 1800000ms for send operation to complete

Did you find the solution to this issue?

Fix to the issue.

def build_hooks(self):
    hooks = super().build_hooks()
    hooks.insert(-1,LossEvalHook(
        cfg.TEST.EVAL_PERIOD,
        self.model,
        build_detection_test_loader(
            self.cfg,
            self.cfg.DATASETS.TEST[0],
            DatasetMapper(self.cfg,True)
        )
    ))
    # swap the order of PeriodicWriter and ValidationLoss
    # code hangs with no GPUs > 1 if this line is removed
    hooks = hooks[:-2] + hooks[-2:][::-1]
    return hooks

@marijnl
Copy link

marijnl commented Oct 4, 2021

I extended the code above to log both the train and val loss in the same graph in tensorboard. I put it here because i think it could be useful for others ending up here.

This is what your TB log will look like eventually
image

To do this, first create a custom tensorboard writer:

import os
from torch.utils.tensorboard import SummaryWriter
from detectron2.utils.events import EventWriter, get_event_storage


class CustomTensorboardXWriter(EventWriter):
    """
    Writes scalars and images based on storage key to train or val tensorboard file.
    """

    def __init__(self, log_dir: str, window_size: int = 20, **kwargs):
        """
        Args:
            log_dir (str): the base directory to save the output events. This class creates two subdirs in log_dir
            window_size (int): the scalars will be median-smoothed by this window size

            kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)`
        """
        self._window_size = window_size
        
        # separate the writers into a train and a val writer
        train_writer_path = os.path.join(log_dir,"train")
        os.makedirs(train_writer_path, exist_ok=True)
        self._writer_train = SummaryWriter(train_writer_path, **kwargs)
        
        val_writer_path = os.path.join(log_dir,"val")
        os.makedirs(val_writer_path, exist_ok=True)
        self._writer_val = SummaryWriter(val_writer_path, **kwargs)

    def write(self):

        storage = get_event_storage()
        for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items():
            if k.startswith("val_"):
                k = k.replace("val_","")
                self._writer_val.add_scalar(k, v, iter)
            else:
                self._writer_train.add_scalar(k, v, iter)

        if len(storage._vis_data) >= 1:
            for img_name, img, step_num in storage._vis_data:
                if k.startswith("val_"):
                    k = k.replace("val_","")
                    self._writer_val.add_image(img_name, img, step_num)
                else:
                    self._writer_train.add_image(img_name, img, step_num)
            # Storage stores all image data and rely on this writer to clear them.
            # As a result it assumes only one writer will use its image data.
            # An alternative design is to let storage store limited recent
            # data (e.g. only the most recent image) that all writers can access.
            # In that case a writer may not see all image data if its period is long.
            storage.clear_images()

        if len(storage._histograms) >= 1:
            for params in storage._histograms:
                self._writer_train.add_histogram_raw(**params)
            storage.clear_histograms()

    def close(self):
        if hasattr(self, "_writer"):  # doesn't exist when the code fails at import
            self._writer_train.close()
            self._writer_val.close()

Then register this writer in your trainer.
It will write plot train and val metrics in the same graph

class Trainer(DefaultTrainer):
    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        if output_folder is None:
            output_folder = os.path.join(cfg.OUTPUT_DIR,"inference")
        return COCOEvaluator(dataset_name, cfg, True, output_folder)
    
    def build_writers(self):
        """
        Overwrites the default writers to contain our custom tensorboard writer

        Returns:
            list[EventWriter]: a list of :class:`EventWriter` objects.
        """
        return [
            CommonMetricPrinter(self.max_iter),
            JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
            CustomTensorboardXWriter(self.cfg.OUTPUT_DIR),
        ]

@cpereir1
Copy link

cpereir1 commented Oct 4, 2021

Hi, all!
Thanks for this great work.
@marijnl, could you also share your train_net and your ValidationHook and how those tie in together?
Many thanks!

@marijnl
Copy link

marijnl commented Oct 6, 2021

As the Validation loss hook i use a slightly modified version of the code earlier in this thread.

import torch
from detectron2.data.build import build_detection_test_loader
from detectron2.engine import HookBase
import detectron2.utils.comm as comm

class ValLossHook(HookBase):
    
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg.clone()
        self._loader = iter(build_detection_test_loader(self.cfg, "my_dataset_val"))
        
    def after_step(self):
        """
            After each step calculates the validation loss and adds it to the train storage
        """
        data = next(self._loader)
        with torch.no_grad():
            loss_dict = self.trainer.model(data)
            
            losses = sum(loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {"val_" + k: v.item() for k, v in comm.reduce_dict(loss_dict).items()}
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                self.trainer.storage.put_scalars(val_total_loss=losses_reduced, 
                                                 **loss_dict_reduced)

Then to tie things together i use

    # setup trainer
    trainer = Trainer(cfg)

    # creates a hook that after each iter calculates the validation loss on the next batch
    # Register the hoooks
    trainer.register_hooks(
        [ValLossHook(cfg)]
    )

    # The PeriodicWriter needs to be the last hook, otherwise it wont have access to valloss metrics 
    # Ensure PeriodicWriter is the last called hook
    periodic_writer_hook = [hook for hook in trainer._hooks if isinstance(hook, PeriodicWriter)]
    all_other_hooks = [hook for hook in trainer._hooks if not isinstance(hook, PeriodicWriter)]
    trainer._hooks = all_other_hooks + periodic_writer_hook

    trainer.resume_or_load(resume=args.resume)

@marijnl
Copy link

marijnl commented Oct 18, 2021

You need a custom data mapper, something like this:

from detectron2.data import detection_utils as utils
from detectron2.data.build import (_test_loader_from_config, build_detection_train_loader)

    def custom_test_mapper(dataset_dict):
        # it will be modified by code below
        dataset_dict = copy.deepcopy(dataset_dict)
        image = utils.read_image(dataset_dict["file_name"], format="BGR")
        transform_list = []
        instances = utils.annotations_to_instances(annos, image.shape[:2])
        dataset_dict["instances"] = utils.filter_empty_instances(instances)
        return dataset_dict

    def build_test_loader(cls, cfg, dataset_name="my_dataset_val"):
        return build_detection_test_loader(cfg, dataset_name,  mapper=get_custom_test_mapper())

@MLDeep414
Copy link

Hi,

I have an hacky solution for this, I'll leave it here in case anyone needs it or someone has suggestions on how to improve it.

from detectron2.engine import HookBase
from detectron2.data import build_detection_train_loader
import detectron2.utils.comm as comm

cfg.DATASETS.VAL = ("voc_2007_val",)


class ValidationLoss(HookBase):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg.clone()
        self.cfg.DATASETS.TRAIN = cfg.DATASETS.VAL
        self._loader = iter(build_detection_train_loader(self.cfg))
        
    def after_step(self):
        data = next(self._loader)
        with torch.no_grad():
            loss_dict = self.trainer.model(data)
            
            losses = sum(loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {"val_" + k: v.item() for k, v in 
                                 comm.reduce_dict(loss_dict).items()}
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                self.trainer.storage.put_scalars(total_val_loss=losses_reduced, 
                                                 **loss_dict_reduced)

And then

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = Trainer(cfg) 
val_loss = ValidationLoss(cfg)  
trainer.register_hooks([val_loss])
# swap the order of PeriodicWriter and ValidationLoss
trainer._hooks = trainer._hooks[:-2] + trainer._hooks[-2:][::-1]
trainer.resume_or_load(resume=False)
trainer.train()

Hi,

How to write code to implement early stopping based on the validation loss?

@zensenlon
Copy link

It doesn't work on me

@HuygheB
Copy link

HuygheB commented Jul 5, 2022

As the Validation loss hook i use a slightly modified version of the code earlier in this thread.

import torch
from detectron2.data.build import build_detection_test_loader
from detectron2.engine import HookBase
import detectron2.utils.comm as comm

class ValLossHook(HookBase):
    
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg.clone()
        self._loader = iter(build_detection_test_loader(self.cfg, "my_dataset_val"))
        
    def after_step(self):
        """
            After each step calculates the validation loss and adds it to the train storage
        """
        data = next(self._loader)
        with torch.no_grad():
            loss_dict = self.trainer.model(data)
            
            losses = sum(loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {"val_" + k: v.item() for k, v in comm.reduce_dict(loss_dict).items()}
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                self.trainer.storage.put_scalars(val_total_loss=losses_reduced, 
                                                 **loss_dict_reduced)

Then to tie things together i use

    # setup trainer
    trainer = Trainer(cfg)

    # creates a hook that after each iter calculates the validation loss on the next batch
    # Register the hoooks
    trainer.register_hooks(
        [ValLossHook(cfg)]
    )

    # The PeriodicWriter needs to be the last hook, otherwise it wont have access to valloss metrics 
    # Ensure PeriodicWriter is the last called hook
    periodic_writer_hook = [hook for hook in trainer._hooks if isinstance(hook, PeriodicWriter)]
    all_other_hooks = [hook for hook in trainer._hooks if not isinstance(hook, PeriodicWriter)]
    trainer._hooks = all_other_hooks + periodic_writer_hook

    trainer.resume_or_load(resume=args.resume)

@marijnl
Hi, I tried implementing your custom tensorboard writer along with the Validation Loss hook and training settings you provided. However, I get the following error:
'NameError: name 'PeriodicWriter' is not defined'
Any idea what the solution is to this?

@poonnatuch
Copy link

poonnatuch commented Aug 29, 2022

Another solution for this problem with not really creating new custom dataset loader is to tell DatasetMapper() to load ground truth along with it.

I prefer this way because I don't want to manipulate config and if you are, like me, freeze your config node.

class ValLossHook(HookBase):
    def __init__(self, cfg, validation_set_key):
        super().__init__()
        self.cfg = cfg.clone()
        self._loader = iter(build_detection_test_loader(self.cfg, self.cfg.DATASETS.TEST, mapper=DatasetMapper(self.cfg, is_train=True)))

    def after_step(self):
        """
        After each step calculates the validation loss and adds it to the train storage
        """
        data = next(self._loader)
        with torch.no_grad():
            loss_dict = self.trainer.model(data)

            losses = sum(loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {"validation_" + k: v.item() for k, v in comm.reduce_dict(loss_dict).items()}
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                self.trainer.storage.put_scalars(validation_total_loss=losses_reduced, **loss_dict_reduced)

@ricber
Copy link

ricber commented Mar 13, 2023

@mnslarcher don't you have to call model.eval() in your after_step method to notify the batchnorm and dropout layers to work in eval mode? Otherwise you get inconsistent results in different runs...

@urbanophile
Copy link

As the Validation loss hook i use a slightly modified version of the code earlier in this thread.

import torch
from detectron2.data.build import build_detection_test_loader
from detectron2.engine import HookBase
import detectron2.utils.comm as comm

class ValLossHook(HookBase):
    
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg.clone()
        self._loader = iter(build_detection_test_loader(self.cfg, "my_dataset_val"))
        
    def after_step(self):
        """
            After each step calculates the validation loss and adds it to the train storage
        """
        data = next(self._loader)
        with torch.no_grad():
            loss_dict = self.trainer.model(data)
            
            losses = sum(loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {"val_" + k: v.item() for k, v in comm.reduce_dict(loss_dict).items()}
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                self.trainer.storage.put_scalars(val_total_loss=losses_reduced, 
                                                 **loss_dict_reduced)

Then to tie things together i use

    # setup trainer
    trainer = Trainer(cfg)

    # creates a hook that after each iter calculates the validation loss on the next batch
    # Register the hoooks
    trainer.register_hooks(
        [ValLossHook(cfg)]
    )

    # The PeriodicWriter needs to be the last hook, otherwise it wont have access to valloss metrics 
    # Ensure PeriodicWriter is the last called hook
    periodic_writer_hook = [hook for hook in trainer._hooks if isinstance(hook, PeriodicWriter)]
    all_other_hooks = [hook for hook in trainer._hooks if not isinstance(hook, PeriodicWriter)]
    trainer._hooks = all_other_hooks + periodic_writer_hook

    trainer.resume_or_load(resume=args.resume)

@marijnl Hi, I tried implementing your custom tensorboard writer along with the Validation Loss hook and training settings you provided. However, I get the following error: 'NameError: name 'PeriodicWriter' is not defined' Any idea what the solution is to this?

Periodic writer is a child class of HookBase as defined in
https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/hooks.py

@ravijo
Copy link

ravijo commented Apr 20, 2023

With DatasetMapper and setting is_train=True, the code is throwing StopIteration exception.

Please see below the ValLossHook:

class ValLossHook(HookBase):
    
    def __init__(self, cfg, validation_set_key):
        super().__init__()
        self.cfg = cfg.clone()
        self._loader = iter(build_detection_test_loader(self.cfg, validation_set_key, 
                                                        mapper=DatasetMapper(self.cfg, is_train=True), 
                                                        num_workers=1))

    def after_step(self):
        """
            After each step calculates the validation loss and adds it to the train storage
        """
        print(type(self._loader), len(self._loader)) # just for debugging
        data = next(self._loader)
        with torch.no_grad():
            loss_dict = self.trainer.model(data)
            
            losses = sum(loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {"val_" + k: v.item() for k, v in comm.reduce_dict(loss_dict).items()}
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                self.trainer.storage.put_scalars(val_total_loss=losses_reduced, 
                                                 **loss_dict_reduced)

Below is the error trace:

<class 'torch.utils.data.dataloader._MultiProcessingDataLoaderIter'> 5
<class 'torch.utils.data.dataloader._MultiProcessingDataLoaderIter'> 5
<class 'torch.utils.data.dataloader._MultiProcessingDataLoaderIter'> 5
<class 'torch.utils.data.dataloader._MultiProcessingDataLoaderIter'> 5
<class 'torch.utils.data.dataloader._MultiProcessingDataLoaderIter'> 5
<class 'torch.utils.data.dataloader._MultiProcessingDataLoaderIter'> 5
ERROR [04/20 18:17:12 d2.engine.train_loop]: Exception during training:
Traceback (most recent call last):
  File "/home/ravi/.local/lib/python3.6/site-packages/detectron2/engine/train_loop.py", line 150, in train
    self.after_step()
  File "/home/ravi/.local/lib/python3.6/site-packages/detectron2/engine/train_loop.py", line 180, in after_step
    h.after_step()
  File "/home/ravi/detectron2_examples/train/val_loss_hook.py", line 37, in after_step
    data = next(self._loader)
  File "/home/ravi/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
    data = self._next_data()
  File "/home/ravi/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1176, in _next_data
    raise StopIteration
StopIteration

What is wrong here?

Thanks a lot

@Lihewin
Copy link

Lihewin commented Apr 22, 2023

With DatasetMapper and setting is_train=True, the code is throwing StopIteration exception.

Please see below the ValLossHook:

class ValLossHook(HookBase):
    
    def __init__(self, cfg, validation_set_key):
        super().__init__()
        self.cfg = cfg.clone()
        self._loader = iter(build_detection_test_loader(self.cfg, validation_set_key, 
                                                        mapper=DatasetMapper(self.cfg, is_train=True), 
                                                        num_workers=1))

    def after_step(self):
        """
            After each step calculates the validation loss and adds it to the train storage
        """
        print(type(self._loader), len(self._loader)) # just for debugging
        data = next(self._loader)
        with torch.no_grad():
            loss_dict = self.trainer.model(data)
            
            losses = sum(loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {"val_" + k: v.item() for k, v in comm.reduce_dict(loss_dict).items()}
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                self.trainer.storage.put_scalars(val_total_loss=losses_reduced, 
                                                 **loss_dict_reduced)

Below is the error trace:

<class 'torch.utils.data.dataloader._MultiProcessingDataLoaderIter'> 5
<class 'torch.utils.data.dataloader._MultiProcessingDataLoaderIter'> 5
<class 'torch.utils.data.dataloader._MultiProcessingDataLoaderIter'> 5
<class 'torch.utils.data.dataloader._MultiProcessingDataLoaderIter'> 5
<class 'torch.utils.data.dataloader._MultiProcessingDataLoaderIter'> 5
<class 'torch.utils.data.dataloader._MultiProcessingDataLoaderIter'> 5
ERROR [04/20 18:17:12 d2.engine.train_loop]: Exception during training:
Traceback (most recent call last):
  File "/home/ravi/.local/lib/python3.6/site-packages/detectron2/engine/train_loop.py", line 150, in train
    self.after_step()
  File "/home/ravi/.local/lib/python3.6/site-packages/detectron2/engine/train_loop.py", line 180, in after_step
    h.after_step()
  File "/home/ravi/detectron2_examples/train/val_loss_hook.py", line 37, in after_step
    data = next(self._loader)
  File "/home/ravi/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
    data = self._next_data()
  File "/home/ravi/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1176, in _next_data
    raise StopIteration
StopIteration

What is wrong here?

Thanks a lot

I think it is because you are using a build_detection_test_loader()function here, which returns a torchdata.DataLoader with produce only one batch of data, so in data = next(self._loader) the iterator fails to get the data when the valisation set is consumed up.You can manually make a validation set which only contains one pictue to confirm this. My solution is to replace with build_detection_train_loader(), which produce batched data, working in my project.
Please reference to build.py

@ravijo
Copy link

ravijo commented Apr 23, 2023

@Lihewin

Thanks a lot. build_detection_train_loader works like a charm.

BTW, do you have any suggestions/comment on #4922

I appriciate your time!

@geotsl
Copy link

geotsl commented Apr 25, 2023

@Lihewin are you aware how with your code I could run the validation evaluation not after each iteration but for example after 1000?

@ravijo
Copy link

ravijo commented Apr 25, 2023

@geotsl

One idea can be using a counter inside ValLossHook and then a conditoinal statement inside after_step enables validation. However, tensorflow may not be happy with it. A workaround is to write the validation loss to terminal.

@urbanophile
Copy link

@ppwwyyxx I'm sorry to drag you back to this closed issue but the continued activity in this thread suggests users are still having difficulty with this problem. Trying to calculate the validation loss is such a common use-case, it would be very useful to have an canonical response in the documentation. Would you be receptive to a pull request formalising one of the solutions above for inclusion in the docs?

@pkhateri
Copy link

@ppwwyyxx
I would like to second this request!

@ppwwyyxx I'm sorry to drag you back to this closed issue but the continued activity in this thread suggests users are still having difficulty with this problem. Trying to calculate the validation loss is such a common use-case, it would be very useful to have an canonical response in the documentation. Would you be receptive to a pull request formalising one of the solutions above for inclusion in the docs?

@xxxming730
Copy link

@mnslarcher On a 4xGPU setup, if I tell it to calculate the validation loss on the same iteration as I calculate my coco_eval results, it hangs indefinitely, just before finishing the inference calculation. Every other iteration works except on the exact one where it is calculating the coco_eval inference. Just very strange behaviour. It also seems a bit odd that now I have to calculate inference on all of my validation data twice: once for the coco_eval results, and then on another iteration for calculating the validation loss. Both are doing inference and comparing them to ground truth: coco_eval produces AP results, and the other produces just validation losses. Would be nice to combine somehow, and figure out why it is breaking whenever I put the cfg.TEST.EVAL_PERIOD as the same iteration as where I am telling it to calculate the validation loss.

Some other questions:

  1. On the iteration where I tell it to calculate the validation loss, is it just not calculating the normal total loss, and only calculating validation loss?
  2. Can you elaborate on what this does below? I am confused by why you have to index things this way?
    trainer._hooks = trainer._hooks[:-2] + trainer._hooks[-2:][::-1]
  3. Is there any way I can verify that it is really getting losses from all 4 GPUs and combining them?
  4. Why do you not use comm.synchronize()? I thought this was necessary for 4 GPUs.

Thanks!

hi, Have you solved the problem? Hope you are ok every day!

@ravijo
Copy link

ravijo commented Jul 27, 2023

@xxxming730

What problem exactly you are trying to solve? Based on the info on this post and others, I was able to compute validation loss. I used only a single GPU, however.

A working code is available here: ravijo/detectron2_tutorial

@xxxming730
Copy link

xxxming730 commented Jul 27, 2023 via email

@ravijo
Copy link

ravijo commented Jul 27, 2023

@xxxming730

I see. I can't say about the training on multiple graphics cards, but maybe you want to try on a single GPU first and then scale it to multiple GPUs. Using ravijo/detectron2_tutorial, I was able to get the plots on Tensorboard. My training was enough for a single GPU, so I did not explore multiple GPUs.

Hope it helps

@xxxming730
Copy link

@xxxming730

I see. I can't say about the training on multiple graphics cards, but maybe you want to try on a single GPU first and then scale it to multiple GPUs. Using ravijo/detectron2_tutorial, I was able to get the plots on Tensorboard. My training was enough for a single GPU, so I did not explore multiple GPUs.

Hope it helps

Thank you very much, I will try as you said! Hope you are doing well every day!

@xxxming730
Copy link

@xxxming730

I see. I can't say about the training on multiple graphics cards, but maybe you want to try on a single GPU first and then scale it to multiple GPUs. Using ravijo/detectron2_tutorial, I was able to get the plots on Tensorboard. My training was enough for a single GPU, so I did not explore multiple GPUs.

Hope it helps

@ravijo Hi,I am using this method below:

Hi, I have written it and commented the code, you can see it here: https://medium.com/@apofeniaco/training-on-detectron2-with-a-validation-set-and-plot-loss-on-it-to-avoid-overfitting-6449418fbf4e or just the gist here: https://gist.github.com/ortegatron/c0dad15e49c2b74de8bb09a5615d9f6b

A single GPU is fine, I can output validation_loss and evaluation results and log it, but when training on multiple Gpus, the record related to the output is wrong, I think I just need to try to modify and look for this part of the code about the output storage, thanks again for helping me!

@edoardounali
Copy link

I extended the code above to log both the train and val loss in the same graph in tensorboard. I put it here because i think it could be useful for others ending up here.

This is what your TB log will look like eventually image

To do this, first create a custom tensorboard writer:

import os
from torch.utils.tensorboard import SummaryWriter
from detectron2.utils.events import EventWriter, get_event_storage


class CustomTensorboardXWriter(EventWriter):
    """
    Writes scalars and images based on storage key to train or val tensorboard file.
    """

    def __init__(self, log_dir: str, window_size: int = 20, **kwargs):
        """
        Args:
            log_dir (str): the base directory to save the output events. This class creates two subdirs in log_dir
            window_size (int): the scalars will be median-smoothed by this window size

            kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)`
        """
        self._window_size = window_size
        
        # separate the writers into a train and a val writer
        train_writer_path = os.path.join(log_dir,"train")
        os.makedirs(train_writer_path, exist_ok=True)
        self._writer_train = SummaryWriter(train_writer_path, **kwargs)
        
        val_writer_path = os.path.join(log_dir,"val")
        os.makedirs(val_writer_path, exist_ok=True)
        self._writer_val = SummaryWriter(val_writer_path, **kwargs)

    def write(self):

        storage = get_event_storage()
        for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items():
            if k.startswith("val_"):
                k = k.replace("val_","")
                self._writer_val.add_scalar(k, v, iter)
            else:
                self._writer_train.add_scalar(k, v, iter)

        if len(storage._vis_data) >= 1:
            for img_name, img, step_num in storage._vis_data:
                if k.startswith("val_"):
                    k = k.replace("val_","")
                    self._writer_val.add_image(img_name, img, step_num)
                else:
                    self._writer_train.add_image(img_name, img, step_num)
            # Storage stores all image data and rely on this writer to clear them.
            # As a result it assumes only one writer will use its image data.
            # An alternative design is to let storage store limited recent
            # data (e.g. only the most recent image) that all writers can access.
            # In that case a writer may not see all image data if its period is long.
            storage.clear_images()

        if len(storage._histograms) >= 1:
            for params in storage._histograms:
                self._writer_train.add_histogram_raw(**params)
            storage.clear_histograms()

    def close(self):
        if hasattr(self, "_writer"):  # doesn't exist when the code fails at import
            self._writer_train.close()
            self._writer_val.close()

Then register this writer in your trainer. It will write plot train and val metrics in the same graph

class Trainer(DefaultTrainer):
    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        if output_folder is None:
            output_folder = os.path.join(cfg.OUTPUT_DIR,"inference")
        return COCOEvaluator(dataset_name, cfg, True, output_folder)
    
    def build_writers(self):
        """
        Overwrites the default writers to contain our custom tensorboard writer

        Returns:
            list[EventWriter]: a list of :class:`EventWriter` objects.
        """
        return [
            CommonMetricPrinter(self.max_iter),
            JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
            CustomTensorboardXWriter(self.cfg.OUTPUT_DIR),
        ]

Hi, I followed every step. But writer saves vall loss only every period defined by

cfg.TEST.EVAL_PERIOD = 50

and not every 20 steps ( windows_size).

Furthermore tensorboard scalars are plotted on two different graphs, not togher undet total_loss.

@ravijo
Copy link

ravijo commented Oct 13, 2023

@edoardounali

tensorboard scalars are plotted on two different graphs, not togher undet total_loss.

For this issue, you may check the following working demo: detectron2_tutorial

@edoardounali
Copy link

@edoardounali

tensorboard scalars are plotted on two different graphs, not togher undet total_loss.

For this issue, you may check the following working demo: detectron2_tutorial

Solved! Thanks you so much!

@CA4GitHub
Copy link

@ravijo is there a reason your ValLossHook class after_step method doesn't loop over the batches returned by the data loader (i.e. self._loader)? I expected the method to loop over all the batches and compute an average or total loss.

@CHN-001
Copy link

CHN-001 commented Jun 18, 2024

Question1:
File "/home/server/anaconda3/envs/bcnet/lib/python3.7/site-packages/yacs/config.py", line 147, in setattr
name, value
AttributeError: Attempted to set TRAIN to ('coco_my_val',), but CfgNode is immutable

Solution:
I solved it by modify code:

class ValidationLoss(HookBase):
def init(self, cfg):
super().init()
self.cfg = cfg.clone()
self.cfg.defrost() # Unfreeze the config to allow modifications
self.cfg.DATASETS.TRAIN = cfg.DATASETS.VAL
self.cfg.freeze() # Refreeze the config after modifications
self._loader = iter(build_detection_train_loader(self.cfg))

Question2:
File "/home/server/anaconda3/envs/bcnet/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
TypeError: forward() missing 2 required positional arguments: 'c_iter' and 'max_iter'

Solution:
I solved it by modify code:

class ValidationLoss(HookBase):
def init(self, cfg):
super().init()
self.cfg = cfg.clone()
self.cfg.defrost() # Unfreeze the config to allow modifications
self.cfg.DATASETS.TRAIN = cfg.DATASETS.VAL
self.cfg.freeze() # Refreeze the config after modifications
self._loader = iter(build_detection_train_loader(self.cfg))

def after_step(self):
    data = next(self._loader)
    with torch.no_grad():
        c_iter = self.trainer.iter
        max_iter = self.trainer.max_iter # 获取当前迭代数和最大迭代数

        # 调用模型时传递 c_iter 和 max_iter 参数
        loss_dict = self.trainer.model(data, c_iter=c_iter, max_iter=max_iter)
        
        losses = sum(loss_dict.values())
        assert torch.isfinite(losses).all(), loss_dict

        loss_dict_reduced = {"val_" + k: v.item() for k, v in 
                             comm.reduce_dict(loss_dict).items()}
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        if comm.is_main_process():
            self.trainer.storage.put_scalars(total_val_loss=losses_reduced, 
                                             **loss_dict_reduced)

@ahmadsadeed
Copy link

Hi Alono, nice that you answer, I was about to research about wesleylp question.
I have only tried it con single gpu, no idea what changes would multiple cpu imply, sorry

Hi I tried your code but after running validation it just hangs and does not run anything else. Please help me. Thank you very much. After a while, an error popped up: RuntimeError: [/opt/conda/conda-bld/pytorch_1587428207430/work/third_party/gloo/gloo/transport/tcp/unbound_buffer.cc:136] Timed out waiting 1800000ms for send operation to complete

Did you find the solution to this issue?

Fix to the issue.

def build_hooks(self):
    hooks = super().build_hooks()
    hooks.insert(-1,LossEvalHook(
        cfg.TEST.EVAL_PERIOD,
        self.model,
        build_detection_test_loader(
            self.cfg,
            self.cfg.DATASETS.TEST[0],
            DatasetMapper(self.cfg,True)
        )
    ))
    # swap the order of PeriodicWriter and ValidationLoss
    # code hangs with no GPUs > 1 if this line is removed
    hooks = hooks[:-2] + hooks[-2:][::-1]
    return hooks

Why not use hooks.append() and won't need swapping ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests