Skip to content

Commit

Permalink
Enable to use mmcv.jit in parrots environment (open-mmlab#4192)
Browse files Browse the repository at this point in the history
* add mmcv.jit

* only use derivate and coderize

* fix for isort

* small modify for decorator order
  • Loading branch information
lml131 authored Feb 21, 2021
1 parent 7b3fae6 commit e1599e7
Show file tree
Hide file tree
Showing 15 changed files with 43 additions and 0 deletions.
4 changes: 4 additions & 0 deletions mmdet/core/bbox/coder/bucketing_bbox_coder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import mmcv
import numpy as np
import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -90,6 +91,7 @@ def decode(self, bboxes, pred_bboxes, max_shape=None):
return decoded_bboxes


@mmcv.jit(coderize=True)
def generat_buckets(proposals, num_buckets, scale_factor=1.0):
"""Generate buckets w.r.t bucket number and scale factor of proposals.
Expand Down Expand Up @@ -138,6 +140,7 @@ def generat_buckets(proposals, num_buckets, scale_factor=1.0):
return bucket_w, bucket_h, l_buckets, r_buckets, t_buckets, d_buckets


@mmcv.jit(coderize=True)
def bbox2bucket(proposals,
gt,
num_buckets,
Expand Down Expand Up @@ -261,6 +264,7 @@ def bbox2bucket(proposals,
return offsets, offsets_weights, bucket_labels, bucket_cls_weights


@mmcv.jit(coderize=True)
def bucket2bbox(proposals,
cls_preds,
offset_preds,
Expand Down
3 changes: 3 additions & 0 deletions mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import mmcv
import numpy as np
import torch

Expand Down Expand Up @@ -75,6 +76,7 @@ def decode(self,
return decoded_bboxes


@mmcv.jit(coderize=True)
def bbox2delta(proposals, gt, means=(0., 0., 0., 0.), stds=(1., 1., 1., 1.)):
"""Compute deltas of proposals w.r.t. gt.
Expand Down Expand Up @@ -120,6 +122,7 @@ def bbox2delta(proposals, gt, means=(0., 0., 0., 0.), stds=(1., 1., 1., 1.)):
return deltas


@mmcv.jit(coderize=True)
def delta2bbox(rois,
deltas,
means=(0., 0., 0., 0.),
Expand Down
3 changes: 3 additions & 0 deletions mmdet/core/bbox/coder/legacy_delta_xywh_bbox_coder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import mmcv
import numpy as np
import torch

Expand Down Expand Up @@ -79,6 +80,7 @@ def decode(self,
return decoded_bboxes


@mmcv.jit(coderize=True)
def legacy_bbox2delta(proposals,
gt,
means=(0., 0., 0., 0.),
Expand Down Expand Up @@ -127,6 +129,7 @@ def legacy_bbox2delta(proposals,
return deltas


@mmcv.jit(coderize=True)
def legacy_delta2bbox(rois,
deltas,
means=(0., 0., 0., 0.),
Expand Down
3 changes: 3 additions & 0 deletions mmdet/core/bbox/coder/tblr_bbox_coder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import mmcv
import torch

from ..builder import BBOX_CODERS
Expand Down Expand Up @@ -68,6 +69,7 @@ def decode(self, bboxes, pred_bboxes, max_shape=None):
return decoded_bboxes


@mmcv.jit(coderize=True)
def bboxes2tblr(priors, gts, normalizer=4.0, normalize_by_wh=True):
"""Encode ground truth boxes to tblr coordinate.
Expand Down Expand Up @@ -114,6 +116,7 @@ def bboxes2tblr(priors, gts, normalizer=4.0, normalize_by_wh=True):
return loc / normalizer


@mmcv.jit(coderize=True)
def tblr2bboxes(priors,
tblr,
normalizer=4.0,
Expand Down
3 changes: 3 additions & 0 deletions mmdet/core/bbox/coder/yolo_bbox_coder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import mmcv
import torch

from ..builder import BBOX_CODERS
Expand All @@ -21,6 +22,7 @@ def __init__(self, eps=1e-6):
super(BaseBBoxCoder, self).__init__()
self.eps = eps

@mmcv.jit(coderize=True)
def encode(self, bboxes, gt_bboxes, stride):
"""Get box regression transformation deltas that can be used to
transform the ``bboxes`` into the ``gt_bboxes``.
Expand Down Expand Up @@ -55,6 +57,7 @@ def encode(self, bboxes, gt_bboxes, stride):
[x_center_target, y_center_target, w_target, h_target], dim=-1)
return encoded_bboxes

@mmcv.jit(coderize=True)
def decode(self, bboxes, pred_bboxes, stride):
"""Apply transformation `pred_bboxes` to `boxes`.
Expand Down
2 changes: 2 additions & 0 deletions mmdet/models/losses/accuracy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import mmcv
import torch.nn as nn


@mmcv.jit(coderize=True)
def accuracy(pred, target, topk=1, thresh=None):
"""Calculate accuracy according to the prediction and target.
Expand Down
2 changes: 2 additions & 0 deletions mmdet/models/losses/ae_loss.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import mmcv
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..builder import LOSSES


@mmcv.jit(derivate=True, coderize=True)
def ae_loss_per_image(tl_preds, br_preds, match):
"""Associative Embedding Loss in one image.
Expand Down
2 changes: 2 additions & 0 deletions mmdet/models/losses/balanced_l1_loss.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import mmcv
import numpy as np
import torch
import torch.nn as nn
Expand All @@ -6,6 +7,7 @@
from .utils import weighted_loss


@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def balanced_l1_loss(pred,
target,
Expand Down
2 changes: 2 additions & 0 deletions mmdet/models/losses/gaussian_focal_loss.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import mmcv
import torch.nn as nn

from ..builder import LOSSES
from .utils import weighted_loss


@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def gaussian_focal_loss(pred, gaussian_target, alpha=2.0, gamma=4.0):
"""`Focal Loss <https://arxiv.org/abs/1708.02002>`_ for targets in gaussian
Expand Down
3 changes: 3 additions & 0 deletions mmdet/models/losses/gfocal_loss.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import mmcv
import torch.nn as nn
import torch.nn.functional as F

from ..builder import LOSSES
from .utils import weighted_loss


@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def quality_focal_loss(pred, target, beta=2.0):
r"""Quality Focal Loss (QFL) is from `Generalized Focal Loss: Learning
Expand Down Expand Up @@ -49,6 +51,7 @@ def quality_focal_loss(pred, target, beta=2.0):
return loss


@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def distribution_focal_loss(pred, label):
r"""Distribution Focal Loss (DFL) is from `Generalized Focal Loss: Learning
Expand Down
6 changes: 6 additions & 0 deletions mmdet/models/losses/iou_loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math

import mmcv
import torch
import torch.nn as nn

Expand All @@ -8,6 +9,7 @@
from .utils import weighted_loss


@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def iou_loss(pred, target, linear=False, eps=1e-6):
"""IoU loss.
Expand All @@ -34,6 +36,7 @@ def iou_loss(pred, target, linear=False, eps=1e-6):
return loss


@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def bounded_iou_loss(pred, target, beta=0.2, eps=1e-3):
"""BIoULoss.
Expand Down Expand Up @@ -79,6 +82,7 @@ def bounded_iou_loss(pred, target, beta=0.2, eps=1e-3):
return loss


@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def giou_loss(pred, target, eps=1e-7):
r"""`Generalized Intersection over Union: A Metric and A Loss for Bounding
Expand All @@ -98,6 +102,7 @@ def giou_loss(pred, target, eps=1e-7):
return loss


@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def diou_loss(pred, target, eps=1e-7):
r"""`Implementation of Distance-IoU Loss: Faster and Better
Expand Down Expand Up @@ -152,6 +157,7 @@ def diou_loss(pred, target, eps=1e-7):
return loss


@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def ciou_loss(pred, target, eps=1e-7):
r"""`Implementation of paper `Enhancing Geometric Factors into
Expand Down
3 changes: 3 additions & 0 deletions mmdet/models/losses/pisa_loss.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import mmcv
import torch

from mmdet.core import bbox_overlaps


@mmcv.jit(derivate=True, coderize=True)
def isr_p(cls_score,
bbox_pred,
bbox_targets,
Expand Down Expand Up @@ -116,6 +118,7 @@ def isr_p(cls_score,
return bbox_targets


@mmcv.jit(derivate=True, coderize=True)
def carl_loss(cls_score,
labels,
bbox_pred,
Expand Down
3 changes: 3 additions & 0 deletions mmdet/models/losses/smooth_l1_loss.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import mmcv
import torch
import torch.nn as nn

from ..builder import LOSSES
from .utils import weighted_loss


@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def smooth_l1_loss(pred, target, beta=1.0):
"""Smooth L1 loss.
Expand All @@ -26,6 +28,7 @@ def smooth_l1_loss(pred, target, beta=1.0):
return loss


@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def l1_loss(pred, target):
"""L1 loss.
Expand Down
2 changes: 2 additions & 0 deletions mmdet/models/losses/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools

import mmcv
import torch.nn.functional as F


Expand All @@ -23,6 +24,7 @@ def reduce_loss(loss, reduction):
return loss.sum()


@mmcv.jit(derivate=True, coderize=True)
def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
"""Apply element-wise weight and reduce loss.
Expand Down
2 changes: 2 additions & 0 deletions mmdet/models/losses/varifocal_loss.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import mmcv
import torch.nn as nn
import torch.nn.functional as F

from ..builder import LOSSES
from .utils import weight_reduce_loss


@mmcv.jit(derivate=True, coderize=True)
def varifocal_loss(pred,
target,
weight=None,
Expand Down

0 comments on commit e1599e7

Please sign in to comment.