Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

logger fixes: close streams and persist during warm start #763

Merged
merged 7 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pysr/logger_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ def write_hparams(self, logger: AnyValue, hparams: dict[str, Any]) -> None:
"""Write hyperparameters to the logger."""
pass # pragma: no cover

@abstractmethod
def close(self, logger: AnyValue) -> None:
"""Close the logger instance."""
pass # pragma: no cover


@dataclass
class TensorBoardLoggerSpec(AbstractLoggerSpec):
Expand Down Expand Up @@ -74,3 +79,7 @@ def write_hparams(self, logger: AnyValue, hparams: dict[str, Any]) -> None:
],
),
)

def close(self, logger: AnyValue) -> None:
base_logger = jl.SymbolicRegression.get_logger(logger)
jl.close(base_logger)
12 changes: 11 additions & 1 deletion pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
stored as an array of uint8, produced by Julia's Serialization.serialize function.
julia_options_stream_ : ndarray
The serialized julia options, stored as an array of uint8,
logger_ : AnyValue | None
The logger instance used for this fit, if any.
expression_spec_ : AbstractExpressionSpec
The expression specification used for this fit. This is equal to
`self.expression_spec` if provided, or `ExpressionSpec()` otherwise.
Expand Down Expand Up @@ -765,6 +767,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
output_directory_: str
julia_state_stream_: NDArray[np.uint8] | None
julia_options_stream_: NDArray[np.uint8] | None
logger_: AnyValue | None
equation_file_contents_: list[pd.DataFrame] | None
show_pickle_warnings_: bool

Expand Down Expand Up @@ -1917,7 +1920,12 @@ def _run(
jl.seval(self.complexity_mapping) if self.complexity_mapping else None
)

logger = self.logger_spec.create_logger() if self.logger_spec else None
if hasattr(self, "logger_") and self.logger_ is not None and self.warm_start:
logger = self.logger_
else:
logger = self.logger_spec.create_logger() if self.logger_spec else None

self.logger_ = logger

# Call to Julia backend.
# See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/OptionsStruct.jl
Expand Down Expand Up @@ -2058,6 +2066,8 @@ def _run(
)
if self.logger_spec is not None:
self.logger_spec.write_hparams(logger, self.get_params())
if not self.warm_start:
self.logger_spec.close(logger)

self.julia_state_stream_ = jl_serialize(out)

Expand Down
69 changes: 40 additions & 29 deletions pysr/test/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,41 +647,52 @@ def test_tensorboard_logger(self):
self.skipTest("TensorBoard not installed. Skipping test.")

y = self.X[:, 0]
with tempfile.TemporaryDirectory() as tmpdir:
logger_spec = TensorBoardLoggerSpec(
log_dir=tmpdir, log_interval=2, overwrite=True
)
model = PySRRegressor(
**self.default_test_kwargs,
logger_spec=logger_spec,
early_stop_condition="stop_if(loss, complexity) = loss < 1e-4 && complexity == 1",
)
model.fit(self.X, y)
for warm_start in [False, True]:
with tempfile.TemporaryDirectory() as tmpdir:
logger_spec = TensorBoardLoggerSpec(
log_dir=tmpdir, log_interval=2, overwrite=True
)
model = PySRRegressor(
**self.default_test_kwargs,
logger_spec=logger_spec,
early_stop_condition="stop_if(loss, complexity) = loss < 1e-4 && complexity == 1",
warm_start=warm_start,
)
model.fit(self.X, y)
logger = model.logger_
# Should restart from same logger if warm_start is True
model.fit(self.X, y)
logger2 = model.logger_

if warm_start:
self.assertEqual(logger, logger2)
else:
self.assertNotEqual(logger, logger2)

# Verify log directory exists and contains TensorBoard files
log_dir = Path(tmpdir)
assert log_dir.exists()
files = list(log_dir.glob("events.out.tfevents.*"))
assert len(files) == 1
# Verify log directory exists and contains TensorBoard files
log_dir = Path(tmpdir)
assert log_dir.exists()
files = list(log_dir.glob("events.out.tfevents.*"))
assert len(files) == 1 if warm_start else 2

# Load and verify TensorBoard events
event_acc = EventAccumulator(str(log_dir))
event_acc.Reload()
# Load and verify TensorBoard events
event_acc = EventAccumulator(str(log_dir))
event_acc.Reload()

# Check that we have the expected scalar summaries
scalars = event_acc.Tags()["scalars"]
self.assertIn("search/data/summaries/pareto_volume", scalars)
self.assertIn("search/data/summaries/min_loss", scalars)
# Check that we have the expected scalar summaries
scalars = event_acc.Tags()["scalars"]
self.assertIn("search/data/summaries/pareto_volume", scalars)
self.assertIn("search/data/summaries/min_loss", scalars)

# Check that we have multiple events for each summary
pareto_events = event_acc.Scalars("search/data/summaries/pareto_volume")
min_loss_events = event_acc.Scalars("search/data/summaries/min_loss")
# Check that we have multiple events for each summary
pareto_events = event_acc.Scalars("search/data/summaries/pareto_volume")
min_loss_events = event_acc.Scalars("search/data/summaries/min_loss")

self.assertGreater(len(pareto_events), 0)
self.assertGreater(len(min_loss_events), 0)
self.assertGreater(len(pareto_events), 0)
self.assertGreater(len(min_loss_events), 0)

# Verify model still works as expected
self.assertLessEqual(model.get_best()["loss"], 1e-4)
# Verify model still works as expected
self.assertLessEqual(model.get_best()["loss"], 1e-4)


def manually_create_model(equations, feature_names=None):
Expand Down
Loading