From 18966f78fae9a84406a75c7d68181b0c8af42daa Mon Sep 17 00:00:00 2001 From: BrotherHa Date: Wed, 4 Dec 2024 10:03:40 +0100 Subject: [PATCH 1/6] Update logger_specs.py Added close method to logger_specs.py --- pysr/logger_specs.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/pysr/logger_specs.py b/pysr/logger_specs.py index 354a21de..c8f8c23b 100644 --- a/pysr/logger_specs.py +++ b/pysr/logger_specs.py @@ -19,6 +19,11 @@ def write_hparams(self, logger: AnyValue, hparams: dict[str, Any]) -> None: """Write hyperparameters to the logger.""" pass # pragma: no cover + @abstractmethod + def close(self, logger: AnyValue) -> None: + """Close the logger instance.""" + pass # pragma: no cover + @dataclass class TensorBoardLoggerSpec(AbstractLoggerSpec): @@ -74,3 +79,18 @@ def write_hparams(self, logger: AnyValue, hparams: dict[str, Any]) -> None: ], ), ) + + def close(self, logger: AnyValue) -> None: + base_logger = jl.SymbolicRegression.get_logger(logger) + close_logger = jl.seval( + """ + function close_files!(lg::TensorBoardLogger.TBLogger) + # close open streams + for k=keys(lg.all_files) + close(lg.all_files[k]) + end + end + """ + ) + close_logger(base_logger) + From c1ded8c977ed9a9a706f38c64770e4a11d09b4d9 Mon Sep 17 00:00:00 2001 From: BrotherHa Date: Wed, 4 Dec 2024 10:05:13 +0100 Subject: [PATCH 2/6] Update sr.py Close logger at the end of run. --- pysr/sr.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pysr/sr.py b/pysr/sr.py index aa9492e2..e7d8156e 100644 --- a/pysr/sr.py +++ b/pysr/sr.py @@ -2058,6 +2058,7 @@ def _run( ) if self.logger_spec is not None: self.logger_spec.write_hparams(logger, self.get_params()) + self.logger_spec.close(logger) self.julia_state_stream_ = jl_serialize(out) From cad563660fff87e80f74fcbd150ef63674ed766f Mon Sep 17 00:00:00 2001 From: BrotherHa Date: Wed, 4 Dec 2024 10:09:35 +0100 Subject: [PATCH 3/6] Update logger_specs.py --- pysr/logger_specs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pysr/logger_specs.py b/pysr/logger_specs.py index c8f8c23b..5a8a328b 100644 --- a/pysr/logger_specs.py +++ b/pysr/logger_specs.py @@ -84,7 +84,7 @@ def close(self, logger: AnyValue) -> None: base_logger = jl.SymbolicRegression.get_logger(logger) close_logger = jl.seval( """ - function close_files!(lg::TensorBoardLogger.TBLogger) + function close_logger(lg::TensorBoardLogger.TBLogger) # close open streams for k=keys(lg.all_files) close(lg.all_files[k]) From 6f557d1f9ac0143972e4f939bee115a61d7c1a6a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Dec 2024 09:21:38 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pysr/logger_specs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pysr/logger_specs.py b/pysr/logger_specs.py index 5a8a328b..c3e4e014 100644 --- a/pysr/logger_specs.py +++ b/pysr/logger_specs.py @@ -93,4 +93,3 @@ def close(self, logger: AnyValue) -> None: """ ) close_logger(base_logger) - From 8a851ed155383764dddc28a975ef459e28ed9da4 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 6 Dec 2024 20:34:53 +0000 Subject: [PATCH 5/6] refactor: cleaner logging --- pysr/logger_specs.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/pysr/logger_specs.py b/pysr/logger_specs.py index c3e4e014..f4c0a3df 100644 --- a/pysr/logger_specs.py +++ b/pysr/logger_specs.py @@ -82,14 +82,4 @@ def write_hparams(self, logger: AnyValue, hparams: dict[str, Any]) -> None: def close(self, logger: AnyValue) -> None: base_logger = jl.SymbolicRegression.get_logger(logger) - close_logger = jl.seval( - """ - function close_logger(lg::TensorBoardLogger.TBLogger) - # close open streams - for k=keys(lg.all_files) - close(lg.all_files[k]) - end - end - """ - ) - close_logger(base_logger) + jl.close(base_logger) From 1e86ec5351857330ad272c9c23bf60c650175811 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 6 Dec 2024 20:54:39 +0000 Subject: [PATCH 6/6] feat: persist logger via warm start --- pysr/sr.py | 13 ++++++-- pysr/test/test_main.py | 69 ++++++++++++++++++++++++------------------ 2 files changed, 51 insertions(+), 31 deletions(-) diff --git a/pysr/sr.py b/pysr/sr.py index e7d8156e..3cae1735 100644 --- a/pysr/sr.py +++ b/pysr/sr.py @@ -703,6 +703,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator): stored as an array of uint8, produced by Julia's Serialization.serialize function. julia_options_stream_ : ndarray The serialized julia options, stored as an array of uint8, + logger_ : AnyValue | None + The logger instance used for this fit, if any. expression_spec_ : AbstractExpressionSpec The expression specification used for this fit. This is equal to `self.expression_spec` if provided, or `ExpressionSpec()` otherwise. @@ -765,6 +767,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator): output_directory_: str julia_state_stream_: NDArray[np.uint8] | None julia_options_stream_: NDArray[np.uint8] | None + logger_: AnyValue | None equation_file_contents_: list[pd.DataFrame] | None show_pickle_warnings_: bool @@ -1917,7 +1920,12 @@ def _run( jl.seval(self.complexity_mapping) if self.complexity_mapping else None ) - logger = self.logger_spec.create_logger() if self.logger_spec else None + if hasattr(self, "logger_") and self.logger_ is not None and self.warm_start: + logger = self.logger_ + else: + logger = self.logger_spec.create_logger() if self.logger_spec else None + + self.logger_ = logger # Call to Julia backend. # See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/OptionsStruct.jl @@ -2058,7 +2066,8 @@ def _run( ) if self.logger_spec is not None: self.logger_spec.write_hparams(logger, self.get_params()) - self.logger_spec.close(logger) + if not self.warm_start: + self.logger_spec.close(logger) self.julia_state_stream_ = jl_serialize(out) diff --git a/pysr/test/test_main.py b/pysr/test/test_main.py index 3264383a..52772042 100644 --- a/pysr/test/test_main.py +++ b/pysr/test/test_main.py @@ -647,41 +647,52 @@ def test_tensorboard_logger(self): self.skipTest("TensorBoard not installed. Skipping test.") y = self.X[:, 0] - with tempfile.TemporaryDirectory() as tmpdir: - logger_spec = TensorBoardLoggerSpec( - log_dir=tmpdir, log_interval=2, overwrite=True - ) - model = PySRRegressor( - **self.default_test_kwargs, - logger_spec=logger_spec, - early_stop_condition="stop_if(loss, complexity) = loss < 1e-4 && complexity == 1", - ) - model.fit(self.X, y) + for warm_start in [False, True]: + with tempfile.TemporaryDirectory() as tmpdir: + logger_spec = TensorBoardLoggerSpec( + log_dir=tmpdir, log_interval=2, overwrite=True + ) + model = PySRRegressor( + **self.default_test_kwargs, + logger_spec=logger_spec, + early_stop_condition="stop_if(loss, complexity) = loss < 1e-4 && complexity == 1", + warm_start=warm_start, + ) + model.fit(self.X, y) + logger = model.logger_ + # Should restart from same logger if warm_start is True + model.fit(self.X, y) + logger2 = model.logger_ + + if warm_start: + self.assertEqual(logger, logger2) + else: + self.assertNotEqual(logger, logger2) - # Verify log directory exists and contains TensorBoard files - log_dir = Path(tmpdir) - assert log_dir.exists() - files = list(log_dir.glob("events.out.tfevents.*")) - assert len(files) == 1 + # Verify log directory exists and contains TensorBoard files + log_dir = Path(tmpdir) + assert log_dir.exists() + files = list(log_dir.glob("events.out.tfevents.*")) + assert len(files) == 1 if warm_start else 2 - # Load and verify TensorBoard events - event_acc = EventAccumulator(str(log_dir)) - event_acc.Reload() + # Load and verify TensorBoard events + event_acc = EventAccumulator(str(log_dir)) + event_acc.Reload() - # Check that we have the expected scalar summaries - scalars = event_acc.Tags()["scalars"] - self.assertIn("search/data/summaries/pareto_volume", scalars) - self.assertIn("search/data/summaries/min_loss", scalars) + # Check that we have the expected scalar summaries + scalars = event_acc.Tags()["scalars"] + self.assertIn("search/data/summaries/pareto_volume", scalars) + self.assertIn("search/data/summaries/min_loss", scalars) - # Check that we have multiple events for each summary - pareto_events = event_acc.Scalars("search/data/summaries/pareto_volume") - min_loss_events = event_acc.Scalars("search/data/summaries/min_loss") + # Check that we have multiple events for each summary + pareto_events = event_acc.Scalars("search/data/summaries/pareto_volume") + min_loss_events = event_acc.Scalars("search/data/summaries/min_loss") - self.assertGreater(len(pareto_events), 0) - self.assertGreater(len(min_loss_events), 0) + self.assertGreater(len(pareto_events), 0) + self.assertGreater(len(min_loss_events), 0) - # Verify model still works as expected - self.assertLessEqual(model.get_best()["loss"], 1e-4) + # Verify model still works as expected + self.assertLessEqual(model.get_best()["loss"], 1e-4) def manually_create_model(equations, feature_names=None):