diff --git a/src/llmcompressor/pytorch/model_load/helpers.py b/src/llmcompressor/pytorch/model_load/helpers.py
index 3db9be173..a9ecb67a7 100644
--- a/src/llmcompressor/pytorch/model_load/helpers.py
+++ b/src/llmcompressor/pytorch/model_load/helpers.py
@@ -9,6 +9,7 @@
 
 from llmcompressor.core import active_session, create_session, pre_initialize_structure
 from llmcompressor.pytorch.utils import ModuleSparsificationInfo
+from llmcompressor.typing import Processor
 
 COMPLETED_STAGES_FILENAME = "completed_stages.json"
 
@@ -92,15 +93,16 @@ def initialize_recipe(model: Module, recipe_path: str):
 def save_model_and_recipe(
     model: Module,
     save_path: str,
-    tokenizer: Optional[Any] = None,
+    processor: Optional[Processor] = None,
     save_safetensors: bool = False,
     save_compressed: bool = False,
 ):
     """
-    Save a model, tokenizer and the currently loaded recipe to file
+    Save a model, processor and the currently loaded recipe to file
+
     :param model: pytorch model to save
     :param save_path: path to save output to
-    :param tokenizer: model tokenizer to save
+    :param processor: model processor or tokenizer to save
     :param save_safetensors: whether to save as safetensors or pickle (bin)
     :param save_compressed: whether to compress sparse weights on disk
     """
@@ -111,8 +113,8 @@ def save_model_and_recipe(
         save_path, save_compressed=save_compressed, safe_serialization=save_safetensors
     )
 
-    if tokenizer is not None:
-        tokenizer.save_pretrained(save_path)
+    if processor is not None:
+        processor.save_pretrained(save_path)
 
     logger.info("Saving output to {}".format(os.path.abspath(save_path)))
 
diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py
index d4c3a6222..3b68e0fc1 100644
--- a/src/llmcompressor/transformers/finetune/data/base.py
+++ b/src/llmcompressor/transformers/finetune/data/base.py
@@ -3,7 +3,6 @@
 from compressed_tensors.registry import RegistryMixin
 from datasets import Dataset, IterableDataset
 from loguru import logger
-from transformers import AutoTokenizer
 
 from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments
 from llmcompressor.transformers.finetune.data.data_helpers import (
@@ -11,6 +10,7 @@
     get_custom_datasets_from_path,
     get_raw_dataset,
 )
+from llmcompressor.typing import Processor
 
 
 class TextGenerationDataset(RegistryMixin):
@@ -30,10 +30,10 @@ def __init__(
         text_column: str,
         data_args: DataTrainingArguments,
         split: str,
-        tokenizer: AutoTokenizer,
+        processor: Processor,
     ):
         self.text_column = text_column
-        self.tokenizer = tokenizer
+        self.processor = processor
         self.data_args = data_args
         self.raw_kwargs = data_args.raw_kwargs or {}
         self.split = split
