From 4af646dc20a574577f79de8d12be4b8dfba0c74a Mon Sep 17 00:00:00 2001 From: Tom George Date: Mon, 19 Feb 2024 19:08:56 +0000 Subject: [PATCH] self._history_arrays and its getter-function allow for fasting animations which repeatedly converting lists to arrays --- ratinabox/Agent.py | 41 +++++++++++++++++++------------------- ratinabox/Neurons.py | 47 ++++++++++++++++++++++++++++++++++---------- 2 files changed, 57 insertions(+), 31 deletions(-) diff --git a/ratinabox/Agent.py b/ratinabox/Agent.py index 78aafd5..19b43c7 100644 --- a/ratinabox/Agent.py +++ b/ratinabox/Agent.py @@ -45,7 +45,7 @@ class Agent: • initialise_position_and_velocity() • get_history_slice() • get_all_default_params() - • cache_history_as_arrays() + • get_history_arrays() The default params for this agent are: default_params = { @@ -115,7 +115,9 @@ def __init__(self, Environment, params={}): self.history["vel"] = [] self.history["rot_vel"] = [] self.history["head_direction"] = [] - self.history_array_cache = {"last_cache_time":None} # this is used to cache the history data as an arrays for faster plotting/animating + + self._last_history_array_cache_time = None + self._history_arrays = {} # this is used to cache the history data as an arrays for faster plotting/animating self.Neurons = [] # each new Neurons class belonging to this Agent will append itself to this list @@ -711,11 +713,10 @@ def plot_trajectory( #get times and trjectory from history data (normal) t_end = t_end or self_.history["t"][-1] slice = self_.get_history_slice(t_start=t_start, t_end=t_end, framerate=framerate) - if (self_.history_array_cache["last_cache_time"] != self.t): - self_.cache_history_as_arrays() - time = self_.history_array_cache["t"][slice] - trajectory = self_.history_array_cache["pos"][slice] - head_direction = self_.history_array_cache["head_direction"][slice] + history_data = self.get_history_arrays() # gets history dataframe as dictionary of arrays (only recomputing arrays from lists if necessary) + time = history_data["t"][slice] + trajectory = history_data["pos"][slice] + head_direction = history_data["head_direction"][slice] else: # data has been passed in manually t_start, t_end = time[0], time[-1] @@ -1068,9 +1069,7 @@ def get_history_slice(self, t_start=None, t_end=None, framerate=None): • t_end: end time in seconds (default = self.history["t"][-1]) • framerate: frames per second (default = None --> step=0 so, just whatever the data frequency (1/Ag.dt) is) """ - if self.history_array_cache["last_cache_time"] != self.t: - self.cache_history_as_arrays() - t = self.history_array_cache["t"] + t = self.get_history_arrays()["t"] t_start = t_start or t[0] startid = np.nanargmin(np.abs(t - (t_start))) t_end = t_end or t[-1] @@ -1081,14 +1080,14 @@ def get_history_slice(self, t_start=None, t_end=None, framerate=None): skiprate = max(1, int((1 / framerate) / self.dt)) return slice(startid, endid, skiprate) - - def cache_history_as_arrays(self): - """Converts anything in the current history dictionary into a numpy array along with the time this cache was made. This is useful for speeding up animating functions which require slicing the history data but repeatedly converting to arrays is expensive. This is called automatically by the plot_trajectory function if the history data has not been cached yet. - TODO This should probably be improved, right now it will convert and cache _all_ history data, even if only some of it is needed.""" - self.history_array_cache = {} - self.history_array_cache["last_cache_time"] = self.t - for key in self.history.keys(): - try: #will skip if for any reason this key cannot be converted to an array, so you can still save random stuff into the history dict without breaking this function - self.history_array_cache[key] = np.array(self.history[key]) - except: pass - return \ No newline at end of file + + def get_history_arrays(self): + """Returns the history dataframe as a dictionary of numpy arrays (as opposed to lists). This getter-function only updates the self._history_arrays if the Agent/Neuron has updates since the last time it was called. This avoids expensive repeated conversion of lists to arrays during animations.""" + if (self._last_history_array_cache_time != self.t): + self._history_arrays = {} + self._last_history_array_cache_time = self.t + for key in self.history.keys(): + try: #will skip if for any reason this key cannot be converted to an array, so you can still save random stuff into the history dict without breaking this function + self._history_arrays[key] = np.array(self.history[key]) + except: pass + return self._history_arrays \ No newline at end of file diff --git a/ratinabox/Neurons.py b/ratinabox/Neurons.py index 2de99a6..8354a50 100644 --- a/ratinabox/Neurons.py +++ b/ratinabox/Neurons.py @@ -124,6 +124,9 @@ def __init__(self, Agent, params={}): self.history["firingrate"] = [] self.history["spikes"] = [] + self._last_history_array_cache_time = None + self._history_arrays = {} # this dictionary is the same as self.history except the data is in arrays not lists BUT it should only be accessed via its getter-function `self.get_history_arrays()`. This is because the lists are only converted to arrays when they are accessed, not on every step, so as to save time. + self.colormap = "inferno" # default colormap for plotting ratemaps if ratinabox.verbose is True: @@ -347,6 +350,7 @@ def plot_rate_map( """ #Set kwargs (TODO make lots of params accessible here as kwargs) spikes_color = kwargs.get("spikes_color", self.color) or "C1" + bin_size = kwargs.get("bin_size", 0.04) #only relevant if you are plotting by method="history" # GET DATA @@ -367,25 +371,25 @@ def plot_rate_map( method = "history" if method == "history" or spikes == True: - t = np.array(self.history["t"]) + history_data = self.get_history_arrays() # converts lists to arrays (if this wasn't just done) and returns them in a dict same as self.history but with arrays not lists + t = history_data["t"] # times to plot if len(t) == 0: - print( - "Can't plot rate map by method='history' since there is no available data to plot. " - ) + print("Can't plot rate map by method='history', nor plot spikes, since there is no available data to plot. ") return t_end = t_end or t[-1] position_data_agent = kwargs.get("position_data_agent", self.Agent) # In rare cases you may like to plot this cells rate/spike data against the position of a diffferent Agent. This kwarg enables that. + position_agent_history_data = position_data_agent.get_history_arrays() slice = position_data_agent.get_history_slice(t_start, t_end) - pos = np.array(position_data_agent.history["pos"])[slice] + pos = position_agent_history_data["pos"][slice] t = t[slice] if method == "history": - rate_timeseries = np.array(self.history["firingrate"])[slice].T + rate_timeseries = history_data["firingrate"][slice].T if len(rate_timeseries) == 0: print("No historical data with which to calculate ratemap.") if spikes == True: - spike_data = np.array(self.history["spikes"])[slice].T + spike_data = history_data["spikes"][slice].T if len(spike_data) == 0: print("No historical data with which to plot spikes.") if method == "ratemaps_provided": @@ -468,14 +472,19 @@ def plot_rate_map( ) im = ax_.imshow(rate_map, extent=ex, zorder=0, cmap=self.colormap) elif method == "history": + bin_size = kwargs.get("bin_size", 0.05) rate_timeseries_ = rate_timeseries[chosen_neurons[i], :] - rate_map = utils.bin_data_for_histogramming( + rate_map, zero_bins = utils.bin_data_for_histogramming( data=pos, extent=ex, - dx=0.05, + dx=bin_size, weights=rate_timeseries_, norm_by_bincount=True, + return_zero_bins=True, ) + #rather than just "nan-ing" the regions where no data was observed we'll plot ontop a "mask" overlay which blocks with a grey square regions where no data was observed. The benefit of this technique is it still allows us to use "bicubic" interpolation which is much smoother than the default "nearest" interpolation. + binary_colors = [(0,0,0,0),ratinabox.LIGHTGREY] #transparent if theres data, grey if there isn't + binary_cmap = matplotlib.colors.ListedColormap(binary_colors) im = ax_.imshow( rate_map, extent=ex, @@ -483,6 +492,13 @@ def plot_rate_map( interpolation="bicubic", zorder=0, ) + no_data_mask = ax_.imshow( + zero_bins, + extent=ex, + cmap=binary_cmap, + interpolation="nearest", + zorder=0.001, + ) ims.append(im) vmin, vmax = ( min(vmin, np.min(rate_map)), @@ -749,7 +765,18 @@ def return_list_of_neurons(self, chosen_neurons="all"): chosen_neurons = list(chosen_neurons.astype(int)) return chosen_neurons - + + def get_history_arrays(self): + """Returns the history dataframe as a dictionary of numpy arrays (as opposed to lists). This getter-function only updates the self._history_arrays if the Agent/Neuron has updates since the last time it was called. This avoids expensive repeated conversion of lists to arrays during animations.""" + if (self._last_history_array_cache_time != self.Agent.t): + self._history_arrays = {} + self._last_history_array_cache_time = self.Agent.t + for key in self.history.keys(): + try: #will skip if for any reason this key cannot be converted to an array, so you can still save random stuff into the history dict without breaking this function + self._history_arrays[key] = np.array(self.history[key]) + except: pass + return self._history_arrays + """Specific subclasses """