Skip to content

Commit

Permalink
Added a warning if number of neurons will be changed in the initializ…
Browse files Browse the repository at this point in the history
…ation of a neuron layer.

Fixed a small bug in VectorCells init: self.Agent was not set before being called.
Removed default value of None for `Other_Agent` in AgentVectorCells init, as None is not an allowed value.
Raised an error during setting of tuning types in ObjectVectorCells if there are not objects in the environment.
  • Loading branch information
colleenjg committed Feb 1, 2024
1 parent ba679c4 commit a2f100d
Showing 1 changed file with 32 additions and 6 deletions.
38 changes: 32 additions & 6 deletions ratinabox/Neurons.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,8 @@ def __init__(self, Agent, params={}):
Agent. The RatInABox Agent these cells belong to.
params (dict, optional). Defaults to {}.
"""
self.Agent = Agent

assert (
self.Agent.Environment.dimensionality == "2D"
), "Vector cells only possible in 2D"
Expand Down Expand Up @@ -1245,18 +1247,28 @@ def __init__(self, Agent, params={}):
self.sigma_angles = None
self.sigma_distances = None


# set the parameters of each cell.
(self.tuning_distances,
self.tuning_angles,
self.sigma_distances,
self.sigma_angles) = self.set_tuning_parameters(**self.params)
self.n = len(self.tuning_distances) #ensure n is correct

# records whether n was passed as a parameter.
if not hasattr(self, "_warn_n_change"):
self._warn_n_change = ("n" in params.keys() and params["n"] is not None)

# raises a warning if n was passed as a parameter, but will change.
if self._warn_n_change and params["n"] != len(self.tuning_distances):
warnings.warn(f"Ignoring 'n' parameter value ({params['n']}) that was passed, and setting number of {self.name} neurons to {len(self.tuning_distances)}, inferred from the cell arrangement parameter.")

self.n = len(self.tuning_distances) # ensure n is correct


self.firingrate = np.zeros(self.n)
self.noise = np.zeros(self.n)
self.cell_colors = None


def set_tuning_parameters(self, **kwargs):
"""Get the tuning parameters for the vector cells.
Args:
Expand Down Expand Up @@ -1438,6 +1450,10 @@ def __init__(self, Agent, params={}):
self.params = copy.deepcopy(__class__.default_params)
self.params.update(params)

# records whether n was passed as a parameter.
if not hasattr(self, "_warn_n_change"):
self._warn_n_change = ("n" in params.keys() and params["n"] is not None)

super().__init__(Agent, self.params)

assert (
Expand Down Expand Up @@ -1483,7 +1499,6 @@ def __init__(self, Agent, params={}):
return



def get_state(self, evaluate_at="agent", **kwargs):
"""
Here we implement the same type if boundary vector cells as de Cothi et al. (2020),
Expand Down Expand Up @@ -1789,6 +1804,10 @@ def __init__(self, Agent, params={}):
self.Agent.Environment.dimensionality == "2D"
), "object vector cells only possible in 2D"

# records whether n was passed as a parameter.
if not hasattr(self, "_warn_n_change"):
self._warn_n_change = ("n" in params.keys() and params["n"] is not None)

super().__init__(Agent, self.params)

self.object_locations = self.Agent.Environment.objects["objects"]
Expand Down Expand Up @@ -1834,6 +1853,8 @@ def set_tuning_types(self, tuning_types=None):

if tuning_types == "random":
self.object_types = self.Agent.Environment.objects["object_types"]
if len(self.object_types) == 0:
raise RuntimeError("Cannot initialize object vector cells randomly, as no objects were found in the environment.")
self.tuning_types = np.random.choice(
np.unique(self.object_types), replace=True, size=(self.n,)
)
Expand Down Expand Up @@ -2006,7 +2027,6 @@ def __init__(self,Agent,params={}):
warnings.warn("For FieldOfViewOVCs you must specify the object type they are selective for with the 'object_tuning_type' parameter. This can be 'random' (each cell in the field of view chooses a random object type) or any integer (all cells have the same preference for this type). For now defaulting to params['object_tuning_type'] = 0.")
self.params["object_tuning_type"] = 0


self.params["reference_frame"] = "egocentric"
assert self.params["cell_arrangement"] is not None, "cell_arrangement must be set for FOV Neurons"

Expand Down Expand Up @@ -2034,7 +2054,7 @@ class AgentVectorCells(VectorCells):

def __init__(self,
Agent,
Other_Agent = None, #this must be another riab Agent object
Other_Agent, #this must be another riab Agent object
params={}):

self.Agent = Agent
Expand Down Expand Up @@ -2269,6 +2289,8 @@ def __init__(self, Agent, params={}):
[self.params["angular_spread_degrees"] * np.pi / 180] * self.n
)
if self.Agent.Environment.dimensionality == "1D":
if "n" in params.keys() and params["n"] != 2:
warnings.warn(f"Ignoring 'n' parameter value ({params['n']}) that was passed for {self.params['name']}. Only 2 head direction cells are needed for a 1D environment.")
self.n = 2 # one left, one right
self.params["n"] = self.n
super().__init__(Agent, self.params)
Expand Down Expand Up @@ -2476,7 +2498,11 @@ def __init__(self, Agent, params={}):
self.params.update(params)

super().__init__(Agent, self.params)

if "n" in params.keys() and params["n"] != 1:
warnings.warn(f"Ignoring 'n' parameter value ({params['n']}) that was passed for {self.name}. Only 1 speed cell is needed.")
self.n = 1

self.one_sigma_speed = self.Agent.speed_mean + self.Agent.speed_std

if ratinabox.verbose is True:
Expand Down

0 comments on commit a2f100d

Please sign in to comment.