Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature]: add continue train on evaluation data #399

Merged
merged 5 commits into from
Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions easy_rec/python/compat/estimator_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,8 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
fout.write('Train Done.')

return result


def estimator_train_done(estimator):
train_done_file = os.path.join(estimator.model_dir, 'ESTIMATOR_TRAIN_DONE')
return gfile.Exists(train_done_file)
24 changes: 17 additions & 7 deletions easy_rec/python/compat/sync_replicas_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
```
"""

sync_que_id = -1

def __init__(self,
opt,
replicas_to_aggregate,
Expand Down Expand Up @@ -299,15 +301,24 @@ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
update_op = self._opt.apply_gradients(aggregated_grads_and_vars,
global_step)

def _get_token_qname():
SyncReplicasOptimizer.sync_que_id += 1
if SyncReplicasOptimizer.sync_que_id == 0:
return 'sync_token_q'
else:
return 'sync_token_q_' + str(SyncReplicasOptimizer.sync_que_id)

# Create token queue.
token_qname = _get_token_qname()
logging.info('create sync_token_queue[%s]' % token_qname)
with ops.device(global_step.device), ops.name_scope(''):
sync_token_queue = (
data_flow_ops.FIFOQueue(
-1,
global_step.dtype.base_dtype,
shapes=(),
name='sync_token_q',
shared_name='sync_token_q'))
name=token_qname,
shared_name=token_qname))
self._sync_token_queue = sync_token_queue
self._is_sync_que_closed = sync_token_queue.is_closed()
self._close_sync_que = sync_token_queue.close(
Expand Down Expand Up @@ -342,6 +353,8 @@ def apply_gradients(self, grads_and_vars, global_step=None, name=None):

self._chief_queue_runner = queue_runner.QueueRunner(
dummy_queue, [sync_op])
ops.add_to_collection(ops.GraphKeys.QUEUE_RUNNERS,
self._chief_queue_runner)
for accum, dev in self._accumulator_list:
with ops.device(dev):
chief_init_ops.append(
Expand Down Expand Up @@ -479,14 +492,12 @@ def begin(self):
self._local_init_op = self._sync_optimizer.chief_init_op
self._ready_for_local_init_op = (
self._sync_optimizer.ready_for_local_init_op)
self._q_runner = self._sync_optimizer.get_chief_queue_runner()
self._init_tokens_op = self._sync_optimizer.get_init_tokens_op(
self._num_tokens)
else:
self._local_init_op = self._sync_optimizer.local_step_init_op
self._ready_for_local_init_op = (
self._sync_optimizer.ready_for_local_init_op)
self._q_runner = None
self._init_tokens_op = None

def after_create_session(self, session, coord):
Expand All @@ -500,11 +511,10 @@ def after_create_session(self, session, coord):
'local_init. Init op: %s, error: %s' %
(self._local_init_op.name, msg))
session.run(self._local_init_op)
is_closed = session.run(self._sync_optimizer._is_sync_que_closed)
assert not is_closed, 'sync_que is closed'
if self._init_tokens_op is not None:
session.run(self._init_tokens_op)
if self._q_runner is not None:
self._q_runner.create_threads(
session, coord=coord, daemon=True, start=True)

def end(self, session):
try:
Expand Down
34 changes: 29 additions & 5 deletions easy_rec/python/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
import math
import os
import time

import six
import tensorflow as tf
Expand Down Expand Up @@ -279,7 +280,9 @@ def train_and_evaluate(pipeline_config_path, continue_train=False):

def _train_and_evaluate_impl(pipeline_config,
continue_train=False,
check_mode=False):
check_mode=False,
fit_on_eval=False,
fit_on_eval_steps=None):
train_config = pipeline_config.train_config
data_config = pipeline_config.data_config
feature_configs = config_util.get_compatible_feature_configs(pipeline_config)
Expand All @@ -301,18 +304,15 @@ def _train_and_evaluate_impl(pipeline_config,
estimator, run_config = _create_estimator(
pipeline_config, distribution=distribution, params=params)

master_stat_file = os.path.join(pipeline_config.model_dir, 'master.stat')
version_file = os.path.join(pipeline_config.model_dir, 'version')
if estimator_utils.is_chief():
_check_model_dir(pipeline_config.model_dir, continue_train)
config_util.save_pipeline_config(pipeline_config, pipeline_config.model_dir)
with gfile.GFile(version_file, 'w') as f:
f.write(easy_rec.__version__ + '\n')
if gfile.Exists(master_stat_file):
gfile.Remove(master_stat_file)

train_steps = None
if train_config.HasField('num_steps'):
if train_config.HasField('num_steps') and train_config.num_steps > 0:
train_steps = train_config.num_steps
assert train_steps is not None or data_config.num_epochs > 0, (
'either num_steps and num_epochs must be set to an integer > 0.')
Expand Down Expand Up @@ -347,6 +347,30 @@ def _train_and_evaluate_impl(pipeline_config,
from easy_rec.python.compat import estimator_train
estimator_train.train_and_evaluate(estimator, train_spec, eval_spec)
logging.info('Train and evaluate finish')
if fit_on_eval and (not estimator_utils.is_evaluator()):
tf.reset_default_graph()
logging.info('Start continue training on eval data')
eval_input_fn = _get_input_fn(data_config, feature_configs, eval_data,
**input_fn_kwargs)
if fit_on_eval_steps is not None:
# wait estimator train done to get the correct train_steps
while not estimator_train.estimator_train_done(estimator):
time.sleep(1)
train_steps = estimator_utils.get_trained_steps(estimator.model_dir)
logging.info('\ttrain_steps=%d fit_on_eval_steps=%d' %
(train_steps, fit_on_eval_steps))
fit_on_eval_steps += train_steps
# Do not use estimator_train.train_and_evaluate as it starts tf.Server,
# which is redundant and reports port not available error.
estimator.train(
input_fn=eval_input_fn,
max_steps=fit_on_eval_steps,
hooks=list(train_spec.hooks),
saving_listeners=train_spec.saving_listeners if hasattr(
train_spec, 'saving_listeners') else None)
logging.info('Finished training on eval data')
# return estimator for custom training using estimator.train
return estimator


def evaluate(pipeline_config,
Expand Down
17 changes: 17 additions & 0 deletions easy_rec/python/test/train_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,23 @@ def test_train_with_ps_worker(self):
'samples/model_config/multi_tower_on_taobao.config', self._test_dir)
self.assertTrue(self._success)

def test_fit_on_eval(self):
self._success = test_utils.test_distributed_train_eval(
'samples/model_config/multi_tower_on_taobao.config',
self._test_dir,
num_evaluator=1,
fit_on_eval=True)
self.assertTrue(self._success)

def test_unbalance_data(self):
self._success = test_utils.test_distributed_train_eval(
'samples/model_config/multi_tower_on_taobao_unblanace.config',
self._test_dir,
total_steps=0,
num_epoch=1,
num_evaluator=1)
self.assertTrue(self._success)

def test_train_with_ps_worker_with_evaluator(self):
self._success = test_utils.test_distributed_train_eval(
'samples/model_config/multi_tower_on_taobao.config',
Expand Down
18 changes: 16 additions & 2 deletions easy_rec/python/train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@
nargs='*',
default=None,
help='eval data input path')
parser.add_argument(
'--fit_on_eval',
action='store_true',
default=False,
help='Fit evaluation data after fitting and evaluating train data')
parser.add_argument(
'--fit_on_eval_steps',
type=int,
default=None,
help='Fit evaluation data steps')
parser.add_argument(
'--fine_tune_checkpoint',
type=str,
Expand Down Expand Up @@ -169,7 +179,11 @@
has_evaluator=False)
else:
config_util.auto_expand_share_feature_configs(pipeline_config)
_train_and_evaluate_impl(pipeline_config, args.continue_train,
args.check_mode)
_train_and_evaluate_impl(
pipeline_config,
args.continue_train,
args.check_mode,
fit_on_eval=args.fit_on_eval,
fit_on_eval_steps=args.fit_on_eval_steps)
else:
raise ValueError('pipeline_config_path should not be empty when training!')
8 changes: 8 additions & 0 deletions easy_rec/python/utils/estimator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,14 @@ def latest_checkpoint(model_dir):
return None


def get_trained_steps(model_dir):
ckpt_path = latest_checkpoint(model_dir)
if ckpt_path is not None:
return int(ckpt_path.split('-')[-1])
else:
return 0


def master_to_chief():
if 'TF_CONFIG' in os.environ:
tf_config = json.loads(os.environ['TF_CONFIG'])
Expand Down
30 changes: 23 additions & 7 deletions easy_rec/python/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ def _replace_data_for_test(data_path):
return data_path


def _load_config_for_test(pipeline_config_path, test_dir, total_steps=50):
def _load_config_for_test(pipeline_config_path,
test_dir,
total_steps=50,
num_epochs=0):
pipeline_config = config_util.get_configs_from_pipeline_file(
pipeline_config_path)
train_config = pipeline_config.train_config
Expand All @@ -171,7 +174,7 @@ def _load_config_for_test(pipeline_config_path, test_dir, total_steps=50):
pipeline_config.model_dir = os.path.join(test_dir, 'train')
logging.info('test_model_dir %s' % pipeline_config.model_dir)
eval_config.num_examples = max(10, data_config.batch_size)
data_config.num_epochs = 0
data_config.num_epochs = num_epochs
return pipeline_config


Expand Down Expand Up @@ -529,7 +532,9 @@ def _get_ports(num_worker):
def _ps_worker_train(pipeline_config_path,
test_dir,
num_worker,
num_evaluator=0):
num_evaluator=0,
fit_on_eval=False,
fit_on_eval_steps=None):
gpus = get_available_gpus()
# not enough gpus, run on cpu only
if len(gpus) < num_worker:
Expand All @@ -547,6 +552,10 @@ def _ps_worker_train(pipeline_config_path,
os.environ['TF_CONFIG'] = json.dumps(tf_config)
set_gpu_id(gpus[0])
train_cmd = 'python -m easy_rec.python.train_eval --pipeline_config_path %s' % pipeline_config_path
if fit_on_eval:
train_cmd += ' --fit_on_eval'
if fit_on_eval_steps is not None:
train_cmd += ' --fit_on_eval_steps ' + str(int(fit_on_eval_steps))
procs[chief_or_master] = run_cmd(
train_cmd, '%s/log_%s.txt' % (test_dir, chief_or_master))
tf_config['task'] = {'type': 'ps', 'index': 0}
Expand Down Expand Up @@ -665,10 +674,12 @@ def test_distributed_train_eval(pipeline_config_path,
total_steps=50,
num_evaluator=0,
edit_config_json=None,
use_hvd=False):
use_hvd=False,
fit_on_eval=False,
num_epoch=0):
logging.info('testing pipeline config %s' % pipeline_config_path)
pipeline_config = _load_config_for_test(pipeline_config_path, test_dir,
total_steps)
total_steps, num_epoch)
if edit_config_json is not None:
config_util.edit_config(pipeline_config, edit_config_json)

Expand All @@ -687,8 +698,13 @@ def test_distributed_train_eval(pipeline_config_path,
return _multi_worker_hvd_train(test_pipeline_config_path, test_dir, 2)
if train_config.train_distribute == DistributionStrategy.NoStrategy:
num_worker = 2
procs = _ps_worker_train(test_pipeline_config_path, test_dir, num_worker,
num_evaluator)
procs = _ps_worker_train(
test_pipeline_config_path,
test_dir,
num_worker,
num_evaluator,
fit_on_eval,
fit_on_eval_steps=int(total_steps // 2))
elif train_config.train_distribute == DistributionStrategy.MultiWorkerMirroredStrategy:
num_worker = 2
procs = _multi_worker_mirror_train(test_pipeline_config_path, test_dir,
Expand Down
Loading