From 4c0b7caa639a7e6b4d2a3840ebdd57759ff0f5f0 Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Sun, 7 Apr 2024 20:25:19 -0700 Subject: [PATCH] Try to reduce training CI time --- tests/nn/test_training.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/nn/test_training.py b/tests/nn/test_training.py index 55f404929..9460b0c0a 100644 --- a/tests/nn/test_training.py +++ b/tests/nn/test_training.py @@ -44,7 +44,7 @@ def cfg(): cfg = TrainingJobConfig() cfg.data.instance_cropping.center_on_part = "A" cfg.model.backbone.unet = UNetConfig( - max_stride=8, output_stride=1, filters=8, filters_rate=1.0 + max_stride=8, output_stride=1, filters=2, filters_rate=1.0 ) cfg.optimization.preload_data = False cfg.optimization.batch_size = 1 @@ -251,12 +251,12 @@ def test_train_bottomup_with_offset(training_labels, cfg): def test_train_bottomup_multiclass(min_tracks_2node_labels, cfg): labels = min_tracks_2node_labels - cfg.data.preprocessing.input_scaling = 0.5 + cfg.data.preprocessing.input_scaling = 0.25 cfg.model.heads.multi_class_bottomup = sleap.nn.config.MultiClassBottomUpConfig( confmaps=sleap.nn.config.MultiInstanceConfmapsHeadConfig( - output_stride=2, offset_refinement=False + output_stride=4, offset_refinement=False ), - class_maps=sleap.nn.config.ClassMapsHeadConfig(output_stride=2), + class_maps=sleap.nn.config.ClassMapsHeadConfig(output_stride=4), ) trainer = sleap.nn.training.BottomUpMultiClassModelTrainer.from_config( cfg, training_labels=labels @@ -266,8 +266,8 @@ def test_train_bottomup_multiclass(min_tracks_2node_labels, cfg): assert trainer.keras_model.output_names[0] == "MultiInstanceConfmapsHead" assert trainer.keras_model.output_names[1] == "ClassMapsHead" - assert tuple(trainer.keras_model.outputs[0].shape) == (None, 256, 256, 2) - assert tuple(trainer.keras_model.outputs[1].shape) == (None, 256, 256, 2) + assert tuple(trainer.keras_model.outputs[0].shape) == (None, 64, 64, 2) + assert tuple(trainer.keras_model.outputs[1].shape) == (None, 64, 64, 2) def test_train_topdown_multiclass(min_tracks_2node_labels, cfg):