diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 763e8db830bc11..813cb34826eb22 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -2395,6 +2395,7 @@ struct RepeatInterLeaveGradOpTranscriber : public OpTranscriber { return op_inputs; } }; + OpTranslator::OpTranslator() { pir::IrContext* ctx = pir::IrContext::Instance(); ctx->GetOrRegisterDialect(); diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index e22712bb67c348..f78168a527d6a9 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -128,7 +128,8 @@ std::string GetValueInfo(Value v) { } void BindProgram(py::module *m) { - py::class_> program(*m, "Program", R"DOC( + py::class_> program( + *m, "Program", py::dynamic_attr(), R"DOC( Create Python Program. Program is an abstraction of model structure, divided into computational graphs and weights. The Program has a main block that stores the computational graphs. diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 9c6484a1d46117..92ac2fcbb5c348 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -33,11 +33,12 @@ # the illogical implement in the monkey-patch methods later. from .framework import monkey_patch_variable from .framework import monkey_patch_math_tensor -from .pir import monkey_patch_opresult +from .pir import monkey_patch_opresult, monkey_patch_program monkey_patch_variable() monkey_patch_math_tensor() monkey_patch_opresult() +monkey_patch_program() from .framework import ( disable_signal_handler, diff --git a/python/paddle/pir/__init__.py b/python/paddle/pir/__init__.py index c0e8e96a86f50e..0a28dbadd6f880 100644 --- a/python/paddle/pir/__init__.py +++ b/python/paddle/pir/__init__.py @@ -38,5 +38,6 @@ from . import core from .math_op_patch import monkey_patch_opresult +from .program_patch import monkey_patch_program __all__ = [] diff --git a/python/paddle/pir/program_patch.py b/python/paddle/pir/program_patch.py new file mode 100644 index 00000000000000..4de46a647259a5 --- /dev/null +++ b/python/paddle/pir/program_patch.py @@ -0,0 +1,34 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import Program + +_already_patch_program = False + +global_prog_seed = 0 + + +def monkey_patch_program(): + def global_seed(self, seed=0): + global global_prog_seed + global_prog_seed = seed + self._seed = global_prog_seed + + Program.global_seed = global_seed + global global_prog_seed + Program._seed = global_prog_seed + + global _already_patch_program + if not _already_patch_program: + _already_patch_program = True diff --git a/test/ir/pir/test_ir_pybind.py b/test/ir/pir/test_ir_pybind.py index 8d843dddf6924f..4c42c0f6f77ae8 100644 --- a/test/ir/pir/test_ir_pybind.py +++ b/test/ir/pir/test_ir_pybind.py @@ -198,6 +198,13 @@ def test_get_output_intermediate_status(self): results = unsqueeze_op.get_output_intermediate_status() self.assertEqual(results, [False, True]) + def test_prog_seed(self): + p = pir.Program() + self.assertEqual(p._seed, 0) + + p.global_seed(10) + self.assertEqual(p._seed, 10) + if __name__ == "__main__": unittest.main()