Skip to content

Commit

Permalink
Merge pull request #95 from colleenjg/cjg-dev
Browse files Browse the repository at this point in the history
Avoid infinite recursion in plot_rate_map()
  • Loading branch information
TomGeorge1234 authored Nov 30, 2023
2 parents c6aeb22 + 2f3880c commit 16375e1
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions ratinabox/Neurons.py
Original file line number Diff line number Diff line change
Expand Up @@ -2603,7 +2603,7 @@ def lambda_activation_function(activation, other_args):
f"FeedForwardLayer initialised with {len(self.inputs.keys())} layers. To add another layer use FeedForwardLayer.add_input_layer().\nTo set the weights manually edit them by changing self.inputs['layer_name']['w']"
)

def add_input(self, input_layer, w=None, w_init_scale=1, **kwargs):
def add_input(self, input_layer, w=None, w_init_scale=1, recurrent=False, **kwargs):
"""Adds an input layer to the class. Each input layer is stored in a dictionary of self.inputs. Each has an associated matrix of weights which are initialised randomly.
Note the inputs are stored in a dictionary. The keys are taken to be the name of each layer passed (input_layer.name). Make sure you set this correctly (and uniquely).
Expand All @@ -2612,6 +2612,7 @@ def add_input(self, input_layer, w=None, w_init_scale=1, **kwargs):
• input_layer (_type_): the layer intself. Must be a Neurons() class object (e.g. can be PlaceCells(), etc...).
• w: the weight matrix. If None these will be drawn randomly, see next argument.
• w_init_scale: initial weights drawn from zero-centred gaussian with std w_init_scale / sqrt(N_in)
• recurrent: if True the input layer introduces recurrency. This can be taken into account when computing groundtruth rate maps to avoid infinite recursion. If a circuit includes a recurrent loop, mark the minimum necessary number of input connections as recurrent to prevent infinite looping.
• **kwargs any extra kwargs will get saved into the inputs dictionary in case you need these
"""
Expand All @@ -2633,19 +2634,22 @@ def add_input(self, input_layer, w=None, w_init_scale=1, **kwargs):
self.inputs[name]["w_init"] = w.copy()
self.inputs[name]["I"] = I
self.inputs[name]["n"] = input_layer.n # a copy for convenience
self.inputs[name]["recurrent"] = recurrent
for key, value in kwargs.items():
self.inputs[name][key] = value
if ratinabox.verbose is True:
print(
f'An input layer called {name} was added. The weights can be accessed with "self.inputs[{name}]["w"]"'
)

def get_state(self, evaluate_at="last", **kwargs):
def get_state(self, evaluate_at="last", max_recurrence=None, **kwargs):
"""Returns the firing rate of the feedforward layer cells. By default this layer uses the last saved firingrate from its input layers. Alternatively evaluate_at and kwargs can be set to be anything else which will just be passed to the input layer for evaluation.
Once the firing rate of the inout layers is established these are multiplied by the weight matrices and then activated to obtain the firing rate of this FeedForwardLayer.
Args:
evaluate_at (str, optional). Defaults to 'last'.
max_recurrence: The maximum number of time get_state() recursively calls recurrent inputs (prevents infinite recursion error).
**kwargs: any extra kwargs will get passed to the input layer get_state() call for evaluation.
Returns:
firingrate: array of firing rates
"""
Expand All @@ -2659,11 +2663,16 @@ def get_state(self, evaluate_at="last", **kwargs):
V = np.zeros((self.n, kwargs["pos"].shape[0]))

for inputlayer in self.inputs.values():
pass_max_recurrence = max_recurrence
if max_recurrence is not None and inputlayer['recurrent']:
if max_recurrence <= 0:
continue
pass_max_recurrence = max_recurrence - 1
w = inputlayer["w"]
if evaluate_at == "last":
I = inputlayer["layer"].firingrate
else: # kick can down the road let input layer decide how to evaluate the firingrate. this is core to feedforward layer as this recursive call will backprop through the upstraem layers until it reaches a "core" (e.g. place cells) layer which will then evaluate the firingrate.
I = inputlayer["layer"].get_state(evaluate_at, **kwargs)
I = inputlayer["layer"].get_state(evaluate_at, max_recurrence=pass_max_recurrence, **kwargs)
inputlayer["I_temp"] = I
V += np.matmul(w, I)

Expand All @@ -2686,7 +2695,17 @@ def get_state(self, evaluate_at="last", **kwargs):
return firingrate


def plot_rate_map(self, method="groundtruth", max_recurrence=None, **kwargs):
"""
If groundtruth rate maps are plotted, then a maximum recursion depth is passed.
max_recurrence: The maximum number of time get_state() recursively calls recurrent inputs (prevents infinite recursion error).
"""

if method.startswith("groundtruth"):
return super().plot_rate_map(method=method, max_recurrence=max_recurrence, **kwargs)
else:
return super().plot_rate_map(method=method, **kwargs)



Expand Down

0 comments on commit 16375e1

Please sign in to comment.