Skip to content

Commit

Permalink
[PIR] Allow pir program dynamically add attr (PaddlePaddle#58660)
Browse files Browse the repository at this point in the history
* allow pir::Program dynamically add attribute

* add seed for pir::Program

* polish code
  • Loading branch information
kangguangli authored Nov 6, 2023
1 parent fc39738 commit a45df38
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 2 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2395,6 +2395,7 @@ struct RepeatInterLeaveGradOpTranscriber : public OpTranscriber {
return op_inputs;
}
};

OpTranslator::OpTranslator() {
pir::IrContext* ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ std::string GetValueInfo(Value v) {
}

void BindProgram(py::module *m) {
py::class_<Program, std::shared_ptr<Program>> program(*m, "Program", R"DOC(
py::class_<Program, std::shared_ptr<Program>> program(
*m, "Program", py::dynamic_attr(), R"DOC(
Create Python Program. Program is an abstraction of model structure, divided into
computational graphs and weights. The Program has a main block that stores the computational
graphs.
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@
# the illogical implement in the monkey-patch methods later.
from .framework import monkey_patch_variable
from .framework import monkey_patch_math_tensor
from .pir import monkey_patch_opresult
from .pir import monkey_patch_opresult, monkey_patch_program

monkey_patch_variable()
monkey_patch_math_tensor()
monkey_patch_opresult()
monkey_patch_program()

from .framework import (
disable_signal_handler,
Expand Down
1 change: 1 addition & 0 deletions python/paddle/pir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,6 @@
from . import core

from .math_op_patch import monkey_patch_opresult
from .program_patch import monkey_patch_program

__all__ = []
34 changes: 34 additions & 0 deletions python/paddle/pir/program_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from . import Program

_already_patch_program = False

global_prog_seed = 0


def monkey_patch_program():
def global_seed(self, seed=0):
global global_prog_seed
global_prog_seed = seed
self._seed = global_prog_seed

Program.global_seed = global_seed
global global_prog_seed
Program._seed = global_prog_seed

global _already_patch_program
if not _already_patch_program:
_already_patch_program = True
7 changes: 7 additions & 0 deletions test/ir/pir/test_ir_pybind.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,13 @@ def test_get_output_intermediate_status(self):
results = unsqueeze_op.get_output_intermediate_status()
self.assertEqual(results, [False, True])

def test_prog_seed(self):
p = pir.Program()
self.assertEqual(p._seed, 0)

p.global_seed(10)
self.assertEqual(p._seed, 10)


if __name__ == "__main__":
unittest.main()

0 comments on commit a45df38

Please sign in to comment.