@@ -46,6 +46,8 @@ def __init__(self, pde, algo, ninputs, inputs_attr, nlabels, labels_attr):
4646 self .algo .net .make_network_static ()
4747
4848 def forward (self , * inputs_labels ):
49+ for input in inputs_labels :
50+ input .stop_gradient = False
4951
5052 loss , outs = self .algo .compute (
5153 * inputs_labels ,
@@ -303,6 +305,10 @@ def __solve_static(self, num_epoch, bs, checkpoint_freq):
303305 fetch_list = fetches )
304306 print ("static epoch: " + str (epoch + 1 ), "loss: " , rslt [0 ])
305307
308+ rslt = self .exe .run (self .main_program ,
309+ feed = feeds ,
310+ fetch_list = fetches )
311+
306312 return rslt [1 :]
307313
308314 # def __solve_static_dist(self, num_epoch, bs, checkpoint_freq):
@@ -384,9 +390,9 @@ def __solve_static_auto_dist(self, num_epoch, bs, checkpoint_freq):
384390 labels_attr )
385391
386392 inputs_labels_spec = list ()
387- for i in inputs_labels :
393+ for i , data in enumerate ( inputs_labels ) :
388394 inputs_labels_spec .append (
389- InputSpec (i .shape , self . _dtype , 'input' + str (i )))
395+ InputSpec (data .shape , 'float32' , 'input' + str (i )))
390396
391397 labels_spec = None
392398
@@ -404,17 +410,53 @@ def __solve_static_auto_dist(self, num_epoch, bs, checkpoint_freq):
404410
405411 # train
406412 engine .prepare (optimizer = self .opt , loss = loss_func )
407- rslt = engine .fit (train_dataset , sample_generator = False )
413+ engine .fit (train_dataset , sample_generator = False )
408414
409415 print ("\n ********** engine predict start **** \n " )
410416
411- # predict
412- test_dataset = DataSetStatic (1 , inputs_labels )
413- engine .prepare (optimizer = self .opt , loss = loss_func , mode = 'predict' )
414- rslt = engine .predict (test_dataset , sample_generator = False )
415-
416- print ("\n ********** engine done **** \n " )
417+ # test
418+ inputs_labels = list ()
419+ test_program = paddle .fluid .Program ()
420+ start_program = paddle .fluid .Program ()
421+ with paddle .fluid .program_guard (test_program , start_program ):
422+ with paddle .fluid .unique_name .guard ():
423+ self .algo .net .make_network_static ()
424+ for i in range (len (inputs )):
425+ #inputs
426+ input = paddle .static .data (
427+ name = 'input' + str (i ),
428+ shape = inputs [i ].shape ,
429+ dtype = 'float32' )
430+ inputs_labels .append (input )
431+ for i in range (len (labels )):
432+ #labels
433+ label = paddle .static .data (
434+ name = 'label' + str (i ),
435+ shape = labels [i ].shape ,
436+ dtype = 'float32' )
437+ inputs_labels .append (label )
438+
439+ _ , outputs = self .algo .compute (
440+ * inputs_labels ,
441+ ninputs = ninputs ,
442+ inputs_attr = inputs_attr ,
443+ nlabels = nlabels ,
444+ labels_attr = labels_attr ,
445+ pde = self .pde )
417446
418- # print(rslt[0][1:])
447+ # feeds inputs
448+ feeds = dict ()
449+ for i in range (len (inputs )):
450+ feeds ['input' + str (i )] = inputs [i ]
451+ # feeds labels
452+ for i in range (len (labels )):
453+ feeds ['label' + str (i )] = labels [i ]
454+ # fetch_list
455+ fetches = []
456+ for out in outputs :
457+ fetches .append (out .name )
458+ rslt = engine ._executor .run (test_program ,
459+ feed = feeds ,
460+ fetch_list = fetches )
419461
420- return rslt [ 0 ][ 1 :]
462+ return rslt
0 commit comments