diff --git a/recipes/configs/llama3_2_vision/11B_lora.yaml b/recipes/configs/llama3_2_vision/11B_lora.yaml
index 357af64496..1e5c0323ac 100644
--- a/recipes/configs/llama3_2_vision/11B_lora.yaml
+++ b/recipes/configs/llama3_2_vision/11B_lora.yaml
@@ -81,7 +81,7 @@ enable_activation_offloading: False
 dtype: bf16
 
 # Logging
-output_dir: /tmp/full-llama3.2-vision-finetune
+output_dir: /tmp/lora-llama3.2-vision-finetune
 metric_logger:
   _component_: torchtune.training.metric_logging.DiskLogger
   log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs
diff --git a/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml b/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml
index f56828c301..88e51aa355 100644
--- a/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml
+++ b/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml
@@ -80,7 +80,7 @@ enable_activation_offloading: False
 dtype: bf16
 
 # Logging
-output_dir: /tmp/full-llama3.2-vision-finetune
+output_dir: /tmp/lora-llama3.2-vision-finetune
 metric_logger:
   _component_: torchtune.training.metric_logging.DiskLogger
   log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs
diff --git a/recipes/configs/llama3_2_vision/11B_qlora.yaml b/recipes/configs/llama3_2_vision/11B_qlora.yaml
new file mode 100644
index 0000000000..1217fb367a
--- /dev/null
+++ b/recipes/configs/llama3_2_vision/11B_qlora.yaml
@@ -0,0 +1,88 @@
+# Config for multi-device QLoRA finetuning in lora_finetune_distributed.py
+# using a Llama3.2 11B Vision Instruct model
+#
+# This config assumes that you've run the following command before launching:
+#   tune download meta-llama/Llama-3.2-11B-Vision-Instruct --output-dir /tmp/Llama-3.2-11B-Vision-Instruct
+#
+# To launch on 2 devices, run the following command from root:
+#   tune run --nproc_per_node 2 lora_finetune_distributed --config llama3_2_vision/11B_qlora
+#
+# You can add specific overrides through the command line. For example
+# to override the checkpointer directory while launching training:
+#   tune run --nproc_per_node 2 lora_finetune_distributed --config llama3_2_vision/11B_qlora checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
+#
+# This config works best when the model is being fine-tuned on 2+ GPUs.
+# For single device QLoRA finetuning please use 11B_qlora_single_device.yaml
+
+# Model arguments
+model:
+  _component_: torchtune.models.llama3_2_vision.qlora_llama3_2_vision_11b
+  decoder_trainable: "frozen"
+  encoder_trainable: "lora"
+  fusion_trainable: "lora"
+  lora_attn_modules: ['q_proj', 'v_proj']
+  apply_lora_to_mlp: False
+  apply_lora_to_output: False
+  lora_rank: 8
+  lora_alpha: 16
+  lora_dropout: 0.0
+  image_size: 560 # Make sure this matches the image_size in tokenizer
+
+# Transform
+tokenizer:
+  _component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform
+  path: /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model
+  image_size: 560
+  max_seq_len: 8192
+
+# Checkpointer
+checkpointer:
+  _component_: torchtune.training.FullModelMetaCheckpointer
+  checkpoint_dir: /tmp/Llama-3.2-11B-Vision-Instruct/original/
+  checkpoint_files: [consolidated.pth]
+  recipe_checkpoint: null
+  output_dir: /tmp/Llama-3.2-11B-Vision-Instruct/
+  model_type: LLAMA3_VISION
+resume_from_checkpoint: False
+
+# Dataset
+dataset:
+  _component_: torchtune.datasets.multimodal.the_cauldron_dataset
+  subset: ocrvqa
+seed: null
+shuffle: True
+collate_fn: torchtune.data.padded_collate_tiled_images_and_mask
+
+# Fine-tuning arguments
+epochs: 1
+max_steps_per_epoch: null
+batch_size: 2
+gradient_accumulation_steps: 4
+optimizer:
+  _component_: torch.optim.AdamW
+  fused: True
+  weight_decay: 0.01
+  lr: 2e-5
+lr_scheduler:
+  _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
+  num_warmup_steps: 100
+loss:
+  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
+clip_grad_norm: 1.0
+compile: False # set it to True for better memory and performance
+
+# Training env
+device: cuda
+
+# Memory management
+enable_activation_checkpointing: True
+enable_activation_offloading: False
+dtype: bf16
+
+# Logging
+output_dir: /tmp/qlora-llama3.2-vision-finetune
+metric_logger:
+  _component_: torchtune.training.metric_logging.DiskLogger
+  log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs
+log_every_n_steps: 1
+log_peak_memory_stats: False
diff --git a/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml b/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml
new file mode 100644
index 0000000000..b12d51237c
--- /dev/null
+++ b/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml
@@ -0,0 +1,113 @@
+# Config for single device QLoRA finetuning in lora_finetune_single_device.py
+# using a Llama3.2 11B Vision Instruct model
+#
+# This config assumes that you've run the following command before launching:
+#   tune download meta-llama/Llama-3.2-11B-Vision-Instruct --output-dir /tmp/Llama-3.2-11B-Vision-Instruct
+#
+# To launch on a single device, run the following command from root:
+#   tune run lora_finetune_single_device --config llama3_2_vision/11B_qlora_single_device
+#
+# You can add specific overrides through the command line. For example
+# to override the checkpointer directory while launching training:
+#   tune run lora_finetune_single_device --config llama3_2_vision/11B_qlora_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
+#
+# This config works only for training on single device.
+
+# Model arguments
+model:
+  _component_: torchtune.models.llama3_2_vision.qlora_llama3_2_vision_11b
+  decoder_trainable: "frozen"
+  encoder_trainable: "lora"
+  fusion_trainable: "lora"
+  lora_attn_modules: ['q_proj', 'v_proj']
+  apply_lora_to_mlp: False
+  apply_lora_to_output: False
+  lora_rank: 8
+  lora_alpha: 16
+  lora_dropout: 0.0
+  image_size: 560 # Make sure this matches the image_size in tokenizer
+
+# Transform
+tokenizer:
+  _component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform
+  path: /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model
+  image_size: 560
+  max_seq_len: 8192
+
+# Checkpointer
+checkpointer:
+  _component_: torchtune.training.FullModelMetaCheckpointer
+  checkpoint_dir: /tmp/Llama-3.2-11B-Vision-Instruct/original/
+  checkpoint_files: [consolidated.pth]
+  recipe_checkpoint: null
+  output_dir: /tmp/Llama-3.2-11B-Vision-Instruct/
+  model_type: LLAMA3_VISION
+resume_from_checkpoint: False
+
+# Dataset
+dataset:
+  _component_: torchtune.datasets.multimodal.the_cauldron_dataset
+  subset: ocrvqa
+seed: null
+shuffle: True
+collate_fn: torchtune.data.padded_collate_tiled_images_and_mask
+
+# Fine-tuning arguments
+epochs: 1
+max_steps_per_epoch: null
+batch_size: 2
+gradient_accumulation_steps: 16
+optimizer:
+  _component_: torch.optim.AdamW
+  fused: True
+  weight_decay: 0.01
+  lr: 2e-5
+optimizer_in_bwd: False
+lr_scheduler:
+  _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
+  num_warmup_steps: 100
+loss:
+  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
+clip_grad_norm: 1.0
+compile: False # set it to True for better memory and performance
+
+# Training env
+device: cuda
+
+# Memory management
+enable_activation_checkpointing: True
+enable_activation_offloading: False
+dtype: bf16
+
+# Logging
+output_dir: /tmp/qlora-llama3.2-vision-finetune
+metric_logger:
+  _component_: torchtune.training.metric_logging.DiskLogger
+  log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs
+log_every_n_steps: 1
+log_peak_memory_stats: False
+
+# Profiler (disabled)
+profiler:
+  _component_: torchtune.training.setup_torch_profiler
+  enabled: False
+
+  #Output directory of trace artifacts
+  output_dir: ${output_dir}/profiling_outputs
+
+  #`torch.profiler.ProfilerActivity` types to trace
+  cpu: True
+  cuda: True
+
+  #trace options passed to `torch.profiler.profile`
+  profile_memory: True
+  with_stack: False
+  record_shapes: True
+  with_flops: False
+
+  # `torch.profiler.schedule` options:
+  # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
+  wait_steps: 1
+  warmup_steps: 2
+  active_steps: 1
+  num_cycles: 1
diff --git a/tests/torchtune/modules/low_precision/test_nf4_linear.py b/tests/torchtune/modules/low_precision/test_nf4_linear.py
index e29a87ba57..fcdb81c260 100644
--- a/tests/torchtune/modules/low_precision/test_nf4_linear.py
+++ b/tests/torchtune/modules/low_precision/test_nf4_linear.py
@@ -40,10 +40,6 @@ class TestNF4Linear:
     Class for testing our NF4Linear implementation.
     """
 
-    def test_bias_unsupported(self):
-        with pytest.raises(RuntimeError, match="does not currently support biases"):
-            _ = FrozenNF4Linear(1, 1, bias=True)
-
     @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
     def test_parameters(self, dtype):
         nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype)
@@ -59,9 +55,10 @@ def test_state_dict(self, dtype):
         assert isinstance(state_dict["weight"], NF4Tensor)
 
     @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
-    def test_output_dtype(self, dtype):
+    @pytest.mark.parametrize("bias", [True, False])
+    def test_output_dtype(self, dtype, bias):
         # Test to ensure W4 A16 produces A16 / W4A32 produces A32
-        nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype)
+        nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype, bias=bias)
         inp = torch.randn(2, 512, dtype=dtype, requires_grad=True)
         out = nf4_linear(inp)
         assert out.dtype == dtype
diff --git a/tests/torchtune/modules/peft/test_dora.py b/tests/torchtune/modules/peft/test_dora.py
index 02954fb685..1b48608852 100644
--- a/tests/torchtune/modules/peft/test_dora.py
+++ b/tests/torchtune/modules/peft/test_dora.py
@@ -49,80 +49,77 @@ def inputs(self, in_dim) -> torch.Tensor:
         return inputs
 
     @pytest.fixture
-    def dora_linear(self, in_dim, out_dim) -> DoRALinear:
-        dora_linear = DoRALinear(
-            in_dim=in_dim,
-            out_dim=out_dim,
-            rank=RANK,
-            alpha=ALPHA,
-            use_bias=False,
-        )
+    def dora_linear(self, in_dim, out_dim):
+        def create_dora_linear(use_bias, dtype, in_dim=in_dim, out_dim=out_dim):
+            with training.set_default_dtype(dtype):
+                dora_linear = DoRALinear(
+                    in_dim=in_dim,
+                    out_dim=out_dim,
+                    rank=RANK,
+                    alpha=ALPHA,
+                    use_bias=use_bias,
+                )
 
-        fixed_init_model(dora_linear)
-        return dora_linear
+                fixed_init_model(dora_linear)
+            return dora_linear
+
+        return create_dora_linear
 
     @pytest.fixture
-    def qdora_linear(self, in_dim, out_dim) -> DoRALinear:
-        with training.set_default_dtype(torch.bfloat16):
-            qdora_linear = DoRALinear(
-                in_dim=512,
-                out_dim=512,
-                rank=RANK,
-                alpha=ALPHA,
-                use_bias=False,
-                quantize_base=True,
-            )
-            fixed_init_model(qdora_linear, dtype=torch.bfloat16)
+    def qdora_linear(self):
+        def create_qdora_linear(
+            use_bias=False, dtype=torch.bfloat16, in_dim=512, out_dim=512
+        ):
+            with training.set_default_dtype(dtype):
+                qdora_linear = DoRALinear(
+                    in_dim=in_dim,
+                    out_dim=out_dim,
+                    rank=RANK,
+                    alpha=ALPHA,
+                    use_bias=use_bias,
+                    quantize_base=True,
+                )
+                fixed_init_model(qdora_linear)
             return qdora_linear
 
+        return create_qdora_linear
+
     def test_forward(self, inputs, dora_linear, out_dim) -> None:
+        dora_linear = dora_linear(use_bias=False, dtype=torch.float32)
         expected = torch.tensor(EXPECTED_VAL)
         actual = dora_linear(inputs)
         assert actual.shape == (BSZ, SEQ_LEN, out_dim)
         torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-6)
 
-    def test_dora_weight_nf4_when_quantized(self, qdora_linear):
+    @pytest.mark.parametrize("use_bias", [True, False])
+    def test_dora_weight_nf4_when_quantized(self, use_bias, qdora_linear):
+        qdora_linear = qdora_linear(use_bias=use_bias, dtype=torch.bfloat16)
         assert isinstance(qdora_linear.weight, NF4Tensor)
-
-    def test_bias_raises(self):
-        with pytest.raises(
-            NotImplementedError, match="DoRALinear does not support using bias"
-        ):
-            DoRALinear(
-                in_dim=512,
-                out_dim=512,
-                rank=RANK,
-                alpha=ALPHA,
-                use_bias=True,
-                quantize_base=False,
-            )
-
-    @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
-    def test_qdora_parity(self, dtype):
+        if use_bias:
+            assert not isinstance(qdora_linear.bias, NF4Tensor)
+            assert qdora_linear.bias.dtype == torch.bfloat16
+
+    # Note: with bfloat16 F.linear(x, weight, bias) != F.linear(x, weight) + bias.
+    # This means we would get different results (irrespective of QDoRA).
+    # So we leave that test case out
+    @pytest.mark.parametrize(
+        "use_bias, dtype",
+        [(False, torch.bfloat16), (True, torch.float32), (False, torch.float32)],
+    )
+    def test_qdora_parity(self, use_bias, dtype, dora_linear, qdora_linear):
         with training.set_default_dtype(dtype):
-            torch.manual_seed(0)
-            qdora_linear = DoRALinear(
-                in_dim=512,
-                out_dim=512,
-                rank=RANK,
-                alpha=ALPHA,
-                use_bias=False,
-                quantize_base=True,
+            qdora_linear = qdora_linear(
+                use_bias=use_bias, dtype=dtype, in_dim=512, out_dim=512
             )
-            torch.manual_seed(0)
-            dora_linear = DoRALinear(
-                in_dim=512,
-                out_dim=512,
-                rank=RANK,
-                alpha=ALPHA,
-                use_bias=False,
-                quantize_base=False,
+            dora_linear = dora_linear(
+                use_bias=use_bias, dtype=dtype, in_dim=512, out_dim=512
             )
 
         # set weight of dora_linear to unquantized weight of qdora_linear and check
         # parity.
         dora_linear.weight.data = qdora_linear.weight.to(dtype)
-
+        if use_bias:
+            dora_linear.bias.data = qdora_linear.bias.detach().clone()
         qdora_linear.initialize_dora_magnitude()
         dora_linear.initialize_dora_magnitude()
 
diff --git a/tests/torchtune/modules/peft/test_lora.py b/tests/torchtune/modules/peft/test_lora.py
index caca54b86b..fc76adea30 100644
--- a/tests/torchtune/modules/peft/test_lora.py
+++ b/tests/torchtune/modules/peft/test_lora.py
@@ -50,30 +50,37 @@ def inputs(self, in_dim) -> torch.Tensor:
 
     @pytest.fixture
     def lora_linear(self, in_dim, out_dim) -> LoRALinear:
-        lora_linear = LoRALinear(
-            in_dim=in_dim,
-            out_dim=out_dim,
-            rank=RANK,
-            alpha=ALPHA,
-            use_bias=True,
-        )
-        fixed_init_model(lora_linear)
-        return lora_linear
+        def create_lora_linear(use_bias, dtype, in_dim=in_dim, out_dim=out_dim):
+            with training.set_default_dtype(dtype):
+                lora_linear = LoRALinear(
+                    in_dim=in_dim,
+                    out_dim=out_dim,
+                    rank=RANK,
+                    alpha=ALPHA,
+                    use_bias=use_bias,
+                )
+                fixed_init_model(lora_linear)
+            return lora_linear
+
+        return create_lora_linear
 
     @pytest.fixture
-    def qlora_linear(self, in_dim, out_dim) -> LoRALinear:
-        with training.set_default_dtype(torch.bfloat16):
-            qlora_linear = LoRALinear(
-                in_dim=512,
-                out_dim=512,
-                rank=RANK,
-                alpha=ALPHA,
-                use_bias=False,
-                quantize_base=True,
-            )
-            fixed_init_model(qlora_linear, dtype=torch.bfloat16)
+    def qlora_linear(self):
+        def create_qlora_linear(use_bias, dtype, in_dim=512, out_dim=512):
+            with training.set_default_dtype(dtype):
+                qlora_linear = LoRALinear(
+                    in_dim=in_dim,
+                    out_dim=out_dim,
+                    rank=RANK,
+                    alpha=ALPHA,
+                    use_bias=use_bias,
+                    quantize_base=True,
+                )
+                fixed_init_model(qlora_linear)
             return qlora_linear
 
+        return create_qlora_linear
+
     @torch.no_grad()
     def set_dummy_weights_for_merge(self, lora_module):
         lora_module.lora_a.weight = nn.Parameter(
@@ -92,55 +99,47 @@ def set_dummy_weights_for_merge(self, lora_module):
         lora_module.lora_b.weight[32, 1] = 12
 
     def test_forward(self, inputs, lora_linear, out_dim) -> None:
+        lora_linear = lora_linear(use_bias=True, dtype=torch.float32)
         expected = torch.tensor(EXPECTED_VAL)
         actual = lora_linear(inputs)
         assert actual.shape == (BSZ, SEQ_LEN, out_dim)
         torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-6)
 
-    def test_lora_weight_nf4_when_quantized(self, qlora_linear):
+    @pytest.mark.parametrize("use_bias", [True, False])
+    def test_lora_weight_nf4_when_quantized(self, use_bias, qlora_linear):
+        qlora_linear = qlora_linear(use_bias=use_bias, dtype=torch.bfloat16)
         assert isinstance(qlora_linear.weight, NF4Tensor)
-
-    def test_quantize_with_bias_raises(self):
-        with pytest.raises(NotImplementedError, match="does not support bias"):
-            LoRALinear(
-                in_dim=512,
-                out_dim=512,
-                rank=RANK,
-                alpha=ALPHA,
-                use_bias=True,
-                quantize_base=True,
-            )
-
-    @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
-    def test_qlora_parity(self, dtype):
-        with training.set_default_dtype(dtype):
-            qlora_linear = LoRALinear(
-                in_dim=512,
-                out_dim=512,
-                rank=RANK,
-                alpha=ALPHA,
-                use_bias=False,
-                quantize_base=True,
-            )
-            lora_linear = LoRALinear(
-                in_dim=512,
-                out_dim=512,
-                rank=RANK,
-                alpha=ALPHA,
-                use_bias=False,
-                quantize_base=False,
-            )
+        if use_bias:
+            assert not isinstance(qlora_linear.bias, NF4Tensor)
+            assert qlora_linear.bias.dtype == torch.bfloat16
+
+    # Note: with bfloat16 F.linear(x, weight, bias) != F.linear(x, weight) + bias.
+    # This means we would get different results (irrespective of QLoRA).
+    # So we leave that test case out
+    @pytest.mark.parametrize(
+        "use_bias, dtype",
+        [(False, torch.bfloat16), (True, torch.float32), (False, torch.float32)],
+    )
+    def test_qlora_parity(self, use_bias, dtype, qlora_linear, lora_linear):
+        qlora_linear = qlora_linear(
+            use_bias=use_bias, dtype=dtype, in_dim=512, out_dim=512
+        )
+        lora_linear = lora_linear(
+            use_bias=use_bias, dtype=dtype, in_dim=512, out_dim=512
+        )
 
         # set weight of lora_linear to unquantized weight of qlora_linear and check
         # parity.
         lora_linear.weight.data = qlora_linear.weight.to(dtype)
-
+        if use_bias:
+            lora_linear.bias.data = qlora_linear.bias.detach().clone()
         # Ensure forward passes are the same. This is because LoRALinear should use a special
         # quantized linear operator that runs compute in higher prec (but only saves the 4 bit quantized tensor)
         # for autograd.
         inputs = torch.randn(BSZ, SEQ_LEN, 512, dtype=dtype)
         lora_linear_out = lora_linear(inputs)
         qlora_linear_out = qlora_linear(inputs)
+
         torch.testing.assert_close(lora_linear_out, qlora_linear_out)
 
     @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py
index 3f6697f593..8a79b6147f 100644
--- a/torchtune/_recipe_registry.py
+++ b/torchtune/_recipe_registry.py
@@ -216,6 +216,10 @@ class Recipe:
                 name="llama3_2_vision/11B_lora_single_device",
                 file_path="llama3_2_vision/11B_lora_single_device.yaml",
             ),
+            Config(
+                name="llama3_2_vision/11B_qlora_single_device",
+                file_path="llama3_2_vision/11B_qlora_single_device.yaml",
+            ),
         ],
         supports_distributed=False,
     ),
@@ -289,6 +293,10 @@ class Recipe:
                 name="llama3_2_vision/11B_lora",
                 file_path="llama3_2_vision/11B_lora.yaml",
             ),
+            Config(
+                name="llama3_2_vision/11B_qlora",
+                file_path="llama3_2_vision/11B_qlora.yaml",
+            ),
         ],
         supports_distributed=True,
     ),
diff --git a/torchtune/models/clip/_component_builders.py b/torchtune/models/clip/_component_builders.py
index 150261fd23..772d1e32df 100644
--- a/torchtune/models/clip/_component_builders.py
+++ b/torchtune/models/clip/_component_builders.py
@@ -18,8 +18,8 @@
 from torchtune.modules import (
     FeedForward,
     Fp32LayerNorm,
+    FrozenNF4Linear,
     MultiHeadAttention,
-    TanhGate,
     TransformerSelfAttentionLayer,
 )
 
@@ -170,12 +170,12 @@ def clip_mlp(
     gate_proj = (
         nn.Linear(in_dim, hidden_dim)
         if not quantize_base
-        else FrozenNF4Linear(in_dim, hidden_dim)
+        else FrozenNF4Linear(in_dim, hidden_dim, bias=True)
     )
     down_proj = (
         nn.Linear(hidden_dim, out_dim)
         if not quantize_base
-        else FrozenNF4Linear(hidden_dim, out_dim)
+        else FrozenNF4Linear(hidden_dim, out_dim, bias=True)
     )
     return FeedForward(
         gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation
diff --git a/torchtune/models/llama3_2_vision/_component_builders.py b/torchtune/models/llama3_2_vision/_component_builders.py
index 111393501d..8881c87531 100644
--- a/torchtune/models/llama3_2_vision/_component_builders.py
+++ b/torchtune/models/llama3_2_vision/_component_builders.py
@@ -4,31 +4,44 @@
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
 
-from functools import partial
 from enum import Enum
-from typing import Optional, List
+from functools import partial
+from typing import List, Optional
 
 from torch import nn
+from torchtune.models.clip._component_builders import (
+    clip_mlp,
+    clip_vision_encoder,
+    lora_clip_attention,
+    lora_clip_mlp,
+    lora_clip_vision_encoder,
+)
 
 from torchtune.models.llama3._model_utils import scale_hidden_dim_for_mlp
-from torchtune.models.llama3_1._component_builders import llama3_mlp, lora_llama3_mlp, lora_llama3_attention
+from torchtune.models.llama3_1._component_builders import (
+    llama3_mlp,
+    lora_llama3_attention,
+    lora_llama3_mlp,
+)
 from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
-from torchtune.models.clip._component_builders import clip_vision_encoder, clip_mlp, lora_clip_attention, lora_clip_mlp, lora_clip_vision_encoder
-from torchtune.models.llama3_2_vision._encoder import Llama3VisionProjectionHead, Llama3VisionEncoder
-
-from torchtune.modules.model_fusion import FusionEmbedding, FusionLayer
+from torchtune.models.llama3_2_vision._encoder import (
+    Llama3VisionEncoder,
+    Llama3VisionProjectionHead,
+)
 from torchtune.modules import (
+    Fp32LayerNorm,
+    MultiHeadAttention,
     RMSNorm,
     TanhGate,
     TransformerCrossAttentionLayer,
-    MultiHeadAttention,
     TransformerDecoder,
     TransformerSelfAttentionLayer,
-    Fp32LayerNorm
 )
 
 from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook
 
+from torchtune.modules.model_fusion import FusionEmbedding, FusionLayer
+
 from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear
 
 
@@ -59,7 +72,7 @@ def llama3_2_vision_encoder(
     tile_size: int,
     max_num_tiles: int = 4,
     in_channels: int = 3,
-    ) -> Llama3VisionEncoder:
+) -> Llama3VisionEncoder:
     """
     Build the Llama 3.2 vision encoder by combining the CLIP image model with an additional
     projection head fusion module. This includes:
