Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flax] Correct shift labels for seq2seq models in Flax #12720

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
f7197df
fix_torch_device_generate_test
patrickvonplaten May 19, 2021
5f70018
remove @
patrickvonplaten May 19, 2021
2da7a31
Merge branch 'master' of https://github.com/patrickvonplaten/transfor…
patrickvonplaten Jun 23, 2021
31c1132
Merge branch 'master' of https://github.com/huggingface/transformers
patrickvonplaten Jun 28, 2021
bf193bc
Merge branch 'master' of https://github.com/huggingface/transformers
patrickvonplaten Jun 30, 2021
2f09a75
Merge branch 'master' of https://github.com/huggingface/transformers
patrickvonplaten Jul 1, 2021
b42f12f
Merge branch 'master' of https://github.com/huggingface/transformers
patrickvonplaten Jul 4, 2021
ca3d9d0
Merge branch 'master' of https://github.com/huggingface/transformers
patrickvonplaten Jul 6, 2021
012433f
Merge branch 'master' of https://github.com/huggingface/transformers
patrickvonplaten Jul 8, 2021
5f0b7d1
Merge branch 'master' of https://github.com/huggingface/transformers
patrickvonplaten Jul 9, 2021
ae7ef40
Merge branch 'master' of https://github.com/huggingface/transformers
patrickvonplaten Jul 9, 2021
016fa5d
Merge branch 'master' of https://github.com/huggingface/transformers
patrickvonplaten Jul 12, 2021
9f5b0eb
Merge branch 'master' of https://github.com/huggingface/transformers
patrickvonplaten Jul 13, 2021
55b7109
Merge branch 'master' of https://github.com/huggingface/transformers
patrickvonplaten Jul 14, 2021
50c69bb
push
patrickvonplaten Jul 14, 2021
6c00601
fix marian
patrickvonplaten Jul 14, 2021
a2622db
fix
patrickvonplaten Jul 14, 2021
7855ea8
up
patrickvonplaten Jul 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/transformers/models/bart/modeling_flax_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from functools import partial
from typing import Callable, Optional, Tuple

import numpy as np

import flax.linen as nn
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -212,15 +214,15 @@
"""


def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
"""
Shift input ids one token to the right.
"""
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
shifted_input_ids = np.zeros_like(input_ids)
shifted_input_ids[:, 1:] = input_ids[:, :-1]
shifted_input_ids[:, 0] = decoder_start_token_id

shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids


Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/marian/modeling_flax_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,11 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_
"""
Shift input ids one token to the right.
"""
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
shifted_input_ids = np.zeros_like(input_ids)
shifted_input_ids[:, 1:] = input_ids[:, :-1]
shifted_input_ids[:, 0] = decoder_start_token_id

shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids


Expand Down
19 changes: 10 additions & 9 deletions src/transformers/models/mbart/modeling_flax_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from functools import partial
from typing import Callable, Optional, Tuple

import numpy as np

import flax.linen as nn
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -217,20 +219,19 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int) -> jnp.ndarray
Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
have a single `decoder_start_token_id` in contrast to other Bart-like models.
"""
prev_output_tokens = jnp.array(input_ids).clone()
prev_output_tokens = np.array(input_ids).copy()

assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."

# replace possible -100 values in labels by `pad_token_id`
prev_output_tokens = jnp.where(prev_output_tokens == -100, pad_token_id, input_ids)
index_of_eos = (jnp.where(prev_output_tokens != pad_token_id, 1, 0).sum(axis=-1) - 1).reshape(-1, 1)
decoder_start_tokens = jnp.array(
[prev_output_tokens[i, eos_idx] for i, eos_idx in enumerate(index_of_eos)]
prev_output_tokens = np.where(prev_output_tokens == -100, pad_token_id, input_ids)
index_of_eos = (np.where(prev_output_tokens != pad_token_id, 1, 0).sum(axis=-1) - 1).reshape(-1, 1)
decoder_start_tokens = np.array(
[prev_output_tokens[i, eos_idx] for i, eos_idx in enumerate(index_of_eos)], dtype=np.int32
).squeeze()
# for loop basically does jax-compatible version of prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
for i in range(prev_output_tokens.shape[1], 0, -1):
prev_output_tokens = jax.ops.index_update(prev_output_tokens, (..., i), prev_output_tokens[:, i - 1])
prev_output_tokens = jax.ops.index_update(prev_output_tokens, (..., 0), decoder_start_tokens)

prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].copy()
prev_output_tokens[:, 0] = decoder_start_tokens

return prev_output_tokens

Expand Down
11 changes: 6 additions & 5 deletions src/transformers/models/t5/modeling_flax_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,16 @@
_TOKENIZER_FOR_DOC = "T5Tokenizer"


def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
"""
Shift input ids one token to the right.
"""
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
shifted_input_ids = np.zeros_like(input_ids)
shifted_input_ids[:, 1:] = input_ids[:, :-1]
shifted_input_ids[:, 0] = decoder_start_token_id

shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids


Expand Down