Skip to content

Commit

Permalink
ENH: remove .apply in predict_transport_mode
Browse files Browse the repository at this point in the history
  • Loading branch information
bifbof committed Jan 7, 2024
1 parent a9a2820 commit 54697c5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 78 deletions.
13 changes: 0 additions & 13 deletions tests/analysis/test_label.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import os

import numpy as np
import pandas as pd
import pytest

import trackintel as ti
from trackintel.analysis.labelling import _check_categories


class TestCreate_activity_flag:
Expand Down Expand Up @@ -100,14 +98,3 @@ def test_simple_coarse_identification_projected(self):
assert tpls_transport_mode_3.iloc[0]["mode"] == "slow_mobility"
assert tpls_transport_mode_3.iloc[1]["mode"] == "motorized_mobility"
assert tpls_transport_mode_3.iloc[2]["mode"] == "fast_mobility"

def test_check_categories(self):
"""Asserts the correct identification of valid category dictionaries."""
tpls_file = os.path.join("tests", "data", "triplegs_transport_mode_identification.csv")
tpls = ti.read_triplegs_csv(tpls_file, sep=";", index_col="id")
correct_dict = {2: "cat1", 7: "cat2", np.inf: "cat3"}

assert _check_categories(correct_dict)
with pytest.raises(ValueError):
incorrect_dict = {10: "cat1", 5: "cat2", np.inf: "cat3"}
tpls.as_triplegs.predict_transport_mode(method="simple-coarse", categories=incorrect_dict)
80 changes: 15 additions & 65 deletions trackintel/analysis/labelling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime

import numpy as np
import pandas as pd

from trackintel.geogr import get_speed_triplegs

Expand Down Expand Up @@ -81,90 +82,39 @@ def predict_transport_mode(triplegs, method="simple-coarse", **kwargs):
categories = kwargs.pop(
"categories", {15 / 3.6: "slow_mobility", 100 / 3.6: "motorized_mobility", np.inf: "fast_mobility"}
)

return _predict_transport_mode_simple_coarse(triplegs, categories)
triplegs = triplegs.copy()
triplegs["mode"] = _predict_transport_mode_simple_coarse(triplegs, categories)
return triplegs
else:
raise AttributeError(f"Method {method} not known for predicting tripleg transport modes.")


def _predict_transport_mode_simple_coarse(triplegs_in, categories):
def _predict_transport_mode_simple_coarse(triplegs, categories):
"""
Predict a transport mode out of three coarse classes.
Predict a transport mode based on provided categories.
Implements a simple speed based heuristic (over the whole tripleg).
As such, it is very fast, but also very simple and coarse.
Parameters
----------
triplegs_in : Triplegs
triplegs : Triplegs
The triplegs for the transport mode prediction.
categories : dict, optional
The categories for the speed classification {upper_boundary:'category_name'}.
The categories for the speed classification {upper_boundary: 'category_name'}.
The unit for the upper boundary is m/s.
The default is {15/3.6: 'slow_mobility', 100/3.6: 'motorized_mobility', np.inf: 'fast_mobility'}.
Raises
------
ValueError
In case the boundaries of the categories are not in ascending order.
Returns
-------
triplegs : trackintel triplegs GeoDataFrame
the triplegs with added column mode, containing the predicted transport modes.
cuts : pd.Series
Column containing the predicted transport modes.
For additional documentation, see
:func:`trackintel.analysis.transport_mode_identification.predict_transport_mode`.
"""
if not (_check_categories(categories)):
raise ValueError("the categories must be in increasing order")

triplegs = triplegs_in.copy()

def category_by_speed(speed):
"""
Identify the mode based on the (overall) tripleg speed.
Parameters
----------
speed : float
the speed of one tripleg
Returns
-------
str
the identified mode.
"""
for bound in categories:
if speed < bound:
return categories[bound]

triplegs_speed = get_speed_triplegs(triplegs)

triplegs["mode"] = triplegs_speed["speed"].apply(category_by_speed)
return triplegs


def _check_categories(cat):
"""
Check if the keys of a dictionary are in ascending order.
Parameters
----------
cat : disct
the dictionary to be checked.
Returns
-------
correct : bool
True if dict keys are in ascending order False otherwise.
"""
correct = True
bounds = list(cat.keys())
for i in range(len(bounds) - 1):
if bounds[i] >= bounds[i + 1]:
correct = False
return correct
categories = dict(sorted(categories.items(), key=lambda item: item[0]))
intervals = pd.IntervalIndex.from_breaks([-np.inf] + list(categories.keys()), closed="left")
speed = get_speed_triplegs(triplegs)["speed"]
cuts = pd.cut(speed, intervals)
return cuts.cat.rename_categories(categories.values())

0 comments on commit 54697c5

Please sign in to comment.