Skip to content

Commit

Permalink
fix randint op (PaddlePaddle#58295)
Browse files Browse the repository at this point in the history
* fix randint op

* Update new_ir_op_test_no_check_list

* fix name

* fix  file name

* fix  file name

* Update pir_op_test_no_check_list
  • Loading branch information
xingmingyyj authored Nov 7, 2023
1 parent f333e4a commit c99dbba
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 0 deletions.
40 changes: 40 additions & 0 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2324,6 +2324,45 @@ struct ShareBufferOpTranscriber : public OpTranscriber {
}
};

struct RandIntOpTranscriber : public OpTranscriber {
std::tuple<OpOutputTypeList, OpOutputMapping> GenerateOperationOutput(
pir::IrContext* ctx,
const OpDesc& op_desc,
const OpOutputInfoList& output_infos) {
OpOutputMapping arg_to_idx;
OpOutputTypeList op_output_types = {};

auto& type_translator = TypeTranslator::instance();

const BlockDesc* block = op_desc.Block();
std::string legacy_output_name = "Out";
const auto& legacy_output_vars = op_desc.Output(legacy_output_name);
auto& var_name = legacy_output_vars[0];
VarDesc* var = block->FindVarRecursive(var_name);
IR_ENFORCE(var != nullptr,
"[op:%s] Output %s should not be null",
op_desc.Type(),
var_name);
int dtype_attr_val = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype"));

paddle::framework::proto::VarType::Type var_type =
static_cast<paddle::framework::proto::VarType::Type>(dtype_attr_val);

pir::Type dtype = type_translator[var_type](ctx, *var);
paddle::dialect::DenseTensorTypeStorage::Dim dim =
phi::make_ddim(var->GetShape());
paddle::dialect::DenseTensorTypeStorage::DataLayout layout =
paddle::dialect::DenseTensorTypeStorage::DataLayout::UNDEFINED;
paddle::dialect::DenseTensorTypeStorage::LoD lod = {};
size_t offset = 0;
pir::Type translated_var_type = paddle::dialect::DenseTensorType::get(
ctx, dtype, dim, layout, lod, offset);
arg_to_idx[var_name] = {0, 0};
op_output_types.push_back(translated_var_type);
return {op_output_types, arg_to_idx};
}
};

struct RepeatInterLeaveOpTranscriber : public OpTranscriber {
pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx,
const OpDesc& op_desc) override {
Expand Down Expand Up @@ -2418,6 +2457,7 @@ OpTranslator::OpTranslator() {
special_handlers["lookup_table_v2"] = EmbeddingOpTranscriber();
special_handlers["lookup_table_v2_grad"] = EmbeddingGradOpTranscriber();
special_handlers["one_hot_v2"] = OneHotTranscriber();
special_handlers["randint"] = RandIntOpTranscriber();
special_handlers["reduce_all"] = ReduceOpTranscriber();
special_handlers["reduce_any"] = ReduceOpTranscriber();
special_handlers["repeat_interleave"] = RepeatInterLeaveOpTranscriber();
Expand Down
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_no_check_list
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
test_exponential_op
test_randint_op
test_seed_op
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_white_list
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ test_prelu_op
test_prior_box_op
test_psroi_pool_op
test_put_along_axis_op
test_randint_op
test_range
test_reduce_op
test_reduce_op_static_build
Expand Down

0 comments on commit c99dbba

Please sign in to comment.