Skip to content

Commit 2783e63

Browse files
author
levi131
committed
update Hessian and print untransfromed op
1 parent bfd59b9 commit 2783e63

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

examples/laplace2d/laplace2d_static.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import numpy as np
1818

1919
import paddle
20-
from paddle.autograd.functional import Hessian
20+
from paddle.incubate.autograd import Hessian
2121
from transform import program_transform
2222

2323
paddle.enable_static()
@@ -96,8 +96,8 @@ def GenSolution(xy, bc_index):
9696
outputs = net.nn_func(inputs)
9797

9898
# eq_loss
99-
hes = Hessian(net.nn_func, inputs, batch=True)
100-
eq_loss = paddle.norm(hes[0, 0] + hes[1, 1], p=2)
99+
hes = Hessian(net.nn_func, inputs, is_batched=True)
100+
eq_loss = paddle.norm(hes[:, 0, 0] + hes[:, 1, 1], p=2)
101101

102102
# bc_loss
103103
bc_index = paddle.static.data(name='bc_idx', shape=[40], dtype='int64')

examples/laplace2d/transform.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def program_transform(program):
144144
'axis': op_desc.attr('axes'),
145145
'starts': op_desc.attr('starts'),
146146
'ends': op_desc.attr('ends'),
147-
'strides': op_desc.attr('decrease_axis')
147+
'strides': 1,
148148
}))
149149

150150
elif op_desc.type() == 'slice_grad':
@@ -161,7 +161,7 @@ def program_transform(program):
161161
'axis': op_desc.attr('axes'),
162162
'starts': op_desc.attr('starts'),
163163
'ends': op_desc.attr('ends'),
164-
'strides': op_desc.attr('decrease_axis')
164+
'strides': 1
165165
}))
166166

167167
elif op_desc.type() == 'concat_grad':
@@ -759,7 +759,8 @@ def program_transform(program):
759759
'add_p', {'X': [in_names[2]],
760760
'Y': [tmp_1]}, {'Z': [out_names[1]]}, {}))
761761
else:
762-
assert op_desc.type() in {'adam', 'shape', 'fill_constant'}
762+
print(op_desc.type())
763+
# assert op_desc.type() in {'adam', 'shape', 'fill_constant'}
763764
to_insert.append(op_desc)
764765

765766
for new_op_desc in to_insert:

0 commit comments

Comments
 (0)