Skip to content

Commit

Permalink
Merge pull request huggingface#9 from stevezheng23/dev/zheng/quac
Browse files Browse the repository at this point in the history
fix issues in new quac-kd runner
  • Loading branch information
stevezheng23 authored Oct 29, 2019
2 parents 758af24 + 9993fad commit f2fed0f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 13 deletions.
34 changes: 27 additions & 7 deletions examples/run_quac_kd.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@
from transformers import AdamW, WarmupLinearSchedule

from utils_quac import (read_quac_examples, convert_examples_to_features,
RawResult, write_predictions, write_predictions_v2,
RawResultExtended, write_predictions_extended)
RawResult, write_predictions, write_predictions_v2,
RawResultExtended, write_predictions_extended,
RawResultV2)

# The follwing import is the official QuAC evaluation script (2.0).
# You can remove it from the dependencies if you are using this script outside of the library
Expand Down Expand Up @@ -249,7 +250,7 @@ def evaluate(args, model, tokenizer, prefix=""):
result = RawResultV2(unique_id = unique_id,
start_logits = to_list(outputs[0][i]),
end_logits = to_list(outputs[1][i]),
kd_end_logits = to_list(outputs[2][i]),
kd_start_logits = to_list(outputs[2][i]),
kd_end_logits = to_list(outputs[3][i]))
else:
result = RawResult(unique_id = unique_id,
Expand Down Expand Up @@ -290,7 +291,7 @@ def evaluate(args, model, tokenizer, prefix=""):
return results

def predict(args, model, tokenizer, prefix=""):
dataset = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=False)
dataset, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=False)

if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
os.makedirs(args.output_dir)
Expand Down Expand Up @@ -342,7 +343,26 @@ def predict(args, model, tokenizer, prefix=""):
end_logits = to_list(outputs[1][i]))
all_results.append(result)

update_features_v2(features, all_results)
input_file = args.predict_file if evaluate else args.train_file
updated_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format(
'dev' if evaluate else 'train',
list(filter(None, args.model_name_or_path.split('/'))).pop(),
str(args.max_seq_length)))

logger.info("Writing updated feature to: %s" % (updated_features_file))

result_lookup = { result.unique_id: result for result in all_results }

updated_features = []
for feature in features:
if feature.unique_id not in result_lookup:
continue

result = result_lookup[feature.unique_id]
feature.start_targets = result.kd_start_logits
feature.end_targets = result.kd_end_logits

torch.save(updated_features, updated_features_file)

def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
if args.local_rank not in [-1, 0] and not evaluate:
Expand Down Expand Up @@ -406,7 +426,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal

if output_examples:
return dataset, examples, features
return dataset
return dataset, features


def main():
Expand Down Expand Up @@ -572,7 +592,7 @@ def main():

# Training
if args.do_train:
train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
train_dataset, _ = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

Expand Down
8 changes: 2 additions & 6 deletions examples/utils_quac_kd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,12 +1208,8 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s

return out_eval

def update_features_v2(all_features, all_results, output_feature_file):
"""Write updated feature to the json file"""
logger.info("Writing updated feature to: %s" % (output_feature_file))

with open(output_prediction_file, "w") as file:
file.write(json.dumps(all_features, indent=2) + "\n")
RawResultV2 = collections.namedtuple("RawResultV2",
["unique_id", "start_logits", "end_logits", "kd_start_logits", "kd_end_logits"])

def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
"""Project the tokenized prediction back to the original text."""
Expand Down

0 comments on commit f2fed0f

Please sign in to comment.