Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
TomGeorge1234 committed Oct 21, 2023
1 parent 0dca070 commit 3474d98
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 16 deletions.
Binary file modified .images/demos/rat_target.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified .images/demos/riab_target.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion demos/actor_critic_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
"from ratinabox.Agent import Agent\n",
"from ratinabox.Neurons import PlaceCells\n",
"from ratinabox.contribs.NeuralNetworkNeurons import NeuralNetworkNeurons #for the Actor and Critic\n",
"ratinabox.autosave_plots = True; ratinabox.figure_directory = \"./figures/\"; ratinabox.stylize_plots()\n",
"ratinabox.autosave_plots = True; ratinabox.figure_directory = \"../figures/\"; ratinabox.stylize_plots()\n",
"\n",
"#misc\n",
"import torch \n",
Expand Down
11 changes: 6 additions & 5 deletions ratinabox/Environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ def __init__(self, params={}):
}
self.n_object_types = 0
self.object_colormap = "rainbow"
self.plot_objects = True

# make some other attributes
left = min([c[0] for c in b])
Expand Down Expand Up @@ -255,7 +254,8 @@ def plot_environment(self,
fig=None,
ax=None,
gridlines=False,
autosave=None):
autosave=None,
**kwargs,):
"""Plots the environment on the x axis, dark grey lines show the walls
Args:
fig,ax: the fig and ax to plot on (can be None)
Expand All @@ -264,7 +264,6 @@ def plot_environment(self,
Returns:
fig, ax: the environment figures, can be used for further downstream plotting.
"""

if self.dimensionality == "1D":
extent = self.extent
if fig is None and ax is None:
Expand Down Expand Up @@ -350,8 +349,10 @@ def plot_environment(self,
zorder=2,
)

# plot objects
if self.plot_objects == True:
# plot objects if there isn't a kwarg setting it to false
if 'plot_objects' in kwargs and kwargs['plot_objects'] == False:
pass
else:
object_cmap = matplotlib.colormaps[self.object_colormap]
for i, object in enumerate(self.objects["objects"]):
object_color = object_cmap(
Expand Down
18 changes: 11 additions & 7 deletions ratinabox/Neurons.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@ class Neurons:
• GridCells()
• BoundaryVectorCells()
• ObjectVectorCells()
• FieldOfViewBVCs()
• FieldOfViewOVCs()
• VelocityCells()
• HeadDirectionCells()
• SpeedCells()
• FeedForwardLayer()
• RandomSpatialNeurons()
as well as (in the contribs)
• ValueNeuron()
FieldOfViewNeurons()
NeuralNetworkNeurons()
The unique function in each child classes is get_state(). Whenever Neurons.update() is called Neurons.get_state() is then called to calculate and return the firing rate of the cells at the current moment in time. This is then saved. In order to make your own Neuron subclass you will need to write a class with the following mandatory structure:
Expand Down Expand Up @@ -381,7 +384,7 @@ def plot_rate_map(
else:
Nx, Ny = shape[0], shape[1]
env_fig, env_ax = self.Agent.Environment.plot_environment(
autosave=False
autosave=False, **kwargs,
)
width, height = env_fig.get_size_inches()
plt.close(env_fig)
Expand Down Expand Up @@ -415,7 +418,7 @@ def plot_rate_map(
cax = divider.append_axes("right", size="5%", pad=0.05)
for i, ax_ in enumerate(axes):
_, ax_ = self.Agent.Environment.plot_environment(
fig, ax_, autosave=False
fig, ax_, autosave=False, **kwargs
)
if len(chosen_neurons) != axes.size:
print(
Expand Down Expand Up @@ -463,9 +466,9 @@ def plot_rate_map(
cbar = plt.colorbar(ims[-1], cax=cax)
cbar.ax.tick_params(length=0)
cbar.set_label("Firing rate / Hz",labelpad=-10)
lim_v = vmax if vmax > -vmin else vmin
cbar.set_ticks([0, lim_v])
cbar.set_ticklabels([0.0, round(lim_v, 1)])
# lim_v = vmax if vmax > -vmin else vmin
cbar.set_ticks([vmin,vmax])
cbar.set_ticklabels([f"{vmin:.1f}", f"{vmax:.1f}"])
cbar.outline.set_visible(False)

if spikes is True:
Expand Down Expand Up @@ -2386,7 +2389,8 @@ class RandomSpatialNeurons(Neurons):
'max_fr':1, #maximum firing rate
'min_fr':0, #minimum firing rate
'n':10, #number of neurons
'wall_geometry':'geodesic' #how to account for walls when calculating distance between points (only relevant in 2D)
'wall_geometry':'geodesic', #how to account for walls when calculating distance between points (only relevant in 2D)
'name':'RandomSpatialNeurons', #name of the class
}

def __init__(self, Agent, params={}):
Expand Down
2 changes: 1 addition & 1 deletion ratinabox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

DARKGREY = [0.3,0.3,0.3,1]
GREY = [0.5,0.5,0.5,1]
LIGHTGREY = [0.92,0.92,0.92,1]
LIGHTGREY = [0.9,0.9,0.9,1]

from .Environment import *
from .Agent import *
Expand Down
5 changes: 4 additions & 1 deletion ratinabox/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ def mountain_plot(
fig=None,
ax=None,
norm_by="max",
linewidth=1,
width=ratinabox.MOUNTAIN_PLOT_WIDTH_MM,
overlap=ratinabox.MOUNTAIN_PLOT_OVERLAP,
shift=ratinabox.MOUNTAIN_PLOT_SHIFT_MM,
Expand All @@ -590,6 +591,8 @@ def mountain_plot(
ax (_type_, optional): ax to plot on if desider. Defaults to None.
norm_by: what to normalise each line of the mountainplot by.
If "max", norms by the maximum firing rate found across all the neurons. Otherwise, pass a float (useful if you want to compare different neural datsets apples-to-apples)
linewidth: width of lines
width: width of figure in mm
overlap: how much each plots overlap by (> 1 = overlap, < 1 = no overlap) (overlap is not relevant if you also set "norm_by")
shift: distance between lines in mm
Expand All @@ -615,7 +618,7 @@ def mountain_plot(

zorder = 1
for i in range(len(NbyX)):
ax.plot(X, NbyX[i] + i + 1, c=c, zorder=zorder)
ax.plot(X, NbyX[i] + i + 1, c=c, zorder=zorder, lw=linewidth)
zorder -= 0.01
ax.fill_between(
X, NbyX[i] + i + 1, i + 1, color=fc, zorder=zorder, alpha=0.8, linewidth=0
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = ratinabox
version = 1.10.0
version = 1.10.1
author = Tom George
author_email = tomgeorge1@btinternet.com
project_urls =
Expand Down

0 comments on commit 3474d98

Please sign in to comment.