Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix format #95

Merged
merged 4 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion format.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
yapf --recursive . --style='{based_on_style: google, column_limit: 120, indent_width: 4}' -i
#!/bin/bash

# Format Python files using yapf
echo "Running yapf..."
find . -type f -name "*.py" \
! -path "./build/*" \
! -path "./.git/*" \
! -path "*.egg-info/*" \
-print0 | xargs -0 yapf --in-place

# Format Python imports using isort
echo "Running isort..."
isort .

# Format C++ files using clang-format
echo "Formatting C++ files..."
find csrc/ \( -name '*.h' -o -name '*.cc' -o -name '*.cu' -o -name '*.cuh' \) -print | xargs clang-format -i

echo "Formatting complete!"
9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ column_limit = 120
indent_width = 4
based_on_style = "google"
split_before_logical_operator = false

dedent_closing_brackets = true
coalesce_brackets = true

[tool.codespell]
ignore-words-list = "ist"
Expand All @@ -58,3 +59,9 @@ skip = "./VPTQ_arxiv.pdf,./build"
[tool.isort]
use_parentheses = true
skip_gitignore = true
line_length = 120
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
combine_as_imports = true
ensure_newline_before_comments = true
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ def build_cuda_extensions():
if torch.cuda.is_available() and torch.version.hip is not None:
extra_compile_args["nvcc"].extend(["-fbracket-depth=1024"])
else:
extra_compile_args["nvcc"].extend(
["--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", "-lineinfo"])
extra_compile_args["nvcc"].extend([
"--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", "-lineinfo"
])

extensions = CUDAExtension(
"vptq.ops",
Expand Down
2 changes: 1 addition & 1 deletion vptq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
# --------------------------------------------------------------------------

__version__ = "0.0.2.post1"
from .layers import AutoModelForCausalLM as AutoModelForCausalLM
from vptq.layers import AutoModelForCausalLM as AutoModelForCausalLM
2 changes: 1 addition & 1 deletion vptq/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from .app_utils import main
from vptq.app_utils import main

main()
13 changes: 6 additions & 7 deletions vptq/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import gradio as gr
from huggingface_hub import snapshot_download

from vptq.app_gpu import disable_gpu_info, enable_gpu_info
from vptq.app_gpu import update_charts as _update_charts
from vptq.app_gpu import disable_gpu_info, enable_gpu_info, update_charts as _update_charts
from vptq.app_utils import get_chat_loop_generator

models = [
Expand Down Expand Up @@ -114,11 +113,11 @@ def respond(
response = ""

for message in chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message

Expand Down
16 changes: 9 additions & 7 deletions vptq/app_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,15 @@ def update_charts(chart_height: int = 200) -> go.Figure:
titlefont=dict(color='blue'),
tickfont=dict(color='blue'),
),
yaxis2=dict(title='Memory Usage (GiB)',
range=[0, max(24,
max(mem_usage_history) + 1)],
titlefont=dict(color='red'),
tickfont=dict(color='red'),
overlaying='y',
side='right'),
yaxis2=dict(
title='Memory Usage (GiB)',
range=[0, max(24,
max(mem_usage_history) + 1)],
titlefont=dict(color='red'),
tickfont=dict(color='red'),
overlaying='y',
side='right'
),
height=chart_height, # set the height of the chart
margin=dict(l=10, r=10, t=0, b=0), # set the margin of the chart
showlegend=False # disable the legend
Expand Down
41 changes: 19 additions & 22 deletions vptq/app_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import transformers

from .layers.model_base import AutoModelForCausalLM as VQAutoModelQuantization
from vptq.layers.model_base import AutoModelForCausalLM as VQAutoModelQuantization


def define_basic_args():
Expand All @@ -21,10 +21,9 @@ def define_basic_args():
""",
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument("--model",
type=str,
required=True,
help="float/float16 model to load, such as [mosaicml/mpt-7b]")
parser.add_argument(
"--model", type=str, required=True, help="float/float16 model to load, such as [mosaicml/mpt-7b]"
)
parser.add_argument("--tokenizer", type=str, default="", help="default same as [model]")
parser.add_argument("--prompt", type=str, default="once upon a time, there ", help="prompt to start generation")
parser.add_argument("--chat", action="store_true", help="chat with the model")
Expand Down Expand Up @@ -62,11 +61,9 @@ def chat_loop(model, tokenizer, args):
encodeds = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
model_inputs = encodeds.to(model.device)
print("assistant: ", end='')
generated_ids = model.generate(model_inputs,
streamer=streamer,
pad_token_id=2,
max_new_tokens=500,
do_sample=True)
generated_ids = model.generate(
model_inputs, streamer=streamer, pad_token_id=2, max_new_tokens=500, do_sample=True
)
decoded = tokenizer.batch_decode(generated_ids[:, model_inputs.shape[-1]:], skip_special_tokens=True)
messages.append({"role": "assistant", "content": decoded[0]})

Expand All @@ -83,11 +80,9 @@ def get_chat_loop_generator(model_id):
if getattr(tokenizer, "chat_template", None) is None:
raise Exception("warning: this tokenizer didn't provide chat_template.!!!")

def chat_loop_generator(messages,
max_tokens: int,
stream: bool = True,
temperature: float = 1.0,
top_p: float = 1.0):
def chat_loop_generator(
messages, max_tokens: int, stream: bool = True, temperature: float = 1.0, top_p: float = 1.0
):
print("============================chat with the model============================")
print("Press 'exit' to quit")

Expand All @@ -99,13 +94,15 @@ def chat_loop_generator(messages,
return_dict=True,
)
model_inputs = encodeds.to(model.device)
generation_kwargs = dict(model_inputs,
streamer=streamer,
max_new_tokens=max_tokens,
pad_token_id=2,
do_sample=True,
temperature=temperature,
top_p=top_p)
generation_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=max_tokens,
pad_token_id=2,
do_sample=True,
temperature=temperature,
top_p=top_p
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
for new_text in streamer:
Expand Down
2 changes: 1 addition & 1 deletion vptq/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from .model_base import AutoModelForCausalLM as AutoModelForCausalLM
from vptq.model_base import AutoModelForCausalLM as AutoModelForCausalLM
30 changes: 16 additions & 14 deletions vptq/layers/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import transformers
from tqdm import tqdm

from .vqlinear import VQuantLinear
from vptq.vqlinear import VQuantLinear


def set_op_by_name(layer, name, new_module):
Expand All @@ -32,9 +32,9 @@ def set_op_by_name(layer, name, new_module):


def make_quant_linear(module, quant_conf, name="", target_layer=None):
for module_name, sub_module in tqdm(module.named_modules(),
total=len(list(module.named_modules())),
desc="Replacing linear layers..."):
for module_name, sub_module in tqdm(
module.named_modules(), total=len(list(module.named_modules())), desc="Replacing linear layers..."
):
if module_name in quant_conf:
layer_conf = quant_conf[module_name]
new_module = target_layer(**layer_conf, enable_proxy_error=False, dtype=sub_module.weight.dtype)
Expand Down Expand Up @@ -124,9 +124,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
checkpoint = pretrained_model_name_or_path
else: # remote
token_arg = {"token": kwargs.get("token", None)}
checkpoint = huggingface_hub.snapshot_download(repo_id=pretrained_model_name_or_path,
ignore_patterns=["*.bin"],
**token_arg)
checkpoint = huggingface_hub.snapshot_download(
repo_id=pretrained_model_name_or_path, ignore_patterns=["*.bin"], **token_arg
)
weight_bins = glob.glob(str(Path(checkpoint).absolute() / "*.safetensors"))
index_json = glob.glob(str(Path(checkpoint).absolute() / "*.index.json"))
pytorch_model_bin = glob.glob(str(Path(checkpoint).absolute() / "pytorch_model.bin"))
Expand All @@ -148,13 +148,15 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
max_memory = local_max_memory

accelerate.hooks.attach_execution_device_hook = attach_execution_device_hook
model = accelerate.load_checkpoint_and_dispatch(model,
checkpoint=checkpoint,
device_map=device_map,
max_memory=max_memory,
no_split_module_classes=no_split_module_classes[0],
dtype=torch_dtype,
preload_module_classes=["VQuantLinear"])
model = accelerate.load_checkpoint_and_dispatch(
model,
checkpoint=checkpoint,
device_map=device_map,
max_memory=max_memory,
no_split_module_classes=no_split_module_classes[0],
dtype=torch_dtype,
preload_module_classes=["VQuantLinear"]
)

# check cuda kernel exist
if importlib.util.find_spec("vptq.ops") is not None:
Expand Down
Loading
Loading