From 58af3dc7699d0fe8b39aeba2c3dcb217a9476584 Mon Sep 17 00:00:00 2001 From: Wang Boyu Date: Sat, 9 Sep 2023 19:58:42 -0400 Subject: [PATCH] fix AttributeError in GeoSpace.agents_at() --- mesa_geo/geospace.py | 19 ++++-- tests/test_GeoSpace.py | 134 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 146 insertions(+), 7 deletions(-) diff --git a/mesa_geo/geospace.py b/mesa_geo/geospace.py index a73f0cef..7cfe6567 100644 --- a/mesa_geo/geospace.py +++ b/mesa_geo/geospace.py @@ -180,8 +180,8 @@ def get_relation(self, agent, relation): """ yield from self._agent_layer.get_relation(agent, relation) - def get_intersecting_agents(self, agent, other_agents=None): - return self._agent_layer.get_intersecting_agents(agent, other_agents) + def get_intersecting_agents(self, agent): + return self._agent_layer.get_intersecting_agents(agent) def get_neighbors_within_distance( self, agent, distance, center=False, relation="intersects" @@ -329,10 +329,13 @@ def get_relation(self, agent, relation): possible_agents = self._get_rtree_intersections(agent.geometry) for other_agent in possible_agents: - if getattr(agent.geometry, relation)(other_agent.geometry): + if ( + getattr(agent.geometry, relation)(other_agent.geometry) + and other_agent.unique_id != agent.unique_id + ): yield other_agent - def get_intersecting_agents(self, agent, other_agents=None): + def get_intersecting_agents(self, agent): intersecting_agents = self.get_relation(agent, "intersects") return intersecting_agents @@ -357,12 +360,16 @@ def get_neighbors_within_distance( def agents_at(self, pos): """ - Return a list of agents at given pos. + Return a generator of agents at given pos. """ if not isinstance(pos, Point): pos = Point(pos) - return self.get_relation(pos, "within") + + possible_agents = self._get_rtree_intersections(pos) + for other_agent in possible_agents: + if pos.within(other_agent.geometry): + yield other_agent def distance(self, agent_a, agent_b): """ diff --git a/tests/test_GeoSpace.py b/tests/test_GeoSpace.py index 079fc98d..13fb0897 100644 --- a/tests/test_GeoSpace.py +++ b/tests/test_GeoSpace.py @@ -6,7 +6,7 @@ import geopandas as gpd import numpy as np import pandas as pd -from shapely.geometry import Point +from shapely.geometry import Point, Polygon import mesa_geo as mg @@ -21,6 +21,24 @@ def setUp(self) -> None: ) for geometry in self.geometries ] + self.polygon_agent = mg.GeoAgent( + unique_id=uuid.uuid4().int, + model=None, + geometry=Polygon([(0, 0), (0, 2), (2, 2), (2, 0)]), + crs="epsg:3857", + ) + self.touching_agent = mg.GeoAgent( + unique_id=uuid.uuid4().int, + model=None, + geometry=Polygon([(2, 0), (2, 2), (4, 2), (4, 0)]), + crs="epsg:3857", + ) + self.disjoint_agent = mg.GeoAgent( + unique_id=uuid.uuid4().int, + model=None, + geometry=Polygon([(10, 10), (10, 12), (12, 12), (12, 10)]), + crs="epsg:3857", + ) self.image_layer = mg.ImageLayer( values=np.random.uniform(low=0, high=255, size=(3, 500, 500)), crs="epsg:4326", @@ -133,3 +151,117 @@ def test_get_agents_as_GeoDataFrame(self): self.assertEqual( self.geo_space.get_agents_as_GeoDataFrame().crs, agents_gdf.crs ) + + def test_get_relation_contains(self): + self.geo_space.add_agents(self.polygon_agent) + self.assertEqual( + list(self.geo_space.get_relation(self.polygon_agent, relation="contains")), + [], + ) + + self.geo_space.add_agents(self.agents) + agents_id = {agent.unique_id for agent in self.agents} + contained_agents_id = { + agent.unique_id + for agent in self.geo_space.get_relation( + self.polygon_agent, relation="contains" + ) + } + self.assertEqual(contained_agents_id, agents_id) + + def test_get_relation_within(self): + self.geo_space.add_agents(self.agents[0]) + self.assertEqual( + list(self.geo_space.get_relation(self.agents[0], relation="within")), [] + ) + self.geo_space.add_agents(self.polygon_agent) + within_agent = list( + self.geo_space.get_relation(self.agents[0], relation="within") + )[0] + self.assertEqual(within_agent.unique_id, self.polygon_agent.unique_id) + + def test_get_relation_touches(self): + self.geo_space.add_agents(self.polygon_agent) + self.assertEqual( + list(self.geo_space.get_relation(self.polygon_agent, relation="touches")), + [], + ) + self.geo_space.add_agents(self.touching_agent) + self.assertEqual( + len( + list( + self.geo_space.get_relation(self.polygon_agent, relation="touches") + ) + ), + 1, + ) + self.assertEqual( + list(self.geo_space.get_relation(self.polygon_agent, relation="touches"))[ + 0 + ].unique_id, + self.touching_agent.unique_id, + ) + + def test_get_relation_intersects(self): + self.geo_space.add_agents(self.polygon_agent) + self.assertEqual( + list( + self.geo_space.get_relation(self.polygon_agent, relation="intersects") + ), + [], + ) + + self.geo_space.add_agents(self.agents) + agents_id = {agent.unique_id for agent in self.agents} + intersecting_agents_id = { + agent.unique_id + for agent in self.geo_space.get_relation( + self.polygon_agent, relation="intersects" + ) + } + self.assertEqual(intersecting_agents_id, agents_id) + + # disjoint agent should not be returned since it is not intersecting + self.geo_space.add_agents(self.disjoint_agent) + intersecting_agents_id = { + agent.unique_id + for agent in self.geo_space.get_relation( + self.polygon_agent, relation="intersects" + ) + } + self.assertEqual(intersecting_agents_id, agents_id) + + def test_get_intersecting_agents(self): + self.geo_space.add_agents(self.polygon_agent) + self.assertEqual( + list(self.geo_space.get_intersecting_agents(self.polygon_agent)), + [], + ) + + self.geo_space.add_agents(self.agents) + agents_id = {agent.unique_id for agent in self.agents} + intersecting_agents_id = { + agent.unique_id + for agent in self.geo_space.get_intersecting_agents(self.polygon_agent) + } + self.assertEqual(intersecting_agents_id, agents_id) + + # disjoint agent should not be returned since it is not intersecting + self.geo_space.add_agents(self.disjoint_agent) + intersecting_agents_id = { + agent.unique_id + for agent in self.geo_space.get_intersecting_agents(self.polygon_agent) + } + self.assertEqual(intersecting_agents_id, agents_id) + + def test_agents_at(self): + self.geo_space.add_agents(self.agents) + self.assertEqual( + len(list(self.geo_space.agents_at(self.agents[0].geometry))), + len(self.agents), + ) + agents_id = {agent.unique_id for agent in self.agents} + agents_id_found = { + agent.unique_id for agent in self.geo_space.agents_at((1, 1)) + } + self.assertEqual(agents_id_found, agents_id)