@@ -50,20 +50,38 @@ def __init__(
         else:
             self.padding = False
 
-        if self.tokenizer:
+        # get tokenizer
+        self.tokenizer = getattr(self.processor, "tokenizer", self.processor)
+
+        if self.tokenizer is not None:
+            # fill in pad token
             if not self.tokenizer.pad_token:
                 self.tokenizer.pad_token = self.tokenizer.eos_token
 
-        # configure sequence length
-        max_seq_length = data_args.max_seq_length
-        model_max_length = tokenizer.model_max_length if tokenizer else max_seq_length
-        if self.tokenizer and max_seq_length > model_max_length:
-            logger.warning(
-                f"The max_seq_length passed ({max_seq_length}) is larger than "
-                f"the maximum length for the model ({tokenizer.model_max_length}). "
-                f"Using max_seq_length={tokenizer.model_max_length}."
+            # configure sequence length
+            max_seq_length = data_args.max_seq_length
+            if data_args.max_seq_length > self.tokenizer.model_max_length:
+                logger.warning(
+                    f"The max_seq_length passed ({max_seq_length}) is larger than "
+                    f"maximum length for model ({self.tokenizer.model_max_length}). "
+                    f"Using max_seq_length={self.tokenizer.model_max_length}."
+                )
+            self.max_seq_length = min(
+                data_args.max_seq_length, self.tokenizer.model_max_length
+            )
+
+            # configure padding
+            self.padding = (
+                False
+                if self.data_args.concatenate_data
+                else "max_length"
+                if self.data_args.pad_to_max_length
+                else False
             )
-        self.max_seq_length = min(data_args.max_seq_length, model_max_length)
+
+        else:
+            self.max_seq_length = None
+            self.padding = False
 
     def get_raw_dataset(self, cache_dir: Optional[str] = None) -> Dataset:
         """
diff --git a/src/llmcompressor/transformers/finetune/data/c4.py b/src/llmcompressor/transformers/finetune/data/c4.py
index 37eeceae6..91cbc58e8 100644
--- a/src/llmcompressor/transformers/finetune/data/c4.py
+++ b/src/llmcompressor/transformers/finetune/data/c4.py
@@ -10,12 +10,12 @@ class C4Dataset(TextGenerationDataset):
 
     :param data_args: configuration settings for dataset loading
     :param split: split from dataset to load, for instance `test` or `train[:5%]`
-    :param tokenizer: tokenizer to use on dataset
+    :param processor: processor or tokenizer to use on dataset
     """
 
-    def __init__(self, data_args, split, tokenizer):
+    def __init__(self, data_args, split, processor):
         data_args = deepcopy(data_args)
         data_args.dataset = "allenai/c4"
         super().__init__(
-            text_column="text", data_args=data_args, split=split, tokenizer=tokenizer
+            text_column="text", data_args=data_args, split=split, processor=processor
         )
diff --git a/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py b/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py
index 64755de4a..dcebe7573 100644
--- a/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py
+++ b/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py
@@ -24,18 +24,18 @@ class CNNDailyMailDataset(TextGenerationDataset):
 
     :param data_args: configuration settings for dataset loading
     :param split: split from dataset to load, for instance `test` or `train[:5%]`
-    :param tokenizer: tokenizer to use on dataset
+    :param processor: processor or tokenizer to use on dataset
     """
 
     SAMPLE_TEMPLATE = "Article:\n{article}\n\n### Summarization:\n{highlights}\n"
 
-    def __init__(self, data_args, split, tokenizer):
+    def __init__(self, data_args, split, processor):
         data_args = deepcopy(data_args)
         data_args.dataset = "cnn_dailymail"
         data_args.dataset_config_name = "3.0.0"
 
         super().__init__(
-            text_column="text", data_args=data_args, split=split, tokenizer=tokenizer
+            text_column="text", data_args=data_args, split=split, processor=processor
         )
 
     def get_raw_dataset(self, cache_dir: Optional[str] = None):
diff --git a/src/llmcompressor/transformers/finetune/data/custom.py b/src/llmcompressor/transformers/finetune/data/custom.py
index e849594e7..817cb34de 100644
--- a/src/llmcompressor/transformers/finetune/data/custom.py
+++ b/src/llmcompressor/transformers/finetune/data/custom.py
@@ -32,17 +32,17 @@ class CustomDataset(TextGenerationDataset):
     :param data_args: configuration settings for dataset loading
     :param split: split from dataset to load, for instance `test` or `train[:5%]`
         Can also be set to None to load all the splits
-    :param tokenizer: tokenizer to use on dataset
+    :param processor: processor or tokenizer to use on dataset
 
     """
 
-    def __init__(self, data_args, split, tokenizer):
+    def __init__(self, data_args, split, processor):
         data_args = deepcopy(data_args)
         super().__init__(
             text_column=data_args.text_column,
             data_args=data_args,
             split=split,
-            tokenizer=tokenizer,
+            processor=processor,
         )
         self.preprocessing_func = data_args.preprocessing_func
         self.remove_columns = data_args.remove_columns
diff --git a/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py b/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py
index 9529d3115..66505f117 100644
--- a/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py
+++ b/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py
@@ -24,7 +24,7 @@ class EvolCodeAlpacaDataset(TextGenerationDataset):
 
     :param data_args: configuration settings for dataset loading
     :param split: split from dataset to load, for instance `test` or `train[:5%]`
-    :param tokenizer: tokenizer to use on dataset
+    :param processor: processor or tokenizer to use on dataset
     """
 
     EVOL_ALPACA_TEMPLATE = (
@@ -34,11 +34,11 @@ class EvolCodeAlpacaDataset(TextGenerationDataset):
         "\n\n### Response:\n"
     )
 
-    def __init__(self, data_args, split, tokenizer):
+    def __init__(self, data_args, split, processor):
         data_args = deepcopy(data_args)
         data_args.dataset = "theblackcat102/evol-codealpaca-v1"
         super().__init__(
-            text_column="text", data_args=data_args, split=split, tokenizer=tokenizer
+            text_column="text", data_args=data_args, split=split, processor=processor
         )
 
     def get_raw_dataset(self, cache_dir: Optional[str] = None):
diff --git a/src/llmcompressor/transformers/finetune/data/gsm8k.py b/src/llmcompressor/transformers/finetune/data/gsm8k.py
index f9a94bcf4..299ae1bb2 100644
--- a/src/llmcompressor/transformers/finetune/data/gsm8k.py
+++ b/src/llmcompressor/transformers/finetune/data/gsm8k.py
@@ -11,16 +11,16 @@ class GSM8KDataset(TextGenerationDataset):
 
     :param data_args: configuration settings for dataset loading
     :param split: split from dataset to load, for instance `test` or `train[:5%]`
-    :param tokenizer: tokenizer to use on dataset
+    :param processor: processor or tokenizer to use on dataset
     """
 
     GSM_TEMPLATE = "Question: {question}\nAnswer:"
 
-    def __init__(self, data_args, split, tokenizer):
+    def __init__(self, data_args, split, processor):
         data_args = deepcopy(data_args)
         data_args.dataset = "gsm8k"
         super().__init__(
-            text_column="text", data_args=data_args, split=split, tokenizer=tokenizer
+            text_column="text", data_args=data_args, split=split, processor=processor
         )
 
     def get_raw_dataset(self, cache_dir: Optional[str] = None):
diff --git a/src/llmcompressor/transformers/finetune/data/open_platypus.py b/src/llmcompressor/transformers/finetune/data/open_platypus.py
index 55e54cbce..7a17c6fde 100644
--- a/src/llmcompressor/transformers/finetune/data/open_platypus.py
+++ b/src/llmcompressor/transformers/finetune/data/open_platypus.py
@@ -24,7 +24,7 @@ class OpenPlatypusDataset(TextGenerationDataset):
 
     :param data_args: configuration settings for dataset loading
     :param split: split from dataset to load, for instance `test` or `train[:5%]`
-    :param tokenizer: tokenizer to use on dataset
+    :param processor: processor or tokenizer to use on dataset
     """
 
     ALPACA_TEMPLATE = {
@@ -37,11 +37,11 @@ class OpenPlatypusDataset(TextGenerationDataset):
         "instruction}\n\n### Response:\n",
     }
 
-    def __init__(self, data_args, split, tokenizer):
+    def __init__(self, data_args, split, processor):
         data_args = deepcopy(data_args)
         data_args.dataset = "garage-bAInd/Open-Platypus"
         super().__init__(
-            text_column="text", data_args=data_args, split=split, tokenizer=tokenizer
+            text_column="text", data_args=data_args, split=split, processor=processor
         )
 
     def get_raw_dataset(self, cache_dir: Optional[str] = None):
diff --git a/src/llmcompressor/transformers/finetune/data/ptb.py b/src/llmcompressor/transformers/finetune/data/ptb.py
index 6f502edaf..8519f023c 100644
--- a/src/llmcompressor/transformers/finetune/data/ptb.py
+++ b/src/llmcompressor/transformers/finetune/data/ptb.py
@@ -10,15 +10,15 @@ class PtbDataset(TextGenerationDataset):
 
     :param data_args: configuration settings for dataset loading
     :param split: split from dataset to load, for instance `test` or `train[:5%]`
-    :param tokenizer: tokenizer to use on dataset
+    :param processor: processor or tokenizer to use on dataset
     """
 
-    def __init__(self, data_args, split, tokenizer):
+    def __init__(self, data_args, split, processor):
         data_args = deepcopy(data_args)
         data_args.dataset = "ptb_text_only"
         super().__init__(
             text_column="sentence",
             data_args=data_args,
             split=split,
-            tokenizer=tokenizer,
+            processor=processor,
         )
diff --git a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py
index 5b2e66ab5..30607847d 100644
--- a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py
+++ b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py
@@ -24,7 +24,7 @@ class UltraChatDataset(TextGenerationDataset):
 
     :param data_args: configuration settings for dataset loading
     :param split: split from dataset to load, for instance `test` or `train[:5%]`
-    :param tokenizer: tokenizer to use on dataset
+    :param processor: processor or tokenizer to use on dataset
     """
 
     DEFAULT_CHAT_TEMPLATE = (
@@ -40,7 +40,7 @@ class UltraChatDataset(TextGenerationDataset):
         "{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
     )
 
-    def __init__(self, data_args, split, tokenizer):
+    def __init__(self, data_args, split, processor):
         data_args = deepcopy(data_args)
         data_args.dataset = "HuggingFaceH4/ultrachat_200k"
 
@@ -51,13 +51,15 @@ def __init__(self, data_args, split, tokenizer):
             text_column="messages",
             data_args=data_args,
             split=split,
-            tokenizer=tokenizer,
+            processor=processor,
         )
 
         if (
             not hasattr(self.tokenizer, "chat_template")
             or self.tokenizer.chat_template is None
         ):
+            # note that since tokenizer is a member of processor,
+            # this change affects processor.apply_chat_template
             self.tokenizer.chat_template = self.DEFAULT_CHAT_TEMPLATE
 
     def get_raw_dataset(self, cache_dir: Optional[str] = None):
@@ -75,7 +77,7 @@ def restructure_fn(sample):
             if sample["messages"][0]["role"] != "system":
                 sample["messages"].insert(0, {"role": "system", "content": ""})
 
-            sample["messages"] = self.tokenizer.apply_chat_template(
+            sample["messages"] = self.processor.apply_chat_template(
                 sample["messages"], tokenize=False, add_generation_prompt=False
             )
             return sample
diff --git a/src/llmcompressor/transformers/finetune/data/wikitext.py b/src/llmcompressor/transformers/finetune/data/wikitext.py
index 034d58ba2..25280589c 100644
--- a/src/llmcompressor/transformers/finetune/data/wikitext.py
+++ b/src/llmcompressor/transformers/finetune/data/wikitext.py
@@ -8,10 +8,10 @@ class WikiTextDataset(TextGenerationDataset):
 
     :param data_args: configuration settings for dataset loading
     :param split: split from dataset to load, for instance `test` or `train[:5%]`
-    :param tokenizer: tokenizer to use on dataset
+    :param processor: processor or tokenizer to use on dataset
     """
 
-    def __init__(self, data_args, split, tokenizer):
+    def __init__(self, data_args, split, processor):
         super().__init__(
-            text_column="text", data_args=data_args, split=split, tokenizer=tokenizer
+            text_column="text", data_args=data_args, split=split, processor=processor
         )
diff --git a/src/llmcompressor/transformers/finetune/model_args.py b/src/llmcompressor/transformers/finetune/model_args.py
index d3d8e974f..c81900ee2 100644
--- a/src/llmcompressor/transformers/finetune/model_args.py
+++ b/src/llmcompressor/transformers/finetune/model_args.py
@@ -34,6 +34,12 @@ class ModelArguments:
             "help": "Pretrained tokenizer name or path if not the same as model_name"
         },
     )
