Skip to content

Commit

Permalink
self._history_arrays and its getter-function allow for fasting animat…
Browse files Browse the repository at this point in the history
…ions which repeatedly converting lists to arrays
  • Loading branch information
TomGeorge1234 committed Feb 19, 2024
1 parent abc66be commit 4af646d
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 31 deletions.
41 changes: 20 additions & 21 deletions ratinabox/Agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class Agent:
• initialise_position_and_velocity()
• get_history_slice()
• get_all_default_params()
cache_history_as_arrays()
get_history_arrays()
The default params for this agent are:
default_params = {
Expand Down Expand Up @@ -115,7 +115,9 @@ def __init__(self, Environment, params={}):
self.history["vel"] = []
self.history["rot_vel"] = []
self.history["head_direction"] = []
self.history_array_cache = {"last_cache_time":None} # this is used to cache the history data as an arrays for faster plotting/animating

self._last_history_array_cache_time = None
self._history_arrays = {} # this is used to cache the history data as an arrays for faster plotting/animating

self.Neurons = [] # each new Neurons class belonging to this Agent will append itself to this list

Expand Down Expand Up @@ -711,11 +713,10 @@ def plot_trajectory(
#get times and trjectory from history data (normal)
t_end = t_end or self_.history["t"][-1]
slice = self_.get_history_slice(t_start=t_start, t_end=t_end, framerate=framerate)
if (self_.history_array_cache["last_cache_time"] != self.t):
self_.cache_history_as_arrays()
time = self_.history_array_cache["t"][slice]
trajectory = self_.history_array_cache["pos"][slice]
head_direction = self_.history_array_cache["head_direction"][slice]
history_data = self.get_history_arrays() # gets history dataframe as dictionary of arrays (only recomputing arrays from lists if necessary)
time = history_data["t"][slice]
trajectory = history_data["pos"][slice]
head_direction = history_data["head_direction"][slice]
else:
# data has been passed in manually
t_start, t_end = time[0], time[-1]
Expand Down Expand Up @@ -1068,9 +1069,7 @@ def get_history_slice(self, t_start=None, t_end=None, framerate=None):
• t_end: end time in seconds (default = self.history["t"][-1])
• framerate: frames per second (default = None --> step=0 so, just whatever the data frequency (1/Ag.dt) is)
"""
if self.history_array_cache["last_cache_time"] != self.t:
self.cache_history_as_arrays()
t = self.history_array_cache["t"]
t = self.get_history_arrays()["t"]
t_start = t_start or t[0]
startid = np.nanargmin(np.abs(t - (t_start)))
t_end = t_end or t[-1]
Expand All @@ -1081,14 +1080,14 @@ def get_history_slice(self, t_start=None, t_end=None, framerate=None):
skiprate = max(1, int((1 / framerate) / self.dt))

return slice(startid, endid, skiprate)

def cache_history_as_arrays(self):
"""Converts anything in the current history dictionary into a numpy array along with the time this cache was made. This is useful for speeding up animating functions which require slicing the history data but repeatedly converting to arrays is expensive. This is called automatically by the plot_trajectory function if the history data has not been cached yet.
TODO This should probably be improved, right now it will convert and cache _all_ history data, even if only some of it is needed."""
self.history_array_cache = {}
self.history_array_cache["last_cache_time"] = self.t
for key in self.history.keys():
try: #will skip if for any reason this key cannot be converted to an array, so you can still save random stuff into the history dict without breaking this function
self.history_array_cache[key] = np.array(self.history[key])
except: pass
return
def get_history_arrays(self):
"""Returns the history dataframe as a dictionary of numpy arrays (as opposed to lists). This getter-function only updates the self._history_arrays if the Agent/Neuron has updates since the last time it was called. This avoids expensive repeated conversion of lists to arrays during animations."""
if (self._last_history_array_cache_time != self.t):
self._history_arrays = {}
self._last_history_array_cache_time = self.t
for key in self.history.keys():
try: #will skip if for any reason this key cannot be converted to an array, so you can still save random stuff into the history dict without breaking this function
self._history_arrays[key] = np.array(self.history[key])
except: pass
return self._history_arrays
47 changes: 37 additions & 10 deletions ratinabox/Neurons.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def __init__(self, Agent, params={}):
self.history["firingrate"] = []
self.history["spikes"] = []

self._last_history_array_cache_time = None
self._history_arrays = {} # this dictionary is the same as self.history except the data is in arrays not lists BUT it should only be accessed via its getter-function `self.get_history_arrays()`. This is because the lists are only converted to arrays when they are accessed, not on every step, so as to save time.

self.colormap = "inferno" # default colormap for plotting ratemaps

if ratinabox.verbose is True:
Expand Down Expand Up @@ -347,6 +350,7 @@ def plot_rate_map(
"""
#Set kwargs (TODO make lots of params accessible here as kwargs)
spikes_color = kwargs.get("spikes_color", self.color) or "C1"
bin_size = kwargs.get("bin_size", 0.04) #only relevant if you are plotting by method="history"


# GET DATA
Expand All @@ -367,25 +371,25 @@ def plot_rate_map(
method = "history"

if method == "history" or spikes == True:
t = np.array(self.history["t"])
history_data = self.get_history_arrays() # converts lists to arrays (if this wasn't just done) and returns them in a dict same as self.history but with arrays not lists
t = history_data["t"]
# times to plot
if len(t) == 0:
print(
"Can't plot rate map by method='history' since there is no available data to plot. "
)
print("Can't plot rate map by method='history', nor plot spikes, since there is no available data to plot. ")
return
t_end = t_end or t[-1]
position_data_agent = kwargs.get("position_data_agent", self.Agent) # In rare cases you may like to plot this cells rate/spike data against the position of a diffferent Agent. This kwarg enables that.
position_agent_history_data = position_data_agent.get_history_arrays()
slice = position_data_agent.get_history_slice(t_start, t_end)
pos = np.array(position_data_agent.history["pos"])[slice]
pos = position_agent_history_data["pos"][slice]
t = t[slice]

if method == "history":
rate_timeseries = np.array(self.history["firingrate"])[slice].T
rate_timeseries = history_data["firingrate"][slice].T
if len(rate_timeseries) == 0:
print("No historical data with which to calculate ratemap.")
if spikes == True:
spike_data = np.array(self.history["spikes"])[slice].T
spike_data = history_data["spikes"][slice].T
if len(spike_data) == 0:
print("No historical data with which to plot spikes.")
if method == "ratemaps_provided":
Expand Down Expand Up @@ -468,21 +472,33 @@ def plot_rate_map(
)
im = ax_.imshow(rate_map, extent=ex, zorder=0, cmap=self.colormap)
elif method == "history":
bin_size = kwargs.get("bin_size", 0.05)
rate_timeseries_ = rate_timeseries[chosen_neurons[i], :]
rate_map = utils.bin_data_for_histogramming(
rate_map, zero_bins = utils.bin_data_for_histogramming(
data=pos,
extent=ex,
dx=0.05,
dx=bin_size,
weights=rate_timeseries_,
norm_by_bincount=True,
return_zero_bins=True,
)
#rather than just "nan-ing" the regions where no data was observed we'll plot ontop a "mask" overlay which blocks with a grey square regions where no data was observed. The benefit of this technique is it still allows us to use "bicubic" interpolation which is much smoother than the default "nearest" interpolation.
binary_colors = [(0,0,0,0),ratinabox.LIGHTGREY] #transparent if theres data, grey if there isn't
binary_cmap = matplotlib.colors.ListedColormap(binary_colors)
im = ax_.imshow(
rate_map,
extent=ex,
cmap=self.colormap,
interpolation="bicubic",
zorder=0,
)
no_data_mask = ax_.imshow(
zero_bins,
extent=ex,
cmap=binary_cmap,
interpolation="nearest",
zorder=0.001,
)
ims.append(im)
vmin, vmax = (
min(vmin, np.min(rate_map)),
Expand Down Expand Up @@ -749,7 +765,18 @@ def return_list_of_neurons(self, chosen_neurons="all"):
chosen_neurons = list(chosen_neurons.astype(int))

return chosen_neurons


def get_history_arrays(self):
"""Returns the history dataframe as a dictionary of numpy arrays (as opposed to lists). This getter-function only updates the self._history_arrays if the Agent/Neuron has updates since the last time it was called. This avoids expensive repeated conversion of lists to arrays during animations."""
if (self._last_history_array_cache_time != self.Agent.t):
self._history_arrays = {}
self._last_history_array_cache_time = self.Agent.t
for key in self.history.keys():
try: #will skip if for any reason this key cannot be converted to an array, so you can still save random stuff into the history dict without breaking this function
self._history_arrays[key] = np.array(self.history[key])
except: pass
return self._history_arrays


"""Specific subclasses """

Expand Down

0 comments on commit 4af646d

Please sign in to comment.