Skip to content

Commit 3bcf0a5

Browse files
authored
Merge pull request PaddlePaddle#6 from zhaoyinglia/develop-arch
update '__solve_static_auto_dist' predict
2 parents 4293643 + 4508b40 commit 3bcf0a5

File tree

1 file changed

+53
-11
lines changed

1 file changed

+53
-11
lines changed

paddlescience/solver/solver.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def __init__(self, pde, algo, ninputs, inputs_attr, nlabels, labels_attr):
4646
self.algo.net.make_network_static()
4747

4848
def forward(self, *inputs_labels):
49+
for input in inputs_labels:
50+
input.stop_gradient = False
4951

5052
loss, outs = self.algo.compute(
5153
*inputs_labels,
@@ -303,6 +305,10 @@ def __solve_static(self, num_epoch, bs, checkpoint_freq):
303305
fetch_list=fetches)
304306
print("static epoch: " + str(epoch + 1), "loss: ", rslt[0])
305307

308+
rslt = self.exe.run(self.main_program,
309+
feed=feeds,
310+
fetch_list=fetches)
311+
306312
return rslt[1:]
307313

308314
# def __solve_static_dist(self, num_epoch, bs, checkpoint_freq):
@@ -384,9 +390,9 @@ def __solve_static_auto_dist(self, num_epoch, bs, checkpoint_freq):
384390
labels_attr)
385391

386392
inputs_labels_spec = list()
387-
for i in inputs_labels:
393+
for i, data in enumerate(inputs_labels):
388394
inputs_labels_spec.append(
389-
InputSpec(i.shape, self._dtype, 'input' + str(i)))
395+
InputSpec(data.shape, 'float32', 'input' + str(i)))
390396

391397
labels_spec = None
392398

@@ -404,17 +410,53 @@ def __solve_static_auto_dist(self, num_epoch, bs, checkpoint_freq):
404410

405411
# train
406412
engine.prepare(optimizer=self.opt, loss=loss_func)
407-
rslt = engine.fit(train_dataset, sample_generator=False)
413+
engine.fit(train_dataset, sample_generator=False)
408414

409415
print("\n ********** engine predict start **** \n")
410416

411-
# predict
412-
test_dataset = DataSetStatic(1, inputs_labels)
413-
engine.prepare(optimizer=self.opt, loss=loss_func, mode='predict')
414-
rslt = engine.predict(test_dataset, sample_generator=False)
415-
416-
print("\n ********** engine done **** \n")
417+
# test
418+
inputs_labels = list()
419+
test_program = paddle.fluid.Program()
420+
start_program = paddle.fluid.Program()
421+
with paddle.fluid.program_guard(test_program, start_program):
422+
with paddle.fluid.unique_name.guard():
423+
self.algo.net.make_network_static()
424+
for i in range(len(inputs)):
425+
#inputs
426+
input = paddle.static.data(
427+
name='input' + str(i),
428+
shape=inputs[i].shape,
429+
dtype='float32')
430+
inputs_labels.append(input)
431+
for i in range(len(labels)):
432+
#labels
433+
label = paddle.static.data(
434+
name='label' + str(i),
435+
shape=labels[i].shape,
436+
dtype='float32')
437+
inputs_labels.append(label)
438+
439+
_, outputs = self.algo.compute(
440+
*inputs_labels,
441+
ninputs=ninputs,
442+
inputs_attr=inputs_attr,
443+
nlabels=nlabels,
444+
labels_attr=labels_attr,
445+
pde=self.pde)
417446

418-
# print(rslt[0][1:])
447+
# feeds inputs
448+
feeds = dict()
449+
for i in range(len(inputs)):
450+
feeds['input' + str(i)] = inputs[i]
451+
# feeds labels
452+
for i in range(len(labels)):
453+
feeds['label' + str(i)] = labels[i]
454+
# fetch_list
455+
fetches = []
456+
for out in outputs:
457+
fetches.append(out.name)
458+
rslt = engine._executor.run(test_program,
459+
feed=feeds,
460+
fetch_list=fetches)
419461

420-
return rslt[0][1:]
462+
return rslt

0 commit comments

Comments
 (0)