Skip to content

Commit

Permalink
Move new affine_grid api to functional
Browse files Browse the repository at this point in the history
test=develop
  • Loading branch information
wanghaoshuang committed Aug 19, 2020
1 parent a949914 commit d91f592
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 7 deletions.
8 changes: 2 additions & 6 deletions python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9081,8 +9081,7 @@ def _attr_offsets_check(offset_val):
return out


def affine_grid(theta, out_shape, name=None, align_corners=True,
use_cudnn=True):
def affine_grid(theta, out_shape, name=None):
"""
It generates a grid of (x,y) coordinates using the parameters of
the affine transformation that correspond to a set of points where
Expand All @@ -9096,9 +9095,6 @@ def affine_grid(theta, out_shape, name=None, align_corners=True,
``out_shape`` can be a Tensor or a list or tuple. The data
type must be int32.
name(str|None): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.
align_corners(bool): Whether to align corners of target feature map and source feature map. Default: True.
use_cudnn(bool): It will ignore `align_corners` and compute in align corners mode when use_cudnn is true.
Default: True.

Returns:
Variable: A Tensor with shape [batch_size, H, W, 2] while 'H' and 'W' are the height and width of feature map in affine transformation. The data type is the same as `theta`.
Expand Down Expand Up @@ -9140,7 +9136,7 @@ def affine_grid(theta, out_shape, name=None, align_corners=True,

out = helper.create_variable_for_type_inference(theta.dtype)
ipts = {'Theta': theta}
attrs = {"align_corners": align_corners, "use_cudnn": use_cudnn}
attrs = {}
if isinstance(out_shape, Variable):
ipts['OutputShape'] = out_shape
check_variable_and_dtype(out_shape, 'out_shape', ['int32'],
Expand Down
134 changes: 134 additions & 0 deletions python/paddle/fluid/tests/unittests/test_affine_grid_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from paddle import fluid, nn
import paddle.fluid.dygraph as dg
import paddle.nn.functional as F
import paddle.fluid.initializer as I
import unittest


class AffineGridTestCase(unittest.TestCase):
def __init__(self,
methodName='runTest',
theta_shape=(20, 2, 3),
output_shape=[20, 2, 5, 7],
align_corners=True,
dtype="float32"):
super(AffineGridTestCase, self).__init__(methodName)

self.theta_shape = theta_shape
self.output_shape = output_shape
self.align_corners = align_corners
self.dtype = dtype

def setUp(self):
self.theta = np.random.randn(*(self.theta_shape)).astype(self.dtype)

def fluid_layer(self, place):
# align_corners = True
main = fluid.Program()
start = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main, start):
theta_var = fluid.data(
"input", self.theta_shape, dtype=self.dtype)
y_var = fluid.layers.affine_grid(theta_var, self.output_shape)
feed_dict = {"input": self.theta}
exe = fluid.Executor(place)
exe.run(start)
y_np, = exe.run(main, feed=feed_dict, fetch_list=[y_var])
return y_np

def functional(self, place):
main = fluid.Program()
start = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main, start):
theta_var = fluid.data(
"input", self.theta_shape, dtype=self.dtype)
y_var = F.affine_grid(
theta_var,
self.output_shape,
align_corners=self.align_corners)
feed_dict = {"input": self.theta}
exe = fluid.Executor(place)
exe.run(start)
y_np, = exe.run(main, feed=feed_dict, fetch_list=[y_var])
return y_np

def paddle_dygraph_layer(self):
theta_var = dg.to_variable(self.theta)
y_var = F.affine_grid(
theta_var, self.output_shape, align_corners=self.align_corners)
y_np = y_var.numpy()
return y_np

def _test_equivalence(self, place):
place = fluid.CPUPlace()
result1 = self.fluid_layer(place)
result2 = self.functional(place)
with dg.guard(place):
result3 = self.paddle_dygraph_layer()
if self.align_corners:
np.testing.assert_array_almost_equal(result1, result2)
np.testing.assert_array_almost_equal(result2, result3)

def runTest(self):
place = fluid.CPUPlace()
self._test_equivalence(place)

if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
self._test_equivalence(place)


class AffineGridErrorTestCase(AffineGridTestCase):
def runTest(self):
place = fluid.CPUPlace()
with dg.guard(place):
with self.assertRaises(ValueError):
self.paddle_dygraph_layer()


def add_cases(suite):
suite.addTest(AffineGridTestCase(methodName='runTest'))
suite.addTest(AffineGridTestCase(methodName='runTest', align_corners=True))

suite.addTest(AffineGridTestCase(methodName='runTest', align_corners=False))

suite.addTest(
AffineGridTestCase(
methodName='runTest',
theta_shape=(20, 2, 3),
output_shape=[20, 1, 7, 7],
align_corners=True))


def add_error_cases(suite):
suite.addTest(
AffineGridErrorTestCase(
methodName='runTest', output_shape="not_valid"))


def load_tests(loader, standard_tests, pattern):
suite = unittest.TestSuite()
add_cases(suite)
add_error_cases(suite)
return suite


if __name__ == '__main__':
unittest.main()
90 changes: 89 additions & 1 deletion python/paddle/nn/functional/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ...device import get_cudnn_version
from ...fluid.framework import core, in_dygraph_mode, Variable
from ...fluid.layer_helper import LayerHelper
from ...fluid.data_feeder import check_variable_and_dtype

# TODO: define specitial functions used in computer vision task
from ...fluid.layers import affine_channel #DEFINE_ALIAS
from ...fluid.layers import affine_grid #DEFINE_ALIAS
from ...fluid.layers import anchor_generator #DEFINE_ALIAS
from ...fluid.layers import bipartite_match #DEFINE_ALIAS
from ...fluid.layers import box_clip #DEFINE_ALIAS
Expand Down Expand Up @@ -89,3 +93,87 @@
'yolo_box',
'yolov3_loss'
]


def affine_grid(theta, out_shape, align_corners=True, name=None):
"""
It generates a grid of (x,y) coordinates using the parameters of
the affine transformation that correspond to a set of points where
the input feature map should be sampled to produce the transformed
output feature map.
Args:
theta (Variable) - A Tensor with shape [N, 2, 3]. It contains a batch of affine transform parameters.
The data type can be float32 or float64.
out_shape (Variable | list | tuple): The shape of target output with format [batch_size, channel, height, width].
``out_shape`` can be a Tensor or a list or tuple. The data
type must be int32.
align_corners(bool): Whether to align corners of target feature map and source feature map. Default: True.
name(str|None): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: A Tensor with shape [batch_size, H, W, 2] while 'H' and 'W' are the height and width of feature map in affine transformation. The data type is the same as `theta`.
Raises:
ValueError: If the type of arguments is not supported.
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
import numpy as np
paddle.disable_static()
place = paddle.CPUPlace()
theta_shape = [20, 2, 3]
theta = np.random.randn(*theta_shape).astype("float32")
theta_var = paddle.to_variable(theta)
y_var = F.affine_grid(
theta_var,
[20, 2, 5, 5],
align_corners=False)
y_np = y_var.numpy()
print(y_np)
"""
helper = LayerHelper('affine_grid')

check_variable_and_dtype(theta, 'theta', ['float32', 'float64'],
'affine_grid')

if get_cudnn_version() >= 6000 and align_corners:
use_cudnn = True
else:
use_cudnn = False

if not (isinstance(out_shape, list) or isinstance(out_shape, tuple) or \
isinstance(out_shape, Variable)):
raise ValueError("The out_shape should be a list, tuple or Variable.")

if in_dygraph_mode():
_out_shape = out_shape.numpy().tolist() if isinstance(
out_shape, Variable) else out_shape
return core.ops.affine_grid(theta, "output_shape", _out_shape,
"align_corners", align_corners, "use_cudnn",
use_cudnn)

if not isinstance(theta, Variable):
raise ValueError("The theta should be a Variable.")

out = helper.create_variable_for_type_inference(theta.dtype)
ipts = {'Theta': theta}
attrs = {"align_corners": align_corners, "use_cudnn": use_cudnn}
if isinstance(out_shape, Variable):
ipts['OutputShape'] = out_shape
check_variable_and_dtype(out_shape, 'out_shape', ['int32'],
'affine_grid')
else:
attrs['output_shape'] = out_shape

helper.append_op(
type='affine_grid',
inputs=ipts,
outputs={'Output': out},
attrs=None if len(attrs) == 0 else attrs)
return out

0 comments on commit d91f592

Please sign in to comment.