Skip to content

Commit 072f8ee

Browse files
author
levi131
committed
update for assign op
1 parent 2783e63 commit 072f8ee

File tree

1 file changed

+71
-19
lines changed

1 file changed

+71
-19
lines changed

examples/laplace2d/transform.py

Lines changed: 71 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,18 @@ def program_transform(program):
125125
_create_op_desc_('tanh_p', {'X': [in_names[0]]},
126126
{'Y': [out_names[0]]}, {}))
127127

128+
elif op_desc.type() == 'assign':
129+
tmp_1 = name_gen.get_var(new_block, block.var(in_names[0]))
130+
to_insert.append(
131+
_create_op_desc_('fill_constant_p', {}, {'Y': [tmp_1]}, {
132+
'shape': block.var(in_names[0]).shape,
133+
'value': 0.0
134+
}))
135+
to_insert.append(
136+
_create_op_desc_('add_p', {'X': [in_names[0]],
137+
'Y': [tmp_1]},
138+
{'Z': [out_names[0]]}, {}))
139+
128140
elif op_desc.type() == 'reshape2':
129141
to_insert.append(
130142
_create_op_desc_('reshape_p', {'X': [in_names[0]]}, {
@@ -138,14 +150,33 @@ def program_transform(program):
138150
}, {'axis': op_desc.attr('axis')}))
139151

140152
elif op_desc.type() == 'slice':
141-
to_insert.append(
142-
_create_op_desc_('slice_select_p', {'X': [in_names[0]]},
143-
{'Y': [out_names[0]]}, {
144-
'axis': op_desc.attr('axes'),
145-
'starts': op_desc.attr('starts'),
146-
'ends': op_desc.attr('ends'),
147-
'strides': 1,
148-
}))
153+
if op_desc.attr('decrease_axis') is None:
154+
to_insert.append(
155+
_create_op_desc_('slice_select_p', {'X': [in_names[0]]},
156+
{'Y': [out_names[0]]}, {
157+
'axis': op_desc.attr('axes'),
158+
'starts': op_desc.attr('starts'),
159+
'ends': op_desc.attr('ends'),
160+
'strides': 1,
161+
}))
162+
else:
163+
tmp_shape = list(block.var(in_names[0]).shape)
164+
for axis in op_desc.attr('decrease_axis'):
165+
tmp_shape[axis] = 1
166+
tmp_0 = name_gen.get_var(
167+
new_block, block.var(in_names[0]), shape=tuple(tmp_shape))
168+
to_insert.append(
169+
_create_op_desc_('slice_select_p', {'X': [in_names[0]]},
170+
{'Y': [tmp_0]}, {
171+
'axis': op_desc.attr('axes'),
172+
'starts': op_desc.attr('starts'),
173+
'ends': op_desc.attr('ends'),
174+
'strides': 1,
175+
}))
176+
to_insert.append(
177+
_create_op_desc_('reshape_p', {'X': [tmp_0]}, {
178+
'Y': [out_names[0]]
179+
}, {'shape': block.var(out_names[0]).shape}))
149180

150181
elif op_desc.type() == 'slice_grad':
151182
tmp_1 = name_gen.get_var(new_block, block.var(in_names[0]))
@@ -154,15 +185,36 @@ def program_transform(program):
154185
'shape': block.var(in_names[0]).shape,
155186
'value': 0.0
156187
}))
157-
to_insert.append(
158-
_create_op_desc_('slice_assign_p',
159-
{'X': [tmp_1],
160-
'Y': [in_names[1]]}, {'Z': [out_names[0]]}, {
161-
'axis': op_desc.attr('axes'),
162-
'starts': op_desc.attr('starts'),
163-
'ends': op_desc.attr('ends'),
164-
'strides': 1
165-
}))
188+
if op_desc.attr('decrease_axis') is None:
189+
to_insert.append(
190+
_create_op_desc_('slice_assign_p', {
191+
'X': [tmp_1],
192+
'Y': [in_names[1]]
193+
}, {'Z': [out_names[0]]}, {
194+
'axis': op_desc.attr('axes'),
195+
'starts': op_desc.attr('starts'),
196+
'ends': op_desc.attr('ends'),
197+
'strides': 1
198+
}))
199+
else:
200+
tmp_shape = list(block.var(in_names[1]).shape)
201+
for axis in op_desc.attr('decrease_axis'):
202+
assert axis == 1
203+
tmp_shape.append(1)
204+
tmp_2 = name_gen.get_var(
205+
new_block, block.var(in_names[1]), shape=tuple(tmp_shape))
206+
to_insert.append(
207+
_create_op_desc_('reshape_p', {'X': [in_names[1]]},
208+
{'Y': [tmp_2]}, {'shape': tmp_shape}))
209+
to_insert.append(
210+
_create_op_desc_('slice_assign_p',
211+
{'X': [tmp_1],
212+
'Y': [tmp_2]}, {'Z': [out_names[0]]}, {
213+
'axis': op_desc.attr('axes'),
214+
'starts': op_desc.attr('starts'),
215+
'ends': op_desc.attr('ends'),
216+
'strides': 1
217+
}))
166218

167219
elif op_desc.type() == 'concat_grad':
168220
to_insert.append(
@@ -759,8 +811,8 @@ def program_transform(program):
759811
'add_p', {'X': [in_names[2]],
760812
'Y': [tmp_1]}, {'Z': [out_names[1]]}, {}))
761813
else:
762-
print(op_desc.type())
763-
# assert op_desc.type() in {'adam', 'shape', 'fill_constant'}
814+
# print(op_desc.type())
815+
assert op_desc.type() in {'adam', 'shape', 'fill_constant'}
764816
to_insert.append(op_desc)
765817

766818
for new_op_desc in to_insert:

0 commit comments

Comments
 (0)