Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix randint op #58295

Merged
merged 9 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2492,6 +2492,45 @@ struct ShareBufferOpTranscriber : public OpTranscriber {
}
};

struct RandIntOpTranscriber : public OpTranscriber {
std::tuple<OpOutputTypeList, OpOutputMapping> GenerateOperationOutput(
pir::IrContext* ctx,
const OpDesc& op_desc,
const OpOutputInfoList& output_infos) {
OpOutputMapping arg_to_idx;
OpOutputTypeList op_output_types = {};

auto& type_translator = TypeTranslator::instance();

const BlockDesc* block = op_desc.Block();
std::string legacy_output_name = "Out";
const auto& legacy_output_vars = op_desc.Output(legacy_output_name);
auto& var_name = legacy_output_vars[0];
VarDesc* var = block->FindVarRecursive(var_name);
IR_ENFORCE(var != nullptr,
"[op:%s] Output %s should not be null",
op_desc.Type(),
var_name);
int dtype_attr_val = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype"));

paddle::framework::proto::VarType::Type var_type =
static_cast<paddle::framework::proto::VarType::Type>(dtype_attr_val);

pir::Type dtype = type_translator[var_type](ctx, *var);
paddle::dialect::DenseTensorTypeStorage::Dim dim =
phi::make_ddim(var->GetShape());
paddle::dialect::DenseTensorTypeStorage::DataLayout layout =
paddle::dialect::DenseTensorTypeStorage::DataLayout::UNDEFINED;
paddle::dialect::DenseTensorTypeStorage::LoD lod = {};
size_t offset = 0;
pir::Type translated_var_type = paddle::dialect::DenseTensorType::get(
ctx, dtype, dim, layout, lod, offset);
arg_to_idx[var_name] = {0, 0};
op_output_types.push_back(translated_var_type);
return {op_output_types, arg_to_idx};
}
};

struct RepeatInterLeaveOpTranscriber : public OpTranscriber {
pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx,
const OpDesc& op_desc) override {
Expand Down Expand Up @@ -2565,6 +2604,7 @@ struct RepeatInterLeaveGradOpTranscriber : public OpTranscriber {
return op_inputs;
}
};

OpTranslator::OpTranslator() {
pir::IrContext* ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
Expand All @@ -2585,6 +2625,7 @@ OpTranslator::OpTranslator() {
special_handlers["lookup_table_v2"] = EmbeddingOpTranscriber();
special_handlers["lookup_table_v2_grad"] = EmbeddingGradOpTranscriber();
special_handlers["one_hot_v2"] = OneHotTranscriber();
special_handlers["randint"] = RandIntOpTranscriber();
special_handlers["reduce_all"] = ReduceOpTranscriber();
special_handlers["reduce_any"] = ReduceOpTranscriber();
special_handlers["repeat_interleave"] = RepeatInterLeaveOpTranscriber();
Expand Down
1 change: 1 addition & 0 deletions test/white_list/new_ir_op_test_no_check_list
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
test_exponential_op
test_randint_op
1 change: 1 addition & 0 deletions test/white_list/new_ir_op_test_white_list
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ test_prelu_op
test_prior_box_op
test_psroi_pool_op
test_put_along_axis_op
test_randint_op
test_range
test_reduce_op
test_reduce_op_static_build
Expand Down