Skip to content

Commit

Permalink
Reduce memory consumption for BasNet tests (#2325)
Browse files Browse the repository at this point in the history
  • Loading branch information
tirthasheshpatel authored Jan 30, 2024
1 parent d04fbcc commit c67a0c7
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions keras_cv/models/segmentation/basnet/basnet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class BASNetTest(TestCase):
def test_basnet_construction(self):
backbone = ResNet18Backbone()
model = BASNet(
input_shape=[288, 288, 3], backbone=backbone, num_classes=1
input_shape=[64, 64, 3], backbone=backbone, num_classes=1
)
model.compile(
optimizer="adam",
Expand All @@ -44,17 +44,17 @@ def test_basnet_construction(self):
def test_basnet_call(self):
backbone = ResNet18Backbone()
model = BASNet(
input_shape=[288, 288, 3], backbone=backbone, num_classes=1
input_shape=[64, 64, 3], backbone=backbone, num_classes=1
)
images = np.random.uniform(size=(2, 288, 288, 3))
images = np.random.uniform(size=(2, 64, 64, 3))
_ = model(images)
_ = model.predict(images)

@pytest.mark.large
@pytest.mark.filterwarnings("ignore::UserWarning")
def test_weights_change(self):
input_size = [288, 288, 3]
target_size = [288, 288, 1]
input_size = [64, 64, 3]
target_size = [64, 64, 1]

images = np.ones([1] + input_size)
labels = np.random.uniform(size=[1] + target_size)
Expand All @@ -64,7 +64,7 @@ def test_weights_change(self):

backbone = ResNet18Backbone()
model = BASNet(
input_shape=[288, 288, 3], backbone=backbone, num_classes=1
input_shape=[64, 64, 3], backbone=backbone, num_classes=1
)
model_metrics = ["accuracy"]
if keras_3():
Expand All @@ -77,7 +77,7 @@ def test_weights_change(self):
)

original_weights = model.refinement_head.get_weights()
model.fit(ds, epochs=1)
model.fit(ds, epochs=1, batch_size=1)
updated_weights = model.refinement_head.get_weights()

for w1, w2 in zip(original_weights, updated_weights):
Expand All @@ -98,11 +98,11 @@ def test_with_model_preset_forward_pass(self):

@pytest.mark.large
def test_saved_model(self):
target_size = [288, 288, 3]
target_size = [64, 64, 3]

backbone = ResNet18Backbone()
model = BASNet(
input_shape=[288, 288, 3], backbone=backbone, num_classes=1
input_shape=[64, 64, 3], backbone=backbone, num_classes=1
)

input_batch = np.ones(shape=[2] + target_size)
Expand Down

0 comments on commit c67a0c7

Please sign in to comment.