Skip to content

Commit

Permalink
TabularData: Add subset arg to dropna and drop_duplicates.
Browse files Browse the repository at this point in the history
  • Loading branch information
daavoo committed Dec 17, 2021
1 parent b283900 commit 3964f3d
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 2 deletions.
20 changes: 18 additions & 2 deletions dvc/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
List,
Mapping,
MutableSequence,
Optional,
Sequence,
Set,
Tuple,
Expand Down Expand Up @@ -224,7 +225,9 @@ def as_dict(
{k: self._columns[k][i] for k in keys} for i in range(len(self))
]

def dropna(self, axis: str = "rows", how="any"):
def dropna(
self, axis: str = "rows", how="any", subset: Optional[List] = None
):
if axis not in ["rows", "cols"]:
raise ValueError(
f"Invalid 'axis' value {axis}."
Expand All @@ -242,6 +245,8 @@ def dropna(self, axis: str = "rows", how="any"):

for n_row, row in enumerate(self):
for n_col, col in enumerate(row):
if subset and self.keys()[n_col] not in subset:
continue
if (col == self._fill_value) is match:
if axis == "rows":
match_line.add(n_row)
Expand Down Expand Up @@ -269,7 +274,9 @@ def dropna(self, axis: str = "rows", how="any"):
else:
self.drop(*to_drop)

def drop_duplicates(self, axis: str = "rows"):
def drop_duplicates(
self, axis: str = "rows", subset: Optional[List] = None
):
if axis not in ["rows", "cols"]:
raise ValueError(
f"Invalid 'axis' value {axis}."
Expand All @@ -279,6 +286,8 @@ def drop_duplicates(self, axis: str = "rows"):
if axis == "cols":
cols_to_drop: List[str] = []
for n_col, col in enumerate(self.columns):
if subset and self.keys()[n_col] not in subset:
continue
# Cast to str because Text is not hashable error
unique_vals = {str(x) for x in col if x != self._fill_value}
if len(unique_vals) == 1:
Expand All @@ -289,6 +298,13 @@ def drop_duplicates(self, axis: str = "rows"):
unique_rows = []
rows_to_drop: List[int] = []
for n_row, row in enumerate(self):
if subset:
row = [
col
for n_col, col in enumerate(row)
if self.keys()[n_col] in subset
]

tuple_row = tuple(row)
if tuple_row in unique_rows:
rows_to_drop.append(n_row)
Expand Down
59 changes: 59 additions & 0 deletions tests/unit/test_tabular_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,20 @@ def test_dropna(axis, how, data, expected):
assert list(td) == expected


@pytest.mark.parametrize(
"axis,expected",
[
("cols", [["foo", ""], ["foo", ""], ["foo", "foobar"]]),
("rows", [["foo", "bar", ""], ["foo", "bar", "foobar"]]),
],
)
def test_dropna_subset(axis, expected):
td = TabularData(["col-1", "col-2", "col-3"])
td.extend([["foo"], ["foo", "bar"], ["foo", "bar", "foobar"]])
td.dropna(axis, subset=["col-1", "col-2"])
assert list(td) == expected


@pytest.mark.parametrize(
"axis,expected",
[
Expand Down Expand Up @@ -273,6 +287,51 @@ def test_drop_duplicates(axis, expected):
assert list(td) == expected


@pytest.mark.parametrize(
"axis,subset,expected",
[
(
"rows",
["col-1"],
[["foo", "foo", "foo", "bar"]],
),
(
"rows",
["col-1", "col-3"],
[
["foo", "foo", "foo", "bar"],
["foo", "bar", "foobar", "bar"],
],
),
(
"cols",
["col-1", "col-3"],
[
["foo", "foo", "bar"],
["bar", "foo", "bar"],
["bar", "foobar", "bar"],
],
),
],
)
def test_drop_duplicates_subset(axis, subset, expected):
td = TabularData(["col-1", "col-2", "col-3", "col-4"])
td.extend(
[
["foo", "foo", "foo", "bar"],
["foo", "bar", "foo", "bar"],
["foo", "bar", "foobar", "bar"],
]
)
assert list(td) == [
["foo", "foo", "foo", "bar"],
["foo", "bar", "foo", "bar"],
["foo", "bar", "foobar", "bar"],
]
td.drop_duplicates(axis, subset=subset)
assert list(td) == expected


def test_dropna_invalid_axis():
td = TabularData(["col-1", "col-2", "col-3"])

Expand Down

0 comments on commit 3964f3d

Please sign in to comment.