Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consider splits and merges in tdating #349

Merged
merged 13 commits into from
Jul 22, 2024
11 changes: 11 additions & 0 deletions examples/thunderstorm_detection_and_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,17 @@
# Properties of one of the identified cells:
print(cells_id.iloc[0])

###############################################################################
# Optionally, one can also ask to consider splits and merges of thunderstorm cells.
# A cell at time t is considered to split if it will verlap more than 10% with more than
# one cell at time t+1. Conversely, a cell is considered to be a merge, if more
# than one cells fron time t will overlap more than 10% with it.

cells_id, labels = tstorm_detect.detection(
input_image, time=time, output_splits_merges=True
)
print(cells_id.iloc[0])

###############################################################################
# Example of thunderstorm tracking over a timeseries
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
59 changes: 55 additions & 4 deletions pysteps/feature/tstorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def detection(
minmax=41,
mindis=10,
output_feat=False,
output_splits_merges=False,
time="000000000",
):
"""
Expand Down Expand Up @@ -93,6 +94,10 @@ def detection(
smaller distance will be merged. The default is 10 km.
output_feat: bool, optional
Set to True to return only the cell coordinates.
output_split_merge: bool, optional
Set to True to return additional columns in the dataframe for describing the
splitting and merging of cells. Note that columns are initialized with None,
and the information needs to be analyzed while tracking.
time: string, optional
Date and time as string. Used to label time in the resulting dataframe.
The default is '000000000'.
Expand Down Expand Up @@ -166,7 +171,15 @@ def detection(

areas, lines = breakup(input_image, np.nanmin(input_image.flatten()), maxima_dis)

cells_id, labels = get_profile(areas, binary, input_image, loc_max, time, minref)
cells_id, labels = get_profile(
areas,
binary,
input_image,
loc_max,
time,
minref,
output_splits_merges=output_splits_merges,
)

if max_num_features is not None:
idx = np.argsort(cells_id.area.to_numpy())[::-1]
Expand Down Expand Up @@ -225,10 +238,12 @@ def longdistance(loc_max, mindis):
return new_max


def get_profile(areas, binary, ref, loc_max, time, minref):
def get_profile(areas, binary, ref, loc_max, time, minref, output_splits_merges=False):
"""
This function returns the identified cells in a dataframe including their x,y
locations, location of their maxima, maximum reflectivity and contours.
Optionally, the dataframe can include columns for storing information regarding
splitting and merging of cells.
"""
cells = areas * binary
cell_labels = cells[loc_max]
Expand All @@ -255,11 +270,47 @@ def get_profile(areas, binary, ref, loc_max, time, minref):
"area": len(x),
}
)
if output_splits_merges:
cells_id[-1].update(
{
"splitted": None,
"split_IDs": None,
"merged": None,
"merged_IDs": None,
"results_from_split": None,
"will_merge": None,
}
)
labels[cells == cell_labels[n]] = this_id

columns = [
"ID",
"time",
"x",
"y",
"cen_x",
"cen_y",
"max_ref",
"cont",
"area",
]
if output_splits_merges:
columns.extend(
[
"splitted",
"split_IDs",
"merged",
"merged_IDs",
"results_from_split",
"will_merge",
]
)
cells_id = pd.DataFrame(
data=cells_id,
index=range(len(cell_labels)),
columns=["ID", "time", "x", "y", "cen_x", "cen_y", "max_ref", "cont", "area"],
columns=columns,
)

if output_splits_merges:
cells_id["split_IDs"] = cells_id["split_IDs"].astype("object")
cells_id["merged_IDs"] = cells_id["merged_IDs"].astype("object")
return cells_id, labels
65 changes: 56 additions & 9 deletions pysteps/tests/test_feature_tstorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,29 @@
except ModuleNotFoundError:
pass

arg_names = ("source", "output_feat", "dry_input", "max_num_features")
arg_names = (
"source",
"output_feat",
"dry_input",
"max_num_features",
"output_split_merge",
)

arg_values = [
("mch", False, False, None),
("mch", False, False, 5),
("mch", True, False, None),
("mch", True, False, 5),
("mch", False, True, None),
("mch", False, True, 5),
("mch", False, False, None, False),
("mch", False, False, 5, False),
("mch", True, False, None, False),
("mch", True, False, 5, False),
("mch", False, True, None, False),
("mch", False, True, 5, False),
("mch", False, False, None, True),
]


@pytest.mark.parametrize(arg_names, arg_values)
def test_feature_tstorm_detection(source, output_feat, dry_input, max_num_features):
def test_feature_tstorm_detection(
source, output_feat, dry_input, max_num_features, output_split_merge
):
pytest.importorskip("skimage")
pytest.importorskip("pandas")

Expand All @@ -36,7 +45,11 @@

time = "000"
output = detection(
input, time=time, output_feat=output_feat, max_num_features=max_num_features
input,
time=time,
output_feat=output_feat,
max_num_features=max_num_features,
output_splits_merges=output_split_merge,
)

if output_feat:
Expand All @@ -45,6 +58,40 @@
assert output.shape[1] == 2
if max_num_features is not None:
assert output.shape[0] <= max_num_features
elif output_split_merge:
assert isinstance(output, tuple)
assert len(output) == 2
assert isinstance(output[0], DataFrame)
assert isinstance(output[1], np.ndarray)
if max_num_features is not None:
assert output[0].shape[0] <= max_num_features

Check warning on line 67 in pysteps/tests/test_feature_tstorm.py

View check run for this annotation

Codecov / codecov/patch

pysteps/tests/test_feature_tstorm.py#L67

Added line #L67 was not covered by tests
assert output[0].shape[1] == 15
assert list(output[0].columns) == [
"ID",
"time",
"x",
"y",
"cen_x",
"cen_y",
"max_ref",
"cont",
"area",
"splitted",
"split_IDs",
"merged",
"merged_IDs",
"results_from_split",
"will_merge",
]
assert (output[0].time == time).all()
assert output[1].ndim == 2
assert output[1].shape == input.shape
if not dry_input:
assert output[0].shape[0] > 0
assert sorted(list(output[0].ID)) == sorted(list(np.unique(output[1]))[1:])
else:
assert output[0].shape[0] == 0
assert output[1].sum() == 0

Check warning on line 94 in pysteps/tests/test_feature_tstorm.py

View check run for this annotation

Codecov / codecov/patch

pysteps/tests/test_feature_tstorm.py#L93-L94

Added lines #L93 - L94 were not covered by tests
else:
assert isinstance(output, tuple)
assert len(output) == 2
Expand Down
32 changes: 21 additions & 11 deletions pysteps/tests/test_tracking_tdating.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,24 @@
from pysteps.utils import to_reflectivity
from pysteps.tests.helpers import get_precipitation_fields

arg_names = ("source", "dry_input")
arg_names = ("source", "dry_input", "output_splits_merges")

arg_values = [
("mch", False),
("mch", False),
("mch", True),
("mch", False, False),
("mch", False, False),
("mch", True, False),
("mch", False, True),
]

arg_names_multistep = ("source", "len_timesteps")
arg_names_multistep = ("source", "len_timesteps", "output_splits_merges")
arg_values_multistep = [
("mch", 6),
("mch", 6, False),
("mch", 6, True),
]


@pytest.mark.parametrize(arg_names_multistep, arg_values_multistep)
def test_tracking_tdating_dating_multistep(source, len_timesteps):
def test_tracking_tdating_dating_multistep(source, len_timesteps, output_splits_merges):
pytest.importorskip("skimage")

input_fields, metadata = get_precipitation_fields(
Expand All @@ -37,6 +39,7 @@ def test_tracking_tdating_dating_multistep(source, len_timesteps):
input_fields[0 : len_timesteps // 2],
timelist[0 : len_timesteps // 2],
mintrack=1,
output_splits_merges=output_splits_merges,
)
# Second half of timesteps
tracks_2, cells, _ = dating(
Expand All @@ -46,6 +49,7 @@ def test_tracking_tdating_dating_multistep(source, len_timesteps):
start=2,
cell_list=cells,
label_list=labels,
output_splits_merges=output_splits_merges,
)

# Since we are adding cells, number of tracks should increase
Expand All @@ -67,7 +71,7 @@ def test_tracking_tdating_dating_multistep(source, len_timesteps):


@pytest.mark.parametrize(arg_names, arg_values)
def test_tracking_tdating_dating(source, dry_input):
def test_tracking_tdating_dating(source, dry_input, output_splits_merges):
pytest.importorskip("skimage")
pandas = pytest.importorskip("pandas")

Expand All @@ -80,7 +84,13 @@ def test_tracking_tdating_dating(source, dry_input):

timelist = metadata["timestamps"]

output = dating(input, timelist, mintrack=1)
cell_column_length = 9
if output_splits_merges:
cell_column_length = 15

output = dating(
input, timelist, mintrack=1, output_splits_merges=output_splits_merges
)

# Check output format
assert isinstance(output, tuple)
Expand All @@ -92,12 +102,12 @@ def test_tracking_tdating_dating(source, dry_input):
assert len(output[2]) == input.shape[0]
assert isinstance(output[1][0], pandas.DataFrame)
assert isinstance(output[2][0], np.ndarray)
assert output[1][0].shape[1] == 9
assert output[1][0].shape[1] == cell_column_length
assert output[2][0].shape == input.shape[1:]
if not dry_input:
assert len(output[0]) > 0
assert isinstance(output[0][0], pandas.DataFrame)
assert output[0][0].shape[1] == 9
assert output[0][0].shape[1] == cell_column_length
else:
assert len(output[0]) == 0
assert output[1][0].shape[0] == 0
Expand Down
Loading