diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index d31ea0a5fa1e9..0ba8b4debd8f4 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -2816,8 +2816,17 @@ cdef class RecordBatch(_Tabular): Parameters ---------- - names : list of str - List of new column names. + names : list[str] or dict[str, str] + List of new column names or mapping of old column names to new column names. + + If a mapping of old to new column names is passed, then all columns which are + found to match a provided old column name will be renamed to the new column name. + If any column names are not found in the mapping, a KeyError will be raised. + + Raises + ------ + KeyError + If any of the column names passed in the names mapping do not exist. Returns ------- @@ -2838,13 +2847,38 @@ cdef class RecordBatch(_Tabular): ---- n: [2,4,5,100] name: ["Flamingo","Horse","Brittle stars","Centipede"] + >>> new_names = {"n_legs": "n", "animals": "name"} + >>> batch.rename_columns(new_names) + pyarrow.RecordBatch + n: int64 + name: string + ---- + n: [2,4,5,100] + name: ["Flamingo","Horse","Brittle stars","Centipede"] """ cdef: shared_ptr[CRecordBatch] c_batch vector[c_string] c_names - for name in names: - c_names.push_back(tobytes(name)) + if isinstance(names, list): + for name in names: + c_names.push_back(tobytes(name)) + elif isinstance(names, dict): + idx_to_new_name = {} + for name, new_name in names.items(): + indices = self.schema.get_all_field_indices(name) + + if not indices: + raise KeyError("Column {!r} not found".format(name)) + + for index in indices: + idx_to_new_name[index] = new_name + + for i in range(self.num_columns): + new_name = idx_to_new_name.get(i, self.column_names[i]) + c_names.push_back(tobytes(new_name)) + else: + raise TypeError(f"names must be a list or dict not {type(names)!r}") with nogil: c_batch = GetResultValue(self.batch.RenameColumns(move(c_names))) @@ -5215,8 +5249,17 @@ cdef class Table(_Tabular): Parameters ---------- - names : list of str - List of new column names. + names : list[str] or dict[str, str] + List of new column names or mapping of old column names to new column names. + + If a mapping of old to new column names is passed, then all columns which are + found to match a provided old column name will be renamed to the new column name. + If any column names are not found in the mapping, a KeyError will be raised. + + Raises + ------ + KeyError + If any of the column names passed in the names mapping do not exist. Returns ------- @@ -5237,13 +5280,37 @@ cdef class Table(_Tabular): ---- n: [[2,4,5,100]] name: [["Flamingo","Horse","Brittle stars","Centipede"]] + >>> new_names = {"n_legs": "n", "animals": "name"} + >>> table.rename_columns(new_names) + pyarrow.Table + n: int64 + name: string + ---- + n: [[2,4,5,100]] + name: [["Flamingo","Horse","Brittle stars","Centipede"]] """ cdef: shared_ptr[CTable] c_table vector[c_string] c_names - for name in names: - c_names.push_back(tobytes(name)) + if isinstance(names, list): + for name in names: + c_names.push_back(tobytes(name)) + elif isinstance(names, dict): + idx_to_new_name = {} + for name, new_name in names.items(): + indices = self.schema.get_all_field_indices(name) + + if not indices: + raise KeyError("Column {!r} not found".format(name)) + + for index in indices: + idx_to_new_name[index] = new_name + + for i in range(self.num_columns): + c_names.push_back(tobytes(idx_to_new_name.get(i, self.schema[i].name))) + else: + raise TypeError(f"names must be a list or dict not {type(names)!r}") with nogil: c_table = GetResultValue(self.table.RenameColumns(move(c_names))) diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index 539da0e685381..7a140d4132c50 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -1737,6 +1737,43 @@ def test_table_rename_columns(cls): expected = cls.from_arrays(data, names=['eh', 'bee', 'sea']) assert t2.equals(expected) + message = "names must be a list or dict not " + with pytest.raises(TypeError, match=message): + table.rename_columns('not a list') + + +@pytest.mark.parametrize( + ('cls'), + [ + (pa.Table), + (pa.RecordBatch) + ] +) +def test_table_rename_columns_mapping(cls): + data = [ + pa.array(range(5)), + pa.array([-10, -5, 0, 5, 10]), + pa.array(range(5, 10)) + ] + table = cls.from_arrays(data, names=['a', 'b', 'c']) + assert table.column_names == ['a', 'b', 'c'] + + expected = cls.from_arrays(data, names=['eh', 'b', 'sea']) + t1 = table.rename_columns({'a': 'eh', 'c': 'sea'}) + t1.validate() + assert t1 == expected + + # Test renaming duplicate column names + table = cls.from_arrays(data, names=['a', 'a', 'c']) + expected = cls.from_arrays(data, names=['eh', 'eh', 'sea']) + t2 = table.rename_columns({'a': 'eh', 'c': 'sea'}) + t2.validate() + assert t2 == expected + + # Test column not found + with pytest.raises(KeyError, match=r"Column 'd' not found"): + table.rename_columns({'a': 'eh', 'd': 'sea'}) + def test_table_flatten(): ty1 = pa.struct([pa.field('x', pa.int16()),