Skip to content

Commit

Permalink
some extra comments
Browse files Browse the repository at this point in the history
  • Loading branch information
TomGeorge1234 committed Jan 29, 2024
1 parent 85d8722 commit ba679c4
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions ratinabox/Environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,8 @@ def sample_positions(self, n=10, method="uniform_jitter"):
return positions

elif self.dimensionality == "2D":
if method == "random":
if method == "random":
# random scatter positions and check they aren't in holes etc.
positions = np.zeros((n, 2))
positions[:, 0] = np.random.uniform(
self.extent[0], self.extent[1], size=n
Expand All @@ -589,6 +590,7 @@ def sample_positions(self, n=10, method="uniform_jitter"):
) # this recursive call must pass eventually, assuming the env is sufficiently large. this is why we don't need a while loop
positions[i] = pos
elif method[:7] == "uniform":
# uniformly scatter positions on a square grid and check they aren't in holes etc.
ex = self.extent
area = (ex[1] - ex[0]) * (ex[3] - ex[2])
if (self.has_holes is True):
Expand All @@ -599,12 +601,13 @@ def sample_positions(self, n=10, method="uniform_jitter"):
positions = np.array(np.meshgrid(x, y)).reshape(2, -1).T

if (self.is_rectangular is False) or (self.has_holes is True):
# in this case, the positions you have sampled within the extent of the environment may not actually fall within it's legal area (i.e. they could be outside the polygon boundary or inside a hole).
# in this case, the positions you have sampled within the extent of the environment may not actually fall within it's legal area (i.e. they could be outside the polygon boundary or inside a hole). delete those that do for resampling later.
delpos = [i for (i,pos) in enumerate(positions) if self.check_if_position_is_in_environment(pos) == False]
positions = np.delete(positions,delpos,axis=0) # this will delete illegal positions

n_uniformly_distributed = positions.shape[0]
if method[7:] == "_jitter":
# add jitter to the uniformly distributed positions
positions += np.random.uniform(
-0.45 * delta, 0.45 * delta, positions.shape
)
Expand Down

0 comments on commit ba679c4

Please sign in to comment.