Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Fix PIR Unittest No.213、277】Support to decompose squeeze_grad and unsqueeze_grad #64277

Merged
merged 5 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,10 @@
'scatter_grad',
'scatter_nd_add_grad',
'slice_grad',
'squeeze_grad',
'tile_grad',
'topk_grad',
'unsqueeze_grad',
]

# whole vjp list of primitive op vjp
Expand Down
45 changes: 45 additions & 0 deletions paddle/fluid/primitive/rule/vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,51 @@ void softmax_grad(const Tensor& out,
}
}

template <typename T>
void squeeze_grad(const Tensor& xshape,
const Tensor& out_grad,
const IntArray& axis,
Tensor* x_grad) {
if (x_grad) {
auto x_grad_out = unsqueeze<T>(out_grad, axis);
set_output<T>(x_grad_out, x_grad);
}
}

template <typename T>
void unsqueeze_grad(const Tensor& xshape,
const Tensor& out_grad,
const IntArray& axis,
Tensor* x_grad) {
// for xshape = [10, 2, 5], axis = [3, 1, 1], out_grad.shape = [10, 1, 1, 2,
// 5, 1], it outputs squeeze axis = [5, 2, 1]
const auto& IncreaseAxis = [](std::vector<int64_t>* axis_data,
int64_t pivot) {
for (size_t i = 0; i < axis_data->size(); ++i) {
if ((*axis_data)[i] >= pivot) (*axis_data)[i] += 1;
}
};
const auto& GetRealAxis = [&](const IntArray& axis) -> decltype(auto) {
// for axis = [0, 3, 3], it outputs [0, 3, 3+1], because unsqueeze support
// duplicated axis.
std::vector<int64_t> output_axis;
const int64_t x_rank = xshape.dims().size() - 1;
const std::vector<int64_t> axis_data = axis.GetData();
for (size_t i = 0; i < axis_data.size(); ++i) {
int64_t value = axis_data[i];
if (value < 0) value += (x_rank + i + 1);
IncreaseAxis(&output_axis, value);
output_axis.push_back(value);
}
return output_axis;
};

if (x_grad) {
auto x_grad_out = squeeze<T>(out_grad, GetRealAxis(axis));
set_output<T>(x_grad_out, x_grad);
}
}

