Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sqlite Callback #1988

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 83 additions & 1 deletion Wrappers/Python/cil/optimisation/utilities/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from abc import ABC, abstractmethod
from datetime import datetime
from functools import partialmethod
from os.path import join
from pathlib import Path
from sqlite3 import Connection
from typing import Callable, List, Union

import numpy as np
from tqdm.auto import tqdm as tqdm_auto
from tqdm.std import tqdm as tqdm_std
import numpy as np

from ..algorithms import Algorithm


class Callback(ABC):
Expand Down Expand Up @@ -187,3 +194,78 @@ def __call__(self, algorithm):
raise StopIteration


class SqliteCallback(Callback):
"""
Callback which evaluates user-defined functions at user-defined times during evaluation. Also saves calculation
state if you want.
Parameters
----------
db_address : Union[str, Path]
Location of the sqlite database we're going to write to.
output_folder : Union[str, Path]
Location of folder to which artifacts will be written.
report_progress : Callable[[Algorithm], bool]
Rules which tells us when to report calculation progress using the functions in to_store. Note that
this is in union with grab_soln's reports.
grab_soln : Callable[[Algorithm], bool]
Rule which tells us when to bother saving solution state. Put it a function which is always false
if you don't want it to do that ever.
to_store : List[Callable[[Algorithm], Union[str, int, float]]]
List of functions to get evaluated and stored.
types : List[str]
Return SQL types (TEXT, INT, etc.) of the to_store functions.
names : List[str]
Names of columns for entries to be recorded in database.
calculation_index : int
Index assigned to this calculation run.
table_name : str
Name of table we're going to put results in.
verbose : int
Verbosity for Callback.
"""
def __init__(self, db_address: Union[str, Path], output_folder: Union[str, Path],
report_progress: Callable[[Algorithm], bool], grab_soln: Callable[[Algorithm], bool],
to_store: List[Callable[[Algorithm], Union[str, int, float]]], types: List[str],
names: List[str], calculation_index: int, table_name: str = "cil_calc_results",
verbose: int = 1):
self._db_address = Path(db_address).resolve()
self._output_folder = Path(output_folder).resolve()
self._take_snapshot = grab_soln
self._to_store = to_store
self._types = types
self._names = names
self._calculation_label = calculation_index
self._table_name = table_name
self._report_progress = report_progress
self._insert_heading_command = f"(time_stamp, calc_label, iteration, {', '.join(self._names)}, artifact_location)"
super().__init__(verbose=verbose)
variable_length_column_definition = ", ".join([f"{self._names[output_index]} " +
f"{self._types[output_index]}"
for output_index in range(len(self._to_store))])
con = Connection(self._db_address)
cur = con.cursor()
cur.execute(f"CREATE TABLE IF NOT EXISTS {self._table_name}(id INTEGER PRIMARY KEY, time_stamp TEXT, " +
f"calc_label INT, iteration INT, {variable_length_column_definition}, artifact_location TEXT);")
con.commit()
con.close()

def __call__(self, algorithm: Algorithm):
artifact_location = "NULL"
take_snapshot = self._take_snapshot(algorithm)
report_progress = self._report_progress(algorithm)
iteration = algorithm.iteration
if take_snapshot:
artifact_location = join(self._output_folder, f"{self._calculation_label}_{iteration}.npy")
np.save(artifact_location, algorithm.solution.array)
artifact_location = f"\'{artifact_location}\'"
if take_snapshot or report_progress:
storables = [str(f(algorithm)) for f in self._to_store]
current_time = str(datetime.now())
runstring = (f"INSERT INTO {self._table_name}{self._insert_heading_command} " +
f"VALUES(\'{current_time}\', {self._calculation_label}, {iteration}, " +
f"{', '.join(storables)}, {artifact_location});")
con = Connection(self._db_address)
cur = con.cursor()
cur.execute(runstring)
con.commit()
con.close()
Loading