Skip to content

Commit

Permalink
tf lite support for retvec, update notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
MarinaZhang committed Oct 12, 2023
1 parent cf9cedc commit fc3f200
Show file tree
Hide file tree
Showing 14 changed files with 589 additions and 544 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Detailed example colabs for RETVec can be found at under [notebooks](notebooks/)
We have the following example colabs:

- Training RETVec-based models using TensorFlow: [train_hello_world_tf.ipynb](notebooks/train_hello_world_tf.ipynb) for GPU/CPU training, and [train_tpu.ipynb](notebooks/train_tpu.ipynb) for a TPU-compatible training example.
- (Coming soon!) Converting RETVec models into TF Lite models to run on-device.
- Converting RETVec models into TF Lite models to run on-device: [tf_lite_retvec.ipynb](notebooks/tf_lite_retvec.ipynb)
- (Coming soon!) Using RETVec JS to deploy RETVec models in the web using TensorFlow.js

## Citing
Expand Down
2 changes: 1 addition & 1 deletion notebooks/demo_models/emotion_model/fingerprint.pb
Original file line number Diff line number Diff line change
@@ -1 +1 @@
�ɫ߀ٰ�*���°�����ە���� �à����N(֜�׸���}2
������͘�ũ��������ە���� ���՗����(����ؚ��b2
48 changes: 24 additions & 24 deletions notebooks/demo_models/emotion_model/keras_metadata.pb

Large diffs are not rendered by default.

Binary file modified notebooks/demo_models/emotion_model/saved_model.pb
Binary file not shown.
Binary file not shown.
Binary file modified notebooks/demo_models/emotion_model/variables/variables.index
Binary file not shown.
398 changes: 207 additions & 191 deletions notebooks/tf_lite_retvec.ipynb

Large diffs are not rendered by default.

245 changes: 108 additions & 137 deletions notebooks/train_retvec_model_tf.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion retvec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
limitations under the License.
"""

__version__ = "1.0.1"
__version__ = "1.0.2"
94 changes: 60 additions & 34 deletions retvec/tf/layers/binarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
limitations under the License.
"""

from typing import Any, Dict, List, Union
import logging
import re
from typing import Any, Dict, List, Union

import tensorflow as tf
from tensorflow import Tensor, TensorShape
Expand All @@ -29,9 +29,13 @@
from .integerizer import RETVecIntegerizer


def _reshape_embeddings(embeddings: tf.Tensor, batch_size: int,
sequence_length: int, word_length: int,
encoding_size: int) -> tf.Tensor:
def _reshape_embeddings(
embeddings: tf.Tensor,
batch_size: int,
sequence_length: int,
word_length: int,
encoding_size: int,
) -> tf.Tensor:
if sequence_length > 1:
return tf.reshape(
embeddings,
Expand All @@ -43,8 +47,7 @@ def _reshape_embeddings(embeddings: tf.Tensor, batch_size: int,
),
)
else:
return tf.reshape(embeddings,
(batch_size, word_length, encoding_size))
return tf.reshape(embeddings, (batch_size, word_length, encoding_size))


@tf.keras.utils.register_keras_serializable(package="retvec")
Expand Down Expand Up @@ -105,10 +108,13 @@ def call(self, inputs: Tensor) -> Tensor:
embeddings = tf.cast(embeddings, dtype="float32")

# reshape back to correct shape
return _reshape_embeddings(embeddings, batch_size=batch_size,
sequence_length=self.sequence_length,
word_length=self.word_length,
encoding_size=self.encoding_size)
return _reshape_embeddings(
embeddings,
batch_size=batch_size,
sequence_length=self.sequence_length,
word_length=self.word_length,
encoding_size=self.encoding_size,
)

def _project(self, chars: Tensor, masks: Tensor) -> Tensor:
"""Project chars in subspace"""
Expand Down Expand Up @@ -191,21 +197,29 @@ def __init__(
self.use_native_tf_ops = use_native_tf_ops

# Check if the native `utf8_binarize` op is available for use.
is_utf8_encoding = re.match('^utf-?8$', encoding_type, re.IGNORECASE)
self._native_mode = (use_native_tf_ops and
is_utf8_encoding and
utf8_binarize is not None)
is_utf8_encoding = re.match("^utf-?8$", encoding_type, re.IGNORECASE)
self._native_mode = (
use_native_tf_ops
and is_utf8_encoding
and utf8_binarize is not None
)
if use_native_tf_ops and not self._native_mode:
logging.warning('Native support for `RETVecBinarizer` unavailable. '
'Check `tensorflow_text.utf8_binarize` availability'
' and its parameter contraints.')
logging.warning(
"Native support for `RETVecBinarizer` unavailable. "
"Check `tensorflow_text.utf8_binarize` availability"
" and its parameter contraints."
)

# Set to True when 'binarize()' is called in eager mode
self.eager = False
self._integerizer = None if self._native_mode else RETVecIntegerizer(
word_length=self.word_length,
encoding_type=self.encoding_type,
replacement_char=self.replacement_char,
self._integerizer = (
None
if self._native_mode
else RETVecIntegerizer(
word_length=self.word_length,
encoding_type=self.encoding_type,
replacement_char=self.replacement_char,
)
)

def build(
Expand All @@ -215,23 +229,35 @@ def build(

# Initialize int binarizer layer here since we know sequence_length
# only once we known the input_shape
self._int_to_binary = None if self._native_mode else RETVecIntToBinary(
word_length=self.word_length,
sequence_length=self.sequence_length,
encoding_size=self.encoding_size,
self._int_to_binary = (
None
if self._native_mode
else RETVecIntToBinary(
word_length=self.word_length,
sequence_length=self.sequence_length,
encoding_size=self.encoding_size,
)
)

def call(self, inputs: Tensor) -> Tensor:
if self._native_mode:
embeddings = utf8_binarize(inputs,
word_length=self.word_length,
bits_per_char=self.encoding_size,
replacement_char=self.replacement_char)
embeddings = utf8_binarize(
inputs,
word_length=self.word_length,
bits_per_char=self.encoding_size,
replacement_char=self.replacement_char,
)
batch_size = tf.shape(inputs)[0]
return _reshape_embeddings(embeddings, batch_size=batch_size,
sequence_length=self.sequence_length,
word_length=self.word_length,
encoding_size=self.encoding_size)
embeddings = _reshape_embeddings(
embeddings,
batch_size=batch_size,
sequence_length=self.sequence_length,
word_length=self.word_length,
encoding_size=self.encoding_size,
)
# TODO (marinazh): little vs big-endian order mismatch
return tf.reverse(embeddings, axis=[-1])

else:
assert self._integerizer is not None
char_encodings = self._integerizer(inputs)
Expand All @@ -245,7 +271,7 @@ def binarize(self, inputs: Tensor) -> Tensor:
"""Return binary encodings for a word or a list of words.
Args:
inputs: A single word or list of words to encode.
inputs: Tensor of a single word or list of words to encode.
Returns:
RETVec binary encodings for the input words(s).
Expand Down
94 changes: 71 additions & 23 deletions retvec/tf/layers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
limitations under the License.
"""

import logging
from pathlib import Path
from typing import Any, Dict, Optional, Union

Expand All @@ -22,11 +23,12 @@
from tensorflow.keras import layers

try:
from tensorflow_text import WhitespaceTokenizer
from tensorflow_text import WhitespaceTokenizer, utf8_binarize
except ImportError:
WhitespaceTokenizer = None
utf8_binarize = None

from .binarizer import RETVecBinarizer
from .binarizer import RETVecBinarizer, _reshape_embeddings
from .embedding import RETVecEmbedding

LOWER_AND_STRIP_PUNCTUATION = "lower_and_strip_punctuation"
Expand Down Expand Up @@ -152,7 +154,19 @@ def __init__(
self.trainable = trainable

# Use whitesapce tokenizer for TF Lite compatibility
self._native_mode = self.use_native_tf_ops and WhitespaceTokenizer
# TODO (marinazh): use TF Text functions like regex_split to offer
# more flexibility and preprocessing options
self._native_mode = (
self.use_native_tf_ops and WhitespaceTokenizer and utf8_binarize
)

if use_native_tf_ops and not self._native_mode:
logging.warning(
"Native support for `RETVecTokenizer` unavailable. "
"Check `tensorflow_text.utf8_binarize` availability"
" and its parameter contraints."
)

if self._native_mode:
self._whitespace_tokenizer = WhitespaceTokenizer()

Expand All @@ -174,7 +188,7 @@ def __init__(
encoding_size=self.char_encoding_size,
encoding_type=self.char_encoding_type,
replacement_char=self.replacement_char,
use_native_tf_ops=use_native_tf_ops
use_native_tf_ops=use_native_tf_ops,
)

# Set to True when 'tokenize()' or 'binarize()' called in eager mode
Expand Down Expand Up @@ -215,12 +229,46 @@ def embedding_size(self):

def call(self, inputs: Tensor, training: bool = False) -> Tensor:
inputs = tf.stop_gradient(inputs)
batch_size = tf.shape(inputs)[0]

# if native mode, use whitespace tokenization for tf lite compatibility
if self._native_mode:
rtensor = self._whitespace_tokenizer.tokenize(inputs)
# ensure batch of tf.strings doesn't have extra dim
if len(inputs.shape) == 2:
inputs = tf.squeeze(inputs, axis=1)

# whitespace tokenization
tokenized = self._whitespace_tokenizer.tokenize(inputs)
row_lengths = tokenized.row_lengths()

# apply native binarization op
# NOTE: utf8_binarize used here because RaggedTensorToTensor isn't
# supported in TF Text / TF Lite conversion, this is a workaround
binarized = utf8_binarize(tokenized.flat_values)
binarized = tf.RaggedTensor.from_row_lengths(
values=binarized, row_lengths=row_lengths
)

# convert from RaggedTensor to Tensor
binarized = binarized.to_tensor(
default_value=0,
shape=(
batch_size,
self.sequence_length,
self.word_length * self.char_encoding_size,
),
)

# reshape embeddings to apply the RETVecEmbedding layer
binarized = _reshape_embeddings(
binarized,
batch_size=batch_size,
sequence_length=self.sequence_length,
word_length=self.word_length,
encoding_size=self.char_encoding_size,
)

else:
# standardize and preprocess text
if self.standardize in (LOWER, LOWER_AND_STRIP_PUNCTUATION):
inputs = tf.strings.lower(inputs)
if self.standardize in (
Expand All @@ -233,32 +281,32 @@ def call(self, inputs: Tensor, training: bool = False) -> Tensor:
if callable(self.standardize):
inputs = self.standardize(inputs)

# split text on separator
rtensor = tf.strings.split(
inputs, sep=self.sep, maxsplit=self.sequence_length
)

#Handle shape differences between eager and graph mode
if self.eager:
stensor = rtensor.to_tensor(
default_value="",
shape=(rtensor.shape[0], self.sequence_length),
)
else:
stensor = rtensor.to_tensor(
default_value="",
shape=(rtensor.shape[0], 1, self.sequence_length),
)
stensor = tf.squeeze(stensor, axis=1)
# Handle shape differences between eager and graph mode
if self.eager:
stensor = rtensor.to_tensor(
default_value="",
shape=(rtensor.shape[0], self.sequence_length),
)
else:
stensor = rtensor.to_tensor(
default_value="",
shape=(rtensor.shape[0], 1, self.sequence_length),
)
stensor = tf.squeeze(stensor, axis=1)

# apply encoding and REW* model, if set
binarized = self._binarizer(stensor, training=training)
# apply RETVec binarization
binarized = self._binarizer(stensor, training=training)

# embed using RETVec word embedding model, if available
if self._embedding:
embeddings = self._embedding(binarized, training=training)
else:
embsize = (
self._binarizer.encoding_size * self._binarizer.word_length
)
embsize = self.char_encoding_size * self.word_length
embeddings = tf.reshape(
binarized, (tf.shape(inputs)[0], self.sequence_length, embsize)
)
Expand Down
Loading

0 comments on commit fc3f200

Please sign in to comment.