Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset split fallbacks #953

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open

Conversation

kylesayrs
Copy link
Collaborator

@kylesayrs kylesayrs commented Dec 4, 2024

Purpose

  • Allow fallback dataset splits to be used for compression tasks requiring specific dataset splits
  • Allow datasets without a calibration split (which is most datasets) to be used in oneshot without explicit preprocessing

Previously

oneshot(
    dataset="some_dataset",
    split="test[:128]"  # intended {"calibration": "test[:128]"}
)
ValueError: --do_oneshot requires a calibration dataset

Now

oneshot(
    dataset="some_dataset",
    split="test[:128]"
)
UserWarning: oneshot expects one of ['calibration', 'train'] dataset split, falling back to test.
Use splits={"calibration": "test"} to silence this warning

Changes

  • Added typing definitions in src/llmcompressor/typing.py
  • Implemented _get_split_with_fallbacks helper function which allows strict and non-strict split retrieval with fallbacks

Postrequisites

Testing

  • Added tests in tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py
test.py
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot

# Select model and load it.
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Select calibration dataset.
DATASET_ID = "ultrachat-200k"
DATASET_SPLIT = "train_sft[:512]"

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

# Configure the quantization algorithm to run.
#   * quantize the weights to 4 bit with GPTQ with a group size 128
recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"])

# Apply algorithms.
oneshot(
    model=model,
    dataset=DATASET_ID,
    splits=DATASET_SPLIT,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Copy link

github-actions bot commented Dec 4, 2024

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

@kylesayrs kylesayrs self-assigned this Dec 4, 2024
@kylesayrs kylesayrs mentioned this pull request Dec 4, 2024
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
@kylesayrs
Copy link
Collaborator Author

This change isn't strictly necessary for #943. I think the original intention was that users would manually assign dataset splits, ie oneshot(splits={"calibration": "train[:128]"}. How to do this isn't immediately obvious from the error we currently give, so automatic split inference is a nice to have.

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
@dsikka
Copy link
Collaborator

dsikka commented Dec 19, 2024

If this is a nice to have, let's not prioritize for now.
We can follow-up.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants