Skip to content

Commit

Permalink
【PIR】modify segment_fault of Swintransformer model (#69036)
Browse files Browse the repository at this point in the history
* modify concat infermeta

* modify dynamic inplace_map
  • Loading branch information
xiaoguoguo626807 authored Oct 31, 2024
1 parent 56cb613 commit 973956f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
17 changes: 10 additions & 7 deletions python/paddle/base/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def _setitem_static(x, indices, values):
values(Tensor|Number|Ndarray): values to be assigned to the x.
"""
from . import in_dynamic_or_pir_mode
from .framework import Variable, default_main_program, in_pir_mode
from .framework import Variable, in_pir_mode

is_tensor_array = is_tensor_array_type(x)

Expand Down Expand Up @@ -557,7 +557,9 @@ def _setitem_static(x, indices, values):
_global_inplace_map,
)

_global_inplace_map.add(default_main_program(), x, output)
_global_inplace_map.add(
paddle.static.default_main_program(), x, output
)
return output
else:
helper = paddle.base.layer_helper.LayerHelper(
Expand All @@ -572,7 +574,7 @@ def _setitem_static(x, indices, values):
output = helper.create_variable_for_type_inference(
dtype=x.dtype
)
cur_block = default_main_program().current_block()
cur_block = paddle.static.default_main_program().current_block()
cur_block.append_op(
type="set_value",
inputs=inputs,
Expand Down Expand Up @@ -680,7 +682,9 @@ def _setitem_static(x, indices, values):
_global_inplace_map,
)

_global_inplace_map.add(default_main_program(), x, output)
_global_inplace_map.add(
paddle.static.default_main_program(), x, output
)
else:
helper = paddle.base.layer_helper.LayerHelper(
'set_value', **locals()
Expand All @@ -694,7 +698,7 @@ def _setitem_static(x, indices, values):
output = helper.create_variable_for_type_inference(
dtype=x.dtype
)
cur_block = default_main_program().current_block()
cur_block = paddle.static.default_main_program().current_block()
cur_block.append_op(
type="set_value",
inputs=inputs,
Expand Down Expand Up @@ -792,9 +796,8 @@ def get_tensor_with_basic_indexing(
attrs['decrease_axis'],
)
else:
from .framework import default_main_program

target_block = default_main_program().current_block()
target_block = paddle.static.default_main_program().current_block()

slice_out_var = target_block.create_var(
name=unique_name.generate_with_ignorable_key(
Expand Down
5 changes: 3 additions & 2 deletions python/paddle/jit/dy2static/convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,14 @@ def convert_load(x):

# get the new output of the var
if isinstance(x, Value):
cur_block = default_main_program().current_block()

from paddle.jit.pir_dy2static.parameter_recorder import (
_global_inplace_map,
)

new_var = _global_inplace_map.get(cur_block.program, x)
new_var = _global_inplace_map.get(
paddle.static.default_main_program(), x
)
if new_var is not None:
return new_var

Expand Down

0 comments on commit 973956f

Please sign in to comment.