Skip to content

Commit

Permalink
refactor(pathline/endpoint plots): support recarray or dataframe (#1888)
Browse files Browse the repository at this point in the history
* refactor/expand tests
* tidy docstrings
  • Loading branch information
wpbonelli authored Aug 1, 2023
1 parent 47e5e35 commit 54d6099
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 50 deletions.
77 changes: 69 additions & 8 deletions autotest/test_plot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import numpy as np
import pandas as pd
import pytest
from flaky import flaky
from matplotlib import pyplot as plt
Expand Down Expand Up @@ -387,9 +388,41 @@ def modpath_model(function_tmpdir, example_data_path):


@requires_exe("mf2005", "mp6")
def test_xc_plot_particle_pathlines(modpath_model):
def test_plot_map_view_mp6_plot_pathline(modpath_model):
ml, mp, sim = modpath_model
mp.write_input()
mp.run_model(silent=False)

pthobj = PathlineFile(os.path.join(mp.model_ws, "ex6.mppth"))
well_pathlines = pthobj.get_destination_pathline_data(
dest_cells=[(4, 12, 12)]
)

def test_plot(pl):
mx = PlotMapView(model=ml)
mx.plot_grid()
mx.plot_bc("WEL", kper=2, color="blue")
pth = mx.plot_pathline(pl, colors="red")
# plt.show()
assert isinstance(pth, LineCollection)
assert len(pth._paths) == 114

# support pathlines as list of recarrays
test_plot(well_pathlines)

# support pathlines as list of dataframes
test_plot([pd.DataFrame(pl) for pl in well_pathlines])

# support pathlines as single recarray
test_plot(np.concatenate(well_pathlines))

# support pathlines as single dataframe
test_plot(pd.DataFrame(np.concatenate(well_pathlines)))


@requires_exe("mf2005", "mp6")
def test_plot_cross_section_mp6_plot_pathline(modpath_model):
ml, mp, sim = modpath_model
mp.write_input()
mp.run_model(silent=False)

Expand All @@ -398,24 +431,52 @@ def test_xc_plot_particle_pathlines(modpath_model):
dest_cells=[(4, 12, 12)]
)

mx = PlotCrossSection(model=ml, line={"row": 4})
mx.plot_bc("WEL", kper=2, color="blue")
pth = mx.plot_pathline(well_pathlines, method="cell", colors="red")
def test_plot(pl):
mx = PlotCrossSection(model=ml, line={"row": 4})
mx.plot_bc("WEL", kper=2, color="blue")
pth = mx.plot_pathline(pl, method="cell", colors="red")
assert isinstance(pth, LineCollection)
assert len(pth._paths) == 6

# support pathlines as list of recarrays
test_plot(well_pathlines)

# support pathlines as list of dataframes
test_plot([pd.DataFrame(pl) for pl in well_pathlines])

# support pathlines as single recarray
test_plot(np.concatenate(well_pathlines))

assert isinstance(pth, LineCollection)
assert len(pth._paths) == 6
# support pathlines as single dataframe
test_plot(pd.DataFrame(np.concatenate(well_pathlines)))


@requires_exe("mf2005", "mp6")
def test_map_plot_particle_endpoints(modpath_model):
def test_plot_map_view_mp6_endpoint(modpath_model):
ml, mp, sim = modpath_model
mp.write_input()
mp.run_model(silent=False)

pthobj = EndpointFile(os.path.join(mp.model_ws, "ex6.mpend"))
endpts = pthobj.get_alldata()

# color kwarg as scalar
# support endpoints as recarray
assert isinstance(endpts, np.recarray)
mv = PlotMapView(model=ml)
mv.plot_bc("WEL", kper=2, color="blue")
ep = mv.plot_endpoint(endpts, direction="ending")
# plt.show()
assert isinstance(ep, PathCollection)

# support endpoints as dataframe
mv = PlotMapView(model=ml)
mv.plot_bc("WEL", kper=2, color="blue")
ep = mv.plot_endpoint(pd.DataFrame(endpts), direction="ending")
# plt.show()
assert isinstance(ep, PathCollection)

# test various possibilities for endpoint color configuration.
# first, color kwarg as scalar
mv = PlotMapView(model=ml)
mv.plot_bc("WEL", kper=2, color="blue")
ep = mv.plot_endpoint(endpts, direction="ending", color="red")
Expand Down
7 changes: 7 additions & 0 deletions flopy/plot/crosssection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import matplotlib.colors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.patches import Polygon

