Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mttl/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,10 @@ class TrainingArgs(DataArgs):
data_dir: str = os.getenv("TRAIN_DIR", "/tmp/")
output_dir: str = os.getenv("OUTPUT_DIR", "./output")

# sparse training:
use_sparse_model: bool = False
parameter_selection_procedure: str = 'per_layer' # {'per_layer': snip per layer, 'model': snip over model, 'weight_magnitude','gradient_magnitude','grow_and_drop'}

# meta-tasks or group of tasks
finetune_task_path: str = None
# name of tasks, or name of group of tasks if finetune_task_path is set
Expand Down
5 changes: 4 additions & 1 deletion mttl/models/library/library_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,10 @@ def transform(self, library) -> Expert:
used += keep_mask.sum().item()
else:
# sign majority vote
sign_per_dim = expert_weights.sign().sum(0, keepdim=True).sign()
sign_per_dim = expert_weights.sum(0, keepdim=True).sign()
# resolve zero signs: https://github.com/rezazzr/breadcrumbs/blob/main/src/task_vectors.py#L334
majority_sign = torch.sign(sign_per_dim.sum())
sign_per_dim[sign_per_dim == 0] = majority_sign

# keep only weights whose sign agree with the majority
use_for_avg = expert_weights.sign() == sign_per_dim
Expand Down Expand Up @@ -1308,3 +1310,4 @@ def create_embeddings():
for key, label in zip(expert_names, cluster_labels):
clusters[f"cluster_{label}"].append(key)
return clusters

81 changes: 81 additions & 0 deletions mttl/models/library/merging_methods/LoRA_ablinear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from dataclasses import dataclass
import torch
from mttl.models.library.expert_library import ExpertLibrary
from mttl.models.library.expert import Expert
from mttl.models.library.library_transforms import (
LibraryTransform,
LibraryTransformConfig,
)
from mttl.models.library.expert import Expert


@dataclass
class LoRA_ab_LinearMergeConfig(LibraryTransformConfig):
weights: dict = None


class LoRA_ab_LinearMerge(LibraryTransform):
"""
Computes a uniform weight mixture across experts of a given library
"""

def __init__(self, config: LoRA_ab_LinearMergeConfig = None):
super().__init__(config or LoRA_ab_LinearMergeConfig())

@torch.no_grad()
def transform(self, library) -> Expert:
if type(library) == str:
library = ExpertLibrary.get_expert_library(library)
# get expert config
from copy import deepcopy

an_expert = library[next(iter(library.keys()))]
training_config = deepcopy(an_expert.training_config)
# create a ExpertModel
from mttl.models.expert_model import ExpertModel

model = ExpertModel(**vars(training_config))
# filter the weight names
weight_names = [
n
for n in model.state_dict().keys()
if ("Wqkv.weight" in n or "out_proj.weight" in n)
]

# iterate over the library
import collections

store_W = collections.defaultdict(dict)
for expert_name, expert in library.items():
# iterate over the expert weights
for l in weight_names:
common_name = ".".join(l.split(".")[1:-1])
A, B = (
expert.expert_weights[f"{common_name}.lora_a"],
expert.expert_weights[f"{common_name}.lora_b"],
)
W = A @ B
store_W[l][expert_name] = W

store_average_W = collections.defaultdict(dict)
# iterate over all the layers of W
for l in weight_names:
average_W = 0
for k, v in store_W[l].items():
average_W += v
average_W /= len(store_W[l])
store_average_W[l] = average_W
# average the Ws for each layer

new_state_dict = {}
# add the averaged Ws to the model
for key, value in model.state_dict().items():
if key in weight_names:
print(f"added {key}")
new_state_dict[key] = value + store_average_W[key].T
else:
new_state_dict[key] = value

# load state_dict into model
model.load_state_dict(new_state_dict)
return model
227 changes: 227 additions & 0 deletions mttl/models/library/merging_methods/base_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import copy
from dataclasses import dataclass
from mttl.logging import logger
from mttl.models.library.expert_library import ExpertLibrary
from mttl.models.library.library_transforms import (
LibraryTransform,
LibraryTransformConfig,
)
from mttl.models.expert_model import ExpertModel
from mttl.models.library.merging_methods import load_mask, convert_idx_2_mask
import torch


@dataclass
class BaseMergeConfig(LibraryTransformConfig):
merging_method: str = "BaseMerge"


class BaseMerge(LibraryTransform):
"""
Base class for TIES-Merge and Model-Breadcrumbs: Computes a merged weight across experts of a given library
"""

def __init__(self, config: BaseMergeConfig = None):
super().__init__(config or BaseMergeConfig())

@torch.no_grad()
def pre_configure(self, library):
if type(library) == str:
library = ExpertLibrary.get_expert_library(library)
expert_names = list(library.keys())
experts = [library[name] for name in expert_names]
logger.info("Averaging {} experts".format(len(experts)))
expert_type = experts[0].training_config.model_modifier
if expert_type is None:
expert_type = "FFT"

# transform experts. NOTE: MUST
self.transform_experts(experts, expert_type)

# get base expert
base_expert = copy.deepcopy(experts[0])
base_expert.name = self.config.merging_method
train_cfg = copy.deepcopy(base_expert.training_config)
train_cfg.device_map = "cpu"
trainable_params = list(
base_expert.expert_weights.keys()
) # 'model.layers.0.self_attn.o_proj.weight'

# get base model
from mttl.models.utils import model_loader_helper