+    processor: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": "Pretrained processor name or path if not the same as model_name"
+        },
+    )
     cache_dir: Optional[str] = field(
         default=None,
         metadata={"help": "Where to store the pretrained data from huggingface.co"},
diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py
index 6344b1a2b..131180199 100644
--- a/src/llmcompressor/transformers/finetune/runner.py
+++ b/src/llmcompressor/transformers/finetune/runner.py
@@ -6,7 +6,6 @@
 import torch
 from loguru import logger
 from torch.utils.data import Dataset
-from transformers import AutoTokenizer
 
 from llmcompressor.core import active_session
 from llmcompressor.pytorch.model_load.helpers import (
@@ -24,6 +23,7 @@
 )
 from llmcompressor.transformers.finetune.model_args import ModelArguments
 from llmcompressor.transformers.finetune.training_args import TrainingArguments
+from llmcompressor.typing import Processor
 from llmcompressor.utils.fsdp.helpers import is_fsdp_model, save_model_and_recipe
 
 
@@ -38,7 +38,7 @@ class StageRunner:
         - set_trainer()
         - train() / evaluate() / predict()
 
-    :param model_args: Arguments pertaining to model/config/tokenizer
+    :param model_args: Arguments pertaining to model/config/processor
     :param data_args: Arguments pertaining to what data to use for different flows
     :param training_args: Arguments pertaining to training loop configuration
     :model: unwrapped model to run flows on
@@ -56,11 +56,11 @@ def __init__(
 
         self.datasets = {}
         self.trainer = None
-        self.tokenizer = None
+        self.processor = None
         self.parent_output_dir = self._training_args.output_dir
         self._output_dir = self._training_args.output_dir
 
-    def populate_datasets(self, tokenizer: "AutoTokenizer", add_labels: bool = True):
+    def populate_datasets(self, processor: Processor, add_labels: bool = True):
         """
         Loads datasets for each flow based on data_args, stores a Dataset for each
         enabled flow in self.datasets
@@ -68,7 +68,7 @@ def populate_datasets(self, tokenizer: "AutoTokenizer", add_labels: bool = True)
         :param tokenizer: tokenizer to use for dataset tokenization
         """
         if self._data_args.dataset is None:
-            self.tokenizer = self._model_args.tokenizer
+            self.processor = self._model_args.processor
             logger.info(
                 "Running oneshot without calibration data. This is expected for "
                 "weight-only and dynamic quantization"
@@ -102,7 +102,7 @@ def _get_split_name(inp_str):
                 registry_id,
                 data_args=self._data_args,
                 split=split_str,
-                tokenizer=tokenizer,
+                processor=processor,
             )
 
             dataset = self._data_args.dataset
@@ -124,7 +124,7 @@ def _get_split_name(inp_str):
             do_predict=self._training_args.do_predict,
             do_oneshot=self._training_args.do_oneshot,
         )
-        self.tokenizer = tokenizer
+        self.processor = processor
 
     def get_dataset_split(self, split_name: str) -> Dataset:
         """
@@ -266,7 +266,7 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None):
                 save_model_and_recipe(
                     model=self.trainer.model,
                     save_path=self._output_dir,
-                    tokenizer=self.tokenizer,
+                    processor=self.processor,
                     save_safetensors=self._training_args.save_safetensors,
                     save_compressed=self._training_args.save_compressed,
                 )
diff --git a/src/llmcompressor/transformers/finetune/session_mixin.py b/src/llmcompressor/transformers/finetune/session_mixin.py
index b1ac57b95..27860aeb4 100644
--- a/src/llmcompressor/transformers/finetune/session_mixin.py
+++ b/src/llmcompressor/transformers/finetune/session_mixin.py
@@ -487,8 +487,9 @@ def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False):
             )
 
         self.save_state()
-        if self.tokenizer is not None:
-            self.tokenizer.save_pretrained(output_dir)
+        processor = getattr(self, "processing_class", self.tokenizer)
+        if processor is not None:
+            processor.save_pretrained(output_dir)
 
         if not self.recipe:
             return
diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py
index 85aa6d82c..f0e3a6b16 100644
--- a/src/llmcompressor/transformers/finetune/text_generation.py
+++ b/src/llmcompressor/transformers/finetune/text_generation.py
@@ -24,9 +24,10 @@
 from transformers import (
     AutoConfig,
     AutoModelForCausalLM,
-    AutoTokenizer,
+    AutoProcessor,
     DefaultDataCollator,
     HfArgumentParser,
+    PreTrainedModel,
     set_seed,
 )
 
@@ -49,9 +50,10 @@
     patch_tied_tensors_bug,
 )
 from llmcompressor.transformers.sparsification.sparse_model import (
-    get_shared_tokenizer_src,
+    get_shared_processor_src,
 )
 from llmcompressor.transformers.utils.helpers import detect_last_checkpoint
+from llmcompressor.typing import Processor
 from llmcompressor.utils.fsdp.helpers import is_fsdp_model
 
 
@@ -134,6 +136,13 @@ def parse_args(**kwargs):
                 arg_dict[key] = value
             training_args.recipe_args = arg_dict
 
+    # silently assign tokenizer to processor
+    if model_args.tokenizer:
+        if model_args.processor:
+            raise ValueError("Cannot use both a tokenizer and processor")
+        model_args.processor = model_args.tokenizer
+    model_args.tokenizer = None
+
     return model_args, data_args, training_args
 
 
@@ -226,11 +235,13 @@ def initialize_model_from_path(
     return teacher, model_path, model
 
 
-def initialize_tokenizer_from_path(model_args, model, teacher):
-    tokenizer_src = model_args.tokenizer
-    tokenizer_src = tokenizer_src or get_shared_tokenizer_src(model, teacher)
-    tokenizer = AutoTokenizer.from_pretrained(
-        tokenizer_src,
+def initialize_processor_from_path(
+    model_args: ModelArguments, model: PreTrainedModel, teacher: PreTrainedModel
+) -> Processor:
+    processor_src = model_args.processor
+    processor_src = processor_src or get_shared_processor_src(model, teacher)
+    processor = AutoProcessor.from_pretrained(
+        processor_src,
         cache_dir=model_args.cache_dir,
         use_fast=True,
         revision=model_args.model_revision,
@@ -238,7 +249,7 @@ def initialize_tokenizer_from_path(model_args, model, teacher):
         trust_remote_code=model_args.trust_remote_code_model,
     )
 
-    return tokenizer
+    return processor
 
 
 def main(
@@ -299,11 +310,9 @@ def main(
     # Detecting last checkpoint.
     last_checkpoint = None
     teacher = model_args.distill_teacher
-    model = model_args.model
-    # Load tokenizer
-    # distill TODO: support for different tokenizer for teacher?
-    tokenizer = model_args.tokenizer
+    # distill TODO: support for different processor for teacher?
 
+    model = model_args.model
     if isinstance(model, str) or isinstance(model, PosixPath):
         (teacher, _model_path, model) = initialize_model_from_path(
             model_args,
@@ -317,8 +326,9 @@ def main(
     if teacher is not None:
         teacher.eval()
 
-    if isinstance(tokenizer, str) or tokenizer is None:
-        tokenizer = initialize_tokenizer_from_path(model_args, model, teacher)
+    processor = model_args.processor
+    if isinstance(processor, str) or processor is None:
+        processor = initialize_processor_from_path(model_args, model, teacher)
 
     pre_initialize_structure(model=model)
 
@@ -330,7 +340,7 @@ def main(
         model_args=model_args, data_args=data_args, training_args=training_args
     )
     add_labels = training_args.do_train or training_args.run_stages
-    stage_runner.populate_datasets(tokenizer=tokenizer, add_labels=add_labels)
+    stage_runner.populate_datasets(processor=processor, add_labels=add_labels)
     train_dataset = stage_runner.get_dataset_split("train")
     eval_dataset = stage_runner.get_dataset_split("validation")
     calib_dataset = stage_runner.get_dataset_split("calibration")
@@ -346,13 +356,13 @@ def main(
         data_args=data_args,
         train_dataset=train_dataset or calib_dataset,
         eval_dataset=eval_dataset,
-        tokenizer=tokenizer,
+        processing_class=processor,
         data_collator=data_collator,
     )
 
     # wrap model.save_pretrained
     if is_fsdp_model(model):
-        modify_fsdp_model_save_pretrained(trainer, tokenizer)
+        modify_fsdp_model_save_pretrained(trainer, processor)
     else:
         modify_save_pretrained(model)
 
@@ -396,8 +406,8 @@ def main(
         model.save_pretrained(
             training_args.output_dir, save_compressed=training_args.save_compressed
         )
-        if tokenizer is not None:
-            tokenizer.save_pretrained(training_args.output_dir)
+        if processor is not None:
+            processor.save_pretrained(training_args.output_dir)
 
     # Clean up the CompressionSession before exit if requested
     if training_args.clear_sparse_session:
diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py
index 759098894..ce4ae7fb2 100644
--- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py
+++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py
@@ -25,6 +25,7 @@
     SparsityConfigMetadata,
 )
 from llmcompressor.transformers.utils import RECIPE_FILE_NAME
+from llmcompressor.typing import Processor
 from llmcompressor.utils.fsdp.helpers import (
     find_and_move_state_dicts_to_cpu,
     unwrap_and_export_model,
@@ -33,7 +34,7 @@
 __all__ = ["modify_save_pretrained", "modify_fsdp_model_save_pretrained"]
 
 
-def modify_fsdp_model_save_pretrained(trainer, tokenizer):
+def modify_fsdp_model_save_pretrained(trainer, processor: Processor):
     """
     Overrides a PreTrainedModel's save_pretrained() method with a wrapped version that
     supports compression for fsdp model
@@ -78,7 +79,7 @@ def save_pretrained_wrapper(
                     model=trainer.model,
                     accelerator=trainer.accelerator,
                     output_dir=save_directory,
-                    tokenizer=tokenizer,
+                    processor=processor,
                 )
                 # only allow the main process move the state
                 # dicts to cpu
diff --git a/src/llmcompressor/transformers/sparsification/sparse_model.py b/src/llmcompressor/transformers/sparsification/sparse_model.py
index bf09396d7..d7abc323a 100644
--- a/src/llmcompressor/transformers/sparsification/sparse_model.py
+++ b/src/llmcompressor/transformers/sparsification/sparse_model.py
@@ -7,7 +7,7 @@
 
 __all__ = [
     "SparseAutoModelForCausalLM",
-    "get_shared_tokenizer_src",
+    "get_shared_processor_src",
 ]
 
 
@@ -20,14 +20,14 @@ def from_pretrained(*args, **kwargs):
         return AutoModelForCausalLM.from_pretrained(*args, **kwargs)
 
 
-def get_shared_tokenizer_src(student: Module, teacher: Optional[Module]) -> str:
+def get_shared_processor_src(student: Module, teacher: Optional[Module]) -> str:
     """
-    Get a tokenizer source used for both student and teacher, assuming
+    Get a processor/tokenizer source used for both student and teacher, assuming
     that they could be shared
 
     :param student: the student model
     :param teacher: the teacher model
-    :return: the source for the tokenizer shared between teacher and model
+    :return: the source for the processor/tokenizer shared between teacher and model
     """
 
     if teacher is not None and teacher not in ("disable", "self"):
diff --git a/src/llmcompressor/transformers/utils/preprocessing_functions.py b/src/llmcompressor/transformers/utils/preprocessing_functions.py
index cadec88f0..6bf6ade42 100644
--- a/src/llmcompressor/transformers/utils/preprocessing_functions.py
+++ b/src/llmcompressor/transformers/utils/preprocessing_functions.py
@@ -1,14 +1,17 @@
-from typing import Dict
+from typing import TYPE_CHECKING, Dict
 
 from compressed_tensors.registry import RegistryMixin
 
+if TYPE_CHECKING:
+    from llmcompressor.transformers.finetune.data.base import TextGenerationDataset
+
 
 class PreprocessingFunctionRegistry(RegistryMixin):
     pass
 
 
 @PreprocessingFunctionRegistry.register()
-def custom_evolved_codealpaca_dataset(data: Dict):
+def custom_evolved_codealpaca_dataset(self: "TextGenerationDataset", data: Dict):
     PROMPT_DICT = """[Instruction]:\n{instruction}\n\n[Response]:"""
     data["prompt"] = PROMPT_DICT.format_map(data)
     data["text"] = data["prompt"] + data["output"]
diff --git a/src/llmcompressor/typing.py b/src/llmcompressor/typing.py
new file mode 100644
index 000000000..1050f7138
--- /dev/null
+++ b/src/llmcompressor/typing.py
@@ -0,0 +1,17 @@
+from typing import Union
+
+from datasets import Dataset, DatasetDict, IterableDataset
+from transformers import (
+    BaseImageProcessor,
+    FeatureExtractionMixin,
+    PreTrainedTokenizer,
+    ProcessorMixin,
+)
+
+# Tokenizer or Processor. Processors do not inherit from a unified base class
+Processor = Union[
+    PreTrainedTokenizer, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin
+]
+
+# Supported dataset types, IterableDataset is a streamed dataset
+DatasetType = Union[Dataset, DatasetDict, IterableDataset]
diff --git a/src/llmcompressor/utils/fsdp/helpers.py b/src/llmcompressor/utils/fsdp/helpers.py
index 8cc0f5405..3a3248fa5 100644
--- a/src/llmcompressor/utils/fsdp/helpers.py
+++ b/src/llmcompressor/utils/fsdp/helpers.py
@@ -18,6 +18,7 @@
 
 from llmcompressor.core.state import State
 from llmcompressor.pytorch.model_load.helpers import save_model_and_recipe
+from llmcompressor.typing import Processor
 from llmcompressor.utils.pytorch import set_layer
 
 __all__ = [
@@ -71,7 +72,7 @@ def set_wrapped_model(state: State, wrapped_model: Module):
         state.model = wrapped_model
 
 
-def unwrap_and_export_model(model, accelerator, output_dir, tokenizer):
+def unwrap_and_export_model(model, accelerator, output_dir: str, processor: Processor):
     """
     Recursively unwraps an FSDP model, then saves the unwrapped model and the
     currently active recipe to disk
@@ -79,7 +80,7 @@ def unwrap_and_export_model(model, accelerator, output_dir, tokenizer):
     :param model: model to unwrap
     :param accelerator: Accelerator instance used to perform unwrapping
     :param output_dir: where to save output model
-    :param tokenizer: tokenizer used by the model
+    :param processor: processor used by the model
     """
     full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
     with FullyShardedDataParallel.state_dict_type(
@@ -95,7 +96,7 @@ def unwrap_and_export_model(model, accelerator, output_dir, tokenizer):
         save_model_and_recipe(
             model=unwrapped_model,
             save_path=output_dir,
-            tokenizer=tokenizer,
+            processor=processor,
         )
 
 
diff --git a/tests/llmcompressor/transformers/compression/test_quantization.py b/tests/llmcompressor/transformers/compression/test_quantization.py
index c0f0d2c02..9b82e5d50 100644
--- a/tests/llmcompressor/transformers/compression/test_quantization.py
+++ b/tests/llmcompressor/transformers/compression/test_quantization.py
@@ -132,7 +132,7 @@ def _get_dataloader(self, data_args, tokenizer):
             data_args.dataset,
             data_args=data_args,
             split="train_gen[:5%]",
-            tokenizer=tokenizer,
+            processor=tokenizer,
         )
         calib_dataset = dataset_manager.tokenize_and_process(
             dataset_manager.get_raw_dataset()
diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py
index a602c4828..7d6fa38da 100644
--- a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py
+++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py
@@ -28,7 +28,7 @@ def test_concatenation_tokenization(self):
             self.data_args.dataset,
             data_args=self.data_args,
             split="train[:5%]",
-            tokenizer=self.tiny_llama_tokenizer,
+            processor=self.tiny_llama_tokenizer,
         )
         raw_dataset = wiki_manager.get_raw_dataset()
         self.assertGreater(len(raw_dataset), 0)
@@ -60,7 +60,7 @@ def test_no_padding_tokenization(self):
             self.data_args.dataset,
             data_args=self.data_args,
             split="train[5%:10%]",
-            tokenizer=self.tiny_llama_tokenizer,
+            processor=self.tiny_llama_tokenizer,
         )
         raw_dataset = op_manager.get_raw_dataset()
         self.assertGreater(len(raw_dataset), 0)
@@ -95,7 +95,7 @@ def test_max_seq_len_clipped(self):
             self.data_args.dataset,
             data_args=self.data_args,
             split="train[80%:]",
-            tokenizer=self.tiny_llama_tokenizer,
+            processor=self.tiny_llama_tokenizer,
         )
 
         self.assertEqual(
@@ -124,7 +124,7 @@ def test_dataset_kwargs_and_percentages(self):
             self.data_args.dataset,
             data_args=self.data_args,
             split="train[5%:10%]",
-            tokenizer=self.tiny_llama_tokenizer,
+            processor=self.tiny_llama_tokenizer,
         )
         raw_dataset_a = c4_manager_a.get_raw_dataset()
 
