Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Use Llama RoPE implementation for Falcon #26933

Merged
merged 7 commits into from
Nov 3, 2023

Conversation

tomaarsen
Copy link
Member

What does this PR do?

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Discussed this internally with @gante.

Details

There's a few differences between Llama and Falcon that complicate this somewhat. In particular, Llama deals with [batch_size, num_(kv)_heads, seq_len, head_dim] on the KVQ states, while Falcon uses [batch_size * num_(kv)_heads, seq_len, head_dim], i.e. one dimension less.

This is why apply_rotary_pos_emb uses torch.repeat_interleave a few times to manually expand the cos, sin when necessary. This used to be in the forward of the FalconRotaryEmbedding.

There are still some differences between the old and new implementations:

In the context of Attention Sinks

For Attention Sinks (context: #26681), it looks like the SinkCache must store a apply_rotary_pos_emb variable or something - because for Falcon it will need to use a different apply_rotary_pos_emb function.

How did I test?

I tested this by running pytest tests/models/falcon and

from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
import torch

model = "tiiuae/falcon-7b"

tokenizer = AutoTokenizer.from_pretrained(model)
pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
sequences = pipeline(
   "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
    max_length=300,
    do_sample=True,
    top_k=10,
    num_return_sequences=1,
    eos_token_id=tokenizer.eos_token_id,
)
for seq in sequences:
    print(f"Result: {seq['generated_text']}")

and observing that the generated text was:

Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.
Daniel: Hello, Girafatron!
Girafatron: (inhales deeply, then exhales and makes a sound like a foghorn)
Daniel: I see you're a giraffe enthusiast. That's really cool!
Girafatron: (inhales again and lets out a sigh, his ears perked up)
Daniel: I'm not sure why I asked if you were a giraffe enthusiast! But it seems like you love giraffes.
Girafatron: (sighs)
Girafatron: Giraffes are my favorite animal, because they're the most glorious animal on this planet!
Daniel:. Giraffes are pretty awesome animals, I agree.
Girafatron: (sighs, exhales loudly)
Girafatron: I think I could go on forever about how much I love these majestic animals, so...
Daniel: You can tell me about that later, if you want to! I just wanted to ask you a quick question.
Daniel: Why are you obsessed with giraffes? What about giraffes fascinates you so much?

Which is of similar quality of the main branch:

Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.
Daniel: Hello, Girafatron!
Girafatron: What is your name, Daniel?
Daniel: I am Daniel.
Girafatron:...
Girafatron is a giraffe. Girafatron is also a robot.
Girafatron is a robot that is obsessed with giraffes. Girafatron believes that all other creatures on this planet are completely and utterly irrelevant.
Girafatron is the best robot to have ever existed and Girafatron has no equal on this planet.
Girafatron is a robot. Girafatron is obsessed with giraffes. Girafatron is also a robot.
Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are completely and utterly irrelevant when compared to the glorious majesty of the giraffe.
Girafatron is also a robot. It is a robot that is obsessed with giraffes.
Girafatron is a robot. Girafatron is obsessed with giraffes.
Giraftron is a robot, and it is obsessed

Note: I didn't run this with Falcon 40b! I would heavily recommend modifying the above script with 40b and ensuring that it runs correctly with that model too. Falcon 40b uses "new_decoder_architecture": true while 7b uses "new_decoder_architecture": false.

Who can review?

@gante

  • Tom Aarsen

+ Add copy functionalities
@gante gante self-requested a review October 19, 2023 15:53
@gante
Copy link
Member

gante commented Oct 19, 2023

Before diving into the code, I tried the snipped you shared, but with an extra set_seed(0) for complete reproducibility at the sampling step.

from transformers import AutoTokenizer, pipeline, set_seed
import torch

model = "tiiuae/falcon-7b"

set_seed(0)
tokenizer = AutoTokenizer.from_pretrained(model)
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
sequences = pipe(
    "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
    max_length=300,
    do_sample=True,
    top_k=10,
    num_return_sequences=1,
    eos_token_id=tokenizer.eos_token_id,
)
for seq in sequences:
    print(f"Result: {seq['generated_text']}")

There is a tiny difference in outputs, which we should try to figure out before merging (maybe I'll have some clues after looking at the diff)

On main, the output is

Result: Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.
Daniel: Hello, Girafatron!
Girafatron: Daniel, it is good that you have called upon me for I have many important matters in which I must attend to.
Daniel: Yes, and I will be brief for I must attend to some urgent matters myself.
Girafatron: You have urgent matters? What are these urgent matters?
Daniel: It has come to my attention that you seem to have a problem, a problem with being a giraffe.
Girafatron: What are you saying? That I have a problem? What kind of problem?
Daniel: You seem to be in denial.
Girafatron: Of what?
Daniel: Of your love for giraffes, of course!
Girafatron: How do you know that I’m in denial?
Daniel: Because you have not been able to get a giraffe to be your friend for all these years.
Girafatron: That’s not a problem.
Daniel: Yes, it is! How are you supposed to find love without being able to love?
Girafatron: I’ll tell you how I find love.
Daniel

Using the latest commit here, the output is

Result: Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.
Daniel: Hello, Girafatron!
Girafatron: Daniel, it is good that you have called upon me for I have many important matters in which I must attend to.
Daniel: Yes, and I will be brief for I must attend to some urgent matters myself.
Girafatron: You have urgent matters? What are these urgent matters?
Daniel: It has come to my attention that you seem to have a problem, a problem with being a giraffe.
Girafatron: What are you saying? That I have a problem? What kind of problem?
Daniel: You seem to be in denial.
Girafatron: Of what?
Daniel: Of your love for giraffes, of course!
Girafatron: How do you know that I’m in denial?
Daniel: Because you have not been able to get a giraffe to be your friend for all these years.
Girafatron: That’s not a problem.
Daniel: Yes, it is! How are you supposed to find love if you don’t go out and try to find a giraffe?
Girafatron: I don’

👉 identical up to the last two sentences

@gante
Copy link
Member

gante commented Oct 19, 2023

Note: adding

if dtype in [torch.float16, torch.bfloat16]:
    emb = emb.float()

back doesn't fix the mismatch

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes look good to me, thank you for working on them!

Re apply_rotary_pos_emb, I think we can go one step further and copy it from Llama as well, if it makes downstream work easier. Looking at the code, we apply a reshape such that the first dim of QKV is batch_size * self.num_heads, run apply_rotary_pos_emb (and nothing for alibi), and then reshape back to batch_size, self.num_heads, .... It seems like we would save two sets of reshape operations if we do it the Llama way 🤔

In addition to this, we should assess:
a) if this minor mismatch is relevant, perhaps by measuring ppl on some test set
b) if there are throughput changes

