Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Prim][PIR] unsequeeze prim sink #59798

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"sqrt",
"squeeze",
"stack",
"unsqueeze",
]

# come into effect in generated file op_decomp.cc
Expand All @@ -53,6 +54,7 @@
"sqrt",
"squeeze",
"stack",
"unsqueeze",
kevincheng2 marked this conversation as resolved.
Show resolved Hide resolved
]


Expand Down
19 changes: 11 additions & 8 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,8 @@ Tensor softmax_decomp(const Tensor& x, const int& axis) {

template <typename T>
Tensor stack_decomp(const std::vector<Tensor>& x, const int& axis) {
auto tensor_dims = x[0].dims();
int tmp_axis = axis;
if (tmp_axis < 0) {
tmp_axis += tensor_dims.size() + 1;
}

auto out_shape = phi::vectorize(tensor_dims);
out_shape.insert(out_shape.begin() + tmp_axis, 1);
std::vector<int64_t> axis_tmp = {axis};
auto out_shape = get_expand_dims(x[0], axis_tmp);

std::vector<Tensor> concat_x;
for (size_t i = 0; i < x.size(); ++i) {
Expand Down Expand Up @@ -318,6 +312,15 @@ std::tuple<Tensor, Tensor> squeeze_decomp(const Tensor& x,
return std::make_tuple(out, xshape);
}

template <typename T>
std::tuple<Tensor, Tensor> unsqueeze_decomp(const Tensor& x,
const IntArray& axis) {
auto out_shape = get_expand_dims(x, axis.GetData());
Tensor out = reshape<T>(x, out_shape);
Tensor xshape;
return std::make_tuple(out, xshape);
}

template <typename T>
Tensor add_n_decomp(const std::vector<Tensor>& x) {
Tensor res = x[0];
Expand Down
23 changes: 23 additions & 0 deletions paddle/fluid/primitive/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,29 @@ static bool is_half_dtype(const DataType& dtype) {
}
}

// This function expands the dimension of origin Tensor based on the value of
// axis
static std::vector<int64_t> get_expand_dims(const Tensor& origin,
const std::vector<int64_t>& axis) {
std::vector<int64_t> result(origin.shape());
for (size_t i = 0; i < axis.size(); ++i) {
int64_t offset = axis[i];
if (offset < 0) {
offset += result.size() + 1;
}

PADDLE_ENFORCE_LE(
offset,
result.size(),
platform::errors::OutOfRange("Your index [%lu] exceeds the number of "
"elements in origin_dims[%lu].",
offset,
result.size()));
result.insert(result.begin() + offset, 1);
kevincheng2 marked this conversation as resolved.
Show resolved Hide resolved
}
return result;
}

// This fucction compute unsqueeze dims for reshape to replace unsqueeze.
static std::vector<int64_t> get_unsqueeze_dims(
const Tensor& origin, const std::vector<int64_t>& axis) {
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/decomposition/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from .register import register_decomp


# TODO(kevincheng2): python implementation of prim feature,
# now it has been sinked to c++, waiting for further deletion.
@register_decomp('pd_op.unsqueeze')
def unsqueeze(x, axis):
"""define composite rule of op unsqueeze"""
Expand Down