diff --git a/ratinabox/Neurons.py b/ratinabox/Neurons.py index 090743f..2b4c717 100644 --- a/ratinabox/Neurons.py +++ b/ratinabox/Neurons.py @@ -2603,7 +2603,7 @@ def lambda_activation_function(activation, other_args): f"FeedForwardLayer initialised with {len(self.inputs.keys())} layers. To add another layer use FeedForwardLayer.add_input_layer().\nTo set the weights manually edit them by changing self.inputs['layer_name']['w']" ) - def add_input(self, input_layer, w=None, w_init_scale=1, **kwargs): + def add_input(self, input_layer, w=None, w_init_scale=1, recurrent=False, **kwargs): """Adds an input layer to the class. Each input layer is stored in a dictionary of self.inputs. Each has an associated matrix of weights which are initialised randomly. Note the inputs are stored in a dictionary. The keys are taken to be the name of each layer passed (input_layer.name). Make sure you set this correctly (and uniquely). @@ -2612,6 +2612,7 @@ def add_input(self, input_layer, w=None, w_init_scale=1, **kwargs): • input_layer (_type_): the layer intself. Must be a Neurons() class object (e.g. can be PlaceCells(), etc...). • w: the weight matrix. If None these will be drawn randomly, see next argument. • w_init_scale: initial weights drawn from zero-centred gaussian with std w_init_scale / sqrt(N_in) + • recurrent: if True the input layer introduces recurrency. This can be taken into account when computing groundtruth rate maps to avoid infinite recursion. If a circuit includes a recurrent loop, mark the minimum necessary number of input connections as recurrent to prevent infinite looping. • **kwargs any extra kwargs will get saved into the inputs dictionary in case you need these """ @@ -2633,6 +2634,7 @@ def add_input(self, input_layer, w=None, w_init_scale=1, **kwargs): self.inputs[name]["w_init"] = w.copy() self.inputs[name]["I"] = I self.inputs[name]["n"] = input_layer.n # a copy for convenience + self.inputs[name]["recurrent"] = recurrent for key, value in kwargs.items(): self.inputs[name][key] = value if ratinabox.verbose is True: @@ -2640,12 +2642,14 @@ def add_input(self, input_layer, w=None, w_init_scale=1, **kwargs): f'An input layer called {name} was added. The weights can be accessed with "self.inputs[{name}]["w"]"' ) - def get_state(self, evaluate_at="last", **kwargs): + def get_state(self, evaluate_at="last", max_recurrence=None, **kwargs): """Returns the firing rate of the feedforward layer cells. By default this layer uses the last saved firingrate from its input layers. Alternatively evaluate_at and kwargs can be set to be anything else which will just be passed to the input layer for evaluation. Once the firing rate of the inout layers is established these are multiplied by the weight matrices and then activated to obtain the firing rate of this FeedForwardLayer. Args: evaluate_at (str, optional). Defaults to 'last'. + max_recurrence: The maximum number of time get_state() recursively calls recurrent inputs (prevents infinite recursion error). + **kwargs: any extra kwargs will get passed to the input layer get_state() call for evaluation. Returns: firingrate: array of firing rates """ @@ -2659,11 +2663,16 @@ def get_state(self, evaluate_at="last", **kwargs): V = np.zeros((self.n, kwargs["pos"].shape[0])) for inputlayer in self.inputs.values(): + pass_max_recurrence = max_recurrence + if max_recurrence is not None and inputlayer['recurrent']: + if max_recurrence <= 0: + continue + pass_max_recurrence = max_recurrence - 1 w = inputlayer["w"] if evaluate_at == "last": I = inputlayer["layer"].firingrate else: # kick can down the road let input layer decide how to evaluate the firingrate. this is core to feedforward layer as this recursive call will backprop through the upstraem layers until it reaches a "core" (e.g. place cells) layer which will then evaluate the firingrate. - I = inputlayer["layer"].get_state(evaluate_at, **kwargs) + I = inputlayer["layer"].get_state(evaluate_at, max_recurrence=pass_max_recurrence, **kwargs) inputlayer["I_temp"] = I V += np.matmul(w, I) @@ -2686,7 +2695,17 @@ def get_state(self, evaluate_at="last", **kwargs): return firingrate + def plot_rate_map(self, method="groundtruth", max_recurrence=None, **kwargs): + """ + If groundtruth rate maps are plotted, then a maximum recursion depth is passed. + + max_recurrence: The maximum number of time get_state() recursively calls recurrent inputs (prevents infinite recursion error). + """ + if method.startswith("groundtruth"): + return super().plot_rate_map(method=method, max_recurrence=max_recurrence, **kwargs) + else: + return super().plot_rate_map(method=method, **kwargs)