You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks to #34184 we can use TP for llama with only one line change. However the current implementation loads the whole model to each GPU in each rank before applying TP, significantly increasing the memory footprint.
Motivation
We can load the model in CPU before applying TP. I tested this with llama3.1 8B on 2 GPUs. The memory usage is reduced from 60G to less than 20G. Below is my test script
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.distributed import device_mesh
from stainedglass_core.integrations.lm_eval.models.tensor_parallel.llama import parallelize_model
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
rank = int(os.environ["RANK"])
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
torch.distributed.init_process_group("nccl")
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map='cpu',
)
num_gpus = torch.cuda.device_count()
tp_mesh = device_mesh.init_device_mesh("cuda", (num_gpus,), mesh_dim_names=("tp",))
model.tensor_parallel(tp_mesh)
model.to(device) # needed for weights and buffers that are not included by the TP plan
tokenizer = AutoTokenizer.from_pretrained(model_id)
prompt = "Can I help"
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
outputs = model(inputs)
print(tokenizer.decode(outputs.logits.squeeze()[-1].argmax()))
Your contribution
We can set device_map to cpu in PreTrainedModel.from_pretrained if tp_plan is not None, and apply TP at the end.
happy to have discussions and work on a pr for this.
Feature request
Thanks to #34184 we can use TP for llama with only one line change. However the current implementation loads the whole model to each GPU in each rank before applying TP, significantly increasing the memory footprint.
Motivation
We can load the model in CPU before applying TP. I tested this with llama3.1 8B on 2 GPUs. The memory usage is reduced from 60G to less than 20G. Below is my test script
Your contribution
We can set
device_map
tocpu
inPreTrainedModel.from_pretrained
iftp_plan
is notNone
, and apply TP at the end.happy to have discussions and work on a pr for this.
CC @kwen2501 @ArthurZucker
The text was updated successfully, but these errors were encountered: