Skip to content

Commit

Permalink
ENH: add kwarg to generate_locations to filter for activities (#493)
Browse files Browse the repository at this point in the history
* ENH: add kwarg to generate_locations to filter activities

* ENH: raise error if activity column is missing

* ENH: move keywords of `generate_locations`
  • Loading branch information
bifbof authored Jul 31, 2023
1 parent 64b39db commit 07b201a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
25 changes: 25 additions & 0 deletions tests/preprocessing/test_staypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,31 @@ def test_method_error(self, example_staypoints):
example_staypoints.as_staypoints.generate_locations(method="unknown")
assert error_msg == str(e.value)

def test_activity_flag(self, example_staypoints):
"""Test if only activity staypoints are used if flag is set."""
# take out staypoint 6 that should have been merged with 2, 15
sp = example_staypoints
data = [True, True, True, True, False, True, True, True]
idx = [1, 2, 3, 5, 6, 7, 15, 80]
activities = pd.Series(data, index=idx)
sp["activity"] = activities
sp, _ = sp.as_staypoints.generate_locations(
method="dbscan",
epsilon=10,
num_samples=2,
distance_metric="haversine",
agg_level="user",
activities_only=True,
)
assert sp.loc[1, "location_id"] == sp.loc[15, "location_id"]
assert sp.loc[2, "location_id"] is pd.NA

def test_activity_flag_missing_column(self, example_staypoints):
"""Test if KeyError is raised if `activity` column is missing"""
msg = 'staypoints must contain column "activity" if "activities_only" flag is set.'
with pytest.raises(KeyError, match=msg):
example_staypoints.as_staypoints.generate_locations(activities_only=True)


class TestMergeStaypoints:
def test_merge_staypoints(self, example_staypoints_merge):
Expand Down
14 changes: 14 additions & 0 deletions trackintel/preprocessing/staypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def generate_locations(
num_samples=1,
distance_metric="haversine",
agg_level="user",
activities_only=False,
print_progress=False,
n_jobs=1,
):
Expand Down Expand Up @@ -47,6 +48,10 @@ def generate_locations(
- 'user' : locations are generated independently per-user.
- 'dataset' : shared locations are generated for all users.
activities_only: bool, default False (requires "activity" column)
Flag to set if locations should be generated only from staypoints on which the value for "activity" is True.
Useful if activites represent more significant places.
print_progress : bool, default False
If print_progress is True, the progress bar is displayed
Expand Down Expand Up @@ -75,6 +80,12 @@ def generate_locations(

# initialize the return GeoDataFrames
sp = staypoints.copy()
non_activities = None
if activities_only:
if "activity" not in sp.columns:
raise KeyError('staypoints must contain column "activity" if "activities_only" flag is set.')
non_activities = sp[~sp["activity"]]
sp = sp[sp["activity"]]
sp = sp.sort_values(["user_id", "started_at"])
geo_col = sp.geometry.name

Expand Down Expand Up @@ -166,6 +177,9 @@ def generate_locations(
# staypoints not linked to a location receive np.nan in 'location_id'
sp.loc[sp["location_id"] == -1, "location_id"] = np.nan

# merge non_activities back if "activities_only" flag is set
sp = pd.concat([sp, non_activities])

if len(locs) > 0:
locs.as_locations # empty location is not valid
else:
Expand Down

0 comments on commit 07b201a

Please sign in to comment.