from ..utils import geometry, import_optional_dependency
Expand Down Expand Up @@ -1098,6 +1099,12 @@ def plot_pathline(
else:
pl = [pl]

# make sure each element in pl is a recarray
pl = [
p.to_records(index=False) if isinstance(p, pd.DataFrame) else p
for p in pl
]

marker = kwargs.pop("marker", None)
markersize = kwargs.pop("markersize", None)
markersize = kwargs.pop("ms", markersize)
Expand Down
96 changes: 54 additions & 42 deletions flopy/plot/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import matplotlib.colors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.collections import LineCollection, PathCollection
from matplotlib.path import Path

Expand Down Expand Up @@ -695,33 +696,35 @@ def plot_vector(

def plot_pathline(self, pl, travel_time=None, **kwargs):
"""
Plot the MODPATH pathlines.
Plot MODPATH pathlines.
Parameters
----------
pl : list of rec arrays or a single rec array
rec array or list of rec arrays is data returned from
modpathfile PathlineFile get_data() or get_alldata()
methods. Data in rec array is 'x', 'y', 'z', 'time',
'k', and 'particleid'.
pl : list of recarrays or dataframes, or a single recarray or dataframe
Particle pathline data. If a list of recarrays or dataframes,
each must contain the path of only a single particle. If just
one recarray or dataframe, it should contain the paths of all
particles. Pathline data returned from PathlineFile.get_data()
or get_alldata() can be passed directly as this argument. Data
columns should be 'x', 'y', 'z', 'time', 'k', and 'particleid'
at minimum. Additional columns are ignored. The 'particleid'
column must be unique to each particle path.
travel_time : float or str
travel_time is a travel time selection for the displayed
pathlines. If a float is passed then pathlines with times
less than or equal to the passed time are plotted. If a
string is passed a variety logical constraints can be added
in front of a time value to select pathlines for a select
period of time. Valid logical constraints are <=, <, ==, >=, and
>. For example, to select all pathlines less than 10000 days
travel_time='< 10000' would be passed to plot_pathline.
(default is None)
kwargs : layer, ax, colors. The remaining kwargs are passed
into the LineCollection constructor. If layer='all',
pathlines are output for all layers
Travel time selection. If a float, then pathlines with total
time less than or equal to the given value are plotted. If a
string, the value must be a comparison operator, then a time
value. Valid operators are <=, <, ==, >=, and >. For example,
to filter pathlines with less than 10000 units of total time
traveled, use '< 10000'. (Default is None.)
kwargs : dict
Explicitly supported kwargs are layer, ax, colors.
Any remaining kwargs are passed into the LineCollection
constructor. If layer='all', pathlines are shown for all layers.
Returns
-------
lc : matplotlib.collections.LineCollection
The pathlines added to the plot.
"""

from matplotlib.collections import LineCollection
Expand All @@ -734,6 +737,11 @@ def plot_pathline(self, pl, travel_time=None, **kwargs):
else:
pl = [pl]

pl = [
p.to_records(index=False) if isinstance(p, pd.DataFrame) else p
for p in pl
]

if "layer" in kwargs:
kon = kwargs.pop("layer")
if isinstance(kon, bytes):
Expand Down Expand Up @@ -809,32 +817,35 @@ def plot_pathline(self, pl, travel_time=None, **kwargs):

def plot_timeseries(self, ts, travel_time=None, **kwargs):
"""
Plot the MODPATH timeseries.
Plot MODPATH timeseries.
Parameters
----------
ts : list of rec arrays or a single rec array
rec array or list of rec arrays is data returned from
modpathfile TimeseriesFile get_data() or get_alldata()
methods. Data in rec array is 'x', 'y', 'z', 'time',
'k', and 'particleid'.
ts : list of recarrays or dataframes, or a single recarray or dataframe
Particle timeseries data. If a list of recarrays or dataframes,
each must contain the path of only a single particle. If just
one recarray or dataframe, it should contain the paths of all
particles. Timeseries data returned from TimeseriesFile.get_data()
or get_alldata() can be passed directly as this argument. Data
columns should be 'x', 'y', 'z', 'time', 'k', and 'particleid'
at minimum. Additional columns are ignored. The 'particleid'
column must be unique to each particle path.
travel_time : float or str
travel_time is a travel time selection for the displayed
pathlines. If a float is passed then pathlines with times
less than or equal to the passed time are plotted. If a
string is passed a variety logical constraints can be added
in front of a time value to select pathlines for a select
period of time. Valid logical constraints are <=, <, ==, >=, and
>. For example, to select all pathlines less than 10000 days
travel_time='< 10000' would be passed to plot_pathline.
(default is None)
kwargs : layer, ax, colors. The remaining kwargs are passed
into the LineCollection constructor. If layer='all',
pathlines are output for all layers
Travel time selection. If a float, then pathlines with total
time less than or equal to the given value are plotted. If a
string, the value must be a comparison operator, then a time
value. Valid operators are <=, <, ==, >=, and >. For example,
to filter pathlines with less than 10000 units of total time
traveled, use '< 10000'. (Default is None.)
kwargs : dict
Explicitly supported kwargs are layer, ax, colors.
Any remaining kwargs are passed into the LineCollection
constructor. If layer='all', pathlines are shown for all layers.
Returns
-------
lo : list of Line2D objects
lc : matplotlib.collections.LineCollection
The pathlines added to the plot.
"""
if "color" in kwargs:
kwargs["markercolor"] = kwargs["color"]
Expand All @@ -850,13 +861,13 @@ def plot_endpoint(
**kwargs,
):
"""
Plot the MODPATH endpoints.
Plot MODPATH endpoints.
Parameters
----------
ep : rec array
ep : recarray or dataframe
A numpy recarray with the endpoint particle data from the
MODPATH 6 endpoint file
MODPATH endpoint file
direction : str
String defining if starting or ending particle locations should be
considered. (default is 'ending')
Expand All @@ -882,7 +893,8 @@ def plot_endpoint(
Returns
-------
sp : matplotlib.pyplot.scatter
sp : matplotlib.collections.PathCollection
The PathCollection added to the plot.
"""

Expand Down

0 comments on commit 54d6099

Please sign in to comment.