Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Jul 11, 2024
1 parent 9a45a68 commit f7c40e3
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 62 deletions.
52 changes: 8 additions & 44 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ jobs:
fail-fast: false
matrix:
backend: [tensorflow, jax, torch]
version: [latest]
include:
- backend: torch
version: 3.1
runs-on: ubuntu-latest
env:
KERAS_BACKEND: ${{ matrix.backend }}
Expand All @@ -42,51 +46,11 @@ jobs:
run: |
pip install -r requirements.txt --progress-bar off
pip install --no-deps -e "." --progress-bar off
- name: Test with pytest
run: |
pytest keras_nlp/
- name: Run integration tests
run: |
python pip_build.py --install
cd integration_tests && pytest . -k "not NoTensorflow"
- name: Run no tensorflow integration test
if: ${{ matrix.backend != 'tensorflow'}}
- name: Pin Keras version
if: ${{ matrix.version == '3.1'}}
run: |
pip uninstall -y tensorflow-text tensorflow
cd integration_tests && pytest . -k "NoTensorflow"
run_tests_with_keras_3_1_0:
name: Test the code with Keras 3.1.0
strategy:
fail-fast: false
matrix:
backend: [tensorflow, jax, torch]
runs-on: ubuntu-latest
env:
KERAS_BACKEND: ${{ matrix.backend }}
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.9
uses: actions/setup-python@v5
with:
python-version: 3.9
- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip setuptools
echo "::set-output name=dir::$(pip cache dir)"
- name: pip cache
uses: actions/cache@v4
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install dependencies with Keras 3.1.0
run: |
pip install -r requirements.txt --progress-bar off
pip install --no-deps -e "." --progress-bar off
pip uninstall -y keras
pip install keras==3.1.0 --progress-bar off
pip uninstall -y keras
pip install keras==3.1.0 --progress-bar off
- name: Test with pytest
run: |
pytest keras_nlp/
Expand Down
7 changes: 2 additions & 5 deletions keras_nlp/src/layers/modeling/reversible_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from packaging.version import parse

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.utils.keras_utils import assert_quantization_support


@keras_nlp_export("keras_nlp.layers.ReversibleEmbedding")
Expand Down Expand Up @@ -237,11 +238,7 @@ def _int8_call(self, inputs, reverse=False):
def quantize(self, mode, type_check=True):
import gc

if parse(keras.version()) < parse("3.4.0"):
raise ValueError(
"`quantize` in KerasNLP requires Keras >= 3.4.0 to function "
f"correctly. Received: '{keras.version()}'"
)
assert_quantization_support()
if type_check and type(self) is not ReversibleEmbedding:
raise NotImplementedError(
f"Layer {self.__class__.__name__} does not have a `quantize()` "
Expand Down
10 changes: 5 additions & 5 deletions keras_nlp/src/layers/modeling/reversible_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
from absl.testing import parameterized
from keras import ops
from keras import random
from packaging.version import parse

from keras_nlp.src.layers.modeling.reversible_embedding import (
ReversibleEmbedding,
)
from keras_nlp.src.tests.test_case import TestCase
from keras_nlp.src.utils.keras_utils import has_quantization_support


class ReversibleEmbeddingTest(TestCase):
Expand Down Expand Up @@ -104,8 +104,8 @@ def test_reverse_dtype(self):
("tie_weights", True), ("untie_weights", False)
)
def test_quantize_int8(self, tie_weights):
if parse(keras.version()) < parse("3.4.0"):
self.skipTest("This test needs keras>=3.4.0.")
if not has_quantization_support():
self.skipTest("This version of Keras doesn't support quantization.")

layer_config = dict(
input_dim=100, output_dim=32, tie_weights=tie_weights
Expand Down Expand Up @@ -155,8 +155,8 @@ def test_quantize_int8(self, tie_weights):
("untie_weights", False),
)
def test_quantize_dtype_argument(self, tie_weights):
if parse(keras.version()) < parse("3.4.0"):
self.skipTest("This test needs keras>=3.4.0.")
if not has_quantization_support():
self.skipTest("This version of Keras doesn't support quantization.")

self.run_layer_test(
cls=ReversibleEmbedding,
Expand Down
8 changes: 2 additions & 6 deletions keras_nlp/src/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
import os

import keras
from packaging.version import parse

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.utils.keras_utils import assert_quantization_support
from keras_nlp.src.utils.preset_utils import CONFIG_FILE
from keras_nlp.src.utils.preset_utils import MODEL_WEIGHTS_FILE
from keras_nlp.src.utils.preset_utils import check_config_class
Expand Down Expand Up @@ -109,11 +109,7 @@ def token_embedding(self, value):
self._token_embedding = value

def quantize(self, mode, **kwargs):
if parse(keras.version()) < parse("3.4.0"):
raise ValueError(
"`quantize` in KerasNLP requires Keras >= 3.4.0 to function "
f"correctly. Received: keras.version()={keras.version()}"
)
assert_quantization_support()
return super().quantize(mode, **kwargs)

def get_config(self):
Expand Down
4 changes: 2 additions & 2 deletions keras_nlp/src/tests/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
from absl.testing import parameterized
from keras import ops
from keras import tree
from packaging.version import parse

from keras_nlp.src import layers as keras_nlp_layers
from keras_nlp.src.tokenizers.tokenizer import Tokenizer
from keras_nlp.src.utils.keras_utils import has_quantization_support
from keras_nlp.src.utils.tensor_utils import is_float_dtype


Expand Down Expand Up @@ -446,7 +446,7 @@ def run_backbone_test(
self.run_precision_test(cls, init_kwargs, input_data)

# Check quantization.
if run_quantization_check and parse(keras.version()) >= parse("3.4.0"):
if run_quantization_check and has_quantization_support():
self.run_quantization_test(backbone, cls, init_kwargs, input_data)

def run_task_test(
Expand Down
13 changes: 13 additions & 0 deletions keras_nlp/src/utils/keras_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import keras
from absl import logging
from packaging.version import parse

from keras_nlp.src.utils.tensor_utils import is_tensor_type

Expand Down Expand Up @@ -102,3 +103,15 @@ def print_msg(message, line_break=True):
@keras.saving.register_keras_serializable(package="keras_nlp")
def gelu_approximate(x):
return keras.activations.gelu(x, approximate=True)


def has_quantization_support():
return False if parse(keras.version()) < parse("3.4.0") else True


def assert_quantization_support():
if not has_quantization_support():
raise ValueError(
"Quantization API requires Keras >= 3.4.0 to function "
f"correctly. Received: '{keras.version()}'"
)

0 comments on commit f7c40e3

Please sign in to comment.