Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR OpTest Fix No.15】 fix test match matrix tensor op #60277

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions paddle/fluid/operators/match_matrix_tensor_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,93 @@ class CPUMatchMatrixTensorOPKernel : public framework::OpKernel<T> {
auto* tmp = ctx.Output<phi::DenseTensor>("Tmp");

int dim_t = ctx.Attr<int>("dim_t");

const auto& x_lod = x->lod();
PADDLE_ENFORCE_EQ(x_lod.empty(),
false,
platform::errors::InvalidArgument(
"The Input(X) should hold LoD information, but "
"received Input(X).lod() is empty."));
const auto& x_lod_0 = x_lod[0];
PADDLE_ENFORCE_GE(x_lod_0.size(),
2,
platform::errors::InvalidArgument(
"The dimensions of Input(X)'s LoD data should be "
"equal to 2, but received %d.",
x_lod_0.size()));
auto x_dims = x->dims();
PADDLE_ENFORCE_EQ(x_dims[0],
static_cast<int64_t>(x_lod_0.back()),
platform::errors::InvalidArgument(
"The last element of Input(X)'s LoD data should be "
"equal to the first dimension of Input(X). "
"But received the last element of Input(X)'s LoD "
"data is %d, the first dimension of Input(X) is %d.",
x_lod_0.back(),
x_dims[0]));
const auto& y_lod = y->lod();
PADDLE_ENFORCE_EQ(y_lod.empty(),
false,
platform::errors::InvalidArgument(
"The Input(Y) should hold LoD information, but "
"received Input(Y).lod() is empty."));
const auto& y_lod_0 = y_lod[0];
PADDLE_ENFORCE_GE(y_lod_0.size(),
2,
platform::errors::InvalidArgument(
"The dimensions of Input(Y)'s LoD data should be "
"equal to 2, but received %d.",
y_lod_0.size()));
auto y_dims = y->dims();
PADDLE_ENFORCE_EQ(y_dims[0],
static_cast<int64_t>(y_lod_0.back()),
platform::errors::InvalidArgument(
"The last element of Input(Y)'s LoD data should be "
"equal to the first dimension of Input(Y). "
"But received the last element of Input(Y)'s LoD "
"data is %d, the first dimension of Input(Y) is %d.",
y_lod_0.back(),
y_dims[0]));

PADDLE_ENFORCE_EQ(x_lod_0.size(),
y_lod_0.size(),
platform::errors::InvalidArgument(
"The dimensions of Input(X)'s and Input(Y)'s LoD "
"data should be equal. "
"But received the dimensions of Input(X)'s LoD is "
"%d, the dimensions of Input(Y)'s LoD is %d.",
x_lod_0.size(),
y_lod_0.size()));

int64_t out_dim_0 = 0;
int64_t tmp_dim_0 = -1;
for (size_t i = 1; i < x_lod_0.size(); i++) {
int64_t x_len = x_lod_0[i] - x_lod_0[i - 1];
int64_t y_len = y_lod_0[i] - y_lod_0[i - 1];
out_dim_0 += (x_len * y_len);
}
out_dim_0 *= dim_t;

tmp_dim_0 = x_dims[0] * dim_t * x_dims[1];
std::vector<int64_t> out_dims_vec{out_dim_0};
out_dims_vec.push_back(1);
std::vector<int64_t> tmp_dims_vec{tmp_dim_0};
tmp_dims_vec.push_back(1);

auto& out_meta = out->meta();
phi::DenseTensorMeta new_out_meta(out_meta.dtype,
common::make_ddim(out_dims_vec),
out_meta.layout,
out_meta.lod);
out->set_meta(new_out_meta);

auto& tmp_meta = tmp->meta();
phi::DenseTensorMeta new_tmp_meta(tmp_meta.dtype,
common::make_ddim(tmp_dims_vec),
tmp_meta.layout,
tmp_meta.lod);
tmp->set_meta(new_tmp_meta);

int64_t dim_in = x->dims()[1];

const auto& offset_l = x->lod()[0];
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@
'sparse_momentum',
'soft_relu',
'uniform_random_batch_size_like',
'match_matrix_tensor',
'c_reduce_min',
'c_reduce_min_',
]
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1491,6 +1491,15 @@
data_type: param
optional: master_param, master_param_out

