-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor: Use Llama RoPE implementation for Falcon #26933
Refactor: Use Llama RoPE implementation for Falcon #26933
Conversation
+ Add copy functionalities
Before diving into the code, I tried the snipped you shared, but with an extra from transformers import AutoTokenizer, pipeline, set_seed
import torch
model = "tiiuae/falcon-7b"
set_seed(0)
tokenizer = AutoTokenizer.from_pretrained(model)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.bfloat16,
device_map="auto",
)
sequences = pipe(
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
max_length=300,
do_sample=True,
top_k=10,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
)
for seq in sequences:
print(f"Result: {seq['generated_text']}") There is a tiny difference in outputs, which we should try to figure out before merging (maybe I'll have some clues after looking at the diff) On
Using the latest commit here, the output is
👉 identical up to the last two sentences |
Note: adding if dtype in [torch.float16, torch.bfloat16]:
emb = emb.float() back doesn't fix the mismatch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes look good to me, thank you for working on them!
Re apply_rotary_pos_emb
, I think we can go one step further and copy it from Llama as well, if it makes downstream work easier. Looking at the code, we apply a reshape such that the first dim of QKV is batch_size * self.num_heads
, run apply_rotary_pos_emb
(and nothing for alibi), and then reshape back to batch_size, self.num_heads, ...
. It seems like we would save two sets of reshape operations if we do it the Llama way 🤔
In addition to this, we should assess:
a) if this minor mismatch is relevant, perhaps by measuring ppl on some test set
b) if there are throughput changes
All of that sounds great. I didn't notice the potential to remove some reshapes, that would definitely be useful. And I'll reuse my benchmarking tools that I made for Will work on this tomorrow!
|
We don't need to convert any caches anymore!
I have some more results @gante 🎉 I've done tests for PerplexityI've tried 3 different books, and for all of them there are some notable differences in favor of If I try a different book, again there's some notable differences in favor of Yet another book, again large differences in favor of VRAM
LatencyThe latency for v2 definitely improves over One thing: I get a test failure locally for This also simplifies implementing the Attention Sink Cache for Falcon 🎉
|
I'll pull this into #26681, i.e. the caching refactor PR, when this is merged. Then I can conveniently test whether the implementation is easily extended for Falcon now too. |
Resolved the merge conflicts introduced by #26792 by @patrickvonplaten, the latency seems slightly better than reported in #26933 (comment) now. This PR should be good to go as far as I'm concerned. Idem dito for #26929 which solves some minor issues with the Falcon config. |
Woah, that is a great analysis of the changes! Great work Tom 💛 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Less code, better modelling performance, and higher throughput -- the dream PR 💛
Tagging @ArthurZucker (core maintainer) for a quick final check |
FYI @tomaarsen: a major cause for the PPL mismatch seems to stem from how
The following |
That must be it indeed! In transformers/src/transformers/models/falcon/modeling_falcon.py Lines 247 to 249 in 576e282
which is called like so: transformers/src/transformers/models/falcon/modeling_falcon.py Lines 262 to 267 in 576e282
which is called like so: transformers/src/transformers/models/falcon/modeling_falcon.py Lines 278 to 280 in 576e282
So, if the model is loaded in fp16, then so is the query , which gives t = torch.arange(seq_len, device=device).to(torch.float16) . This results in:
>>> torch.arange(3000, dtype=torch.float16)[-10:]
tensor([2990., 2992., 2992., 2992., 2994., 2996., 2996., 2996., 2998., 3000.],
dtype=torch.float16) Alternatively, if the model is loaded in bf16, then it uses >>> torch.arange(3000, dtype=torch.bfloat16)[-10:]
tensor([2992., 2992., 2992., 2992., 2992., 2992., 2992., 2992., 2992., 2992.],
dtype=torch.bfloat16) which is just awful! With this PR, we get: transformers/src/transformers/models/falcon/modeling_falcon.py Lines 256 to 258 in dcde537
and transformers/src/transformers/models/falcon/modeling_falcon.py Lines 248 to 249 in dcde537
Which is always float32 due to the .float() call. So, we get t = torch.arange(seq_len, device=device).to(torch.float32) which results in:
>>> torch.arange(3000, dtype=torch.float32)[-10:]
tensor([2990., 2991., 2992., 2993., 2994., 2995., 2996., 2997., 2998., 2999.]) So, this PR solves a hidden bug that has been resulting in reduced performance at higher sequence lengths. After all, the problem only gets worse at higher seq lengths, e.g.: >>> torch.arange(30000, dtype=torch.bfloat16)[-10:]
tensor([29952., 29952., 29952., 29952., 29952., 29952., 29952., 29952., 29952.,
29952.], dtype=torch.bfloat16)
|
I've resolved the outstanding merge conflicts. This should be ready for final review @ArthurZucker. I've verified that the results are identical to what I previously plotted here. I want to point out: This PR is completely unrelated to attention sinks (although it will eventually help with the implementation). These findings are for regular, pure Using these scripts you can reproduce my findings: perplexity.py"""
Adapted from https://github.com/mit-han-lab/streaming-llm
Note: Although this script measures latency, it is not optimized whatsoever!
The latency is only tracked to see the impact of speed over time.
Usage:
python benchmark/perplexity.py --experiment attention_sinks
python benchmark/perplexity.py --experiment transformers
python benchmark/perplexity.py --experiment windowed
"""
import argparse
import itertools
import time
from collections import defaultdict
from pathlib import Path
from typing import Optional
import pandas as pd
import torch
from datasets import load_dataset
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
def compute_perplexity(
model,
tokenizer,
dataset,
experiment: str,
output_dir: str = "outputs",
data_column: str = "text",
num_samples: int = 1,
num_tokens: Optional[int] = None,
overwrite: bool = False,
) -> None:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_file = output_dir / f"{experiment}.csv"
if output_file.exists() and not overwrite:
raise ValueError(
f"The {output_file!r} output file already exists - if you really want to override it, then use `--overwrite`."
)
logs = defaultdict(list)
loss_fn = CrossEntropyLoss(reduction="none")
past_key_values = None
num_processed_tokens = 0
for text in itertools.islice(dataset, num_samples):
encodings = tokenizer(text[data_column], return_tensors="pt")
seq_len = encodings.input_ids.size(1)
print(f"sequence length: {seq_len}")
pbar = tqdm(range(0, seq_len - 1))
for idx in pbar:
start_t = time.time()
input_ids = encodings.input_ids[:, idx : idx + 1].to(model.device)
with torch.no_grad():
outputs = model(input_ids, past_key_values=past_key_values, use_cache=True)
logits = outputs.logits.view(-1, model.config.vocab_size)
past_key_values = outputs.past_key_values
label = encodings.input_ids[:, idx + 1 : idx + 2].to(logits.device).view(-1)
neg_log_likelihood = loss_fn(logits, label)
perplexity = neg_log_likelihood.exp()
pbar.set_description(f"nll: {neg_log_likelihood.item():>5.2f}, ppl: {perplexity.item():>8.2f}")
# Store data and save every 10 tokens
logs["input_length"].append(idx + 1)
logs["nll"].append(neg_log_likelihood.item())
logs["ppl"].append(perplexity.item())
logs["overall_ppl"].append(torch.tensor(logs["nll"]).mean().exp().item())
logs["cuda_vram_allocated"].append(torch.cuda.memory_allocated(0) / 1024 / 1024 / 1024) # in GB
logs["latency"].append(time.time() - start_t)
if num_processed_tokens % 10 == 0:
try:
pd.DataFrame(logs).to_csv(output_file, index=False)
except KeyboardInterrupt as ex:
# If there's a Keyboard Interrupt, still write the file, and then stop
pd.DataFrame(logs).to_csv(output_file, index=False)
raise ex
num_processed_tokens += 1
if num_tokens and num_processed_tokens >= num_tokens:
return
def main():
parser = argparse.ArgumentParser()
# How to call this experiment?
parser.add_argument(
"--experiment", type=str, default="main"
)
# Model args
parser.add_argument("--model_name_or_path", type=str, default="tiiuae/falcon-7b")
parser.add_argument("--revision", type=str, default="main")
parser.add_argument("--trust_remote_code", action="store_true")
# Dataset args
parser.add_argument("--dataset_name", type=str, default="emozilla/pg19-test")
parser.add_argument("--data_column", type=str, default="text")
parser.add_argument("--task", type=str, default=None)
parser.add_argument("--split", type=str, default="test", choices=["validation", "test"])
# parser.add_argument("--num_samples", type=int, default=1)
parser.add_argument("--num_tokens", type=int, default=5000)
# Where to log
parser.add_argument("--output_dir", type=str, default="perplexity_benchmark")
parser.add_argument("--overwrite", action="store_true")
args = parser.parse_args()
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
revision=args.revision,
trust_remote_code=bool(args.trust_remote_code),
torch_dtype=torch.float16,
device_map="auto",
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=bool(args.trust_remote_code))
# Set up the dataset
dataset = load_dataset(args.dataset_name, args.task, split=args.split, streaming=True)
compute_perplexity(
model,
tokenizer,
dataset,
args.experiment,
output_dir=args.output_dir,
data_column=args.data_column,
num_samples=1, # <- No support for more than one instance now
num_tokens=args.num_tokens,
overwrite=args.overwrite,
)
if __name__ == "__main__":
main() plot_perplexity.py"""
First run `perplexity.py` to generate one or more `csv` files.
This script can plot those csv files.
Usage:
python benchmark/plot_perplexity.py
python benchmark/plot_perplexity.py --features perplexity latency --title "Log perplexity & latency of Llama 2 7B as a function of input lengths"
"""
import argparse
from pathlib import Path
from typing import List, Optional
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
FEATURE_DF_MAP = {
"perplexity": "overall_ppl",
"vram": "cuda_vram_allocated",
"latency": "latency",
}
FEATURE_STYLE_MAP = {
"perplexity": "-",
"vram": "--",
"latency": ":",
}
FEATURE_LABEL_MAP = {
"perplexity": "Perplexity (log), lower is better",
"vram": "CUDA VRAM Usage (GB), lower is better",
"latency": "Time per token (sec), lower is better",
}
def plot(
features: List[str],
output_dir: str = "outputs",
title: Optional[str] = None,
perplexity_limit: Optional[float] = None,
skip_first: int = 100,
):
output_dir = Path(output_dir)
fig, ax = plt.subplots()
ax.set_xlabel("Input Sequence Length")
for feature_i, feature in enumerate(features):
# If we already plotted on this ax, make a new one
if feature_i:
ax = ax.twinx()
for file in output_dir.glob("*.csv"):
experiment = file.stem
df = pd.read_csv(file)
X = df["input_length"][skip_first:]
Y = df[FEATURE_DF_MAP[feature]][skip_first:]
if feature == "perplexity":
Y = np.log(Y)
if feature == "latency":
poly = np.polyfit(X, Y, 20)
poly_y = np.poly1d(poly)(X)
ax.plot(X, poly_y, FEATURE_STYLE_MAP[feature], label=f"{experiment} {feature}")
else:
ax.plot(X, Y, FEATURE_STYLE_MAP[feature], label=f"{experiment} {feature}")
ax.set_ylabel(FEATURE_LABEL_MAP[feature])
if perplexity_limit and feature == "perplexity":
ax.set_ylim(top=min(ax.get_ylim()[1], perplexity_limit))
ax.legend(loc=[1, 2, 7][feature_i]) # upper right, upper left, center right
ax.set_title(title.replace("\\n", "\n") if title else "Log perplexity as a function of input lengths")
fig.tight_layout()
return fig
def main():
parser = argparse.ArgumentParser()
# Where csv files have been logged
parser.add_argument("--output_dir", type=str, default="perplexity_benchmark")
parser.add_argument(
"--features", choices=["perplexity", "vram", "latency"], nargs="+", default=["perplexity", "vram"]
)
parser.add_argument("--title", type=str, default=None)
parser.add_argument("--log_perplexity_limit", type=float, default=5.0)
# Perplexity starts a bit unstable, so we skip the start
parser.add_argument("--skip_first", type=int, default=100)
args = parser.parse_args()
figure = plot(
args.features,
output_dir=args.output_dir,
title=args.title,
perplexity_limit=args.log_perplexity_limit,
skip_first=args.skip_first,
)
# Add your own code here if you'd like to change the figure
plt.show()
if __name__ == "__main__":
main() Usage: git checkout main
# the --experiment just determines the filename
python ./perplexity.py --experiment main
git checkout pr-26933 # <- Or whatever branch you use locally for this PR
python ./perplexity.py --experiment llama_rope_for_falcon
python ./plot_perplexity.py And you'll get a plot just like here.
|
The documentation is not available anymore as the PR was closed or merged. |
(updating core maintainer to review :) ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow - a really great PR! 🔥
Thank you for all the work in refactoring this, writing up such detailed explanations and all of the investigative work.
* Use Llama RoPE implementation for Falcon + Add copy functionalities * Use standard cache format for Falcon * Simplify apply_rotary_pos_emb, copy from Llama * Remove unnecessary cache conversion test We don't need to convert any caches anymore! * Resolve copy complaint
What does this PR do?
transformers/src/transformers/models/falcon/modeling_falcon.py
Line 91 in ad08137
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Discussed this internally with @gante.
Details
There's a few differences between Llama and Falcon that complicate this somewhat. In particular, Llama deals with
[batch_size, num_(kv)_heads, seq_len, head_dim]
on the KVQ states, while Falcon uses[batch_size * num_(kv)_heads, seq_len, head_dim]
, i.e. one dimension less.This is why
apply_rotary_pos_emb
usestorch.repeat_interleave
a few times to manually expand the cos, sin when necessary. This used to be in theforward
of theFalconRotaryEmbedding
.There are still some differences between the old and new implementations:
.cos()
and.sin()
in float32:transformers/src/transformers/models/falcon/modeling_falcon.py
Lines 115 to 116 in ad08137
forward
call, while the Llama RoPE doesn't:transformers/src/transformers/models/falcon/modeling_falcon.py
Lines 131 to 133 in ad08137
(Should we also implement this on Llama? In case someone wants to move a model on the fly)
In the context of Attention Sinks
For Attention Sinks (context: #26681), it looks like the SinkCache must store a
apply_rotary_pos_emb
variable or something - because for Falcon it will need to use a differentapply_rotary_pos_emb
function.How did I test?
I tested this by running
pytest tests/models/falcon
andand observing that the generated text was:
Which is of similar quality of the
main
branch:Note: I didn't run this with Falcon 40b! I would heavily recommend modifying the above script with 40b and ensuring that it runs correctly with that model too. Falcon 40b uses
"new_decoder_architecture": true
while 7b uses"new_decoder_architecture": false
.Who can review?
@gante