-
Notifications
You must be signed in to change notification settings - Fork 679
Open
Labels
enhancementNew feature or requestNew feature or request
Description
i try to use QAT to quantize qwen2 1.5B model
The error raise from function training.load_from_full_model_state_dict( model, model_state_dict, self._device, self._is_rank_zero, strict=True ) from recipes/qat_distributed
Then i find error caused by
# torchtune/torchtune/training/_distributed.py
def load_from_full_model_state_dict(
model: "FSDPModule", # noqa
full_sd: Dict[str, Any],
device: torch.device,
is_rank_zero: bool,
strict: bool = False,
cpu_offload: bool = False,
):
"""
Converting full state dict into a sharded state dict
and loading it into FSDP model
- 'full' means plain tensor
- 'sharded' means `DTensor` where reach rank has a shard of the plain tensor
- `is_rank_zero` matters if only rank 0 pass in non-empty `full_sd` and
we need to broadcast from rank 0
"""
meta_sharded_sd = model.state_dict()
sharded_sd = {}
for param_name, full_tensor in full_sd.items():
sharded_meta_param = meta_sharded_sd.get(param_name)
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device)
it shows us that shared_meta_param.dtype is None.
By adding printing function, i find meta_shared_sd doesn‘t have bias layer
sharded_meta_param DTensor(local_tensor=tensor(..., device='meta', size=(151936, 1536), dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0]), placements=(Shard(dim=0),)) dtype: torch.bfloat16
param_name: layers.0.sa_norm.scale
sharded_meta_param DTensor(local_tensor=tensor(..., device='meta', size=(1536,), dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0]), placements=(Shard(dim=0),)) dtype: torch.bfloat16
param_name: layers.0.mlp.w2.weight
sharded_meta_param DTensor(local_tensor=tensor(..., device='meta', size=(1536, 8960), dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0]), placements=(Shard(dim=0),)) dtype: torch.bfloat16
param_name: layers.0.mlp.w1.weight
sharded_meta_param DTensor(local_tensor=tensor(..., device='meta', size=(8960, 1536), dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0]), placements=(Shard(dim=0),)) dtype: torch.bfloat16
param_name: layers.0.mlp.w3.weight
sharded_meta_param DTensor(local_tensor=tensor(..., device='meta', size=(8960, 1536), dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0]), placements=(Shard(dim=0),)) dtype: torch.bfloat16
param_name: layers.0.mlp_norm.scale
sharded_meta_param DTensor(local_tensor=tensor(..., device='meta', size=(1536,), dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0]), placements=(Shard(dim=0),)) dtype: torch.bfloat16
param_name: layers.0.attn.k_proj.bias
### error raise
my yaml file is as follows:
# Tokenizer
tokenizer:
_component_: torchtune.models.qwen2.qwen2_1_5b
path: /QAT/Qwen2-1.5B/vocab.json
merges_file: /QAT/Qwen2-1.5b/merges.txt
max_seq_len: null
# Dataset
dataset:
_component_: torchtune.datasets.alpaca_dataset
source: parquet
data_files: /QAT/dataset/data/train-00000-of-00001-a09b74b3ef9c3b56.parquet
seed: 42
shuffle: True
# Model Arguments
model:
_component_: torchtune.models.qwen2.qwen2_1_5b
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /QAT/Qwen2-1.5B/
checkpoint_files: [
model.safetensors
]
recipe_checkpoint: null
output_dir: /QAT/Qwen2-1.5B
model_type: QWEN2
resume_from_checkpoint: False
# Fine-tuning arguments
batch_size: 8
epochs: 1
# QAT arguments
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256
optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
fused: True
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1
# Training env
device: cuda
# Memory management
enable_activation_checkpointing: True
memory_efficient_fsdp_wrap: True
# Reduced precision
dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
output_dir: /QAT/Qwen2-1.5B/finetune-logs
log_every_n_steps: 1
log_peak_memory_stats: False
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request