diff --git a/Wrappers/Python/cil/optimisation/utilities/callbacks.py b/Wrappers/Python/cil/optimisation/utilities/callbacks.py index a5c136695..84ea62619 100644 --- a/Wrappers/Python/cil/optimisation/utilities/callbacks.py +++ b/Wrappers/Python/cil/optimisation/utilities/callbacks.py @@ -1,9 +1,16 @@ from abc import ABC, abstractmethod +from datetime import datetime from functools import partialmethod +from os.path import join +from pathlib import Path +from sqlite3 import Connection +from typing import Callable, List, Union +import numpy as np from tqdm.auto import tqdm as tqdm_auto from tqdm.std import tqdm as tqdm_std -import numpy as np + +from ..algorithms import Algorithm class Callback(ABC): @@ -187,3 +194,78 @@ def __call__(self, algorithm): raise StopIteration +class SqliteCallback(Callback): + """ + Callback which evaluates user-defined functions at user-defined times during evaluation. Also saves calculation + state if you want. + Parameters + ---------- + db_address : Union[str, Path] + Location of the sqlite database we're going to write to. + output_folder : Union[str, Path] + Location of folder to which artifacts will be written. + report_progress : Callable[[Algorithm], bool] + Rules which tells us when to report calculation progress using the functions in to_store. Note that + this is in union with grab_soln's reports. + grab_soln : Callable[[Algorithm], bool] + Rule which tells us when to bother saving solution state. Put it a function which is always false + if you don't want it to do that ever. + to_store : List[Callable[[Algorithm], Union[str, int, float]]] + List of functions to get evaluated and stored. + types : List[str] + Return SQL types (TEXT, INT, etc.) of the to_store functions. + names : List[str] + Names of columns for entries to be recorded in database. + calculation_index : int + Index assigned to this calculation run. + table_name : str + Name of table we're going to put results in. + verbose : int + Verbosity for Callback. + """ + def __init__(self, db_address: Union[str, Path], output_folder: Union[str, Path], + report_progress: Callable[[Algorithm], bool], grab_soln: Callable[[Algorithm], bool], + to_store: List[Callable[[Algorithm], Union[str, int, float]]], types: List[str], + names: List[str], calculation_index: int, table_name: str = "cil_calc_results", + verbose: int = 1): + self._db_address = Path(db_address).resolve() + self._output_folder = Path(output_folder).resolve() + self._take_snapshot = grab_soln + self._to_store = to_store + self._types = types + self._names = names + self._calculation_label = calculation_index + self._table_name = table_name + self._report_progress = report_progress + self._insert_heading_command = f"(time_stamp, calc_label, iteration, {', '.join(self._names)}, artifact_location)" + super().__init__(verbose=verbose) + variable_length_column_definition = ", ".join([f"{self._names[output_index]} " + + f"{self._types[output_index]}" + for output_index in range(len(self._to_store))]) + con = Connection(self._db_address) + cur = con.cursor() + cur.execute(f"CREATE TABLE IF NOT EXISTS {self._table_name}(id INTEGER PRIMARY KEY, time_stamp TEXT, " + + f"calc_label INT, iteration INT, {variable_length_column_definition}, artifact_location TEXT);") + con.commit() + con.close() + + def __call__(self, algorithm: Algorithm): + artifact_location = "NULL" + take_snapshot = self._take_snapshot(algorithm) + report_progress = self._report_progress(algorithm) + iteration = algorithm.iteration + if take_snapshot: + artifact_location = join(self._output_folder, f"{self._calculation_label}_{iteration}.npy") + np.save(artifact_location, algorithm.solution.array) + artifact_location = f"\'{artifact_location}\'" + if take_snapshot or report_progress: + storables = [str(f(algorithm)) for f in self._to_store] + current_time = str(datetime.now()) + runstring = (f"INSERT INTO {self._table_name}{self._insert_heading_command} " + + f"VALUES(\'{current_time}\', {self._calculation_label}, {iteration}, " + + f"{', '.join(storables)}, {artifact_location});") + con = Connection(self._db_address) + cur = con.cursor() + cur.execute(runstring) + con.commit() + con.close() \ No newline at end of file