diff --git a/mttl/arguments.py b/mttl/arguments.py index cc10e5788..73cfe8cdd 100644 --- a/mttl/arguments.py +++ b/mttl/arguments.py @@ -337,6 +337,10 @@ class TrainingArgs(DataArgs): data_dir: str = os.getenv("TRAIN_DIR", "/tmp/") output_dir: str = os.getenv("OUTPUT_DIR", "./output") + # sparse training: + use_sparse_model: bool = False + parameter_selection_procedure: str = 'per_layer' # {'per_layer': snip per layer, 'model': snip over model, 'weight_magnitude','gradient_magnitude','grow_and_drop'} + # meta-tasks or group of tasks finetune_task_path: str = None # name of tasks, or name of group of tasks if finetune_task_path is set diff --git a/mttl/models/library/library_transforms.py b/mttl/models/library/library_transforms.py index 0de5c1824..33e85e087 100644 --- a/mttl/models/library/library_transforms.py +++ b/mttl/models/library/library_transforms.py @@ -339,8 +339,10 @@ def transform(self, library) -> Expert: used += keep_mask.sum().item() else: # sign majority vote - sign_per_dim = expert_weights.sign().sum(0, keepdim=True).sign() sign_per_dim = expert_weights.sum(0, keepdim=True).sign() + # resolve zero signs: https://github.com/rezazzr/breadcrumbs/blob/main/src/task_vectors.py#L334 + majority_sign = torch.sign(sign_per_dim.sum()) + sign_per_dim[sign_per_dim == 0] = majority_sign # keep only weights whose sign agree with the majority use_for_avg = expert_weights.sign() == sign_per_dim @@ -1308,3 +1310,4 @@ def create_embeddings(): for key, label in zip(expert_names, cluster_labels): clusters[f"cluster_{label}"].append(key) return clusters + diff --git a/mttl/models/library/merging_methods/LoRA_ablinear.py b/mttl/models/library/merging_methods/LoRA_ablinear.py new file mode 100644 index 000000000..697bf974d --- /dev/null +++ b/mttl/models/library/merging_methods/LoRA_ablinear.py @@ -0,0 +1,81 @@ +from dataclasses import dataclass +import torch +from mttl.models.library.expert_library import ExpertLibrary +from mttl.models.library.expert import Expert +from mttl.models.library.library_transforms import ( + LibraryTransform, + LibraryTransformConfig, +) +from mttl.models.library.expert import Expert + + +@dataclass +class LoRA_ab_LinearMergeConfig(LibraryTransformConfig): + weights: dict = None + + +class LoRA_ab_LinearMerge(LibraryTransform): + """ + Computes a uniform weight mixture across experts of a given library + """ + + def __init__(self, config: LoRA_ab_LinearMergeConfig = None): + super().__init__(config or LoRA_ab_LinearMergeConfig()) + + @torch.no_grad() + def transform(self, library) -> Expert: + if type(library) == str: + library = ExpertLibrary.get_expert_library(library) + # get expert config + from copy import deepcopy + + an_expert = library[next(iter(library.keys()))] + training_config = deepcopy(an_expert.training_config) + # create a ExpertModel + from mttl.models.expert_model import ExpertModel + + model = ExpertModel(**vars(training_config)) + # filter the weight names + weight_names = [ + n + for n in model.state_dict().keys() + if ("Wqkv.weight" in n or "out_proj.weight" in n) + ] + + # iterate over the library + import collections + + store_W = collections.defaultdict(dict) + for expert_name, expert in library.items(): + # iterate over the expert weights + for l in weight_names: + common_name = ".".join(l.split(".")[1:-1]) + A, B = ( + expert.expert_weights[f"{common_name}.lora_a"], + expert.expert_weights[f"{common_name}.lora_b"], + ) + W = A @ B + store_W[l][expert_name] = W + + store_average_W = collections.defaultdict(dict) + # iterate over all the layers of W + for l in weight_names: + average_W = 0 + for k, v in store_W[l].items(): + average_W += v + average_W /= len(store_W[l]) + store_average_W[l] = average_W + # average the Ws for each layer + + new_state_dict = {} + # add the averaged Ws to the model + for key, value in model.state_dict().items(): + if key in weight_names: + print(f"added {key}") + new_state_dict[key] = value + store_average_W[key].T + else: + new_state_dict[key] = value + + # load state_dict into model + model.load_state_dict(new_state_dict) + return model diff --git a/mttl/models/library/merging_methods/base_merge.py b/mttl/models/library/merging_methods/base_merge.py new file mode 100644 index 000000000..9496672e8 --- /dev/null +++ b/mttl/models/library/merging_methods/base_merge.py @@ -0,0 +1,227 @@ +import copy +from dataclasses import dataclass +from mttl.logging import logger +from mttl.models.library.expert_library import ExpertLibrary +from mttl.models.library.library_transforms import ( + LibraryTransform, + LibraryTransformConfig, +) +from mttl.models.expert_model import ExpertModel +from mttl.models.library.merging_methods import load_mask, convert_idx_2_mask +import torch + + +@dataclass +class BaseMergeConfig(LibraryTransformConfig): + merging_method: str = "BaseMerge" + + +class BaseMerge(LibraryTransform): + """ + Base class for TIES-Merge and Model-Breadcrumbs: Computes a merged weight across experts of a given library + """ + + def __init__(self, config: BaseMergeConfig = None): + super().__init__(config or BaseMergeConfig()) + + @torch.no_grad() + def pre_configure(self, library): + if type(library) == str: + library = ExpertLibrary.get_expert_library(library) + expert_names = list(library.keys()) + experts = [library[name] for name in expert_names] + logger.info("Averaging {} experts".format(len(experts))) + expert_type = experts[0].training_config.model_modifier + if expert_type is None: + expert_type = "FFT" + + # transform experts. NOTE: MUST + self.transform_experts(experts, expert_type) + + # get base expert + base_expert = copy.deepcopy(experts[0]) + base_expert.name = self.config.merging_method + train_cfg = copy.deepcopy(base_expert.training_config) + train_cfg.device_map = "cpu" + trainable_params = list( + base_expert.expert_weights.keys() + ) # 'model.layers.0.self_attn.o_proj.weight' + + # get base model + from mttl.models.utils import model_loader_helper + + base_model = model_loader_helper( + train_cfg.model, + load_in_8bit=train_cfg.load_in_8bit, + load_in_4bit=train_cfg.load_in_4bit, + device_map=getattr(train_cfg, "device_map", "cpu"), + ) + + return experts, expert_type, base_expert, base_model, trainable_params + + @torch.no_grad() + def extract_expert_vector( + self, experts, expert_type, base_model_state_dict, trainable_params + ): + # Build n_tasks x D experts + expert_vectors = [] + for expert in experts: + if expert_type == "FFT": + # W_t = W_base - W' + expert_vectors += [ + torch.nn.utils.parameters_to_vector( + list( + expert.expert_weights[k] - base_model_state_dict[k] + for k in trainable_params + ) + ) + ] + # W_t = (lora_a*lora_b).T + # NOTE: it's already done when we call self.transform_experts() + elif expert_type == "lora": + weights_list = [ + param.contiguous() + for param in (expert.expert_weights[k] for k in trainable_params) + ] + expert_vectors += [torch.nn.utils.parameters_to_vector(weights_list)] + + return expert_vectors + + @torch.no_grad() + def extract_expert_weight( + self, base_model_state_dict, experts, param_name, expert_type + ): + # for given "param_name", iterates over all expert and gets the trained expert-weights + if expert_type == "FFT": + expert_weights = torch.stack( + [ + expert.expert_weights[param_name] + - base_model_state_dict[param_name] + for expert in experts + ], + dim=0, + ) + elif expert_type == "lora": + expert_weights = torch.stack( + [expert.expert_weights[param_name] for expert in experts], dim=0 + ) + elif expert_type == "sparse_mask_adapter": + expert_weights = torch.stack( + [expert.expert_weights[param_name] for expert in experts], dim=0 + ) + return expert_weights + + @torch.no_grad() + def transform_experts(self, experts, expert_type): + assert expert_type in ["FFT", "lora", "sparse_mask_adapter"], print( + f"{expert_type} is not implemented" + ) + if expert_type == "FFT": + pass + elif expert_type == "lora": + # Lora + base_expert = copy.deepcopy(experts[0]) + trainable_layers = [ + ".".join(l.split(".")[:-1]) + for l in list(base_expert.expert_weights.keys()) + if "qkv_proj" in l + ] ## 'layers.0.self_attn.o_proj' + trainable_layers = list( + dict.fromkeys(trainable_layers) + ) # removes duplicate layers (loraA,loraB), w/o changing layer order + for expert in experts: + for l in trainable_layers: + expert.expert_weights[f"{l}.weight"] = ( + expert.expert_weights[f"{l}.lora_a"] + @ expert.expert_weights[f"{l}.lora_b"] + ).T + del ( + expert.expert_weights[f"{l}.lora_a"], + expert.expert_weights[f"{l}.lora_b"], + ) + + # NOTE: sanity check, please don't remove this block + for expert in experts: + for l in list(expert.expert_weights.keys()): + if "lora_a" in l or "lora_b" in l: + del expert.expert_weights[l] + + elif expert_type == "sparse_mask_adapter": + base_expert = copy.deepcopy(experts[0]) + trainable_layers = [ + ".".join(l.split(".")[:-1]) + for l in list(base_expert.expert_weights.keys()) + if "qkv_proj" in l + ] ## 'layers.0.self_attn.o_proj' + trainable_layers = list( + dict.fromkeys(trainable_layers) + ) # removes duplicate layers (loraA,loraB), w/o changing layer order + + for expert in experts: + Mask = load_mask(expert) + # for each layer compute the "average of the overlapped weight" + for l in trainable_layers: + # weight + m = convert_idx_2_mask( + weight_idx=Mask[f"model.{l}"], + mat_dim=expert.expert_weights[f"{l}.weight"].shape, + ) + expert.expert_weights[f"{l}.weight"] = ( + expert.expert_weights[f"{l}.weight"] * m + ) + expert.expert_weights[f"{l}.bias"] = expert.expert_weights[ + f"{l}.bias" + ] + + @torch.no_grad() + def compute_per_task_threhold(self, expert_vectors): + # Given expert vector, W_task compute TH score to prune parameters + # NOTE: W_task = W_base - W' + pass + + @torch.no_grad() + def merge_expert( + self, + experts, + expert_vectors, + trainable_params, + base_expert, + base_model_state_dict, + expert_type, + ): + pass + + @torch.no_grad() + def transform(self, library): + experts, expert_type, base_expert, base_model, trainable_params = ( + self.pre_configure(library) + ) + base_model_state_dict = dict(base_model.state_dict()) + # ---------------------------------------------------------------------- + # Collect Expert-vector + # for FFT: expert_vectors, delta_W = W-W' + # for LoRA: expert_vectors, delta_W = A.B + # for sparse: expert_vectors, delta_W = delta_W*mask (NOTE: not implemented yet) + expert_vectors = self.extract_expert_vector( + experts, expert_type, base_model_state_dict, trainable_params + ) + base_expert = self.merge_expert( + experts, + expert_vectors, + trainable_params, + base_expert, + base_model_state_dict, + expert_type, + ) + # load to base model: + config = base_expert.training_config + config.model_modifier = None # load only the base model + config.device_map = "cpu" + config.trainable_param_names = ".*" # allows to train all linear layers + base_model = ExpertModel(**vars(config)) + # load state_dict into model + assert set(base_model.model.state_dict().keys()) == set( + base_expert.expert_weights.keys() + ), "Expert weights must have the same keys" + base_model.model.load_state_dict(base_expert._expert_weights) + return base_model diff --git a/mttl/models/library/merging_methods/model_breadcrumbs.py b/mttl/models/library/merging_methods/model_breadcrumbs.py new file mode 100644 index 000000000..505bc21d6 --- /dev/null +++ b/mttl/models/library/merging_methods/model_breadcrumbs.py @@ -0,0 +1,90 @@ +import copy +import torch +from dataclasses import dataclass +from mttl.models.library.merge_models.base_merge import BaseMerge, BaseMergeConfig +from mttl.models.library.expert_library import ExpertLibrary +from mttl.models.library.merge_models.utils import topk_multiple_experts +from mttl.models.utils import logger + + +@dataclass +class ModelBreadcrumbsConfig(BaseMergeConfig): + merging_method: str = "model_breadcrumbs_expert" + alpha: float = ( + 0.4 # scaling factor # source : (a) https://openreview.net/pdf?id=6t0Kwf8-jrj (b) https://arxiv.org/pdf/2306.01708 + ) + beta: float = 0.9 # 90% beta=sparsity, keep-ratio=1-beta + gamma: float = 0.99 # mask out top 1% + + +class ModelBreadcrumbs(BaseMerge): + """ + Computes a uniform weight mixture across experts of a given library + """ + + def __init__(self, config: ModelBreadcrumbsConfig = None): + super().__init__(config or ModelBreadcrumbsConfig()) + + @torch.no_grad() + def compute_per_task_threhold(self, expert_vectors): + + # take the absolute value: + expert_vectors = torch.stack(expert_vectors, dim=0).abs() + lower_topk = int(expert_vectors.size(1) * self.config.beta) + upper_topk = int(expert_vectors.size(1) * self.config.gamma) + + per_exp_lth = topk_multiple_experts(expert_vectors, lower_topk, TH_type="lower") + per_exp_uth = topk_multiple_experts(expert_vectors, upper_topk, TH_type="upper") + + return per_exp_lth, per_exp_uth + + @torch.no_grad() + def merge_expert( + self, + experts, + expert_vectors, + trainable_params, + base_expert, + base_model_state_dict, + expert_type, + ): + # Compute Threshold score, TH + per_exp_lth, per_exp_uth = self.compute_per_task_threhold(expert_vectors) + used, total = 0, 0 + for param_name in base_model_state_dict.keys(): + if param_name in trainable_params: + # stack the expert weights + expert_weights = self.extract_expert_weight( + base_model_state_dict, experts, param_name, expert_type + ) + + # keep weights over the threshold + Lower_TH = per_exp_lth.view(-1, *((1,) * (expert_weights.ndim - 1))) + Upper_TH = per_exp_uth.view(-1, *((1,) * (expert_weights.ndim - 1))) + + keep_mask = torch.logical_and( + expert_weights.abs() > Lower_TH, expert_weights.abs() < Upper_TH + ) + # keep_mask = (expert_weights.abs() > Lower_TH and expert_weights.abs() < Upper_TH) + expert_weights = expert_weights * keep_mask + + # base_weight + sum of the "filtered" task-vector + final_param = base_model_state_dict[ + param_name + ] + self.config.alpha * expert_weights.sum(0) + + used += keep_mask.sum().item() + total += expert_weights.numel() + + base_expert.expert_weights[param_name].data.copy_(final_param) + else: + base_expert.expert_weights[param_name] = copy.deepcopy( + base_model_state_dict[param_name] + ) + logger.info( + "Params used to compute Model-breadcrumb mean: {:.10f}%".format( + 100.0 * used / total + ) + ) + + return base_expert diff --git a/mttl/models/library/merging_methods/slerp.py b/mttl/models/library/merging_methods/slerp.py new file mode 100644 index 000000000..8cd9dba6d --- /dev/null +++ b/mttl/models/library/merging_methods/slerp.py @@ -0,0 +1,185 @@ +import copy +import torch +import numpy as np +from dataclasses import dataclass +from huggingface_hub import hf_hub_download +from mttl.models.library.expert_library import ExpertLibrary +from mttl.models.library.library_transforms import ( + LibraryTransform, + LibraryTransformConfig, +) +from mttl.models.expert_model import ExpertModel +from mttl.models.library.expert import Expert + + +# ----------------------------------- +# SLERP and LERP implementation +# ----------------------------------- +def lerp(t, v0, v1, origin_data_type=None): + v2 = (1 - t) * v0 + t * v1 + return torch.from_numpy(v2).to(origin_data_type) + + +# SLERP +def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): + """ + Spherical linear interpolation + Args: + t (float/np.ndarray): Float value between 0.0 and 1.0 + v0 (np.ndarray): Starting vector + v1 (np.ndarray): Final vector + DOT_THRESHOLD (float): Threshold for considering the two vectors as + colineal. Not recommended to alter this. + Returns: + v2 (np.ndarray): Interpolation vector between v0 and v1 + """ + origin_data_type = v0.dtype + v0 = v0.detach().cpu().float().numpy() + v1 = v1.detach().cpu().float().numpy() + + # Copy the vectors to reuse them later + v0_copy = np.copy(v0) + v1_copy = np.copy(v1) + # Normalize the vectors to get the directions and angles + v0 = v0 / np.linalg.norm(v0) + v1 = v1 / np.linalg.norm(v1) + # Dot product with the normalized vectors (can't use np.dot in W) + dot = np.sum(v0 * v1) + # If absolute value of dot product is almost 1, vectors are ~colineal, so use lerp + if np.abs(dot) > DOT_THRESHOLD: + return lerp(t, v0_copy, v1_copy, origin_data_type=origin_data_type) + # Calculate initial angle between v0 and v1 + theta_0 = np.arccos(dot) + sin_theta_0 = np.sin(theta_0) + # Angle at timestep t + theta_t = theta_0 * t + sin_theta_t = np.sin(theta_t) + # Finish the slerp algorithm + s0 = np.sin(theta_0 - theta_t) / sin_theta_0 + s1 = sin_theta_t / sin_theta_0 + v2 = s0 * v0_copy + s1 * v1_copy + + return torch.from_numpy(v2).to(origin_data_type) + + +@dataclass +class SLERPMergeConfig(LibraryTransformConfig): + weights: dict = None + + +class SLERPMerge(LibraryTransform): + """ + Computes a uniform weight mixture across experts of a given library + """ + + def __init__(self, config: SLERPMergeConfig = None): + super().__init__(config or SLERPMergeConfig()) + + def load_mask(self, expert): + try: + print("trying to load mask from hf") + library_id = expert.training_config.library_id + destination_type, f_name = library_id.split("://") + repo_id = ("/").join(f_name.split("/")[:2]) + filename = f"{expert.expert_info.expert_name}_mask.npz" + f_path = hf_hub_download(repo_id=repo_id, filename=filename) + Mask = np.load(f_path, allow_pickle=True)["arr"].item() + except: + print("trying to load mask from local dir") + m_loc = f"experiment/{expert.training_config.exp_name}/mask.npz" + Mask = np.load(m_loc, allow_pickle=True)["arr"].item() + return Mask + + def convert_idx_2_mask(self, weight_idx, mat_dim): + m = np.zeros(mat_dim) + m[tuple(zip(*weight_idx))] = 1 + return torch.FloatTensor(m) + + def sparse_SLERP(self, model, experts, base_expert): + base_expert_mask = self.load_mask(base_expert) + for layer, _ in base_expert_mask.items(): + mod_layer = ".".join(layer.split(".")[1:]) + base_expert_mask[layer] = self.convert_idx_2_mask( + weight_idx=base_expert_mask[layer], + mat_dim=base_expert.expert_weights[f"{mod_layer}.weight"].shape, + ) + weight_names = [n for n in model.state_dict().keys() if "sparse_layer" in n] + + for expert in experts: + mask = self.load_mask(expert) + # for expert_name, expert in library.items(): + for layer, weight in model.state_dict().items(): + if layer in weight_names: + common_name = ".".join(layer.split(".")[1:-1]) + param_type = layer.split(".")[-1] + + if param_type == "weight": + # get mask-m for layer-l: convert the weight_indx to convert sparse-mask + m = self.convert_idx_2_mask( + weight_idx=mask[f"model.{common_name}"], + mat_dim=expert.expert_weights[ + f"{common_name}.weight" + ].shape, + ) + bm = base_expert_mask[f"model.{common_name}"] + else: + m = 1.0 + bm = 1.0 + base_expert._expert_weights[f"{common_name}.{param_type}"] = slerp( + float(1.0) - 0.5, + v0=base_expert._expert_weights[f"{common_name}.{param_type}"] + * bm, + v1=expert._expert_weights[f"{common_name}.{param_type}"] * m, + ) + if param_type == "weight": + base_expert_mask[f"model.{common_name}"] = torch.logical_or( + m, bm + ).float() + + updated_state_dict = {} + for layer, weight in model.state_dict().items(): + if layer in weight_names: + mod_layer = ".".join(layer.split(".")[1:]) + updated_state_dict[layer] = base_expert._expert_weights[mod_layer] + else: + updated_state_dict[layer] = weight + # load state_dict into model + model.load_state_dict(updated_state_dict) + return model + + def FFT_SLERP(self, model, experts, base_expert): + for expert in experts: + for layer, _ in model.state_dict().items(): + common_name = ".".join(layer.split(".")[1:-1]) + param_type = layer.split(".")[-1] + base_expert._expert_weights[f"{common_name}.{param_type}"] = slerp( + float(1.0) - 0.5, + v0=base_expert._expert_weights[f"{common_name}.{param_type}"], + v1=expert._expert_weights[f"{common_name}.{param_type}"], + ) + updated_state_dict = {} + for layer, _ in model.state_dict().items(): + mod_layer = ".".join(layer.split(".")[1:]) + updated_state_dict[layer] = base_expert._expert_weights[mod_layer] + # load state_dict into model + model.load_state_dict(updated_state_dict) + return model + + @torch.no_grad() + def transform(self, library) -> Expert: + if type(library) == str: + library = ExpertLibrary.get_expert_library(library) + + expert_names = list(library.keys()) + experts = [library[name] for name in expert_names] + base_expert = copy.deepcopy(experts[0]) + training_config = copy.deepcopy(base_expert.training_config) + from mttl.models.expert_model import ExpertModel + + model = ExpertModel(**vars(training_config)) + if training_config.model_modifier == None: + # skip the first expert as it's now acting as base_expert + model = self.FFT_SLERP(model, experts[1:], base_expert) + elif training_config.model_modifier == "sparse_mask_adapter": + model = self.sparse_SLERP(model, experts[1:], base_expert) + return model diff --git a/mttl/models/library/merging_methods/sparse_merge.py b/mttl/models/library/merging_methods/sparse_merge.py new file mode 100644 index 000000000..f642d6516 --- /dev/null +++ b/mttl/models/library/merging_methods/sparse_merge.py @@ -0,0 +1,206 @@ +import os +import json +import collections +import torch +from huggingface_hub import hf_hub_download +from copy import deepcopy +import numpy as np +from dataclasses import dataclass +from mttl.models.library.expert_library import ExpertLibrary +from mttl.models.library.expert import Expert +from mttl.models.library.library_transforms import ( + LibraryTransform, + LibraryTransformConfig, +) +from mttl.models.expert_model import ExpertModel +from mttl.models.library.merge_models.utils import load_mask, convert_idx_2_mask + + +@dataclass +class SparseWeightLinearMergeConfig(LibraryTransformConfig): + weights: dict = None + + +class SparseWeightLinearMerge(LibraryTransform): + """ + Computes a uniform weight mixture across experts of a given library + """ + + def __init__(self, config: SparseWeightLinearMergeConfig = None): + super().__init__(config or SparseWeightLinearMergeConfig()) + + def update_module_mask(self, module, expert): + Mask = load_mask(expert) + for m_name, m in dict(module.named_modules()).items(): + if "sparse_layer" in m_name: + keep_mask = convert_idx_2_mask( + weight_idx=Mask[m_name], mat_dim=m.weight_mask.shape + ) + + m.weight_mask = keep_mask.data.clone() + + @torch.no_grad() + def transform(self, library) -> Expert: + if type(library) == str: + library = ExpertLibrary.get_expert_library(library) + # get expert config + an_expert = library[next(iter(library.keys()))] + training_config = deepcopy(an_expert.training_config) + # create a ExpertModel + from mttl.models.expert_model import ExpertModel + + print("Trying to load base model from:", training_config.model) + model = ExpertModel(**vars(training_config)) + sparse_layer_names = [ + n for n in model.state_dict().keys() if ("sparse_layer" in n) + ] # allow to add weights and bias + assert sparse_layer_names != [], print("could not find sparse-layer modules") + + # iterate over the library + store_W = collections.defaultdict(dict) + store_m = collections.defaultdict(dict) + expert_weight_hist = collections.defaultdict(dict) # weight stats + + for expert_name, expert in library.items(): + print(f"Merging sparse weight for task: {expert_name}") + Mask = self.load_mask(expert) + expert_weight_hist[expert_name] = collections.defaultdict( + dict + ) # weight stats + # for each layer compute the "average of the overlapped weight" + for l in sparse_layer_names: + common_name = ".".join(l.split(".")[1:-1]) + param_type = l.split(".")[-1] + + if param_type == "weight": + # get mask-m for layer-l: convert the weight_indx to convert sparse-mask + m = self.convert_idx_2_mask( + weight_idx=Mask[f"model.{common_name}"], + mat_dim=expert.expert_weights[f"{common_name}.weight"].shape, + ) + else: + m = 1.0 + # Check if entry exists + if l in store_W: + # store weight + store_W[l] += ( + expert.expert_weights[f"{common_name}.{param_type}"] * m + ) + # store mask + store_m[l] += m + # new entry for expert 1 + else: + store_W[l] = ( + expert.expert_weights[f"{common_name}.{param_type}"] * m + ) + store_m[l] = m + expert_weight_hist[expert_name][common_name] = ( + float( + expert.expert_weights[f"{common_name}.weight"] + .mean() + .data.numpy() + ), + float( + expert.expert_weights[f"{common_name}.weight"] + .std() + .data.numpy() + ), + ) # weight stats + + # we sum the total per-layer weight overlap and devide the count as an alternate to calculate accurate weight average + for l in sparse_layer_names: + param_type = l.split(".")[-1] + if param_type == "weight": + store_m[l][ + store_m[l] == 0 + ] = 1 # assigning 1 to the zero-masked weights positions to avoid numerical error in the next step + store_W[l] /= store_m[l] + else: + store_W[l] /= len(library) + new_state_dict = {} + # add the averaged Ws to the model + for key, value in model.state_dict().items(): + if key in sparse_layer_names: + print(f"added {key}") + new_state_dict[key] = value + store_W[key] + else: + new_state_dict[key] = value + + # load state_dict into model + model.load_state_dict(new_state_dict) + # save weights stats + exp_temp = training_config.library_id.split("/")[-1] + file_loc = f"Weight_Stats/{exp_temp}" + os.makedirs(file_loc, exist_ok=True) + with open(f"{file_loc}/weight_stats.json", "w") as json_file: + json.dump(expert_weight_hist, json_file) + + return model + + @torch.no_grad() + def transform_dummy(self, library, get_expert) -> Expert: + if type(library) == str: + library = ExpertLibrary.get_expert_library(library) + # ============================= + # get expert config + from copy import deepcopy + + an_expert = library[next(iter(library.keys()))] + training_config = deepcopy(an_expert.training_config) + # create a ExpertModel + from mttl.models.expert_model import ExpertModel + + model = ExpertModel(**vars(training_config)) + # filter the weight names + # weight_names = [n for n in model.state_dict().keys() if ('Wqkv.weight' in n or 'out_proj.weight' in n)] + weight_names = [ + n + for n in model.state_dict().keys() + if ("Wqkv.sparse_layer.weight" in n or "out_proj.sparse_layer.weight" in n) + ] + + # iterate over the library + import collections + + store_W = collections.defaultdict(dict) + store_m = collections.defaultdict(dict) + for expert_name, expert in library.items(): + # TODO: only consider the weights that matches the given `get_expert` input + if expert_name == get_expert: + print(f"Merging sparse weight for task: {expert_name}") + Mask = self.load_mask(expert) + # for each layer compute the "average of the overlapped weight" + for l in weight_names: + common_name = ".".join(l.split(".")[1:-1]) + # get mask-m for layer-l: convert the weight_indx to convert sparse-mask + m = self.convert_idx_2_mask( + weight_idx=Mask[f"model.{common_name}"], + mat_dim=expert.expert_weights[f"{common_name}.weight"].shape, + ) + if l in store_W.keys(): + # store weight + store_W[l] += expert.expert_weights[f"{common_name}.weight"] * m + # store mask + store_m[l] += m + else: + store_W[l] = expert.expert_weights[f"{common_name}.weight"] * m + store_m[l] = m + # we sum the total per-layer weight overlap and devide the count as an alternate to calculate accurate weight average + for l in weight_names: + store_m[l][ + store_m[l] == 0 + ] = 1 # assigning 1 to the zero-masked weights positions to avoid numerical error in the next step + store_W[l] /= store_m[l] + + new_state_dict = {} + # add the averaged Ws to the model + for key, value in model.state_dict().items(): + if key in weight_names: + print(f"added {key}") + new_state_dict[key] = value + store_W[key] + else: + new_state_dict[key] = value + + # load state_dict into model + model.load_state_dict(new_state_dict) + return model diff --git a/mttl/models/library/merging_methods/task_arithmetic.py b/mttl/models/library/merging_methods/task_arithmetic.py new file mode 100644 index 000000000..b3f112184 --- /dev/null +++ b/mttl/models/library/merging_methods/task_arithmetic.py @@ -0,0 +1,67 @@ +import copy +import torch +from dataclasses import dataclass +from mttl.models.library.merge_models.base_merge import BaseMerge, BaseMergeConfig +from mttl.models.library.expert_library import ExpertLibrary +from mttl.models.utils import logger + + +@dataclass +class TaskArithmeticConfig(BaseMergeConfig): + merging_method: str = "task_arithmetic_merge_expert" + alpha: float = 0.4 # scaling + + +class TaskArithmetic(BaseMerge): + """ + Computes a uniform weight mixture across experts of a given library + """ + + def __init__(self, config: TaskArithmeticConfig = None): + super().__init__(config or TaskArithmeticConfig()) + + def merge_expert( + self, + experts, + expert_vectors, + trainable_params, + base_expert, + base_model_state_dict, + expert_type, + ): + used, total = 0, 0 + for param_name in base_model_state_dict.keys(): + if param_name in trainable_params: + # stack the expert weights + expert_weights = self.extract_expert_weight( + base_model_state_dict, experts, param_name, expert_type + ) + + # collect mask + keep_mask = torch.ones_like(expert_weights) + # keep weights over the threshold + expert_weights = expert_weights * keep_mask + # NOTE: sum for Task-Arithmetic + final_param = expert_weights.sum(0) + + # ----------------------------------------------------- + # base_weight + sum of the "filtered" task-vector + # W = W + delta_W + # source : (a) https://openreview.net/pdf?id=6t0Kwf8-jrj (b) https://arxiv.org/pdf/2306.01708 + final_param = ( + base_model_state_dict[param_name] + self.config.alpha * final_param + ) + + used += keep_mask.sum().item() + total += expert_weights.numel() + + base_expert.expert_weights[param_name].data.copy_(final_param) + else: + base_expert.expert_weights[param_name] = copy.deepcopy( + base_model_state_dict[param_name] + ) + + logger.info( + "Params used to compute Ties mean: {:.10f}%".format(100.0 * used / total) + ) + return base_expert diff --git a/mttl/models/library/merging_methods/ties.py b/mttl/models/library/merging_methods/ties.py new file mode 100644 index 000000000..84d7ab6a3 --- /dev/null +++ b/mttl/models/library/merging_methods/ties.py @@ -0,0 +1,102 @@ +import copy +import torch +from dataclasses import dataclass +from mttl.models.library.merge_models.base_merge import BaseMerge, BaseMergeConfig +from mttl.models.library.expert_library import ExpertLibrary +from mttl.models.library.merge_models.utils import topk_multiple_experts +from mttl.models.utils import logger + + +@dataclass +class TiesMergeSimpleConfig(BaseMergeConfig): + top_k: float = 0.2 + merging_method: str = "ties_merge_expert" + alpha: float = ( + 0.4 # scaling factor # source : (a) https://openreview.net/pdf?id=6t0Kwf8-jrj (b) https://arxiv.org/pdf/2306.01708 + ) + beta: float = ( + 0.8 # 80% beta=sparsity, keep-ratio=1-beta, fig 3 https://arxiv.org/pdf/2306.01708 suggest to keep top20% params + ) + + +class TiesMergeSimple(BaseMerge): + """ + Computes a uniform weight mixture across experts of a given library + """ + + def __init__(self, config: TiesMergeSimpleConfig = None): + super().__init__(config or TiesMergeSimpleConfig()) + + assert self.config.top_k > 0.0 and self.config.top_k <= 1.0 + + def compute_per_task_threhold(self, expert_vectors): + # take the absolute value: + expert_vectors = torch.stack(expert_vectors, dim=0).abs() + topk = int(expert_vectors.size(1) * self.config.beta) + per_exp_lth = topk_multiple_experts(expert_vectors, topk, TH_type="lower") + + return per_exp_lth + + @torch.no_grad() + def merge_expert( + self, + experts, + expert_vectors, + trainable_params, + base_expert, + base_model_state_dict, + expert_type, + ): + # ---------------------------------------------------------------------- + # Compute Threshold score, TH + per_exp_lth = self.compute_per_task_threhold(expert_vectors) + + used, total = 0, 0 + for param_name in base_model_state_dict.keys(): + if param_name in trainable_params: + # stack the expert weights + expert_weights = self.extract_expert_weight( + base_model_state_dict, experts, param_name, expert_type + ) + + # keep weights over the threshold + TH = per_exp_lth.view( + -1, *((1,) * (expert_weights.ndim - 1)) + ) # reshape + keep_mask = expert_weights.abs() > TH + expert_weights = expert_weights * keep_mask + + # sign majority vote + # sign_per_dim = expert_weights.sign().sum(0, keepdim=True).sign() + sign_per_dim = expert_weights.sum(0, keepdim=True).sign() + # resolve zero signs: https://github.com/rezazzr/breadcrumbs/blob/main/src/task_vectors.py#L334 + majority_sign = torch.sign(sign_per_dim.sum()) + sign_per_dim[sign_per_dim == 0] = majority_sign + + # keep only weights whose sign agree with the majority + use_for_avg = expert_weights.sign() == sign_per_dim + + deno = (use_for_avg != 0).sum(0).clamp(min=1.0) + sum_param = (expert_weights * use_for_avg).sum(0) + final_param = sum_param / deno + used += (use_for_avg & (sign_per_dim != 0.0)).sum().item() + + # ----------------------------------------------------- + # base_weight + sum of the "filtered" task-vector + # W = W + delta_W + # source : (a) https://openreview.net/pdf?id=6t0Kwf8-jrj (b) https://arxiv.org/pdf/2306.01708 + final_param = ( + base_model_state_dict[param_name] + self.config.alpha * final_param + ) + used += keep_mask.sum().item() + total += expert_weights.numel() + base_expert.expert_weights[param_name].data.copy_(final_param) + else: + base_expert.expert_weights[param_name] = copy.deepcopy( + base_model_state_dict[param_name] + ) + + logger.info( + "Params used to compute Ties mean: {:.10f}%".format(100.0 * used / total) + ) + return base_expert diff --git a/mttl/models/library/merging_methods/uniform_merge.py b/mttl/models/library/merging_methods/uniform_merge.py new file mode 100644 index 000000000..51573bba9 --- /dev/null +++ b/mttl/models/library/merging_methods/uniform_merge.py @@ -0,0 +1,70 @@ +from __future__ import annotations +import copy +from dataclasses import dataclass +import torch +from mttl.models.library.merging_methods.base_merge import ( + BaseMerge, + BaseMergeConfig, +) +from mttl.logging import logger + + +@dataclass +class UniformMergeConfig(BaseMergeConfig): + merging_method: str = "uniform_merge_expert" + alpha: float = 1.0 + + +class UniformMerge(BaseMerge): + """ + Computes a uniform weight mixture across experts of a given library + """ + + def __init__(self, config: UniformMergeConfig = None): + super().__init__(config or UniformMergeConfig()) + + def merge_expert( + self, + experts, + expert_vectors, + trainable_params, + base_expert, + base_model_state_dict, + expert_type, + ): + used, total = 0, 0 + for param_name in base_model_state_dict.keys(): + if param_name in trainable_params: + # stack the expert weights + expert_weights = self.extract_expert_weight( + base_model_state_dict, experts, param_name, expert_type + ) + + # collect mask + keep_mask = torch.ones_like(expert_weights) + # keep weights over the threshold + expert_weights = expert_weights * keep_mask + # uniform + final_param = expert_weights.mean(0) + + # ----------------------------------------------------- + # base_weight + sum of the "filtered" task-vector + # W = W + delta_W + # source : (a) https://openreview.net/pdf?id=6t0Kwf8-jrj (b) https://arxiv.org/pdf/2306.01708 + final_param = ( + base_model_state_dict[param_name] + self.config.alpha * final_param + ) + + used += keep_mask.sum().item() + total += expert_weights.numel() + + base_expert.expert_weights[param_name].data.copy_(final_param) + else: + base_expert.expert_weights[param_name] = copy.deepcopy( + base_model_state_dict[param_name] + ) + + logger.info( + "Params used to compute Ties mean: {:.10f}%".format(100.0 * used / total) + ) + return base_expert diff --git a/mttl/models/library/merging_methods/uniform_sparse.py b/mttl/models/library/merging_methods/uniform_sparse.py new file mode 100644 index 000000000..05f76f96b --- /dev/null +++ b/mttl/models/library/merging_methods/uniform_sparse.py @@ -0,0 +1,95 @@ +import torch +import numpy as np +from dataclasses import dataclass +from huggingface_hub import hf_hub_download +from mttl.models.library.merging_methods.base_merge import ( + BaseMerge, + BaseMergeConfig, +) +from mttl.models.expert_model import ExpertModel + + +@dataclass +class UniformSparsConfig(BaseMergeConfig): + top_k: float = 0.2 + merging_method: str = "uniform_sparse_merge_expert" + alpha: float = ( + 1 # scaling factor # source : (a) https://openreview.net/pdf?id=6t0Kwf8-jrj (b) https://arxiv.org/pdf/2306.01708 + ) + + +class UniformSparse(BaseMerge): + """ + Computes a uniform weight mixture across experts of a given library + """ + + def __init__(self, config: UniformSparsConfig = None): + super().__init__(config or UniformSparsConfig()) + assert self.config.top_k > 0.0 and self.config.top_k <= 1.0 + + @torch.no_grad() + def merge_expert( + self, experts, trainable_params, base_expert, base_model_state_dict, expert_type + ): + param_dict = {} + for param_name, base_w in base_expert.model.state_dict().items(): + if param_name in trainable_params: + # ignore bias + if "weight" in param_name: + # stack the expert weights + expert_weights = self.extract_expert_weight( + base_model_state_dict, experts, param_name, expert_type + ) + + sum_param = expert_weights.sum(0) + mask_overlaps = torch.stack( + [(e != 0).float() for e in expert_weights], dim=0 + ).sum(0) + + mask_overlaps[mask_overlaps == 0] = 1 + final_param = sum_param / mask_overlaps + + layer_name = ".".join(param_name.split(".")[:-2]) + updated_param_name = f"{layer_name}.weight" + param_dict[updated_param_name] = final_param + + return param_dict + + @torch.no_grad() + def transform(self, library): + experts, expert_type, base_expert, base_model, trainable_params = ( + self.pre_configure(library) + ) + base_model_state_dict = dict(base_model.state_dict()) + base_expert.training_config.device_map = "cpu" + base_expert = ExpertModel(**vars(base_expert.training_config)) + trainable_params = [ + n for n in base_expert.model.state_dict().keys() if ("sparse_layer" in n) + ] # allow to add weights and bias + assert trainable_params != [], print("could not find sparse-layer modules") + base_model_state_dict = base_expert.model.state_dict() + + param_dict = self.merge_expert( + experts, trainable_params, base_expert, base_model_state_dict, expert_type + ) + + config = base_expert.training_config + config.model_modifier = None # load only the base model + config.device_map = "cpu" + config.trainable_param_names = ".*" # allows to train all linear layers + base_model = ExpertModel(**vars(config)) + + for param_name, base_w in base_model.model.state_dict().items(): + if param_name in param_dict: + param_dict[param_name] = base_w + param_dict[param_name].to( + base_w.dtype + ) + else: + param_dict[param_name] = base_w + + assert set(base_model.model.state_dict().keys()) == set( + param_dict.keys() + ), "Expert weights must have the same keys" + base_model.model.load_state_dict(param_dict) + + return base_model diff --git a/mttl/models/library/merging_methods/utils.py b/mttl/models/library/merging_methods/utils.py new file mode 100644 index 000000000..47b55e801 --- /dev/null +++ b/mttl/models/library/merging_methods/utils.py @@ -0,0 +1,41 @@ +import torch +import numpy as np +from huggingface_hub import hf_hub_download + + +def topk_multiple_experts(expert_vectors, topk, TH_type=None): + assert TH_type != None + n_tasks = expert_vectors.shape[0] + values = [] + for t in range(n_tasks): + print("topk expert", t) + v, _ = torch.topk(expert_vectors[t, :], topk) + if TH_type == "lower": + values.append(v[-1]) + elif TH_type == "upper": + values.append(v[0]) + del v + values = torch.stack(values, dim=0) # Shape will be (n_tasks,) + return values + + +def load_mask(expert): + try: + print("trying to load mask from hf") + library_id = expert.training_config.library_id + destination_type, f_name = library_id.split("://") + repo_id = ("/").join(f_name.split("/")[:2]) + filename = f"{expert.expert_info.expert_name}_mask.npz" + f_path = hf_hub_download(repo_id=repo_id, filename=filename) + Mask = np.load(f_path, allow_pickle=True)["arr"].item() + except: + print("trying to load mask from local dir") + m_loc = f"experiment/{expert.training_config.exp_name}/mask.npz" + Mask = np.load(m_loc, allow_pickle=True)["arr"].item() + return Mask + + +def convert_idx_2_mask(weight_idx, mat_dim): + m = np.zeros(mat_dim) + m[tuple(zip(*weight_idx))] = 1 + return torch.FloatTensor(m) diff --git a/mttl/models/lightning/callbacks.py b/mttl/models/lightning/callbacks.py index fec54176e..b3aec2e13 100644 --- a/mttl/models/lightning/callbacks.py +++ b/mttl/models/lightning/callbacks.py @@ -20,37 +20,35 @@ from mttl.evaluators.base import EvaluatorRunner, setup_evaluators from mttl.evaluators.evaluators import Evaluator from mttl.logging import logger -from mttl.models.modifiers.sparse_mask import make_sparse_model_during_training +from mttl.models.modifiers.sparse_mask import make_sparse_model_during_training, save_mask from mttl.models.utils import transfer_batch_to_device DEBUG = False class UpdateSparseMask(pl.Callback): - def __init__( - self, - update_interval=5, - dm=None, - task_name=None, - parameter_selection_procedure="per_layer", - ): + def __init__(self, update_interval=5, + num_train_steps=None, + dm=None, + save_mask_dir=None, + task_name=None, + parameter_selection_procedure='per_layer'): super().__init__() self.update_interval = update_interval self.update_counter = 0 self.dm = dm + self.save_mask_dir = save_mask_dir self.task_name = task_name - assert parameter_selection_procedure in [ - "model", - "per_layer", - ], "choose the right `parameter_selection_procedure`" - self.parameter_selection_procedure = parameter_selection_procedure - - def update_mask(self, pl_module, batch): - make_sparse_model_during_training( - pl_module, - batch, - parameter_selection_procedure=self.parameter_selection_procedure, - ) + self.num_train_steps=num_train_steps + assert parameter_selection_procedure in ['model','per_layer','layer_and_param', 'weight_magnitude', 'gradient_magnitude','grow_and_drop'], "choose the right `parameter_selection_procedure`" + self.parameter_selection_procedure=parameter_selection_procedure + + def update_mask(self, pl_module, batch, num_train_steps, current_steps): + make_sparse_model_during_training(pl_module, + batch, + num_train_steps, + current_steps, + parameter_selection_procedure=self.parameter_selection_procedure) def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): """ @@ -60,9 +58,14 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self.update_counter += 1 if self.update_counter % self.update_interval == 0: # Update mask - self.update_mask(pl_module, batch) - self.update_counter = 0 # Reset counter for next interval + self.update_mask(pl_module, batch, self.num_train_steps, self.update_counter) + def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """ + save mask end of training + """ + f_name = f'{self.save_mask_dir}/{self.task_name}_mask' + save_mask(pl_module, f_name) class LiveCheckpointCallback(pl.Callback): """A better model checkpoint callback, that works in synchrony with LiveLogMixin.""" diff --git a/mttl/models/modifiers/sparse_mask.py b/mttl/models/modifiers/sparse_mask.py index 402b36271..144855760 100644 --- a/mttl/models/modifiers/sparse_mask.py +++ b/mttl/models/modifiers/sparse_mask.py @@ -12,9 +12,74 @@ import numpy as np import torch from torch import nn - +from huggingface_hub import hf_hub_download from mttl.models.modifiers.base import Modifier, ModifierConfig from mttl.utils import logger +from collections import OrderedDict + + +def hook_factory(keep_mask): + """ + The hook function can't be defined directly here because of Python's + late binding which would result in all hooks getting the very last + mask! Getting it through another function forces early binding. + """ + + def hook(grads): + return grads * keep_mask + + return hook + + +def load_mask(f_name): + destination_type, _ = f_name.split("://") + if destination_type == "hf": + destination_type, f_name = f_name.split("://") + repo_id = ("/").join(f_name.split("/")[:2]) + task_name = f_name.split("/")[-1] + f_path = hf_hub_download(repo_id=repo_id, filename=task_name) + mask_dict = np.load(f_path, allow_pickle=True)["arr"] + elif destination_type == "local": + mask_dict = np.load(f"{f_name}.npz", allow_pickle=True)["arr"] + return mask_dict + + +def save_mask(module, f_name): + """ + to load the saved mask use the `load_mask` function + """ + mask_dict = {} + import numpy as np + + for m_name, m in dict(module.named_modules()).items(): + if "sparse_layer" in m_name: + mask_dict[m_name] = torch.nonzero(m.weight_mask.data).cpu().numpy() + destination_type = f_name.split("://")[0] + # save in local dir + if destination_type == "local": + destination_type, f_name = f_name.split("://") + np.savez_compressed(f"./{f_name}.npz", arr=mask_dict) + + # upload to hf + elif destination_type == "hf": + from huggingface_hub.hf_api import upload_file as hf_upload_file + + destination_type, f_name = f_name.split("://") + repo_id = ("/").join(f_name.split("/")[:2]) + task_name = f_name.split("/")[-1] + path_in_repo = f"{task_name}.npz" + os.makedirs("./temp/test_library/", exist_ok=True) + local_file_path = f"./temp/test_library/{path_in_repo}" + np.savez_compressed(local_file_path, arr=mask_dict) + + hf_upload_file( + path_or_fileobj=local_file_path, # path saved in local machine + path_in_repo=path_in_repo, # path with in repo + repo_id=repo_id, + ) # + # exact local dir is provided + else: + np.savez_compressed(f"./{f_name}.npz", arr=mask_dict) class MatrixBlockIndexer: @@ -116,34 +181,156 @@ def get_regular_sparse_mask(m): return keep_masks +def get_gradient_magnitude_based_sparse_mask(m): + """ + parameter-wise sparse calculation based on gradient-magnitude + """ + num_params_to_keep = int(torch.numel(m.sparse_layer.weight) * m.keep_ratio) + threshold, _ = torch.topk( + m.sparse_layer.weight.grad.abs().flatten(), num_params_to_keep, sorted=True + ) + accepted_score = threshold[-1] + keep_masks = (m.sparse_layer.weight.grad.abs() >= accepted_score).float() + + return keep_masks + + +def get_weight_magnitude_based_sparse_mask(m): + """ + parameter-wise sparse calculation based on weight-magnitude + """ + num_params_to_keep = int(torch.numel(m.sparse_layer.weight) * m.keep_ratio) + threshold, _ = torch.topk( + m.sparse_layer.weight.abs().flatten(), num_params_to_keep, sorted=True + ) + accepted_score = threshold[-1] + keep_masks = (m.sparse_layer.weight.abs() >= accepted_score).float() + + return keep_masks + + +def get_grow_drop_based_sparse_mask(m, num_train_steps, current_steps): + """ + parameter-wise sparse calculation based on grow and prune + """ + gamma = 0.2 # peak_replacement_rate + grow_prune_frac = gamma * (1 - current_steps / num_train_steps) + + def random_mask(): + keep_masks = torch.zeros_like(m.sparse_layer.weight) + num_params_to_keep = int(torch.numel(m.sparse_layer.weight) * m.keep_ratio) + ones_idx = torch.randperm(torch.numel(m.sparse_layer.weight))[ + :num_params_to_keep + ] + keep_masks.flatten()[ones_idx] = 1 + return keep_masks + + def grad_magnitude_mask(): + num_params_to_keep = int(torch.numel(m.sparse_layer.weight) * m.keep_ratio) + threshold, _ = torch.topk( + (m.sparse_layer.weight.grad).abs().flatten(), + num_params_to_keep, + sorted=True, + ) + accepted_score = threshold[-1] + keep_masks = (m.sparse_layer.weight.grad.abs() >= accepted_score).float() + return keep_masks + + if current_steps == 0: + keep_masks = random_mask() + + else: + # grow using largest magnitude + num_params_to_keep_or_drop = int( + torch.numel(m.sparse_layer.weight) * m.keep_ratio * grow_prune_frac + ) + mask_candidate = (1 - m.sparse_layer.weight_mask_old).to( + m.sparse_layer.weight.device + ) + threshold, _ = torch.topk( + (m.sparse_layer.weight.grad * mask_candidate).abs().flatten(), + num_params_to_keep_or_drop, + sorted=True, + ) + accepted_score = threshold[-1] + if accepted_score == 0: + keep_masks = random_mask() + else: + grow_masks = ( + (m.sparse_layer.weight.grad.abs() * mask_candidate) >= accepted_score + ).float() + # prune find the weights changed the least: lowest weight magnitude + threshold, _ = torch.topk( + (m.sparse_layer.weight * m.sparse_layer.weight_mask).abs().flatten(), + num_params_to_keep_or_drop, + largest=False, + sorted=True, + ) + accepted_score = threshold[-1] + prune_masks = ( + (m.sparse_layer.weight * m.sparse_layer.weight_mask).abs() + <= accepted_score + ).float() + # add 1s from grow mask to existing keep mask + keep_masks = m.sparse_layer.weight_mask + grow_masks + # remove 1s from prune mask, i.e., set to 0 where prune mask is 1) + keep_masks = m.sparse_layer.weight_mask - prune_masks + m.sparse_layer.weight.data = ( + m.sparse_layer.weight.data * prune_masks + ) # dropped weights are set to 0 + assert keep_masks.max() == 1 and keep_masks.min() == 0 + + return keep_masks + + def make_sparse_model_during_training( - module, batch, parameter_selection_procedure="per_layer" + module, + batch, + num_train_steps, + current_steps, + print_statement=False, + parameter_selection_procedure="per_layer", ): from mttl.models.modifiers.sparse_mask import SparseMaskAdapter as SparseMaskModule + assert parameter_selection_procedure in [ + "model", + "per_layer", + "layer_and_param", + "weight_magnitude", + "gradient_magnitude", + "grow_and_drop", + ], "choose the right `parameter_selection_procedure`" # (1) preprocess the sparse-layers for m in module.modules(): if isinstance(m, SparseMaskModule): - m.preprocess_for_mask_update() + if parameter_selection_procedure in ["grow_and_drop"]: + m.preprocess_for_grow_and_drop_mask_update() + elif parameter_selection_procedure in [ + "weight_magnitude", + "gradient_magnitude", + ]: + m.preprocess_for_weight_and_grad_magnitude() + else: + m.preprocess_for_mask_update() # (2) collect grads from mttl.models.utils import transfer_batch_to_device - loss = module.forward(**batch).loss loss.backward() - assert parameter_selection_procedure in [ - "model", - "per_layer", - ], "choose the right `parameter_selection_procedure`" - # (3) compute mask # (a) layer-wise if parameter_selection_procedure == "per_layer": - for m in module.modules(): + for idx, m in enumerate(module.modules()): if isinstance(m, SparseMaskModule): if m.sparse_cat == "block_sparse": keep_masks = get_block_mask(m) + # check: sample noise-block-idx + # block_noise_idx = sample(m.nextK_idx) + # noise_masks_idx = [m.BlockwiseConvolution.get_block_indices(i) for i in block_noise_idx] + # noise_masks_idx=torch.stack(noise_masks_idx).flatten().to(m.layer.weight.device) + # print('check') elif m.sparse_cat == "regular_sparse": keep_masks = get_regular_sparse_mask(m) @@ -151,10 +338,13 @@ def make_sparse_model_during_training( # (a) reverse the require-grad: Turn on for `weight` and turn-off for `weight_mask` # (b) convert `module` back to `cpu` m.revert_weight_grad_and_update_mask(keep_masks) + # check + # print(f'Layer {idx} sparsity', (keep_masks.sum()/keep_masks.numel()).data.cpu().numpy()) + # if save_mask_indx: mask_indx.append(torch.nonzero(keep_masks).data.cpu().numpy()) # nonzero finds the ind # (b) based on whole-net - # b.1 compute score elif parameter_selection_procedure == "model": + # b.1 compute score num_params_to_keep = 0 grads = [] for m in module.modules(): @@ -166,7 +356,36 @@ def make_sparse_model_during_training( torch.numel(m.sparse_layer.weight_mask) * m.keep_ratio ) grads.append(m.sparse_layer.weight_mask.grad.flatten().cpu()) + threshold, _ = torch.topk( + torch.stack(grads).flatten(), num_params_to_keep, sorted=True + ) + accepted_score = threshold[-1] + # b.2 mask + for m in module.modules(): + if isinstance(m, SparseMaskModule): + keep_masks = (m.sparse_layer.weight_mask.grad >= accepted_score).float() + if print_statement: + print( + "sparsity", + (keep_masks.sum() / m.sparse_layer.weight_mask.numel()) * 100, + "expected", + m.keep_ratio * 100, + ) + m.revert_weight_grad_and_update_mask(keep_masks) + # drop layer and params + elif parameter_selection_procedure == "layer_and_param": + num_params_to_keep = 0 + grads = [] + for m in module.modules(): + if isinstance(m, SparseMaskModule): + assert ( + m.sparse_cat == "regular_sparse" + ), "parameter_selection_procedure over `model` is not implemented for `block_sparse`" + num_params_to_keep += int( + torch.numel(m.sparse_layer.weight_mask) * m.keep_ratio + ) + grads.append(m.sparse_layer.weight_mask.grad.flatten().cpu()) threshold, _ = torch.topk( torch.stack(grads).flatten(), num_params_to_keep, sorted=True ) @@ -175,11 +394,48 @@ def make_sparse_model_during_training( for m in module.modules(): if isinstance(m, SparseMaskModule): keep_masks = (m.sparse_layer.weight_mask.grad >= accepted_score).float() + if ( + keep_masks.sum() / m.sparse_layer.weight_mask.numel() + ) >= m.keep_ratio: + keep_masks = get_regular_sparse_mask(m) + m.revert_weight_grad_and_update_mask(keep_masks) + else: + keep_masks = torch.zeros_like(m.sparse_layer.weight) + m.revert_weight_grad_and_update_mask(keep_masks) + # based on gradient-magnitude + elif parameter_selection_procedure == "gradient_magnitude": + for m in module.modules(): + if isinstance(m, SparseMaskModule): + assert m.sparse_cat == "regular_sparse", print( + "block-sparse is not implemented fro gradient magnitude sparse" + ) + keep_masks = get_gradient_magnitude_based_sparse_mask(m) + m.revert_weight_grad_and_update_mask(keep_masks) + # based on weight-magnitude + elif parameter_selection_procedure == "weight_magnitude": + for m in module.modules(): + if isinstance(m, SparseMaskModule): + assert m.sparse_cat == "regular_sparse", print( + "block-sparse is not implemented fro gradient magnitude sparse" + ) + keep_masks = get_weight_magnitude_based_sparse_mask(m) + m.revert_weight_grad_and_update_mask(keep_masks) + # based on grow and drop + elif parameter_selection_procedure == "grow_and_drop": + for m in module.modules(): + if isinstance(m, SparseMaskModule): + assert m.sparse_cat == "regular_sparse", print( + "block-sparse is not implemented fro gradient magnitude sparse" + ) + keep_masks = get_grow_drop_based_sparse_mask( + m, num_train_steps, current_steps + ) m.revert_weight_grad_and_update_mask(keep_masks) def mod_forward(self, x): - return torch.nn.functional.linear(x, self.weight * self.weight_mask, self.bias) + # return torch.nn.functional.linear(x, self.weight * self.weight_mask, self.bias) + return torch.nn.functional.linear(x, self.weight * self.weight_mask, None) @dataclass @@ -213,13 +469,13 @@ def __init__( ], "Choose `sparse_cat` from ['block_sparse','regular_sparse'] " # weight initialization - self.sparse_layer = nn.Linear(input_dim, output_dim).to( + self.sparse_layer = nn.Linear(input_dim, output_dim, bias=False).to( device=layer.weight.device ) self.sparse_layer.weight = nn.Parameter( torch.zeros(self.sparse_layer.weight.shape) ) - self.sparse_layer.bias = nn.Parameter(torch.zeros(self.sparse_layer.bias.shape)) + # self.sparse_layer.bias = nn.Parameter(torch.zeros(self.sparse_layer.bias.shape)) if self.sparse_cat == "block_sparse": self.BLOCK_SIZE = config.BLOCK_SIZE @@ -241,6 +497,15 @@ def __init__( def patch_forward(self): self.sparse_layer.forward = types.MethodType(mod_forward, self.sparse_layer) + @torch.no_grad() + def convert_sparse_weight_to_1D(self): + assert len(self.sparse_layer.weight.shape) == 2, print( + "sparse_layer.weight is already converted to 1D" + ) + self.sparse_layer.weight = nn.Parameter( + self.sparse_layer.weight.flatten()[self.keep_mask_idx].data + ).to(self.layer.weight.device) + def data_preprocess(self, x): sparse_model_dtype = self.sparse_layer.weight.dtype return x.to(sparse_model_dtype) @@ -277,6 +542,45 @@ def preprocess_for_mask_update(self): ) # compute gradient for weight_mask self.sparse_layer.weight_mask.requires_grad = True + # remove backward-hook + self.sparse_layer.weight._backward_hooks = OrderedDict() + + """ + prepare the mask for weight and grad magnitude + """ + + def preprocess_for_weight_and_grad_magnitude(self): + # Turn off the gradient for weight + self.sparse_layer.weight.requires_grad = True + # init the mask + self.sparse_layer.weight_mask = nn.Parameter( + torch.ones( + self.sparse_layer.weight_mask.shape, device=self.layer.weight.device + ) + ) + # compute gradient for weight_mask + self.sparse_layer.weight_mask.requires_grad = False + # remove backward-hook + self.sparse_layer.weight._backward_hooks = OrderedDict() + + """ + prepare the mask for grow and drop + """ + + def preprocess_for_grow_and_drop_mask_update(self): + # Turn off the gradient for weight + assert self.sparse_layer.weight.requires_grad == True + self.sparse_layer.weight_mask_old = self.sparse_layer.weight_mask.data.to("cpu") + # init the mask + self.sparse_layer.weight_mask = nn.Parameter( + torch.ones( + self.sparse_layer.weight_mask.shape, device=self.layer.weight.device + ) + ) + # compute gradient for weight_mask + self.sparse_layer.weight_mask.requires_grad = False + # remove backward-hook + self.sparse_layer.weight._backward_hooks = OrderedDict() """ after configuring the mask, it's important to update and allow gradient to pass through the weight for training @@ -306,6 +610,12 @@ def revert_weight_grad_and_update_mask(self, mask=None): self.sparse_layer.weight_mask = nn.Parameter( mask, requires_grad=False ).to(self.sparse_layer.weight.device) + + # apply backward-hook + self.sparse_layer.weight.register_hook( + hook_factory(mask.to(self.sparse_layer.weight.device)) + ) + else: print("Mask is not provided, initializing to default mask value=1") del self.sparse_layer.weight_mask diff --git a/projects/sparse_finetuning/README.md b/projects/sparse_finetuning/README.md index f72a2e327..6c9bfb289 100644 --- a/projects/sparse_finetuning/README.md +++ b/projects/sparse_finetuning/README.md @@ -21,7 +21,6 @@ projects/sparse_finetuning/scripts/local/run.sh "full" "local://temp/test_librar --- * Notes: - * conda create -n mttl python=3.9 + * conda create -n mttl python=3.11 * conda activate mttl - * pip install -r local_requirements.txt - * use `transformers==4.42.0` to properly use local `microsoft/phi-2` model + * pip install -r requirements.txt diff --git a/projects/sparse_finetuning/eval_library.py b/projects/sparse_finetuning/eval_library.py new file mode 100644 index 000000000..5a2d490e2 --- /dev/null +++ b/projects/sparse_finetuning/eval_library.py @@ -0,0 +1,348 @@ +import os +import sys +import torch +import copy +import wandb +import numpy as np +from copy import deepcopy +import torch.nn.functional as F +from pytorch_lightning import seed_everything +import json + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) + +from mttl.models.library.expert_library import ExpertLibrary +from mttl.models.containers.selectors.base import Selector +from mttl.models.modifiers.lora import LoRAConfig + +from mttl.utils import logger, remote_login, setup_logging +from mttl.models.expert_model import MultiExpertModel, ExpertModel +from mttl.models.expert_config import ExpertConfig + +from mttl.evaluators.base import EvaluatorRunner, setup_evaluators +from mttl.models.lightning.callbacks import LossCallback +from mttl.datamodule.base import get_datamodule +from mttl.evaluators.rouge_evaluator import RougeEvaluator +from mttl.logging import TableLogger + + +def eval_in_distribution(module, args: ExpertConfig, tasks: list): + args.include_task_source = "*" + transfer_table = TableLogger() + print(f"eval metric: {args.eval_metric}") + + for i, task in enumerate(tasks): + args.finetune_task_name = task + args.predict_batch_size = 16 + if args.eval_metric in ["val_loss", "loss"]: + dm = get_datamodule(args) + evaluator = LossCallback( + dm.val_dataloader(), output_dir=args.output_dir, name=task + "_val" + ) + metric = evaluator.test(pl_module=module).item() + + elif args.eval_metric == "test_loss": + dm = get_datamodule(args) + evaluator = LossCallback( + dm.test_dataloader(), output_dir=args.output_dir, name=task + "_test" + ) + metric = evaluator.test(pl_module=module).item() + elif args.eval_metric == "val_rougeL": + dm = get_datamodule(args, for_generation=True) + evaluator = RougeEvaluator( + datamodule=dm, + ) + metric = evaluator.evaluate( + module, + split="val", + verbose=False, + ) + elif args.eval_metric == "rougeL": + dm = get_datamodule(args, for_generation=True) + evaluator = RougeEvaluator( + datamodule=dm, + ) + metric = evaluator.evaluate( + module, + split="test", + verbose=False, + ) + else: + raise ValueError(f"Unknown eval metric {args.eval_metric}") + if wandb.run is not None: + wandb.log({f"test/{args.eval_metric}_{task}": metric}) + transfer_table.log({"task": task, args.eval_metric: metric}) + + if wandb.run is not None: + wandb.log( + {f"mean_{args.eval_metric}": transfer_table.df[args.eval_metric].mean()} + ) + + transfer_table.log( + { + "task": "mean", + args.eval_metric: transfer_table.df[args.eval_metric].mean(), + } + ) + transfer_table.log_final_table() + + +def eval_in_distribution_sparse_model( + module, library, expert, args: ExpertConfig, tasks: list +): + args.include_task_source = "*" + transfer_table = TableLogger() + + for i, task in enumerate(tasks): + # update the mask correspond to the task + expert.update_module_mask(module, library[task]) + + args.finetune_task_name = task + args.predict_batch_size = 16 + if args.eval_metric in ["val_loss", "loss"]: + dm = get_datamodule(args) + evaluator = LossCallback( + dm.val_dataloader(), output_dir=args.output_dir, name=task + "_val" + ) + metric = evaluator.test(pl_module=module).item() + + elif args.eval_metric == "test_loss": + dm = get_datamodule(args) + evaluator = LossCallback( + dm.test_dataloader(), output_dir=args.output_dir, name=task + "_test" + ) + metric = evaluator.test(pl_module=module).item() + elif args.eval_metric == "val_rougeL": + dm = get_datamodule(args, for_generation=True) + evaluator = RougeEvaluator( + datamodule=dm, + ) + metric = evaluator.evaluate( + module, + split="val", + verbose=False, + ) + elif args.eval_metric == "rougeL": + dm = get_datamodule(args, for_generation=True) + evaluator = RougeEvaluator( + datamodule=dm, + ) + metric = evaluator.evaluate( + module, + split="test", + verbose=False, + ) + else: + raise ValueError(f"Unknown eval metric {args.eval_metric}") + if wandb.run is not None: + wandb.log({f"test/{args.eval_metric}_{task}": metric}) + transfer_table.log({"task": task, args.eval_metric: metric}) + + if wandb.run is not None: + wandb.log( + {f"mean_{args.eval_metric}": transfer_table.df[args.eval_metric].mean()} + ) + + transfer_table.log( + { + "task": "mean", + args.eval_metric: transfer_table.df[args.eval_metric].mean(), + } + ) + transfer_table.log_final_table() + + +def run_eval(args: ExpertConfig): + seed_everything(args.seed, workers=True) + + # get directory of the current file + setup_logging(args.output_dir) + + logger.info("Args: {}".format(args.to_json())) + + remote_login(args.remote_token) + + # defult + selection = None # for debugging: selection = ['duorc_ParaphraseRC_extract_answer', 'wiki_qa_Topic_Prediction_Question_and_Answer_Pair'] + if selection is None: + exclude_phi_tasks = [ + "hellaswag_1_1_0", + "ai2_arc_ARC_Challenge_1_0_0", + "ai2_arc_ARC_Easy_1_0_0", + "piqa_1_0_0", + "winogrande_1_1_0", + "bool_q_1_0_0", + "openbookqa_0_1_0", + ] + else: + exclude_phi_tasks = None + print(args.library_id) + library = ExpertLibrary.get_expert_library( + repo_id=args.library_id, + token=args.remote_token, + exclude_selection=exclude_phi_tasks, + destination_id=args.destination_library_id, + selection=selection, + N_experts=args.N_experts, + ) + an_expert = library[next(iter(library.keys()))] + train_cfg = deepcopy(an_expert.training_config) + train_cfg.device_map = "cpu" + # For starts, always overwrite the following arguments + for arg_name in [ + "output_dir", + "eval_metric", + "remove_phi_eval_tasks", + "include_task_source", + ]: + value = getattr(args, arg_name, None) + setattr(train_cfg, arg_name, value) + + """ Parameter Merging Approaches """ + if args.merge_or_route == "uniform": + from mttl.models.library.merging_methods.uniform_merge import UniformMerge, UniformMergeConfig + cfg = UniformMergeConfig(alpha=args.merge_alpha) + module = UniformMerge(cfg).transform(library).to("cuda") + + elif args.merge_or_route == "ties": + from mttl.models.library.merging_methods.ties import TiesMergeSimple, TiesMergeSimpleConfig + cfg = TiesMergeSimpleConfig(alpha=args.merge_alpha) + module = TiesMergeSimple(cfg).transform(library).to("cuda") + + elif args.merge_or_route == "model_breadcrumbs": + from mttl.models.library.merging_methods.model_breadcrumbs import ModelBreadcrumbs, ModelBreadcrumbsConfig + cfg = ModelBreadcrumbsConfig(alpha=args.merge_alpha) + module = ModelBreadcrumbs(cfg).transform(library).to("cuda") + + elif args.merge_or_route == "task_arithmetic": + from mttl.models.library.merging_methods.task_arithmetic import TaskArithmetic, TaskArithmeticConfig + cfg = TaskArithmeticConfig(alpha=args.merge_alpha) + module = TaskArithmetic(cfg).transform(library).to("cuda") + + elif args.merge_or_route == "SLERP": + from mttl.models.library.merging_methods.slerp import SLERPMerge, SLERPMergeConfig + module = SLERPMerge(SLERPMergeConfig()).transform(library).to("cuda") + + elif args.merge_or_route == "uniform_lora_before_op": + from mttl.models.library.merging_methods.LoRA_ablinear import LoRA_ab_LinearMerge, LoRA_ab_LinearMergeConfig + module = ( + LoRA_ab_LinearMerge(LoRA_ab_LinearMergeConfig()) + .transform(library) + .to("cuda") + ) + + elif args.merge_or_route in [ + "uniform_sparse_weight", + "uniform_sparse_weight_oracle_routing", + ]: + """uniform merge of all weights""" + from mttl.models.library.merging_methods.sparse_merge import SparseWeightLinearMerge, SparseWeightLinearMergeConfig + expert = SparseWeightLinearMerge(SparseWeightLinearMergeConfig()) + module = expert.transform(library).to("cuda") + + """masked weight for single task""" + # TODO: remove provide weights of only one task + # expert = SparseWeightLinearMerge(SparseWeightLinearMergeConfig()) + # expert_names = list(library.keys()) #TODO remove + # module = expert.transform_dummy(library, expert_names[0]).to("cuda") #TODO remove + + elif args.merge_or_route == "uniform_lora_after_op": + # Here we merge the LoRA experts after the outer product we cannot really do it + # with the lib transform, cause this would require storing large matrices in memory + # Instead we do it with a uniform selector + assert type(an_expert.expert_info.expert_config) == LoRAConfig + train_cfg.router_selector = "uniform" + train_cfg.lora_merge_after = True + module = MultiExpertModel(**vars(train_cfg)).to("cuda") + module.load_from_module_dict(library) + + elif args.merge_or_route == "base": + module = ExpertModel(**vars(train_cfg)).to("cuda") + + else: + raise ValueError(f"Unknown merge_or_route {args.merge_or_route}") + + metric_logger = Selector.metric_logger + + if wandb.run is None and os.environ.get("WANDB_API_KEY"): + wandb.init( + project=os.environ.get("WANDB_PROJECT", "0shot_routing"), + config=dict(module.hparams), + name=os.environ.get("AMLT_JOB_NAME", None), + ) + # update config + wandb.config.update({f"cmd_args_{k}": v for k, v in vars(args).items()}) + + if args.pipeline_eval_tasks in [ + "in_distribution", + ]: + tasks = [expert.expert_task_name for expert in library.data.values()] + if tasks[0] is None: + # for some older version of lib (in case of joint experts) no expert_task_name was set + tasks = json.load(open(args.flan_tasks_path))["flan256"] + # make sure we evaluate each task seperately (so the mean is over tasks at the end) + tasks = ",".join(tasks).split(",") + train_cfg.eval_metric = args.eval_metric + train_cfg.subsample_dev = args.subsample_dev + + # debug with in task: tasks = [expert_names[0]] + if args.merge_or_route == "uniform_sparse_weight_oracle_routing": + scores = eval_in_distribution_sparse_model( + module, library, expert, train_cfg, tasks + ) + else: + scores = eval_in_distribution(module, train_cfg, tasks) + + elif args.pipeline_eval_tasks in [ + "out_distribution", + ]: + # give eval tasks in `finetune_task_name` argument + if isinstance(args.finetune_task_name, tuple): + tasks = list(args.finetune_task_name) + elif isinstance(args.finetune_task_name, str): + tasks = args.finetune_task_name.split(",") + + train_cfg.eval_metric = args.eval_metric + train_cfg.subsample_dev = args.subsample_dev + scores = eval_in_distribution(module, train_cfg, tasks) + + else: + if args.pipeline_eval_tasks == "all": + args.pipeline_eval_tasks = "arc-challenge,arc-easy,boolq,hellaswag,humaneval,mbpp,openbookqa,piqa,bbh-fast,winogrande" + + with torch.no_grad(): + runner: EvaluatorRunner = setup_evaluators( + model_type=module.hparams.model, + model_family=module.hparams.model_family, + max_input_length=module.hparams.max_input_length, + max_output_length=module.hparams.max_output_length, + predict_batch_size=args.predict_batch_size, + truncation_side=module.hparams.truncation_side, + tasks=args.pipeline_eval_tasks, + output_path=os.path.join(args.output_dir, "DOWNSTREAM"), + add_eos_to_targets=args.add_eos_to_downstream_targets, + ) + scores = runner.run(module) + + if len(metric_logger) > 0: + task_table = metric_logger.pretty_table(match_on="task|.*uniform.*") + layer_table = metric_logger.pretty_table(match_on="layer|.*uniform.*") + expert_p = metric_logger.pretty_table(match_on=".*expert_p|.*uniform.*") + angle = metric_logger.pretty_table(match_on=".*angle.*") + print(task_table) + print(layer_table) + print(expert_p) + print(angle) + + if wandb.run is not None: + if scores is not None: + wandb.log({f"downstream/{k}": v for k, v in scores.items()}) + if len(metric_logger) > 0: + wandb.log({k: v.avg for k, v in metric_logger.meters.items()}) + + wandb.finish() + + +if __name__ == "__main__": + args = ExpertConfig.parse() + run_eval(args) diff --git a/projects/sparse_finetuning/requirements.txt b/projects/sparse_finetuning/requirements.txt index af0c8f951..8450d7b61 100644 --- a/projects/sparse_finetuning/requirements.txt +++ b/projects/sparse_finetuning/requirements.txt @@ -1,29 +1,28 @@ -transformers==4.47.1 # need this specific version to avoid phi3 issues https://github.com/huggingface/transformers/issues/36071 -torch>=2.3.1 -datasets>=2.20.0 -pytorch-lightning>=2.3.3 -accelerate -deepspeed -huggingface_hub>=0.26.2 -click -wandb -rouge +pytorch-lightning==2.4.0 +transformers==4.44.2 +# huggingface_hub==0.24.7 +huggingface_hub==0.26.2 +datasets==3.0.0 +bitsandbytes==0.43.3 +accelerate==0.34.2 +#vllm==0.2.4 +wandb==0.18.0 +azure-storage-blob==12.22.0 +azure-identity==1.17.1 +pyparsing==3.1.4 +peft==0.12.0 +openai==1.45.1 +pytest-mock==3.14.0 +sentence-transformers==3.1.0 +prettytable==3.11.0 +matplotlib==3.7.5 +seaborn==0.13.2 +nltk==3.9.1 +einops tqdm pandas -sentence-transformers -fsspec[adl] -prettytable -rich -bitsandbytes -matplotlib -openai -ray -nevergrad -evaluate -seaborn -azure-storage-blob -azure-identity -einops -triton -nltk -# spops @ git+https://github.com/IST-DASLab/spops.git@main +rouge==1.0.1 +mistral_inference==1.5.0 +pytest==8.3.3 +triton==3.1.0 +torchvision diff --git a/projects/sparse_finetuning/train_experts_main.py b/projects/sparse_finetuning/train_experts_main.py index 1e14928a4..3f50146f9 100644 --- a/projects/sparse_finetuning/train_experts_main.py +++ b/projects/sparse_finetuning/train_experts_main.py @@ -29,6 +29,7 @@ from projects.modular_llm.compute_transfer_matrix import ( run_eval as produce_transfer_matrix, ) +from mttl.models.lightning.callbacks import UpdateSparseMask def create_transfer_matrix(args, checkpoint): @@ -77,12 +78,12 @@ def run_multitask(args: ExpertConfig): ) loggers = get_pl_loggers(args) - model_class = ExpertModule dm = get_datamodule(args) args.n_tasks = len(dm._task_names) args.task_names = dm._task_names + model_class = ExpertModule module = model_class(**vars(args)) # get metric monitors for models @@ -94,18 +95,22 @@ def run_multitask(args: ExpertConfig): monitor = "val/loss" mode = "min" - # -=============== Iterative masking using Callback ==================== + # ---------------------------------- + # Iterative masking using Callback + # ---------------------------------- # NOTE: Don't move this block, it's important we call maskCallBack before others - from mttl.models.lightning.callbacks import UpdateSparseMask + if args.use_sparse_model: + assert len(args.task_names) == 1, print( + "sparse mask does not support more than 1 task" + ) + maskCallback = UpdateSparseMask( + update_interval=100, + num_train_steps=len(dm.train_dataloader()), + save_mask_dir=args.library_id, + task_name=args.task_names[0], + parameter_selection_procedure=args.parameter_selection_procedure, + ) # "per_layer"/"model" use "per_layer" for default - assert len(args.task_names) == 1, print( - "sparse mask does not support more than 1 task" - ) - maskCallback = UpdateSparseMask( - update_interval=100, - task_name=args.task_names[0], - parameter_selection_procedure="per_layer", - ) # "per_layer"/"model" use "per_layer" for default callbacks.append(maskCallback) checkpoint_callback = LiveCheckpointCallback(