From 085111db2c29ddb0874c6554d81e5a9bd4fc0d40 Mon Sep 17 00:00:00 2001 From: "Jack FitzGerald (jgmf)" Date: Wed, 22 Jun 2022 13:04:07 -0600 Subject: [PATCH] fixes for validation engine and for using torchrun --- scripts/hpo.py | 3 +++ scripts/test.py | 3 +++ scripts/train.py | 3 +++ src/massive/utils/training_utils.py | 5 +++++ 4 files changed, 14 insertions(+) diff --git a/scripts/hpo.py b/scripts/hpo.py index 951d06c..3c6b799 100644 --- a/scripts/hpo.py +++ b/scripts/hpo.py @@ -17,6 +17,7 @@ import argparse import datetime import logging +import os import sys import datasets @@ -50,6 +51,8 @@ def main(): trainer_args = MASSIVETrainingArguments(**conf.get('train_val.trainer_args')) if args.local_rank: trainer_args.local_rank = int(args.local_rank) + elif os.getenv('LOCAL_RANK'): + trainer_args.local_rank = int(os.environ['LOCAL_RANK']) # Setup logging logging.basicConfig( diff --git a/scripts/test.py b/scripts/test.py index bb67bb0..ee25361 100644 --- a/scripts/test.py +++ b/scripts/test.py @@ -17,6 +17,7 @@ import argparse import datetime import logging +import os import pprint import sys import time @@ -53,6 +54,8 @@ def main(): trainer_args = MASSIVETrainingArguments(**conf.get('test.trainer_args')) if args.local_rank: trainer_args.local_rank = int(args.local_rank) + elif os.getenv('LOCAL_RANK'): + trainer_args.local_rank = int(os.environ['LOCAL_RANK']) # Setup logging logging.basicConfig( diff --git a/scripts/train.py b/scripts/train.py index 9ec36ff..1bfe579 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -17,6 +17,7 @@ import argparse import datetime import logging +import os import sys import datasets @@ -49,6 +50,8 @@ def main(): trainer_args = MASSIVETrainingArguments(**conf.get('train_val.trainer_args')) if args.local_rank: trainer_args.local_rank = int(args.local_rank) + elif os.getenv('LOCAL_RANK'): + trainer_args.local_rank = int(os.environ['LOCAL_RANK']) # Setup logging logging.basicConfig( diff --git a/src/massive/utils/training_utils.py b/src/massive/utils/training_utils.py index 276ac22..a8e4273 100644 --- a/src/massive/utils/training_utils.py +++ b/src/massive/utils/training_utils.py @@ -470,6 +470,11 @@ def eval_preds(pred_intents=None, lab_intents=None, pred_slots=None, lab_slots=N if type(pred) == list: pred = pred[:len(lab)] + [pad]*(len(lab) - len(pred)) + # Fix for Issue 21 -- subwords after the first one from a word should be ignored + for i, x in enumerate(lab): + if x == -100: + pred[i] = -100 + # convert to BIO bio_slot_labels.append( convert_to_bio(lab, outside=labels_ignore, labels_merge=labels_merge)