Skip to content

Commit b4ad94f

Browse files
authored
Merge pull request PaddlePaddle#4 from levi131/laplace_static_update
update Hessian and print untransfromed op
2 parents 272151d + 072f8ee commit b4ad94f

File tree

2 files changed

+73
-20
lines changed

2 files changed

+73
-20
lines changed

examples/laplace2d/laplace2d_static.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import time
1919

2020
import paddle
21-
from paddle.autograd.functional import Hessian
21+
from paddle.incubate.autograd import Hessian
2222
from transform import program_transform
2323

2424
paddle.enable_static()
@@ -98,8 +98,8 @@ def GenSolution(xy, bc_index):
9898
outputs = net.nn_func(inputs)
9999

100100
# eq_loss
101-
hes = Hessian(net.nn_func, inputs, batch=True)
102-
eq_loss = paddle.norm(hes[0, 0] + hes[1, 1], p=2)
101+
hes = Hessian(net.nn_func, inputs, is_batched=True)
102+
eq_loss = paddle.norm(hes[:, 0, 0] + hes[:, 1, 1], p=2)
103103

104104
# bc_loss
105105
bc_index = paddle.static.data(name='bc_idx', shape=[40], dtype='int32')

examples/laplace2d/transform.py

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,18 @@ def program_transform(program):
125125
_create_op_desc_('tanh_p', {'X': [in_names[0]]},
126126
{'Y': [out_names[0]]}, {}))
127127

128+
elif op_desc.type() == 'assign':
129+
tmp_1 = name_gen.get_var(new_block, block.var(in_names[0]))
130+
to_insert.append(
131+
_create_op_desc_('fill_constant_p', {}, {'Y': [tmp_1]}, {
132+
'shape': block.var(in_names[0]).shape,
133+
'value': 0.0
134+
}))
135+
to_insert.append(
136+
_create_op_desc_('add_p', {'X': [in_names[0]],
137+
'Y': [tmp_1]},
138+
{'Z': [out_names[0]]}, {}))
139+
128140
elif op_desc.type() == 'reshape2':
129141
to_insert.append(
130142
_create_op_desc_('reshape_p', {'X': [in_names[0]]}, {
@@ -138,14 +150,33 @@ def program_transform(program):
138150
}, {'axis': op_desc.attr('axis')}))
139151

140152
elif op_desc.type() == 'slice':
141-
to_insert.append(
142-
_create_op_desc_('slice_select_p', {'X': [in_names[0]]},
143-
{'Y': [out_names[0]]}, {
144-
'axis': op_desc.attr('axes'),
145-
'starts': op_desc.attr('starts'),
146-
'ends': op_desc.attr('ends'),
147-
'strides': op_desc.attr('decrease_axis')
148-
}))
153+
if op_desc.attr('decrease_axis') is None:
154+
to_insert.append(
155+
_create_op_desc_('slice_select_p', {'X': [in_names[0]]},
156+
{'Y': [out_names[0]]}, {
157+
'axis': op_desc.attr('axes'),
158+
'starts': op_desc.attr('starts'),
159+
'ends': op_desc.attr('ends'),
160+
'strides': 1,
161+
}))
162+
else:
163+
tmp_shape = list(block.var(in_names[0]).shape)
164+
for axis in op_desc.attr('decrease_axis'):
165+
tmp_shape[axis] = 1
166+
tmp_0 = name_gen.get_var(
167+
new_block, block.var(in_names[0]), shape=tuple(tmp_shape))
168+
to_insert.append(
169+
_create_op_desc_('slice_select_p', {'X': [in_names[0]]},
170+
{'Y': [tmp_0]}, {
171+
'axis': op_desc.attr('axes'),
172+
'starts': op_desc.attr('starts'),
173+
'ends': op_desc.attr('ends'),
174+
'strides': 1,
175+
}))
176+
to_insert.append(
177+
_create_op_desc_('reshape_p', {'X': [tmp_0]}, {
178+
'Y': [out_names[0]]
179+
}, {'shape': block.var(out_names[0]).shape}))
149180

150181
elif op_desc.type() == 'slice_grad':
151182
tmp_1 = name_gen.get_var(new_block, block.var(in_names[0]))
@@ -154,15 +185,36 @@ def program_transform(program):
154185
'shape': block.var(in_names[0]).shape,
155186
'value': 0.0
156187
}))
157-
to_insert.append(
158-
_create_op_desc_('slice_assign_p',
159-
{'X': [tmp_1],
160-
'Y': [in_names[1]]}, {'Z': [out_names[0]]}, {
161-
'axis': op_desc.attr('axes'),
162-
'starts': op_desc.attr('starts'),
163-
'ends': op_desc.attr('ends'),
164-
'strides': op_desc.attr('decrease_axis')
165-
}))
188+
if op_desc.attr('decrease_axis') is None:
189+
to_insert.append(
190+
_create_op_desc_('slice_assign_p', {
191+
'X': [tmp_1],
192+
'Y': [in_names[1]]
193+
}, {'Z': [out_names[0]]}, {
194+
'axis': op_desc.attr('axes'),
195+
'starts': op_desc.attr('starts'),
196+
'ends': op_desc.attr('ends'),
197+
'strides': 1
198+
}))
199+
else:
200+
tmp_shape = list(block.var(in_names[1]).shape)
201+
for axis in op_desc.attr('decrease_axis'):
202+
assert axis == 1
203+
tmp_shape.append(1)
204+
tmp_2 = name_gen.get_var(
205+
new_block, block.var(in_names[1]), shape=tuple(tmp_shape))
206+
to_insert.append(
207+
_create_op_desc_('reshape_p', {'X': [in_names[1]]},
208+
{'Y': [tmp_2]}, {'shape': tmp_shape}))
209+
to_insert.append(
210+
_create_op_desc_('slice_assign_p',
211+
{'X': [tmp_1],
212+
'Y': [tmp_2]}, {'Z': [out_names[0]]}, {
213+
'axis': op_desc.attr('axes'),
214+
'starts': op_desc.attr('starts'),
215+
'ends': op_desc.attr('ends'),
216+
'strides': 1
217+
}))
166218

167219
elif op_desc.type() == 'concat_grad':
168220
to_insert.append(
@@ -759,6 +811,7 @@ def program_transform(program):
759811
'add_p', {'X': [in_names[2]],
760812
'Y': [tmp_1]}, {'Z': [out_names[1]]}, {}))
761813
else:
814+
# print(op_desc.type())
762815
assert op_desc.type() in {'adam', 'shape', 'fill_constant'}
763816
to_insert.append(op_desc)
764817

0 commit comments

Comments
 (0)