Skip to content

Commit

Permalink
fix AttributeError in GeoSpace.agents_at()
Browse files Browse the repository at this point in the history
  • Loading branch information
wang-boyu authored and rht committed Sep 12, 2023
1 parent 66831b6 commit 58af3dc
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 7 deletions.
19 changes: 13 additions & 6 deletions mesa_geo/geospace.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ def get_relation(self, agent, relation):
"""
yield from self._agent_layer.get_relation(agent, relation)

def get_intersecting_agents(self, agent, other_agents=None):
return self._agent_layer.get_intersecting_agents(agent, other_agents)
def get_intersecting_agents(self, agent):
return self._agent_layer.get_intersecting_agents(agent)

def get_neighbors_within_distance(
self, agent, distance, center=False, relation="intersects"
Expand Down Expand Up @@ -329,10 +329,13 @@ def get_relation(self, agent, relation):

possible_agents = self._get_rtree_intersections(agent.geometry)
for other_agent in possible_agents:
if getattr(agent.geometry, relation)(other_agent.geometry):
if (
getattr(agent.geometry, relation)(other_agent.geometry)
and other_agent.unique_id != agent.unique_id
):
yield other_agent

def get_intersecting_agents(self, agent, other_agents=None):
def get_intersecting_agents(self, agent):
intersecting_agents = self.get_relation(agent, "intersects")
return intersecting_agents

Expand All @@ -357,12 +360,16 @@ def get_neighbors_within_distance(

def agents_at(self, pos):
"""
Return a list of agents at given pos.
Return a generator of agents at given pos.
"""

if not isinstance(pos, Point):
pos = Point(pos)
return self.get_relation(pos, "within")

possible_agents = self._get_rtree_intersections(pos)
for other_agent in possible_agents:
if pos.within(other_agent.geometry):
yield other_agent

def distance(self, agent_a, agent_b):
"""
Expand Down
134 changes: 133 additions & 1 deletion tests/test_GeoSpace.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import geopandas as gpd
import numpy as np
import pandas as pd
from shapely.geometry import Point
from shapely.geometry import Point, Polygon

import mesa_geo as mg

Expand All @@ -21,6 +21,24 @@ def setUp(self) -> None:
)
for geometry in self.geometries
]
self.polygon_agent = mg.GeoAgent(
unique_id=uuid.uuid4().int,
model=None,
geometry=Polygon([(0, 0), (0, 2), (2, 2), (2, 0)]),
crs="epsg:3857",
)
self.touching_agent = mg.GeoAgent(
unique_id=uuid.uuid4().int,
model=None,
geometry=Polygon([(2, 0), (2, 2), (4, 2), (4, 0)]),
crs="epsg:3857",
)
self.disjoint_agent = mg.GeoAgent(
unique_id=uuid.uuid4().int,
model=None,
geometry=Polygon([(10, 10), (10, 12), (12, 12), (12, 10)]),
crs="epsg:3857",
)
self.image_layer = mg.ImageLayer(
values=np.random.uniform(low=0, high=255, size=(3, 500, 500)),
crs="epsg:4326",
Expand Down Expand Up @@ -133,3 +151,117 @@ def test_get_agents_as_GeoDataFrame(self):
self.assertEqual(
self.geo_space.get_agents_as_GeoDataFrame().crs, agents_gdf.crs
)

def test_get_relation_contains(self):
self.geo_space.add_agents(self.polygon_agent)
self.assertEqual(
list(self.geo_space.get_relation(self.polygon_agent, relation="contains")),
[],
)

self.geo_space.add_agents(self.agents)
agents_id = {agent.unique_id for agent in self.agents}
contained_agents_id = {
agent.unique_id
for agent in self.geo_space.get_relation(
self.polygon_agent, relation="contains"
)
}
self.assertEqual(contained_agents_id, agents_id)

def test_get_relation_within(self):
self.geo_space.add_agents(self.agents[0])
self.assertEqual(
list(self.geo_space.get_relation(self.agents[0], relation="within")), []
)
self.geo_space.add_agents(self.polygon_agent)
within_agent = list(
self.geo_space.get_relation(self.agents[0], relation="within")
)[0]
self.assertEqual(within_agent.unique_id, self.polygon_agent.unique_id)

def test_get_relation_touches(self):
self.geo_space.add_agents(self.polygon_agent)
self.assertEqual(
list(self.geo_space.get_relation(self.polygon_agent, relation="touches")),
[],
)
self.geo_space.add_agents(self.touching_agent)
self.assertEqual(
len(
list(
self.geo_space.get_relation(self.polygon_agent, relation="touches")
)
),
1,
)
self.assertEqual(
list(self.geo_space.get_relation(self.polygon_agent, relation="touches"))[
0
].unique_id,
self.touching_agent.unique_id,
)

def test_get_relation_intersects(self):
self.geo_space.add_agents(self.polygon_agent)
self.assertEqual(
list(
self.geo_space.get_relation(self.polygon_agent, relation="intersects")
),
[],
)

self.geo_space.add_agents(self.agents)
agents_id = {agent.unique_id for agent in self.agents}
intersecting_agents_id = {
agent.unique_id
for agent in self.geo_space.get_relation(
self.polygon_agent, relation="intersects"
)
}
self.assertEqual(intersecting_agents_id, agents_id)

# disjoint agent should not be returned since it is not intersecting
self.geo_space.add_agents(self.disjoint_agent)
intersecting_agents_id = {
agent.unique_id
for agent in self.geo_space.get_relation(
self.polygon_agent, relation="intersects"
)
}
self.assertEqual(intersecting_agents_id, agents_id)

def test_get_intersecting_agents(self):
self.geo_space.add_agents(self.polygon_agent)
self.assertEqual(
list(self.geo_space.get_intersecting_agents(self.polygon_agent)),
[],
)

self.geo_space.add_agents(self.agents)
agents_id = {agent.unique_id for agent in self.agents}
intersecting_agents_id = {
agent.unique_id
for agent in self.geo_space.get_intersecting_agents(self.polygon_agent)
}
self.assertEqual(intersecting_agents_id, agents_id)

# disjoint agent should not be returned since it is not intersecting
self.geo_space.add_agents(self.disjoint_agent)
intersecting_agents_id = {
agent.unique_id
for agent in self.geo_space.get_intersecting_agents(self.polygon_agent)
}
self.assertEqual(intersecting_agents_id, agents_id)

def test_agents_at(self):
self.geo_space.add_agents(self.agents)
self.assertEqual(
len(list(self.geo_space.agents_at(self.agents[0].geometry))),
len(self.agents),
)
agents_id = {agent.unique_id for agent in self.agents}
agents_id_found = {
agent.unique_id for agent in self.geo_space.agents_at((1, 1))
}
self.assertEqual(agents_id_found, agents_id)

0 comments on commit 58af3dc

Please sign in to comment.