Skip to content

Commit

Permalink
add: video swin backbone test
Browse files Browse the repository at this point in the history
  • Loading branch information
innat committed Mar 3, 2024
1 parent d126b7c commit 61303be
Showing 1 changed file with 60 additions and 0 deletions.
60 changes: 60 additions & 0 deletions keras_cv/models/backbones/video_swin/video_swin_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import numpy as np
import pytest

from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.models.backbones.video_swin.video_swin_aliases import VideoSwinSBackbone
from keras_cv.tests.test_case import TestCase

class TestViTDetBackbone(TestCase):
@pytest.mark.large
def test_call(self):
model = VideoSwinSBackbone()
x = np.ones((1, 32, 224, 224, 3))
x_out = ops.convert_to_numpy(model(x))
num_parameters = sum(
np.prod(tuple(x.shape)) for x in model.trainable_variables
)
self.assertEqual(x_out.shape, (1, 16, 7, 7, 768))
self.assertEqual(num_parameters, 49_509_078)

@pytest.mark.extra_large
def teat_save(self):
# saving test
model = VideoSwinSBackbone()
x = np.ones((1, 32, 224, 224, 3))
x_out = ops.convert_to_numpy(model(x))
path = os.path.join(self.get_temp_dir(), "model.keras")
model.save(path)
loaded_model = keras.saving.load_model(path)
x_out_loaded = ops.convert_to_numpy(loaded_model(x))
self.assertAllClose(x_out, x_out_loaded)

@pytest.mark.extra_large
def test_fit(self):
model = VideoSwinSBackbone()
x = np.ones((1, 32, 224, 224, 3))
y = np.zeros((1, 16, 7, 7, 768))
model.compile(optimizer="adam", loss="mse", metrics=["mse"])
model.fit(x, y, epochs=1)

def test_pyramid_level_inputs_error(self):
model = VideoSwinSBackbone()
with self.assertRaises(NotImplementedError, msg="doesn't compute"):
model.pyramid_level_inputs

0 comments on commit 61303be

Please sign in to comment.