Skip to content

Commit

Permalink
Add an abstract class for Tokenizer (#53)
Browse files Browse the repository at this point in the history
* Add an abstract class for tokenizer

* Add sentence piece tokenizer as a subclass of Tokenizer

* Fix decode method for SentencePieceTokenizer

* Fix circular import issue

* fix type annotations

* fix linting issues

* Format files using pyink

* Update the tokenizer decode interface to return ids instead of str

* format using pyink

* Move Tokenizer class to a tokenizer_api.py file

* Update engine.build_tokenizer method to return SentencePieceTokenizer by default
  • Loading branch information
bhavya01 authored Apr 26, 2024
1 parent 3fdacc8 commit f6751d2
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 12 deletions.
12 changes: 4 additions & 8 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@
from jetstream.core.proto import jetstream_pb2_grpc
from jetstream.core.utils import async_multifuture
from jetstream.engine import engine_api
from jetstream.engine import token_utils
import numpy as np


Expand Down Expand Up @@ -397,7 +396,7 @@ def _prefill_thread(self, idx: int):
prefill_engine = self._prefill_engines[idx]
prefill_params = self._prefill_params[idx]
metadata = prefill_engine.get_tokenizer()
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)
tokenizer = prefill_engine.build_tokenizer(metadata)
logging.info("---------Prefill params %d loaded.---------", idx)

while self.live:
Expand Down Expand Up @@ -429,9 +428,8 @@ def _prefill_thread(self, idx: int):
is_bos,
request.history_path,
)
padded_tokens, true_length = token_utils.tokenize_and_pad(
padded_tokens, true_length = tokenizer.encode(
request.prefill_text,
vocab,
is_bos=is_bos,
max_prefill_length=prefill_engine.max_prefill_length,
jax_padding=self._jax_padding,
Expand Down Expand Up @@ -568,8 +566,7 @@ def _detokenize_thread(self, idx: int):
my_slots = self._generate_slots[idx]

metadata = my_generate_engine.get_tokenizer()
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)

tokenizer = my_generate_engine.build_tokenizer(metadata)
my_live_requests = {
i: None for i in range(my_generate_engine.max_concurrent_decodes)
}
Expand All @@ -587,11 +584,10 @@ def _detokenize_thread(self, idx: int):

for slot, request in my_live_requests.items():
if request is not None:
results, complete = token_utils.process_result_tokens(
results, complete = tokenizer.decode(
slot=slot,
slot_max_length=request.max_tokens,
result_tokens=result_tokens,
vocab=vocab,
complete=request.complete,
)
request.complete = complete
Expand Down
12 changes: 11 additions & 1 deletion jetstream/engine/engine_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import numpy as np

from jetstream.engine import tokenizer_pb2
from jetstream.engine import token_utils


# The model parameters - their partitioning will be unique for different prefill
Expand All @@ -39,6 +40,8 @@
DeviceTokens = Any
# Cpus asscociated with the mesh.
CpuDevices = Any
# Tokenkizer used by the engine
Tokenizer = Any


@struct.dataclass
Expand Down Expand Up @@ -200,7 +203,14 @@ def get_prefix_destination_sharding(self) -> Any:
def get_tokenizer(
self,
) -> tokenizer_pb2.TokenizerParameters:
"""Returns the info to construct a sentencepiece tokenizer in py/c++."""
"""Returns the info to construct a tokenizer in py/c++."""

def build_tokenizer(
self,
metadata: tokenizer_pb2.TokenizerParameters,
) -> Tokenizer:
"""Builds a new tokenizer object and returns it."""
return token_utils.SentencePieceTokenizer(metadata)

@abc.abstractmethod
def init_decode_state(self, *args, **kwargs) -> DecodeState:
Expand Down
88 changes: 85 additions & 3 deletions jetstream/engine/token_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,20 @@

from bisect import bisect_left
import logging
from typing import List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union

import jax
import jax.numpy as jnp
import numpy as np
from seqio.vocabularies import SentencePieceVocabulary
from seqio.vocabularies import Vocabulary

from jetstream.engine import engine_api
from jetstream.engine import mock_utils
from jetstream.engine import tokenizer_api
from jetstream.engine import tokenizer_pb2

# ResultToken class to store tokens ids.
ResultTokens = Any


def take_nearest_length(lengths: list[int], length: int) -> int:
Expand Down Expand Up @@ -112,7 +116,7 @@ def tokenize_and_pad(
def process_result_tokens(
slot: int,
slot_max_length: int,
result_tokens: engine_api.ResultTokens,
result_tokens: ResultTokens,
vocab: Vocabulary,
complete: np.ndarray,
debug: bool = False,
Expand Down Expand Up @@ -196,3 +200,81 @@ def load_vocab(path: str, extra_ids: int = 0) -> Vocabulary:
sp_model = vocab.sp_model
del sp_model
return vocab


class SentencePieceTokenizer(tokenizer_api.Tokenizer):
"""Tokenizer to convert strings to token ids and vice-versa."""

def __init__(self, metadata: tokenizer_pb2.TokenizerParameters):
self.vocab = load_vocab(metadata.path, metadata.extra_ids)

def encode(
self, s: str, **kwargs
) -> Tuple[Union[jax.Array, np.ndarray], int]:
"""Tokenize a string.
Args:
s: String to tokenize.
**kwargs: Additional keyword arguments
Returns:
tokens: Tokenized into integers.
true_length: Actual length of the non-padded sequence
if padding is used.
"""
is_bos = kwargs.pop("is_bos", True)
prefill_lengths = kwargs.pop("prefill_lengths", None)
max_prefill_length = kwargs.pop("max_prefill_length", None)

tokens, true_length = tokenize_and_pad(
s,
self.vocab,
is_bos=is_bos,
prefill_lengths=prefill_lengths,
max_prefill_length=max_prefill_length,
)
return tokens, true_length

def decode(
self,
slot: int,
slot_max_length: int,
result_tokens: ResultTokens,
complete: np.ndarray,
**kwargs,
) -> Tuple[List[List[int]], np.ndarray]:
"""Processes a result tokens into a list of strings, handling multiple
samples.
Args:
slot: The slot at which to draw tokens from.
slot_max_length: Max length for a sample in the slot.
result_tokens: The tokens to access by slot.
complete: Array representing the completion status of each sample in the
slot.
kwargs: Additional keyword arguments.
Returns:
sample_return: List of strings, one per sample.
complete: Updated complete.
"""
debug = kwargs.pop("debug", False)
results, complete = process_result_tokens(
slot=slot,
slot_max_length=slot_max_length,
result_tokens=result_tokens,
vocab=self.vocab,
complete=complete,
debug=debug,
)
return results, complete

@property
def pad_id(self) -> int:
"""ID of the pad token."""
return self.vocab.pad_id

@property
def eos_id(self) -> int:
"""ID of EOS token."""
return self.vocab.eos_id
80 changes: 80 additions & 0 deletions jetstream/engine/tokenizer_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines the JetStream Tokenizer API."""

import abc
from typing import Any, Tuple, Union

import numpy as np
import jax

# Class to store token ids.
ResultTokens = Any


class Tokenizer(abc.ABC):
"""Tokenizer to convert strings to token ids and vice-versa."""

@abc.abstractmethod
def encode(
self, s: str, **kwargs
) -> Tuple[Union[jax.Array, np.ndarray], int]:
"""Tokenize a string.
Args:
s: String to tokenize.
**kwargs: Additional keyword arguments
Returns:
tokens: Tokenized into integers.
true_length: Actual length of the non-padded sequence
if padding is used.
"""

@abc.abstractmethod
def decode(
self,
slot: int,
slot_max_length: int,
result_tokens: ResultTokens,
complete: np.ndarray,
**kwargs,
) -> Tuple[list[list[int]], np.ndarray]:
"""Processes a result tokens into a list of token ids, handling multiple
samples.
Args:
slot: The slot at which to draw tokens from.
slot_max_length: Max length for a sample in the slot.
result_tokens: The tokens to access by slot.
complete: Array representing the completion status of each sample in the
slot.
**kwards: Additional keyword arguments.
Returns:
sample_return: List of strings, one per sample.
complete: Updated complete.
"""
# TODO(bbahl): Add an option to return str from decode.

@property
@abc.abstractmethod
def pad_id(self) -> int:
"""ID of the pad token."""

@property
@abc.abstractmethod
def eos_id(self) -> int:
"""ID of EOS token."""

0 comments on commit f6751d2

Please sign in to comment.