From d0b8452cc1e070f7dd967c50ac1bb6c874756dbd Mon Sep 17 00:00:00 2001 From: Guoxia Wang Date: Tue, 1 Jun 2021 15:10:01 +0800 Subject: [PATCH] update api to paddle 2.1 (#81) --- plsc/entry.py | 534 ++++++++++++----------- plsc/models/base_model.py | 33 +- plsc/models/dist_algo.py | 305 ++++++------- plsc/models/resnet.py | 19 +- plsc/utils/fp16_utils.py | 22 +- plsc/utils/input_field.py | 12 +- plsc/utils/jpeg_reader.py | 65 ++- plsc/utils/parameter_converter.py | 683 ++++++++---------------------- train.py | 26 ++ train.sh | 15 + 10 files changed, 769 insertions(+), 945 deletions(-) create mode 100644 train.py create mode 100755 train.sh diff --git a/plsc/entry.py b/plsc/entry.py index 04c3c3960f363..4df9b0c903a60 100644 --- a/plsc/entry.py +++ b/plsc/entry.py @@ -29,10 +29,8 @@ import numpy as np import paddle import sklearn -import paddle import paddle.distributed.fleet as fleet - -from paddle.fluid.contrib.slim.quantization.quantize_program_pass import QuantizeProgramPass +from paddle.optimizer import Optimizer paddle.enable_static() @@ -41,7 +39,7 @@ from .models import base_model from .models import resnet from .utils import jpeg_reader as reader -from .utils.parameter_converter import ParameterConverter +from .utils.parameter_converter import rearrange_weight from .utils.verification import evaluate from .utils.input_field import InputField @@ -94,12 +92,11 @@ def __init__(self): self.optimizer = None self.model = None - self.train_reader = None - self.test_reader = None - self.predict_reader = None + self.train_dataset = None + self.test_dataset = None - self.train_program = paddle.static.Program() self.startup_program = paddle.static.Program() + self.train_program = paddle.static.Program() self.test_program = paddle.static.Program() self.predict_program = paddle.static.Program() @@ -113,7 +110,7 @@ def __init__(self): self.has_run_train = False # Whether has run training or not self.test_initialized = False - self.train_pass_id = -1 + self.cur_epoch = -1 self.use_fp16 = False self.fp16_user_dict = None @@ -126,6 +123,7 @@ def __init__(self): self.scale = self.config.scale self.lr = self.config.lr self.lr_steps = self.config.lr_steps + self.lr_scheduler = None self.train_image_num = self.config.train_image_num self.model_name = self.config.model_name self.emb_dim = self.config.emb_dim @@ -136,6 +134,7 @@ def __init__(self): self.warmup_epochs = self.config.warmup_epochs self.calc_train_acc = False + self.max_last_checkpoint_num = 5 if self.checkpoint_dir: self.checkpoint_dir = os.path.abspath(self.checkpoint_dir) if self.model_save_dir: @@ -147,6 +146,8 @@ def __init__(self): self.lr_decay_factor = 0.1 self.log_period = 200 + self.test_period = 0 + self.cur_steps = 0 self.input_info = [{ 'name': 'image', @@ -154,7 +155,7 @@ def __init__(self): 'dtype': 'float32' }, { 'name': 'label', - 'shape': [-1, 1], + 'shape': [-1], 'dtype': 'int64' }] self.input_field = None @@ -167,6 +168,7 @@ def __init__(self): num_trainers)) logger.info('default lr_decay_factor: {}'.format(self.lr_decay_factor)) logger.info('default log period: {}'.format(self.log_period)) + logger.info('default test period: {}'.format(self.test_period)) logger.info('=' * 30) def set_use_quant(self, quant): @@ -217,6 +219,10 @@ def set_log_period(self, period): self.log_period = period logger.info("Set log period to {}.".format(period)) + def set_test_period(self, period): + self.test_period = period + logger.info("Set test period to {}.".format(period)) + def set_lr_decay_factor(self, factor): self.lr_decay_factor = factor logger.info("Set lr decay factor to {}.".format(factor)) @@ -354,6 +360,13 @@ def set_checkpoint_dir(self, directory): self.checkpoint_dir = directory logger.info("Set checkpoint_dir to {}.".format(directory)) + def set_max_last_checkpoint_num(self, num): + """ + Set the max number of last checkpoint to keep. + """ + self.max_last_checkpoint_num = num + logger.info("Set max_last_checkpoint_num to {}.".format(num)) + def set_warmup_epochs(self, num): self.warmup_epochs = num logger.info("Set warmup_epochs to {}.".format(num)) @@ -386,6 +399,16 @@ def set_distfc_attr(self, param_attr=None, bias_attr=None): logger.info("Set bias_attr for distfc to {}.".format( self.bias_attr)) + def _set_info(self, key, value): + if not hasattr(self, '_info'): + self._info = {} + self._info[key] = value + + def _get_info(self, key): + if hasattr(self, '_info') and key in self._info: + return self._info[key] + return None + def _get_optimizer(self): if not self.optimizer: bd = [step for step in self.lr_steps] @@ -406,18 +429,18 @@ def _get_optimizer(self): logger.info("LR boundaries: {}".format(bd)) logger.info("lr_step: {}".format(lr)) if self.warmup_epochs: - lr_val = paddle.optimizer.lr.LinearWarmup( + self.lr_scheduler = paddle.optimizer.lr.LinearWarmup( paddle.optimizer.lr.PiecewiseDecay( boundaries=bd, values=lr), warmup_steps, start_lr, base_lr) else: - lr_val = paddle.optimizer.lr.PiecewiseDecay( + self.lr_scheduler = paddle.optimizer.lr.PiecewiseDecay( boundaries=bd, values=lr) optimizer = paddle.optimizer.Momentum( - learning_rate=lr_val, + learning_rate=self.lr_scheduler, momentum=0.9, weight_decay=paddle.regularizer.L2Decay(5e-4)) self.optimizer = optimizer @@ -430,7 +453,7 @@ def _get_optimizer(self): loss_type=self.loss_type, fp16_user_dict=self.fp16_user_dict) elif self.use_fp16: - self.optimizer = fluid.contrib.mixed_precision.decorate( + self.optimizer = paddle.static.amp.decorate( optimizer=self.optimizer, init_loss_scaling=self.fp16_user_dict['init_loss_scaling'], incr_every_n_steps=self.fp16_user_dict['incr_every_n_steps'], @@ -443,10 +466,7 @@ def _get_optimizer(self): amp_lists=self.fp16_user_dict['amp_lists']) return self.optimizer - def build_program(self, - is_train=True, - use_parallel_test=False, - dist_strategy=None): + def build_program(self, is_train=True, use_parallel_test=False): model_name = self.model_name assert not (is_train and use_parallel_test), \ "is_train and use_parallel_test cannot be set simultaneously." @@ -468,13 +488,13 @@ def build_program(self, emb, loss, prob = model.get_output( input=input_field, + num_classes=self.num_classes, num_ranks=num_trainers, rank_id=trainer_id, is_train=is_train, - num_classes=self.num_classes, - loss_type=self.loss_type, param_attr=self.param_attr, bias_attr=self.bias_attr, + loss_type=self.loss_type, margin=self.margin, scale=self.scale) @@ -485,49 +505,53 @@ def build_program(self, if self.calc_train_acc: shard_prob = loss._get_info("shard_prob") - prob_all = fluid.layers.collective._c_allgather( - shard_prob, - nranks=num_trainers, - use_calc_stream=True) - prob_list = fluid.layers.split( - prob_all, dim=0, num_or_sections=num_trainers) - prob = fluid.layers.concat(prob_list, axis=1) - label_all = fluid.layers.collective._c_allgather( - input_field.label, - nranks=num_trainers, - use_calc_stream=True) - acc1 = fluid.layers.accuracy( - input=prob, label=label_all, k=1) - acc5 = fluid.layers.accuracy( - input=prob, label=label_all, k=5) + prob_list = [] + paddle.distributed.all_gather(prob_list, shard_prob) + prob = paddle.concat(prob_list, axis=1) + label_list = [] + paddle.distributed.all_gather(label_list, + input_field.label) + label_all = paddle.concat(label_list, axis=0) + acc1 = paddle.static.accuracy( + input=prob, + label=paddle.reshape(label_all, (-1, 1)), + k=1) + acc5 = paddle.static.accuracy( + input=prob, + label=paddle.reshape(label_all, (-1, 1)), + k=5) else: if self.calc_train_acc: - acc1 = fluid.layers.accuracy( - input=prob, label=input_field.label, k=1) - acc5 = fluid.layers.accuracy( - input=prob, label=input_field.label, k=5) + acc1 = paddle.static.accuracy( + input=prob, + label=paddle.reshape(input_field.label, (-1, 1)), + k=1) + acc5 = paddle.static.accuracy( + input=prob, + label=paddle.reshape(input_field.label, (-1, 1)), + k=5) optimizer = None if is_train: # initialize optimizer optimizer = self._get_optimizer() if self.num_trainers > 1: - dist_optimizer = fleet.distributed_optimizer( - optimizer, strategy=dist_strategy) + dist_optimizer = fleet.distributed_optimizer(optimizer) dist_optimizer.minimize(loss) else: # single card training optimizer.minimize(loss) if "dist" in self.loss_type or self.use_fp16: optimizer = optimizer._optimizer elif use_parallel_test: - emb = fluid.layers.collective._c_allgather( - emb, nranks=num_trainers, use_calc_stream=True) + emb_list = [] + paddle.distributed.all_gather(emb_list, emb) + emb = paddle.concat(emb_list, axis=0) return emb, loss, acc1, acc5, optimizer def get_files_from_hdfs(self): assert self.fs_checkpoint_dir, \ logger.error("Please set the fs_checkpoint_dir paramerters for " - "set_llllllhdfs_info to get models from hdfs.") + "set_hdfs_info to get models from hdfs.") self.fs_checkpoint_dir = os.path.join(self.fs_checkpoint_dir, '*') cmd = "hadoop fs -D fs.default.name=" cmd += self.fs_name + " " @@ -557,24 +581,6 @@ def put_files_to_hdfs(self, local_dir): cmd, stdout=sys.stdout, stderr=subprocess.STDOUT) process.wait() - def process_distributed_params(self, local_dir): - local_dir = os.path.abspath(local_dir) - output_dir = tempfile.mkdtemp() - converter = ParameterConverter(local_dir, output_dir, - self.num_trainers) - converter.process() - - for file in os.listdir(local_dir): - if "dist@" in file and "@rank@" in file: - file = os.path.join(local_dir, file) - os.remove(file) - - for file in os.listdir(output_dir): - if "dist@" in file and "@rank@" in file: - file = os.path.join(output_dir, file) - shutil.move(file, local_dir) - shutil.rmtree(output_dir) - def _append_broadcast_ops(self, program): """ Before test, we broadcast bathnorm-related parameters to all @@ -593,17 +599,66 @@ def _append_broadcast_ops(self, program): outputs={'Out': var}, attrs={'use_calc_stream': True}) - def load_checkpoint(self, - executor, - main_program, - use_per_trainer_checkpoint=False, - load_for_train=True): - if use_per_trainer_checkpoint: - checkpoint_dir = os.path.join(self.checkpoint_dir, - str(self.trainer_id)) - else: - checkpoint_dir = self.checkpoint_dir + def save(self, program, epoch=0, for_train=True): + if not self.model_save_dir: + return + + trainer_id = self.trainer_id + model_save_dir = os.path.join(self.model_save_dir, str(epoch)) + if not os.path.exists(model_save_dir): + # may be more than one processes trying + # to create the directory + try: + os.makedirs(model_save_dir) + except OSError as exc: + if exc.errno != errno.EEXIST: + raise + pass + + param_state_dict = program.state_dict(mode='param') + for name, param in param_state_dict.items(): + # for non dist param, we only save their at trainer 0, + # but for dist param, we need to save their at all trainers + if 'dist@' in name and '@rank@' in name or trainer_id == 0: + paddle.save(param, + os.path.join(model_save_dir, name + '.pdparam')) + + if for_train: + opt_state_dict = program.state_dict(mode='opt') + for name, opt in opt_state_dict.items(): + # for non opt var, we only save their at trainer 0, + # but for opt var, we need to save their at all trainers + if 'dist@' in name and '@rank@' in name or trainer_id == 0: + paddle.save(opt, + os.path.join(model_save_dir, name + '.pdopt')) + + if trainer_id == 0: + # save some extra info for resume + # pretrain_nranks, emb_dim, num_classes are used for + # re-split fc weight when gpu setting changed. + # epoch use to restart. + config_file = os.path.join(model_save_dir, 'meta.json') + extra_info = dict() + extra_info["pretrain_nranks"] = self.num_trainers + extra_info["emb_dim"] = self.emb_dim + extra_info['num_classes'] = self.num_classes + extra_info['epoch'] = epoch + extra_info['lr_state'] = self.lr_scheduler.state_dict() + with open(config_file, 'w') as f: + json.dump(extra_info, f) + + logger.info("Save model to {}.".format(self.model_save_dir)) + if trainer_id == 0 and self.max_last_checkpoint_num > 0: + for idx in range(-1, epoch - self.max_last_checkpoint_num + 1): + path = os.path.join(self.model_save_dir, str(idx)) + if os.path.exists(path): + logger.info("Remove checkpoint {}.".format(path)) + shutil.rmtree(path) + + def load(self, program, for_train=True): + checkpoint_dir = os.path.abspath(self.checkpoint_dir) + logger.info("Load checkpoint from '{}'. ".format(checkpoint_dir)) if self.fs_name is not None: if os.path.exists(checkpoint_dir): logger.info("Local dir {} exists, we'll overwrite it.".format( @@ -625,40 +680,99 @@ def load_checkpoint(self, else: break - # Preporcess distributed parameters. - file_name = os.path.join(checkpoint_dir, '.lock') - meta_file = os.path.join(checkpoint_dir, 'meta.json') - if not os.path.exists(meta_file): - logger.error("Please make sure the checkpoint dir {} exists, and " - "parameters in that dir are validating.".format( - checkpoint_dir)) - exit() + state_dict = {} + dist_weight_state_dict = {} + dist_weight_velocity_state_dict = {} + dist_bias_state_dict = {} + dist_bias_velocity_state_dict = {} + for path in os.listdir(checkpoint_dir): + path = os.path.join(checkpoint_dir, path) + if not os.path.isfile(path): + continue + + basename = os.path.basename(path) + name, ext = os.path.splitext(basename) + + if ext not in ['.pdopt', '.pdparam']: + continue + + if not for_train and ext == '.pdopt': + continue + + tensor = paddle.load(path, return_numpy=True) + + if 'dist@' in name and '@rank@' in name: + if '.w' in name and 'velocity' not in name: + dist_weight_state_dict[name] = tensor + elif '.w' in name and 'velocity' in name: + dist_weight_velocity_state_dict[name] = tensor + elif '.b' in name and 'velocity' not in name: + dist_bias_state_dict[name] = tensor + elif '.b' in name and 'velocity' in name: + dist_bias_velocity_state_dict[name] = tensor + + else: + state_dict[name] = tensor + distributed = self.loss_type in ["dist_softmax", "dist_arcface"] - if load_for_train and self.trainer_id == 0 and distributed: - self.process_distributed_params(checkpoint_dir) - with open(file_name, 'w') as f: - pass - time.sleep(10) - os.remove(file_name) - elif load_for_train and distributed: - # wait trainer_id (0) to complete - while True: - if not os.path.exists(file_name): - time.sleep(1) - else: - break - def if_exist(var): - has_var = os.path.exists(os.path.join(checkpoint_dir, var.name)) - if has_var: - logger.info('var: %s found' % (var.name)) - return has_var + if for_train or distributed: + meta_file = os.path.join(checkpoint_dir, 'meta.json') + if not os.path.exists(meta_file): + logger.error( + "Please make sure the checkpoint dir {} exists, and " + "parameters in that dir are validating.".format( + checkpoint_dir)) + exit() - fluid.io.load_vars( - executor, - checkpoint_dir, - predicate=if_exist, - main_program=main_program) + with open(meta_file, 'r') as handle: + config = json.load(handle) + + # Preporcess distributed parameters. + if distributed: + pretrain_nranks = config['pretrain_nranks'] + assert pretrain_nranks > 0 + emb_dim = config['emb_dim'] + assert emb_dim == self.emb_dim + num_classes = config['num_classes'] + assert num_classes == self.num_classes + + logger.info("Parameters for pre-training: pretrain_nranks ({}), " + "emb_dim ({}), and num_classes ({}).".format( + pretrain_nranks, emb_dim, num_classes)) + logger.info("Parameters for inference or fine-tuning: " + "nranks ({}).".format(self.num_trainers)) + + trainer_id_str = '%05d' % self.trainer_id + + dist_weight_state_dict = rearrange_weight( + dist_weight_state_dict, pretrain_nranks, self.num_trainers) + dist_bias_state_dict = rearrange_weight( + dist_bias_state_dict, pretrain_nranks, self.num_trainers) + for name, value in dist_weight_state_dict.items(): + if trainer_id_str in name: + state_dict[name] = value + for name, value in dist_bias_state_dict.items(): + if trainer_id_str in name: + state_dict[name] = value + + if for_train: + dist_weight_velocity_state_dict = rearrange_weight( + dist_weight_velocity_state_dict, pretrain_nranks, + self.num_trainers) + dist_bias_velocity_state_dict = rearrange_weight( + dist_bias_velocity_state_dict, pretrain_nranks, + self.num_trainers) + for name, value in dist_weight_velocity_state_dict.items(): + if trainer_id_str in name: + state_dict[name] = value + for name, value in dist_bias_velocity_state_dict.items(): + if trainer_id_str in name: + state_dict[name] = value + if for_train: + return {'state_dict': state_dict, 'extra_info': config} + else: + return {'state_dict': state_dict} def convert_for_prediction(self): model_name = self.model_name @@ -669,21 +783,20 @@ def convert_for_prediction(self): model = resnet.__dict__[model_name](emb_dim=self.emb_dim) main_program = self.predict_program startup_program = self.startup_program - with fluid.program_guard(main_program, startup_program): - with fluid.unique_name.guard(): + with paddle.static.program_guard(main_program, startup_program): + with paddle.utils.unique_name.guard(): input_field = InputField(self.input_info) input_field.build() emb = model.build_network(input=input_field, is_train=False) gpu_id = int(os.getenv("FLAGS_selected_gpus", 0)) - place = fluid.CUDAPlace(gpu_id) - exe = fluid.Executor(place) + place = paddle.CUDAPlace(gpu_id) + exe = paddle.static.Executor(place) exe.run(startup_program) assert self.checkpoint_dir, "No checkpoint found for converting." - self.load_checkpoint( - executor=exe, main_program=main_program, load_for_train=False) + self.load(program=main_program, for_train=False) assert self.model_save_dir, \ "Does not set model_save_dir for inference model converting." @@ -695,71 +808,22 @@ def convert_for_prediction(self): for name in input_field.feed_list_str: if name == "label": continue feed_var_names.append(name) - fluid.io.save_inference_model( + paddle.static.save_inference_model( self.model_save_dir, - feeded_var_names=feed_var_names, - target_vars=[emb], - executor=exe, - main_program=main_program) + feed_var_names, [emb], + exe, + program=main_program) if self.fs_name: self.put_files_to_hdfs(self.model_save_dir) - def _set_info(self, key, value): - if not hasattr(self, '_info'): - self._info = {} - self._info[key] = value - - def _get_info(self, key): - if hasattr(self, '_info') and key in self._info: - return self._info[key] - return None - - def predict(self): - model_name = self.model_name - # model definition - model = self.model - if model is None: - model = resnet.__dict__[model_name](emb_dim=self.emb_dim) - main_program = self.predict_program - startup_program = self.startup_program - with fluid.program_guard(main_program, startup_program): - with fluid.unique_name.guard(): - input_field = InputField(self.input_holder) - input_field.build() - - emb = model.build_network(input=input_field, is_train=False) - - gpu_id = int(os.getenv("FLAGS_selected_gpus", 0)) - place = paddle.CUDAPlace(gpu_id) - exe = paddle.static.Executor(place) - exe.run(startup_program) - - assert self.checkpoint_dir, "No checkpoint found for predicting." - self.load_checkpoint( - executor=exe, main_program=main_program, load_for_train=False) - - if self.predict_reader is None: - predict_reader = reader.arc_train(self.dataset_dir, - self.num_classes) - else: - predict_reader = self.predict_reader - - input_field.loader.set_sample_generator( - predict_reader, batch_size=self.train_batch_size, places=place) - - fetch_list = [emb.name] - for data in input_field.loader: - emb = exe.run(main_program, - feed=data, - fetch_list=fetch_list, - use_program_cache=True) - def _run_test(self, exe, test_list, test_name_list, feeder, fetch_list): trainer_id = self.trainer_id real_test_batch_size = self.global_test_batch_size for i in range(len(test_list)): data_list, issame_list = test_list[i] embeddings_list = [] + # data_list[0] for normalize + # data_list[1] for flip_left_right for j in range(len(data_list)): data = data_list[j] embeddings = None @@ -821,11 +885,12 @@ def _run_test(self, exe, test_list, test_name_list, feeder, fetch_list): embeddings, issame_list, nrof_folds=10) acc, std = np.mean(accuracy), np.std(accuracy) - if self.train_pass_id >= 0: - logger.info('[{}][{}]XNorm: {:.5f}'.format(test_name_list[ - i], self.train_pass_id, xnorm)) - logger.info('[{}][{}]Accuracy-Flip: {:.5f}+-{:.5f}'.format( - test_name_list[i], self.train_pass_id, acc, std)) + if self.cur_epoch >= 0: + logger.info('[{}][{}][{}]XNorm: {:.5f}'.format(test_name_list[ + i], self.cur_epoch, self.cur_steps, xnorm)) + logger.info('[{}][{}][{}]Accuracy-Flip: {:.5f}+-{:.5f}'.format( + test_name_list[ + i], self.cur_epoch, self.cur_steps, acc, std)) else: logger.info('[{}]XNorm: {:.5f}'.format(test_name_list[i], xnorm)) @@ -857,10 +922,11 @@ def test(self): worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS") current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT") - config = dist_transpiler.DistributeTranspilerConfig() + #TODO how to transpile + config = paddle.fluid.transpiler.DistributeTranspilerConfig() config.mode = "collective" config.collective_mode = "grad_allreduce" - t = dist_transpiler.DistributeTranspiler(config=config) + t = paddle.fluid.transpiler.DistributeTranspiler(config=config) t.transpile( trainer_id=trainer_id, trainers=worker_endpoints, @@ -876,7 +942,7 @@ def test(self): if not self.has_run_train: exe.run(self.startup_program) - if not self.test_reader: + if not self.test_dataset: test_reader = reader.test else: test_reader = self.test_reader @@ -898,7 +964,8 @@ def test(self): self.load_checkpoint( executor=exe, main_program=test_program, load_for_train=False) - feeder = fluid.DataFeeder( + #TODO paddle.fluid.DataFeeder + feeder = paddle.fluid.DataFeeder( place=place, feed_list=self.input_field.feed_list_str, program=test_program) @@ -918,32 +985,44 @@ def train(self): trainer_id = self.trainer_id num_trainers = self.num_trainers + gpu_id = int(os.getenv("FLAGS_selected_gpus", 0)) + place = paddle.CUDAPlace(gpu_id) + strategy = None if num_trainers > 1: - fleet.init(is_collective=True) strategy = fleet.DistributedStrategy() - strategy.mode = "collective" - strategy.collective_mode = "grad_allreduce" - - emb, loss, acc1, acc5, optimizer = self.build_program( - True, False, dist_strategy=strategy) + strategy.without_graph_optimization = True + fleet.init(is_collective=True, strategy=strategy) + + emb, loss, acc1, acc5, optimizer = self.build_program(True, False) + + # define dataset + if self.train_dataset is None: + train_dataset = reader.TrainDataset( + self.dataset_dir, + self.num_classes, + color_jitter=False, + rotate=False, + rand_mirror=True, + normalize=True) + else: + train_dataset = self.train_dataset + + dataloader = paddle.io.DataLoader( + train_dataset, + feed_list=self.input_field.feed_list, + places=place, + return_list=False, + batch_size=self.train_batch_size, + shuffle=True, + drop_last=True, + num_workers=4) global_lr = optimizer._global_learning_rate(program=self.train_program) origin_prog = self.train_program train_prog = self.train_program - if self.use_quant: - qpp = QuantizeProgramPass( - activation_quantize_type='abs_max', - weight_quantize_type='abs_max', - quantizable_op_type=[ - 'conv2d', 'depthwise_conv2d', 'mul', 'pool2d' - ]) - qpp.apply(train_prog, self.startup_program) - - gpu_id = int(os.getenv("FLAGS_selected_gpus", 0)) - place = paddle.CUDAPlace(gpu_id) exe = paddle.static.Executor(place) exe.run(self.startup_program) @@ -951,16 +1030,17 @@ def train(self): load_checkpoint = True else: load_checkpoint = False - if load_checkpoint: - self.load_checkpoint(executor=exe, main_program=origin_prog) - if self.train_reader is None: - train_reader = reader.arc_train(self.dataset_dir, self.num_classes) - else: - train_reader = self.train_reader - - self.input_field.loader.set_sample_generator( - train_reader, batch_size=self.train_batch_size, places=place) + start_epoch = 0 + self.cur_steps = 0 + if load_checkpoint: + checkpoint = self.load(program=origin_prog, for_train=True) + origin_prog.set_state_dict(checkpoint['state_dict']) + start_epoch = checkpoint['extra_info']['epoch'] + 1 + lr_state = checkpoint['extra_info']['lr_state'] + # there last_epoch means last_step in step style + self.cur_steps = lr_state['last_epoch'] + self.lr_scheduler.set_state_dict(lr_state) if self.calc_train_acc: fetch_list = [loss.name, global_lr.name, acc1.name, acc5.name] @@ -971,11 +1051,12 @@ def train(self): nsamples = 0 inspect_steps = self.log_period global_batch_size = self.global_train_batch_size - for pass_id in range(self.train_epochs): - self.train_pass_id = pass_id + for epoch in range(start_epoch, self.train_epochs): + self.cur_epoch = epoch train_info = [[], [], [], []] local_train_info = [[], [], [], []] - for batch_id, data in enumerate(self.input_field.loader): + for batch_id, data in enumerate(dataloader): + self.cur_steps += 1 nsamples += global_batch_size t1 = time.time() acc1 = None @@ -990,6 +1071,7 @@ def train(self): feed=data, fetch_list=fetch_list, use_program_cache=True) + self.lr_scheduler.step() t2 = time.time() period = t2 - t1 local_time += period @@ -1004,60 +1086,28 @@ def train(self): if self.calc_train_acc: logger.info("Pass:{} batch:{} lr:{:.8f} loss:{:.6f} " "qps:{:.2f} acc1:{:.6f} acc5:{:.6f}". - format(pass_id, batch_id, avg_lr, avg_loss, + format(epoch, batch_id, avg_lr, avg_loss, speed, acc1[0], acc5[0])) else: logger.info( "Pass:{} batch:{} lr:{:.8f} loss:{:.6f} " - "qps:{:.2f}".format(pass_id, batch_id, avg_lr, + "qps:{:.2f}".format(epoch, batch_id, avg_lr, avg_loss, speed)) local_time = 0 nsamples = 0 local_train_info = [[], [], [], []] + if self.test_period > 0 and self.cur_steps % self.test_period == 0: + if self.with_test: + self.test() + train_loss = np.array(train_info[0]).mean() - logger.info("End pass {}, train_loss {:.6f}".format(pass_id, + logger.info("End pass {}, train_loss {:.6f}".format(epoch, train_loss)) sys.stdout.flush() - if self.with_test: - self.test() - # save model - if self.model_save_dir: - model_save_dir = os.path.join(self.model_save_dir, - str(pass_id)) - if not os.path.exists(model_save_dir): - # may be more than one processes trying - # to create the directory - try: - os.makedirs(model_save_dir) - except OSError as exc: - if exc.errno != errno.EEXIST: - raise - pass - if trainer_id == 0: - fluid.io.save_persistables(exe, model_save_dir, - origin_prog) - else: - - def save_var(var): - to_save = "dist@" in var.name and '@rank@' in var.name - return to_save and var.persistable - - fluid.io.save_vars( - exe, model_save_dir, origin_prog, predicate=save_var) - - # save training info - if self.model_save_dir and trainer_id == 0: - config_file = os.path.join(self.model_save_dir, - str(pass_id), 'meta.json') - train_info = dict() - train_info["pretrain_nranks"] = self.num_trainers - train_info["emb_dim"] = self.emb_dim - train_info['num_classes'] = self.num_classes - with open(config_file, 'w') as f: - json.dump(train_info, f) + self.save(origin_prog, epoch=epoch) # upload model if self.model_save_dir and self.fs_name and trainer_id == 0: diff --git a/plsc/models/base_model.py b/plsc/models/base_model.py index 91dec9102c91e..99119865f2c2d 100644 --- a/plsc/models/base_model.py +++ b/plsc/models/base_model.py @@ -116,26 +116,26 @@ def _fc_classify(input, label, out_dim, param_attr, bias_attr): param_attr = paddle.ParamAttr( initializer=paddle.nn.initializer.Uniform(-stddev, stddev)) - out = paddle.static.nn.fc(input=input, + out = paddle.static.nn.fc(x=input, size=out_dim, weight_attr=param_attr, bias_attr=bias_attr) loss, prob = paddle.nn.functional.softmax_with_cross_entropy( - logits=out, label=label, return_softmax=True) + logits=out, + label=paddle.reshape(label, (-1, 1)), + return_softmax=True) avg_loss = paddle.mean(x=loss) return avg_loss, prob @staticmethod def _arcface(input, label, out_dim, param_attr, margin, scale): input_norm = paddle.sqrt( - paddle.reduce_sum( - paddle.square(input), dim=1)) - input = paddle.elementwise_div(input, input_norm, axis=0) + paddle.sum(paddle.square(input), axis=1, keepdim=True)) + input = paddle.divide(input, input_norm) if param_attr is None: param_attr = paddle.ParamAttr( - initializer=paddle.nn.initializer.Xavier( - uniform=False, fan_in=0.0)) + initializer=paddle.nn.initializer.XavierNormal(fan_in=0.0)) weight = paddle.static.create_parameter( shape=[input.shape[1], out_dim], dtype='float32', @@ -143,20 +143,21 @@ def _arcface(input, label, out_dim, param_attr, margin, scale): attr=param_attr) weight_norm = paddle.sqrt( - paddle.reduce_sum( - paddle.square(weight), dim=0)) - weight = paddle.elementwise_div(weight, weight_norm, axis=1) - cos = paddle.mul(input, weight) + paddle.sum(paddle.square(weight), axis=0, keepdim=True)) + weight = paddle.divide(weight, weight_norm) + cos = paddle.matmul(input, weight) theta = paddle.acos(cos) margin_cos = paddle.cos(theta + margin) - one_hot = paddle.one_hot(label, out_dim) - diff = (margin_cos - cos) * one_hot - target_cos = cos + diff + one_hot = paddle.nn.functional.one_hot(label, out_dim) + diff = paddle.multiply(paddle.subtract(margin_cos, cos), one_hot) + target_cos = paddle.add(cos, diff) logit = paddle.scale(target_cos, scale=scale) - loss, prob = paddle.softmax_with_cross_entropy( - logits=logit, label=label, return_softmax=True) + loss, prob = paddle.nn.functional.softmax_with_cross_entropy( + logits=logit, + label=paddle.reshape(label, (-1, 1)), + return_softmax=True) avg_loss = paddle.mean(x=loss) one_hot.stop_gradient = True diff --git a/plsc/models/dist_algo.py b/plsc/models/dist_algo.py index 79b9aa854355f..527738e98e076 100644 --- a/plsc/models/dist_algo.py +++ b/plsc/models/dist_algo.py @@ -21,8 +21,9 @@ import paddle import paddle.nn as nn import paddle.utils.unique_name as unique_name +from paddle.optimizer import Optimizer from ..utils.fp16_utils import rewrite_program, update_role_var_grad -from ..utils.fp16_utils import update_loss_scaling, move_optimize_ops_back +from ..utils.fp16_utils import update_loss_scaling, move_optimize_ops_back, check_finite_and_unscale from ..utils.fp16_lists import AutoMixedPrecisionLists from six.moves import reduce @@ -38,6 +39,21 @@ class DistributedClassificationOptimizer(Optimizer): classification training of model parallelism. """ + def __init__(self, + optimizer, + batch_size, + use_fp16=False, + loss_type='dist_arcface', + fp16_user_dict=None): + super(DistributedClassificationOptimizer, self).__init__( + learning_rate=optimizer._learning_rate) + self._optimizer = optimizer + self._batch_size = batch_size + self._use_fp16 = use_fp16 + + if self._use_fp16: + self.init_fp16_params(loss_type, fp16_user_dict) + def init_fp16_params(self, loss_type, fp16_user_dict): # set default value for fp16_params_dict fp16_params_dict = dict() @@ -51,7 +67,7 @@ def init_fp16_params(self, loss_type, fp16_user_dict): if fp16_user_dict is not None: # update fp16_params_dict for key in fp16_user_dict: - if fp16_params_dict.has_key(key): + if key in fp16_params_dict: fp16_params_dict[key] = fp16_user_dict[key] else: logging.warning( @@ -66,33 +82,29 @@ def init_fp16_params(self, loss_type, fp16_user_dict): self._amp_lists = AutoMixedPrecisionLists() self._loss_type = loss_type + self._init_loss_scaling = fp16_params_dict['init_loss_scaling'] self._loss_scaling = paddle.static.create_global_var( - name=unique_name.generate("loss_scaling"), + name=paddle.utils.unique_name.generate("loss_scaling"), shape=[1], - value=fp16_params_dict['init_loss_scaling'], + value=self._init_loss_scaling, dtype='float32', persistable=True) self._use_dynamic_loss_scaling = fp16_params_dict[ 'use_dynamic_loss_scaling'] if self._use_dynamic_loss_scaling: - self._incr_every_n_steps = paddle.full( - shape=[1], - values=fp16_params_dict['incr_every_n_steps'], - dtype='int32') - self._decr_every_n_nan_or_inf = paddle.full( - shape=[1], - value=fp16_params_dict['decr_every_n_nan_or_inf'], - dtype='int32') + self._incr_every_n_steps = fp16_params_dict['incr_every_n_steps'] + self._decr_every_n_nan_or_inf = fp16_params_dict[ + 'decr_every_n_nan_or_inf'] self._incr_ratio = fp16_params_dict['incr_ratio'] self._decr_ratio = fp16_params_dict['decr_ratio'] self._num_good_steps = paddle.static.create_global_var( - name=unique_name.generate("num_good_steps"), + name=paddle.utils.unique_name.generate("num_good_steps"), shape=[1], value=0, dtype='int32', persistable=True) self._num_bad_steps = paddle.static.create_global_var( - name=unique_name.generate("num_bad_steps"), + name=paddle.utils.unique_name.generate("num_bad_steps"), shape=[1], value=0, dtype='int32', @@ -103,25 +115,12 @@ def init_fp16_params(self, loss_type, fp16_user_dict): if isinstance(self._optimizer._learning_rate, float): self._optimizer._learning_rate_map[paddle.static.default_main_program()] = \ paddle.static.create_global_var( - name=unique_name.generate("learning_rate"), + name=paddle.utils.unique_name.generate("learning_rate"), shape=[1], value=float(self._optimizer._learning_rate), dtype='float32', persistable=True) - def __init__(self, - optimizer, - batch_size, - use_fp16=False, - loss_type='dist_arcface', - fp16_user_dict=None): - self._optimizer = optimizer - self._batch_size = batch_size - self._use_fp16 = use_fp16 - - if self._use_fp16: - self.init_fp16_params(loss_type, fp16_user_dict) - def fp16_backward(self, block, scalar, @@ -131,16 +130,34 @@ def fp16_backward(self, callbacks=None): rewrite_program(block.program, self._amp_lists) + if self._use_dynamic_loss_scaling or self._init_loss_scaling != 1.0: + scaled_scalar = scalar * self._loss_scaling + else: + scaled_scalar = scalar + self._params_grads = self._optimizer.backward( - scalar, startup_program, parameter_list, no_grad_set, callbacks) - update_role_var_grad(block.program, self._params_grads) - move_optimize_ops_back(block.program.global_block()) - scaled_params_grads = [] - for p, g in self._params_grads: - with paddle.static.default_main_program()._optimized_guard([p, g]): - scaled_g = g / self._loss_scaling - scaled_params_grads.append([p, scaled_g]) - return scaled_params_grads + scaled_scalar, startup_program, parameter_list, no_grad_set, + callbacks) + + grads = [g for _, g in self._params_grads] + with block.program._optimized_guard(grads): + _, found_inf = check_finite_and_unscale( + grads, self._loss_scaling, name="find_infinite_scale") + if self._use_dynamic_loss_scaling: + with block.program._optimized_guard([]): + update_loss_scaling( + grads, + found_inf, + self._loss_scaling, + self._num_good_steps, + self._num_bad_steps, + self._incr_every_n_steps, + self._decr_every_n_nan_or_inf, + self._incr_ratio, + self._decr_ratio, + name="update_loss_scaling") + + return self._params_grads def insert_dist_arcface_backward_op( self, block, index, shard_logit, shard_prob, shard_label, @@ -150,17 +167,14 @@ def insert_dist_arcface_backward_op( when loss_type equals dist_arcface. ''' shard_one_hot = block.create_var( - name=fluid.unique_name.generate('shard_one_hot'), - dtype=shard_logit.dtype) - shard_one_hot_fp32 = block.create_var( - name=fluid.unique_name.generate(shard_one_hot.name + '.cast_fp32'), - dtype=shard_logit.dtype) + name=paddle.utils.unique_name.generate('shard_one_hot'), + dtype=paddle.fluid.core.VarDesc.VarType.FP32) # input var of elementwise_add_grad op after scale - shard_logit_grad_fp32 = block.var('tmp_3@GRAD') + shard_logit_grad_fp32 = block.var(shard_logit.name + '@GRAD') block._insert_op( - index - 1, - type='one_hot', + index - 2, + type='one_hot_v2', inputs={'X': shard_label}, outputs={'Out': shard_one_hot}, attrs={ @@ -169,31 +183,21 @@ def insert_dist_arcface_backward_op( op_role_key: backward_role }) block._insert_op( - index, - type="cast", - inputs={"X": shard_one_hot}, - outputs={"Out": shard_one_hot_fp32}, - attrs={ - "in_dtype": fluid.core.VarDesc.VarType.FP16, - "out_dtype": fluid.core.VarDesc.VarType.FP32, - op_role_key: backward_role - }) - block._insert_op( - index + 1, + index - 1, type='elementwise_sub', inputs={'X': shard_prob, - 'Y': shard_one_hot_fp32}, + 'Y': shard_one_hot}, outputs={'Out': shard_logit_grad_fp32}, attrs={op_role_key: backward_role}) block._insert_op( - index + 2, + index, type='elementwise_mul', inputs={'X': shard_logit_grad_fp32, 'Y': self._loss_scaling}, outputs={'Out': shard_logit_grad_fp32}, attrs={op_role_key: backward_role}) block._insert_op( - index + 3, + index + 1, type='scale', inputs={'X': shard_logit_grad_fp32}, outputs={'Out': shard_logit_grad_fp32}, @@ -210,39 +214,36 @@ def insert_dist_softmax_backward_op( when loss_type equals dist_softmax. ''' shard_one_hot = block.create_var( - name=fluid.unique_name.generate('shard_one_hot'), - dtype=fluid.core.VarDesc.VarType.FP32) - shard_one_hot_fp32 = block.create_var( - name=fluid.unique_name.generate(shard_one_hot.name + '.cast_fp32'), - dtype=fluid.core.VarDesc.VarType.FP32) - shard_logit_grad_fp32 = block.var(shard_logit.name + ".cast_fp32@GRAD") + name=paddle.utils.unique_name.generate('shard_one_hot'), + dtype=paddle.fluid.core.VarDesc.VarType.FP32) + shard_logit_grad_fp32 = block.var(shard_logit.name + "@GRAD") block._insert_op( - index - 1, - type='one_hot', + index - 2, + type='one_hot_v2', inputs={'X': shard_label}, - outputs={'Out': shard_one_hot_fp32}, + outputs={'Out': shard_one_hot}, attrs={ 'depth': shard_dim, 'allow_out_of_range': True, op_role_key: backward_role }) block._insert_op( - index, + index - 1, type='elementwise_sub', inputs={'X': shard_prob, - 'Y': shard_one_hot_fp32}, + 'Y': shard_one_hot}, outputs={'Out': shard_logit_grad_fp32}, attrs={op_role_key: backward_role}) block._insert_op( - index + 1, + index, type='elementwise_mul', inputs={'X': shard_logit_grad_fp32, 'Y': self._loss_scaling}, outputs={'Out': shard_logit_grad_fp32}, attrs={op_role_key: backward_role}) block._insert_op( - index + 2, + index + 1, type='scale', inputs={'X': shard_logit_grad_fp32}, outputs={'Out': shard_logit_grad_fp32}, @@ -261,10 +262,12 @@ def insert_commom_backward_op(self, block, index, shard_logit, shard_prob, # insert the calculated gradient dtype = shard_logit.dtype - shard_one_hot = fluid.layers.create_tensor(dtype, name='shard_one_hot') + shard_one_hot = block.create_var( + name=paddle.utils.unique_name.generate('shard_one_hot'), + dtype=dtype) block._insert_op( index - 1, - type='one_hot', + type='one_hot_v2', inputs={'X': shard_label}, outputs={'Out': shard_one_hot}, attrs={ @@ -272,8 +275,7 @@ def insert_commom_backward_op(self, block, index, shard_logit, shard_prob, 'allow_out_of_range': True, op_role_key: backward_role }) - shard_logit_grad = fluid.layers.create_tensor( - dtype, name=fluid.backward._append_grad_suffix_(shard_logit.name)) + shard_logit_grad = block.var(shard_logit.name + "@GRAD") block._insert_op( index, type='elementwise_sub', @@ -304,15 +306,15 @@ def minimize(self, shard_label = loss._get_info('shard_label') shard_dim = loss._get_info('shard_dim') - op_maker = paddle.core.op_proto_and_checker_maker + op_maker = paddle.fluid.core.op_proto_and_checker_maker op_role_key = op_maker.kOpRoleAttrName() op_role_var_key = op_maker.kOpRoleVarAttrName() backward_role = int(op_maker.OpRole.Backward) loss_backward_role = int(op_maker.OpRole.Loss) | int( op_maker.OpRole.Backward) - # minimize a scalar of reduce_sum to generate the backward network - scalar = paddle.reduce_sum(shard_logit) + # minimize a scalar of sum to generate the backward network + scalar = paddle.sum(shard_logit) block = loss.block if not self._use_fp16: @@ -337,9 +339,9 @@ def minimize(self, op_role_key, backward_role, loss_backward_role) return ret else: - scaled_params_grads = self.fp16_backward( - block, scalar, startup_program, parameter_list, no_grad_set, - callbacks) + params_grads = self.fp16_backward(block, scalar, startup_program, + parameter_list, no_grad_set, + callbacks) index = 0 for i, op in enumerate(block.ops): if op.all_attrs()[op_role_key] == loss_backward_role: @@ -347,60 +349,44 @@ def minimize(self, break if self._loss_type == 'dist_arcface': - assert block.ops[index - 2].type == 'fill_constant' - assert block.ops[index - 1].type == 'reduce_sum' + assert block.ops[index - 2].type == 'reduce_sum' + assert block.ops[index - 1].type == 'elementwise_mul' assert block.ops[index].type == 'fill_constant' - assert block.ops[index + 1].type == 'reduce_sum_grad' - assert block.ops[index + 2].type == 'scale' - assert block.ops[index + 3].type == 'elementwise_add_grad' + assert block.ops[index + 1].type == 'elementwise_mul_grad' + assert block.ops[index + 2].type == 'reduce_sum_grad' + assert block.ops[index + 3].type == 'scale' + assert block.ops[index + 4].type == 'elementwise_add_grad' block._remove_op(index + 2) block._remove_op(index + 1) block._remove_op(index) block._remove_op(index - 1) + block._remove_op(index - 2) self.insert_dist_arcface_backward_op( block, index, shard_logit, shard_prob, shard_label, shard_dim, op_role_key, backward_role, loss_backward_role) elif self._loss_type == 'dist_softmax': - assert block.ops[index - 1].type == 'reduce_sum' + assert block.ops[index - 2].type == 'reduce_sum' + assert block.ops[index - 1].type == 'elementwise_mul' assert block.ops[index].type == 'fill_constant' - assert block.ops[index + 1].type == 'reduce_sum_grad' - assert block.ops[index + 2].type == 'cast' + assert block.ops[index + 1].type == 'elementwise_mul_grad' + assert block.ops[index + 2].type == 'reduce_sum_grad' assert block.ops[index + 3].type == 'elementwise_add_grad' + block._remove_op(index + 2) block._remove_op(index + 1) block._remove_op(index) block._remove_op(index - 1) + block._remove_op(index - 2) self.insert_dist_softmax_backward_op( block, index, shard_logit, shard_prob, shard_label, shard_dim, op_role_key, backward_role, loss_backward_role) - if self._use_dynamic_loss_scaling: - grads = [ - paddle.reduce_sum(g) for [_, g] in scaled_params_grads - ] - all_grads = paddle.concat(grads) - all_grads_sum = paddle.reduce_sum(all_grads) - is_overall_finite = paddle.isfinite(all_grads_sum) - - update_loss_scaling(is_overall_finite, self._loss_scaling, - self._num_good_steps, self._num_bad_steps, - self._incr_every_n_steps, - self._decr_every_n_nan_or_inf, - self._incr_ratio, self._decr_ratio) - - with layers.Switch() as switch: - with switch.case(is_overall_finite): - pass - with switch.default(): - for _, g in scaled_params_grads: - layers.assign(layers.zeros_like(g), g) - - optimize_ops = self._optimizer.apply_gradients(scaled_params_grads) - ret = optimize_ops, scaled_params_grads + optimize_ops = self._optimizer.apply_gradients(params_grads) + ret = optimize_ops, params_grads return ret @@ -410,11 +396,11 @@ class DistributedClassifier(object): full-connected layer is distributed to all trainers """ - def __init__(self, nclasses, nranks, rank_id, layer_helper): + def __init__(self, nclasses, nranks, rank_id, name): self.nclasses = nclasses self.nranks = nranks self.rank_id = rank_id - self._layer_helper = layer_helper + self.name = name self.shard_dim = (nclasses + nranks - 1) // nranks self.padding_dim = 0 @@ -436,11 +422,15 @@ def create_parameter(self, if param_attr is None: stddev = math.sqrt(2.0 / (in_dim + self.nclasses)) param_attr = paddle.ParamAttr( - initializer=paddle.nn.initializer.Normal(scale=stddev)) + initializer=paddle.nn.initializer.Normal(std=stddev)) weight_shape = [self.shard_dim, in_dim ] if transpose_weight else [in_dim, self.shard_dim] weight = paddle.static.create_parameter( - shape=weight_shape, dtype=dtype, attr=param_attr, is_bias=False) + shape=weight_shape, + dtype=dtype, + name=self.name, + attr=param_attr, + is_bias=False) # avoid allreducing gradients for distributed parameters weight.is_distributed = True @@ -455,30 +445,38 @@ def create_parameter(self, attr=bias_attr, dtype=dtype, is_bias=True) + # avoid allreducing gradients for distributed parameters bias.is_distributed = True + # avoid broadcasting distributed parameters in startup program paddle.static.default_startup_program().global_block().vars[ bias.name].is_distributed = True return weight, bias def softmax_with_cross_entropy(self, shard_logit, shard_label): - shard_max = paddle.reduce_max(shard_logit, dim=1, keep_dim=True) - global_max = paddle.distributed.all_reduce(shard_max, op=ReduceOp.MAX) - shard_logit_new = paddle.elementwise_sub(shard_logit, global_max) + shard_max = paddle.max(shard_logit, axis=1, keepdim=True) + global_max = shard_max + paddle.distributed.all_reduce( + global_max, op=paddle.distributed.ReduceOp.MAX) + shard_logit_new = paddle.subtract(shard_logit, global_max) shard_exp = paddle.exp(shard_logit_new) - shard_demon = paddle.reduce_sum(shard_exp, dim=1, keep_dim=True) - global_demon = paddle.distributed.allreduce(shard_demon) + shard_demon = paddle.sum(shard_exp, axis=1, keepdim=True) + global_demon = shard_demon + paddle.distributed.all_reduce( + global_demon, op=paddle.distributed.ReduceOp.SUM) global_log_demon = paddle.log(global_demon) shard_log_prob = shard_logit_new - global_log_demon shard_prob = paddle.exp(shard_log_prob) - shard_one_hot = paddle.one_hot( - shard_label, depth=self.shard_dim, allow_out_of_range=True) - target_log_prob = paddle.reduce_min( - shard_log_prob * shard_one_hot, dim=1, keep_dim=True) + shard_one_hot = paddle.nn.functional.one_hot( + shard_label, num_classes=self.shard_dim) + target_log_prob = paddle.min(shard_log_prob * shard_one_hot, + axis=1, + keepdim=True) shard_loss = paddle.scale(target_log_prob, scale=-1.0) - global_loss = collective._c_reducescatter( + #TODO paddle.distributed.reducescatter not found + global_loss = paddle.fluid.layers.collective._c_reducescatter( shard_loss, nranks=self.nranks, use_calc_stream=True) return global_loss, shard_prob @@ -496,20 +494,27 @@ def softmax_classify(self, bias_attr=bias_attr, use_bias=use_bias) - paddle.distributed.allgather(x_all, x) - label_all = paddle.distributed.allgather(label_all, label) + x_list = [] + paddle.distributed.all_gather(x_list, x) + x_all = paddle.concat(x_list, axis=0) + + label_list = [] + paddle.distributed.all_gather(label_list, label) + label_all = paddle.concat(label_list, axis=0) label_all.stop_gradient = True - shard_fc = paddle.mul(x_all, weight, x_num_col_dims=1) + shard_fc = paddle.matmul(x_all, weight) if use_bias: - shard_fc = paddle.elementwise_add(shard_fc, bias) + shard_fc = paddle.add(shard_fc, bias) + label_all = paddle.reshape(label_all, (-1, 1)) shard_label = paddle.shard_index( label_all, index_num=self.nclasses, nshards=self.nranks, shard_id=self.rank_id, ignore_value=-1) + shard_label = paddle.reshape(shard_label, (-1, )) shard_label.stop_gradient = True global_loss, shard_prob = self.softmax_with_cross_entropy(shard_fc, @@ -540,39 +545,47 @@ def arcface_classify(self, use_bias=False) # normalize x - x_l2 = paddle.sqrt(nn.reduce_sum(ops.square(x), dim=1)) - norm_x = paddle.elementwise_div(x, x_l2, axis=0) + x_l2 = paddle.sqrt(paddle.sum(paddle.square(x), axis=1, keepdim=True)) + norm_x = paddle.divide(x, x_l2) + + norm_x_list = [] + paddle.distributed.all_gather(norm_x_list, norm_x) + norm_x_all = paddle.concat(norm_x_list, axis=0) - paddle.distributed.all_gather(norm_x_all, norm_x) - paddle.distributed.all_gather(label_all, label) + label_list = [] + paddle.distributed.all_gather(label_list, label) + label_all = paddle.concat(label_list, axis=0) label_all.stop_gradient = True + + label_all = paddle.reshape(label_all, (-1, 1)) shard_label = paddle.shard_index( label_all, index_num=self.nclasses, nshards=self.nranks, shard_id=self.rank_id, ignore_value=-1) + shard_label = paddle.reshape(shard_label, (-1, )) # TODO check necessary shard_label.stop_gradient = True # normalize weight weight_l2 = paddle.sqrt( - paddle.reduce_sum( - paddle.square(weight), dim=0)) - norm_weight = paddle.elementwise_div(weight, weight_l2, axis=1) + paddle.sum(paddle.square(weight), axis=0, keepdim=True)) + norm_weight = paddle.divide(weight, weight_l2) - shard_cos = paddle.mul(norm_x_all, norm_weight, x_num_col_dims=1) + shard_cos = paddle.matmul(norm_x_all, norm_weight) theta = paddle.acos(shard_cos) margin_cos = paddle.cos(theta + margin) - shard_one_hot = paddle.one_hot( - shard_label, depth=self.shard_dim, allow_out_of_range=True) + shard_one_hot = paddle.nn.functional.one_hot( + shard_label, num_classes=self.shard_dim) # TODO check necessary shard_one_hot.stop_gradient = True - diff = (margin_cos - shard_cos) * shard_one_hot - shard_target_cos = shard_cos + diff + diff = paddle.multiply( + paddle.subtract(margin_cos, shard_cos), shard_one_hot) + shard_target_cos = paddle.add(shard_cos, diff) shard_logit = paddle.scale(shard_target_cos, scale=logit_scale) global_loss, shard_prob = self.softmax_with_cross_entropy(shard_logit, @@ -620,6 +633,7 @@ def distributed_softmax_classify(x, Examples: .. code-block:: python + #TODO modify example import paddle.fluid as fluid input = fluid.layers.data(name="input", shape=[32, 1024], @@ -638,7 +652,7 @@ def distributed_softmax_classify(x, if name is None: name = 'dist@softmax@rank@%05d' % rank_id - classifier = DistributedClassifier(class_num, nranks, rank_id) + classifier = DistributedClassifier(class_num, nranks, rank_id, name) return classifier.softmax_classify(x, label, param_attr, use_bias, bias_attr) @@ -686,6 +700,7 @@ def distributed_arcface_classify(x, Examples: .. code-block:: python + #TODO modify example import paddle.fluid as fluid input = fluid.layers.data(name="input", shape=[32, 1024], @@ -703,7 +718,7 @@ def distributed_arcface_classify(x, """ if name is None: name = 'dist@arcface@rank@%05d' % rank_id - classifier = DistributedClassifier(class_num, nranks, rank_id) + classifier = DistributedClassifier(class_num, nranks, rank_id, name) return classifier.arcface_classify( x=x, label=label, diff --git a/plsc/models/resnet.py b/plsc/models/resnet.py index b03fc3596f005..92491950cefd2 100644 --- a/plsc/models/resnet.py +++ b/plsc/models/resnet.py @@ -63,18 +63,14 @@ def build_network(self, input, is_train=True): epsilon=2e-05, is_test=False if is_train else True) drop = paddle.nn.functional.dropout( - x=bn, - dropout_prob=0.4, - mode='upscale_in_train', - training=True if is_train else False) + x=bn, p=0.4, training=is_train, mode='upscale_in_train') fc = paddle.static.nn.fc( - input=drop, + x=drop, size=self.emb_dim, - param_attr=paddle.ParamAttr( - initializer=paddle.nn.initializer.Xavier( - uniform=False, fan_in=0.0)), + weight_attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.XavierNormal(fan_in=0.0)), bias_attr=paddle.ParamAttr( - initializer=paddle.nn.initializer.ConstantInitializer())) + initializer=paddle.nn.initializer.Constant())) emb = paddle.static.nn.batch_norm( input=fc, act=None, @@ -99,8 +95,7 @@ def conv_bn_layer(self, padding=pad, groups=groups, param_attr=paddle.ParamAttr( - initializer=paddle.nn.initializer.Xavier( - uniform=False, fan_in=0.0)), + initializer=paddle.nn.initializer.XavierNormal(fan_in=0.0)), bias_attr=False) if act == 'prelu': bn = paddle.static.nn.batch_norm( @@ -179,7 +174,7 @@ def bottleneck_block(self, input, num_filters, stride, is_train): is_train=is_train) short = self.shortcut(input, num_filters, stride, is_train=is_train) - return paddle.elementwise_add(x=short, y=conv2, act=None) + return paddle.add(short, conv2) def ResNet50(emb_dim=512): diff --git a/plsc/utils/fp16_utils.py b/plsc/utils/fp16_utils.py index 1b10afca5c88d..1bf4a52b640d6 100644 --- a/plsc/utils/fp16_utils.py +++ b/plsc/utils/fp16_utils.py @@ -14,8 +14,8 @@ from __future__ import print_function -from paddle import core -from paddle.fluid import layers +from paddle.fluid import core +from paddle.fluid.layer_helper import LayerHelper def _rename_arg(op, old_name, new_name): @@ -350,10 +350,10 @@ def check_finite_and_unscale(x, scale, name=None): x(list|tuple): The input tensors of check_finite_and_unscale operator. scale: The scale of check_finite_and_unscale operator. """ - check_type(x, 'x', (tuple, list), 'check_finite_and_unscale') - for e in x: - check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'], - 'check_finite_and_unscale') + #check_type(x, 'x', (tuple, list), 'check_finite_and_unscale') + #for e in x: + # check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'], + # 'check_finite_and_unscale') helper = LayerHelper("check_finite_and_unscale", **locals()) found_inf = helper.create_variable_for_type_inference(dtype='bool') @@ -404,12 +404,12 @@ def update_loss_scaling(x, loss scaling. """ - check_variable_and_dtype(prev_loss_scaling, "prev_loss_scaling", - ['float32', 'float64'], "update_loss_scaling") - check_type(x, 'x', (tuple, list), 'update_loss_scaling') + #check_variable_and_dtype(prev_loss_scaling, "prev_loss_scaling", + # ['float32', 'float64'], "update_loss_scaling") + #check_type(x, 'x', (tuple, list), 'update_loss_scaling') for e in x: - check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'], - 'update_loss_scaling') + #check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'], + # 'update_loss_scaling') if e.dtype == core.VarDesc.VarType.FP16: assert prev_loss_scaling.dtype == core.VarDesc.VarType.FP32, \ "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16." diff --git a/plsc/utils/input_field.py b/plsc/utils/input_field.py index 4c7d414e081ff..51c534bc46bfc 100644 --- a/plsc/utils/input_field.py +++ b/plsc/utils/input_field.py @@ -17,7 +17,7 @@ from __future__ import print_function import numpy as np -import paddle.fluid as fluid +import paddle class InputField(object): @@ -74,7 +74,7 @@ def __getattr__(self, name): return self.input_slots[name] - def build(self, dataset, place, batch_size, num_workers=4): + def build(self): for _name, _shape, _dtype, _lod_level in zip( self.names, self.shapes, self.dtypes, self.lod_levels): @@ -83,11 +83,3 @@ def build(self, dataset, place, batch_size, num_workers=4): for name in self.feed_list_str: self.feed_list.append(self.input_slots[name]) - - self.loader = paddle.io.DataLoader( - dataset, - feed_list=self.feed_list, - places=place, - batch_size=batch_size, - num_workers=num_workers, - capacity=capacity) diff --git a/plsc/utils/jpeg_reader.py b/plsc/utils/jpeg_reader.py index ccb4652bbd005..2318e3b6f8fe0 100644 --- a/plsc/utils/jpeg_reader.py +++ b/plsc/utils/jpeg_reader.py @@ -288,9 +288,9 @@ def arc_train(data_dir, class_dim): train_image_list = get_train_image_list(data_dir) return arc_iterator( train_image_list, - shuffle=True, class_dim=class_dim, data_dir=data_dir, + shuffle=True, color_jitter=False, rotate=False, rand_mirror=True, @@ -308,3 +308,66 @@ def test(data_dir, datasets): test_name_list.append(name) print('test', name) return test_list, test_name_list + + +class TrainDataset(paddle.io.Dataset): + def __init__(self, + data_dir, + class_dim, + color_jitter=False, + rotate=False, + rand_mirror=False, + normalize=False): + self.data_dir = data_dir + self.class_dim = class_dim + self.color_jitter = color_jitter + self.rotate = rotate + self.rand_mirror = rand_mirror + self.normalize = normalize + self.sample_list = get_train_image_list(data_dir) + + def __getitem__(self, idx): + img_path, label = self.sample_list[idx] + img_path = os.path.join(self.data_dir, img_path) + img = Image.open(img_path) + if self.rotate: + img = rotate_image(img) + img = random_resized_crop(img, DATA_DIM) + + if self.color_jitter: + img = distort_color(img) + + if self.rand_mirror: + if random.randint(0, 1) == 1: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + + if img.mode != 'RGB': + img = img.convert('RGB') + + img = np.array(img).astype('float32').transpose((2, 0, 1)) + + if self.normalize: + img -= img_mean + img /= img_std + + assert label < self.class_dim, \ + "label of train dataset should be less than the class_dim." + + return img, label + + def __len__(self): + return len(self.sample_list) + + +class TestDataset(paddle.io.Dataset): + def __init__(self, data_dir, datasets): + self.data_dir = data_dir + self.datasets = datasets.split(',') + self.sample_list = [] + + def __getitem__(self, idx): + img_path, label = self.sample_list[idx] + return img, label + + def __len__(self): + return len(self.sample_list) diff --git a/plsc/utils/parameter_converter.py b/plsc/utils/parameter_converter.py index bb22364c8f1fa..715c2f65bc842 100644 --- a/plsc/utils/parameter_converter.py +++ b/plsc/utils/parameter_converter.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,532 +11,199 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np -from __future__ import print_function -import json -import logging -import os -import shutil -from functools import cmp_to_key - -import paddle - -logging.basicConfig( - level=logging.INFO, - format='[%(levelname)s %(asctime)s line:%(lineno)d] %(message)s', - datefmt='%d %b %Y %H:%M:%S') -logger = logging.getLogger() - - -class ParameterConverter(object): +def rearrange_weight(weight_dict, init_num_rank, new_num_rank): """ - Tool to convert pre-trained distributed fc parameters for inference or - fine-tuning. Note that the number of ranks or GPUs for inference or - fine-tuning can be different from that for pre-training. + A help function to convert pre-trained distributed fc parameters for + inference or fine-tuning. Note that the number of ranks or GPUs for + inference or fine-tuning can be different from that for pre-training. + + Args: + weight_dict(dict): the dict store distributed parameters, + key: eg. dist@arcface@rank@00000.w_0 + value: numpy.ndarray + init_num_rank(int) : pre-trained weight at init_num_rank gpu device. + new_num_rank(int) : want to rearrange weight to new_num_rank gpu device. + + Returns: + dict: rearranged weight for new_num_rank gpu device. """ - def __init__(self, model_dir, output_dir, num_trainers): - super(ParameterConverter, self).__init__() - self.model_dir = model_dir - self.output_dir = output_dir - self.pretrain_nranks = -1 - self.emb_dim = -1 - self.num_classes = -1 - self.nranks = num_trainers - - self.load_config() - - def load_config(self): - """ - Load config file which contains the following information for - pre-training: - 1. pretrain_nranks (int): number of ranks for pre-training; - 2. emb_dim (int): embedding dim for pre-training; - 3. num_classes (int): number of classes for classification. - """ - meta_file = os.path.join(self.model_dir, 'meta.json') - if not os.path.exists(meta_file): - logger.error( - "Meta file does not exist, make sure your pre-trained " - "models are legal.") - exit() - - with open(meta_file, 'r') as handle: - config = json.load(handle) - - self.pretrain_nranks = config['pretrain_nranks'] - assert self.pretrain_nranks > 0 - self.emb_dim = config['emb_dim'] - assert self.emb_dim > 0 - self.num_classes = config['num_classes'] - assert self.num_classes > 0 - - logger.info("Parameters for pre-training: pretrain_nranks ({}), " - "emb_dim ({}), and num_classes ({}).".format( - self.pretrain_nranks, self.emb_dim, self.num_classes)) - logger.debug("Parameters for inference or fine-tuning: " - "nranks ({}).".format(self.nranks)) - - def find_var_names(self): - """ - Find all names of pre-trained parameters for the distributed fc layer, - e.g., dist@softmax@rank@00000.w_0, dist@softmax@rank@00000.b_0 etc. - We assume that names of distributed fc related parameters start with the - prefix dist@ and have @rank@ in their names. - """ - var_names = [] - model_dir = os.path.abspath(self.model_dir) - if not os.path.exists(model_dir): - logger.error("The directory for pre-trained model ({}) does not " - "exist, please check it.".format(model_dir)) - exit() - logger.info("The directory for pre-trained model: {}".format( - model_dir)) - for file in os.listdir(model_dir): - if 'dist@' in file and '@rank@' in file: - var_names.append(file) - assert len(var_names) > 0, \ - logger.error("No distributed fc parameters found.") - logger.info("Number of distributed fc parameters: {}.".format( - len(var_names))) - logger.info("Distributed fc parameters: {}.".format(var_names)) - return var_names - - def split_load_and_save(self, - name_index, - param_names, - save_rank_id, - remainder, - as_bias, - train_nshards, - train_nranks, - nshards, - dtype="float32"): - var2 = None - advance = False - emb_dim = self.emb_dim - main_program = paddle.static.Program() - startup_program = paddle.static.Program() - num_classes = self.num_classes - - load_var_name = param_names[name_index] - save_var_name_list = load_var_name.split('.') - save_var_name_list[0] = save_var_name_list[0].split('@') - save_var_name_list[0][-1] = "%05d" % save_rank_id - save_var_name_list[0] = '@'.join(save_var_name_list[0]) - save_var_name = '.'.join(save_var_name_list) - - last_train_nshards = num_classes - (train_nranks - 1) * train_nshards - - with paddle.static.program_guard(main_program, startup_program): - if name_index == train_nranks - 1: - var_dim = last_train_nshards - else: - var_dim = train_nshards - - shape = [var_dim] if as_bias else [emb_dim, var_dim] - var = paddle.static.create_parameter( - shape, dtype=dtype, name=load_var_name) - - if as_bias: - var = paddle.slice( - var, - axes=[0], - starts=[var.shape[0] - remainder], - ends=[var.shape[0]]) - else: - var = paddle.split( - var, [var.shape[1] - remainder, remainder], dim=1)[1] - - save_var_dim = nshards - if remainder < nshards: - if name_index == train_nranks - 1: - save_var_dim = remainder + ret_dict = {} + if init_num_rank == new_num_rank: + return weight_dict + + if len(weight_dict) == 0: + return weight_dict + + # generate name format + name_format = list(weight_dict.keys())[0] + name_format = name_format.split('.') + name_format[0] = name_format[0].split('@') + name_format[0][-1] = '%05d' + name_format[0] = '@'.join(name_format[0]) + name_format = '.'.join(name_format) + + # calculate num class of pretrain shard + # num class of new shard + num_class = sum([ + w.shape[1] if len(w.shape) == 2 else len(w) + for _, w in weight_dict.items() + ]) + init_nshard = (num_class + init_num_rank - 1) // init_num_rank + new_nshard = (num_class + new_num_rank - 1) // new_num_rank + + if new_nshard * (new_num_rank - 1) >= num_class: + raise ValueError( + "num class {} cann't be rationally splited by num rank {}".format( + num_class, new_num_rank)) + + if init_num_rank > new_num_rank: + for new_idx in range(new_num_rank): + start = new_idx * new_nshard + end = min((new_idx + 1) * new_nshard - 1, num_class - 1) + init_shard_idx_start = start // init_nshard + init_shard_idx_end = end // init_nshard + + weight_list = [] + for init_idx in range(init_shard_idx_start, + init_shard_idx_end + 1): + name = name_format % init_idx + init_weight = weight_dict[name] + s = max(start - init_idx * init_nshard, 0) + if init_idx == init_shard_idx_end: + e = min(end - init_idx * init_nshard + 1, init_nshard) else: - name_index += 1 - advance = True - load_var_name = param_names[name_index] - - if name_index == train_nranks - 1: - var_dim = last_train_nshards - else: - var_dim = train_nshards - shape = [var_dim] if as_bias else [emb_dim, var_dim] - var2 = paddle.static.create_parameter( - shape, dtype=dtype, name=load_var_name) - - if remainder + var_dim < nshards: - # The last train rank - save_var_dim = remainder + var_dim - else: - remainder = remainder + var_dim - nshards - elif remainder == nshards: - if name_index == train_nranks - 2: - remainder = last_train_nshards - advance = True - elif name_index < train_nranks - 2: - remainder = train_nshards - advance = True - else: - remainder = remainder - nshards - if var2 is not None: - var = paddle.concat([var, var2], axis=0 if as_bias else 1) - - shape = [save_var_dim] if as_bias else [emb_dim, save_var_dim] - to_save_var = paddle.static.create_parameter( - shape, dtype=dtype, name=save_var_name + '_temp') - if save_var_dim != nshards: # get last dim - if as_bias: - temp_var = paddle.slice( - var, - axes=[0], - starts=[var.shape[0] - save_var_dim], - ends=[var.shape[0]]) + e = init_nshard + if len(init_weight.shape) == 2: + weight_list.append(init_weight[:, s:e]) else: - temp_var = paddle.split( - var, [var.shape[1] - save_var_dim, save_var_dim], - dim=1)[1] - paddle.assign(temp_var, to_save_var) - else: - if as_bias: - temp_var = paddle.slice( - var, axes=[0], starts=[0], ends=[nshards]) + weight_list.append(init_weight[s:e]) + + name = name_format % new_idx + # for 2-dimention, we concat at axis=1, + # else for 1-dimention, we concat at axis=0 + ret_dict[name] = np.concatenate( + weight_list, axis=len(weight_list[0].shape) - 1) + else: + for new_idx in range(new_num_rank): + start = new_idx * new_nshard + end = min((new_idx + 1) * new_nshard - 1, num_class - 1) + init_shard_idx_start = start // init_nshard + init_shard_idx_end = end // init_nshard + + if init_shard_idx_start == init_shard_idx_end: + name = name_format % init_shard_idx_start + init_weight = weight_dict[name] + init_start = init_shard_idx_start * init_nshard + s = max(start - init_start, 0) + e = min((init_shard_idx_start + 1) * init_nshard, + end) - init_start + 1 + if len(init_weight.shape) == 2: + new_weight = init_weight[:, s:e] else: - temp_var = paddle.split( - var, [nshards, var.shape[1] - nshards], dim=1)[0] - paddle.assign(temp_var, to_save_var) - - def expected_var(var): - has_var = os.path.exists(os.path.join(self.model_dir, var.name)) - if has_var: - return True - return False - - place = paddle.CPUPlace() - exe = paddle.static.Executor(place) - exe.run(startup_program) - paddle.static.load( - main_program, self.model_dir, exe, predicate=expected_var) - exe.run(main_program) - if not os.path.exists(self.output_dir): - os.makedirs(self.output_dir) - paddle.static.save(main_program, self.output_dir) - srcfile = os.path.join(self.output_dir, to_save_var.name) - dstfile = os.path.join(self.output_dir, save_var_name) - shutil.move(srcfile, dstfile) - return remainder, advance - - def split_parameters(self, param_names, as_bias): - """ - Split parameters whose names are in param_names. - Params: - param_names: list of names of parameters to split - as_bias: whether parameters to split are as bias or not - """ - num_classes = self.num_classes - train_nranks = self.pretrain_nranks - nranks = self.nranks - - train_nshards = (num_classes + train_nranks - 1) // train_nranks - nshards = (num_classes + nranks - 1) // nranks - - save_rank_id = 0 - # remainder dim that is not split in a var - remainder_var_dim = train_nshards - name_index = 0 # index of name of pre-trained parameter to process - for save_rank_id in range(nranks): - assert name_index < train_nranks - remainder_var_dim, advance = self.split_load_and_save( - name_index, param_names, save_rank_id, remainder_var_dim, - as_bias, train_nshards, train_nranks, nshards) - name_index += 1 if advance else 0 - processed_var_count = name_index + 1 - - assert processed_var_count == train_nranks, \ - logger.error("Number of pre-trained parameters processed ({}) is " - "not equal to the number of ranks ({}) for " - "pre-training.".format(processed_var_count, - train_nranks)) - assert save_rank_id == nranks - 1, \ - logger.error("Number of saved parameters ({}) is not equal to the " - "number of ranks ({}) for inference or " - "fine-tuning.".format(save_rank_id + 1, nranks)) - - def split_distfc_parameters(self, weight_param_names, - weight_velocity_param_names, bias_param_names, - bias_velocity_param_names): - """ - Split each distributed fc-related parameter according to number of ranks - for inference or fine-tuning. - - Params: - weight_param_names: list of names of weight parameters - bias_param_names: list of names of bias parameters - """ - self.split_parameters(weight_param_names, as_bias=False) - self.split_parameters(weight_velocity_param_names, as_bias=False) - if len(bias_param_names) != 0: - self.split_parameters(bias_param_names, as_bias=True) - self.split_parameters(bias_velocity_param_names, as_bias=True) - - def concat_load_and_save(self, - name_index, - param_names, - save_rank_id, - remainder, - as_bias, - train_nshards, - train_nranks, - nshards, - dtype="float32"): - advance = 0 - emb_dim = self.emb_dim - main_program = paddle.static.Program() - startup_program = paddle.static.Program() - num_classes = self.num_classes - - load_var_name = param_names[name_index] - save_var_name_list = load_var_name.split('.') - save_var_name_list[0] = save_var_name_list[0].split('@') - save_var_name_list[0][-1] = "%05d" % save_rank_id - save_var_name_list[0] = '@'.join(save_var_name_list[0]) - save_var_name = '.'.join(save_var_name_list) - - last_train_nshards = num_classes - (train_nranks - 1) * train_nshards - - with paddle.static.program_guard(main_program, startup_program): - if name_index == train_nranks - 1: - var_dim = last_train_nshards + new_weight = init_weight[s:e] else: - var_dim = train_nshards - - shape = [var_dim] if as_bias else [emb_dim, var_dim] - var = paddle.static.create_parameter( - shape, dtype=dtype, name=load_var_name) - - if as_bias: - var = paddle.slice( - var, - axes=[0], - starts=[var.shape[0] - remainder], - ends=[var.shape[0]]) - else: - var = paddle.split( - var, [var.shape[1] - remainder, remainder], dim=1)[1] - to_concat_var_list = [var] - while remainder < nshards and name_index < train_nranks - 1: - name_index += 1 - advance += 1 - load_var_name = param_names[name_index] - if name_index == train_nranks - 1: - var_dim = last_train_nshards - else: - var_dim = train_nshards - shape = [var_dim] if as_bias else [emb_dim, var_dim] - var = paddle.static.create_parameter( - shape, dtype=dtype, name=load_var_name) - - to_concat_var_list.append(var) - remainder += var_dim - if len(to_concat_var_list) > 1: - var = paddle.concat( - to_concat_var_list, axis=0 if as_bias else 1) - save_var_dim = nshards - if remainder > nshards: - if as_bias: - var = paddle.slice( - var, axes=[0], starts=[0], ends=[nshards]) + # init_shard_idx_start + 1 == init_shard_idx_end + name = name_format % init_shard_idx_start + init_weight = weight_dict[name] + init_start = init_shard_idx_start * init_nshard + s = max(start - init_start, 0) + if len(init_weight.shape) == 2: + new_weight = init_weight[:, s:] else: - var = paddle.split( - var, [nshards, var.shape[1] - nshards], dim=1)[0] - remainder = remainder - nshards - elif remainder == nshards: - if name_index == train_nranks - 2: - # advance += 1 if len(to_concat_var_list) > 1 else 0 - # to avoid duplicate add - # name_index += 1 if len(to_concat_var_list) > 1 else 0 - # to avoid duplicate add - advance += 1 - name_index += 1 - remainder = last_train_nshards - elif name_index < train_nranks - 2: - # advance += 1 if len(to_concat_var_list) > 1 else 0 - # to avoid duplicate add - # name_index += 1 if len(to_concat_var_list) > 1 else 0 - # to avoid duplicate add - advance += 1 - name_index += 1 - remainder = train_nshards - else: - save_var_dim = remainder - - shape = [save_var_dim] if as_bias else [emb_dim, save_var_dim] - to_save_var = paddle.static.create_parameter( - shape, dtype=dtype, name=save_var_name + '_temp') - - paddle.assign(var, to_save_var) - - def expected_var(var): - has_var = os.path.exists(os.path.join(self.model_dir, var.name)) - if has_var: - return True - return False - - place = paddle.CPUPlace() - exe = paddle.static.Executor(place) - exe.run(startup_program) - paddle.static.load( - main_program, dirname=self.model_dir, exe, predicate=expected_var) - exe.run(main_program) - if not os.path.exists(self.output_dir): - os.makedirs(self.output_dir) - paddle.static.save(main_program, self.output_dir) - srcfile = os.path.join(self.output_dir, to_save_var.name) - dstfile = os.path.join(self.output_dir, save_var_name) - shutil.move(srcfile, dstfile) - return remainder, advance - - def concat_parameters(self, param_names, as_bias): - """ - Concat parameters whose names are in param_names. - Params: - param_names: list of names of parameters to concat - as_bias: whether parameters to split are as bias or not - """ - num_classes = self.num_classes - train_nranks = self.pretrain_nranks - nranks = self.nranks + new_weight = init_weight[s:] + + e = end - (init_shard_idx_end * init_nshard) + 1 + if e > 0: + name = name_format % init_shard_idx_end + init_weight = weight_dict[name] + if len(init_weight.shape) == 2: + new_weight2 = init_weight[:, :e] + else: + new_weight2 = init_weight[:e] - train_nshards = (num_classes + train_nranks - 1) // train_nranks - nshards = (num_classes + nranks - 1) // nranks + new_weight = np.concatenate( + [new_weight, new_weight2], + axis=len(new_weight.shape) - 1) + name = name_format % new_idx + ret_dict[name] = new_weight - save_rank_id = 0 - remainder_dim = train_nshards # remainder dim that is not concated - name_index = 0 # index of name of pre-trained parameter to process - for save_rank_id in range(nranks): - assert name_index < train_nranks - remainder_dim, advance = self.concat_load_and_save( - name_index, param_names, save_rank_id, remainder_dim, as_bias, - train_nshards, train_nranks, nshards) - name_index += advance - processed_var_count = name_index + 1 + return ret_dict - assert processed_var_count == train_nranks, \ - logger.error("Number of pre-trained parameters processed ({}) is " - "not equal to the number of ranks ({}) for " - "pre-training.".format(processed_var_count, - train_nranks)) - assert save_rank_id == nranks - 1, \ - logger.error("Number of saved parameters ({}) is not equal to the " - "number of ranks ({}) for inference or " - "fine-tuning.".format(save_rank_id + 1, nranks)) - def concat_distfc_parameters(self, weight_param_names, - weight_velocity_param_names, bias_param_names, - bias_velocity_param_names): - """ - Concat distributed fc-related parameters according to number of ranks - for inference or finetuning. +if __name__ == '__main__': - Params: - weight_param_names: list of names of weight parameters - weight_velocity_param_names: list of names of weight velocity - parameters - bias_param_names: list of names of bias parameters - bias_velocity_param_names: list of names of bias velocity parameters - """ - self.concat_parameters(weight_param_names, as_bias=False) - self.concat_parameters(weight_velocity_param_names, as_bias=False) - if len(bias_param_names) != 0: - self.concat_parameters(bias_param_names, as_bias=True) - self.concat_parameters(bias_velocity_param_names, as_bias=True) + def generate_data(num_rank, is_bias=False): + num_dim = 2 + num_class = 16 + nshard = (num_class + num_rank - 1) // num_rank - def process(self): - self.load_config() - var_names = self.find_var_names() - weight_param_names = [ - name for name in var_names - if '.w' in name and 'velocity' not in name - ] - weight_velocity_param_names = [ - name for name in var_names if '.w' in name and 'velocity' in name - ] - bias_param_names = [ - name for name in var_names - if '.b' in name and 'velocity' not in name - ] - bias_velocity_param_names = [ - name for name in var_names if '.b' in name and 'velocity' in name - ] + weight_dict = {} - def parameter_name_compare(x, y): - """ - Compare two parameter names depend on their rank id. - A parameter name is like dist_softmax_rank_00000.w_0, - where 00000 is the rank id. - """ - rank_id_x = int(x.split('.')[0].split('@')[-1]) - rank_id_y = int(y.split('.')[0].split('@')[-1]) - if rank_id_x < rank_id_y: - return -1 - elif rank_id_x == rank_id_y: - return 0 + if is_bias: + data = np.array(range(num_class)) + else: + data = np.array(range(num_dim * num_class)).reshape( + (num_dim, num_class)) + print('fc weight:') + print(data) + + print('shard fc weight:') + for i in range(num_rank): + name = 'dist@arcface@rank@%05d.w_0' % i + start = i * nshard + end = min((i + 1) * nshard, num_class) + if is_bias: + weight_dict[name] = data[start:end] else: - return 1 + weight_dict[name] = data[:, start:end] + print(name) + print(weight_dict[name]) - weight_param_names.sort(key=cmp_to_key(parameter_name_compare)) - weight_velocity_param_names.sort( - key=cmp_to_key(parameter_name_compare)) - bias_param_names.sort(key=cmp_to_key(parameter_name_compare)) - bias_velocity_param_names.sort(key=cmp_to_key(parameter_name_compare)) + return weight_dict - assert len(weight_param_names) == self.pretrain_nranks, \ - logger.error( - "Number of distributed fc-related weight parameters ({}) " - "should be equal to the number of ranks ({}) for " - "pre-training.".format(len(weight_param_names), - self.pretrain_nranks)) - assert len(weight_velocity_param_names) == self.pretrain_nranks, \ - logger.error( - "Number of distributed fc-related weight parameters ({}) " - "should be equal to the number of ranks ({}) for " - "pre-training.".format(len(weight_velocity_param_names), - self.pretrain_nranks)) - assert (len(bias_param_names) == 0 or - len(bias_param_names) == self.pretrain_nranks), \ - logger.error( - "Number of distributed fc-related bias parameters ({}) " - "should be 0 or equal to the number of ranks ({}) for " - "pre-training.".format(len(bias_param_names), - self.pretrain_nranks)) - assert (len(bias_velocity_param_names) == 0 or - len(bias_velocity_param_names) == self.pretrain_nranks), \ - logger.error( - "Number of distributed fc-related bias parameters ({}) " - "should be 0 or equal to the number of ranks ({}) for " - "pre-training.".format(len(bias_velocity_param_names), - self.pretrain_nranks)) - - pretrain_nranks = self.pretrain_nranks - nranks = self.nranks - if pretrain_nranks == nranks: - logger.info( - "Pre-training and inference (or fine-tuning) have the same " - "number of ranks, nothing to do.") - elif pretrain_nranks < nranks: - self.split_distfc_parameters( - weight_param_names, weight_velocity_param_names, - bias_param_names, bias_velocity_param_names) - else: - self.concat_distfc_parameters( - weight_param_names, weight_velocity_param_names, - bias_param_names, bias_velocity_param_names) + def generate_data1(num_rank, is_bias=False): + num_dim = 2 + num_class = 85742 + nshard = (num_class + num_rank - 1) // num_rank - logger.info("Done.") + weight_dict = {} - -if __name__ == "__main__": - converter = ParameterConverter('./trained_model', "./trained_model_temp", - 8) - converter.process() + if is_bias: + data = np.array(range(num_class)) + else: + data = np.array(range(num_dim * num_class)).reshape( + (num_dim, num_class)) + print('fc weight:') + print(data.shape) + + print('shard fc weight:') + for i in range(num_rank): + name = 'dist@arcface@rank@%05d.w_0' % i + start = i * nshard + end = min((i + 1) * nshard, num_class) + if is_bias: + weight_dict[name] = data[start:end] + else: + weight_dict[name] = data[:, start:end] + print(name) + print(weight_dict[name].shape) + + print() + return weight_dict + + init_num_rank = 3 + new_num_rank = 6 + is_bias = False + weight_dict = generate_data1(init_num_rank, is_bias) + + weight_dict = rearrange_weight(weight_dict, init_num_rank, new_num_rank) + num_class = 0 + for n, w in weight_dict.items(): + print(n) + print(w.shape) + num_class += w.shape[-1] + print(num_class) diff --git a/train.py b/train.py new file mode 100644 index 0000000000000..be01045cd379e --- /dev/null +++ b/train.py @@ -0,0 +1,26 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from plsc import Entry + +if __name__ == "__main__": + ins = Entry() + ins.set_dataset_dir('/path/to/your/data/folder/') + ins.set_loss_type('dist_arcface') + #ins.set_mixed_precision(True) + ins.set_train_epochs(50) + ins.set_test_period(2000) + ins.set_calc_acc(True) + ins.set_model_save_dir('./saved_model') + ins.train() diff --git a/train.sh b/train.sh new file mode 100755 index 0000000000000..a5947e232953b --- /dev/null +++ b/train.sh @@ -0,0 +1,15 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +python3 -m paddle.distributed.launch --gpus=0,1,2,3,4,5,6,7 train.py