diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index 08ea83512..dff5eba8e 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -107,8 +107,8 @@ def __call__(self, add_labels: bool = True) -> DatasetType: if "input_ids" not in get_columns(dataset): # tokenize/ process - dataset = self.filter_tokenizer_args(dataset) - logger.debug(f"Tokenizer args after filtering: {get_columns(dataset)}") + dataset = self.filter_processor_args(dataset) + logger.debug(f"Processor args after filtering: {get_columns(dataset)}") dataset = self.map( dataset, self.tokenize, @@ -215,26 +215,33 @@ def dataset_template(self) -> Union[Callable[[Any], Any], None]: def rename_columns(self, dataset: DatasetType) -> DatasetType: # rename columns to match processor/tokenizer kwargs column_names = get_columns(dataset) - if self.data_args.text_column in column_names and "text" not in column_names: - logger.debug(f"Renaming column `{self.data_args.text_column}` to `text`") - dataset = dataset.rename_column(self.data_args.text_column, "text") + for from_, to_ in self.data_args.rename_columns: + if from_ not in column_names: + raise ValueError( + f"Cannot rename {from_} to {to_}from columns {column_names}" + ) + dataset = dataset.rename_column(from_, to_) + column_names.remove(from_) + column_names.append(to_) return dataset - def filter_tokenizer_args(self, dataset: DatasetType) -> DatasetType: - # assumes that inputs are not passed via self.processor.__call__ args and kwargs - signature = inspect.signature(self.processor.__call__) - tokenizer_args = set( - key - for key, param in signature.parameters.items() - if param.kind not in (Kind.VAR_POSITIONAL, Kind.VAR_KEYWORD) - ) + def filter_processor_args(self, dataset: DatasetType) -> DatasetType: + processor_kwargs = self.data_args.processor_kwargs + if processor_kwargs is None: + # assumes that inputs are not passed via args and kwargs + signature = inspect.signature(self.processor.__call__) + processor_kwargs = set( + key + for key, param in signature.parameters.items() + if param.kind not in (Kind.VAR_POSITIONAL, Kind.VAR_KEYWORD) + ) logger.debug( - f"Found processor args `{tokenizer_args}`. Removing all other columns" + f"Found processor args `{processor_kwargs}`. Removing all other columns" ) column_names = get_columns(dataset) - return dataset.remove_columns(list(set(column_names) - set(tokenizer_args))) + return dataset.remove_columns(list(set(column_names) - set(processor_kwargs))) def tokenize(self, data: LazyRow) -> Dict[str, Any]: # separate prompt diff --git a/src/llmcompressor/transformers/finetune/data/data_args.py b/src/llmcompressor/transformers/finetune/data/data_args.py index 7d0bc14ce..9b410b512 100644 --- a/src/llmcompressor/transformers/finetune/data/data_args.py +++ b/src/llmcompressor/transformers/finetune/data/data_args.py @@ -38,14 +38,35 @@ class CustomDataTrainingArguments(DVCDatasetTrainingArguments): metadata={ "help": ( "Optional key to be used as the `text` input to tokenizer/processor " - "after dataset preprocesssing" + "after dataset preprocesssing (deprecated, please use " + "`rename_columns` instead)" ) }, ) remove_columns: Union[None, str, List] = field( default=None, - metadata={"help": "Column names to remove after preprocessing (deprecated)"}, + metadata={ + "help": ( + "Column names to remove after preprocessing (deprecated, please use " + "`rename_columns` instead)" + ) + }, + ) + + rename_columns: Optional[Dict[str, str]] = field( + default_factory=dict, + metadata={ + "help": "Optional mapping to rename dataset columns after preprocessing" + }, + ) + + tokenizer_kwargs: Optional[List[str]] = field( + default=None, metadata={"help": "Alias for `processor_kwargs`"} + ) + + processor_kwargs: Optional[List[str]] = field( + default=None, metadata={"help": "Optional list of processor argument names"} ) preprocessing_func: Union[None, str, Callable] = field( diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index 5a06b302f..b808eeffd 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -141,8 +141,14 @@ def parse_args(**kwargs): # raise depreciation warnings if data_args.remove_columns is not None: warnings.warn( - "`remove_columns` argument is depreciated. When tokenizing datasets, all " - "columns which are invalid inputs the tokenizer will be removed", + "`remove_columns` is depreciated, please use `rename_columns` and " + "`processor_kwargs` instead", + DeprecationWarning, + ) + + if data_args.text_column is not None: + warnings.warn( + "`text_column` is depreciated, please use `rename_columns` instead", DeprecationWarning, ) @@ -153,6 +159,12 @@ def parse_args(**kwargs): model_args.processor = model_args.tokenizer model_args.tokenizer = None + if data_args.tokenizer_kwargs: + if data_args.processor_kwargs: + raise ValueError("Cannot use both a tokenizer and processor") + data_args.processor_kwargs = data_args.tokenizer_kwargs + data_args.tokenizer_kwargs = None + return model_args, data_args, training_args