Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about GIoU #302

Closed
lyx190 opened this issue May 27, 2019 · 36 comments
Closed

Question about GIoU #302

lyx190 opened this issue May 27, 2019 · 36 comments

Comments

@lyx190
Copy link

lyx190 commented May 27, 2019

@glenn-jocher Thank you very much for your code.
I am attempting to using GIoU loss with yolov3 on my own dataset with single class. I am confused about some relative problems. So I wanna ask you guys, if anyone achieved to replace lxy, lwy with GIoU loss to train the net?

I have mapped x, y, w, h of prediction and targets to the size of feature map and then calculated the GIoU loss value in the training phase. But after training, the recall can reach a high value like 90%, precision stay always very low at 25%. I don't know which part causes this problem.

Have anyone idea or managed to modified the GIoU loss on yolov3?

@glenn-jocher
Copy link
Member

@lyx190 ah this is very interesting, I had not heard about GIoU before. I read the paper, and it seems like you may need to do some regularization of the new GIoU loss to balance it's magnitude with the existing losses (conf and class).

Independently of this, I'd advise you to test the default repo first so you have a proper way to test the difference, otherwise your metrics mean nothing by themselves.

@lyx190
Copy link
Author

lyx190 commented May 27, 2019

@glenn-jocher thank you for your reply. I have trained on my own dataset with the default repo before. The result looks good, with 88% recall and 78% precision.

What's more, I am trying to finetune the values of 'cls', 'conf', 'iou_t' on the Hyperparameters.

As for my problem with the high recall(90) and low precision(20), my idea is, that it perhaps caused by the maldistribution of True and False samples during training phase. How do you think of it?

Edited:
According to my thought above, I have changed the 'iou_t' from 0.24(before) to 0.54 now, it is training. But the result seems not good enough, to epoch 32, recall: 30%, precision: 11.3%.

Besides, because of the range of the GIoU loss value is among 0 ~ 2, so I have also normalized it into 0 to 1 to fit the cls_loss and conf_loss value.

@glenn-jocher
Copy link
Member

@lyx190 ah ok. Hmm, those are very high numbers indeed you got before, and now the precision has dropped a lot. hyp['iou_t'] in particular affects the precision-recall balance. If you increase this hyperparamter your precision will increase (at the expense of recall). The new GIoU loss needs a hyperparameter to balance it against conf.

