diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 8cfd12d3ba330c..a52154ea8bea80 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -2324,6 +2324,45 @@ struct ShareBufferOpTranscriber : public OpTranscriber { } }; +struct RandIntOpTranscriber : public OpTranscriber { + std::tuple GenerateOperationOutput( + pir::IrContext* ctx, + const OpDesc& op_desc, + const OpOutputInfoList& output_infos) { + OpOutputMapping arg_to_idx; + OpOutputTypeList op_output_types = {}; + + auto& type_translator = TypeTranslator::instance(); + + const BlockDesc* block = op_desc.Block(); + std::string legacy_output_name = "Out"; + const auto& legacy_output_vars = op_desc.Output(legacy_output_name); + auto& var_name = legacy_output_vars[0]; + VarDesc* var = block->FindVarRecursive(var_name); + IR_ENFORCE(var != nullptr, + "[op:%s] Output %s should not be null", + op_desc.Type(), + var_name); + int dtype_attr_val = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype")); + + paddle::framework::proto::VarType::Type var_type = + static_cast(dtype_attr_val); + + pir::Type dtype = type_translator[var_type](ctx, *var); + paddle::dialect::DenseTensorTypeStorage::Dim dim = + phi::make_ddim(var->GetShape()); + paddle::dialect::DenseTensorTypeStorage::DataLayout layout = + paddle::dialect::DenseTensorTypeStorage::DataLayout::UNDEFINED; + paddle::dialect::DenseTensorTypeStorage::LoD lod = {}; + size_t offset = 0; + pir::Type translated_var_type = paddle::dialect::DenseTensorType::get( + ctx, dtype, dim, layout, lod, offset); + arg_to_idx[var_name] = {0, 0}; + op_output_types.push_back(translated_var_type); + return {op_output_types, arg_to_idx}; + } +}; + struct RepeatInterLeaveOpTranscriber : public OpTranscriber { pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, const OpDesc& op_desc) override { @@ -2418,6 +2457,7 @@ OpTranslator::OpTranslator() { special_handlers["lookup_table_v2"] = EmbeddingOpTranscriber(); special_handlers["lookup_table_v2_grad"] = EmbeddingGradOpTranscriber(); special_handlers["one_hot_v2"] = OneHotTranscriber(); + special_handlers["randint"] = RandIntOpTranscriber(); special_handlers["reduce_all"] = ReduceOpTranscriber(); special_handlers["reduce_any"] = ReduceOpTranscriber(); special_handlers["repeat_interleave"] = RepeatInterLeaveOpTranscriber(); diff --git a/test/white_list/pir_op_test_no_check_list b/test/white_list/pir_op_test_no_check_list index a5996049fc8e91..8363980af03472 100644 --- a/test/white_list/pir_op_test_no_check_list +++ b/test/white_list/pir_op_test_no_check_list @@ -1,2 +1,3 @@ test_exponential_op +test_randint_op test_seed_op diff --git a/test/white_list/pir_op_test_white_list b/test/white_list/pir_op_test_white_list index 81a41e28d66ecd..d3ff77d26da669 100644 --- a/test/white_list/pir_op_test_white_list +++ b/test/white_list/pir_op_test_white_list @@ -165,6 +165,7 @@ test_prelu_op test_prior_box_op test_psroi_pool_op test_put_along_axis_op +test_randint_op test_range test_reduce_op test_reduce_op_static_build