diff --git a/examples/pytorch/text-classification/run_classification.py b/examples/pytorch/text-classification/run_classification.py index 0b3d6517c70869..982dbf9cc71bdc 100755 --- a/examples/pytorch/text-classification/run_classification.py +++ b/examples/pytorch/text-classification/run_classification.py @@ -422,7 +422,7 @@ def main(): for split in raw_datasets.keys(): for column in data_args.remove_columns.split(","): logger.info(f"removing column {column} from split {split}") - raw_datasets[split].remove_columns(column) + raw_datasets[split] = raw_datasets[split].remove_columns(column) if data_args.label_column_name is not None and data_args.label_column_name != "label": for key in raw_datasets.keys():