class loss in single-class datasets is always zero (if you've configured your *.cfg for 1 class), so in your case hyp['cls'] should not have any effect. Maybe I can try to implement a GIoU loss as well, it seems simple enough if I understand the paper correctly.

@lyx190
Copy link
Author

lyx190 commented May 27, 2019

@glenn-jocher ah I forgot to tell before that hyp['giou'] and hyp['conf'] I set was 2.4 and 4.3.
But I am confused that would these two parameters affect the recall and precision so hard?

And also in the training phase, if I wanna map the prediction values to the size of feature map, is it correct that:
for xy:

grid_xy = torch.cat((gi.view(-1, 1), gj.view(-1, 1)), 1).float()
pi_xy = torch.sigmoid(pi[..., :2]) + grid_xy

for wh:
pi_wh = torch.exp(pi[..., 2:4]) * (the size of corresponding anchors)

@glenn-jocher
Copy link
Member

glenn-jocher commented May 27, 2019

Oh, you should probably make hyp['giou'] much smaller, maybe in the 0.1-0.5 range, since the current box regression hyperparameters are 0.2 and 0.1 for xy and wh.

hyp = {'xy': 0.2,  # xy loss gain
       'wh': 0.1,  # wh loss gain
       'cls': 0.04,  # cls loss gain
       'conf': 4.5,  # conf loss gain
       'iou_t': 0.5,  # iou target-anchor training threshold
       'lr0': 0.001,  # initial learning rate
       'lrf': -4.,  # final learning rate = lr0 * (10 ** lrf)
       'momentum': 0.90,  # SGD momentum
       'weight_decay': 0.0005}  # optimizer weight decay

I think the IOU metric should not be affected by the units used, so the IOUs could be in grid units or pixels. Off the top of my head the equations seem right, but you should be careful because i and j may be swapped for x and y, i.e. its possible the xy gridpoints are ji and not ij.

@lyx190
Copy link
Author

lyx190 commented May 27, 2019

@glenn-jocher thanks for your remind. I would check again the xy gridpoints.
now I am gonna use the whole current hyperparameters to test again the giou_loss. If I have any result i will post here.

@glenn-jocher
Copy link
Member

glenn-jocher commented May 28, 2019

@lyx190 I started looking into a GIoU implementation. The code to export the bounding boxes is already present in models.py (for inference). This will create the boxes in units of pixels. If you want the output in grid units you simply skip the line 164 stride multiplication.

yolov3/models.py

Lines 157 to 167 in 9cf5ab0

else: # inference
io = p.clone() # inference output
io[..., 0:2] = torch.sigmoid(io[..., 0:2]) + self.grid_xy # xy
io[..., 2:4] = torch.exp(io[..., 2:4]) * self.anchor_wh # wh yolo method
# io[..., 2:4] = ((torch.sigmoid(io[..., 2:4]) * 2) ** 3) * self.anchor_wh # wh power method
io[..., 4:] = torch.sigmoid(io[..., 4:]) # p_conf, p_cls
# io[..., 5:] = F.softmax(io[..., 5:], dim=4) # p_cls
io[..., :4] *= self.stride
if self.nc == 1:
io[..., 5] = 1 # single-class model https://github.com/ultralytics/yolov3/issues/235

Have you had any luck with your implementation?

@glenn-jocher
Copy link
Member

glenn-jocher commented May 28, 2019

@lyx190 BTW you should be extremely careful about building your own grid_xy, it can be very complicated, particularly in the case of rectangular inference. Our validated code which does this in models.py is:

yolov3/models.py

Lines 247 to 250 in 9cf5ab0

# build xy offsets
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
self.grid_xy = torch.stack((xv, yv), 2).to(device).float().view((1, 1, ny, nx, 2))

Note that in grid_xy the vertical dimension (y) comes first, before the horizontal (x) dimension. This is in maintaining with typical pytorch image dimension standards.

@glenn-jocher
Copy link
Member

glenn-jocher commented May 28, 2019

@lyx190 ok, this doesn't seem too hard. I think the fastest way to try is to output models.py raw values along with inference values. The target boxes are in xywh (normalized), while the inference outputs are in xywh (pixels). So we simply multiply the target boxes by the img_size, (target-anchor matches are already made so no new work needed there), and then simply pass the matched boxes to the existing bbox_iou() function to get IOUs (it can handle xywh format). Then determine convex area and implement GIoU loss, maybe using about 0.3 as the start of a hyperparameter search for hyp['giou'].

@glenn-jocher
Copy link
Member

@lyx190 I implemented giou loss in a new giou branch:
https://github.com/ultralytics/yolov3/tree/giou

I tuned the weighting hyperparameters manually a bit to hyp['giou'] = 0.02. I compared two trainings on the coco_img64.data file that comes with the repo. I was not able to exceed the default performance on this training set, but the loss seems to work well, and the branch seems to be a good starting point for experimentation. Note that in the branch, the xy loss reported is the giou loss.
results

@lyx190
Copy link
Author

lyx190 commented May 29, 2019

@glenn-jocher ah it's nice. I am gonna check the new branch now. Yesterday I have print the values of
MSE(torch.sigmoid(pi[..., 0:2]), txy[i]) and GIoU(pi_xyxy, gt_xyxy) .
The value of xy(MSE) is among 0.03--0.08, while GIoU() is among 0.3--0.9. So I set the hyp['giou_loss'] from 0.2 to 0.04 and trained the net, the result is so bad that recall and precision was only 30%.

@lyx190
Copy link
Author

lyx190 commented May 29, 2019

@glenn-jocher BTW, I am gonna take a look the code you implemented. Your result looks much better than mine. I think maybe there is something wrong with my implementation. My modified with GIoU is below:

def compute_loss(p, targets, model, device):  # predictions, targets, model
    ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
    lcls, lconf, lgiou = ft([0]), ft([0]), ft([0])
    tcls, indices, xyOriginal, whOriginal, grids, anchorsFeatureMap = build_targets(model, targets)

    # Define criteria
    # MSE = nn.MSELoss()
    CE = nn.CrossEntropyLoss()  # (weight=model.class_weights)
    BCE = nn.BCEWithLogitsLoss()
    GIoU = GIoU_loss()

    # Compute losses
    h = model.hyp  # hyperparameters
    bs = p[0].shape[0]  # batch size
    k = bs  # loss gain
    for i, pi0 in enumerate(p):  # layer i predictions, i

        b, a, gj, gi = indices[i]  # image, anchor, gridy, gridx
        tconf = torch.zeros_like(pi0[..., 0])  # conf

        # Compute losses
        if len(b):  # number of targets
            pi = pi0[b, a, gj, gi]  # predictions closest to anchors
            tconf[b, a, gj, gi] = 1  # conf
            # pi[..., 2:4] = torch.sigmoid(pi[..., 2:4]) # wh power loss (uncomment)

            grid_xy = torch.cat((gi.view(-1, 1), gj.view(-1, 1)), 1).float().to(device)
            pi_xy = (torch.sigmoid(pi[..., :2]) + grid_xy) / grids[i][0]
            pi_wh = torch.exp(pi[..., 2:4]) * anchorsFeatureMap[i][a] / grids[i][0]
            

            pi_xy_view = pi_xy.view(-1, 2)
            pi_wh_view = pi_wh.view(-1, 2)
            pi_xywh = torch.cat((pi_xy_view, pi_wh_view), 1)
            pi_xyxy = xywh2xyxy(pi_xywh)

            gt_xy = xyOriginal[i]
            gt_wh = whOriginal[i]
            gt_xywh = torch.cat((gt_xy, gt_wh), 1)
            gt_xyxy = xywh2xyxy(gt_xywh)
            # liou = (k * h['iou_loss']) * GIoU(pi_xyxy, gt_xyxy)[0]
            lgiou = (k * h['giou_loss']) * (GIoU(pi_xyxy, gt_xyxy)[1] / 2)

            # lxy += (k * h['xy']) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i])  # xy loss
            # lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i])  # wh yolo loss
            lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i])  # class_conf loss

        # pos_weight = ft([gp[i] / min(gp) * 4.])
        # BCE = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        lconf += (k * h['conf']) * BCE(pi0[..., 4], tconf)  # obj_conf loss
    loss = lgiou + lconf + lcls

    return loss, torch.cat((lgiou, lconf, lcls, loss)).detach()
