Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Yada Pruksachatkun committed May 23, 2019
2 parents 0e622dd + b43fba6 commit 978d8b7
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 77 deletions.
4 changes: 3 additions & 1 deletion config/superglue-bert.conf
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ exp_name = "bert-large-cased"
max_seq_len = 256 // Mainly needed for MultiRC, to avoid over-truncating
// But not 512 as that is really hard to fit in memory.
tokenizer = "bert-large-cased"

// Model settings
bert_model_name = "bert-large-cased"
bert_embeddings_mode = "top"
Expand Down Expand Up @@ -42,3 +41,6 @@ do_full_eval = 1
write_preds = "val,test"
write_strict_glue_format = 1

// For WSC
classifier_loss_fn = "softmax"
classifier_span_pooling = "attn"
66 changes: 29 additions & 37 deletions src/tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@
from ..allennlp_mods.correlation import Correlation
from ..allennlp_mods.numeric_field import NumericField
from ..utils import utils
from ..utils.data_loaders import get_tag_list, load_diagnostic_tsv, load_tsv, process_sentence
from ..utils.data_loaders import (
get_tag_list,
load_diagnostic_tsv,
load_span_data,
load_tsv,
process_sentence,
)
from ..utils.tokenizers import get_tokenizer
from .registry import register_task # global task registry

Expand Down Expand Up @@ -1995,24 +2001,13 @@ def load_data(self):

class SpanClassificationTask(Task):
"""
Generic class for span tasks.
Generic class for span tasks.
Acts as a classifier, but with multiple targets for each input text.
Targets are of the form (span1, span2,..., span_n, label), where the spans are
half-open token intervals [i, j).
The number of spans is constant across examples.
"""

@property
def _tokenizer_suffix(self):
""""
Suffix to make sure we use the correct source files,
based on the given tokenizer.
"""
if self.tokenizer_name:
return ".retokenized." + self.tokenizer_name
else:
return ""

def tokenizer_is_supported(self, tokenizer_name):
""" Check if the tokenizer is supported for this task. """
# Assume all tokenizers supported; if retokenized data not found
Expand Down Expand Up @@ -2049,8 +2044,7 @@ def __init__(
assert label_file is not None
assert files_by_split is not None
self._files_by_split = {
split: os.path.join(path, fname) + self._tokenizer_suffix
for split, fname in files_by_split.items()
split: os.path.join(path, fname) for split, fname in files_by_split.items()
}
self.num_spans = num_spans
self.max_seq_len = max_seq_len
Expand Down Expand Up @@ -2089,15 +2083,6 @@ def _stream_records(self, filename):
filename,
)

def load_data(self):
iters_by_split = collections.OrderedDict()
for split, filename in self._files_by_split.items():
iter = list(self._stream_records(filename))
iters_by_split[split] = iter
self._iters_by_split = iters_by_split
self.all_labels = list(utils.load_lines(self.label_file))
self.n_classes = len(self.all_labels)

def get_split_text(self, split: str):
"""
Get split text as iterable of records.
Expand Down Expand Up @@ -2139,19 +2124,15 @@ def make_instance(self, record, idx, indexers) -> Type[Instance]:

for i in range(self.num_spans):
example["span" + str(i + 1) + "s"] = ListField(
[
self._make_span_field(t["span" + str(i + 1)], text_field, 1)
for t in record["targets"]
]
[self._make_span_field(record["target"]["span" + str(i + 1)], text_field, 1)]
)

labels = [utils.wrap_singleton_string(t["label"]) for t in record["targets"]]
example["labels"] = ListField(
[
MultiLabelField(
label_set, label_namespace=self._label_namespace, skip_indexing=False
[str(record["label"])],
label_namespace=self._label_namespace,
skip_indexing=False,
)
for label_set in labels
]
)
return Instance(example)
Expand Down Expand Up @@ -2533,17 +2514,28 @@ def get_metrics(self, reset=False):
@register_task("winograd-coreference", rel_path="winograd-coref")
class WinogradCoreferenceTask(SpanClassificationTask):
def __init__(self, path, **kw):
self._files_by_split = {
"train": "train.jsonl",
"val": "val.jsonl",
"test": "test_with_labels.jsonl",
}
self._files_by_split = {"train": "train.jsonl", "val": "val.jsonl", "test": "test.jsonl"}
self.num_spans = 2
super().__init__(
files_by_split=self._files_by_split, label_file="labels.txt", path=path, **kw
)
self.n_classes = 2
self.val_metric = "%s_acc" % self.name

def load_data(self):
iters_by_split = collections.OrderedDict()
for split, filename in self._files_by_split.items():
if filename.endswith("test.jsonl"):
iters_by_split[split] = load_span_data(
self.tokenizer_name, filename, has_labels=False
)
else:
iters_by_split[split] = load_span_data(self.tokenizer_name, filename)
self._iters_by_split = iters_by_split

def get_all_labels(self):
return ["True", "False"]

def update_metrics(self, logits, labels, tagmask=None):
logits, labels = logits.detach(), labels.detach()

Expand Down
2 changes: 1 addition & 1 deletion src/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,7 @@ def _save_checkpoint(self, training_state, phase="pretrain", new_best_macro=Fals
training_state,
os.path.join(
self._serialization_dir,
"pretraining_state_{}_epoch_{}{}.th".format(phase, epoch, best_str),
"metric_state_{}_epoch_{}{}.th".format(phase, epoch, best_str),
),
)

Expand Down
28 changes: 28 additions & 0 deletions src/utils/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,39 @@
from allennlp.data import vocabulary

from .tokenizers import get_tokenizer
from .retokenize import realign_spans

BERT_CLS_TOK, BERT_SEP_TOK = "[CLS]", "[SEP]"
SOS_TOK, EOS_TOK = "<SOS>", "<EOS>"


def load_span_data(tokenizer_name, file_name, label_fn=None, has_labels=True):
"""
Load a span-related task file in .jsonl format, does re-alignment of spans, and tokenizes the text.
Re-alignment of spans involves transforming the spans so that it matches the text after
tokenization.
For example, given the original text: [Mr., Porter, is, nice] and bert-base-cased tokenization, we get
[Mr, ., Por, ter, is, nice ]. If the original span indices was [0,2], under the new tokenization,
it becomes [0, 3].
The task file should of be of the following form:
text: str,
label: bool
target: dict that contains the spans
Args:
tokenizer_name: str,
file_name: str,
label_fn: function that expects a row and outputs a transformed row with labels tarnsformed.
Returns:
List of dictionaries of the aligned spans and tokenized text.
"""
rows = pd.read_json(file_name, lines=True)
# realign spans
rows = rows.apply(lambda x: realign_spans(x, tokenizer_name), axis=1)
if has_labels is False:
rows["label"] = False
return list(rows.T.to_dict().values())


def load_tsv(
tokenizer_name,
data_file,
Expand Down
73 changes: 73 additions & 0 deletions src/utils/retokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,79 @@ def _mat_from_spans_sparse(spans: Sequence[Tuple[int, int]], n_chars: int) -> Ma
return sparse.csr_matrix((data, (ridxs, cidxs)), shape=(len(spans), n_chars))


def realign_spans(record, tokenizer_name):
"""
Builds the indices alignment while also tokenizing the input
piece by piece.
Only BERT and Moses tokenization is supported currently.
Parameters
-----------------------
record: dict with the below fields
text: str
targets: list of dictionaries
label: bool
span1_index: int, start index of first span
span1_text: str, text of first span
span2_index: int, start index of second span
span2_text: str, text of second span
tokenizer_name: str
Returns
------------------------
record: dict with the below fields:
text: str in tokenized form
targets: dictionary with the below fields
-label: bool
-span_1: (int, int) of token indices
-span1_text: str, the string
-span2: (int, int) of token indices
-span2_text: str, the string
"""

# find span indices and text
text = record["text"].split()
span1 = record["target"]["span1_index"]
span1_text = record["target"]["span1_text"]
span2 = record["target"]["span2_index"]
span2_text = record["target"]["span2_text"]

# construct end spans given span text space-tokenized length
span1 = [span1, span1 + len(span1_text.strip().split())]
span2 = [span2, span2 + len(span2_text.strip().split())]
indices = [span1, span2]

sorted_indices = sorted(indices, key=lambda x: x[0])
current_tokenization = []
span_mapping = {}

# align first span to tokenized text
aligner_fn = get_aligner_fn(tokenizer_name)
_, new_tokens = aligner_fn(" ".join(text[: sorted_indices[0][0]]))
current_tokenization.extend(new_tokens)
new_span1start = len(current_tokenization)
_, span_tokens = aligner_fn(" ".join(text[sorted_indices[0][0] : sorted_indices[0][1]]))
current_tokenization.extend(span_tokens)
new_span1end = len(current_tokenization)
span_mapping[sorted_indices[0][0]] = [new_span1start, new_span1end]

# re-indexing second span
_, new_tokens = aligner_fn(" ".join(text[sorted_indices[0][1] : sorted_indices[1][0]]))
current_tokenization.extend(new_tokens)
new_span2start = len(current_tokenization)
_, span_tokens = aligner_fn(" ".join(text[sorted_indices[1][0] : sorted_indices[1][1]]))
current_tokenization.extend(span_tokens)
new_span2end = len(current_tokenization)
span_mapping[sorted_indices[1][0]] = [new_span2start, new_span2end]

# save back into record
_, all_text = aligner_fn(" ".join(text))
record["target"]["span1"] = span_mapping[record["target"]["span1_index"]]
record["target"]["span2"] = span_mapping[record["target"]["span2_index"]]
record["text"] = " ".join(all_text)
return record


class TokenAligner(object):
"""Align two similiar tokenizations.
Expand Down
Loading

0 comments on commit 978d8b7

Please sign in to comment.