diff --git a/format.sh b/format.sh index 9822b67..5f2c3e0 100644 --- a/format.sh +++ b/format.sh @@ -1,5 +1,19 @@ -yapf --recursive . --style='{based_on_style: google, column_limit: 120, indent_width: 4}' -i +#!/bin/bash +# Format Python files using yapf +echo "Running yapf..." +find . -type f -name "*.py" \ + ! -path "./build/*" \ + ! -path "./.git/*" \ + ! -path "*.egg-info/*" \ + -print0 | xargs -0 yapf --in-place + +# Format Python imports using isort +echo "Running isort..." isort . +# Format C++ files using clang-format +echo "Formatting C++ files..." find csrc/ \( -name '*.h' -o -name '*.cc' -o -name '*.cu' -o -name '*.cuh' \) -print | xargs clang-format -i + +echo "Formatting complete!" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 0632ace..205e102 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,8 @@ column_limit = 120 indent_width = 4 based_on_style = "google" split_before_logical_operator = false - +dedent_closing_brackets = true +coalesce_brackets = true [tool.codespell] ignore-words-list = "ist" @@ -58,3 +59,9 @@ skip = "./VPTQ_arxiv.pdf,./build" [tool.isort] use_parentheses = true skip_gitignore = true +line_length = 120 +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +combine_as_imports = true +ensure_newline_before_comments = true diff --git a/setup.py b/setup.py index ec271a9..9750bd9 100644 --- a/setup.py +++ b/setup.py @@ -65,8 +65,9 @@ def build_cuda_extensions(): if torch.cuda.is_available() and torch.version.hip is not None: extra_compile_args["nvcc"].extend(["-fbracket-depth=1024"]) else: - extra_compile_args["nvcc"].extend( - ["--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", "-lineinfo"]) + extra_compile_args["nvcc"].extend([ + "--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", "-lineinfo" + ]) extensions = CUDAExtension( "vptq.ops", diff --git a/vptq/__init__.py b/vptq/__init__.py index 3cda50f..9a18e94 100644 --- a/vptq/__init__.py +++ b/vptq/__init__.py @@ -4,4 +4,4 @@ # -------------------------------------------------------------------------- __version__ = "0.0.2.post1" -from .layers import AutoModelForCausalLM as AutoModelForCausalLM +from vptq.layers import AutoModelForCausalLM as AutoModelForCausalLM diff --git a/vptq/__main__.py b/vptq/__main__.py index e125043..bdf2ecf 100644 --- a/vptq/__main__.py +++ b/vptq/__main__.py @@ -3,6 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from .app_utils import main +from vptq.app_utils import main main() diff --git a/vptq/app.py b/vptq/app.py index 7cf8272..18ee374 100644 --- a/vptq/app.py +++ b/vptq/app.py @@ -9,8 +9,7 @@ import gradio as gr from huggingface_hub import snapshot_download -from vptq.app_gpu import disable_gpu_info, enable_gpu_info -from vptq.app_gpu import update_charts as _update_charts +from vptq.app_gpu import disable_gpu_info, enable_gpu_info, update_charts as _update_charts from vptq.app_utils import get_chat_loop_generator models = [ @@ -114,11 +113,11 @@ def respond( response = "" for message in chat_completion( - messages, - max_tokens=max_tokens, - stream=True, - temperature=temperature, - top_p=top_p, + messages, + max_tokens=max_tokens, + stream=True, + temperature=temperature, + top_p=top_p, ): token = message diff --git a/vptq/app_gpu.py b/vptq/app_gpu.py index 64434c9..088ecd3 100644 --- a/vptq/app_gpu.py +++ b/vptq/app_gpu.py @@ -82,13 +82,15 @@ def update_charts(chart_height: int = 200) -> go.Figure: titlefont=dict(color='blue'), tickfont=dict(color='blue'), ), - yaxis2=dict(title='Memory Usage (GiB)', - range=[0, max(24, - max(mem_usage_history) + 1)], - titlefont=dict(color='red'), - tickfont=dict(color='red'), - overlaying='y', - side='right'), + yaxis2=dict( + title='Memory Usage (GiB)', + range=[0, max(24, + max(mem_usage_history) + 1)], + titlefont=dict(color='red'), + tickfont=dict(color='red'), + overlaying='y', + side='right' + ), height=chart_height, # set the height of the chart margin=dict(l=10, r=10, t=0, b=0), # set the margin of the chart showlegend=False # disable the legend diff --git a/vptq/app_utils.py b/vptq/app_utils.py index 6271279..ec85017 100644 --- a/vptq/app_utils.py +++ b/vptq/app_utils.py @@ -9,7 +9,7 @@ import transformers -from .layers.model_base import AutoModelForCausalLM as VQAutoModelQuantization +from vptq.layers.model_base import AutoModelForCausalLM as VQAutoModelQuantization def define_basic_args(): @@ -21,10 +21,9 @@ def define_basic_args(): """, formatter_class=argparse.RawTextHelpFormatter, ) - parser.add_argument("--model", - type=str, - required=True, - help="float/float16 model to load, such as [mosaicml/mpt-7b]") + parser.add_argument( + "--model", type=str, required=True, help="float/float16 model to load, such as [mosaicml/mpt-7b]" + ) parser.add_argument("--tokenizer", type=str, default="", help="default same as [model]") parser.add_argument("--prompt", type=str, default="once upon a time, there ", help="prompt to start generation") parser.add_argument("--chat", action="store_true", help="chat with the model") @@ -62,11 +61,9 @@ def chat_loop(model, tokenizer, args): encodeds = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt") model_inputs = encodeds.to(model.device) print("assistant: ", end='') - generated_ids = model.generate(model_inputs, - streamer=streamer, - pad_token_id=2, - max_new_tokens=500, - do_sample=True) + generated_ids = model.generate( + model_inputs, streamer=streamer, pad_token_id=2, max_new_tokens=500, do_sample=True + ) decoded = tokenizer.batch_decode(generated_ids[:, model_inputs.shape[-1]:], skip_special_tokens=True) messages.append({"role": "assistant", "content": decoded[0]}) @@ -83,11 +80,9 @@ def get_chat_loop_generator(model_id): if getattr(tokenizer, "chat_template", None) is None: raise Exception("warning: this tokenizer didn't provide chat_template.!!!") - def chat_loop_generator(messages, - max_tokens: int, - stream: bool = True, - temperature: float = 1.0, - top_p: float = 1.0): + def chat_loop_generator( + messages, max_tokens: int, stream: bool = True, temperature: float = 1.0, top_p: float = 1.0 + ): print("============================chat with the model============================") print("Press 'exit' to quit") @@ -99,13 +94,15 @@ def chat_loop_generator(messages, return_dict=True, ) model_inputs = encodeds.to(model.device) - generation_kwargs = dict(model_inputs, - streamer=streamer, - max_new_tokens=max_tokens, - pad_token_id=2, - do_sample=True, - temperature=temperature, - top_p=top_p) + generation_kwargs = dict( + model_inputs, + streamer=streamer, + max_new_tokens=max_tokens, + pad_token_id=2, + do_sample=True, + temperature=temperature, + top_p=top_p + ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() for new_text in streamer: diff --git a/vptq/layers/__init__.py b/vptq/layers/__init__.py index f68bf61..c1dbe2e 100644 --- a/vptq/layers/__init__.py +++ b/vptq/layers/__init__.py @@ -3,4 +3,4 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from .model_base import AutoModelForCausalLM as AutoModelForCausalLM \ No newline at end of file +from vptq.model_base import AutoModelForCausalLM as AutoModelForCausalLM \ No newline at end of file diff --git a/vptq/layers/model_base.py b/vptq/layers/model_base.py index 03f0880..ba09327 100644 --- a/vptq/layers/model_base.py +++ b/vptq/layers/model_base.py @@ -14,7 +14,7 @@ import transformers from tqdm import tqdm -from .vqlinear import VQuantLinear +from vptq.vqlinear import VQuantLinear def set_op_by_name(layer, name, new_module): @@ -32,9 +32,9 @@ def set_op_by_name(layer, name, new_module): def make_quant_linear(module, quant_conf, name="", target_layer=None): - for module_name, sub_module in tqdm(module.named_modules(), - total=len(list(module.named_modules())), - desc="Replacing linear layers..."): + for module_name, sub_module in tqdm( + module.named_modules(), total=len(list(module.named_modules())), desc="Replacing linear layers..." + ): if module_name in quant_conf: layer_conf = quant_conf[module_name] new_module = target_layer(**layer_conf, enable_proxy_error=False, dtype=sub_module.weight.dtype) @@ -124,9 +124,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): checkpoint = pretrained_model_name_or_path else: # remote token_arg = {"token": kwargs.get("token", None)} - checkpoint = huggingface_hub.snapshot_download(repo_id=pretrained_model_name_or_path, - ignore_patterns=["*.bin"], - **token_arg) + checkpoint = huggingface_hub.snapshot_download( + repo_id=pretrained_model_name_or_path, ignore_patterns=["*.bin"], **token_arg + ) weight_bins = glob.glob(str(Path(checkpoint).absolute() / "*.safetensors")) index_json = glob.glob(str(Path(checkpoint).absolute() / "*.index.json")) pytorch_model_bin = glob.glob(str(Path(checkpoint).absolute() / "pytorch_model.bin")) @@ -148,13 +148,15 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): max_memory = local_max_memory accelerate.hooks.attach_execution_device_hook = attach_execution_device_hook - model = accelerate.load_checkpoint_and_dispatch(model, - checkpoint=checkpoint, - device_map=device_map, - max_memory=max_memory, - no_split_module_classes=no_split_module_classes[0], - dtype=torch_dtype, - preload_module_classes=["VQuantLinear"]) + model = accelerate.load_checkpoint_and_dispatch( + model, + checkpoint=checkpoint, + device_map=device_map, + max_memory=max_memory, + no_split_module_classes=no_split_module_classes[0], + dtype=torch_dtype, + preload_module_classes=["VQuantLinear"] + ) # check cuda kernel exist if importlib.util.find_spec("vptq.ops") is not None: diff --git a/vptq/layers/vqlinear.py b/vptq/layers/vqlinear.py index 59f42ae..27d5850 100644 --- a/vptq/layers/vqlinear.py +++ b/vptq/layers/vqlinear.py @@ -131,8 +131,9 @@ def __init__( assert self.vector_quant_dim == "out", "Currently outlier only support vector quant on out_features" assert self.outlier_num_res_centroids == -1, "Currently do not support residual quant on outliers" - self.outlier_centroids = nn.Embedding(1, self.outlier_num_centroids * self.outlier_vector_len, - **factory_kwargs) + self.outlier_centroids = nn.Embedding( + 1, self.outlier_num_centroids * self.outlier_vector_len, **factory_kwargs + ) # all index and perm are uint16 to avoid nccl and safetensor check # we view them as float16 or int16 @@ -174,8 +175,9 @@ def __init__( assert True, "Not implemented" else: perm_dtype = torch.int16 if self.is_indice_packed else torch.int64 - self.perm = Parameter(torch.arange(self.in_features, device=device, dtype=perm_dtype), - requires_grad=False) + self.perm = Parameter( + torch.arange(self.in_features, device=device, dtype=perm_dtype), requires_grad=False + ) # indices shape # self.num_indices in each codebook @@ -224,8 +226,9 @@ def __init__( # set residual centroids and indices if self.enable_residual: - self.res_centroids = nn.Embedding(self.num_codebooks, self.num_res_centroids * self.vector_len, - **factory_kwargs) + self.res_centroids = nn.Embedding( + self.num_codebooks, self.num_res_centroids * self.vector_len, **factory_kwargs + ) if self.is_indice_packed is False: if self.indices_as_float: @@ -271,11 +274,15 @@ def init_parameters( outlier_indices = indices[0] if self.indices_as_float: - outlier_indices = (outlier_indices.clone().detach().to(torch.uint16).view(torch.float16).to( - self.outlier_centroids.weight.device)) + outlier_indices = ( + outlier_indices.clone().detach().to(torch.uint16).view(torch.float16 + ).to(self.outlier_centroids.weight.device) + ) else: - outlier_indices = (outlier_indices.clone().detach().to(torch.uint16).view(torch.int16).to( - self.outlier_centroids.weight.device)) + outlier_indices = ( + outlier_indices.clone().detach().to(torch.uint16).view(torch.int16 + ).to(self.outlier_centroids.weight.device) + ) if len(outlier_indices.shape) == 2: outlier_indices = outlier_indices.unsqueeze(0) @@ -348,11 +355,13 @@ def init_parameters( _res_indices = _res_indices.reshape(self.num_codebooks, self.num_indices, self.group_size) if self.indices_as_float: - self.res_indices.data = (_res_indices.to(torch.uint16).view(torch.float16).to( - self.res_centroids.weight.device)) + self.res_indices.data = ( + _res_indices.to(torch.uint16).view(torch.float16).to(self.res_centroids.weight.device) + ) else: - self.res_indices.data = (_res_indices.to(torch.uint16).view(torch.int16).to( - self.res_centroids.weight.device)) + self.res_indices.data = ( + _res_indices.to(torch.uint16).view(torch.int16).to(self.res_centroids.weight.device) + ) if self.enable_norm: self.weight_scale.data = weight_scale.to(self.centroids.weight.device) @@ -371,8 +380,9 @@ def set_centroids_grad(self, requires_grad): # TODO: FIX def post_init(self): if not hasattr(self, "invert_perm"): - self.invert_perm = (torch.argsort(self.perm.view(torch.uint16).to(torch.int64)).to(torch.uint16).view( - torch.int16)) + self.invert_perm = ( + torch.argsort(self.perm.view(torch.uint16).to(torch.int64)).to(torch.uint16).view(torch.int16) + ) # if self.indices.dtype != torch.int: # self.short_indices = self.indices.view( # torch.int16) if self.indices_as_fp16 else self.indices.short() @@ -394,10 +404,14 @@ def fast_gemv(self, x): return None self.post_init() centroids = self.centroids.weight.view(self.num_codebooks, self.num_centroids, self.vector_len) - res_centroids = (self.res_centroids.weight.view(self.num_codebooks, self.num_res_centroids, self.vector_len) - if self.res_centroids is not None else None) - outlier_centroids = (self.outlier_centroids.weight.view(1, self.outlier_num_centroids, self.outlier_vector_len) - if hasattr(self, "outlier_centroids") else None) + res_centroids = ( + self.res_centroids.weight.view(self.num_codebooks, self.num_res_centroids, self.vector_len) + if self.res_centroids is not None else None + ) + outlier_centroids = ( + self.outlier_centroids.weight.view(1, self.outlier_num_centroids, self.outlier_vector_len) + if hasattr(self, "outlier_centroids") else None + ) if self.indices.dtype == torch.int: indices = self.indices res_indices = self.res_indices if hasattr(self, "res_indices") else None @@ -468,10 +482,14 @@ def fast_dequant(self): self.post_init() centroids = self.centroids.weight.view(self.num_codebooks, self.num_centroids, self.vector_len) - res_centroids = (self.res_centroids.weight.view(self.num_codebooks, self.num_res_centroids, self.vector_len) - if self.res_centroids is not None else None) - outlier_centroids = (self.outlier_centroids.weight.view(1, self.outlier_num_centroids, self.outlier_vector_len) - if hasattr(self, "outlier_centroids") else None) + res_centroids = ( + self.res_centroids.weight.view(self.num_codebooks, self.num_res_centroids, self.vector_len) + if self.res_centroids is not None else None + ) + outlier_centroids = ( + self.outlier_centroids.weight.view(1, self.outlier_num_centroids, self.outlier_vector_len) + if hasattr(self, "outlier_centroids") else None + ) if self.is_indice_packed: indices = self.indices @@ -564,13 +582,16 @@ def dequant(self): selected_res_centroids = torch.gather(res_centroids, 1, res_indices) - selected_res_centroids = selected_res_centroids.reshape(self.num_codebooks, -1, self.group_size, - self.vector_len) + selected_res_centroids = selected_res_centroids.reshape( + self.num_codebooks, -1, self.group_size, self.vector_len + ) selected_res_centroids = selected_res_centroids.permute(0, 1, 3, 2) - qweight = qweight + (selected_res_centroids.reshape(self.num_codebooks, -1, self.group_size).permute( - 1, 0, 2).reshape(-1, self.num_codebooks * self.group_size)) + qweight = qweight + ( + selected_res_centroids.reshape(self.num_codebooks, -1, self.group_size + ).permute(1, 0, 2).reshape(-1, self.num_codebooks * self.group_size) + ) # print(f'self.padding: {self.padding}') # print(f'self.out_features: {self.out_features}') @@ -589,8 +610,9 @@ def dequant(self): # print(f'qweight: {qweight.shape}') if self.enable_outlier: - outlier_centroids = self.outlier_centroids.weight.view(1, self.outlier_num_centroids, - self.outlier_vector_len) + outlier_centroids = self.outlier_centroids.weight.view( + 1, self.outlier_num_centroids, self.outlier_vector_len + ) # outlier_centroids_shape = outlier_centroids.shape outlier_indices = self.outlier_indices.view(torch.uint16).to(torch.int64) @@ -602,8 +624,9 @@ def dequant(self): selected_outlier_centroids = torch.gather(outlier_centroids, 1, outlier_indices) # print( # f'1 selected_outlier_centroids: {selected_outlier_centroids.shape}') - selected_outlier_centroids = selected_outlier_centroids.reshape(1, -1, self.outlier_size, - self.outlier_vector_len) + selected_outlier_centroids = selected_outlier_centroids.reshape( + 1, -1, self.outlier_size, self.outlier_vector_len + ) # selected_outlier_centroids = selected_outlier_centroids.view( # 1, -1, len(self.outlier_indices), self.outlier_vector_len) # print( @@ -710,7 +733,8 @@ def set_l2_indices(self, weights): if self.enable_residual: res_vectors = vectors - centroids.squeeze(0)[indices.squeeze(0)] res_indices = self._batched_indices( - res_vectors, self.res_centroids.weight.view(self.num_codebooks, self.num_centroids, self.vector_len)) + res_vectors, self.res_centroids.weight.view(self.num_codebooks, self.num_centroids, self.vector_len) + ) # res_indices = self._get_indices( # res_vectors, self.res_centroids.weight.view( # self.num_codebooks, self.num_centroids, self.vector_len))