diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index 08ea8351..9736ca0e 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -234,7 +234,9 @@ def filter_tokenizer_args(self, dataset: DatasetType) -> DatasetType: ) column_names = get_columns(dataset) - return dataset.remove_columns(list(set(column_names) - set(tokenizer_args))) + return dataset.remove_columns( + list(set(column_names) - set(tokenizer_args) - set([self.PROMPT_KEY])) + ) def tokenize(self, data: LazyRow) -> Dict[str, Any]: # separate prompt