def build_targets(model, targets):
    # targets = [image, class, x, y, w, h]
    iou_thres = model.hyp['iou_t']  # hyperparameter
    if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
        model = model.module

    nt = len(targets) # number of bouding boxes
    tcls, indices, original_xy, original_wh, grid_layers, anchor_vec = [], [], [], [], [], []
    for i in model.yolo_layers:
        layer = model.module_list[i][0]

        # iou of targets-anchors
        t, a = targets, []
        gwh = targets[:, 4:6] * layer.ng
        if nt:
            iou = [wh_iou(x, gwh) for x in layer.anchor_vec]  # anchor_vec: 映射到feature map的anchor boxes的大小
            iou, a = torch.stack(iou, 0).max(0)  # best iou and anchor

            # reject below threshold ious (OPTIONAL, increases P, lowers R)
            reject = True
            if reject:
                j = iou > iou_thres
                t, a, gwh= targets[j], a[j], gwh[j]

        # Indices
        b, c = t[:, :2].long().t()  # target image, class
        gxy = t[:, 2:4] * layer.ng  # grid x, y
        gi, gj = gxy.long().t()  # grid x, y indices # position of middle point of gound truth in feature map.
        indices.append((b, a, gj, gi)) 

        gxy_original = t[:, 2:4]
        gwh_original = t[:, 4:6]
        anchor_vec.append(layer.anchor_vec)
        grid_layers.append(layer.ng)
        original_xy.append(gxy_original)
        original_wh.append(gwh_original)

        # XY coordinates
        # txy.append(gxy - gxy.floor()) # In the feature map, the offset of ground truth respect to the grid it belongs.

        # Width and height
        # twh.append(torch.log(gwh / layer.anchor_vec[a]))  # wh yolo method
        # twh.append((gwh / layer.anchor_vec[a]) ** (1 / 3) / 2)  

        # Class
        tcls.append(c)
        if c.shape[0]:
            assert c.max() <= layer.nc, 'Target classes exceed model classes'

    return tcls, indices, original_xy, original_wh, grid_layers, anchor_vec

