Skip to content

Commit

Permalink
feat: Init tokenizer from filehandle (#76)
Browse files Browse the repository at this point in the history
* feat: allow creating JurassicTokenizer from model file handle

* fix: Add default for model_path and model_file_handle

* feat: Add JurassicTokenizer.from_file_path classmethod

* fix: remove model_path=None in JurassicTokenizer.from_file_handle

* fix: rename _assert_exactly_one to _validate_init and make it not static

* refactor: semantics

* test: Added tests

---------

Co-authored-by: Asaf Gardin <asafg@ai21.com>
  • Loading branch information
tomeras91 and asafgardin authored Jan 2, 2024
1 parent 7b8348d commit dcb73a7
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 4 deletions.
28 changes: 25 additions & 3 deletions ai21_tokenizer/jurassic_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import re
from dataclasses import dataclass
from typing import List, Union, Optional, Dict, Any, Tuple
from typing import List, Union, Optional, Dict, Any, Tuple, BinaryIO

import sentencepiece as spm

Expand All @@ -19,11 +19,16 @@ class SpaceSymbol:
class JurassicTokenizer(BaseTokenizer):
def __init__(
self,
model_path: PathLike,
model_path: Optional[PathLike] = None,
model_file_handle: Optional[BinaryIO] = None,
config: Optional[Dict[str, Any]] = None,
):
self._validate_init(model_path=model_path, model_file_handle=model_file_handle)

model_proto = load_binary(model_path) if model_path else model_file_handle.read()

# noinspection PyArgumentList
self._sp = spm.SentencePieceProcessor(model_proto=load_binary(model_path))
self._sp = spm.SentencePieceProcessor(model_proto=model_proto)
config = config or {}

self.pad_id = config.get("pad_id")
Expand Down Expand Up @@ -52,6 +57,13 @@ def __init__(
self._space_mode = config.get("space_mode")
self._space_tokens = self._map_space_tokens()

def _validate_init(self, model_path: Optional[PathLike], model_file_handle: Optional[BinaryIO]) -> None:
if model_path is None and model_file_handle is None:
raise ValueError("Must provide exactly one of model_path or model_file_handle. Got none.")

if model_path is not None and model_file_handle is not None:
raise ValueError("Must provide exactly one of model_path or model_file_handle. Got both.")

def _map_space_tokens(self) -> List[SpaceSymbol]:
res = []
for count in range(32, 0, -1):
Expand Down Expand Up @@ -226,3 +238,13 @@ def convert_ids_to_tokens(self, token_ids: Union[int, List[int]], **kwargs) -> U
return self._id_to_token(token_ids)

return [self._id_to_token(token_id) for token_id in token_ids]

@classmethod
def from_file_handle(
cls, model_file_handle: BinaryIO, config: Optional[Dict[str, Any]] = None
) -> JurassicTokenizer:
return cls(model_file_handle=model_file_handle, config=config)

@classmethod
def from_file_path(cls, model_path: PathLike, config: Optional[Dict[str, Any]] = None) -> JurassicTokenizer:
return cls(model_path=model_path, config=config)
76 changes: 75 additions & 1 deletion tests/test_jurassic_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import json
from pathlib import Path
from typing import Union, List
from typing import Union, List, BinaryIO, Optional

import pytest

from ai21_tokenizer.jurassic_tokenizer import JurassicTokenizer
from ai21_tokenizer.utils import PathLike

_LOCAL_RESOURCES_PATH = Path(__file__).parents[1] / "ai21_tokenizer" / "resources" / "j2-tokenizer"


def test_tokenizer_encode_decode(tokenizer: JurassicTokenizer):
Expand Down Expand Up @@ -87,3 +90,74 @@ def test_tokenizer__convert_tokens_to_ids(
actual_ids = tokenizer.convert_tokens_to_ids(tokens)

assert actual_ids == expected_ids


def test_tokenizer__from_file_handle():
text = "Hello world!"
model_config = {
"vocab_size": 262144,
"pad_id": 0,
"bos_id": 1,
"eos_id": 2,
"unk_id": 3,
"add_dummy_prefix": False,
"newline_piece": "<|newline|>",
"number_mode": "right_keep",
"space_mode": "left",
}

with (_LOCAL_RESOURCES_PATH / "j2-tokenizer.model").open("rb") as tokenizer_file:
tokenizer = JurassicTokenizer.from_file_handle(model_file_handle=tokenizer_file, config=model_config)

encoded = tokenizer.encode(text)
decoded = tokenizer.decode(encoded)

assert decoded == text


def test_tokenizer__from_file_path():
text = "Hello world!"
model_config = {
"vocab_size": 262144,
"pad_id": 0,
"bos_id": 1,
"eos_id": 2,
"unk_id": 3,
"add_dummy_prefix": False,
"newline_piece": "<|newline|>",
"number_mode": "right_keep",
"space_mode": "left",
}

tokenizer = JurassicTokenizer.from_file_path(
model_path=(_LOCAL_RESOURCES_PATH / "j2-tokenizer.model"), config=model_config
)

encoded = tokenizer.encode(text)
decoded = tokenizer.decode(encoded)

assert decoded == text


@pytest.mark.parametrize(
ids=[
"when_model_path_and_file_handle_are_none__should_raise_value_error",
"when_model_path_and_file_handle_are_not_none__should_raise_value_error",
],
argnames=["model_path", "model_file_handle", "expected_error_message"],
argvalues=[
(None, None, "Must provide exactly one of model_path or model_file_handle. Got none."),
(
Path("some_path"),
"some_file_handle",
"Must provide exactly one of model_path or model_file_handle. Got both.",
),
],
)
def test_tokenizer__(
model_path: Optional[PathLike], model_file_handle: Optional[BinaryIO], expected_error_message: str
):
with pytest.raises(ValueError) as error:
JurassicTokenizer(model_file_handle=model_file_handle, model_path=model_path, config={})

assert error.value.args[0] == expected_error_message

0 comments on commit dcb73a7

Please sign in to comment.