Skip to content

Commit

Permalink
added the option to give a single Column to the replace_column method…
Browse files Browse the repository at this point in the history
… and updated tests
  • Loading branch information
robmeth committed May 26, 2023
1 parent 06aa7ad commit 9aefec9
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 11 deletions.
9 changes: 7 additions & 2 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,7 @@ def rename_column(self, old_name: str, new_name: str) -> Table:
new_df.columns = self._schema.column_names
return Table._from_pandas_dataframe(new_df.rename(columns={old_name: new_name}))

def replace_column(self, old_column_name: str, new_columns: list[Column]) -> Table:
def replace_column(self, old_column_name: str, new_columns: Column | list[Column] | Table) -> Table:
"""
Return a copy of the table with the specified old column replaced by a list of new columns. Keeps the order of columns.
Expand Down Expand Up @@ -900,14 +900,19 @@ def replace_column(self, old_column_name: str, new_columns: list[Column]) -> Tab
if old_column_name not in self._schema.column_names:
raise UnknownColumnNameError([old_column_name])

if isinstance(new_columns, Column):
new_columns = [new_columns]
elif isinstance(new_columns, Table):
new_columns = new_columns.to_columns()

columns = list[Column]()
for old_column in self.schema.column_names:
if old_column == old_column_name:
for new_column in new_columns:
if new_column.name in self._schema.column_names and new_column.name != old_column_name:
raise DuplicateColumnNameError(new_column.name)

if self.number_of_rows != new_column.number_of_rows:
if self.number_of_rows != new_column._data.size:
raise ColumnSizeError(str(self.number_of_rows), str(new_column._data.size))
columns.append(new_column)
else:
Expand Down
41 changes: 32 additions & 9 deletions tests/safeds/data/tabular/containers/_table/test_replace_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
},
),
"C",
[Column("D", ["d", "e", "f"])],
Column("D", ["d", "e", "f"]),
Table(
{
"A": [1, 2, 3],
Expand All @@ -47,26 +47,50 @@
},
),
),
(
Table(
{
"A": [1, 2, 3],
"B": [4, 5, 6],
"C": ["a", "b", "c"],
},
),
"B",
Table(
{
"D": [7, 8, 9],
"E": ["c", "b", "a"],
},
),
Table(
{
"A": [1, 2, 3],
"D": [7, 8, 9],
"E": ["c", "b", "a"],
"C": ["a", "b", "c"],
},
),
),
],
ids=["list[Column]", "Column", "Table"],
)
def test_should_replace_column(table: Table, column_name: str, columns: list[Column], expected: Table) -> None:
def test_should_replace_column(table: Table, column_name: str, columns: Column | list[Column] | Table, expected: Table) -> None:
result = table.replace_column(column_name, columns)
assert result == expected


@pytest.mark.parametrize(
("old_column_name", "column_values", "column_name", "error"),
("old_column_name", "column", "error"),
[
("D", ["d", "e", "f"], "C", UnknownColumnNameError),
("C", ["d", "e", "f"], "B", DuplicateColumnNameError),
("C", ["d", "e"], "D", ColumnSizeError),
("D", Column("C", ["d", "e", "f"]), UnknownColumnNameError),
("C", [Column("B", ["d", "e", "f"]), Column("D", [3, 2, 1])], DuplicateColumnNameError),
("C", Table({"D": [7, 8], "E": ["c", "b"]}), ColumnSizeError),
],
ids=["UnknownColumnNameError", "DuplicateColumnNameError", "ColumnSizeError"],
)
def test_should_raise_error(
old_column_name: str,
column_values: list[str],
column_name: str,
column: Column | list[Column] | Table,
error: type[Exception],
) -> None:
input_table: Table = Table(
Expand All @@ -76,7 +100,6 @@ def test_should_raise_error(
"C": ["a", "b", "c"],
},
)
column = [Column(column_name, column_values)]

with pytest.raises(error):
input_table.replace_column(old_column_name, column)

0 comments on commit 9aefec9

Please sign in to comment.