base_model = model_loader_helper(
train_cfg.model,
load_in_8bit=train_cfg.load_in_8bit,
load_in_4bit=train_cfg.load_in_4bit,
device_map=getattr(train_cfg, "device_map", "cpu"),
)

return experts, expert_type, base_expert, base_model, trainable_params

@torch.no_grad()
def extract_expert_vector(
self, experts, expert_type, base_model_state_dict, trainable_params
):
# Build n_tasks x D experts
expert_vectors = []
for expert in experts:
if expert_type == "FFT":
# W_t = W_base - W'
expert_vectors += [
torch.nn.utils.parameters_to_vector(
list(
expert.expert_weights[k] - base_model_state_dict[k]
for k in trainable_params
)
)
]
# W_t = (lora_a*lora_b).T
# NOTE: it's already done when we call self.transform_experts()
elif expert_type == "lora":
weights_list = [
param.contiguous()
for param in (expert.expert_weights[k] for k in trainable_params)
]
expert_vectors += [torch.nn.utils.parameters_to_vector(weights_list)]

return expert_vectors

@torch.no_grad()
def extract_expert_weight(
self, base_model_state_dict, experts, param_name, expert_type
):
# for given "param_name", iterates over all expert and gets the trained expert-weights
if expert_type == "FFT":
expert_weights = torch.stack(
[
expert.expert_weights[param_name]
- base_model_state_dict[param_name]
for expert in experts
],
dim=0,
)
elif expert_type == "lora":
expert_weights = torch.stack(
[expert.expert_weights[param_name] for expert in experts], dim=0
)
elif expert_type == "sparse_mask_adapter":
expert_weights = torch.stack(
[expert.expert_weights[param_name] for expert in experts], dim=0
)
return expert_weights

@torch.no_grad()
def transform_experts(self, experts, expert_type):
assert expert_type in ["FFT", "lora", "sparse_mask_adapter"], print(
f"{expert_type} is not implemented"
)
if expert_type == "FFT":
pass
elif expert_type == "lora":
# Lora
base_expert = copy.deepcopy(experts[0])
trainable_layers = [
".".join(l.split(".")[:-1])
for l in list(base_expert.expert_weights.keys())
if "qkv_proj" in l
] ## 'layers.0.self_attn.o_proj'
trainable_layers = list(
dict.fromkeys(trainable_layers)
) # removes duplicate layers (loraA,loraB), w/o changing layer order
for expert in experts:
for l in trainable_layers:
expert.expert_weights[f"{l}.weight"] = (
expert.expert_weights[f"{l}.lora_a"]
@ expert.expert_weights[f"{l}.lora_b"]
).T
del (
expert.expert_weights[f"{l}.lora_a"],
expert.expert_weights[f"{l}.lora_b"],
)

# NOTE: sanity check, please don't remove this block
for expert in experts:
for l in list(expert.expert_weights.keys()):
if "lora_a" in l or "lora_b" in l:
del expert.expert_weights[l]

elif expert_type == "sparse_mask_adapter":
base_expert = copy.deepcopy(experts[0])
trainable_layers = [
".".join(l.split(".")[:-1])
for l in list(base_expert.expert_weights.keys())
if "qkv_proj" in l
] ## 'layers.0.self_attn.o_proj'
trainable_layers = list(
dict.fromkeys(trainable_layers)
) # removes duplicate layers (loraA,loraB), w/o changing layer order

for expert in experts:
Mask = load_mask(expert)
# for each layer compute the "average of the overlapped weight"
for l in trainable_layers:
# weight
m = convert_idx_2_mask(
weight_idx=Mask[f"model.{l}"],
mat_dim=expert.expert_weights[f"{l}.weight"].shape,
)
expert.expert_weights[f"{l}.weight"] = (
expert.expert_weights[f"{l}.weight"] * m
)
expert.expert_weights[f"{l}.bias"] = expert.expert_weights[
f"{l}.bias"
]

@torch.no_grad()
def compute_per_task_threhold(self, expert_vectors):
# Given expert vector, W_task compute TH score to prune parameters
# NOTE: W_task = W_base - W'
pass

@torch.no_grad()
def merge_expert(
self,
experts,
expert_vectors,
trainable_params,
base_expert,
base_model_state_dict,
expert_type,
):
pass

@torch.no_grad()
def transform(self, library):
experts, expert_type, base_expert, base_model, trainable_params = (
self.pre_configure(library)
)
base_model_state_dict = dict(base_model.state_dict())
# ----------------------------------------------------------------------
# Collect Expert-vector
# for FFT: expert_vectors, delta_W = W-W'
# for LoRA: expert_vectors, delta_W = A.B
# for sparse: expert_vectors, delta_W = delta_W*mask (NOTE: not implemented yet)
expert_vectors = self.extract_expert_vector(
experts, expert_type, base_model_state_dict, trainable_params
)
base_expert = self.merge_expert(
experts,
expert_vectors,
trainable_params,
base_expert,
base_model_state_dict,
expert_type,
)
# load to base model:
config = base_expert.training_config
config.model_modifier = None # load only the base model
config.device_map = "cpu"
config.trainable_param_names = ".*" # allows to train all linear layers
base_model = ExpertModel(**vars(config))
# load state_dict into model
assert set(base_model.model.state_dict().keys()) == set(
base_expert.expert_weights.keys()
), "Expert weights must have the same keys"
base_model.model.load_state_dict(base_expert._expert_weights)
return base_model
Loading