Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] EVA meta device check bug + add multi-gpu functionality #2218

Merged
merged 5 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/developer_guides/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ config = LoraConfig(init_lora_weights="olora", ...)
For more advanced usage, please refer to our [documentation](https://github.com/huggingface/peft/tree/main/examples/olora_finetuning).

### EVA
[EVA](https://arxiv.org/pdf/2410.07170) performs SVD on the input activations of each layer and uses the right-singular vectors to initialize LoRA weights. It therefore is a data-driven initialization scheme. Furthermore EVA adaptively allocates ranks across layers based on their "explained variance ratio" - a metric derived from the SVD analysis.
[EVA](https://arxiv.org/pdf/2410.07170) performs SVD on the input activations of each layer and uses the right-singular vectors to initialize LoRA weights. It is therefore a data-driven initialization scheme. Furthermore EVA adaptively allocates ranks across layers based on their "explained variance ratio" - a metric derived from the SVD analysis.

You can use EVA by setting `init_lora_weights="eva"` and defining [`EvaConfig`] in [`LoraConfig`]:
```python
Expand All @@ -76,7 +76,7 @@ peft_config = LoraConfig(
...
)
```
The parameter `rho` (≥ 1.0) determines how much redistribution is allowed. When `rho=1.0` and `r=16`, the system is limited to exactly 16 ranks, preventing any redistribution from occurring. A recommended value for eva with redistribution is 2.0, meaning the maximum rank allowed for a layer is 2r.
The parameter `rho` (≥ 1.0) determines how much redistribution is allowed. When `rho=1.0` and `r=16`, LoRA adapters are limited to exactly 16 ranks, preventing any redistribution from occurring. A recommended value for EVA with redistribution is 2.0, meaning the maximum rank allowed for a layer is 2r.

It is recommended to perform EVA initialization on a GPU as it is much faster. To optimize the amount of available memory for EVA, you can use the `low_cpu_mem_usage` flag in [`get_peft_model`]:
```python
Expand Down
7 changes: 6 additions & 1 deletion examples/eva_finetuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,23 @@ In some cases you might just want to get the state_dict after EVA initialization
- you want to precompute and store the state_dict for different downstream tasks.
- you need to quantize the model for finetuning but want to perform EVA initialization with model weights in full/half precision.
- you do not intend to use a peft model for LoRA finetuning.
- you would like to leverage multiple GPUs for EVA initialization. (At the moment this is not directly supported by `initialize_lora_eva_weights`)

You can do this by calling `get_eva_state_dict` directly (you only need to pass `peft_config` if `model` is not a PeftModel):
```python
from peft import get_eva_state_dict

eva_state_dict = get_eva_state_dict(model, dataloader, peft_config)
```
Later you can load the state_dict into a model without adapter weights by using the `eva_state_dict` argument in `initialize_lora_eva_weights`:
Later you can load the state_dict into a `PeftModel` by using the `eva_state_dict` argument in `initialize_lora_eva_weights`:
```python
initialize_lora_eva_weights(peft_model, eva_state_dict=eva_state_dict)
```

## Leveraging multiple GPUs

EVA initialization can be parallelized across multiple GPUs. In this case inputs from multiple GPUs are gathered before computing the SVD for the batch. This requires that the model is wrapped in a `torch.nn.DataParallel` or `torch.nn.DistributedDataParallel` class. An example of how to use this can be found in [eva_finetuning_multi_gpu.py](https://github.com/huggingface/peft/blob/main/examples/eva_finetuning/eva_finetuning_multi_gpu.py).

## Customizing EVA

By default, EVA is designed to work with standard transformer language models. However we integrated three different paramters which can be used to customize EVA for other types of models.
Expand Down
14 changes: 12 additions & 2 deletions examples/eva_finetuning/eva_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
Expand All @@ -20,6 +21,9 @@
from peft import EvaConfig, LoraConfig, get_peft_model, initialize_lora_eva_weights


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


# config
model_name = "meta-llama/Llama-3.1-8B"
max_seq_len = 512
Expand All @@ -29,9 +33,12 @@
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
svd_batch_size = 4 # can be different from the batch size used in finetuning
batch_size = 4
learning_rate = 5e-4
gradient_accumulation_steps = 8
num_epochs = 1
output_dir = "outputs"
device = "cuda:0"
bf16 = True


# load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_name)
Expand Down Expand Up @@ -63,7 +70,7 @@
)

# move model to GPU
model = model.to(device)
model = model.to(DEVICE)

# to optimize memory usage during eva initialization, set low_cpu_mem_usage=True
peft_model = get_peft_model(model, peft_config, low_cpu_mem_usage=True)
Expand All @@ -72,9 +79,12 @@
# setup training arguments
training_args = TrainingArguments(
per_device_train_batch_size=batch_size,
learning_rate=learning_rate,
gradient_accumulation_steps=gradient_accumulation_steps,
num_train_epochs=num_epochs,
output_dir=output_dir,
remove_unused_columns=False,
bf16=bf16,
)

# continue with standard finetuning
Expand Down
127 changes: 127 additions & 0 deletions examples/eva_finetuning/eva_finetuning_multi_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import torch
import torch.distributed as dist
from datasets import load_dataset
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from utils import DataCollator, TokenizerMetaMath

from peft import EvaConfig, LoraConfig, get_eva_state_dict, get_peft_model, initialize_lora_eva_weights


# run this script e.g. with: torchrun --nproc_per_node=4 eva_finetuning_multi_gpu.py

# config
model_name = "meta-llama/Llama-2-7b-hf"
max_seq_len = 512
rank = 16
alpha = 1
rho = 2.0
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
svd_batch_size = 4 # can be different from the batch size used in finetuning
batch_size = 4
learning_rate = 5e-4
gradient_accumulation_steps = 8
num_epochs = 1
output_dir = "outputs"
bf16 = True


# Initialize distributed environment
if torch.cuda.is_available():
local_rank = int(os.environ.get("LOCAL_RANK", -1))
torch.cuda.set_device(local_rank)
dist.init_process_group("nccl")
world_size = dist.get_world_size()
else:
local_rank = -1
world_size = 1


# load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# load dataset
dataset = load_dataset("meta-math/MetaMathQA")
dataset = dataset.map(
TokenizerMetaMath(model_name),
batched=True,
remove_columns=dataset["train"].column_names,
)
dataset.set_format(type="torch")

# data collator
data_collator = DataCollator(tokenizer.eos_token_id, max_length=max_seq_len)

# Create sampler for distributed training
sampler = DistributedSampler(dataset["train"], num_replicas=world_size, rank=local_rank)

# dataloader
dataloader = DataLoader(
dataset["train"],
batch_size=svd_batch_size,
collate_fn=data_collator,
sampler=sampler,
shuffle=False,
)

sampler.set_epoch(0)

# Wrap model in DDP
model = model.to(local_rank)
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)

# setup peft config
eva_config = EvaConfig(rho=rho)
peft_config = LoraConfig(
r=rank, lora_alpha=alpha, target_modules=target_modules, init_lora_weights="eva", eva_config=eva_config
)

# EVA initialization
eva_state_dict = get_eva_state_dict(model, dataloader, peft_config)
eva_state_dict = {".".join(["base_model.model"] + k.split(".")[1:]): v for k, v in eva_state_dict.items()}

# cleanup ddp
model = model.module

# initialize peft model
peft_model = get_peft_model(model, peft_config, low_cpu_mem_usage=True)
initialize_lora_eva_weights(peft_model, eva_state_dict=eva_state_dict)

# setup training arguments
training_args = TrainingArguments(
per_device_train_batch_size=batch_size,
learning_rate=learning_rate,
gradient_accumulation_steps=gradient_accumulation_steps,
num_train_epochs=num_epochs,
output_dir=output_dir,
remove_unused_columns=False,
bf16=bf16,
)

# continue with standard finetuning
trainer = Trainer(
model=peft_model,
args=training_args,
train_dataset=dataset["train"],
data_collator=data_collator,
)
trainer.train()
2 changes: 1 addition & 1 deletion src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def get_peft_model(
and not low_cpu_mem_usage
):
warnings.warn(
"lora with eva initialization used with low_cpu_mem_usage=False"
"lora with eva initialization used with low_cpu_mem_usage=False. "
"Setting low_cpu_mem_usage=True can improve the maximum batch size possible for eva initialization."
)

Expand Down
48 changes: 41 additions & 7 deletions src/peft/tuners/lora/eva.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import Dict, Iterable, Optional, Union

import torch
import torch.distributed as dist
from tqdm import tqdm
from transformers.pytorch_utils import Conv1D

Expand Down Expand Up @@ -67,8 +68,34 @@ def _prepare_layer_inputs_fn_default(layer_input, model_input, layer_name) -> to
return layer_input

@torch.no_grad()
def prepare_layer_inputs(self, input):
return self._prepare_layer_inputs_fn(input, self.model_input, self.name)
def prepare_layer_inputs(self, layer_input):
return self._prepare_layer_inputs_fn(layer_input, self.model_input, self.name)

@staticmethod
def gather_layer_inputs(layer_input):
if dist.is_initialized():
world_size = dist.get_world_size()

# First gather sizes from all processes more efficiently
local_size = torch.tensor([layer_input.shape[0]], device=layer_input.device)
all_sizes = torch.empty(world_size, dtype=local_size.dtype, device=layer_input.device)
dist.all_gather_into_tensor(all_sizes, local_size)
all_sizes = all_sizes.tolist()

# Find maximum size and pad tensors
padded_input = layer_input.new_zeros((max(all_sizes), *layer_input.shape[1:]))
padded_input[: layer_input.shape[0]] = layer_input

# Gather padded tensors
gathered_inputs = [torch.zeros_like(padded_input) for _ in range(world_size)]
dist.all_gather(gathered_inputs, padded_input.contiguous())

# Remove padding for each gathered tensor
gathered_inputs = [tensor[:size] for tensor, size in zip(gathered_inputs, all_sizes)]

# Concatenate along batch dimension
return torch.cat(gathered_inputs, dim=0)
return layer_input


class SVDHook(_Hook):
Expand Down Expand Up @@ -114,6 +141,7 @@ def __call__(self, model, input, output):
if hasattr(self.svd, "components_"):
previous_components = self.svd.components_.clone().detach()
states = self.prepare_layer_inputs(input)
states = self.gather_layer_inputs(states)
# check if batch sizes is more than the number of components
if states.size(0) < self.n_components:
print(f"skipping SVD for {self.name} because there are less than {self.n_components} examples")
Expand Down Expand Up @@ -160,8 +188,9 @@ def hash_fn(tensor):

@torch.no_grad()
def __call__(self, model, input, output):
x = self.prepare_layer_inputs(input).cpu()
self.hashed_inputs.append(self.hash_fn(x))
x = self.prepare_layer_inputs(input)
x = self.gather_layer_inputs(x)
self.hashed_inputs.append(self.hash_fn(x.cpu()))


def find_equal_values(dictionary: dict) -> dict:
Expand Down Expand Up @@ -262,6 +291,9 @@ def _get_eva_state_dict(
prepare_layer_inputs_fn: Union[callable, Dict[str, callable], None],
show_progress_bar: bool,
) -> dict:
# Set seeds for reproducibility at the start of EVA computation
torch.manual_seed(0)

# Computes the rank distribution for each layer based on the explained variance ratio.
# when rank_pattern flag is False, all values in max_components are the same
def _get_rank_distribution(hooks, layer_hook_map, equal_inputs_map, rank_budget, max_components):
Expand Down Expand Up @@ -352,10 +384,12 @@ def _get_rank_distribution(hooks, layer_hook_map, equal_inputs_map, rank_budget,
layer_hook_map = {**dict(zip(hooks.keys(), hooks.keys())), **equal_inputs_map}

# start svd calculation
if show_progress_bar:
if show_progress_bar and (not dist.is_initialized() or dist.get_rank() == 0):
pbar = tqdm(iter(cycle(dataloader)), position=0, leave=False)
use_tqdm = True
else:
pbar = iter(cycle(dataloader))
use_tqdm = False
convergence_dict = {k: False for k in hooks.keys()}
rank_dist = max_components.copy()
for inputs in pbar:
Expand Down Expand Up @@ -384,7 +418,7 @@ def _get_rank_distribution(hooks, layer_hook_map, equal_inputs_map, rank_budget,
hook.model_input = model_inputs_for_hooks
hooks[name] = (hook, handle)

if show_progress_bar:
if use_tqdm:
layer_converged = list(convergence_dict.values()) + [
convergence_dict[v] for v in equal_inputs_map.values()
]
Expand Down Expand Up @@ -469,7 +503,7 @@ def _load_eva_state_dict(
elif new_rank != r:
if peft_config.eva_config.adjust_scaling_factors:
alpha *= new_rank / r
if new_rank != r or module.lora_A[adapter_name].weight.device == "meta":
if new_rank != r or module.lora_A[adapter_name].weight.device.type == "meta":
module.update_layer(r=new_rank, lora_alpha=alpha, init_lora_weights="eva", **update_layer_kwargs)
module.lora_A[adapter_name].weight.copy_(w)
new_target_modules.append(name_in_base_model)
Expand Down
Loading