Skip to content

Commit

Permalink
use rename_columns and processor_kwargs args
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
  • Loading branch information
kylesayrs committed Dec 20, 2024
1 parent 7c54bed commit f2b3c99
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 19 deletions.
37 changes: 22 additions & 15 deletions src/llmcompressor/transformers/finetune/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def __call__(self, add_labels: bool = True) -> DatasetType:

if "input_ids" not in get_columns(dataset):
# tokenize/ process
dataset = self.filter_tokenizer_args(dataset)
logger.debug(f"Tokenizer args after filtering: {get_columns(dataset)}")
dataset = self.filter_processor_args(dataset)
logger.debug(f"Processor args after filtering: {get_columns(dataset)}")
dataset = self.map(
dataset,
self.tokenize,
Expand Down Expand Up @@ -215,26 +215,33 @@ def dataset_template(self) -> Union[Callable[[Any], Any], None]:
def rename_columns(self, dataset: DatasetType) -> DatasetType:
# rename columns to match processor/tokenizer kwargs
column_names = get_columns(dataset)
if self.data_args.text_column in column_names and "text" not in column_names:
logger.debug(f"Renaming column `{self.data_args.text_column}` to `text`")
dataset = dataset.rename_column(self.data_args.text_column, "text")
for from_, to_ in self.data_args.rename_columns:
if from_ not in column_names:
raise ValueError(
f"Cannot rename {from_} to {to_}from columns {column_names}"
)
dataset = dataset.rename_column(from_, to_)
column_names.remove(from_)
column_names.append(to_)

return dataset

def filter_tokenizer_args(self, dataset: DatasetType) -> DatasetType:
# assumes that inputs are not passed via self.processor.__call__ args and kwargs
signature = inspect.signature(self.processor.__call__)
tokenizer_args = set(
key
for key, param in signature.parameters.items()
if param.kind not in (Kind.VAR_POSITIONAL, Kind.VAR_KEYWORD)
)
def filter_processor_args(self, dataset: DatasetType) -> DatasetType:
processor_kwargs = self.data_args.processor_kwargs
if processor_kwargs is None:
# assumes that inputs are not passed via args and kwargs
signature = inspect.signature(self.processor.__call__)
processor_kwargs = set(
key
for key, param in signature.parameters.items()
if param.kind not in (Kind.VAR_POSITIONAL, Kind.VAR_KEYWORD)
)
logger.debug(
f"Found processor args `{tokenizer_args}`. Removing all other columns"
f"Found processor args `{processor_kwargs}`. Removing all other columns"
)

column_names = get_columns(dataset)
return dataset.remove_columns(list(set(column_names) - set(tokenizer_args)))
return dataset.remove_columns(list(set(column_names) - set(processor_kwargs)))

def tokenize(self, data: LazyRow) -> Dict[str, Any]:
# separate prompt
Expand Down
25 changes: 23 additions & 2 deletions src/llmcompressor/transformers/finetune/data/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,35 @@ class CustomDataTrainingArguments(DVCDatasetTrainingArguments):
metadata={
"help": (
"Optional key to be used as the `text` input to tokenizer/processor "
"after dataset preprocesssing"
"after dataset preprocesssing (deprecated, please use "
"`rename_columns` instead)"
)
},
)

remove_columns: Union[None, str, List] = field(
default=None,
metadata={"help": "Column names to remove after preprocessing (deprecated)"},
metadata={
"help": (
"Column names to remove after preprocessing (deprecated, please use "
"`rename_columns` instead)"
)
},
)

rename_columns: Optional[Dict[str, str]] = field(
default_factory=dict,
metadata={
"help": "Optional mapping to rename dataset columns after preprocessing"
},
)

tokenizer_kwargs: Optional[List[str]] = field(
default=None, metadata={"help": "Alias for `processor_kwargs`"}
)

processor_kwargs: Optional[List[str]] = field(
default=None, metadata={"help": "Optional list of processor argument names"}
)

preprocessing_func: Union[None, str, Callable] = field(
Expand Down
16 changes: 14 additions & 2 deletions src/llmcompressor/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,14 @@ def parse_args(**kwargs):
# raise depreciation warnings
if data_args.remove_columns is not None:
warnings.warn(
"`remove_columns` argument is depreciated. When tokenizing datasets, all "
"columns which are invalid inputs the tokenizer will be removed",
"`remove_columns` is depreciated, please use `rename_columns` and "
"`processor_kwargs` instead",
DeprecationWarning,
)

if data_args.text_column is not None:
warnings.warn(
"`text_column` is depreciated, please use `rename_columns` instead",
DeprecationWarning,
)

Expand All @@ -153,6 +159,12 @@ def parse_args(**kwargs):
model_args.processor = model_args.tokenizer
model_args.tokenizer = None

if data_args.tokenizer_kwargs:
if data_args.processor_kwargs:
raise ValueError("Cannot use both a tokenizer and processor")
data_args.processor_kwargs = data_args.tokenizer_kwargs
data_args.tokenizer_kwargs = None

return model_args, data_args, training_args


Expand Down

0 comments on commit f2b3c99

Please sign in to comment.