Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for Efficient Inference #47

Merged
merged 171 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
171 commits
Select commit Hold shift + click to select a range
9069b58
End-to-end integration
elvircrn Aug 29, 2024
a12e91f
Adds initial support for CUDA inference
elvircrn Sep 6, 2024
0e195b4
Bugfixes
elvircrn Sep 8, 2024
6ade617
Demo ready
elvircrn Sep 11, 2024
5ad861d
Demo ready
elvircrn Sep 11, 2024
dc5ad18
Remove absolute path
elvircrn Sep 11, 2024
737f0b3
Rename spqr -> inference in benchmarks
elvircrn Sep 12, 2024
9c587b4
Update benchmark code
elvircrn Sep 12, 2024
ff55c85
Dead code removal
elvircrn Sep 12, 2024
dbb7164
Trying out a new fused kernel implementation
elvircrn Sep 14, 2024
9f9a3ec
Fused kernel getting there...
elvircrn Sep 15, 2024
6e33f72
Major fused kernel improvements
elvircrn Sep 16, 2024
e91263f
Make the fused kernel more async
elvircrn Sep 16, 2024
2474292
Fused kernel WIP
elvircrn Sep 16, 2024
2916175
Fused kernel mean speed-up over 3X
elvircrn Sep 17, 2024
9533f71
Done developing the fused kernel
elvircrn Sep 17, 2024
3d49eb3
Get rid of absolute paths
elvircrn Sep 18, 2024
a9d586a
Start work on reorder optimization
elvircrn Sep 19, 2024
e95d292
Get ready for end-to-end benchmarks
elvircrn Sep 19, 2024
12f1b7d
WIDTH=16
elvircrn Sep 19, 2024
8bc942f
WIDTH=16
elvircrn Sep 19, 2024
6b0ecbf
Increase pipeline depth
elvircrn Sep 19, 2024
7059d1c
GPU specialization
elvircrn Sep 19, 2024
192a356
Reenable pipelining
elvircrn Sep 19, 2024
eddac7d
Try setting block size to 32
elvircrn Sep 19, 2024
99347d9
Try setting block size to 64
elvircrn Sep 19, 2024
7bda888
Try disabling the LUT
elvircrn Sep 19, 2024
15a1e2c
Try returning early
elvircrn Sep 19, 2024
e2c7337
Try returning earlier
elvircrn Sep 19, 2024
95fa2f8
Streamline loops
elvircrn Sep 20, 2024
ba23dd4
Bring back actual computing
elvircrn Sep 20, 2024
96a8163
Work on making everything async
elvircrn Sep 22, 2024
d60ae9a
Work on memory compression
elvircrn Sep 22, 2024
e5156c9
update the slow version
elvircrn Sep 23, 2024
49cd194
Finalize optimizations
elvircrn Sep 23, 2024
d231bc9
Apply it to A100
elvircrn Sep 23, 2024
d873c71
Remove torchvision
elvircrn Sep 23, 2024
e75b32c
Tweak A100 settings
elvircrn Sep 23, 2024
f38e588
Get rid of other inefficiencies
elvircrn Sep 23, 2024
c467255
Resolve failing test issues
elvircrn Sep 24, 2024
8d1e845
Uncomment __synchthread
elvircrn Sep 24, 2024
e47907d
Minor perf wins
elvircrn Sep 24, 2024
4260af4
Add cache policy
elvircrn Sep 24, 2024
d343225
why is ldcs broken?
elvircrn Sep 24, 2024
9023f21
Tweak A100 settings
elvircrn Sep 24, 2024
e01bbc1
Go back to ldg
elvircrn Sep 24, 2024
2ac5f59
resolve typo
elvircrn Sep 24, 2024
2badfb0
Minor tweaks
elvircrn Sep 24, 2024
a735768
A100 tweaks
elvircrn Sep 24, 2024
9156dcd
A100 tweaks
elvircrn Sep 24, 2024
a9037b9
80
elvircrn Sep 24, 2024
de2afb3
A100 tweaks
elvircrn Sep 24, 2024
ea86c59
Apply it to A100
elvircrn Sep 24, 2024
c93d282
Resolve failing tests
elvircrn Sep 25, 2024
a1ed722
Run configuration tweaks
elvircrn Sep 25, 2024
3a26daf
Dead code removal
elvircrn Sep 25, 2024
e2d449d
Dead code removal
elvircrn Sep 25, 2024
19a481d
Fix benchmak
elvircrn Sep 25, 2024
37f5efe
Move test utils into a separate file
elvircrn Sep 25, 2024
83fef7c
Remove duplicate function
elvircrn Sep 25, 2024
022bcc5
Sparsity analysis refactor
elvircrn Sep 25, 2024
5f5f841
Dead code removal
elvircrn Sep 25, 2024
730283d
Remove old benchmark method
elvircrn Sep 25, 2024
6840990
Compression refactor
elvircrn Sep 25, 2024
a8efd8a
Update inference code
elvircrn Sep 25, 2024
bec333e
Remove hardcoded .so load
elvircrn Sep 25, 2024
d95ee80
Remove torch jit script
elvircrn Sep 25, 2024
7406739
Remove source activate from build.sh
elvircrn Sep 25, 2024
ce195e3
Use a saner default
elvircrn Sep 25, 2024
14e2474
Prepare for endtoend benchmark
elvircrn Sep 25, 2024
b5c64a0
Prepare for endtoend benchmark
elvircrn Sep 25, 2024
b90e64d
Optimize inference
elvircrn Sep 25, 2024
0b185cc
A100 optimizations
elvircrn Sep 25, 2024
6bf1934
Further optimization for inference
elvircrn Sep 25, 2024
4aa0a49
remove perm temporarely
elvircrn Sep 25, 2024
218a191
remove device query
elvircrn Sep 25, 2024
a5722b6
Tweak a100 config
elvircrn Sep 25, 2024
3bb4aa0
Async
elvircrn Sep 25, 2024
4779a0c
Optimize inference
elvircrn Sep 25, 2024
1ebff7e
remove reordering
elvircrn Sep 25, 2024
36c42d5
remove device lookup again
elvircrn Sep 25, 2024
e144c92
no cheating
elvircrn Sep 25, 2024
ae22edc
torch empty
elvircrn Sep 25, 2024
09c6cbe
Bring back reordering
elvircrn Sep 25, 2024
8c89098
Proper reordering
elvircrn Sep 25, 2024
817c6b0
Streamline sparsity
elvircrn Sep 25, 2024
bf84de8
Special case density
elvircrn Sep 27, 2024
e93d970
multidim
elvircrn Sep 28, 2024
d700e4f
Some minor tweaks
elvircrn Sep 29, 2024
f0e9fbd
Add support for multidim threadlaunch; Get rid of all atomics
elvircrn Sep 30, 2024
b2143f2
Merge branch 'multidim' into inference
elvircrn Sep 30, 2024
d307b0d
Get rid of daed tests
elvircrn Sep 30, 2024
f94fff3
Bring back fast config
elvircrn Sep 30, 2024
8e32659
Also bring back dense-only config
elvircrn Sep 30, 2024
9572eea
Revert to old thread config
elvircrn Oct 1, 2024
a8dd177
Experimental
elvircrn Oct 12, 2024
b8a264e
Resolve the issues with the new format
elvircrn Oct 14, 2024
382c081
Sparse v2
elvircrn Oct 15, 2024
1c918b2
Resolve some maladies
elvircrn Oct 16, 2024
a5ed000
Resolve all issues with the new csr format
elvircrn Oct 16, 2024
e03a65a
Visualization work
elvircrn Oct 21, 2024
c2cf2fc
Speeding this up even more
elvircrn Oct 22, 2024
9281be4
Also measure modified CSR
elvircrn Oct 22, 2024
784e2ef
Optimize LUT loads/Put pipelined loads behind a flag
elvircrn Oct 24, 2024
5586150
Bench visualization updates
elvircrn Oct 27, 2024
441fdf4
Visualization updates
elvircrn Oct 28, 2024
b35c7c8
Significant speed-ups made for inference
elvircrn Nov 1, 2024
7e8ece3
Resolved most PR comments; Minor speed-ups; Refactor; Started work on…
elvircrn Nov 5, 2024
4503f92
Simplify setup.py
elvircrn Nov 5, 2024
8ce70af
Refactor torch mul
elvircrn Nov 5, 2024
2fdec88
Significantly simply the inference demo
elvircrn Nov 5, 2024
87525d9
Inference simplifications; Documentation updates
elvircrn Nov 5, 2024
dca2136
Remove inference license
elvircrn Nov 5, 2024
3dfcac9
Remove environment hardcore from ncu.sh script
elvircrn Nov 5, 2024
243a94b
Move ncu.sh out of repo
elvircrn Nov 5, 2024
3cb64c4
Move benchmark visualization
elvircrn Nov 5, 2024
1d6e0f9
make tests sligthly less descriptiive
elvircrn Nov 5, 2024
9c051a3
Make benchmark nicer
elvircrn Nov 5, 2024
38a5d5c
Make benchmarks up-to-date
elvircrn Nov 5, 2024
d6be547
Resolve issue with PTCSR
elvircrn Nov 5, 2024
274c1dd
Kernel simplification
elvircrn Nov 5, 2024
2c28527
Remove all mentions of elvircrn/
elvircrn Nov 5, 2024
0a0a9de
HF integration
elvircrn Nov 6, 2024
447f887
HF integration
elvircrn Nov 6, 2024
875b787
Update readme for benchmark py
elvircrn Nov 8, 2024
49c9211
Further README updates
elvircrn Nov 8, 2024
cbdbc69
Make PTCSR optional
elvircrn Nov 8, 2024
bbd456d
Remove the use of the mold linker
elvircrn Nov 8, 2024
2348671
Actually resolve the PTCSR path in bench.py
elvircrn Nov 8, 2024
f4bc3a1
Actually resolve the PTCSR path in bench.py
elvircrn Nov 8, 2024
1b4f32a
Actually resolve the PTCSR path in bench.py
elvircrn Nov 8, 2024
8c981c0
Make benchmark script nicer to work with
elvircrn Nov 8, 2024
10b3e65
Docs update
elvircrn Nov 9, 2024
4bed26c
Tune threads for A100
elvircrn Nov 9, 2024
69dd1c5
End-to-end bugfix
elvircrn Nov 9, 2024
dd610ab
Bring back Sdpa for end-to-end inference
elvircrn Nov 9, 2024
5a66c6b
Bring back inductor
elvircrn Nov 9, 2024
cb37502
Temporarely disable reordering
elvircrn Nov 9, 2024
6324564
Torchscript=True
elvircrn Nov 9, 2024
86570fd
Get rid of flattening
elvircrn Nov 9, 2024
09fdc04
Temp disable indexing
elvircrn Nov 9, 2024
208c633
Get rid of all indices
elvircrn Nov 9, 2024
50d0e40
Split forward() into two parts
elvircrn Nov 9, 2024
4f65a7d
Get rid of torch no grad
elvircrn Nov 9, 2024
3eafdd6
stop compiling twice
elvircrn Nov 10, 2024
46b0fb7
Disable sdpa optimization
elvircrn Nov 10, 2024
6cb83c1
Try with default
elvircrn Nov 10, 2024
aec7c2b
Temporarely disable reordering
elvircrn Nov 10, 2024
32af5b0
Simplify flattening
elvircrn Nov 10, 2024
05a9154
Get rid of flatten
elvircrn Nov 10, 2024
35aa8d8
Get rid of flatten
elvircrn Nov 10, 2024
bdcfcdb
Revert "Get rid of flatten"
elvircrn Nov 10, 2024
7fa9431
Increase prediction length
elvircrn Nov 10, 2024
2161b13
Even more optimization work, somehow
elvircrn Nov 11, 2024
9ed5217
Try ramping up thread count
elvircrn Nov 11, 2024
25d4f80
Ramp down thread count
elvircrn Nov 11, 2024
77b5392
Resolve critical bug
elvircrn Nov 11, 2024
fc930ba
Ramp up the thread count
elvircrn Nov 13, 2024
53cbccc
Ramp up the thread count
elvircrn Nov 13, 2024
9d6a157
Fix benchmarking bug
elvircrn Nov 13, 2024
4cbd7b5
Reduce thread count
elvircrn Nov 13, 2024
8daecb1
Change thread count
elvircrn Nov 13, 2024
5206238
Finalize thread count
elvircrn Nov 13, 2024
63a525c
Resolve all ptcsr bugs
elvircrn Nov 14, 2024
f007655
Finalize benchmark updates
elvircrn Nov 15, 2024
eb46401
Finalize benchmark updates
elvircrn Nov 15, 2024
2529106
Finalize HF support
elvircrn Nov 18, 2024
9817e09
Remove added gitiginore lines
elvircrn Nov 20, 2024
c7f6cb8
Address PR comments
elvircrn Nov 20, 2024
53d096c
Rename LLama -> InferenceDemo
elvircrn Nov 20, 2024
1477ee0
Apply black and isort
elvircrn Nov 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
PYTHONPATH=.
8 changes: 5 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
.vscode/*
.idea/*
.ipynb_checkpoints/*
__pycache__
__pycache__/*
SpQR/.ipynb_checkpoints/*
SpQR/__pycache__/*
spqr/.ipynb_checkpoints/*
spqr/__pycache__/*
outliers/*
outliers_stru*
wandb/*
Expand All @@ -17,4 +19,4 @@ lm-evaluation-harness/*.sh
lm-evaluation-harness/lm_eval/datasets/*/__pycache__/
lm-evaluation-harness/lm_eval/*/__pycache__/
lm-evaluation-harness/lm_eval/__pycache__/
lm-evaluation-harness/lm_cache
lm-evaluation-harness/lm_cache
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Apache License
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/

Expand Down
108 changes: 107 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,112 @@ Performance and runtime notes:
* With enough spare GPU memory, one can raise batch size to accelerate evaluation process.


## Inference

This repository also contains an efficient CUDA kernel implementation of the
SpQR matvec. The file `inference_demo.py` contains a demo of this functionality
by running end-to-end model inference. Below is an example of how to launch it.

Vahe1994 marked this conversation as resolved.
Show resolved Hide resolved
```bash
usage: inference_demo.py [-h] [--pretrained_model_path PRETRAINED_MODEL_PATH] [--compressed_model_path COMPRESSED_MODEL_PATH] --execution_mode {0,1}

options:
-h, --help show this help message and exit
--pretrained_model_path PRETRAINED_MODEL_PATH
Path to the model to the pretrained model
--compressed_model_path COMPRESSED_MODEL_PATH
Path to the compressed .pt model
--execution_mode {0,1}
If set to 0, will evaluate the dense pretrained model. If set to 1, will evaluate the spqr-quantized model
```

This script also reports the mean and median time of the forward() passes and the total inference execution time.

# Pre-Requisites for Running the Conversion Scripts, Tests and Benchmarks

In order to run the benchmark and test suite you need to build the sources used by these scripts.
You can do so by running the following command:

```bash
/bin/bash scripts/build.sh
```

which simply runs the `setup.py` script.

# Conversion From Legacy to Optimized SPQR Storage

After running SpQR which produces the tensors stored in int8, in order to run the efficient inference kernels,
one must convert the tensors produces by SpQR (legacy tensors) into the optimized storage format used by
the cuda kernel. In order to do so, run the following script:

```bash
usage: convert_legacy_model_format.py [-h] --base_model BASE_MODEL --legacy_model_path LEGACY_MODEL_PATH [--sparse_strategy {csr,ptcsr,optimize_latency}] [--save_pt SAVE_PT] [--save_per_layer SAVE_PER_LAYER]

options:
-h, --help show this help message and exit
--base_model BASE_MODEL
path or name of the unquantized model
--legacy_model_path LEGACY_MODEL_PATH
path to legacy model
--sparse_strategy {csr,ptcsr,optimize_latency}
Sparse strategy storage. Options: csr, ptcsr, auto. CSR - Compressed Sparse Rows PTCSR - Alternative storage format optimize_latency - Use the current GPU to determine the optimal storage format to reduce
kernel latency
--save_pt SAVE_PT Save the converted quantized .pt model here
--save_per_layer SAVE_PER_LAYER
Save the converted quantized m
```

# Hugginface Conversion

To convert a model into a Hugging Face compatible format, use convert_to_hf.py script:

```bash
usage: convert_to_hf.py [-h] [--model MODEL] [--config_path CONFIG_PATH] [--in_path_pt IN_PATH_PT] [--out_path OUT_PATH] [--save_safetensors] [--trust_remote_code] [--load_model] [--save_tokenizer]

options:
-h, --help show this help message and exit
--model MODEL Path to the model to base config on, as in AutoConfig.from_pretrained()
--config_path CONFIG_PATH
Path to the model to base config on, as in AutoConfig.from_pretrained()
--in_path_pt IN_PATH_PT
Path of the checkpoint to convert
--out_path OUT_PATH Path to save HF compatible checkpoint to
--save_safetensors Whether to save in safetensors format
--trust_remote_code Whether to trust remote code
--load_model Whether to load model
--save_tokenizer Whether to save tokenizer
```

# Benchmarks (matvec kernel)

In order to run the matvec benchmark suite, one should run:

```bash
bench_spqr.py [-h] --tensor_path TENSOR_PATH [--ptcsr_path PTCSR_PATH] [--output_path OUTPUT_PATH]

options:
-h, --help show this help message and exit
--tensor_path TENSOR_PATH
Path to folder containing the tensors of the formmodel_path/ 0/ tensor0 tensor1
--ptcsr_path PTCSR_PATH
Path to folder containing the tensors of the formmodel_path/ 0/ tensor0 tensor1
--output_path OUTPUT_PATH
Path to results *.csv file.

```

Make sure that the `<tensor_path>` and the optional `<ptcsr_path.` point to a folder containing quantized matrices produced by the `convert_legacy_model_format.py` script.
Use `<cuda_device_id>` to set the cuda device during benchmark. The script outputs the results in `<results_output>`.

# Tests

In order to run the unittest, simply execute:

```bash
python3 tests/test.py
```


## Citation
```
@misc{dettmers2023spqr,
Expand All @@ -143,4 +249,4 @@ Performance and runtime notes:
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
```
3 changes: 3 additions & 0 deletions build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash
MAX_JOBS=16 python3 setup.py install

175 changes: 175 additions & 0 deletions convert_legacy_model_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import argparse
import os

import torch
from transformers import AutoConfig, AutoModelForCausalLM

from spqr import ModelArgs, QuantizedLinear, SPQRLegacy, flatten_tensor


def load_legacy_tensor(p: str, model_args: ModelArgs) -> SPQRLegacy:
"""
Load legacy tensor given tensor path @p and model args @model_args.
Background:
spqr_engine.py produces tensors whose 3-bit weights are stored as int8.
We refer to this storage scheme as legacy, since the 3-bit inference kernel
only accepts the compressed storage format.
@param p: Legacy tensor path.
@param model_args: Model arguments - we obtain the beta1, beta2, bits and the sparse compression format from here.
@return: QuantizedLinear object, storing the compressed matrix format ready to be used by the efficient inference
kernel.
"""
bits = model_args.bits
beta1 = model_args.beta1
beta2 = model_args.beta2

legacy_tensor = torch.load(p, map_location="cpu")

W = legacy_tensor["quant_weights"]
m = W.shape[0]
n = W.shape[1]
W = flatten_tensor(W)
W_s = flatten_tensor(legacy_tensor["quant_layer_scale"])
W_z = flatten_tensor(legacy_tensor["quant_layer_zeros"])

perm = legacy_tensor["perm"]

outliers_matrix = legacy_tensor["outliers_matrix"].to_sparse_csr()

col_ids = outliers_matrix.col_indices().short()
values = outliers_matrix.values().half()

return SPQRLegacy(
m=m,
n=n,
bits=bits,
W=flatten_tensor(W),
beta1=beta1,
beta2=beta2,
W_s=W_s,
W_z=W_z,
W_s_s=flatten_tensor(legacy_tensor["quant_layer_scale_qq_scale"]),
W_s_z=flatten_tensor(legacy_tensor["quant_layer_scale_qq_zero"]),
W_z_s=flatten_tensor(legacy_tensor["quant_layer_zero_qq_scale"]),
W_z_z=flatten_tensor(legacy_tensor["quant_layer_zero_qq_zero"]),
row_offsets=outliers_matrix.crow_indices().int(),
col_ids=col_ids,
values=values,
in_perm=perm.long(),
)


def replace_and_save_quantized_layers(
model_args: ModelArgs,
model_to_be_quantized,
legacy_model_path,
current_model=None,
layer_id: int = -1,
parent_tensor_name="",
output_per_layer_path=None,
):
"""
This function goes through the @model_to_be_quantized recursively and
replaces all the dense layers with their quantized counterpart where
applicable. The legacy quantized layers are stored in @legacy_model_path.

As we go through the model, we construct the tensor name using layer_id and parent tensor name.
We then use these values to check if the current dense tensor is a valid candidate for substitution
with its quantized counterpart.

@param model_args: Global model args.
@param model_to_be_quantized: Model to be quantized.
@param legacy_model_path: Location of the quantized tnesors stored in the legacy format as output by SpQR.
@param output_per_layer_path: Optionally, one may wish to store the compressed SpQR layers separately in a folder
specified by this parameter (for example, this may or may not be useful during benchmarking or data analysis).
@param layer_id: Internal used to keep track of the current layer as we descend the model.
@param parent_tensor_name: Name of the previous layer in the recursion chain.
"""
if current_model == None:
current_model = model_to_be_quantized
for tensor_name, m in current_model.named_children():
if tensor_name.isnumeric():
layer_id = int(tensor_name)
if output_per_layer_path is not None:
os.makedirs(os.path.join(output_per_layer_path, str(layer_id)), exist_ok=True)

if isinstance(m, torch.nn.Linear):
assert m.bias is None
legacy_tensor_path = os.path.join(legacy_model_path, f"{layer_id}", f"{parent_tensor_name}.{tensor_name}")
if os.path.exists(legacy_tensor_path):
spqr_uncompressed = load_legacy_tensor(legacy_tensor_path, model_args)
spqr_module = QuantizedLinear.from_legacy(spqr_uncompressed, model_args, "cpu")
if output_per_layer_path is not None:
per_layer_tensor_path = os.path.join(
output_per_layer_path, f"{layer_id}", f"{parent_tensor_name}.{tensor_name}"
)
torch.save(spqr_module, per_layer_tensor_path)
setattr(current_model, tensor_name, spqr_module)
else:
replace_and_save_quantized_layers(
model_args, model_to_be_quantized, legacy_model_path, m, layer_id, tensor_name, output_per_layer_path
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(add_help=True)
parser.add_argument(
"--base_model",
type=str,
required=True,
help="path or name of the unquantized model",
)
parser.add_argument(
"--legacy_model_path",
type=str,
required=True,
help="path to legacy model",
)
parser.add_argument(
"--sparse_strategy",
type=str,
default="csr",
choices=["csr", "ptcsr", "optimize_latency"],
help="Sparse strategy storage. Options: csr, ptcsr, auto.\nCSR - Compressed Sparse Rows\nPTCSR - Alternative storage format\noptimize_latency - Use the current GPU to determine the optimal storage format to reduce kernel latency",
)
parser.add_argument("--save_pt", type=str, required=False, help="Save the converted quantized .pt model here")
parser.add_argument(
"--save_per_layer",
type=str,
required=False,
help="Save the converted quantized model per layer here - useful for benchmarking individual layers",
)

args, leftovers = parser.parse_known_args()

if args.save_per_layer is not None:
os.makedirs(args.save_per_layer, exist_ok=True)

layers = os.listdir(args.legacy_model_path)

args_path = os.path.join(args.legacy_model_path, "args.pt")
model_args = ModelArgs.from_file(args.legacy_model_path, args.sparse_strategy)

config = AutoConfig.from_pretrained(args.base_model, return_dict=True)

config.max_position_embeddings = 4096

model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=args.base_model, trust_remote_code=True, torch_dtype=torch.half, config=config
)

if args.save_per_layer is not None:
not_quantized_weights_path = os.path.join(args.legacy_model_path, "not_quantized_weights.pt")
not_quantized_weights = torch.load(not_quantized_weights_path)
for w in not_quantized_weights.values():
w.requires_grad = False
model.load_state_dict(not_quantized_weights, strict=False)
for f in ["args.pt", "not_quantized_weights.pt"]:
os.system(f"cp {os.path.join(args.legacy_model_path, f)} {os.path.join(args.save_per_layer, f)}")

replace_and_save_quantized_layers(
model_args, model, args.legacy_model_path, output_per_layer_path=args.save_per_layer
)

if args.save_pt is not None:
torch.save(model, args.save_pt)
Loading