From 94bf36467df0dc61ad6a4bc67ad1030bcbe32457 Mon Sep 17 00:00:00 2001 From: Gaston Sivori Date: Mon, 22 Jan 2024 15:18:37 +0900 Subject: [PATCH 1/3] sample positions considering holes in the environment. The function also now has a boolean flag that watches whether to force samplig method. --- ratinabox/Environment.py | 39 ++++++++++++++++++++++++++------------- ratinabox/utils.py | 9 +++++++++ 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/ratinabox/Environment.py b/ratinabox/Environment.py index 2fb2468..22e0035 100644 --- a/ratinabox/Environment.py +++ b/ratinabox/Environment.py @@ -548,7 +548,7 @@ def plot_environment(self, return fig, ax - def sample_positions(self, n=10, method="uniform_jitter"): + def sample_positions(self, n=10, method="uniform_jitter",force_method=False): """Scatters 'n' locations across the environment which can act as, for example, the centres of gaussian place fields, or as a random starting position. If method == "uniform" an evenly spaced grid of locations is returned. If method == "uniform_jitter" these locations are jittered slightly (i.e. random but span the space). Note; if n doesn't uniformly divide the size (i.e. n is not a square number in a square environment) then the largest number that can be scattered uniformly are found, the remaining are randomly placed. Args: @@ -581,33 +581,46 @@ def sample_positions(self, n=10, method="uniform_jitter"): positions[:, 1] = np.random.uniform( self.extent[2], self.extent[3], size=n ) + if (self.is_rectangular is False) or (self.has_holes is True): + # in this case, the positions you have sampled within the extent of the environment may not actually fall within it's legal area (i.e. they could be outside the polygon boundary or inside a hole). Brute force this by randomly resampling these points until all fall within the env. + for i, pos in enumerate(positions): + if self.check_if_position_is_in_environment(pos) == False: + pos = self.sample_positions(n=1, method="random").reshape( + -1 + ) # this recursive call must pass eventually, assuming the env is sufficiently large. this is why we don't need a while loop + positions[i] = pos elif method[:7] == "uniform": ex = self.extent area = (ex[1] - ex[0]) * (ex[3] - ex[2]) + if (self.has_holes is True): + area -= sum(polygon_area(hole) for hole in self.holes) delta = np.sqrt(area / n) x = np.linspace(ex[0] + delta /2, ex[1] - delta /2, int((ex[1] - ex[0])/delta)) y = np.linspace(ex[2] + delta /2, ex[3] - delta /2, int((ex[3] - ex[2])/delta)) positions = np.array(np.meshgrid(x, y)).reshape(2, -1).T + + if (self.is_rectangular is False) or (self.has_holes is True): + # in this case, the positions you have sampled within the extent of the environment may not actually fall within it's legal area (i.e. they could be outside the polygon boundary or inside a hole). + delpos = [i for (i,pos) in enumerate(positions) if self.check_if_position_is_in_environment(pos) == False] + positions = np.delete(positions,delpos,axis=0) # this will delete illegal positions + n_uniformly_distributed = positions.shape[0] if method[7:] == "_jitter": positions += np.random.uniform( -0.45 * delta, 0.45 * delta, positions.shape - ) + ) n_remaining = n - n_uniformly_distributed if n_remaining > 0: - positions_remaining = self.sample_positions( - n=n_remaining, method="random" - ) + if force_method: + # resample from available positions (repeating sampled positions) + positions_remaining = [positions[i] for i in np.random.choice(range(len(positions)),n_remaining, replace=False)] + else: + # or brute force this by randomly resampling these points until all fall within the env. + positions_remaining = self.sample_positions( + n=n_remaining, method="random" + ) positions = np.vstack((positions, positions_remaining)) - if (self.is_rectangular is False) or (self.has_holes is True): - # in this case, the positions you have sampled within the extent of the environment may not actually fall within it's legal area (i.e. they could be outside the polygon boundary or inside a hole). Brute force this by randomly resampling these points until all fall within the env. - for i, pos in enumerate(positions): - if self.check_if_position_is_in_environment(pos) == False: - pos = self.sample_positions(n=1, method="random").reshape( - -1 - ) # this recursive call must pass eventually, assuming the env is sufficiently large. this is why we don't need a while loop - positions[i] = pos return positions def discretise_environment(self, dx=None): diff --git a/ratinabox/utils.py b/ratinabox/utils.py index f5d8952..4806bf9 100644 --- a/ratinabox/utils.py +++ b/ratinabox/utils.py @@ -13,6 +13,15 @@ """OTHER USEFUL FUNCTIONS""" """Geometry functions""" +def polygon_area(hole): + """Given 4-point list defining a hole in the environment, returns its area. + Args: + hole (array): list of list of points defining the hole. + Returns: + scalar: area of the hole. + """ + x, y = zip(*hole) + return round(0.5*np.abs(np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1))),2) def get_perpendicular(a=None): """Given 2-vector, a, returns its perpendicular From d407cb953b6f4e1dc9e06270836883970a2b02a9 Mon Sep 17 00:00:00 2001 From: Gaston Sivori Date: Wed, 24 Jan 2024 18:57:24 +0900 Subject: [PATCH 2/3] Modified to utilize shapely Polygon area method instead of utils.py polygon area function. Included 'force_method' parameter reference in the docstring of sample_positions() --- ratinabox/Environment.py | 3 ++- ratinabox/utils.py | 9 --------- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/ratinabox/Environment.py b/ratinabox/Environment.py index 22e0035..07fd650 100644 --- a/ratinabox/Environment.py +++ b/ratinabox/Environment.py @@ -555,6 +555,7 @@ def sample_positions(self, n=10, method="uniform_jitter",force_method=False): n (int): number of features method: "uniform", "uniform_jittered" or "random" for how points are distributed true_random: if True, just randomly scatters point + force_method: if True, forces sampling 'method'. if False, illegal sampled positions will be resampled randomly. Returns: array: (n x dimensionality) of positions """ @@ -593,7 +594,7 @@ def sample_positions(self, n=10, method="uniform_jitter",force_method=False): ex = self.extent area = (ex[1] - ex[0]) * (ex[3] - ex[2]) if (self.has_holes is True): - area -= sum(polygon_area(hole) for hole in self.holes) + area -= sum(shapely.geometry.Polygon(hole).area for hole in self.holes) delta = np.sqrt(area / n) x = np.linspace(ex[0] + delta /2, ex[1] - delta /2, int((ex[1] - ex[0])/delta)) y = np.linspace(ex[2] + delta /2, ex[3] - delta /2, int((ex[3] - ex[2])/delta)) diff --git a/ratinabox/utils.py b/ratinabox/utils.py index 4806bf9..f5d8952 100644 --- a/ratinabox/utils.py +++ b/ratinabox/utils.py @@ -13,15 +13,6 @@ """OTHER USEFUL FUNCTIONS""" """Geometry functions""" -def polygon_area(hole): - """Given 4-point list defining a hole in the environment, returns its area. - Args: - hole (array): list of list of points defining the hole. - Returns: - scalar: area of the hole. - """ - x, y = zip(*hole) - return round(0.5*np.abs(np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1))),2) def get_perpendicular(a=None): """Given 2-vector, a, returns its perpendicular From f224975609ab0dd8f4b64eec887e6ca5aec6b1e9 Mon Sep 17 00:00:00 2001 From: Gaston Sivori Date: Fri, 26 Jan 2024 10:49:44 +0900 Subject: [PATCH 3/3] Resampling of remaining positions now at a smaller delta (delta/2). --- ratinabox/Environment.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/ratinabox/Environment.py b/ratinabox/Environment.py index 07fd650..9ee3690 100644 --- a/ratinabox/Environment.py +++ b/ratinabox/Environment.py @@ -548,14 +548,12 @@ def plot_environment(self, return fig, ax - def sample_positions(self, n=10, method="uniform_jitter",force_method=False): + def sample_positions(self, n=10, method="uniform_jitter"): """Scatters 'n' locations across the environment which can act as, for example, the centres of gaussian place fields, or as a random starting position. If method == "uniform" an evenly spaced grid of locations is returned. If method == "uniform_jitter" these locations are jittered slightly (i.e. random but span the space). Note; if n doesn't uniformly divide the size (i.e. n is not a square number in a square environment) then the largest number that can be scattered uniformly are found, the remaining are randomly placed. Args: n (int): number of features method: "uniform", "uniform_jittered" or "random" for how points are distributed - true_random: if True, just randomly scatters point - force_method: if True, forces sampling 'method'. if False, illegal sampled positions will be resampled randomly. Returns: array: (n x dimensionality) of positions """ @@ -612,14 +610,12 @@ def sample_positions(self, n=10, method="uniform_jitter",force_method=False): ) n_remaining = n - n_uniformly_distributed if n_remaining > 0: - if force_method: - # resample from available positions (repeating sampled positions) - positions_remaining = [positions[i] for i in np.random.choice(range(len(positions)),n_remaining, replace=False)] - else: - # or brute force this by randomly resampling these points until all fall within the env. - positions_remaining = self.sample_positions( - n=n_remaining, method="random" - ) + # sample remaining from available positions with further jittering (delta = delta/2) + positions_remaining = np.array([positions[i] for i in np.random.choice(range(len(positions)),n_remaining, replace=False)]) + delta /= 2 + positions_remaining += np.random.uniform( + -0.45 * delta, 0.45 * delta, positions_remaining.shape + ) positions = np.vstack((positions, positions_remaining)) return positions