@lyx190
Copy link
Author

lyx190 commented May 29, 2019

@glenn-jocher Now I am using your code to train on my own dataset, but I have a question about the calculation of GIoU:

yolov3/utils/utils.py

Lines 297 to 298 in 243344a

pbox = torch.cat((torch.sigmoid(pi[..., 0:2]), torch.exp(pi[..., 2:4]) * anchor_vec[i]), 1) # predicted box
giou = bbox_iou(pbox.t(), tbox[i], GIoU=True)

These two lines here, the xy coordinates of ground truth are resized based on the size of feature maps, so it isn't necessary to map the torch.sigmoid([pi[..., :2]]) to the size of feature map?
If so, I think we should also modify the values of xy coordinates (without adding to the self.grid_xy) during inference, right?

@glenn-jocher
Copy link
Member

@lyx190 inference operates completely well, as you can see from mAP comparisons to published results in the README.

The predicted and target boxes will always have a common origin, adding the grid location is redundant since it will be removed in the IoU calculation. i.e. there's no point in adding the same constant to both xy values.

@lyx190
Copy link
Author

lyx190 commented May 30, 2019

@glenn-jocher Yes you are right. But I am confused that why if I added the grid_xy to predicted xy coordinate, the result would go down.

Besides, I have trained on my own dataset with the new giou branch, below are the results between with giou and without giou:
results

Both results are trained on my own dataset. The blue line with default hyp[] setting. The orange line with hyp[] which you set on the 'giou' branch.

@glenn-jocher
Copy link
Member

glenn-jocher commented May 30, 2019

@lyx190 grid_xy is a constant offset. You can add it to both boxes you pass to giou, or add it to none, but you can not add it to only one. Your results look similar to mine above, the giou results are worse unfortunately. In my example the difference is minimal though, only a few percent lower.

@glenn-jocher
Copy link
Member

Your second run looks a little more promising though (the area between 220 to 300 I assume is a new run).

@lyx190
Copy link
Author

lyx190 commented May 30, 2019

@glenn-jocher I don't know what happened between epoch 220 to 300, all things went down automatically from 220. So I shut it down at 300 epoch.
Do you think the way to improve the result is to adjust the weight of giou_loss?

@glenn-jocher
Copy link
Member

glenn-jocher commented May 30, 2019