@tomaarsen
Copy link
Member Author

All of that sounds great. I didn't notice the potential to remove some reshapes, that would definitely be useful. And I'll reuse my benchmarking tools that I made for attention_sinks, but use it for transformers main and this PR to get info on latencies and ppl.

Will work on this tomorrow!

  • Tom Aarsen

@tomaarsen
Copy link
Member Author

tomaarsen commented Oct 20, 2023

I have some more results @gante 🎉

I've done tests for main, this PR after c4a8d52 (which I called pr_v1) and this PR after e68b8e4 (which I called pr_v2).

Perplexity

I've tried 3 different books, and for all of them there are some notable differences in favor of pr_v2:
image

If I try a different book, again there's some notable differences in favor of pr_v2:
image

Yet another book, again large differences in favor of pr_v2:
image

VRAM

image
There's two groups: main as one group, and pr_v1 and pr_v2 as the other. It seems that the Llama RoPE implementation immediately cashes more cos/sin values, hence a slightly larger memory usage at the start.

Latency

The latency for v2 definitely improves over main, which is very promising. Here's results for the 3 different books that I tested with:
image
image
image
I think with the second book, my OS de-prioritized the process for main, so that's a bit of an unrelated outlier.
For book 1, the average tokens per second was 18.04 for pr_v2 and 14.13 for main, a 27.65% speedup. For book 3, I got 17.54 tokens/sec for pr_v2 and 14.82 for main, i.e. an 18.38% speedup.


One thing: I get a test failure locally for tests/models/falcon/test_modeling_falcon.py::FalconModelTest::test_left_padding_compatibility due to some nan popping up where the attention_mask is 0. Not sure if this is anything to be concerned about, as it does pass on the CI...

This also simplifies implementing the Attention Sink Cache for Falcon 🎉

  • Tom Aarsen

@tomaarsen
Copy link
Member Author

I'll pull this into #26681, i.e. the caching refactor PR, when this is merged. Then I can conveniently test whether the implementation is easily extended for Falcon now too.

@tomaarsen
Copy link
Member Author

tomaarsen commented Oct 24, 2023

Resolved the merge conflicts introduced by #26792 by @patrickvonplaten, the latency seems slightly better than reported in #26933 (comment) now. This PR should be good to go as far as I'm concerned. Idem dito for #26929 which solves some minor issues with the Falcon config.

@gante
Copy link
Member

gante commented Oct 24, 2023

Woah, that is a great analysis of the changes! Great work Tom 💛

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Less code, better modelling performance, and higher throughput -- the dream PR 💛

@gante gante requested a review from ArthurZucker October 24, 2023 10:41
@gante
Copy link
Member

gante commented Oct 24, 2023

Tagging @ArthurZucker (core maintainer) for a quick final check

@gante
Copy link
Member

gante commented Oct 24, 2023

FYI @tomaarsen: a major cause for the PPL mismatch seems to stem from how t was being cast to a lower precision in the older version of the code:

(Pdb) torch.arange(3000, device=device).to(torch.float16)[-10:]
tensor([2990., 2992., 2992., 2992., 2994., 2996., 2996., 2996., 2998., 3000.],
       device='cuda:0', dtype=torch.float16)
(Pdb) torch.arange(3000, device=device).to(torch.float32)[-10:]
tensor([2990., 2991., 2992., 2993., 2994., 2995., 2996., 2997., 2998., 2999.],
       device='cuda:0')

The following einsum is mischievous, as it automatically upcasts t to be compatible with the other input!

@tomaarsen
Copy link
Member Author

That must be it indeed!

In main, we have:

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=device).to(dtype)

which is called like so:
def cos_sin(
self, seq_len: int, past_key_values_length: int, position_ids: torch.Tensor, device="cpu", dtype=torch.bfloat16
) -> torch.Tensor:
total_length = seq_len + past_key_values_length
if total_length > self.seq_len_cached:
self._set_cos_sin_cache(total_length, device, dtype)

which is called like so:
def forward(self, query, key, past_key_values_length, position_ids):
_, seq_len, _ = query.shape
cos, sin = self.cos_sin(seq_len, past_key_values_length, position_ids, query.device, query.dtype)

So, if the model is loaded in fp16, then so is the query, which gives t = torch.arange(seq_len, device=device).to(torch.float16). This results in:

>>> torch.arange(3000, dtype=torch.float16)[-10:]
tensor([2990., 2992., 2992., 2992., 2994., 2996., 2996., 2996., 2998., 3000.],
       dtype=torch.float16)

Alternatively, if the model is loaded in bf16, then it uses t = torch.arange(seq_len, device=device).to(torch.bfloat16) and we get:

>>> torch.arange(3000, dtype=torch.bfloat16)[-10:]
tensor([2992., 2992., 2992., 2992., 2992., 2992., 2992., 2992., 2992., 2992.],
       dtype=torch.bfloat16)

which is just awful!


With this PR, we get:

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

and inv_freq is this:

inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)

Which is always float32 due to the .float() call. So, we get t = torch.arange(seq_len, device=device).to(torch.float32) which results in:

>>> torch.arange(3000, dtype=torch.float32)[-10:]
tensor([2990., 2991., 2992., 2993., 2994., 2995., 2996., 2997., 2998., 2999.])

So, this PR solves a hidden bug that has been resulting in reduced performance at higher sequence lengths. After all, the problem only gets worse at higher seq lengths, e.g.:

>>> torch.arange(30000, dtype=torch.bfloat16)[-10:]
tensor([29952., 29952., 29952., 29952., 29952., 29952., 29952., 29952., 29952.,
        29952.], dtype=torch.bfloat16)
  • Tom Aarsen

@tomaarsen
Copy link
Member Author

I've resolved the outstanding merge conflicts. This should be ready for final review @ArthurZucker. I've verified that the results are identical to what I previously plotted here. I want to point out: This PR is completely unrelated to attention sinks (although it will eventually help with the implementation). These findings are for regular, pure transformers usage.

Using these scripts you can reproduce my findings:

perplexity.py
"""
Adapted from https://github.com/mit-han-lab/streaming-llm

Note: Although this script measures latency, it is not optimized whatsoever!
The latency is only tracked to see the impact of speed over time.

Usage:

python benchmark/perplexity.py --experiment attention_sinks
python benchmark/perplexity.py --experiment transformers
python benchmark/perplexity.py --experiment windowed
"""


import argparse
import itertools
import time
from collections import defaultdict
from pathlib import Path
from typing import Optional

import pandas as pd
import torch
from datasets import load_dataset
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM

def compute_perplexity(
    model,
    tokenizer,
    dataset,
    experiment: str,
    output_dir: str = "outputs",
    data_column: str = "text",
    num_samples: int = 1,
    num_tokens: Optional[int] = None,
    overwrite: bool = False,
) -> None:
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    output_file = output_dir / f"{experiment}.csv"

    if output_file.exists() and not overwrite:
        raise ValueError(
            f"The {output_file!r} output file already exists - if you really want to override it, then use `--overwrite`."
        )

    logs = defaultdict(list)
    loss_fn = CrossEntropyLoss(reduction="none")
    past_key_values = None
    num_processed_tokens = 0
    for text in itertools.islice(dataset, num_samples):
        encodings = tokenizer(text[data_column], return_tensors="pt")

        seq_len = encodings.input_ids.size(1)
        print(f"sequence length: {seq_len}")
        pbar = tqdm(range(0, seq_len - 1))

        for idx in pbar:
            start_t = time.time()
            input_ids = encodings.input_ids[:, idx : idx + 1].to(model.device)
            with torch.no_grad():
                outputs = model(input_ids, past_key_values=past_key_values, use_cache=True)
                logits = outputs.logits.view(-1, model.config.vocab_size)
                past_key_values = outputs.past_key_values
                label = encodings.input_ids[:, idx + 1 : idx + 2].to(logits.device).view(-1)
                neg_log_likelihood = loss_fn(logits, label)
                perplexity = neg_log_likelihood.exp()
            pbar.set_description(f"nll: {neg_log_likelihood.item():>5.2f}, ppl: {perplexity.item():>8.2f}")

            # Store data and save every 10 tokens
            logs["input_length"].append(idx + 1)
            logs["nll"].append(neg_log_likelihood.item())
            logs["ppl"].append(perplexity.item())
            logs["overall_ppl"].append(torch.tensor(logs["nll"]).mean().exp().item())
            logs["cuda_vram_allocated"].append(torch.cuda.memory_allocated(0) / 1024 / 1024 / 1024)  # in GB
            logs["latency"].append(time.time() - start_t)
            if num_processed_tokens % 10 == 0:
                try:
                    pd.DataFrame(logs).to_csv(output_file, index=False)
                except KeyboardInterrupt as ex:
                    # If there's a Keyboard Interrupt, still write the file, and then stop
                    pd.DataFrame(logs).to_csv(output_file, index=False)
                    raise ex

            num_processed_tokens += 1
            if num_tokens and num_processed_tokens >= num_tokens:
                return


def main():
    parser = argparse.ArgumentParser()
    # How to call this experiment?
    parser.add_argument(
        "--experiment", type=str, default="main"
    )

    # Model args
    parser.add_argument("--model_name_or_path", type=str, default="tiiuae/falcon-7b")
    parser.add_argument("--revision", type=str, default="main")
    parser.add_argument("--trust_remote_code", action="store_true")

    # Dataset args
    parser.add_argument("--dataset_name", type=str, default="emozilla/pg19-test")
    parser.add_argument("--data_column", type=str, default="text")
    parser.add_argument("--task", type=str, default=None)
    parser.add_argument("--split", type=str, default="test", choices=["validation", "test"])
    # parser.add_argument("--num_samples", type=int, default=1)
    parser.add_argument("--num_tokens", type=int, default=5000)

    # Where to log
    parser.add_argument("--output_dir", type=str, default="perplexity_benchmark")
    parser.add_argument("--overwrite", action="store_true")

    args = parser.parse_args()

    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        revision=args.revision,
        trust_remote_code=bool(args.trust_remote_code),
        torch_dtype=torch.float16,
        device_map="auto",
    )
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=bool(args.trust_remote_code))

    # Set up the dataset
    dataset = load_dataset(args.dataset_name, args.task, split=args.split, streaming=True)

    compute_perplexity(
        model,
        tokenizer,
        dataset,
        args.experiment,
        output_dir=args.output_dir,
        data_column=args.data_column,
        num_samples=1,  # <- No support for more than one instance now
        num_tokens=args.num_tokens,
        overwrite=args.overwrite,
    )


if __name__ == "__main__":
    main()
