diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 813cb34826eb22..8cfd12d3ba330c 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -1027,7 +1027,10 @@ struct DataOpTranscriber : public FeedOpTranscriber { const std::string& normalized_op_name, const OpAttributeInfoList& op_attr_infos, const OpDesc& op_desc) override { - int allocate_type = paddle::get(op_desc.GetAttr("place")); + int allocate_type = PADDLE_GET_CONST(int, op_desc.GetAttr("place")); + int var_dtype = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype")); + auto phi_dtype = phi::TransToPhiDataType(var_dtype); + auto& attribute_translator = AttributeTranslator::instance(); pir::Attribute shape = attribute_translator( "paddle::dialect::IntArrayAttribute", op_desc.GetAttr("shape")); @@ -1036,8 +1039,7 @@ struct DataOpTranscriber : public FeedOpTranscriber { pir::StrAttribute::get(ctx, op_desc.GetAttrIfExists("name"))}, {"shape", shape}, - {"dtype", - paddle::dialect::DataTypeAttribute::get(ctx, phi::DataType::FLOAT32)}, + {"dtype", paddle::dialect::DataTypeAttribute::get(ctx, phi_dtype)}, {"place", paddle::dialect::PlaceAttribute::get( ctx, phi::Place(static_cast(allocate_type)))}, diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 4d297d02afe539..3a5716877a59d8 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -43,6 +43,7 @@ #include "paddle/fluid/pir/transforms/inplace_pass.h" #include "paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.h" #include "paddle/phi/core/enforce.h" +#include "paddle/pir/core/attribute.h" #include "paddle/pir/core/block.h" #include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/program.h" @@ -64,6 +65,7 @@ namespace py = pybind11; using paddle::dialect::APIBuilder; using paddle::dialect::DenseTensorType; using paddle::dialect::SelectedRowsType; +using pir::Attribute; using pir::Block; using pir::Operation; using pir::OpOperand; @@ -804,6 +806,15 @@ void BindType(py::module *m) { }); } +void BindAttribute(py::module *m) { + py::class_ ir_attr(*m, "Attribute", py::module_local()); + ir_attr.def("__str__", [](Attribute &self) { + std::ostringstream print_stream; + print_stream << self; + return print_stream.str(); + }); +} + Operation *BuildOpFrom( Operation *to_copy_op, std::unordered_map &value_map) { // NOLINT @@ -1472,6 +1483,7 @@ void BindPir(pybind11::module *module) { BindOpOperand(&ir_module); BindOpResult(&ir_module); BindType(&ir_module); + BindAttribute(&ir_module); BindUtils(&ir_module); BindIrPass(&ir_module); BindPassManager(&ir_module); diff --git a/python/paddle/base/executor.py b/python/paddle/base/executor.py index 37356defd92196..1f3792f62d6e38 100755 --- a/python/paddle/base/executor.py +++ b/python/paddle/base/executor.py @@ -21,7 +21,9 @@ import numpy as np -from ..pir import OpResult, Value, translate_to_pir +from ..pir import OpResult +from ..pir import Program as PirProgram +from ..pir import Value, translate_to_pir from . import compiler, core, framework, get_flags, set_flags, unique_name from .data_feeder import convert_dtype from .framework import ( @@ -594,6 +596,10 @@ def _to_str(var): return str(var) elif isinstance(var, Operator): return str(id(var)) + elif isinstance(var, OpResult): + return str(var) + elif isinstance(var, Value): + return str(var) else: raise TypeError(str(var) + " should be Variable, Operator or str") @@ -628,11 +634,18 @@ def _prepare_fleet_executor(): def _get_strong_program_cache_key_for_new_exe(program, scope, feed, fetch_list): - return ( - program.desc.cached_hash_str() - + str(scope.raw_address()) - + _get_program_cache_key(feed, fetch_list) - ) + if isinstance(program, PirProgram): + return ( + str(program) + + str(scope.raw_address()) + + _get_program_cache_key(feed, fetch_list) + ) + else: + return ( + program.desc.cached_hash_str() + + str(scope.raw_address()) + + _get_program_cache_key(feed, fetch_list) + ) def _get_strong_program_cache_key(program, feed, fetch_list): @@ -876,6 +889,9 @@ def __init__(self): self._get_cached_program_and_executor = lru_cache(maxsize=8)( self._get_program_and_executor ) + self._get_cached_program_and_executor_pir_mode = lru_cache(maxsize=8)( + self._get_pir_program_and_executor + ) def clear(self): self._get_cached_program_and_executor.cache_clear() @@ -1033,6 +1049,27 @@ def get_pir_program_and_executor( place, scope, ): + return self._get_cached_program_and_executor_pir_mode( + self._CachedData( + program, + feed, + fetch_list, + feed_var_name, + fetch_var_name, + place, + scope, + ) + ) + + def _get_pir_program_and_executor(self, cached_data): + program = cached_data.program + feed = cached_data.feed + fetch_list = cached_data.fetch_list + feed_var_name = cached_data.feed_var_name + fetch_var_name = cached_data.fetch_var_name + place = cached_data.place + scope = cached_data.scope + _add_pir_fetch_ops( program, fetch_list=fetch_list, fetch_var_name=fetch_var_name ) diff --git a/test/ir/pir/test_special_op_translator.py b/test/ir/pir/test_special_op_translator.py index 1c46d89a54c17a..415ff2513b2f1e 100644 --- a/test/ir/pir/test_special_op_translator.py +++ b/test/ir/pir/test_special_op_translator.py @@ -519,6 +519,24 @@ def test_program(self): ), "share_buffer should be translated to share_data" +class TestDataOp(unittest.TestCase): + def test_data_op(self): + place = core.Place() + place.set_place(paddle.CPUPlace()) + + new_scope = paddle.static.Scope() + main_program = paddle.static.Program() + with paddle.static.scope_guard(new_scope): + with paddle.static.program_guard(main_program): + _ = paddle.static.data(name="y", shape=[3, 9, 5], dtype="int64") + l = pir.translate_to_pir(main_program.desc) + self.assertTrue(len(l.global_block().ops) > 0) + self.assertTrue(l.global_block().ops[0].name() == "pd_op.data") + data_op = l.global_block().ops[0] + self.assertIn("dtype", data_op.attrs()) + self.assertEqual(str(data_op.attrs()["dtype"]), "DataType.INT64") + + class TestCheckUnregisteredOp(unittest.TestCase): def test_program(self): main_program = paddle.static.Program()