Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Marlin24 Compressor #77

Merged
merged 22 commits into from
Jun 11, 2024
Merged

Marlin24 Compressor #77

merged 22 commits into from
Jun 11, 2024

Conversation

Satrat
Copy link

@Satrat Satrat commented Jun 3, 2024

Implements Marlin24 compression in SparseML. Note that this compressor does not have a decompress method, it is meant only for feeding directly into vllm.

Example Usage

Since this is a one-way conversion for vLLM, we won't automatically convert to this format when save_compressed is set in SparseML. Instead the save will have to be done explicitly like this:

from sparseml.transformers import SparseAutoModelForCausalLM, SparseAutoTokenizer
import torch

model_dir = "/network/sadkins/llama7b_sparse_24_w4a16_group128"
output_dir = "llama7b_marlin24"
model = SparseAutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16, device_map="cuda")
tokenizer = SparseAutoTokenizer.from_pretrained(model_dir)
model.save_pretrained(output_dir, quantization_format="marlin-24")
tokenizer.save_pretrained(output_dir)

Testing

Tested with a 7b Llama 2 model in vLLM, confirmed outputs for the marlin-24 format match the ouputs for the equivalent pack-quantized model. Also added some unit tests to compressed-tensors

Prompt: 'Hello, my name is', Generated text: '… Business: And I'm here to tell you about my new book!'
Prompt: 'The capital of France is', Generated text: ' Paris, but did you know that Parisians call themselves Parisins? True enough'
Prompt: 'The president of the United States is', Generated text: ', by his or her position, obliged to represent the United States internation,'
Prompt: 'The Boston Bruins are', Generated text: ' a hockey team based in Boston, Massachusetts that compete in the National Hockey League'

@Satrat Satrat changed the title [WIP] Marlin24 Compressor Marlin24 Compressor Jun 7, 2024
@Satrat Satrat marked this pull request as ready for review June 7, 2024 21:38
dbogunowicz
dbogunowicz previously approved these changes Jun 10, 2024
src/compressed_tensors/compressors/marlin_24.py Outdated Show resolved Hide resolved
src/compressed_tensors/compressors/marlin_24.py Outdated Show resolved Hide resolved
src/compressed_tensors/compressors/marlin_24.py Outdated Show resolved Hide resolved
src/compressed_tensors/compressors/marlin_24.py Outdated Show resolved Hide resolved

for name, value in tqdm(model_state.items(), desc="Compressing model"):
if name.endswith(weight_suffix):
prefix = name[: -(len(weight_suffix))]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
prefix = name[: -(len(weight_suffix))]
prefix = name.replace(weight_suffix)

More readable :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but the replace code runs the risk of replacing ".weight" if it exists elsewhere in the string, not likely but possible


# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bfineran can we have a license attached to this code if it's largely copied over from an open-source project (vllm), even though "our" Alex is the author?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, I think this got added in automatically when I ran make style, is there a way to turn that off for a certain file??

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can skip it with `neuralmagic: no copyright. it's an apache license though so I think usage should be pretty permissable

Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>
@horheynm
Copy link
Member

I like the design and how simple the implementation is. I didnt check the math and bit operations but other than that green from me

horheynm
horheynm previously approved these changes Jun 10, 2024
Sara Adkins and others added 3 commits June 10, 2024 10:58
Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>
@Satrat Satrat requested review from dbogunowicz and horheynm June 10, 2024 15:32

# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can skip it with `neuralmagic: no copyright. it's an apache license though so I think usage should be pretty permissable

@Satrat Satrat merged commit 14b1db1 into main Jun 11, 2024
1 check passed
@Satrat Satrat deleted the sa/marlin_24 branch June 11, 2024 14:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants