From 7053588de9ea25d7330e77eca0bf21136015d99a Mon Sep 17 00:00:00 2001 From: Radi Cho Date: Thu, 29 Aug 2024 00:03:31 +0300 Subject: [PATCH 1/3] Add learnable scales functionality. --- flute/integrations/base.py | 159 ++++++++++++++++++++++++++++++++++++- flute/nf_utils.py | 12 ++- requirements.txt | 1 + 3 files changed, 163 insertions(+), 9 deletions(-) diff --git a/flute/integrations/base.py b/flute/integrations/base.py index ba001d4..afc6062 100644 --- a/flute/integrations/base.py +++ b/flute/integrations/base.py @@ -1,9 +1,15 @@ import os +import math import json import torch import warnings import argparse + +from tqdm.auto import tqdm + +from datasets import load_dataset from transformers import ( + AutoTokenizer, LlamaForCausalLM, Gemma2ForCausalLM, AutoModelForCausalLM) @@ -36,6 +42,152 @@ def get_accelerate_hook(name: str, module: torch.nn.Module, allow: bool) -> Opti return hook +class LearnableQuantizedLinear(torch.nn.Module): + in_features : int + out_features: int + + num_bits : int + group_size : int + symmetric : bool + weight : torch.Tensor + scales : torch.Tensor + values : torch.Tensor + pivots : torch.Tensor + + def __init__( + self, + in_features: int, + out_features: int, + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + scales: Optional[torch.Tensor] = None, + num_bits: int = 4, + group_size: int = 64, + symmetric: bool = False + ): + super(LearnableQuantizedLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.group_size = group_size + self.symmetric = symmetric + self.num_bits = num_bits + + self.values, self.pivots = flute.nf_utils.get_values_pivots(num_bits, False, dtype=torch.bfloat16) + + if weight is None: + self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) + torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + else: + assert weight.dtype == torch.bfloat16, "Training is currently only supported in bfloat16!" + self.weight = torch.nn.Parameter(weight, requires_grad=False) + + if scales is None: + self.scales = torch.nn.Parameter(torch.max(torch.abs(self.weight.view(-1, self.group_size)), dim=1, keepdim=True).values, requires_grad=True) # * nf4_constant + else: + self.scales = torch.nn.Parameter(scales, requires_grad=True) + + if bias is None: + self.register_parameter('bias', None) + else: + self.bias = torch.nn.Parameter(bias, requires_grad=False) + + + def forward(self, inp: torch.Tensor): + qweight = flute.nf_utils.manual_nf4(self.weight, absmax=self.scales, bits=self.num_bits, blocksize=self.group_size, values=self.values, pivots=self.pivots) + + return torch.nn.functional.linear(inp, qweight, self.bias) + + +def get_parent(module: torch.nn.Module, name_split: str): + for n in name_split: + module = getattr(module, n) + return module + + +def learn_scales( + model: AutoModelForCausalLM, + tokenizer: AutoTokenizer, + num_bits: int = 4, + group_size: int = 64, + custom_corpora: Optional[str] = None, + epochs: int = 1, + lr: float = 0.0001, + iters: int = 128, + logging: bool = False, +) -> None: + layer_types = ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"] + + for param in model.parameters(): + param.requires_grad = False + + print("Adding tunable scales to the linear layers...") + for name, module in model.named_modules(): + name_split = name.split('.') + if name_split[-1] in layer_types: + q_layer = LearnableQuantizedLinear( + module.in_features, + module.out_features, + weight=module.weight.data, + bias=module.bias.data if module.bias is not None else None, + num_bits=num_bits, + group_size=group_size + ) + + parent = get_parent(model, name_split[:-1]) + setattr(parent, name_split[-1], q_layer) + + print("Tokenizing corpora...") + if custom_corpora == None: + train = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") + corpora = [tokenizer("\n\n".join(train["text"]), return_tensors="pt")] + else: + corpora = [tokenizer(corpus, return_tensors="pt") for corpus in custom_corpora] + + print("Prepare model for training...") + model.train() + model.gradient_checkpointing_enable() + model.enable_input_require_grads() + + optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr) + max_length = 2048 + device = model.device + bos_token_id = tokenizer.bos_token_id + + # Use BOS token in each sequence - especially important for Gemma + stride = max_length - 1 + seq_len = iters * (max_length - 1) + + for epoch in range(epochs): + print(f"Running epoch {epoch}...") + + prev_end_loc = 0 + for begin_loc in tqdm(range(0, seq_len, stride)): + for encodings in corpora: + optimizer.zero_grad() + + end_loc = min(begin_loc + stride, seq_len) + trg_len = end_loc - prev_end_loc + input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device) + input_ids = torch.concat([torch.tensor([bos_token_id], dtype=torch.int64).unsqueeze(0).to(device), input_ids], dim=1) + + target_ids = input_ids.clone() + target_ids[:, :-trg_len] = -100 + + outputs = model(input_ids, labels=target_ids) + neg_log_likelihood = outputs.loss + neg_log_likelihood.backward() + + optimizer.step() + + if logging: + print(f"Step loss: {neg_log_likelihood.item()}.") + + prev_end_loc = end_loc + + if end_loc == seq_len: + break + + # 2/4 @torch.no_grad() def prepare_model_flute( @@ -62,7 +214,7 @@ def _replace_linear(_name: str, _module: torch.nn.Module) -> None: child_full_name = f"{_name}.{child_name}" - if isinstance(child, torch.nn.Linear): + if isinstance(child, torch.nn.Linear) or isinstance(child, LearnableQuantizedLinear): if isinstance(child, BNBLinear4bit): if child.weight.dtype not in [torch.uint8]: @@ -130,7 +282,10 @@ def _replace_linear(_name: str, _module: torch.nn.Module) -> None: if custom_scales_dict is not None: custom_scales = custom_scales_dict[child_full_name] else: - custom_scales = None + if isinstance(child, LearnableQuantizedLinear): + custom_scales = child.scales + else: + custom_scales = None if not isinstance(child, BNBLinear4bit): _, _Q, scales, qmap = flute.nf_utils.nf_quantize( diff --git a/flute/nf_utils.py b/flute/nf_utils.py index f95a3f3..dc0a736 100644 --- a/flute/nf_utils.py +++ b/flute/nf_utils.py @@ -1,11 +1,9 @@ import torch from typing import Tuple, Optional -DTYPE = torch.float32 - -def linspace(start, stop, num): - steps = torch.arange(num, dtype=DTYPE, device=start.device) / (num - 1) +def linspace(start, stop, num, dtype=torch.float32): + steps = torch.arange(num, dtype=dtype, device=start.device) / (num - 1) for i in range(start.ndim): steps = steps.unsqueeze(-1) @@ -13,13 +11,13 @@ def linspace(start, stop, num): return start[None] + steps*(stop - start)[None] -def get_values_pivots(bits=4, symmetric=False): +def get_values_pivots(bits=4, symmetric=False, dtype=torch.float32): dist = torch.distributions.normal.Normal(torch.tensor(0.0), torch.tensor(1.0)) offset = 0.5*(1/32 + 1/30) if symmetric: - v = dist.icdf(linspace(torch.tensor(offset), torch.tensor(1 - offset), 2**(bits))) + v = dist.icdf(linspace(torch.tensor(offset), torch.tensor(1 - offset), 2**(bits), dtype=dtype)) else: v1 = -1 * dist.icdf(linspace(torch.tensor(1-offset), torch.tensor(0.5), 2**(bits-1))) v2 = dist.icdf(linspace(torch.tensor(0.5), torch.tensor(1-offset), 2**(bits-1)+1)[1:]) @@ -31,7 +29,7 @@ def get_values_pivots(bits=4, symmetric=False): v = torch.tensor([-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0]) p = (v[1:] + v[:-1]) / 2 - return v.to(DTYPE).cuda().clone(), p.to(DTYPE).cuda().clone() + return v.to(dtype).cuda().clone(), p.to(dtype).cuda().clone() def manual_nf4(inp, absmax=None, bits=4, blocksize=128, return_stats=False, values=None, pivots=None): diff --git a/requirements.txt b/requirements.txt index 73e47af..f498005 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ wheel jaxtyping matplotlib transformers +datasets accelerate vllm >= 0.5.3.post1 bitsandbytes From 75a5adc45e83f1dc0105e0ba3561ad6a8a1d32f7 Mon Sep 17 00:00:00 2001 From: Radi Cho Date: Thu, 29 Aug 2024 18:29:23 +0300 Subject: [PATCH 2/3] Separate learning in a new file. --- flute/integrations/base.py | 153 +------------------------------ flute/integrations/learnable.py | 156 ++++++++++++++++++++++++++++++++ 2 files changed, 158 insertions(+), 151 deletions(-) create mode 100644 flute/integrations/learnable.py diff --git a/flute/integrations/base.py b/flute/integrations/base.py index afc6062..36167e5 100644 --- a/flute/integrations/base.py +++ b/flute/integrations/base.py @@ -1,15 +1,10 @@ import os -import math import json import torch import warnings import argparse -from tqdm.auto import tqdm - -from datasets import load_dataset from transformers import ( - AutoTokenizer, LlamaForCausalLM, Gemma2ForCausalLM, AutoModelForCausalLM) @@ -24,7 +19,9 @@ import flute.utils import flute.nf_utils import flute.integrations.bitsandbytes +import flute.integrations.learnable +from flute.integrations.learnable import LearnableQuantizedLinear def get_accelerate_hook(name: str, module: torch.nn.Module, allow: bool) -> Optional[ModelHook]: @@ -42,152 +39,6 @@ def get_accelerate_hook(name: str, module: torch.nn.Module, allow: bool) -> Opti return hook -class LearnableQuantizedLinear(torch.nn.Module): - in_features : int - out_features: int - - num_bits : int - group_size : int - symmetric : bool - weight : torch.Tensor - scales : torch.Tensor - values : torch.Tensor - pivots : torch.Tensor - - def __init__( - self, - in_features: int, - out_features: int, - weight: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - scales: Optional[torch.Tensor] = None, - num_bits: int = 4, - group_size: int = 64, - symmetric: bool = False - ): - super(LearnableQuantizedLinear, self).__init__() - self.in_features = in_features - self.out_features = out_features - self.group_size = group_size - self.symmetric = symmetric - self.num_bits = num_bits - - self.values, self.pivots = flute.nf_utils.get_values_pivots(num_bits, False, dtype=torch.bfloat16) - - if weight is None: - self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) - torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) - else: - assert weight.dtype == torch.bfloat16, "Training is currently only supported in bfloat16!" - self.weight = torch.nn.Parameter(weight, requires_grad=False) - - if scales is None: - self.scales = torch.nn.Parameter(torch.max(torch.abs(self.weight.view(-1, self.group_size)), dim=1, keepdim=True).values, requires_grad=True) # * nf4_constant - else: - self.scales = torch.nn.Parameter(scales, requires_grad=True) - - if bias is None: - self.register_parameter('bias', None) - else: - self.bias = torch.nn.Parameter(bias, requires_grad=False) - - - def forward(self, inp: torch.Tensor): - qweight = flute.nf_utils.manual_nf4(self.weight, absmax=self.scales, bits=self.num_bits, blocksize=self.group_size, values=self.values, pivots=self.pivots) - - return torch.nn.functional.linear(inp, qweight, self.bias) - - -def get_parent(module: torch.nn.Module, name_split: str): - for n in name_split: - module = getattr(module, n) - return module - - -def learn_scales( - model: AutoModelForCausalLM, - tokenizer: AutoTokenizer, - num_bits: int = 4, - group_size: int = 64, - custom_corpora: Optional[str] = None, - epochs: int = 1, - lr: float = 0.0001, - iters: int = 128, - logging: bool = False, -) -> None: - layer_types = ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"] - - for param in model.parameters(): - param.requires_grad = False - - print("Adding tunable scales to the linear layers...") - for name, module in model.named_modules(): - name_split = name.split('.') - if name_split[-1] in layer_types: - q_layer = LearnableQuantizedLinear( - module.in_features, - module.out_features, - weight=module.weight.data, - bias=module.bias.data if module.bias is not None else None, - num_bits=num_bits, - group_size=group_size - ) - - parent = get_parent(model, name_split[:-1]) - setattr(parent, name_split[-1], q_layer) - - print("Tokenizing corpora...") - if custom_corpora == None: - train = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") - corpora = [tokenizer("\n\n".join(train["text"]), return_tensors="pt")] - else: - corpora = [tokenizer(corpus, return_tensors="pt") for corpus in custom_corpora] - - print("Prepare model for training...") - model.train() - model.gradient_checkpointing_enable() - model.enable_input_require_grads() - - optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr) - max_length = 2048 - device = model.device - bos_token_id = tokenizer.bos_token_id - - # Use BOS token in each sequence - especially important for Gemma - stride = max_length - 1 - seq_len = iters * (max_length - 1) - - for epoch in range(epochs): - print(f"Running epoch {epoch}...") - - prev_end_loc = 0 - for begin_loc in tqdm(range(0, seq_len, stride)): - for encodings in corpora: - optimizer.zero_grad() - - end_loc = min(begin_loc + stride, seq_len) - trg_len = end_loc - prev_end_loc - input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device) - input_ids = torch.concat([torch.tensor([bos_token_id], dtype=torch.int64).unsqueeze(0).to(device), input_ids], dim=1) - - target_ids = input_ids.clone() - target_ids[:, :-trg_len] = -100 - - outputs = model(input_ids, labels=target_ids) - neg_log_likelihood = outputs.loss - neg_log_likelihood.backward() - - optimizer.step() - - if logging: - print(f"Step loss: {neg_log_likelihood.item()}.") - - prev_end_loc = end_loc - - if end_loc == seq_len: - break - - # 2/4 @torch.no_grad() def prepare_model_flute( diff --git a/flute/integrations/learnable.py b/flute/integrations/learnable.py new file mode 100644 index 0000000..d2a9e47 --- /dev/null +++ b/flute/integrations/learnable.py @@ -0,0 +1,156 @@ +import math +import torch + +from tqdm.auto import tqdm +from datasets import load_dataset +from transformers import AutoTokenizer, AutoModelForCausalLM + +from typing import Optional + +import flute +import flute.nf_utils + +class LearnableQuantizedLinear(torch.nn.Module): + in_features : int + out_features: int + + num_bits : int + group_size : int + symmetric : bool + weight : torch.Tensor + scales : torch.Tensor + values : torch.Tensor + pivots : torch.Tensor + + def __init__( + self, + in_features: int, + out_features: int, + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + scales: Optional[torch.Tensor] = None, + num_bits: int = 4, + group_size: int = 64, + symmetric: bool = False + ): + super(LearnableQuantizedLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.group_size = group_size + self.symmetric = symmetric + self.num_bits = num_bits + + self.values, self.pivots = flute.nf_utils.get_values_pivots(num_bits, False, dtype=torch.bfloat16) + + if weight is None: + self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) + torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + else: + assert weight.dtype == torch.bfloat16, "Training is currently only supported in bfloat16!" + self.weight = torch.nn.Parameter(weight, requires_grad=False) + + if scales is None: + self.scales = torch.nn.Parameter(torch.max(torch.abs(self.weight.view(-1, self.group_size)), dim=1, keepdim=True).values, requires_grad=True) # * nf4_constant + else: + self.scales = torch.nn.Parameter(scales, requires_grad=True) + + if bias is None: + self.register_parameter('bias', None) + else: + self.bias = torch.nn.Parameter(bias, requires_grad=False) + + + def forward(self, inp: torch.Tensor): + qweight = flute.nf_utils.manual_nf4(self.weight, absmax=self.scales, bits=self.num_bits, blocksize=self.group_size, values=self.values, pivots=self.pivots) + + return torch.nn.functional.linear(inp, qweight, self.bias) + + +def get_parent(module: torch.nn.Module, name_split: str): + for n in name_split: + module = getattr(module, n) + return module + + +def learn_scales( + model: AutoModelForCausalLM, + tokenizer: AutoTokenizer, + num_bits: int = 4, + group_size: int = 64, + custom_corpora: Optional[str] = None, + epochs: int = 1, + lr: float = 0.0001, + samples: int = 128, + logging: bool = False, +) -> None: + layer_types = ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"] + + for param in model.parameters(): + param.requires_grad = False + + print("Adding tunable scales to the linear layers...") + for name, module in model.named_modules(): + name_split = name.split('.') + if name_split[-1] in layer_types: + q_layer = LearnableQuantizedLinear( + module.in_features, + module.out_features, + weight=module.weight.data, + bias=module.bias.data if module.bias is not None else None, + num_bits=num_bits, + group_size=group_size + ) + + parent = get_parent(model, name_split[:-1]) + setattr(parent, name_split[-1], q_layer) + + print("Tokenizing corpora...") + if custom_corpora == None: + train = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") + corpora = [tokenizer("\n\n".join(train["text"]), return_tensors="pt")] + else: + corpora = [tokenizer(corpus, return_tensors="pt") for corpus in custom_corpora] + + print("Prepare model for training...") + model.train() + model.gradient_checkpointing_enable() + model.enable_input_require_grads() + + optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr) + max_length = 2048 + device = model.device + bos_token_id = tokenizer.bos_token_id + + # Use BOS token in each sequence - especially important for Gemma + stride = max_length - 1 + seq_len = samples * (max_length - 1) + + for epoch in range(epochs): + print(f"Running epoch {epoch}...") + + prev_end_loc = 0 + for begin_loc in tqdm(range(0, seq_len, stride)): + for encodings in corpora: + optimizer.zero_grad() + + end_loc = min(begin_loc + stride, seq_len) + trg_len = end_loc - prev_end_loc + input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device) + input_ids = torch.concat([torch.tensor([bos_token_id], dtype=torch.int64).unsqueeze(0).to(device), input_ids], dim=1) + + target_ids = input_ids.clone() + target_ids[:, :-trg_len] = -100 + + outputs = model(input_ids, labels=target_ids) + neg_log_likelihood = outputs.loss + neg_log_likelihood.backward() + + optimizer.step() + + if logging: + print(f"Step loss: {neg_log_likelihood.item()}.") + + prev_end_loc = end_loc + + if end_loc == seq_len: + break \ No newline at end of file From 15d2acf8120a20da73d278388a8419f09456cb53 Mon Sep 17 00:00:00 2001 From: Radi Cho Date: Thu, 29 Aug 2024 18:30:04 +0300 Subject: [PATCH 3/3] Add an example notebook. --- examples/learnable_scales_eval.ipynb | 281 +++++++++++++++++++++++++++ 1 file changed, 281 insertions(+) create mode 100644 examples/learnable_scales_eval.ipynb diff --git a/examples/learnable_scales_eval.ipynb b/examples/learnable_scales_eval.ipynb new file mode 100644 index 0000000..73a9988 --- /dev/null +++ b/examples/learnable_scales_eval.ipynb @@ -0,0 +1,281 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "5ee00c71-2956-4d0f-936f-f19fe5b9ab99", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "\n", + "torch.cuda.device_count()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "f3bf3fb3-d959-465d-8389-3fde5c87d4ea", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a6d7a23e9d954a74b5fa683a96fb7cc7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/4 [00:00 131072). Running this sequence through the model will result in indexing errors\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prepare model for training...\n", + "Running epoch 0...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ae9ada499bd94ad1870a956d653eb1f9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/128 [00:00