Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tokenizer] Unify tokenizer _pad #9280

Merged
Merged
Next Next commit
update attention_mask padding
DrownFish19 committed Oct 16, 2024
commit a6bfd7492c121c8dbb1959af89151213fe3c02a9
23 changes: 20 additions & 3 deletions paddlenlp/transformers/tokenizer_utils_base.py
Original file line number Diff line number Diff line change
@@ -3189,8 +3189,16 @@ def _pad(

if self.padding_side == "right":
if return_attention_mask:

encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
if len(encoded_inputs["attention_mask"].shape) > 2:
# attention_mask shape [1,seq_len,seq_len]
encoded_inputs["attention_mask"] = np.pad(
encoded_inputs["attention_mask"],
pad_width=[(0, 0), (0, difference), (0, difference)],
mode="constant",
constant_values=0,
)
else:
encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = (
encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
@@ -3209,7 +3217,16 @@ def _pad(
encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
elif self.padding_side == "left":
if return_attention_mask:
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
if len(encoded_inputs["attention_mask"].shape) > 2:
# attention_mask shape [1,seq_len,seq_len]
encoded_inputs["attention_mask"] = np.pad(
encoded_inputs["attention_mask"],
pad_width=[(0, 0), (difference, 0), (difference, 0)],
mode="constant",
constant_values=0,
)
else:
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
"token_type_ids"