diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 696673fc65d93..e758f5588b1d0 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -943,45 +943,13 @@ def box_coder(prior_box, box_normalized=False, axis=1) """ - check_variable_and_dtype(prior_box, 'prior_box', ['float32', 'float64'], - 'box_coder') - check_variable_and_dtype(target_box, 'target_box', ['float32', 'float64'], - 'box_coder') - if in_dygraph_mode(): - if isinstance(prior_box_var, Variable): - box_coder_op = _C_ops.box_coder(prior_box, prior_box_var, - target_box, code_type, - box_normalized, axis, []) - elif isinstance(prior_box_var, list): - box_coder_op = _C_ops.box_coder(prior_box, None, target_box, - code_type, box_normalized, axis, - prior_box_var) - else: - raise TypeError( - "Input variance of box_coder must be Variable or lisz") - return box_coder_op - helper = LayerHelper("box_coder", **locals()) - - output_box = helper.create_variable_for_type_inference( - dtype=prior_box.dtype) - - inputs = {"PriorBox": prior_box, "TargetBox": target_box} - attrs = { - "code_type": code_type, - "box_normalized": box_normalized, - "axis": axis - } - if isinstance(prior_box_var, Variable): - inputs['PriorBoxVar'] = prior_box_var - elif isinstance(prior_box_var, list): - attrs['variance'] = prior_box_var - else: - raise TypeError("Input variance of box_coder must be Variable or lisz") - helper.append_op(type="box_coder", - inputs=inputs, - attrs=attrs, - outputs={"OutputBox": output_box}) - return output_box + return paddle.vision.ops.box_coder(prior_box=prior_box, + prior_box_var=prior_box_var, + target_box=target_box, + code_type=code_type, + box_normalized=box_normalized, + axis=axis, + name=name) @templatedoc() @@ -1903,68 +1871,19 @@ def prior_box( # [6L, 9L, 1L, 4L] """ - - if in_dygraph_mode(): - step_w, step_h = steps - if max_sizes == None: - max_sizes = [] - return _C_ops.prior_box(input, image, min_sizes, aspect_ratios, - variance, max_sizes, flip, clip, step_w, step_h, - offset, min_max_aspect_ratios_order) - helper = LayerHelper("prior_box", **locals()) - dtype = helper.input_dtype() - check_variable_and_dtype(input, 'input', - ['uint8', 'int8', 'float32', 'float64'], - 'prior_box') - - def _is_list_or_tuple_(data): - return (isinstance(data, list) or isinstance(data, tuple)) - - if not _is_list_or_tuple_(min_sizes): - min_sizes = [min_sizes] - if not _is_list_or_tuple_(aspect_ratios): - aspect_ratios = [aspect_ratios] - if not (_is_list_or_tuple_(steps) and len(steps) == 2): - raise ValueError('steps should be a list or tuple ', - 'with length 2, (step_width, step_height).') - - min_sizes = list(map(float, min_sizes)) - aspect_ratios = list(map(float, aspect_ratios)) - steps = list(map(float, steps)) - - attrs = { - 'min_sizes': min_sizes, - 'aspect_ratios': aspect_ratios, - 'variances': variance, - 'flip': flip, - 'clip': clip, - 'step_w': steps[0], - 'step_h': steps[1], - 'offset': offset, - 'min_max_aspect_ratios_order': min_max_aspect_ratios_order - } - if max_sizes is not None and len(max_sizes) > 0 and max_sizes[0] > 0: - if not _is_list_or_tuple_(max_sizes): - max_sizes = [max_sizes] - attrs['max_sizes'] = max_sizes - - box = helper.create_variable_for_type_inference(dtype) - var = helper.create_variable_for_type_inference(dtype) - helper.append_op( - type="prior_box", - inputs={ - "Input": input, - "Image": image - }, - outputs={ - "Boxes": box, - "Variances": var - }, - attrs=attrs, - ) - box.stop_gradient = True - var.stop_gradient = True - return box, var + return paddle.vision.ops.prior_box( + input=input, + image=image, + min_sizes=min_sizes, + max_sizes=max_sizes, + aspect_ratios=aspect_ratios, + variance=variance, + flip=flip, + clip=clip, + steps=steps, + offset=offset, + min_max_aspect_ratios_order=min_max_aspect_ratios_order, + name=name) def density_prior_box(input, diff --git a/python/paddle/fluid/tests/unittests/test_box_coder_op.py b/python/paddle/fluid/tests/unittests/test_box_coder_op.py index 0e70f9904b50a..824d8155b7f28 100644 --- a/python/paddle/fluid/tests/unittests/test_box_coder_op.py +++ b/python/paddle/fluid/tests/unittests/test_box_coder_op.py @@ -282,5 +282,57 @@ def run(place): run(place) +class TestBoxCoderAPI(unittest.TestCase): + + def setUp(self): + np.random.seed(678) + self.prior_box_np = np.random.random((80, 4)).astype('float32') + self.prior_box_var_np = np.random.random((80, 4)).astype('float32') + self.target_box_np = np.random.random((20, 80, 4)).astype('float32') + + def test_dygraph_with_static(self): + paddle.enable_static() + prior_box = paddle.static.data(name='prior_box', + shape=[80, 4], + dtype='float32') + prior_box_var = paddle.static.data(name='prior_box_var', + shape=[80, 4], + dtype='float32') + target_box = paddle.static.data(name='target_box', + shape=[20, 80, 4], + dtype='float32') + + boxes = paddle.vision.ops.box_coder(prior_box=prior_box, + prior_box_var=prior_box_var, + target_box=target_box, + code_type="decode_center_size", + box_normalized=False) + + exe = paddle.static.Executor() + boxes_np = exe.run(paddle.static.default_main_program(), + feed={ + 'prior_box': self.prior_box_np, + 'prior_box_var': self.prior_box_var_np, + 'target_box': self.target_box_np, + }, + fetch_list=[boxes]) + + paddle.disable_static() + prior_box_dy = paddle.to_tensor(self.prior_box_np) + prior_box_var_dy = paddle.to_tensor(self.prior_box_var_np) + target_box_dy = paddle.to_tensor(self.target_box_np) + + boxes_dy = paddle.vision.ops.box_coder(prior_box=prior_box_dy, + prior_box_var=prior_box_var_dy, + target_box=target_box_dy, + code_type="decode_center_size", + box_normalized=False) + boxes_dy_np = boxes_dy.numpy() + + np.testing.assert_allclose(boxes_np[0], boxes_dy_np) + paddle.enable_static() + + if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_prior_box_op.py b/python/paddle/fluid/tests/unittests/test_prior_box_op.py index 67d468ac626e6..0f04f6c5f511e 100644 --- a/python/paddle/fluid/tests/unittests/test_prior_box_op.py +++ b/python/paddle/fluid/tests/unittests/test_prior_box_op.py @@ -107,8 +107,6 @@ def init_test_params(self): self.flip = True self.set_min_max_aspect_ratios_order() self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0] - self.aspect_ratios = np.array(self.aspect_ratios, - dtype=np.float64).flatten() self.variances = [0.1, 0.1, 0.2, 0.2] self.variances = np.array(self.variances, dtype=np.float64).flatten() @@ -217,6 +215,54 @@ def set_min_max_aspect_ratios_order(self): self.min_max_aspect_ratios_order = True +class TestPriorBoxAPI(unittest.TestCase): + + def setUp(self): + np.random.seed(678) + self.input_np = np.random.rand(2, 10, 32, 32).astype('float32') + self.image_np = np.random.rand(2, 10, 40, 40).astype('float32') + self.min_sizes = [2.0, 4.0] + + def test_dygraph_with_static(self): + paddle.enable_static() + input = paddle.static.data(name='input', + shape=[2, 10, 32, 32], + dtype='float32') + image = paddle.static.data(name='image', + shape=[2, 10, 40, 40], + dtype='float32') + + box, var = paddle.vision.ops.prior_box(input=input, + image=image, + min_sizes=self.min_sizes, + clip=True, + flip=True) + + exe = paddle.static.Executor() + box_np, var_np = exe.run(paddle.static.default_main_program(), + feed={ + 'input': self.input_np, + 'image': self.image_np, + }, + fetch_list=[box, var]) + + paddle.disable_static() + inputs_dy = paddle.to_tensor(self.input_np) + image_dy = paddle.to_tensor(self.image_np) + + box_dy, var_dy = paddle.vision.ops.prior_box(input=inputs_dy, + image=image_dy, + min_sizes=self.min_sizes, + clip=True, + flip=True) + box_dy_np = box_dy.numpy() + var_dy_np = var_dy.numpy() + + np.testing.assert_allclose(box_np, box_dy_np) + np.testing.assert_allclose(var_np, var_dy_np) + paddle.enable_static() + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py index 112d233cc494d..01387823c17f5 100755 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -18,15 +18,29 @@ from ..fluid.layers import nn, utils from ..nn import Layer, Conv2D, Sequential, ReLU, BatchNorm2D from ..fluid.initializer import Normal -from ..fluid.framework import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph +from ..fluid.framework import Variable, _non_static_mode, in_dygraph_mode, _in_legacy_dygraph from paddle import _C_ops, _legacy_C_ops from ..framework import _current_expected_place __all__ = [ #noqa - 'yolo_loss', 'yolo_box', 'deform_conv2d', 'DeformConv2D', - 'distribute_fpn_proposals', 'generate_proposals', 'read_file', - 'decode_jpeg', 'roi_pool', 'RoIPool', 'psroi_pool', 'PSRoIPool', - 'roi_align', 'RoIAlign', 'nms', 'matrix_nms' + 'yolo_loss', + 'yolo_box', + 'prior_box', + 'box_coder', + 'deform_conv2d', + 'DeformConv2D', + 'distribute_fpn_proposals', + 'generate_proposals', + 'read_file', + 'decode_jpeg', + 'roi_pool', + 'RoIPool', + 'psroi_pool', + 'PSRoIPool', + 'roi_align', + 'RoIAlign', + 'nms', + 'matrix_nms', ] @@ -407,6 +421,313 @@ def yolo_box(x, return boxes, scores +def prior_box(input, + image, + min_sizes, + max_sizes=None, + aspect_ratios=[1.], + variance=[0.1, 0.1, 0.2, 0.2], + flip=False, + clip=False, + steps=[0.0, 0.0], + offset=0.5, + min_max_aspect_ratios_order=False, + name=None): + r""" + + This op generates prior boxes for SSD(Single Shot MultiBox Detector) algorithm. + + Each position of the input produce N prior boxes, N is determined by + the count of min_sizes, max_sizes and aspect_ratios, The size of the + box is in range(min_size, max_size) interval, which is generated in + sequence according to the aspect_ratios. + + Args: + input (Tensor): 4-D tensor(NCHW), the data type should be float32 or float64. + image (Tensor): 4-D tensor(NCHW), the input image data of PriorBoxOp, + the data type should be float32 or float64. + min_sizes (list|tuple|float): the min sizes of generated prior boxes. + max_sizes (list|tuple|None, optional): the max sizes of generated prior boxes. + Default: None. + aspect_ratios (list|tuple|float, optional): the aspect ratios of generated + prior boxes. Default: [1.]. + variance (list|tuple, optional): the variances to be encoded in prior boxes. + Default:[0.1, 0.1, 0.2, 0.2]. + flip (bool): Whether to flip aspect ratios. Default:False. + clip (bool): Whether to clip out-of-boundary boxes. Default: False. + steps (list|tuple, optional): Prior boxes steps across width and height, If + steps[0] equals to 0.0 or steps[1] equals to 0.0, the prior boxes steps across + height or weight of the input will be automatically calculated. + Default: [0., 0.] + offset (float, optional)): Prior boxes center offset. Default: 0.5 + min_max_aspect_ratios_order (bool, optional): If set True, the output prior box is + in order of [min, max, aspect_ratios], which is consistent with + Caffe. Please note, this order affects the weights order of + convolution layer followed by and does not affect the final + detection results. Default: False. + name (str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name` + + Returns: + Tensor: the output prior boxes and the expanded variances of PriorBox. + The prior boxes is a 4-D tensor, the layout is [H, W, num_priors, 4], + num_priors is the total box count of each position of input. + The expanded variances is a 4-D tensor, same shape as the prior boxes. + + Examples: + .. code-block:: python + + import paddle + + input = paddle.rand((1, 3, 6, 9), dtype=paddle.float32) + image = paddle.rand((1, 3, 9, 12), dtype=paddle.float32) + + box, var = paddle.vision.ops.prior_box( + input=input, + image=image, + min_sizes=[2.0, 4.0], + clip=True, + flip=True) + + """ + helper = LayerHelper("prior_box", **locals()) + dtype = helper.input_dtype() + check_variable_and_dtype(input, 'input', + ['uint8', 'int8', 'float32', 'float64'], + 'prior_box') + + def _is_list_or_tuple_(data): + return (isinstance(data, list) or isinstance(data, tuple)) + + if not _is_list_or_tuple_(min_sizes): min_sizes = [min_sizes] + if not _is_list_or_tuple_(aspect_ratios): aspect_ratios = [aspect_ratios] + if not _is_list_or_tuple_(steps): steps = [steps] + if not len(steps) == 2: raise ValueError('steps should be (step_w, step_h)') + + min_sizes = list(map(float, min_sizes)) + aspect_ratios = list(map(float, aspect_ratios)) + steps = list(map(float, steps)) + + cur_max_sizes = None + if max_sizes is not None and len(max_sizes) > 0 and max_sizes[0] > 0: + if not _is_list_or_tuple_(max_sizes): max_sizes = [max_sizes] + cur_max_sizes = max_sizes + + if in_dygraph_mode(): + step_w, step_h = steps + if max_sizes == None: max_sizes = [] + box, var = _C_ops.prior_box(input, image, min_sizes, aspect_ratios, + variance, max_sizes, flip, clip, step_w, + step_h, offset, min_max_aspect_ratios_order) + return box, var + + if _in_legacy_dygraph(): + attrs = ('min_sizes', min_sizes, 'aspect_ratios', aspect_ratios, + 'variances', variance, 'flip', flip, 'clip', clip, 'step_w', + steps[0], 'step_h', steps[1], 'offset', offset, + 'min_max_aspect_ratios_order', min_max_aspect_ratios_order) + if cur_max_sizes is not None: attrs += ('max_sizes', cur_max_sizes) + box, var = _legacy_C_ops.prior_box(input, image, *attrs) + return box, var + else: + attrs = { + 'min_sizes': min_sizes, + 'aspect_ratios': aspect_ratios, + 'variances': variance, + 'flip': flip, + 'clip': clip, + 'step_w': steps[0], + 'step_h': steps[1], + 'offset': offset, + 'min_max_aspect_ratios_order': min_max_aspect_ratios_order + } + if cur_max_sizes is not None: attrs['max_sizes'] = cur_max_sizes + + box = helper.create_variable_for_type_inference(dtype) + var = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type="prior_box", + inputs={ + "Input": input, + "Image": image + }, + outputs={ + "Boxes": box, + "Variances": var + }, + attrs=attrs, + ) + box.stop_gradient = True + var.stop_gradient = True + return box, var + + +def box_coder(prior_box, + prior_box_var, + target_box, + code_type="encode_center_size", + box_normalized=True, + axis=0, + name=None): + r""" + Encode/Decode the target bounding box with the priorbox information. + + The Encoding schema described below: + + .. math:: + + ox &= (tx - px) / pw / pxv + + oy &= (ty - py) / ph / pyv + + ow &= log(abs(tw / pw)) / pwv + + oh &= log(abs(th / ph)) / phv + + The Decoding schema described below: + + .. math:: + + ox &= (pw * pxv * tx * + px) - tw / 2 + + oy &= (ph * pyv * ty * + py) - th / 2 + + ow &= exp(pwv * tw) * pw + tw / 2 + + oh &= exp(phv * th) * ph + th / 2 + + where `tx`, `ty`, `tw`, `th` denote the target box's center coordinates, + width and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote + the priorbox's (anchor) center coordinates, width and height. `pxv`, + `pyv`, `pwv`, `phv` denote the variance of the priorbox and `ox`, `oy`, + `ow`, `oh` denote the encoded/decoded coordinates, width and height. + During Box Decoding, two modes for broadcast are supported. Say target + box has shape [N, M, 4], and the shape of prior box can be [N, 4] or + [M, 4]. Then prior box will broadcast to target box along the + assigned axis. + + Args: + prior_box (Tensor): Box list prior_box is a 2-D Tensor with shape + [M, 4] holds M boxes and data type is float32 or float64. Each box + is represented as [xmin, ymin, xmax, ymax], [xmin, ymin] is the + left top coordinate of the anchor box, if the input is image feature + map, they are close to the origin of the coordinate system. + [xmax, ymax] is the right bottom coordinate of the anchor box. + prior_box_var (List|Tensor|None): prior_box_var supports three types + of input. One is Tensor with shape [M, 4] which holds M group and + data type is float32 or float64. The second is list consist of + 4 elements shared by all boxes and data type is float32 or float64. + Other is None and not involved in calculation. + target_box (Tensor): This input can be a 2-D LoDTensor with shape + [N, 4] when code_type is 'encode_center_size'. This input also can + be a 3-D Tensor with shape [N, M, 4] when code_type is + 'decode_center_size'. Each box is represented as + [xmin, ymin, xmax, ymax]. The data type is float32 or float64. + code_type (str, optional): The code type used with the target box. It can be + `encode_center_size` or `decode_center_size`. `encode_center_size` + by default. + box_normalized (bool, optional): Whether treat the priorbox as a normalized box. + Set true by default. + axis (int, optional): Which axis in PriorBox to broadcast for box decode, + for example, if axis is 0 and TargetBox has shape [N, M, 4] and + PriorBox has shape [M, 4], then PriorBox will broadcast to [N, M, 4] + for decoding. It is only valid when code type is + `decode_center_size`. Set 0 by default. + name (str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + Tensor: output boxes, when code_type is 'encode_center_size', the + output tensor of box_coder_op with shape [N, M, 4] representing the + result of N target boxes encoded with M Prior boxes and variances. + When code_type is 'decode_center_size', N represents the batch size + and M represents the number of decoded boxes. + + Examples: + .. code-block:: python + + import paddle + + # For encode + prior_box_encode = paddle.rand((80, 4), dtype=paddle.float32) + prior_box_var_encode = paddle.rand((80, 4), dtype=paddle.float32) + target_box_encode = paddle.rand((20, 4), dtype=paddle.float32) + output_encode = paddle.vision.ops.box_coder( + prior_box=prior_box_encode, + prior_box_var=prior_box_var_encode, + target_box=target_box_encode, + code_type="encode_center_size") + + # For decode + prior_box_decode = paddle.rand((80, 4), dtype=paddle.float32) + prior_box_var_decode = paddle.rand((80, 4), dtype=paddle.float32) + target_box_decode = paddle.rand((20, 80, 4), dtype=paddle.float32) + output_decode = paddle.vision.ops.box_coder( + prior_box=prior_box_decode, + prior_box_var=prior_box_var_decode, + target_box=target_box_decode, + code_type="decode_center_size", + box_normalized=False) + + """ + check_variable_and_dtype(prior_box, 'prior_box', ['float32', 'float64'], + 'box_coder') + check_variable_and_dtype(target_box, 'target_box', ['float32', 'float64'], + 'box_coder') + + if in_dygraph_mode(): + if isinstance(prior_box_var, Variable): + output_box = _C_ops.box_coder(prior_box, prior_box_var, target_box, + code_type, box_normalized, axis, []) + elif isinstance(prior_box_var, list): + output_box = _C_ops.box_coder(prior_box, None, target_box, + code_type, box_normalized, axis, + prior_box_var) + else: + raise TypeError("Input prior_box_var must be Variable or list") + return output_box + + if _in_legacy_dygraph(): + if isinstance(prior_box_var, Variable): + output_box = _legacy_C_ops.box_coder(prior_box, prior_box_var, + target_box, "code_type", + code_type, "box_normalized", + box_normalized, "axis", axis) + elif isinstance(prior_box_var, list): + output_box = _legacy_C_ops.box_coder(prior_box, None, target_box, + "code_type", code_type, + "box_normalized", + box_normalized, "axis", axis, + "variance", prior_box_var) + else: + raise TypeError("Input prior_box_var must be Variable or list") + return output_box + else: + helper = LayerHelper("box_coder", **locals()) + + output_box = helper.create_variable_for_type_inference( + dtype=prior_box.dtype) + + inputs = {"PriorBox": prior_box, "TargetBox": target_box} + attrs = { + "code_type": code_type, + "box_normalized": box_normalized, + "axis": axis + } + if isinstance(prior_box_var, Variable): + inputs['PriorBoxVar'] = prior_box_var + elif isinstance(prior_box_var, list): + attrs['variance'] = prior_box_var + else: + raise TypeError("Input prior_box_var must be Variable or list") + helper.append_op(type="box_coder", + inputs=inputs, + attrs=attrs, + outputs={"OutputBox": output_box}) + return output_box + + def deform_conv2d(x, offset, weight,