From 690da602416658756cc29776ecd0747699e9cc2d Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Tue, 16 Apr 2024 11:47:17 -0400 Subject: [PATCH 1/2] lazy run --- dacapo/experiments/run.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/dacapo/experiments/run.py b/dacapo/experiments/run.py index a0089e712..e88ac4200 100644 --- a/dacapo/experiments/run.py +++ b/dacapo/experiments/run.py @@ -93,6 +93,7 @@ def __init__(self, run_config, load_starter_model: bool = True): """ self.name = run_config.name + self._config = run_config self.train_until = run_config.num_iterations self.validation_interval = run_config.validation_interval @@ -106,7 +107,10 @@ def __init__(self, run_config, load_starter_model: bool = True): self.task = task_type(run_config.task_config) self.architecture = architecture_type(run_config.architecture_config) self.trainer = trainer_type(run_config.trainer_config) - self.datasplit = datasplit_type(run_config.datasplit_config) + + # lazy load datasplit + self._datasplit = None + # combined pieces self.model = self.task.create_model(self.architecture) @@ -114,9 +118,7 @@ def __init__(self, run_config, load_starter_model: bool = True): # tracking self.training_stats = TrainingStats() - self.validation_scores = ValidationScores( - self.task.parameters, self.datasplit.validate, self.task.evaluation_scores - ) + self._validation_scores = None if not load_starter_model: self.start = None @@ -142,6 +144,22 @@ def __init__(self, run_config, load_starter_model: bool = True): self.start.initialize_weights(self.model, new_head=new_head) + @property + def datasplit(self): + if self._datasplit is None: + self._datasplit = self._config.datasplit_config.datasplit_type( + self._config.datasplit_config + ) + return self._datasplit + + @property + def validation_scores(self): + if self._validation_scores is None: + self._validation_scores = ValidationScores( + self.task.parameters, self.datasplit.validate, self.task.evaluation_scores + ) + return self._validation_scores + @staticmethod def get_validation_scores(run_config) -> ValidationScores: """ From 0b443d4cb8f0b66ee161afeeee72e27cfdb2daa7 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Tue, 16 Apr 2024 11:48:00 -0400 Subject: [PATCH 2/2] balck format --- dacapo/experiments/datasplits/datasets/arrays/zarr_array.py | 5 ++--- dacapo/experiments/datasplits/datasplit_generator.py | 4 ++-- dacapo/experiments/run.py | 5 +++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index c976f7d96..f61bf0cd4 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -108,7 +108,6 @@ def __init__(self, array_config): self._attributes = self.data.attrs self._axes = array_config._axes self.snap_to_grid = array_config.snap_to_grid - def __str__(self): """ @@ -144,7 +143,7 @@ def __repr__(self): """ return f"ZarrArray({self.file_name}, {self.dataset})" - + @property def mode(self): if not hasattr(self, "_mode"): @@ -421,7 +420,7 @@ def create_from_array_identifier( num_channels, voxel_size, dtype, - mode = "a", + mode="a", write_size=None, name=None, overwrite=False, diff --git a/dacapo/experiments/datasplits/datasplit_generator.py b/dacapo/experiments/datasplits/datasplit_generator.py index 54df330b6..a69fd633f 100644 --- a/dacapo/experiments/datasplits/datasplit_generator.py +++ b/dacapo/experiments/datasplits/datasplit_generator.py @@ -125,7 +125,7 @@ def get_right_resolution_array_config( file_name=container, dataset=str(current_dataset_path), snap_to_grid=target_resolution, - mode = "r" + mode="r", ) zarr_array = ZarrArray(zarr_config) while ( @@ -138,7 +138,7 @@ def get_right_resolution_array_config( file_name=container, dataset=str(Path(dataset, f"s{level}")), snap_to_grid=target_resolution, - mode = "r" + mode="r", ) zarr_array = ZarrArray(zarr_config) diff --git a/dacapo/experiments/run.py b/dacapo/experiments/run.py index e88ac4200..e7afb03ad 100644 --- a/dacapo/experiments/run.py +++ b/dacapo/experiments/run.py @@ -110,7 +110,6 @@ def __init__(self, run_config, load_starter_model: bool = True): # lazy load datasplit self._datasplit = None - # combined pieces self.model = self.task.create_model(self.architecture) @@ -156,7 +155,9 @@ def datasplit(self): def validation_scores(self): if self._validation_scores is None: self._validation_scores = ValidationScores( - self.task.parameters, self.datasplit.validate, self.task.evaluation_scores + self.task.parameters, + self.datasplit.validate, + self.task.evaluation_scores, ) return self._validation_scores