template <typename T>
void matmul_grad(const Tensor& x,
const Tensor& y,
Expand Down
10 changes: 0 additions & 10 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4286,16 +4286,6 @@ void SqueezeInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config) {
const auto& x_dims = x.dims();
// Check input tensor dims (<6) Eigen limit.
PADDLE_ENFORCE_LE(x_dims.size(),
6,
phi::errors::InvalidArgument(
"The dimensions of Input(X) "
"should be in the range of [1, 6] (Eigen limit)."
"But received X's dimensions = %d, X's shape = [%s].",
x_dims.size(),
x_dims));

Aurelius84 marked this conversation as resolved.
Show resolved Hide resolved
if (!config.is_runtime && axes.FromTensor()) {
// compile time infershape, set all elements to -1.
int output_size = static_cast<int>(x.dims().size() - axes.GetData().size());
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/decomposition/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def decompose(
blacklist (frozenset): The Operators that will be exclude when decomposed into primitives.
whitelist (frozenset): Only the operators in whitelist will be decomposed into primitives.
start_index (int): The start index of decomposed operator in global block, default 0;
end_index (int): The end index of decomposed operator in global block, default -1 means all ops will be composed.
end_index (int): The end index of decomposed operator in global block, default -1 means all ops will be composed. start_index and end_index follow the principle of left closed and right open, that is [start_index, end_index).

Returns:
dst_vars (list): A list contains all vars which replace origin ones in src_vars.
Expand Down
2 changes: 0 additions & 2 deletions test/deprecated/legacy_test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -772,11 +772,9 @@ set(TEST_CINN_OPS
test_top_k_v2_op
test_elementwise_mul_op
test_gather_nd_op
test_squeeze2_op
test_elementwise_pow_op
test_transpose_op
test_reshape_op
test_unsqueeze2_op
test_meshgrid_op
test_scale_op
test_scatter_op
Expand Down
206 changes: 0 additions & 206 deletions test/deprecated/legacy_test/test_squeeze2_op.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -15,182 +15,14 @@
import os
import unittest

import numpy as np
from op_test import OpTest, convert_float_to_uint16
from test_attribute_var import UnittestBase

import paddle
from paddle.base import core
from paddle.base.framework import Program, program_guard
from paddle.pir_utils import test_with_pir_api

paddle.enable_static()


# Correct: General.
class TestSqueezeOp(OpTest):
def setUp(self):
self.op_type = "squeeze2"
self.prim_op_type = "comp"
self.python_api = paddle.squeeze
self.public_python_api = paddle.squeeze
self.python_out_sig = [
"Out"
] # python out sig is customized output signature.
self.init_test_case()
self.init_dtype()
self.if_enable_cinn()
x = np.random.random(self.ori_shape).astype("float64")
xshape = np.random.random(self.ori_shape).astype("float64")
if hasattr(self, "dtype") and self.dtype == np.uint16:
x = convert_float_to_uint16(x.astype(np.float32))
xshape = convert_float_to_uint16(xshape.astype(np.float32))
self.inputs = {"X": x}
self.init_attrs()
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": xshape,
}

def if_enable_cinn(self):
pass

def test_check_output(self):
self.check_output(
no_check_set=['XShape'],
check_prim=True,
check_pir=True,
check_prim_pir=True,
)

def test_check_grad(self):
self.check_grad(
["X"],
"Out",
check_prim=True,
check_pir=True,
check_prim_pir=True,
)

def init_dtype(self):
self.dtype = np.float64

def init_test_case(self):
self.ori_shape = (1, 3, 1, 40)
self.axes = (0, 2)
self.new_shape = (3, 40)

def init_attrs(self):
self.attrs = {"axes": self.axes}


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class TestSqueezeOpBF16OP(TestSqueezeOp):
def init_dtype(self):
self.dtype = np.uint16


# Correct: There is mins axis.
class TestSqueezeOp1(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (1, 20, 1, 5)
self.axes = (0, -2)
self.new_shape = (20, 5)


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class TestSqueezeOp1BF16Op(TestSqueezeOp):
def init_dtype(self):
self.dtype = np.uint16


class TestSqueezeOp_ZeroDim1(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = ()
self.axes = (0,)
self.new_shape = ()


class TestSqueezeOp_ZeroDim2(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (1, 1, 1)
self.axes = (0, 1, 2)
self.new_shape = ()


# Correct: No axes input.
class TestSqueezeOp2(TestSqueezeOp):
def setUp(self):
self.op_type = "squeeze2"
self.prim_op_type = "comp"
self.python_api = paddle.squeeze
self.public_python_api = paddle.squeeze
self.python_out_sig = [
"Out"
] # python out sig is customized output signature.
self.init_test_case()
self.init_dtype()
self.if_enable_cinn()
x = np.random.random(self.ori_shape).astype("float64")
xshape = np.random.random(self.ori_shape).astype("float64")
if hasattr(self, "dtype") and self.dtype == np.uint16:
x = convert_float_to_uint16(x.astype(np.float32))
xshape = convert_float_to_uint16(xshape.astype(np.float32))
self.inputs = {"X": x}
self.init_attrs()
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": xshape,
}

def if_enable_cinn(self):
pass

def init_dtype(self):
self.dtype = np.float64

def init_test_case(self):
self.ori_shape = (1, 20, 1, 5)
self.axes = ()
self.new_shape = (20, 5)


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class TestSqueezeOp2BF16Op(TestSqueezeOp):
def init_dtype(self):
self.dtype = np.uint16


# Correct: Just part of axes be squeezed.
class TestSqueezeOp3(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (6, 1, 5, 1, 4, 1)
self.axes = (1, -1)
self.new_shape = (6, 5, 1, 4)


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class TestSqueezeOp3BF16Op(TestSqueezeOp):
def init_dtype(self):
self.dtype = np.uint16


class TestSqueeze2AxesTensor(UnittestBase):
def init_info(self):
self.shapes = [[2, 3, 4]]
Expand Down Expand Up @@ -266,43 +98,5 @@ def test_static(self):
self.assertEqual(infer_out.shape, (2, 3, 10))


# test api
class TestSqueezeAPI(unittest.TestCase):
def setUp(self):
self.executed_api()

def executed_api(self):
self.squeeze = paddle.squeeze

def test_api(self):
paddle.disable_static()
input_data = np.random.random([3, 2, 1]).astype("float32")
x = paddle.to_tensor(input_data)
out = self.squeeze(x, axis=2)
out.backward()

self.assertEqual(out.shape, [3, 2])

paddle.enable_static()

@test_with_pir_api
def test_error(self):
def test_axes_type():
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x2 = paddle.static.data(
name="x2", shape=[2, 1, 25], dtype="int32"
)
self.squeeze(x2, axis=2.1)

self.assertRaises(TypeError, test_axes_type)


class TestSqueezeInplaceAPI(TestSqueezeAPI):
def executed_api(self):
self.squeeze = paddle.squeeze_


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions test/legacy_test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,8 @@ set(TEST_CINN_OPS
test_group_norm_op
test_tile_op
test_sum_op
test_squeeze2_op
test_unsqueeze2_op
test_elementwise_min_op
test_take_along_axis_op
test_strided_slice_op
Expand Down
Loading