Skip to content

Commit

Permalink
Try to reduce training CI time
Browse files Browse the repository at this point in the history
  • Loading branch information
talmo committed Apr 8, 2024
1 parent c0db085 commit 4c0b7ca
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions tests/nn/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def cfg():
cfg = TrainingJobConfig()
cfg.data.instance_cropping.center_on_part = "A"
cfg.model.backbone.unet = UNetConfig(
max_stride=8, output_stride=1, filters=8, filters_rate=1.0
max_stride=8, output_stride=1, filters=2, filters_rate=1.0
)
cfg.optimization.preload_data = False
cfg.optimization.batch_size = 1
Expand Down Expand Up @@ -251,12 +251,12 @@ def test_train_bottomup_with_offset(training_labels, cfg):

def test_train_bottomup_multiclass(min_tracks_2node_labels, cfg):
labels = min_tracks_2node_labels
cfg.data.preprocessing.input_scaling = 0.5
cfg.data.preprocessing.input_scaling = 0.25
cfg.model.heads.multi_class_bottomup = sleap.nn.config.MultiClassBottomUpConfig(
confmaps=sleap.nn.config.MultiInstanceConfmapsHeadConfig(
output_stride=2, offset_refinement=False
output_stride=4, offset_refinement=False
),
class_maps=sleap.nn.config.ClassMapsHeadConfig(output_stride=2),
class_maps=sleap.nn.config.ClassMapsHeadConfig(output_stride=4),
)
trainer = sleap.nn.training.BottomUpMultiClassModelTrainer.from_config(
cfg, training_labels=labels
Expand All @@ -266,8 +266,8 @@ def test_train_bottomup_multiclass(min_tracks_2node_labels, cfg):

assert trainer.keras_model.output_names[0] == "MultiInstanceConfmapsHead"
assert trainer.keras_model.output_names[1] == "ClassMapsHead"
assert tuple(trainer.keras_model.outputs[0].shape) == (None, 256, 256, 2)
assert tuple(trainer.keras_model.outputs[1].shape) == (None, 256, 256, 2)
assert tuple(trainer.keras_model.outputs[0].shape) == (None, 64, 64, 2)
assert tuple(trainer.keras_model.outputs[1].shape) == (None, 64, 64, 2)


def test_train_topdown_multiclass(min_tracks_2node_labels, cfg):
Expand Down

0 comments on commit 4c0b7ca

Please sign in to comment.