From dc2203828eb3aeda2ef1ce4dc7eecd19226084a6 Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi Date: Wed, 25 Oct 2023 20:37:41 +0000 Subject: [PATCH 1/2] Add preset tests and weights URLs --- keras_nlp/models/backbone.py | 20 +++++++++++----- keras_nlp/models/t5/t5_backbone_test.py | 29 ++++++++++++++++++++++++ keras_nlp/models/t5/t5_presets.py | 24 ++++++++++++++++++++ keras_nlp/models/t5/t5_tokenizer_test.py | 20 ++++++++++++++++ 4 files changed, 87 insertions(+), 6 deletions(-) diff --git a/keras_nlp/models/backbone.py b/keras_nlp/models/backbone.py index b7a7ba2119..2425345495 100644 --- a/keras_nlp/models/backbone.py +++ b/keras_nlp/models/backbone.py @@ -112,12 +112,20 @@ def from_preset( if not load_weights: return model - weights = keras.utils.get_file( - "model.h5", - metadata["weights_url"], - cache_subdir=os.path.join("models", preset), - file_hash=metadata["weights_hash"], - ) + if metadata["weights_url"].endswith(".weights.h5"): + weights = keras.utils.get_file( + "model.weights.h5", + metadata["weights_url"], + cache_subdir=os.path.join("models", preset), + file_hash=metadata["weights_hash"], + ) + else: + weights = keras.utils.get_file( + "model.h5", + metadata["weights_url"], + cache_subdir=os.path.join("models", preset), + file_hash=metadata["weights_hash"], + ) model.load_weights(weights) return model diff --git a/keras_nlp/models/t5/t5_backbone_test.py b/keras_nlp/models/t5/t5_backbone_test.py index b8041e876e..9006925f10 100644 --- a/keras_nlp/models/t5/t5_backbone_test.py +++ b/keras_nlp/models/t5/t5_backbone_test.py @@ -53,3 +53,32 @@ def test_saved_model(self): init_kwargs=self.init_kwargs, input_data=self.input_data, ) + + @pytest.mark.large + def test_smallest_preset(self): + self.run_preset_test( + cls=T5Backbone, + preset="t5_small_multi", + input_data=self.input_data, + expected_output_shape={ + "encoder_sequence_output": (2, 3, 512), + "decoder_sequence_output": (2, 3, 512), + }, + expected_partial_output={ + "encoder_sequence_output": ops.array( + [-0.0034, 0.0293, -0.0827, -0.1076] + ), + "decoder_sequence_output": ops.array( + [0.0097, 0.3576, -0.1508, 0.0150] + ), + }, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in T5Backbone.presets: + self.run_preset_test( + cls=T5Backbone, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/t5/t5_presets.py b/keras_nlp/models/t5/t5_presets.py index cbdde0391a..1c737a863b 100644 --- a/keras_nlp/models/t5/t5_presets.py +++ b/keras_nlp/models/t5/t5_presets.py @@ -38,6 +38,10 @@ "layer_norm_epsilon": 1e-06, }, "preprocessor_config": {}, + "weights_url": "https://storage.googleapis.com/keras-nlp/models/t5_small_multi/v1/model.weights.h5", + "weights_hash": "5a241ea61142eaf96ac1805898a2f2d1", + "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/t5_small_multi/v1/vocab.spm", + "vocabulary_hash": "9d15ef55d09d5a425ceb63fa31f7cae3", }, "t5_base_multi": { "metadata": { @@ -62,6 +66,10 @@ "layer_norm_epsilon": 1e-06, }, "preprocessor_config": {}, + "weights_url": "https://storage.googleapis.com/keras-nlp/models/t5_base_multi/v1/model.weights.h5", + "weights_hash": "9bef4c6650d91d1ea438ee4a2bea47ad", + "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/t5_base_multi/v1/vocab.spm", + "vocabulary_hash": "9d15ef55d09d5a425ceb63fa31f7cae3", }, "t5_large_multi": { "metadata": { @@ -86,6 +94,10 @@ "layer_norm_epsilon": 1e-06, }, "preprocessor_config": {}, + "weights_url": "https://storage.googleapis.com/keras-nlp/models/t5_large_multi/v1/model.weights.h5", + "weights_hash": "eab8eee1bad033e65324a71cd6e5a8e9", + "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/t5_large_multi/v1/vocab.spm", + "vocabulary_hash": "9d15ef55d09d5a425ceb63fa31f7cae3", }, "flan_small_multi": { "metadata": { @@ -111,6 +123,10 @@ "layer_norm_epsilon": 1e-06, }, "preprocessor_config": {}, + "weights_url": "https://storage.googleapis.com/keras-nlp/models/flan_small_multi/v1/model.weights.h5", + "weights_hash": "4e39b0bab56606a9ab2b8e52a6bc7a9f", + "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/flan_small_multi/v1/vocab.spm", + "vocabulary_hash": "9d15ef55d09d5a425ceb63fa31f7cae3", }, "flan_base_multi": { "metadata": { @@ -135,6 +151,10 @@ "layer_norm_epsilon": 1e-06, }, "preprocessor_config": {}, + "weights_url": "https://storage.googleapis.com/keras-nlp/models/flan_base_multi/v1/model.weights.h5", + "weights_hash": "b529270c5361db36d359a46403532b5c", + "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/flan_base_multi/v1/vocab.spm", + "vocabulary_hash": "9d15ef55d09d5a425ceb63fa31f7cae3", }, "flan_large_multi": { "metadata": { @@ -159,5 +179,9 @@ "layer_norm_epsilon": 1e-06, }, "preprocessor_config": {}, + "weights_url": "https://storage.googleapis.com/keras-nlp/models/flan_large_multi/v1/model.weights.h5", + "weights_hash": "50b8d3c88fc10db07e495d79ff29a1b6", + "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/flan_large_multi/v1/vocab.spm", + "vocabulary_hash": "9d15ef55d09d5a425ceb63fa31f7cae3", }, } diff --git a/keras_nlp/models/t5/t5_tokenizer_test.py b/keras_nlp/models/t5/t5_tokenizer_test.py index f8cef35c30..9f6f4e9e8f 100644 --- a/keras_nlp/models/t5/t5_tokenizer_test.py +++ b/keras_nlp/models/t5/t5_tokenizer_test.py @@ -14,6 +14,7 @@ import io +import pytest import sentencepiece import tensorflow as tf @@ -64,3 +65,22 @@ def test_errors_missing_special_tokens(self): ) with self.assertRaises(ValueError): T5Tokenizer(proto=bytes_io.getvalue()) + + @pytest.mark.large + def test_smallest_preset(self): + for preset in T5Tokenizer.presets: + self.run_preset_test( + cls=T5Tokenizer, + preset=preset, + input_data=["The quick brown fox."], + expected_output=[[1996, 4248, 2829, 4419, 1012]], + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in T5Tokenizer.presets: + self.run_preset_test( + cls=T5Tokenizer, + preset=preset, + input_data=self.input_data, + ) From 471f65985d209955f98fef6f4072c3312561f80b Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi Date: Wed, 25 Oct 2023 21:18:07 +0000 Subject: [PATCH 2/2] Change filename conditional --- keras_nlp/models/backbone.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/keras_nlp/models/backbone.py b/keras_nlp/models/backbone.py index 2425345495..a55f767394 100644 --- a/keras_nlp/models/backbone.py +++ b/keras_nlp/models/backbone.py @@ -112,20 +112,14 @@ def from_preset( if not load_weights: return model - if metadata["weights_url"].endswith(".weights.h5"): - weights = keras.utils.get_file( - "model.weights.h5", - metadata["weights_url"], - cache_subdir=os.path.join("models", preset), - file_hash=metadata["weights_hash"], - ) - else: - weights = keras.utils.get_file( - "model.h5", - metadata["weights_url"], - cache_subdir=os.path.join("models", preset), - file_hash=metadata["weights_hash"], - ) + filename = os.path.basename(metadata["weights_url"]) + weights = keras.utils.get_file( + filename, + metadata["weights_url"], + cache_subdir=os.path.join("models", preset), + file_hash=metadata["weights_hash"], + ) + model.load_weights(weights) return model