Skip to content

Commit

Permalink
update from Tencent/TencentPretrain (#79)
Browse files Browse the repository at this point in the history
* add inference/run_classifier_mt_infer.py

* alter inference/run_classifier_mt_infer.py

* alter inference/run_classifier_mt_infer.py

* standardize inference/run_classifier_mt_infer.py
  • Loading branch information
Winter523 authored Aug 9, 2023
1 parent ea1f69b commit d24ce29
Showing 1 changed file with 4 additions and 22 deletions.
26 changes: 4 additions & 22 deletions inference/run_classifier_mt_infer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
This script provides an example to wrap TencentPretrain for classification multi-tasking inference.
This script provides an example to wrap TencentPretrain for multi-task classification inference.
"""
import sys
import os
Expand All @@ -17,7 +17,6 @@
from tencentpretrain.utils import *
from tencentpretrain.utils.config import load_hyperparam
from tencentpretrain.utils.seed import set_seed
from tencentpretrain.utils.logging import init_logger
from tencentpretrain.utils.misc import pooling
from tencentpretrain.model_loader import *
from tencentpretrain.opts import infer_opts, tokenizer_opts, log_opts
Expand Down Expand Up @@ -57,14 +56,6 @@ def forward(self, src, tgt, seg, soft_tgt=None):
return None, logits


def load_or_initialize_parameters(args, model):
assert args.load_model_path is not None, "load_model_path cat not be None!"
# Initialize with pretrained model.
keys_info = model.load_state_dict(torch.load(args.load_model_path, map_location="cpu"), strict=False)
args.logger.info("missing_keys: {0}".format(keys_info.missing_keys))
args.logger.info("unexpected_keys: {0}".format(keys_info.unexpected_keys))


def read_dataset(args, path):
dataset, columns = [], {}
with open(path, mode="r", encoding="utf-8") as f:
Expand Down Expand Up @@ -133,17 +124,13 @@ def main():
# Build classification model and load parameters.
args.soft_targets, args.soft_alpha = False, False
model = MultitaskClassifier(args)

# Get logger
args.logger = init_logger(args)

load_or_initialize_parameters(args, model)
model = load_model(model, args.load_model_path)

# For simplicity, we use DataParallel wrapper to use multiple GPUs.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
if torch.cuda.device_count() > 1:
args.logger.info("{0} GPUs are available. Let's use them.".format(torch.cuda.device_count()))
print("{0} GPUs are available. Let's use them.".format(torch.cuda.device_count()))
model = torch.nn.DataParallel(model)

dataset = read_dataset(args, args.test_path)
Expand All @@ -154,11 +141,10 @@ def main():
batch_size = args.batch_size
instances_num = src.size()[0]

args.logger.info("The number of prediction instances: {0}".format(instances_num))
print("The number of prediction instances: {0}".format(instances_num))

model.eval()

pred_num = 0
with open(args.prediction_path, mode="w", encoding="utf-8") as f:
f.write("label")
if args.output_logits:
Expand All @@ -167,8 +153,6 @@ def main():
f.write("\t" + "prob")
f.write("\n")
for i, (src_batch, seg_batch) in enumerate(batch_loader(batch_size, src, seg)):
pred_num += batch_size
args.logger.info("pre {0}/{1}".format(pred_num, instances_num))
src_batch = src_batch.to(device)
seg_batch = seg_batch.to(device)
with torch.no_grad():
Expand All @@ -189,8 +173,6 @@ def main():
f.write("\t" + "|".join([" ".join(["{0:.4f}".format(w) for w in v[j]]) for v in prob]))
f.write("\n")
f.close()
args.logger.info("Done.")
args.logger.info("Saved in {0}".format(args.prediction_path))


if __name__ == "__main__":
Expand Down

0 comments on commit d24ce29

Please sign in to comment.