-
Notifications
You must be signed in to change notification settings - Fork 7.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How do I compute validation loss during training? #810
Comments
That's correct and it's because the default dataloader for test-set does not include ground truth: detectron2/detectron2/data/dataset_mapper.py Lines 109 to 112 in ee0cbd8
You can provide
That's because you're not calling |
Ah, copy-and-paste error. It's working now, thanks for the assist. Cheers, |
Hi @tshead2, after creating hooker class, I performed the following:
Still get the same error, do I have to create my own mapper function? Can you provide me a template? Thanks. |
Hi, I have an hacky solution for this, I'll leave it here in case anyone needs it or someone has suggestions on how to improve it. from detectron2.engine import HookBase
from detectron2.data import build_detection_train_loader
import detectron2.utils.comm as comm
cfg.DATASETS.VAL = ("voc_2007_val",)
class ValidationLoss(HookBase):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg.clone()
self.cfg.DATASETS.TRAIN = cfg.DATASETS.VAL
self._loader = iter(build_detection_train_loader(self.cfg))
def after_step(self):
data = next(self._loader)
with torch.no_grad():
loss_dict = self.trainer.model(data)
losses = sum(loss_dict.values())
assert torch.isfinite(losses).all(), loss_dict
loss_dict_reduced = {"val_" + k: v.item() for k, v in
comm.reduce_dict(loss_dict).items()}
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
if comm.is_main_process():
self.trainer.storage.put_scalars(total_val_loss=losses_reduced,
**loss_dict_reduced) And then
|
Hi @tshead2 , could you please mention the copy paste error ? How did you get it to work using |
Hi, I have written it and commented the code, you can see it here: |
@ortegatron Aren't you accumulating gradients in your implementation? |
@alono88 can you please suggest me how that could be happening? On a code level, I'm just doing the sum on each iteration. But maybe I'm missing something at a general understanding level of how gradient behave |
@ortegatron My mistake. It seems in this discussion that using |
@ortegatron, first thank you for your code, it's very helpful ! I have the same question as @alono88. In your code, shouldn't the model be switched to eval mode ( Or did I miss something? once again, |
@wesleylp eval mode does not effect gradient accumulation, it adjusts layers such as dropout. In addition, using eval mode will cause the model to output predictions instead of loss values so you will have nothing to write. @ortegatron I was trying to run your code using multiple GPUs and it does not work. Have you had experience with such setting or did you run it on a single gpu? |
Hi Alono, nice that you answer, I was about to research about wesleylp question. I have only tried it con single gpu, no idea what changes would multiple cpu imply, sorry |
Hi I tried your code but after running validation it just hangs and does not run anything else. Please help me. Thank you very much. After a while, an error popped up: RuntimeError: [/opt/conda/conda-bld/pytorch_1587428207430/work/third_party/gloo/gloo/transport/tcp/unbound_buffer.cc:136] Timed out waiting 1800000ms for send operation to complete |
@dangmanhtruong1995 hi, have you solved it?
|
Hi, I have not been able to solve it. |
I copied the idea from @mnslarcher and wrote the following two functions for my keypoint detector (resnet50) algorithm.
then in
after 1k iteration, |
I just wondered if there was a way to only calculate the validation loss every 500 iterations instead of every 20? I found that your code even works on my multi-GPU setup, but calculating the validation loss every 20 iterations is very costly time-wise. |
Hi @bconsolvo-zvelo , its a lot that I don't play with this library but something like this PROBABLY (I'm not 100% sure) works: YOUR_MAGIC_NUMBER = 42
class ValidationLoss(HookBase):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg.clone()
self.cfg.DATASETS.TRAIN = cfg.DATASETS.VAL
self._loader = iter(build_detection_train_loader(self.cfg))
self.num_steps = 0
def after_step(self):
self.num_steps += 1
if self.num_steps % YOUR_MAGIC_NUMBER == 0:
data = next(self._loader)
with torch.no_grad():
loss_dict = self.trainer.model(data)
losses = sum(loss_dict.values())
assert torch.isfinite(losses).all(), loss_dict
loss_dict_reduced = {"val_" + k: v.item() for k, v in
comm.reduce_dict(loss_dict).items()}
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
if comm.is_main_process():
self.trainer.storage.put_scalars(total_val_loss=losses_reduced,
**loss_dict_reduced)
else:
pass Now I'm sure you can do a lot better than this, for example probably you don't have to re-define a concept like "num_steps" and instead to hardcode a number you can have something like this cfg.VAL_INTERVAL = 42
...
if self.num_steps % self.cfg.VAL_INTERVAL == 0: I didn't test this solution so sorry if you will find out that for some reason it doesn't work, in case it does work or in case you will find a better solution, please comment here so also others can benefit from it |
Thank you for your comments. For whatever reason, I am finding that calculating the validation at different steps produces different validation loss results (drastic orders of magnitude difference). Not sure if it is my setup/data or something inherent with the code. But trying to resolve it. I have also heard the suggestion of not using hooks, but rather using "run_step" as seen here: https://tshafer.com/blog/2020/06/detectron2-eval-loss Still investigating. Thank you for your prompt reply. |
For some reason tensorboard will not display the validation-loss at all when using mnslarcher's code (only run if self.num_steps % MAGIC_NUM == 0) - the val_losses are computed and shown in the console, but somehow tensorboard does not like them.. Validation losses show fine if it runs on every call... |
@mnslarcher Some other questions:
Thanks! |
Said this, I'm not sure, I'm not an expert of this library and I don't use it from a long time so is better if you open a specific issue so someone more expert then me can answer your questions |
Did you find the solution to this issue? |
Hi, no I have not. |
Fix to the issue.
|
Fix to the issue.
|
Hi, all! |
As the Validation loss hook i use a slightly modified version of the code earlier in this thread.
Then to tie things together i use
|
You need a custom data mapper, something like this:
|
Hi, How to write code to implement early stopping based on the validation loss? |
It doesn't work on me |
@marijnl |
Another solution for this problem with not really creating new custom dataset loader is to tell I prefer this way because I don't want to manipulate config and if you are, like me, freeze your config node. class ValLossHook(HookBase):
def __init__(self, cfg, validation_set_key):
super().__init__()
self.cfg = cfg.clone()
self._loader = iter(build_detection_test_loader(self.cfg, self.cfg.DATASETS.TEST, mapper=DatasetMapper(self.cfg, is_train=True)))
def after_step(self):
"""
After each step calculates the validation loss and adds it to the train storage
"""
data = next(self._loader)
with torch.no_grad():
loss_dict = self.trainer.model(data)
losses = sum(loss_dict.values())
assert torch.isfinite(losses).all(), loss_dict
loss_dict_reduced = {"validation_" + k: v.item() for k, v in comm.reduce_dict(loss_dict).items()}
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
if comm.is_main_process():
self.trainer.storage.put_scalars(validation_total_loss=losses_reduced, **loss_dict_reduced) |
@mnslarcher don't you have to call |
Periodic writer is a child class of HookBase as defined in |
With Please see below the class ValLossHook(HookBase):
def __init__(self, cfg, validation_set_key):
super().__init__()
self.cfg = cfg.clone()
self._loader = iter(build_detection_test_loader(self.cfg, validation_set_key,
mapper=DatasetMapper(self.cfg, is_train=True),
num_workers=1))
def after_step(self):
"""
After each step calculates the validation loss and adds it to the train storage
"""
print(type(self._loader), len(self._loader)) # just for debugging
data = next(self._loader)
with torch.no_grad():
loss_dict = self.trainer.model(data)
losses = sum(loss_dict.values())
assert torch.isfinite(losses).all(), loss_dict
loss_dict_reduced = {"val_" + k: v.item() for k, v in comm.reduce_dict(loss_dict).items()}
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
if comm.is_main_process():
self.trainer.storage.put_scalars(val_total_loss=losses_reduced,
**loss_dict_reduced) Below is the error trace:
What is wrong here? Thanks a lot |
I think it is because you are using a |
@Lihewin are you aware how with your code I could run the validation evaluation not after each iteration but for example after 1000? |
One idea can be using a counter inside |
@ppwwyyxx I'm sorry to drag you back to this closed issue but the continued activity in this thread suggests users are still having difficulty with this problem. Trying to calculate the validation loss is such a common use-case, it would be very useful to have an canonical response in the documentation. Would you be receptive to a pull request formalising one of the solutions above for inclusion in the docs? |
@ppwwyyxx
|
hi, Have you solved the problem? Hope you are ok every day! |
What problem exactly you are trying to solve? Based on the info on this post and others, I was able to compute validation loss. I used only a single GPU, however. A working code is available here: ravijo/detectron2_tutorial |
First of all, thank you for your reply. I want to train on multiple graphics cards and output the validation loss and evaluation results, but after using the above code, the console can print the value of validation_loss, but there is no numerical record of validation loss in meters.json. So there is no validation loss curve in the plot after running PlotTogether.py. I was very troubled. It seemed that only my output was different from the output of others, because I did not see that others had the same question. Thank you again!
Marco Wei
***@***.***
…------------------ 原始邮件 ------------------
发件人: "Ravi ***@***.***>;
发送时间: 2023年7月27日(星期四) 下午4:15
收件人: ***@***.***>;
抄送: "Marco ***@***.***>; ***@***.***>;
主题: Re: [facebookresearch/detectron2] How do I compute validation loss during training? (#810)
@xxxming730
你到底想解决什么问题?基于这篇文章和其他文章的信息,我能够计算验证损失。然而,我只用了一个GPU。
这里提供了一个工作代码:Ravijo/detector on 2_教程
—
直接回复此邮件,在GitHub上查看,或取消订阅.
***@***.***与>.
|
I see. I can't say about the training on multiple graphics cards, but maybe you want to try on a single GPU first and then scale it to multiple GPUs. Using ravijo/detectron2_tutorial, I was able to get the plots on Tensorboard. My training was enough for a single GPU, so I did not explore multiple GPUs. Hope it helps |
Thank you very much, I will try as you said! Hope you are doing well every day! |
@ravijo Hi,I am using this method below:
A single GPU is fine, I can output validation_loss and evaluation results and log it, but when training on multiple Gpus, the record related to the output is wrong, I think I just need to try to modify and look for this part of the code about the output storage, thanks again for helping me! |
For this issue, you may check the following working demo: detectron2_tutorial |
Solved! Thanks you so much! |
@ravijo is there a reason your ValLossHook class after_step method doesn't loop over the batches returned by the data loader (i.e. self._loader)? I expected the method to loop over all the batches and compute an average or total loss. |
Question1: Solution: class ValidationLoss(HookBase): Question2: Solution: class ValidationLoss(HookBase):
|
Why not use |
How do I compute validation loss during training?
I'm trying to compute the loss on a validation dataset for each iteration during training. To do so, I've created my own hook:
... which I register with a DefaultTrainer. The hook code is called during training, but fails with the following:
The traceback seems to imply that ground truth data is missing, which made me think that the data loader was the problem. However, switching to a training loader produces a different error:
As a sanity check, inference works just fine:
... but that isn't what I want, of course. Any thoughts?
Thanks in advance,
Tim
The text was updated successfully, but these errors were encountered: