Skip to content

Commit

Permalink
Add image and audio converter classes (#1813)
Browse files Browse the repository at this point in the history
* Add image and audio converter classes

These classes will occupy the same role as tokenizers for text models.
They will transform raw inputs to model inputs in a way that is not
task specific.

* Fix some tests

* Input conversion fixes

* Torch property fixes

* Another fix

* Address comments

* Add assets on kaggle; bump preset versions

* Fix last failing test
  • Loading branch information
mattdangerw authored Sep 10, 2024
1 parent 84a6b66 commit 23815d6
Show file tree
Hide file tree
Showing 66 changed files with 930 additions and 644 deletions.
1 change: 0 additions & 1 deletion keras_nlp/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from keras_nlp.api import models
from keras_nlp.api import samplers
from keras_nlp.api import tokenizers
from keras_nlp.api import utils
from keras_nlp.src.utils.preset_utils import upload_preset
from keras_nlp.src.version_utils import __version__
from keras_nlp.src.version_utils import version
11 changes: 11 additions & 0 deletions keras_nlp/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
)
from keras_nlp.src.layers.modeling.transformer_decoder import TransformerDecoder
from keras_nlp.src.layers.modeling.transformer_encoder import TransformerEncoder
from keras_nlp.src.layers.preprocessing.audio_converter import AudioConverter
from keras_nlp.src.layers.preprocessing.image_converter import ImageConverter
from keras_nlp.src.layers.preprocessing.masked_lm_mask_generator import (
MaskedLMMaskGenerator,
)
Expand All @@ -44,4 +46,13 @@
)
from keras_nlp.src.layers.preprocessing.random_deletion import RandomDeletion
from keras_nlp.src.layers.preprocessing.random_swap import RandomSwap
from keras_nlp.src.layers.preprocessing.resizing_image_converter import (
ResizingImageConverter,
)
from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker
from keras_nlp.src.models.pali_gemma.pali_gemma_image_converter import (
PaliGemmaImageConverter,
)
from keras_nlp.src.models.whisper.whisper_audio_converter import (
WhisperAudioConverter,
)
6 changes: 0 additions & 6 deletions keras_nlp/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,7 @@
from keras_nlp.src.models.text_classifier_preprocessor import (
TextClassifierPreprocessor,
)
from keras_nlp.src.models.whisper.whisper_audio_feature_extractor import (
WhisperAudioFeatureExtractor,
)
from keras_nlp.src.models.whisper.whisper_backbone import WhisperBackbone
from keras_nlp.src.models.whisper.whisper_preprocessor import (
WhisperPreprocessor,
)
from keras_nlp.src.models.whisper.whisper_tokenizer import WhisperTokenizer
from keras_nlp.src.models.xlm_roberta.xlm_roberta_backbone import (
XLMRobertaBackbone,
Expand Down
121 changes: 121 additions & 0 deletions keras_nlp/src/layers/preprocessing/audio_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.layers.preprocessing.preprocessing_layer import (
PreprocessingLayer,
)
from keras_nlp.src.utils.preset_utils import AUDIO_CONVERTER_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import find_subclass
from keras_nlp.src.utils.preset_utils import get_preset_loader
from keras_nlp.src.utils.preset_utils import list_presets
from keras_nlp.src.utils.preset_utils import list_subclasses
from keras_nlp.src.utils.preset_utils import save_serialized_object
from keras_nlp.src.utils.python_utils import classproperty


@keras_nlp_export("keras_nlp.layers.AudioConverter")
class AudioConverter(PreprocessingLayer):
"""Convert raw audio for models that support audio input.
This class converts from raw audio tensors of any length, to preprocessed
audio for pretrained model inputs. It is meant to be a convenient way to
write custom preprocessing code that is not model specific. This layer
should be instantiated via the `from_preset()` constructor, which will
create the correct subclass of this layer for the model preset.
The layer will take as input a raw audio tensor with shape `(batch_size,
num_samples)`, and output a preprocessed audio input for modeling. The exact
structure of the preprocessed input will vary per model. Preprocessing
will often include computing a spectogram of the raw audio signal.
Examples:
```python
# Load an audio converter from a preset.
converter = keras_nlp.layers.AudioConverter.from_preset("whisper_base_en")
# Convert some raw audio input.
converter(np.ones(2, 1_000))
```
"""

backbone_cls = None

@classproperty
def presets(cls):
"""List built-in presets for a `Task` subclass."""
presets = list_presets(cls)
for subclass in list_subclasses(cls):
presets.update(subclass.presets)
return presets

@classmethod
def from_preset(
cls,
preset,
**kwargs,
):
"""Instantiate a `keras_nlp.layers.AudioConverter` from a model preset.
A preset is a directory of configs, weights and other file assets used
to save and load a pre-trained model. The `preset` can be passed as
one of:
1. a built-in preset identifier like `'whisper_base_en'`
2. a Kaggle Models handle like
`'kaggle://user/whisper/keras/whisper_base_en'`
3. a Hugging Face handle like `'hf://user/whisper_base_en'`
4. a path to a local preset directory like `'./whisper_base_en'`
You can run `cls.presets.keys()` to list all built-in presets available
on the class.
This constructor can be called in one of two ways. Either from the base
class like `keras_nlp.models.AudioConverter.from_preset()`, or from a
model class like `keras_nlp.models.WhisperAudioConverter.from_preset()`.
If calling from the base class, the subclass of the returning object
will be inferred from the config in the preset directory.
Args:
preset: string. A built-in preset identifier, a Kaggle Models
handle, a Hugging Face handle, or a path to a local directory.
load_weights: bool. If `True`, the weights will be loaded into the
model architecture. If `False`, the weights will be randomly
initialized.
Examples:
```python
# Load an audio converter from a preset.
converter = keras_nlp.layers.AudioConverter.from_preset(
"whisper_base_en"
)
# Convert some raw mono channel audio input.
converter(np.ones(2, 1_000))
```
"""
loader = get_preset_loader(preset)
backbone_cls = loader.check_backbone_class()
if cls.backbone_cls != backbone_cls:
cls = find_subclass(preset, cls, backbone_cls)
return loader.load_audio_converter(cls, **kwargs)

def save_to_preset(self, preset_dir):
"""Save audio converter to a preset directory.
Args:
preset_dir: The path to the local model preset directory.
"""
save_serialized_object(
self,
preset_dir,
config_file=AUDIO_CONVERTER_CONFIG_FILE,
)
69 changes: 69 additions & 0 deletions keras_nlp/src/layers/preprocessing/audio_converter_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import pathlib

import numpy as np
import pytest

from keras_nlp.src.layers.preprocessing.audio_converter import AudioConverter
from keras_nlp.src.models.backbone import Backbone
from keras_nlp.src.models.whisper.whisper_audio_converter import (
WhisperAudioConverter,
)
from keras_nlp.src.tests.test_case import TestCase


class AudioConverterTest(TestCase):
def test_preset_accessors(self):
pali_gemma_presets = set(WhisperAudioConverter.presets.keys())
all_presets = set(AudioConverter.presets.keys())
self.assertContainsSubset(pali_gemma_presets, all_presets)

@pytest.mark.large
def test_from_preset(self):
self.assertIsInstance(
AudioConverter.from_preset("whisper_tiny_en"),
WhisperAudioConverter,
)

@pytest.mark.large
def test_from_preset_errors(self):
with self.assertRaises(ValueError):
AudioConverter.from_preset("bert_tiny_en_uncased")
with self.assertRaises(ValueError):
# No loading on a non-keras model.
AudioConverter.from_preset("hf://spacy/en_core_web_sm")

@pytest.mark.large
def test_save_to_preset(self):
save_dir = self.get_temp_dir()
converter = AudioConverter.from_preset(
"whisper_tiny_en",
num_mels=40,
)
converter.save_to_preset(save_dir)
# Save a backbone so the preset is valid.
backbone = Backbone.from_preset("whisper_tiny_en", load_weights=False)
backbone.save_to_preset(save_dir)

# Check existence of files.
path = pathlib.Path(save_dir)
self.assertTrue(os.path.exists(path / "audio_converter.json"))

# Check loading.
restored = AudioConverter.from_preset(save_dir)
test_audio = np.random.rand(1_000)
self.assertAllClose(restored(test_audio), converter(test_audio))
130 changes: 130 additions & 0 deletions keras_nlp/src/layers/preprocessing/image_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.layers.preprocessing.preprocessing_layer import (
PreprocessingLayer,
)
from keras_nlp.src.utils.preset_utils import IMAGE_CONVERTER_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import find_subclass
from keras_nlp.src.utils.preset_utils import get_preset_loader
from keras_nlp.src.utils.preset_utils import list_presets
from keras_nlp.src.utils.preset_utils import list_subclasses
from keras_nlp.src.utils.preset_utils import save_serialized_object
from keras_nlp.src.utils.python_utils import classproperty


@keras_nlp_export("keras_nlp.layers.ImageConverter")
class ImageConverter(PreprocessingLayer):
"""Convert raw image for models that support image input.
This class converts from raw images of any size, to preprocessed
images for pretrained model inputs. It is meant to be a convenient way to
write custom preprocessing code that is not model specific. This layer
should be instantiated via the `from_preset()` constructor, which will
create the correct subclass of this layer for the model preset.
The layer will take as input a raw image tensor in the channels last or
channels first format, and output a preprocessed image input for modeling.
The exact structure of the output will vary per model, though in most cases
this layer will simply resize the image to the size needed by the model
input.
Examples:
```python
# Resize images for `"pali_gemma_3b_224"`.
converter = keras_nlp.layers.ImageConverter.from_preset("pali_gemma_3b_224")
converter(np.ones(2, 512, 512, 3)) # Output shape: (2, 224, 224, 3)
# Resize images for `"pali_gemma_3b_448"`.
converter = keras_nlp.layers.ImageConverter.from_preset("pali_gemma_3b_448")
converter(np.ones(2, 512, 512, 3)) # Output shape: (2, 448, 448, 3)
```
"""

backbone_cls = None

@classproperty
def presets(cls):
"""List built-in presets for a `Task` subclass."""
presets = list_presets(cls)
for subclass in list_subclasses(cls):
presets.update(subclass.presets)
return presets

@classmethod
def from_preset(
cls,
preset,
**kwargs,
):
"""Instantiate a `keras_nlp.layers.ImageConverter` from a model preset.
A preset is a directory of configs, weights and other file assets used
to save and load a pre-trained model. The `preset` can be passed as
one of:
1. a built-in preset identifier like `'pali_gemma_3b_224'`
2. a Kaggle Models handle like
`'kaggle://user/paligemma/keras/pali_gemma_3b_224'`
3. a Hugging Face handle like `'hf://user/pali_gemma_3b_224'`
4. a path to a local preset directory like `'./pali_gemma_3b_224'`
You can run `cls.presets.keys()` to list all built-in presets available
on the class.
This constructor can be called in one of two ways. Either from the base
class like `keras_nlp.models.ImageConverter.from_preset()`, or from a
model class like
`keras_nlp.models.PaliGemmaImageConverter.from_preset()`. If calling
from the base class, the subclass of the returning object will be
inferred from the config in the preset directory.
Args:
preset: string. A built-in preset identifier, a Kaggle Models
handle, a Hugging Face handle, or a path to a local directory.
load_weights: bool. If `True`, the weights will be loaded into the
model architecture. If `False`, the weights will be randomly
initialized.
Examples:
```python
# Resize images for `"pali_gemma_3b_224"`.
converter = keras_nlp.layers.ImageConverter.from_preset(
"pali_gemma_3b_224"
)
converter(np.ones(2, 512, 512, 3)) # Output shape: (2, 224, 224, 3)
# Override arguments on the base class.
converter = keras_nlp.layers.ImageConverter.from_preset(
"pali_gemma_3b_448",
crop_to_aspect_ratio=False,
)
converter(np.ones(2, 512, 512, 3)) # (2, 448, 448, 3)
```
"""
loader = get_preset_loader(preset)
backbone_cls = loader.check_backbone_class()
if cls.backbone_cls != backbone_cls:
cls = find_subclass(preset, cls, backbone_cls)
return loader.load_image_converter(cls, **kwargs)

def save_to_preset(self, preset_dir):
"""Save image converter to a preset directory.
Args:
preset_dir: The path to the local model preset directory.
"""
save_serialized_object(
self,
preset_dir,
config_file=IMAGE_CONVERTER_CONFIG_FILE,
)
Loading

0 comments on commit 23815d6

Please sign in to comment.