Skip to content

Commit

Permalink
[PIR] Add cache for pir program (PaddlePaddle#58666)
Browse files Browse the repository at this point in the history
* seperate the translation context bewteen different blocks

* fix bug

* remove debug code

* fix pir interpreter

* fix

* add_cache_for_pir_program

* translate_dtype_for_data_op

* polish

* bind attribute and fix unittest

* fix

* fix
  • Loading branch information
kangguangli authored Nov 7, 2023
1 parent 46f211d commit f100de1
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 9 deletions.
8 changes: 5 additions & 3 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,10 @@ struct DataOpTranscriber : public FeedOpTranscriber {
const std::string& normalized_op_name,
const OpAttributeInfoList& op_attr_infos,
const OpDesc& op_desc) override {
int allocate_type = paddle::get<int>(op_desc.GetAttr("place"));
int allocate_type = PADDLE_GET_CONST(int, op_desc.GetAttr("place"));
int var_dtype = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype"));
auto phi_dtype = phi::TransToPhiDataType(var_dtype);

auto& attribute_translator = AttributeTranslator::instance();
pir::Attribute shape = attribute_translator(
"paddle::dialect::IntArrayAttribute", op_desc.GetAttr("shape"));
Expand All @@ -1036,8 +1039,7 @@ struct DataOpTranscriber : public FeedOpTranscriber {
pir::StrAttribute::get(ctx,
op_desc.GetAttrIfExists<std::string>("name"))},
{"shape", shape},
{"dtype",
paddle::dialect::DataTypeAttribute::get(ctx, phi::DataType::FLOAT32)},
{"dtype", paddle::dialect::DataTypeAttribute::get(ctx, phi_dtype)},
{"place",
paddle::dialect::PlaceAttribute::get(
ctx, phi::Place(static_cast<phi::AllocationType>(allocate_type)))},
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include "paddle/fluid/pir/transforms/inplace_pass.h"
#include "paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/pir/core/attribute.h"
#include "paddle/pir/core/block.h"
#include "paddle/pir/core/builtin_attribute.h"
#include "paddle/pir/core/program.h"
Expand All @@ -64,6 +65,7 @@ namespace py = pybind11;
using paddle::dialect::APIBuilder;
using paddle::dialect::DenseTensorType;
using paddle::dialect::SelectedRowsType;
using pir::Attribute;
using pir::Block;
using pir::Operation;
using pir::OpOperand;
Expand Down Expand Up @@ -804,6 +806,15 @@ void BindType(py::module *m) {
});
}

void BindAttribute(py::module *m) {
py::class_<Attribute> ir_attr(*m, "Attribute", py::module_local());
ir_attr.def("__str__", [](Attribute &self) {
std::ostringstream print_stream;
print_stream << self;
return print_stream.str();
});
}

Operation *BuildOpFrom(
Operation *to_copy_op,
std::unordered_map<pir::Value, pir::Value> &value_map) { // NOLINT
Expand Down Expand Up @@ -1472,6 +1483,7 @@ void BindPir(pybind11::module *module) {
BindOpOperand(&ir_module);
BindOpResult(&ir_module);
BindType(&ir_module);
BindAttribute(&ir_module);
BindUtils(&ir_module);
BindIrPass(&ir_module);
BindPassManager(&ir_module);
Expand Down
49 changes: 43 additions & 6 deletions python/paddle/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@

import numpy as np

from ..pir import OpResult, Value, translate_to_pir
from ..pir import OpResult
from ..pir import Program as PirProgram
from ..pir import Value, translate_to_pir
from . import compiler, core, framework, get_flags, set_flags, unique_name
from .data_feeder import convert_dtype
from .framework import (
Expand Down Expand Up @@ -594,6 +596,10 @@ def _to_str(var):
return str(var)
elif isinstance(var, Operator):
return str(id(var))
elif isinstance(var, OpResult):
return str(var)
elif isinstance(var, Value):
return str(var)
else:
raise TypeError(str(var) + " should be Variable, Operator or str")

Expand Down Expand Up @@ -628,11 +634,18 @@ def _prepare_fleet_executor():


def _get_strong_program_cache_key_for_new_exe(program, scope, feed, fetch_list):
return (
program.desc.cached_hash_str()
+ str(scope.raw_address())
+ _get_program_cache_key(feed, fetch_list)
)
if isinstance(program, PirProgram):
return (
str(program)
+ str(scope.raw_address())
+ _get_program_cache_key(feed, fetch_list)
)
else:
return (
program.desc.cached_hash_str()
+ str(scope.raw_address())
+ _get_program_cache_key(feed, fetch_list)
)


def _get_strong_program_cache_key(program, feed, fetch_list):
Expand Down Expand Up @@ -876,6 +889,9 @@ def __init__(self):
self._get_cached_program_and_executor = lru_cache(maxsize=8)(
self._get_program_and_executor
)
self._get_cached_program_and_executor_pir_mode = lru_cache(maxsize=8)(
self._get_pir_program_and_executor
)

def clear(self):
self._get_cached_program_and_executor.cache_clear()
Expand Down Expand Up @@ -1033,6 +1049,27 @@ def get_pir_program_and_executor(
place,
scope,
):
return self._get_cached_program_and_executor_pir_mode(
self._CachedData(
program,
feed,
fetch_list,
feed_var_name,
fetch_var_name,
place,
scope,
)
)

def _get_pir_program_and_executor(self, cached_data):
program = cached_data.program
feed = cached_data.feed
fetch_list = cached_data.fetch_list
feed_var_name = cached_data.feed_var_name
fetch_var_name = cached_data.fetch_var_name
place = cached_data.place
scope = cached_data.scope

_add_pir_fetch_ops(
program, fetch_list=fetch_list, fetch_var_name=fetch_var_name
)
Expand Down
18 changes: 18 additions & 0 deletions test/ir/pir/test_special_op_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,24 @@ def test_program(self):
), "share_buffer should be translated to share_data"


class TestDataOp(unittest.TestCase):
def test_data_op(self):
place = core.Place()
place.set_place(paddle.CPUPlace())

new_scope = paddle.static.Scope()
main_program = paddle.static.Program()
with paddle.static.scope_guard(new_scope):
with paddle.static.program_guard(main_program):
_ = paddle.static.data(name="y", shape=[3, 9, 5], dtype="int64")
l = pir.translate_to_pir(main_program.desc)
self.assertTrue(len(l.global_block().ops) > 0)
self.assertTrue(l.global_block().ops[0].name() == "pd_op.data")
data_op = l.global_block().ops[0]
self.assertIn("dtype", data_op.attrs())
self.assertEqual(str(data_op.attrs()["dtype"]), "DataType.INT64")


class TestCheckUnregisteredOp(unittest.TestCase):
def test_program(self):
main_program = paddle.static.Program()
Expand Down

0 comments on commit f100de1

Please sign in to comment.