- op: match_matrix_tensor
args: (Tensor x, Tensor y, Tensor w, int dim_t=1)
output: Tensor(out), Tensor(tmp)
infer_meta:
func: MatchMatrixTensorInferMeta
kernel:
func: match_matrix_tensor
backward: match_matrix_tensor_grad

- op: nce
args: (Tensor input, Tensor label, Tensor weight, Tensor bias, Tensor sample_weight, Tensor custom_dist_probs, Tensor custom_dist_alias, Tensor custom_dist_alias_probs, int num_total_classes, int[] custom_neg_classes={}, int num_neg_samples=10, int sampler=0, int seed=0, bool is_sparse=false, bool remote_prefetch=false, bool is_test=false)
output: Tensor(cost), Tensor(sample_logits), Tensor(sample_labels)
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,15 @@
func: fused_elemwise_add_activation_grad
optional : x, intermediate_out

- backward_op: match_matrix_tensor_grad
forward: match_matrix_tensor (Tensor x, Tensor y, Tensor w, int dim_t=1) -> Tensor(out), Tensor(tmp)
args: (Tensor x, Tensor y, Tensor w, Tensor tmp, Tensor out_grad, int dim_t=1)
output: Tensor(x_grad), Tensor(y_grad), Tensor(w_grad)
infer_meta:
func: MatchMatrixTensorGradInferMeta
kernel:
func: match_matrix_tensor_grad

- backward_op: shuffle_batch_grad
forward: shuffle_batch (Tensor x, Tensor seed, int startup_seed=0) -> Tensor(out), Tensor(shuffle_idx), Tensor(seed_out)
args: (Tensor shuffle_idx, Tensor out_grad,int startup_seed=0)
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ const std::unordered_set<std::string> LegacyOpList = {
RowConvGradOp::name(),
SoftReluOp::name(),
SoftReluGradOp::name(),
MatchMatrixTensorOp::name(),
MatchMatrixTensorGradOp::name(),
NceOp::name(),
NceGradOp::name(),
CReduceMinOp::name()};
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3511,6 +3511,13 @@
attrs:
pivot : pivots

- op: match_matrix_tensor
backward: match_matrix_tensor_grad
inputs:
{x : X, y : Y, w : W}
outputs:
{out : Out, tmp : Tmp}

- op: memcpy
inputs:
x: X
Expand Down
26 changes: 25 additions & 1 deletion paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,31 @@ void MarginCrossEntropyGradInferMeta(const MetaTensor& logits,
logits_grad->set_dtype(softmax.dtype());
}

void MatchMatrixTensorGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& w,
const MetaTensor& tmp,
const MetaTensor& out_grad,
int dim_t,
MetaTensor* x_grad,
MetaTensor* y_grad,
MetaTensor* w_grad) {
if (x_grad != nullptr) {
x_grad->set_dims(x.dims());
x_grad->share_lod(x);
x_grad->set_dtype(x.dtype());
}
if (y_grad != nullptr) {
y_grad->set_dims(y.dims());
y_grad->share_lod(y);
y_grad->set_dtype(y.dtype());
}
if (w_grad != nullptr) {
w_grad->set_dims(w.dims());
w_grad->set_dtype(w.dtype());
}
}

void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
const MetaTensor& mask,
const MetaTensor& dout,
Expand Down Expand Up @@ -1352,5 +1377,4 @@ void SetValueGradInferMeta(const MetaTensor& out_grad,
value_grad->share_lod(values);
}
}

} // namespace phi
10 changes: 10 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,16 @@ void MarginCrossEntropyGradInferMeta(const MetaTensor& logits,
float scale,
MetaTensor* logits_grad);

void MatchMatrixTensorGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& w,
const MetaTensor& tmp,
const MetaTensor& out_grad,
int dim_t,
MetaTensor* x_grad,
MetaTensor* y_grad,
MetaTensor* w_grad);

void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
const MetaTensor& mask,
const MetaTensor& dout,
Expand Down
73 changes: 73 additions & 0 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,79 @@ void LinspaceInferMeta(const MetaTensor& start,
LinspaceRawInferMeta(start, stop, number, out);
}

void MatchMatrixTensorInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& w,
int dim_t,
MetaTensor* out,
MetaTensor* tmp,
MetaConfig config) {
auto x_dims = x.dims();
PADDLE_ENFORCE_EQ(x_dims.size(),
2,
phi::errors::InvalidArgument(
"The dimensions of Input(X) should be equal to 2, "
"but received %d.",
x_dims.size()));

auto y_dims = y.dims();
PADDLE_ENFORCE_EQ(y_dims.size(),
2,
phi::errors::InvalidArgument(
"The dimensions of Input(Y) should be equal to 2, "
"but received %d.",
y_dims.size()));

auto w_dims = w.dims();
PADDLE_ENFORCE_EQ(w_dims.size(),
3,
phi::errors::InvalidArgument(
"The dimensions of Input(W) should be equal to 3, "
"but received %d.",
w_dims.size()));

PADDLE_ENFORCE_EQ(
w_dims[0],
x_dims[1],
phi::errors::InvalidArgument(
"The first dimension of Input(W) should be equal to the second "
"dimension of Input(X). But received the first dimension of Input(W) "
"is %d, the second dimension of Input(X) is %d.",
w_dims[0],
x_dims[1]));
PADDLE_ENFORCE_EQ(
w_dims[1],
dim_t,
phi::errors::InvalidArgument(
"The second dimension of Input(W) should be equal to 'dim_t', but "
"received the second dimension of Input(W) is %d, 'dim_t' is %d.",
w_dims[1],
dim_t));
PADDLE_ENFORCE_EQ(
w_dims[2],
y_dims[1],
phi::errors::InvalidArgument(
"The last dimension of Input(W) should be equal to "
"the second dimension of Input(Y). But received the last dimension "
"of Input(W) is %d, the second dimension of Input(Y) is %d.",
w_dims[2],
y_dims[1]));

int64_t out_dim_0 = -1;
int64_t tmp_dim_0 = -1;
if (!config.is_runtime) {
out->share_lod(x);
std::vector<int64_t> out_dims_vec{out_dim_0};
out_dims_vec.push_back(1);
std::vector<int64_t> tmp_dims_vec{tmp_dim_0};
tmp_dims_vec.push_back(1);
out->set_dims(common::make_ddim(out_dims_vec));
out->set_dtype(x.dtype());
tmp->set_dims(common::make_ddim(tmp_dims_vec));
tmp->set_dtype(x.dtype());
}
}

void MultiClassNMSInferMeta(const MetaTensor& bboxes,
const MetaTensor& scores,
const MetaTensor& rois_num,
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,14 @@ void LinspaceInferMeta(const MetaTensor& start,
DataType dtype,
MetaTensor* out);

void MatchMatrixTensorInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& w,
int dim_t,
MetaTensor* out,
MetaTensor* tmp,
MetaConfig config = MetaConfig());

void MultiClassNMSInferMeta(const MetaTensor& bboxes,
const MetaTensor& scores,
const MetaTensor& rois_num,
Expand Down
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_white_list
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ test_lu_op
test_lu_unpack_op
test_margin_cross_entropy_op
test_masked_select_op
test_match_matrix_tensor_op
test_matmul_v2_op
test_matmul_v2_op_static_build
test_matrix_nms_op
Expand Down