plot_perplexity.py
"""
First run `perplexity.py` to generate one or more `csv` files.
This script can plot those csv files.

Usage:
python benchmark/plot_perplexity.py
python benchmark/plot_perplexity.py --features perplexity latency --title "Log perplexity & latency of Llama 2 7B as a function of input lengths"
"""

import argparse
from pathlib import Path
from typing import List, Optional

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

FEATURE_DF_MAP = {
    "perplexity": "overall_ppl",
    "vram": "cuda_vram_allocated",
    "latency": "latency",
}
FEATURE_STYLE_MAP = {
    "perplexity": "-",
    "vram": "--",
    "latency": ":",
}
FEATURE_LABEL_MAP = {
    "perplexity": "Perplexity (log), lower is better",
    "vram": "CUDA VRAM Usage (GB), lower is better",
    "latency": "Time per token (sec), lower is better",
}


def plot(
    features: List[str],
    output_dir: str = "outputs",
    title: Optional[str] = None,
    perplexity_limit: Optional[float] = None,
    skip_first: int = 100,
):
    output_dir = Path(output_dir)

    fig, ax = plt.subplots()
    ax.set_xlabel("Input Sequence Length")

    for feature_i, feature in enumerate(features):
        # If we already plotted on this ax, make a new one
        if feature_i:
            ax = ax.twinx()

        for file in output_dir.glob("*.csv"):
            experiment = file.stem
            df = pd.read_csv(file)
            X = df["input_length"][skip_first:]
            Y = df[FEATURE_DF_MAP[feature]][skip_first:]
            if feature == "perplexity":
                Y = np.log(Y)
            if feature == "latency":
                poly = np.polyfit(X, Y, 20)
                poly_y = np.poly1d(poly)(X)
                ax.plot(X, poly_y, FEATURE_STYLE_MAP[feature], label=f"{experiment} {feature}")
            else:
                ax.plot(X, Y, FEATURE_STYLE_MAP[feature], label=f"{experiment} {feature}")

        ax.set_ylabel(FEATURE_LABEL_MAP[feature])
        if perplexity_limit and feature == "perplexity":
            ax.set_ylim(top=min(ax.get_ylim()[1], perplexity_limit))

        ax.legend(loc=[1, 2, 7][feature_i])  # upper right, upper left, center right

    ax.set_title(title.replace("\\n", "\n") if title else "Log perplexity as a function of input lengths")
    fig.tight_layout()

    return fig


def main():
    parser = argparse.ArgumentParser()
    # Where csv files have been logged
    parser.add_argument("--output_dir", type=str, default="perplexity_benchmark")
    parser.add_argument(
        "--features", choices=["perplexity", "vram", "latency"], nargs="+", default=["perplexity", "vram"]
    )
    parser.add_argument("--title", type=str, default=None)
    parser.add_argument("--log_perplexity_limit", type=float, default=5.0)
    # Perplexity starts a bit unstable, so we skip the start
    parser.add_argument("--skip_first", type=int, default=100)

    args = parser.parse_args()

    figure = plot(
        args.features,
        output_dir=args.output_dir,
        title=args.title,
        perplexity_limit=args.log_perplexity_limit,
        skip_first=args.skip_first,
    )

    # Add your own code here if you'd like to change the figure

    plt.show()


if __name__ == "__main__":
    main()

Usage:

git checkout main
# the --experiment just determines the filename
python ./perplexity.py --experiment main
git checkout pr-26933 # <- Or whatever branch you use locally for this PR
python ./perplexity.py --experiment llama_rope_for_falcon

python ./plot_perplexity.py

And you'll get a plot just like here.

  • Tom Aarsen

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 1, 2023

The documentation is not available anymore as the PR was closed or merged.

@gante gante requested review from amyeroberts and removed request for ArthurZucker November 2, 2023 16:26
@gante
Copy link
Member

gante commented Nov 2, 2023

(updating core maintainer to review :) )

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow - a really great PR! 🔥

Thank you for all the work in refactoring this, writing up such detailed explanations and all of the investigative work.

@gante gante merged commit 05ea7b7 into huggingface:main Nov 3, 2023
@tomaarsen tomaarsen deleted the feat/use_llama_rope_for_falcon branch November 3, 2023 11:13
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
* Use Llama RoPE implementation for Falcon

+ Add copy functionalities

* Use standard cache format for Falcon

* Simplify apply_rotary_pos_emb, copy from Llama

* Remove unnecessary cache conversion test

We don't need to convert any caches anymore!

* Resolve copy complaint
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants