Skip to content

Commit 5341426

Browse files
authored
Implement compute_output_spec() for tokenizers with vocabulary. (#1523)
* Implement compute_output_spec() for tokenizers with vocabulary. (restarted from new point in master branch) * Remove type annotation from compute_output_spec() in tokenizers
1 parent e8f75c8 commit 5341426

File tree

3 files changed

+18
-0
lines changed

3 files changed

+18
-0
lines changed

keras_nlp/tokenizers/byte_pair_tokenizer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from typing import Iterable
2525
from typing import List
2626

27+
import keras
2728
import regex as re
2829
import tensorflow as tf
2930

@@ -605,6 +606,11 @@ def detokenize(self, inputs):
605606
outputs = tf.squeeze(outputs, 0)
606607
return outputs
607608

609+
def compute_output_spec(self, input_spec):
610+
return keras.KerasTensor(
611+
input_spec.shape + (self.sequence_length,), dtype=self.compute_dtype
612+
)
613+
608614
def _transform_bytes(self, tokens):
609615
"""Map token bytes to unicode using `byte2unicode`."""
610616
split_bytes = tf.strings.bytes_split(tokens)

keras_nlp/tokenizers/sentence_piece_tokenizer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
from typing import List
1919

20+
import keras
2021
import tensorflow as tf
2122

2223
from keras_nlp.api_export import keras_nlp_export
@@ -255,3 +256,8 @@ def detokenize(self, inputs):
255256
if unbatched:
256257
outputs = tf.squeeze(outputs, 0)
257258
return outputs
259+
260+
def compute_output_spec(self, input_spec):
261+
return keras.KerasTensor(
262+
input_spec.shape + (self.sequence_length,), dtype=self.compute_dtype
263+
)

keras_nlp/tokenizers/word_piece_tokenizer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Iterable
1818
from typing import List
1919

20+
import keras
2021
import tensorflow as tf
2122

2223
from keras_nlp.api_export import keras_nlp_export
@@ -528,3 +529,8 @@ def detokenize(self, inputs):
528529
if unbatched:
529530
outputs = tf.squeeze(outputs, 0)
530531
return outputs
532+
533+
def compute_output_spec(self, input_spec):
534+
return keras.KerasTensor(
535+
input_spec.shape + (self.sequence_length,), dtype=self.compute_dtype
536+
)

0 commit comments

Comments
 (0)