From 9993fad60641e03b579f16e2e74b7666561548f7 Mon Sep 17 00:00:00 2001 From: Mingzhi Zheng Date: Mon, 28 Oct 2019 23:06:22 -0700 Subject: [PATCH] fix issues in new quac-kd runner --- examples/run_quac_kd.py | 34 +++++++++++++++++++++++++++------- examples/utils_quac_kd.py | 8 ++------ 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/examples/run_quac_kd.py b/examples/run_quac_kd.py index 5b5c503575b665..085e50fd4f66ad 100644 --- a/examples/run_quac_kd.py +++ b/examples/run_quac_kd.py @@ -45,8 +45,9 @@ from transformers import AdamW, WarmupLinearSchedule from utils_quac import (read_quac_examples, convert_examples_to_features, - RawResult, write_predictions, write_predictions_v2, - RawResultExtended, write_predictions_extended) + RawResult, write_predictions, write_predictions_v2, + RawResultExtended, write_predictions_extended, + RawResultV2) # The follwing import is the official QuAC evaluation script (2.0). # You can remove it from the dependencies if you are using this script outside of the library @@ -249,7 +250,7 @@ def evaluate(args, model, tokenizer, prefix=""): result = RawResultV2(unique_id = unique_id, start_logits = to_list(outputs[0][i]), end_logits = to_list(outputs[1][i]), - kd_end_logits = to_list(outputs[2][i]), + kd_start_logits = to_list(outputs[2][i]), kd_end_logits = to_list(outputs[3][i])) else: result = RawResult(unique_id = unique_id, @@ -290,7 +291,7 @@ def evaluate(args, model, tokenizer, prefix=""): return results def predict(args, model, tokenizer, prefix=""): - dataset = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=False) + dataset, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=False) if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: os.makedirs(args.output_dir) @@ -342,7 +343,26 @@ def predict(args, model, tokenizer, prefix=""): end_logits = to_list(outputs[1][i])) all_results.append(result) - update_features_v2(features, all_results) + input_file = args.predict_file if evaluate else args.train_file + updated_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format( + 'dev' if evaluate else 'train', + list(filter(None, args.model_name_or_path.split('/'))).pop(), + str(args.max_seq_length))) + + logger.info("Writing updated feature to: %s" % (updated_features_file)) + + result_lookup = { result.unique_id: result for result in all_results } + + updated_features = [] + for feature in features: + if feature.unique_id not in result_lookup: + continue + + result = result_lookup[feature.unique_id] + feature.start_targets = result.kd_start_logits + feature.end_targets = result.kd_end_logits + + torch.save(updated_features, updated_features_file) def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False): if args.local_rank not in [-1, 0] and not evaluate: @@ -406,7 +426,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal if output_examples: return dataset, examples, features - return dataset + return dataset, features def main(): @@ -572,7 +592,7 @@ def main(): # Training if args.do_train: - train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False) + train_dataset, _ = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False) global_step, tr_loss = train(args, train_dataset, model, tokenizer) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) diff --git a/examples/utils_quac_kd.py b/examples/utils_quac_kd.py index 23c342e3854ba1..8880c2e65c42e2 100644 --- a/examples/utils_quac_kd.py +++ b/examples/utils_quac_kd.py @@ -1208,12 +1208,8 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s return out_eval -def update_features_v2(all_features, all_results, output_feature_file): - """Write updated feature to the json file""" - logger.info("Writing updated feature to: %s" % (output_feature_file)) - - with open(output_prediction_file, "w") as file: - file.write(json.dumps(all_features, indent=2) + "\n") +RawResultV2 = collections.namedtuple("RawResultV2", + ["unique_id", "start_logits", "end_logits", "kd_start_logits", "kd_end_logits"]) def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): """Project the tokenized prediction back to the original text."""