@@ -76,7 +89,7 @@ def llama3_2_vision_encoder(
         clip_embed_dim (int): The dimensionality of each patch embedding in CLIP.
         clip_num_layers (int): The number of transformer layers.
         clip_hidden_states (Optional[List[int]]): The indices of CLIP hidden layers to return
-            to return to the encoder projection head. It will return the intermediate results 
+            to return to the encoder projection head. It will return the intermediate results
             of the vision transformer layers which will be concatenated with the CLIP output
             and input into the projection head. For example, ``clip_hidden_states=[0,3]`` will
             return the embeddings before they go through the first and fourth layers.
@@ -113,7 +126,7 @@ def llama3_2_vision_encoder(
         num_heads=num_heads,
         decoder_embed_dim=decoder_embed_dim,
         clip_embed_dim=clip_embed_dim,
-        num_hidden_inputs=len(clip_hidden_states or [])
+        num_hidden_inputs=len(clip_hidden_states or []),
     )
 
     return Llama3VisionEncoder(clip=clip, projection_head=projection_head)
@@ -239,6 +252,7 @@ def llama3_2_vision_decoder(
         output=output_proj,
     )
 
+
 def llama3_2_vision_projection_head(
     *,
     num_layers: int,
@@ -306,9 +320,10 @@ def llama3_2_vision_projection_head(
     return Llama3VisionProjectionHead(
         layers=layers,
         output=nn.Linear(proj_in, decoder_embed_dim),
-        num_hidden_inputs=num_hidden_inputs
+        num_hidden_inputs=num_hidden_inputs,
     )
 
+
 # ------------------ LoRA Llama 3.2 Vision ------------------
 
 
@@ -344,7 +359,7 @@ def lora_llama3_2_vision_encoder(
     lora_dropout: float = 0.0,
     use_dora: bool = False,
     quantize_base: bool = False,
-    ) -> Llama3VisionEncoder:
+) -> Llama3VisionEncoder:
     """
     Build the Llama 3.2 vision encoder by combining the CLIP image model with an additional
     projection head fusion module. This includes:
@@ -370,7 +385,7 @@ def lora_llama3_2_vision_encoder(
         clip_embed_dim (int): The dimensionality of each patch embedding in CLIP.
         clip_num_layers (int): The number of transformer layers.
         clip_hidden_states (Optional[List[int]]): The indices of CLIP hidden layers to return
-            to return to the encoder projection head. It will return the intermediate results 
+            to return to the encoder projection head. It will return the intermediate results
             of the vision transformer layers which will be concatenated with the CLIP output
             and input into the projection head. For example, ``clip_hidden_states=[0,3]`` will
             return the embeddings before they go through the first and fourth layers.
@@ -388,7 +403,7 @@ def lora_llama3_2_vision_encoder(
         quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base
             weights within linear layers LoRA is applied to. The final output linear projection is not
             supported for quantization currently.
-        
+
 
     Returns:
         Llama3VisionEncoder: Instantiation of Llama 3.2 vision encoder.
@@ -423,7 +438,7 @@ def lora_llama3_2_vision_encoder(
     else:
         clip = clip_vision_encoder(**clip_options)
 
-    # Projection 
+    # Projection
     projection_options = {
         "num_layers": num_layers_projection,
         "num_heads": num_heads,
@@ -432,11 +447,22 @@ def lora_llama3_2_vision_encoder(
         "num_hidden_inputs": len(clip_hidden_states or []),
     }
     if fusion_lora:
-        projection_head = lora_llama3_2_vision_projection_head(**projection_options, **lora_options)
+        projection_head = lora_llama3_2_vision_projection_head(
+            **projection_options, **lora_options
+        )
     else:
         projection_head = lora_llama3_2_vision_projection_head(**projection_options)
 
-    return Llama3VisionEncoder(clip=clip, projection_head=projection_head)
+    encoder = Llama3VisionEncoder(clip=clip, projection_head=projection_head)
+
+    if quantize_base:
+        # For QLoRA, we reparametrize 4-bit tensors to bf16, and offload to CPU on the fly
+        # so as to not increase peak memory
+        encoder._register_state_dict_hook(
+            partial(reparametrize_as_dtype_state_dict_post_hook, offload_to_cpu=True)
+        )
+
+    return encoder
 
 
 def lora_llama3_2_vision_decoder(
@@ -458,7 +484,7 @@ def lora_llama3_2_vision_decoder(
     encoder_max_seq_len: int,
     rope_base: int = 500000.0,
     intermediate_dim: Optional[int] = None,
-     # LoRA parameters
+    # LoRA parameters
     lora_rank: int = 8,
     lora_alpha: float = 16,
     lora_dropout: float = 0.0,
@@ -546,7 +572,9 @@ def lora_llama3_2_vision_decoder(
                 use_dora=use_dora,
             )
         else:
-            mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base)
+            mlp = llama3_mlp(
+                dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base
+            )
         decoder_layer = TransformerSelfAttentionLayer(
             attn=self_attn,
             mlp=mlp,
@@ -586,7 +614,9 @@ def lora_llama3_2_vision_decoder(
                     use_dora=use_dora,
                 )
             else:
-                mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base)
+                mlp = llama3_mlp(
+                    dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base
+                )
             xattn_layer = TransformerCrossAttentionLayer(
                 attn=attn,
                 mlp=mlp,
@@ -601,11 +631,17 @@ def lora_llama3_2_vision_decoder(
             layers.append(decoder_layer)
 
     tok_embeddings = FusionEmbedding(vocab_size, num_special_tokens, embed_dim)
-    
-     # TODO: quantize_base is not applied to final output_proj currently.
+
+    # TODO: quantize_base is not applied to final output_proj currently.
     adapter_cls = DoRALinear if use_dora else LoRALinear
     output_proj = (
-        adapter_cls(embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout)
+        adapter_cls(
+            embed_dim,
+            vocab_size,
+            rank=lora_rank,
+            alpha=lora_alpha,
+            dropout=lora_dropout,
+        )
         if apply_lora_to_output
         else nn.Linear(embed_dim, vocab_size, bias=False)
     )
@@ -713,7 +749,7 @@ def lora_llama3_2_vision_projection_head(
                 hidden_dim=hidden_dim,
                 out_dim=clip_embed_dim,
                 activation=nn.GELU(),
-                quantize_base=quantize_base
+                quantize_base=quantize_base,
             )
 
         layer = TransformerSelfAttentionLayer(
@@ -733,7 +769,14 @@ def lora_llama3_2_vision_projection_head(
     proj_in = clip_embed_dim * (num_hidden_inputs + 1)
     adapter_cls = DoRALinear if use_dora else LoRALinear
     output_proj = (
-        adapter_cls(proj_in, decoder_embed_dim, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, use_bias=True)
+        adapter_cls(
+            proj_in,
+            decoder_embed_dim,
+            rank=lora_rank,
+            alpha=lora_alpha,
+            dropout=lora_dropout,
+            use_bias=True,
+        )
         if apply_lora_to_output
         else nn.Linear(proj_in, decoder_embed_dim)
     )
diff --git a/torchtune/modules/low_precision/nf4_linear.py b/torchtune/modules/low_precision/nf4_linear.py
index 9b0eaf53a3..0c387ffa0f 100644
--- a/torchtune/modules/low_precision/nf4_linear.py
+++ b/torchtune/modules/low_precision/nf4_linear.py
@@ -18,32 +18,34 @@ class FrozenNF4Linear(nn.Linear):
     NF4Tensor as its weight. This class also freezes its ``weight`` parameter
     and is meant to be used as the base Linear layer for modeling
     use cases such as QLoRA where base model parameters are frozen.
-    NOTE: biases are currently not supported.
 
     Args:
         in_dim (int): input dimension
         out_dim (int): output dimension
         device (Optional[torch.device]): device to use for the underlying weight. If ``None``, uses the default
             device given by `torch.get_default_device()`.
+        bias (bool): whether to include bias in the linear layer. Default: False
         **kwargs: any additional arguments to pass to the underlying Linear layer.
 
-    Raises:
-        RuntimeError: if ``bias`` is set to ``True``
     """
 
     def __init__(
-        self, in_dim: int, out_dim: int, device: Optional[torch.device] = None, **kwargs
+        self,
+        in_dim: int,
+        out_dim: int,
+        device: Optional[torch.device] = None,
+        bias: bool = False,
+        **kwargs,
     ):
-        if "bias" in kwargs and kwargs.pop("bias"):
-            raise RuntimeError("FrozenNF4Linear does not currently support biases!")
-
-        super().__init__(in_dim, out_dim, device=device, bias=False, **kwargs)
+        super().__init__(in_dim, out_dim, device=device, bias=bias, **kwargs)
         self.weight.requires_grad_(False)
-        self.nf4_weight = to_nf4(self.weight)
+        if self.bias is not None:
+            self.bias.requires_grad_(False)
+        nf4_weight = to_nf4(self.weight)
         # re-register self.weight as the nf4 weight, so that the nf4 weight
         # shows up as expected in .parameters, state_dict, etc.
         torch.utils.swap_tensors(
-            self.weight, torch.nn.Parameter(self.nf4_weight, requires_grad=False)
+            self.weight, torch.nn.Parameter(nf4_weight, requires_grad=False)
         )
 
     def forward(self, input: torch.Tensor) -> torch.Tensor:
@@ -57,4 +59,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
         Returns:
             Tensor: output tensor
         """
-        return linear_nf4(input=input, weight=self.weight)
+        out = linear_nf4(input=input, weight=self.weight)
+        if self.bias is not None:
+            out = out + self.bias
+        return out
diff --git a/torchtune/modules/model_fusion/_fusion.py b/torchtune/modules/model_fusion/_fusion.py
index 1a5452daae..907c7a2ed0 100644
--- a/torchtune/modules/model_fusion/_fusion.py
+++ b/torchtune/modules/model_fusion/_fusion.py
@@ -232,10 +232,11 @@ def _load_state_dict_hook(self, state_dict, prefix, *args, **kwargs):
         """Apply extra "embedding" prefix to the state_dict key to
         account for the FusionEmbedding wrapping.
         """
-        key = prefix + "weight"
-        new_key = prefix + "embedding.weight"
-        state_dict[new_key] = state_dict[key]
-        del state_dict[key]
+        if state_dict:
+            key = prefix + "weight"
+            new_key = prefix + "embedding.weight"
+            state_dict[new_key] = state_dict[key]
+            del state_dict[key]
 
     def fusion_params(self) -> List[str]:
         """
diff --git a/torchtune/modules/peft/dora.py b/torchtune/modules/peft/dora.py
index e0a8fe9788..153b3c78e1 100644
--- a/torchtune/modules/peft/dora.py
+++ b/torchtune/modules/peft/dora.py
@@ -39,8 +39,6 @@ class DoRALinear(nn.Module, AdapterModule):
         quantize_base (bool): Whether to quantize base linear weight or not.
             Default: False
 
-    Raises:
-        NotImplementedError: If use_bias is enabled.
     """
 
     def __init__(
@@ -54,14 +52,16 @@ def __init__(
         quantize_base: bool = False,
     ):
         super().__init__()
-        if use_bias:
-            raise NotImplementedError("DoRALinear does not support using bias")
         self.in_dim = in_dim
         self.out_dim = out_dim
         self.scaling = alpha / rank
+        self.use_bias = use_bias
         self._quantize_base = quantize_base
-        weight = self._create_weight()
+        weight, bias = self._create_weight_and_bias()
         self.register_parameter("weight", nn.Parameter(weight))
+        self.register_parameter(
+            "bias", nn.Parameter(bias) if bias is not None else None
+        )
 
         # 'self.disabled' is a flag showing whether to turn off DoRA adapters,
         # this can be used in DPO for treating the dora adapters as the policy model
@@ -90,15 +90,18 @@ def initialize_dora_magnitude(self):
         weight_norm = self._get_weight_norm(base_weight, lora_weight)
         self.magnitude = nn.Parameter(weight_norm, requires_grad=True)
 
-    def _create_weight(self):
+    def _create_weight_and_bias(self):
         """
         Creates a linear weight and bias tensor, using NF4 dtype if we're quantizing
         (indicated via quantize_base=True).
         """
-        in_dim, out_dim = self.in_dim, self.out_dim
-        linear = nn.Linear(in_features=in_dim, out_features=out_dim, bias=False)
+        in_dim, out_dim, use_bias = self.in_dim, self.out_dim, self.use_bias
+        linear = nn.Linear(in_features=in_dim, out_features=out_dim, bias=use_bias)
         weight = linear.weight if not self._quantize_base else to_nf4(linear.weight)
-        return weight
+        bias = None
+        if self.use_bias:
+            bias = linear.bias
+        return weight, bias
 
     def _get_weight_norm(self, weight, lora_weight):
         weight = weight + self.scaling * lora_weight
@@ -123,8 +126,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
         """
         if self._quantize_base:
             base_out = linear_nf4(input=x, weight=self.weight)
+            if self.use_bias:
+                base_out = base_out + self.bias
         else:
-            base_out = F.linear(x, self.weight)
+            base_out = F.linear(x, self.weight, self.bias)
         if self.disabled:
             return base_out
 
diff --git a/torchtune/modules/peft/lora.py b/torchtune/modules/peft/lora.py
index 57d72af672..3b90a89306 100644
--- a/torchtune/modules/peft/lora.py
+++ b/torchtune/modules/peft/lora.py
@@ -95,10 +95,6 @@ def _create_weight_and_bias(self):
         weight = linear.weight if not self._quantize_base else to_nf4(linear.weight)
         bias = None
         if self.use_bias:
-            if self._quantize_base:
-                raise NotImplementedError(
-                    "Quantized LoRALinear does not support bias at the moment."
-                )
             bias = linear.bias
         return weight, bias
 
@@ -123,6 +119,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
         """
         if self._quantize_base:
             out = linear_nf4(input=x, weight=self.weight)
+            if self.use_bias:
+                out = out + self.bias
         else:
             out = F.linear(x, self.weight, self.bias)
         if self.disabled: