Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ExportCallback #130

Merged
merged 6 commits into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 46 additions & 18 deletions elastica/callback_functions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
__doc__ = """ Module contains callback classes to save simulation data for rod-like objects """
__all__ = ["CallBackBaseClass", "MyCallBack", "ExportCallBack"]

import warnings
import os
import sys
import numpy as np
import logging

from collections import defaultdict


class CallBackBaseClass:
Expand Down Expand Up @@ -111,52 +113,59 @@ class ExportCallBack(CallBackBaseClass):
def __init__(
self,
step_skip: int,
path: str,
filename: str,
directory: str,
method: str,
initial_file_count: int = 0,
file_save_interval: int = 1e8,
):
"""

Parameters
----------
step_skip : int
Interval to collect simulation data into buffer.
The data will be collected at every `dt * step_skip`
interval.
path : str
Path to save the file. If directories are prepended,
they must exist. The filename depends on the method.
The path is not expected to include extension.
filename : str
Name of the file without extension. The extension will be
determined depend on the method.
directory : str
Path to save the file. If directory doesn't exist, it will
be created. Any existing files will be overwritten.
skim0119 marked this conversation as resolved.
Show resolved Hide resolved
method : str
Method name. Only the name in AVAILABLE_METHOD is
allowed.
initial_file_count : int
Initial file count index that will be appended
file_save_interval : int
Interval, in steps, to export/save collected buffer
as file. (default=1e8)
as file. (default = 1e8)
"""
# Assertions
MIN_STEP_SKIP = 100
if step_skip <= MIN_STEP_SKIP:
warnings.warn(f"We recommend step_skip at least {MIN_STEP_SKIP}")
logging.warning(
f"We recommend (step_skip={step_skip}) at least {MIN_STEP_SKIP}"
)
assert (
method in ExportCallBack.AVAILABLE_METHOD
), f"The exporting method ({method}) is not supported. Please use one of {ExportCallBack.AVAILABLE_METHOD}."
# TODO: Assertion is temporarily disabled. Should be fixed with #127
# assert os.path.exists(path), "The export path does not exist."

# Create directory
if os.path.exists(directory):
logging.warning(
f"Save file already exists in {directory}. Files will be overwritten"
)
skim0119 marked this conversation as resolved.
Show resolved Hide resolved
os.makedirs(directory, exist_ok=True)

# Argument Parameters
self.step_skip = step_skip
self.save_path = path
self.save_path = os.path.join(directory, filename) + "_{:02d}.{}"
self.method = method
self.file_count = initial_file_count
self.file_save_interval = file_save_interval

# Data collector
from collections import defaultdict

self.buffer = defaultdict(list)
self.buffer_size = 0

Expand All @@ -165,16 +174,19 @@ def __init__(
import pickle

self._pickle = pickle
self._ext = "pkl"
elif method == ExportCallBack.AVAILABLE_METHOD[1]:
from numpy import savez

self._savez = savez
self._ext = "npz"
elif method == ExportCallBack.AVAILABLE_METHOD[2]:
import tempfile
import pickle

self._tempfile = tempfile.NamedTemporaryFile(delete=False)
self._pickle = pickle
self._ext = "pkl"

def make_callback(self, system, time, current_step: int):
"""
Expand Down Expand Up @@ -206,7 +218,7 @@ def make_callback(self, system, time, current_step: int):
)

if (
self.buffer_size > ExportCallBack.FILE_SIZE_CUTOFF
self.buffer_size > self.FILE_SIZE_CUTOFF
or (current_step + 1) % self.file_save_interval == 0
):
self._dump()
Expand All @@ -216,7 +228,7 @@ def _dump(self, **kwargs):
Dump dictionary buffer (self.buffer) to a file and clear
the buffer.
"""
file_path = f"{self.save_path}_{self.file_count}.dat"
file_path = self.save_path.format(self.file_count, self._ext)
data = {k: np.array(v) for k, v in self.buffer.items()}
if self.method == ExportCallBack.AVAILABLE_METHOD[0]:
# pickle
Expand All @@ -234,9 +246,25 @@ def _dump(self, **kwargs):
self.buffer_size = 0
self.buffer.clear()

def __del__(self):
def get_last_saved_path(self) -> str:
"""
Save residual buffer on exit
Return last saved file path. If no file has been saved,
return None
"""
if self.file_count == 0:
return None
else:
return self.save_path.format(self.file_count - 1, self._ext)

def close(self):
"""
Save residual buffer
"""
if self.buffer_size:
self._dump()

def clear(self):
"""
Alias to `close`
"""
self.close()
Empty file added tests/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion tests/test_boundary_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# System imports
import numpy as np
from test_rod.test_rods import MockTestRod
from tests.test_rod.test_rods import MockTestRod
from elastica.boundary_conditions import (
ConstraintBase,
FreeBC,
Expand Down
Loading