Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposal: Offset based Token Classification utilities #7019

Closed
talolard opened this issue Sep 8, 2020 · 11 comments
Closed

Proposal: Offset based Token Classification utilities #7019

talolard opened this issue Sep 8, 2020 · 11 comments
Labels

Comments

@talolard
Copy link

talolard commented Sep 8, 2020

🚀 Feature request

Hi. So we work a lot with span annotations on text that isn't tokenized and want a "canonical" way to work with that. I have some ideas and rough implementations, so I'm looking for feedback on if this belongs in the library, and if the proposed implementation is more or less good.

I also think there is a good chance that everything I want exists, and the only solution needed is slightly clearer documentation. I should hope that's the case and happy to document if someone can point me in the right direction.

The Desired Capabilities

What I'd like is a canonical way to:

  • Tokenize the examples in the dataset
  • Align my annotations with the output tokens (see notes below)
  • Have the tokens and labels correctly padded to the max length of an example in the batch or max_sequence_length
  • Have a convenient function that returns predicted offsets

Some Nice To Haves

  • It would be nice if such a utility internally handled tagging schemes like IOB BIOES internally and optionally exposed them in the output or "folded" them to the core entities.
  • It would be nice if there was a recommended/default strategy implemented for handling examples that are longer then the max_sequence_length
  • It would be amazing if we could pass labels to the tokenizer and have the alignment happen in Rust (in parallel). But I don't know Rust and I have a sense this is complicated so I won't be taking that on myself, and assuming that this is happening in Python.

Current State and what I'm missing

  • The docs and examples for Token Classification assume that the text is pre-tokenized.
  • For a word that has a label and is tokenized to multiple tokens, it is recommended to place the label on the first token and "ignore" the following tokens
  • The example pads all examples to max_sequence_length which is a big performance hit (as opposed to bucketing by length and padding dynamically)
  • The example loads the entire dataset at once in memory. I'm not sure if this is a real problem or I'm being nitpicky, but I think "the right way" to do this would be to lazy load a batch or a few batches.

Alignment

The path to align tokens to span annotations is by using the return_offsets_mapping flag on the tokenizer (which is awesome!).
There are probably a few strategies, I've been using this
I use logic like this:

def align_tokens_to_annos(offsets,annos):
    anno_ix =0
    results =[]
    done =len(annos)==0
    for offset in offsets:

        if done == True:
            results.append(dict(offset=offset,tag='O',))
        else:
            anno = annos[anno_ix]
            start, end = offset
            if end < anno['start']:
                # the offset is before the next annotation
                results.append(dict(offset=offset, tag='O', ))
            elif start <=anno['start'] and end <=anno['end']:
                results.append(dict(offset=offset, tag=f'B-{anno["tag"]}',))
            elif start>=anno['start'] and end<=anno['end']:
                results.append(dict(offset=offset, tag=f'I-{anno["tag"]}', ))
            elif start>=anno['start'] and end>anno['end']:
                anno_ix += 1
                results.append(dict(offset=offset, tag=f'E-{anno["tag"]}', ))
            else:
                raise Exception(f"Funny Overlap {offset},{anno}",)

            if anno_ix>=len(annos):
                done=True
    return results

And then call that function inside add_labels here

        res_batch = tokenizer([s['text'] for s in pre_batch],return_offsets_mapping=True,padding=True)
        offsets_batch = res_batch.pop('offset_mapping')
        res_batch['labels'] =[]
        for i in range(len(offsets_batch)):
          labels = add_labels(res_batch['input_ids'][i],offsets_batch[i],pre_batch[i]['annotations'])
          res_batch['labels'].append(labels)

This works, and it's nice because the padding is consistent with the longest sentence so bucketing gives a big boost. But, the add_labels stuff is in python and thus sequential over the examples and not super fast. I haven't measured this to confirm it's a problem, just bring it up.

Desired Solution

I need most of this stuff so I'm going to make it. I could do it

The current "NER" examples and issues assume that text is pre-tokenized. Our use case is such that the full text is not tokenized and the labels for "NER" come as offsets. I propose a utility /example to handle that scenario because I haven't been able to find one.

In practice, most values of X don't need any modification, and doing what I propose (below) in Rust is beyond me, so this might boil down to a utility class and documentation.

Motivation

I make text annotation tools and our output is span annotations on untokenized text. I want our users to be able to easily use transformers. I suspect from my (limited) experience that in many non-academic use cases, span annotations on untokenized text is the norm and that others would benefit from this as well.

Possible ways to address this

I can imagine a few scenarios here

  • This is out of scope Maybe this isn't something that should be handled by transformers at all, and delegated to a library and blog post
  • This is in scope and just needs documentation e.g. all the things I mentioned are things transformers should and can already do. In that case the solution would be pointing someone (me) to the right functions and adding some documentation
  • **This is in scope and should be a set of utilities ** Solving this could be as simple as making a file similar to utils_ner.py. I think that would be the simplest way to get something usable and gather feedback see if anyone else cares
  • This is in scope but should be done in Rust soon If we want to be performance purists, it would make sense to handle the alignment of span based labels in Rust. I don't know Rust so I can't help much and I don't know if there is any appetite or capacity from someone that does, or if it's worth the (presumably) additional effort.

Your contribution

I'd be happy to implement and submit a PR, or make an external library or add to a relevant existing one.

Related issues

@thomwolf
Copy link
Member

thomwolf commented Sep 9, 2020

Hi, this is a very nice issue and I plan to work soon (in the coming 2 weeks) on related things (improving the examples to make full use of the Rust tokenization features). I'll re-read this issue (and all the links) to extract all the details and likely come back to you at that time.

In the meantime, here are two elements for your project:

  • the first is that for the fast tokenizers, the output of the tokenizer (a BatchEncoding instance) is actually a special kind of python dict with some super fast alignement methods powered by Rust, including a char_to_token alignement method that could maybe make your align_tokens_to_annos method a lot simpler and faster. You can read about it here: https://huggingface.co/transformers/main_classes/tokenizer.html#transformers.BatchEncoding.char_to_token
  • the second element is that support for sentence piece is almost there in tokenizers so we will soon be able to use fast tokenizers for almost all the models.

@talolard
Copy link
Author

Thanks! That's super helpful.
I also found out I can iterate over the batch which is really nice.

I did find a bug and opened an issue

from transformers import BertTokenizerFast,GPT2TokenizerFast
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased',)


for i in range(1,5):
    txt = "💩"*i
    enc = tokenizer(txt,return_offsets_mapping=True)
    token_at_i = enc.char_to_token(i-1)
    dec = tokenizer.decode(enc['input_ids'])
    
    print (f" I wrote {txt} but got back '{dec}' and char_to_tokens({i-1}) returned {token_at_i}")
 I wrote 💩 but got back '[CLS] [UNK] [SEP]' and char_to_tokens(0) returned 1
 I wrote 💩💩 but got back '[CLS] [UNK] [SEP]' and char_to_tokens(1) returned 1
 I wrote 💩💩💩 but got back '[CLS] [UNK] [SEP]' and char_to_tokens(2) returned 1
 I wrote 💩💩💩💩 but got back '[CLS] [UNK] [SEP]' and char_to_tokens(3) returned 1

@talolard
Copy link
Author

Progress - But Am I doing this right ?

Hi,
I made some progress and got some logic that aligns the tokens to offset annotations. A function call looks like this

batch,labels = tokenize_with_labels(texts,annotations,tokenizer,label_set=label_set)

And then a visualization of the alignment looks like this (which is trying to show annotations across multiple tokens)
image

Question

So I'm trying to get the padding /batching working. I think I've got something good but would love some input on how this might subtly fail.

batch,labels = tokenize_with_labels(texts,annotations,tokenizer,label_set=label_set)
batch.data['labels'] =labels # Put the labels in the BatchEncoding dict

padded = tokenizer.pad(batch,padding='longest') # Call pad, which ignores the labels and offset_mappings
batch_size = len(padded['input_ids'][0]) # Get the padded sequence size
for i in range(len(padded['labels'])): #for each label
    ls = padded['labels'][i]
    difference = batch_size - len(ls) # How much do we need to pad ? 
    padded['labels'][i] =  padded['labels'][i] +[0] *difference # Pad 
    padded['offset_mapping'][i]+=[(0,0)]*difference #pad the offset mapping so we can call convert_to_tensors
    
    
tensors = padded.convert_to_tensors(tensor_type='pt') #convert to a tensor

@thomwolf
Copy link
Member

Hmm I think we should have an option to pad the labels in tokenizer.pad.
Either based on the shape of the labels or with an additional flag.
I'll work on the tokenizers this week. Will add this to the stack.

@talolard
Copy link
Author

I think that presupposes that the user has labels aligned to tokens, or that their is one and only one right way to align labels and tokens, which isn't consistent with the original issue.

When that's not the case, then we need to tokenize, then align labels and finally pad. (Also need to deal with overflow, but I haven't gotten that far yet) . Notably, the user may want to use a BIO,BILSO or other schema and needs access to the tokens to modify the labels accordingly.

Something that confused me as I've been working on this is that the _pad function operates explicitly on named attributes of the batch encoding dict whereas as a user I'd expect it to operate on everything in the underlying encoding.data dict. That however doesn't work because the dict includes offset_mappings which don't tensorize nicely.

Because of the logic involved in alignment, I think that padding of the tokens and labels might be better done outside of the tokenizer, probably with a specialized function / module.
The upside of padding in one go is the efficiency of doing so in Rust, but I'd speculate that for token classification, the running time would be dominated by training anyway, and the efficiency gains wouldn't justify the API complexity or documentation burdon of doing it all in one place.

Also, I think that's a theoretical point because it seems that the padding is done in python anyway ?

I ended up doing

def tokenize_with_labels(
    texts: List[str],
    raw_labels: List[List[SpanAnnotation]],
    tokenizer: PreTrainedTokenizerFast,
    label_set: LabelSet, #Basically the alignment strategy
):
    batch_encodings = tokenizer(
        texts,
        return_offsets_mapping=True,
        padding="longest",
        max_length=256,
        truncation=True,
    )
    batch_labels: IntListList = []
    for encoding, annotations in zip(batch_encodings.encodings, raw_labels):
        batch_labels.append(label_set.align_labels_to_tokens(encoding, annotations))
    return batch_encodings, batch_labels

where align_labels_to_tokens operates on already padded tokens.

I found this the most convenient way to get dynamic batches with a collator

@dataclass
class LTCollator:
    tokenizer: PreTrainedTokenizerFast
    label_set: LabelSet
    padding: PaddingStrategy = True
    max_length: Optional[int] = None

    def __call__(self, texts_and_labels: Example) -> BatchEncoding:
        texts: List[str] = []
        annotations: List[List[SpanAnnotation]] = []
        for (text, annos) in texts_and_labels:
            texts.append(text)
            annotations.append(annos)

        batch, labels = tokenize_with_labels(
            texts, annotations, self.tokenizer, label_set=self.label_set
        )
        del batch["offset_mapping"]
        batch.data["labels"] = labels  # Put the labels in the BatchEncoding dict
        tensors = batch.convert_to_tensors(tensor_type="pt")  # convert to a tensor
        return tensors

@talolard
Copy link
Author

As an example of the end to end flow, (and please No one use this it's a probably buggy work in progress)

from typing import Any, Optional, List, Tuple
from transformers import (
    BertTokenizerFast,
    BertModel,
    BertForMaskedLM,
    BertForTokenClassification,
    TrainingArguments,
)
import torch
from transformers import AdamW, Trainer

from dataclasses import dataclass
from torch.utils.data import Dataset
import json

from torch.utils.data.dataloader import DataLoader
from transformers import PreTrainedTokenizerFast, DataCollatorWithPadding, BatchEncoding
from transformers.tokenization_utils_base import PaddingStrategy

from labelset import LabelSet
from token_types import IntListList, SpanAnnotation
from tokenize_with_labels import tokenize_with_labels

Example = Tuple[str, List[List[SpanAnnotation]]]


@dataclass
class LTCollator:
    tokenizer: PreTrainedTokenizerFast
    label_set: LabelSet
    padding: PaddingStrategy = True
    max_length: Optional[int] = None

    def __call__(self, texts_and_labels: Example) -> BatchEncoding:
        texts: List[str] = []
        annotations: List[List[SpanAnnotation]] = []
        for (text, annos) in texts_and_labels:
            texts.append(text)
            annotations.append(annos)

        batch, labels = tokenize_with_labels(
            texts, annotations, self.tokenizer, label_set=self.label_set
        )
        del batch["offset_mapping"]
        batch.data["labels"] = labels  # Put the labels in the BatchEncoding dict
        tensors = batch.convert_to_tensors(tensor_type="pt")  # convert to a tensor
        return tensors


class LTDataset(Dataset):
    def __init__(
        self, data: Any, tokenizer: PreTrainedTokenizerFast,
    ):
        self.tokenizer = tokenizer
        for example in data["examples"]:
            for a in example["annotations"]:
                a["label"] = a["tag"]
        self.texts = []
        self.annotations = []
        for example in data["examples"]:
            self.texts.append(example["content"])
            self.annotations.append(example["annotations"])

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx) -> Example:

        return self.texts[idx], self.annotations[idx]


@dataclass
class LTDataControls:
    dataset: LTDataset
    collator: LTCollator
    label_set: LabelSet


def lt_data_factory(
    json_path: str, tokenizer: PreTrainedTokenizerFast, max_length=None
):
    data = json.load(open(json_path))
    dataset = LTDataset(data=data, tokenizer=tokenizer)
    tags = list(map(lambda x: x["name"], data["schema"]["tags"]))
    label_set = LabelSet(tags)
    collator = LTCollator(
        max_length=max_length, label_set=label_set, tokenizer=tokenizer
    )
    return LTDataControls(dataset=dataset, label_set=label_set, collator=collator)


if __name__ == "__main__":
    from transformers import BertTokenizerFast, GPT2TokenizerFast

    tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased",)
    data_controls = lt_data_factory(
        "/home/tal/Downloads/small_gold_no_paragr_location_types_false_5_annotations.json",
        tokenizer=tokenizer,
        max_length=256,
    )
    dl = DataLoader(
        data_controls.dataset, collate_fn=data_controls.collator, batch_size=10
    )
    model = BertForTokenClassification.from_pretrained(
        "bert-base-cased", num_labels=len(data_controls.label_set.ids_to_label.values())
    )
    train = Trainer(
        model=model,
        data_collator=data_controls.collator,
        train_dataset=data_controls.dataset,
        args=TrainingArguments("/tmp/trainer", per_device_train_batch_size=2),
    )
    train.train()

@talolard
Copy link
Author

Also, I found this comment by @sgugger about the trainer

Note that there are multiple frameworks that provide generic training loops. The goal of Trainer (I'm assuming you're talking about it since there is no train.py file) is not to replace them or compete with them but to provide an easy way to train and finetune Transformers models. Those models don't take nested inputs, so Trainer does not support this. Those models are expected to return the loss as the first item of their output, so Trainer expects it too.

I think that sentiment might make sense here, that what I'm looking for is outside the scope of the library. If that's the case I would have preferred it be written in big bold letters, rather than the library trying to cater to this use case

@talolard
Copy link
Author

talolard commented Sep 20, 2020

So,
After much rabbit hole, I've written a blog post about the considerations when doing alignment/padding/batching and another walking through an implementation.

It even comes with a repo

so
If we have annotated data like this

[{'annotations': [],
  'content': 'No formal drug interaction studies of Aranesp? have been '
             'performed.',
  'metadata': {'original_id': 'DrugDDI.d390.s0'}},
 {'annotations': [{'end': 13, 'label': 'drug', 'start': 6, 'tag': 'drug'},
                  {'end': 60, 'label': 'drug', 'start': 43, 'tag': 'drug'},
                  {'end': 112, 'label': 'drug', 'start': 105, 'tag': 'drug'},
                  {'end': 177, 'label': 'drug', 'start': 164, 'tag': 'drug'},
                  {'end': 194, 'label': 'drug', 'start': 181, 'tag': 'drug'},
                  {'end': 219, 'label': 'drug', 'start': 211, 'tag': 'drug'},
                  {'end': 238, 'label': 'drug', 'start': 227, 'tag': 'drug'}],
  'content': 'Since PLETAL is extensively metabolized by cytochrome P-450 '
             'isoenzymes, caution should be exercised when PLETAL is '
             'coadministered with inhibitors of C.P.A. such as ketoconazole '
             'and erythromycin or inhibitors of CYP2C19 such as omeprazole.',
  'metadata': {'original_id': 'DrugDDI.d452.s0'}},
 {'annotations': [{'end': 58, 'label': 'drug', 'start': 47, 'tag': 'drug'},
                  {'end': 75, 'label': 'drug', 'start': 62, 'tag': 'drug'},
                  {'end': 135, 'label': 'drug', 'start': 124, 'tag': 'drug'},
                  {'end': 164, 'label': 'drug', 'start': 152, 'tag': 'drug'}],
  'content': 'Pharmacokinetic studies have demonstrated that omeprazole and '
             'erythromycin significantly increased the systemic exposure of '
             'cilostazol and/or its major metabolites.',
  'metadata': {'original_id': 'DrugDDI.d452.s1'}}]

We can do this

from sequence_aligner.labelset import LabelSet
from sequence_aligner.dataset import  TrainingDataset
from sequence_aligner.containers import TraingingBatch
import json
raw = json.load(open('./data/ddi_train.json'))
for example in raw:
    for annotation in example['annotations']:
        #We expect the key of label to be label but the data has tag
        annotation['label'] = annotation['tag']

from torch.utils.data import DataLoader
from transformers import BertForTokenClassification,AdamW
model = BertForTokenClassification.from_pretrained(
    "bert-base-cased", num_labels=len(dataset.label_set.ids_to_label.values())
)
optimizer = AdamW(model.parameters(), lr=5e-6)

dataloader = DataLoader(
    dataset,
    collate_fn=TraingingBatch,
    batch_size=4,
    shuffle=True,
)
for num, batch in enumerate(dataloader):
    loss, logits = model(
        input_ids=batch.input_ids,
        attention_mask=batch.attention_masks,
        labels=batch.labels,
    )
    loss.backward()
    optimizer.step()


-------------------------------

I think most of this is out of scope for the transformers library itself, so am all for closing this issue if no one objects

@julien-c
Copy link
Member

(I attempted to fix the links above, let me know if this is correct @talolard)

@talolard
Copy link
Author

(I attempted to fix the links above, let me know if this is correct @talolard)

Links seem kosher, thanks

@stale
Copy link

stale bot commented Nov 21, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Nov 21, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants