Skip to content

Commit

Permalink
Adds drop_columns method to TimeSeries (#1040)
Browse files Browse the repository at this point in the history
* Add drop_columns method

* Add drop columns test

* Indentation fix

Co-authored-by: Julien Herzen <julien@unit8.co>
  • Loading branch information
shaido987 and hrzn authored Jun 27, 2022
1 parent a632b37 commit 8eb331a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
12 changes: 12 additions & 0 deletions darts/tests/test_timeseries_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,15 @@ def test_first_last_values(self):
self.assertEqual(
self.series3.univariate_component(1).last_values().tolist(), [20]
)

def test_drop_column(self):
# testing dropping a single column
seriesA = self.series1.drop_columns("0")
self.assertNotIn("0", seriesA.columns.values)
self.assertEqual(seriesA.columns.tolist(), ["1", "2"])
self.assertEqual(len(seriesA.columns), 2)

# testing dropping multiple columns
seriesB = self.series1.drop_columns(["0", "1"])
self.assertIn("2", seriesB.columns.values)
self.assertEqual(len(seriesB.columns), 1)
26 changes: 26 additions & 0 deletions darts/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2660,6 +2660,32 @@ def stack(self, other: "TimeSeries") -> "TimeSeries":
"""
return concatenate([self, other], axis=1)

def drop_columns(self, col_names: Union[List[str], str]) -> "TimeSeries":
"""
Return a new ``TimeSeries`` instance with dropped columns/components.
Parameters
-------
col_names
String or list of strings corresponding the the columns to be dropped.
Returns
-------
TimeSeries
A new TimeSeries instance with specified columns dropped.
"""
if isinstance(col_names, str):
col_names = [col_names]

raise_if_not(
all([(x in self.columns.to_list()) for x in col_names]),
"Some column names in col_names don't exist in the time series.",
logger,
)

new_xa = self._xa.drop_sel({"component": col_names})
return self.__class__(new_xa)

def univariate_component(self, index: Union[str, int]) -> "TimeSeries":
"""
Retrieve one of the components of the series
Expand Down

0 comments on commit 8eb331a

Please sign in to comment.