diff --git a/README.md b/README.md index c0a1f09a..14eecadc 100644 --- a/README.md +++ b/README.md @@ -315,7 +315,7 @@ fig, ax = PCs.plot_rate_timeseries() We have a dedicated [contribs](./ratinabox/contribs/) directory where you can safely add awesome scripts and new `Neurons` classes etc. -*Questions?* Can't figure out how something works. If you can't fgure it out from the readme, demos, code comments etc. then ask me! Open an issue, I'm usually pretty quick to respond. +*Questions?* If you can't figure out how something works from the readme, demos, code comments etc. then ask! Open an issue, I'm usually pretty quick to respond. Here's our [official theme tune](https://www.youtube.com/watch?v=dY-FOI-9SOE) by the way. @@ -338,5 +338,4 @@ Formatted: ``` Tom M George, William de Cothi, Claudia Clopath, Kimberly Stachenfeld, Caswell Barry. "RatInABox: A toolkit for modelling locomotion and neuronal activity in continuous environments" (2022). ``` -The research paper corresponding to the above citation can be found [here](https://www.biorxiv.org/content/10.1101/2022.08.10.503541v3). - +The research paper corresponding to the above citation can be found [here](https://www.biorxiv.org/content/10.1101/2022.08.10.503541v3). \ No newline at end of file diff --git a/ratinabox/Agent.py b/ratinabox/Agent.py index 9f99064c..a5bc6691 100644 --- a/ratinabox/Agent.py +++ b/ratinabox/Agent.py @@ -3,6 +3,8 @@ import numpy as np import matplotlib from matplotlib import pyplot as plt +plt.rcParams["animation.html"] = "jshtml" #for animations + from ratinabox import utils @@ -65,6 +67,7 @@ def __init__(self, Environment, params={}): "thigmotaxis": 0.5, # tendency for agents to linger near walls [0 = not at all, 1 = max] "wall_repel_distance": 0.1, "walls_repel": True, # whether or not the walls repel + } self.Environment = Environment default_params.update(params) @@ -370,6 +373,7 @@ def update(self, dt=None, drift_velocity=None, drift_to_random_strength_ratio=1) save_velocity ) + # TO DO: make this a function call # write to history self.history["t"].append(self.t) self.history["pos"].append(list(self.pos)) @@ -583,8 +587,10 @@ def animate_trajectory( self, t_start=None, t_end=None, fps=15, speed_up=1, **kwargs ): """Returns an animation (anim) of the trajectory, 25fps. - Should be saved using comand like - anim.save("./where_to_save/animations.gif",dpi=300) + Should be saved using command like + >>> anim.save("./where_to_save/animations.gif",dpi=300) + To display in jupyter notebook, call it: + >>> anim Args: t_start: Agent time at which to start animation @@ -602,7 +608,7 @@ def animate_trajectory( if t_end == None: t_end = self.history["t"][-1] - def animate_(i, fig, ax, t_start, t_max, speed_up, dt): + def animate_(i, fig, ax, t_start, t_max, speed_up, dt, additional_plot_func, **kwargs): t_end = t_start + (i + 1) * speed_up * dt ax.clear() if self.Environment.dimensionality == "2D": @@ -616,6 +622,12 @@ def animate_(i, fig, ax, t_start, t_max, speed_up, dt): xlim=t_max / 60, **kwargs, ) + if additional_plot_func is not None: + fig, ax = additional_plot_func(fig=fig, + ax=ax, + t=t_end, #the current time + **kwargs) + plt.close() return @@ -624,6 +636,11 @@ def animate_(i, fig, ax, t_start, t_max, speed_up, dt): ) from matplotlib import animation + # if passed, after plotting the trajectory fig, ax are passed through this function. + # use it to add other things ontop of the animation + additional_plot_func = None + if 'additional_plot_func' in kwargs.keys(): + additional_plot_func = kwargs['additional_plot_func'] anim = matplotlib.animation.FuncAnimation( fig, @@ -631,7 +648,7 @@ def animate_(i, fig, ax, t_start, t_max, speed_up, dt): interval=1000 * dt, frames=int((t_end - t_start) / (dt * speed_up)), blit=False, - fargs=(fig, ax, t_start, t_end, speed_up, dt), + fargs=(fig, ax, t_start, t_end, speed_up, dt, additional_plot_func), ) return anim diff --git a/ratinabox/Environment.py b/ratinabox/Environment.py index 2c94e2b4..8ce4a388 100644 --- a/ratinabox/Environment.py +++ b/ratinabox/Environment.py @@ -301,7 +301,7 @@ def sample_positions(self, n=10, method="uniform_jitter"): positions = np.vstack((positions, positions_remaining)) if (self.is_rectangular) or (self.has_holes is True): - # in this case, the positions you have sampled within the extent of the environment may not actually fall within it's legal area (i.e. they could be outside the polygon boundary or inside a hole). Brute for this by randomly resampling these oints until all fall within the env. + # in this case, the positions you have sampled within the extent of the environment may not actually fall within it's legal area (i.e. they could be outside the polygon boundary or inside a hole). Brute force this by randomly resampling these points until all fall within the env. for (i, pos) in enumerate(positions): if self.check_if_position_is_in_environment(pos) == False: pos = self.sample_positions(n=1, method="random").reshape( diff --git a/ratinabox/Neurons.py b/ratinabox/Neurons.py index 8dbb22c2..d1d59cc8 100644 --- a/ratinabox/Neurons.py +++ b/ratinabox/Neurons.py @@ -466,8 +466,10 @@ def animate_rate_timeseries( **kwargs, ): """Returns an animation (anim) of the firing rates, 25fps. - Should be saved using comand like - anim.save("./where_to_save/animations.gif",dpi=300) + Should be saved using command like: + >>> anim.save("./where_to_save/animations.gif",dpi=300) #or ".mp4" etc... + To display within jupyter notebook, just call it: + >>> anim Args: • t_end (_type_, optional): _description_. Defaults to None. @@ -478,6 +480,9 @@ def animate_rate_timeseries( Returns: animation """ + + plt.rcParams["animation.html"] = "jshtml" #for animation rendering in jupyter + dt = 1 / fps if t_start == None: t_start = self.history["t"][0] @@ -570,6 +575,7 @@ class PlaceCells(Neurons): • diff_of_gaussians • top_hat • one_hot + #TO-DO • tanni_harland https://pubmed.ncbi.nlm.nih.gov/33770492/ List of functions: • get_state() @@ -1013,7 +1019,7 @@ def get_state(self, evaluate_at="agent", **kwargs): # if egocentric references frame shift angle into coordinate from of heading direction of agent if self.reference_frame == "egocentric": if evaluate_at == "agent": - vel = self.Agent.pos + vel = self.Agent.velocity elif "vel" in kwargs.keys(): vel = kwargs["vel"] else: