From d24ce2967bf6aeedd260acde82bb1c6abb405902 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=99=93=E9=BE=99?= <1225386395@qq.com> Date: Wed, 9 Aug 2023 21:28:46 +0800 Subject: [PATCH] update from Tencent/TencentPretrain (#79) * add inference/run_classifier_mt_infer.py * alter inference/run_classifier_mt_infer.py * alter inference/run_classifier_mt_infer.py * standardize inference/run_classifier_mt_infer.py --- inference/run_classifier_mt_infer.py | 26 ++++---------------------- 1 file changed, 4 insertions(+), 22 deletions(-) diff --git a/inference/run_classifier_mt_infer.py b/inference/run_classifier_mt_infer.py index 243668f..9d15155 100644 --- a/inference/run_classifier_mt_infer.py +++ b/inference/run_classifier_mt_infer.py @@ -1,5 +1,5 @@ """ - This script provides an example to wrap TencentPretrain for classification multi-tasking inference. + This script provides an example to wrap TencentPretrain for multi-task classification inference. """ import sys import os @@ -17,7 +17,6 @@ from tencentpretrain.utils import * from tencentpretrain.utils.config import load_hyperparam from tencentpretrain.utils.seed import set_seed -from tencentpretrain.utils.logging import init_logger from tencentpretrain.utils.misc import pooling from tencentpretrain.model_loader import * from tencentpretrain.opts import infer_opts, tokenizer_opts, log_opts @@ -57,14 +56,6 @@ def forward(self, src, tgt, seg, soft_tgt=None): return None, logits -def load_or_initialize_parameters(args, model): - assert args.load_model_path is not None, "load_model_path cat not be None!" - # Initialize with pretrained model. - keys_info = model.load_state_dict(torch.load(args.load_model_path, map_location="cpu"), strict=False) - args.logger.info("missing_keys: {0}".format(keys_info.missing_keys)) - args.logger.info("unexpected_keys: {0}".format(keys_info.unexpected_keys)) - - def read_dataset(args, path): dataset, columns = [], {} with open(path, mode="r", encoding="utf-8") as f: @@ -133,17 +124,13 @@ def main(): # Build classification model and load parameters. args.soft_targets, args.soft_alpha = False, False model = MultitaskClassifier(args) - - # Get logger - args.logger = init_logger(args) - - load_or_initialize_parameters(args, model) + model = load_model(model, args.load_model_path) # For simplicity, we use DataParallel wrapper to use multiple GPUs. device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) if torch.cuda.device_count() > 1: - args.logger.info("{0} GPUs are available. Let's use them.".format(torch.cuda.device_count())) + print("{0} GPUs are available. Let's use them.".format(torch.cuda.device_count())) model = torch.nn.DataParallel(model) dataset = read_dataset(args, args.test_path) @@ -154,11 +141,10 @@ def main(): batch_size = args.batch_size instances_num = src.size()[0] - args.logger.info("The number of prediction instances: {0}".format(instances_num)) + print("The number of prediction instances: {0}".format(instances_num)) model.eval() - pred_num = 0 with open(args.prediction_path, mode="w", encoding="utf-8") as f: f.write("label") if args.output_logits: @@ -167,8 +153,6 @@ def main(): f.write("\t" + "prob") f.write("\n") for i, (src_batch, seg_batch) in enumerate(batch_loader(batch_size, src, seg)): - pred_num += batch_size - args.logger.info("pre {0}/{1}".format(pred_num, instances_num)) src_batch = src_batch.to(device) seg_batch = seg_batch.to(device) with torch.no_grad(): @@ -189,8 +173,6 @@ def main(): f.write("\t" + "|".join([" ".join(["{0:.4f}".format(w) for w in v[j]]) for v in prob])) f.write("\n") f.close() - args.logger.info("Done.") - args.logger.info("Saved in {0}".format(args.prediction_path)) if __name__ == "__main__":