Skip to content

Commit

Permalink
[AutoParallel] add fetch_list in engine api (#43312)
Browse files Browse the repository at this point in the history
* add fetch_list

* fix evaluate log

* tiny fix
  • Loading branch information
zhaoyinglia authored Jun 8, 2022
1 parent 07ede11 commit 971e479
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 28 deletions.
74 changes: 49 additions & 25 deletions python/paddle/distributed/auto_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from paddle.fluid.layers.utils import flatten
from paddle.fluid.executor import global_scope
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Operator
from paddle.fluid.framework import Operator, Variable
from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed import fleet
Expand Down Expand Up @@ -256,6 +256,7 @@ def fit(self,
train_data,
batch_size=1,
epochs=1,
fetch_list=None,
steps_per_epoch=None,
use_program_cache=False,
return_numpy=True):
Expand All @@ -266,13 +267,14 @@ def fit(self,
"train model is not ready, please call `engine.prepare()` first."
train_dataloader = self._create_dataloader(train_data, batch_size,
epochs, steps_per_epoch)
self._usr_fetch_list = fetch_list

outputs = []
for epoch in range(epochs):
for step, data in enumerate(train_dataloader):
logs, loss = self._train_step(data, use_program_cache,
logs, outs = self._train_step(data, use_program_cache,
return_numpy)
outputs.append(loss)
outputs.append(outs)
train_logs = {
"train_" + name: val
for name, val in logs.items()
Expand All @@ -283,94 +285,116 @@ def fit(self,
def evaluate(self,
eval_data,
batch_size=1,
fetch_list=None,
use_program_cache=False,
return_numpy=True):
self.mode = 'eval'
assert self.mode in self._dist_main_progs, \
"eval model is not ready, please call `engine.prepare()` first."
eval_dataloader = self._create_dataloader(eval_data, batch_size)
self._usr_fetch_list = fetch_list

for step, data in enumerate(eval_dataloader):
eval_logs = dict()
outs = self._eval_step(data, use_program_cache, return_numpy)
logs, outs = self._eval_step(data, use_program_cache, return_numpy)
eval_logs["eval_loss"] = outs[0] if len(outs) > 0 else []
for metric in self._metrics:
results = metric.accumulate()
for i, res in enumerate(to_list(results)):
eval_logs["eval_" + metric.name()[i]] = res
for name, val in logs.items():
eval_logs["eval_" + name] = val
self._logger.info(eval_logs)
return eval_logs

def predict(self,
test_data,
batch_size=1,
fetch_list=None,
use_program_cache=False,
return_numpy=True):
self.mode = 'predict'
assert self.mode in self._dist_main_progs, \
"predict model is not ready, please call `engine.prepare()` first."
test_dataloader = self._create_dataloader(test_data, batch_size)
self._usr_fetch_list = fetch_list

outputs = []
for step, data in enumerate(test_dataloader):
logs, outs = self._predict_step(data, use_program_cache,
return_numpy)
outputs.append(outs)
predict_logs = {
"predict_" + name: val
for name, val in logs.items()
}
predict_logs = {"pred_" + name: val for name, val in logs.items()}
self._logger.info(predict_logs)
return outputs

def _train_step(self, data, use_program_cache=False, return_numpy=True):
logs = {}
fetch_vars = self._fetch_vars[self.mode]["loss"]
fetch_list = self._fetch_list(fetch_vars)
fetch_list, usr_fetch_list = self._fetch_list(fetch_vars)
fetch_list += usr_fetch_list

loss = self._executor.run(self.main_program,
outs = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=use_program_cache,
return_numpy=return_numpy)
logs["loss"] = loss
return logs, loss
for i, out in enumerate(outs):
logs[fetch_list[i]] = out
return logs, outs

def _eval_step(self, data, use_program_cache=False, return_numpy=True):
logs = {}
metrics = self._fetch_vars[self.mode]["metrics"]
losses = self._fetch_vars[self.mode]["loss"]
fetch_loss = self._fetch_list(losses)
fetch_metrics = self._fetch_list(metrics)
fetch_loss, usr_fetch_list = self._fetch_list(losses)
fetch_metrics, usr_fetch_list = self._fetch_list(metrics)
fetch_list = fetch_loss + fetch_metrics

res = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=use_program_cache,
return_numpy=return_numpy)
if not res[len(fetch_loss):]:
return res[:len(fetch_loss)]
outs = self._executor.run(self.main_program,
fetch_list=fetch_list + usr_fetch_list,
use_program_cache=use_program_cache,
return_numpy=return_numpy)
usr_out = outs[len(fetch_list):]
for i, out in enumerate(usr_out):
logs[usr_fetch_list[i]] = out
outs = outs[:len(fetch_list)]
if not outs[len(fetch_loss):]:
return logs, outs[:len(fetch_loss)]
for metric in self._metrics:
metric.update(*res[len(fetch_loss):])
return res[:len(fetch_loss)]
metric.update(*outs[len(fetch_loss):])
return logs, outs[:len(fetch_loss)]

def _predict_step(self, data, use_program_cache=False, return_numpy=True):
logs = {}
fetch_vars = self._fetch_vars[self.mode]["outputs"]
fetch_list = self._fetch_list(fetch_vars)
fetch_list, usr_fetch_list = self._fetch_list(fetch_vars)
fetch_list += usr_fetch_list

outs = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=use_program_cache,
return_numpy=return_numpy)
logs["pred"] = outs
for i, out in enumerate(outs):
logs[fetch_list[i]] = out
return logs, outs

def _fetch_list(self, fetch_vars):
fetch_list = []
for var in fetch_vars:
if var.name in self.main_program.global_block().vars:
fetch_list.append(var.name)
return fetch_list
usr_fetch_list = []
if self._usr_fetch_list:
assert isinstance(self._usr_fetch_list,
list), "'fetch_list' type should be list."
for var in self._usr_fetch_list:
if isinstance(var, str):
if var in self.main_program.global_block().vars:
usr_fetch_list.append(var)
elif isinstance(var, Variable):
if var.name in self.main_program.global_block().vars:
usr_fetch_list.append(var.name)
return fetch_list, usr_fetch_list

def _create_dataloader(self,
dataset,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,16 @@ def train():
train_dataset = MyDataset(batch_num * batch_size)
engine.fit(train_dataset,
batch_size=batch_size,
steps_per_epoch=batch_num * batch_size)
steps_per_epoch=batch_num * batch_size,
fetch_list=['label'])

# eval
eval_dataset = MyDataset(batch_size)
engine.evaluate(eval_dataset, batch_size)
engine.evaluate(eval_dataset, batch_size, fetch_list=['label'])

# predict
test_dataset = MyDataset(batch_size)
engine.predict(test_dataset, batch_size)
engine.predict(test_dataset, batch_size, fetch_list=['label'])

# save
engine.save('./mlp_inf', training=False, mode='predict')
Expand Down

0 comments on commit 971e479

Please sign in to comment.