Skip to content

Commit

Permalink
Enhance HolidayTransform (#1340)
Browse files Browse the repository at this point in the history
* Enhance holiday transform

* Tests and changelog

* Fix linters

* Fixes for PR

* Fixes for PR

* Fix tests

* Fix PR 2

* Fix PR 3

* Fix PR 4
  • Loading branch information
malodetz authored Aug 7, 2023
1 parent f4bcc29 commit 9824888
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 19 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Implementation of PatchTS model ([#1277](https://github.com/tinkoff-ai/etna/pull/1277))

### Changed
-
-
- Add modes `binary` and `category` to `HolidayTransform` ([#763](https://github.com/tinkoff-ai/etna/pull/763))
- Add sorting by timestamp before the fit in `CatBoostPerSegmentModel` and `CatBoostMultiSegmentModel` ([#1337](https://github.com/tinkoff-ai/etna/pull/1337))
- Speed up metrics computation by optimizing segment validation, forbid NaNs during metrics computation ([#1338](https://github.com/tinkoff-ai/etna/pull/1338))
- Unify errors, warnings and checks in models ([#1312](https://github.com/tinkoff-ai/etna/pull/1312))
Expand Down
42 changes: 36 additions & 6 deletions etna/transforms/timestamp/holiday.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
from enum import Enum
from typing import List
from typing import Optional

Expand All @@ -10,23 +11,47 @@
from etna.transforms.base import IrreversibleTransform


class HolidayTransformMode(str, Enum):
"""Enum for different imputation strategy."""

binary = "binary"
category = "category"

@classmethod
def _missing_(cls, value):
raise NotImplementedError(
f"{value} is not a valid {cls.__name__}. Supported mode: {', '.join([repr(m.value) for m in cls])}"
)


class HolidayTransform(IrreversibleTransform, FutureMixin):
"""HolidayTransform generates series that indicates holidays in given dataframe."""
"""
HolidayTransform generates series that indicates holidays in given dataset.
def __init__(self, iso_code: str = "RUS", out_column: Optional[str] = None):
In ``binary`` mode shows the presence of holiday in that day. In ``category`` mode shows the name of the holiday
with value "NO_HOLIDAY" reserved for days without holidays.
"""

_no_holiday_name: str = "NO_HOLIDAY"

def __init__(self, iso_code: str = "RUS", mode: str = "binary", out_column: Optional[str] = None):
"""
Create instance of HolidayTransform.
Parameters
----------
iso_code:
internationally recognised codes, designated to country for which we want to find the holidays
mode:
`binary` to indicate holidays, `category` to specify which holiday do we have at each day
out_column:
name of added column. Use ``self.__repr__()`` if not given.
"""
super().__init__(required_features=["target"])
self.iso_code = iso_code
self.holidays = holidays.CountryHoliday(iso_code)
self.mode = mode
self._mode = HolidayTransformMode(mode)
self.holidays = holidays.country_holidays(iso_code)
self.out_column = out_column

def _get_column_name(self) -> str:
Expand All @@ -48,7 +73,7 @@ def _fit(self, df: pd.DataFrame) -> "HolidayTransform":

def _transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Transform data from df with HolidayTransform and generate a column of holidays flags.
Transform data from df with HolidayTransform and generate a column of holidays flags or its titles.
Parameters
----------
Expand All @@ -64,9 +89,14 @@ def _transform(self, df: pd.DataFrame) -> pd.DataFrame:
raise ValueError("Frequency of data should be no more than daily.")

cols = df.columns.get_level_values("segment").unique()

out_column = self._get_column_name()
encoded_matrix = np.array([int(x in self.holidays) for x in df.index])

if self._mode is HolidayTransformMode.category:
encoded_matrix = np.array(
[self.holidays[x] if x in self.holidays else self._no_holiday_name for x in df.index]
)
else:
encoded_matrix = np.array([int(x in self.holidays) for x in df.index])
encoded_matrix = encoded_matrix.reshape(-1, 1).repeat(len(cols), axis=1)
encoded_df = pd.DataFrame(
encoded_matrix,
Expand Down
18 changes: 12 additions & 6 deletions tests/test_transforms/test_inference/test_inverse_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ def _test_inverse_transform_train_subset_segments(self, ts, transform, segments)
# timestamp
(DateFlagsTransform(), "regular_ts"),
(FourierTransform(period=7, order=2), "regular_ts"),
(HolidayTransform(), "regular_ts"),
(HolidayTransform(mode="binary"), "regular_ts"),
(HolidayTransform(mode="category"), "regular_ts"),
(SpecialDaysTransform(), "regular_ts"),
(TimeFlagsTransform(), "regular_ts"),
],
Expand Down Expand Up @@ -427,7 +428,8 @@ def _test_inverse_transform_future_subset_segments(self, ts, transform, segments
# timestamp
(DateFlagsTransform(), "regular_ts"),
(FourierTransform(period=7, order=2), "regular_ts"),
(HolidayTransform(), "regular_ts"),
(HolidayTransform(mode="binary"), "regular_ts"),
(HolidayTransform(mode="category"), "regular_ts"),
(SpecialDaysTransform(), "regular_ts"),
(TimeFlagsTransform(), "regular_ts"),
],
Expand Down Expand Up @@ -654,7 +656,8 @@ def _test_inverse_transform_train_new_segments(self, ts, transform, train_segmen
"regular_ts",
{},
),
(HolidayTransform(out_column="res"), "regular_ts", {}),
(HolidayTransform(out_column="res", mode="binary"), "regular_ts", {}),
(HolidayTransform(out_column="res", mode="category"), "regular_ts", {}),
(
TimeFlagsTransform(out_column="res"),
"regular_ts",
Expand Down Expand Up @@ -981,7 +984,8 @@ def _test_inverse_transform_future_new_segments(self, ts, transform, train_segme
"regular_ts",
{},
),
(HolidayTransform(out_column="res"), "regular_ts", {}),
(HolidayTransform(out_column="res", mode="binary"), "regular_ts", {}),
(HolidayTransform(out_column="res", mode="category"), "regular_ts", {}),
(
TimeFlagsTransform(out_column="res"),
"regular_ts",
Expand Down Expand Up @@ -1465,7 +1469,8 @@ def _test_inverse_transform_future_with_target(
"regular_ts",
{},
),
(HolidayTransform(out_column="res"), "regular_ts", {}),
(HolidayTransform(out_column="res", mode="binary"), "regular_ts", {}),
(HolidayTransform(out_column="res", mode="category"), "regular_ts", {}),
(
TimeFlagsTransform(out_column="res"),
"regular_ts",
Expand Down Expand Up @@ -1883,7 +1888,8 @@ def _test_inverse_transform_future_without_target(
"regular_ts",
{},
),
(HolidayTransform(out_column="res"), "regular_ts", {}),
(HolidayTransform(out_column="res", mode="binary"), "regular_ts", {}),
(HolidayTransform(out_column="res", mode="category"), "regular_ts", {}),
(
TimeFlagsTransform(out_column="res"),
"regular_ts",
Expand Down
18 changes: 12 additions & 6 deletions tests/test_transforms/test_inference/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,8 @@ def _test_transform_train_subset_segments(self, ts, transform, segments):
# timestamp
(DateFlagsTransform(), "regular_ts"),
(FourierTransform(period=7, order=2), "regular_ts"),
(HolidayTransform(), "regular_ts"),
(HolidayTransform(mode="binary"), "regular_ts"),
(HolidayTransform(mode="category"), "regular_ts"),
(SpecialDaysTransform(), "regular_ts"),
(TimeFlagsTransform(), "regular_ts"),
],
Expand Down Expand Up @@ -409,7 +410,8 @@ def _test_transform_future_subset_segments(self, ts, transform, segments, horizo
# timestamp
(DateFlagsTransform(), "regular_ts"),
(FourierTransform(period=7, order=2), "regular_ts"),
(HolidayTransform(), "regular_ts"),
(HolidayTransform(mode="binary"), "regular_ts"),
(HolidayTransform(mode="category"), "regular_ts"),
(SpecialDaysTransform(), "regular_ts"),
(TimeFlagsTransform(), "regular_ts"),
],
Expand Down Expand Up @@ -605,7 +607,8 @@ def _test_transform_train_new_segments(self, ts, transform, train_segments, expe
"regular_ts",
{"create": {"res_1", "res_2", "res_3", "res_4"}},
),
(HolidayTransform(out_column="res"), "regular_ts", {"create": {"res"}}),
(HolidayTransform(out_column="res", mode="binary"), "regular_ts", {"create": {"res"}}),
(HolidayTransform(out_column="res", mode="category"), "regular_ts", {"create": {"res"}}),
(
TimeFlagsTransform(out_column="res"),
"regular_ts",
Expand Down Expand Up @@ -925,7 +928,8 @@ def _test_transform_future_new_segments(self, ts, transform, train_segments, exp
"regular_ts",
{"create": {"res_1", "res_2", "res_3", "res_4"}},
),
(HolidayTransform(out_column="res"), "regular_ts", {"create": {"res"}}),
(HolidayTransform(out_column="res", mode="binary"), "regular_ts", {"create": {"res"}}),
(HolidayTransform(out_column="res", mode="category"), "regular_ts", {"create": {"res"}}),
(
TimeFlagsTransform(out_column="res"),
"regular_ts",
Expand Down Expand Up @@ -1326,7 +1330,8 @@ def _test_transform_future_with_target(self, ts, transform, expected_changes, ga
"regular_ts",
{"create": {"res_1", "res_2", "res_3", "res_4"}},
),
(HolidayTransform(out_column="res"), "regular_ts", {"create": {"res"}}),
(HolidayTransform(out_column="res", mode="binary"), "regular_ts", {"create": {"res"}}),
(HolidayTransform(out_column="res", mode="category"), "regular_ts", {"create": {"res"}}),
(
TimeFlagsTransform(out_column="res"),
"regular_ts",
Expand Down Expand Up @@ -1705,7 +1710,8 @@ def _test_transform_future_without_target(self, ts, transform, expected_changes,
"regular_ts",
{"create": {"res_1", "res_2", "res_3", "res_4"}},
),
(HolidayTransform(out_column="res"), "regular_ts", {"create": {"res"}}),
(HolidayTransform(out_column="res", mode="binary"), "regular_ts", {"create": {"res"}}),
(HolidayTransform(out_column="res", mode="category"), "regular_ts", {"create": {"res"}}),
(
TimeFlagsTransform(out_column="res"),
"regular_ts",
Expand Down
28 changes: 28 additions & 0 deletions tests/test_transforms/test_timestamp/test_holiday_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ def two_segments_simple_ts_min(simple_constant_df_min: pd.DataFrame):
return ts


@pytest.fixture()
def uk_holiday_names_daily():
values = ["New Year's Day"] + ["New Year Holiday [Scotland]"] + ["NO_HOLIDAY"] * 13
return np.array(values)


@pytest.fixture()
def us_holiday_names_daily():
values = ["New Year's Day"] + ["NO_HOLIDAY"] * 14
return np.array(values)


def test_holiday_with_regressors(simple_ts_with_regressors: TSDataset):
holiday = HolidayTransform(out_column="holiday")
new = holiday.fit_transform(simple_ts_with_regressors)
Expand Down Expand Up @@ -136,6 +148,22 @@ def test_holidays_day(iso_code: str, answer: np.array, two_segments_simple_ts_da
assert np.array_equal(df[segment]["regressor_holidays"].values, answer)


def test_uk_holidays_day_category(uk_holiday_names_daily: np.array, two_segments_simple_ts_daily: TSDataset):
holidays_finder = HolidayTransform(iso_code="UK", mode="category", out_column="regressor_holidays")
ts = holidays_finder.fit_transform(two_segments_simple_ts_daily)
df = ts.to_pandas()
for segment in df.columns.get_level_values("segment").unique():
assert np.array_equal(df[segment]["regressor_holidays"].values, uk_holiday_names_daily)


def test_us_holidays_day_category(us_holiday_names_daily: np.array, two_segments_simple_ts_daily: TSDataset):
holidays_finder = HolidayTransform(iso_code="US", mode="category", out_column="regressor_holidays")
ts = holidays_finder.fit_transform(two_segments_simple_ts_daily)
df = ts.to_pandas()
for segment in df.columns.get_level_values("segment").unique():
assert np.array_equal(df[segment]["regressor_holidays"].values, us_holiday_names_daily)


@pytest.mark.parametrize(
"iso_code,answer",
(
Expand Down

1 comment on commit 9824888

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.