Skip to content

qwen2 is not supported by QAT #1818

@elfisworking

Description

@elfisworking

i try to use QAT to quantize qwen2 1.5B model
The error raise from function training.load_from_full_model_state_dict( model, model_state_dict, self._device, self._is_rank_zero, strict=True ) from recipes/qat_distributed
Then i find error caused by

# torchtune/torchtune/training/_distributed.py
def load_from_full_model_state_dict(
    model: "FSDPModule",  # noqa
    full_sd: Dict[str, Any],
    device: torch.device,
    is_rank_zero: bool,
    strict: bool = False,
    cpu_offload: bool = False,
):
    """
    Converting full state dict into a sharded state dict
    and loading it into FSDP model
    - 'full' means plain tensor
    - 'sharded' means `DTensor` where reach rank has a shard of the plain tensor
    - `is_rank_zero` matters if only rank 0 pass in non-empty `full_sd` and
       we need to broadcast from rank 0
    """
    meta_sharded_sd = model.state_dict()
    sharded_sd = {}
    for param_name, full_tensor in full_sd.items():
        sharded_meta_param = meta_sharded_sd.get(param_name)
        full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device)

it shows us that shared_meta_param.dtype is None.
By adding printing function, i find meta_shared_sd doesn‘t have bias layer

sharded_meta_param DTensor(local_tensor=tensor(..., device='meta', size=(151936, 1536), dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0]), placements=(Shard(dim=0),))  dtype:  torch.bfloat16
param_name: layers.0.sa_norm.scale
sharded_meta_param DTensor(local_tensor=tensor(..., device='meta', size=(1536,), dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0]), placements=(Shard(dim=0),))  dtype:  torch.bfloat16
param_name: layers.0.mlp.w2.weight
sharded_meta_param DTensor(local_tensor=tensor(..., device='meta', size=(1536, 8960), dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0]), placements=(Shard(dim=0),))  dtype:  torch.bfloat16
param_name: layers.0.mlp.w1.weight
sharded_meta_param DTensor(local_tensor=tensor(..., device='meta', size=(8960, 1536), dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0]), placements=(Shard(dim=0),))  dtype:  torch.bfloat16
param_name: layers.0.mlp.w3.weight
sharded_meta_param DTensor(local_tensor=tensor(..., device='meta', size=(8960, 1536), dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0]), placements=(Shard(dim=0),))  dtype:  torch.bfloat16
param_name: layers.0.mlp_norm.scale
sharded_meta_param DTensor(local_tensor=tensor(..., device='meta', size=(1536,), dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0]), placements=(Shard(dim=0),))  dtype:  torch.bfloat16
param_name: layers.0.attn.k_proj.bias
### error raise

my yaml file is as follows:

# Tokenizer
tokenizer:
  _component_: torchtune.models.qwen2.qwen2_1_5b
  path: /QAT/Qwen2-1.5B/vocab.json
  merges_file: /QAT/Qwen2-1.5b/merges.txt
  max_seq_len: null

# Dataset
dataset:
  _component_: torchtune.datasets.alpaca_dataset
  source: parquet
  data_files: /QAT/dataset/data/train-00000-of-00001-a09b74b3ef9c3b56.parquet
seed: 42
shuffle: True

# Model Arguments
model:
  _component_: torchtune.models.qwen2.qwen2_1_5b

checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /QAT/Qwen2-1.5B/
  checkpoint_files: [
   model.safetensors
  ]
  recipe_checkpoint: null
  output_dir: /QAT/Qwen2-1.5B
  model_type: QWEN2
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 8
epochs: 1

# QAT arguments
quantizer:
  _component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
  groupsize: 256

optimizer:
  _component_: torch.optim.AdamW
  lr: 2e-5
  fused: True
loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1

# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True
memory_efficient_fsdp_wrap: True

# Reduced precision
dtype: bf16

# Logging
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: ${output_dir}
output_dir: /QAT/Qwen2-1.5B/finetune-logs
log_every_n_steps: 1
log_peak_memory_stats: False

Metadata

Metadata

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions