Skip to content

Commit

Permalink
Merge pull request #6 from jklj077/patch-6
Browse files Browse the repository at this point in the history
fix merge reading and update tests
  • Loading branch information
JustinLin610 authored Jan 13, 2024
2 parents f419098 + 6b3247b commit 68968f5
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 38 deletions.
8 changes: 6 additions & 2 deletions src/transformers/models/qwen2/tokenization_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,13 @@ def __init__(
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_merges = []
with open(merges_file, encoding="utf-8") as merges_handle:
bpe_merges = merges_handle.read().split("\n")[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
for line in merges_handle:
line = line.strip()
if not line or line.startswith("#"):
continue
bpe_merges.append(tuple(line.split()))
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
# NOTE: the cache can grow without bound and will get really large for long running processes
# (esp. for texts of language that do not use space between word, e.g. Chinese); technically
Expand Down
57 changes: 21 additions & 36 deletions tests/models/qwen2/test_tokenization_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import unittest

from transformers import Qwen2Tokenizer, Qwen2TokenizerFast
from transformers.models.qwen2.tokenization_qwen2 import VOCAB_FILES_NAMES
from transformers.models.qwen2.tokenization_qwen2 import VOCAB_FILES_NAMES, bytes_to_unicode
from transformers.testing_utils import require_tokenizers

from ...test_tokenization_common import TokenizerTesterMixin
Expand All @@ -38,39 +38,24 @@ class Qwen2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def setUp(self):
super().setUp()

# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab = [
"l",
"o",
"w",
"e",
"r",
"s",
"t",
"i",
"d",
"n",
"\u0120",
"\u0120l",
"\u0120n",
"\u0120lo",
"\u0120low",
"er",
"\u0120lowest",
"\u0120newer",
"\u0120wider",
"0",
"1",
"01",
"}",
";",
"\u010a",
";}",
";}\u010a",
"\u0315",
"\u00cf",
"\u00cf\u0135",
]
vocab = list(bytes_to_unicode().values())
vocab.extend(
[
"\u0120l",
"\u0120n",
"\u0120lo",
"\u0120low",
"er",
"\u0120lowest",
"\u0120newer",
"\u0120wider",
"01",
";}",
";}\u010a",
"\u00cf\u0135",
]
)

vocab_tokens = dict(zip(vocab, range(len(vocab))))

merges = [
Expand All @@ -86,7 +71,7 @@ def setUp(self):
]

# unk_token is needed, because this stub tokenizer is not complete at the byte level
self.special_tokens_map = {"eos_token": "<|endoftext|>", "pad_token": "<|endoftext|>", "unk_token": "<|unk|>"}
self.special_tokens_map = {"eos_token": "<|endoftext|>"}

self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
Expand Down Expand Up @@ -139,5 +124,5 @@ def test_python_full_tokenizer(self):
self.assertListEqual(tokens, bpe_tokens)

input_tokens = tokens
input_bpe_tokens = [0, 1, 2, 15, 14, 15, 10, 9, 3, 2, 15, 10, 19, 20, 19, 26, 30, 29]
input_bpe_tokens = [75, 78, 86, 260, 259, 260, 220, 77, 68, 86, 260, 220, 15, 16, 15, 266, 268, 267]
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)

0 comments on commit 68968f5

Please sign in to comment.