diff --git a/examples/geo_schelling/model.py b/examples/geo_schelling/model.py index 486ebf42..3e50b314 100644 --- a/examples/geo_schelling/model.py +++ b/examples/geo_schelling/model.py @@ -56,9 +56,10 @@ def __repr__(self): class GeoSchelling(Model): """Model class for the Schelling segregation model.""" - def __init__(self, density=0.6, minority_pc=0.2): + def __init__(self, density=0.6, minority_pc=0.2, export_data=False): self.density = density self.minority_pc = minority_pc + self.export_data = export_data self.schedule = RandomActivation(self) self.space = GeoSpace(warn_crs_conversion=False) @@ -82,6 +83,11 @@ def __init__(self, density=0.6, minority_pc=0.2): agent.atype = 0 self.schedule.add(agent) + def export_agents_to_file(self) -> None: + self.space.get_agents_as_GeoDataFrame(agent_cls=SchellingAgent).to_crs( + "epsg:4326" + ).to_file("data/schelling_agents.geojson", driver="GeoJSON") + def step(self): """Run one step of the model. @@ -93,3 +99,6 @@ def step(self): if self.happy == self.schedule.get_agent_count(): self.running = False + + if not self.running and self.export_data: + self.export_agents_to_file() diff --git a/examples/geo_schelling/server.py b/examples/geo_schelling/server.py index a172d86b..2807eb67 100644 --- a/examples/geo_schelling/server.py +++ b/examples/geo_schelling/server.py @@ -21,6 +21,7 @@ def render(self, model): model_params = { "density": mesa.visualization.Slider("Agent density", 0.6, 0.1, 1.0, 0.1), "minority_pc": mesa.visualization.Slider("Fraction minority", 0.2, 0.00, 1.0, 0.05), + "export_data": mesa.visualization.Checkbox("Export data after simulation", False), } diff --git a/examples/rainfall/rainfall/model.py b/examples/rainfall/rainfall/model.py index 55b9f7cf..a16a5161 100644 --- a/examples/rainfall/rainfall/model.py +++ b/examples/rainfall/rainfall/model.py @@ -56,10 +56,12 @@ def step(self): class Rainfall(mesa.Model): - def __init__(self, rain_rate=500, water_height=5): + def __init__(self, rain_rate=500, water_height=5, export_data=False, num_steps=20): super().__init__() self.rain_rate = rain_rate self.water_amount = 0 + self.export_data = export_data + self.num_steps = num_steps self.space = CraterLake(crs="epsg:4326", water_height=water_height) self.schedule = mesa.time.RandomActivation(self) @@ -81,6 +83,13 @@ def contained(self): def outflow(self): return self.space.outflow + def export_water_level_to_file(self): + self.space.raster_layer.to_file( + raster_file="data/water_level.asc", + attr_name="water_level", + driver="AAIGrid", + ) + def step(self): for _ in range(self.rain_rate): random_x = np.random.randint(0, self.space.raster_layer.width) @@ -96,3 +105,9 @@ def step(self): self.schedule.step() self.datacollector.collect(self) + + self.num_steps -= 1 + if self.num_steps == 0: + self.running = False + if not self.running and self.export_data: + self.export_water_level_to_file() diff --git a/examples/rainfall/rainfall/server.py b/examples/rainfall/rainfall/server.py index c07cbdd4..cab7e0ac 100644 --- a/examples/rainfall/rainfall/server.py +++ b/examples/rainfall/rainfall/server.py @@ -10,6 +10,8 @@ model_params = { "rain_rate": mesa.visualization.Slider("rain rate", 500, 0, 500, 5), "water_height": mesa.visualization.Slider("water height", 5, 1, 5, 1), + "num_steps": mesa.visualization.Slider("total number of steps", 20, 1, 100, 1), + "export_data": mesa.visualization.Checkbox("export data after simulation", False), } diff --git a/mesa_geo/geospace.py b/mesa_geo/geospace.py index 38e6d3e1..ab9a1e09 100644 --- a/mesa_geo/geospace.py +++ b/mesa_geo/geospace.py @@ -226,6 +226,9 @@ def get_neighbors(self, agent): """Get (touching) neighbors of an agent.""" return self._agent_layer.get_neighbors(agent) + def get_agents_as_GeoDataFrame(self, agent_cls=GeoAgent) -> gpd.GeoDataFrame: + return self._agent_layer.get_agents_as_GeoDataFrame(agent_cls) + class _AgentLayer: """Layer that contains the GeoAgents. Mainly for internal usage within `GeoSpace`. @@ -368,3 +371,19 @@ def get_neighbors(self, agent): neighbors_idx = self._neighborhood.neighbors[idx] neighbors = [self.agents[i] for i in neighbors_idx] return neighbors + + def get_agents_as_GeoDataFrame(self, agent_cls=GeoAgent) -> gpd.GeoDataFrame: + agents_list = [] + crs = None + for agent in self.agents: + if isinstance(agent, agent_cls): + crs = agent.crs + agent_dict = { + attr: value + for attr, value in vars(agent).items() + if attr not in {"model", "pos", "_crs"} + } + agents_list.append(agent_dict) + agents_gdf = gpd.GeoDataFrame.from_records(agents_list, index="unique_id") + agents_gdf.crs = crs + return agents_gdf diff --git a/mesa_geo/raster_layers.py b/mesa_geo/raster_layers.py index 8b9ae3e6..c4d83304 100644 --- a/mesa_geo/raster_layers.py +++ b/mesa_geo/raster_layers.py @@ -154,6 +154,7 @@ class RasterLayer(RasterBase): cells: List[List[Cell]] _neighborhood_cache: Dict[Any, List[Coordinate]] + _attributes: Set[str] def __init__(self, width, height, crs, total_bounds, cell_cls: Type[Cell] = Cell): super().__init__(width, height, crs, total_bounds) @@ -166,8 +167,13 @@ def __init__(self, width, height, crs, total_bounds, cell_cls: Type[Cell] = Cell col.append(self.cell_cls(pos=(x, y), indices=(row_idx, col_idx))) self.cells.append(col) + self._attributes = set() self._neighborhood_cache = {} + @property + def attributes(self) -> Set[str]: + return self._attributes + @overload def __getitem__(self, index: int) -> List[Cell]: ... @@ -233,14 +239,53 @@ def coord_iter(self) -> Iterator[Tuple[Cell, int, int]]: for col in range(self.height): yield self.cells[row][col], row, col # cell, x, y - def apply_raster(self, data: np.ndarray, attr_name: str = None) -> None: - assert data.shape == (1, self.height, self.width) + def apply_raster(self, data: np.ndarray, attr_name: str | None = None) -> None: + """Apply raster data to the cells. + Args: + data: 2D numpy array with shape (1, height, width). + attr_name: name of the attribute to be added to the cells. If None, a random name will be generated. Default is None. + Returns: + None + Raises: + ValueError: if the shape of the data is not (1, height, width). + """ + if data.shape != (1, self.height, self.width): + raise ValueError( + f"Data shape does not match raster shape. " + f"Expected {(1, self.height, self.width)}, received {data.shape}." + ) if attr_name is None: attr_name = f"attribute_{len(self.cell_cls.__dict__)}" + self._attributes.add(attr_name) for x in range(self.width): for y in range(self.height): setattr(self.cells[x][y], attr_name, data[0, self.height - y - 1, x]) + def get_raster(self, attr_name: str | None = None) -> np.ndarray: + """Returns the values of given attribute. + Args: + attr_name: The name of the attribute to return. If None, returns all attributes. Default is None. + Returns: + The values of given attribute. + """ + if attr_name is not None and attr_name not in self.attributes: + raise ValueError( + f"Attribute {attr_name} does not exist. " + f"Choose from {self.attributes}, or set `attr_name` to `None` to retrieve all." + ) + if attr_name is None: + num_bands = len(self.attributes) + attr_names = self.attributes + else: + num_bands = 1 + attr_names = {attr_name} + data = np.empty((num_bands, self.height, self.width)) + for ind, name in enumerate(attr_names): + for x in range(self.width): + for y in range(self.height): + data[ind, self.height - y - 1, x] = getattr(self.cells[x][y], name) + return data + def iter_neighborhood( self, pos: Coordinate, @@ -399,7 +444,7 @@ def to_image(self, colormap) -> ImageLayer: @classmethod def from_file( - cls, raster_file, cell_cls: Type[Cell] = Cell, attr_name: str = None + cls, raster_file: str, cell_cls: Type[Cell] = Cell, attr_name: str | None = None ) -> RasterLayer: with rio.open(raster_file, "r") as dataset: values = dataset.read() @@ -415,6 +460,27 @@ def from_file( obj.apply_raster(values, attr_name=attr_name) return obj + def to_file(self, raster_file: str, attr_name: str | None = None, driver="GTiff"): + """Writes a raster layer to a file. + Args: + raster_file: Path to the raster file to write. + attr_name: Name of the attribute to write to the raster. If None, all attributes are written. Default is None. + driver: Driver to use for writing the raster. Default is "GTiff" (see GDAL docs at https://gdal.org/drivers/raster/index.html). + """ + data = self.get_raster(attr_name) + with rio.open( + raster_file, + "w", + driver=driver, + width=self.width, + height=self.height, + count=data.shape[0], + dtype=data.dtype, + crs=self.crs, + transform=self.transform, + ) as dataset: + dataset.write(data) + class ImageLayer(RasterBase): _values: np.ndarray diff --git a/tests/test_GeoSpace.py b/tests/test_GeoSpace.py index a473d88f..460b997a 100644 --- a/tests/test_GeoSpace.py +++ b/tests/test_GeoSpace.py @@ -4,6 +4,7 @@ import warnings import numpy as np +import pandas as pd import geopandas as gpd from shapely.geometry import Point @@ -116,3 +117,20 @@ def test_get_neighbors_within_distance(self): self.geo_space.get_neighbors_within_distance(agent_to_check, distance=1.0) ) self.assertEqual(len(neighbors), 7) + + def test_get_agents_as_GeoDataFrame(self): + self.geo_space.add_agents(self.agents) + + agents_list = [ + {"geometry": agent.geometry, "unique_id": agent.unique_id} + for agent in self.agents + ] + agents_gdf = gpd.GeoDataFrame.from_records(agents_list, index="unique_id") + agents_gdf.crs = self.geo_space.crs + + pd.testing.assert_frame_equal( + self.geo_space.get_agents_as_GeoDataFrame(), agents_gdf + ) + self.assertEqual( + self.geo_space.get_agents_as_GeoDataFrame().crs, agents_gdf.crs + ) diff --git a/tests/test_RasterLayer.py b/tests/test_RasterLayer.py index 4e8c116a..856115ea 100644 --- a/tests/test_RasterLayer.py +++ b/tests/test_RasterLayer.py @@ -37,9 +37,40 @@ def test_apple_raster(self): [5, 6]]] """ self.assertEqual(self.raster_layer.cells[0][1].attribute_5, 3) + self.assertEqual(self.raster_layer.attributes, {"attribute_5"}) self.raster_layer.apply_raster(raster_data, attr_name="elevation") self.assertEqual(self.raster_layer.cells[0][1].elevation, 3) + self.assertEqual(self.raster_layer.attributes, {"attribute_5", "elevation"}) + + with self.assertRaises(ValueError): + self.raster_layer.apply_raster(np.empty((1, 100, 100))) + + def test_get_raster(self): + raster_data = np.array([[[1, 2], [3, 4], [5, 6]]]) + self.raster_layer.apply_raster(raster_data) + """ + (x, y) coordinates: + (0, 2), (1, 2) + (0, 1), (1, 1) + (0, 0), (1, 0) + + values: + [[[1, 2], + [3, 4], + [5, 6]]] + """ + self.raster_layer.apply_raster(raster_data, attr_name="elevation") + np.testing.assert_array_equal( + self.raster_layer.get_raster(attr_name="elevation"), raster_data + ) + + self.raster_layer.apply_raster(raster_data) + np.testing.assert_array_equal( + self.raster_layer.get_raster(), np.concatenate((raster_data, raster_data)) + ) + with self.assertRaises(ValueError): + self.raster_layer.get_raster("not_existing_attr") def test_get_min_cell(self): self.raster_layer.apply_raster(