I don't know. It's a very nice idea, its too bad we can't reproduce the paper results yet. I suppose the conclusions one might draw from these results are that:

  • giou is not as effective as normal YOLO loss (paper claims otherwise)
  • there is a mistake in this giou implementation (always possible)
  • the hyperparameters need tuning after the introduction of giou (I've tried this already)
  • the loss function in general needs more work (it is not precisely aligned with darknet)
  • giou may help or hurt certain datasets depending on the specifics of those datasets

@lyx190
Copy link
Author

lyx190 commented May 30, 2019

@glenn-jocher thank you for your idea, I would keep working on it. If I got a better result would come back here again.

@glenn-jocher
Copy link
Member

@its a very appealing idea to wrap all of the regression losses into one. Hopefully we can get it to work better. I'll leave the branch open, and if you make any discoveries let me know!

@lyx190
Copy link
Author

lyx190 commented May 30, 2019

@glenn-jocher ok!

@lyx190
Copy link
Author

lyx190 commented Jun 8, 2019

@glenn-jocher Hello, I am still struggling for the implementation now. I have found two things:

  1. In the code in branch giou, when calculating the value of giou, I think the keyword x1y1x2y2 of function bbox_iou should be set to False.
  2. I found that the value of predicted weight and height would sometimes turned to very large or very small, like 2e+34 or -2e+34, which causes NaN. I don't know why, do you have any idea?

@glenn-jocher
Copy link
Member

glenn-jocher commented Jun 9, 2019

@glenn-jocher Hello, I am still struggling for the implementation now. I have found two things:

  1. In the code in branch giou, when calculating the value of giou, I think the keyword x1y1x2y2 of function bbox_iou should be set to False.

Can you provide an exact line number of which True should be set to False?

  1. I found that the value of predicted weight and height would sometimes turned to very large or very small, like 2e+34 or -2e+34, which causes NaN. I don't know why, do you have any idea?

I haven't observed this when training COCO or any of the smaller COCO datasets like coco_64img.data. What dataset are you using? Can you provide code for a minimum reproducible example?

@lyx190
Copy link
Author

lyx190 commented Jun 9, 2019

@glenn-jocher

  1. giou = bbox_iou(pbox.t(), tbox[i], GIoU=True)
    here is what I think needed to be added "x1y1x2y2=False".
  2. I was using my own dataset and corresponding anchor sizes, which had a very good performance on the default code. At the time predicted w and h turned very high I have also checked the values of corresponding targets, they look very normal and reasonable.

@glenn-jocher
Copy link
Member

glenn-jocher commented Jun 10, 2019

@lyx190 hey I think you are right!! That line should read like this instead, because both the targets and prediction boxes are in xywh format, not xyxy format. This is a huge correction, it should hopefully improve the results dramatically!! Have you tried it out with the correction?

giou = bbox_iou(pbox.t(), tbox[i], GIoU=True, x1y1x2y2=False)

Your #2 is a pretty well known issue with the YOLO layer, which is that wh loss can diverge early on during training, becoming infinity or nan (search the issues for wh divergence). This is the reason for the burnin period in train.py when using the SGD optimizer. I had hoped that GIOU loss would fix this issue though. Are you saying that your wh predictions are still divering to nan even with GIOU loss? Perhaps the error above is the cause. Can you test again with the correction? I'll submit a commit right now, then you just need to git pull.

UPDATE1: I re-ran coco_64img.data to compare to my previous results. The updated results incorporating the x1y1x2y2=False fix are below as results_giou_v2. There is only a slight improvement unfortunately :(
results

@lyx190
Copy link
Author

lyx190 commented Jun 10, 2019

@glenn-jocher I have tried the new on my own dataset, it also, unfortunately, had a little improvement. I am trying now with the loss format: lgiou = -log((1 + giou) / 2). If I get better result, I would tell you again.

And about my second problem, the very high value of predicted wh was only one bounding box in the array of prediction, that is, other predicted boxes looked very normal around 1.000, but only one was very large. And because we apply mean() on the giou loss, so the whole loss turned into nan.

@glenn-jocher
Copy link
Member

@lyx190 we trained yolov3-spp fully with giou for 68 epochs and ended up at 0.464 mAP, about the same as regular xy and wh implementation. Any luck on your side?

@lyx190
Copy link
Author

lyx190 commented Jun 23, 2019

@glenn-jocher hello, I have trained on my own dataset, got the same result of regular yolov3. On which dataset did you trained with giou? I am going to test it on the same dataset as yours.

@glenn-jocher
Copy link
Member

@lyx190 I trained coco for one epoch. Also got same results as default. I've integrated giou into the main branch now, you can train with it using --giou like this:

python3 train.py --data data/coco.data --img-size 320 --epochs 1 --giou

@lyx190
Copy link
Author

lyx190 commented Jun 24, 2019

@glenn-jocher Okay, thank you. I would train it on coco 64 image dataset and then post the result here

@glenn-jocher
Copy link
Member

@lyx190 GIoU is now integrated as the default regression loss for this repository, so simply running the default training command will use GIoU. Closing this issue as resolved.

@developer0hye
Copy link
Contributor

@glenn-jocher
So, your conclusion is that GIoU makes the model be trained well?

@glenn-jocher
Copy link
Member

@developer0hye yes GIoU is a definite improvement over the individual xywh MSE losses. It stabilizes the wh loss, and since it merges 4 losses into 1 it is much easier to work with and tune hyperparameters with.

@developer0hye
Copy link
Contributor

developer0hye commented Feb 21, 2020

@glenn-jocher
Thank you for the answer. When I applied that to my custom dataset and yolo implementation, Its performance is worse than the xywh loss. Okay, I will retry!

Ah, one more! have you tried optimize the model with DIoU or CIoU?

@glenn-jocher
Copy link
Member

@developer0hye hmm, interesting. I suppose results will vary across datasets. The ultralytics results in https://github.com/ultralytics/yolov3#map were all trained from scratch using this repo with GIoU.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants