Skip to content

Commit

Permalink
Fix and improve documentation for LEDForConditionalGeneration (#12303)
Browse files Browse the repository at this point in the history
* Replace conditional generation example (fixes #12268)

* Replace model in summarization example with finetuned checkpoint, adapt example text

* Fix typo in new summarization example

* Fix docstring formatting, add missing import statement to example
  • Loading branch information
ionicsolutions authored Jun 22, 2021
1 parent 1498eb9 commit 032d56a
Showing 1 changed file with 37 additions and 15 deletions.
52 changes: 37 additions & 15 deletions src/transformers/models/led/modeling_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,17 +1436,43 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
LED_GENERATION_EXAMPLE = r"""
Summarization example::
>>> from transformers import LEDTokenizer, LEDForConditionalGeneration, LEDConfig
>>> model = LEDForConditionalGeneration.from_pretrained('allenai/led-base-16384')
>>> tokenizer = LEDTokenizer.from_pretrained('allenai/led-base-16384')
>>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
>>> import torch
>>> from transformers import LEDTokenizer, LEDForConditionalGeneration
>>> model = LEDForConditionalGeneration.from_pretrained('allenai/led-large-16384-arxiv')
>>> tokenizer = LEDTokenizer.from_pretrained('allenai/led-large-16384-arxiv')
>>> ARTICLE_TO_SUMMARIZE = '''Transformers (Vaswani et al., 2017) have achieved state-of-the-art
... results in a wide range of natural language tasks including generative
... language modeling (Dai et al., 2019; Radford et al., 2019) and discriminative
... language understanding (Devlin et al., 2019). This success is partly due to
... the self-attention component which enables the network to capture contextual
... information from the entire sequence. While powerful, the memory and computational
... requirements of self-attention grow quadratically with sequence length, making
... it infeasible (or very expensive) to process long sequences.
...
... To address this limitation, we present Longformer, a modified Transformer
... architecture with a self-attention operation that scales linearly with the
... sequence length, making it versatile for processing long documents (Fig 1). This
... is an advantage for natural language tasks such as long document classification,
... question answering (QA), and coreference resolution, where existing approaches
... partition or shorten the long context into smaller sequences that fall within the
... typical 512 token limit of BERT-style pretrained models. Such partitioning could
... potentially result in loss of important cross-partition information, and to
... mitigate this problem, existing methods often rely on complex architectures to
... address such interactions. On the other hand, our proposed Longformer is able to
... build contextual representations of the entire context using multiple layers of
... attention, reducing the need for task-specific architectures.'''
>>> inputs = tokenizer.encode(ARTICLE_TO_SUMMARIZE, return_tensors='pt')
>>> # Global attention on the first token (cf. Beltagy et al. 2020)
>>> global_attention_mask = torch.zeros_like(inputs)
>>> global_attention_mask[:, 0] = 1
>>> # Generate Summary
>>> summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True)
>>> print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
>>> summary_ids = model.generate(inputs, global_attention_mask=global_attention_mask,
... num_beams=3, max_length=32, early_stopping=True)
>>> print(tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True))
"""

LED_INPUTS_DOCSTRING = r"""
Expand Down Expand Up @@ -2305,13 +2331,9 @@ def forward(
>>> model = LEDForConditionalGeneration.from_pretrained('allenai/led-base-16384')
>>> input_ids = tokenizer([TXT], return_tensors='pt')['input_ids']
>>> logits = model(input_ids).logits
>>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
>>> probs = logits[0, masked_index].softmax(dim=0)
>>> values, predictions = probs.topk(5)
>>> tokenizer.decode(predictions).split()
>>> prediction = model.generate(input_ids)[0]
>>> print(tokenizer.decode(prediction, skip_special_tokens=True))
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

Expand Down

0 comments on commit 032d56a

Please sign in to comment.