@@ -132,7 +132,7 @@ def test_dataset_kwargs_and_percentages(self):
             self.data_args.dataset,
             data_args=self.data_args,
             split="train[5%:15%]",
-            tokenizer=self.tiny_llama_tokenizer,
+            processor=self.tiny_llama_tokenizer,
         )
         raw_dataset_b = c4_manager_b.get_raw_dataset()
 
@@ -163,7 +163,7 @@ def test_datasets(self, dataset_key, dataset_config, split, do_concat):
             data_args.dataset,
             data_args=data_args,
             split=split,
-            tokenizer=self.tiny_llama_tokenizer,
+            processor=self.tiny_llama_tokenizer,
         )
         raw_dataset = manager.get_raw_dataset()
         self.assertGreater(len(raw_dataset), 0)
@@ -203,7 +203,7 @@ def test_evol(self):
             self.data_args.dataset,
             data_args=self.data_args,
             split="train[:2%]",
-            tokenizer=self.tiny_llama_tokenizer,
+            processor=self.tiny_llama_tokenizer,
         )
         raw_dataset = evol_manager.get_raw_dataset()
         self.assertGreater(len(raw_dataset), 0)
@@ -237,7 +237,7 @@ def test_stream_loading(self):
             self.data_args.dataset,
             data_args=self.data_args,
             split="train",
-            tokenizer=self.tiny_llama_tokenizer,
+            processor=self.tiny_llama_tokenizer,
         )
 
         raw_dataset = manager.get_raw_dataset()
@@ -275,7 +275,7 @@ def test_split_loading(self, split_def):
         stage_runner = StageRunner(
             model_args=model_args, data_args=data_args, training_args=training_args
         )
-        stage_runner.populate_datasets(tokenizer=self.tiny_llama_tokenizer)
+        stage_runner.populate_datasets(processor=self.tiny_llama_tokenizer)
 
         train_dataset = stage_runner.get_dataset_split("train")
         assert train_dataset is not None
@@ -318,7 +318,7 @@ def preprocess(sample):
             ),
             training_args=TrainingArguments(do_oneshot=True),
         )
-        stage_runner.populate_datasets(tokenizer=None)
+        stage_runner.populate_datasets(processor=None)
         calib_dataset = stage_runner.get_dataset_split("calibration")
         self.assertEqual(len(calib_dataset), self.num_calib_samples)
         data_cols = calib_dataset.column_names
diff --git a/tests/llmcompressor/transformers/finetune/data/test_registry.py b/tests/llmcompressor/transformers/finetune/data/test_registry.py
index e4c804c07..3350d0a79 100644
--- a/tests/llmcompressor/transformers/finetune/data/test_registry.py
+++ b/tests/llmcompressor/transformers/finetune/data/test_registry.py
@@ -16,7 +16,7 @@ def test_c4_initializes(tiny_llama_tokenizer):
         data_args.dataset,
         data_args=data_args,
         split=None,
-        tokenizer=tiny_llama_tokenizer,
+        processor=tiny_llama_tokenizer,
     )
     assert isinstance(c4_manager, TextGenerationDataset)
     assert isinstance(c4_manager, C4Dataset)
@@ -34,7 +34,7 @@ def test_wikitext_initializes(tiny_llama_tokenizer):
         data_args.dataset,
         data_args=data_args,
         split=None,
-        tokenizer=tiny_llama_tokenizer,
+        processor=tiny_llama_tokenizer,
     )
     assert isinstance(wiki_manager, TextGenerationDataset)
     assert isinstance(wiki_manager, WikiTextDataset)
@@ -50,7 +50,7 @@ def test_open_platypus_initializes(tiny_llama_tokenizer):
         data_args.dataset,
         data_args=data_args,
         split=None,
-        tokenizer=tiny_llama_tokenizer,
+        processor=tiny_llama_tokenizer,
     )
     assert isinstance(op_manager, TextGenerationDataset)
     assert isinstance(op_manager, OpenPlatypusDataset)
diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_completion.py b/tests/llmcompressor/transformers/obcq/test_obcq_completion.py
index cb7f64943..f49a02bd1 100644
--- a/tests/llmcompressor/transformers/obcq/test_obcq_completion.py
+++ b/tests/llmcompressor/transformers/obcq/test_obcq_completion.py
@@ -37,7 +37,7 @@ def labeled_dataloader(self, dataset_name, model_name):
             data_args.dataset,
             data_args=data_args,
             split="train",
-            tokenizer=tokenizer,
+            processor=tokenizer,
         )
         calib_dataset = dataset_manager.tokenize_and_process(
             dataset_manager.get_raw_dataset()
diff --git a/tests/testing_utils.py b/tests/testing_utils.py
index 07b166013..a6103a73c 100644
--- a/tests/testing_utils.py
+++ b/tests/testing_utils.py
@@ -9,7 +9,7 @@
 
 import yaml
 from datasets import Dataset
-from transformers import AutoTokenizer
+from transformers import PreTrainedTokenizer
 
 from tests.data import CustomTestConfig, TestConfig
 
@@ -126,7 +126,7 @@ def run_cli_command(cmd: List[str], cwd: Optional[Union[str, Path]] = None):
 
 
 def preprocess_tokenize_dataset(
-    ds: Dataset, tokenizer: AutoTokenizer, max_seq_length: int
+    ds: Dataset, tokenizer: PreTrainedTokenizer, max_seq_length: int
 ) -> Dataset:
     """
     Helper function to preprocess and tokenize a dataset according to presets