Skip to content

Commit

Permalink
Merge pull request #27 from TomGeorge1234/dev
Browse files Browse the repository at this point in the history
Minor changes
  • Loading branch information
TomGeorge1234 authored Mar 16, 2023
2 parents ec84ef0 + 851963c commit eb3b0f0
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 11 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ fig, ax = PCs.plot_rate_timeseries()

We have a dedicated [contribs](./ratinabox/contribs/) directory where you can safely add awesome scripts and new `Neurons` classes etc.

*Questions?* Can't figure out how something works. If you can't fgure it out from the readme, demos, code comments etc. then ask me! Open an issue, I'm usually pretty quick to respond.
*Questions?* If you can't figure out how something works from the readme, demos, code comments etc. then ask! Open an issue, I'm usually pretty quick to respond. Here's our [official theme tune](https://www.youtube.com/watch?v=dY-FOI-9SOE) by the way.



Expand All @@ -338,5 +338,4 @@ Formatted:
```
Tom M George, William de Cothi, Claudia Clopath, Kimberly Stachenfeld, Caswell Barry. "RatInABox: A toolkit for modelling locomotion and neuronal activity in continuous environments" (2022).
```
The research paper corresponding to the above citation can be found [here](https://www.biorxiv.org/content/10.1101/2022.08.10.503541v3).

The research paper corresponding to the above citation can be found [here](https://www.biorxiv.org/content/10.1101/2022.08.10.503541v3).
25 changes: 21 additions & 4 deletions ratinabox/Agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
plt.rcParams["animation.html"] = "jshtml" #for animations


from ratinabox import utils

Expand Down Expand Up @@ -65,6 +67,7 @@ def __init__(self, Environment, params={}):
"thigmotaxis": 0.5, # tendency for agents to linger near walls [0 = not at all, 1 = max]
"wall_repel_distance": 0.1,
"walls_repel": True, # whether or not the walls repel

}
self.Environment = Environment
default_params.update(params)
Expand Down Expand Up @@ -370,6 +373,7 @@ def update(self, dt=None, drift_velocity=None, drift_to_random_strength_ratio=1)
save_velocity
)

# TO DO: make this a function call
# write to history
self.history["t"].append(self.t)
self.history["pos"].append(list(self.pos))
Expand Down Expand Up @@ -583,8 +587,10 @@ def animate_trajectory(
self, t_start=None, t_end=None, fps=15, speed_up=1, **kwargs
):
"""Returns an animation (anim) of the trajectory, 25fps.
Should be saved using comand like
anim.save("./where_to_save/animations.gif",dpi=300)
Should be saved using command like
>>> anim.save("./where_to_save/animations.gif",dpi=300)
To display in jupyter notebook, call it:
>>> anim
Args:
t_start: Agent time at which to start animation
Expand All @@ -602,7 +608,7 @@ def animate_trajectory(
if t_end == None:
t_end = self.history["t"][-1]

def animate_(i, fig, ax, t_start, t_max, speed_up, dt):
def animate_(i, fig, ax, t_start, t_max, speed_up, dt, additional_plot_func, **kwargs):
t_end = t_start + (i + 1) * speed_up * dt
ax.clear()
if self.Environment.dimensionality == "2D":
Expand All @@ -616,6 +622,12 @@ def animate_(i, fig, ax, t_start, t_max, speed_up, dt):
xlim=t_max / 60,
**kwargs,
)
if additional_plot_func is not None:
fig, ax = additional_plot_func(fig=fig,
ax=ax,
t=t_end, #the current time
**kwargs)

plt.close()
return

Expand All @@ -624,14 +636,19 @@ def animate_(i, fig, ax, t_start, t_max, speed_up, dt):
)

from matplotlib import animation
# if passed, after plotting the trajectory fig, ax are passed through this function.
# use it to add other things ontop of the animation
additional_plot_func = None
if 'additional_plot_func' in kwargs.keys():
additional_plot_func = kwargs['additional_plot_func']

anim = matplotlib.animation.FuncAnimation(
fig,
animate_,
interval=1000 * dt,
frames=int((t_end - t_start) / (dt * speed_up)),
blit=False,
fargs=(fig, ax, t_start, t_end, speed_up, dt),
fargs=(fig, ax, t_start, t_end, speed_up, dt, additional_plot_func),
)
return anim

Expand Down
2 changes: 1 addition & 1 deletion ratinabox/Environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def sample_positions(self, n=10, method="uniform_jitter"):
positions = np.vstack((positions, positions_remaining))

if (self.is_rectangular) or (self.has_holes is True):
# in this case, the positions you have sampled within the extent of the environment may not actually fall within it's legal area (i.e. they could be outside the polygon boundary or inside a hole). Brute for this by randomly resampling these oints until all fall within the env.
# in this case, the positions you have sampled within the extent of the environment may not actually fall within it's legal area (i.e. they could be outside the polygon boundary or inside a hole). Brute force this by randomly resampling these points until all fall within the env.
for (i, pos) in enumerate(positions):
if self.check_if_position_is_in_environment(pos) == False:
pos = self.sample_positions(n=1, method="random").reshape(
Expand Down
12 changes: 9 additions & 3 deletions ratinabox/Neurons.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,8 +466,10 @@ def animate_rate_timeseries(
**kwargs,
):
"""Returns an animation (anim) of the firing rates, 25fps.
Should be saved using comand like
anim.save("./where_to_save/animations.gif",dpi=300)
Should be saved using command like:
>>> anim.save("./where_to_save/animations.gif",dpi=300) #or ".mp4" etc...
To display within jupyter notebook, just call it:
>>> anim
Args:
• t_end (_type_, optional): _description_. Defaults to None.
Expand All @@ -478,6 +480,9 @@ def animate_rate_timeseries(
Returns:
animation
"""

plt.rcParams["animation.html"] = "jshtml" #for animation rendering in jupyter

dt = 1 / fps
if t_start == None:
t_start = self.history["t"][0]
Expand Down Expand Up @@ -570,6 +575,7 @@ class PlaceCells(Neurons):
• diff_of_gaussians
• top_hat
• one_hot
#TO-DO • tanni_harland https://pubmed.ncbi.nlm.nih.gov/33770492/
List of functions:
• get_state()
Expand Down Expand Up @@ -1013,7 +1019,7 @@ def get_state(self, evaluate_at="agent", **kwargs):
# if egocentric references frame shift angle into coordinate from of heading direction of agent
if self.reference_frame == "egocentric":
if evaluate_at == "agent":
vel = self.Agent.pos
vel = self.Agent.velocity
elif "vel" in kwargs.keys():
vel = kwargs["vel"]
else:
Expand Down

0 comments on commit eb3b0f0

Please sign in to comment.