Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enhance replace_column to accept a list of new columns #312

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 22 additions & 24 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ def add_row(self, row: Row) -> Table:
result = Table._from_pandas_dataframe(new_df)

for column in int_columns:
result = result.replace_column(column, result.get_column(column).transform(lambda it: int(it)))
result = result.replace_column(column, [result.get_column(column).transform(lambda it: int(it))])

return result

Expand Down Expand Up @@ -768,7 +768,7 @@ def add_rows(self, rows: list[Row] | Table) -> Table:
result = Table._from_pandas_dataframe(new_df)

for column in int_columns:
result = result.replace_column(column, result.get_column(column).transform(lambda it: int(it)))
result = result.replace_column(column, [result.get_column(column).transform(lambda it: int(it))])

return result

Expand Down Expand Up @@ -1001,9 +1001,9 @@ def rename_column(self, old_name: str, new_name: str) -> Table:
new_df.columns = self._schema.column_names
return Table._from_pandas_dataframe(new_df.rename(columns={old_name: new_name}))

def replace_column(self, old_column_name: str, new_column: Column) -> Table:
def replace_column(self, old_column_name: str, new_columns: list[Column]) -> Table:
robmeth marked this conversation as resolved.
Show resolved Hide resolved
"""
Return a copy of the table with the specified old column replaced by a new column. Keeps the order of columns.
Return a copy of the table with the specified old column replaced by a list of new columns. Keeps the order of columns.
robmeth marked this conversation as resolved.
Show resolved Hide resolved

This table is not modified.

Expand All @@ -1012,44 +1012,42 @@ def replace_column(self, old_column_name: str, new_column: Column) -> Table:
old_column_name : str
The name of the column to be replaced.

new_column : Column
The new column replacing the old column.
new_columns : list[Column]
The list of new columns replacing the old column.

Returns
-------
result : Table
A table with the old column replaced by the new column.
A table with the old column replaced by the new columns.

Raises
------
UnknownColumnNameError
If the old column does not exist.

DuplicateColumnNameError
If the new column already exists and the existing column is not affected by the replacement.
If at least one of the new columns already exists and the existing column is not affected by the replacement.

ColumnSizeError
If the size of the column does not match the amount of rows.
If the size of at least one of the new columns does not match the amount of rows.
"""
if old_column_name not in self._schema.column_names:
raise UnknownColumnNameError([old_column_name])

if new_column.name in self._schema.column_names and new_column.name != old_column_name:
raise DuplicateColumnNameError(new_column.name)

if self.number_of_rows != new_column._data.size:
raise ColumnSizeError(str(self.number_of_rows), str(new_column._data.size))
columns = list[Column]()
for old_column in self.column_names:
if old_column == old_column_name:
for new_column in new_columns:
if new_column.name in self.column_names and new_column.name != old_column_name:
raise DuplicateColumnNameError(new_column.name)

if old_column_name != new_column.name:
renamed_table = self.rename_column(old_column_name, new_column.name)
result = renamed_table._data
result.columns = renamed_table._schema.column_names
else:
result = self._data.copy()
result.columns = self._schema.column_names
if self.number_of_rows != new_column.number_of_rows:
raise ColumnSizeError(str(self.number_of_rows), str(new_column.number_of_rows))
columns.append(new_column)
else:
columns.append(self.get_column(old_column))

result[new_column.name] = new_column._data
return Table._from_pandas_dataframe(result)
return Table.from_columns(columns)

def shuffle_rows(self) -> Table:
"""
Expand Down Expand Up @@ -1251,7 +1249,7 @@ def transform_column(self, name: str, transformer: Callable[[Row], Any]) -> Tabl
"""
if self.has_column(name):
items: list = [transformer(item) for item in self.to_rows()]
result: Column = Column(name, items)
result: list[Column] = [Column(name, items)]
return self.replace_column(name, result)
raise UnknownColumnNameError([name])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


@pytest.mark.parametrize(
("table", "column_name", "column", "expected"),
("table", "column_name", "columns", "expected"),
[
(
Table(
Expand All @@ -18,13 +18,14 @@
"C": ["a", "b", "c"],
},
),
"C",
Column("C", ["d", "e", "f"]),
"B",
[Column("B", ["d", "e", "f"]), Column("D", [3, 4, 5])],
Table(
{
"A": [1, 2, 3],
"B": [4, 5, 6],
"C": ["d", "e", "f"],
"B": ["d", "e", "f"],
"D": [3, 4, 5],
"C": ["a", "b", "c"],
},
),
),
Expand All @@ -37,7 +38,7 @@
},
),
"C",
Column("D", ["d", "e", "f"]),
[Column("D", ["d", "e", "f"])],
Table(
{
"A": [1, 2, 3],
Expand All @@ -47,26 +48,36 @@
),
),
],
ids=["multiple Columns", "one Column"],
)
def test_should_replace_column(table: Table, column_name: str, column: Column, expected: Table) -> None:
result = table.replace_column(column_name, column)
assert result.schema == expected.schema
robmeth marked this conversation as resolved.
Show resolved Hide resolved
def test_should_replace_column(table: Table, column_name: str, columns: list[Column], expected: Table) -> None:
result = table.replace_column(column_name, columns)
assert result._schema == expected._schema
assert result == expected


@pytest.mark.parametrize(
("old_column_name", "column_values", "column_name", "error", "error_message"),
("old_column_name", "column", "error", "error_message"),
[
("D", ["d", "e", "f"], "C", UnknownColumnNameError, r"Could not find column\(s\) 'D'"),
("C", ["d", "e", "f"], "B", DuplicateColumnNameError, r"Column 'B' already exists."),
("C", ["d", "e"], "D", ColumnSizeError, r"Expected a column of size 3 but got column of size 2."),
("D", [Column("C", ["d", "e", "f"])], UnknownColumnNameError, r"Could not find column\(s\) 'D'"),
(
"C",
[Column("B", ["d", "e", "f"]), Column("D", [3, 2, 1])],
DuplicateColumnNameError,
r"Column 'B' already exists.",
),
(
"C",
[Column("D", [7, 8]), Column("E", ["c", "b"])],
ColumnSizeError,
r"Expected a column of size 3 but got column of size 2.",
),
],
ids=["UnknownColumnNameError", "DuplicateColumnNameError", "ColumnSizeError"],
)
def test_should_raise_error(
old_column_name: str,
column_values: list[str],
column_name: str,
column: list[Column],
error: type[Exception],
error_message: str,
) -> None:
Expand All @@ -77,12 +88,11 @@ def test_should_raise_error(
"C": ["a", "b", "c"],
},
)
column = Column(column_name, column_values)

with pytest.raises(error, match=error_message):
input_table.replace_column(old_column_name, column)


def test_should_fail_on_empty_table() -> None:
with pytest.raises(UnknownColumnNameError):
Table().replace_column("col", Column("a", [1, 2]))
Table().replace_column("col", [Column("a", [1, 2])])