Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add learnable scales functionality. #8

Merged
merged 3 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 281 additions & 0 deletions examples/learnable_scales_eval.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "5ee00c71-2956-4d0f-936f-f19fe5b9ab99",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"\n",
"torch.cuda.device_count()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f3bf3fb3-d959-465d-8389-3fde5c87d4ea",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a6d7a23e9d954a74b5fa683a96fb7cc7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model_path = \"meta-llama/Meta-Llama-3.1-8B\"\n",
"model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map=\"cuda\")\n",
"tokenizer = AutoTokenizer.from_pretrained(model_path)\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3c90b75f-270f-44e8-a80a-92594b02c09d",
"metadata": {},
"outputs": [],
"source": [
"# Data used for published models\n",
"from datasets import load_dataset\n",
"\n",
"dataset1 = load_dataset(\"wikitext\", \"wikitext-2-raw-v1\", split=\"train\")\n",
"corpus1 = \"\\n\\n\".join(dataset1[:20000][\"text\"])\n",
"\n",
"dataset2 = load_dataset(\"allenai/c4\", data_files={\"train\": \"en/c4-train.00000-of-01024.json.gz\"}, split=\"train\")\n",
"corpus2 = \"\\n\\n\".join(dataset2[:20000][\"text\"])\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7c7764a2-ee78-4c3b-8a36-c9437125395a",
"metadata": {},
"outputs": [],
"source": [
"import flute.integrations.base\n",
"import flute.integrations.learnable"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "0c38a53f-eb1d-44a8-9a25-176518e5866f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Adding tunable scales to the linear layers...\n",
"Tokenizing corpora...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Token indices sequence length is longer than the specified maximum sequence length for this model (1315227 > 131072). Running this sequence through the model will result in indexing errors\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Prepare model for training...\n",
"Running epoch 0...\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ae9ada499bd94ad1870a956d653eb1f9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/128 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.\n",
"/data/cl/u/radi-cho/env/lib/python3.8/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.\n",
" with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]\n"
]
}
],
"source": [
"flute.integrations.learnable.learn_scales(\n",
" model=model,\n",
" tokenizer=tokenizer,\n",
" num_bits=4,\n",
" group_size=64,\n",
" custom_corpora=[corpus1, corpus2],\n",
" samples=128,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f80c9224-74fc-40c9-9191-af043bbb6bb9",
"metadata": {},
"outputs": [],
"source": [
"# Casting the model and learned scales to float16 (instead of bfloat16) might result in speed benefits due to kernel specifics."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "1906a406-e157-4112-b0f9-12a2e3cdf716",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/data/cl/u/radi-cho/env/lib/python3.8/site-packages/flute/integrations/base.py:56: UserWarning: Quantization always happen on 1st GPU\n",
" warnings.warn(f\"Quantization always happen on 1st GPU\")\n"
]
}
],
"source": [
"flute.integrations.base.prepare_model_flute(\n",
" name=\"model.layers\",\n",
" module=model.model.layers,\n",
" num_bits=4,\n",
" group_size=64,\n",
" fake=False,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "b0b26587-59b6-47d2-9522-36791e51159d",
"metadata": {},
"outputs": [],
"source": [
"from lm_eval import evaluator\n",
"from lm_eval.models.huggingface import HFLM"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "d03dc231-d20f-4140-8396-8ab1482f420a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-08-29:11:20:52,393 WARNING [huggingface.py:122] `pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way.\n",
"2024-08-29:11:20:52,430 WARNING [huggingface.py:350] Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration\n"
]
}
],
"source": [
"lm = HFLM(pretrained=model, tokenizer=tokenizer, batch_size=32, add_bos_token=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "97c7c3aa-6cc6-42b2-abd4-06904f1d58a0",
"metadata": {},
"outputs": [],
"source": [
"results = evaluator.simple_evaluate(\n",
" lm,\n",
" tasks=\"arc_easy\", # piqa, arc_easy, arc_challenge, hellaswag, winogrande\n",
" num_fewshot=0,\n",
" limit=None,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "13a70935-5108-4c3b-a99b-346618dd29c8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'arc_easy': {'acc,none': 0.8173400673400674,\n",
" 'acc_stderr,none': 0.007928503719209124,\n",
" 'acc_norm,none': 0.8122895622895623,\n",
" 'acc_norm_stderr,none': 0.008012496274011486,\n",
" 'alias': 'arc_easy'}}"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"results[\"results\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dd34fe47-bb3c-4d91-9404-f37ba38ac1b0",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
10 changes: 8 additions & 2 deletions flute/integrations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import warnings
import argparse

from transformers import (
LlamaForCausalLM,
Gemma2ForCausalLM,
Expand All @@ -18,7 +19,9 @@
import flute.utils
import flute.nf_utils
import flute.integrations.bitsandbytes
import flute.integrations.learnable

from flute.integrations.learnable import LearnableQuantizedLinear

def get_accelerate_hook(name: str, module: torch.nn.Module, allow: bool) -> Optional[ModelHook]:

Expand Down Expand Up @@ -62,7 +65,7 @@ def _replace_linear(_name: str, _module: torch.nn.Module) -> None:

child_full_name = f"{_name}.{child_name}"

if isinstance(child, torch.nn.Linear):
if isinstance(child, torch.nn.Linear) or isinstance(child, LearnableQuantizedLinear):

if isinstance(child, BNBLinear4bit):
if child.weight.dtype not in [torch.uint8]:
Expand Down Expand Up @@ -130,7 +133,10 @@ def _replace_linear(_name: str, _module: torch.nn.Module) -> None:
if custom_scales_dict is not None:
custom_scales = custom_scales_dict[child_full_name]
else:
custom_scales = None
if isinstance(child, LearnableQuantizedLinear):
custom_scales = child.scales
else:
custom_scales = None

if not isinstance(child, BNBLinear4bit):
_, _Q, scales, qmap = flute.nf_utils.nf_quantize(
Expand Down
Loading